diff --git a/README.md b/README.md index 527e5f0..0799cbd 100644 --- a/README.md +++ b/README.md @@ -113,14 +113,11 @@ All four commands accept: ## Library The `wormhole` module makes it possible for other applications to use these -code-protected channels. This includes blocking/synchronous support (for an -asymmetric pair of "initiator" and "receiver" endpoints), and async/Twisted -support (for a symmetric scheme). The main module is named +code-protected channels. This includes blocking/synchronous support and +async/Twisted support, both for a symmetric scheme. The main module is named `wormhole.blocking.transcribe`, to reflect that it is for synchronous/blocking code, and uses a PAKE mode whereby one user transcribes -their code to the other. (internal names may change in the future). The -synchronous support uses distinctive sides: one `Initiator`, and one -`Receiver`. +their code to the other. (internal names may change in the future). The file-transfer tools use a second module named `wormhole.blocking.transit`, which provides an encrypted record-pipe. It diff --git a/docs/api.md b/docs/api.md index 130566d..123184d 100644 --- a/docs/api.md +++ b/docs/api.md @@ -12,26 +12,48 @@ server" that relays information from one machine to the other. ## Modes -This library will eventually offer multiple modes. +This library will eventually offer multiple modes. For now, only "transcribe +mode" is available. -The first mode provided is "transcribe" mode. In this mode, one machine goes -first, and is called the "initiator". The initiator contacts the rendezvous -server and allocates a "channel ID", which is a small integer. The initiator -then displays the "invitation code", which is the channel-ID plus a few -secret words. The user copies the invitation code to the second machine, -called the "receiver". The receiver connects to the rendezvous server, and -uses the invitation code to contact the initiator. They agree upon an -encryption key, and exchange a small encrypted+authenticated data message. +Transcribe mode has two variants. In the "machine-generated" variant, the +"initiator" machine creates the invitation code, displays it to the first +user, they convey it (somehow) to the second user, who transcribes it into +the second ("receiver") machine. In the "human-generated" variant, the two +humans come up with the code (possibly without computers), then later +transcribe it into both machines. + +When the initator machine generates the invitation code, the initiator +contacts the rendezvous server and allocates a "channel ID", which is a small +integer. The initiator then displays the invitation code, which is the +channel-ID plus a few secret words. The user copies the code to the second +machine. The receiver machine connects to the rendezvous server, and uses the +invitation code to contact the initiator. They agree upon an encryption key, +and exchange a small encrypted+authenticated data message. + +When the humans create an invitation code out-of-band, they are responsible +for choosing an unused channel-ID (simply picking a random 3-or-more digit +number is probably enough), and some random words. The invitation code uses +the same format in either variant: channel-ID, a hyphen, and an arbitrary +string. + +The two machines participating in the wormhole setup are not distinguished: +it doesn't matter which one goes first, and both use the same Wormhole class. +In the first variant, one side calls `get_code()` while the other calls +`set_code()`. In the second variant, both sides call `set_code()`. Note that +this is not true for the "Transit" protocol used for bulk data-transfer: the +Transit class currently distinguishes "Sender" from "Receiver", so the +programs on each side must have some way to decide (ahead of time) which is +which. ## Examples The synchronous+blocking flow looks like this: ```python -from wormhole.transcribe import Initiator +from wormhole.transcribe import Wormhole from wormhole.public_relay import RENDEZVOUS_RELAY mydata = b"initiator's data" -i = Initiator("appid", RENDEZVOUS_RELAY) +i = Wormhole("appid", RENDEZVOUS_RELAY) code = i.get_code() print("Invitation Code: %s" % code) theirdata = i.get_data(mydata) @@ -40,11 +62,11 @@ print("Their data: %s" % theirdata.decode("ascii")) ```python import sys -from wormhole.transcribe import Receiver +from wormhole.transcribe import Wormhole from wormhole.public_relay import RENDEZVOUS_RELAY mydata = b"receiver's data" code = sys.argv[1] -r = Receiver("appid", RENDEZVOUS_RELAY) +r = Wormhole("appid", RENDEZVOUS_RELAY) r.set_code(code) theirdata = r.get_data(mydata) print("Their data: %s" % theirdata.decode("ascii")) @@ -57,9 +79,9 @@ The Twisted-friendly flow looks like this: ```python from twisted.internet import reactor from wormhole.public_relay import RENDEZVOUS_RELAY -from wormhole.twisted.transcribe import SymmetricWormhole +from wormhole.twisted.transcribe import Wormhole outbound_message = b"outbound data" -w1 = SymmetricWormhole("appid", RENDEZVOUS_RELAY) +w1 = Wormhole("appid", RENDEZVOUS_RELAY) d = w1.get_code() def _got_code(code): print "Invitation Code:", code @@ -75,9 +97,10 @@ reactor.run() On the other side, you call `set_code()` instead of waiting for `get_code()`: ```python -w2 = SymmetricWormhole("appid", RENDEZVOUS_RELAY) +w2 = Wormhole("appid", RENDEZVOUS_RELAY) w2.set_code(code) d = w2.get_data(my_message) +... ``` You can call `d=w.get_verifier()` before `get_data()`: this will perform the @@ -90,14 +113,14 @@ pausing. ## Generating the Invitation Code -In most situations, the Initiator will call `i.get_code()` to generate the -invitation code. This returns a string in the form `NNN-code-words`. The -numeric "NNN" prefix is the "channel id", and is a short integer allocated by -talking to the rendezvous server. The rest is a randomly-generated selection -from the PGP wordlist, providing a default of 16 bits of entropy. The -initiating program should display this code to the user, who should -transcribe it to the receiving user, who gives it to the Receiver object by -calling `r.set_code()`. The receiving program can also use +In most situations, the "sending" or "initiating" side will call +`i.get_code()` to generate the invitation code. This returns a string in the +form `NNN-code-words`. The numeric "NNN" prefix is the "channel id", and is a +short integer allocated by talking to the rendezvous server. The rest is a +randomly-generated selection from the PGP wordlist, providing a default of 16 +bits of entropy. The initiating program should display this code to the user, +who should transcribe it to the receiving user, who gives it to the Receiver +object by calling `r.set_code()`. The receiving program can also use `input_code_with_completion()` to use a readline-based input function: this offers tab completion of allocated channel-ids and known codewords. @@ -168,12 +191,12 @@ Both have defaults suitable for face-to-face realtime setup environments. TODO: only the Twisted form supports serialization so far -You may not be able to hold the Initiator/Receiver object in memory for the -whole sync process: maybe you allow it to wait for several days, but the -program will be restarted during that time. To support this, you can persist -the state of the object by calling `data = w.serialize()`, which will return -a printable bytestring (the JSON-encoding of a small dictionary). To restore, -use the `from_serialized(data)` classmethod (e.g. `w = +You may not be able to hold the Wormhole object in memory for the whole sync +process: maybe you allow it to wait for several days, but the program will be +restarted during that time. To support this, you can persist the state of the +object by calling `data = w.serialize()`, which will return a printable +bytestring (the JSON-encoding of a small dictionary). To restore, use the +`from_serialized(data)` classmethod (e.g. `w = SymmetricWormhole.from_serialized(data)`). There is exactly one point at which you can serialize the wormhole: *after* diff --git a/setup.py b/setup.py index 8b23c5d..7488beb 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,7 @@ setup(name="magic-wormhole", package_data={"wormhole": ["db-schemas/*.sql"]}, entry_points={"console_scripts": ["wormhole = wormhole.scripts.runner:entry"]}, - install_requires=["spake2==0.2", "pynacl", "requests", "argparse"], + install_requires=["spake2==0.3", "pynacl", "requests", "argparse"], test_suite="wormhole.test", cmdclass=commands, ) diff --git a/src/wormhole/blocking/transcribe.py b/src/wormhole/blocking/transcribe.py index eb76576..6a6c250 100644 --- a/src/wormhole/blocking/transcribe.py +++ b/src/wormhole/blocking/transcribe.py @@ -1,45 +1,20 @@ from __future__ import print_function -import sys, time, re, requests, json, textwrap +import os, sys, time, re, requests, json from binascii import hexlify, unhexlify -from spake2 import SPAKE2_A, SPAKE2_B +from spake2 import SPAKE2_Symmetric from nacl.secret import SecretBox from nacl.exceptions import CryptoError from nacl import utils from .eventsource import EventSourceFollower from .. import __version__ from .. import codes -from ..errors import ServerError +from ..errors import (ServerError, Timeout, WrongPasswordError, + ReflectionAttack, UsageError) from ..util.hkdf import HKDF SECOND = 1 MINUTE = 60*SECOND -class Timeout(Exception): - pass - -class WrongPasswordError(Exception): - """ - Key confirmation failed. - """ - # or the data blob was corrupted, and that's why decrypt failed - def explain(self): - return textwrap.dedent(self.__doc__) - -class InitiatorWrongPasswordError(WrongPasswordError): - """ - Key confirmation failed. Either your correspondent typed the code wrong, - or a would-be man-in-the-middle attacker guessed incorrectly. You could - try again, giving both your correspondent and the attacker another - chance. - """ - -class ReceiverWrongPasswordError(WrongPasswordError): - """ - Key confirmation failed. Either you typed the code wrong, or a would-be - man-in-the-middle attacker guessed incorrectly. You could try again, - giving both you and the attacker another chance. - """ - # relay URLs are: # GET /list -> {channel-ids: [INT..]} # POST /allocate/SIDE -> {channel-id: INT} @@ -49,16 +24,28 @@ class ReceiverWrongPasswordError(WrongPasswordError): # GET /CHANNEL-ID/SIDE/poll/MSGNUM (eventsource) -> STR, STR, .. # POST /CHANNEL-ID/SIDE/deallocate -> waiting | deleted -class Common: - def url(self, verb, msgnum=None): +class Wormhole: + motd_displayed = False + version_warning_displayed = False + + def __init__(self, appid, relay): + self.appid = appid + self.relay = relay + if not self.relay.endswith("/"): raise UsageError + self.started = time.time() + self.wait = 0.5*SECOND + self.timeout = 3*MINUTE + self.side = None + self.code = None + self.key = None + self.verifier = None + + def _url(self, verb, msgnum=None): url = "%s%d/%s/%s" % (self.relay, self.channel_id, self.side, verb) if msgnum is not None: url += "/" + msgnum return url - motd_displayed = False - version_warning_displayed = False - def handle_welcome(self, welcome): if ("motd" in welcome and not self.motd_displayed): @@ -81,33 +68,17 @@ class Common: if "error" in welcome: raise ServerError(welcome["error"], self.relay) - def get(self, old_msgs, verb, msgnum): - # For now, server errors cause the client to fail. TODO: don't. This - # will require changing the client to re-post messages when the - # server comes back up. + def _post_json(self, url, post_json=None): + # POST to a URL, parsing the response as JSON. Optionally include a + # JSON request body. + data = None + if post_json: + data = json.dumps(post_json).encode("utf-8") + r = requests.post(url, data=data) + r.raise_for_status() + return r.json() - # note: while this passes around msgs (plural), our callers really - # only care about the first one. we use "WHICH" and "SIDE" so that we - # only expect to see a single message (not our own, where "SIDE" is - # our own, and not messages for earlier stages, where "WHICH" is - # different) - msgs = old_msgs - while not msgs: - remaining = self.started + self.timeout - time.time() - if remaining < 0: - raise Timeout - #time.sleep(self.wait) - f = EventSourceFollower(self.url(verb, msgnum), remaining) - for (eventtype, data) in f.iter_events(): - if eventtype == "welcome": - self.handle_welcome(json.loads(data)) - if eventtype == "message": - msgs = [json.loads(data)["message"]] - break - f.close() - return msgs - - def _allocate(self): + def _allocate_channel(self): r = requests.post(self.relay + "allocate/%s" % self.side) r.raise_for_status() data = r.json() @@ -116,124 +87,15 @@ class Common: channel_id = data["channel-id"] return channel_id - def _post_pake(self): - msg = self.sp.start() - post_data = {"message": hexlify(msg).decode("ascii")} - r = requests.post(self.url("post", "pake"), data=json.dumps(post_data)) - r.raise_for_status() - other_msgs = r.json()["messages"] - return other_msgs - - def _get_pake(self, other_msgs): - msgs = self.get(other_msgs, "poll", "pake") - pake_msg = unhexlify(msgs[0].encode("ascii")) - key = self.sp.finish(pake_msg) - return key - - def _encrypt_data(self, key, data): - assert len(key) == SecretBox.KEY_SIZE - box = SecretBox(key) - nonce = utils.random(SecretBox.NONCE_SIZE) - return box.encrypt(data, nonce) - - def _post_data(self, data): - post_data = json.dumps({"message": hexlify(data).decode("ascii")}) - r = requests.post(self.url("post", "data"), data=post_data) - r.raise_for_status() - other_msgs = r.json()["messages"] - return other_msgs - - def _get_data(self, other_msgs): - msgs = self.get(other_msgs, "poll", "data") - data = unhexlify(msgs[0].encode("ascii")) - return data - - def _decrypt_data(self, key, encrypted): - assert len(key) == SecretBox.KEY_SIZE - box = SecretBox(key) - data = box.decrypt(encrypted) - return data - - def _deallocate(self): - r = requests.post(self.url("deallocate")) - r.raise_for_status() - - def derive_key(self, purpose, length=SecretBox.KEY_SIZE): - assert type(purpose) == type(b"") - return HKDF(self.key, length, CTXinfo=purpose) - -class Initiator(Common): - def __init__(self, appid, relay): - self.appid = appid - self.relay = relay - assert self.relay.endswith("/") - self.started = time.time() - self.wait = 0.5*SECOND - self.timeout = 3*MINUTE - self.side = "initiator" - self.key = None - self.verifier = None - - def set_code(self, code): # used for human-made pre-generated codes - mo = re.search(r'^(\d+)-', code) - if not mo: - raise ValueError("code (%s) must start with NN-" % code) - self.channel_id = int(mo.group(1)) - self.code = code - self.sp = SPAKE2_A(self.code.encode("ascii"), - idA=self.appid+":Initiator", - idB=self.appid+":Receiver") - self._post_pake() - def get_code(self, code_length=2): - channel_id = self._allocate() # allocate channel + if self.code is not None: raise UsageError + self.side = hexlify(os.urandom(5)) + channel_id = self._allocate_channel() # allocate channel code = codes.make_code(channel_id, code_length) - self.set_code(code) + self._set_code_and_channel_id(code) + self._start() return code - def _wait_for_key(self): - if not self.key: - key = self._get_pake([]) - self.key = key - self.verifier = self.derive_key(self.appid+b":Verifier") - - def get_verifier(self): - self._wait_for_key() - return self.verifier - - def get_data(self, outbound_data): - self._wait_for_key() - try: - outbound_key = self.derive_key(b"sender") - outbound_encrypted = self._encrypt_data(outbound_key, outbound_data) - other_msgs = self._post_data(outbound_encrypted) - - inbound_encrypted = self._get_data(other_msgs) - inbound_key = self.derive_key(b"receiver") - try: - inbound_data = self._decrypt_data(inbound_key, - inbound_encrypted) - except CryptoError: - raise InitiatorWrongPasswordError - finally: - self._deallocate() - return inbound_data - - -class Receiver(Common): - def __init__(self, appid, relay): - self.appid = appid - self.relay = relay - assert self.relay.endswith("/") - self.started = time.time() - self.wait = 0.5*SECOND - self.timeout = 3*MINUTE - self.side = "receiver" - self.code = None - self.channel_id = None - self.key = None - self.verifier = None - def list_channels(self): r = requests.get(self.relay + "list") r.raise_for_status() @@ -245,43 +107,117 @@ class Receiver(Common): code_length) return code - def set_code(self, code): - assert self.code is None - assert self.channel_id is None - self.code = code - self.channel_id = codes.extract_channel_id(code) - self.sp = SPAKE2_B(code.encode("ascii"), - idA=self.appid+":Initiator", - idB=self.appid+":Receiver") + def set_code(self, code): # used for human-made pre-generated codes + if self.code is not None: raise UsageError + if self.side is not None: raise UsageError + self._set_code_and_channel_id(code) + self.side = hexlify(os.urandom(5)) + self._start() - def _wait_for_key(self): + def _set_code_and_channel_id(self, code): + if self.code is not None: raise UsageError + mo = re.search(r'^(\d+)-', code) + if not mo: + raise ValueError("code (%s) must start with NN-" % code) + self.channel_id = int(mo.group(1)) + self.code = code + + def _start(self): + # allocate the rest now too, so it can be serialized + self.sp = SPAKE2_Symmetric(self.code.encode("ascii"), + idSymmetric=self.appid) + self.msg1 = self.sp.start() + + def _post_message(self, url, msg): + # TODO: retry on failure, with exponential backoff. We're guarding + # against the rendezvous server being temporarily offline. + if not isinstance(msg, type(b"")): raise UsageError(type(msg)) + resp = self._post_json(url, {"message": hexlify(msg).decode("ascii")}) + return resp["messages"] # other_msgs + + def _get_message(self, old_msgs, verb, msgnum): + # For now, server errors cause the client to fail. TODO: don't. This + # will require changing the client to re-post messages when the + # server comes back up. + + # fire with a bytestring of the first message that matches + # verb/msgnum, which either came from old_msgs, or from an + # EventSource that we attached to the corresponding URL + msgs = old_msgs + while not msgs: + remaining = self.started + self.timeout - time.time() + if remaining < 0: + raise Timeout + #time.sleep(self.wait) + f = EventSourceFollower(self._url(verb, msgnum), remaining) + for (eventtype, data) in f.iter_events(): + if eventtype == "welcome": + self.handle_welcome(json.loads(data)) + if eventtype == "message": + msgs = [json.loads(data)["message"]] + break + f.close() + return unhexlify(msgs[0].encode("ascii")) + + def derive_key(self, purpose, length=SecretBox.KEY_SIZE): + if not isinstance(purpose, type(b"")): raise UsageError + return HKDF(self.key, length, CTXinfo=purpose) + + def _encrypt_data(self, key, data): + if len(key) != SecretBox.KEY_SIZE: raise UsageError + box = SecretBox(key) + nonce = utils.random(SecretBox.NONCE_SIZE) + return box.encrypt(data, nonce) + + def _decrypt_data(self, key, encrypted): + if len(key) != SecretBox.KEY_SIZE: raise UsageError + box = SecretBox(key) + data = box.decrypt(encrypted) + return data + + + def _get_key(self): if not self.key: - other_msgs = self._post_pake() - key = self._get_pake(other_msgs) - self.key = key + old_msgs = self._post_message(self._url("post", "pake"), self.msg1) + pake_msg = self._get_message(old_msgs, "poll", "pake") + self.key = self.sp.finish(pake_msg) self.verifier = self.derive_key(self.appid+b":Verifier") def get_verifier(self): - self._wait_for_key() + if self.code is None: raise UsageError + if self.channel_id is None: raise UsageError + self._get_key() return self.verifier def get_data(self, outbound_data): - assert self.code is not None - assert self.channel_id is not None - self._wait_for_key() - + # only call this once + if self.code is None: raise UsageError + if self.channel_id is None: raise UsageError try: - outbound_key = self.derive_key(b"receiver") - outbound_encrypted = self._encrypt_data(outbound_key, outbound_data) - other_msgs = self._post_data(outbound_encrypted) - - inbound_encrypted = self._get_data(other_msgs) - inbound_key = self.derive_key(b"sender") - try: - inbound_data = self._decrypt_data(inbound_key, - inbound_encrypted) - except CryptoError: - raise ReceiverWrongPasswordError + self._get_key() + return self._get_data2(outbound_data) finally: self._deallocate() - return inbound_data + + def _get_data2(self, outbound_data): + # Without predefined roles, we can't derive predictably unique keys + # for each side, so we use the same key for both. We use random + # nonces to keep the messages distinct, and check for reflection. + data_key = self.derive_key(b"data-key") + + outbound_encrypted = self._encrypt_data(data_key, outbound_data) + msgs = self._post_message(self._url("post", "data"), outbound_encrypted) + + inbound_encrypted = self._get_message(msgs, "poll", "data") + if inbound_encrypted == outbound_encrypted: + raise ReflectionAttack + try: + inbound_data = self._decrypt_data(data_key, inbound_encrypted) + return inbound_data + except CryptoError: + raise WrongPasswordError + + def _deallocate(self): + # only try once, no retries + requests.post(self._url("deallocate")) + # ignore POST failure, don't call r.raise_for_status() diff --git a/src/wormhole/errors.py b/src/wormhole/errors.py index 2e8de3e..41d23ad 100644 --- a/src/wormhole/errors.py +++ b/src/wormhole/errors.py @@ -1,4 +1,4 @@ -import functools +import functools, textwrap class ServerError(Exception): def __init__(self, message, relay): @@ -16,3 +16,23 @@ def handle_server_error(func): print("Server error (from %s):\n%s" % (e.relay, e.message)) return 1 return _wrap + +class Timeout(Exception): + pass + +class WrongPasswordError(Exception): + """ + Key confirmation failed. Either you or your correspondent typed the code + wrong, or a would-be man-in-the-middle attacker guessed incorrectly. You + could try again, giving both your correspondent and the attacker another + chance. + """ + # or the data blob was corrupted, and that's why decrypt failed + def explain(self): + return textwrap.dedent(self.__doc__) + +class ReflectionAttack(Exception): + """An attacker (or bug) reflected our outgoing message back to us.""" + +class UsageError(Exception): + """The programmer did something wrong.""" diff --git a/src/wormhole/scripts/cmd_receive_file.py b/src/wormhole/scripts/cmd_receive_file.py index aefdff6..c215da1 100644 --- a/src/wormhole/scripts/cmd_receive_file.py +++ b/src/wormhole/scripts/cmd_receive_file.py @@ -7,24 +7,24 @@ APPID = "lothar.com/wormhole/file-xfer" @handle_server_error def receive_file(args): # we're receiving - from ..blocking.transcribe import Receiver, WrongPasswordError + from ..blocking.transcribe import Wormhole, WrongPasswordError from ..blocking.transit import TransitReceiver, TransitError from .progress import start_progress, update_progress, finish_progress transit_receiver = TransitReceiver(args.transit_helper) - r = Receiver(APPID, args.relay_url) + w = Wormhole(APPID, args.relay_url) if args.zeromode: assert not args.code args.code = "0-" code = args.code if not code: - code = r.input_code("Enter receive-file wormhole code: ", + code = w.input_code("Enter receive-file wormhole code: ", args.code_length) - r.set_code(code) + w.set_code(code) if args.verify: - verifier = binascii.hexlify(r.get_verifier()) + verifier = binascii.hexlify(w.get_verifier()) print("Verifier %s." % verifier) mydata = json.dumps({ @@ -34,7 +34,7 @@ def receive_file(args): }, }).encode("utf-8") try: - data = json.loads(r.get_data(mydata).decode("utf-8")) + data = json.loads(w.get_data(mydata).decode("utf-8")) except WrongPasswordError as e: print("ERROR: " + e.explain(), file=sys.stderr) return 1 @@ -50,7 +50,7 @@ def receive_file(args): # now receive the rest of the owl tdata = data["transit"] - transit_key = r.derive_key(APPID+"/transit-key") + transit_key = w.derive_key(APPID+"/transit-key") transit_receiver.set_transit_key(transit_key) transit_receiver.add_their_direct_hints(tdata["direct_connection_hints"]) transit_receiver.add_their_relay_hints(tdata["relay_connection_hints"]) diff --git a/src/wormhole/scripts/cmd_receive_text.py b/src/wormhole/scripts/cmd_receive_text.py index 64ba2da..ea374c2 100644 --- a/src/wormhole/scripts/cmd_receive_text.py +++ b/src/wormhole/scripts/cmd_receive_text.py @@ -7,25 +7,25 @@ APPID = "lothar.com/wormhole/text-xfer" @handle_server_error def receive_text(args): # we're receiving - from ..blocking.transcribe import Receiver, WrongPasswordError + from ..blocking.transcribe import Wormhole, WrongPasswordError - r = Receiver(APPID, args.relay_url) + w = Wormhole(APPID, args.relay_url) if args.zeromode: assert not args.code args.code = "0-" code = args.code if not code: - code = r.input_code("Enter receive-text wormhole code: ", + code = w.input_code("Enter receive-text wormhole code: ", args.code_length) - r.set_code(code) + w.set_code(code) if args.verify: - verifier = binascii.hexlify(r.get_verifier()) + verifier = binascii.hexlify(w.get_verifier()) print("Verifier %s." % verifier) data = json.dumps({"message": "ok"}).encode("utf-8") try: - them_bytes = r.get_data(data) + them_bytes = w.get_data(data) except WrongPasswordError as e: print("ERROR: " + e.explain(), file=sys.stderr) return 1 diff --git a/src/wormhole/scripts/cmd_send_file.py b/src/wormhole/scripts/cmd_send_file.py index 8858bc4..3d3046d 100644 --- a/src/wormhole/scripts/cmd_send_file.py +++ b/src/wormhole/scripts/cmd_send_file.py @@ -7,7 +7,7 @@ APPID = "lothar.com/wormhole/file-xfer" @handle_server_error def send_file(args): # we're sending - from ..blocking.transcribe import Initiator, WrongPasswordError + from ..blocking.transcribe import Wormhole, WrongPasswordError from ..blocking.transit import TransitSender from .progress import start_progress, update_progress, finish_progress @@ -15,15 +15,15 @@ def send_file(args): assert os.path.isfile(filename) transit_sender = TransitSender(args.transit_helper) - i = Initiator(APPID, args.relay_url) + w = Wormhole(APPID, args.relay_url) if args.zeromode: assert not args.code args.code = "0-" if args.code: - i.set_code(args.code) + w.set_code(args.code) code = args.code else: - code = i.get_code(args.code_length) + code = w.get_code(args.code_length) other_cmd = "wormhole receive-file" if args.verify: other_cmd = "wormhole --verify receive-file" @@ -35,7 +35,7 @@ def send_file(args): print() if args.verify: - verifier = binascii.hexlify(i.get_verifier()) + verifier = binascii.hexlify(w.get_verifier()) while True: ok = raw_input("Verifier %s. ok? (yes/no): " % verifier) if ok.lower() == "yes": @@ -45,7 +45,7 @@ def send_file(args): file=sys.stderr) reject_data = json.dumps({"error": "verification rejected", }).encode("utf-8") - i.get_data(reject_data) + w.get_data(reject_data) return 1 filesize = os.stat(filename).st_size @@ -61,7 +61,7 @@ def send_file(args): }).encode("utf-8") try: - them_bytes = i.get_data(data) + them_bytes = w.get_data(data) except WrongPasswordError as e: print("ERROR: " + e.explain(), file=sys.stderr) return 1 @@ -70,7 +70,7 @@ def send_file(args): tdata = them_d["transit"] - transit_key = i.derive_key(APPID+"/transit-key") + transit_key = w.derive_key(APPID+"/transit-key") transit_sender.set_transit_key(transit_key) transit_sender.add_their_direct_hints(tdata["direct_connection_hints"]) transit_sender.add_their_relay_hints(tdata["relay_connection_hints"]) diff --git a/src/wormhole/scripts/cmd_send_text.py b/src/wormhole/scripts/cmd_send_text.py index c25620f..7651cb1 100644 --- a/src/wormhole/scripts/cmd_send_text.py +++ b/src/wormhole/scripts/cmd_send_text.py @@ -7,17 +7,17 @@ APPID = "lothar.com/wormhole/text-xfer" @handle_server_error def send_text(args): # we're sending - from ..blocking.transcribe import Initiator, WrongPasswordError + from ..blocking.transcribe import Wormhole, WrongPasswordError - i = Initiator(APPID, args.relay_url) + w = Wormhole(APPID, args.relay_url) if args.zeromode: assert not args.code args.code = "0-" if args.code: - i.set_code(args.code) + w.set_code(args.code) code = args.code else: - code = i.get_code(args.code_length) + code = w.get_code(args.code_length) other_cmd = "wormhole receive-text" if args.verify: other_cmd = "wormhole --verify receive-text" @@ -29,7 +29,7 @@ def send_text(args): print("") if args.verify: - verifier = binascii.hexlify(i.get_verifier()) + verifier = binascii.hexlify(w.get_verifier()) while True: ok = raw_input("Verifier %s. ok? (yes/no): " % verifier) if ok.lower() == "yes": @@ -39,14 +39,14 @@ def send_text(args): file=sys.stderr) reject_data = json.dumps({"error": "verification rejected", }).encode("utf-8") - i.get_data(reject_data) + w.get_data(reject_data) return 1 message = args.text data = json.dumps({"message": message, }).encode("utf-8") try: - them_bytes = i.get_data(data) + them_bytes = w.get_data(data) except WrongPasswordError as e: print("ERROR: " + e.explain(), file=sys.stderr) return 1 diff --git a/src/wormhole/test/test_twisted.py b/src/wormhole/test/test_twisted.py index fe016f4..8e267fc 100644 --- a/src/wormhole/test/test_twisted.py +++ b/src/wormhole/test/test_twisted.py @@ -1,16 +1,18 @@ import json from twisted.trial import unittest from twisted.internet import defer +from twisted.internet.threads import deferToThread from twisted.application import service from ..servers.relay import RelayServer -from ..twisted.transcribe import SymmetricWormhole, UsageError +from ..twisted.transcribe import Wormhole, UsageError from ..twisted.util import allocate_ports +from ..blocking.transcribe import Wormhole as BlockingWormhole from .. import __version__ #from twisted.python import log #import sys #log.startLogging(sys.stdout) -class Basic(unittest.TestCase): +class ServerBase: def setUp(self): self.sp = service.MultiService() self.sp.startService() @@ -29,10 +31,11 @@ class Basic(unittest.TestCase): def tearDown(self): return self.sp.stopService() +class Basic(ServerBase, unittest.TestCase): def test_basic(self): appid = "appid" - w1 = SymmetricWormhole(appid, self.relayurl) - w2 = SymmetricWormhole(appid, self.relayurl) + w1 = Wormhole(appid, self.relayurl) + w2 = Wormhole(appid, self.relayurl) d = w1.get_code() def _got_code(code): w2.set_code(code) @@ -43,8 +46,8 @@ class Basic(unittest.TestCase): def _done(dl): ((success1, dataX), (success2, dataY)) = dl r1,r2 = dl - self.assertTrue(success1) - self.assertTrue(success2) + self.assertTrue(success1, dataX) + self.assertTrue(success2, dataY) self.assertEqual(dataX, "data2") self.assertEqual(dataY, "data1") d.addCallback(_done) @@ -52,8 +55,8 @@ class Basic(unittest.TestCase): def test_fixed_code(self): appid = "appid" - w1 = SymmetricWormhole(appid, self.relayurl) - w2 = SymmetricWormhole(appid, self.relayurl) + w1 = Wormhole(appid, self.relayurl) + w2 = Wormhole(appid, self.relayurl) w1.set_code("123-purple-elephant") w2.set_code("123-purple-elephant") d1 = w1.get_data("data1") @@ -62,8 +65,8 @@ class Basic(unittest.TestCase): def _done(dl): ((success1, dataX), (success2, dataY)) = dl r1,r2 = dl - self.assertTrue(success1) - self.assertTrue(success2) + self.assertTrue(success1, dataX) + self.assertTrue(success2, dataY) self.assertEqual(dataX, "data2") self.assertEqual(dataY, "data1") d.addCallback(_done) @@ -71,22 +74,22 @@ class Basic(unittest.TestCase): def test_errors(self): appid = "appid" - w1 = SymmetricWormhole(appid, self.relayurl) + w1 = Wormhole(appid, self.relayurl) self.assertRaises(UsageError, w1.get_verifier) self.assertRaises(UsageError, w1.get_data, "data") w1.set_code("123-purple-elephant") self.assertRaises(UsageError, w1.set_code, "123-nope") self.assertRaises(UsageError, w1.get_code) - w2 = SymmetricWormhole(appid, self.relayurl) + w2 = Wormhole(appid, self.relayurl) d = w2.get_code() self.assertRaises(UsageError, w2.get_code) return d def test_serialize(self): appid = "appid" - w1 = SymmetricWormhole(appid, self.relayurl) + w1 = Wormhole(appid, self.relayurl) self.assertRaises(UsageError, w1.serialize) # too early - w2 = SymmetricWormhole(appid, self.relayurl) + w2 = Wormhole(appid, self.relayurl) d = w1.get_code() def _got_code(code): self.assertRaises(UsageError, w2.serialize) # too early @@ -96,7 +99,7 @@ class Basic(unittest.TestCase): self.assertEqual(type(s), type("")) unpacked = json.loads(s) # this is supposed to be JSON self.assertEqual(type(unpacked), dict) - new_w1 = SymmetricWormhole.from_serialized(s) + new_w1 = Wormhole.from_serialized(s) d1 = new_w1.get_data("data1") d2 = w2.get_data("data2") return defer.DeferredList([d1,d2], fireOnOneErrback=False) @@ -104,10 +107,100 @@ class Basic(unittest.TestCase): def _done(dl): ((success1, dataX), (success2, dataY)) = dl r1,r2 = dl - self.assertTrue(success1) - self.assertTrue(success2) + self.assertTrue(success1, dataX) + self.assertTrue(success2, dataY) self.assertEqual(dataX, "data2") self.assertEqual(dataY, "data1") self.assertRaises(UsageError, w2.serialize) # too late d.addCallback(_done) return d + +class Blocking(ServerBase, unittest.TestCase): + # we need Twisted to run the server, but we run the sender and receiver + # with deferToThread() + + def test_basic(self): + appid = "appid" + w1 = BlockingWormhole(appid, self.relayurl) + w2 = BlockingWormhole(appid, self.relayurl) + d = deferToThread(w1.get_code) + def _got_code(code): + w2.set_code(code) + d1 = deferToThread(w1.get_data, "data1") + d2 = deferToThread(w2.get_data, "data2") + return defer.DeferredList([d1,d2], fireOnOneErrback=False) + d.addCallback(_got_code) + def _done(dl): + ((success1, dataX), (success2, dataY)) = dl + r1,r2 = dl + self.assertTrue(success1, dataX) + self.assertTrue(success2, dataY) + self.assertEqual(dataX, "data2") + self.assertEqual(dataY, "data1") + d.addCallback(_done) + return d + + def test_fixed_code(self): + appid = "appid" + w1 = BlockingWormhole(appid, self.relayurl) + w2 = BlockingWormhole(appid, self.relayurl) + w1.set_code("123-purple-elephant") + w2.set_code("123-purple-elephant") + d1 = deferToThread(w1.get_data, "data1") + d2 = deferToThread(w2.get_data, "data2") + d = defer.DeferredList([d1,d2], fireOnOneErrback=False) + def _done(dl): + ((success1, dataX), (success2, dataY)) = dl + r1,r2 = dl + self.assertTrue(success1, dataX) + self.assertTrue(success2, dataY) + self.assertEqual(dataX, "data2") + self.assertEqual(dataY, "data1") + d.addCallback(_done) + return d + + def test_errors(self): + appid = "appid" + w1 = BlockingWormhole(appid, self.relayurl) + self.assertRaises(UsageError, w1.get_verifier) + self.assertRaises(UsageError, w1.get_data, "data") + w1.set_code("123-purple-elephant") + self.assertRaises(UsageError, w1.set_code, "123-nope") + self.assertRaises(UsageError, w1.get_code) + w2 = BlockingWormhole(appid, self.relayurl) + d = deferToThread(w2.get_code) + def _done(code): + self.assertRaises(UsageError, w2.get_code) + d.addCallback(_done) + return d + + def test_serialize(self): + appid = "appid" + w1 = BlockingWormhole(appid, self.relayurl) + self.assertRaises(UsageError, w1.serialize) # too early + w2 = BlockingWormhole(appid, self.relayurl) + d = deferToThread(w1.get_code) + def _got_code(code): + self.assertRaises(UsageError, w2.serialize) # too early + w2.set_code(code) + w2.serialize() # ok + s = w1.serialize() + self.assertEqual(type(s), type("")) + unpacked = json.loads(s) # this is supposed to be JSON + self.assertEqual(type(unpacked), dict) + new_w1 = BlockingWormhole.from_serialized(s) + d1 = deferToThread(new_w1.get_data, "data1") + d2 = deferToThread(w2.get_data, "data2") + return defer.DeferredList([d1,d2], fireOnOneErrback=False) + d.addCallback(_got_code) + def _done(dl): + ((success1, dataX), (success2, dataY)) = dl + r1,r2 = dl + self.assertTrue(success1, dataX) + self.assertTrue(success2, dataY) + self.assertEqual(dataX, "data2") + self.assertEqual(dataY, "data1") + self.assertRaises(UsageError, w2.serialize) # too late + d.addCallback(_done) + return d + test_serialize.skip = "not yet implemented for the blocking flavor" diff --git a/src/wormhole/twisted/demo.py b/src/wormhole/twisted/demo.py index adc3b74..a9313d6 100644 --- a/src/wormhole/twisted/demo.py +++ b/src/wormhole/twisted/demo.py @@ -1,28 +1,38 @@ -import sys +import sys, json from twisted.internet import reactor -from .transcribe import SymmetricWormhole +from .transcribe import Wormhole from .. import public_relay APPID = "lothar.com/wormhole/text-xfer" -w = SymmetricWormhole(APPID, public_relay.RENDEZVOUS_RELAY) +w = Wormhole(APPID, public_relay.RENDEZVOUS_RELAY) if sys.argv[1] == "send-text": message = sys.argv[2] + data = json.dumps({"message": message}).encode("utf-8") d = w.get_code() def _got_code(code): print "code is:", code - return w.get_data(message) + return w.get_data(data) d.addCallback(_got_code) - def _got_data(their_data): - print "ack:", their_data + def _got_data(them_bytes): + them_d = json.loads(them_bytes.decode("utf-8")) + if them_d["message"] == "ok": + print "text sent" + else: + print "error sending text: %r" % (them_d,) d.addCallback(_got_data) elif sys.argv[1] == "receive-text": code = sys.argv[2] w.set_code(code) - d = w.get_data("ok") - def _got_data(their_data): - print their_data + data = json.dumps({"message": "ok"}).encode("utf-8") + d = w.get_data(data) + def _got_data(them_bytes): + them_d = json.loads(them_bytes.decode("utf-8")) + if "error" in them_d: + print >>sys.stderr, "ERROR: " + them_d["error"] + return 1 + print them_d["message"] d.addCallback(_got_data) else: raise ValueError("bad command") diff --git a/src/wormhole/twisted/eventsource.py b/src/wormhole/twisted/eventsource_twisted.py similarity index 100% rename from src/wormhole/twisted/eventsource.py rename to src/wormhole/twisted/eventsource_twisted.py diff --git a/src/wormhole/twisted/transcribe.py b/src/wormhole/twisted/transcribe.py index 775dd67..ee047dd 100644 --- a/src/wormhole/twisted/transcribe.py +++ b/src/wormhole/twisted/transcribe.py @@ -10,23 +10,13 @@ from nacl.secret import SecretBox from nacl.exceptions import CryptoError from nacl import utils from spake2 import SPAKE2_Symmetric -from .eventsource import ReconnectingEventSource +from .eventsource_twisted import ReconnectingEventSource from .. import __version__ from .. import codes -from ..errors import ServerError +from ..errors import (ServerError, WrongPasswordError, + ReflectionAttack, UsageError) from ..util.hkdf import HKDF -class WrongPasswordError(Exception): - """ - Key confirmation failed. - """ - -class ReflectionAttack(Exception): - """An attacker (or bug) reflected our outgoing message back to us.""" - -class UsageError(Exception): - """The programmer did something wrong.""" - @implementer(IBodyProducer) class DataProducer: def __init__(self, data): @@ -43,7 +33,10 @@ class DataProducer: pass -class SymmetricWormhole: +class Wormhole: + motd_displayed = False + version_warning_displayed = False + def __init__(self, appid, relay): self.appid = appid self.relay = relay @@ -53,6 +46,61 @@ class SymmetricWormhole: self.key = None self._started_get_code = False + def _url(self, verb, msgnum=None): + url = "%s%d/%s/%s" % (self.relay, self.channel_id, self.side, verb) + if msgnum is not None: + url += "/" + msgnum + return url + + def handle_welcome(self, welcome): + if ("motd" in welcome and + not self.motd_displayed): + motd_lines = welcome["motd"].splitlines() + motd_formatted = "\n ".join(motd_lines) + print("Server (at %s) says:\n %s" % (self.relay, motd_formatted), + file=sys.stderr) + self.motd_displayed = True + + # Only warn if we're running a release version (e.g. 0.0.6, not + # 0.0.6-DISTANCE-gHASH). Only warn once. + if ("-" not in __version__ and + not self.version_warning_displayed and + welcome["current_version"] != __version__): + print("Warning: errors may occur unless both sides are running the same version", file=sys.stderr) + print("Server claims %s is current, but ours is %s" + % (welcome["current_version"], __version__), file=sys.stderr) + self.version_warning_displayed = True + + if "error" in welcome: + raise ServerError(welcome["error"], self.relay) + + def _post_json(self, url, post_json=None): + # POST to a URL, parsing the response as JSON. Optionally include a + # JSON request body. + p = None + if post_json: + data = json.dumps(post_json).encode("utf-8") + p = DataProducer(data) + d = self.agent.request("POST", url, bodyProducer=p) + def _check_error(resp): + if resp.code != 200: + raise web_error.Error(resp.code, resp.phrase) + return resp + d.addCallback(_check_error) + d.addCallback(web_client.readBody) + d.addCallback(lambda data: json.loads(data)) + return d + + def _allocate_channel(self): + url = self.relay + "allocate/%s" % self.side + d = self._post_json(url) + def _got_channel(data): + if "welcome" in data: + self.handle_welcome(data["welcome"]) + return data["channel-id"] + d.addCallback(_got_channel) + return d + def get_code(self, code_length=2): if self.code is not None: raise UsageError if self._started_get_code: raise UsageError @@ -67,16 +115,6 @@ class SymmetricWormhole: d.addCallback(_got_channel_id) return d - def _allocate_channel(self): - url = self.relay + "allocate/%s" % self.side - d = self.post(url) - def _got_channel(data): - if "welcome" in data: - self.handle_welcome(data["welcome"]) - return data["channel-id"] - d.addCallback(_got_channel) - return d - def set_code(self, code): if self.code is not None: raise UsageError if self.side is not None: raise UsageError @@ -95,8 +133,7 @@ class SymmetricWormhole: def _start(self): # allocate the rest now too, so it can be serialized self.sp = SPAKE2_Symmetric(self.code.encode("ascii"), - idA=self.appid+":SymmetricA", - idB=self.appid+":SymmetricB") + idSymmetric=self.appid) self.msg1 = self.sp.start() def serialize(self): @@ -124,60 +161,21 @@ class SymmetricWormhole: self.msg1 = d["msg1"].decode("hex") return self - motd_displayed = False - version_warning_displayed = False - - def handle_welcome(self, welcome): - if ("motd" in welcome and - not self.motd_displayed): - motd_lines = welcome["motd"].splitlines() - motd_formatted = "\n ".join(motd_lines) - print("Server (at %s) says:\n %s" % (self.relay, motd_formatted), - file=sys.stderr) - self.motd_displayed = True - - # Only warn if we're running a release version (e.g. 0.0.6, not - # 0.0.6-DISTANCE-gHASH). Only warn once. - if ("-" not in __version__ and - not self.version_warning_displayed and - welcome["current_version"] != __version__): - print("Warning: errors may occur unless both sides are running the same version", file=sys.stderr) - print("Server claims %s is current, but ours is %s" - % (welcome["current_version"], __version__), file=sys.stderr) - self.version_warning_displayed = True - - if "error" in welcome: - raise ServerError(welcome["error"], self.relay) - - def url(self, verb, msgnum=None): - url = "%s%d/%s/%s" % (self.relay, self.channel_id, self.side, verb) - if msgnum is not None: - url += "/" + msgnum - return url - - def post(self, url, post_json=None): + def _post_message(self, url, msg): # TODO: retry on failure, with exponential backoff. We're guarding # against the rendezvous server being temporarily offline. - p = None - if post_json: - data = json.dumps(post_json).encode("utf-8") - p = DataProducer(data) - d = self.agent.request("POST", url, bodyProducer=p) - def _check_error(resp): - if resp.code != 200: - raise web_error.Error(resp.code, resp.phrase) - return resp - d.addCallback(_check_error) - d.addCallback(web_client.readBody) - d.addCallback(lambda data: json.loads(data)) + if not isinstance(msg, type(b"")): raise UsageError(type(msg)) + d = self._post_json(url, {"message": hexlify(msg).decode("ascii")}) + d.addCallback(lambda resp: resp["messages"]) # other_msgs return d - def _get_msgs(self, old_msgs, verb, msgnum): - # fire with a list of messages that match verb/msgnum, which either - # came from old_msgs, or from an EventSource that we attached to the - # corresponding URL + def _get_message(self, old_msgs, verb, msgnum): + # fire with a bytestring of the first message that matches + # verb/msgnum, which either came from old_msgs, or from an + # EventSource that we attached to the corresponding URL if old_msgs: - return defer.succeed(old_msgs) + msg = unhexlify(old_msgs[0].encode("ascii")) + return defer.succeed(msg) d = defer.Deferred() msgs = [] def _handle(name, data): @@ -186,28 +184,30 @@ class SymmetricWormhole: if name == "message": msgs.append(json.loads(data)["message"]) d.callback(None) - es = ReconnectingEventSource(None, lambda: self.url(verb, msgnum), + es = ReconnectingEventSource(None, lambda: self._url(verb, msgnum), _handle)#, agent=self.agent) es.startService() # TODO: .setServiceParent(self) es.activate() d.addCallback(lambda _: es.deactivate()) d.addCallback(lambda _: es.stopService()) - d.addCallback(lambda _: msgs) + d.addCallback(lambda _: unhexlify(msgs[0].encode("ascii"))) return d def derive_key(self, purpose, length=SecretBox.KEY_SIZE): - assert self.key is not None # call after get_verifier() or get_data() - assert type(purpose) == type(b"") + if self.key is None: + # call after get_verifier() or get_data() + raise UsageError + if not isinstance(purpose, type(b"")): raise UsageError return HKDF(self.key, length, CTXinfo=purpose) def _encrypt_data(self, key, data): - assert len(key) == SecretBox.KEY_SIZE + if len(key) != SecretBox.KEY_SIZE: raise UsageError box = SecretBox(key) nonce = utils.random(SecretBox.NONCE_SIZE) return box.encrypt(data, nonce) def _decrypt_data(self, key, encrypted): - assert len(key) == SecretBox.KEY_SIZE + if len(key) != SecretBox.KEY_SIZE: raise UsageError box = SecretBox(key) data = box.decrypt(encrypted) return data @@ -217,11 +217,9 @@ class SymmetricWormhole: # TODO: prevent multiple invocation if self.key: return defer.succeed(self.key) - data = {"message": hexlify(self.msg1).decode("ascii")} - d = self.post(self.url("post", "pake"), data) - d.addCallback(lambda j: self._get_msgs(j["messages"], "poll", "pake")) - def _got_pake(msgs): - pake_msg = unhexlify(msgs[0].encode("ascii")) + d = self._post_message(self._url("post", "pake"), self.msg1) + d.addCallback(lambda msgs: self._get_message(msgs, "poll", "pake")) + def _got_pake(pake_msg): key = self.sp.finish(pake_msg) self.key = key self.verifier = self.derive_key(self.appid+b":Verifier") @@ -240,6 +238,7 @@ class SymmetricWormhole: if self.code is None: raise UsageError d = self._get_key() d.addCallback(self._get_data2, outbound_data) + d.addBoth(self._deallocate) return d def _get_data2(self, key, outbound_data): @@ -247,12 +246,12 @@ class SymmetricWormhole: # for each side, so we use the same key for both. We use random # nonces to keep the messages distinct, and check for reflection. data_key = self.derive_key(b"data-key") + outbound_encrypted = self._encrypt_data(data_key, outbound_data) - data = {"message": hexlify(outbound_encrypted).decode("ascii")} - d = self.post(self.url("post", "data"), data) - d.addCallback(lambda j: self._get_msgs(j["messages"], "poll", "data")) - def _got_data(msgs): - inbound_encrypted = unhexlify(msgs[0].encode("ascii")) + d = self._post_message(self._url("post", "data"), outbound_encrypted) + + d.addCallback(lambda msgs: self._get_message(msgs, "poll", "data")) + def _got_data(inbound_encrypted): if inbound_encrypted == outbound_encrypted: raise ReflectionAttack try: @@ -261,11 +260,10 @@ class SymmetricWormhole: except CryptoError: raise WrongPasswordError d.addCallback(_got_data) - d.addBoth(self._deallocate) return d def _deallocate(self, res): # only try once, no retries - d = self.agent.request("POST", self.url("deallocate")) + d = self.agent.request("POST", self._url("deallocate")) d.addBoth(lambda _: res) # ignore POST failure, pass-through result return d