From 4dfa569769e37a44cc9a6fd66eef68ac684e437d Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Thu, 12 May 2016 15:42:40 -0700 Subject: [PATCH 01/51] docs: remove named phases, Wormhole is now a record pipe --- docs/api.md | 137 ++++++++++++++++++++++++++-------------------------- 1 file changed, 69 insertions(+), 68 deletions(-) diff --git a/docs/api.md b/docs/api.md index fd7e6cb..53d7079 100644 --- a/docs/api.md +++ b/docs/api.md @@ -10,6 +10,15 @@ short string that is transcribed from one machine to the other by the users at the keyboard. This works in conjunction with a baked-in "rendezvous server" that relays information from one machine to the other. +The "Wormhole" object provides a secure record pipe between any two programs +that use the same wormhole code (and are configured with the same application +ID and rendezvous server). Each side can send multiple messages to the other, +but the encrypted data for all messages must pass through (and be temporarily +stored on) the rendezvous server, which is a shared resource. For this +reason, larger data (including bulk file transfers) should use the Transit +class instead. The Wormhole object has a method to create a Transit object +for this purpose. + ## Modes This library will eventually offer multiple modes. For now, only "transcribe @@ -39,22 +48,32 @@ string. The two machines participating in the wormhole setup are not distinguished: it doesn't matter which one goes first, and both use the same Wormhole class. In the first variant, one side calls `get_code()` while the other calls -`set_code()`. In the second variant, both sides call `set_code()`. Note that +`set_code()`. In the second variant, both sides call `set_code()`. (Note that this is not true for the "Transit" protocol used for bulk data-transfer: the Transit class currently distinguishes "Sender" from "Receiver", so the -programs on each side must have some way to decide (ahead of time) which is -which. +programs on each side must have some way to decide ahead of time which is +which). -Each side gets to do one `send_data()` call and one `get_data()` call per -phase (see below). `get_data` will wait until the other side has done -`send_data`, so the application developer must be careful to avoid deadlocks -(don't get before you send on both sides in the same protocol). When both -sides are done, they must call `close()`, to let the library know that the -connection is complete and it can deallocate the channel. If you forget to -call `close()`, the server will not free the channel, and other users will -suffer longer invitation codes as a result. To encourage `close()`, the -library will log an error if a Wormhole object is destroyed before being -closed. +Each side can then do an arbitrary number of `send()` and `get()` calls. +`send()` writes a message into the channel. `get()` waits for a new message +to be available, then returns it. The Wormhole is not meant as a long-term +communication channel, but some protocols work better if they can exchange an +initial pair of messages (perhaps offering some set of negotiable +capabilities), and then follow up with a second pair (to reveal the results +of the negotiation). Another use case is for an ACK that gets sent at the end +of a file transfer: the Wormhole is held open until the Transit object +reports completion, and the last message is a hash of the file contents to +prove it was received correctly. + +Note: the application developer must be careful to avoid deadlocks (if both +sides want to `get()`, somebody has to `send()` first). + +When both sides are done, they must call `close()`, to let the library know +that the connection is complete and it can deallocate the channel. If you +forget to call `close()`, the server will not free the channel, and other +users will suffer longer invitation codes as a result. To encourage +`close()`, the library will log an error if a Wormhole object is destroyed +before being closed. To make it easier to call `close()`, the blocking Wormhole objects can be used as a context manager. Just put your code in the body of a `with @@ -72,8 +91,8 @@ mydata = b"initiator's data" with Wormhole(u"appid", RENDEZVOUS_RELAY) as i: code = i.get_code() print("Invitation Code: %s" % code) - i.send_data(mydata) - theirdata = i.get_data() + i.send(mydata) + theirdata = i.get() print("Their data: %s" % theirdata.decode("ascii")) ``` @@ -85,8 +104,8 @@ mydata = b"receiver's data" code = sys.argv[1] with Wormhole(u"appid", RENDEZVOUS_RELAY) as r: r.set_code(code) - r.send_data(mydata) - theirdata = r.get_data() + r.send(mydata) + theirdata = r.get() print("Their data: %s" % theirdata.decode("ascii")) ``` @@ -103,12 +122,12 @@ w1 = Wormhole(u"appid", RENDEZVOUS_RELAY) d = w1.get_code() def _got_code(code): print "Invitation Code:", code - return w1.send_data(outbound_message) + return w1.send(outbound_message) d.addCallback(_got_code) -d.addCallback(lambda _: w1.get_data()) -def _got_data(inbound_message): +d.addCallback(lambda _: w1.get()) +def _got(inbound_message): print "Inbound message:", inbound_message -d.addCallback(_got_data) +d.addCallback(_got) d.addCallback(w1.close) d.addBoth(lambda _: reactor.stop()) reactor.run() @@ -119,7 +138,7 @@ On the other side, you call `set_code()` instead of waiting for `get_code()`: ```python w2 = Wormhole(u"appid", RENDEZVOUS_RELAY) w2.set_code(code) -d = w2.send_data(my_message) +d = w2.send(my_message) ... ``` @@ -127,56 +146,45 @@ Note that the Twisted-form `close()` accepts (and returns) an optional argument, so you can use `d.addCallback(w.close)` instead of `d.addCallback(lambda _: w.close())`. -## Phases - -If necessary, more than one message can be exchanged through the relay -server. It is not meant as a long-term communication channel, but some -protocols work better if they can exchange an initial pair of messages -(perhaps offering some set of negotiable capabilities), and then follow up -with a second pair (to reveal the results of the negotiation). - -To support this, `send_data()/get_data()` accept a "phase" argument: an -arbitrary (unicode) string. It must match the other side: calling -`send_data(data, phase=u"offer")` on one side will deliver that data to -`get_data(phase=u"offer")` on the other. - -It is a UsageError to call `send_data()` or `get_data()` twice with the same -phase name. The relay server may limit the number of phases that may be -exchanged, however it will always allow at least two. - ## Verifier -You can call `w.get_verifier()` before `send_data()/get_data()`: this will -perform the first half of the PAKE negotiation, then return a verifier object -(bytes) which can be converted into a printable representation and manually -compared. When the users are convinced that `get_verifier()` from both sides -are the same, call `send_data()/get_data()` to continue the transfer. If you -call `send_data()/get_data()` before `get_verifier()`, it will perform the -complete transfer without pausing. +For extra protection against guessing attacks, Wormhole can provide a +"Verifier". This is a moderate-length series of bytes (a SHA256 hash) that is +derived from the supposedly-shared session key. If desired, both sides can +display this value, and the humans can manually compare them before allowing +the rest of the protocol to proceed. If they do not match, then the two +programs are not talking to each other (they may both be talking to a +man-in-the-middle attacker), and the protocol should be abandoned. + +To retrieve the verifier, you call `w.get_verifier()` before any calls to +`send()/get()`. Turn this into hex or Base64 to print it, or render it as +ASCII-art, etc. Once the users are convinced that `get_verifier()` from both +sides are the same, call `send()/get()` to continue the protocol. If you call +`send()/get()` before `get_verifier()`, it will perform the complete protocol +without pausing. The Twisted form of `get_verifier()` returns a Deferred that fires with the verifier bytes. ## Generating the Invitation Code -In most situations, the "sending" or "initiating" side will call -`i.get_code()` to generate the invitation code. This returns a string in the -form `NNN-code-words`. The numeric "NNN" prefix is the "channel id", and is a +In most situations, the "sending" or "initiating" side will call `get_code()` +to generate the invitation code. This returns a string in the form +`NNN-code-words`. The numeric "NNN" prefix is the "channel id", and is a short integer allocated by talking to the rendezvous server. The rest is a randomly-generated selection from the PGP wordlist, providing a default of 16 bits of entropy. The initiating program should display this code to the user, who should transcribe it to the receiving user, who gives it to the Receiver -object by calling `r.set_code()`. The receiving program can also use +object by calling `set_code()`. The receiving program can also use `input_code_with_completion()` to use a readline-based input function: this offers tab completion of allocated channel-ids and known codewords. Alternatively, the human users can agree upon an invitation code themselves, -and provide it to both programs later (with `i.set_code()` and -`r.set_code()`). They should choose a channel-id that is unlikely to already -be in use (3 or more digits are recommended), append a hyphen, and then -include randomly-selected words or characters. Dice, coin flips, shuffled -cards, or repeated sampling of a high-resolution stopwatch are all useful -techniques. +and provide it to both programs later (both sides call `set_code()`). They +should choose a channel-id that is unlikely to already be in use (3 or more +digits are recommended), append a hyphen, and then include randomly-selected +words or characters. Dice, coin flips, shuffled cards, or repeated sampling +of a high-resolution stopwatch are all useful techniques. Note that the code is a human-readable string (the python "unicode" type in python2, "str" in python3). @@ -192,8 +200,8 @@ invitation codes are scoped to the app-id. Note that the app-id must be unicode, not bytes, so on python2 use `u"appid"`. Distinct app-ids reduce the size of the connection-id numbers. If fewer than -ten initiators are active for a given app-id, the connection-id will only -need to contain a single digit, even if some other app-id is currently using +ten Wormholes are active for a given app-id, the connection-id will only need +to contain a single digit, even if some other app-id is currently using thousands of concurrent sessions. ## Rendezvous Relays @@ -251,10 +259,9 @@ Wormhole.from_serialized(data)`). There is exactly one point at which you can serialize the wormhole: *after* establishing the invitation code, but before waiting for `get_verifier()` or -`get_data()`, or calling `send_data()`. If you are creating a new invitation -code, the correct time is during the callback fired by `get_code()`. If you -are accepting a pre-generated code, the time is just after calling -`set_code()`. +`get()`, or calling `send()`. If you are creating a new invitation code, the +correct time is during the callback fired by `get_code()`. If you are +accepting a pre-generated code, the time is just after calling `set_code()`. To properly checkpoint the process, you should store the first message (returned by `start()`) next to the serialized wormhole instance, so you can @@ -278,9 +285,3 @@ in python3): * transit connection hints (e.g. "host:port") * application identifier * derived-key "purpose" string: `w.derive_key(PURPOSE)` - -## Detailed Example - -```python - -``` From 49785008bbaed748c2b420bc4c3b9b633d7ff9d1 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Thu, 12 May 2016 15:50:44 -0700 Subject: [PATCH 02/51] remove blocking implementation: it will return It will return as a crochet-based wrapper around the Twisted implementation. --- setup.py | 3 +- src/wormhole/blocking/__init__.py | 0 src/wormhole/blocking/eventsource.py | 49 --- src/wormhole/blocking/transcribe.py | 413 ------------------------- src/wormhole/test/test_blocking.py | 446 --------------------------- src/wormhole/test/test_interop.py | 57 ---- 6 files changed, 1 insertion(+), 967 deletions(-) delete mode 100644 src/wormhole/blocking/__init__.py delete mode 100644 src/wormhole/blocking/eventsource.py delete mode 100644 src/wormhole/blocking/transcribe.py delete mode 100644 src/wormhole/test/test_blocking.py delete mode 100644 src/wormhole/test/test_interop.py diff --git a/setup.py b/setup.py index 0e4e4e3..803c131 100644 --- a/setup.py +++ b/setup.py @@ -14,7 +14,6 @@ setup(name="magic-wormhole", url="https://github.com/warner/magic-wormhole", package_dir={"": "src"}, packages=["wormhole", - "wormhole.blocking", "wormhole.cli", "wormhole.server", "wormhole.test", @@ -25,7 +24,7 @@ setup(name="magic-wormhole", ["wormhole = wormhole.cli.runner:entry", "wormhole-server = wormhole.server.runner:entry", ]}, - install_requires=["spake2==0.3", "pynacl", "requests", "argparse", + install_requires=["spake2==0.3", "pynacl", "argparse", "six", "twisted >= 16.1.0", "hkdf", "tqdm", "autobahn[twisted]", "pytrie", # autobahn seems to have a bug, and one plugin throws diff --git a/src/wormhole/blocking/__init__.py b/src/wormhole/blocking/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/wormhole/blocking/eventsource.py b/src/wormhole/blocking/eventsource.py deleted file mode 100644 index fd7a4a0..0000000 --- a/src/wormhole/blocking/eventsource.py +++ /dev/null @@ -1,49 +0,0 @@ -from __future__ import print_function, unicode_literals -import requests - -class EventSourceFollower: - def __init__(self, url, timeout): - self._resp = requests.get(url, - headers={"accept": "text/event-stream"}, - stream=True, - timeout=timeout) - self._resp.raise_for_status() - self._lines_iter = self._resp.iter_lines(chunk_size=1, - decode_unicode=True) - - def close(self): - self._resp.close() - - def iter_events(self): - # I think Request.iter_lines and .iter_content use chunk_size= in a - # funny way, and nothing happens until at least that much data has - # arrived. So unless we set chunk_size=1, we won't hear about lines - # for a long time. I'd prefer that chunk_size behaved like - # read(size), and gave you 1<=x<=size bytes in response. - eventtype = "message" - current_lines = [] - for line in self._lines_iter: - assert isinstance(line, type(u"")), type(line) - if not line: - # blank line ends the field: deliver event, reset for next - yield (eventtype, "\n".join(current_lines)) - eventtype = "message" - current_lines[:] = [] - continue - if ":" in line: - fieldname, data = line.split(":", 1) - if data.startswith(" "): - data = data[1:] - else: - fieldname = line - data = "" - if fieldname == "event": - eventtype = data - elif fieldname == "data": - current_lines.append(data) - elif fieldname in ("id", "retry"): - # documented but unhandled - pass - else: - #log.msg("weird fieldname", fieldname, data) - pass diff --git a/src/wormhole/blocking/transcribe.py b/src/wormhole/blocking/transcribe.py deleted file mode 100644 index d2a9920..0000000 --- a/src/wormhole/blocking/transcribe.py +++ /dev/null @@ -1,413 +0,0 @@ -from __future__ import print_function -import os, sys, time, re, requests, json, unicodedata -from six.moves.urllib_parse import urlencode -from binascii import hexlify, unhexlify -from spake2 import SPAKE2_Symmetric -from nacl.secret import SecretBox -from nacl.exceptions import CryptoError -from nacl import utils -from .eventsource import EventSourceFollower -from .. import __version__ -from .. import codes -from ..errors import ServerError, Timeout, WrongPasswordError, UsageError -from ..timing import DebugTiming -from hkdf import Hkdf -from ..channel_monitor import monitor - -def HKDF(skm, outlen, salt=None, CTXinfo=b""): - return Hkdf(salt, skm).expand(CTXinfo, outlen) - -SECOND = 1 -MINUTE = 60*SECOND - -CONFMSG_NONCE_LENGTH = 128//8 -CONFMSG_MAC_LENGTH = 256//8 -def make_confmsg(confkey, nonce): - return nonce+HKDF(confkey, CONFMSG_MAC_LENGTH, nonce) - -def to_bytes(u): - return unicodedata.normalize("NFC", u).encode("utf-8") - -class Channel: - def __init__(self, relay_url, appid, channelid, side, handle_welcome, - wait, timeout, timing): - self._relay_url = relay_url - self._appid = appid - self._channelid = channelid - self._side = side - self._handle_welcome = handle_welcome - self._messages = set() # (phase,body) , body is bytes - self._sent_messages = set() # (phase,body) - self._started = time.time() - self._wait = wait - self._timeout = timeout - self._timing = timing - - def _add_inbound_messages(self, messages): - for msg in messages: - phase = msg["phase"] - body = unhexlify(msg["body"].encode("ascii")) - self._messages.add( (phase, body) ) - - def _find_inbound_message(self, phases): - their_messages = self._messages - self._sent_messages - for phase in phases: - for (their_phase,body) in their_messages: - if their_phase == phase: - return (phase, body) - return None - - def send(self, phase, msg): - # TODO: retry on failure, with exponential backoff. We're guarding - # against the rendezvous server being temporarily offline. - if not isinstance(phase, type(u"")): raise TypeError(type(phase)) - if not isinstance(msg, type(b"")): raise TypeError(type(msg)) - self._sent_messages.add( (phase,msg) ) - payload = {"appid": self._appid, - "channelid": self._channelid, - "side": self._side, - "phase": phase, - "body": hexlify(msg).decode("ascii")} - data = json.dumps(payload).encode("utf-8") - with self._timing.add("send %s" % phase): - r = requests.post(self._relay_url+"add", data=data, - timeout=self._timeout) - r.raise_for_status() - resp = r.json() - if "welcome" in resp: - self._handle_welcome(resp["welcome"]) - self._add_inbound_messages(resp["messages"]) - - def get_first_of(self, phases): - if not isinstance(phases, (list, set)): raise TypeError(type(phases)) - for phase in phases: - if not isinstance(phase, type(u"")): raise TypeError(type(phase)) - - # For now, server errors cause the client to fail. TODO: don't. This - # will require changing the client to re-post messages when the - # server comes back up. - - # fire with a bytestring of the first message for any 'phase' that - # wasn't one of our own messages. It will either come from - # previously-received messages, or from an EventSource that we attach - # to the corresponding URL - with self._timing.add("get %s" % "/".join(sorted(phases))): - phase_and_body = self._find_inbound_message(phases) - while phase_and_body is None: - remaining = self._started + self._timeout - time.time() - if remaining < 0: - raise Timeout - queryargs = urlencode([("appid", self._appid), - ("channelid", self._channelid)]) - f = EventSourceFollower(self._relay_url+"watch?%s" % queryargs, - remaining) - # we loop here until the connection is lost, or we see the - # message we want - for (eventtype, line) in f.iter_events(): - if eventtype == "welcome": - self._handle_welcome(json.loads(line)) - if eventtype == "message": - data = json.loads(line) - self._add_inbound_messages([data]) - phase_and_body = self._find_inbound_message(phases) - if phase_and_body: - f.close() - break - if not phase_and_body: - time.sleep(self._wait) - return phase_and_body - - def get(self, phase): - (got_phase, body) = self.get_first_of([phase]) - assert got_phase == phase - return body - - def deallocate(self, mood=None): - # only try once, no retries - data = json.dumps({"appid": self._appid, - "channelid": self._channelid, - "side": self._side, - "mood": mood}).encode("utf-8") - try: - # ignore POST failure, don't call r.raise_for_status(), set a - # short timeout and ignore failures - with self._timing.add("close"): - r = requests.post(self._relay_url+"deallocate", data=data, - timeout=5) - r.json() - except requests.exceptions.RequestException: - pass - -class ChannelManager: - def __init__(self, relay_url, appid, side, handle_welcome, timing=None, - wait=0.5*SECOND, timeout=3*MINUTE): - self._relay_url = relay_url - self._appid = appid - self._side = side - self._handle_welcome = handle_welcome - self._timing = timing or DebugTiming() - self._wait = wait - self._timeout = timeout - - def list_channels(self): - queryargs = urlencode([("appid", self._appid)]) - with self._timing.add("list"): - r = requests.get(self._relay_url+"list?%s" % queryargs, - timeout=self._timeout) - r.raise_for_status() - data = r.json() - if "welcome" in data: - self._handle_welcome(data["welcome"]) - channelids = data["channelids"] - return channelids - - def allocate(self): - data = json.dumps({"appid": self._appid, - "side": self._side}).encode("utf-8") - with self._timing.add("allocate"): - r = requests.post(self._relay_url+"allocate", data=data, - timeout=self._timeout) - r.raise_for_status() - data = r.json() - if "welcome" in data: - self._handle_welcome(data["welcome"]) - channelid = data["channelid"] - return channelid - - def connect(self, channelid): - return Channel(self._relay_url, self._appid, channelid, self._side, - self._handle_welcome, self._wait, self._timeout, - self._timing) - -def close_on_error(f): # method decorator - # Clients report certain errors as "moods", so the server can make a - # rough count failed connections (due to mismatched passwords, attacks, - # or timeouts). We don't report precondition failures, as those are the - # responsibility/fault of the local application code. We count - # non-precondition errors in case they represent server-side problems. - def _f(self, *args, **kwargs): - try: - return f(self, *args, **kwargs) - except Timeout: - self.close(u"lonely") - raise - except WrongPasswordError: - self.close(u"scary") - raise - except (TypeError, UsageError): - # preconditions don't warrant _close_with_error() - raise - except: - self.close(u"errory") - raise - return _f - -class Wormhole: - motd_displayed = False - version_warning_displayed = False - _send_confirm = True - - def __init__(self, appid, relay_url, wait=0.5*SECOND, timeout=3*MINUTE, - timing=None): - if not isinstance(appid, type(u"")): raise TypeError(type(appid)) - if not isinstance(relay_url, type(u"")): - raise TypeError(type(relay_url)) - if not relay_url.endswith(u"/"): raise UsageError - self._appid = appid - self._relay_url = relay_url - self._wait = wait - self._timeout = timeout - self._timing = timing or DebugTiming() - side = hexlify(os.urandom(5)).decode("ascii") - self._channel_manager = ChannelManager(relay_url, appid, side, - self.handle_welcome, - self._timing, - self._wait, self._timeout) - self._channel = None - self.code = None - self.key = None - self.verifier = None - self._sent_data = set() # phases - self._got_data = set() - self._got_confirmation = False - self._closed = False - self._timing_started = self._timing.add("wormhole") - - def __enter__(self): - return self - def __exit__(self, exc_type, exc_val, exc_tb): - self.close() - return False - - def handle_welcome(self, welcome): - if ("motd" in welcome and - not self.motd_displayed): - motd_lines = welcome["motd"].splitlines() - motd_formatted = "\n ".join(motd_lines) - print("Server (at %s) says:\n %s" % (self._relay_url, motd_formatted), - file=sys.stderr) - self.motd_displayed = True - - # Only warn if we're running a release version (e.g. 0.0.6, not - # 0.0.6-DISTANCE-gHASH). Only warn once. - if ("-" not in __version__ and - not self.version_warning_displayed and - welcome["current_version"] != __version__): - print("Warning: errors may occur unless both sides are running the same version", file=sys.stderr) - print("Server claims %s is current, but ours is %s" - % (welcome["current_version"], __version__), file=sys.stderr) - self.version_warning_displayed = True - - if "error" in welcome: - raise ServerError(welcome["error"], self._relay_url) - - def get_code(self, code_length=2): - if self.code is not None: raise UsageError - channelid = self._channel_manager.allocate() - code = codes.make_code(channelid, code_length) - assert isinstance(code, type(u"")), type(code) - self._set_code_and_channelid(code) - self._start() - return code - - def input_code(self, prompt="Enter wormhole code: ", code_length=2): - lister = self._channel_manager.list_channels - # fetch the list of channels ahead of time, to give us a chance to - # discover the welcome message (and warn the user about an obsolete - # client) - initial_channelids = lister() - with self._timing.add("input code", waiting="user"): - code = codes.input_code_with_completion(prompt, - initial_channelids, lister, - code_length) - return code - - def set_code(self, code): # used for human-made pre-generated codes - if not isinstance(code, type(u"")): raise TypeError(type(code)) - if self.code is not None: raise UsageError - self._set_code_and_channelid(code) - self._start() - - def _set_code_and_channelid(self, code): - if self.code is not None: raise UsageError - self._timing.add("code established") - mo = re.search(r'^(\d+)-', code) - if not mo: - raise ValueError("code (%s) must start with NN-" % code) - self.code = code - channelid = int(mo.group(1)) - self._channel = self._channel_manager.connect(channelid) - monitor.add(self._channel) - - def _start(self): - # allocate the rest now too, so it can be serialized - self.sp = SPAKE2_Symmetric(to_bytes(self.code), - idSymmetric=to_bytes(self._appid)) - self.msg1 = self.sp.start() - - def derive_key(self, purpose, length=SecretBox.KEY_SIZE): - if not isinstance(purpose, type(u"")): raise TypeError(type(purpose)) - return HKDF(self.key, length, CTXinfo=to_bytes(purpose)) - - def _encrypt_data(self, key, data): - assert isinstance(key, type(b"")), type(key) - assert isinstance(data, type(b"")), type(data) - assert len(key) == SecretBox.KEY_SIZE, len(key) - box = SecretBox(key) - nonce = utils.random(SecretBox.NONCE_SIZE) - return box.encrypt(data, nonce) - - def _decrypt_data(self, key, encrypted): - assert isinstance(key, type(b"")), type(key) - assert isinstance(encrypted, type(b"")), type(encrypted) - assert len(key) == SecretBox.KEY_SIZE, len(key) - box = SecretBox(key) - data = box.decrypt(encrypted) - return data - - - def _get_key(self): - if not self.key: - self._channel.send(u"pake", self.msg1) - pake_msg = self._channel.get(u"pake") - - self.key = self.sp.finish(pake_msg) - self.verifier = self.derive_key(u"wormhole:verifier") - self._timing.add("key established") - - if not self._send_confirm: - return - confkey = self.derive_key(u"wormhole:confirmation") - nonce = os.urandom(CONFMSG_NONCE_LENGTH) - confmsg = make_confmsg(confkey, nonce) - self._channel.send(u"_confirm", confmsg) - - @close_on_error - def get_verifier(self): - if self._closed: raise UsageError - if self.code is None: raise UsageError - if self._channel is None: raise UsageError - self._get_key() - return self.verifier - - @close_on_error - def send_data(self, outbound_data, phase=u"data"): - if not isinstance(outbound_data, type(b"")): - raise TypeError(type(outbound_data)) - if not isinstance(phase, type(u"")): raise TypeError(type(phase)) - if self._closed: raise UsageError - if phase in self._sent_data: raise UsageError # only call this once - if phase.startswith(u"_"): raise UsageError # reserved for internals - if self.code is None: raise UsageError - if self._channel is None: raise UsageError - with self._timing.add("API send data", phase=phase): - # Without predefined roles, we can't derive predictably unique - # keys for each side, so we use the same key for both. We use - # random nonces to keep the messages distinct, and the Channel - # automatically ignores reflections. - self._sent_data.add(phase) - self._get_key() - data_key = self.derive_key(u"wormhole:phase:%s" % phase) - outbound_encrypted = self._encrypt_data(data_key, outbound_data) - self._channel.send(phase, outbound_encrypted) - - @close_on_error - def get_data(self, phase=u"data"): - if not isinstance(phase, type(u"")): raise TypeError(type(phase)) - if phase in self._got_data: raise UsageError # only call this once - if phase.startswith(u"_"): raise UsageError # reserved for internals - if self._closed: raise UsageError - if self.code is None: raise UsageError - if self._channel is None: raise UsageError - with self._timing.add("API get data", phase=phase): - self._got_data.add(phase) - self._get_key() - phases = [] - if not self._got_confirmation: - phases.append(u"_confirm") - phases.append(phase) - (got_phase, body) = self._channel.get_first_of(phases) - if got_phase == u"_confirm": - confkey = self.derive_key(u"wormhole:confirmation") - nonce = body[:CONFMSG_NONCE_LENGTH] - if body != make_confmsg(confkey, nonce): - raise WrongPasswordError - self._got_confirmation = True - (got_phase, body) = self._channel.get_first_of([phase]) - assert got_phase == phase - try: - data_key = self.derive_key(u"wormhole:phase:%s" % phase) - inbound_data = self._decrypt_data(data_key, body) - return inbound_data - except CryptoError: - raise WrongPasswordError - - def close(self, mood=u"happy"): - if not isinstance(mood, (type(None), type(u""))): - raise TypeError(type(mood)) - self._closed = True - if self._channel: - self._timing_started.finish(mood=mood) - c, self._channel = self._channel, None - monitor.close(c) - c.deallocate(mood) diff --git a/src/wormhole/test/test_blocking.py b/src/wormhole/test/test_blocking.py deleted file mode 100644 index 7ba8222..0000000 --- a/src/wormhole/test/test_blocking.py +++ /dev/null @@ -1,446 +0,0 @@ -from __future__ import print_function -import json -from twisted.trial import unittest -from twisted.internet.defer import gatherResults, succeed -from twisted.internet.threads import deferToThread -from ..blocking.transcribe import (Wormhole, UsageError, ChannelManager, - WrongPasswordError) -from ..blocking.eventsource import EventSourceFollower -from .common import ServerBase - -APPID = u"appid" - -class Channel(ServerBase, unittest.TestCase): - def ignore(self, welcome): - pass - - def test_allocate(self): - cm = ChannelManager(self.relayurl, APPID, u"side", self.ignore) - d = deferToThread(cm.list_channels) - def _got_channels(channels): - self.failUnlessEqual(channels, []) - d.addCallback(_got_channels) - d.addCallback(lambda _: deferToThread(cm.allocate)) - def _allocated(channelid): - self.failUnlessEqual(type(channelid), int) - self._channelid = channelid - d.addCallback(_allocated) - d.addCallback(lambda _: deferToThread(cm.connect, self._channelid)) - def _connected(c): - self._channel = c - d.addCallback(_connected) - d.addCallback(lambda _: deferToThread(self._channel.deallocate, - u"happy")) - return d - - def test_messages(self): - cm1 = ChannelManager(self.relayurl, APPID, u"side1", self.ignore) - cm2 = ChannelManager(self.relayurl, APPID, u"side2", self.ignore) - c1 = cm1.connect(1) - c2 = cm2.connect(1) - - d = succeed(None) - d.addCallback(lambda _: deferToThread(c1.send, u"phase1", b"msg1")) - d.addCallback(lambda _: deferToThread(c2.get, u"phase1")) - d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg1")) - d.addCallback(lambda _: deferToThread(c2.send, u"phase1", b"msg2")) - d.addCallback(lambda _: deferToThread(c1.get, u"phase1")) - d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg2")) - # it's legal to fetch a phase multiple times, should be idempotent - d.addCallback(lambda _: deferToThread(c1.get, u"phase1")) - d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg2")) - # deallocating one side is not enough to destroy the channel - d.addCallback(lambda _: deferToThread(c2.deallocate)) - def _not_yet(_): - self._rendezvous.prune() - self.failUnlessEqual(len(self._rendezvous._apps), 1) - d.addCallback(_not_yet) - # but deallocating both will make the messages go away - d.addCallback(lambda _: deferToThread(c1.deallocate, u"sad")) - def _gone(_): - self._rendezvous.prune() - self.failUnlessEqual(len(self._rendezvous._apps), 0) - d.addCallback(_gone) - - return d - - def test_get_multiple_phases(self): - cm1 = ChannelManager(self.relayurl, APPID, u"side1", self.ignore) - cm2 = ChannelManager(self.relayurl, APPID, u"side2", self.ignore) - c1 = cm1.connect(1) - c2 = cm2.connect(1) - - self.failUnlessRaises(TypeError, c2.get_first_of, u"phase1") - self.failUnlessRaises(TypeError, c2.get_first_of, [u"phase1", 7]) - - d = succeed(None) - d.addCallback(lambda _: deferToThread(c1.send, u"phase1", b"msg1")) - - d.addCallback(lambda _: deferToThread(c2.get_first_of, [u"phase1", - u"phase2"])) - d.addCallback(lambda phase_and_body: - self.failUnlessEqual(phase_and_body, - (u"phase1", b"msg1"))) - d.addCallback(lambda _: deferToThread(c2.get_first_of, [u"phase2", - u"phase1"])) - d.addCallback(lambda phase_and_body: - self.failUnlessEqual(phase_and_body, - (u"phase1", b"msg1"))) - - d.addCallback(lambda _: deferToThread(c1.send, u"phase2", b"msg2")) - d.addCallback(lambda _: deferToThread(c2.get, u"phase2")) - - # if both are present, it should prefer the first one we asked for - d.addCallback(lambda _: deferToThread(c2.get_first_of, [u"phase1", - u"phase2"])) - d.addCallback(lambda phase_and_body: - self.failUnlessEqual(phase_and_body, - (u"phase1", b"msg1"))) - d.addCallback(lambda _: deferToThread(c2.get_first_of, [u"phase2", - u"phase1"])) - d.addCallback(lambda phase_and_body: - self.failUnlessEqual(phase_and_body, - (u"phase2", b"msg2"))) - - return d - - def test_appid_independence(self): - APPID_A = u"appid_A" - APPID_B = u"appid_B" - cm1a = ChannelManager(self.relayurl, APPID_A, u"side1", self.ignore) - cm2a = ChannelManager(self.relayurl, APPID_A, u"side2", self.ignore) - c1a = cm1a.connect(1) - c2a = cm2a.connect(1) - cm1b = ChannelManager(self.relayurl, APPID_B, u"side1", self.ignore) - cm2b = ChannelManager(self.relayurl, APPID_B, u"side2", self.ignore) - c1b = cm1b.connect(1) - c2b = cm2b.connect(1) - - d = succeed(None) - d.addCallback(lambda _: deferToThread(c1a.send, u"phase1", b"msg1a")) - d.addCallback(lambda _: deferToThread(c1b.send, u"phase1", b"msg1b")) - d.addCallback(lambda _: deferToThread(c2a.get, u"phase1")) - d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg1a")) - d.addCallback(lambda _: deferToThread(c2b.get, u"phase1")) - d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg1b")) - return d - -class _DoBothMixin: - def doBoth(self, call1, call2): - f1 = call1[0] - f1args = call1[1:] - f2 = call2[0] - f2args = call2[1:] - return gatherResults([deferToThread(f1, *f1args), - deferToThread(f2, *f2args)], True) - -class Blocking(_DoBothMixin, ServerBase, unittest.TestCase): - # we need Twisted to run the server, but we run the sender and receiver - # with deferToThread() - - def test_basic(self): - w1 = Wormhole(APPID, self.relayurl) - w2 = Wormhole(APPID, self.relayurl) - d = deferToThread(w1.get_code) - def _got_code(code): - w2.set_code(code) - return self.doBoth([w1.send_data, b"data1"], - [w2.send_data, b"data2"]) - d.addCallback(_got_code) - def _sent(res): - return self.doBoth([w1.get_data], [w2.get_data]) - d.addCallback(_sent) - def _done(dl): - (dataX, dataY) = dl - self.assertEqual(dataX, b"data2") - self.assertEqual(dataY, b"data1") - return self.doBoth([w1.close], [w2.close]) - d.addCallback(_done) - return d - - def test_same_message(self): - # the two sides use random nonces for their messages, so it's ok for - # both to try and send the same body: they'll result in distinct - # encrypted messages - w1 = Wormhole(APPID, self.relayurl) - w2 = Wormhole(APPID, self.relayurl) - d = deferToThread(w1.get_code) - def _got_code(code): - w2.set_code(code) - return self.doBoth([w1.send_data, b"data"], - [w2.send_data, b"data"]) - d.addCallback(_got_code) - def _sent(res): - return self.doBoth([w1.get_data], [w2.get_data]) - d.addCallback(_sent) - def _done(dl): - (dataX, dataY) = dl - self.assertEqual(dataX, b"data") - self.assertEqual(dataY, b"data") - return self.doBoth([w1.close], [w2.close]) - d.addCallback(_done) - return d - - def test_interleaved(self): - w1 = Wormhole(APPID, self.relayurl) - w2 = Wormhole(APPID, self.relayurl) - d = deferToThread(w1.get_code) - def _got_code(code): - w2.set_code(code) - return self.doBoth([w1.send_data, b"data1"], - [w2.get_data]) - d.addCallback(_got_code) - def _sent(res): - (_, dataY) = res - self.assertEqual(dataY, b"data1") - return self.doBoth([w1.get_data], [w2.send_data, b"data2"]) - d.addCallback(_sent) - def _done(dl): - (dataX, _) = dl - self.assertEqual(dataX, b"data2") - return self.doBoth([w1.close], [w2.close]) - d.addCallback(_done) - return d - - def test_fixed_code(self): - w1 = Wormhole(APPID, self.relayurl) - w2 = Wormhole(APPID, self.relayurl) - w1.set_code(u"123-purple-elephant") - w2.set_code(u"123-purple-elephant") - d = self.doBoth([w1.send_data, b"data1"], [w2.send_data, b"data2"]) - def _sent(res): - return self.doBoth([w1.get_data], [w2.get_data]) - d.addCallback(_sent) - def _done(dl): - (dataX, dataY) = dl - self.assertEqual(dataX, b"data2") - self.assertEqual(dataY, b"data1") - return self.doBoth([w1.close], [w2.close]) - d.addCallback(_done) - return d - - def test_phases(self): - w1 = Wormhole(APPID, self.relayurl) - w2 = Wormhole(APPID, self.relayurl) - w1.set_code(u"123-purple-elephant") - w2.set_code(u"123-purple-elephant") - d = self.doBoth([w1.send_data, b"data1", u"p1"], - [w2.send_data, b"data2", u"p1"]) - d.addCallback(lambda _: - self.doBoth([w1.send_data, b"data3", u"p2"], - [w2.send_data, b"data4", u"p2"])) - d.addCallback(lambda _: - self.doBoth([w1.get_data, u"p2"], - [w2.get_data, u"p1"])) - def _got_1(dl): - (dataX, dataY) = dl - self.assertEqual(dataX, b"data4") - self.assertEqual(dataY, b"data1") - return self.doBoth([w1.get_data, u"p1"], - [w2.get_data, u"p2"]) - d.addCallback(_got_1) - def _got_2(dl): - (dataX, dataY) = dl - self.assertEqual(dataX, b"data2") - self.assertEqual(dataY, b"data3") - return self.doBoth([w1.close], [w2.close]) - d.addCallback(_got_2) - return d - - def test_wrong_password(self): - w1 = Wormhole(APPID, self.relayurl) - w2 = Wormhole(APPID, self.relayurl) - - # make sure we can detect WrongPasswordError even if one side only - # does get_data() and not send_data(), like "wormhole receive" does - d = deferToThread(w1.get_code) - d.addCallback(lambda code: w2.set_code(code+"not")) - - # w2 can't throw WrongPasswordError until it sees a CONFIRM message, - # and w1 won't send CONFIRM until it sees a PAKE message, which w2 - # won't send until we call get_data. So we need both sides to be - # running at the same time for this test. - def _w1_sends(): - w1.send_data(b"data1") - def _w2_gets(): - self.assertRaises(WrongPasswordError, w2.get_data) - d.addCallback(lambda _: self.doBoth([_w1_sends], [_w2_gets])) - - # and now w1 should have enough information to throw too - d.addCallback(lambda _: deferToThread(self.assertRaises, - WrongPasswordError, w1.get_data)) - def _done(_): - # both sides are closed automatically upon error, but it's still - # legal to call .close(), and should be idempotent - return self.doBoth([w1.close], [w2.close]) - d.addCallback(_done) - return d - - def test_no_confirm(self): - # newer versions (which check confirmations) should will work with - # older versions (that don't send confirmations) - w1 = Wormhole(APPID, self.relayurl) - w1._send_confirm = False - w2 = Wormhole(APPID, self.relayurl) - - d = deferToThread(w1.get_code) - d.addCallback(lambda code: w2.set_code(code)) - d.addCallback(lambda _: self.doBoth([w1.send_data, b"data1"], - [w2.get_data])) - d.addCallback(lambda dl: self.assertEqual(dl[1], b"data1")) - d.addCallback(lambda _: self.doBoth([w1.get_data], - [w2.send_data, b"data2"])) - d.addCallback(lambda dl: self.assertEqual(dl[0], b"data2")) - d.addCallback(lambda _: self.doBoth([w1.close], [w2.close])) - return d - - def test_verifier(self): - w1 = Wormhole(APPID, self.relayurl) - w2 = Wormhole(APPID, self.relayurl) - d = deferToThread(w1.get_code) - def _got_code(code): - w2.set_code(code) - return self.doBoth([w1.get_verifier], [w2.get_verifier]) - d.addCallback(_got_code) - def _check_verifier(res): - v1, v2 = res - self.failUnlessEqual(type(v1), type(b"")) - self.failUnlessEqual(v1, v2) - return self.doBoth([w1.send_data, b"data1"], - [w2.send_data, b"data2"]) - d.addCallback(_check_verifier) - def _sent(res): - return self.doBoth([w1.get_data], [w2.get_data]) - d.addCallback(_sent) - def _done(dl): - (dataX, dataY) = dl - self.assertEqual(dataX, b"data2") - self.assertEqual(dataY, b"data1") - return self.doBoth([w1.close], [w2.close]) - d.addCallback(_done) - return d - - def test_verifier_mismatch(self): - w1 = Wormhole(APPID, self.relayurl) - w2 = Wormhole(APPID, self.relayurl) - d = deferToThread(w1.get_code) - def _got_code(code): - w2.set_code(code+"not") - return self.doBoth([w1.get_verifier], [w2.get_verifier]) - d.addCallback(_got_code) - def _check_verifier(res): - v1, v2 = res - self.failUnlessEqual(type(v1), type(b"")) - self.failIfEqual(v1, v2) - return self.doBoth([w1.close], [w2.close]) - d.addCallback(_check_verifier) - return d - - def test_errors(self): - w1 = Wormhole(APPID, self.relayurl) - self.assertRaises(UsageError, w1.get_verifier) - self.assertRaises(UsageError, w1.get_data) - self.assertRaises(UsageError, w1.send_data, b"data") - w1.set_code(u"123-purple-elephant") - self.assertRaises(UsageError, w1.set_code, u"123-nope") - self.assertRaises(UsageError, w1.get_code) - w2 = Wormhole(APPID, self.relayurl) - d = deferToThread(w2.get_code) - def _done(code): - self.assertRaises(UsageError, w2.get_code) - return self.doBoth([w1.close], [w2.close]) - d.addCallback(_done) - return d - - def test_repeat_phases(self): - w1 = Wormhole(APPID, self.relayurl) - w1.set_code(u"123-purple-elephant") - w2 = Wormhole(APPID, self.relayurl) - w2.set_code(u"123-purple-elephant") - # we must let them establish a key before we can send data - d = self.doBoth([w1.get_verifier], [w2.get_verifier]) - d.addCallback(lambda _: - deferToThread(w1.send_data, b"data1", phase=u"1")) - def _sent(res): - # underscore-prefixed phases are reserved - self.assertRaises(UsageError, w1.send_data, b"data1", phase=u"_1") - self.assertRaises(UsageError, w1.get_data, phase=u"_1") - # you can't send twice to the same phase - self.assertRaises(UsageError, w1.send_data, b"data1", phase=u"1") - # but you can send to a different one - return deferToThread(w1.send_data, b"data2", phase=u"2") - d.addCallback(_sent) - d.addCallback(lambda _: deferToThread(w2.get_data, phase=u"1")) - def _got1(res): - self.failUnlessEqual(res, b"data1") - # and you can't read twice from the same phase - self.assertRaises(UsageError, w2.get_data, phase=u"1") - # but you can read from a different one - return deferToThread(w2.get_data, phase=u"2") - d.addCallback(_got1) - def _got2(res): - self.failUnlessEqual(res, b"data2") - return self.doBoth([w1.close], [w2.close]) - d.addCallback(_got2) - return d - - def test_serialize(self): - w1 = Wormhole(APPID, self.relayurl) - self.assertRaises(UsageError, w1.serialize) # too early - w2 = Wormhole(APPID, self.relayurl) - d = deferToThread(w1.get_code) - def _got_code(code): - self.assertRaises(UsageError, w2.serialize) # too early - w2.set_code(code) - w2.serialize() # ok - s = w1.serialize() - self.assertEqual(type(s), type("")) - unpacked = json.loads(s) # this is supposed to be JSON - self.assertEqual(type(unpacked), dict) - self.new_w1 = Wormhole.from_serialized(s) - return self.doBoth([self.new_w1.send_data, b"data1"], - [w2.send_data, b"data2"]) - d.addCallback(_got_code) - def _sent(res): - return self.doBoth(self.new_w1.get_data(), w2.get_data()) - d.addCallback(_sent) - def _done(dl): - (dataX, dataY) = dl - self.assertEqual(dataX, b"data2") - self.assertEqual(dataY, b"data1") - self.assertRaises(UsageError, w2.serialize) # too late - return self.doBoth([w1.close], [w2.close]) - d.addCallback(_done) - return d - test_serialize.skip = "not yet implemented for the blocking flavor" - -data1 = u"""\ -event: welcome -data: one and a -data: two -data:. - -data: three - -: this line is ignored -event: e2 -: this line is ignored too -i am a dataless field name -data: four - -""" - -class NoNetworkESF(EventSourceFollower): - def __init__(self, text): - self._lines_iter = iter(text.splitlines()) - -class EventSourceClient(unittest.TestCase): - def test_parser(self): - events = [] - f = NoNetworkESF(data1) - events = list(f.iter_events()) - self.failUnlessEqual(events, - [(u"welcome", u"one and a\ntwo\n."), - (u"message", u"three"), - (u"e2", u"four"), - ]) diff --git a/src/wormhole/test/test_interop.py b/src/wormhole/test/test_interop.py deleted file mode 100644 index 23daa9c..0000000 --- a/src/wormhole/test/test_interop.py +++ /dev/null @@ -1,57 +0,0 @@ -from __future__ import print_function -from twisted.trial import unittest -from twisted.internet.defer import gatherResults -from twisted.internet.threads import deferToThread -from ..twisted.transcribe import Wormhole as twisted_Wormhole -from ..blocking.transcribe import Wormhole as blocking_Wormhole -from .common import ServerBase - -# make sure the two implementations (Twisted-style and blocking-style) can -# interoperate - -APPID = u"appid" - -class Basic(ServerBase, unittest.TestCase): - - def doBoth(self, call1, d2): - f1 = call1[0] - f1args = call1[1:] - return gatherResults([deferToThread(f1, *f1args), d2], True) - - def test_twisted_to_blocking(self): - tw = twisted_Wormhole(APPID, self.relayurl) - bw = blocking_Wormhole(APPID, self.relayurl) - d = tw.get_code() - def _got_code(code): - bw.set_code(code) - return self.doBoth([bw.send_data, b"data2"], tw.send_data(b"data1")) - d.addCallback(_got_code) - def _sent(res): - return self.doBoth([bw.get_data], tw.get_data()) - d.addCallback(_sent) - def _done(dl): - (dataX, dataY) = dl - self.assertEqual(dataX, b"data1") - self.assertEqual(dataY, b"data2") - return self.doBoth([bw.close], tw.close()) - d.addCallback(_done) - return d - - def test_blocking_to_twisted(self): - bw = blocking_Wormhole(APPID, self.relayurl) - tw = twisted_Wormhole(APPID, self.relayurl) - d = deferToThread(bw.get_code) - def _got_code(code): - tw.set_code(code) - return self.doBoth([bw.send_data, b"data1"], tw.send_data(b"data2")) - d.addCallback(_got_code) - def _sent(res): - return self.doBoth([bw.get_data], tw.get_data()) - d.addCallback(_sent) - def _done(dl): - (dataX, dataY) = dl - self.assertEqual(dataX, b"data2") - self.assertEqual(dataY, b"data1") - return self.doBoth([bw.close], tw.close()) - d.addCallback(_done) - return d From 501af4b4ec8c7e8f4164ba4390885f6b58172211 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Thu, 12 May 2016 15:52:22 -0700 Subject: [PATCH 03/51] rename send_data/get_data to just send/get --- src/wormhole/cli/cmd_receive.py | 8 +-- src/wormhole/cli/cmd_send.py | 6 +-- src/wormhole/test/test_twisted.py | 84 ++++++++++++++---------------- src/wormhole/twisted/transcribe.py | 14 ++--- 4 files changed, 52 insertions(+), 60 deletions(-) diff --git a/src/wormhole/cli/cmd_receive.py b/src/wormhole/cli/cmd_receive.py index 0de21cd..b39e3c6 100644 --- a/src/wormhole/cli/cmd_receive.py +++ b/src/wormhole/cli/cmd_receive.py @@ -90,7 +90,7 @@ class TwistedReceiver: raise RespondError("unknown offer type") except RespondError as r: data = json.dumps({"error": r.response}).encode("utf-8") - yield w.send_data(data) + yield w.send(data) raise TransferError(r.response) returnValue(None) @@ -113,7 +113,7 @@ class TwistedReceiver: @inlineCallbacks def get_data(self, w): # this may raise WrongPasswordError - them_bytes = yield w.get_data() + them_bytes = yield w.get() them_d = json.loads(them_bytes.decode("utf-8")) if "error" in them_d: raise TransferError(them_d["error"]) @@ -124,7 +124,7 @@ class TwistedReceiver: # we're receiving a text message self.msg(them_d["message"]) data = json.dumps({"message_ack": "ok"}).encode("utf-8") - yield w.send_data(data, wait=True) + yield w.send(data, wait=True) def handle_file(self, them_d): file_data = them_d["file"] @@ -199,7 +199,7 @@ class TwistedReceiver: "relay_connection_hints": relay_hints, }, }).encode("utf-8") - yield w.send_data(data) + yield w.send(data) # now receive the rest of the owl tdata = them_d["transit"] diff --git a/src/wormhole/cli/cmd_send.py b/src/wormhole/cli/cmd_send.py index 9904c58..cdc5cec 100644 --- a/src/wormhole/cli/cmd_send.py +++ b/src/wormhole/cli/cmd_send.py @@ -94,17 +94,17 @@ def _send(reactor, w, args, phase1, fd_to_send, tor_manager): if ok.lower() == "no": err = "sender rejected verification check, abandoned transfer" reject_data = json.dumps({"error": err}).encode("utf-8") - yield w.send_data(reject_data) + yield w.send(reject_data) raise TransferError(err) if fd_to_send is not None: transit_key = w.derive_key(APPID+"/transit-key") transit_sender.set_transit_key(transit_key) my_phase1_bytes = json.dumps(phase1).encode("utf-8") - yield w.send_data(my_phase1_bytes) + yield w.send(my_phase1_bytes) # this may raise WrongPasswordError - them_phase1_bytes = yield w.get_data() + them_phase1_bytes = yield w.get() them_phase1 = json.loads(them_phase1_bytes.decode("utf-8")) diff --git a/src/wormhole/test/test_twisted.py b/src/wormhole/test/test_twisted.py index 1c186da..2d4941a 100644 --- a/src/wormhole/test/test_twisted.py +++ b/src/wormhole/test/test_twisted.py @@ -18,8 +18,8 @@ class Basic(ServerBase, unittest.TestCase): w2 = Wormhole(APPID, self.relayurl) code = yield w1.get_code() w2.set_code(code) - yield self.doBoth(w1.send_data(b"data1"), w2.send_data(b"data2")) - dl = yield self.doBoth(w1.get_data(), w2.get_data()) + yield self.doBoth(w1.send(b"data1"), w2.send(b"data2")) + dl = yield self.doBoth(w1.get(), w2.get()) (dataX, dataY) = dl self.assertEqual(dataX, b"data2") self.assertEqual(dataY, b"data1") @@ -34,8 +34,8 @@ class Basic(ServerBase, unittest.TestCase): w2 = Wormhole(APPID, self.relayurl) code = yield w1.get_code() w2.set_code(code) - yield self.doBoth(w1.send_data(b"data"), w2.send_data(b"data")) - dl = yield self.doBoth(w1.get_data(), w2.get_data()) + yield self.doBoth(w1.send(b"data"), w2.send(b"data")) + dl = yield self.doBoth(w1.get(), w2.get()) (dataX, dataY) = dl self.assertEqual(dataX, b"data") self.assertEqual(dataY, b"data") @@ -47,10 +47,10 @@ class Basic(ServerBase, unittest.TestCase): w2 = Wormhole(APPID, self.relayurl) code = yield w1.get_code() w2.set_code(code) - res = yield self.doBoth(w1.send_data(b"data1"), w2.get_data()) + res = yield self.doBoth(w1.send(b"data1"), w2.get()) (_, dataY) = res self.assertEqual(dataY, b"data1") - dl = yield self.doBoth(w1.get_data(), w2.send_data(b"data2")) + dl = yield self.doBoth(w1.get(), w2.send(b"data2")) (dataX, _) = dl self.assertEqual(dataX, b"data2") yield self.doBoth(w1.close(), w2.close()) @@ -61,8 +61,8 @@ class Basic(ServerBase, unittest.TestCase): w2 = Wormhole(APPID, self.relayurl) w1.set_code(u"123-purple-elephant") w2.set_code(u"123-purple-elephant") - yield self.doBoth(w1.send_data(b"data1"), w2.send_data(b"data2")) - dl = yield self.doBoth(w1.get_data(), w2.get_data()) + yield self.doBoth(w1.send(b"data1"), w2.send(b"data2")) + dl = yield self.doBoth(w1.get(), w2.get()) (dataX, dataY) = dl self.assertEqual(dataX, b"data2") self.assertEqual(dataY, b"data1") @@ -75,17 +75,13 @@ class Basic(ServerBase, unittest.TestCase): w2 = Wormhole(APPID, self.relayurl) w1.set_code(u"123-purple-elephant") w2.set_code(u"123-purple-elephant") - yield self.doBoth(w1.send_data(b"data1", u"p1"), - w2.send_data(b"data2", u"p1")) - yield self.doBoth(w1.send_data(b"data3", u"p2"), - w2.send_data(b"data4", u"p2")) - dl = yield self.doBoth(w1.get_data(u"p2"), - w2.get_data(u"p1")) + yield self.doBoth(w1.send(b"data1", u"p1"), w2.send(b"data2", u"p1")) + yield self.doBoth(w1.send(b"data3", u"p2"), w2.send(b"data4", u"p2")) + dl = yield self.doBoth(w1.get(u"p2"), w2.get(u"p1")) (dataX, dataY) = dl self.assertEqual(dataX, b"data4") self.assertEqual(dataY, b"data1") - dl = yield self.doBoth(w1.get_data(u"p1"), - w2.get_data(u"p2")) + dl = yield self.doBoth(w1.get(u"p1"), w2.get(u"p2")) (dataX, dataY) = dl self.assertEqual(dataX, b"data2") self.assertEqual(dataY, b"data3") @@ -100,21 +96,21 @@ class Basic(ServerBase, unittest.TestCase): # w2 can't throw WrongPasswordError until it sees a CONFIRM message, # and w1 won't send CONFIRM until it sees a PAKE message, which w2 - # won't send until we call get_data. So we need both sides to be + # won't send until we call get. So we need both sides to be # running at the same time for this test. - d1 = w1.send_data(b"data1") + d1 = w1.send(b"data1") # at this point, w1 should be waiting for w2.PAKE - yield self.assertFailure(w2.get_data(), WrongPasswordError) + yield self.assertFailure(w2.get(), WrongPasswordError) # * w2 will send w2.PAKE, wait for (and get) w1.PAKE, compute a key, # send w2.CONFIRM, then wait for w1.DATA. # * w1 will get w2.PAKE, compute a key, send w1.CONFIRM. # * w1 might also get w2.CONFIRM, and may notice the error before it # sends w1.CONFIRM, in which case the wait=True will signal an - # error inside _get_master_key() (inside send_data), and d1 will + # error inside _get_master_key() (inside send), and d1 will # errback. # * but w1 might not see w2.CONFIRM yet, in which case it won't - # errback until we do w1.get_data() + # errback until we do w1.get() # * w2 gets w1.CONFIRM, notices the error, records it. # * w2 (waiting for w1.DATA) wakes up, sees the error, throws # * meanwhile w1 finishes sending its data. w2.CONFIRM may or may not @@ -124,20 +120,20 @@ class Basic(ServerBase, unittest.TestCase): except WrongPasswordError: pass - # When we ask w1 to get_data(), one of two things might happen: + # When we ask w1 to get(), one of two things might happen: # * if w2.CONFIRM arrived already, it will have recorded the error. - # When w1.get_data() sleeps (waiting for w2.DATA), we'll notice the + # When w1.get() sleeps (waiting for w2.DATA), we'll notice the # error before sleeping, and throw WrongPasswordError # * if w2.CONFIRM hasn't arrived yet, we'll sleep. When w2.CONFIRM # arrives, we notice and record the error, and wake up, and throw - # Note that we didn't do w2.send_data(), so we're hoping that w1 will + # Note that we didn't do w2.send(), so we're hoping that w1 will # have enough information to detect the error before it sleeps # (waiting for w2.DATA). Checking for the error both before sleeping # and after waking up makes this happen. # so now w1 should have enough information to throw too - yield self.assertFailure(w1.get_data(), WrongPasswordError) + yield self.assertFailure(w1.get(), WrongPasswordError) # both sides are closed automatically upon error, but it's still # legal to call .close(), and should be idempotent @@ -153,9 +149,9 @@ class Basic(ServerBase, unittest.TestCase): code = yield w1.get_code() w2.set_code(code) - dl = yield self.doBoth(w1.send_data(b"data1"), w2.get_data()) + dl = yield self.doBoth(w1.send(b"data1"), w2.get()) self.assertEqual(dl[1], b"data1") - dl = yield self.doBoth(w1.get_data(), w2.send_data(b"data2")) + dl = yield self.doBoth(w1.get(), w2.send(b"data2")) self.assertEqual(dl[0], b"data2") yield self.doBoth(w1.close(), w2.close()) @@ -169,8 +165,8 @@ class Basic(ServerBase, unittest.TestCase): v1, v2 = res self.failUnlessEqual(type(v1), type(b"")) self.failUnlessEqual(v1, v2) - yield self.doBoth(w1.send_data(b"data1"), w2.send_data(b"data2")) - dl = yield self.doBoth(w1.get_data(), w2.get_data()) + yield self.doBoth(w1.send(b"data1"), w2.send(b"data2")) + dl = yield self.doBoth(w1.get(), w2.get()) (dataX, dataY) = dl self.assertEqual(dataX, b"data2") self.assertEqual(dataY, b"data1") @@ -180,8 +176,8 @@ class Basic(ServerBase, unittest.TestCase): def test_errors(self): w1 = Wormhole(APPID, self.relayurl) yield self.assertFailure(w1.get_verifier(), UsageError) - yield self.assertFailure(w1.send_data(b"data"), UsageError) - yield self.assertFailure(w1.get_data(), UsageError) + yield self.assertFailure(w1.send(b"data"), UsageError) + yield self.assertFailure(w1.get(), UsageError) w1.set_code(u"123-purple-elephant") yield self.assertRaises(UsageError, w1.set_code, u"123-nope") yield self.assertFailure(w1.get_code(), UsageError) @@ -198,22 +194,20 @@ class Basic(ServerBase, unittest.TestCase): w2.set_code(u"123-purple-elephant") # we must let them establish a key before we can send data yield self.doBoth(w1.get_verifier(), w2.get_verifier()) - yield w1.send_data(b"data1", phase=u"1") + yield w1.send(b"data1", phase=u"1") # underscore-prefixed phases are reserved - yield self.assertFailure(w1.send_data(b"data1", phase=u"_1"), - UsageError) - yield self.assertFailure(w1.get_data(phase=u"_1"), UsageError) + yield self.assertFailure(w1.send(b"data1", phase=u"_1"), UsageError) + yield self.assertFailure(w1.get(phase=u"_1"), UsageError) # you can't send twice to the same phase - yield self.assertFailure(w1.send_data(b"data1", phase=u"1"), - UsageError) + yield self.assertFailure(w1.send(b"data1", phase=u"1"), UsageError) # but you can send to a different one - yield w1.send_data(b"data2", phase=u"2") - res = yield w2.get_data(phase=u"1") + yield w1.send(b"data2", phase=u"2") + res = yield w2.get(phase=u"1") self.failUnlessEqual(res, b"data1") # and you can't read twice from the same phase - yield self.assertFailure(w2.get_data(phase=u"1"), UsageError) + yield self.assertFailure(w2.get(phase=u"1"), UsageError) # but you can read from a different one - res = yield w2.get_data(phase=u"2") + res = yield w2.get(phase=u"2") self.failUnlessEqual(res, b"data2") yield self.doBoth(w1.close(), w2.close()) @@ -232,12 +226,10 @@ class Basic(ServerBase, unittest.TestCase): self.assertEqual(type(unpacked), dict) self.new_w1 = Wormhole.from_serialized(s) - yield self.doBoth(self.new_w1.send_data(b"data1"), - w2.send_data(b"data2")) - dl = yield self.doBoth(self.new_w1.get_data(), w2.get_data()) + yield self.doBoth(self.new_w1.send(b"data1"), w2.send(b"data2")) + dl = yield self.doBoth(self.new_w1.get(), w2.get()) (dataX, dataY) = dl self.assertEqual((dataX, dataY), (b"data2", b"data1")) self.assertRaises(UsageError, w2.serialize) # too late - yield gatherResults([w1.close(), w2.close(), self.new_w1.close()], - True) + yield gatherResults([w1.close(), w2.close(), self.new_w1.close()], True) diff --git a/src/wormhole/twisted/transcribe.py b/src/wormhole/twisted/transcribe.py index 199fa50..63f3e6b 100644 --- a/src/wormhole/twisted/transcribe.py +++ b/src/wormhole/twisted/transcribe.py @@ -332,7 +332,7 @@ class Wormhole: def serialize(self): # I can only be serialized after get_code/set_code and before - # get_verifier/get_data + # get_verifier/get if self._code is None: raise UsageError if self._key is not None: raise UsageError if self._sent_phases: raise UsageError @@ -448,7 +448,7 @@ class Wormhole: def derive_key(self, purpose, length=SecretBox.KEY_SIZE): if not isinstance(purpose, type(u"")): raise TypeError(type(purpose)) if self._key is None: - # call after get_verifier() or get_data() + # call after get_verifier() or get() raise UsageError return HKDF(self._key, length, CTXinfo=to_bytes(purpose)) @@ -469,17 +469,17 @@ class Wormhole: return data @inlineCallbacks - def send_data(self, outbound_data, phase=u"data", wait=False): + def send(self, outbound_data, phase=u"data", wait=False): if not isinstance(outbound_data, type(b"")): raise TypeError(type(outbound_data)) if not isinstance(phase, type(u"")): raise TypeError(type(phase)) if self._closed: raise UsageError if self._code is None: - raise UsageError("You must set_code() before send_data()") + raise UsageError("You must set_code() before send()") if phase.startswith(u"_"): raise UsageError # reserved for internals if phase in self._sent_phases: raise UsageError # only call this once self._sent_phases.add(phase) - with self._timing.add("API send_data", phase=phase, wait=wait): + with self._timing.add("API send", phase=phase, wait=wait): # Without predefined roles, we can't derive predictably unique # keys for each side, so we use the same key for both. We use # random nonces to keep the messages distinct, and we @@ -490,14 +490,14 @@ class Wormhole: yield self._msg_send(phase, outbound_encrypted, wait) @inlineCallbacks - def get_data(self, phase=u"data"): + def get(self, phase=u"data"): if not isinstance(phase, type(u"")): raise TypeError(type(phase)) if self._closed: raise UsageError if self._code is None: raise UsageError if phase.startswith(u"_"): raise UsageError # reserved for internals if phase in self._got_phases: raise UsageError # only call this once self._got_phases.add(phase) - with self._timing.add("API get_data", phase=phase): + with self._timing.add("API get", phase=phase): yield self._get_master_key() body = yield self._msg_get(phase) # we can wait a long time here try: From d0ef53fc4d154e533be8c69dc0f7da129e446af7 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Thu, 12 May 2016 16:16:05 -0700 Subject: [PATCH 04/51] remove phase= from the Wormhole API Phase are now implicit and numbered. --- src/wormhole/test/test_twisted.py | 61 ++++++++++++++---------------- src/wormhole/twisted/transcribe.py | 38 +++++++++---------- 2 files changed, 46 insertions(+), 53 deletions(-) diff --git a/src/wormhole/test/test_twisted.py b/src/wormhole/test/test_twisted.py index 2d4941a..aa1c38d 100644 --- a/src/wormhole/test/test_twisted.py +++ b/src/wormhole/test/test_twisted.py @@ -70,23 +70,43 @@ class Basic(ServerBase, unittest.TestCase): @inlineCallbacks - def test_phases(self): + def test_multiple_messages(self): w1 = Wormhole(APPID, self.relayurl) w2 = Wormhole(APPID, self.relayurl) w1.set_code(u"123-purple-elephant") w2.set_code(u"123-purple-elephant") - yield self.doBoth(w1.send(b"data1", u"p1"), w2.send(b"data2", u"p1")) - yield self.doBoth(w1.send(b"data3", u"p2"), w2.send(b"data4", u"p2")) - dl = yield self.doBoth(w1.get(u"p2"), w2.get(u"p1")) - (dataX, dataY) = dl - self.assertEqual(dataX, b"data4") - self.assertEqual(dataY, b"data1") - dl = yield self.doBoth(w1.get(u"p1"), w2.get(u"p2")) + yield self.doBoth(w1.send(b"data1"), w2.send(b"data2")) + yield self.doBoth(w1.send(b"data3"), w2.send(b"data4")) + dl = yield self.doBoth(w1.get(), w2.get()) (dataX, dataY) = dl self.assertEqual(dataX, b"data2") + self.assertEqual(dataY, b"data1") + dl = yield self.doBoth(w1.get(), w2.get()) + (dataX, dataY) = dl + self.assertEqual(dataX, b"data4") self.assertEqual(dataY, b"data3") yield self.doBoth(w1.close(), w2.close()) + @inlineCallbacks + def test_multiple_messages_2(self): + w1 = Wormhole(APPID, self.relayurl) + w2 = Wormhole(APPID, self.relayurl) + w1.set_code(u"123-purple-elephant") + w2.set_code(u"123-purple-elephant") + # TODO: set_code should be sufficient to kick things off, but for now + # we must also let both sides do at least one send() or get() + yield self.doBoth(w1.send(b"data1"), w2.send(b"ignored")) + yield w1.get() + yield w1.send(b"data2") + yield w1.send(b"data3") + data = yield w2.get() + self.assertEqual(data, b"data1") + data = yield w2.get() + self.assertEqual(data, b"data2") + data = yield w2.get() + self.assertEqual(data, b"data3") + yield self.doBoth(w1.close(), w2.close()) + @inlineCallbacks def test_wrong_password(self): w1 = Wormhole(APPID, self.relayurl) @@ -186,31 +206,6 @@ class Basic(ServerBase, unittest.TestCase): yield self.assertFailure(w2.get_code(), UsageError) yield self.doBoth(w1.close(), w2.close()) - @inlineCallbacks - def test_repeat_phases(self): - w1 = Wormhole(APPID, self.relayurl) - w1.set_code(u"123-purple-elephant") - w2 = Wormhole(APPID, self.relayurl) - w2.set_code(u"123-purple-elephant") - # we must let them establish a key before we can send data - yield self.doBoth(w1.get_verifier(), w2.get_verifier()) - yield w1.send(b"data1", phase=u"1") - # underscore-prefixed phases are reserved - yield self.assertFailure(w1.send(b"data1", phase=u"_1"), UsageError) - yield self.assertFailure(w1.get(phase=u"_1"), UsageError) - # you can't send twice to the same phase - yield self.assertFailure(w1.send(b"data1", phase=u"1"), UsageError) - # but you can send to a different one - yield w1.send(b"data2", phase=u"2") - res = yield w2.get(phase=u"1") - self.failUnlessEqual(res, b"data1") - # and you can't read twice from the same phase - yield self.assertFailure(w2.get(phase=u"1"), UsageError) - # but you can read from a different one - res = yield w2.get(phase=u"2") - self.failUnlessEqual(res, b"data2") - yield self.doBoth(w1.close(), w2.close()) - @inlineCallbacks def test_serialize(self): w1 = Wormhole(APPID, self.relayurl) diff --git a/src/wormhole/twisted/transcribe.py b/src/wormhole/twisted/transcribe.py index 63f3e6b..b85456d 100644 --- a/src/wormhole/twisted/transcribe.py +++ b/src/wormhole/twisted/transcribe.py @@ -75,10 +75,11 @@ class Wormhole: self._channelid = None self._key = None self._started_get_code = False - self._sent_messages = set() # (phase, body_bytes) - self._delivered_messages = set() # (phase, body_bytes) + self._next_outbound_phase = 0 + self._sent_messages = {} # phase -> body_bytes + self._delivered_messages = set() # phase + self._next_inbound_phase = 0 self._received_messages = {} # phase -> body_bytes - self._sent_phases = set() # phases, to prohibit double-send self._got_phases = set() # phases, to prohibit double-read self._sleepers = [] self._confirmation_failed = False @@ -335,7 +336,7 @@ class Wormhole: # get_verifier/get if self._code is None: raise UsageError if self._key is not None: raise UsageError - if self._sent_phases: raise UsageError + if self._sent_messages: raise UsageError if self._got_phases: raise UsageError data = { "appid": self._appid, @@ -400,14 +401,15 @@ class Wormhole: @inlineCallbacks def _msg_send(self, phase, body, wait=False): - self._sent_messages.add( (phase, body) ) + if phase in self._sent_messages: raise UsageError + self._sent_messages[phase] = body # TODO: retry on failure, with exponential backoff. We're guarding # against the rendezvous server being temporarily offline. t = self._timing.add("add", phase=phase, wait=wait) yield self._ws_send(u"add", phase=phase, body=hexlify(body).decode("ascii")) if wait: - while (phase, body) not in self._delivered_messages: + while phase not in self._delivered_messages: yield self._sleep() t.finish() @@ -415,8 +417,8 @@ class Wormhole: m = msg["message"] phase = m["phase"] body = unhexlify(m["body"].encode("ascii")) - if (phase, body) in self._sent_messages: - self._delivered_messages.add( (phase, body) ) # ack by server + if phase in self._sent_messages and self._sent_messages[phase] == body: + self._delivered_messages.add(phase) # ack by server self._wakeup() return # ignore echoes of our outbound messages if phase in self._received_messages: @@ -469,39 +471,35 @@ class Wormhole: return data @inlineCallbacks - def send(self, outbound_data, phase=u"data", wait=False): + def send(self, outbound_data, wait=False): if not isinstance(outbound_data, type(b"")): raise TypeError(type(outbound_data)) - if not isinstance(phase, type(u"")): raise TypeError(type(phase)) if self._closed: raise UsageError if self._code is None: raise UsageError("You must set_code() before send()") - if phase.startswith(u"_"): raise UsageError # reserved for internals - if phase in self._sent_phases: raise UsageError # only call this once - self._sent_phases.add(phase) + phase = self._next_outbound_phase + self._next_outbound_phase += 1 with self._timing.add("API send", phase=phase, wait=wait): # Without predefined roles, we can't derive predictably unique # keys for each side, so we use the same key for both. We use # random nonces to keep the messages distinct, and we # automatically ignore reflections. yield self._get_master_key() - data_key = self.derive_key(u"wormhole:phase:%s" % phase) + data_key = self.derive_key(u"wormhole:phase:%d" % phase) outbound_encrypted = self._encrypt_data(data_key, outbound_data) yield self._msg_send(phase, outbound_encrypted, wait) @inlineCallbacks - def get(self, phase=u"data"): - if not isinstance(phase, type(u"")): raise TypeError(type(phase)) + def get(self): if self._closed: raise UsageError if self._code is None: raise UsageError - if phase.startswith(u"_"): raise UsageError # reserved for internals - if phase in self._got_phases: raise UsageError # only call this once - self._got_phases.add(phase) + phase = self._next_inbound_phase + self._next_inbound_phase += 1 with self._timing.add("API get", phase=phase): yield self._get_master_key() body = yield self._msg_get(phase) # we can wait a long time here try: - data_key = self.derive_key(u"wormhole:phase:%s" % phase) + data_key = self.derive_key(u"wormhole:phase:%d" % phase) inbound_data = self._decrypt_data(data_key, body) returnValue(inbound_data) except CryptoError: From d87aba40e498bc9489fd281ac39ecdd65bb91cbe Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Thu, 12 May 2016 16:18:34 -0700 Subject: [PATCH 05/51] rename _confirm message to just "confirm" --- src/wormhole/twisted/transcribe.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/wormhole/twisted/transcribe.py b/src/wormhole/twisted/transcribe.py index b85456d..8b8fc47 100644 --- a/src/wormhole/twisted/transcribe.py +++ b/src/wormhole/twisted/transcribe.py @@ -376,7 +376,7 @@ class Wormhole: # receiver will see this a round-trip after they send their PAKE # (because the sender is using wait=True inside _get_master_key, # below: otherwise the sender might go do some blocking call). - yield self._msg_get(u"_confirm") + yield self._msg_get(u"confirm") returnValue(self._verifier) @inlineCallbacks @@ -397,7 +397,7 @@ class Wormhole: confkey = self.derive_key(u"wormhole:confirmation") nonce = os.urandom(CONFMSG_NONCE_LENGTH) confmsg = make_confmsg(confkey, nonce) - yield self._msg_send(u"_confirm", confmsg, wait=True) + yield self._msg_send(u"confirm", confmsg, wait=True) @inlineCallbacks def _msg_send(self, phase, body, wait=False): @@ -426,7 +426,7 @@ class Wormhole: err = ServerError("got duplicate phase %s" % phase, self._ws_url) return self._signal_error(err) self._received_messages[phase] = body - if phase == u"_confirm": + if phase == u"confirm": # TODO: we might not have a master key yet, if the caller wasn't # waiting in _get_master_key() when a back-to-back pake+_confirm # message pair arrived. From 104ef44d537810e9eb0da250f1bb592d9d720406 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Thu, 12 May 2016 16:36:48 -0700 Subject: [PATCH 06/51] provide wormhole() as a function, rather than a class constructor You must always provide a reactor= argument. In the future, omitting the reactor= argument is how you ask for a blocking Wormhole. --- docs/api.md | 21 ++++++++++----------- src/wormhole/cli/cmd_receive.py | 7 +++---- src/wormhole/cli/cmd_send.py | 6 +++--- src/wormhole/test/test_twisted.py | 9 +++++++-- src/wormhole/twisted/transcribe.py | 20 ++++++++++++++------ 5 files changed, 37 insertions(+), 26 deletions(-) diff --git a/docs/api.md b/docs/api.md index 53d7079..2efb436 100644 --- a/docs/api.md +++ b/docs/api.md @@ -77,7 +77,7 @@ before being closed. To make it easier to call `close()`, the blocking Wormhole objects can be used as a context manager. Just put your code in the body of a `with -Wormhole(ARGS) as w:` statement, and `close()` will automatically be called +wormhole(ARGS) as w:` statement, and `close()` will automatically be called when the block exits (either successfully or due to an exception). ## Examples @@ -85,10 +85,10 @@ when the block exits (either successfully or due to an exception). The synchronous+blocking flow looks like this: ```python -from wormhole.blocking.transcribe import Wormhole +from wormhole.blocking.transcribe import wormhole from wormhole.public_relay import RENDEZVOUS_RELAY mydata = b"initiator's data" -with Wormhole(u"appid", RENDEZVOUS_RELAY) as i: +with wormhole(u"appid", RENDEZVOUS_RELAY) as i: code = i.get_code() print("Invitation Code: %s" % code) i.send(mydata) @@ -98,11 +98,11 @@ with Wormhole(u"appid", RENDEZVOUS_RELAY) as i: ```python import sys -from wormhole.blocking.transcribe import Wormhole +from wormhole.blocking.transcribe import wormhole from wormhole.public_relay import RENDEZVOUS_RELAY mydata = b"receiver's data" code = sys.argv[1] -with Wormhole(u"appid", RENDEZVOUS_RELAY) as r: +with wormhole(u"appid", RENDEZVOUS_RELAY) as r: r.set_code(code) r.send(mydata) theirdata = r.get() @@ -116,9 +116,9 @@ The Twisted-friendly flow looks like this: ```python from twisted.internet import reactor from wormhole.public_relay import RENDEZVOUS_RELAY -from wormhole.twisted.transcribe import Wormhole +from wormhole.twisted.transcribe import wormhole outbound_message = b"outbound data" -w1 = Wormhole(u"appid", RENDEZVOUS_RELAY) +w1 = wormhole(u"appid", RENDEZVOUS_RELAY, reactor) d = w1.get_code() def _got_code(code): print "Invitation Code:", code @@ -136,7 +136,7 @@ reactor.run() On the other side, you call `set_code()` instead of waiting for `get_code()`: ```python -w2 = Wormhole(u"appid", RENDEZVOUS_RELAY) +w2 = wormhole(u"appid", RENDEZVOUS_RELAY, reactor) w2.set_code(code) d = w2.send(my_message) ... @@ -253,9 +253,8 @@ You may not be able to hold the Wormhole object in memory for the whole sync process: maybe you allow it to wait for several days, but the program will be restarted during that time. To support this, you can persist the state of the object by calling `data = w.serialize()`, which will return a printable -bytestring (the JSON-encoding of a small dictionary). To restore, use the -`from_serialized(data)` classmethod (e.g. `w = -Wormhole.from_serialized(data)`). +bytestring (the JSON-encoding of a small dictionary). To restore, use `w = +wormhole_from_serialized(data, reactor)`. There is exactly one point at which you can serialize the wormhole: *after* establishing the invitation code, but before waiting for `get_verifier()` or diff --git a/src/wormhole/cli/cmd_receive.py b/src/wormhole/cli/cmd_receive.py index b39e3c6..5d308a3 100644 --- a/src/wormhole/cli/cmd_receive.py +++ b/src/wormhole/cli/cmd_receive.py @@ -3,7 +3,7 @@ import os, sys, json, binascii, six, tempfile, zipfile from tqdm import tqdm from twisted.internet import reactor from twisted.internet.defer import inlineCallbacks, returnValue -from ..twisted.transcribe import Wormhole +from ..twisted.transcribe import wormhole from ..twisted.transit import TransitReceiver from ..errors import TransferError @@ -45,9 +45,8 @@ class TwistedReceiver: # can lazy-provide an endpoint, and overlap the startup process # with the user handing off the wormhole code yield tor_manager.start() - w = Wormhole(APPID, self.args.relay_url, tor_manager, - timing=self.args.timing, - reactor=self._reactor) + w = wormhole(APPID, self.args.relay_url, self._reactor, + tor_manager, timing=self.args.timing) # I wanted to do this instead: # # try: diff --git a/src/wormhole/cli/cmd_send.py b/src/wormhole/cli/cmd_send.py index cdc5cec..8ad6973 100644 --- a/src/wormhole/cli/cmd_send.py +++ b/src/wormhole/cli/cmd_send.py @@ -5,7 +5,7 @@ from twisted.protocols import basic from twisted.internet import reactor from twisted.internet.defer import inlineCallbacks, returnValue from ..errors import TransferError -from ..twisted.transcribe import Wormhole +from ..twisted.transcribe import wormhole from ..twisted.transit import TransitSender APPID = u"lothar.com/wormhole/text-or-file-xfer" @@ -49,8 +49,8 @@ def send(args, reactor=reactor): # user handing off the wormhole code yield tor_manager.start() - w = Wormhole(APPID, args.relay_url, tor_manager, timing=args.timing, - reactor=reactor) + w = wormhole(APPID, args.relay_url, reactor, tor_manager, + timing=args.timing) d = _send(reactor, w, args, phase1, fd_to_send, tor_manager) d.addBoth(w.close) diff --git a/src/wormhole/test/test_twisted.py b/src/wormhole/test/test_twisted.py index aa1c38d..da87c97 100644 --- a/src/wormhole/test/test_twisted.py +++ b/src/wormhole/test/test_twisted.py @@ -1,12 +1,17 @@ from __future__ import print_function import json from twisted.trial import unittest +from twisted.internet import reactor from twisted.internet.defer import gatherResults, inlineCallbacks -from ..twisted.transcribe import Wormhole, UsageError, WrongPasswordError +from ..twisted.transcribe import (wormhole, wormhole_from_serialized, + UsageError, WrongPasswordError) from .common import ServerBase APPID = u"appid" +def Wormhole(appid, relayurl): + return wormhole(appid, relayurl, reactor) + class Basic(ServerBase, unittest.TestCase): def doBoth(self, d1, d2): @@ -220,7 +225,7 @@ class Basic(ServerBase, unittest.TestCase): unpacked = json.loads(s) # this is supposed to be JSON self.assertEqual(type(unpacked), dict) - self.new_w1 = Wormhole.from_serialized(s) + self.new_w1 = wormhole_from_serialized(s, reactor) yield self.doBoth(self.new_w1.send(b"data1"), w2.send(b"data2")) dl = yield self.doBoth(self.new_w1.get(), w2.get()) (dataX, dataY) = dl diff --git a/src/wormhole/twisted/transcribe.py b/src/wormhole/twisted/transcribe.py index 8b8fc47..248d1b9 100644 --- a/src/wormhole/twisted/transcribe.py +++ b/src/wormhole/twisted/transcribe.py @@ -2,7 +2,7 @@ from __future__ import print_function import os, sys, json, re, unicodedata from six.moves.urllib_parse import urlparse from binascii import hexlify, unhexlify -from twisted.internet import reactor, defer, endpoints, error +from twisted.internet import defer, endpoints, error from twisted.internet.threads import deferToThread, blockingCallFromThread from twisted.internet.defer import inlineCallbacks, returnValue from twisted.python import log @@ -53,13 +53,13 @@ class WSFactory(websocket.WebSocketClientFactory): proto.wormhole_open = False return proto -class Wormhole: +class _Wormhole: motd_displayed = False version_warning_displayed = False _send_confirm = True - def __init__(self, appid, relay_url, tor_manager=None, timing=None, - reactor=reactor): + def __init__(self, appid, relay_url, reactor, + tor_manager=None, timing=None): if not isinstance(appid, type(u"")): raise TypeError(type(appid)) if not isinstance(relay_url, type(u"")): raise TypeError(type(relay_url)) @@ -351,9 +351,9 @@ class Wormhole: # entry point 3: resume a previously-serialized session @classmethod - def from_serialized(klass, data): + def from_serialized(klass, data, reactor): d = json.loads(data) - self = klass(d["appid"], d["relay_url"]) + self = klass(d["appid"], d["relay_url"], reactor) self._side = d["side"] self._channelid = d["channelid"] self._set_code(d["code"]) @@ -556,3 +556,11 @@ class Wormhole: def _ws_handle_deallocated(self, msg): self._deallocated_status = msg["status"] self._wakeup() + +def wormhole(appid, relay_url, reactor, tor_manager=None, timing=None): + w = _Wormhole(appid, relay_url, reactor, tor_manager, timing) + return w + +def wormhole_from_serialized(data, reactor): + w = _Wormhole.from_serialized(data, reactor) + return w From a34fb2a98ba601a9332be3ec094d3de9d70cbc37 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Thu, 12 May 2016 16:56:19 -0700 Subject: [PATCH 07/51] remove plain-HTTP (non-WebSocket) rendezvous server --- src/wormhole/server/rendezvous_web.py | 223 ----------------- src/wormhole/server/server.py | 4 +- src/wormhole/test/test_server.py | 337 -------------------------- src/wormhole/twisted/eventsource.py | 238 ------------------ 4 files changed, 1 insertion(+), 801 deletions(-) delete mode 100644 src/wormhole/server/rendezvous_web.py delete mode 100644 src/wormhole/twisted/eventsource.py diff --git a/src/wormhole/server/rendezvous_web.py b/src/wormhole/server/rendezvous_web.py deleted file mode 100644 index c89654e..0000000 --- a/src/wormhole/server/rendezvous_web.py +++ /dev/null @@ -1,223 +0,0 @@ -import json, time -from twisted.web import server, resource -from twisted.python import log - -def json_response(request, data): - request.setHeader(b"content-type", b"application/json; charset=utf-8") - return (json.dumps(data)+"\n").encode("utf-8") - -class EventsProtocol: - def __init__(self, request): - self.request = request - - def sendComment(self, comment): - # this is ignored by clients, but can keep the connection open in the - # face of firewall/NAT timeouts. It also helps unit tests, since - # apparently twisted.web.client.Agent doesn't consider the connection - # to be established until it sees the first byte of the reponse body. - self.request.write(b": " + comment + b"\n\n") - - def sendEvent(self, data, name=None, id=None, retry=None): - if name: - self.request.write(b"event: " + name.encode("utf-8") + b"\n") - # e.g. if name=foo, then the client web page should do: - # (new EventSource(url)).addEventListener("foo", handlerfunc) - # Note that this basically defaults to "message". - if id: - self.request.write(b"id: " + id.encode("utf-8") + b"\n") - if retry: - self.request.write(b"retry: " + retry + b"\n") # milliseconds - for line in data.splitlines(): - self.request.write(b"data: " + line.encode("utf-8") + b"\n") - self.request.write(b"\n") - - def stop(self): - self.request.finish() - - def send_rendezvous_event(self, data): - data = data.copy() - data["sent"] = time.time() - self.sendEvent(json.dumps(data)) - def stop_rendezvous_watcher(self): - self.stop() - -# note: no versions of IE (including the current IE11) support EventSource - -# relay URLs are as follows: (MESSAGES=[{phase:,body:}..]) -# ("-" indicates a deprecated URL) -# GET /list?appid= -> {channelids: [INT..]} -# POST /allocate {appid:,side:} -> {channelid: INT} -# these return all messages (base64) for appid=/channelid= : -# POST /add {appid:,channelid:,side:,phase:,body:} -> {messages: MESSAGES} -# GET /get?appid=&channelid= (no-eventsource) -> {messages: MESSAGES} -#- GET /get?appid=&channelid= (eventsource) -> {phase:, body:}.. -# GET /watch?appid=&channelid= (eventsource) -> {phase:, body:}.. -# POST /deallocate {appid:,channelid:,side:} -> {status: waiting | deleted} -# all JSON responses include a "welcome:{..}" key - -class RelayResource(resource.Resource): - def __init__(self, rendezvous): - resource.Resource.__init__(self) - self._rendezvous = rendezvous - self._welcome = rendezvous.get_welcome() - -class ChannelLister(RelayResource): - def render_GET(self, request): - if b"appid" not in request.args: - e = NeedToUpgradeErrorResource(self._welcome) - return e.get_message() - appid = request.args[b"appid"][0].decode("utf-8") - #print("LIST", appid) - app = self._rendezvous.get_app(appid) - allocated = app.get_allocated() - data = {"welcome": self._welcome, "channelids": sorted(allocated), - "sent": time.time()} - return json_response(request, data) - -class Allocator(RelayResource): - def render_POST(self, request): - content = request.content.read() - data = json.loads(content.decode("utf-8")) - appid = data["appid"] - side = data["side"] - if not isinstance(side, type(u"")): - raise TypeError("side must be string, not '%s'" % type(side)) - #print("ALLOCATE", appid, side) - app = self._rendezvous.get_app(appid) - channelid = app.find_available_channelid() - app.allocate_channel(channelid, side) - if self._rendezvous.get_log_requests(): - log.msg("allocated #%d, now have %d DB channels" % - (channelid, len(app.get_allocated()))) - response = {"welcome": self._welcome, "channelid": channelid, - "sent": time.time()} - return json_response(request, response) - - def getChild(self, path, req): - # wormhole-0.4.0 "send" started with "POST /allocate/SIDE". - # wormhole-0.5.0 changed that to "POST /allocate". We catch the old - # URL here to deliver a nicer error message (with upgrade - # instructions) than an ugly 404. - return NeedToUpgradeErrorResource(self._welcome) - -class NeedToUpgradeErrorResource(resource.Resource): - def __init__(self, welcome): - resource.Resource.__init__(self) - w = welcome.copy() - w["error"] = "Sorry, you must upgrade your client to use this server." - message = {"welcome": w} - self._message = (json.dumps(message)+"\n").encode("utf-8") - def get_message(self): - return self._message - def render_POST(self, request): - return self._message - def render_GET(self, request): - return self._message - def getChild(self, path, req): - return self - -class Adder(RelayResource): - def render_POST(self, request): - #content = json.load(request.content, encoding="utf-8") - content = request.content.read() - data = json.loads(content.decode("utf-8")) - appid = data["appid"] - channelid = int(data["channelid"]) - side = data["side"] - phase = data["phase"] - if not isinstance(phase, type(u"")): - raise TypeError("phase must be string, not %s" % type(phase)) - body = data["body"] - #print("ADD", appid, channelid, side, phase, body) - - app = self._rendezvous.get_app(appid) - channel = app.get_channel(channelid) - messages = channel.add_message(side, phase, body, time.time(), None) - response = {"welcome": self._welcome, "messages": messages, - "sent": time.time()} - return json_response(request, response) - -class GetterOrWatcher(RelayResource): - def render_GET(self, request): - appid = request.args[b"appid"][0].decode("utf-8") - channelid = int(request.args[b"channelid"][0]) - #print("GET", appid, channelid) - app = self._rendezvous.get_app(appid) - channel = app.get_channel(channelid) - - if b"text/event-stream" not in (request.getHeader(b"accept") or b""): - messages = channel.get_messages() - response = {"welcome": self._welcome, "messages": messages, - "sent": time.time()} - return json_response(request, response) - - request.setHeader(b"content-type", b"text/event-stream; charset=utf-8") - ep = EventsProtocol(request) - ep.sendEvent(json.dumps(self._welcome), name="welcome") - old_events = channel.add_listener(ep) - request.notifyFinish().addErrback(lambda f: - channel.remove_listener(ep)) - for old_event in old_events: - ep.send_rendezvous_event(old_event) - return server.NOT_DONE_YET - -class Watcher(RelayResource): - def render_GET(self, request): - appid = request.args[b"appid"][0].decode("utf-8") - channelid = int(request.args[b"channelid"][0]) - app = self._rendezvous.get_app(appid) - channel = app.get_channel(channelid) - if b"text/event-stream" not in (request.getHeader(b"accept") or b""): - raise TypeError("/watch is for EventSource only") - - request.setHeader(b"content-type", b"text/event-stream; charset=utf-8") - ep = EventsProtocol(request) - ep.sendEvent(json.dumps(self._welcome), name="welcome") - old_events = channel.add_listener(ep) - request.notifyFinish().addErrback(lambda f: - channel.remove_listener(ep)) - for old_event in old_events: - ep.send_rendezvous_event(old_event) - return server.NOT_DONE_YET - -class Deallocator(RelayResource): - def render_POST(self, request): - content = request.content.read() - data = json.loads(content.decode("utf-8")) - appid = data["appid"] - channelid = int(data["channelid"]) - side = data["side"] - if not isinstance(side, type(u"")): - raise TypeError("side must be string, not '%s'" % type(side)) - mood = data.get("mood") - #print("DEALLOCATE", appid, channelid, side) - - app = self._rendezvous.get_app(appid) - channel = app.get_channel(channelid) - deleted = channel.deallocate(side, mood) - response = {"status": "waiting", "sent": time.time()} - if deleted: - response = {"status": "deleted", "sent": time.time()} - return json_response(request, response) - - -class WebRendezvous(resource.Resource): - def __init__(self, rendezvous): - resource.Resource.__init__(self) - self._rendezvous = rendezvous - self.putChild(b"list", ChannelLister(rendezvous)) - self.putChild(b"allocate", Allocator(rendezvous)) - self.putChild(b"add", Adder(rendezvous)) - self.putChild(b"get", GetterOrWatcher(rendezvous)) - self.putChild(b"watch", Watcher(rendezvous)) - self.putChild(b"deallocate", Deallocator(rendezvous)) - - def getChild(self, path, req): - # 0.4.0 used "POST /CID/SIDE/post/MSGNUM" - # 0.5.0 replaced it with "POST /add (json body)" - # give a nicer error message to old clients - if (len(req.postpath) >= 2 - and req.postpath[1] in (b"post", b"poll", b"deallocate")): - welcome = self._rendezvous.get_welcome() - return NeedToUpgradeErrorResource(welcome) - return resource.NoResource("No such child resource.") diff --git a/src/wormhole/server/server.py b/src/wormhole/server/server.py index 166a6ba..694ed19 100644 --- a/src/wormhole/server/server.py +++ b/src/wormhole/server/server.py @@ -8,7 +8,6 @@ from .endpoint_service import ServerEndpointService from .. import __version__ from .database import get_db from .rendezvous import Rendezvous -from .rendezvous_web import WebRendezvous from .rendezvous_websocket import WebSocketRendezvousFactory from .transit_server import Transit @@ -47,7 +46,7 @@ class RelayServer(service.MultiService): rendezvous.setServiceParent(self) # for the pruning timer root = Root() - wr = WebRendezvous(rendezvous) + wr = resource.Resource() root.putChild(b"wormhole-relay", wr) wsrf = WebSocketRendezvousFactory(None, rendezvous) @@ -72,7 +71,6 @@ class RelayServer(service.MultiService): self._db = db self._rendezvous = rendezvous self._root = root - self._rendezvous_web = wr self._rendezvous_web_service = rendezvous_web_service self._rendezvous_websocket = wsrf if transit_port: diff --git a/src/wormhole/test/test_server.py b/src/wormhole/test/test_server.py index d9010f1..c52bd50 100644 --- a/src/wormhole/test/test_server.py +++ b/src/wormhole/test/test_server.py @@ -1,19 +1,15 @@ from __future__ import print_function import json, itertools from binascii import hexlify -import requests -from six.moves.urllib_parse import urlencode from twisted.trial import unittest from twisted.internet import protocol, reactor, defer from twisted.internet.defer import inlineCallbacks, returnValue -from twisted.internet.threads import deferToThread from twisted.internet.endpoints import clientFromString, connectProtocol from twisted.web.client import getPage, Agent, readBody from autobahn.twisted import websocket from .. import __version__ from .common import ServerBase from ..server import rendezvous, transit_server -from ..twisted.eventsource import EventSource class Reachable(ServerBase, unittest.TestCase): @@ -39,22 +35,6 @@ class Reachable(ServerBase, unittest.TestCase): d.addCallback(_got) return d - def test_requests(self): - # requests requires bytes URL, returns unicode - url = self.relayurl.replace("wormhole-relay/", "") - def _get(url): - r = requests.get(url) - r.raise_for_status() - return r.text - d = deferToThread(_get, url) - def _got(res): - self.failUnlessEqual(res, "Wormhole Relay\n") - d.addCallback(_got) - return d - -def unjson(data): - return json.loads(data.decode("utf-8")) - def strip_message(msg): m2 = msg.copy() m2.pop("id", None) @@ -64,323 +44,6 @@ def strip_message(msg): def strip_messages(messages): return [strip_message(m) for m in messages] -class WebAPI(ServerBase, unittest.TestCase): - def build_url(self, path, appid, channelid): - url = self.relayurl+path - queryargs = [] - if appid: - queryargs.append(("appid", appid)) - if channelid: - queryargs.append(("channelid", channelid)) - if queryargs: - url += "?" + urlencode(queryargs) - return url - - def get(self, path, appid=None, channelid=None): - url = self.build_url(path, appid, channelid) - d = getPage(url.encode("ascii")) - d.addCallback(unjson) - return d - - def post(self, path, data): - url = self.relayurl+path - d = getPage(url.encode("ascii"), method=b"POST", - postdata=json.dumps(data).encode("utf-8")) - d.addCallback(unjson) - return d - - def check_welcome(self, data): - self.failUnlessIn("welcome", data) - self.failUnlessEqual(data["welcome"], {"current_version": __version__}) - - def test_allocate_1(self): - d = self.get("list", "app1") - def _check_list_1(data): - self.check_welcome(data) - self.failUnlessEqual(data["channelids"], []) - d.addCallback(_check_list_1) - - d.addCallback(lambda _: self.post("allocate", {"appid": "app1", - "side": "abc"})) - def _allocated(data): - data.pop("sent", None) - self.failUnlessEqual(set(data.keys()), - set(["welcome", "channelid"])) - self.failUnlessIsInstance(data["channelid"], int) - self.cid = data["channelid"] - d.addCallback(_allocated) - - d.addCallback(lambda _: self.get("list", "app1")) - def _check_list_2(data): - self.failUnlessEqual(data["channelids"], [self.cid]) - d.addCallback(_check_list_2) - - d.addCallback(lambda _: self.post("deallocate", - {"appid": "app1", - "channelid": str(self.cid), - "side": "abc"})) - def _check_deallocate(res): - self.failUnlessEqual(res["status"], "deleted") - d.addCallback(_check_deallocate) - - d.addCallback(lambda _: self.get("list", "app1")) - def _check_list_3(data): - self.failUnlessEqual(data["channelids"], []) - d.addCallback(_check_list_3) - - return d - - def test_allocate_2(self): - d = self.post("allocate", {"appid": "app1", "side": "abc"}) - def _allocated(data): - self.cid = data["channelid"] - d.addCallback(_allocated) - - # second caller increases the number of known sides to 2 - d.addCallback(lambda _: self.post("add", - {"appid": "app1", - "channelid": str(self.cid), - "side": "def", - "phase": "1", - "body": ""})) - - d.addCallback(lambda _: self.get("list", "app1")) - d.addCallback(lambda data: - self.failUnlessEqual(data["channelids"], [self.cid])) - - d.addCallback(lambda _: self.post("deallocate", - {"appid": "app1", - "channelid": str(self.cid), - "side": "abc"})) - d.addCallback(lambda res: - self.failUnlessEqual(res["status"], "waiting")) - - d.addCallback(lambda _: self.post("deallocate", - {"appid": "app1", - "channelid": str(self.cid), - "side": "NOT"})) - d.addCallback(lambda res: - self.failUnlessEqual(res["status"], "waiting")) - - d.addCallback(lambda _: self.post("deallocate", - {"appid": "app1", - "channelid": str(self.cid), - "side": "def"})) - d.addCallback(lambda res: - self.failUnlessEqual(res["status"], "deleted")) - - d.addCallback(lambda _: self.get("list", "app1")) - d.addCallback(lambda data: - self.failUnlessEqual(data["channelids"], [])) - - return d - - UPGRADE_ERROR = "Sorry, you must upgrade your client to use this server." - def test_old_allocate(self): - # 0.4.0 used "POST /allocate/SIDE". - # 0.5.0 replaced it with "POST /allocate". - # test that an old client gets a useful error message, not a 404. - d = self.post("allocate/abc", {}) - def _check(data): - self.failUnlessEqual(data["welcome"]["error"], self.UPGRADE_ERROR) - d.addCallback(_check) - return d - - def test_old_list(self): - # 0.4.0 used "GET /list". - # 0.5.0 replaced it with "GET /list?appid=" - d = self.get("list", {}) # no appid - def _check(data): - self.failUnlessEqual(data["welcome"]["error"], self.UPGRADE_ERROR) - d.addCallback(_check) - return d - - def test_old_post(self): - # 0.4.0 used "POST /CID/SIDE/post/MSGNUM" - # 0.5.0 replaced it with "POST /add (json body)" - d = self.post("1/abc/post/pake", {}) - def _check(data): - self.failUnlessEqual(data["welcome"]["error"], self.UPGRADE_ERROR) - d.addCallback(_check) - return d - - def add_message(self, message, side="abc", phase="1"): - return self.post("add", - {"appid": "app1", - "channelid": str(self.cid), - "side": side, - "phase": phase, - "body": message}) - - def parse_messages(self, messages): - out = set() - for m in messages: - self.failUnlessEqual(sorted(m.keys()), sorted(["phase", "body"])) - self.failUnlessIsInstance(m["phase"], type(u"")) - self.failUnlessIsInstance(m["body"], type(u"")) - out.add( (m["phase"], m["body"]) ) - return out - - def check_messages(self, one, two): - # Comparing lists-of-dicts is non-trivial in python3 because we can - # neither sort them (dicts are uncomparable), nor turn them into sets - # (dicts are unhashable). This is close enough. - self.failUnlessEqual(len(one), len(two), (one,two)) - for d in one: - self.failUnlessIn(d, two) - - def test_message(self): - # exercise POST /add - d = self.post("allocate", {"appid": "app1", "side": "abc"}) - def _allocated(data): - self.cid = data["channelid"] - d.addCallback(_allocated) - - d.addCallback(lambda _: self.add_message("msg1A")) - def _check1(data): - self.check_welcome(data) - self.failUnlessEqual(strip_messages(data["messages"]), - [{"phase": "1", "body": "msg1A"}]) - d.addCallback(_check1) - d.addCallback(lambda _: self.get("get", "app1", str(self.cid))) - d.addCallback(_check1) - d.addCallback(lambda _: self.add_message("msg1B", side="def")) - def _check2(data): - self.check_welcome(data) - self.failUnlessEqual(self.parse_messages(strip_messages(data["messages"])), - set([("1", "msg1A"), - ("1", "msg1B")])) - d.addCallback(_check2) - d.addCallback(lambda _: self.get("get", "app1", str(self.cid))) - d.addCallback(_check2) - - # adding a duplicate message is not an error, is ignored by clients - d.addCallback(lambda _: self.add_message("msg1B", side="def")) - def _check3(data): - self.check_welcome(data) - self.failUnlessEqual(self.parse_messages(strip_messages(data["messages"])), - set([("1", "msg1A"), - ("1", "msg1B")])) - d.addCallback(_check3) - d.addCallback(lambda _: self.get("get", "app1", str(self.cid))) - d.addCallback(_check3) - - d.addCallback(lambda _: self.add_message("msg2A", side="abc", - phase="2")) - def _check4(data): - self.check_welcome(data) - self.failUnlessEqual(self.parse_messages(strip_messages(data["messages"])), - set([("1", "msg1A"), - ("1", "msg1B"), - ("2", "msg2A"), - ])) - d.addCallback(_check4) - d.addCallback(lambda _: self.get("get", "app1", str(self.cid))) - d.addCallback(_check4) - - return d - - def test_watch_message(self): - # exercise GET /get (the EventSource version) - # this API is scheduled to be removed after 0.6.0 - return self._do_watch("get") - - def test_watch(self): - # exercise GET /watch (the EventSource version) - return self._do_watch("watch") - - def _do_watch(self, endpoint_name): - d = self.post("allocate", {"appid": "app1", "side": "abc"}) - def _allocated(data): - self.cid = data["channelid"] - url = self.build_url(endpoint_name, "app1", self.cid) - self.o = OneEventAtATime(url, parser=json.loads) - return self.o.wait_for_connection() - d.addCallback(_allocated) - d.addCallback(lambda _: self.o.wait_for_next_event()) - def _check_welcome(ev): - eventtype, data = ev - self.failUnlessEqual(eventtype, "welcome") - self.failUnlessEqual(data, {"current_version": __version__}) - d.addCallback(_check_welcome) - d.addCallback(lambda _: self.add_message("msg1A")) - d.addCallback(lambda _: self.o.wait_for_next_event()) - def _check_msg1(ev): - eventtype, data = ev - self.failUnlessEqual(eventtype, "message") - data.pop("sent", None) - self.failUnlessEqual(strip_message(data), - {"phase": "1", "body": "msg1A"}) - d.addCallback(_check_msg1) - - d.addCallback(lambda _: self.add_message("msg1B")) - d.addCallback(lambda _: self.add_message("msg2A", phase="2")) - d.addCallback(lambda _: self.o.wait_for_next_event()) - def _check_msg2(ev): - eventtype, data = ev - self.failUnlessEqual(eventtype, "message") - data.pop("sent", None) - self.failUnlessEqual(strip_message(data), - {"phase": "1", "body": "msg1B"}) - d.addCallback(_check_msg2) - d.addCallback(lambda _: self.o.wait_for_next_event()) - def _check_msg3(ev): - eventtype, data = ev - self.failUnlessEqual(eventtype, "message") - data.pop("sent", None) - self.failUnlessEqual(strip_message(data), - {"phase": "2", "body": "msg2A"}) - d.addCallback(_check_msg3) - - d.addCallback(lambda _: self.o.close()) - d.addCallback(lambda _: self.o.wait_for_disconnection()) - return d - -class OneEventAtATime: - def __init__(self, url, parser=lambda e: e): - self.parser = parser - self.d = None - self._connected = False - self.connected_d = defer.Deferred() - self.disconnected_d = defer.Deferred() - self.events = [] - self.es = EventSource(url, self.handler, when_connected=self.connected) - d = self.es.start() - d.addBoth(self.disconnected) - - def close(self): - self.es.cancel() - - def wait_for_next_event(self): - assert not self.d - if self.events: - event = self.events.pop(0) - return defer.succeed(event) - self.d = defer.Deferred() - return self.d - - def handler(self, eventtype, data): - event = (eventtype, self.parser(data)) - if self.d: - assert not self.events - d,self.d = self.d,None - d.callback(event) - return - self.events.append(event) - - def wait_for_connection(self): - return self.connected_d - def connected(self): - self._connected = True - self.connected_d.callback(None) - - def wait_for_disconnection(self): - return self.disconnected_d - def disconnected(self, why): - if not self._connected: - self.connected_d.errback(why) - self.disconnected_d.callback((why,)) - class WSClient(websocket.WebSocketClientProtocol): def __init__(self): websocket.WebSocketClientProtocol.__init__(self) diff --git a/src/wormhole/twisted/eventsource.py b/src/wormhole/twisted/eventsource.py deleted file mode 100644 index 19272ca..0000000 --- a/src/wormhole/twisted/eventsource.py +++ /dev/null @@ -1,238 +0,0 @@ -#import sys -from twisted.python import log, failure -from twisted.internet import reactor, defer, protocol -from twisted.application import service -from twisted.protocols import basic -from twisted.web.client import Agent, ResponseDone -from twisted.web.http_headers import Headers -from cgi import parse_header -from .eventual import eventually - -#if sys.version_info[0] == 2: -# to_unicode = unicode -#else: -# to_unicode = str - -class EventSourceParser(basic.LineOnlyReceiver): - # http://www.w3.org/TR/eventsource/ - delimiter = b"\n" - - def __init__(self, handler): - self.current_field = None - self.current_lines = [] - self.handler = handler - self.done_deferred = defer.Deferred() - self.eventtype = u"message" - self.encoding = "utf-8" - - def set_encoding(self, encoding): - self.encoding = encoding - - def connectionLost(self, why): - if why.check(ResponseDone): - why = None - self.done_deferred.callback(why) - - def dataReceived(self, data): - # exceptions here aren't being logged properly, and tests will hang - # rather than halt. I suspect twisted.web._newclient's - # HTTP11ClientProtocol.dataReceived(), which catches everything and - # responds with self._giveUp() but doesn't log.err. - try: - basic.LineOnlyReceiver.dataReceived(self, data) - except: - log.err() - raise - - def lineReceived(self, line): - #line = to_unicode(line, self.encoding) - line = line.decode(self.encoding) - if not line: - # blank line ends the field: deliver event, reset for next - self.eventReceived(self.eventtype, "\n".join(self.current_lines)) - self.eventtype = u"message" - self.current_lines[:] = [] - return - if u":" in line: - fieldname, data = line.split(u":", 1) - if data.startswith(u" "): - data = data[1:] - else: - fieldname = line - data = u"" - if fieldname == u"event": - self.eventtype = data - elif fieldname == u"data": - self.current_lines.append(data) - elif fieldname in (u"id", u"retry"): - # documented but unhandled - pass - else: - log.msg("weird fieldname", fieldname, data) - - def eventReceived(self, eventtype, data): - self.handler(eventtype, data) - -class EventSourceError(Exception): - pass - -# es = EventSource(url, handler) -# d = es.start() -# es.cancel() - -class EventSource: # TODO: service.Service - def __init__(self, url, handler, when_connected=None, agent=None): - assert isinstance(url, type(u"")) - self.url = url - self.handler = handler - self.when_connected = when_connected - self.started = False - self.cancelled = False - self.proto = EventSourceParser(self.handler) - if not agent: - agent = Agent(reactor) - self.agent = agent - - def start(self): - assert not self.started, "single-use" - self.started = True - assert self.url - d = self.agent.request(b"GET", self.url.encode("utf-8"), - Headers({b"accept": [b"text/event-stream"]})) - d.addCallback(self._connected) - return d - - def _connected(self, resp): - if resp.code != 200: - raise EventSourceError("%d: %s" % (resp.code, resp.phrase)) - if self.when_connected: - self.when_connected() - default_ct = "text/event-stream; charset=utf-8" - ct_headers = resp.headers.getRawHeaders("content-type", [default_ct]) - ct, ct_params = parse_header(ct_headers[0]) - assert ct == "text/event-stream", ct - self.proto.set_encoding(ct_params.get("charset", "utf-8")) - resp.deliverBody(self.proto) - if self.cancelled: - self.kill_connection() - return self.proto.done_deferred - - def cancel(self): - self.cancelled = True - if not self.proto.transport: - # _connected hasn't been called yet, but that self.cancelled - # should take care of it when the connection is established - def kill(data): - # this should kill it as soon as any data is delivered - raise ValueError("dead") - self.proto.dataReceived = kill # just in case - return - self.kill_connection() - - def kill_connection(self): - if (hasattr(self.proto.transport, "_producer") - and self.proto.transport._producer): - # This is gross and fragile. We need a clean way to stop the - # client connection. p.transport is a - # twisted.web._newclient.TransportProxyProducer , and its - # ._producer is the tcp.Port. - self.proto.transport._producer.loseConnection() - else: - log.err("get_events: unable to stop connection") - # oh well - #err = EventSourceError("unable to cancel") - try: - self.proto.done_deferred.callback(None) - except defer.AlreadyCalledError: - pass - - -class Connector: - # behave enough like an IConnector to appease ReconnectingClientFactory - def __init__(self, res): - self.res = res - def connect(self): - self.res._maybeStart() - def stopConnecting(self): - self.res._stop_eventsource() - -class ReconnectingEventSource(service.MultiService, - protocol.ReconnectingClientFactory): - def __init__(self, url, handler, agent=None): - service.MultiService.__init__(self) - # we don't use any of the basic Factory/ClientFactory methods of - # this, just the ReconnectingClientFactory.retry, stopTrying, and - # resetDelay methods. - - self.url = url - self.handler = handler - self.agent = agent - # IService provides self.running, toggled by {start,stop}Service. - # self.active is toggled by {,de}activate. If both .running and - # .active are True, then we want to have an outstanding EventSource - # and will start one if necessary. If either is False, then we don't - # want one to be outstanding, and will initiate shutdown. - self.active = False - self.connector = Connector(self) - self.es = None # set we have an outstanding EventSource - self.when_stopped = [] # list of Deferreds - - def isStopped(self): - return not self.es - - def startService(self): - service.MultiService.startService(self) # sets self.running - self._maybeStart() - - def stopService(self): - # clears self.running - d = defer.maybeDeferred(service.MultiService.stopService, self) - d.addCallback(self._maybeStop) - return d - - def activate(self): - assert not self.active - self.active = True - self._maybeStart() - - def deactivate(self): - assert self.active # XXX - self.active = False - return self._maybeStop() - - def _maybeStart(self): - if not (self.active and self.running): - return - self.continueTrying = True - self.es = EventSource(self.url, self.handler, self.resetDelay, - agent=self.agent) - d = self.es.start() - d.addBoth(self._stopped) - - def _stopped(self, res): - self.es = None - # we might have stopped because of a connection error, or because of - # an intentional shutdown. - if self.active and self.running: - # we still want to be connected, so schedule a reconnection - if isinstance(res, failure.Failure): - log.err(res) - self.retry() # will eventually call _maybeStart - return - # intentional shutdown - self.stopTrying() - for d in self.when_stopped: - eventually(d.callback, None) - self.when_stopped = [] - - def _stop_eventsource(self): - if self.es: - eventually(self.es.cancel) - - def _maybeStop(self, _=None): - self.stopTrying() # cancels timer, calls _stop_eventsource() - if not self.es: - return defer.succeed(None) - d = defer.Deferred() - self.when_stopped.append(d) - return d From bdc9066c2354303e77c9e5d8fc62c3cf83c0d328 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Thu, 12 May 2016 17:03:57 -0700 Subject: [PATCH 08/51] rendezvous: change add_listener signature Pass in a handle and a pair of functions, rather than an object with two well-known methods. This should make it easier to subscribe to multiple channels in the future. --- src/wormhole/server/rendezvous.py | 28 ++++++++++----------- src/wormhole/server/rendezvous_websocket.py | 8 ++++-- 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/src/wormhole/server/rendezvous.py b/src/wormhole/server/rendezvous.py index 122f780..3401008 100644 --- a/src/wormhole/server/rendezvous.py +++ b/src/wormhole/server/rendezvous.py @@ -24,9 +24,9 @@ class Channel: self._log_requests = log_requests self._appid = appid self._channelid = channelid - self._listeners = set() # instances with .send_rendezvous_event (that - # takes a JSONable object) and - # .stop_rendezvous_watcher() + self._listeners = {} # handle -> (send_f, stop_f) + # "handle" is a hashable object, for deregistration + # send_f() takes a JSONable object, stop_f() has no args def get_channelid(self): return self._channelid @@ -44,17 +44,17 @@ class Channel: "server_rx": row["server_rx"], "id": row["msgid"]}) return messages - def add_listener(self, ep): - self._listeners.add(ep) + def add_listener(self, handle, send_f, stop_f): + self._listeners[handle] = (send_f, stop_f) return self.get_messages() - def remove_listener(self, ep): - self._listeners.discard(ep) + def remove_listener(self, handle): + self._listeners.pop(handle) def broadcast_message(self, phase, body, server_rx, msgid): - for ep in self._listeners: - ep.send_rendezvous_event({"phase": phase, "body": body, - "server_rx": server_rx, "id": msgid}) + for (send_f, stop_f) in self._listeners.values(): + send_f({"phase": phase, "body": body, + "server_rx": server_rx, "id": msgid}) def _add_message(self, side, phase, body, server_rx, msgid): db = self._db @@ -183,15 +183,15 @@ class Channel: # Shut down any listeners, just in case they're still lingering # around. - for ep in self._listeners: - ep.stop_rendezvous_watcher() + for (send_f, stop_f) in self._listeners.values(): + stop_f() self._app.free_channel(self._channelid) def _shutdown(self): # used at test shutdown to accelerate client disconnects - for ep in self._listeners: - ep.stop_rendezvous_watcher() + for (send_f, stop_f) in self._listeners.values(): + stop_f() class AppNamespace: def __init__(self, db, welcome, blur_usage, log_requests, appid): diff --git a/src/wormhole/server/rendezvous_websocket.py b/src/wormhole/server/rendezvous_websocket.py index 93abf2c..6bd291b 100644 --- a/src/wormhole/server/rendezvous_websocket.py +++ b/src/wormhole/server/rendezvous_websocket.py @@ -150,8 +150,12 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): if self._watching: raise Error("already watching") self._watching = True - for old_message in channel.add_listener(self): - self.send_rendezvous_event(old_message) + def _send(event): + self.send_rendezvous_event(event) + def _stop(): + self.stop_rendezvous_watcher() + for old_message in channel.add_listener(self, _send, _stop): + _send(old_message) def handle_add(self, channel, msg, server_rx): if "phase" not in msg: From 2c2cf29564ecd369b3b113d42b8daa3ee9f074a0 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Thu, 12 May 2016 17:12:04 -0700 Subject: [PATCH 09/51] update comment: sent -> server_tx --- src/wormhole/server/rendezvous_websocket.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/wormhole/server/rendezvous_websocket.py b/src/wormhole/server/rendezvous_websocket.py index 6bd291b..6f21b44 100644 --- a/src/wormhole/server/rendezvous_websocket.py +++ b/src/wormhole/server/rendezvous_websocket.py @@ -17,9 +17,9 @@ from autobahn.twisted import websocket # that some time after A is received, at least one message of type B will be # sent out. -# All outbound messages include a "sent" key, which is a float (seconds since -# epoch) with the server clock just before the outbound message was written -# to the socket. +# All outbound messages include a "server_tx" key, which is a float (seconds +# since epoch) with the server clock just before the outbound message was +# written to the socket. # connection -> welcome # <- {type: "welcome", welcome: {}} # .welcome keys are all optional: From 85dc0fd41b4a3c99be58ff47f970e79ee8072666 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Thu, 12 May 2016 17:46:15 -0700 Subject: [PATCH 10/51] change server API: "release" instead of "deallocate" --- src/wormhole/server/rendezvous.py | 40 ++++++++++----------- src/wormhole/server/rendezvous_websocket.py | 16 ++++----- src/wormhole/test/test_scripts.py | 4 +-- src/wormhole/test/test_server.py | 28 +++++++-------- src/wormhole/twisted/transcribe.py | 18 +++++----- 5 files changed, 53 insertions(+), 53 deletions(-) diff --git a/src/wormhole/server/rendezvous.py b/src/wormhole/server/rendezvous.py index 3401008..9f20b61 100644 --- a/src/wormhole/server/rendezvous.py +++ b/src/wormhole/server/rendezvous.py @@ -12,8 +12,8 @@ MB = 1000*1000 CHANNEL_EXPIRATION_TIME = 3*DAY EXPIRATION_CHECK_PERIOD = 2*HOUR -ALLOCATE = u"_allocate" -DEALLOCATE = u"_deallocate" +CLAIM = u"_claim" +RELEASE = u"_release" class Channel: def __init__(self, app, db, welcome, blur_usage, log_requests, @@ -38,7 +38,7 @@ class Channel: " WHERE `appid`=? AND `channelid`=?" " ORDER BY `server_rx` ASC", (self._appid, self._channelid)).fetchall(): - if row["phase"] in (u"_allocate", u"_deallocate"): + if row["phase"] in (CLAIM, RELEASE): continue messages.append({"phase": row["phase"], "body": row["body"], "server_rx": row["server_rx"], "id": row["msgid"]}) @@ -66,16 +66,16 @@ class Channel: server_rx, msgid)) db.commit() - def allocate(self, side): - self._add_message(side, ALLOCATE, None, time.time(), None) + def claim(self, side): + self._add_message(side, CLAIM, None, time.time(), None) def add_message(self, side, phase, body, server_rx, msgid): self._add_message(side, phase, body, server_rx, msgid) self.broadcast_message(phase, body, server_rx, msgid) return self.get_messages() # for rendezvous_web.py POST /add - def deallocate(self, side, mood): - self._add_message(side, DEALLOCATE, mood, time.time(), None) + def release(self, side, mood): + self._add_message(side, RELEASE, mood, time.time(), None) db = self._db seen = set([row["side"] for row in db.execute("SELECT `side` FROM `messages`" @@ -85,7 +85,7 @@ class Channel: db.execute("SELECT `side` FROM `messages`" " WHERE `appid`=? AND `channelid`=?" " AND `phase`=?", - (self._appid, self._channelid, DEALLOCATE))]) + (self._appid, self._channelid, RELEASE))]) if seen - freed: return False self.delete_and_summarize() @@ -127,7 +127,7 @@ class Channel: started = min([m["server_rx"] for m in messages]) # 'total_time' is how long the channel was occupied. That ends now, # both for channels that got pruned for inactivity, and for channels - # that got pruned because of two DEALLOCATE messages + # that got pruned because of two RELEASE messages total_time = delete_time - started if len(all_sides) == 1: @@ -147,9 +147,9 @@ class Channel: # now, were all sides closed? If not, this is "pruney" A_deallocs = [m for m in messages - if m["phase"] == DEALLOCATE and m["side"] == A_side] + if m["phase"] == RELEASE and m["side"] == A_side] B_deallocs = [m for m in messages - if m["phase"] == DEALLOCATE and m["side"] == B_side] + if m["phase"] == RELEASE and m["side"] == B_side] if not A_deallocs or not B_deallocs: return (started, "pruney", total_time, None) @@ -202,31 +202,31 @@ class AppNamespace: self._appid = appid self._channels = {} - def get_allocated(self): + def get_claimed(self): db = self._db c = db.execute("SELECT DISTINCT `channelid` FROM `messages`" " WHERE `appid`=?", (self._appid,)) return set([row["channelid"] for row in c.fetchall()]) def find_available_channelid(self): - allocated = self.get_allocated() + claimed = self.get_claimed() for size in range(1,4): # stick to 1-999 for now available = set() for cid in range(10**(size-1), 10**size): - if cid not in allocated: + if cid not in claimed: available.add(cid) if available: return random.choice(list(available)) - # ouch, 999 currently allocated. Try random ones for a while. + # ouch, 999 currently claimed. Try random ones for a while. for tries in range(1000): cid = random.randrange(1000, 1000*1000) - if cid not in allocated: + if cid not in claimed: return cid raise ValueError("unable to find a free channel-id") - def allocate_channel(self, channelid, side): + def claim_channel(self, channelid, side): channel = self.get_channel(channelid) - channel.allocate(side) + channel.claim(side) return channel def get_channel(self, channelid): @@ -248,7 +248,7 @@ class AppNamespace: self._channels.pop(channelid) if self._log_requests: log.msg("freed+killed #%d, now have %d DB channels, %d live" % - (channelid, len(self.get_allocated()), len(self._channels))) + (channelid, len(self.get_claimed()), len(self._channels))) def prune_old_channels(self): # For now, pruning is logged even if log_requests is False, to debug @@ -259,7 +259,7 @@ class AppNamespace: log.msg(" channel prune begins") # a channel is deleted when there are no listeners and there have # been no messages added in CHANNEL_EXPIRATION_TIME seconds - channels = set(self.get_allocated()) # these have messages + channels = set(self.get_claimed()) # these have messages channels.update(self._channels) # these might have listeners for channelid in channels: log.msg(" channel prune checking %d" % channelid) diff --git a/src/wormhole/server/rendezvous_websocket.py b/src/wormhole/server/rendezvous_websocket.py index 6f21b44..3cecd12 100644 --- a/src/wormhole/server/rendezvous_websocket.py +++ b/src/wormhole/server/rendezvous_websocket.py @@ -97,8 +97,8 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): return self.handle_watch(self._channel, msg) if mtype == "add": return self.handle_add(self._channel, msg, server_rx) - if mtype == "deallocate": - return self.handle_deallocate(self._channel, msg) + if mtype == "release": + return self.handle_release(self._channel, msg) raise Error("Unknown type") except Error as e: @@ -126,14 +126,14 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): self._side = msg["side"] def handle_list(self): - channelids = sorted(self._app.get_allocated()) + channelids = sorted(self._app.get_claimed()) self.send("channelids", channelids=channelids) def handle_allocate(self): if self._channel: raise Error("Already bound to a channelid") channelid = self._app.find_available_channelid() - self._channel = self._app.allocate_channel(channelid, self._side) + self._channel = self._app.claim_channel(channelid, self._side) self.send("allocated", channelid=channelid) def handle_claim(self, msg): @@ -144,7 +144,7 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): old_cid = self._channel.get_channelid() if msg["channelid"] != old_cid: raise Error("Already bound to channelid %d" % old_cid) - self._channel = self._app.allocate_channel(msg["channelid"], self._side) + self._channel = self._app.claim_channel(msg["channelid"], self._side) def handle_watch(self, channel, msg): if self._watching: @@ -166,9 +166,9 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): channel.add_message(self._side, msg["phase"], msg["body"], server_rx, msgid) - def handle_deallocate(self, channel, msg): - deleted = channel.deallocate(self._side, msg.get("mood")) - self.send("deallocated", status="deleted" if deleted else "waiting") + def handle_release(self, channel, msg): + deleted = channel.release(self._side, msg.get("mood")) + self.send("released", status="deleted" if deleted else "waiting") def send(self, mtype, **kwargs): kwargs["type"] = mtype diff --git a/src/wormhole/test/test_scripts.py b/src/wormhole/test/test_scripts.py index b39f2c5..e70cc1a 100644 --- a/src/wormhole/test/test_scripts.py +++ b/src/wormhole/test/test_scripts.py @@ -453,7 +453,7 @@ class Cleanup(ServerBase, unittest.TestCase): yield send_d yield receive_d - cids = self._rendezvous.get_app(cmd_send.APPID).get_allocated() + cids = self._rendezvous.get_app(cmd_send.APPID).get_claimed() self.assertEqual(len(cids), 0) @inlineCallbacks @@ -482,6 +482,6 @@ class Cleanup(ServerBase, unittest.TestCase): yield self.assertFailure(send_d, WrongPasswordError) yield self.assertFailure(receive_d, WrongPasswordError) - cids = self._rendezvous.get_app(cmd_send.APPID).get_allocated() + cids = self._rendezvous.get_app(cmd_send.APPID).get_claimed() self.assertEqual(len(cids), 0) diff --git a/src/wormhole/test/test_server.py b/src/wormhole/test/test_server.py index c52bd50..8708dd1 100644 --- a/src/wormhole/test/test_server.py +++ b/src/wormhole/test/test_server.py @@ -210,7 +210,7 @@ class WebSocketAPI(ServerBase, unittest.TestCase): yield c1.sync() self.assertEqual(list(self._rendezvous._apps.keys()), [u"appid"]) app = self._rendezvous.get_app(u"appid") - self.assertEqual(app.get_allocated(), set()) + self.assertEqual(app.get_claimed(), set()) c1.send(u"list") msg = yield c1.next_non_ack() self.assertEqual(msg["type"], u"channelids") @@ -221,7 +221,7 @@ class WebSocketAPI(ServerBase, unittest.TestCase): self.assertEqual(msg["type"], u"allocated") cid = msg["channelid"] self.failUnlessIsInstance(cid, int) - self.assertEqual(app.get_allocated(), set([cid])) + self.assertEqual(app.get_claimed(), set([cid])) channel = app.get_channel(cid) self.assertEqual(channel.get_messages(), []) @@ -230,11 +230,11 @@ class WebSocketAPI(ServerBase, unittest.TestCase): self.assertEqual(msg["type"], u"channelids") self.assertEqual(msg["channelids"], [cid]) - c1.send(u"deallocate") + c1.send(u"release") msg = yield c1.next_non_ack() - self.assertEqual(msg["type"], u"deallocated") + self.assertEqual(msg["type"], u"released") self.assertEqual(msg["status"], u"deleted") - self.assertEqual(app.get_allocated(), set()) + self.assertEqual(app.get_claimed(), set()) c1.send(u"list") msg = yield c1.next_non_ack() @@ -249,13 +249,13 @@ class WebSocketAPI(ServerBase, unittest.TestCase): c1.send(u"bind", appid=u"appid", side=u"side") yield c1.sync() app = self._rendezvous.get_app(u"appid") - self.assertEqual(app.get_allocated(), set()) + self.assertEqual(app.get_claimed(), set()) c1.send(u"allocate") msg = yield c1.next_non_ack() self.assertEqual(msg["type"], u"allocated") cid = msg["channelid"] self.failUnlessIsInstance(cid, int) - self.assertEqual(app.get_allocated(), set([cid])) + self.assertEqual(app.get_claimed(), set([cid])) channel = app.get_channel(cid) self.assertEqual(channel.get_messages(), []) @@ -268,7 +268,7 @@ class WebSocketAPI(ServerBase, unittest.TestCase): c2.send(u"add", phase="1", body="") yield c2.sync() - self.assertEqual(app.get_allocated(), set([cid])) + self.assertEqual(app.get_claimed(), set([cid])) self.assertEqual(strip_messages(channel.get_messages()), [{"phase": "1", "body": ""}]) @@ -282,14 +282,14 @@ class WebSocketAPI(ServerBase, unittest.TestCase): self.assertEqual(msg["type"], u"channelids") self.assertEqual(msg["channelids"], [cid]) - c1.send(u"deallocate") + c1.send(u"release") msg = yield c1.next_non_ack() - self.assertEqual(msg["type"], u"deallocated") + self.assertEqual(msg["type"], u"released") self.assertEqual(msg["status"], u"waiting") - c2.send(u"deallocate") + c2.send(u"release") msg = yield c2.next_non_ack() - self.assertEqual(msg["type"], u"deallocated") + self.assertEqual(msg["type"], u"released") self.assertEqual(msg["status"], u"deleted") c2.send(u"list") @@ -420,8 +420,8 @@ class WebSocketAPI(ServerBase, unittest.TestCase): class Summary(unittest.TestCase): def test_summarize(self): c = rendezvous.Channel(None, None, None, None, False, None, None) - A = rendezvous.ALLOCATE - D = rendezvous.DEALLOCATE + A = rendezvous.CLAIM + D = rendezvous.RELEASE messages = [{"server_rx": 1, "side": "a", "phase": A}] self.failUnlessEqual(c._summarize(messages, 2), diff --git a/src/wormhole/twisted/transcribe.py b/src/wormhole/twisted/transcribe.py index 248d1b9..64d83c3 100644 --- a/src/wormhole/twisted/transcribe.py +++ b/src/wormhole/twisted/transcribe.py @@ -84,7 +84,7 @@ class _Wormhole: self._sleepers = [] self._confirmation_failed = False self._closed = False - self._deallocated_status = None + self._released_status = None self._timing_started = self._timing.add("wormhole") self._ws = None self._ws_t = None # timing Event @@ -535,7 +535,7 @@ class _Wormhole: raise TypeError(type(mood)) with self._timing.add("API close"): - yield self._deallocate(mood) + yield self._release(mood) # TODO: mark WebSocket as don't-reconnect self._ws.transport.loseConnection() # probably flushes del self._ws @@ -544,17 +544,17 @@ class _Wormhole: returnValue(f) @inlineCallbacks - def _deallocate(self, mood): - with self._timing.add("deallocate"): - yield self._ws_send(u"deallocate", mood=mood) - while self._deallocated_status is None: + def _release(self, mood): + with self._timing.add("release"): + yield self._ws_send(u"release", mood=mood) + while self._released_status is None: yield self._sleep(wake_on_error=False) # TODO: set a timeout, don't wait forever for an ack # TODO: if the connection is lost, let it go - returnValue(self._deallocated_status) + returnValue(self._released_status) - def _ws_handle_deallocated(self, msg): - self._deallocated_status = msg["status"] + def _ws_handle_released(self, msg): + self._released_status = msg["status"] self._wakeup() def wormhole(appid, relay_url, reactor, tor_manager=None, timing=None): From 31491bb939b0c2cb6f3f16b26ca57c88fa645794 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Thu, 12 May 2016 17:48:26 -0700 Subject: [PATCH 11/51] update docs --- src/wormhole/server/rendezvous_websocket.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/wormhole/server/rendezvous_websocket.py b/src/wormhole/server/rendezvous_websocket.py index 3cecd12..40c0b50 100644 --- a/src/wormhole/server/rendezvous_websocket.py +++ b/src/wormhole/server/rendezvous_websocket.py @@ -8,7 +8,10 @@ from autobahn.twisted import websocket # (which must be the first message on the connection). The channelid is set # by either a "allocate" message (where the server picks the channelid), or # by a "claim" message (where the client picks it). All three values must be -# set before any other message (watch, add, deallocate) can be sent. +# set before any other message (watch, add, deallocate) can be sent. Channels +# are maintained (saved from deletion) by a "claim" message (and also +# incidentally by "allocate"). Channels are deleted when the last claim is +# released with "release". # All websocket messages are JSON-encoded. The client can send us "inbound" # messages (marked as "->" below), which may (or may not) provoke immediate @@ -35,8 +38,10 @@ from autobahn.twisted import websocket # -> {type: "watch"} -> message # sends old messages and more in future # <- {type: "message", message: {phase:, body:}} # body is hex # -> {type: "add", phase: str, body: hex} # may send echo -# -> {type: "deallocate", mood: str} -> deallocated -# <- {type: "deallocated", status: waiting|deleted} +# +# -> {type: "release", mood: str} -> deallocated +# <- {type: "released", status: waiting|deleted} +# # <- {type: "error", error: str, orig: {}} # in response to malformed msgs # for tests that need to know when a message has been processed: From c14e982ae7d9185cdad9570279ac03ad3dbf61ba Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Thu, 12 May 2016 18:01:56 -0700 Subject: [PATCH 12/51] rendezvous: allow multiple channels per connection --- src/wormhole/server/rendezvous_websocket.py | 98 +++++++++++---------- src/wormhole/test/test_server.py | 39 +++----- src/wormhole/twisted/transcribe.py | 7 +- 3 files changed, 65 insertions(+), 79 deletions(-) diff --git a/src/wormhole/server/rendezvous_websocket.py b/src/wormhole/server/rendezvous_websocket.py index 40c0b50..9ec113d 100644 --- a/src/wormhole/server/rendezvous_websocket.py +++ b/src/wormhole/server/rendezvous_websocket.py @@ -3,26 +3,26 @@ from twisted.internet import reactor from twisted.python import log from autobahn.twisted import websocket -# Each WebSocket connection is bound to one "appid", one "side", and one -# "channelid". The connection's appid and side are set by the "bind" message -# (which must be the first message on the connection). The channelid is set -# by either a "allocate" message (where the server picks the channelid), or -# by a "claim" message (where the client picks it). All three values must be -# set before any other message (watch, add, deallocate) can be sent. Channels -# are maintained (saved from deletion) by a "claim" message (and also -# incidentally by "allocate"). Channels are deleted when the last claim is -# released with "release". +# Each WebSocket connection is bound to one "appid", one "side", and zero or +# more "channelids". The connection's appid and side are set by the "bind" +# message (which must be the first message on the connection). Both must be +# set before any other message (allocate, claim, watch, add, deallocate) will +# be accepted. Short channel IDs can be obtained from the server with an +# "allocate" message. Longer ones can be selected independently by the +# client. Channels are maintained (saved from deletion) by a "claim" message +# (and also incidentally by "allocate"). Channels are deleted when the last +# claim is released with "release". # All websocket messages are JSON-encoded. The client can send us "inbound" # messages (marked as "->" below), which may (or may not) provoke immediate # (or delayed) "outbound" messages (marked as "<-"). There is no guaranteed # correlation between requests and responses. In this list, "A -> B" means # that some time after A is received, at least one message of type B will be -# sent out. +# sent out (probably). # All outbound messages include a "server_tx" key, which is a float (seconds # since epoch) with the server clock just before the outbound message was -# written to the socket. +# written to the socket. Unrecognized keys will be ignored. # connection -> welcome # <- {type: "welcome", welcome: {}} # .welcome keys are all optional: @@ -30,17 +30,20 @@ from autobahn.twisted import websocket # motd: all clients display message, then continue normally # error: all clients display mesage, then terminate with error # -> {type: "bind", appid:, side:} +# # -> {type: "list"} -> channelids # <- {type: "channelids", channelids: [int..]} # -> {type: "allocate"} -> allocated # <- {type: "allocated", channelid: int} # -> {type: "claim", channelid: int} -# -> {type: "watch"} -> message # sends old messages and more in future -# <- {type: "message", message: {phase:, body:}} # body is hex -# -> {type: "add", phase: str, body: hex} # may send echo # -# -> {type: "release", mood: str} -> deallocated -# <- {type: "released", status: waiting|deleted} +# -> {type: "watch", channelid: int} -> message +# sends old messages and more in future +# <- {type: "message", channelid: int, message: {phase:, body:}} # body is hex +# -> {type: "add", channelid: int, phase: str, body: hex} # will send echo +# +# -> {type: "release", channelid: int, mood: str} -> deallocated +# <- {type: "released", channelid: int, status: waiting|deleted} # # <- {type: "error", error: str, orig: {}} # in response to malformed msgs @@ -57,8 +60,8 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): websocket.WebSocketServerProtocol.__init__(self) self._app = None self._side = None - self._channel = None - self._watching = False + self._did_allocate = False # only one allocate() per websocket + self._channels = {} # channel-id -> Channel (claimed) def onConnect(self, request): rv = self.factory.rendezvous @@ -95,26 +98,17 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): return self.handle_allocate() if mtype == "claim": return self.handle_claim(msg) - - if not self._channel: - raise Error("Must set channel first") if mtype == "watch": - return self.handle_watch(self._channel, msg) + return self.handle_watch(msg) if mtype == "add": - return self.handle_add(self._channel, msg, server_rx) + return self.handle_add(msg, server_rx) if mtype == "release": - return self.handle_release(self._channel, msg) + return self.handle_release(msg) raise Error("Unknown type") except Error as e: self.send("error", error=e._explain, orig=msg) - def send_rendezvous_event(self, event): - self.send("message", message=event) - - def stop_rendezvous_watcher(self): - self._reactor.callLater(0, self.transport.loseConnection) - def handle_ping(self, msg): if "ping" not in msg: raise Error("ping requires 'ping'") @@ -135,34 +129,39 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): self.send("channelids", channelids=channelids) def handle_allocate(self): - if self._channel: - raise Error("Already bound to a channelid") + if self._did_allocate: + raise Error("You already allocated one channel, don't be greedy") channelid = self._app.find_available_channelid() - self._channel = self._app.claim_channel(channelid, self._side) + self._did_allocate = True + channel = self._app.claim_channel(channelid, self._side) + self._channels[channelid] = channel self.send("allocated", channelid=channelid) def handle_claim(self, msg): if "channelid" not in msg: raise Error("claim requires 'channelid'") - # we allow allocate+claim as long as they match - if self._channel is not None: - old_cid = self._channel.get_channelid() - if msg["channelid"] != old_cid: - raise Error("Already bound to channelid %d" % old_cid) - self._channel = self._app.claim_channel(msg["channelid"], self._side) + channelid = msg["channelid"] + if channelid not in self._channels: + channel = self._app.claim_channel(channelid, self._side) + self._channels[channelid] = channel - def handle_watch(self, channel, msg): - if self._watching: - raise Error("already watching") - self._watching = True + def handle_watch(self, msg): + channelid = msg["channelid"] + if channelid not in self._channels: + raise Error("must claim channel before watching") + channel = self._channels[channelid] def _send(event): - self.send_rendezvous_event(event) + self.send("message", channelid=channelid, message=event) def _stop(): - self.stop_rendezvous_watcher() + self._reactor.callLater(0, self.transport.loseConnection) for old_message in channel.add_listener(self, _send, _stop): _send(old_message) - def handle_add(self, channel, msg, server_rx): + def handle_add(self, msg, server_rx): + channelid = msg["channelid"] + if channelid not in self._channels: + raise Error("must claim channel before adding") + channel = self._channels[channelid] if "phase" not in msg: raise Error("missing 'phase'") if "body" not in msg: @@ -171,8 +170,13 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): channel.add_message(self._side, msg["phase"], msg["body"], server_rx, msgid) - def handle_release(self, channel, msg): + def handle_release(self, msg): + channelid = msg["channelid"] + if channelid not in self._channels: + raise Error("must claim channel before releasing") + channel = self._channels[channelid] deleted = channel.release(self._side, msg.get("mood")) + del self._channels[channelid] self.send("released", status="deleted" if deleted else "waiting") def send(self, mtype, **kwargs): diff --git a/src/wormhole/test/test_server.py b/src/wormhole/test/test_server.py index 8708dd1..6deef6c 100644 --- a/src/wormhole/test/test_server.py +++ b/src/wormhole/test/test_server.py @@ -230,7 +230,7 @@ class WebSocketAPI(ServerBase, unittest.TestCase): self.assertEqual(msg["type"], u"channelids") self.assertEqual(msg["channelids"], [cid]) - c1.send(u"release") + c1.send(u"release", channelid=cid) msg = yield c1.next_non_ack() self.assertEqual(msg["type"], u"released") self.assertEqual(msg["status"], u"deleted") @@ -265,7 +265,7 @@ class WebSocketAPI(ServerBase, unittest.TestCase): self.check_welcome(msg) c2.send(u"bind", appid=u"appid", side=u"side-2") c2.send(u"claim", channelid=cid) - c2.send(u"add", phase="1", body="") + c2.send(u"add", channelid=cid, phase="1", body="") yield c2.sync() self.assertEqual(app.get_claimed(), set([cid])) @@ -282,12 +282,12 @@ class WebSocketAPI(ServerBase, unittest.TestCase): self.assertEqual(msg["type"], u"channelids") self.assertEqual(msg["channelids"], [cid]) - c1.send(u"release") + c1.send(u"release", channelid=cid) msg = yield c1.next_non_ack() self.assertEqual(msg["type"], u"released") self.assertEqual(msg["status"], u"waiting") - c2.send(u"release") + c2.send(u"release", channelid=cid) msg = yield c2.next_non_ack() self.assertEqual(msg["type"], u"released") self.assertEqual(msg["status"], u"deleted") @@ -312,25 +312,6 @@ class WebSocketAPI(ServerBase, unittest.TestCase): # there should no error self.assertEqual(c1.errors, []) - @inlineCallbacks - def test_allocate_and_claim_different(self): - c1 = yield self.make_client() - msg = yield c1.next_non_ack() - self.check_welcome(msg) - c1.send(u"bind", appid=u"appid", side=u"side") - c1.send(u"allocate") - msg = yield c1.next_non_ack() - self.assertEqual(msg["type"], u"allocated") - cid = msg["channelid"] - c1.send(u"claim", channelid=cid+1) - yield c1.sync() - # that should signal an error - self.assertEqual(len(c1.errors), 1, c1.errors) - msg = c1.errors[0] - self.assertEqual(msg["type"], "error") - self.assertEqual(msg["error"], "Already bound to channelid %d" % cid) - self.assertEqual(msg["orig"], {"type": "claim", "channelid": cid+1}) - @inlineCallbacks def test_message(self): c1 = yield self.make_client() @@ -345,13 +326,13 @@ class WebSocketAPI(ServerBase, unittest.TestCase): channel = app.get_channel(cid) self.assertEqual(channel.get_messages(), []) - c1.send(u"watch") + c1.send(u"watch", channelid=cid) yield c1.sync() self.assertEqual(len(channel._listeners), 1) c1.strip_acks() self.assertEqual(c1.events, []) - c1.send(u"add", phase="1", body="msg1A") + c1.send(u"add", channelid=cid, phase="1", body="msg1A") yield c1.sync() c1.strip_acks() self.assertEqual(strip_messages(channel.get_messages()), @@ -364,8 +345,8 @@ class WebSocketAPI(ServerBase, unittest.TestCase): self.assertIn("server_tx", msg) self.assertIsInstance(msg["server_tx"], float) - c1.send(u"add", phase="1", body="msg1B") - c1.send(u"add", phase="2", body="msg2A") + c1.send(u"add", channelid=cid, phase="1", body="msg1B") + c1.send(u"add", channelid=cid, phase="2", body="msg2A") msg = yield c1.next_non_ack() self.assertEqual(msg["type"], "message") @@ -390,7 +371,7 @@ class WebSocketAPI(ServerBase, unittest.TestCase): c2.send(u"bind", appid=u"appid", side=u"side") c2.send(u"claim", channelid=cid) # 'watch' triggers delivery of old messages, in temporal order - c2.send(u"watch") + c2.send(u"watch", channelid=cid) msg = yield c2.next_non_ack() self.assertEqual(msg["type"], "message") @@ -408,7 +389,7 @@ class WebSocketAPI(ServerBase, unittest.TestCase): {"phase": "2", "body": "msg2A"}) # adding a duplicate is not an error, and clients will ignore it - c1.send(u"add", phase="2", body="msg2A") + c1.send(u"add", channelid=cid, phase="2", body="msg2A") # the duplicate message *does* get stored, and delivered msg = yield c2.next_non_ack() diff --git a/src/wormhole/twisted/transcribe.py b/src/wormhole/twisted/transcribe.py index 64d83c3..5f362a3 100644 --- a/src/wormhole/twisted/transcribe.py +++ b/src/wormhole/twisted/transcribe.py @@ -211,7 +211,7 @@ class _Wormhole: if not self._ws_channel_claimed: yield self._ws_send(u"claim", channelid=self._channelid) self._ws_channel_claimed = True - yield self._ws_send(u"watch") + yield self._ws_send(u"watch", channelid=self._channelid) # entry point 1: generate a new code @inlineCallbacks @@ -406,7 +406,7 @@ class _Wormhole: # TODO: retry on failure, with exponential backoff. We're guarding # against the rendezvous server being temporarily offline. t = self._timing.add("add", phase=phase, wait=wait) - yield self._ws_send(u"add", phase=phase, + yield self._ws_send(u"add", channelid=self._channelid, phase=phase, body=hexlify(body).decode("ascii")) if wait: while phase not in self._delivered_messages: @@ -546,7 +546,8 @@ class _Wormhole: @inlineCallbacks def _release(self, mood): with self._timing.add("release"): - yield self._ws_send(u"release", mood=mood) + yield self._ws_send(u"release", channelid=self._channelid, + mood=mood) while self._released_status is None: yield self._sleep(wake_on_error=False) # TODO: set a timeout, don't wait forever for an ack From 1198977e069bcf0ea01e66f538269972dc06060a Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Fri, 13 May 2016 00:37:53 -0700 Subject: [PATCH 13/51] SCHEMA CHANGE: channelids are now strs, not ints This will enable the use of large randomly-generated hex or base32 channelids, for post-startup or resumed-connection channels. --- src/wormhole/codes.py | 3 ++- src/wormhole/server/db-schemas/v1.sql | 4 ++-- src/wormhole/server/rendezvous.py | 13 ++++++++----- src/wormhole/server/rendezvous_websocket.py | 5 +++++ src/wormhole/test/test_server.py | 4 ++-- src/wormhole/twisted/transcribe.py | 4 +++- 6 files changed, 22 insertions(+), 11 deletions(-) diff --git a/src/wormhole/codes.py b/src/wormhole/codes.py index fd08694..44f2665 100644 --- a/src/wormhole/codes.py +++ b/src/wormhole/codes.py @@ -4,6 +4,7 @@ from .wordlist import (byte_to_even_word, byte_to_odd_word, even_words_lowercase, odd_words_lowercase) def make_code(channel_id, code_length): + assert isinstance(channel_id, type(u"")), type(channel_id) words = [] for i in range(code_length): # we start with an "odd word" @@ -11,7 +12,7 @@ def make_code(channel_id, code_length): words.append(byte_to_odd_word[os.urandom(1)].lower()) else: words.append(byte_to_even_word[os.urandom(1)].lower()) - return u"%d-%s" % (channel_id, u"-".join(words)) + return u"%s-%s" % (channel_id, u"-".join(words)) def extract_channel_id(code): channel_id = int(code.split("-")[0]) diff --git a/src/wormhole/server/db-schemas/v1.sql b/src/wormhole/server/db-schemas/v1.sql index 7d7115a..2740661 100644 --- a/src/wormhole/server/db-schemas/v1.sql +++ b/src/wormhole/server/db-schemas/v1.sql @@ -10,9 +10,9 @@ CREATE TABLE `version` CREATE TABLE `messages` ( `appid` VARCHAR, - `channelid` INTEGER, + `channelid` VARCHAR, `side` VARCHAR, - `phase` VARCHAR, -- not numeric, more of a PAKE-phase indicator string + `phase` VARCHAR, -- numeric or string -- phase="_allocate" and "_deallocate" are used internally `body` VARCHAR, `server_rx` INTEGER, diff --git a/src/wormhole/server/rendezvous.py b/src/wormhole/server/rendezvous.py index 9f20b61..691d06e 100644 --- a/src/wormhole/server/rendezvous.py +++ b/src/wormhole/server/rendezvous.py @@ -212,28 +212,31 @@ class AppNamespace: claimed = self.get_claimed() for size in range(1,4): # stick to 1-999 for now available = set() - for cid in range(10**(size-1), 10**size): + for cid_int in range(10**(size-1), 10**size): + cid = u"%d" % cid_int if cid not in claimed: available.add(cid) if available: return random.choice(list(available)) # ouch, 999 currently claimed. Try random ones for a while. for tries in range(1000): - cid = random.randrange(1000, 1000*1000) + cid_int = random.randrange(1000, 1000*1000) + cid = u"%d" % cid_int if cid not in claimed: return cid raise ValueError("unable to find a free channel-id") def claim_channel(self, channelid, side): + assert isinstance(channelid, type(u"")), type(channelid) channel = self.get_channel(channelid) channel.claim(side) return channel def get_channel(self, channelid): - assert isinstance(channelid, int) + assert isinstance(channelid, type(u"")) if not channelid in self._channels: if self._log_requests: - log.msg("spawning #%d for appid %s" % (channelid, self._appid)) + log.msg("spawning #%s for appid %s" % (channelid, self._appid)) self._channels[channelid] = Channel(self, self._db, self._welcome, self._blur_usage, self._log_requests, @@ -247,7 +250,7 @@ class AppNamespace: if channelid in self._channels: self._channels.pop(channelid) if self._log_requests: - log.msg("freed+killed #%d, now have %d DB channels, %d live" % + log.msg("freed+killed #%s, now have %d DB channels, %d live" % (channelid, len(self.get_claimed()), len(self._channels))) def prune_old_channels(self): diff --git a/src/wormhole/server/rendezvous_websocket.py b/src/wormhole/server/rendezvous_websocket.py index 9ec113d..358063d 100644 --- a/src/wormhole/server/rendezvous_websocket.py +++ b/src/wormhole/server/rendezvous_websocket.py @@ -132,6 +132,7 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): if self._did_allocate: raise Error("You already allocated one channel, don't be greedy") channelid = self._app.find_available_channelid() + assert isinstance(channelid, type(u"")) self._did_allocate = True channel = self._app.claim_channel(channelid, self._side) self._channels[channelid] = channel @@ -141,6 +142,7 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): if "channelid" not in msg: raise Error("claim requires 'channelid'") channelid = msg["channelid"] + assert isinstance(channelid, type(u"")), type(channelid) if channelid not in self._channels: channel = self._app.claim_channel(channelid, self._side) self._channels[channelid] = channel @@ -149,6 +151,7 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): channelid = msg["channelid"] if channelid not in self._channels: raise Error("must claim channel before watching") + assert isinstance(channelid, type(u"")) channel = self._channels[channelid] def _send(event): self.send("message", channelid=channelid, message=event) @@ -161,6 +164,7 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): channelid = msg["channelid"] if channelid not in self._channels: raise Error("must claim channel before adding") + assert isinstance(channelid, type(u"")) channel = self._channels[channelid] if "phase" not in msg: raise Error("missing 'phase'") @@ -174,6 +178,7 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): channelid = msg["channelid"] if channelid not in self._channels: raise Error("must claim channel before releasing") + assert isinstance(channelid, type(u"")) channel = self._channels[channelid] deleted = channel.release(self._side, msg.get("mood")) del self._channels[channelid] diff --git a/src/wormhole/test/test_server.py b/src/wormhole/test/test_server.py index 6deef6c..ad3e227 100644 --- a/src/wormhole/test/test_server.py +++ b/src/wormhole/test/test_server.py @@ -220,7 +220,7 @@ class WebSocketAPI(ServerBase, unittest.TestCase): msg = yield c1.next_non_ack() self.assertEqual(msg["type"], u"allocated") cid = msg["channelid"] - self.failUnlessIsInstance(cid, int) + self.failUnlessIsInstance(cid, type(u"")) self.assertEqual(app.get_claimed(), set([cid])) channel = app.get_channel(cid) self.assertEqual(channel.get_messages(), []) @@ -254,7 +254,7 @@ class WebSocketAPI(ServerBase, unittest.TestCase): msg = yield c1.next_non_ack() self.assertEqual(msg["type"], u"allocated") cid = msg["channelid"] - self.failUnlessIsInstance(cid, int) + self.failUnlessIsInstance(cid, type(u"")) self.assertEqual(app.get_claimed(), set([cid])) channel = app.get_channel(cid) self.assertEqual(channel.get_messages(), []) diff --git a/src/wormhole/twisted/transcribe.py b/src/wormhole/twisted/transcribe.py index 5f362a3..d8f39d6 100644 --- a/src/wormhole/twisted/transcribe.py +++ b/src/wormhole/twisted/transcribe.py @@ -234,6 +234,7 @@ class _Wormhole: if self._channelid is not None: return self._signal_error("got duplicate channelid") self._channelid = msg["channelid"] + assert isinstance(self._channelid, type(u"")), type(self._channelid) self._wakeup() def _start(self): @@ -322,7 +323,8 @@ class _Wormhole: if not mo: raise ValueError("code (%s) must start with NN-" % code) with self._timing.add("API set_code"): - self._channelid = int(mo.group(1)) + self._channelid = mo.group(1) + assert isinstance(self._channelid, type(u"")), type(self._channelid) self._set_code(code) self._start() From c4c0cf71eb9ed38af3b9e55b48c9763ba21635ac Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Fri, 13 May 2016 00:43:59 -0700 Subject: [PATCH 14/51] add test --- src/wormhole/test/test_server.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/src/wormhole/test/test_server.py b/src/wormhole/test/test_server.py index ad3e227..a32c878 100644 --- a/src/wormhole/test/test_server.py +++ b/src/wormhole/test/test_server.py @@ -201,6 +201,34 @@ class WebSocketAPI(ServerBase, unittest.TestCase): self.check_welcome(msg) self.assertEqual(self._rendezvous._apps, {}) + @inlineCallbacks + def test_claim(self): + r = self._rendezvous.get_app(u"appid") + c1 = yield self.make_client() + msg = yield c1.next_non_ack() + self.check_welcome(msg) + c1.send(u"bind", appid=u"appid", side=u"side") + c1.send(u"claim", channelid=u"1") + yield c1.sync() + self.assertEqual(r.get_claimed(), set(u"1")) + + c1.send(u"claim", channelid=u"2") + yield c1.sync() + self.assertEqual(r.get_claimed(), set([u"1", u"2"])) + + c1.send(u"claim", channelid=u"72aoqnnnbj7r2") + yield c1.sync() + self.assertEqual(r.get_claimed(), set([u"1", u"2", u"72aoqnnnbj7r2"])) + + c1.send(u"release", channelid=u"2") + yield c1.sync() + self.assertEqual(r.get_claimed(), set([u"1", u"72aoqnnnbj7r2"])) + + c1.send(u"release", channelid=u"1") + yield c1.sync() + self.assertEqual(r.get_claimed(), set([u"72aoqnnnbj7r2"])) + + @inlineCallbacks def test_allocate_1(self): c1 = yield self.make_client() From 5dd91c73110a9dff2f72ef24cf1d74a0ea197ce2 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Fri, 13 May 2016 00:46:12 -0700 Subject: [PATCH 15/51] test too-many-allocate, allocate+claim --- src/wormhole/test/test_server.py | 38 ++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/src/wormhole/test/test_server.py b/src/wormhole/test/test_server.py index a32c878..8fa709c 100644 --- a/src/wormhole/test/test_server.py +++ b/src/wormhole/test/test_server.py @@ -327,6 +327,7 @@ class WebSocketAPI(ServerBase, unittest.TestCase): @inlineCallbacks def test_allocate_and_claim(self): + r = self._rendezvous.get_app(u"appid") c1 = yield self.make_client() msg = yield c1.next_non_ack() self.check_welcome(msg) @@ -339,6 +340,43 @@ class WebSocketAPI(ServerBase, unittest.TestCase): yield c1.sync() # there should no error self.assertEqual(c1.errors, []) + self.assertEqual(r.get_claimed(), set([cid])) + + # but trying to allocate twice is an error + c1.send(u"allocate") + yield c1.sync() + self.assertEqual(len(c1.errors), 1) + self.assertEqual(c1.errors[0]["error"], + "You already allocated one channel, don't be greedy") + self.assertEqual(r.get_claimed(), set([cid])) + + @inlineCallbacks + def test_allocate_and_claim_two(self): + r = self._rendezvous.get_app(u"appid") + c1 = yield self.make_client() + msg = yield c1.next_non_ack() + self.check_welcome(msg) + c1.send(u"bind", appid=u"appid", side=u"side") + c1.send(u"allocate") + msg = yield c1.next_non_ack() + self.assertEqual(msg["type"], u"allocated") + cid = msg["channelid"] + c1.send(u"claim", channelid=cid) + yield c1.sync() + # there should no error + self.assertEqual(c1.errors, []) + + c1.send(u"claim", channelid=u"other") + yield c1.sync() + self.assertEqual(c1.errors, []) + self.assertEqual(r.get_claimed(), set([cid, u"other"])) + + c1.send(u"release", channelid=cid) + yield c1.sync() + self.assertEqual(r.get_claimed(), set([u"other"])) + c1.send(u"release", channelid="other") + yield c1.sync() + self.assertEqual(r.get_claimed(), set()) @inlineCallbacks def test_message(self): From a74b1b1e3a398104630e47b32a43d5481d771550 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Mon, 16 May 2016 22:04:25 -0700 Subject: [PATCH 16/51] WIP: new server protocol --- src/wormhole/server/database.py | 2 +- src/wormhole/server/db-schemas/v1.sql | 43 --------- src/wormhole/server/db-schemas/v2.sql | 98 +++++++++++++++++++++ src/wormhole/server/rendezvous_websocket.py | 81 +++++++++++------ 4 files changed, 151 insertions(+), 73 deletions(-) delete mode 100644 src/wormhole/server/db-schemas/v1.sql create mode 100644 src/wormhole/server/db-schemas/v2.sql diff --git a/src/wormhole/server/database.py b/src/wormhole/server/database.py index d8edee0..163c8cd 100644 --- a/src/wormhole/server/database.py +++ b/src/wormhole/server/database.py @@ -22,7 +22,7 @@ def get_db(dbfile, stderr=sys.stderr): raise DBError("Unable to create/open db file %s: %s" % (dbfile, e)) db.row_factory = sqlite3.Row - VERSION = 1 + VERSION = 2 if must_create: schema = get_schema(VERSION) db.executescript(schema) diff --git a/src/wormhole/server/db-schemas/v1.sql b/src/wormhole/server/db-schemas/v1.sql deleted file mode 100644 index 2740661..0000000 --- a/src/wormhole/server/db-schemas/v1.sql +++ /dev/null @@ -1,43 +0,0 @@ - --- note: anything which isn't an boolean, integer, or human-readable unicode --- string, (i.e. binary strings) will be stored as hex - -CREATE TABLE `version` -( - `version` INTEGER -- contains one row, set to 1 -); - -CREATE TABLE `messages` -( - `appid` VARCHAR, - `channelid` VARCHAR, - `side` VARCHAR, - `phase` VARCHAR, -- numeric or string - -- phase="_allocate" and "_deallocate" are used internally - `body` VARCHAR, - `server_rx` INTEGER, - `msgid` VARCHAR -); -CREATE INDEX `messages_idx` ON `messages` (`appid`, `channelid`); - -CREATE TABLE `usage` -( - `type` VARCHAR, -- "rendezvous" or "transit" - `started` INTEGER, -- seconds since epoch, rounded to one day - `result` VARCHAR, -- happy, scary, lonely, errory, pruney - -- rendezvous moods: - -- "happy": both sides close with mood=happy - -- "scary": any side closes with mood=scary (bad MAC, probably wrong pw) - -- "lonely": any side closes with mood=lonely (no response from 2nd side) - -- "errory": any side closes with mood=errory (other errors) - -- "pruney": channels which get pruned for inactivity - -- "crowded": three or more sides were involved - -- transit moods: - -- "errory": this side have the wrong handshake - -- "lonely": good handshake, but the other side never showed up - -- "happy": both sides gave correct handshake - `total_bytes` INTEGER, -- for transit, total bytes relayed (both directions) - `total_time` INTEGER, -- seconds from start to closed, or None - `waiting_time` INTEGER -- seconds from start to 2nd side appearing, or None -); -CREATE INDEX `usage_idx` ON `usage` (`started`); diff --git a/src/wormhole/server/db-schemas/v2.sql b/src/wormhole/server/db-schemas/v2.sql new file mode 100644 index 0000000..795d49d --- /dev/null +++ b/src/wormhole/server/db-schemas/v2.sql @@ -0,0 +1,98 @@ + +-- note: anything which isn't an boolean, integer, or human-readable unicode +-- string, (i.e. binary strings) will be stored as hex + +CREATE TABLE `version` +( + `version` INTEGER -- contains one row, set to 2 +); + + +-- Wormhole codes use a "nameplate": a short identifier which is only used to +-- reference a specific (long-named) mailbox. The codes only use numeric +-- nameplates, but the protocol and server allow can use arbitrary strings. +CREATE TABLE `nameplates` +( + `app_id` VARCHAR, + `id` VARCHAR PRIMARY KEY, + `mailbox_id` VARCHAR, -- really a foreign key + `side1` VARCHAR, -- side name, or NULL + `side2` VARCHAR -- side name, or NULL +); +CREATE INDEX `nameplates_idx` ON `nameplates` (`app_id`, `id`); +CREATE INDEX `nameplates_mailbox_idx` ON `nameplates` (`app_id`, `mailbox_id`); + +-- Clients exchange messages through a "mailbox", which has a long (randomly +-- unique) identifier and a queue of messages. +CREATE TABLE `mailboxes` +( + `app_id` VARCHAR, + `id` VARCHAR, + `side1` VARCHAR -- side name, or NULL + `side2` VARCHAR -- side name, or NULL + -- timing data for the (optional) linked nameplate + `nameplate_started` INTEGER, -- time when related nameplace was opened + `nameplate_second` INTEGER, -- time when second side opened + `nameplate_closed` INTEGER, -- time when closed + -- timing data for the mailbox itself + `started` INTEGER, -- time when opened + `second` INTEGER, -- time when second side opened + `closed` INTEGER -- time when closed +); +CREATE INDEX `mailboxes_idx` ON `mailboxes` (`app_id`, `id`); + +CREATE TABLE `messages` +( + `app_id` VARCHAR, + `mailbox_id` VARCHAR, + `side` VARCHAR, + `phase` VARCHAR, -- numeric or string + `body` VARCHAR, + `server_rx` INTEGER, + `msg_id` VARCHAR +); +CREATE INDEX `messages_idx` ON `messages` (`app_id`, `mailbox_id`); + +CREATE TABLE `nameplate_usage` +( + `started` INTEGER, -- seconds since epoch, rounded to "blur time" + `total_time` INTEGER, -- seconds from open to last close + `result` VARCHAR, -- happy, lonely, pruney, crowded + -- nameplate moods: + -- "happy": two sides open and close + -- "lonely": one side opens and closes (no response from 2nd side) + -- "pruney": channels which get pruned for inactivity + -- "crowded": three or more sides were involved + `waiting_time` INTEGER -- seconds from start to 2nd side appearing, or None +); +CREATE INDEX `nameplate_usage_idx` ON `nameplate_usage` (`started`); + +CREATE TABLE `mailbox_usage` +( + `started` INTEGER, -- seconds since epoch, rounded to "blur time" + `total_time` INTEGER, -- seconds from open to last close + `waiting_time` INTEGER, -- seconds from start to 2nd side appearing, or None + `result` VARCHAR -- happy, scary, lonely, errory, pruney + -- rendezvous moods: + -- "happy": both sides close with mood=happy + -- "scary": any side closes with mood=scary (bad MAC, probably wrong pw) + -- "lonely": any side closes with mood=lonely (no response from 2nd side) + -- "errory": any side closes with mood=errory (other errors) + -- "pruney": channels which get pruned for inactivity + -- "crowded": three or more sides were involved +); +CREATE INDEX `mailbox_usage_idx` ON `mailbox_usage` (`started`); + +CREATE TABLE `transit_usage` +( + `started` INTEGER, -- seconds since epoch, rounded to "blur time" + `total_time` INTEGER, -- seconds from open to last close + `waiting_time` INTEGER, -- seconds from start to 2nd side appearing, or None + `total_bytes` INTEGER, -- total bytes relayed (both directions) + `result` VARCHAR -- happy, scary, lonely, errory, pruney + -- transit moods: + -- "errory": one side gave the wrong handshake + -- "lonely": good handshake, but the other side never showed up + -- "happy": both sides gave correct handshake +); +CREATE INDEX `transit_usage_idx` ON `transit_usage` (`started`); diff --git a/src/wormhole/server/rendezvous_websocket.py b/src/wormhole/server/rendezvous_websocket.py index 358063d..e7f1fde 100644 --- a/src/wormhole/server/rendezvous_websocket.py +++ b/src/wormhole/server/rendezvous_websocket.py @@ -3,26 +3,47 @@ from twisted.internet import reactor from twisted.python import log from autobahn.twisted import websocket -# Each WebSocket connection is bound to one "appid", one "side", and zero or -# more "channelids". The connection's appid and side are set by the "bind" -# message (which must be the first message on the connection). Both must be -# set before any other message (allocate, claim, watch, add, deallocate) will -# be accepted. Short channel IDs can be obtained from the server with an -# "allocate" message. Longer ones can be selected independently by the -# client. Channels are maintained (saved from deletion) by a "claim" message -# (and also incidentally by "allocate"). Channels are deleted when the last -# claim is released with "release". +# The WebSocket allows the client to send "commands" to the server, and the +# server to send "responses" to the client. Note that commands and responses +# are not necessarily one-to-one. All commands provoke an "ack" response +# (with a copy of the original message) for timing, testing, and +# synchronization purposes. All commands and responses are JSON-encoded. -# All websocket messages are JSON-encoded. The client can send us "inbound" -# messages (marked as "->" below), which may (or may not) provoke immediate -# (or delayed) "outbound" messages (marked as "<-"). There is no guaranteed -# correlation between requests and responses. In this list, "A -> B" means -# that some time after A is received, at least one message of type B will be -# sent out (probably). +# Each WebSocket connection is bound to one "appid" and one "side", which are +# set by the "bind" command (which must be the first command on the +# connection), and must be set before any other command will be accepted. -# All outbound messages include a "server_tx" key, which is a float (seconds -# since epoch) with the server clock just before the outbound message was -# written to the socket. Unrecognized keys will be ignored. +# Each connection can be bound to a single "mailbox" (a two-sided +# store-and-forward queue, identified by the "mailbox id": a long, randomly +# unique string identifier) by using the "open" command. This protects the +# mailbox from idle closure, enables the "add" command (to put new messages +# in the queue), and triggers delivery of past and future messages via the +# "message" response. The "close" command removes the binding (but note that +# it does not enable the subsequent binding of a second mailbox). When the +# last side closes a mailbox, its contents are deleted. + +# Additionally, the connection can be bound a single "nameplate", which is +# short identifier that makes up the first component of a wormhole code. Each +# nameplate points to a single long-id "mailbox". The "allocate" message +# determines the shortest available numeric nameplate, reserves it, and +# returns the nameplate id. "list" returns a list of all numeric nameplates +# which currently have only one side active (i.e. they are waiting for a +# partner). The "claim" message reserves an arbitrary nameplate id (perhaps +# the receiver of a wormhole connection typed in a code they got from the +# sender, or perhaps the two sides agreed upon a code offline and are both +# typing it in), and the "release" message releases it. When every side that +# has claimed the nameplate has also released it, the nameplate is +# deallocated (but they will probably keep the underlying mailbox open). + +# Inbound (client to server) commands are marked as "->" below. Unrecognized +# inbound keys will be ignored. Outbound (server to client) responses use +# "<-". There is no guaranteed correlation between requests and responses. In +# this list, "A -> B" means that some time after A is received, at least one +# message of type B will be sent out (probably). + +# All responses include a "server_tx" key, which is a float (seconds since +# epoch) with the server clock just before the outbound response was written +# to the socket. # connection -> welcome # <- {type: "welcome", welcome: {}} # .welcome keys are all optional: @@ -31,19 +52,21 @@ from autobahn.twisted import websocket # error: all clients display mesage, then terminate with error # -> {type: "bind", appid:, side:} # -# -> {type: "list"} -> channelids -# <- {type: "channelids", channelids: [int..]} -# -> {type: "allocate"} -> allocated -# <- {type: "allocated", channelid: int} -# -> {type: "claim", channelid: int} +# -> {type: "list"} -> nameplates +# <- {type: "nameplates", nameplates: [str..]} +# -> {type: "allocate"} -> nameplate, mailbox +# <- {type: "nameplate", nameplate: str} +# -> {type: "claim", nameplate: str} -> mailbox +# <- {type: "mailbox", mailbox: str} +# -> {type: "release"} # -# -> {type: "watch", channelid: int} -> message -# sends old messages and more in future -# <- {type: "message", channelid: int, message: {phase:, body:}} # body is hex -# -> {type: "add", channelid: int, phase: str, body: hex} # will send echo +# -> {type: "open", mailbox: str} -> message +# sends old messages now, and subscribes to deliver future messages +# <- {type: "message", message: {phase:, body:}} # body is hex +# -> {type: "add", phase: str, body: hex} # will send echo in a "message" # -# -> {type: "release", channelid: int, mood: str} -> deallocated -# <- {type: "released", channelid: int, status: waiting|deleted} +# -> {type: "close", mood: str} -> closed +# <- {type: "closed", status: waiting|deleted} # # <- {type: "error", error: str, orig: {}} # in response to malformed msgs From 2ea5d9629034b969fefa8b1c6798d2919536a3b1 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Tue, 17 May 2016 17:35:44 -0700 Subject: [PATCH 17/51] Channels don't need "welcome" anymore --- src/wormhole/server/rendezvous.py | 5 ++--- src/wormhole/test/test_server.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/wormhole/server/rendezvous.py b/src/wormhole/server/rendezvous.py index 691d06e..2072812 100644 --- a/src/wormhole/server/rendezvous.py +++ b/src/wormhole/server/rendezvous.py @@ -16,8 +16,7 @@ CLAIM = u"_claim" RELEASE = u"_release" class Channel: - def __init__(self, app, db, welcome, blur_usage, log_requests, - appid, channelid): + def __init__(self, app, db, blur_usage, log_requests, appid, channelid): self._app = app self._db = db self._blur_usage = blur_usage @@ -237,7 +236,7 @@ class AppNamespace: if not channelid in self._channels: if self._log_requests: log.msg("spawning #%s for appid %s" % (channelid, self._appid)) - self._channels[channelid] = Channel(self, self._db, self._welcome, + self._channels[channelid] = Channel(self, self._db, self._blur_usage, self._log_requests, self._appid, channelid) diff --git a/src/wormhole/test/test_server.py b/src/wormhole/test/test_server.py index 8fa709c..e98f4d1 100644 --- a/src/wormhole/test/test_server.py +++ b/src/wormhole/test/test_server.py @@ -466,7 +466,7 @@ class WebSocketAPI(ServerBase, unittest.TestCase): class Summary(unittest.TestCase): def test_summarize(self): - c = rendezvous.Channel(None, None, None, None, False, None, None) + c = rendezvous.Channel(None, None, None, False, None, None) A = rendezvous.CLAIM D = rendezvous.RELEASE From 5994eb11d4398a981e2fcd4ab442d7f26b3eea0b Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Wed, 18 May 2016 00:16:46 -0700 Subject: [PATCH 18/51] WIP new proto --- src/wormhole/server/db-schemas/v2.sql | 2 +- src/wormhole/server/rendezvous.py | 195 +++++++++++++++----- src/wormhole/server/rendezvous_websocket.py | 70 ++++--- 3 files changed, 187 insertions(+), 80 deletions(-) diff --git a/src/wormhole/server/db-schemas/v2.sql b/src/wormhole/server/db-schemas/v2.sql index 795d49d..6659fc6 100644 --- a/src/wormhole/server/db-schemas/v2.sql +++ b/src/wormhole/server/db-schemas/v2.sql @@ -14,7 +14,7 @@ CREATE TABLE `version` CREATE TABLE `nameplates` ( `app_id` VARCHAR, - `id` VARCHAR PRIMARY KEY, + `id` VARCHAR, `mailbox_id` VARCHAR, -- really a foreign key `side1` VARCHAR, -- side name, or NULL `side2` VARCHAR -- side name, or NULL diff --git a/src/wormhole/server/rendezvous.py b/src/wormhole/server/rendezvous.py index 2072812..1007579 100644 --- a/src/wormhole/server/rendezvous.py +++ b/src/wormhole/server/rendezvous.py @@ -15,13 +15,75 @@ EXPIRATION_CHECK_PERIOD = 2*HOUR CLAIM = u"_claim" RELEASE = u"_release" -class Channel: - def __init__(self, app, db, blur_usage, log_requests, appid, channelid): +def get_sides(row): + return set([s for s in [row["side1"], row["side2"]] if s]) +def make_sides(side1, side2): + return list(sides) + [None] * (2 - len(sides)) +def generate_mailbox_id(): + return base64.b32encode(os.urandom(8)).lower().strip("=") + +# Unlike Channels, these instances are ephemeral, and are created and +# destroyed casually. +class Nameplate: + def __init__(self, app_id, db, id, mailbox_id): + self._app_id = app_id + self._db = db + self._id = id + self._mailbox_id = mailbox_id + + def get_id(self): + return self._id + + def get_mailbox_id(self): + return self._mailbox_id + + def claim(self, side, when): + db = self._db + sides = get_sides(db.execute("SELECT `side1`, `side2` FROM `nameplates`" + " WHERE `app_id`=? AND `id`=?", + (self._app_id, self._id)).fetchone()) + old_sides = len(sides) + sides.add(side) + if len(sides) > 2: + # XXX: crowded: bail + pass + sides12 = make_sides(sides) + db.execute("UPDATE `nameplates` SET `side1`=?, `side2`=?" + " WHERE `app_id`=? AND `id`=?", + (sides12[0], sides12[1], self._app_id, self._id)) + if old_sides == 0: + db.execute("UPDATE `mailboxes` SET `nameplate_started`=?" + " WHERE `app_id`=? AND `id`=?", + (when, self._app_id, self._mailbox_id)) + else: + db.execute("UPDATE `mailboxes` SET `nameplate_second`=?" + " WHERE `app_id`=? AND `id`=?", + (when, self._app_id, self._mailbox_id)) + db.commit() + + def release(self, side, when): + db = self._db + sides = get_sides(db.execute("SELECT `side1`, `side2` FROM `nameplates`" + " WHERE `app_id`=? AND `id`=?", + (self._app_id, self._id)).fetchone()) + sides.discard(side) + sides12 = make_sides(sides) + db.execute("UPDATE `nameplates` SET `side1`=?, `side2`=?" + " WHERE `app_id`=? AND `id`=?", + (sides12[0], sides12[1], self._app_id, self._id)) + if len(sides) == 0: + db.execute("UPDATE `mailboxes` SET `nameplate_closed`=?" + " WHERE `app_id`=? AND `id`=?", + (when, self._app_id, self._mailbox_id)) + db.commit() + +class Mailbox: + def __init__(self, app, db, blur_usage, log_requests, app_id, channelid): self._app = app self._db = db self._blur_usage = blur_usage self._log_requests = log_requests - self._appid = appid + self._app_id = app_id self._channelid = channelid self._listeners = {} # handle -> (send_f, stop_f) # "handle" is a hashable object, for deregistration @@ -34,9 +96,9 @@ class Channel: messages = [] db = self._db for row in db.execute("SELECT * FROM `messages`" - " WHERE `appid`=? AND `channelid`=?" + " WHERE `app_id`=? AND `channelid`=?" " ORDER BY `server_rx` ASC", - (self._appid, self._channelid)).fetchall(): + (self._app_id, self._channelid)).fetchall(): if row["phase"] in (CLAIM, RELEASE): continue messages.append({"phase": row["phase"], "body": row["body"], @@ -58,10 +120,10 @@ class Channel: def _add_message(self, side, phase, body, server_rx, msgid): db = self._db db.execute("INSERT INTO `messages`" - " (`appid`, `channelid`, `side`, `phase`, `body`," + " (`app_id`, `channelid`, `side`, `phase`, `body`," " `server_rx`, `msgid`)" " VALUES (?,?,?,?,?, ?,?)", - (self._appid, self._channelid, side, phase, body, + (self._app_id, self._channelid, side, phase, body, server_rx, msgid)) db.commit() @@ -78,13 +140,13 @@ class Channel: db = self._db seen = set([row["side"] for row in db.execute("SELECT `side` FROM `messages`" - " WHERE `appid`=? AND `channelid`=?", - (self._appid, self._channelid))]) + " WHERE `app_id`=? AND `channelid`=?", + (self._app_id, self._channelid))]) freed = set([row["side"] for row in db.execute("SELECT `side` FROM `messages`" - " WHERE `appid`=? AND `channelid`=?" + " WHERE `app_id`=? AND `channelid`=?" " AND `phase`=?", - (self._appid, self._channelid, RELEASE))]) + (self._app_id, self._channelid, RELEASE))]) if seen - freed: return False self.delete_and_summarize() @@ -94,9 +156,9 @@ class Channel: if self._listeners: return False c = self._db.execute("SELECT `server_rx` FROM `messages`" - " WHERE `appid`=? AND `channelid`=?" + " WHERE `app_id`=? AND `channelid`=?" " ORDER BY `server_rx` DESC LIMIT 1", - (self._appid, self._channelid)) + (self._app_id, self._channelid)) rows = c.fetchall() if not rows: return True @@ -169,15 +231,15 @@ class Channel: def delete_and_summarize(self): db = self._db c = self._db.execute("SELECT * FROM `messages`" - " WHERE `appid`=? AND `channelid`=?" + " WHERE `app_id`=? AND `channelid`=?" " ORDER BY `server_rx`", - (self._appid, self._channelid)) + (self._app_id, self._channelid)) messages = c.fetchall() summary = self._summarize(messages, time.time()) self._store_summary(summary) db.execute("DELETE FROM `messages`" - " WHERE `appid`=? AND `channelid`=?", - (self._appid, self._channelid)) + " WHERE `app_id`=? AND `channelid`=?", + (self._app_id, self._channelid)) db.commit() # Shut down any listeners, just in case they're still lingering @@ -193,37 +255,70 @@ class Channel: stop_f() class AppNamespace: - def __init__(self, db, welcome, blur_usage, log_requests, appid): + def __init__(self, db, welcome, blur_usage, log_requests, app_id): self._db = db self._welcome = welcome self._blur_usage = blur_usage self._log_requests = log_requests - self._appid = appid + self._app_id = app_id self._channels = {} - def get_claimed(self): + def get_nameplate_ids(self): db = self._db - c = db.execute("SELECT DISTINCT `channelid` FROM `messages`" - " WHERE `appid`=?", (self._appid,)) - return set([row["channelid"] for row in c.fetchall()]) + # TODO: filter this to numeric ids? + c = db.execute("SELECT DISTINCT `id` FROM `nameplates`" + " WHERE `app_id`=?", (self._app_id,)) + return set([row["id"] for row in c.fetchall()]) - def find_available_channelid(self): - claimed = self.get_claimed() + def find_available_nameplate_id(self): + claimed = self.get_nameplate_ids() for size in range(1,4): # stick to 1-999 for now available = set() - for cid_int in range(10**(size-1), 10**size): - cid = u"%d" % cid_int - if cid not in claimed: - available.add(cid) + for id_int in range(10**(size-1), 10**size): + id = u"%d" % id_int + if id not in claimed: + available.add(id) if available: return random.choice(list(available)) # ouch, 999 currently claimed. Try random ones for a while. for tries in range(1000): - cid_int = random.randrange(1000, 1000*1000) - cid = u"%d" % cid_int - if cid not in claimed: - return cid - raise ValueError("unable to find a free channel-id") + id_int = random.randrange(1000, 1000*1000) + id = u"%d" % id_int + if id not in claimed: + return id + raise ValueError("unable to find a free nameplate-id") + + def _get_mailbox_id(self, nameplate_id): + row = self._db.execute("SELECT `mailbox_id` FROM `nameplates`" + " WHERE `app_id`=? AND `id`=?", + (self._app_id, nameplate_id)).fetchone() + return row["mailbox_id"] + + def claim_nameplate(self, nameplate_id, side, when): + assert isinstance(nameplate_id, type(u"")), type(nameplate_id) + db = self._db + rows = db.execute("SELECT * FROM `nameplates`" + " WHERE `app_id`=? AND `id`=?", + (self._app_id, nameplate_id)) + if rows: + mailbox_id = rows[0]["mailbox_id"] + else: + if self._log_requests: + log.msg("creating nameplate#%s for app_id %s" % + (nameplate_id, self._app_id)) + mailbox_id = UUID() + db.execute("INSERT INTO `mailboxes`" + " (`app_id`, `id`)" + " VALUES(?,?)", + (self._app_id, mailbox_id)) + db.execute("INSERT INTO `nameplates`" + " (`app_id`, `id`, `mailbox_id`, `side1`, `side2`)" + " VALUES(?,?,?,?,?)", + (self._app_id, nameplate_id, mailbox_id, None, None)) + + nameplate = Nameplate(self._app_id, self._db, nameplate_id, mailbox_id) + nameplate.claim(side, when) + return nameplate def claim_channel(self, channelid, side): assert isinstance(channelid, type(u"")), type(channelid) @@ -235,11 +330,11 @@ class AppNamespace: assert isinstance(channelid, type(u"")) if not channelid in self._channels: if self._log_requests: - log.msg("spawning #%s for appid %s" % (channelid, self._appid)) + log.msg("spawning #%s for app_id %s" % (channelid, self._app_id)) self._channels[channelid] = Channel(self, self._db, self._blur_usage, self._log_requests, - self._appid, channelid) + self._app_id, channelid) return self._channels[channelid] def free_channel(self, channelid): @@ -293,28 +388,28 @@ class Rendezvous(service.MultiService): def get_log_requests(self): return self._log_requests - def get_app(self, appid): - assert isinstance(appid, type(u"")) - if not appid in self._apps: + def get_app(self, app_id): + assert isinstance(app_id, type(u"")) + if not app_id in self._apps: if self._log_requests: - log.msg("spawning appid %s" % (appid,)) - self._apps[appid] = AppNamespace(self._db, self._welcome, + log.msg("spawning app_id %s" % (app_id,)) + self._apps[app_id] = AppNamespace(self._db, self._welcome, self._blur_usage, - self._log_requests, appid) - return self._apps[appid] + self._log_requests, app_id) + return self._apps[app_id] def prune(self): # As with AppNamespace.prune_old_channels, we log for now. log.msg("beginning app prune") - c = self._db.execute("SELECT DISTINCT `appid` FROM `messages`") - apps = set([row["appid"] for row in c.fetchall()]) # these have messages + c = self._db.execute("SELECT DISTINCT `app_id` FROM `messages`") + apps = set([row["app_id"] for row in c.fetchall()]) # these have messages apps.update(self._apps) # these might have listeners - for appid in apps: - log.msg(" app prune checking %r" % (appid,)) - still_active = self.get_app(appid).prune_old_channels() + for app_id in apps: + log.msg(" app prune checking %r" % (app_id,)) + still_active = self.get_app(app_id).prune_old_channels() if not still_active: - log.msg("prune pops app %r" % (appid,)) - self._apps.pop(appid) + log.msg("prune pops app %r" % (app_id,)) + self._apps.pop(app_id) log.msg("app prune ends, %d remaining apps" % len(self._apps)) def stopService(self): diff --git a/src/wormhole/server/rendezvous_websocket.py b/src/wormhole/server/rendezvous_websocket.py index e7f1fde..b2bf3f2 100644 --- a/src/wormhole/server/rendezvous_websocket.py +++ b/src/wormhole/server/rendezvous_websocket.py @@ -102,10 +102,7 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): try: if "type" not in msg: raise Error("missing 'type'") - if "id" in msg: - # Only ack clients modern enough to include [id]. Older ones - # won't recognize the message, then they'll abort. - self.send("ack", id=msg["id"]) + self.send("ack", id=msg.get("id")) mtype = msg["type"] if mtype == "ping": @@ -118,15 +115,18 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): if mtype == "list": return self.handle_list() if mtype == "allocate": - return self.handle_allocate() + return self.handle_allocate(server_rx) if mtype == "claim": - return self.handle_claim(msg) - if mtype == "watch": - return self.handle_watch(msg) + return self.handle_claim(msg, server_rx) + if mtype == "release": + return self.handle_release(msg, server_rx) + + if mtype == "open": + return self.handle_open(msg) if mtype == "add": return self.handle_add(msg, server_rx) - if mtype == "release": - return self.handle_release(msg) + if mtype == "close": + return self.handle_close(msg) raise Error("Unknown type") except Error as e: @@ -147,30 +147,42 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): self._app = self.factory.rendezvous.get_app(msg["appid"]) self._side = msg["side"] - def handle_list(self): - channelids = sorted(self._app.get_claimed()) - self.send("channelids", channelids=channelids) - def handle_allocate(self): + def handle_list(self): + nameplate_ids = sorted(self._app.get_nameplate_ids()) + self.send("nameplates", nameplates=nameplate_ids) + + def handle_allocate(self, server_rx): if self._did_allocate: raise Error("You already allocated one channel, don't be greedy") - channelid = self._app.find_available_channelid() - assert isinstance(channelid, type(u"")) + nameplate_id = self._app.find_available_nameplate_id() + assert isinstance(nameplate_id, type(u"")) self._did_allocate = True - channel = self._app.claim_channel(channelid, self._side) - self._channels[channelid] = channel - self.send("allocated", channelid=channelid) + self._nameplate = self._app.claim_nameplate(nameplate_id, self._side, + server_rx) + self.send("nameplate", nameplate=nameplate_id) - def handle_claim(self, msg): - if "channelid" not in msg: - raise Error("claim requires 'channelid'") - channelid = msg["channelid"] - assert isinstance(channelid, type(u"")), type(channelid) - if channelid not in self._channels: - channel = self._app.claim_channel(channelid, self._side) - self._channels[channelid] = channel + def handle_claim(self, msg, server_rx): + if "nameplate" not in msg: + raise Error("claim requires 'nameplate'") + nameplate_id = msg["nameplate"] + assert isinstance(nameplate_id, type(u"")), type(nameplate) + if self._nameplate and self._nameplate.get_id() != nameplate_id: + raise Error("claimed nameplate doesn't match allocated nameplate") + self._nameplate = self._app.claim_nameplate(nameplate_id, self._side, + server_rx) + mailbox_id = self._nameplate.get_mailbox_id() + self.send("mailbox", mailbox=mailbox_id) - def handle_watch(self, msg): + def handle_release(self, server_rx): + if not self._nameplate: + raise Error("must claim a nameplate before releasing it") + + deleted = self._nameplate.release(self._side, server_rx) + self._nameplate = None + + + def handle_open(self, msg): channelid = msg["channelid"] if channelid not in self._channels: raise Error("must claim channel before watching") @@ -197,7 +209,7 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): channel.add_message(self._side, msg["phase"], msg["body"], server_rx, msgid) - def handle_release(self, msg): + def handle_close(self, msg): channelid = msg["channelid"] if channelid not in self._channels: raise Error("must claim channel before releasing") From 0e72422ffab627797d7c92b00a6483bedc68143e Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Thu, 19 May 2016 14:18:49 -0700 Subject: [PATCH 19/51] WIP --- src/wormhole/server/db-schemas/v2.sql | 10 +-- src/wormhole/server/rendezvous.py | 87 ++++++++++++++++----- src/wormhole/server/rendezvous_websocket.py | 19 ++--- 3 files changed, 82 insertions(+), 34 deletions(-) diff --git a/src/wormhole/server/db-schemas/v2.sql b/src/wormhole/server/db-schemas/v2.sql index 6659fc6..a3a3485 100644 --- a/src/wormhole/server/db-schemas/v2.sql +++ b/src/wormhole/server/db-schemas/v2.sql @@ -17,7 +17,11 @@ CREATE TABLE `nameplates` `id` VARCHAR, `mailbox_id` VARCHAR, -- really a foreign key `side1` VARCHAR, -- side name, or NULL - `side2` VARCHAR -- side name, or NULL + `side2` VARCHAR, -- side name, or NULL + -- timing data + `started` INTEGER, -- time when nameplace was opened + `second` INTEGER, -- time when second side opened + `closed` INTEGER -- time when closed ); CREATE INDEX `nameplates_idx` ON `nameplates` (`app_id`, `id`); CREATE INDEX `nameplates_mailbox_idx` ON `nameplates` (`app_id`, `mailbox_id`); @@ -30,10 +34,6 @@ CREATE TABLE `mailboxes` `id` VARCHAR, `side1` VARCHAR -- side name, or NULL `side2` VARCHAR -- side name, or NULL - -- timing data for the (optional) linked nameplate - `nameplate_started` INTEGER, -- time when related nameplace was opened - `nameplate_second` INTEGER, -- time when second side opened - `nameplate_closed` INTEGER, -- time when closed -- timing data for the mailbox itself `started` INTEGER, -- time when opened `second` INTEGER, -- time when second side opened diff --git a/src/wormhole/server/rendezvous.py b/src/wormhole/server/rendezvous.py index 1007579..6571f09 100644 --- a/src/wormhole/server/rendezvous.py +++ b/src/wormhole/server/rendezvous.py @@ -17,7 +17,7 @@ RELEASE = u"_release" def get_sides(row): return set([s for s in [row["side1"], row["side2"]] if s]) -def make_sides(side1, side2): +def make_sides(sides): return list(sides) + [None] * (2 - len(sides)) def generate_mailbox_id(): return base64.b32encode(os.urandom(8)).lower().strip("=") @@ -270,7 +270,7 @@ class AppNamespace: " WHERE `app_id`=?", (self._app_id,)) return set([row["id"] for row in c.fetchall()]) - def find_available_nameplate_id(self): + def _find_available_nameplate_id(self): claimed = self.get_nameplate_ids() for size in range(1,4): # stick to 1-999 for now available = set() @@ -288,6 +288,12 @@ class AppNamespace: return id raise ValueError("unable to find a free nameplate-id") + def allocate_nameplate(self, side, when): + nameplate_id = self._find_available_nameplate_id() + mailbox_id = self.claim_nameplate(self, nameplate_id, side, when) + del mailbox_id # ignored, they'll learn it from claim() + return nameplate_id + def _get_mailbox_id(self, nameplate_id): row = self._db.execute("SELECT `mailbox_id` FROM `nameplates`" " WHERE `app_id`=? AND `id`=?", @@ -295,36 +301,81 @@ class AppNamespace: return row["mailbox_id"] def claim_nameplate(self, nameplate_id, side, when): + # when we're done: + # * there will be one row for the nameplate + # * side1 or side2 will be populated + # * started or second will be populated + # * a mailbox id will be created, but not a mailbox row + # (ids are randomly unique, so we can defer creation until 'open') assert isinstance(nameplate_id, type(u"")), type(nameplate_id) + assert isinstance(side, type(u"")), type(side) db = self._db - rows = db.execute("SELECT * FROM `nameplates`" - " WHERE `app_id`=? AND `id`=?", - (self._app_id, nameplate_id)) - if rows: - mailbox_id = rows[0]["mailbox_id"] + row = db.execute("SELECT * FROM `nameplates`" + " WHERE `app_id`=? AND `id`=?", + (self._app_id, nameplate_id)).fetchone() + if row: + mailbox_id = row["mailbox_id"] + sides = [row["side1"], row["sides2"]] + if side not in sides: + if sides[0] and sides[1]: + raise XXXERROR("crowded") + sides[1] = side + db.execute("UPDATE `nameplates` SET " + "`side1`=?, `side2`=?, `mailbox_id`=?, `second`=?" + " WHERE `app_id`=? AND `id`=?", + (sides[0], sides[1], mailbox_id, when, + self._app_id, nameplate_id)) else: if self._log_requests: log.msg("creating nameplate#%s for app_id %s" % (nameplate_id, self._app_id)) - mailbox_id = UUID() - db.execute("INSERT INTO `mailboxes`" - " (`app_id`, `id`)" - " VALUES(?,?)", - (self._app_id, mailbox_id)) + mailbox_id = generate_mailbox_id() db.execute("INSERT INTO `nameplates`" - " (`app_id`, `id`, `mailbox_id`, `side1`, `side2`)" + " (`app_id`, `id`, `mailbox_id`, `side1`, `started`)" " VALUES(?,?,?,?,?)", - (self._app_id, nameplate_id, mailbox_id, None, None)) + (self._app_id, nameplate_id, mailbox_id, side, when)) + db.commit() + return mailbox_id - nameplate = Nameplate(self._app_id, self._db, nameplate_id, mailbox_id) - nameplate.claim(side, when) - return nameplate + def release_nameplate(self, nameplate_id, side, when): + # when we're done: + # * in the nameplate row, side1 or side2 will be removed + # * if the nameplate is now unused: + # * mailbox.nameplate_closed will be populated + # * the nameplate row will be removed + assert isinstance(nameplate_id, type(u"")), type(nameplate_id) + assert isinstance(side, type(u"")), type(side) + db = self._db + row = db.execute("SELECT * FROM `nameplates`" + " WHERE `app_id`=? AND `id`=?", + (self._app_id, nameplate_id)).fetchone() + if not row: + return + sides = get_sides(row) + if side not in sides: + return + sides.discard(side) + if sides: + s12 = make_sides(sides) + db.execute("UPDATE `nameplates` SET `side1`=?, `side2`=?" + " WHERE `app_id`=? AND `id`=?", + (s12[0], s12[1], self._app_id, nameplate_id)) + else: + db.execute("DELETE FROM `nameplates`" + " WHERE `app_id`=? AND `id`=?", + (self._app_id, nameplate_id)) + self._summarize_nameplate(row) - def claim_channel(self, channelid, side): + def open_mailbox(self, channelid, side): assert isinstance(channelid, type(u"")), type(channelid) channel = self.get_channel(channelid) channel.claim(side) return channel + # some of this overlaps with open() on a new mailbox + db.execute("INSERT INTO `mailboxes`" + " (`app_id`, `id`, `nameplate_started`, `started`)" + " VALUES(?,?,?,?)", + (self._app_id, mailbox_id, when, when)) def get_channel(self, channelid): assert isinstance(channelid, type(u"")) diff --git a/src/wormhole/server/rendezvous_websocket.py b/src/wormhole/server/rendezvous_websocket.py index b2bf3f2..612b1f3 100644 --- a/src/wormhole/server/rendezvous_websocket.py +++ b/src/wormhole/server/rendezvous_websocket.py @@ -155,31 +155,28 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): def handle_allocate(self, server_rx): if self._did_allocate: raise Error("You already allocated one channel, don't be greedy") - nameplate_id = self._app.find_available_nameplate_id() + nameplate_id = self._app.allocate_nameplate(self._side, server_rx) assert isinstance(nameplate_id, type(u"")) self._did_allocate = True - self._nameplate = self._app.claim_nameplate(nameplate_id, self._side, - server_rx) self.send("nameplate", nameplate=nameplate_id) def handle_claim(self, msg, server_rx): if "nameplate" not in msg: raise Error("claim requires 'nameplate'") nameplate_id = msg["nameplate"] + self._nameplate_id = nameplate_id assert isinstance(nameplate_id, type(u"")), type(nameplate) - if self._nameplate and self._nameplate.get_id() != nameplate_id: - raise Error("claimed nameplate doesn't match allocated nameplate") - self._nameplate = self._app.claim_nameplate(nameplate_id, self._side, - server_rx) - mailbox_id = self._nameplate.get_mailbox_id() + mailbox_id = self._app.claim_nameplate(nameplate_id, self._side, + server_rx) self.send("mailbox", mailbox=mailbox_id) def handle_release(self, server_rx): - if not self._nameplate: + if not self._nameplate_id: raise Error("must claim a nameplate before releasing it") - deleted = self._nameplate.release(self._side, server_rx) - self._nameplate = None + deleted = self._app.release_nameplate(self._nameplate_id, + self._side, server_rx) + self._nameplate_id = None def handle_open(self, msg): From e39a8291e316b652728a0919cc69e043bc64d936 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Thu, 19 May 2016 18:09:17 -0700 Subject: [PATCH 20/51] checkpointing: server roughed out --- src/wormhole/server/db-schemas/v2.sql | 23 +- src/wormhole/server/rendezvous.py | 478 +++++++++++--------- src/wormhole/server/rendezvous_websocket.py | 63 ++- 3 files changed, 298 insertions(+), 266 deletions(-) diff --git a/src/wormhole/server/db-schemas/v2.sql b/src/wormhole/server/db-schemas/v2.sql index a3a3485..83751ce 100644 --- a/src/wormhole/server/db-schemas/v2.sql +++ b/src/wormhole/server/db-schemas/v2.sql @@ -18,12 +18,14 @@ CREATE TABLE `nameplates` `mailbox_id` VARCHAR, -- really a foreign key `side1` VARCHAR, -- side name, or NULL `side2` VARCHAR, -- side name, or NULL + `crowded` BOOLEAN, -- at some point, three or more sides were involved + `updated` INTEGER, -- time of last activity, used for pruning -- timing data `started` INTEGER, -- time when nameplace was opened - `second` INTEGER, -- time when second side opened - `closed` INTEGER -- time when closed + `second` INTEGER -- time when second side opened ); CREATE INDEX `nameplates_idx` ON `nameplates` (`app_id`, `id`); +CREATE INDEX `nameplates_updated_idx` ON `nameplates` (`app_id`, `updated`); CREATE INDEX `nameplates_mailbox_idx` ON `nameplates` (`app_id`, `mailbox_id`); -- Clients exchange messages through a "mailbox", which has a long (randomly @@ -34,10 +36,11 @@ CREATE TABLE `mailboxes` `id` VARCHAR, `side1` VARCHAR -- side name, or NULL `side2` VARCHAR -- side name, or NULL + `crowded` BOOLEAN, -- at some point, three or more sides were involved + `first_mood` VARCHAR, -- timing data for the mailbox itself `started` INTEGER, -- time when opened - `second` INTEGER, -- time when second side opened - `closed` INTEGER -- time when closed + `second` INTEGER -- time when second side opened ); CREATE INDEX `mailboxes_idx` ON `mailboxes` (`app_id`, `id`); @@ -55,20 +58,22 @@ CREATE INDEX `messages_idx` ON `messages` (`app_id`, `mailbox_id`); CREATE TABLE `nameplate_usage` ( + `app_id` VARCHAR, `started` INTEGER, -- seconds since epoch, rounded to "blur time" - `total_time` INTEGER, -- seconds from open to last close - `result` VARCHAR, -- happy, lonely, pruney, crowded + `waiting_time` INTEGER, -- seconds from start to 2nd side appearing, or None + `total_time` INTEGER, -- seconds from open to last close/prune + `result` VARCHAR -- happy, lonely, pruney, crowded -- nameplate moods: -- "happy": two sides open and close -- "lonely": one side opens and closes (no response from 2nd side) -- "pruney": channels which get pruned for inactivity -- "crowded": three or more sides were involved - `waiting_time` INTEGER -- seconds from start to 2nd side appearing, or None ); -CREATE INDEX `nameplate_usage_idx` ON `nameplate_usage` (`started`); +CREATE INDEX `nameplate_usage_idx` ON `nameplate_usage` (`app_id`, `started`); CREATE TABLE `mailbox_usage` ( + `app_id` VARCHAR, `started` INTEGER, -- seconds since epoch, rounded to "blur time" `total_time` INTEGER, -- seconds from open to last close `waiting_time` INTEGER, -- seconds from start to 2nd side appearing, or None @@ -81,7 +86,7 @@ CREATE TABLE `mailbox_usage` -- "pruney": channels which get pruned for inactivity -- "crowded": three or more sides were involved ); -CREATE INDEX `mailbox_usage_idx` ON `mailbox_usage` (`started`); +CREATE INDEX `mailbox_usage_idx` ON `mailbox_usage` (`app_id`, `started`); CREATE TABLE `transit_usage` ( diff --git a/src/wormhole/server/rendezvous.py b/src/wormhole/server/rendezvous.py index 6571f09..76a8fed 100644 --- a/src/wormhole/server/rendezvous.py +++ b/src/wormhole/server/rendezvous.py @@ -1,5 +1,6 @@ from __future__ import print_function -import time, random +import os, time, random, base64 +from collections import namedtuple from twisted.python import log from twisted.application import service, internet @@ -22,83 +23,78 @@ def make_sides(sides): def generate_mailbox_id(): return base64.b32encode(os.urandom(8)).lower().strip("=") -# Unlike Channels, these instances are ephemeral, and are created and -# destroyed casually. -class Nameplate: - def __init__(self, app_id, db, id, mailbox_id): - self._app_id = app_id - self._db = db - self._id = id - self._mailbox_id = mailbox_id - def get_id(self): - return self._id +SideResult = namedtuple("SideResult", ["changed", "empty", "side1", "side2"]) +Unchanged = SideResult(changed=False, empty=False, side1=None, side2=None) +class CrowdedError(Exception): + pass - def get_mailbox_id(self): - return self._mailbox_id +def add_side(row, new_side): + old_sides = [s for s in [row["side1"], row["side2"]] if s] + if new_side in old_sides: + return Unchanged + if len(old_sides) == 2: + raise CrowdedError("too many sides for this thing") + return SideResult(changed=True, empty=False, + side1=old_sides[0], side2=new_side) - def claim(self, side, when): - db = self._db - sides = get_sides(db.execute("SELECT `side1`, `side2` FROM `nameplates`" - " WHERE `app_id`=? AND `id`=?", - (self._app_id, self._id)).fetchone()) - old_sides = len(sides) - sides.add(side) - if len(sides) > 2: - # XXX: crowded: bail - pass - sides12 = make_sides(sides) - db.execute("UPDATE `nameplates` SET `side1`=?, `side2`=?" - " WHERE `app_id`=? AND `id`=?", - (sides12[0], sides12[1], self._app_id, self._id)) - if old_sides == 0: - db.execute("UPDATE `mailboxes` SET `nameplate_started`=?" - " WHERE `app_id`=? AND `id`=?", - (when, self._app_id, self._mailbox_id)) - else: - db.execute("UPDATE `mailboxes` SET `nameplate_second`=?" - " WHERE `app_id`=? AND `id`=?", - (when, self._app_id, self._mailbox_id)) - db.commit() +def remove_side(row, side): + old_sides = [s for s in [row["side1"], row["side2"]] if s] + if side not in old_sides: + return Unchanged + remaining_sides = old_sides[:] + remaining_sides.remove(side) + if remaining_sides: + return SideResult(changed=True, empty=False, side1=remaining_sides[0], + side2=None) + return SideResult(changed=True, empty=True, side1=None, side2=None) - def release(self, side, when): - db = self._db - sides = get_sides(db.execute("SELECT `side1`, `side2` FROM `nameplates`" - " WHERE `app_id`=? AND `id`=?", - (self._app_id, self._id)).fetchone()) - sides.discard(side) - sides12 = make_sides(sides) - db.execute("UPDATE `nameplates` SET `side1`=?, `side2`=?" - " WHERE `app_id`=? AND `id`=?", - (sides12[0], sides12[1], self._app_id, self._id)) - if len(sides) == 0: - db.execute("UPDATE `mailboxes` SET `nameplate_closed`=?" - " WHERE `app_id`=? AND `id`=?", - (when, self._app_id, self._mailbox_id)) - db.commit() +Usage = namedtuple("Usage", ["started", "waiting_time", "total_time", "result"]) +TransitUsage = namedtuple("TransitUsage", + ["started", "waiting_time", "total_time", + "total_bytes", "result"]) class Mailbox: - def __init__(self, app, db, blur_usage, log_requests, app_id, channelid): + def __init__(self, app, db, blur_usage, log_requests, app_id, mailbox_id): self._app = app self._db = db self._blur_usage = blur_usage self._log_requests = log_requests self._app_id = app_id - self._channelid = channelid + self._mailbox_id = mailbox_id self._listeners = {} # handle -> (send_f, stop_f) # "handle" is a hashable object, for deregistration # send_f() takes a JSONable object, stop_f() has no args - def get_channelid(self): - return self._channelid + def open(self, side, when): + # requires caller to db.commit() + assert isinstance(side, type(u"")), type(side) + db = self._db + row = db.execute("SELECT * FROM `mailboxes`" + " WHERE `app_id`=? AND `id`=?", + (self._app_id, self._mailbox_id)).fetchone() + try: + sr = add_side(row, side) + except CrowdedError: + db.execute("UPDATE `mailboxes` SET `crowded`=?" + " WHERE `app_id`=? AND `id`=?", + (True, self._app_id, self._mailbox_id)) + db.commit() + raise + if sr.changed: + db.execute("UPDATE `mailboxes` SET" + " `side1`=?, `side2`=?, `second`=?" + " WHERE `app_id`=? AND `id`=?", + (sr.side1, sr.side2, when, + self._app_id, self._mailbox_id)) def get_messages(self): messages = [] db = self._db for row in db.execute("SELECT * FROM `messages`" - " WHERE `app_id`=? AND `channelid`=?" + " WHERE `app_id`=? AND `mailbox_id`=?" " ORDER BY `server_rx` ASC", - (self._app_id, self._channelid)).fetchall(): + (self._app_id, self._mailbox_id)).fetchall(): if row["phase"] in (CLAIM, RELEASE): continue messages.append({"phase": row["phase"], "body": row["body"], @@ -120,45 +116,107 @@ class Mailbox: def _add_message(self, side, phase, body, server_rx, msgid): db = self._db db.execute("INSERT INTO `messages`" - " (`app_id`, `channelid`, `side`, `phase`, `body`," + " (`app_id`, `mailbox_id`, `side`, `phase`, `body`," " `server_rx`, `msgid`)" " VALUES (?,?,?,?,?, ?,?)", - (self._app_id, self._channelid, side, phase, body, + (self._app_id, self._mailbox_id, side, phase, body, server_rx, msgid)) db.commit() - def claim(self, side): - self._add_message(side, CLAIM, None, time.time(), None) - def add_message(self, side, phase, body, server_rx, msgid): self._add_message(side, phase, body, server_rx, msgid) self.broadcast_message(phase, body, server_rx, msgid) return self.get_messages() # for rendezvous_web.py POST /add - def release(self, side, mood): - self._add_message(side, RELEASE, mood, time.time(), None) + def close(self, side, mood, when): + assert isinstance(side, type(u"")), type(side) db = self._db - seen = set([row["side"] for row in - db.execute("SELECT `side` FROM `messages`" - " WHERE `app_id`=? AND `channelid`=?", - (self._app_id, self._channelid))]) - freed = set([row["side"] for row in - db.execute("SELECT `side` FROM `messages`" - " WHERE `app_id`=? AND `channelid`=?" - " AND `phase`=?", - (self._app_id, self._channelid, RELEASE))]) - if seen - freed: - return False - self.delete_and_summarize() - return True + row = db.execute("SELECT * FROM `mailboxes`" + " WHERE `app_id`=? AND `id`=?", + (self._app_id, self._mailbox_id)).fetchone() + if not row: + return + sr = remove_side(row, side) + if sr.empty: + rows = db.execute("SELECT DISTINCT(side) FROM `messages`" + " WHERE `app_id`=? AND `id`=?", + (self._app_id, self._mailbox_id)).fetchall() + num_sides = len(rows) + self._summarize_and_store(row, num_sides, mood, when, pruned=False) + self._delete() + db.commit() + elif sr.changed: + db.execute("UPDATE `mailboxes`" + " SET `side1`=?, `side2`=?, `first_mood`=?" + " WHERE `app_id`=? AND `id`=?", + (sr.side1, sr.side2, mood, + self._app_id, self._mailbox_id)) + db.commit() + + def _delete(self): + # requires caller to db.commit() + self._db.execute("DELETE FROM `mailboxes`" + " WHERE `app_id`=? AND `id`=?", + (self._app_id, self._mailbox_id)) + self._db.execute("DELETE FROM `messages`" + " WHERE `app_id`=? AND `mailbox_id`=?", + (self._app_id, self._mailbox_id)) + + # Shut down any listeners, just in case they're still lingering + # around. + for (send_f, stop_f) in self._listeners.values(): + stop_f() + + self._app.free_mailbox(self._mailbox_id) + + def _summarize_and_store(self, row, num_sides, second_mood, delete_time, + pruned): + u = self._summarize(row, num_sides, second_mood, delete_time, pruned) + self._db.execute("INSERT INTO `mailbox_usage`" + " (`app_id`, " + " `started`, `total_time`, `waiting_time`, `result`)" + " VALUES (?, ?,?,?,?)", + (self._app_id, + u.started, u.total_time, u.waiting_time, u.result)) + + def _summarize(self, row, num_sides, second_mood, delete_time, pruned): + started = row["started"] + if self._blur_usage: + started = self._blur_usage * (started // self._blur_usage) + waiting_time = None + if row["second"]: + waiting_time = row["second"] - row["started"] + total_time = delete_time - row["started"] + + if len(num_sides) == 0: + result = u"quiet" + elif len(num_sides) == 1: + result = u"lonely" + else: + result = u"happy" + + moods = set([row["first_mood"], second_mood]) + if u"lonely" in moods: + result = u"lonely" + if u"errory" in moods: + result = u"errory" + if u"scary" in moods: + result = u"scary" + if pruned: + result = u"pruney" + if row["crowded"]: + result = u"crowded" + + return Usage(started=started, waiting_time=waiting_time, + total_time=total_time, result=result) def is_idle(self): if self._listeners: return False c = self._db.execute("SELECT `server_rx` FROM `messages`" - " WHERE `app_id`=? AND `channelid`=?" + " WHERE `app_id`=? AND `mailbox_id`=?" " ORDER BY `server_rx` DESC LIMIT 1", - (self._app_id, self._channelid)) + (self._app_id, self._mailbox_id)) rows = c.fetchall() if not rows: return True @@ -167,88 +225,6 @@ class Mailbox: return True return False - def _store_summary(self, summary): - (started, result, total_time, waiting_time) = summary - if self._blur_usage: - started = self._blur_usage * (started // self._blur_usage) - self._db.execute("INSERT INTO `usage`" - " (`type`, `started`, `result`," - " `total_time`, `waiting_time`)" - " VALUES (?,?,?, ?,?)", - (u"rendezvous", started, result, - total_time, waiting_time)) - self._db.commit() - - def _summarize(self, messages, delete_time): - all_sides = set([m["side"] for m in messages]) - if len(all_sides) == 0: - log.msg("_summarize was given zero messages") # shouldn't happen - return - - started = min([m["server_rx"] for m in messages]) - # 'total_time' is how long the channel was occupied. That ends now, - # both for channels that got pruned for inactivity, and for channels - # that got pruned because of two RELEASE messages - total_time = delete_time - started - - if len(all_sides) == 1: - return (started, "lonely", total_time, None) - if len(all_sides) > 2: - # TODO: it'll be useful to have more detail here - return (started, "crowded", total_time, None) - - # exactly two sides were involved - A_side = sorted(messages, key=lambda m: m["server_rx"])[0]["side"] - B_side = list(all_sides - set([A_side]))[0] - - # How long did the first side wait until the second side showed up? - first_A = min([m["server_rx"] for m in messages if m["side"] == A_side]) - first_B = min([m["server_rx"] for m in messages if m["side"] == B_side]) - waiting_time = first_B - first_A - - # now, were all sides closed? If not, this is "pruney" - A_deallocs = [m for m in messages - if m["phase"] == RELEASE and m["side"] == A_side] - B_deallocs = [m for m in messages - if m["phase"] == RELEASE and m["side"] == B_side] - if not A_deallocs or not B_deallocs: - return (started, "pruney", total_time, None) - - # ok, both sides closed. figure out the mood - A_mood = A_deallocs[0]["body"] # maybe None - B_mood = B_deallocs[0]["body"] # maybe None - mood = "quiet" - if A_mood == u"happy" and B_mood == u"happy": - mood = "happy" - if A_mood == u"lonely" or B_mood == u"lonely": - mood = "lonely" - if A_mood == u"errory" or B_mood == u"errory": - mood = "errory" - if A_mood == u"scary" or B_mood == u"scary": - mood = "scary" - return (started, mood, total_time, waiting_time) - - def delete_and_summarize(self): - db = self._db - c = self._db.execute("SELECT * FROM `messages`" - " WHERE `app_id`=? AND `channelid`=?" - " ORDER BY `server_rx`", - (self._app_id, self._channelid)) - messages = c.fetchall() - summary = self._summarize(messages, time.time()) - self._store_summary(summary) - db.execute("DELETE FROM `messages`" - " WHERE `app_id`=? AND `channelid`=?", - (self._app_id, self._channelid)) - db.commit() - - # Shut down any listeners, just in case they're still lingering - # around. - for (send_f, stop_f) in self._listeners.values(): - stop_f() - - self._app.free_channel(self._channelid) - def _shutdown(self): # used at test shutdown to accelerate client disconnects for (send_f, stop_f) in self._listeners.values(): @@ -261,7 +237,7 @@ class AppNamespace: self._blur_usage = blur_usage self._log_requests = log_requests self._app_id = app_id - self._channels = {} + self._mailboxes = {} def get_nameplate_ids(self): db = self._db @@ -315,15 +291,19 @@ class AppNamespace: (self._app_id, nameplate_id)).fetchone() if row: mailbox_id = row["mailbox_id"] - sides = [row["side1"], row["sides2"]] - if side not in sides: - if sides[0] and sides[1]: - raise XXXERROR("crowded") - sides[1] = side - db.execute("UPDATE `nameplates` SET " - "`side1`=?, `side2`=?, `mailbox_id`=?, `second`=?" + try: + sr = add_side(row, side) + except CrowdedError: + db.execute("UPDATE `nameplates` SET `crowded`=?" " WHERE `app_id`=? AND `id`=?", - (sides[0], sides[1], mailbox_id, when, + (True, self._app_id, nameplate_id)) + db.commit() + raise + if sr.changed: + db.execute("UPDATE `nameplates` SET" + " `side1`=?, `side2`=?, `updated`=?, `second`=?" + " WHERE `app_id`=? AND `id`=?", + (sr.side1, sr.side2, when, when, self._app_id, nameplate_id)) else: if self._log_requests: @@ -331,9 +311,11 @@ class AppNamespace: (nameplate_id, self._app_id)) mailbox_id = generate_mailbox_id() db.execute("INSERT INTO `nameplates`" - " (`app_id`, `id`, `mailbox_id`, `side1`, `started`)" - " VALUES(?,?,?,?,?)", - (self._app_id, nameplate_id, mailbox_id, side, when)) + " (`app_id`, `id`, `mailbox_id`, `side1`," + " `updated`, `started`)" + " VALUES(?,?,?,?, ?,?)", + (self._app_id, nameplate_id, mailbox_id, side, + when, when)) db.commit() return mailbox_id @@ -351,75 +333,120 @@ class AppNamespace: (self._app_id, nameplate_id)).fetchone() if not row: return - sides = get_sides(row) - if side not in sides: - return - sides.discard(side) - if sides: - s12 = make_sides(sides) - db.execute("UPDATE `nameplates` SET `side1`=?, `side2`=?" - " WHERE `app_id`=? AND `id`=?", - (s12[0], s12[1], self._app_id, nameplate_id)) - else: + sr = remove_side(row, side) + if sr.empty: db.execute("DELETE FROM `nameplates`" " WHERE `app_id`=? AND `id`=?", (self._app_id, nameplate_id)) - self._summarize_nameplate(row) + self._summarize_nameplate_and_store(row, when, pruned=False) + db.commit() + elif sr.changed: + db.execute("UPDATE `nameplates`" + " SET `side1`=?, `side2`=?, `updated`=?" + " WHERE `app_id`=? AND `id`=?", + (sr.side1, sr.side2, when, + self._app_id, nameplate_id)) + db.commit() - def open_mailbox(self, channelid, side): - assert isinstance(channelid, type(u"")), type(channelid) - channel = self.get_channel(channelid) - channel.claim(side) - return channel - # some of this overlaps with open() on a new mailbox - db.execute("INSERT INTO `mailboxes`" - " (`app_id`, `id`, `nameplate_started`, `started`)" - " VALUES(?,?,?,?)", - (self._app_id, mailbox_id, when, when)) + def _summarize_nameplate_and_store(self, row, delete_time, pruned): + # requires caller to db.commit() + u = self._summarize_nameplate_usage(row, delete_time, pruned) + self._db.execute("INSERT INTO `nameplate_usage`" + " (`app_id`," + " `started`, `total_time`, `waiting_time`, `result`)" + " VALUES (?, ?,?,?,?)", + (self._app_id, + u.started, u.total_time, u.waiting_time, u.result)) - def get_channel(self, channelid): - assert isinstance(channelid, type(u"")) - if not channelid in self._channels: + def _summarize_nameplate_usage(self, row, delete_time, pruned): + started = row["started"] + if self._blur_usage: + started = self._blur_usage * (started // self._blur_usage) + waiting_time = None + if row["second"]: + waiting_time = row["second"] - row["started"] + total_time = delete_time - row["started"] + result = u"lonely" + if pruned: + result = u"pruney" + if row["second"]: + result = u"happy" + if row["crowded"]: + result = u"crowded" + return Usage(started=started, waiting_time=waiting_time, + total_time=total_time, result=result) + + def _prune_nameplate(self, row, delete_time): + # requires caller to db.commit() + db = self._db + db.execute("DELETE FROM `nameplates` WHERE `app_id`=? AND `id`=?", + (self._app_id, row["id"])) + self._summarize_nameplate_and_store(row, delete_time, pruned=True) + # TODO: make a Nameplate object, keep track of when there's a + # websocket that's watching it, don't prune a nameplate that someone + # is watching, even if they started watching a long time ago + + def prune_nameplates(self, old): + db = self._db + for row in db.execute("SELECT * FROM `nameplates`" + " WHERE `updated` < ?", + (old,)).fetchall(): + self._prune_nameplate(row) + count = db.execute("SELECT COUNT(*) FROM `nameplates`").fetchone()[0] + return count + + def open_mailbox(self, mailbox_id, side, when): + assert isinstance(mailbox_id, type(u"")), type(mailbox_id) + db = self._db + if not mailbox_id in self._mailboxes: if self._log_requests: - log.msg("spawning #%s for app_id %s" % (channelid, self._app_id)) - self._channels[channelid] = Channel(self, self._db, - self._blur_usage, - self._log_requests, - self._app_id, channelid) - return self._channels[channelid] + log.msg("spawning #%s for app_id %s" % (mailbox_id, + self._app_id)) + db.execute("INSERT INTO `mailboxes`" + " (`app_id`, `id`, `started`)" + " VALUES(?,?,?)", + (self._app_id, mailbox_id, when)) + self._mailboxes[mailbox_id] = Mailbox(self, self._db, + self._blur_usage, + self._log_requests, + self._app_id, mailbox_id) + mailbox = self._mailboxes[mailbox_id] + mailbox.open(side) + db.commit() + return mailbox - def free_channel(self, channelid): - # called from Channel.delete_and_summarize(), which deletes any + def free_mailbox(self, mailbox_id): + # called from Mailbox.delete_and_summarize(), which deletes any # messages - if channelid in self._channels: - self._channels.pop(channelid) + if mailbox_id in self._mailboxes: + self._mailboxes.pop(mailbox_id) if self._log_requests: - log.msg("freed+killed #%s, now have %d DB channels, %d live" % - (channelid, len(self.get_claimed()), len(self._channels))) + log.msg("freed+killed #%s, now have %d DB mailboxes, %d live" % + (mailbox_id, len(self.get_claimed()), len(self._mailboxes))) - def prune_old_channels(self): + def prune_mailboxes(self, old): # For now, pruning is logged even if log_requests is False, to debug # the pruning process, and since pruning is triggered by a timer - # instead of by user action. It does reveal which channels were + # instead of by user action. It does reveal which mailboxes were # present when the pruning process began, though, so in the log run # it should do less logging. log.msg(" channel prune begins") # a channel is deleted when there are no listeners and there have # been no messages added in CHANNEL_EXPIRATION_TIME seconds - channels = set(self.get_claimed()) # these have messages - channels.update(self._channels) # these might have listeners - for channelid in channels: - log.msg(" channel prune checking %d" % channelid) - channel = self.get_channel(channelid) + mailboxes = set(self.get_claimed()) # these have messages + mailboxes.update(self._mailboxes) # these might have listeners + for mailbox_id in mailboxes: + log.msg(" channel prune checking %d" % mailbox_id) + channel = self.get_channel(mailbox_id) if channel.is_idle(): - log.msg(" channel prune expiring %d" % channelid) + log.msg(" channel prune expiring %d" % mailbox_id) channel.delete_and_summarize() # calls self.free_channel - log.msg(" channel prune done, %r left" % (self._channels.keys(),)) - return bool(self._channels) + log.msg(" channel prune done, %r left" % (self._mailboxes.keys(),)) + return bool(self._mailboxes) def _shutdown(self): - for channel in self._channels.values(): + for channel in self._mailboxes.values(): channel._shutdown() class Rendezvous(service.MultiService): @@ -427,7 +454,7 @@ class Rendezvous(service.MultiService): service.MultiService.__init__(self) self._db = db self._welcome = welcome - self._blur_usage = blur_usage + self._blur_usage = None log_requests = blur_usage is None self._log_requests = log_requests self._apps = {} @@ -449,15 +476,18 @@ class Rendezvous(service.MultiService): self._log_requests, app_id) return self._apps[app_id] - def prune(self): - # As with AppNamespace.prune_old_channels, we log for now. + def prune(self, old=None): + # As with AppNamespace.prune_old_mailboxes, we log for now. log.msg("beginning app prune") + if old is None: + old = time.time() - CHANNEL_EXPIRATION_TIME c = self._db.execute("SELECT DISTINCT `app_id` FROM `messages`") apps = set([row["app_id"] for row in c.fetchall()]) # these have messages apps.update(self._apps) # these might have listeners for app_id in apps: log.msg(" app prune checking %r" % (app_id,)) - still_active = self.get_app(app_id).prune_old_channels() + app = self.get_app(app_id) + still_active = app.prune_nameplates(old) + app.prune_mailboxes(old) if not still_active: log.msg("prune pops app %r" % (app_id,)) self._apps.pop(app_id) diff --git a/src/wormhole/server/rendezvous_websocket.py b/src/wormhole/server/rendezvous_websocket.py index 612b1f3..0f6d5d2 100644 --- a/src/wormhole/server/rendezvous_websocket.py +++ b/src/wormhole/server/rendezvous_websocket.py @@ -2,6 +2,7 @@ import json, time from twisted.internet import reactor from twisted.python import log from autobahn.twisted import websocket +from .rendezvous import CrowdedError # The WebSocket allows the client to send "commands" to the server, and the # server to send "responses" to the client. Note that commands and responses @@ -84,7 +85,7 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): self._app = None self._side = None self._did_allocate = False # only one allocate() per websocket - self._channels = {} # channel-id -> Channel (claimed) + self._mailbox = None def onConnect(self, request): rv = self.factory.rendezvous @@ -122,11 +123,11 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): return self.handle_release(msg, server_rx) if mtype == "open": - return self.handle_open(msg) + return self.handle_open(msg, server_rx) if mtype == "add": return self.handle_add(msg, server_rx) if mtype == "close": - return self.handle_close(msg) + return self.handle_close(msg, server_rx) raise Error("Unknown type") except Error as e: @@ -154,7 +155,7 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): def handle_allocate(self, server_rx): if self._did_allocate: - raise Error("You already allocated one channel, don't be greedy") + raise Error("You already allocated one mailbox, don't be greedy") nameplate_id = self._app.allocate_nameplate(self._side, server_rx) assert isinstance(nameplate_id, type(u"")) self._did_allocate = True @@ -164,56 +165,52 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): if "nameplate" not in msg: raise Error("claim requires 'nameplate'") nameplate_id = msg["nameplate"] + assert isinstance(nameplate_id, type(u"")), type(nameplate_id) self._nameplate_id = nameplate_id - assert isinstance(nameplate_id, type(u"")), type(nameplate) - mailbox_id = self._app.claim_nameplate(nameplate_id, self._side, - server_rx) + try: + mailbox_id = self._app.claim_nameplate(nameplate_id, self._side, + server_rx) + except CrowdedError: + raise Error("crowded") self.send("mailbox", mailbox=mailbox_id) def handle_release(self, server_rx): if not self._nameplate_id: raise Error("must claim a nameplate before releasing it") - - deleted = self._app.release_nameplate(self._nameplate_id, - self._side, server_rx) + self._app.release_nameplate(self._nameplate_id, self._side, server_rx) self._nameplate_id = None - def handle_open(self, msg): - channelid = msg["channelid"] - if channelid not in self._channels: - raise Error("must claim channel before watching") - assert isinstance(channelid, type(u"")) - channel = self._channels[channelid] + def handle_open(self, msg, server_rx): + if self._mailbox: + raise Error("you already have a mailbox open") + mailbox_id = msg["mailbox_id"] + assert isinstance(mailbox_id, type(u"")) + self._mailbox = self._app.open_mailbox(mailbox_id, self._side, + server_rx) def _send(event): - self.send("message", channelid=channelid, message=event) + self.send("message", message=event) def _stop(): self._reactor.callLater(0, self.transport.loseConnection) - for old_message in channel.add_listener(self, _send, _stop): + for old_message in self._mailbox.add_listener(self, _send, _stop): _send(old_message) def handle_add(self, msg, server_rx): - channelid = msg["channelid"] - if channelid not in self._channels: - raise Error("must claim channel before adding") - assert isinstance(channelid, type(u"")) - channel = self._channels[channelid] + if not self._mailbox: + raise Error("must open mailbox before adding") if "phase" not in msg: raise Error("missing 'phase'") if "body" not in msg: raise Error("missing 'body'") msgid = msg.get("id") # optional - channel.add_message(self._side, msg["phase"], msg["body"], - server_rx, msgid) + self._mailbox.add_message(self._side, msg["phase"], msg["body"], + server_rx, msgid) - def handle_close(self, msg): - channelid = msg["channelid"] - if channelid not in self._channels: - raise Error("must claim channel before releasing") - assert isinstance(channelid, type(u"")) - channel = self._channels[channelid] - deleted = channel.release(self._side, msg.get("mood")) - del self._channels[channelid] + def handle_close(self, msg, server_rx): + if not self._mailbox: + raise Error("must open mailbox before closing") + deleted = self._mailbox.close(self._side, msg.get("mood"), server_rx) + self._mailbox = None self.send("released", status="deleted" if deleted else "waiting") def send(self, mtype, **kwargs): From 335ed00cb75795c9f62b5876b5426b7133440474 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Thu, 19 May 2016 19:55:11 -0700 Subject: [PATCH 21/51] build out tests --- src/wormhole/server/db-schemas/v2.sql | 4 +- src/wormhole/server/rendezvous.py | 41 ++-- src/wormhole/server/transit_server.py | 12 +- src/wormhole/test/test_server.py | 314 +++++++++++++++++++++----- 4 files changed, 287 insertions(+), 84 deletions(-) diff --git a/src/wormhole/server/db-schemas/v2.sql b/src/wormhole/server/db-schemas/v2.sql index 83751ce..b436897 100644 --- a/src/wormhole/server/db-schemas/v2.sql +++ b/src/wormhole/server/db-schemas/v2.sql @@ -34,8 +34,8 @@ CREATE TABLE `mailboxes` ( `app_id` VARCHAR, `id` VARCHAR, - `side1` VARCHAR -- side name, or NULL - `side2` VARCHAR -- side name, or NULL + `side1` VARCHAR, -- side name, or NULL + `side2` VARCHAR, -- side name, or NULL `crowded` BOOLEAN, -- at some point, three or more sides were involved `first_mood` VARCHAR, -- timing data for the mailbox itself diff --git a/src/wormhole/server/rendezvous.py b/src/wormhole/server/rendezvous.py index 76a8fed..f2b0829 100644 --- a/src/wormhole/server/rendezvous.py +++ b/src/wormhole/server/rendezvous.py @@ -31,6 +31,7 @@ class CrowdedError(Exception): def add_side(row, new_side): old_sides = [s for s in [row["side1"], row["side2"]] if s] + assert old_sides if new_side in old_sides: return Unchanged if len(old_sides) == 2: @@ -98,7 +99,7 @@ class Mailbox: if row["phase"] in (CLAIM, RELEASE): continue messages.append({"phase": row["phase"], "body": row["body"], - "server_rx": row["server_rx"], "id": row["msgid"]}) + "server_rx": row["server_rx"], "id": row["msg_id"]}) return messages def add_listener(self, handle, send_f, stop_f): @@ -117,7 +118,7 @@ class Mailbox: db = self._db db.execute("INSERT INTO `messages`" " (`app_id`, `mailbox_id`, `side`, `phase`, `body`," - " `server_rx`, `msgid`)" + " `server_rx`, `msg_id`)" " VALUES (?,?,?,?,?, ?,?)", (self._app_id, self._mailbox_id, side, phase, body, server_rx, msgid)) @@ -138,8 +139,8 @@ class Mailbox: return sr = remove_side(row, side) if sr.empty: - rows = db.execute("SELECT DISTINCT(side) FROM `messages`" - " WHERE `app_id`=? AND `id`=?", + rows = db.execute("SELECT DISTINCT(`side`) FROM `messages`" + " WHERE `app_id`=? AND `mailbox_id`=?", (self._app_id, self._mailbox_id)).fetchall() num_sides = len(rows) self._summarize_and_store(row, num_sides, mood, when, pruned=False) @@ -188,9 +189,9 @@ class Mailbox: waiting_time = row["second"] - row["started"] total_time = delete_time - row["started"] - if len(num_sides) == 0: + if num_sides == 0: result = u"quiet" - elif len(num_sides) == 1: + elif num_sides == 1: result = u"lonely" else: result = u"happy" @@ -266,7 +267,7 @@ class AppNamespace: def allocate_nameplate(self, side, when): nameplate_id = self._find_available_nameplate_id() - mailbox_id = self.claim_nameplate(self, nameplate_id, side, when) + mailbox_id = self.claim_nameplate(nameplate_id, side, when) del mailbox_id # ignored, they'll learn it from claim() return nameplate_id @@ -311,10 +312,10 @@ class AppNamespace: (nameplate_id, self._app_id)) mailbox_id = generate_mailbox_id() db.execute("INSERT INTO `nameplates`" - " (`app_id`, `id`, `mailbox_id`, `side1`," + " (`app_id`, `id`, `mailbox_id`, `side1`, `crowded`," " `updated`, `started`)" - " VALUES(?,?,?,?, ?,?)", - (self._app_id, nameplate_id, mailbox_id, side, + " VALUES(?,?,?,?,?, ?,?)", + (self._app_id, nameplate_id, mailbox_id, side, False, when, when)) db.commit() return mailbox_id @@ -367,10 +368,10 @@ class AppNamespace: waiting_time = row["second"] - row["started"] total_time = delete_time - row["started"] result = u"lonely" - if pruned: - result = u"pruney" if row["second"]: result = u"happy" + if pruned: + result = u"pruney" if row["crowded"]: result = u"crowded" return Usage(started=started, waiting_time=waiting_time, @@ -403,15 +404,17 @@ class AppNamespace: log.msg("spawning #%s for app_id %s" % (mailbox_id, self._app_id)) db.execute("INSERT INTO `mailboxes`" - " (`app_id`, `id`, `started`)" - " VALUES(?,?,?)", - (self._app_id, mailbox_id, when)) + " (`app_id`, `id`, `side1`, `crowded`, `started`)" + " VALUES(?,?,?,?,?)", + (self._app_id, mailbox_id, side, False, when)) + db.commit() # XXX + # mailbox.open() does a SELECT to find the old sides self._mailboxes[mailbox_id] = Mailbox(self, self._db, self._blur_usage, self._log_requests, self._app_id, mailbox_id) mailbox = self._mailboxes[mailbox_id] - mailbox.open(side) + mailbox.open(side, when) db.commit() return mailbox @@ -421,9 +424,9 @@ class AppNamespace: if mailbox_id in self._mailboxes: self._mailboxes.pop(mailbox_id) - if self._log_requests: - log.msg("freed+killed #%s, now have %d DB mailboxes, %d live" % - (mailbox_id, len(self.get_claimed()), len(self._mailboxes))) + #if self._log_requests: + # log.msg("freed+killed #%s, now have %d DB mailboxes, %d live" % + # (mailbox_id, len(self.get_claimed()), len(self._mailboxes))) def prune_mailboxes(self, old): # For now, pruning is logged even if log_requests is False, to debug diff --git a/src/wormhole/server/transit_server.py b/src/wormhole/server/transit_server.py index 2412c89..7c55b44 100644 --- a/src/wormhole/server/transit_server.py +++ b/src/wormhole/server/transit_server.py @@ -186,12 +186,12 @@ class Transit(protocol.ServerFactory, service.MultiService): if self._blur_usage: started = self._blur_usage * (started // self._blur_usage) total_bytes = blur_size(total_bytes) - self._db.execute("INSERT INTO `usage`" - " (`type`, `started`, `result`, `total_bytes`," - " `total_time`, `waiting_time`)" - " VALUES (?,?,?,?, ?,?)", - (u"transit", started, result, total_bytes, - total_time, waiting_time)) + self._db.execute("INSERT INTO `transit_usage`" + " (`started`, `total_time`, `waiting_time`," + " `total_bytes`, `result`)" + " VALUES (?,?,?, ?,?)", + (started, total_time, waiting_time, + total_bytes, result)) self._db.commit() def transitFinished(self, p, token, description): diff --git a/src/wormhole/test/test_server.py b/src/wormhole/test/test_server.py index e98f4d1..1497e7d 100644 --- a/src/wormhole/test/test_server.py +++ b/src/wormhole/test/test_server.py @@ -10,6 +10,7 @@ from autobahn.twisted import websocket from .. import __version__ from .common import ServerBase from ..server import rendezvous, transit_server +from ..server.rendezvous import Usage class Reachable(ServerBase, unittest.TestCase): @@ -35,6 +36,218 @@ class Reachable(ServerBase, unittest.TestCase): d.addCallback(_got) return d +class Server(ServerBase, unittest.TestCase): + def test_apps(self): + app1 = self._rendezvous.get_app(u"appid1") + self.assertIdentical(app1, self._rendezvous.get_app(u"appid1")) + app2 = self._rendezvous.get_app(u"appid2") + self.assertNotIdentical(app1, app2) + + def test_nameplate_allocation(self): + app = self._rendezvous.get_app(u"appid") + nids = set() + # this takes a second, and claims all the short-numbered nameplates + def add(): + nameplate_id = app.allocate_nameplate(u"side1", 0) + self.assertEqual(type(nameplate_id), type(u"")) + nid = int(nameplate_id) + nids.add(nid) + for i in range(9): add() + self.assertNotIn(0, nids) + self.assertEqual(set(range(1,10)), nids) + + for i in range(100-10): add() + self.assertEqual(len(nids), 99) + self.assertEqual(set(range(1,100)), nids) + + for i in range(1000-100): add() + self.assertEqual(len(nids), 999) + self.assertEqual(set(range(1,1000)), nids) + + add() + self.assertEqual(len(nids), 1000) + biggest = max(nids) + self.assert_(1000 <= biggest < 1000000, biggest) + + def _nameplate(self, app, nameplate_id): + return app._db.execute("SELECT * FROM `nameplates`" + " WHERE `app_id`='appid' AND `id`=?", + (nameplate_id,)).fetchone() + + def test_nameplate(self): + app = self._rendezvous.get_app(u"appid") + nameplate_id = app.allocate_nameplate(u"side1", 0) + self.assertEqual(type(nameplate_id), type(u"")) + nid = int(nameplate_id) + self.assert_(0 < nid < 10, nid) + self.assertEqual(app.get_nameplate_ids(), set([nameplate_id])) + # allocate also does a claim + row = self._nameplate(app, nameplate_id) + self.assertEqual(row["side1"], u"side1") + self.assertEqual(row["side2"], None) + self.assertEqual(row["crowded"], False) + self.assertEqual(row["started"], 0) + self.assertEqual(row["second"], None) + + mailbox_id = app.claim_nameplate(nameplate_id, u"side1", 1) + self.assertEqual(type(mailbox_id), type(u"")) + # duplicate claims by the same side are combined + row = self._nameplate(app, nameplate_id) + self.assertEqual(row["side1"], u"side1") + self.assertEqual(row["side2"], None) + + mailbox_id2 = app.claim_nameplate(nameplate_id, u"side1", 2) + self.assertEqual(mailbox_id, mailbox_id2) + row = self._nameplate(app, nameplate_id) + self.assertEqual(row["side1"], u"side1") + self.assertEqual(row["side2"], None) + self.assertEqual(row["started"], 0) + self.assertEqual(row["second"], None) + + # claim by the second side is new + mailbox_id3 = app.claim_nameplate(nameplate_id, u"side2", 3) + self.assertEqual(mailbox_id, mailbox_id3) + row = self._nameplate(app, nameplate_id) + self.assertEqual(row["side1"], u"side1") + self.assertEqual(row["side2"], u"side2") + self.assertEqual(row["crowded"], False) + self.assertEqual(row["started"], 0) + self.assertEqual(row["second"], 3) + + # a third claim marks the nameplate as "crowded", but leaves the two + # existing claims alone + self.assertRaises(rendezvous.CrowdedError, + app.claim_nameplate, nameplate_id, u"side3", 0) + row = self._nameplate(app, nameplate_id) + self.assertEqual(row["side1"], u"side1") + self.assertEqual(row["side2"], u"side2") + self.assertEqual(row["crowded"], True) + + # releasing a non-existent nameplate is ignored + app.release_nameplate(nameplate_id+u"not", u"side4", 0) + + # releasing a side that never claimed the nameplate is ignored + app.release_nameplate(nameplate_id, u"side4", 0) + row = self._nameplate(app, nameplate_id) + self.assertEqual(row["side1"], u"side1") + self.assertEqual(row["side2"], u"side2") + + # releasing one side leaves the second claim + app.release_nameplate(nameplate_id, u"side1", 5) + row = self._nameplate(app, nameplate_id) + self.assertEqual(row["side1"], u"side2") + self.assertEqual(row["side2"], None) + + # releasing one side multiple times is ignored + app.release_nameplate(nameplate_id, u"side1", 5) + row = self._nameplate(app, nameplate_id) + self.assertEqual(row["side1"], u"side2") + self.assertEqual(row["side2"], None) + + # releasing the second side frees the nameplate, and adds usage + app.release_nameplate(nameplate_id, u"side2", 6) + row = self._nameplate(app, nameplate_id) + self.assertEqual(row, None) + usage = app._db.execute("SELECT * FROM `nameplate_usage`").fetchone() + self.assertEqual(usage["app_id"], u"appid") + self.assertEqual(usage["started"], 0) + self.assertEqual(usage["waiting_time"], 3) + self.assertEqual(usage["total_time"], 6) + self.assertEqual(usage["result"], u"crowded") + + + def _mailbox(self, app, mailbox_id): + return app._db.execute("SELECT * FROM `mailboxes`" + " WHERE `app_id`='appid' AND `id`=?", + (mailbox_id,)).fetchone() + + def test_mailbox(self): + app = self._rendezvous.get_app(u"appid") + mailbox_id = u"mid" + m1 = app.open_mailbox(mailbox_id, u"side1", 0) + + row = self._mailbox(app, mailbox_id) + self.assertEqual(row["side1"], u"side1") + self.assertEqual(row["side2"], None) + self.assertEqual(row["crowded"], False) + self.assertEqual(row["started"], 0) + self.assertEqual(row["second"], None) + + # opening the same mailbox twice, by the same side, gets the same + # object + self.assertIdentical(m1, app.open_mailbox(mailbox_id, u"side1", 1)) + row = self._mailbox(app, mailbox_id) + self.assertEqual(row["side1"], u"side1") + self.assertEqual(row["side2"], None) + self.assertEqual(row["crowded"], False) + self.assertEqual(row["started"], 0) + self.assertEqual(row["second"], None) + + # opening a second side gets the same object, and adds a new claim + self.assertIdentical(m1, app.open_mailbox(mailbox_id, u"side2", 2)) + row = self._mailbox(app, mailbox_id) + self.assertEqual(row["side1"], u"side1") + self.assertEqual(row["side2"], u"side2") + self.assertEqual(row["crowded"], False) + self.assertEqual(row["started"], 0) + self.assertEqual(row["second"], 2) + + # a third open marks it as crowded + self.assertRaises(rendezvous.CrowdedError, + app.open_mailbox, mailbox_id, u"side3", 3) + row = self._mailbox(app, mailbox_id) + self.assertEqual(row["side1"], u"side1") + self.assertEqual(row["side2"], u"side2") + self.assertEqual(row["crowded"], True) + self.assertEqual(row["started"], 0) + self.assertEqual(row["second"], 2) + + # closing a side that never claimed the mailbox is ignored + m1.close(u"side4", u"mood", 4) + row = self._mailbox(app, mailbox_id) + self.assertEqual(row["side1"], u"side1") + self.assertEqual(row["side2"], u"side2") + self.assertEqual(row["crowded"], True) + self.assertEqual(row["started"], 0) + self.assertEqual(row["second"], 2) + + # closing one side leaves the second claim + m1.close(u"side1", u"mood", 5) + row = self._mailbox(app, mailbox_id) + self.assertEqual(row["side1"], u"side2") + self.assertEqual(row["side2"], None) + self.assertEqual(row["crowded"], True) + self.assertEqual(row["started"], 0) + self.assertEqual(row["second"], 2) + + # closing one side multiple is ignored + m1.close(u"side1", u"mood", 6) + row = self._mailbox(app, mailbox_id) + self.assertEqual(row["side1"], u"side2") + self.assertEqual(row["side2"], None) + self.assertEqual(row["crowded"], True) + self.assertEqual(row["started"], 0) + self.assertEqual(row["second"], 2) + + # closing the second side frees the mailbox, and adds usage + m1.close(u"side2", u"mood", 7) + row = self._mailbox(app, mailbox_id) + self.assertEqual(row, None) + usage = app._db.execute("SELECT * FROM `mailbox_usage`").fetchone() + self.assertEqual(usage["app_id"], u"appid") + self.assertEqual(usage["started"], 0) + self.assertEqual(usage["waiting_time"], 2) + self.assertEqual(usage["total_time"], 7) + self.assertEqual(usage["result"], u"crowded") + + def test_messages(self): + app = self._rendezvous.get_app(u"appid") + mailbox_id = u"mid" + m1 = app.open_mailbox(mailbox_id, u"side1", 0) + m1.add_message(u"side1", u"phase", u"body", 1, u"msgid") + # XXX more + + def strip_message(msg): m2 = msg.copy() m2.pop("id", None) @@ -465,73 +678,60 @@ class WebSocketAPI(ServerBase, unittest.TestCase): class Summary(unittest.TestCase): - def test_summarize(self): - c = rendezvous.Channel(None, None, None, False, None, None) - A = rendezvous.CLAIM - D = rendezvous.RELEASE + def test_mailbox(self): + c = rendezvous.Mailbox(None, None, None, False, None, None) + # starts at time 1, maybe gets second open at time 3, closes at 5 + base_row = {u"started": 1, u"second": None, + u"first_mood": None, u"crowded": False} + def summ(num_sides, second_mood=None, pruned=False, **kwargs): + row = base_row.copy() + row.update(kwargs) + return c._summarize(row, num_sides, second_mood, 5, pruned) - messages = [{"server_rx": 1, "side": "a", "phase": A}] - self.failUnlessEqual(c._summarize(messages, 2), - (1, "lonely", 1, None)) + self.assertEqual(summ(1), Usage(1, None, 4, u"lonely")) + self.assertEqual(summ(1, u"lonely"), Usage(1, None, 4, u"lonely")) + self.assertEqual(summ(1, u"errory"), Usage(1, None, 4, u"errory")) + self.assertEqual(summ(1, crowded=True), Usage(1, None, 4, u"crowded")) - messages = [{"server_rx": 1, "side": "a", "phase": A}, - {"server_rx": 2, "side": "a", "phase": D, "body": "lonely"}, - ] - self.failUnlessEqual(c._summarize(messages, 3), - (1, "lonely", 2, None)) + self.assertEqual(summ(2, first_mood=u"happy", + second=3, second_mood=u"happy"), + Usage(1, 2, 4, u"happy")) - messages = [{"server_rx": 1, "side": "a", "phase": A}, - {"server_rx": 2, "side": "b", "phase": A}, - {"server_rx": 3, "side": "c", "phase": A}, - ] - self.failUnlessEqual(c._summarize(messages, 4), - (1, "crowded", 3, None)) + self.assertEqual(summ(2, first_mood=u"errory", + second=3, second_mood=u"happy"), + Usage(1, 2, 4, u"errory")) - base = [{"server_rx": 1, "side": "a", "phase": A}, - {"server_rx": 2, "side": "a", "phase": "pake", "body": "msg1"}, - {"server_rx": 10, "side": "b", "phase": "pake", "body": "msg2"}, - {"server_rx": 11, "side": "b", "phase": "data", "body": "msg3"}, - {"server_rx": 20, "side": "a", "phase": "data", "body": "msg4"}, - ] - def make_moods(A_mood, B_mood): - return base + [ - {"server_rx": 21, "side": "a", "phase": D, "body": A_mood}, - {"server_rx": 30, "side": "b", "phase": D, "body": B_mood}, - ] + self.assertEqual(summ(2, first_mood=u"happy", + second=3, second_mood=u"errory"), + Usage(1, 2, 4, u"errory")) - self.failUnlessEqual(c._summarize(make_moods("happy", "happy"), 41), - (1, "happy", 40, 9)) + self.assertEqual(summ(2, first_mood=u"scary", + second=3, second_mood=u"happy"), + Usage(1, 2, 4, u"scary")) - self.failUnlessEqual(c._summarize(make_moods("scary", "happy"), 41), - (1, "scary", 40, 9)) - self.failUnlessEqual(c._summarize(make_moods("happy", "scary"), 41), - (1, "scary", 40, 9)) + self.assertEqual(summ(2, first_mood=u"scary", + second=3, second_mood=u"errory"), + Usage(1, 2, 4, u"scary")) - self.failUnlessEqual(c._summarize(make_moods("lonely", "happy"), 41), - (1, "lonely", 40, 9)) - self.failUnlessEqual(c._summarize(make_moods("happy", "lonely"), 41), - (1, "lonely", 40, 9)) + self.assertEqual(summ(2, first_mood=u"happy", second=3, pruned=True), + Usage(1, 2, 4, u"pruney")) - self.failUnlessEqual(c._summarize(make_moods("errory", "happy"), 41), - (1, "errory", 40, 9)) - self.failUnlessEqual(c._summarize(make_moods("happy", "errory"), 41), - (1, "errory", 40, 9)) + def test_nameplate(self): + a = rendezvous.AppNamespace(None, None, None, False, None) + # starts at time 1, maybe gets second open at time 3, closes at 5 + base_row = {u"started": 1, u"second": None, u"crowded": False} + def summ(num_sides, pruned=False, **kwargs): + row = base_row.copy() + row.update(kwargs) + return a._summarize_nameplate_usage(row, 5, pruned) - # scary trumps other moods - self.failUnlessEqual(c._summarize(make_moods("scary", "lonely"), 41), - (1, "scary", 40, 9)) - self.failUnlessEqual(c._summarize(make_moods("scary", "errory"), 41), - (1, "scary", 40, 9)) + self.assertEqual(summ(1), Usage(1, None, 4, u"lonely")) + self.assertEqual(summ(1, crowded=True), Usage(1, None, 4, u"crowded")) - # older clients don't send a mood - self.failUnlessEqual(c._summarize(make_moods(None, None), 41), - (1, "quiet", 40, 9)) - self.failUnlessEqual(c._summarize(make_moods(None, "happy"), 41), - (1, "quiet", 40, 9)) - self.failUnlessEqual(c._summarize(make_moods(None, "happy"), 41), - (1, "quiet", 40, 9)) - self.failUnlessEqual(c._summarize(make_moods(None, "scary"), 41), - (1, "scary", 40, 9)) + self.assertEqual(summ(2, second=3), Usage(1, 2, 4, u"happy")) + + self.assertEqual(summ(2, second=3, pruned=True), + Usage(1, 2, 4, u"pruney")) class Accumulator(protocol.Protocol): def __init__(self): From f044ef0efaf95f53af745d1b1b2e32026e6adc76 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Thu, 19 May 2016 23:50:22 -0700 Subject: [PATCH 22/51] tests almost good --- src/wormhole/server/rendezvous.py | 11 - src/wormhole/server/rendezvous_websocket.py | 15 +- src/wormhole/test/test_server.py | 318 ++++++++++++++++++-- 3 files changed, 304 insertions(+), 40 deletions(-) diff --git a/src/wormhole/server/rendezvous.py b/src/wormhole/server/rendezvous.py index f2b0829..38f7af3 100644 --- a/src/wormhole/server/rendezvous.py +++ b/src/wormhole/server/rendezvous.py @@ -13,9 +13,6 @@ MB = 1000*1000 CHANNEL_EXPIRATION_TIME = 3*DAY EXPIRATION_CHECK_PERIOD = 2*HOUR -CLAIM = u"_claim" -RELEASE = u"_release" - def get_sides(row): return set([s for s in [row["side1"], row["side2"]] if s]) def make_sides(sides): @@ -96,8 +93,6 @@ class Mailbox: " WHERE `app_id`=? AND `mailbox_id`=?" " ORDER BY `server_rx` ASC", (self._app_id, self._mailbox_id)).fetchall(): - if row["phase"] in (CLAIM, RELEASE): - continue messages.append({"phase": row["phase"], "body": row["body"], "server_rx": row["server_rx"], "id": row["msg_id"]}) return messages @@ -271,12 +266,6 @@ class AppNamespace: del mailbox_id # ignored, they'll learn it from claim() return nameplate_id - def _get_mailbox_id(self, nameplate_id): - row = self._db.execute("SELECT `mailbox_id` FROM `nameplates`" - " WHERE `app_id`=? AND `id`=?", - (self._app_id, nameplate_id)).fetchone() - return row["mailbox_id"] - def claim_nameplate(self, nameplate_id, side, when): # when we're done: # * there will be one row for the nameplate diff --git a/src/wormhole/server/rendezvous_websocket.py b/src/wormhole/server/rendezvous_websocket.py index 0f6d5d2..7e72813 100644 --- a/src/wormhole/server/rendezvous_websocket.py +++ b/src/wormhole/server/rendezvous_websocket.py @@ -85,6 +85,7 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): self._app = None self._side = None self._did_allocate = False # only one allocate() per websocket + self._nameplate_id = None self._mailbox = None def onConnect(self, request): @@ -112,7 +113,7 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): return self.handle_bind(msg) if not self._app: - raise Error("Must bind first") + raise Error("must bind first") if mtype == "list": return self.handle_list() if mtype == "allocate": @@ -120,7 +121,7 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): if mtype == "claim": return self.handle_claim(msg, server_rx) if mtype == "release": - return self.handle_release(msg, server_rx) + return self.handle_release(server_rx) if mtype == "open": return self.handle_open(msg, server_rx) @@ -155,7 +156,7 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): def handle_allocate(self, server_rx): if self._did_allocate: - raise Error("You already allocated one mailbox, don't be greedy") + raise Error("you already allocated one, don't be greedy") nameplate_id = self._app.allocate_nameplate(self._side, server_rx) assert isinstance(nameplate_id, type(u"")) self._did_allocate = True @@ -184,7 +185,9 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): def handle_open(self, msg, server_rx): if self._mailbox: raise Error("you already have a mailbox open") - mailbox_id = msg["mailbox_id"] + if "mailbox" not in msg: + raise Error("open requires 'mailbox'") + mailbox_id = msg["mailbox"] assert isinstance(mailbox_id, type(u"")) self._mailbox = self._app.open_mailbox(mailbox_id, self._side, server_rx) @@ -209,9 +212,9 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): def handle_close(self, msg, server_rx): if not self._mailbox: raise Error("must open mailbox before closing") - deleted = self._mailbox.close(self._side, msg.get("mood"), server_rx) + self._mailbox.close(self._side, msg.get("mood"), server_rx) self._mailbox = None - self.send("released", status="deleted" if deleted else "waiting") + self.send("closed") def send(self, mtype, **kwargs): kwargs["type"] = mtype diff --git a/src/wormhole/test/test_server.py b/src/wormhole/test/test_server.py index 1497e7d..5627a1c 100644 --- a/src/wormhole/test/test_server.py +++ b/src/wormhole/test/test_server.py @@ -229,8 +229,13 @@ class Server(ServerBase, unittest.TestCase): self.assertEqual(row["started"], 0) self.assertEqual(row["second"], 2) + l1 = []; stop1 = []; stop1_f = lambda: stop1.append(True) + m1.add_listener("handle1", l1.append, stop1_f) + # closing the second side frees the mailbox, and adds usage m1.close(u"side2", u"mood", 7) + self.assertEqual(stop1, [True]) + row = self._mailbox(app, mailbox_id) self.assertEqual(row, None) usage = app._db.execute("SELECT * FROM `mailbox_usage`").fetchone() @@ -240,12 +245,55 @@ class Server(ServerBase, unittest.TestCase): self.assertEqual(usage["total_time"], 7) self.assertEqual(usage["result"], u"crowded") + def _messages(self, app): + c = app._db.execute("SELECT * FROM `messages`" + " WHERE `app_id`='appid' AND `mailbox_id`='mid'") + return c.fetchall() + def test_messages(self): app = self._rendezvous.get_app(u"appid") mailbox_id = u"mid" m1 = app.open_mailbox(mailbox_id, u"side1", 0) m1.add_message(u"side1", u"phase", u"body", 1, u"msgid") - # XXX more + msgs = self._messages(app) + self.assertEqual(len(msgs), 1) + self.assertEqual(msgs[0]["body"], u"body") + + l1 = []; stop1 = []; stop1_f = lambda: stop1.append(True) + l2 = []; stop2 = []; stop2_f = lambda: stop2.append(True) + old = m1.add_listener("handle1", l1.append, stop1_f) + self.assertEqual(len(old), 1) + self.assertEqual(old[0]["body"], u"body") + + m1.add_message(u"side1", u"phase2", u"body2", 1, u"msgid") + self.assertEqual(len(l1), 1) + self.assertEqual(l1[0]["body"], u"body2") + old = m1.add_listener("handle2", l2.append, stop2_f) + self.assertEqual(len(old), 2) + + m1.add_message(u"side1", u"phase3", u"body3", 1, u"msgid") + self.assertEqual(len(l1), 2) + self.assertEqual(l1[-1]["body"], u"body3") + self.assertEqual(len(l2), 1) + self.assertEqual(l2[-1]["body"], u"body3") + + m1.remove_listener("handle1") + + m1.add_message(u"side1", u"phase4", u"body4", 1, u"msgid") + self.assertEqual(len(l1), 2) + self.assertEqual(l1[-1]["body"], u"body3") + self.assertEqual(len(l2), 2) + self.assertEqual(l2[-1]["body"], u"body4") + + m1._shutdown() + self.assertEqual(stop1, []) + self.assertEqual(stop2, [True]) + + # message adds are not idempotent: clients filter duplicates + m1.add_message(u"side1", u"phase", u"body", 1, u"msgid") + msgs = self._messages(app) + self.assertEqual(len(msgs), 5) + self.assertEqual(msgs[-1]["body"], u"body") def strip_message(msg): @@ -415,35 +463,259 @@ class WebSocketAPI(ServerBase, unittest.TestCase): self.assertEqual(self._rendezvous._apps, {}) @inlineCallbacks - def test_claim(self): - r = self._rendezvous.get_app(u"appid") + def test_bind(self): c1 = yield self.make_client() - msg = yield c1.next_non_ack() - self.check_welcome(msg) + yield c1.next_non_ack() + + c1.send(u"bind", appid=u"appid") # missing side= + err = yield c1.next_non_ack() + self.assertEqual(err[u"type"], u"error") + self.assertEqual(err[u"error"], u"bind requires 'side'") + + c1.send(u"bind", side=u"side") # missing appid= + err = yield c1.next_non_ack() + self.assertEqual(err[u"type"], u"error") + self.assertEqual(err[u"error"], u"bind requires 'appid'") + c1.send(u"bind", appid=u"appid", side=u"side") - c1.send(u"claim", channelid=u"1") yield c1.sync() - self.assertEqual(r.get_claimed(), set(u"1")) + self.assertEqual(self._rendezvous._apps.keys(), [u"appid"]) - c1.send(u"claim", channelid=u"2") - yield c1.sync() - self.assertEqual(r.get_claimed(), set([u"1", u"2"])) + c1.send(u"bind", appid=u"appid", side=u"side") # duplicate + err = yield c1.next_non_ack() + self.assertEqual(err[u"type"], u"error") + self.assertEqual(err[u"error"], u"already bound") - c1.send(u"claim", channelid=u"72aoqnnnbj7r2") - yield c1.sync() - self.assertEqual(r.get_claimed(), set([u"1", u"2", u"72aoqnnnbj7r2"])) + @inlineCallbacks + def test_list(self): + c1 = yield self.make_client() + yield c1.next_non_ack() - c1.send(u"release", channelid=u"2") - yield c1.sync() - self.assertEqual(r.get_claimed(), set([u"1", u"72aoqnnnbj7r2"])) + c1.send(u"list") # too early, must bind first + err = yield c1.next_non_ack() + self.assertEqual(err[u"type"], u"error") + self.assertEqual(err[u"error"], u"must bind first") - c1.send(u"release", channelid=u"1") + c1.send(u"bind", appid=u"appid", side=u"side") + c1.send(u"list") + m = yield c1.next_non_ack() + self.assertEqual(m[u"type"], u"nameplates") + self.assertEqual(m[u"nameplates"], []) + + app = self._rendezvous.get_app(u"appid") + nameplate_id1 = app.allocate_nameplate(u"side", 0) + app.claim_nameplate(u"np2", u"side", 0) + + c1.send(u"list") + m = yield c1.next_non_ack() + self.assertEqual(m[u"type"], u"nameplates") + self.assertEqual(set(m[u"nameplates"]), set([nameplate_id1, u"np2"])) + + @inlineCallbacks + def test_allocate(self): + c1 = yield self.make_client() + yield c1.next_non_ack() + + c1.send(u"allocate") # too early, must bind first + err = yield c1.next_non_ack() + self.assertEqual(err[u"type"], u"error") + self.assertEqual(err[u"error"], u"must bind first") + + c1.send(u"bind", appid=u"appid", side=u"side") + app = self._rendezvous.get_app(u"appid") + c1.send(u"allocate") + m = yield c1.next_non_ack() + self.assertEqual(m[u"type"], u"nameplate") + nameplate_id = m[u"nameplate"] + + nids = app.get_nameplate_ids() + self.assertEqual(len(nids), 1) + self.assertEqual(nameplate_id, list(nids)[0]) + + c1.send(u"allocate") + err = yield c1.next_non_ack() + self.assertEqual(err[u"type"], u"error") + self.assertEqual(err[u"error"], + u"you already allocated one, don't be greedy") + + c1.send(u"claim", nameplate=nameplate_id) # allocate+claim is ok yield c1.sync() - self.assertEqual(r.get_claimed(), set([u"72aoqnnnbj7r2"])) + row = app._db.execute("SELECT * FROM `nameplates`" + " WHERE `app_id`='appid' AND `id`=?", + (nameplate_id,)).fetchone() + self.assertEqual(row["side1"], u"side") + self.assertEqual(row["side2"], None) + + @inlineCallbacks + def test_claim(self): + c1 = yield self.make_client() + yield c1.next_non_ack() + c1.send(u"bind", appid=u"appid", side=u"side") + app = self._rendezvous.get_app(u"appid") + + c1.send(u"claim") # missing nameplate= + err = yield c1.next_non_ack() + self.assertEqual(err[u"type"], u"error") + self.assertEqual(err[u"error"], u"claim requires 'nameplate'") + + c1.send(u"claim", nameplate=u"np1") + m = yield c1.next_non_ack() + self.assertEqual(m[u"type"], u"mailbox") + mailbox_id = m[u"mailbox"] + self.assertEqual(type(mailbox_id), type(u"")) + + nids = app.get_nameplate_ids() + self.assertEqual(len(nids), 1) + self.assertEqual(u"np1", list(nids)[0]) + + # claiming a nameplate will assign a random mailbox id, but won't + # create the mailbox itself + mailboxes = app._db.execute("SELECT * FROM `mailboxes`" + " WHERE `app_id`='appid'").fetchall() + self.assertEqual(len(mailboxes), 0) + + @inlineCallbacks + def test_claim_crowded(self): + c1 = yield self.make_client() + yield c1.next_non_ack() + c1.send(u"bind", appid=u"appid", side=u"side") + app = self._rendezvous.get_app(u"appid") + + app.claim_nameplate(u"np1", u"side1", 0) + app.claim_nameplate(u"np1", u"side2", 0) + + # the third claim will signal crowding + c1.send(u"claim", nameplate=u"np1") + err = yield c1.next_non_ack() + self.assertEqual(err[u"type"], u"error") + self.assertEqual(err[u"error"], u"crowded") + + @inlineCallbacks + def test_release(self): + c1 = yield self.make_client() + yield c1.next_non_ack() + c1.send(u"bind", appid=u"appid", side=u"side") + app = self._rendezvous.get_app(u"appid") + + app.claim_nameplate(u"np1", u"side2", 0) + + c1.send(u"release") # didn't do claim first + err = yield c1.next_non_ack() + self.assertEqual(err[u"type"], u"error") + self.assertEqual(err[u"error"], + u"must claim a nameplate before releasing it") + + c1.send(u"claim", nameplate=u"np1") + yield c1.next_non_ack() + + c1.send(u"release") + yield c1.sync() + + row = app._db.execute("SELECT * FROM `nameplates`" + " WHERE `app_id`='appid' AND `id`='np1'").fetchone() + self.assertEqual(row["side1"], u"side2") + self.assertEqual(row["side2"], None) + + c1.send(u"release") # no longer claimed + err = yield c1.next_non_ack() + self.assertEqual(err[u"type"], u"error") + self.assertEqual(err[u"error"], + u"must claim a nameplate before releasing it") + + @inlineCallbacks + def test_open(self): + c1 = yield self.make_client() + yield c1.next_non_ack() + c1.send(u"bind", appid=u"appid", side=u"side") + app = self._rendezvous.get_app(u"appid") + + c1.send(u"open") # missing mailbox= + err = yield c1.next_non_ack() + self.assertEqual(err[u"type"], u"error") + self.assertEqual(err[u"error"], u"open requires 'mailbox'") + + mb1 = app.open_mailbox(u"mb1", u"side2", 0) + mb1.add_message(u"side2", u"phase", u"body", 0, u"msgid") + + c1.send(u"open", mailbox=u"mb1") + m = yield c1.next_non_ack() + self.assertEqual(m[u"type"], u"message") + self.assertEqual(m[u"message"][u"body"], u"body") + + mb1.add_message(u"side2", u"phase2", u"body2", 0, u"msgid") + m = yield c1.next_non_ack() + self.assertEqual(m[u"type"], u"message") + self.assertEqual(m[u"message"][u"body"], u"body2") + + c1.send(u"open", mailbox=u"mb1") + err = yield c1.next_non_ack() + self.assertEqual(err[u"type"], u"error") + self.assertEqual(err[u"error"], u"you already have a mailbox open") + + @inlineCallbacks + def test_add(self): + c1 = yield self.make_client() + yield c1.next_non_ack() + c1.send(u"bind", appid=u"appid", side=u"side") + app = self._rendezvous.get_app(u"appid") + mb1 = app.open_mailbox(u"mb1", u"side2", 0) + l1 = []; stop1 = []; stop1_f = lambda: stop1.append(True) + mb1.add_listener("handle1", l1.append, stop1_f) + + c1.send(u"add") # didn't open first + err = yield c1.next_non_ack() + self.assertEqual(err[u"type"], u"error") + self.assertEqual(err[u"error"], u"must open mailbox before adding") + + c1.send(u"open", mailbox=u"mb1") + + c1.send(u"add", body=u"body") # missing phase= + err = yield c1.next_non_ack() + self.assertEqual(err[u"type"], u"error") + self.assertEqual(err[u"error"], u"missing 'phase'") + + c1.send(u"add", phase=u"phase") # missing body= + err = yield c1.next_non_ack() + self.assertEqual(err[u"type"], u"error") + self.assertEqual(err[u"error"], u"missing 'body'") + + c1.send(u"add", phase=u"phase", body=u"body") + m = yield c1.next_non_ack() # echoed back + self.assertEqual(m[u"type"], u"message") + self.assertEqual(m[u"message"][u"body"], u"body") + + self.assertEqual(len(l1), 1) + self.assertEqual(l1[0][u"body"], u"body") + + @inlineCallbacks + def test_close(self): + c1 = yield self.make_client() + yield c1.next_non_ack() + c1.send(u"bind", appid=u"appid", side=u"side") + app = self._rendezvous.get_app(u"appid") + + c1.send(u"close", mood=u"mood") # must open first + err = yield c1.next_non_ack() + self.assertEqual(err[u"type"], u"error") + self.assertEqual(err[u"error"], u"must open mailbox before closing") + + c1.send(u"open", mailbox=u"mb1") + c1.send(u"close", mood=u"mood") + m = yield c1.next_non_ack() + self.assertEqual(m[u"type"], u"closed") + + return + print("doing last close") + c1.send(u"close", mood=u"mood") # already closed # XXX not getting through + print("did last close") + err = yield c1.next_non_ack() + print("done") + self.assertEqual(err[u"type"], u"error") + self.assertEqual(err[u"error"], u"must open mailbox before closing") @inlineCallbacks - def test_allocate_1(self): + def OFFtest_allocate_1(self): c1 = yield self.make_client() msg = yield c1.next_non_ack() self.check_welcome(msg) @@ -483,7 +755,7 @@ class WebSocketAPI(ServerBase, unittest.TestCase): self.assertEqual(msg["channelids"], []) @inlineCallbacks - def test_allocate_2(self): + def OFFtest_allocate_2(self): c1 = yield self.make_client() msg = yield c1.next_non_ack() self.check_welcome(msg) @@ -539,7 +811,7 @@ class WebSocketAPI(ServerBase, unittest.TestCase): self.assertEqual(msg["channelids"], []) @inlineCallbacks - def test_allocate_and_claim(self): + def OFFtest_allocate_and_claim(self): r = self._rendezvous.get_app(u"appid") c1 = yield self.make_client() msg = yield c1.next_non_ack() @@ -564,7 +836,7 @@ class WebSocketAPI(ServerBase, unittest.TestCase): self.assertEqual(r.get_claimed(), set([cid])) @inlineCallbacks - def test_allocate_and_claim_two(self): + def OFFtest_allocate_and_claim_two(self): r = self._rendezvous.get_app(u"appid") c1 = yield self.make_client() msg = yield c1.next_non_ack() @@ -592,7 +864,7 @@ class WebSocketAPI(ServerBase, unittest.TestCase): self.assertEqual(r.get_claimed(), set()) @inlineCallbacks - def test_message(self): + def OFFtest_message(self): c1 = yield self.make_client() msg = yield c1.next_non_ack() self.check_welcome(msg) From 399efb374c99e94ba85e0d5383fa27b38ac0c823 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Fri, 20 May 2016 11:07:21 -0700 Subject: [PATCH 23/51] don't close websocket when mailbox is deleted This made sense for ServerSentEvent channels (which has no purpose once the channel was gone), but not so much for websockets. And it prevented testing duplicate-close. --- src/wormhole/server/rendezvous_websocket.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/wormhole/server/rendezvous_websocket.py b/src/wormhole/server/rendezvous_websocket.py index 7e72813..e37fb6b 100644 --- a/src/wormhole/server/rendezvous_websocket.py +++ b/src/wormhole/server/rendezvous_websocket.py @@ -194,7 +194,7 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): def _send(event): self.send("message", message=event) def _stop(): - self._reactor.callLater(0, self.transport.loseConnection) + pass for old_message in self._mailbox.add_listener(self, _send, _stop): _send(old_message) From 0a14901e94831513dad0e250fe995e951e698175 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Fri, 20 May 2016 11:08:10 -0700 Subject: [PATCH 24/51] full coverage of websocket --- src/wormhole/server/rendezvous_websocket.py | 2 +- src/wormhole/test/test_server.py | 25 ++++++++++++++++----- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/src/wormhole/server/rendezvous_websocket.py b/src/wormhole/server/rendezvous_websocket.py index e37fb6b..b0fb10a 100644 --- a/src/wormhole/server/rendezvous_websocket.py +++ b/src/wormhole/server/rendezvous_websocket.py @@ -130,7 +130,7 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): if mtype == "close": return self.handle_close(msg, server_rx) - raise Error("Unknown type") + raise Error("unknown type") except Error as e: self.send("error", error=e._explain, orig=msg) diff --git a/src/wormhole/test/test_server.py b/src/wormhole/test/test_server.py index 5627a1c..4bd1167 100644 --- a/src/wormhole/test/test_server.py +++ b/src/wormhole/test/test_server.py @@ -349,6 +349,10 @@ class WSClient(websocket.WebSocketClientProtocol): payload = json.dumps(kwargs).encode("utf-8") self.sendMessage(payload, False) + def send_notype(self, **kwargs): + payload = json.dumps(kwargs).encode("utf-8") + self.sendMessage(payload, False) + @inlineCallbacks def sync(self): ping = next(self.ping_counter) @@ -486,6 +490,21 @@ class WebSocketAPI(ServerBase, unittest.TestCase): self.assertEqual(err[u"type"], u"error") self.assertEqual(err[u"error"], u"already bound") + c1.send_notype(other="misc") # missing 'type' + err = yield c1.next_non_ack() + self.assertEqual(err[u"type"], u"error") + self.assertEqual(err[u"error"], u"missing 'type'") + + c1.send("___unknown") # unknown type + err = yield c1.next_non_ack() + self.assertEqual(err[u"type"], u"error") + self.assertEqual(err[u"error"], u"unknown type") + + c1.send("ping") # missing 'ping' + err = yield c1.next_non_ack() + self.assertEqual(err[u"type"], u"error") + self.assertEqual(err[u"error"], u"ping requires 'ping'") + @inlineCallbacks def test_list(self): c1 = yield self.make_client() @@ -704,12 +723,8 @@ class WebSocketAPI(ServerBase, unittest.TestCase): m = yield c1.next_non_ack() self.assertEqual(m[u"type"], u"closed") - return - print("doing last close") - c1.send(u"close", mood=u"mood") # already closed # XXX not getting through - print("did last close") + c1.send(u"close", mood=u"mood") # already closed err = yield c1.next_non_ack() - print("done") self.assertEqual(err[u"type"], u"error") self.assertEqual(err[u"error"], u"must open mailbox before closing") From ce06d379d929858b3d65a1d12142308b6354145a Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Fri, 20 May 2016 11:09:45 -0700 Subject: [PATCH 25/51] remove old tests --- src/wormhole/test/test_server.py | 235 ------------------------------- 1 file changed, 235 deletions(-) diff --git a/src/wormhole/test/test_server.py b/src/wormhole/test/test_server.py index 4bd1167..9c57ce4 100644 --- a/src/wormhole/test/test_server.py +++ b/src/wormhole/test/test_server.py @@ -729,241 +729,6 @@ class WebSocketAPI(ServerBase, unittest.TestCase): self.assertEqual(err[u"error"], u"must open mailbox before closing") - @inlineCallbacks - def OFFtest_allocate_1(self): - c1 = yield self.make_client() - msg = yield c1.next_non_ack() - self.check_welcome(msg) - c1.send(u"bind", appid=u"appid", side=u"side") - yield c1.sync() - self.assertEqual(list(self._rendezvous._apps.keys()), [u"appid"]) - app = self._rendezvous.get_app(u"appid") - self.assertEqual(app.get_claimed(), set()) - c1.send(u"list") - msg = yield c1.next_non_ack() - self.assertEqual(msg["type"], u"channelids") - self.assertEqual(msg["channelids"], []) - - c1.send(u"allocate") - msg = yield c1.next_non_ack() - self.assertEqual(msg["type"], u"allocated") - cid = msg["channelid"] - self.failUnlessIsInstance(cid, type(u"")) - self.assertEqual(app.get_claimed(), set([cid])) - channel = app.get_channel(cid) - self.assertEqual(channel.get_messages(), []) - - c1.send(u"list") - msg = yield c1.next_non_ack() - self.assertEqual(msg["type"], u"channelids") - self.assertEqual(msg["channelids"], [cid]) - - c1.send(u"release", channelid=cid) - msg = yield c1.next_non_ack() - self.assertEqual(msg["type"], u"released") - self.assertEqual(msg["status"], u"deleted") - self.assertEqual(app.get_claimed(), set()) - - c1.send(u"list") - msg = yield c1.next_non_ack() - self.assertEqual(msg["type"], u"channelids") - self.assertEqual(msg["channelids"], []) - - @inlineCallbacks - def OFFtest_allocate_2(self): - c1 = yield self.make_client() - msg = yield c1.next_non_ack() - self.check_welcome(msg) - c1.send(u"bind", appid=u"appid", side=u"side") - yield c1.sync() - app = self._rendezvous.get_app(u"appid") - self.assertEqual(app.get_claimed(), set()) - c1.send(u"allocate") - msg = yield c1.next_non_ack() - self.assertEqual(msg["type"], u"allocated") - cid = msg["channelid"] - self.failUnlessIsInstance(cid, type(u"")) - self.assertEqual(app.get_claimed(), set([cid])) - channel = app.get_channel(cid) - self.assertEqual(channel.get_messages(), []) - - # second caller increases the number of known sides to 2 - c2 = yield self.make_client() - msg = yield c2.next_non_ack() - self.check_welcome(msg) - c2.send(u"bind", appid=u"appid", side=u"side-2") - c2.send(u"claim", channelid=cid) - c2.send(u"add", channelid=cid, phase="1", body="") - yield c2.sync() - - self.assertEqual(app.get_claimed(), set([cid])) - self.assertEqual(strip_messages(channel.get_messages()), - [{"phase": "1", "body": ""}]) - - c1.send(u"list") - msg = yield c1.next_non_ack() - self.assertEqual(msg["type"], u"channelids") - self.assertEqual(msg["channelids"], [cid]) - - c2.send(u"list") - msg = yield c2.next_non_ack() - self.assertEqual(msg["type"], u"channelids") - self.assertEqual(msg["channelids"], [cid]) - - c1.send(u"release", channelid=cid) - msg = yield c1.next_non_ack() - self.assertEqual(msg["type"], u"released") - self.assertEqual(msg["status"], u"waiting") - - c2.send(u"release", channelid=cid) - msg = yield c2.next_non_ack() - self.assertEqual(msg["type"], u"released") - self.assertEqual(msg["status"], u"deleted") - - c2.send(u"list") - msg = yield c2.next_non_ack() - self.assertEqual(msg["type"], u"channelids") - self.assertEqual(msg["channelids"], []) - - @inlineCallbacks - def OFFtest_allocate_and_claim(self): - r = self._rendezvous.get_app(u"appid") - c1 = yield self.make_client() - msg = yield c1.next_non_ack() - self.check_welcome(msg) - c1.send(u"bind", appid=u"appid", side=u"side") - c1.send(u"allocate") - msg = yield c1.next_non_ack() - self.assertEqual(msg["type"], u"allocated") - cid = msg["channelid"] - c1.send(u"claim", channelid=cid) - yield c1.sync() - # there should no error - self.assertEqual(c1.errors, []) - self.assertEqual(r.get_claimed(), set([cid])) - - # but trying to allocate twice is an error - c1.send(u"allocate") - yield c1.sync() - self.assertEqual(len(c1.errors), 1) - self.assertEqual(c1.errors[0]["error"], - "You already allocated one channel, don't be greedy") - self.assertEqual(r.get_claimed(), set([cid])) - - @inlineCallbacks - def OFFtest_allocate_and_claim_two(self): - r = self._rendezvous.get_app(u"appid") - c1 = yield self.make_client() - msg = yield c1.next_non_ack() - self.check_welcome(msg) - c1.send(u"bind", appid=u"appid", side=u"side") - c1.send(u"allocate") - msg = yield c1.next_non_ack() - self.assertEqual(msg["type"], u"allocated") - cid = msg["channelid"] - c1.send(u"claim", channelid=cid) - yield c1.sync() - # there should no error - self.assertEqual(c1.errors, []) - - c1.send(u"claim", channelid=u"other") - yield c1.sync() - self.assertEqual(c1.errors, []) - self.assertEqual(r.get_claimed(), set([cid, u"other"])) - - c1.send(u"release", channelid=cid) - yield c1.sync() - self.assertEqual(r.get_claimed(), set([u"other"])) - c1.send(u"release", channelid="other") - yield c1.sync() - self.assertEqual(r.get_claimed(), set()) - - @inlineCallbacks - def OFFtest_message(self): - c1 = yield self.make_client() - msg = yield c1.next_non_ack() - self.check_welcome(msg) - c1.send(u"bind", appid=u"appid", side=u"side") - c1.send(u"allocate") - msg = yield c1.next_non_ack() - self.assertEqual(msg["type"], u"allocated") - cid = msg["channelid"] - app = self._rendezvous.get_app(u"appid") - channel = app.get_channel(cid) - self.assertEqual(channel.get_messages(), []) - - c1.send(u"watch", channelid=cid) - yield c1.sync() - self.assertEqual(len(channel._listeners), 1) - c1.strip_acks() - self.assertEqual(c1.events, []) - - c1.send(u"add", channelid=cid, phase="1", body="msg1A") - yield c1.sync() - c1.strip_acks() - self.assertEqual(strip_messages(channel.get_messages()), - [{"phase": "1", "body": "msg1A"}]) - self.assertEqual(len(c1.events), 1) # echo should be sent right away - msg = yield c1.next_non_ack() - self.assertEqual(msg["type"], "message") - self.assertEqual(strip_message(msg["message"]), - {"phase": "1", "body": "msg1A"}) - self.assertIn("server_tx", msg) - self.assertIsInstance(msg["server_tx"], float) - - c1.send(u"add", channelid=cid, phase="1", body="msg1B") - c1.send(u"add", channelid=cid, phase="2", body="msg2A") - - msg = yield c1.next_non_ack() - self.assertEqual(msg["type"], "message") - self.assertEqual(strip_message(msg["message"]), - {"phase": "1", "body": "msg1B"}) - - msg = yield c1.next_non_ack() - self.assertEqual(msg["type"], "message") - self.assertEqual(strip_message(msg["message"]), - {"phase": "2", "body": "msg2A"}) - - self.assertEqual(strip_messages(channel.get_messages()), [ - {"phase": "1", "body": "msg1A"}, - {"phase": "1", "body": "msg1B"}, - {"phase": "2", "body": "msg2A"}, - ]) - - # second client should see everything - c2 = yield self.make_client() - msg = yield c2.next_non_ack() - self.check_welcome(msg) - c2.send(u"bind", appid=u"appid", side=u"side") - c2.send(u"claim", channelid=cid) - # 'watch' triggers delivery of old messages, in temporal order - c2.send(u"watch", channelid=cid) - - msg = yield c2.next_non_ack() - self.assertEqual(msg["type"], "message") - self.assertEqual(strip_message(msg["message"]), - {"phase": "1", "body": "msg1A"}) - - msg = yield c2.next_non_ack() - self.assertEqual(msg["type"], "message") - self.assertEqual(strip_message(msg["message"]), - {"phase": "1", "body": "msg1B"}) - - msg = yield c2.next_non_ack() - self.assertEqual(msg["type"], "message") - self.assertEqual(strip_message(msg["message"]), - {"phase": "2", "body": "msg2A"}) - - # adding a duplicate is not an error, and clients will ignore it - c1.send(u"add", channelid=cid, phase="2", body="msg2A") - - # the duplicate message *does* get stored, and delivered - msg = yield c2.next_non_ack() - self.assertEqual(msg["type"], "message") - self.assertEqual(strip_message(msg["message"]), - {"phase": "2", "body": "msg2A"}) - - class Summary(unittest.TestCase): def test_mailbox(self): c = rendezvous.Mailbox(None, None, None, False, None, None) From 6c5b517ad1fc69e512245c04345677690fc7caec Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Fri, 20 May 2016 11:10:17 -0700 Subject: [PATCH 26/51] hush --- src/wormhole/test/test_server.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/wormhole/test/test_server.py b/src/wormhole/test/test_server.py index 9c57ce4..5289a98 100644 --- a/src/wormhole/test/test_server.py +++ b/src/wormhole/test/test_server.py @@ -711,7 +711,6 @@ class WebSocketAPI(ServerBase, unittest.TestCase): c1 = yield self.make_client() yield c1.next_non_ack() c1.send(u"bind", appid=u"appid", side=u"side") - app = self._rendezvous.get_app(u"appid") c1.send(u"close", mood=u"mood") # must open first err = yield c1.next_non_ack() From 390cd08b5305717ef0b67f36ecf372db5498af2c Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Fri, 20 May 2016 11:35:30 -0700 Subject: [PATCH 27/51] better command/response names: allocate+allocated, claim+claimed --- src/wormhole/server/rendezvous_websocket.py | 8 ++++---- src/wormhole/test/test_server.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/wormhole/server/rendezvous_websocket.py b/src/wormhole/server/rendezvous_websocket.py index b0fb10a..4129400 100644 --- a/src/wormhole/server/rendezvous_websocket.py +++ b/src/wormhole/server/rendezvous_websocket.py @@ -56,9 +56,9 @@ from .rendezvous import CrowdedError # -> {type: "list"} -> nameplates # <- {type: "nameplates", nameplates: [str..]} # -> {type: "allocate"} -> nameplate, mailbox -# <- {type: "nameplate", nameplate: str} +# <- {type: "allocated", nameplate: str} # -> {type: "claim", nameplate: str} -> mailbox -# <- {type: "mailbox", mailbox: str} +# <- {type: "claimed", mailbox: str} # -> {type: "release"} # # -> {type: "open", mailbox: str} -> message @@ -160,7 +160,7 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): nameplate_id = self._app.allocate_nameplate(self._side, server_rx) assert isinstance(nameplate_id, type(u"")) self._did_allocate = True - self.send("nameplate", nameplate=nameplate_id) + self.send("allocated", nameplate=nameplate_id) def handle_claim(self, msg, server_rx): if "nameplate" not in msg: @@ -173,7 +173,7 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): server_rx) except CrowdedError: raise Error("crowded") - self.send("mailbox", mailbox=mailbox_id) + self.send("claimed", mailbox=mailbox_id) def handle_release(self, server_rx): if not self._nameplate_id: diff --git a/src/wormhole/test/test_server.py b/src/wormhole/test/test_server.py index 5289a98..d7377d1 100644 --- a/src/wormhole/test/test_server.py +++ b/src/wormhole/test/test_server.py @@ -544,7 +544,7 @@ class WebSocketAPI(ServerBase, unittest.TestCase): app = self._rendezvous.get_app(u"appid") c1.send(u"allocate") m = yield c1.next_non_ack() - self.assertEqual(m[u"type"], u"nameplate") + self.assertEqual(m[u"type"], u"allocated") nameplate_id = m[u"nameplate"] nids = app.get_nameplate_ids() @@ -579,7 +579,7 @@ class WebSocketAPI(ServerBase, unittest.TestCase): c1.send(u"claim", nameplate=u"np1") m = yield c1.next_non_ack() - self.assertEqual(m[u"type"], u"mailbox") + self.assertEqual(m[u"type"], u"claimed") mailbox_id = m[u"mailbox"] self.assertEqual(type(mailbox_id), type(u"")) From 3b86571de33c39250c81bc492adf3c09632c5d8a Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Fri, 20 May 2016 12:12:07 -0700 Subject: [PATCH 28/51] fix py3 --- src/wormhole/server/rendezvous.py | 2 +- src/wormhole/test/test_server.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/wormhole/server/rendezvous.py b/src/wormhole/server/rendezvous.py index 38f7af3..018b8e2 100644 --- a/src/wormhole/server/rendezvous.py +++ b/src/wormhole/server/rendezvous.py @@ -18,7 +18,7 @@ def get_sides(row): def make_sides(sides): return list(sides) + [None] * (2 - len(sides)) def generate_mailbox_id(): - return base64.b32encode(os.urandom(8)).lower().strip("=") + return base64.b32encode(os.urandom(8)).lower().strip(b"=").decode("ascii") SideResult = namedtuple("SideResult", ["changed", "empty", "side1", "side2"]) diff --git a/src/wormhole/test/test_server.py b/src/wormhole/test/test_server.py index d7377d1..b654a24 100644 --- a/src/wormhole/test/test_server.py +++ b/src/wormhole/test/test_server.py @@ -483,7 +483,7 @@ class WebSocketAPI(ServerBase, unittest.TestCase): c1.send(u"bind", appid=u"appid", side=u"side") yield c1.sync() - self.assertEqual(self._rendezvous._apps.keys(), [u"appid"]) + self.assertEqual(list(self._rendezvous._apps.keys()), [u"appid"]) c1.send(u"bind", appid=u"appid", side=u"side") # duplicate err = yield c1.next_non_ack() From 05aa5ca76ee7995c99003178bfbd1acab5a4bc14 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Fri, 20 May 2016 13:51:05 -0700 Subject: [PATCH 29/51] WIP Wormhole --- src/wormhole/twisted/transcribe.py | 241 ++++++++++++++++++----------- 1 file changed, 148 insertions(+), 93 deletions(-) diff --git a/src/wormhole/twisted/transcribe.py b/src/wormhole/twisted/transcribe.py index d8f39d6..555499b 100644 --- a/src/wormhole/twisted/transcribe.py +++ b/src/wormhole/twisted/transcribe.py @@ -28,6 +28,17 @@ def make_confmsg(confkey, nonce): def to_bytes(u): return unicodedata.normalize("NFC", u).encode("utf-8") +# We send the following messages through the relay server to the far side (by +# sending "add" commands to the server, and getting "message" responses): +# +# phase=setup: +# * unauthenticated version strings (but why?) +# * early warmup for connection hints ("I can do tor, spin up HS") +# * wordlist l10n identifier +# phase=pake: just the SPAKE2 'start' message (binary) +# phase=confirm: key verification (HKDF(key, nonce)+nonce) +# phase=1,2,3,..: application messages + class WSClient(websocket.WebSocketClientProtocol): def onOpen(self): self.wormhole_open = True @@ -35,7 +46,7 @@ class WSClient(websocket.WebSocketClientProtocol): def onMessage(self, payload, isBinary): assert not isBinary - self.wormhole._ws_dispatch_msg(payload) + self.wormhole._ws_dispatch_response(payload) def onClose(self, wasClean, code, reason): if self.wormhole_open: @@ -70,9 +81,20 @@ class _Wormhole: self._tor_manager = tor_manager self._timing = timing or DebugTiming() self._reactor = reactor + + self._ws_connected = defer.Deferred() # XXX + self._side = hexlify(os.urandom(5)).decode("ascii") self._code = None - self._channelid = None + + self._nameplate_id = None + self._nameplate_claimed = False + self._nameplate_released = False + + self._mailbox_id = None + self._mailbox_opened = False + self._mailbox_closed = False + self._key = None self._started_get_code = False self._next_outbound_phase = 0 @@ -88,7 +110,6 @@ class _Wormhole: self._timing_started = self._timing.add("wormhole") self._ws = None self._ws_t = None # timing Event - self._ws_channel_claimed = False self._error = None def _make_endpoint(self, hostname, port): @@ -100,11 +121,9 @@ class _Wormhole: @inlineCallbacks def _get_websocket(self): if not self._ws: - # TODO: if we lose the connection, make a new one - #from twisted.python import log - #log.startLogging(sys.stderr) + # TODO: if we lose the connection, make a new one, re-establish + # the state assert self._side - assert not self._ws_channel_claimed p = urlparse(self._ws_url) f = WSFactory(self._ws_url) f.wormhole = self @@ -118,12 +137,12 @@ class _Wormhole: self._ws_t = self._timing.add("websocket") # f.d is errbacked if WebSocket negotiation fails yield f.d # WebSocket drops data sent before onOpen() fires - self._ws_send(u"bind", appid=self._appid, side=self._side) - # the socket is connected, and bound, but no channel has been claimed + self._ws_send_command(u"bind", appid=self._appid, side=self._side) + # the socket is connected, and bound, but no nameplate has been claimed returnValue(self._ws) @inlineCallbacks - def _ws_send(self, mtype, **kwargs): + def _ws_send_command(self, mtype, **kwargs): ws = yield self._get_websocket() # msgid is used by misc/dump-timing.py to correlate our sends with # their receives, and vice versa. They are also correlated with the @@ -135,7 +154,7 @@ class _Wormhole: self._timing.add("ws_send", _side=self._side, **kwargs) ws.sendMessage(payload, False) - def _ws_dispatch_msg(self, payload): + def _ws_dispatch_response(self, payload): msg = json.loads(payload.decode("utf-8")) self._timing.add("ws_receive", _side=self._side, message=msg) mtype = msg["type"] @@ -205,36 +224,120 @@ class _Wormhole: return self._signal_error(err) @inlineCallbacks - def _claim_channel_and_watch(self): - assert self._channelid is not None - yield self._get_websocket() - if not self._ws_channel_claimed: - yield self._ws_send(u"claim", channelid=self._channelid) - self._ws_channel_claimed = True - yield self._ws_send(u"watch", channelid=self._channelid) + def _claim_nameplate(self): + if not self._nameplate_id: raise UsageError + if self._nameplate_claimed: raise UsageError + yield self._ws_send_command(u"claim", nameplate=self._nameplate_id) + # provokes "claimed" response + + def _ws_handle_claimed(self, msg): + mailbox_id = msg["mailbox"] + assert isinstance(mailbox_id, type(u"")), type(mailbox_id) + self._mailbox_id = mailbox_id + self._open_mailbox() + + @inlineCallbacks + def _release_nameplate(self): + if not self._nameplate_claimed: raise UsageError + if self._nameplate_released: raise UsageError + yield self._ws_send_command(u"release") + self._nameplate_released = True + + + @inlineCallbacks + def _open_mailbox(self): + if not self._mailbox_id: raise UsageError + if self._mailbox_opened: raise UsageError + yield self._ws_send_command(u"open", mailbox=self._mailbox_id) + self._mailbox_opened = True + # causes old messages to be sent now, and subscribes to new messages + + @inlineCallbacks + def _close_mailbox(self): + if not self._mailbox_id: raise UsageError + if not self._mailbox_opened: raise UsageError + if self._mailbox_closed: raise UsageError + yield self._ws_send_command(u"close") + self._mailbox_closed = True + + + @inlineCallbacks + def _msg_send(self, phase, body, wait=False): + if phase in self._sent_messages: raise UsageError + if not self._mailbox_opened: raise UsageError + if self._mailbox_closed: raise UsageError + self._sent_messages[phase] = body + # TODO: retry on failure, with exponential backoff. We're guarding + # against the rendezvous server being temporarily offline. + t = self._timing.add("add", phase=phase, wait=wait) + yield self._ws_send_command(u"add", phase=phase, + body=hexlify(body).decode("ascii")) + if wait: + while phase not in self._delivered_messages: + yield self._sleep() + t.finish() + + def _ws_handle_message(self, msg): + # any message in the mailbox means we no longer need the nameplate + if not self._nameplate_released: + self._release_nameplate() # XXX returns Deferred + + m = msg["message"] + phase = m["phase"] + body = unhexlify(m["body"].encode("ascii")) + if phase in self._sent_messages and self._sent_messages[phase] == body: + self._delivered_messages.add(phase) # ack by server + self._wakeup() + return # ignore echoes of our outbound messages + if phase in self._received_messages: + # a nameplate collision would cause this + err = ServerError("got duplicate phase %s" % phase, self._ws_url) + return self._signal_error(err) + self._received_messages[phase] = body + if phase == u"confirm": + # TODO: we might not have a master key yet, if the caller wasn't + # waiting in _get_master_key() when a back-to-back pake+_confirm + # message pair arrived. + confkey = self.derive_key(u"wormhole:confirmation") + nonce = body[:CONFMSG_NONCE_LENGTH] + if body != make_confmsg(confkey, nonce): + # this makes all API calls fail + return self._signal_error(WrongPasswordError()) + # now notify anyone waiting on it + self._wakeup() + + @inlineCallbacks + def _msg_get(self, phase): + with self._timing.add("get", phase=phase): + while phase not in self._received_messages: + yield self._sleep() # we can wait a long time here + # that will throw an error if something goes wrong + msg = self._received_messages[phase] + returnValue(msg) # entry point 1: generate a new code @inlineCallbacks - def get_code(self, code_length=2): # rename to allocate_code()? create_? + def get_code(self, code_length=2): # XX rename to allocate_code()? create_? if self._code is not None: raise UsageError if self._started_get_code: raise UsageError self._started_get_code = True with self._timing.add("API get_code"): with self._timing.add("allocate"): - yield self._ws_send(u"allocate") - while self._channelid is None: + yield self._ws_send_command(u"allocate") + while self._nameplate_id is None: yield self._sleep() - code = codes.make_code(self._channelid, code_length) + code = codes.make_code(self._nameplate_id, code_length) assert isinstance(code, type(u"")), type(code) self._set_code(code) self._start() returnValue(code) def _ws_handle_allocated(self, msg): - if self._channelid is not None: - return self._signal_error("got duplicate channelid") - self._channelid = msg["channelid"] - assert isinstance(self._channelid, type(u"")), type(self._channelid) + if self._nameplate_id is not None: + return self._signal_error("got duplicate 'allocated' response") + nid = msg["nameplate"] + assert isinstance(nid, type(u"")), type(nid) + self._nameplate_id = nid self._wakeup() def _start(self): @@ -248,8 +351,8 @@ class _Wormhole: @inlineCallbacks def input_code(self, prompt="Enter wormhole code: ", code_length=2): def _lister(): - return blockingCallFromThread(self._reactor, self._list_channels) - # fetch the list of channels ahead of time, to give us a chance to + return blockingCallFromThread(self._reactor, self._list_nameplates) + # fetch the list of nameplates ahead of time, to give us a chance to # discover the welcome message (and warn the user about an obsolete # client) # @@ -257,13 +360,13 @@ class _Wormhole: # latency in the user's indecision and slow typing. If we're lucky # the answer will come back before they hit TAB. with self._timing.add("API input_code"): - initial_channelids = yield self._list_channels() + initial_nameplate_ids = yield self._list_nameplates() with self._timing.add("input code", waiting="user"): t = self._reactor.addSystemEventTrigger("before", "shutdown", self._warn_readline) code = yield deferToThread(codes.input_code_with_completion, prompt, - initial_channelids, _lister, + initial_nameplate_ids, _lister, code_length) self._reactor.removeSystemEventTrigger(t) returnValue(code) # application will give this to set_code() @@ -295,7 +398,7 @@ class _Wormhole: # --internal-get-code"), run it as a subprocess, let it inherit # stdin/stdout, send it SIGINT when we receive SIGINT ourselves. It # needs an RPC mechanism (over some extra file descriptors) to ask - # us to fetch the current channelid list. + # us to fetch the current nameplate_id list. # # Note that hard-terminating our process with os.kill(os.getpid(), # signal.SIGKILL), or SIGTERM, doesn't seem to work: the thread @@ -303,16 +406,16 @@ class _Wormhole: # readline finish. @inlineCallbacks - def _list_channels(self): + def _list_nameplates(self): with self._timing.add("list"): - self._latest_channelids = None - yield self._ws_send(u"list") - while self._latest_channelids is None: + self._latest_nameplate_ids = None + yield self._ws_send_command(u"list") + while self._latest_nameplate_ids is None: yield self._sleep() - returnValue(self._latest_channelids) + returnValue(self._latest_nameplate_ids) - def _ws_handle_channelids(self, msg): - self._latest_channelids = msg["channelids"] + def _ws_handle_nameplates(self, msg): + self._latest_nameplate_ids = msg["nameplates"] self._wakeup() # entry point 2b: paste in a fully-formed code @@ -323,8 +426,8 @@ class _Wormhole: if not mo: raise ValueError("code (%s) must start with NN-" % code) with self._timing.add("API set_code"): - self._channelid = mo.group(1) - assert isinstance(self._channelid, type(u"")), type(self._channelid) + self._nameplate_id = mo.group(1) + assert isinstance(self._nameplate_id, type(u"")), type(self._nameplate_id) self._set_code(code) self._start() @@ -344,7 +447,7 @@ class _Wormhole: "appid": self._appid, "relay_url": self._relay_url, "code": self._code, - "channelid": self._channelid, + "nameplate_id": self._nameplate_id, "side": self._side, "spake2": json.loads(self._sp.serialize().decode("ascii")), "msg1": hexlify(self._msg1).decode("ascii"), @@ -357,7 +460,7 @@ class _Wormhole: d = json.loads(data) self = klass(d["appid"], d["relay_url"], reactor) self._side = d["side"] - self._channelid = d["channelid"] + self._nameplate_id = d["nameplate_id"] self._set_code(d["code"]) sp_data = json.dumps(d["spake2"]).encode("ascii") self._sp = SPAKE2_Symmetric.from_serialized(sp_data) @@ -385,7 +488,7 @@ class _Wormhole: def _get_master_key(self): # TODO: prevent multiple invocation if not self._key: - yield self._claim_channel_and_watch() + yield self._claim_nameplate_and_watch() yield self._msg_send(u"pake", self._msg1) pake_msg = yield self._msg_get(u"pake") @@ -401,54 +504,6 @@ class _Wormhole: confmsg = make_confmsg(confkey, nonce) yield self._msg_send(u"confirm", confmsg, wait=True) - @inlineCallbacks - def _msg_send(self, phase, body, wait=False): - if phase in self._sent_messages: raise UsageError - self._sent_messages[phase] = body - # TODO: retry on failure, with exponential backoff. We're guarding - # against the rendezvous server being temporarily offline. - t = self._timing.add("add", phase=phase, wait=wait) - yield self._ws_send(u"add", channelid=self._channelid, phase=phase, - body=hexlify(body).decode("ascii")) - if wait: - while phase not in self._delivered_messages: - yield self._sleep() - t.finish() - - def _ws_handle_message(self, msg): - m = msg["message"] - phase = m["phase"] - body = unhexlify(m["body"].encode("ascii")) - if phase in self._sent_messages and self._sent_messages[phase] == body: - self._delivered_messages.add(phase) # ack by server - self._wakeup() - return # ignore echoes of our outbound messages - if phase in self._received_messages: - # a channel collision would cause this - err = ServerError("got duplicate phase %s" % phase, self._ws_url) - return self._signal_error(err) - self._received_messages[phase] = body - if phase == u"confirm": - # TODO: we might not have a master key yet, if the caller wasn't - # waiting in _get_master_key() when a back-to-back pake+_confirm - # message pair arrived. - confkey = self.derive_key(u"wormhole:confirmation") - nonce = body[:CONFMSG_NONCE_LENGTH] - if body != make_confmsg(confkey, nonce): - # this makes all API calls fail - return self._signal_error(WrongPasswordError()) - # now notify anyone waiting on it - self._wakeup() - - @inlineCallbacks - def _msg_get(self, phase): - with self._timing.add("get", phase=phase): - while phase not in self._received_messages: - yield self._sleep() # we can wait a long time here - # that will throw an error if something goes wrong - msg = self._received_messages[phase] - returnValue(msg) - def derive_key(self, purpose, length=SecretBox.KEY_SIZE): if not isinstance(purpose, type(u"")): raise TypeError(type(purpose)) if self._key is None: @@ -548,8 +603,7 @@ class _Wormhole: @inlineCallbacks def _release(self, mood): with self._timing.add("release"): - yield self._ws_send(u"release", channelid=self._channelid, - mood=mood) + yield self._ws_send_command(u"release", mood=mood) while self._released_status is None: yield self._sleep(wake_on_error=False) # TODO: set a timeout, don't wait forever for an ack @@ -562,6 +616,7 @@ class _Wormhole: def wormhole(appid, relay_url, reactor, tor_manager=None, timing=None): w = _Wormhole(appid, relay_url, reactor, tor_manager, timing) + w._start() return w def wormhole_from_serialized(data, reactor): From 181ef04a91672a1e833bc23e7450d3a674934b96 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Fri, 20 May 2016 16:39:59 -0700 Subject: [PATCH 30/51] break out more message components, use SidedMessage --- src/wormhole/server/rendezvous.py | 39 ++++++++-------- src/wormhole/server/rendezvous_websocket.py | 19 ++++---- src/wormhole/test/test_server.py | 51 +++++++++++++-------- 3 files changed, 65 insertions(+), 44 deletions(-) diff --git a/src/wormhole/server/rendezvous.py b/src/wormhole/server/rendezvous.py index 018b8e2..d439fbb 100644 --- a/src/wormhole/server/rendezvous.py +++ b/src/wormhole/server/rendezvous.py @@ -52,6 +52,9 @@ TransitUsage = namedtuple("TransitUsage", ["started", "waiting_time", "total_time", "total_bytes", "result"]) +SidedMessage = namedtuple("SidedMessage", ["side", "phase", "body", + "server_rx", "msg_id"]) + class Mailbox: def __init__(self, app, db, blur_usage, log_requests, app_id, mailbox_id): self._app = app @@ -93,8 +96,10 @@ class Mailbox: " WHERE `app_id`=? AND `mailbox_id`=?" " ORDER BY `server_rx` ASC", (self._app_id, self._mailbox_id)).fetchall(): - messages.append({"phase": row["phase"], "body": row["body"], - "server_rx": row["server_rx"], "id": row["msg_id"]}) + sm = SidedMessage(side=row["side"], phase=row["phase"], + body=row["body"], server_rx=row["server_rx"], + msg_id=row["msg_id"]) + messages.append(sm) return messages def add_listener(self, handle, send_f, stop_f): @@ -104,25 +109,23 @@ class Mailbox: def remove_listener(self, handle): self._listeners.pop(handle) - def broadcast_message(self, phase, body, server_rx, msgid): + def broadcast_message(self, sm): for (send_f, stop_f) in self._listeners.values(): - send_f({"phase": phase, "body": body, - "server_rx": server_rx, "id": msgid}) + send_f(sm) - def _add_message(self, side, phase, body, server_rx, msgid): - db = self._db - db.execute("INSERT INTO `messages`" - " (`app_id`, `mailbox_id`, `side`, `phase`, `body`," - " `server_rx`, `msg_id`)" - " VALUES (?,?,?,?,?, ?,?)", - (self._app_id, self._mailbox_id, side, phase, body, - server_rx, msgid)) - db.commit() + def _add_message(self, sm): + self._db.execute("INSERT INTO `messages`" + " (`app_id`, `mailbox_id`, `side`, `phase`, `body`," + " `server_rx`, `msg_id`)" + " VALUES (?,?,?,?,?, ?,?)", + (self._app_id, self._mailbox_id, sm.side, + sm.phase, sm.body, sm.server_rx, sm.msg_id)) + self._db.commit() - def add_message(self, side, phase, body, server_rx, msgid): - self._add_message(side, phase, body, server_rx, msgid) - self.broadcast_message(phase, body, server_rx, msgid) - return self.get_messages() # for rendezvous_web.py POST /add + def add_message(self, sm): + assert isinstance(sm, SidedMessage) + self._add_message(sm) + self.broadcast_message(sm) def close(self, side, mood, when): assert isinstance(side, type(u"")), type(side) diff --git a/src/wormhole/server/rendezvous_websocket.py b/src/wormhole/server/rendezvous_websocket.py index 4129400..2a7e82f 100644 --- a/src/wormhole/server/rendezvous_websocket.py +++ b/src/wormhole/server/rendezvous_websocket.py @@ -2,7 +2,7 @@ import json, time from twisted.internet import reactor from twisted.python import log from autobahn.twisted import websocket -from .rendezvous import CrowdedError +from .rendezvous import CrowdedError, SidedMessage # The WebSocket allows the client to send "commands" to the server, and the # server to send "responses" to the client. Note that commands and responses @@ -63,7 +63,7 @@ from .rendezvous import CrowdedError # # -> {type: "open", mailbox: str} -> message # sends old messages now, and subscribes to deliver future messages -# <- {type: "message", message: {phase:, body:}} # body is hex +# <- {type: "message", side:, phase:, body:, msg_id:}} # body is hex # -> {type: "add", phase: str, body: hex} # will send echo in a "message" # # -> {type: "close", mood: str} -> closed @@ -191,12 +191,13 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): assert isinstance(mailbox_id, type(u"")) self._mailbox = self._app.open_mailbox(mailbox_id, self._side, server_rx) - def _send(event): - self.send("message", message=event) + def _send(sm): + self.send("message", side=sm.side, phase=sm.phase, + body=sm.body, server_rx=sm.server_rx, id=sm.msg_id) def _stop(): pass - for old_message in self._mailbox.add_listener(self, _send, _stop): - _send(old_message) + for old_sm in self._mailbox.add_listener(self, _send, _stop): + _send(old_sm) def handle_add(self, msg, server_rx): if not self._mailbox: @@ -206,8 +207,10 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): if "body" not in msg: raise Error("missing 'body'") msgid = msg.get("id") # optional - self._mailbox.add_message(self._side, msg["phase"], msg["body"], - server_rx, msgid) + sm = SidedMessage(side=self._side, phase=msg["phase"], + body=msg["body"], server_rx=server_rx, + msg_id=msgid) + self._mailbox.add_message(sm) def handle_close(self, msg, server_rx): if not self._mailbox: diff --git a/src/wormhole/test/test_server.py b/src/wormhole/test/test_server.py index b654a24..4d78bf4 100644 --- a/src/wormhole/test/test_server.py +++ b/src/wormhole/test/test_server.py @@ -10,7 +10,7 @@ from autobahn.twisted import websocket from .. import __version__ from .common import ServerBase from ..server import rendezvous, transit_server -from ..server.rendezvous import Usage +from ..server.rendezvous import Usage, SidedMessage class Reachable(ServerBase, unittest.TestCase): @@ -254,7 +254,9 @@ class Server(ServerBase, unittest.TestCase): app = self._rendezvous.get_app(u"appid") mailbox_id = u"mid" m1 = app.open_mailbox(mailbox_id, u"side1", 0) - m1.add_message(u"side1", u"phase", u"body", 1, u"msgid") + m1.add_message(SidedMessage(side=u"side1", phase=u"phase", + body=u"body", server_rx=1, + msg_id=u"msgid")) msgs = self._messages(app) self.assertEqual(len(msgs), 1) self.assertEqual(msgs[0]["body"], u"body") @@ -263,34 +265,43 @@ class Server(ServerBase, unittest.TestCase): l2 = []; stop2 = []; stop2_f = lambda: stop2.append(True) old = m1.add_listener("handle1", l1.append, stop1_f) self.assertEqual(len(old), 1) - self.assertEqual(old[0]["body"], u"body") + self.assertEqual(old[0].side, u"side1") + self.assertEqual(old[0].body, u"body") - m1.add_message(u"side1", u"phase2", u"body2", 1, u"msgid") + m1.add_message(SidedMessage(side=u"side1", phase=u"phase2", + body=u"body2", server_rx=1, + msg_id=u"msgid")) self.assertEqual(len(l1), 1) - self.assertEqual(l1[0]["body"], u"body2") + self.assertEqual(l1[0].body, u"body2") old = m1.add_listener("handle2", l2.append, stop2_f) self.assertEqual(len(old), 2) - m1.add_message(u"side1", u"phase3", u"body3", 1, u"msgid") + m1.add_message(SidedMessage(side=u"side1", phase=u"phase3", + body=u"body3", server_rx=1, + msg_id=u"msgid")) self.assertEqual(len(l1), 2) - self.assertEqual(l1[-1]["body"], u"body3") + self.assertEqual(l1[-1].body, u"body3") self.assertEqual(len(l2), 1) - self.assertEqual(l2[-1]["body"], u"body3") + self.assertEqual(l2[-1].body, u"body3") m1.remove_listener("handle1") - m1.add_message(u"side1", u"phase4", u"body4", 1, u"msgid") + m1.add_message(SidedMessage(side=u"side1", phase=u"phase4", + body=u"body4", server_rx=1, + msg_id=u"msgid")) self.assertEqual(len(l1), 2) - self.assertEqual(l1[-1]["body"], u"body3") + self.assertEqual(l1[-1].body, u"body3") self.assertEqual(len(l2), 2) - self.assertEqual(l2[-1]["body"], u"body4") + self.assertEqual(l2[-1].body, u"body4") m1._shutdown() self.assertEqual(stop1, []) self.assertEqual(stop2, [True]) # message adds are not idempotent: clients filter duplicates - m1.add_message(u"side1", u"phase", u"body", 1, u"msgid") + m1.add_message(SidedMessage(side=u"side1", phase=u"phase", + body=u"body", server_rx=1, + msg_id=u"msgid")) msgs = self._messages(app) self.assertEqual(len(msgs), 5) self.assertEqual(msgs[-1]["body"], u"body") @@ -654,17 +665,21 @@ class WebSocketAPI(ServerBase, unittest.TestCase): self.assertEqual(err[u"error"], u"open requires 'mailbox'") mb1 = app.open_mailbox(u"mb1", u"side2", 0) - mb1.add_message(u"side2", u"phase", u"body", 0, u"msgid") + mb1.add_message(SidedMessage(side=u"side2", phase=u"phase", + body=u"body", server_rx=0, + msg_id=u"msgid")) c1.send(u"open", mailbox=u"mb1") m = yield c1.next_non_ack() self.assertEqual(m[u"type"], u"message") - self.assertEqual(m[u"message"][u"body"], u"body") + self.assertEqual(m[u"body"], u"body") - mb1.add_message(u"side2", u"phase2", u"body2", 0, u"msgid") + mb1.add_message(SidedMessage(side=u"side2", phase=u"phase2", + body=u"body2", server_rx=0, + msg_id=u"msgid")) m = yield c1.next_non_ack() self.assertEqual(m[u"type"], u"message") - self.assertEqual(m[u"message"][u"body"], u"body2") + self.assertEqual(m[u"body"], u"body2") c1.send(u"open", mailbox=u"mb1") err = yield c1.next_non_ack() @@ -701,10 +716,10 @@ class WebSocketAPI(ServerBase, unittest.TestCase): c1.send(u"add", phase=u"phase", body=u"body") m = yield c1.next_non_ack() # echoed back self.assertEqual(m[u"type"], u"message") - self.assertEqual(m[u"message"][u"body"], u"body") + self.assertEqual(m[u"body"], u"body") self.assertEqual(len(l1), 1) - self.assertEqual(l1[0][u"body"], u"body") + self.assertEqual(l1[0].body, u"body") @inlineCallbacks def test_close(self): From 53bbcc33f633d85f58e64c73bd6eb53f684b9b51 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Fri, 20 May 2016 18:49:20 -0700 Subject: [PATCH 31/51] new file, state-machine based --- events.dot | 38 +++ src/wormhole/cli/cmd_receive.py | 9 +- src/wormhole/wormhole.py | 465 ++++++++++++++++++++++++++++++++ 3 files changed, 508 insertions(+), 4 deletions(-) create mode 100644 events.dot create mode 100644 src/wormhole/wormhole.py diff --git a/events.dot b/events.dot new file mode 100644 index 0000000..2a280da --- /dev/null +++ b/events.dot @@ -0,0 +1,38 @@ +digraph { + event_learned_code [label="learned\ncode" style="bold"] + event_learned_nameplate [label="learned\nnameplate" style="bold"] + event_learned_mailbox [label="learned\nmailbox" style="bold"] + event_connected [label="connected" style="bold"] + event_built_msg1 [label="built\nmsg1" style="bold"] + event_mailbox_used [label="mailbox\nused" style="bold"] + event_learned_PAKE [label="learned\nmsg2" style="bold"] + event_established_key [label="established\nkey" style="bold"] + event_computed_verifier [label="computed\nverifier" style="bold"] + event_received_confirm [label="received\nconfirm" style="bold"] + + maybe_build_msg1 [label="build\nmsg1"] + maybe_get_mailbox [label="get\nmailbox"] + maybe_send_pake [label="send\npake"] + maybe_send_phase_messages [label="send\nphase\nmessages"] + + event_connected -> maybe_get_mailbox + + event_built_msg1 -> maybe_send_pake + + event_learned_code -> maybe_build_msg1 + event_learned_code -> event_learned_nameplate + + maybe_build_msg1 -> event_built_msg1 + event_learned_nameplate -> maybe_get_mailbox + + maybe_get_mailbox -> event_learned_mailbox [style="dashed"] + maybe_get_mailbox -> event_mailbox_used [style="dashed"] + maybe_get_mailbox -> event_learned_PAKE [style="dashed"] + maybe_get_mailbox -> event_received_confirm [style="dashed"] + + event_learned_mailbox -> maybe_send_pake + event_learned_mailbox -> maybe_send_phase_messages + + event_learned_PAKE -> event_established_key + event_established_key -> event_computed_verifier +} diff --git a/src/wormhole/cli/cmd_receive.py b/src/wormhole/cli/cmd_receive.py index 5d308a3..7813d75 100644 --- a/src/wormhole/cli/cmd_receive.py +++ b/src/wormhole/cli/cmd_receive.py @@ -99,10 +99,11 @@ class TwistedReceiver: if self.args.zeromode: assert not code code = u"0-" - if not code: - code = yield w.input_code("Enter receive wormhole code: ", - self.args.code_length) - yield w.set_code(code) + if code: + w.set_code(code) + else: + yield w.input_code("Enter receive wormhole code: ", + self.args.code_length) def show_verifier(self, verifier): verifier_hex = binascii.hexlify(verifier).decode("ascii") diff --git a/src/wormhole/wormhole.py b/src/wormhole/wormhole.py new file mode 100644 index 0000000..ca3d7c7 --- /dev/null +++ b/src/wormhole/wormhole.py @@ -0,0 +1,465 @@ +from __future__ import print_function +import os, sys, json, re, unicodedata +from six.moves.urllib_parse import urlparse +from binascii import hexlify, unhexlify +from twisted.internet import defer, endpoints, error +from twisted.internet.threads import deferToThread, blockingCallFromThread +from twisted.internet.defer import inlineCallbacks, returnValue +from twisted.python import log +from autobahn.twisted import websocket +from nacl.secret import SecretBox +from nacl.exceptions import CryptoError +from nacl import utils +from spake2 import SPAKE2_Symmetric +from .. import __version__ +from .. import codes +from ..errors import ServerError, Timeout, WrongPasswordError, UsageError +from ..timing import DebugTiming +from hkdf import Hkdf + +def HKDF(skm, outlen, salt=None, CTXinfo=b""): + return Hkdf(salt, skm).expand(CTXinfo, outlen) + +CONFMSG_NONCE_LENGTH = 128//8 +CONFMSG_MAC_LENGTH = 256//8 +def make_confmsg(confkey, nonce): + return nonce+HKDF(confkey, CONFMSG_MAC_LENGTH, nonce) + +def to_bytes(u): + return unicodedata.normalize("NFC", u).encode("utf-8") + +# We send the following messages through the relay server to the far side (by +# sending "add" commands to the server, and getting "message" responses): +# +# phase=setup: +# * unauthenticated version strings (but why?) +# * early warmup for connection hints ("I can do tor, spin up HS") +# * wordlist l10n identifier +# phase=pake: just the SPAKE2 'start' message (binary) +# phase=confirm: key verification (HKDF(key, nonce)+nonce) +# phase=1,2,3,..: application messages + +class WSClient(websocket.WebSocketClientProtocol): + def onOpen(self): + self.wormhole_open = True + self.factory.d.callback(self) + + def onMessage(self, payload, isBinary): + assert not isBinary + self.wormhole._ws_dispatch_response(payload) + + def onClose(self, wasClean, code, reason): + if self.wormhole_open: + self.wormhole._ws_closed(wasClean, code, reason) + else: + # we closed before establishing a connection (onConnect) or + # finishing WebSocket negotiation (onOpen): errback + self.factory.d.errback(error.ConnectError(reason)) + +class WSFactory(websocket.WebSocketClientFactory): + protocol = WSClient + def buildProtocol(self, addr): + proto = websocket.WebSocketClientFactory.buildProtocol(self, addr) + proto.wormhole = self.wormhole + proto.wormhole_open = False + return proto + + +class _GetCode: + def __init__(self, code_length, send_command): + self._code_length = code_length + self._send_command = send_command + self._allocated_d = defer.Deferred() + + @inlineCallbacks + def go(self): + with self._timing.add("allocate"): + self._send_command(u"allocate") + nameplate_id = yield self._allocated_d + code = codes.make_code(nameplate_id, self._code_length) + assert isinstance(code, type(u"")), type(code) + returnValue(code) + + def _ws_handle_allocated(self, msg): + nid = msg["nameplate"] + assert isinstance(nid, type(u"")), type(nid) + self._allocated_d.callback(nid) + +class _InputCode: + def __init__(self, reactor, prompt, code_length, send_command): + self._reactor = reactor + self._prompt = prompt + self._code_length = code_length + self._send_command = send_command + + @inlineCallbacks + def _list(self): + self._lister_d = defer.Deferred() + self._send_command(u"list") + nameplates = yield self._lister_d + self._lister_d = None + returnValue(nameplates) + + def _list_blocking(self): + return blockingCallFromThread(self._reactor, self._list) + + @inlineCallbacks + def go(self): + # fetch the list of nameplates ahead of time, to give us a chance to + # discover the welcome message (and warn the user about an obsolete + # client) + # + # TODO: send the request early, show the prompt right away, hide the + # latency in the user's indecision and slow typing. If we're lucky + # the answer will come back before they hit TAB. + + initial_nameplate_ids = yield self._list() + with self._timing.add("input code", waiting="user"): + t = self._reactor.addSystemEventTrigger("before", "shutdown", + self._warn_readline) + code = yield deferToThread(codes.input_code_with_completion, + self._prompt, + initial_nameplate_ids, + self._list_blocking, + self._code_length) + self._reactor.removeSystemEventTrigger(t) + returnValue(code) + + def _ws_handle_nameplates(self, msg): + nameplates = msg["nameplates"] + assert isinstance(nameplates, list), type(nameplates) + for nameplate_id in nameplates: + assert isinstance(nameplate_id, type(u"")), type(nameplate_id) + self._lister_d.callback(nameplates) + + def _warn_readline(self): + # When our process receives a SIGINT, Twisted's SIGINT handler will + # stop the reactor and wait for all threads to terminate before the + # process exits. However, if we were waiting for + # input_code_with_completion() when SIGINT happened, the readline + # thread will be blocked waiting for something on stdin. Trick the + # user into satisfying the blocking read so we can exit. + print("\nCommand interrupted: please press Return to quit", + file=sys.stderr) + + # Other potential approaches to this problem: + # * hard-terminate our process with os._exit(1), but make sure the + # tty gets reset to a normal mode ("cooked"?) first, so that the + # next shell command the user types is echoed correctly + # * track down the thread (t.p.threadable.getThreadID from inside the + # thread), get a cffi binding to pthread_kill, deliver SIGINT to it + # * allocate a pty pair (pty.openpty), replace sys.stdin with the + # slave, build a pty bridge that copies bytes (and other PTY + # things) from the real stdin to the master, then close the slave + # at shutdown, so readline sees EOF + # * write tab-completion and basic editing (TTY raw mode, + # backspace-is-erase) without readline, probably with curses or + # twisted.conch.insults + # * write a separate program to get codes (maybe just "wormhole + # --internal-get-code"), run it as a subprocess, let it inherit + # stdin/stdout, send it SIGINT when we receive SIGINT ourselves. It + # needs an RPC mechanism (over some extra file descriptors) to ask + # us to fetch the current nameplate_id list. + # + # Note that hard-terminating our process with os.kill(os.getpid(), + # signal.SIGKILL), or SIGTERM, doesn't seem to work: the thread + # doesn't see the signal, and we must still wait for stdin to make + # readline finish. + + + +class _Wormhole: + def __init__(self): + self._connected = None + self._flag_need_mailbox = True + self._flag_need_to_see_mailbox_used = True + self._flag_need_to_build_msg1 = True + self._flag_need_to_send_PAKE = True + self._flag_need_PAKE = True + self._flag_need_key = True # rename to not self._key + + self._next_send_phase = 0 + self._phase_messages_to_send = [] # not yet acked by server + + self._next_receive_phase = 0 + self._phase_messages_received = {} # phase -> message + + + def _start(self): + d = self._connect() # causes stuff to happen + d.addErrback(log.err) + return d # fires when connection is established, if you care + + def _make_endpoint(self, hostname, port): + if self._tor_manager: + return self._tor_manager.get_endpoint_for(hostname, port) + # note: HostnameEndpoints have a default 30s timeout + return endpoints.HostnameEndpoint(self._reactor, hostname, port) + + def _connect(self): + # TODO: if we lose the connection, make a new one, re-establish the + # state + assert self._side + p = urlparse(self._ws_url) + f = WSFactory(self._ws_url) + f.wormhole = self + f.d = defer.Deferred() + # TODO: if hostname="localhost", I get three factories starting + # and stopping (maybe 127.0.0.1, ::1, and something else?), and + # an error in the factory is masked. + ep = self._make_endpoint(p.hostname, p.port or 80) + # .connect errbacks if the TCP connection fails + d = ep.connect(f) + d.addCallback(self._event_connected) + # f.d is errbacked if WebSocket negotiation fails, and the WebSocket + # drops any data sent before onOpen() fires, so we must wait for it + d.addCallback(self._event_ws_opened) + return d + + def _event_connected(self, ws, f): + self._ws = ws + self._ws_t = self._timing.add("websocket") + + def _event_ws_opened(self, _): + self._connected = True + self._ws_send_command(u"bind", appid=self._appid, side=self._side) + self._maybe_get_mailbox() + + def _ws_handle_welcome(self, msg): + welcome = msg["welcome"] + if ("motd" in welcome and + not self.motd_displayed): + motd_lines = welcome["motd"].splitlines() + motd_formatted = "\n ".join(motd_lines) + print("Server (at %s) says:\n %s" % + (self._ws_url, motd_formatted), file=sys.stderr) + self.motd_displayed = True + + # Only warn if we're running a release version (e.g. 0.0.6, not + # 0.0.6-DISTANCE-gHASH). Only warn once. + if ("-" not in __version__ and + not self.version_warning_displayed and + welcome["current_version"] != __version__): + print("Warning: errors may occur unless both sides are running the same version", file=sys.stderr) + print("Server claims %s is current, but ours is %s" + % (welcome["current_version"], __version__), file=sys.stderr) + self.version_warning_displayed = True + + if "error" in welcome: + return self._signal_error(welcome["error"]) + + + # entry point 1: generate a new code + @inlineCallbacks + def get_code(self, code_length=2): # XX rename to allocate_code()? create_? + if self._code is not None: raise UsageError + if self._started_get_code: raise UsageError + self._started_get_code = True + with self._timing.add("API get_code"): + gc = _GetCode(code_length, self._ws_send_command) + self._ws_handle_allocated = gc._ws_handle_allocated + code = yield gc.go() + self._event_learned_code(code) + returnValue(code) + + # entry point 2: interactively type in a code, with completion + @inlineCallbacks + def input_code(self, prompt="Enter wormhole code: ", code_length=2): + if self._code is not None: raise UsageError + if self._started_input_code: raise UsageError + self._started_input_code = True + with self._timing.add("API input_code"): + gc = _InputCode(prompt, code_length, self._ws_send_command) + self._ws_handle_nameplates = gc._ws_handle_nameplates + code = yield gc.go() + self._event_learned_code(code) + returnValue(None) + + # entry point 3: paste in a fully-formed code + def set_code(self, code): + self._timing.add("API set_code") + if not isinstance(code, type(u"")): raise TypeError(type(code)) + if self._code is not None: raise UsageError + self._event_learned_code(code) + + def _event_learned_code(self, code): + self._timing.add("code established") + self._code = code + mo = re.search(r'^(\d+)-', code) + if not mo: + raise ValueError("code (%s) must start with NN-" % code) + nid = mo.group(1) + assert isinstance(nid, type(u"")), type(nid) + self._nameplate_id = nid + # fire more events + self._maybe_build_msg1() + self._event_learned_nameplate() + + def _maybe_build_msg1(self): + if not (self._code and self._flag_need_to_build_msg1): + return + with self._timing.add("pake1", waiting="crypto"): + self._sp = SPAKE2_Symmetric(to_bytes(self._code), + idSymmetric=to_bytes(self._appid)) + self._msg1 = self._sp.start() + self._flag_need_to_build_msg1 = False + self._event_built_msg1() + + def _event_built_msg1(self): + self._maybe_send_pake() + + # every _maybe_X starts with a set of conditions + # for each such condition Y, every _event_Y must call _maybe_X + + def _event_learned_nameplate(self): + self._maybe_get_mailbox() + + def _maybe_get_mailbox(self): + if not (self._flag_need_mailbox and self._nameplate_id + and self._connected): + return + self._ws_send_command(u"claim", nameplate=self._nameplate_id) + + def _ws_handle_claimed(self, msg): + mailbox_id = msg["mailbox"] + assert isinstance(mailbox_id, type(u"")), type(mailbox_id) + self._mailbox_id = mailbox_id + self._event_learned_mailbox() + + def _event_welcome(self): + pass + + def _event_learned_mailbox(self): + self._flag_need_mailbox = False + if not self._mailbox_id: raise UsageError + if self._mailbox_opened: raise UsageError + self._ws_send_command(u"open", mailbox=self._mailbox_id) + # causes old messages to be sent now, and subscribes to new messages + self._maybe_send_pake() + self._maybe_send_phase_messages() + + def _maybe_send_pake(self): + # TODO: deal with reentrant call + if not (self._connected and self._mailbox + and self._flag_need_to_send_PAKE): + return + d = self._msg_send(u"pake", self._msg1) + def _pake_sent(res): + self._flag_need_to_send_PAKE = False + d.addCallback(_pake_sent) + d.addErrback(log.err) + + def _maybe_send_phase_messages(self): + # TODO: deal with reentrant call + if not (self._connected and self._mailbox and self._key): + return + for pm in self._phase_messages_to_send: + (phase, message) = pm + d = self._msg_send(phase, message) + def _phase_message_sent(res, pm=pm): + try: + self._phase_messages_to_send.remove(pm) + except ValueError: + pass + d.addCallback(_phase_message_sent) + d.addErrback(log.err) + + + + def _event_received_message(self, msg): + pass + def _event_mailbox_used(self): + if self._flag_need_to_see_mailbox_used: + self._ws_send_command(u"release") + self._flag_need_to_see_mailbox_used = False + + def _event_learned_PAKE(self, pake_msg): + with self._timing.add("pake2", waiting="crypto"): + self._key = self._sp.finish(pake_msg) + self._event_established_key() + + def derive_key(self, purpose, length=SecretBox.KEY_SIZE): + if not isinstance(purpose, type(u"")): raise TypeError(type(purpose)) + if self._key is None: + # call after get_verifier() or get() + raise UsageError + return HKDF(self._key, length, CTXinfo=to_bytes(purpose)) + + def _event_established_key(self): + self._timing.add("key established") + if self._send_confirm: + # both sides send different (random) confirmation messages + confkey = self.derive_key(u"wormhole:confirmation") + nonce = os.urandom(CONFMSG_NONCE_LENGTH) + confmsg = make_confmsg(confkey, nonce) + self._msg_send(u"confirm", confmsg, wait=True) + verifier = self.derive_key(u"wormhole:verifier") + self._event_computed_verifier(verifier) + pass + def _event_computed_verifier(self, verifier): + self._verifier = verifier + d, self._verifier_waiter = self._verifier_waiter, None + if d: + d.callback(verifier) + + def _event_received_confirm(self, body): + # TODO: we might not have a master key yet, if the caller wasn't + # waiting in _get_master_key() when a back-to-back pake+_confirm + # message pair arrived. + confkey = self.derive_key(u"wormhole:confirmation") + nonce = body[:CONFMSG_NONCE_LENGTH] + if body != make_confmsg(confkey, nonce): + # this makes all API calls fail + return self._signal_error(WrongPasswordError()) + + def _event_received_phase_message(self, phase, message): + self._phase_messages_received[phase] = message + if phase in self._phase_message_waiters: + d = self._phase_message_waiters.pop(phase) + d.callback(message) + + def _ws_handle_message(self, msg): + side = msg["side"] + phase = msg["phase"] + body = unhexlify(msg["body"].encode("ascii")) + if side == self._side: + return + self._event_received_peer_message(phase, body) + + def XXXackstuff(): + if phase in self._sent_messages and self._sent_messages[phase] == body: + self._delivered_messages.add(phase) # ack by server + self._wakeup() + return # ignore echoes of our outbound messages + + def _event_received_peer_message(self, phase, body): + # any message in the mailbox means we no longer need the nameplate + self._event_mailbox_used() + #if phase in self._received_messages: + # # a nameplate collision would cause this + # err = ServerError("got duplicate phase %s" % phase, self._ws_url) + # return self._signal_error(err) + #self._received_messages[phase] = body + if phase == u"confirm": + self._event_received_confirm(body) + # now notify anyone waiting on it + self._wakeup() + + def _event_asked_to_send_phase_message(self, phase, message): + pm = (phase, message) + self._phase_messages_to_send.append(pm) + self._maybe_send_phase_messages() + + def _event_asked_to_close(self): + pass + + + +def wormhole(appid, relay_url, reactor, tor_manager=None, timing=None): + w = _Wormhole(appid, relay_url, reactor, tor_manager, timing) + w._start() + return w + +def wormhole_from_serialized(data, reactor): + w = _Wormhole.from_serialized(data, reactor) + return w From 0ee56e12b05cc77ddbf8fdc699f702eeed29607a Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Sun, 22 May 2016 11:01:44 -0700 Subject: [PATCH 32/51] change 'list' protocol, make room for nameplate attributes --- src/wormhole/server/rendezvous_websocket.py | 7 +++++-- src/wormhole/test/test_server.py | 7 ++++++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/wormhole/server/rendezvous_websocket.py b/src/wormhole/server/rendezvous_websocket.py index 2a7e82f..153b037 100644 --- a/src/wormhole/server/rendezvous_websocket.py +++ b/src/wormhole/server/rendezvous_websocket.py @@ -54,7 +54,7 @@ from .rendezvous import CrowdedError, SidedMessage # -> {type: "bind", appid:, side:} # # -> {type: "list"} -> nameplates -# <- {type: "nameplates", nameplates: [str..]} +# <- {type: "nameplates", nameplates: [{id: str,..},..]} # -> {type: "allocate"} -> nameplate, mailbox # <- {type: "allocated", nameplate: str} # -> {type: "claim", nameplate: str} -> mailbox @@ -152,7 +152,10 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): def handle_list(self): nameplate_ids = sorted(self._app.get_nameplate_ids()) - self.send("nameplates", nameplates=nameplate_ids) + # provide room to add nameplate attributes later (like which wordlist + # is used for each, maybe how many words) + nameplates = [{"id": nid} for nid in nameplate_ids] + self.send("nameplates", nameplates=nameplates) def handle_allocate(self, server_rx): if self._did_allocate: diff --git a/src/wormhole/test/test_server.py b/src/wormhole/test/test_server.py index 4d78bf4..90c03eb 100644 --- a/src/wormhole/test/test_server.py +++ b/src/wormhole/test/test_server.py @@ -539,7 +539,12 @@ class WebSocketAPI(ServerBase, unittest.TestCase): c1.send(u"list") m = yield c1.next_non_ack() self.assertEqual(m[u"type"], u"nameplates") - self.assertEqual(set(m[u"nameplates"]), set([nameplate_id1, u"np2"])) + nids = set() + for n in m[u"nameplates"]: + self.assertEqual(type(n), dict) + self.assertEqual(list(n.keys()), [u"id"]) + nids.add(n[u"id"]) + self.assertEqual(nids, set([nameplate_id1, u"np2"])) @inlineCallbacks def test_allocate(self): From 3da52b0a3ec22f69ce279346d1fb7e4dc56fc927 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Sun, 22 May 2016 11:31:00 -0700 Subject: [PATCH 33/51] add 'mock', building out test_wormhole --- src/wormhole/test/test_wormhole.py | 74 ++++++++ src/wormhole/wormhole.py | 273 ++++++++++++++++++++--------- tox.ini | 1 + 3 files changed, 270 insertions(+), 78 deletions(-) create mode 100644 src/wormhole/test/test_wormhole.py diff --git a/src/wormhole/test/test_wormhole.py b/src/wormhole/test/test_wormhole.py new file mode 100644 index 0000000..1591530 --- /dev/null +++ b/src/wormhole/test/test_wormhole.py @@ -0,0 +1,74 @@ +from __future__ import print_function +import json +import mock +from twisted.trial import unittest +from twisted.internet import reactor +from twisted.internet.defer import gatherResults, inlineCallbacks +#from ..twisted.transcribe import (wormhole, wormhole_from_serialized, +# UsageError, WrongPasswordError) +#from .common import ServerBase +from ..wormhole import _Wormhole, _WelcomeHandler +from ..timing import DebugTiming + +APPID = u"appid" + +class MockWebSocket: + def __init__(self): + self._payloads = [] + def sendMessage(self, payload, is_binary): + assert not is_binary + self._payloads.append(payload) + + def outbound(self): + out = [] + while self._payloads: + p = self._payloads.pop(0) + out.append(json.loads(p.decode("utf-8"))) + return out + +def response(w, **kwargs): + payload = json.dumps(kwargs).encode("utf-8") + w._ws_dispatch_response(payload) + +class Welcome(unittest.TestCase): + def test_no_current_version(self): + # WelcomeHandler should tolerate lack of ["current_version"] + w = _WelcomeHandler(u"relay_url", u"current_version") + w.handle_welcome({}) + + +class Basic(unittest.TestCase): + def test_create(self): + w = _Wormhole(APPID, u"relay_url", reactor, None, None) + + def test_basic(self): + # We don't call w._start(), so this doesn't create a WebSocket + # connection. We provide a mock connection instead. + timing = DebugTiming() + with mock.patch("wormhole.wormhole._WelcomeHandler") as whc: + w = _Wormhole(APPID, u"relay_url", reactor, None, timing) + wh = whc.return_value + #w._welcomer = mock.Mock() + # w._connect = lambda self: None + # w._event_connected(mock_ws) + # w._event_ws_opened() + # w._ws_dispatch_response(payload) + self.assertEqual(w._ws_url, u"relay_url") + ws = MockWebSocket() + w._event_connected(ws) + out = ws.outbound() + self.assertEqual(len(out), 0) + + w._event_ws_opened(None) + out = ws.outbound() + self.assertEqual(len(out), 1) + self.assertEqual(out[0]["type"], u"bind") + self.assertEqual(out[0]["appid"], APPID) + self.assertEqual(out[0]["side"], w._side) + self.assertIn(u"id", out[0]) + + # WelcomeHandler should get called upon 'welcome' response + WELCOME = {u"foo": u"bar"} + response(w, type="welcome", welcome=WELCOME) + self.assertEqual(wh.mock_calls, [mock.call.handle_welcome(WELCOME)]) + diff --git a/src/wormhole/wormhole.py b/src/wormhole/wormhole.py index ca3d7c7..2a16bb7 100644 --- a/src/wormhole/wormhole.py +++ b/src/wormhole/wormhole.py @@ -1,4 +1,4 @@ -from __future__ import print_function +from __future__ import print_function, absolute_import import os, sys, json, re, unicodedata from six.moves.urllib_parse import urlparse from binascii import hexlify, unhexlify @@ -11,10 +11,11 @@ from nacl.secret import SecretBox from nacl.exceptions import CryptoError from nacl import utils from spake2 import SPAKE2_Symmetric -from .. import __version__ -from .. import codes -from ..errors import ServerError, Timeout, WrongPasswordError, UsageError -from ..timing import DebugTiming +from . import __version__ +from . import codes +#from .errors import ServerError, Timeout +from .errors import WrongPasswordError, UsageError +#from .timing import DebugTiming from hkdf import Hkdf def HKDF(skm, outlen, salt=None, CTXinfo=b""): @@ -80,7 +81,7 @@ class _GetCode: assert isinstance(code, type(u"")), type(code) returnValue(code) - def _ws_handle_allocated(self, msg): + def _response_handle_allocated(self, msg): nid = msg["nameplate"] assert isinstance(nid, type(u"")), type(nid) self._allocated_d.callback(nid) @@ -125,12 +126,16 @@ class _InputCode: self._reactor.removeSystemEventTrigger(t) returnValue(code) - def _ws_handle_nameplates(self, msg): + def _response_handle_nameplates(self, msg): nameplates = msg["nameplates"] assert isinstance(nameplates, list), type(nameplates) - for nameplate_id in nameplates: + nids = [] + for n in nameplates: + assert isinstance(n, dict), type(n) + nameplate_id = n[u"id"] assert isinstance(nameplate_id, type(u"")), type(nameplate_id) - self._lister_d.callback(nameplates) + nids.append(nameplate_id) + self._lister_d.callback(nids) def _warn_readline(self): # When our process receives a SIGINT, Twisted's SIGINT handler will @@ -166,11 +171,53 @@ class _InputCode: # doesn't see the signal, and we must still wait for stdin to make # readline finish. +class _WelcomeHandler: + def __init__(self, url, current_version): + self._ws_url = url + self._version_warning_displayed = False + self._motd_displayed = False + self._current_version = current_version + + def handle_welcome(self, welcome): + if ("motd" in welcome and + not self._motd_displayed): + motd_lines = welcome["motd"].splitlines() + motd_formatted = "\n ".join(motd_lines) + print("Server (at %s) says:\n %s" % + (self._ws_url, motd_formatted), file=sys.stderr) + self._motd_displayed = True + + # Only warn if we're running a release version (e.g. 0.0.6, not + # 0.0.6-DISTANCE-gHASH). Only warn once. + if ("current_version" in welcome + and "-" not in self._current_version + and not self._version_warning_displayed + and welcome["current_version"] != self._current_version): + print("Warning: errors may occur unless both sides are running the same version", file=sys.stderr) + print("Server claims %s is current, but ours is %s" + % (welcome["current_version"], self._current_version), + file=sys.stderr) + self._version_warning_displayed = True + + if "error" in welcome: + return self._signal_error(welcome["error"]) class _Wormhole: - def __init__(self): + def __init__(self, appid, relay_url, reactor, tor_manager, timing): + self._appid = appid + self._ws_url = relay_url + self._reactor = reactor + self._tor_manager = tor_manager + self._timing = timing + + self._welcomer = _WelcomeHandler(self._ws_url, __version__) + self._side = hexlify(os.urandom(5)).decode("ascii") self._connected = None + self._nameplate_id = None + self._mailbox_id = None + self._mailbox_opened = False + self._mailbox_closed = False self._flag_need_mailbox = True self._flag_need_to_see_mailbox_used = True self._flag_need_to_build_msg1 = True @@ -179,9 +226,11 @@ class _Wormhole: self._flag_need_key = True # rename to not self._key self._next_send_phase = 0 + self._plaintext_to_send = [] # (phase, plaintext, deferred) self._phase_messages_to_send = [] # not yet acked by server self._next_receive_phase = 0 + self._receive_waiters = {} # phase -> Deferred self._phase_messages_received = {} # phase -> message @@ -213,10 +262,11 @@ class _Wormhole: d.addCallback(self._event_connected) # f.d is errbacked if WebSocket negotiation fails, and the WebSocket # drops any data sent before onOpen() fires, so we must wait for it + d.addCallback(lambda _: f.d) d.addCallback(self._event_ws_opened) return d - def _event_connected(self, ws, f): + def _event_connected(self, ws): self._ws = ws self._ws_t = self._timing.add("websocket") @@ -224,30 +274,35 @@ class _Wormhole: self._connected = True self._ws_send_command(u"bind", appid=self._appid, side=self._side) self._maybe_get_mailbox() + self._maybe_send_pake() - def _ws_handle_welcome(self, msg): - welcome = msg["welcome"] - if ("motd" in welcome and - not self.motd_displayed): - motd_lines = welcome["motd"].splitlines() - motd_formatted = "\n ".join(motd_lines) - print("Server (at %s) says:\n %s" % - (self._ws_url, motd_formatted), file=sys.stderr) - self.motd_displayed = True + def _ws_send_command(self, mtype, **kwargs): + # msgid is used by misc/dump-timing.py to correlate our sends with + # their receives, and vice versa. They are also correlated with the + # ACKs we get back from the server (which we otherwise ignore). There + # are so few messages, 16 bits is enough to be mostly-unique. + kwargs["id"] = hexlify(os.urandom(2)).decode("ascii") + kwargs["type"] = mtype + payload = json.dumps(kwargs).encode("utf-8") + self._timing.add("ws_send", _side=self._side, **kwargs) + self._ws.sendMessage(payload, False) - # Only warn if we're running a release version (e.g. 0.0.6, not - # 0.0.6-DISTANCE-gHASH). Only warn once. - if ("-" not in __version__ and - not self.version_warning_displayed and - welcome["current_version"] != __version__): - print("Warning: errors may occur unless both sides are running the same version", file=sys.stderr) - print("Server claims %s is current, but ours is %s" - % (welcome["current_version"], __version__), file=sys.stderr) - self.version_warning_displayed = True + def _ws_dispatch_response(self, payload): + msg = json.loads(payload.decode("utf-8")) + self._timing.add("ws_receive", _side=self._side, message=msg) + mtype = msg["type"] + meth = getattr(self, "_response_handle_"+mtype, None) + if not meth: + # make tests fail, but real application will ignore it + log.err(ValueError("Unknown inbound message type %r" % (msg,))) + return + return meth(msg) - if "error" in welcome: - return self._signal_error(welcome["error"]) + def _response_handle_ack(self, msg): + pass + def _response_handle_welcome(self, msg): + self._welcomer.handle_welcome(msg["welcome"]) # entry point 1: generate a new code @inlineCallbacks @@ -257,7 +312,7 @@ class _Wormhole: self._started_get_code = True with self._timing.add("API get_code"): gc = _GetCode(code_length, self._ws_send_command) - self._ws_handle_allocated = gc._ws_handle_allocated + self._response_handle_allocated = gc._response_handle_allocated code = yield gc.go() self._event_learned_code(code) returnValue(code) @@ -269,9 +324,9 @@ class _Wormhole: if self._started_input_code: raise UsageError self._started_input_code = True with self._timing.add("API input_code"): - gc = _InputCode(prompt, code_length, self._ws_send_command) - self._ws_handle_nameplates = gc._ws_handle_nameplates - code = yield gc.go() + ic = _InputCode(prompt, code_length, self._ws_send_command) + self._response_handle_nameplates = ic._response_handle_nameplates + code = yield ic.go() self._event_learned_code(code) returnValue(None) @@ -320,15 +375,12 @@ class _Wormhole: return self._ws_send_command(u"claim", nameplate=self._nameplate_id) - def _ws_handle_claimed(self, msg): + def _response_handle_claimed(self, msg): mailbox_id = msg["mailbox"] assert isinstance(mailbox_id, type(u"")), type(mailbox_id) self._mailbox_id = mailbox_id self._event_learned_mailbox() - def _event_welcome(self): - pass - def _event_learned_mailbox(self): self._flag_need_mailbox = False if not self._mailbox_id: raise UsageError @@ -340,7 +392,7 @@ class _Wormhole: def _maybe_send_pake(self): # TODO: deal with reentrant call - if not (self._connected and self._mailbox + if not (self._connected and self._mailbox_opened and self._flag_need_to_send_PAKE): return d = self._msg_send(u"pake", self._msg1) @@ -349,42 +401,11 @@ class _Wormhole: d.addCallback(_pake_sent) d.addErrback(log.err) - def _maybe_send_phase_messages(self): - # TODO: deal with reentrant call - if not (self._connected and self._mailbox and self._key): - return - for pm in self._phase_messages_to_send: - (phase, message) = pm - d = self._msg_send(phase, message) - def _phase_message_sent(res, pm=pm): - try: - self._phase_messages_to_send.remove(pm) - except ValueError: - pass - d.addCallback(_phase_message_sent) - d.addErrback(log.err) - - - - def _event_received_message(self, msg): - pass - def _event_mailbox_used(self): - if self._flag_need_to_see_mailbox_used: - self._ws_send_command(u"release") - self._flag_need_to_see_mailbox_used = False - def _event_learned_PAKE(self, pake_msg): with self._timing.add("pake2", waiting="crypto"): self._key = self._sp.finish(pake_msg) self._event_established_key() - def derive_key(self, purpose, length=SecretBox.KEY_SIZE): - if not isinstance(purpose, type(u"")): raise TypeError(type(purpose)) - if self._key is None: - # call after get_verifier() or get() - raise UsageError - return HKDF(self._key, length, CTXinfo=to_bytes(purpose)) - def _event_established_key(self): self._timing.add("key established") if self._send_confirm: @@ -395,7 +416,8 @@ class _Wormhole: self._msg_send(u"confirm", confmsg, wait=True) verifier = self.derive_key(u"wormhole:verifier") self._event_computed_verifier(verifier) - pass + self._maybe_send_phase_messages() + def _event_computed_verifier(self, verifier): self._verifier = verifier d, self._verifier_waiter = self._verifier_waiter, None @@ -412,13 +434,85 @@ class _Wormhole: # this makes all API calls fail return self._signal_error(WrongPasswordError()) + + @inlineCallbacks + def send(self, outbound_data, wait=False): + if not isinstance(outbound_data, type(b"")): + raise TypeError(type(outbound_data)) + if self._closed: raise UsageError + phase = self._next_send_phase + self._next_send_phase += 1 + d = defer.Deferred() + self._plaintext_to_send.append( (phase, outbound_data, d) ) + with self._timing.add("API send", phase=phase, wait=wait): + self._maybe_send_phase_messages() + if wait: + yield d + + def _maybe_send_phase_messages(self): + # TODO: deal with reentrant call + if not (self._connected and self._mailbox_opened and self._key): + return + plaintexts = self._plaintext_to_send + self._plaintext_to_send = [] + for pm in plaintexts: + (phase, plaintext, wait_d) = pm + data_key = self.derive_key(u"wormhole:phase:%d" % phase) + encrypted = self._encrypt_data(data_key, plaintext) + d = self._msg_send(phase, encrypted) + d.addBoth(wait_d.callback) + d.addErrback(log.err) + + def _encrypt_data(self, key, data): + # Without predefined roles, we can't derive predictably unique keys + # for each side, so we use the same key for both. We use random + # nonces to keep the messages distinct, and we automatically ignore + # reflections. + assert isinstance(key, type(b"")), type(key) + assert isinstance(data, type(b"")), type(data) + assert len(key) == SecretBox.KEY_SIZE, len(key) + box = SecretBox(key) + nonce = utils.random(SecretBox.NONCE_SIZE) + return box.encrypt(data, nonce) + + @inlineCallbacks + def _msg_send(self, phase, body, wait=False): + if phase in self._sent_messages: raise UsageError + if not self._mailbox_opened: raise UsageError + if self._mailbox_closed: raise UsageError + self._sent_messages[phase] = body + # TODO: retry on failure, with exponential backoff. We're guarding + # against the rendezvous server being temporarily offline. + t = self._timing.add("add", phase=phase, wait=wait) + yield self._ws_send_command(u"add", phase=phase, + body=hexlify(body).decode("ascii")) + if wait: + while phase not in self._delivered_messages: + yield self._sleep() + t.finish() + + + def _event_received_message(self, msg): + pass + def _event_mailbox_used(self): + if self._flag_need_to_see_mailbox_used: + self._ws_send_command(u"release") + self._flag_need_to_see_mailbox_used = False + + def derive_key(self, purpose, length=SecretBox.KEY_SIZE): + if not isinstance(purpose, type(u"")): raise TypeError(type(purpose)) + if self._key is None: + # call after get_verifier() or get() + raise UsageError + return HKDF(self._key, length, CTXinfo=to_bytes(purpose)) + def _event_received_phase_message(self, phase, message): self._phase_messages_received[phase] = message if phase in self._phase_message_waiters: d = self._phase_message_waiters.pop(phase) d.callback(message) - def _ws_handle_message(self, msg): + def _response_handle_message(self, msg): side = msg["side"] phase = msg["phase"] body = unhexlify(msg["body"].encode("ascii")) @@ -443,12 +537,35 @@ class _Wormhole: if phase == u"confirm": self._event_received_confirm(body) # now notify anyone waiting on it - self._wakeup() + try: + data_key = self.derive_key(u"wormhole:phase:%s" % phase) + inbound_data = self._decrypt_data(data_key, body) + except CryptoError: + raise WrongPasswordError + self._phase_messages_received[phase] = inbound_data + if phase in self._receive_waiters: + d = self._receive_waiters.pop(phase) + d.callback(inbound_data) - def _event_asked_to_send_phase_message(self, phase, message): - pm = (phase, message) - self._phase_messages_to_send.append(pm) - self._maybe_send_phase_messages() + def _decrypt_data(self, key, encrypted): + assert isinstance(key, type(b"")), type(key) + assert isinstance(encrypted, type(b"")), type(encrypted) + assert len(key) == SecretBox.KEY_SIZE, len(key) + box = SecretBox(key) + data = box.decrypt(encrypted) + return data + + @inlineCallbacks + def get(self): + if self._closed: raise UsageError + if self._code is None: raise UsageError + phase = self._next_receive_phase + self._next_receive_phase += 1 + with self._timing.add("API get", phase=phase): + if phase in self._phase_messages_received: + returnValue(self._phase_messages_received[phase]) + d = self._receive_waiters[phase] = defer.Deferred() + yield d def _event_asked_to_close(self): pass diff --git a/tox.ini b/tox.ini index 7ee4127..326afa5 100644 --- a/tox.ini +++ b/tox.ini @@ -19,6 +19,7 @@ skip_missing_interpreters = True [testenv] deps = pyflakes >= 1.2.3 + mock {env:EXTRA_DEPENDENCY:} commands = pyflakes setup.py src From c6ba55c6b53393fd376b2fdf36451ef5e0146c09 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Sun, 22 May 2016 11:31:15 -0700 Subject: [PATCH 34/51] better diagram --- events.dot | 44 ++++++++++++++++++++++++++++++++++---------- 1 file changed, 34 insertions(+), 10 deletions(-) diff --git a/events.dot b/events.dot index 2a280da..66ea93b 100644 --- a/events.dot +++ b/events.dot @@ -1,14 +1,24 @@ digraph { - event_learned_code [label="learned\ncode" style="bold"] - event_learned_nameplate [label="learned\nnameplate" style="bold"] - event_learned_mailbox [label="learned\nmailbox" style="bold"] - event_connected [label="connected" style="bold"] - event_built_msg1 [label="built\nmsg1" style="bold"] - event_mailbox_used [label="mailbox\nused" style="bold"] - event_learned_PAKE [label="learned\nmsg2" style="bold"] - event_established_key [label="established\nkey" style="bold"] - event_computed_verifier [label="computed\nverifier" style="bold"] - event_received_confirm [label="received\nconfirm" style="bold"] + event_learned_code [label="learned\ncode" shape="box"] + event_learned_nameplate [label="learned\nnameplate" shape="box"] + event_learned_mailbox [label="learned\nmailbox" shape="box"] + event_connected [label="connected" shape="box"] + event_built_msg1 [label="built\nmsg1" shape="box"] + event_mailbox_used [label="mailbox\nused" shape="box"] + event_learned_PAKE [label="learned\nmsg2" shape="box"] + event_established_key [label="established\nkey" shape="box"] + event_computed_verifier [label="computed\nverifier" shape="box"] + event_received_confirm [label="received\nconfirm" shape="box"] + + event_connected -> api_get_code + event_connected -> api_input_code + api_get_code [label="get_code" shape="hexagon"] + api_input_code [label="input_code" shape="hexagon"] + api_set_code [label="set_code" shape="hexagon"] + api_get_code -> event_learned_code + api_input_code -> event_learned_code + api_set_code -> event_learned_code + maybe_build_msg1 [label="build\nmsg1"] maybe_get_mailbox [label="get\nmailbox"] @@ -16,6 +26,7 @@ digraph { maybe_send_phase_messages [label="send\nphase\nmessages"] event_connected -> maybe_get_mailbox + event_connected -> maybe_send_pake event_built_msg1 -> maybe_send_pake @@ -30,9 +41,22 @@ digraph { maybe_get_mailbox -> event_learned_PAKE [style="dashed"] maybe_get_mailbox -> event_received_confirm [style="dashed"] + event_learned_mailbox -> event_learned_PAKE [style="dashed"] + event_learned_PAKE -> event_mailbox_used [style="dashed"] + event_mailbox_used -> event_received_confirm [style="dashed"] + + send [label="API\nsend" shape="hexagon"] + send -> maybe_send_phase_messages + event_mailbox_used -> release event_learned_mailbox -> maybe_send_pake event_learned_mailbox -> maybe_send_phase_messages event_learned_PAKE -> event_established_key event_established_key -> event_computed_verifier + event_established_key -> maybe_send_phase_messages + + check_verifier [label="check\nverifier"] + event_computed_verifier -> check_verifier + event_received_confirm -> check_verifier + } From c10fd9816740fa5bd5f5dfda3c8a44956b09abb8 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Sun, 22 May 2016 18:40:44 -0700 Subject: [PATCH 35/51] many tests working * add "released" ack-response for "release" command, to sync w.close() * move websocket URL to root * relayurl= should now be a "ws://" URL * many tests pass (except for test_twisted, which will be removed, and test_scripts) * still moving integration tests from test_twisted to test_wormhole.Wormholes --- src/wormhole/server/rendezvous_websocket.py | 2 + src/wormhole/server/server.py | 6 +- src/wormhole/test/common.py | 3 +- src/wormhole/test/test_server.py | 29 +- src/wormhole/test/test_wormhole.py | 578 +++++++++++++++++++- src/wormhole/wormhole.py | 237 +++++--- 6 files changed, 717 insertions(+), 138 deletions(-) diff --git a/src/wormhole/server/rendezvous_websocket.py b/src/wormhole/server/rendezvous_websocket.py index 153b037..05e2700 100644 --- a/src/wormhole/server/rendezvous_websocket.py +++ b/src/wormhole/server/rendezvous_websocket.py @@ -60,6 +60,7 @@ from .rendezvous import CrowdedError, SidedMessage # -> {type: "claim", nameplate: str} -> mailbox # <- {type: "claimed", mailbox: str} # -> {type: "release"} +# <- {type: "released"} # # -> {type: "open", mailbox: str} -> message # sends old messages now, and subscribes to deliver future messages @@ -183,6 +184,7 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): raise Error("must claim a nameplate before releasing it") self._app.release_nameplate(self._nameplate_id, self._side, server_rx) self._nameplate_id = None + self.send("released") def handle_open(self, msg, server_rx): diff --git a/src/wormhole/server/server.py b/src/wormhole/server/server.py index 694ed19..b847bc7 100644 --- a/src/wormhole/server/server.py +++ b/src/wormhole/server/server.py @@ -45,12 +45,8 @@ class RelayServer(service.MultiService): rendezvous = Rendezvous(db, welcome, blur_usage) rendezvous.setServiceParent(self) # for the pruning timer - root = Root() - wr = resource.Resource() - root.putChild(b"wormhole-relay", wr) - wsrf = WebSocketRendezvousFactory(None, rendezvous) - wr.putChild(b"ws", WebSocketResource(wsrf)) + root = WebSocketResource(wsrf) site = PrivacyEnhancedSite(root) if blur_usage: diff --git a/src/wormhole/test/common.py b/src/wormhole/test/common.py index 382cbfe..48c0685 100644 --- a/src/wormhole/test/common.py +++ b/src/wormhole/test/common.py @@ -17,8 +17,7 @@ class ServerBase: s.setServiceParent(self.sp) self._rendezvous = s._rendezvous self._transit_server = s._transit - self.relayurl = u"http://127.0.0.1:%d/wormhole-relay/" % relayport - self.rdv_ws_url = self.relayurl.replace("http:", "ws:") + "ws" + self.relayurl = u"ws://127.0.0.1:%d/" % relayport self.rdv_ws_port = relayport # ws://127.0.0.1:%d/wormhole-relay/ws self.transit = u"tcp:127.0.0.1:%d" % transitport diff --git a/src/wormhole/test/test_server.py b/src/wormhole/test/test_server.py index 90c03eb..aeaa494 100644 --- a/src/wormhole/test/test_server.py +++ b/src/wormhole/test/test_server.py @@ -12,30 +12,6 @@ from .common import ServerBase from ..server import rendezvous, transit_server from ..server.rendezvous import Usage, SidedMessage -class Reachable(ServerBase, unittest.TestCase): - - def test_getPage(self): - # client.getPage requires bytes URL, returns bytes - url = self.relayurl.replace("wormhole-relay/", "").encode("ascii") - d = getPage(url) - def _got(res): - self.failUnlessEqual(res, b"Wormhole Relay\n") - d.addCallback(_got) - return d - - def test_agent(self): - url = self.relayurl.replace("wormhole-relay/", "").encode("ascii") - agent = Agent(reactor) - d = agent.request(b"GET", url) - def _check(resp): - self.failUnlessEqual(resp.code, 200) - return readBody(resp) - d.addCallback(_check) - def _got(res): - self.failUnlessEqual(res, b"Wormhole Relay\n") - d.addCallback(_got) - return d - class Server(ServerBase, unittest.TestCase): def test_apps(self): app1 = self._rendezvous.get_app(u"appid1") @@ -459,7 +435,7 @@ class WebSocketAPI(ServerBase, unittest.TestCase): @inlineCallbacks def make_client(self): - f = WSFactory(self.rdv_ws_url) + f = WSFactory(self.relayurl) f.d = defer.Deferred() reactor.connectTCP("127.0.0.1", self.rdv_ws_port, f) c = yield f.d @@ -644,7 +620,8 @@ class WebSocketAPI(ServerBase, unittest.TestCase): yield c1.next_non_ack() c1.send(u"release") - yield c1.sync() + m = yield c1.next_non_ack() + self.assertEqual(m[u"type"], u"released") row = app._db.execute("SELECT * FROM `nameplates`" " WHERE `app_id`='appid' AND `id`='np1'").fetchone() diff --git a/src/wormhole/test/test_wormhole.py b/src/wormhole/test/test_wormhole.py index 1591530..426bc0f 100644 --- a/src/wormhole/test/test_wormhole.py +++ b/src/wormhole/test/test_wormhole.py @@ -1,13 +1,13 @@ from __future__ import print_function -import json +import os, json, re +from binascii import hexlify, unhexlify import mock from twisted.trial import unittest from twisted.internet import reactor -from twisted.internet.defer import gatherResults, inlineCallbacks -#from ..twisted.transcribe import (wormhole, wormhole_from_serialized, -# UsageError, WrongPasswordError) -#from .common import ServerBase -from ..wormhole import _Wormhole, _WelcomeHandler +from twisted.internet.defer import Deferred, gatherResults, inlineCallbacks +from .common import ServerBase +from .. import wormhole +from spake2 import SPAKE2_Symmetric from ..timing import DebugTiming APPID = u"appid" @@ -31,29 +31,131 @@ def response(w, **kwargs): w._ws_dispatch_response(payload) class Welcome(unittest.TestCase): - def test_no_current_version(self): - # WelcomeHandler should tolerate lack of ["current_version"] - w = _WelcomeHandler(u"relay_url", u"current_version") + def test_tolerate_no_current_version(self): + w = wormhole._WelcomeHandler(u"relay_url", u"current_version", None) w.handle_welcome({}) + def test_print_motd(self): + w = wormhole._WelcomeHandler(u"relay_url", u"current_version", None) + with mock.patch("sys.stderr") as stderr: + w.handle_welcome({u"motd": u"message of\nthe day"}) + self.assertEqual(stderr.method_calls, + [mock.call.write(u"Server (at relay_url) says:\n" + " message of\n the day"), + mock.call.write(u"\n")]) + # motd is only displayed once + with mock.patch("sys.stderr") as stderr2: + w.handle_welcome({u"motd": u"second message"}) + self.assertEqual(stderr2.method_calls, []) + + def test_current_version(self): + w = wormhole._WelcomeHandler(u"relay_url", u"2.0", None) + with mock.patch("sys.stderr") as stderr: + w.handle_welcome({u"current_version": u"2.0"}) + self.assertEqual(stderr.method_calls, []) + + with mock.patch("sys.stderr") as stderr: + w.handle_welcome({u"current_version": u"3.0"}) + exp1 = (u"Warning: errors may occur unless both sides are" + " running the same version") + exp2 = (u"Server claims 3.0 is current, but ours is 2.0") + self.assertEqual(stderr.method_calls, + [mock.call.write(exp1), + mock.call.write(u"\n"), + mock.call.write(exp2), + mock.call.write(u"\n"), + ]) + + # warning is only displayed once + with mock.patch("sys.stderr") as stderr: + w.handle_welcome({u"current_version": u"3.0"}) + self.assertEqual(stderr.method_calls, []) + + def test_non_release_version(self): + w = wormhole._WelcomeHandler(u"relay_url", u"2.0-dirty", None) + with mock.patch("sys.stderr") as stderr: + w.handle_welcome({u"current_version": u"3.0"}) + self.assertEqual(stderr.method_calls, []) + + def test_signal_error(self): + se = mock.Mock() + w = wormhole._WelcomeHandler(u"relay_url", u"2.0", se) + w.handle_welcome({}) + self.assertEqual(se.mock_calls, []) + + w.handle_welcome({u"error": u"oops"}) + self.assertEqual(se.mock_calls, [mock.call(u"oops")]) + +class InputCode(unittest.TestCase): + def test_list(self): + send_command = mock.Mock() + ic = wormhole._InputCode(None, u"prompt", 2, send_command, + DebugTiming()) + d = ic._list() + self.assertNoResult(d) + self.assertEqual(send_command.mock_calls, [mock.call(u"list")]) + ic._response_handle_nameplates({u"type": u"nameplates", + u"nameplates": [{u"id": u"123"}]}) + res = self.successResultOf(d) + self.assertEqual(res, [u"123"]) + +class GetCode(unittest.TestCase): + def test_get(self): + send_command = mock.Mock() + gc = wormhole._GetCode(2, send_command, DebugTiming()) + d = gc.go() + self.assertNoResult(d) + self.assertEqual(send_command.mock_calls, [mock.call(u"allocate")]) + # TODO: nameplate attributes get added and checked here + gc._response_handle_allocated({u"type": u"allocated", + u"nameplate": u"123"}) + code = self.successResultOf(d) + self.assertIsInstance(code, type(u"")) + self.assert_(code.startswith(u"123-")) + pieces = code.split(u"-") + self.assertEqual(len(pieces), 3) # nameplate plus two words + self.assert_(re.search(r'^\d+-\w+-\w+$', code), code) + class Basic(unittest.TestCase): def test_create(self): - w = _Wormhole(APPID, u"relay_url", reactor, None, None) + wormhole._Wormhole(APPID, u"relay_url", reactor, None, None) + + def check_out(self, out, **kwargs): + # Assert that each kwarg is present in the 'out' dict. Ignore other + # keys ('msgid' in particular) + for key, value in kwargs.items(): + self.assertIn(key, out) + self.assertEqual(out[key], value, (out, key, value)) + + def check_outbound(self, ws, types): + out = ws.outbound() + self.assertEqual(len(out), len(types), (out, types)) + for i,t in enumerate(types): + self.assertEqual(out[i][u"type"], t, (i,t,out)) + return out def test_basic(self): # We don't call w._start(), so this doesn't create a WebSocket - # connection. We provide a mock connection instead. - timing = DebugTiming() - with mock.patch("wormhole.wormhole._WelcomeHandler") as whc: - w = _Wormhole(APPID, u"relay_url", reactor, None, timing) - wh = whc.return_value - #w._welcomer = mock.Mock() + # connection. We provide a mock connection instead. If we wanted to + # exercise _connect, we'd mock out WSFactory. # w._connect = lambda self: None # w._event_connected(mock_ws) # w._event_ws_opened() # w._ws_dispatch_response(payload) + + timing = DebugTiming() + with mock.patch("wormhole.wormhole._WelcomeHandler") as wh_c: + w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing) + wh = wh_c.return_value self.assertEqual(w._ws_url, u"relay_url") + self.assertTrue(w._flag_need_nameplate) + self.assertTrue(w._flag_need_to_build_msg1) + self.assertTrue(w._flag_need_to_send_PAKE) + + v = w.get_verifier() + + w._drop_connection = mock.Mock() ws = MockWebSocket() w._event_connected(ws) out = ws.outbound() @@ -62,13 +164,449 @@ class Basic(unittest.TestCase): w._event_ws_opened(None) out = ws.outbound() self.assertEqual(len(out), 1) - self.assertEqual(out[0]["type"], u"bind") - self.assertEqual(out[0]["appid"], APPID) - self.assertEqual(out[0]["side"], w._side) + self.check_out(out[0], type=u"bind", appid=APPID, side=w._side) self.assertIn(u"id", out[0]) - # WelcomeHandler should get called upon 'welcome' response + # WelcomeHandler should get called upon 'welcome' response. Its full + # behavior is exercised in 'Welcome' above. WELCOME = {u"foo": u"bar"} response(w, type="welcome", welcome=WELCOME) self.assertEqual(wh.mock_calls, [mock.call.handle_welcome(WELCOME)]) + # because we're connected, setting the code also claims the mailbox + CODE = u"123-foo-bar" + w.set_code(CODE) + self.assertFalse(w._flag_need_to_build_msg1) + out = ws.outbound() + self.assertEqual(len(out), 1) + self.check_out(out[0], type=u"claim", nameplate=u"123") + + # the server reveals the linked mailbox + response(w, type=u"claimed", mailbox=u"mb456") + + # that triggers event_learned_mailbox, which should send open() and + # PAKE + self.assertTrue(w._mailbox_opened) + out = ws.outbound() + self.assertEqual(len(out), 2) + self.check_out(out[0], type=u"open", mailbox=u"mb456") + self.check_out(out[1], type=u"add", phase=u"pake") + self.assertNoResult(v) + + # server echoes back all "add" messages + response(w, type=u"message", phase=u"pake", body=out[1][u"body"], + side=w._side) + self.assertNoResult(v) + + # next we build the simulated peer's PAKE operation + side2 = w._side + u"other" + msg1 = unhexlify(out[1][u"body"].encode("ascii")) + sp2 = SPAKE2_Symmetric(wormhole.to_bytes(CODE), + idSymmetric=wormhole.to_bytes(APPID)) + msg2 = sp2.start() + msg2_hex = hexlify(msg2).decode("ascii") + key = sp2.finish(msg1) + response(w, type=u"message", phase=u"pake", body=msg2_hex, side=side2) + + # hearing the peer's PAKE (msg2) makes us release the nameplate, send + # the confirmation message, delivered the verifier, and sends any + # queued phase messages + self.assertFalse(w._flag_need_to_see_mailbox_used) + self.assertEqual(w._key, key) + out = ws.outbound() + self.assertEqual(len(out), 2, out) + self.check_out(out[0], type=u"release") + self.check_out(out[1], type=u"add", phase=u"confirm") + verifier = self.successResultOf(v) + self.assertEqual(verifier, w.derive_key(u"wormhole:verifier")) + + # hearing a valid confirmation message doesn't throw an error + confkey = w.derive_key(u"wormhole:confirmation") + nonce = os.urandom(wormhole.CONFMSG_NONCE_LENGTH) + confirm2 = wormhole.make_confmsg(confkey, nonce) + confirm2_hex = hexlify(confirm2).decode("ascii") + response(w, type=u"message", phase=u"confirm", body=confirm2_hex, + side=side2) + + # an outbound message can now be sent immediately + w.send(b"phase0-outbound") + out = ws.outbound() + self.assertEqual(len(out), 1) + self.check_out(out[0], type=u"add", phase=u"0") + # decrypt+check the outbound message + p0_outbound = unhexlify(out[0][u"body"].encode("ascii")) + msgkey0 = w.derive_key(u"wormhole:phase:0") + p0_plaintext = w._decrypt_data(msgkey0, p0_outbound) + self.assertEqual(p0_plaintext, b"phase0-outbound") + + # get() waits for the inbound message to arrive + md = w.get() + self.assertNoResult(md) + self.assertIn(u"0", w._receive_waiters) + self.assertNotIn(u"0", w._received_messages) + p0_inbound = w._encrypt_data(msgkey0, b"phase0-inbound") + p0_inbound_hex = hexlify(p0_inbound).decode("ascii") + response(w, type=u"message", phase=u"0", body=p0_inbound_hex, + side=side2) + p0_in = self.successResultOf(md) + self.assertEqual(p0_in, b"phase0-inbound") + self.assertNotIn(u"0", w._receive_waiters) + self.assertIn(u"0", w._received_messages) + + # receiving an inbound message will queue it until get() is called + msgkey1 = w.derive_key(u"wormhole:phase:1") + p1_inbound = w._encrypt_data(msgkey1, b"phase1-inbound") + p1_inbound_hex = hexlify(p1_inbound).decode("ascii") + response(w, type=u"message", phase=u"1", body=p1_inbound_hex, + side=side2) + self.assertIn(u"1", w._received_messages) + self.assertNotIn(u"1", w._receive_waiters) + p1_in = self.successResultOf(w.get()) + self.assertEqual(p1_in, b"phase1-inbound") + self.assertIn(u"1", w._received_messages) + self.assertNotIn(u"1", w._receive_waiters) + + w.close() + out = ws.outbound() + self.assertEqual(len(out), 1) + self.check_out(out[0], type=u"close", mood=u"happy") + self.assertEqual(w._drop_connection.mock_calls, [mock.call()]) + + def test_close_wait_1(self): + # close after claiming the nameplate, but before opening the mailbox + timing = DebugTiming() + w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing) + w._drop_connection = mock.Mock() + ws = MockWebSocket() + w._event_connected(ws) + w._event_ws_opened(None) + CODE = u"123-foo-bar" + w.set_code(CODE) + self.check_outbound(ws, [u"bind", u"claim"]) + + d = w.close(wait=True) + self.check_outbound(ws, [u"release"]) + self.assertNoResult(d) + self.assertEqual(w._drop_connection.mock_calls, []) + + response(w, type=u"released") + self.successResultOf(d) + self.assertEqual(w._drop_connection.mock_calls, [mock.call()]) + + def test_close_wait_2(self): + # close after both claiming the nameplate and opening the mailbox + timing = DebugTiming() + w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing) + w._drop_connection = mock.Mock() + ws = MockWebSocket() + w._event_connected(ws) + w._event_ws_opened(None) + CODE = u"123-foo-bar" + w.set_code(CODE) + response(w, type=u"claimed", mailbox=u"mb456") + self.check_outbound(ws, [u"bind", u"claim", u"open", u"add"]) + + d = w.close(wait=True) + self.check_outbound(ws, [u"release", u"close"]) + self.assertNoResult(d) + self.assertEqual(w._drop_connection.mock_calls, []) + + response(w, type=u"released") + self.assertNoResult(d) + self.assertEqual(w._drop_connection.mock_calls, []) + response(w, type=u"closed") + self.successResultOf(d) + self.assertEqual(w._drop_connection.mock_calls, [mock.call()]) + + def test_close_wait_3(self): + # close after claiming the nameplate, opening the mailbox, then + # releasing the nameplate + timing = DebugTiming() + w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing) + w._drop_connection = mock.Mock() + ws = MockWebSocket() + w._event_connected(ws) + w._event_ws_opened(None) + CODE = u"123-foo-bar" + w.set_code(CODE) + response(w, type=u"claimed", mailbox=u"mb456") + + w._key = b"" + msgkey = w.derive_key(u"wormhole:phase:misc") + p1_inbound = w._encrypt_data(msgkey, b"") + p1_inbound_hex = hexlify(p1_inbound).decode("ascii") + response(w, type=u"message", phase=u"misc", side=u"side2", + body=p1_inbound_hex) + self.check_outbound(ws, [u"bind", u"claim", u"open", u"add", + u"release"]) + + d = w.close(wait=True) + self.check_outbound(ws, [u"close"]) + self.assertNoResult(d) + self.assertEqual(w._drop_connection.mock_calls, []) + + response(w, type=u"released") + self.assertNoResult(d) + self.assertEqual(w._drop_connection.mock_calls, []) + response(w, type=u"closed") + self.successResultOf(d) + self.assertEqual(w._drop_connection.mock_calls, [mock.call()]) + + def test_get_code_mock(self): + timing = DebugTiming() + w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing) + ws = MockWebSocket() # TODO: mock w._ws_send_command instead + w._event_connected(ws) + w._event_ws_opened(None) + self.check_outbound(ws, [u"bind"]) + + gc_c = mock.Mock() + gc = gc_c.return_value = mock.Mock() + gc_d = gc.go.return_value = Deferred() + with mock.patch("wormhole.wormhole._GetCode", gc_c): + d = w.get_code() + self.assertNoResult(d) + + gc_d.callback(u"123-foo-bar") + code = self.successResultOf(d) + self.assertEqual(code, u"123-foo-bar") + + def test_get_code_real(self): + timing = DebugTiming() + w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing) + ws = MockWebSocket() + w._event_connected(ws) + w._event_ws_opened(None) + self.check_outbound(ws, [u"bind"]) + + d = w.get_code() + + out = ws.outbound() + self.assertEqual(len(out), 1) + self.check_out(out[0], type=u"allocate") + # TODO: nameplate attributes go here + self.assertNoResult(d) + + response(w, type=u"allocated", nameplate=u"123") + code = self.successResultOf(d) + self.assertIsInstance(code, type(u"")) + self.assert_(code.startswith(u"123-")) + pieces = code.split(u"-") + self.assertEqual(len(pieces), 3) # nameplate plus two words + self.assert_(re.search(r'^\d+-\w+-\w+$', code), code) + +# event orderings to exercise: +# +# * normal sender: set_code, send_phase1, connected, claimed, learn_msg2, +# learn_phase1 +# * normal receiver (argv[2]=code): set_code, connected, learn_msg1, +# learn_phase1, send_phase1, +# * normal receiver (readline): connected, input_code +# * +# * set_code, then connected +# * connected, receive_pake, send_phase, set_code + +class Wormholes(ServerBase, unittest.TestCase): + # integration test, with a real server + + def doBoth(self, d1, d2): + return gatherResults([d1, d2], True) + + @inlineCallbacks + def test_basic(self): + w1 = wormhole.wormhole(APPID, self.relayurl, reactor) + w2 = wormhole.wormhole(APPID, self.relayurl, reactor) + code = yield w1.get_code() + w2.set_code(code) + w1.send(b"data1") + w2.send(b"data2") + dataX = yield w1.get() + dataY = yield w2.get() + self.assertEqual(dataX, b"data2") + self.assertEqual(dataY, b"data1") + yield w1.close(wait=True) + yield w2.close(wait=True) + +class Off: + + @inlineCallbacks + def test_same_message(self): + # the two sides use random nonces for their messages, so it's ok for + # both to try and send the same body: they'll result in distinct + # encrypted messages + w1 = Wormhole(APPID, self.relayurl) + w2 = Wormhole(APPID, self.relayurl) + code = yield w1.get_code() + w2.set_code(code) + yield self.doBoth(w1.send(b"data"), w2.send(b"data")) + dl = yield self.doBoth(w1.get(), w2.get()) + (dataX, dataY) = dl + self.assertEqual(dataX, b"data") + self.assertEqual(dataY, b"data") + yield self.doBoth(w1.close(), w2.close()) + + @inlineCallbacks + def test_interleaved(self): + w1 = Wormhole(APPID, self.relayurl) + w2 = Wormhole(APPID, self.relayurl) + code = yield w1.get_code() + w2.set_code(code) + res = yield self.doBoth(w1.send(b"data1"), w2.get()) + (_, dataY) = res + self.assertEqual(dataY, b"data1") + dl = yield self.doBoth(w1.get(), w2.send(b"data2")) + (dataX, _) = dl + self.assertEqual(dataX, b"data2") + yield self.doBoth(w1.close(), w2.close()) + + @inlineCallbacks + def test_fixed_code(self): + w1 = Wormhole(APPID, self.relayurl) + w2 = Wormhole(APPID, self.relayurl) + w1.set_code(u"123-purple-elephant") + w2.set_code(u"123-purple-elephant") + yield self.doBoth(w1.send(b"data1"), w2.send(b"data2")) + dl = yield self.doBoth(w1.get(), w2.get()) + (dataX, dataY) = dl + self.assertEqual(dataX, b"data2") + self.assertEqual(dataY, b"data1") + yield self.doBoth(w1.close(), w2.close()) + + + @inlineCallbacks + def test_multiple_messages(self): + w1 = Wormhole(APPID, self.relayurl) + w2 = Wormhole(APPID, self.relayurl) + w1.set_code(u"123-purple-elephant") + w2.set_code(u"123-purple-elephant") + yield self.doBoth(w1.send(b"data1"), w2.send(b"data2")) + yield self.doBoth(w1.send(b"data3"), w2.send(b"data4")) + dl = yield self.doBoth(w1.get(), w2.get()) + (dataX, dataY) = dl + self.assertEqual(dataX, b"data2") + self.assertEqual(dataY, b"data1") + dl = yield self.doBoth(w1.get(), w2.get()) + (dataX, dataY) = dl + self.assertEqual(dataX, b"data4") + self.assertEqual(dataY, b"data3") + yield self.doBoth(w1.close(), w2.close()) + + @inlineCallbacks + def test_multiple_messages_2(self): + w1 = Wormhole(APPID, self.relayurl) + w2 = Wormhole(APPID, self.relayurl) + w1.set_code(u"123-purple-elephant") + w2.set_code(u"123-purple-elephant") + # TODO: set_code should be sufficient to kick things off, but for now + # we must also let both sides do at least one send() or get() + yield self.doBoth(w1.send(b"data1"), w2.send(b"ignored")) + yield w1.get() + yield w1.send(b"data2") + yield w1.send(b"data3") + data = yield w2.get() + self.assertEqual(data, b"data1") + data = yield w2.get() + self.assertEqual(data, b"data2") + data = yield w2.get() + self.assertEqual(data, b"data3") + yield self.doBoth(w1.close(), w2.close()) + + @inlineCallbacks + def test_wrong_password(self): + w1 = Wormhole(APPID, self.relayurl) + w2 = Wormhole(APPID, self.relayurl) + code = yield w1.get_code() + w2.set_code(code+"not") + + # w2 can't throw WrongPasswordError until it sees a CONFIRM message, + # and w1 won't send CONFIRM until it sees a PAKE message, which w2 + # won't send until we call get. So we need both sides to be + # running at the same time for this test. + d1 = w1.send(b"data1") + # at this point, w1 should be waiting for w2.PAKE + + yield self.assertFailure(w2.get(), WrongPasswordError) + # * w2 will send w2.PAKE, wait for (and get) w1.PAKE, compute a key, + # send w2.CONFIRM, then wait for w1.DATA. + # * w1 will get w2.PAKE, compute a key, send w1.CONFIRM. + # * w1 might also get w2.CONFIRM, and may notice the error before it + # sends w1.CONFIRM, in which case the wait=True will signal an + # error inside _get_master_key() (inside send), and d1 will + # errback. + # * but w1 might not see w2.CONFIRM yet, in which case it won't + # errback until we do w1.get() + # * w2 gets w1.CONFIRM, notices the error, records it. + # * w2 (waiting for w1.DATA) wakes up, sees the error, throws + # * meanwhile w1 finishes sending its data. w2.CONFIRM may or may not + # have arrived by then + try: + yield d1 + except WrongPasswordError: + pass + + # When we ask w1 to get(), one of two things might happen: + # * if w2.CONFIRM arrived already, it will have recorded the error. + # When w1.get() sleeps (waiting for w2.DATA), we'll notice the + # error before sleeping, and throw WrongPasswordError + # * if w2.CONFIRM hasn't arrived yet, we'll sleep. When w2.CONFIRM + # arrives, we notice and record the error, and wake up, and throw + + # Note that we didn't do w2.send(), so we're hoping that w1 will + # have enough information to detect the error before it sleeps + # (waiting for w2.DATA). Checking for the error both before sleeping + # and after waking up makes this happen. + + # so now w1 should have enough information to throw too + yield self.assertFailure(w1.get(), WrongPasswordError) + + # both sides are closed automatically upon error, but it's still + # legal to call .close(), and should be idempotent + yield self.doBoth(w1.close(), w2.close()) + + @inlineCallbacks + def test_no_confirm(self): + # newer versions (which check confirmations) should will work with + # older versions (that don't send confirmations) + w1 = Wormhole(APPID, self.relayurl) + w1._send_confirm = False + w2 = Wormhole(APPID, self.relayurl) + + code = yield w1.get_code() + w2.set_code(code) + dl = yield self.doBoth(w1.send(b"data1"), w2.get()) + self.assertEqual(dl[1], b"data1") + dl = yield self.doBoth(w1.get(), w2.send(b"data2")) + self.assertEqual(dl[0], b"data2") + yield self.doBoth(w1.close(), w2.close()) + + @inlineCallbacks + def test_verifier(self): + w1 = Wormhole(APPID, self.relayurl) + w2 = Wormhole(APPID, self.relayurl) + code = yield w1.get_code() + w2.set_code(code) + res = yield self.doBoth(w1.get_verifier(), w2.get_verifier()) + v1, v2 = res + self.failUnlessEqual(type(v1), type(b"")) + self.failUnlessEqual(v1, v2) + yield self.doBoth(w1.send(b"data1"), w2.send(b"data2")) + dl = yield self.doBoth(w1.get(), w2.get()) + (dataX, dataY) = dl + self.assertEqual(dataX, b"data2") + self.assertEqual(dataY, b"data1") + yield self.doBoth(w1.close(), w2.close()) + + @inlineCallbacks + def test_errors(self): + w1 = Wormhole(APPID, self.relayurl) + yield self.assertFailure(w1.get_verifier(), UsageError) + yield self.assertFailure(w1.send(b"data"), UsageError) + yield self.assertFailure(w1.get(), UsageError) + w1.set_code(u"123-purple-elephant") + yield self.assertRaises(UsageError, w1.set_code, u"123-nope") + yield self.assertFailure(w1.get_code(), UsageError) + w2 = Wormhole(APPID, self.relayurl) + yield w2.get_code() + yield self.assertFailure(w2.get_code(), UsageError) + yield self.doBoth(w1.close(), w2.close()) + diff --git a/src/wormhole/wormhole.py b/src/wormhole/wormhole.py index 2a16bb7..31ffe4c 100644 --- a/src/wormhole/wormhole.py +++ b/src/wormhole/wormhole.py @@ -15,7 +15,7 @@ from . import __version__ from . import codes #from .errors import ServerError, Timeout from .errors import WrongPasswordError, UsageError -#from .timing import DebugTiming +from .timing import DebugTiming from hkdf import Hkdf def HKDF(skm, outlen, salt=None, CTXinfo=b""): @@ -67,9 +67,10 @@ class WSFactory(websocket.WebSocketClientFactory): class _GetCode: - def __init__(self, code_length, send_command): + def __init__(self, code_length, send_command, timing): self._code_length = code_length self._send_command = send_command + self._timing = timing self._allocated_d = defer.Deferred() @inlineCallbacks @@ -87,11 +88,12 @@ class _GetCode: self._allocated_d.callback(nid) class _InputCode: - def __init__(self, reactor, prompt, code_length, send_command): + def __init__(self, reactor, prompt, code_length, send_command, timing): self._reactor = reactor self._prompt = prompt self._code_length = code_length self._send_command = send_command + self._timing = timing @inlineCallbacks def _list(self): @@ -172,11 +174,12 @@ class _InputCode: # readline finish. class _WelcomeHandler: - def __init__(self, url, current_version): + def __init__(self, url, current_version, signal_error): self._ws_url = url self._version_warning_displayed = False self._motd_displayed = False self._current_version = current_version + self._signal_error = signal_error def handle_welcome(self, welcome): if ("motd" in welcome and @@ -211,28 +214,45 @@ class _Wormhole: self._tor_manager = tor_manager self._timing = timing - self._welcomer = _WelcomeHandler(self._ws_url, __version__) + self._welcomer = _WelcomeHandler(self._ws_url, __version__, + self._signal_error) self._side = hexlify(os.urandom(5)).decode("ascii") self._connected = None + self._connection_waiters = [] + self._started_get_code = False + self._code = None self._nameplate_id = None + self._nameplate_claimed = False + self._nameplate_released = False + self._release_waiter = defer.Deferred() self._mailbox_id = None self._mailbox_opened = False self._mailbox_closed = False - self._flag_need_mailbox = True + self._close_waiter = defer.Deferred() + self._flag_need_nameplate = True self._flag_need_to_see_mailbox_used = True self._flag_need_to_build_msg1 = True self._flag_need_to_send_PAKE = True - self._flag_need_PAKE = True - self._flag_need_key = True # rename to not self._key + self._key = None + self._closed = False + self._mood = u"happy" + + self._get_verifier_called = False + self._verifier_waiter = defer.Deferred() self._next_send_phase = 0 - self._plaintext_to_send = [] # (phase, plaintext, deferred) - self._phase_messages_to_send = [] # not yet acked by server + # send() queues plaintext here, waiting for a connection and the key + self._plaintext_to_send = [] # (phase, plaintext) + self._sent_phases = set() # to detect double-send self._next_receive_phase = 0 self._receive_waiters = {} # phase -> Deferred - self._phase_messages_received = {} # phase -> message + self._received_messages = {} # phase -> plaintext + def _signal_error(self, error): + # close the mailbox with an "errory" mood, errback all Deferreds, + # record the error, fail all subsequent API calls + pass # XXX def _start(self): d = self._connect() # causes stuff to happen @@ -275,6 +295,16 @@ class _Wormhole: self._ws_send_command(u"bind", appid=self._appid, side=self._side) self._maybe_get_mailbox() self._maybe_send_pake() + waiters, self._connection_waiters = self._connection_waiters, [] + for d in waiters: + d.callback(None) + + def _when_connected(self): + if self._connected: + return defer.succeed(None) + d = defer.Deferred() + self._connection_waiters.append(d) + return d def _ws_send_command(self, mtype, **kwargs): # msgid is used by misc/dump-timing.py to correlate our sends with @@ -287,8 +317,10 @@ class _Wormhole: self._timing.add("ws_send", _side=self._side, **kwargs) self._ws.sendMessage(payload, False) + DEBUG=False def _ws_dispatch_response(self, payload): msg = json.loads(payload.decode("utf-8")) + if self.DEBUG and msg["type"]!="ack": print("DIS", msg["type"], msg) self._timing.add("ws_receive", _side=self._side, message=msg) mtype = msg["type"] meth = getattr(self, "_response_handle_"+mtype, None) @@ -311,9 +343,11 @@ class _Wormhole: if self._started_get_code: raise UsageError self._started_get_code = True with self._timing.add("API get_code"): - gc = _GetCode(code_length, self._ws_send_command) + yield self._when_connected() + gc = _GetCode(code_length, self._ws_send_command, self._timing) self._response_handle_allocated = gc._response_handle_allocated code = yield gc.go() + self._nameplate_claimed = True # side-effect of allocation self._event_learned_code(code) returnValue(code) @@ -324,6 +358,7 @@ class _Wormhole: if self._started_input_code: raise UsageError self._started_input_code = True with self._timing.add("API input_code"): + yield self._when_connected() ic = _InputCode(prompt, code_length, self._ws_send_command) self._response_handle_nameplates = ic._response_handle_nameplates code = yield ic.go() @@ -337,6 +372,17 @@ class _Wormhole: if self._code is not None: raise UsageError self._event_learned_code(code) + # TODO: entry point 4: restore pre-contact saved state (we haven't heard + # from the peer yet, so we still need the nameplate) + + # TODO: entry point 5: restore post-contact saved state (so we don't need + # or use the nameplate, only the mailbox) + def _restore_post_contact_state(self, state): + # ... + self._flag_need_nameplate = False + #self._mailbox_id = X(state) + self._event_learned_mailbox() + def _event_learned_code(self, code): self._timing.add("code established") self._code = code @@ -370,10 +416,10 @@ class _Wormhole: self._maybe_get_mailbox() def _maybe_get_mailbox(self): - if not (self._flag_need_mailbox and self._nameplate_id - and self._connected): + if not (self._nameplate_id and self._connected): return self._ws_send_command(u"claim", nameplate=self._nameplate_id) + self._nameplate_claimed = True def _response_handle_claimed(self, msg): mailbox_id = msg["mailbox"] @@ -382,10 +428,10 @@ class _Wormhole: self._event_learned_mailbox() def _event_learned_mailbox(self): - self._flag_need_mailbox = False if not self._mailbox_id: raise UsageError if self._mailbox_opened: raise UsageError self._ws_send_command(u"open", mailbox=self._mailbox_id) + self._mailbox_opened = True # causes old messages to be sent now, and subscribes to new messages self._maybe_send_pake() self._maybe_send_phase_messages() @@ -395,34 +441,36 @@ class _Wormhole: if not (self._connected and self._mailbox_opened and self._flag_need_to_send_PAKE): return - d = self._msg_send(u"pake", self._msg1) - def _pake_sent(res): - self._flag_need_to_send_PAKE = False - d.addCallback(_pake_sent) - d.addErrback(log.err) + self._msg_send(u"pake", self._msg1) + self._flag_need_to_send_PAKE = False - def _event_learned_PAKE(self, pake_msg): + def _event_received_pake(self, pake_msg): with self._timing.add("pake2", waiting="crypto"): self._key = self._sp.finish(pake_msg) self._event_established_key() def _event_established_key(self): self._timing.add("key established") - if self._send_confirm: - # both sides send different (random) confirmation messages - confkey = self.derive_key(u"wormhole:confirmation") - nonce = os.urandom(CONFMSG_NONCE_LENGTH) - confmsg = make_confmsg(confkey, nonce) - self._msg_send(u"confirm", confmsg, wait=True) + + # both sides send different (random) confirmation messages + confkey = self.derive_key(u"wormhole:confirmation") + nonce = os.urandom(CONFMSG_NONCE_LENGTH) + confmsg = make_confmsg(confkey, nonce) + self._msg_send(u"confirm", confmsg) + verifier = self.derive_key(u"wormhole:verifier") self._event_computed_verifier(verifier) + self._maybe_send_phase_messages() + def get_verifier(self): + if self._closed: raise UsageError + if self._get_verifier_called: raise UsageError + self._get_verifier_called = True + return self._verifier_waiter + def _event_computed_verifier(self, verifier): - self._verifier = verifier - d, self._verifier_waiter = self._verifier_waiter, None - if d: - d.callback(verifier) + self._verifier_waiter.callback(verifier) def _event_received_confirm(self, body): # TODO: we might not have a master key yet, if the caller wasn't @@ -435,19 +483,15 @@ class _Wormhole: return self._signal_error(WrongPasswordError()) - @inlineCallbacks - def send(self, outbound_data, wait=False): + def send(self, outbound_data): if not isinstance(outbound_data, type(b"")): raise TypeError(type(outbound_data)) if self._closed: raise UsageError phase = self._next_send_phase self._next_send_phase += 1 - d = defer.Deferred() - self._plaintext_to_send.append( (phase, outbound_data, d) ) - with self._timing.add("API send", phase=phase, wait=wait): + self._plaintext_to_send.append( (phase, outbound_data) ) + with self._timing.add("API send", phase=phase): self._maybe_send_phase_messages() - if wait: - yield d def _maybe_send_phase_messages(self): # TODO: deal with reentrant call @@ -456,18 +500,19 @@ class _Wormhole: plaintexts = self._plaintext_to_send self._plaintext_to_send = [] for pm in plaintexts: - (phase, plaintext, wait_d) = pm + (phase, plaintext) = pm + assert isinstance(phase, int), type(phase) data_key = self.derive_key(u"wormhole:phase:%d" % phase) encrypted = self._encrypt_data(data_key, plaintext) - d = self._msg_send(phase, encrypted) - d.addBoth(wait_d.callback) - d.addErrback(log.err) + self._msg_send(u"%d" % phase, encrypted) def _encrypt_data(self, key, data): # Without predefined roles, we can't derive predictably unique keys # for each side, so we use the same key for both. We use random # nonces to keep the messages distinct, and we automatically ignore # reflections. + # TODO: HKDF(side, nonce, key) ?? include 'side' to prevent + # reflections, since we no longer compare messages assert isinstance(key, type(b"")), type(key) assert isinstance(data, type(b"")), type(data) assert len(key) == SecretBox.KEY_SIZE, len(key) @@ -475,57 +520,39 @@ class _Wormhole: nonce = utils.random(SecretBox.NONCE_SIZE) return box.encrypt(data, nonce) - @inlineCallbacks - def _msg_send(self, phase, body, wait=False): - if phase in self._sent_messages: raise UsageError + def _msg_send(self, phase, body): + if phase in self._sent_phases: raise UsageError if not self._mailbox_opened: raise UsageError if self._mailbox_closed: raise UsageError - self._sent_messages[phase] = body + self._sent_phases.add(phase) # TODO: retry on failure, with exponential backoff. We're guarding # against the rendezvous server being temporarily offline. - t = self._timing.add("add", phase=phase, wait=wait) - yield self._ws_send_command(u"add", phase=phase, - body=hexlify(body).decode("ascii")) - if wait: - while phase not in self._delivered_messages: - yield self._sleep() - t.finish() + self._timing.add("add", phase=phase) + self._ws_send_command(u"add", phase=phase, + body=hexlify(body).decode("ascii")) - def _event_received_message(self, msg): - pass def _event_mailbox_used(self): + if self.DEBUG: print("_event_mailbox_used") if self._flag_need_to_see_mailbox_used: - self._ws_send_command(u"release") + self._maybe_release_nameplate() self._flag_need_to_see_mailbox_used = False def derive_key(self, purpose, length=SecretBox.KEY_SIZE): if not isinstance(purpose, type(u"")): raise TypeError(type(purpose)) if self._key is None: - # call after get_verifier() or get() - raise UsageError + raise UsageError # call derive_key after get_verifier() or get() return HKDF(self._key, length, CTXinfo=to_bytes(purpose)) - def _event_received_phase_message(self, phase, message): - self._phase_messages_received[phase] = message - if phase in self._phase_message_waiters: - d = self._phase_message_waiters.pop(phase) - d.callback(message) - def _response_handle_message(self, msg): side = msg["side"] phase = msg["phase"] + assert isinstance(phase, type(u"")), type(phase) body = unhexlify(msg["body"].encode("ascii")) if side == self._side: return self._event_received_peer_message(phase, body) - def XXXackstuff(): - if phase in self._sent_messages and self._sent_messages[phase] == body: - self._delivered_messages.add(phase) # ack by server - self._wakeup() - return # ignore echoes of our outbound messages - def _event_received_peer_message(self, phase, body): # any message in the mailbox means we no longer need the nameplate self._event_mailbox_used() @@ -534,18 +561,23 @@ class _Wormhole: # err = ServerError("got duplicate phase %s" % phase, self._ws_url) # return self._signal_error(err) #self._received_messages[phase] = body + if phase == u"pake": + self._event_received_pake(body) + return if phase == u"confirm": self._event_received_confirm(body) + return + # now notify anyone waiting on it try: data_key = self.derive_key(u"wormhole:phase:%s" % phase) - inbound_data = self._decrypt_data(data_key, body) + plaintext = self._decrypt_data(data_key, body) except CryptoError: - raise WrongPasswordError - self._phase_messages_received[phase] = inbound_data + raise WrongPasswordError # TODO: signal + self._received_messages[phase] = plaintext if phase in self._receive_waiters: d = self._receive_waiters.pop(phase) - d.callback(inbound_data) + d.callback(plaintext) def _decrypt_data(self, key, encrypted): assert isinstance(key, type(b"")), type(key) @@ -555,28 +587,63 @@ class _Wormhole: data = box.decrypt(encrypted) return data - @inlineCallbacks def get(self): if self._closed: raise UsageError - if self._code is None: raise UsageError - phase = self._next_receive_phase + phase = u"%d" % self._next_receive_phase self._next_receive_phase += 1 with self._timing.add("API get", phase=phase): - if phase in self._phase_messages_received: - returnValue(self._phase_messages_received[phase]) + if phase in self._received_messages: + return defer.succeed(self._received_messages[phase]) d = self._receive_waiters[phase] = defer.Deferred() - yield d + return d - def _event_asked_to_close(self): + @inlineCallbacks + def close(self, mood=None, wait=False): + # TODO: auto-close on error, mostly for load-from-state + if self._closed: raise UsageError + if mood: + self._mood = mood + self._maybe_release_nameplate() + self._maybe_close_mailbox() + if wait: + if self._nameplate_claimed: + yield self._release_waiter + if self._mailbox_opened: + yield self._close_waiter + self._drop_connection() + + def _maybe_release_nameplate(self): + if self.DEBUG: print("_maybe_release_nameplate", self._nameplate_claimed, self._nameplate_released) + if self._nameplate_claimed and not self._nameplate_released: + if self.DEBUG: print(" sending release") + self._ws_send_command(u"release") + self._nameplate_released = True + + def _response_handle_released(self, msg): + self._release_waiter.callback(None) + + def _maybe_close_mailbox(self): + if self._mailbox_opened and not self._mailbox_closed: + self._ws_send_command(u"close", mood=self._mood) + self._mailbox_closed = True + + def _response_handle_closed(self, msg): + self._close_waiter.callback(None) + + def _drop_connection(self): + self._ws.transport.loseConnection() # probably flushes + # calls _ws_closed() when done + + def _ws_closed(self, wasClean, code, reason): pass - - def wormhole(appid, relay_url, reactor, tor_manager=None, timing=None): + timing = timing or DebugTiming() w = _Wormhole(appid, relay_url, reactor, tor_manager, timing) w._start() return w -def wormhole_from_serialized(data, reactor): - w = _Wormhole.from_serialized(data, reactor) +def wormhole_from_serialized(data, reactor, timing=None): + timing = timing or DebugTiming() + w = _Wormhole.from_serialized(data, reactor, timing) return w From c88d6937c2e8d695049fffa042104b044cc213d5 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Sun, 22 May 2016 18:45:50 -0700 Subject: [PATCH 36/51] close(wait=True): wait for connection to be dropped --- src/wormhole/test/test_wormhole.py | 17 ++++++++++++++--- src/wormhole/wormhole.py | 5 ++++- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/src/wormhole/test/test_wormhole.py b/src/wormhole/test/test_wormhole.py index 426bc0f..1a6d865 100644 --- a/src/wormhole/test/test_wormhole.py +++ b/src/wormhole/test/test_wormhole.py @@ -290,8 +290,11 @@ class Basic(unittest.TestCase): self.assertEqual(w._drop_connection.mock_calls, []) response(w, type=u"released") - self.successResultOf(d) self.assertEqual(w._drop_connection.mock_calls, [mock.call()]) + self.assertNoResult(d) + + w._ws_closed(True, None, None) + self.successResultOf(d) def test_close_wait_2(self): # close after both claiming the nameplate and opening the mailbox @@ -314,10 +317,14 @@ class Basic(unittest.TestCase): response(w, type=u"released") self.assertNoResult(d) self.assertEqual(w._drop_connection.mock_calls, []) + response(w, type=u"closed") - self.successResultOf(d) + self.assertNoResult(d) self.assertEqual(w._drop_connection.mock_calls, [mock.call()]) + w._ws_closed(True, None, None) + self.successResultOf(d) + def test_close_wait_3(self): # close after claiming the nameplate, opening the mailbox, then # releasing the nameplate @@ -348,10 +355,14 @@ class Basic(unittest.TestCase): response(w, type=u"released") self.assertNoResult(d) self.assertEqual(w._drop_connection.mock_calls, []) + response(w, type=u"closed") - self.successResultOf(d) + self.assertNoResult(d) self.assertEqual(w._drop_connection.mock_calls, [mock.call()]) + w._ws_closed(True, None, None) + self.successResultOf(d) + def test_get_code_mock(self): timing = DebugTiming() w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing) diff --git a/src/wormhole/wormhole.py b/src/wormhole/wormhole.py index 31ffe4c..ccb7f9c 100644 --- a/src/wormhole/wormhole.py +++ b/src/wormhole/wormhole.py @@ -235,6 +235,7 @@ class _Wormhole: self._flag_need_to_send_PAKE = True self._key = None self._closed = False + self._disconnect_waiter = defer.Deferred() self._mood = u"happy" self._get_verifier_called = False @@ -611,6 +612,8 @@ class _Wormhole: if self._mailbox_opened: yield self._close_waiter self._drop_connection() + if wait: + yield self._disconnect_waiter def _maybe_release_nameplate(self): if self.DEBUG: print("_maybe_release_nameplate", self._nameplate_claimed, self._nameplate_released) @@ -635,7 +638,7 @@ class _Wormhole: # calls _ws_closed() when done def _ws_closed(self, wasClean, code, reason): - pass + self._disconnect_waiter.callback(None) def wormhole(appid, relay_url, reactor, tor_manager=None, timing=None): timing = timing or DebugTiming() From 528092dd970d9f8d10c0e0a7f0088be07de2a2d5 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Mon, 23 May 2016 00:14:39 -0700 Subject: [PATCH 37/51] improve error signalling --- src/wormhole/errors.py | 4 + src/wormhole/test/test_server.py | 1 - src/wormhole/test/test_wormhole.py | 144 +++++++++++++++++++++++++++-- src/wormhole/wormhole.py | 60 ++++++++++-- 4 files changed, 191 insertions(+), 18 deletions(-) diff --git a/src/wormhole/errors.py b/src/wormhole/errors.py index 4d91270..0c140ac 100644 --- a/src/wormhole/errors.py +++ b/src/wormhole/errors.py @@ -20,6 +20,10 @@ def handle_server_error(func): class Timeout(Exception): pass +class WelcomeError(Exception): + """The server told us to signal an error, probably because our version is + too old to possibly work.""" + class WrongPasswordError(Exception): """ Key confirmation failed. Either you or your correspondent typed the code diff --git a/src/wormhole/test/test_server.py b/src/wormhole/test/test_server.py index aeaa494..437e8bd 100644 --- a/src/wormhole/test/test_server.py +++ b/src/wormhole/test/test_server.py @@ -5,7 +5,6 @@ from twisted.trial import unittest from twisted.internet import protocol, reactor, defer from twisted.internet.defer import inlineCallbacks, returnValue from twisted.internet.endpoints import clientFromString, connectProtocol -from twisted.web.client import getPage, Agent, readBody from autobahn.twisted import websocket from .. import __version__ from .common import ServerBase diff --git a/src/wormhole/test/test_wormhole.py b/src/wormhole/test/test_wormhole.py index 1a6d865..ad8e5ef 100644 --- a/src/wormhole/test/test_wormhole.py +++ b/src/wormhole/test/test_wormhole.py @@ -1,5 +1,5 @@ from __future__ import print_function -import os, json, re +import os, json, re, gc from binascii import hexlify, unhexlify import mock from twisted.trial import unittest @@ -84,7 +84,14 @@ class Welcome(unittest.TestCase): self.assertEqual(se.mock_calls, []) w.handle_welcome({u"error": u"oops"}) - self.assertEqual(se.mock_calls, [mock.call(u"oops")]) + self.assertEqual(len(se.mock_calls), 1) + self.assertEqual(len(se.mock_calls[0][1]), 1) # posargs + we = se.mock_calls[0][1][0] + self.assertIsInstance(we, wormhole.WelcomeError) + self.assertEqual(we.args, (u"oops",)) + # alas WelcomeError instances don't compare against each other + #self.assertEqual(se.mock_calls, + # [mock.call(wormhole.WelcomeError(u"oops"))]) class InputCode(unittest.TestCase): def test_list(self): @@ -116,10 +123,10 @@ class GetCode(unittest.TestCase): self.assertEqual(len(pieces), 3) # nameplate plus two words self.assert_(re.search(r'^\d+-\w+-\w+$', code), code) - class Basic(unittest.TestCase): - def test_create(self): - wormhole._Wormhole(APPID, u"relay_url", reactor, None, None) + def tearDown(self): + # flush out any errorful Deferreds left dangling in cycles + gc.collect() def check_out(self, out, **kwargs): # Assert that each kwarg is present in the 'out' dict. Ignore other @@ -135,6 +142,17 @@ class Basic(unittest.TestCase): self.assertEqual(out[i][u"type"], t, (i,t,out)) return out + def make_pake(self, code, side, msg1): + sp2 = SPAKE2_Symmetric(wormhole.to_bytes(code), + idSymmetric=wormhole.to_bytes(APPID)) + msg2 = sp2.start() + msg2_hex = hexlify(msg2).decode("ascii") + key = sp2.finish(msg1) + return key, msg2_hex + + def test_create(self): + wormhole._Wormhole(APPID, u"relay_url", reactor, None, None) + def test_basic(self): # We don't call w._start(), so this doesn't create a WebSocket # connection. We provide a mock connection instead. If we wanted to @@ -201,11 +219,7 @@ class Basic(unittest.TestCase): # next we build the simulated peer's PAKE operation side2 = w._side + u"other" msg1 = unhexlify(out[1][u"body"].encode("ascii")) - sp2 = SPAKE2_Symmetric(wormhole.to_bytes(CODE), - idSymmetric=wormhole.to_bytes(APPID)) - msg2 = sp2.start() - msg2_hex = hexlify(msg2).decode("ascii") - key = sp2.finish(msg1) + key, msg2_hex = self.make_pake(CODE, side2, msg1) response(w, type=u"message", phase=u"pake", body=msg2_hex, side=side2) # hearing the peer's PAKE (msg2) makes us release the nameplate, send @@ -272,6 +286,24 @@ class Basic(unittest.TestCase): self.check_out(out[0], type=u"close", mood=u"happy") self.assertEqual(w._drop_connection.mock_calls, [mock.call()]) + def test_close_wait_0(self): + # close before even claiming the nameplate + timing = DebugTiming() + w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing) + w._drop_connection = mock.Mock() + ws = MockWebSocket() + w._event_connected(ws) + w._event_ws_opened(None) + + d = w.close(wait=True) + self.check_outbound(ws, [u"bind"]) + self.assertNoResult(d) + self.assertEqual(w._drop_connection.mock_calls, [mock.call()]) + self.assertNoResult(d) + + w._ws_closed(True, None, None) + self.successResultOf(d) + def test_close_wait_1(self): # close after claiming the nameplate, but before opening the mailbox timing = DebugTiming() @@ -406,6 +438,98 @@ class Basic(unittest.TestCase): self.assertEqual(len(pieces), 3) # nameplate plus two words self.assert_(re.search(r'^\d+-\w+-\w+$', code), code) + def test_api_errors(self): + # doing things you're not supposed to do + pass + + def test_welcome_error(self): + # A welcome message could arrive at any time, with an [error] key + # that should make us halt. In practice, though, this gets sent as + # soon as the connection is established, which limits the possible + # states in which we might see it. + + timing = DebugTiming() + w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing) + w._drop_connection = mock.Mock() + ws = MockWebSocket() + w._event_connected(ws) + w._event_ws_opened(None) + self.check_outbound(ws, [u"bind"]) + + WE = wormhole.WelcomeError + d1 = w.get() + d2 = w.get_verifier() + d3 = w.get_code() + # TODO (tricky): test w.input_code + + self.assertNoResult(d1) + self.assertNoResult(d2) + self.assertNoResult(d3) + + w._signal_error(WE(u"you are not actually welcome")) + self.failureResultOf(d1, WE) + self.failureResultOf(d2, WE) + self.failureResultOf(d3, WE) + + # once the error is signalled, all API calls should fail + self.assertRaises(WE, w.send, u"foo") + self.assertRaises(WE, w.derive_key, u"foo") + self.failureResultOf(w.get(), WE) + self.failureResultOf(w.get_verifier(), WE) + + def test_confirm_error(self): + # we should only receive the "confirm" message after we receive the + # PAKE message, by which point we should know the key. If the + # confirmation message doesn't decrypt, we signal an error. + timing = DebugTiming() + w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing) + w._drop_connection = mock.Mock() + ws = MockWebSocket() + w._event_connected(ws) + w._event_ws_opened(None) + w.set_code(u"123-foo-bar") + response(w, type=u"claimed", mailbox=u"mb456") + + WP = wormhole.WrongPasswordError + d1 = w.get() + d2 = w.get_verifier() + self.assertNoResult(d1) + self.assertNoResult(d2) + + out = ws.outbound() + # [u"bind", u"claim", u"open", u"add"] + self.assertEqual(len(out), 4) + self.assertEqual(out[3][u"type"], u"add") + + sp2 = SPAKE2_Symmetric(b"", idSymmetric=wormhole.to_bytes(APPID)) + msg2 = sp2.start() + msg2_hex = hexlify(msg2).decode("ascii") + response(w, type=u"message", phase=u"pake", body=msg2_hex, side=u"s2") + self.assertNoResult(d1) + self.successResultOf(d2) # early get_verifier is unaffected + # TODO: get_verifier would be a lovely place to signal a confirmation + # error, but that's at odds with delivering the verifier as early as + # possible. The confirmation messages should be hot on the heels of + # the PAKE message that produced the verifier. Maybe get_verifier() + # should explicitly wait for confirm()? + + # sending a random confirm message will cause a confirmation error + confkey = w.derive_key(u"WRONG") + nonce = os.urandom(wormhole.CONFMSG_NONCE_LENGTH) + badconfirm = wormhole.make_confmsg(confkey, nonce) + badconfirm_hex = hexlify(badconfirm).decode("ascii") + response(w, type=u"message", phase=u"confirm", body=badconfirm_hex, + side=u"s2") + + self.failureResultOf(d1, WP) + + # once the error is signalled, all API calls should fail + self.assertRaises(WP, w.send, u"foo") + self.assertRaises(WP, w.derive_key, u"foo") + self.failureResultOf(w.get(), WP) + self.failureResultOf(w.get_verifier(), WP) + + # event orderings to exercise: # # * normal sender: set_code, send_phase1, connected, claimed, learn_msg2, diff --git a/src/wormhole/wormhole.py b/src/wormhole/wormhole.py index ccb7f9c..53cefca 100644 --- a/src/wormhole/wormhole.py +++ b/src/wormhole/wormhole.py @@ -14,7 +14,7 @@ from spake2 import SPAKE2_Symmetric from . import __version__ from . import codes #from .errors import ServerError, Timeout -from .errors import WrongPasswordError, UsageError +from .errors import WrongPasswordError, UsageError, WelcomeError from .timing import DebugTiming from hkdf import Hkdf @@ -203,7 +203,7 @@ class _WelcomeHandler: self._version_warning_displayed = True if "error" in welcome: - return self._signal_error(welcome["error"]) + return self._signal_error(WelcomeError(welcome["error"])) class _Wormhole: @@ -220,15 +220,16 @@ class _Wormhole: self._connected = None self._connection_waiters = [] self._started_get_code = False + self._get_code = None self._code = None self._nameplate_id = None self._nameplate_claimed = False self._nameplate_released = False - self._release_waiter = defer.Deferred() + self._release_waiter = None self._mailbox_id = None self._mailbox_opened = False self._mailbox_closed = False - self._close_waiter = defer.Deferred() + self._close_waiter = None self._flag_need_nameplate = True self._flag_need_to_see_mailbox_used = True self._flag_need_to_build_msg1 = True @@ -237,6 +238,7 @@ class _Wormhole: self._closed = False self._disconnect_waiter = defer.Deferred() self._mood = u"happy" + self._error = None self._get_verifier_called = False self._verifier_waiter = defer.Deferred() @@ -253,7 +255,24 @@ class _Wormhole: def _signal_error(self, error): # close the mailbox with an "errory" mood, errback all Deferreds, # record the error, fail all subsequent API calls - pass # XXX + if self.DEBUG: print("_signal_error", error) + self._error = error # causes new API calls to fail + for d in self._connection_waiters: + d.errback(error) + if self._get_code: + self._get_code._allocated_d.errback(error) + if not self._verifier_waiter.called: + self._verifier_waiter.errback(error) + for d in self._receive_waiters.values(): + d.errback(error) + + self._maybe_close(mood=u"errory") + if self._release_waiter and not self._release_waiter.called: + self._release_waiter.errback(error) + if self._close_waiter and not self._close_waiter.called: + self._close_waiter.errback(error) + # leave self._disconnect_waiter alone + if self.DEBUG: print("_signal_error done") def _start(self): d = self._connect() # causes stuff to happen @@ -346,8 +365,11 @@ class _Wormhole: with self._timing.add("API get_code"): yield self._when_connected() gc = _GetCode(code_length, self._ws_send_command, self._timing) + self._get_code = gc self._response_handle_allocated = gc._response_handle_allocated + # TODO: signal_error code = yield gc.go() + self._get_code = None self._nameplate_claimed = True # side-effect of allocation self._event_learned_code(code) returnValue(code) @@ -362,6 +384,7 @@ class _Wormhole: yield self._when_connected() ic = _InputCode(prompt, code_length, self._ws_send_command) self._response_handle_nameplates = ic._response_handle_nameplates + # TODO: signal_error code = yield ic.go() self._event_learned_code(code) returnValue(None) @@ -465,9 +488,11 @@ class _Wormhole: self._maybe_send_phase_messages() def get_verifier(self): + if self._error: return defer.fail(self._error) if self._closed: raise UsageError if self._get_verifier_called: raise UsageError self._get_verifier_called = True + # TODO: maybe have this wait on _event_received_confirm too return self._verifier_waiter def _event_computed_verifier(self, verifier): @@ -481,10 +506,12 @@ class _Wormhole: nonce = body[:CONFMSG_NONCE_LENGTH] if body != make_confmsg(confkey, nonce): # this makes all API calls fail + if self.DEBUG: print("CONFIRM FAILED") return self._signal_error(WrongPasswordError()) def send(self, outbound_data): + if self._error: raise self._error if not isinstance(outbound_data, type(b"")): raise TypeError(type(outbound_data)) if self._closed: raise UsageError @@ -540,6 +567,7 @@ class _Wormhole: self._flag_need_to_see_mailbox_used = False def derive_key(self, purpose, length=SecretBox.KEY_SIZE): + if self._error: raise self._error if not isinstance(purpose, type(u"")): raise TypeError(type(purpose)) if self._key is None: raise UsageError # call derive_key after get_verifier() or get() @@ -589,6 +617,7 @@ class _Wormhole: return data def get(self): + if self._error: return defer.fail(self._error) if self._closed: raise UsageError phase = u"%d" % self._next_receive_phase self._next_receive_phase += 1 @@ -598,21 +627,34 @@ class _Wormhole: d = self._receive_waiters[phase] = defer.Deferred() return d + def _maybe_close(self, mood): + if self._closed: + return + self.close(mood) + @inlineCallbacks def close(self, mood=None, wait=False): # TODO: auto-close on error, mostly for load-from-state + if self.DEBUG: print("close", wait) if self._closed: raise UsageError + self._closed = True if mood: self._mood = mood self._maybe_release_nameplate() self._maybe_close_mailbox() if wait: if self._nameplate_claimed: + if self.DEBUG: print("waiting for released") + self._release_waiter = defer.Deferred() yield self._release_waiter if self._mailbox_opened: + if self.DEBUG: print("waiting for closed") + self._close_waiter = defer.Deferred() yield self._close_waiter + if self.DEBUG: print("dropping connection") self._drop_connection() if wait: + if self.DEBUG: print("waiting for disconnect") yield self._disconnect_waiter def _maybe_release_nameplate(self): @@ -623,15 +665,19 @@ class _Wormhole: self._nameplate_released = True def _response_handle_released(self, msg): - self._release_waiter.callback(None) + if self._release_waiter and not self._release_waiter.called: + self._release_waiter.callback(None) def _maybe_close_mailbox(self): + if self.DEBUG: print("_maybe_close_mailbox", self._mailbox_opened, self._mailbox_closed) if self._mailbox_opened and not self._mailbox_closed: + if self.DEBUG: print(" sending close") self._ws_send_command(u"close", mood=self._mood) self._mailbox_closed = True def _response_handle_closed(self, msg): - self._close_waiter.callback(None) + if self._close_waiter and not self._close_waiter.called: + self._close_waiter.callback(None) def _drop_connection(self): self._ws.transport.loseConnection() # probably flushes From 7bcefa78e6dbb2b3c9d39b426181a9996bdb188d Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Mon, 23 May 2016 22:53:15 -0700 Subject: [PATCH 38/51] remove test_twisted, now in test_wormhole --- src/wormhole/test/test_twisted.py | 235 ------------------------------ 1 file changed, 235 deletions(-) delete mode 100644 src/wormhole/test/test_twisted.py diff --git a/src/wormhole/test/test_twisted.py b/src/wormhole/test/test_twisted.py deleted file mode 100644 index da87c97..0000000 --- a/src/wormhole/test/test_twisted.py +++ /dev/null @@ -1,235 +0,0 @@ -from __future__ import print_function -import json -from twisted.trial import unittest -from twisted.internet import reactor -from twisted.internet.defer import gatherResults, inlineCallbacks -from ..twisted.transcribe import (wormhole, wormhole_from_serialized, - UsageError, WrongPasswordError) -from .common import ServerBase - -APPID = u"appid" - -def Wormhole(appid, relayurl): - return wormhole(appid, relayurl, reactor) - -class Basic(ServerBase, unittest.TestCase): - - def doBoth(self, d1, d2): - return gatherResults([d1, d2], True) - - @inlineCallbacks - def test_basic(self): - w1 = Wormhole(APPID, self.relayurl) - w2 = Wormhole(APPID, self.relayurl) - code = yield w1.get_code() - w2.set_code(code) - yield self.doBoth(w1.send(b"data1"), w2.send(b"data2")) - dl = yield self.doBoth(w1.get(), w2.get()) - (dataX, dataY) = dl - self.assertEqual(dataX, b"data2") - self.assertEqual(dataY, b"data1") - yield self.doBoth(w1.close(), w2.close()) - - @inlineCallbacks - def test_same_message(self): - # the two sides use random nonces for their messages, so it's ok for - # both to try and send the same body: they'll result in distinct - # encrypted messages - w1 = Wormhole(APPID, self.relayurl) - w2 = Wormhole(APPID, self.relayurl) - code = yield w1.get_code() - w2.set_code(code) - yield self.doBoth(w1.send(b"data"), w2.send(b"data")) - dl = yield self.doBoth(w1.get(), w2.get()) - (dataX, dataY) = dl - self.assertEqual(dataX, b"data") - self.assertEqual(dataY, b"data") - yield self.doBoth(w1.close(), w2.close()) - - @inlineCallbacks - def test_interleaved(self): - w1 = Wormhole(APPID, self.relayurl) - w2 = Wormhole(APPID, self.relayurl) - code = yield w1.get_code() - w2.set_code(code) - res = yield self.doBoth(w1.send(b"data1"), w2.get()) - (_, dataY) = res - self.assertEqual(dataY, b"data1") - dl = yield self.doBoth(w1.get(), w2.send(b"data2")) - (dataX, _) = dl - self.assertEqual(dataX, b"data2") - yield self.doBoth(w1.close(), w2.close()) - - @inlineCallbacks - def test_fixed_code(self): - w1 = Wormhole(APPID, self.relayurl) - w2 = Wormhole(APPID, self.relayurl) - w1.set_code(u"123-purple-elephant") - w2.set_code(u"123-purple-elephant") - yield self.doBoth(w1.send(b"data1"), w2.send(b"data2")) - dl = yield self.doBoth(w1.get(), w2.get()) - (dataX, dataY) = dl - self.assertEqual(dataX, b"data2") - self.assertEqual(dataY, b"data1") - yield self.doBoth(w1.close(), w2.close()) - - - @inlineCallbacks - def test_multiple_messages(self): - w1 = Wormhole(APPID, self.relayurl) - w2 = Wormhole(APPID, self.relayurl) - w1.set_code(u"123-purple-elephant") - w2.set_code(u"123-purple-elephant") - yield self.doBoth(w1.send(b"data1"), w2.send(b"data2")) - yield self.doBoth(w1.send(b"data3"), w2.send(b"data4")) - dl = yield self.doBoth(w1.get(), w2.get()) - (dataX, dataY) = dl - self.assertEqual(dataX, b"data2") - self.assertEqual(dataY, b"data1") - dl = yield self.doBoth(w1.get(), w2.get()) - (dataX, dataY) = dl - self.assertEqual(dataX, b"data4") - self.assertEqual(dataY, b"data3") - yield self.doBoth(w1.close(), w2.close()) - - @inlineCallbacks - def test_multiple_messages_2(self): - w1 = Wormhole(APPID, self.relayurl) - w2 = Wormhole(APPID, self.relayurl) - w1.set_code(u"123-purple-elephant") - w2.set_code(u"123-purple-elephant") - # TODO: set_code should be sufficient to kick things off, but for now - # we must also let both sides do at least one send() or get() - yield self.doBoth(w1.send(b"data1"), w2.send(b"ignored")) - yield w1.get() - yield w1.send(b"data2") - yield w1.send(b"data3") - data = yield w2.get() - self.assertEqual(data, b"data1") - data = yield w2.get() - self.assertEqual(data, b"data2") - data = yield w2.get() - self.assertEqual(data, b"data3") - yield self.doBoth(w1.close(), w2.close()) - - @inlineCallbacks - def test_wrong_password(self): - w1 = Wormhole(APPID, self.relayurl) - w2 = Wormhole(APPID, self.relayurl) - code = yield w1.get_code() - w2.set_code(code+"not") - - # w2 can't throw WrongPasswordError until it sees a CONFIRM message, - # and w1 won't send CONFIRM until it sees a PAKE message, which w2 - # won't send until we call get. So we need both sides to be - # running at the same time for this test. - d1 = w1.send(b"data1") - # at this point, w1 should be waiting for w2.PAKE - - yield self.assertFailure(w2.get(), WrongPasswordError) - # * w2 will send w2.PAKE, wait for (and get) w1.PAKE, compute a key, - # send w2.CONFIRM, then wait for w1.DATA. - # * w1 will get w2.PAKE, compute a key, send w1.CONFIRM. - # * w1 might also get w2.CONFIRM, and may notice the error before it - # sends w1.CONFIRM, in which case the wait=True will signal an - # error inside _get_master_key() (inside send), and d1 will - # errback. - # * but w1 might not see w2.CONFIRM yet, in which case it won't - # errback until we do w1.get() - # * w2 gets w1.CONFIRM, notices the error, records it. - # * w2 (waiting for w1.DATA) wakes up, sees the error, throws - # * meanwhile w1 finishes sending its data. w2.CONFIRM may or may not - # have arrived by then - try: - yield d1 - except WrongPasswordError: - pass - - # When we ask w1 to get(), one of two things might happen: - # * if w2.CONFIRM arrived already, it will have recorded the error. - # When w1.get() sleeps (waiting for w2.DATA), we'll notice the - # error before sleeping, and throw WrongPasswordError - # * if w2.CONFIRM hasn't arrived yet, we'll sleep. When w2.CONFIRM - # arrives, we notice and record the error, and wake up, and throw - - # Note that we didn't do w2.send(), so we're hoping that w1 will - # have enough information to detect the error before it sleeps - # (waiting for w2.DATA). Checking for the error both before sleeping - # and after waking up makes this happen. - - # so now w1 should have enough information to throw too - yield self.assertFailure(w1.get(), WrongPasswordError) - - # both sides are closed automatically upon error, but it's still - # legal to call .close(), and should be idempotent - yield self.doBoth(w1.close(), w2.close()) - - @inlineCallbacks - def test_no_confirm(self): - # newer versions (which check confirmations) should will work with - # older versions (that don't send confirmations) - w1 = Wormhole(APPID, self.relayurl) - w1._send_confirm = False - w2 = Wormhole(APPID, self.relayurl) - - code = yield w1.get_code() - w2.set_code(code) - dl = yield self.doBoth(w1.send(b"data1"), w2.get()) - self.assertEqual(dl[1], b"data1") - dl = yield self.doBoth(w1.get(), w2.send(b"data2")) - self.assertEqual(dl[0], b"data2") - yield self.doBoth(w1.close(), w2.close()) - - @inlineCallbacks - def test_verifier(self): - w1 = Wormhole(APPID, self.relayurl) - w2 = Wormhole(APPID, self.relayurl) - code = yield w1.get_code() - w2.set_code(code) - res = yield self.doBoth(w1.get_verifier(), w2.get_verifier()) - v1, v2 = res - self.failUnlessEqual(type(v1), type(b"")) - self.failUnlessEqual(v1, v2) - yield self.doBoth(w1.send(b"data1"), w2.send(b"data2")) - dl = yield self.doBoth(w1.get(), w2.get()) - (dataX, dataY) = dl - self.assertEqual(dataX, b"data2") - self.assertEqual(dataY, b"data1") - yield self.doBoth(w1.close(), w2.close()) - - @inlineCallbacks - def test_errors(self): - w1 = Wormhole(APPID, self.relayurl) - yield self.assertFailure(w1.get_verifier(), UsageError) - yield self.assertFailure(w1.send(b"data"), UsageError) - yield self.assertFailure(w1.get(), UsageError) - w1.set_code(u"123-purple-elephant") - yield self.assertRaises(UsageError, w1.set_code, u"123-nope") - yield self.assertFailure(w1.get_code(), UsageError) - w2 = Wormhole(APPID, self.relayurl) - yield w2.get_code() - yield self.assertFailure(w2.get_code(), UsageError) - yield self.doBoth(w1.close(), w2.close()) - - @inlineCallbacks - def test_serialize(self): - w1 = Wormhole(APPID, self.relayurl) - self.assertRaises(UsageError, w1.serialize) # too early - w2 = Wormhole(APPID, self.relayurl) - code = yield w1.get_code() - self.assertRaises(UsageError, w2.serialize) # too early - w2.set_code(code) - w2.serialize() # ok - s = w1.serialize() - self.assertEqual(type(s), type("")) - unpacked = json.loads(s) # this is supposed to be JSON - self.assertEqual(type(unpacked), dict) - - self.new_w1 = wormhole_from_serialized(s, reactor) - yield self.doBoth(self.new_w1.send(b"data1"), w2.send(b"data2")) - dl = yield self.doBoth(self.new_w1.get(), w2.get()) - (dataX, dataY) = dl - self.assertEqual((dataX, dataY), (b"data2", b"data1")) - self.assertRaises(UsageError, w2.serialize) # too late - yield gatherResults([w1.close(), w2.close(), self.new_w1.close()], True) - From e11a6f82431c64b33ab0bad8854e85f8a4cb85fc Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Mon, 23 May 2016 22:53:00 -0700 Subject: [PATCH 39/51] new connection management, test_wormhole passes --- events.dot | 67 +++- src/wormhole/errors.py | 3 + src/wormhole/server/rendezvous_websocket.py | 2 +- src/wormhole/test/test_wormhole.py | 359 +++++++++++--------- src/wormhole/wormhole.py | 306 +++++++++++------ 5 files changed, 459 insertions(+), 278 deletions(-) diff --git a/events.dot b/events.dot index 66ea93b..f6aa1c3 100644 --- a/events.dot +++ b/events.dot @@ -1,31 +1,39 @@ digraph { + api_get_code [label="get_code" shape="hexagon" color="red"] + api_input_code [label="input_code" shape="hexagon" color="red"] + api_set_code [label="set_code" shape="hexagon" color="red"] + send [label="API\nsend" shape="hexagon" color="red"] + get [label="API\nget" shape="hexagon" color="red"] + close [label="API\nclose" shape="hexagon" color="red"] + + event_connected [label="connected" shape="box"] event_learned_code [label="learned\ncode" shape="box"] event_learned_nameplate [label="learned\nnameplate" shape="box"] - event_learned_mailbox [label="learned\nmailbox" shape="box"] - event_connected [label="connected" shape="box"] + event_received_mailbox [label="received\nmailbox" shape="box"] + event_opened_mailbox [label="opened\nmailbox" shape="box"] event_built_msg1 [label="built\nmsg1" shape="box"] event_mailbox_used [label="mailbox\nused" shape="box"] event_learned_PAKE [label="learned\nmsg2" shape="box"] event_established_key [label="established\nkey" shape="box"] event_computed_verifier [label="computed\nverifier" shape="box"] event_received_confirm [label="received\nconfirm" shape="box"] + event_received_message [label="received\nmessage" shape="box"] + event_received_released [label="ack\nreleased" shape="box"] + event_received_closed [label="ack\nclosed" shape="box"] event_connected -> api_get_code event_connected -> api_input_code - api_get_code [label="get_code" shape="hexagon"] - api_input_code [label="input_code" shape="hexagon"] - api_set_code [label="set_code" shape="hexagon"] api_get_code -> event_learned_code api_input_code -> event_learned_code api_set_code -> event_learned_code maybe_build_msg1 [label="build\nmsg1"] - maybe_get_mailbox [label="get\nmailbox"] + maybe_claim_nameplate [label="claim\nnameplate"] maybe_send_pake [label="send\npake"] maybe_send_phase_messages [label="send\nphase\nmessages"] - event_connected -> maybe_get_mailbox + event_connected -> maybe_claim_nameplate event_connected -> maybe_send_pake event_built_msg1 -> maybe_send_pake @@ -34,22 +42,23 @@ digraph { event_learned_code -> event_learned_nameplate maybe_build_msg1 -> event_built_msg1 - event_learned_nameplate -> maybe_get_mailbox + event_learned_nameplate -> maybe_claim_nameplate + maybe_claim_nameplate -> event_received_mailbox [style="dashed"] - maybe_get_mailbox -> event_learned_mailbox [style="dashed"] - maybe_get_mailbox -> event_mailbox_used [style="dashed"] - maybe_get_mailbox -> event_learned_PAKE [style="dashed"] - maybe_get_mailbox -> event_received_confirm [style="dashed"] + event_received_mailbox -> event_opened_mailbox + maybe_claim_nameplate -> event_learned_PAKE [style="dashed"] + maybe_claim_nameplate -> event_received_confirm [style="dashed"] - event_learned_mailbox -> event_learned_PAKE [style="dashed"] + event_opened_mailbox -> event_learned_PAKE [style="dashed"] event_learned_PAKE -> event_mailbox_used [style="dashed"] - event_mailbox_used -> event_received_confirm [style="dashed"] + event_learned_PAKE -> event_received_confirm [style="dashed"] + event_received_confirm -> event_received_message [style="dashed"] - send [label="API\nsend" shape="hexagon"] send -> maybe_send_phase_messages - event_mailbox_used -> release - event_learned_mailbox -> maybe_send_pake - event_learned_mailbox -> maybe_send_phase_messages + release_nameplate [label="release\nnameplate"] + event_mailbox_used -> release_nameplate + event_opened_mailbox -> maybe_send_pake + event_opened_mailbox -> maybe_send_phase_messages event_learned_PAKE -> event_established_key event_established_key -> event_computed_verifier @@ -59,4 +68,26 @@ digraph { event_computed_verifier -> check_verifier event_received_confirm -> check_verifier + check_verifier -> error + event_received_message -> error + event_received_message -> get + event_established_key -> get + + close -> close_mailbox + close -> release_nameplate + error [label="signal\nerror"] + error -> close_mailbox + error -> release_nameplate + + release_nameplate -> event_received_released [style="dashed"] + close_mailbox [label="close\nmailbox"] + close_mailbox -> event_received_closed [style="dashed"] + + maybe_close_websocket [label="close\nwebsocket"] + event_received_released -> maybe_close_websocket + event_received_closed -> maybe_close_websocket + maybe_close_websocket -> event_websocket_closed [style="dashed"] + event_websocket_closed [label="websocket\nclosed"] + + } diff --git a/src/wormhole/errors.py b/src/wormhole/errors.py index 0c140ac..141523d 100644 --- a/src/wormhole/errors.py +++ b/src/wormhole/errors.py @@ -41,5 +41,8 @@ class ReflectionAttack(Exception): class UsageError(Exception): """The programmer did something wrong.""" +class WormholeClosedError(UsageError): + """API calls may not be made after close() is called.""" + class TransferError(Exception): """Something bad happened and the transfer failed.""" diff --git a/src/wormhole/server/rendezvous_websocket.py b/src/wormhole/server/rendezvous_websocket.py index 05e2700..2a19ab2 100644 --- a/src/wormhole/server/rendezvous_websocket.py +++ b/src/wormhole/server/rendezvous_websocket.py @@ -68,7 +68,7 @@ from .rendezvous import CrowdedError, SidedMessage # -> {type: "add", phase: str, body: hex} # will send echo in a "message" # # -> {type: "close", mood: str} -> closed -# <- {type: "closed", status: waiting|deleted} +# <- {type: "closed"} # # <- {type: "error", error: str, orig: {}} # in response to malformed msgs diff --git a/src/wormhole/test/test_wormhole.py b/src/wormhole/test/test_wormhole.py index ad8e5ef..c0f2823 100644 --- a/src/wormhole/test/test_wormhole.py +++ b/src/wormhole/test/test_wormhole.py @@ -7,8 +7,10 @@ from twisted.internet import reactor from twisted.internet.defer import Deferred, gatherResults, inlineCallbacks from .common import ServerBase from .. import wormhole +from ..errors import WrongPasswordError, WelcomeError, UsageError from spake2 import SPAKE2_Symmetric from ..timing import DebugTiming +from nacl.secret import SecretBox APPID = u"appid" @@ -87,11 +89,10 @@ class Welcome(unittest.TestCase): self.assertEqual(len(se.mock_calls), 1) self.assertEqual(len(se.mock_calls[0][1]), 1) # posargs we = se.mock_calls[0][1][0] - self.assertIsInstance(we, wormhole.WelcomeError) + self.assertIsInstance(we, WelcomeError) self.assertEqual(we.args, (u"oops",)) # alas WelcomeError instances don't compare against each other - #self.assertEqual(se.mock_calls, - # [mock.call(wormhole.WelcomeError(u"oops"))]) + #self.assertEqual(se.mock_calls, [mock.call(WelcomeError(u"oops"))]) class InputCode(unittest.TestCase): def test_list(self): @@ -171,7 +172,7 @@ class Basic(unittest.TestCase): self.assertTrue(w._flag_need_to_build_msg1) self.assertTrue(w._flag_need_to_send_PAKE) - v = w.get_verifier() + v = w.verify() w._drop_connection = mock.Mock() ws = MockWebSocket() @@ -204,7 +205,7 @@ class Basic(unittest.TestCase): # that triggers event_learned_mailbox, which should send open() and # PAKE - self.assertTrue(w._mailbox_opened) + self.assertEqual(w._mailbox_state, wormhole.OPEN) out = ws.outbound() self.assertEqual(len(out), 2) self.check_out(out[0], type=u"open", mailbox=u"mb456") @@ -232,10 +233,11 @@ class Basic(unittest.TestCase): self.check_out(out[0], type=u"release") self.check_out(out[1], type=u"add", phase=u"confirm") verifier = self.successResultOf(v) - self.assertEqual(verifier, w.derive_key(u"wormhole:verifier")) + self.assertEqual(verifier, + w.derive_key(u"wormhole:verifier", SecretBox.KEY_SIZE)) # hearing a valid confirmation message doesn't throw an error - confkey = w.derive_key(u"wormhole:confirmation") + confkey = w.derive_key(u"wormhole:confirmation", SecretBox.KEY_SIZE) nonce = os.urandom(wormhole.CONFMSG_NONCE_LENGTH) confirm2 = wormhole.make_confmsg(confkey, nonce) confirm2_hex = hexlify(confirm2).decode("ascii") @@ -249,7 +251,7 @@ class Basic(unittest.TestCase): self.check_out(out[0], type=u"add", phase=u"0") # decrypt+check the outbound message p0_outbound = unhexlify(out[0][u"body"].encode("ascii")) - msgkey0 = w.derive_key(u"wormhole:phase:0") + msgkey0 = w.derive_key(u"wormhole:phase:0", SecretBox.KEY_SIZE) p0_plaintext = w._decrypt_data(msgkey0, p0_outbound) self.assertEqual(p0_plaintext, b"phase0-outbound") @@ -268,7 +270,7 @@ class Basic(unittest.TestCase): self.assertIn(u"0", w._received_messages) # receiving an inbound message will queue it until get() is called - msgkey1 = w.derive_key(u"wormhole:phase:1") + msgkey1 = w.derive_key(u"wormhole:phase:1", SecretBox.KEY_SIZE) p1_inbound = w._encrypt_data(msgkey1, b"phase1-inbound") p1_inbound_hex = hexlify(p1_inbound).decode("ascii") response(w, type=u"message", phase=u"1", body=p1_inbound_hex, @@ -284,9 +286,34 @@ class Basic(unittest.TestCase): out = ws.outbound() self.assertEqual(len(out), 1) self.check_out(out[0], type=u"close", mood=u"happy") + self.assertEqual(w._drop_connection.mock_calls, []) + + response(w, type=u"released") + self.assertEqual(w._drop_connection.mock_calls, []) + response(w, type=u"closed") self.assertEqual(w._drop_connection.mock_calls, [mock.call()]) + w._ws_closed(True, None, None) def test_close_wait_0(self): + # Close before the connection is established. The connection still + # gets established, but it is then torn down before sending anything. + timing = DebugTiming() + w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing) + w._drop_connection = mock.Mock() + + d = w.close(wait=True) + self.assertNoResult(d) + + ws = MockWebSocket() + w._event_connected(ws) + w._event_ws_opened(None) + self.assertEqual(w._drop_connection.mock_calls, [mock.call()]) + self.assertNoResult(d) + + w._ws_closed(True, None, None) + self.successResultOf(d) + + def test_close_wait_1(self): # close before even claiming the nameplate timing = DebugTiming() w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing) @@ -304,8 +331,40 @@ class Basic(unittest.TestCase): w._ws_closed(True, None, None) self.successResultOf(d) - def test_close_wait_1(self): + def test_close_wait_2(self): + # Close after claiming the nameplate, but before opening the mailbox. + # The 'claimed' response arrives before we close. + timing = DebugTiming() + w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing) + w._drop_connection = mock.Mock() + ws = MockWebSocket() + w._event_connected(ws) + w._event_ws_opened(None) + CODE = u"123-foo-bar" + w.set_code(CODE) + self.check_outbound(ws, [u"bind", u"claim"]) + + response(w, type=u"claimed", mailbox=u"mb123") + + d = w.close(wait=True) + self.check_outbound(ws, [u"open", u"add", u"release", u"close"]) + self.assertNoResult(d) + self.assertEqual(w._drop_connection.mock_calls, []) + + response(w, type=u"released") + self.assertNoResult(d) + self.assertEqual(w._drop_connection.mock_calls, []) + + response(w, type=u"closed") + self.assertEqual(w._drop_connection.mock_calls, [mock.call()]) + self.assertNoResult(d) + + w._ws_closed(True, None, None) + self.successResultOf(d) + + def test_close_wait_3(self): # close after claiming the nameplate, but before opening the mailbox + # The 'claimed' response arrives after we start to close. timing = DebugTiming() w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing) w._drop_connection = mock.Mock() @@ -317,6 +376,7 @@ class Basic(unittest.TestCase): self.check_outbound(ws, [u"bind", u"claim"]) d = w.close(wait=True) + response(w, type=u"claimed", mailbox=u"mb123") self.check_outbound(ws, [u"release"]) self.assertNoResult(d) self.assertEqual(w._drop_connection.mock_calls, []) @@ -328,7 +388,7 @@ class Basic(unittest.TestCase): w._ws_closed(True, None, None) self.successResultOf(d) - def test_close_wait_2(self): + def test_close_wait_4(self): # close after both claiming the nameplate and opening the mailbox timing = DebugTiming() w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing) @@ -357,7 +417,7 @@ class Basic(unittest.TestCase): w._ws_closed(True, None, None) self.successResultOf(d) - def test_close_wait_3(self): + def test_close_wait_5(self): # close after claiming the nameplate, opening the mailbox, then # releasing the nameplate timing = DebugTiming() @@ -371,7 +431,7 @@ class Basic(unittest.TestCase): response(w, type=u"claimed", mailbox=u"mb456") w._key = b"" - msgkey = w.derive_key(u"wormhole:phase:misc") + msgkey = w.derive_key(u"wormhole:phase:misc", SecretBox.KEY_SIZE) p1_inbound = w._encrypt_data(msgkey, b"") p1_inbound_hex = hexlify(p1_inbound).decode("ascii") response(w, type=u"message", phase=u"misc", side=u"side2", @@ -395,6 +455,11 @@ class Basic(unittest.TestCase): w._ws_closed(True, None, None) self.successResultOf(d) + def test_close_errbacks(self): + # make sure the Deferreds returned by verify() and get() are properly + # errbacked upon close + pass + def test_get_code_mock(self): timing = DebugTiming() w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing) @@ -438,6 +503,11 @@ class Basic(unittest.TestCase): self.assertEqual(len(pieces), 3) # nameplate plus two words self.assert_(re.search(r'^\d+-\w+-\w+$', code), code) + def test_verifier(self): + # make sure verify() can be called both before and after the verifier + # is computed + pass + def test_api_errors(self): # doing things you're not supposed to do pass @@ -456,9 +526,8 @@ class Basic(unittest.TestCase): w._event_ws_opened(None) self.check_outbound(ws, [u"bind"]) - WE = wormhole.WelcomeError d1 = w.get() - d2 = w.get_verifier() + d2 = w.verify() d3 = w.get_code() # TODO (tricky): test w.input_code @@ -466,16 +535,17 @@ class Basic(unittest.TestCase): self.assertNoResult(d2) self.assertNoResult(d3) - w._signal_error(WE(u"you are not actually welcome")) - self.failureResultOf(d1, WE) - self.failureResultOf(d2, WE) - self.failureResultOf(d3, WE) + w._signal_error(WelcomeError(u"you are not actually welcome"), u"pouty") + self.failureResultOf(d1, WelcomeError) + self.failureResultOf(d2, WelcomeError) + self.failureResultOf(d3, WelcomeError) # once the error is signalled, all API calls should fail - self.assertRaises(WE, w.send, u"foo") - self.assertRaises(WE, w.derive_key, u"foo") - self.failureResultOf(w.get(), WE) - self.failureResultOf(w.get_verifier(), WE) + self.assertRaises(WelcomeError, w.send, u"foo") + self.assertRaises(WelcomeError, + w.derive_key, u"foo", SecretBox.KEY_SIZE) + self.failureResultOf(w.get(), WelcomeError) + self.failureResultOf(w.verify(), WelcomeError) def test_confirm_error(self): # we should only receive the "confirm" message after we receive the @@ -490,9 +560,8 @@ class Basic(unittest.TestCase): w.set_code(u"123-foo-bar") response(w, type=u"claimed", mailbox=u"mb456") - WP = wormhole.WrongPasswordError d1 = w.get() - d2 = w.get_verifier() + d2 = w.verify() self.assertNoResult(d1) self.assertNoResult(d2) @@ -506,28 +575,25 @@ class Basic(unittest.TestCase): msg2_hex = hexlify(msg2).decode("ascii") response(w, type=u"message", phase=u"pake", body=msg2_hex, side=u"s2") self.assertNoResult(d1) - self.successResultOf(d2) # early get_verifier is unaffected - # TODO: get_verifier would be a lovely place to signal a confirmation - # error, but that's at odds with delivering the verifier as early as - # possible. The confirmation messages should be hot on the heels of - # the PAKE message that produced the verifier. Maybe get_verifier() - # should explicitly wait for confirm()? + self.successResultOf(d2) # early verify is unaffected + # TODO: change verify() to wait for "confirm" # sending a random confirm message will cause a confirmation error - confkey = w.derive_key(u"WRONG") + confkey = w.derive_key(u"WRONG", SecretBox.KEY_SIZE) nonce = os.urandom(wormhole.CONFMSG_NONCE_LENGTH) badconfirm = wormhole.make_confmsg(confkey, nonce) badconfirm_hex = hexlify(badconfirm).decode("ascii") response(w, type=u"message", phase=u"confirm", body=badconfirm_hex, side=u"s2") - self.failureResultOf(d1, WP) + self.failureResultOf(d1, WrongPasswordError) # once the error is signalled, all API calls should fail - self.assertRaises(WP, w.send, u"foo") - self.assertRaises(WP, w.derive_key, u"foo") - self.failureResultOf(w.get(), WP) - self.failureResultOf(w.get_verifier(), WP) + self.assertRaises(WrongPasswordError, w.send, u"foo") + self.assertRaises(WrongPasswordError, + w.derive_key, u"foo", SecretBox.KEY_SIZE) + self.failureResultOf(w.get(), WrongPasswordError) + self.failureResultOf(w.verify(), WrongPasswordError) # event orderings to exercise: @@ -562,60 +628,88 @@ class Wormholes(ServerBase, unittest.TestCase): yield w1.close(wait=True) yield w2.close(wait=True) -class Off: - @inlineCallbacks def test_same_message(self): # the two sides use random nonces for their messages, so it's ok for # both to try and send the same body: they'll result in distinct # encrypted messages - w1 = Wormhole(APPID, self.relayurl) - w2 = Wormhole(APPID, self.relayurl) + w1 = wormhole.wormhole(APPID, self.relayurl, reactor) + w2 = wormhole.wormhole(APPID, self.relayurl, reactor) code = yield w1.get_code() w2.set_code(code) - yield self.doBoth(w1.send(b"data"), w2.send(b"data")) - dl = yield self.doBoth(w1.get(), w2.get()) - (dataX, dataY) = dl + w1.send(b"data") + w2.send(b"data") + dataX = yield w1.get() + dataY = yield w2.get() self.assertEqual(dataX, b"data") self.assertEqual(dataY, b"data") - yield self.doBoth(w1.close(), w2.close()) + yield w1.close(wait=True) + yield w2.close(wait=True) @inlineCallbacks def test_interleaved(self): - w1 = Wormhole(APPID, self.relayurl) - w2 = Wormhole(APPID, self.relayurl) + w1 = wormhole.wormhole(APPID, self.relayurl, reactor) + w2 = wormhole.wormhole(APPID, self.relayurl, reactor) code = yield w1.get_code() w2.set_code(code) - res = yield self.doBoth(w1.send(b"data1"), w2.get()) - (_, dataY) = res + w1.send(b"data1") + dataY = yield w2.get() self.assertEqual(dataY, b"data1") - dl = yield self.doBoth(w1.get(), w2.send(b"data2")) - (dataX, _) = dl + d = w1.get() + w2.send(b"data2") + dataX = yield d self.assertEqual(dataX, b"data2") - yield self.doBoth(w1.close(), w2.close()) + yield w1.close(wait=True) + yield w2.close(wait=True) + + @inlineCallbacks + def test_unidirectional(self): + w1 = wormhole.wormhole(APPID, self.relayurl, reactor) + w2 = wormhole.wormhole(APPID, self.relayurl, reactor) + code = yield w1.get_code() + w2.set_code(code) + w1.send(b"data1") + dataY = yield w2.get() + self.assertEqual(dataY, b"data1") + yield w1.close(wait=True) + yield w2.close(wait=True) + + @inlineCallbacks + def test_early(self): + w1 = wormhole.wormhole(APPID, self.relayurl, reactor) + w1.send(b"data1") + w2 = wormhole.wormhole(APPID, self.relayurl, reactor) + d = w2.get() + w1.set_code(u"123-abc-def") + w2.set_code(u"123-abc-def") + dataY = yield d + self.assertEqual(dataY, b"data1") + yield w1.close(wait=True) + yield w2.close(wait=True) @inlineCallbacks def test_fixed_code(self): - w1 = Wormhole(APPID, self.relayurl) - w2 = Wormhole(APPID, self.relayurl) + w1 = wormhole.wormhole(APPID, self.relayurl, reactor) + w2 = wormhole.wormhole(APPID, self.relayurl, reactor) w1.set_code(u"123-purple-elephant") w2.set_code(u"123-purple-elephant") - yield self.doBoth(w1.send(b"data1"), w2.send(b"data2")) + w1.send(b"data1"), w2.send(b"data2") dl = yield self.doBoth(w1.get(), w2.get()) (dataX, dataY) = dl self.assertEqual(dataX, b"data2") self.assertEqual(dataY, b"data1") - yield self.doBoth(w1.close(), w2.close()) + yield w1.close(wait=True) + yield w2.close(wait=True) @inlineCallbacks def test_multiple_messages(self): - w1 = Wormhole(APPID, self.relayurl) - w2 = Wormhole(APPID, self.relayurl) + w1 = wormhole.wormhole(APPID, self.relayurl, reactor) + w2 = wormhole.wormhole(APPID, self.relayurl, reactor) w1.set_code(u"123-purple-elephant") w2.set_code(u"123-purple-elephant") - yield self.doBoth(w1.send(b"data1"), w2.send(b"data2")) - yield self.doBoth(w1.send(b"data3"), w2.send(b"data4")) + w1.send(b"data1"), w2.send(b"data2") + w1.send(b"data3"), w2.send(b"data4") dl = yield self.doBoth(w1.get(), w2.get()) (dataX, dataY) = dl self.assertEqual(dataX, b"data2") @@ -624,124 +718,69 @@ class Off: (dataX, dataY) = dl self.assertEqual(dataX, b"data4") self.assertEqual(dataY, b"data3") - yield self.doBoth(w1.close(), w2.close()) - - @inlineCallbacks - def test_multiple_messages_2(self): - w1 = Wormhole(APPID, self.relayurl) - w2 = Wormhole(APPID, self.relayurl) - w1.set_code(u"123-purple-elephant") - w2.set_code(u"123-purple-elephant") - # TODO: set_code should be sufficient to kick things off, but for now - # we must also let both sides do at least one send() or get() - yield self.doBoth(w1.send(b"data1"), w2.send(b"ignored")) - yield w1.get() - yield w1.send(b"data2") - yield w1.send(b"data3") - data = yield w2.get() - self.assertEqual(data, b"data1") - data = yield w2.get() - self.assertEqual(data, b"data2") - data = yield w2.get() - self.assertEqual(data, b"data3") - yield self.doBoth(w1.close(), w2.close()) + yield w1.close(wait=True) + yield w2.close(wait=True) @inlineCallbacks def test_wrong_password(self): - w1 = Wormhole(APPID, self.relayurl) - w2 = Wormhole(APPID, self.relayurl) + w1 = wormhole.wormhole(APPID, self.relayurl, reactor) + w2 = wormhole.wormhole(APPID, self.relayurl, reactor) code = yield w1.get_code() w2.set_code(code+"not") + # That's enough to allow both sides to discover the mismatch, but + # only after the confirmation message gets through. API calls that + # don't wait will appear to work until the mismatched confirmation + # message arrives. + w1.send(b"should still work") + w2.send(b"should still work") - # w2 can't throw WrongPasswordError until it sees a CONFIRM message, - # and w1 won't send CONFIRM until it sees a PAKE message, which w2 - # won't send until we call get. So we need both sides to be - # running at the same time for this test. - d1 = w1.send(b"data1") - # at this point, w1 should be waiting for w2.PAKE - + # API calls that wait (i.e. get) will errback yield self.assertFailure(w2.get(), WrongPasswordError) - # * w2 will send w2.PAKE, wait for (and get) w1.PAKE, compute a key, - # send w2.CONFIRM, then wait for w1.DATA. - # * w1 will get w2.PAKE, compute a key, send w1.CONFIRM. - # * w1 might also get w2.CONFIRM, and may notice the error before it - # sends w1.CONFIRM, in which case the wait=True will signal an - # error inside _get_master_key() (inside send), and d1 will - # errback. - # * but w1 might not see w2.CONFIRM yet, in which case it won't - # errback until we do w1.get() - # * w2 gets w1.CONFIRM, notices the error, records it. - # * w2 (waiting for w1.DATA) wakes up, sees the error, throws - # * meanwhile w1 finishes sending its data. w2.CONFIRM may or may not - # have arrived by then - try: - yield d1 - except WrongPasswordError: - pass - - # When we ask w1 to get(), one of two things might happen: - # * if w2.CONFIRM arrived already, it will have recorded the error. - # When w1.get() sleeps (waiting for w2.DATA), we'll notice the - # error before sleeping, and throw WrongPasswordError - # * if w2.CONFIRM hasn't arrived yet, we'll sleep. When w2.CONFIRM - # arrives, we notice and record the error, and wake up, and throw - - # Note that we didn't do w2.send(), so we're hoping that w1 will - # have enough information to detect the error before it sleeps - # (waiting for w2.DATA). Checking for the error both before sleeping - # and after waking up makes this happen. - - # so now w1 should have enough information to throw too yield self.assertFailure(w1.get(), WrongPasswordError) - # both sides are closed automatically upon error, but it's still - # legal to call .close(), and should be idempotent - yield self.doBoth(w1.close(), w2.close()) - - @inlineCallbacks - def test_no_confirm(self): - # newer versions (which check confirmations) should will work with - # older versions (that don't send confirmations) - w1 = Wormhole(APPID, self.relayurl) - w1._send_confirm = False - w2 = Wormhole(APPID, self.relayurl) - - code = yield w1.get_code() - w2.set_code(code) - dl = yield self.doBoth(w1.send(b"data1"), w2.get()) - self.assertEqual(dl[1], b"data1") - dl = yield self.doBoth(w1.get(), w2.send(b"data2")) - self.assertEqual(dl[0], b"data2") - yield self.doBoth(w1.close(), w2.close()) + yield w1.close(wait=True) + yield w2.close(wait=True) + self.flushLoggedErrors(WrongPasswordError) @inlineCallbacks def test_verifier(self): - w1 = Wormhole(APPID, self.relayurl) - w2 = Wormhole(APPID, self.relayurl) + w1 = wormhole.wormhole(APPID, self.relayurl, reactor) + w2 = wormhole.wormhole(APPID, self.relayurl, reactor) code = yield w1.get_code() w2.set_code(code) - res = yield self.doBoth(w1.get_verifier(), w2.get_verifier()) - v1, v2 = res + v1 = yield w1.verify() + v2 = yield w2.verify() self.failUnlessEqual(type(v1), type(b"")) self.failUnlessEqual(v1, v2) - yield self.doBoth(w1.send(b"data1"), w2.send(b"data2")) - dl = yield self.doBoth(w1.get(), w2.get()) - (dataX, dataY) = dl + w1.send(b"data1") + w2.send(b"data2") + dataX = yield w1.get() + dataY = yield w2.get() self.assertEqual(dataX, b"data2") self.assertEqual(dataY, b"data1") - yield self.doBoth(w1.close(), w2.close()) + yield w1.close(wait=True) + yield w2.close(wait=True) + +class Errors(ServerBase, unittest.TestCase): + @inlineCallbacks + def test_codes_1(self): + w = wormhole.wormhole(APPID, self.relayurl, reactor) + # definitely too early + self.assertRaises(UsageError, w.derive_key, u"purpose", 12) + + w.set_code(u"123-purple-elephant") + # code can only be set once + self.assertRaises(UsageError, w.set_code, u"123-nope") + yield self.assertFailure(w.get_code(), UsageError) + yield self.assertFailure(w.input_code(), UsageError) + yield w.close(wait=True) @inlineCallbacks - def test_errors(self): - w1 = Wormhole(APPID, self.relayurl) - yield self.assertFailure(w1.get_verifier(), UsageError) - yield self.assertFailure(w1.send(b"data"), UsageError) - yield self.assertFailure(w1.get(), UsageError) - w1.set_code(u"123-purple-elephant") - yield self.assertRaises(UsageError, w1.set_code, u"123-nope") - yield self.assertFailure(w1.get_code(), UsageError) - w2 = Wormhole(APPID, self.relayurl) - yield w2.get_code() - yield self.assertFailure(w2.get_code(), UsageError) - yield self.doBoth(w1.close(), w2.close()) + def test_codes_2(self): + w = wormhole.wormhole(APPID, self.relayurl, reactor) + yield w.get_code() + self.assertRaises(UsageError, w.set_code, u"123-nope") + yield self.assertFailure(w.get_code(), UsageError) + yield self.assertFailure(w.input_code(), UsageError) + yield w.close(wait=True) diff --git a/src/wormhole/wormhole.py b/src/wormhole/wormhole.py index 53cefca..134aa0b 100644 --- a/src/wormhole/wormhole.py +++ b/src/wormhole/wormhole.py @@ -14,7 +14,8 @@ from spake2 import SPAKE2_Symmetric from . import __version__ from . import codes #from .errors import ServerError, Timeout -from .errors import WrongPasswordError, UsageError, WelcomeError +from .errors import (WrongPasswordError, UsageError, WelcomeError, + WormholeClosedError) from .timing import DebugTiming from hkdf import Hkdf @@ -205,6 +206,9 @@ class _WelcomeHandler: if "error" in welcome: return self._signal_error(WelcomeError(welcome["error"])) +# states for nameplates, mailboxes, and the websocket connection +(CLOSED, OPENING, OPEN, CLOSING) = ("closed", "opening", "open", "closing") + class _Wormhole: def __init__(self, appid, relay_url, reactor, tor_manager, timing): @@ -217,31 +221,28 @@ class _Wormhole: self._welcomer = _WelcomeHandler(self._ws_url, __version__, self._signal_error) self._side = hexlify(os.urandom(5)).decode("ascii") - self._connected = None + self._connection_state = CLOSED self._connection_waiters = [] self._started_get_code = False self._get_code = None self._code = None self._nameplate_id = None - self._nameplate_claimed = False - self._nameplate_released = False - self._release_waiter = None + self._nameplate_state = CLOSED self._mailbox_id = None - self._mailbox_opened = False - self._mailbox_closed = False - self._close_waiter = None + self._mailbox_state = CLOSED self._flag_need_nameplate = True self._flag_need_to_see_mailbox_used = True self._flag_need_to_build_msg1 = True self._flag_need_to_send_PAKE = True self._key = None - self._closed = False + self._close_called = False # the close() API has been called + self._closing = False # we've started shutdown self._disconnect_waiter = defer.Deferred() - self._mood = u"happy" self._error = None self._get_verifier_called = False - self._verifier_waiter = defer.Deferred() + self._verifier = None + self._verifier_waiter = None self._next_send_phase = 0 # send() queues plaintext here, waiting for a connection and the key @@ -252,33 +253,62 @@ class _Wormhole: self._receive_waiters = {} # phase -> Deferred self._received_messages = {} # phase -> plaintext - def _signal_error(self, error): - # close the mailbox with an "errory" mood, errback all Deferreds, - # record the error, fail all subsequent API calls - if self.DEBUG: print("_signal_error", error) - self._error = error # causes new API calls to fail - for d in self._connection_waiters: - d.errback(error) - if self._get_code: - self._get_code._allocated_d.errback(error) - if not self._verifier_waiter.called: - self._verifier_waiter.errback(error) - for d in self._receive_waiters.values(): - d.errback(error) + # API METHODS for applications to call - self._maybe_close(mood=u"errory") - if self._release_waiter and not self._release_waiter.called: - self._release_waiter.errback(error) - if self._close_waiter and not self._close_waiter.called: - self._close_waiter.errback(error) - # leave self._disconnect_waiter alone - if self.DEBUG: print("_signal_error done") + # You must use at least one of these entry points, to establish the + # wormhole code. Other APIs will stall or be queued until we have one. + + # entry point 1: generate a new code. returns a Deferred + def get_code(self, code_length=2): # XX rename to allocate_code()? create_? + return self._API_get_code(code_length) + + # entry point 2: interactively type in a code, with completion. returns + # Deferred + def input_code(self, prompt="Enter wormhole code: ", code_length=2): + return self._API_input_code(prompt, code_length) + + # entry point 3: paste in a fully-formed code. No return value. + def set_code(self, code): + self._API_set_code(code) + + # todo: restore-saved-state entry points + + def verify(self): + """Returns a Deferred that fires when we've heard back from the other + side, and have confirmed that they used the right wormhole code. When + successful, the Deferred fires with a "verifier" (a bytestring) which + can be compared out-of-band before making additional API calls. If + they used the wrong wormhole code, the Deferred errbacks with + WrongPasswordError. + """ + return self._API_verify() + + def send(self, outbound_data): + return self._API_send(outbound_data) + + def get(self): + return self._API_get() + + def derive_key(self, purpose, length): + """Derive a new key from the established wormhole channel for some + other purpose. This is a deterministic randomized function of the + session key and the 'purpose' string (unicode/py3-string). This + cannot be called until verify() or get() has fired. + """ + return self._API_derive_key(purpose, length) + + def close(self, wait=False): + return self._API_close(wait) + + # INTERNAL METHODS beyond here def _start(self): d = self._connect() # causes stuff to happen d.addErrback(log.err) return d # fires when connection is established, if you care + + def _make_endpoint(self, hostname, port): if self._tor_manager: return self._tor_manager.get_endpoint_for(hostname, port) @@ -289,6 +319,7 @@ class _Wormhole: # TODO: if we lose the connection, make a new one, re-establish the # state assert self._side + self._connection_state = OPENING p = urlparse(self._ws_url) f = WSFactory(self._ws_url) f.wormhole = self @@ -311,16 +342,18 @@ class _Wormhole: self._ws_t = self._timing.add("websocket") def _event_ws_opened(self, _): - self._connected = True + self._connection_state = OPEN + if self._closing: + return self._maybe_finished_closing() self._ws_send_command(u"bind", appid=self._appid, side=self._side) - self._maybe_get_mailbox() + self._maybe_claim_nameplate() self._maybe_send_pake() waiters, self._connection_waiters = self._connection_waiters, [] for d in waiters: d.callback(None) def _when_connected(self): - if self._connected: + if self._connection_state == OPEN: return defer.succeed(None) d = defer.Deferred() self._connection_waiters.append(d) @@ -331,6 +364,7 @@ class _Wormhole: # their receives, and vice versa. They are also correlated with the # ACKs we get back from the server (which we otherwise ignore). There # are so few messages, 16 bits is enough to be mostly-unique. + if self.DEBUG: print("SEND", mtype) kwargs["id"] = hexlify(os.urandom(2)).decode("ascii") kwargs["type"] = mtype payload = json.dumps(kwargs).encode("utf-8") @@ -358,7 +392,7 @@ class _Wormhole: # entry point 1: generate a new code @inlineCallbacks - def get_code(self, code_length=2): # XX rename to allocate_code()? create_? + def _API_get_code(self, code_length): if self._code is not None: raise UsageError if self._started_get_code: raise UsageError self._started_get_code = True @@ -370,13 +404,13 @@ class _Wormhole: # TODO: signal_error code = yield gc.go() self._get_code = None - self._nameplate_claimed = True # side-effect of allocation + self._nameplate_state = OPEN self._event_learned_code(code) returnValue(code) # entry point 2: interactively type in a code, with completion @inlineCallbacks - def input_code(self, prompt="Enter wormhole code: ", code_length=2): + def _API_input_code(self, prompt, code_length): if self._code is not None: raise UsageError if self._started_input_code: raise UsageError self._started_input_code = True @@ -390,7 +424,7 @@ class _Wormhole: returnValue(None) # entry point 3: paste in a fully-formed code - def set_code(self, code): + def _API_set_code(self, code): self._timing.add("API set_code") if not isinstance(code, type(u"")): raise TypeError(type(code)) if self._code is not None: raise UsageError @@ -437,13 +471,13 @@ class _Wormhole: # for each such condition Y, every _event_Y must call _maybe_X def _event_learned_nameplate(self): - self._maybe_get_mailbox() + self._maybe_claim_nameplate() - def _maybe_get_mailbox(self): - if not (self._nameplate_id and self._connected): + def _maybe_claim_nameplate(self): + if not (self._nameplate_id and self._connection_state == OPEN): return self._ws_send_command(u"claim", nameplate=self._nameplate_id) - self._nameplate_claimed = True + self._nameplate_state = OPEN def _response_handle_claimed(self, msg): mailbox_id = msg["mailbox"] @@ -453,16 +487,19 @@ class _Wormhole: def _event_learned_mailbox(self): if not self._mailbox_id: raise UsageError - if self._mailbox_opened: raise UsageError + assert self._mailbox_state == CLOSED, self._mailbox_state + if self._closing: + return self._ws_send_command(u"open", mailbox=self._mailbox_id) - self._mailbox_opened = True + self._mailbox_state = OPEN # causes old messages to be sent now, and subscribes to new messages self._maybe_send_pake() self._maybe_send_phase_messages() def _maybe_send_pake(self): # TODO: deal with reentrant call - if not (self._connected and self._mailbox_opened + if not (self._connection_state == OPEN + and self._mailbox_state == OPEN and self._flag_need_to_send_PAKE): return self._msg_send(u"pake", self._msg1) @@ -477,44 +514,52 @@ class _Wormhole: self._timing.add("key established") # both sides send different (random) confirmation messages - confkey = self.derive_key(u"wormhole:confirmation") + confkey = self._derive_key(u"wormhole:confirmation") nonce = os.urandom(CONFMSG_NONCE_LENGTH) confmsg = make_confmsg(confkey, nonce) self._msg_send(u"confirm", confmsg) - verifier = self.derive_key(u"wormhole:verifier") + verifier = self._derive_key(u"wormhole:verifier") self._event_computed_verifier(verifier) self._maybe_send_phase_messages() - def get_verifier(self): + def _API_verify(self): + # TODO: rename "verify()", make it stall until confirm received. If + # you want to discover WrongPasswordError before doing send(), call + # verify() first. If you also want to deny a successful MitM (and + # have some other way to check a long verifier), use the return value + # of verify(). if self._error: return defer.fail(self._error) - if self._closed: raise UsageError if self._get_verifier_called: raise UsageError self._get_verifier_called = True + if self._verifier: + return defer.succeed(self._verifier) # TODO: maybe have this wait on _event_received_confirm too + self._verifier_waiter = defer.Deferred() return self._verifier_waiter def _event_computed_verifier(self, verifier): - self._verifier_waiter.callback(verifier) + self._verifier = verifier + if self._verifier_waiter: + self._verifier_waiter.callback(verifier) def _event_received_confirm(self, body): # TODO: we might not have a master key yet, if the caller wasn't # waiting in _get_master_key() when a back-to-back pake+_confirm # message pair arrived. - confkey = self.derive_key(u"wormhole:confirmation") + confkey = self._derive_key(u"wormhole:confirmation") nonce = body[:CONFMSG_NONCE_LENGTH] if body != make_confmsg(confkey, nonce): # this makes all API calls fail if self.DEBUG: print("CONFIRM FAILED") - return self._signal_error(WrongPasswordError()) + return self._signal_error(WrongPasswordError(), u"scary") - def send(self, outbound_data): + def _API_send(self, outbound_data): if self._error: raise self._error if not isinstance(outbound_data, type(b"")): raise TypeError(type(outbound_data)) - if self._closed: raise UsageError phase = self._next_send_phase self._next_send_phase += 1 self._plaintext_to_send.append( (phase, outbound_data) ) @@ -523,14 +568,16 @@ class _Wormhole: def _maybe_send_phase_messages(self): # TODO: deal with reentrant call - if not (self._connected and self._mailbox_opened and self._key): + if not (self._connection_state == OPEN + and self._mailbox_state == OPEN + and self._key): return plaintexts = self._plaintext_to_send self._plaintext_to_send = [] for pm in plaintexts: (phase, plaintext) = pm assert isinstance(phase, int), type(phase) - data_key = self.derive_key(u"wormhole:phase:%d" % phase) + data_key = self._derive_key(u"wormhole:phase:%d" % phase) encrypted = self._encrypt_data(data_key, plaintext) self._msg_send(u"%d" % phase, encrypted) @@ -550,8 +597,7 @@ class _Wormhole: def _msg_send(self, phase, body): if phase in self._sent_phases: raise UsageError - if not self._mailbox_opened: raise UsageError - if self._mailbox_closed: raise UsageError + assert self._mailbox_state == OPEN, self._mailbox_state self._sent_phases.add(phase) # TODO: retry on failure, with exponential backoff. We're guarding # against the rendezvous server being temporarily offline. @@ -566,8 +612,11 @@ class _Wormhole: self._maybe_release_nameplate() self._flag_need_to_see_mailbox_used = False - def derive_key(self, purpose, length=SecretBox.KEY_SIZE): + def _API_derive_key(self, purpose, length): if self._error: raise self._error + return self._derive_key(purpose, length) + + def _derive_key(self, purpose, length=SecretBox.KEY_SIZE): if not isinstance(purpose, type(u"")): raise TypeError(type(purpose)) if self._key is None: raise UsageError # call derive_key after get_verifier() or get() @@ -597,12 +646,19 @@ class _Wormhole: self._event_received_confirm(body) return - # now notify anyone waiting on it + # It's a phase message, aimed at the application above us. Decrypt + # and deliver upstairs, notifying anyone waiting on it try: - data_key = self.derive_key(u"wormhole:phase:%s" % phase) + data_key = self._derive_key(u"wormhole:phase:%s" % phase) plaintext = self._decrypt_data(data_key, body) except CryptoError: - raise WrongPasswordError # TODO: signal + e = WrongPasswordError() + self._signal_error(e, u"scary") # flunk all other API calls + # make tests fail, if they aren't explicitly catching it + if self.DEBUG: print("CryptoError in msg received") + log.err(e) + if self.DEBUG: print(" did log.err", e) + return # ignore this message self._received_messages[phase] = plaintext if phase in self._receive_waiters: d = self._receive_waiters.pop(phase) @@ -616,9 +672,8 @@ class _Wormhole: data = box.decrypt(encrypted) return data - def get(self): + def _API_get(self): if self._error: return defer.fail(self._error) - if self._closed: raise UsageError phase = u"%d" % self._next_receive_phase self._next_receive_phase += 1 with self._timing.add("API get", phase=phase): @@ -627,64 +682,117 @@ class _Wormhole: d = self._receive_waiters[phase] = defer.Deferred() return d - def _maybe_close(self, mood): - if self._closed: + def _signal_error(self, error, mood): + if self.DEBUG: print("_signal_error", error, mood) + if self._error: return - self.close(mood) + self._maybe_close(error, mood) + if self.DEBUG: print("_signal_error done") @inlineCallbacks - def close(self, mood=None, wait=False): - # TODO: auto-close on error, mostly for load-from-state + def _API_close(self, wait=False, mood=u"happy"): if self.DEBUG: print("close", wait) - if self._closed: raise UsageError - self._closed = True - if mood: - self._mood = mood - self._maybe_release_nameplate() - self._maybe_close_mailbox() - if wait: - if self._nameplate_claimed: - if self.DEBUG: print("waiting for released") - self._release_waiter = defer.Deferred() - yield self._release_waiter - if self._mailbox_opened: - if self.DEBUG: print("waiting for closed") - self._close_waiter = defer.Deferred() - yield self._close_waiter - if self.DEBUG: print("dropping connection") - self._drop_connection() + if self._close_called: raise UsageError + self._close_called = True + self._maybe_close(WormholeClosedError(), mood) if wait: if self.DEBUG: print("waiting for disconnect") yield self._disconnect_waiter + def _maybe_close(self, error, mood): + if self._closing: + return + + # ordering constraints: + # * must wait for nameplate/mailbox acks before closing the websocket + # * must mark APIs for failure before errbacking Deferreds + # * since we give up control + # * must mark self._closing before errbacking Deferreds + # * since caller may call close() when we give up control + # * and close() will reenter _maybe_close + + self._error = error # causes new API calls to fail + + # since we're about to give up control by errbacking any API + # Deferreds, set self._closing, to make sure that a new call to + # close() isn't going to confuse anything + self._closing = True + + # now errback all API deferreds except close(): get_code, + # input_code, verify, get + for d in self._connection_waiters: # input_code, get_code (early) + if self.DEBUG: print("EB cw") + d.errback(error) + if self._get_code: # get_code (late) + if self.DEBUG: print("EB gc") + self._get_code._allocated_d.errback(error) + if self._verifier_waiter and not self._verifier_waiter.called: + if self.DEBUG: print("EB VW") + self._verifier_waiter.errback(error) + for d in self._receive_waiters.values(): + if self.DEBUG: print("EB RW") + d.errback(error) + # Release nameplate and close mailbox, if either was claimed/open. + # Since _closing is True when both ACKs come back, the handlers will + # close the websocket. When *that* finishes, _disconnect_waiter() + # will fire. + self._maybe_release_nameplate() + self._maybe_close_mailbox(mood) + # In the off chance we got closed before we even claimed the + # nameplate, give _maybe_finished_closing a chance to run now. + self._maybe_finished_closing() + def _maybe_release_nameplate(self): - if self.DEBUG: print("_maybe_release_nameplate", self._nameplate_claimed, self._nameplate_released) - if self._nameplate_claimed and not self._nameplate_released: + if self.DEBUG: print("_maybe_release_nameplate", self._nameplate_state) + if self._nameplate_state == OPEN: if self.DEBUG: print(" sending release") self._ws_send_command(u"release") - self._nameplate_released = True + self._nameplate_state = CLOSING def _response_handle_released(self, msg): - if self._release_waiter and not self._release_waiter.called: - self._release_waiter.callback(None) + self._nameplate_state = CLOSED + self._maybe_finished_closing() - def _maybe_close_mailbox(self): - if self.DEBUG: print("_maybe_close_mailbox", self._mailbox_opened, self._mailbox_closed) - if self._mailbox_opened and not self._mailbox_closed: + def _maybe_close_mailbox(self, mood): + if self.DEBUG: print("_maybe_close_mailbox", self._mailbox_state) + if self._mailbox_state == OPEN: if self.DEBUG: print(" sending close") - self._ws_send_command(u"close", mood=self._mood) - self._mailbox_closed = True + self._ws_send_command(u"close", mood=mood) + self._mailbox_state = CLOSING def _response_handle_closed(self, msg): - if self._close_waiter and not self._close_waiter.called: - self._close_waiter.callback(None) + self._mailbox_state = CLOSED + self._maybe_finished_closing() + + def _maybe_finished_closing(self): + if self.DEBUG: print("_maybe_finished_closing", self._closing, self._nameplate_state, self._mailbox_state, self._connection_state) + if not self._closing: + return + if (self._nameplate_state == CLOSED + and self._mailbox_state == CLOSED + and self._connection_state == OPEN): + self._connection_state = CLOSING + self._drop_connection() def _drop_connection(self): - self._ws.transport.loseConnection() # probably flushes + # separate method so it can be overridden by tests + self._ws.transport.loseConnection() # probably flushes output # calls _ws_closed() when done def _ws_closed(self, wasClean, code, reason): + # For now (until we add reconnection), losing the websocket means + # losing everything. Make all API callers fail. Help someone waiting + # in close() to finish + self._connection_state = CLOSED self._disconnect_waiter.callback(None) + self._maybe_finished_closing() + + # what needs to happen when _ws_closed() happens unexpectedly + # * errback all API deferreds + # * maybe: cause new API calls to fail + # * obviously can't release nameplate or close mailbox + # * can't re-close websocket + # * close(wait=True) callers should fire right away def wormhole(appid, relay_url, reactor, tor_manager=None, timing=None): timing = timing or DebugTiming() From 9bd5afe7df097bc46263bb327c46a5be588d1d16 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Mon, 23 May 2016 23:59:49 -0700 Subject: [PATCH 40/51] make close() always wait --- src/wormhole/test/test_wormhole.py | 56 ++++++++++++++++-------------- src/wormhole/wormhole.py | 35 +++++++++++++++---- 2 files changed, 57 insertions(+), 34 deletions(-) diff --git a/src/wormhole/test/test_wormhole.py b/src/wormhole/test/test_wormhole.py index c0f2823..c9e9a0f 100644 --- a/src/wormhole/test/test_wormhole.py +++ b/src/wormhole/test/test_wormhole.py @@ -282,7 +282,8 @@ class Basic(unittest.TestCase): self.assertIn(u"1", w._received_messages) self.assertNotIn(u"1", w._receive_waiters) - w.close() + d = w.close() + self.assertNoResult(d) out = ws.outbound() self.assertEqual(len(out), 1) self.check_out(out[0], type=u"close", mood=u"happy") @@ -293,6 +294,7 @@ class Basic(unittest.TestCase): response(w, type=u"closed") self.assertEqual(w._drop_connection.mock_calls, [mock.call()]) w._ws_closed(True, None, None) + self.assertEqual(self.successResultOf(d), None) def test_close_wait_0(self): # Close before the connection is established. The connection still @@ -301,7 +303,7 @@ class Basic(unittest.TestCase): w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing) w._drop_connection = mock.Mock() - d = w.close(wait=True) + d = w.close() self.assertNoResult(d) ws = MockWebSocket() @@ -322,7 +324,7 @@ class Basic(unittest.TestCase): w._event_connected(ws) w._event_ws_opened(None) - d = w.close(wait=True) + d = w.close() self.check_outbound(ws, [u"bind"]) self.assertNoResult(d) self.assertEqual(w._drop_connection.mock_calls, [mock.call()]) @@ -346,7 +348,7 @@ class Basic(unittest.TestCase): response(w, type=u"claimed", mailbox=u"mb123") - d = w.close(wait=True) + d = w.close() self.check_outbound(ws, [u"open", u"add", u"release", u"close"]) self.assertNoResult(d) self.assertEqual(w._drop_connection.mock_calls, []) @@ -375,7 +377,7 @@ class Basic(unittest.TestCase): w.set_code(CODE) self.check_outbound(ws, [u"bind", u"claim"]) - d = w.close(wait=True) + d = w.close() response(w, type=u"claimed", mailbox=u"mb123") self.check_outbound(ws, [u"release"]) self.assertNoResult(d) @@ -401,7 +403,7 @@ class Basic(unittest.TestCase): response(w, type=u"claimed", mailbox=u"mb456") self.check_outbound(ws, [u"bind", u"claim", u"open", u"add"]) - d = w.close(wait=True) + d = w.close() self.check_outbound(ws, [u"release", u"close"]) self.assertNoResult(d) self.assertEqual(w._drop_connection.mock_calls, []) @@ -439,7 +441,7 @@ class Basic(unittest.TestCase): self.check_outbound(ws, [u"bind", u"claim", u"open", u"add", u"release"]) - d = w.close(wait=True) + d = w.close() self.check_outbound(ws, [u"close"]) self.assertNoResult(d) self.assertEqual(w._drop_connection.mock_calls, []) @@ -625,8 +627,8 @@ class Wormholes(ServerBase, unittest.TestCase): dataY = yield w2.get() self.assertEqual(dataX, b"data2") self.assertEqual(dataY, b"data1") - yield w1.close(wait=True) - yield w2.close(wait=True) + yield w1.close() + yield w2.close() @inlineCallbacks def test_same_message(self): @@ -643,8 +645,8 @@ class Wormholes(ServerBase, unittest.TestCase): dataY = yield w2.get() self.assertEqual(dataX, b"data") self.assertEqual(dataY, b"data") - yield w1.close(wait=True) - yield w2.close(wait=True) + yield w1.close() + yield w2.close() @inlineCallbacks def test_interleaved(self): @@ -659,8 +661,8 @@ class Wormholes(ServerBase, unittest.TestCase): w2.send(b"data2") dataX = yield d self.assertEqual(dataX, b"data2") - yield w1.close(wait=True) - yield w2.close(wait=True) + yield w1.close() + yield w2.close() @inlineCallbacks def test_unidirectional(self): @@ -671,8 +673,8 @@ class Wormholes(ServerBase, unittest.TestCase): w1.send(b"data1") dataY = yield w2.get() self.assertEqual(dataY, b"data1") - yield w1.close(wait=True) - yield w2.close(wait=True) + yield w1.close() + yield w2.close() @inlineCallbacks def test_early(self): @@ -684,8 +686,8 @@ class Wormholes(ServerBase, unittest.TestCase): w2.set_code(u"123-abc-def") dataY = yield d self.assertEqual(dataY, b"data1") - yield w1.close(wait=True) - yield w2.close(wait=True) + yield w1.close() + yield w2.close() @inlineCallbacks def test_fixed_code(self): @@ -698,8 +700,8 @@ class Wormholes(ServerBase, unittest.TestCase): (dataX, dataY) = dl self.assertEqual(dataX, b"data2") self.assertEqual(dataY, b"data1") - yield w1.close(wait=True) - yield w2.close(wait=True) + yield w1.close() + yield w2.close() @inlineCallbacks @@ -718,8 +720,8 @@ class Wormholes(ServerBase, unittest.TestCase): (dataX, dataY) = dl self.assertEqual(dataX, b"data4") self.assertEqual(dataY, b"data3") - yield w1.close(wait=True) - yield w2.close(wait=True) + yield w1.close() + yield w2.close() @inlineCallbacks def test_wrong_password(self): @@ -738,8 +740,8 @@ class Wormholes(ServerBase, unittest.TestCase): yield self.assertFailure(w2.get(), WrongPasswordError) yield self.assertFailure(w1.get(), WrongPasswordError) - yield w1.close(wait=True) - yield w2.close(wait=True) + yield w1.close() + yield w2.close() self.flushLoggedErrors(WrongPasswordError) @inlineCallbacks @@ -758,8 +760,8 @@ class Wormholes(ServerBase, unittest.TestCase): dataY = yield w2.get() self.assertEqual(dataX, b"data2") self.assertEqual(dataY, b"data1") - yield w1.close(wait=True) - yield w2.close(wait=True) + yield w1.close() + yield w2.close() class Errors(ServerBase, unittest.TestCase): @inlineCallbacks @@ -773,7 +775,7 @@ class Errors(ServerBase, unittest.TestCase): self.assertRaises(UsageError, w.set_code, u"123-nope") yield self.assertFailure(w.get_code(), UsageError) yield self.assertFailure(w.input_code(), UsageError) - yield w.close(wait=True) + yield w.close() @inlineCallbacks def test_codes_2(self): @@ -782,5 +784,5 @@ class Errors(ServerBase, unittest.TestCase): self.assertRaises(UsageError, w.set_code, u"123-nope") yield self.assertFailure(w.get_code(), UsageError) yield self.assertFailure(w.input_code(), UsageError) - yield w.close(wait=True) + yield w.close() diff --git a/src/wormhole/wormhole.py b/src/wormhole/wormhole.py index 134aa0b..26c05a7 100644 --- a/src/wormhole/wormhole.py +++ b/src/wormhole/wormhole.py @@ -297,8 +297,29 @@ class _Wormhole: """ return self._API_derive_key(purpose, length) - def close(self, wait=False): - return self._API_close(wait) + def close(self, res=None): + """Collapse the wormhole, freeing up server resources and flushing + all pending messages. Returns a Deferred that fires when everything + is done. It fires with any argument close() was given, to enable use + as a d.addBoth() handler: + + w = wormhole(...) + d = w.get() + .. + d.addBoth(w.close) + return d + + Another reasonable approach is to use inlineCallbacks: + + @inlineCallbacks + def pair(self, code): + w = wormhole(...) + try: + them = yield w.get() + finally: + yield w.close() + """ + return self._API_close(res) # INTERNAL METHODS beyond here @@ -690,14 +711,14 @@ class _Wormhole: if self.DEBUG: print("_signal_error done") @inlineCallbacks - def _API_close(self, wait=False, mood=u"happy"): - if self.DEBUG: print("close", wait) + def _API_close(self, res, mood=u"happy"): + if self.DEBUG: print("close") if self._close_called: raise UsageError self._close_called = True self._maybe_close(WormholeClosedError(), mood) - if wait: - if self.DEBUG: print("waiting for disconnect") - yield self._disconnect_waiter + if self.DEBUG: print("waiting for disconnect") + yield self._disconnect_waiter + returnValue(res) def _maybe_close(self, error, mood): if self._closing: From 2c64805ea11361788efaa7788d9982dc156ef995 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Tue, 24 May 2016 00:00:04 -0700 Subject: [PATCH 41/51] fix input_code --- src/wormhole/wormhole.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/wormhole/wormhole.py b/src/wormhole/wormhole.py index 26c05a7..662b6e4 100644 --- a/src/wormhole/wormhole.py +++ b/src/wormhole/wormhole.py @@ -225,6 +225,7 @@ class _Wormhole: self._connection_waiters = [] self._started_get_code = False self._get_code = None + self._started_input_code = False self._code = None self._nameplate_id = None self._nameplate_state = CLOSED @@ -437,7 +438,8 @@ class _Wormhole: self._started_input_code = True with self._timing.add("API input_code"): yield self._when_connected() - ic = _InputCode(prompt, code_length, self._ws_send_command) + ic = _InputCode(self._reactor, prompt, code_length, + self._ws_send_command, self._timing) self._response_handle_nameplates = ic._response_handle_nameplates # TODO: signal_error code = yield ic.go() From e2aa43d0a90f5b41cc14418b20ade98afd893617 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Tue, 24 May 2016 00:00:21 -0700 Subject: [PATCH 42/51] transit: expose desired key length --- src/wormhole/twisted/transit.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/wormhole/twisted/transit.py b/src/wormhole/twisted/transit.py index d0eef5e..cbc8724 100644 --- a/src/wormhole/twisted/transit.py +++ b/src/wormhole/twisted/transit.py @@ -548,6 +548,7 @@ def there_can_be_only_one(contenders): class Common: RELAY_DELAY = 2.0 + TRANSIT_KEY_LENGTH = SecretBox.KEY_SIZE def __init__(self, transit_relay, no_listen=False, tor_manager=None, reactor=reactor, timing=None): From 3a062eaa26e47b64e34b920e4a33c7c11065fa05 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Tue, 24 May 2016 00:00:44 -0700 Subject: [PATCH 43/51] bring scripts and tests up to date * use wormhole instead of transcribe.py * send() no longer waits * get_verifier -> verify * derive_key demands a length --- src/wormhole/cli/cmd_receive.py | 16 ++++++++-------- src/wormhole/cli/cmd_send.py | 11 ++++++----- src/wormhole/test/test_scripts.py | 5 +++-- 3 files changed, 17 insertions(+), 15 deletions(-) diff --git a/src/wormhole/cli/cmd_receive.py b/src/wormhole/cli/cmd_receive.py index 7813d75..2e0447a 100644 --- a/src/wormhole/cli/cmd_receive.py +++ b/src/wormhole/cli/cmd_receive.py @@ -3,7 +3,7 @@ import os, sys, json, binascii, six, tempfile, zipfile from tqdm import tqdm from twisted.internet import reactor from twisted.internet.defer import inlineCallbacks, returnValue -from ..twisted.transcribe import wormhole +from ..wormhole import wormhole from ..twisted.transit import TransitReceiver from ..errors import TransferError @@ -64,12 +64,12 @@ class TwistedReceiver: @inlineCallbacks def _go(self, w, tor_manager): yield self.handle_code(w) - verifier = yield w.get_verifier() + verifier = yield w.verify() self.show_verifier(verifier) them_d = yield self.get_data(w) try: if "message" in them_d: - yield self.handle_text(them_d, w) + self.handle_text(them_d, w) returnValue(None) if "file" in them_d: f = self.handle_file(them_d) @@ -89,7 +89,7 @@ class TwistedReceiver: raise RespondError("unknown offer type") except RespondError as r: data = json.dumps({"error": r.response}).encode("utf-8") - yield w.send(data) + w.send(data) raise TransferError(r.response) returnValue(None) @@ -119,12 +119,11 @@ class TwistedReceiver: raise TransferError(them_d["error"]) returnValue(them_d) - @inlineCallbacks def handle_text(self, them_d, w): # we're receiving a text message self.msg(them_d["message"]) data = json.dumps({"message_ack": "ok"}).encode("utf-8") - yield w.send(data, wait=True) + w.send(data) def handle_file(self, them_d): file_data = them_d["file"] @@ -183,12 +182,13 @@ class TwistedReceiver: @inlineCallbacks def establish_transit(self, w, them_d, tor_manager): - transit_key = w.derive_key(APPID+u"/transit-key") transit_receiver = TransitReceiver(self.args.transit_helper, no_listen=self.args.no_listen, tor_manager=tor_manager, reactor=self._reactor, timing=self.args.timing) + transit_key = w.derive_key(APPID+u"/transit-key", + transit_receiver.TRANSIT_KEY_LENGTH) transit_receiver.set_transit_key(transit_key) direct_hints = yield transit_receiver.get_direct_hints() relay_hints = yield transit_receiver.get_relay_hints() @@ -199,7 +199,7 @@ class TwistedReceiver: "relay_connection_hints": relay_hints, }, }).encode("utf-8") - yield w.send(data) + w.send(data) # now receive the rest of the owl tdata = them_d["transit"] diff --git a/src/wormhole/cli/cmd_send.py b/src/wormhole/cli/cmd_send.py index 8ad6973..dba5ee0 100644 --- a/src/wormhole/cli/cmd_send.py +++ b/src/wormhole/cli/cmd_send.py @@ -5,7 +5,7 @@ from twisted.protocols import basic from twisted.internet import reactor from twisted.internet.defer import inlineCallbacks, returnValue from ..errors import TransferError -from ..twisted.transcribe import wormhole +from ..wormhole import wormhole from ..twisted.transit import TransitSender APPID = u"lothar.com/wormhole/text-or-file-xfer" @@ -83,7 +83,7 @@ def _send(reactor, w, args, phase1, fd_to_send, tor_manager): # get the verifier, because that also lets us derive the transit key, # which we want to set before revealing the connection hints to the far # side, so we'll be ready for them when they connect - verifier_bytes = yield w.get_verifier() + verifier_bytes = yield w.verify() verifier = binascii.hexlify(verifier_bytes).decode("ascii") if args.verify: @@ -94,14 +94,15 @@ def _send(reactor, w, args, phase1, fd_to_send, tor_manager): if ok.lower() == "no": err = "sender rejected verification check, abandoned transfer" reject_data = json.dumps({"error": err}).encode("utf-8") - yield w.send(reject_data) + w.send(reject_data) raise TransferError(err) if fd_to_send is not None: - transit_key = w.derive_key(APPID+"/transit-key") + transit_key = w.derive_key(APPID+"/transit-key", + transit_sender.TRANSIT_KEY_LENGTH) transit_sender.set_transit_key(transit_key) my_phase1_bytes = json.dumps(phase1).encode("utf-8") - yield w.send(my_phase1_bytes) + w.send(my_phase1_bytes) # this may raise WrongPasswordError them_phase1_bytes = yield w.get() diff --git a/src/wormhole/test/test_scripts.py b/src/wormhole/test/test_scripts.py index e70cc1a..71fd33a 100644 --- a/src/wormhole/test/test_scripts.py +++ b/src/wormhole/test/test_scripts.py @@ -453,7 +453,7 @@ class Cleanup(ServerBase, unittest.TestCase): yield send_d yield receive_d - cids = self._rendezvous.get_app(cmd_send.APPID).get_claimed() + cids = self._rendezvous.get_app(cmd_send.APPID).get_nameplate_ids() self.assertEqual(len(cids), 0) @inlineCallbacks @@ -482,6 +482,7 @@ class Cleanup(ServerBase, unittest.TestCase): yield self.assertFailure(send_d, WrongPasswordError) yield self.assertFailure(receive_d, WrongPasswordError) - cids = self._rendezvous.get_app(cmd_send.APPID).get_claimed() + cids = self._rendezvous.get_app(cmd_send.APPID).get_nameplate_ids() self.assertEqual(len(cids), 0) + self.flushLoggedErrors(WrongPasswordError) From 1ef6218b5be7a468be7489c530e7fba838203a3a Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Tue, 24 May 2016 00:01:22 -0700 Subject: [PATCH 44/51] remove old twisted/transcribe.py, now just wormhole.py --- src/wormhole/twisted/transcribe.py | 624 ----------------------------- 1 file changed, 624 deletions(-) delete mode 100644 src/wormhole/twisted/transcribe.py diff --git a/src/wormhole/twisted/transcribe.py b/src/wormhole/twisted/transcribe.py deleted file mode 100644 index 555499b..0000000 --- a/src/wormhole/twisted/transcribe.py +++ /dev/null @@ -1,624 +0,0 @@ -from __future__ import print_function -import os, sys, json, re, unicodedata -from six.moves.urllib_parse import urlparse -from binascii import hexlify, unhexlify -from twisted.internet import defer, endpoints, error -from twisted.internet.threads import deferToThread, blockingCallFromThread -from twisted.internet.defer import inlineCallbacks, returnValue -from twisted.python import log -from autobahn.twisted import websocket -from nacl.secret import SecretBox -from nacl.exceptions import CryptoError -from nacl import utils -from spake2 import SPAKE2_Symmetric -from .. import __version__ -from .. import codes -from ..errors import ServerError, Timeout, WrongPasswordError, UsageError -from ..timing import DebugTiming -from hkdf import Hkdf - -def HKDF(skm, outlen, salt=None, CTXinfo=b""): - return Hkdf(salt, skm).expand(CTXinfo, outlen) - -CONFMSG_NONCE_LENGTH = 128//8 -CONFMSG_MAC_LENGTH = 256//8 -def make_confmsg(confkey, nonce): - return nonce+HKDF(confkey, CONFMSG_MAC_LENGTH, nonce) - -def to_bytes(u): - return unicodedata.normalize("NFC", u).encode("utf-8") - -# We send the following messages through the relay server to the far side (by -# sending "add" commands to the server, and getting "message" responses): -# -# phase=setup: -# * unauthenticated version strings (but why?) -# * early warmup for connection hints ("I can do tor, spin up HS") -# * wordlist l10n identifier -# phase=pake: just the SPAKE2 'start' message (binary) -# phase=confirm: key verification (HKDF(key, nonce)+nonce) -# phase=1,2,3,..: application messages - -class WSClient(websocket.WebSocketClientProtocol): - def onOpen(self): - self.wormhole_open = True - self.factory.d.callback(self) - - def onMessage(self, payload, isBinary): - assert not isBinary - self.wormhole._ws_dispatch_response(payload) - - def onClose(self, wasClean, code, reason): - if self.wormhole_open: - self.wormhole._ws_closed(wasClean, code, reason) - else: - # we closed before establishing a connection (onConnect) or - # finishing WebSocket negotiation (onOpen): errback - self.factory.d.errback(error.ConnectError(reason)) - -class WSFactory(websocket.WebSocketClientFactory): - protocol = WSClient - def buildProtocol(self, addr): - proto = websocket.WebSocketClientFactory.buildProtocol(self, addr) - proto.wormhole = self.wormhole - proto.wormhole_open = False - return proto - -class _Wormhole: - motd_displayed = False - version_warning_displayed = False - _send_confirm = True - - def __init__(self, appid, relay_url, reactor, - tor_manager=None, timing=None): - if not isinstance(appid, type(u"")): raise TypeError(type(appid)) - if not isinstance(relay_url, type(u"")): - raise TypeError(type(relay_url)) - if not relay_url.endswith(u"/"): raise UsageError - self._appid = appid - self._relay_url = relay_url - self._ws_url = relay_url.replace("http:", "ws:") + "ws" - self._tor_manager = tor_manager - self._timing = timing or DebugTiming() - self._reactor = reactor - - self._ws_connected = defer.Deferred() # XXX - - self._side = hexlify(os.urandom(5)).decode("ascii") - self._code = None - - self._nameplate_id = None - self._nameplate_claimed = False - self._nameplate_released = False - - self._mailbox_id = None - self._mailbox_opened = False - self._mailbox_closed = False - - self._key = None - self._started_get_code = False - self._next_outbound_phase = 0 - self._sent_messages = {} # phase -> body_bytes - self._delivered_messages = set() # phase - self._next_inbound_phase = 0 - self._received_messages = {} # phase -> body_bytes - self._got_phases = set() # phases, to prohibit double-read - self._sleepers = [] - self._confirmation_failed = False - self._closed = False - self._released_status = None - self._timing_started = self._timing.add("wormhole") - self._ws = None - self._ws_t = None # timing Event - self._error = None - - def _make_endpoint(self, hostname, port): - if self._tor_manager: - return self._tor_manager.get_endpoint_for(hostname, port) - # note: HostnameEndpoints have a default 30s timeout - return endpoints.HostnameEndpoint(self._reactor, hostname, port) - - @inlineCallbacks - def _get_websocket(self): - if not self._ws: - # TODO: if we lose the connection, make a new one, re-establish - # the state - assert self._side - p = urlparse(self._ws_url) - f = WSFactory(self._ws_url) - f.wormhole = self - f.d = defer.Deferred() - # TODO: if hostname="localhost", I get three factories starting - # and stopping (maybe 127.0.0.1, ::1, and something else?), and - # an error in the factory is masked. - ep = self._make_endpoint(p.hostname, p.port or 80) - # .connect errbacks if the TCP connection fails - self._ws = yield ep.connect(f) - self._ws_t = self._timing.add("websocket") - # f.d is errbacked if WebSocket negotiation fails - yield f.d # WebSocket drops data sent before onOpen() fires - self._ws_send_command(u"bind", appid=self._appid, side=self._side) - # the socket is connected, and bound, but no nameplate has been claimed - returnValue(self._ws) - - @inlineCallbacks - def _ws_send_command(self, mtype, **kwargs): - ws = yield self._get_websocket() - # msgid is used by misc/dump-timing.py to correlate our sends with - # their receives, and vice versa. They are also correlated with the - # ACKs we get back from the server (which we otherwise ignore). There - # are so few messages, 16 bits is enough to be mostly-unique. - kwargs["id"] = hexlify(os.urandom(2)).decode("ascii") - kwargs["type"] = mtype - payload = json.dumps(kwargs).encode("utf-8") - self._timing.add("ws_send", _side=self._side, **kwargs) - ws.sendMessage(payload, False) - - def _ws_dispatch_response(self, payload): - msg = json.loads(payload.decode("utf-8")) - self._timing.add("ws_receive", _side=self._side, message=msg) - mtype = msg["type"] - meth = getattr(self, "_ws_handle_"+mtype, None) - if not meth: - # make tests fail, but real application will ignore it - log.err(ValueError("Unknown inbound message type %r" % (msg,))) - return - return meth(msg) - - def _ws_handle_ack(self, msg): - pass - - def _ws_handle_welcome(self, msg): - welcome = msg["welcome"] - if ("motd" in welcome and - not self.motd_displayed): - motd_lines = welcome["motd"].splitlines() - motd_formatted = "\n ".join(motd_lines) - print("Server (at %s) says:\n %s" % - (self._ws_url, motd_formatted), file=sys.stderr) - self.motd_displayed = True - - # Only warn if we're running a release version (e.g. 0.0.6, not - # 0.0.6-DISTANCE-gHASH). Only warn once. - if ("-" not in __version__ and - not self.version_warning_displayed and - welcome["current_version"] != __version__): - print("Warning: errors may occur unless both sides are running the same version", file=sys.stderr) - print("Server claims %s is current, but ours is %s" - % (welcome["current_version"], __version__), file=sys.stderr) - self.version_warning_displayed = True - - if "error" in welcome: - return self._signal_error(welcome["error"]) - - @inlineCallbacks - def _sleep(self, wake_on_error=True): - if wake_on_error and self._error: - # don't sleep if the bed's already on fire, unless we're waiting - # for the fire department to respond, in which case sure, keep on - # sleeping - raise self._error - d = defer.Deferred() - self._sleepers.append(d) - yield d - if wake_on_error and self._error: - raise self._error - - def _wakeup(self): - sleepers = self._sleepers - self._sleepers = [] - for d in sleepers: - d.callback(None) - # NOTE: callers should avoid reentrancy themselves. An - # eventual-send would be safer here, but it makes synchronizing - # unit tests annoying. - - def _signal_error(self, error): - assert isinstance(error, Exception) - self._error = error - self._wakeup() - - def _ws_handle_error(self, msg): - err = ServerError("%s: %s" % (msg["error"], msg["orig"]), - self._ws_url) - return self._signal_error(err) - - @inlineCallbacks - def _claim_nameplate(self): - if not self._nameplate_id: raise UsageError - if self._nameplate_claimed: raise UsageError - yield self._ws_send_command(u"claim", nameplate=self._nameplate_id) - # provokes "claimed" response - - def _ws_handle_claimed(self, msg): - mailbox_id = msg["mailbox"] - assert isinstance(mailbox_id, type(u"")), type(mailbox_id) - self._mailbox_id = mailbox_id - self._open_mailbox() - - @inlineCallbacks - def _release_nameplate(self): - if not self._nameplate_claimed: raise UsageError - if self._nameplate_released: raise UsageError - yield self._ws_send_command(u"release") - self._nameplate_released = True - - - @inlineCallbacks - def _open_mailbox(self): - if not self._mailbox_id: raise UsageError - if self._mailbox_opened: raise UsageError - yield self._ws_send_command(u"open", mailbox=self._mailbox_id) - self._mailbox_opened = True - # causes old messages to be sent now, and subscribes to new messages - - @inlineCallbacks - def _close_mailbox(self): - if not self._mailbox_id: raise UsageError - if not self._mailbox_opened: raise UsageError - if self._mailbox_closed: raise UsageError - yield self._ws_send_command(u"close") - self._mailbox_closed = True - - - @inlineCallbacks - def _msg_send(self, phase, body, wait=False): - if phase in self._sent_messages: raise UsageError - if not self._mailbox_opened: raise UsageError - if self._mailbox_closed: raise UsageError - self._sent_messages[phase] = body - # TODO: retry on failure, with exponential backoff. We're guarding - # against the rendezvous server being temporarily offline. - t = self._timing.add("add", phase=phase, wait=wait) - yield self._ws_send_command(u"add", phase=phase, - body=hexlify(body).decode("ascii")) - if wait: - while phase not in self._delivered_messages: - yield self._sleep() - t.finish() - - def _ws_handle_message(self, msg): - # any message in the mailbox means we no longer need the nameplate - if not self._nameplate_released: - self._release_nameplate() # XXX returns Deferred - - m = msg["message"] - phase = m["phase"] - body = unhexlify(m["body"].encode("ascii")) - if phase in self._sent_messages and self._sent_messages[phase] == body: - self._delivered_messages.add(phase) # ack by server - self._wakeup() - return # ignore echoes of our outbound messages - if phase in self._received_messages: - # a nameplate collision would cause this - err = ServerError("got duplicate phase %s" % phase, self._ws_url) - return self._signal_error(err) - self._received_messages[phase] = body - if phase == u"confirm": - # TODO: we might not have a master key yet, if the caller wasn't - # waiting in _get_master_key() when a back-to-back pake+_confirm - # message pair arrived. - confkey = self.derive_key(u"wormhole:confirmation") - nonce = body[:CONFMSG_NONCE_LENGTH] - if body != make_confmsg(confkey, nonce): - # this makes all API calls fail - return self._signal_error(WrongPasswordError()) - # now notify anyone waiting on it - self._wakeup() - - @inlineCallbacks - def _msg_get(self, phase): - with self._timing.add("get", phase=phase): - while phase not in self._received_messages: - yield self._sleep() # we can wait a long time here - # that will throw an error if something goes wrong - msg = self._received_messages[phase] - returnValue(msg) - - # entry point 1: generate a new code - @inlineCallbacks - def get_code(self, code_length=2): # XX rename to allocate_code()? create_? - if self._code is not None: raise UsageError - if self._started_get_code: raise UsageError - self._started_get_code = True - with self._timing.add("API get_code"): - with self._timing.add("allocate"): - yield self._ws_send_command(u"allocate") - while self._nameplate_id is None: - yield self._sleep() - code = codes.make_code(self._nameplate_id, code_length) - assert isinstance(code, type(u"")), type(code) - self._set_code(code) - self._start() - returnValue(code) - - def _ws_handle_allocated(self, msg): - if self._nameplate_id is not None: - return self._signal_error("got duplicate 'allocated' response") - nid = msg["nameplate"] - assert isinstance(nid, type(u"")), type(nid) - self._nameplate_id = nid - self._wakeup() - - def _start(self): - # allocate the rest now too, so it can be serialized - with self._timing.add("pake1", waiting="crypto"): - self._sp = SPAKE2_Symmetric(to_bytes(self._code), - idSymmetric=to_bytes(self._appid)) - self._msg1 = self._sp.start() - - # entry point 2a: interactively type in a code, with completion - @inlineCallbacks - def input_code(self, prompt="Enter wormhole code: ", code_length=2): - def _lister(): - return blockingCallFromThread(self._reactor, self._list_nameplates) - # fetch the list of nameplates ahead of time, to give us a chance to - # discover the welcome message (and warn the user about an obsolete - # client) - # - # TODO: send the request early, show the prompt right away, hide the - # latency in the user's indecision and slow typing. If we're lucky - # the answer will come back before they hit TAB. - with self._timing.add("API input_code"): - initial_nameplate_ids = yield self._list_nameplates() - with self._timing.add("input code", waiting="user"): - t = self._reactor.addSystemEventTrigger("before", "shutdown", - self._warn_readline) - code = yield deferToThread(codes.input_code_with_completion, - prompt, - initial_nameplate_ids, _lister, - code_length) - self._reactor.removeSystemEventTrigger(t) - returnValue(code) # application will give this to set_code() - - def _warn_readline(self): - # When our process receives a SIGINT, Twisted's SIGINT handler will - # stop the reactor and wait for all threads to terminate before the - # process exits. However, if we were waiting for - # input_code_with_completion() when SIGINT happened, the readline - # thread will be blocked waiting for something on stdin. Trick the - # user into satisfying the blocking read so we can exit. - print("\nCommand interrupted: please press Return to quit", - file=sys.stderr) - - # Other potential approaches to this problem: - # * hard-terminate our process with os._exit(1), but make sure the - # tty gets reset to a normal mode ("cooked"?) first, so that the - # next shell command the user types is echoed correctly - # * track down the thread (t.p.threadable.getThreadID from inside the - # thread), get a cffi binding to pthread_kill, deliver SIGINT to it - # * allocate a pty pair (pty.openpty), replace sys.stdin with the - # slave, build a pty bridge that copies bytes (and other PTY - # things) from the real stdin to the master, then close the slave - # at shutdown, so readline sees EOF - # * write tab-completion and basic editing (TTY raw mode, - # backspace-is-erase) without readline, probably with curses or - # twisted.conch.insults - # * write a separate program to get codes (maybe just "wormhole - # --internal-get-code"), run it as a subprocess, let it inherit - # stdin/stdout, send it SIGINT when we receive SIGINT ourselves. It - # needs an RPC mechanism (over some extra file descriptors) to ask - # us to fetch the current nameplate_id list. - # - # Note that hard-terminating our process with os.kill(os.getpid(), - # signal.SIGKILL), or SIGTERM, doesn't seem to work: the thread - # doesn't see the signal, and we must still wait for stdin to make - # readline finish. - - @inlineCallbacks - def _list_nameplates(self): - with self._timing.add("list"): - self._latest_nameplate_ids = None - yield self._ws_send_command(u"list") - while self._latest_nameplate_ids is None: - yield self._sleep() - returnValue(self._latest_nameplate_ids) - - def _ws_handle_nameplates(self, msg): - self._latest_nameplate_ids = msg["nameplates"] - self._wakeup() - - # entry point 2b: paste in a fully-formed code - def set_code(self, code): - if not isinstance(code, type(u"")): raise TypeError(type(code)) - if self._code is not None: raise UsageError - mo = re.search(r'^(\d+)-', code) - if not mo: - raise ValueError("code (%s) must start with NN-" % code) - with self._timing.add("API set_code"): - self._nameplate_id = mo.group(1) - assert isinstance(self._nameplate_id, type(u"")), type(self._nameplate_id) - self._set_code(code) - self._start() - - def _set_code(self, code): - if self._code is not None: raise UsageError - self._timing.add("code established") - self._code = code - - def serialize(self): - # I can only be serialized after get_code/set_code and before - # get_verifier/get - if self._code is None: raise UsageError - if self._key is not None: raise UsageError - if self._sent_messages: raise UsageError - if self._got_phases: raise UsageError - data = { - "appid": self._appid, - "relay_url": self._relay_url, - "code": self._code, - "nameplate_id": self._nameplate_id, - "side": self._side, - "spake2": json.loads(self._sp.serialize().decode("ascii")), - "msg1": hexlify(self._msg1).decode("ascii"), - } - return json.dumps(data) - - # entry point 3: resume a previously-serialized session - @classmethod - def from_serialized(klass, data, reactor): - d = json.loads(data) - self = klass(d["appid"], d["relay_url"], reactor) - self._side = d["side"] - self._nameplate_id = d["nameplate_id"] - self._set_code(d["code"]) - sp_data = json.dumps(d["spake2"]).encode("ascii") - self._sp = SPAKE2_Symmetric.from_serialized(sp_data) - self._msg1 = unhexlify(d["msg1"].encode("ascii")) - return self - - @inlineCallbacks - def get_verifier(self): - if self._closed: raise UsageError - if self._code is None: raise UsageError - with self._timing.add("API get_verifier"): - yield self._get_master_key() - # If the caller cares about the verifier, then they'll probably - # also willing to wait a moment to see the _confirm message. Each - # side sends this as soon as it sees the other's PAKE message. So - # the sender should see this hot on the heels of the inbound PAKE - # message (a moment after _get_master_key() returns). The - # receiver will see this a round-trip after they send their PAKE - # (because the sender is using wait=True inside _get_master_key, - # below: otherwise the sender might go do some blocking call). - yield self._msg_get(u"confirm") - returnValue(self._verifier) - - @inlineCallbacks - def _get_master_key(self): - # TODO: prevent multiple invocation - if not self._key: - yield self._claim_nameplate_and_watch() - yield self._msg_send(u"pake", self._msg1) - pake_msg = yield self._msg_get(u"pake") - - with self._timing.add("pake2", waiting="crypto"): - self._key = self._sp.finish(pake_msg) - self._verifier = self.derive_key(u"wormhole:verifier") - self._timing.add("key established") - - if self._send_confirm: - # both sides send different (random) confirmation messages - confkey = self.derive_key(u"wormhole:confirmation") - nonce = os.urandom(CONFMSG_NONCE_LENGTH) - confmsg = make_confmsg(confkey, nonce) - yield self._msg_send(u"confirm", confmsg, wait=True) - - def derive_key(self, purpose, length=SecretBox.KEY_SIZE): - if not isinstance(purpose, type(u"")): raise TypeError(type(purpose)) - if self._key is None: - # call after get_verifier() or get() - raise UsageError - return HKDF(self._key, length, CTXinfo=to_bytes(purpose)) - - def _encrypt_data(self, key, data): - assert isinstance(key, type(b"")), type(key) - assert isinstance(data, type(b"")), type(data) - assert len(key) == SecretBox.KEY_SIZE, len(key) - box = SecretBox(key) - nonce = utils.random(SecretBox.NONCE_SIZE) - return box.encrypt(data, nonce) - - def _decrypt_data(self, key, encrypted): - assert isinstance(key, type(b"")), type(key) - assert isinstance(encrypted, type(b"")), type(encrypted) - assert len(key) == SecretBox.KEY_SIZE, len(key) - box = SecretBox(key) - data = box.decrypt(encrypted) - return data - - @inlineCallbacks - def send(self, outbound_data, wait=False): - if not isinstance(outbound_data, type(b"")): - raise TypeError(type(outbound_data)) - if self._closed: raise UsageError - if self._code is None: - raise UsageError("You must set_code() before send()") - phase = self._next_outbound_phase - self._next_outbound_phase += 1 - with self._timing.add("API send", phase=phase, wait=wait): - # Without predefined roles, we can't derive predictably unique - # keys for each side, so we use the same key for both. We use - # random nonces to keep the messages distinct, and we - # automatically ignore reflections. - yield self._get_master_key() - data_key = self.derive_key(u"wormhole:phase:%d" % phase) - outbound_encrypted = self._encrypt_data(data_key, outbound_data) - yield self._msg_send(phase, outbound_encrypted, wait) - - @inlineCallbacks - def get(self): - if self._closed: raise UsageError - if self._code is None: raise UsageError - phase = self._next_inbound_phase - self._next_inbound_phase += 1 - with self._timing.add("API get", phase=phase): - yield self._get_master_key() - body = yield self._msg_get(phase) # we can wait a long time here - try: - data_key = self.derive_key(u"wormhole:phase:%d" % phase) - inbound_data = self._decrypt_data(data_key, body) - returnValue(inbound_data) - except CryptoError: - raise WrongPasswordError - - def _ws_closed(self, wasClean, code, reason): - self._ws = None - self._ws_t.finish() - # TODO: schedule reconnect, unless we're done - - @inlineCallbacks - def close(self, f=None, mood=None): - """Do d.addBoth(w.close) at the end of your chain.""" - if self._closed: - returnValue(None) - self._closed = True - if not self._ws: - returnValue(None) - - if mood is None: - mood = u"happy" - if f: - if f.check(Timeout): - mood = u"lonely" - elif f.check(WrongPasswordError): - mood = u"scary" - elif f.check(TypeError, UsageError): - # preconditions don't warrant reporting mood - pass - else: - mood = u"errory" # other errors do - if not isinstance(mood, (type(None), type(u""))): - raise TypeError(type(mood)) - - with self._timing.add("API close"): - yield self._release(mood) - # TODO: mark WebSocket as don't-reconnect - self._ws.transport.loseConnection() # probably flushes - del self._ws - self._ws_t.finish() - self._timing_started.finish(mood=mood) - returnValue(f) - - @inlineCallbacks - def _release(self, mood): - with self._timing.add("release"): - yield self._ws_send_command(u"release", mood=mood) - while self._released_status is None: - yield self._sleep(wake_on_error=False) - # TODO: set a timeout, don't wait forever for an ack - # TODO: if the connection is lost, let it go - returnValue(self._released_status) - - def _ws_handle_released(self, msg): - self._released_status = msg["status"] - self._wakeup() - -def wormhole(appid, relay_url, reactor, tor_manager=None, timing=None): - w = _Wormhole(appid, relay_url, reactor, tor_manager, timing) - w._start() - return w - -def wormhole_from_serialized(data, reactor): - w = _Wormhole.from_serialized(data, reactor) - return w From bc908ef07e390ac98a7a2418b57854a62d5fc86b Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Tue, 24 May 2016 00:10:16 -0700 Subject: [PATCH 45/51] setup.py: pin twisted==16.1.1, remove pytrie * To avoid an incompatible patch that landed in Twisted trunk after the 16.1.1 release, autobahn pinned their requirement on Twisted to be <=16.1.1 . However Twisted reverted the patch before making a release. The new 16.2.0 is fine. Since autobahn has this pin, and since pip doesn't do full dependency resolution, I must add the pin too, so that 'pip install magic-wormhole' can work. I plan to remove this pin as soon as autobahn does the same upstream. https://github.com/crossbario/autobahn-python/issues/680 * A previous version of autobahn had a bug where it tried to import something that wasn't actually depended upon, exposed by having pynacl installed. Installing 'pytrie' manually fixed it. This doesn't seem to be a problem anymore, so I'm removing the manual dependency. --- setup.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/setup.py b/setup.py index 803c131..1f2a3de 100644 --- a/setup.py +++ b/setup.py @@ -25,10 +25,10 @@ setup(name="magic-wormhole", "wormhole-server = wormhole.server.runner:entry", ]}, install_requires=["spake2==0.3", "pynacl", "argparse", - "six", "twisted >= 16.1.0", "hkdf", "tqdm", - "autobahn[twisted]", "pytrie", - # autobahn seems to have a bug, and one plugin throws - # errors unless pytrie is installed + "six", + "twisted==16.1.1", # since autobahn pins it + "autobahn[twisted]", + "hkdf", "tqdm", ], extras_require={"tor": ["txtorcon", "ipaddr"]}, test_suite="wormhole.test", From 77661bf94e88dd7c774cf6479991817fead2155a Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Tue, 24 May 2016 13:10:45 -0700 Subject: [PATCH 46/51] use new relay URL, for new protocol --- src/wormhole/cli/public_relay.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/wormhole/cli/public_relay.py b/src/wormhole/cli/public_relay.py index de7b339..7d82c66 100644 --- a/src/wormhole/cli/public_relay.py +++ b/src/wormhole/cli/public_relay.py @@ -1,5 +1,5 @@ # This is a relay I run on a personal server. If it gets too expensive to # run, I'll shut it down. -RENDEZVOUS_RELAY = u"http://wormhole-relay.petmail.org:3000/wormhole-relay/" -TRANSIT_RELAY = u"tcp:wormhole-transit-relay.petmail.org:3001" +RENDEZVOUS_RELAY = u"ws://wormhole-relay.petmail.org:4000/" +TRANSIT_RELAY = u"tcp:wormhole-transit-relay.petmail.org:4001" From b72f0ce934aa48ff691490a534e4a0f6be380fff Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Tue, 24 May 2016 13:14:34 -0700 Subject: [PATCH 47/51] INCOMPATIBLE CHANGE: switch to spake2-0.7 This changes the way keys are derived, and thus is incompatible with previous versions. We pin "spake2==0.7" to avoid future surprises. --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 1f2a3de..d1c239a 100644 --- a/setup.py +++ b/setup.py @@ -24,7 +24,7 @@ setup(name="magic-wormhole", ["wormhole = wormhole.cli.runner:entry", "wormhole-server = wormhole.server.runner:entry", ]}, - install_requires=["spake2==0.3", "pynacl", "argparse", + install_requires=["spake2==0.7", "pynacl", "argparse", "six", "twisted==16.1.1", # since autobahn pins it "autobahn[twisted]", From 7c8e5fb062137284f514b3044b2cf68a4021c90d Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Tue, 24 May 2016 13:26:08 -0700 Subject: [PATCH 48/51] factor out key-derivation, prepare for change --- src/wormhole/wormhole.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/src/wormhole/wormhole.py b/src/wormhole/wormhole.py index 662b6e4..95575bc 100644 --- a/src/wormhole/wormhole.py +++ b/src/wormhole/wormhole.py @@ -533,11 +533,14 @@ class _Wormhole: self._key = self._sp.finish(pake_msg) self._event_established_key() + def _derive_confirmation_key(self): + return self._derive_key(u"wormhole:confirmation") + def _event_established_key(self): self._timing.add("key established") # both sides send different (random) confirmation messages - confkey = self._derive_key(u"wormhole:confirmation") + confkey = self._derive_confirmation_key() nonce = os.urandom(CONFMSG_NONCE_LENGTH) confmsg = make_confmsg(confkey, nonce) self._msg_send(u"confirm", confmsg) @@ -571,7 +574,7 @@ class _Wormhole: # TODO: we might not have a master key yet, if the caller wasn't # waiting in _get_master_key() when a back-to-back pake+_confirm # message pair arrived. - confkey = self._derive_key(u"wormhole:confirmation") + confkey = self._derive_confirmation_key() nonce = body[:CONFMSG_NONCE_LENGTH] if body != make_confmsg(confkey, nonce): # this makes all API calls fail @@ -589,6 +592,10 @@ class _Wormhole: with self._timing.add("API send", phase=phase): self._maybe_send_phase_messages() + #def _derive_phase_key(self, side, phase): + def _derive_phase_key(self, phase): + return self._derive_key(u"wormhole:phase:%s" % phase) + def _maybe_send_phase_messages(self): # TODO: deal with reentrant call if not (self._connection_state == OPEN @@ -600,7 +607,7 @@ class _Wormhole: for pm in plaintexts: (phase, plaintext) = pm assert isinstance(phase, int), type(phase) - data_key = self._derive_key(u"wormhole:phase:%d" % phase) + data_key = self._derive_phase_key(u"%d" % phase) encrypted = self._encrypt_data(data_key, plaintext) self._msg_send(u"%d" % phase, encrypted) @@ -672,7 +679,7 @@ class _Wormhole: # It's a phase message, aimed at the application above us. Decrypt # and deliver upstairs, notifying anyone waiting on it try: - data_key = self._derive_key(u"wormhole:phase:%s" % phase) + data_key = self._derive_phase_key(phase) plaintext = self._decrypt_data(data_key, body) except CryptoError: e = WrongPasswordError() From 97c5d08b6adb9a5c4793376ebb5f55676e2f9e05 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Tue, 24 May 2016 13:31:03 -0700 Subject: [PATCH 49/51] internally, _derive_key now takes bytes The w.derive_key(purpose) API still requires unicode. --- src/wormhole/wormhole.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/src/wormhole/wormhole.py b/src/wormhole/wormhole.py index 95575bc..c56148d 100644 --- a/src/wormhole/wormhole.py +++ b/src/wormhole/wormhole.py @@ -534,7 +534,7 @@ class _Wormhole: self._event_established_key() def _derive_confirmation_key(self): - return self._derive_key(u"wormhole:confirmation") + return self._derive_key(b"wormhole:confirmation") def _event_established_key(self): self._timing.add("key established") @@ -545,7 +545,7 @@ class _Wormhole: confmsg = make_confmsg(confkey, nonce) self._msg_send(u"confirm", confmsg) - verifier = self._derive_key(u"wormhole:verifier") + verifier = self._derive_key(b"wormhole:verifier") self._event_computed_verifier(verifier) self._maybe_send_phase_messages() @@ -594,7 +594,9 @@ class _Wormhole: #def _derive_phase_key(self, side, phase): def _derive_phase_key(self, phase): - return self._derive_key(u"wormhole:phase:%s" % phase) + assert isinstance(phase, type(b"")), type(phase) + purpose = b"wormhole:phase:" + phase + return self._derive_key(purpose) def _maybe_send_phase_messages(self): # TODO: deal with reentrant call @@ -607,7 +609,8 @@ class _Wormhole: for pm in plaintexts: (phase, plaintext) = pm assert isinstance(phase, int), type(phase) - data_key = self._derive_phase_key(u"%d" % phase) + phase_bytes = (u"%d" % phase).encode("ascii") + data_key = self._derive_phase_key(phase_bytes) encrypted = self._encrypt_data(data_key, plaintext) self._msg_send(u"%d" % phase, encrypted) @@ -644,13 +647,14 @@ class _Wormhole: def _API_derive_key(self, purpose, length): if self._error: raise self._error - return self._derive_key(purpose, length) + if not isinstance(purpose, type(u"")): raise TypeError(type(purpose)) + return self._derive_key(to_bytes(purpose), length) def _derive_key(self, purpose, length=SecretBox.KEY_SIZE): - if not isinstance(purpose, type(u"")): raise TypeError(type(purpose)) + if not isinstance(purpose, type(b"")): raise TypeError(type(purpose)) if self._key is None: raise UsageError # call derive_key after get_verifier() or get() - return HKDF(self._key, length, CTXinfo=to_bytes(purpose)) + return HKDF(self._key, length, CTXinfo=purpose) def _response_handle_message(self, msg): side = msg["side"] @@ -678,8 +682,9 @@ class _Wormhole: # It's a phase message, aimed at the application above us. Decrypt # and deliver upstairs, notifying anyone waiting on it + phase_bytes = phase.encode("ascii") try: - data_key = self._derive_phase_key(phase) + data_key = self._derive_phase_key(phase_bytes) plaintext = self._decrypt_data(data_key, body) except CryptoError: e = WrongPasswordError() From 30ab9400342088118b50034650a123afbf28626e Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Tue, 24 May 2016 13:47:15 -0700 Subject: [PATCH 50/51] INCOMPATIBLE: change derivation of phase keys to include side Previously the encryption key used for "phase messages" (anything sent from one side to the other, protected by the shared PAKE-generated session key) was derived just from the session key and the phase name. The two sides would use the same key for their first message (but with random, thus different, nonces). This uses the sending side's string (a random 5-byte/10-character hex string) in the derivation process too, so the two sides use different keys. This gives us an easy way to reject reflected messages. We already ignore messages that claim to use a "side" which matches our own (to ignore server echoes of our own outbound messages). With this change, an attacker (or the server) can't swap in the payload of an outbound message, change the "side" to make it look like a peer message, and then let us decrypt it correctly. It also changes the derivation function to combine the phase and side values safely. This didn't matter much when we only had one externally-provided string, but with two, there's an opportunity for format confusion if they were combined with a simple delimiter. Now we hash both values before concatenating them. This breaks interoperability with clients from before this change. They will always get WrongPasswordErrors. --- src/wormhole/test/test_wormhole.py | 11 ++++++----- src/wormhole/wormhole.py | 30 +++++++++++++++++------------- 2 files changed, 23 insertions(+), 18 deletions(-) diff --git a/src/wormhole/test/test_wormhole.py b/src/wormhole/test/test_wormhole.py index c9e9a0f..8948535 100644 --- a/src/wormhole/test/test_wormhole.py +++ b/src/wormhole/test/test_wormhole.py @@ -251,7 +251,7 @@ class Basic(unittest.TestCase): self.check_out(out[0], type=u"add", phase=u"0") # decrypt+check the outbound message p0_outbound = unhexlify(out[0][u"body"].encode("ascii")) - msgkey0 = w.derive_key(u"wormhole:phase:0", SecretBox.KEY_SIZE) + msgkey0 = w._derive_phase_key(w._side, u"0") p0_plaintext = w._decrypt_data(msgkey0, p0_outbound) self.assertEqual(p0_plaintext, b"phase0-outbound") @@ -260,7 +260,8 @@ class Basic(unittest.TestCase): self.assertNoResult(md) self.assertIn(u"0", w._receive_waiters) self.assertNotIn(u"0", w._received_messages) - p0_inbound = w._encrypt_data(msgkey0, b"phase0-inbound") + msgkey1 = w._derive_phase_key(side2, u"0") + p0_inbound = w._encrypt_data(msgkey1, b"phase0-inbound") p0_inbound_hex = hexlify(p0_inbound).decode("ascii") response(w, type=u"message", phase=u"0", body=p0_inbound_hex, side=side2) @@ -270,8 +271,8 @@ class Basic(unittest.TestCase): self.assertIn(u"0", w._received_messages) # receiving an inbound message will queue it until get() is called - msgkey1 = w.derive_key(u"wormhole:phase:1", SecretBox.KEY_SIZE) - p1_inbound = w._encrypt_data(msgkey1, b"phase1-inbound") + msgkey2 = w._derive_phase_key(side2, u"1") + p1_inbound = w._encrypt_data(msgkey2, b"phase1-inbound") p1_inbound_hex = hexlify(p1_inbound).decode("ascii") response(w, type=u"message", phase=u"1", body=p1_inbound_hex, side=side2) @@ -433,7 +434,7 @@ class Basic(unittest.TestCase): response(w, type=u"claimed", mailbox=u"mb456") w._key = b"" - msgkey = w.derive_key(u"wormhole:phase:misc", SecretBox.KEY_SIZE) + msgkey = w._derive_phase_key(u"side2", u"misc") p1_inbound = w._encrypt_data(msgkey, b"") p1_inbound_hex = hexlify(p1_inbound).decode("ascii") response(w, type=u"message", phase=u"misc", side=u"side2", diff --git a/src/wormhole/wormhole.py b/src/wormhole/wormhole.py index c56148d..d84af3d 100644 --- a/src/wormhole/wormhole.py +++ b/src/wormhole/wormhole.py @@ -11,6 +11,7 @@ from nacl.secret import SecretBox from nacl.exceptions import CryptoError from nacl import utils from spake2 import SPAKE2_Symmetric +from hashlib import sha256 from . import __version__ from . import codes #from .errors import ServerError, Timeout @@ -592,10 +593,14 @@ class _Wormhole: with self._timing.add("API send", phase=phase): self._maybe_send_phase_messages() - #def _derive_phase_key(self, side, phase): - def _derive_phase_key(self, phase): - assert isinstance(phase, type(b"")), type(phase) - purpose = b"wormhole:phase:" + phase + def _derive_phase_key(self, side, phase): + assert isinstance(side, type(u"")), type(side) + assert isinstance(phase, type(u"")), type(phase) + side_bytes = side.encode("ascii") + phase_bytes = phase.encode("ascii") + purpose = (b"wormhole:phase:" + + sha256(side_bytes).digest() + + sha256(phase_bytes).digest()) return self._derive_key(purpose) def _maybe_send_phase_messages(self): @@ -607,12 +612,12 @@ class _Wormhole: plaintexts = self._plaintext_to_send self._plaintext_to_send = [] for pm in plaintexts: - (phase, plaintext) = pm - assert isinstance(phase, int), type(phase) - phase_bytes = (u"%d" % phase).encode("ascii") - data_key = self._derive_phase_key(phase_bytes) + (phase_int, plaintext) = pm + assert isinstance(phase_int, int), type(phase_int) + phase = u"%d" % phase_int + data_key = self._derive_phase_key(self._side, phase) encrypted = self._encrypt_data(data_key, plaintext) - self._msg_send(u"%d" % phase, encrypted) + self._msg_send(phase, encrypted) def _encrypt_data(self, key, data): # Without predefined roles, we can't derive predictably unique keys @@ -663,9 +668,9 @@ class _Wormhole: body = unhexlify(msg["body"].encode("ascii")) if side == self._side: return - self._event_received_peer_message(phase, body) + self._event_received_peer_message(side, phase, body) - def _event_received_peer_message(self, phase, body): + def _event_received_peer_message(self, side, phase, body): # any message in the mailbox means we no longer need the nameplate self._event_mailbox_used() #if phase in self._received_messages: @@ -682,9 +687,8 @@ class _Wormhole: # It's a phase message, aimed at the application above us. Decrypt # and deliver upstairs, notifying anyone waiting on it - phase_bytes = phase.encode("ascii") try: - data_key = self._derive_phase_key(phase_bytes) + data_key = self._derive_phase_key(side, phase) plaintext = self._decrypt_data(data_key, body) except CryptoError: e = WrongPasswordError() From 90e6d23c1732cfe0683f212cf47223560d31632f Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Tue, 24 May 2016 14:11:57 -0700 Subject: [PATCH 51/51] change server default port to match new public relay --- src/wormhole/server/cli_args.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/wormhole/server/cli_args.py b/src/wormhole/server/cli_args.py index d68d8db..3f3b02b 100644 --- a/src/wormhole/server/cli_args.py +++ b/src/wormhole/server/cli_args.py @@ -18,9 +18,9 @@ s = parser.add_subparsers(title="subcommands", dest="subcommand") # CLI: run-server sp_start = s.add_parser("start", description="Start a relay server", usage="wormhole server start [opts] [TWISTD-ARGS..]") -sp_start.add_argument("--rendezvous", default="tcp:3000", metavar="tcp:PORT", +sp_start.add_argument("--rendezvous", default="tcp:4000", metavar="tcp:PORT", help="endpoint specification for the rendezvous port") -sp_start.add_argument("--transit", default="tcp:3001", metavar="tcp:PORT", +sp_start.add_argument("--transit", default="tcp:4001", metavar="tcp:PORT", help="endpoint specification for the transit-relay port") sp_start.add_argument("--advertise-version", metavar="VERSION", help="version to recommend to clients")