diff --git a/docs/api.md b/docs/api.md index fd7e6cb..2efb436 100644 --- a/docs/api.md +++ b/docs/api.md @@ -10,6 +10,15 @@ short string that is transcribed from one machine to the other by the users at the keyboard. This works in conjunction with a baked-in "rendezvous server" that relays information from one machine to the other. +The "Wormhole" object provides a secure record pipe between any two programs +that use the same wormhole code (and are configured with the same application +ID and rendezvous server). Each side can send multiple messages to the other, +but the encrypted data for all messages must pass through (and be temporarily +stored on) the rendezvous server, which is a shared resource. For this +reason, larger data (including bulk file transfers) should use the Transit +class instead. The Wormhole object has a method to create a Transit object +for this purpose. + ## Modes This library will eventually offer multiple modes. For now, only "transcribe @@ -39,26 +48,36 @@ string. The two machines participating in the wormhole setup are not distinguished: it doesn't matter which one goes first, and both use the same Wormhole class. In the first variant, one side calls `get_code()` while the other calls -`set_code()`. In the second variant, both sides call `set_code()`. Note that +`set_code()`. In the second variant, both sides call `set_code()`. (Note that this is not true for the "Transit" protocol used for bulk data-transfer: the Transit class currently distinguishes "Sender" from "Receiver", so the -programs on each side must have some way to decide (ahead of time) which is -which. +programs on each side must have some way to decide ahead of time which is +which). -Each side gets to do one `send_data()` call and one `get_data()` call per -phase (see below). `get_data` will wait until the other side has done -`send_data`, so the application developer must be careful to avoid deadlocks -(don't get before you send on both sides in the same protocol). When both -sides are done, they must call `close()`, to let the library know that the -connection is complete and it can deallocate the channel. If you forget to -call `close()`, the server will not free the channel, and other users will -suffer longer invitation codes as a result. To encourage `close()`, the -library will log an error if a Wormhole object is destroyed before being -closed. +Each side can then do an arbitrary number of `send()` and `get()` calls. +`send()` writes a message into the channel. `get()` waits for a new message +to be available, then returns it. The Wormhole is not meant as a long-term +communication channel, but some protocols work better if they can exchange an +initial pair of messages (perhaps offering some set of negotiable +capabilities), and then follow up with a second pair (to reveal the results +of the negotiation). Another use case is for an ACK that gets sent at the end +of a file transfer: the Wormhole is held open until the Transit object +reports completion, and the last message is a hash of the file contents to +prove it was received correctly. + +Note: the application developer must be careful to avoid deadlocks (if both +sides want to `get()`, somebody has to `send()` first). + +When both sides are done, they must call `close()`, to let the library know +that the connection is complete and it can deallocate the channel. If you +forget to call `close()`, the server will not free the channel, and other +users will suffer longer invitation codes as a result. To encourage +`close()`, the library will log an error if a Wormhole object is destroyed +before being closed. To make it easier to call `close()`, the blocking Wormhole objects can be used as a context manager. Just put your code in the body of a `with -Wormhole(ARGS) as w:` statement, and `close()` will automatically be called +wormhole(ARGS) as w:` statement, and `close()` will automatically be called when the block exits (either successfully or due to an exception). ## Examples @@ -66,27 +85,27 @@ when the block exits (either successfully or due to an exception). The synchronous+blocking flow looks like this: ```python -from wormhole.blocking.transcribe import Wormhole +from wormhole.blocking.transcribe import wormhole from wormhole.public_relay import RENDEZVOUS_RELAY mydata = b"initiator's data" -with Wormhole(u"appid", RENDEZVOUS_RELAY) as i: +with wormhole(u"appid", RENDEZVOUS_RELAY) as i: code = i.get_code() print("Invitation Code: %s" % code) - i.send_data(mydata) - theirdata = i.get_data() + i.send(mydata) + theirdata = i.get() print("Their data: %s" % theirdata.decode("ascii")) ``` ```python import sys -from wormhole.blocking.transcribe import Wormhole +from wormhole.blocking.transcribe import wormhole from wormhole.public_relay import RENDEZVOUS_RELAY mydata = b"receiver's data" code = sys.argv[1] -with Wormhole(u"appid", RENDEZVOUS_RELAY) as r: +with wormhole(u"appid", RENDEZVOUS_RELAY) as r: r.set_code(code) - r.send_data(mydata) - theirdata = r.get_data() + r.send(mydata) + theirdata = r.get() print("Their data: %s" % theirdata.decode("ascii")) ``` @@ -97,18 +116,18 @@ The Twisted-friendly flow looks like this: ```python from twisted.internet import reactor from wormhole.public_relay import RENDEZVOUS_RELAY -from wormhole.twisted.transcribe import Wormhole +from wormhole.twisted.transcribe import wormhole outbound_message = b"outbound data" -w1 = Wormhole(u"appid", RENDEZVOUS_RELAY) +w1 = wormhole(u"appid", RENDEZVOUS_RELAY, reactor) d = w1.get_code() def _got_code(code): print "Invitation Code:", code - return w1.send_data(outbound_message) + return w1.send(outbound_message) d.addCallback(_got_code) -d.addCallback(lambda _: w1.get_data()) -def _got_data(inbound_message): +d.addCallback(lambda _: w1.get()) +def _got(inbound_message): print "Inbound message:", inbound_message -d.addCallback(_got_data) +d.addCallback(_got) d.addCallback(w1.close) d.addBoth(lambda _: reactor.stop()) reactor.run() @@ -117,9 +136,9 @@ reactor.run() On the other side, you call `set_code()` instead of waiting for `get_code()`: ```python -w2 = Wormhole(u"appid", RENDEZVOUS_RELAY) +w2 = wormhole(u"appid", RENDEZVOUS_RELAY, reactor) w2.set_code(code) -d = w2.send_data(my_message) +d = w2.send(my_message) ... ``` @@ -127,56 +146,45 @@ Note that the Twisted-form `close()` accepts (and returns) an optional argument, so you can use `d.addCallback(w.close)` instead of `d.addCallback(lambda _: w.close())`. -## Phases - -If necessary, more than one message can be exchanged through the relay -server. It is not meant as a long-term communication channel, but some -protocols work better if they can exchange an initial pair of messages -(perhaps offering some set of negotiable capabilities), and then follow up -with a second pair (to reveal the results of the negotiation). - -To support this, `send_data()/get_data()` accept a "phase" argument: an -arbitrary (unicode) string. It must match the other side: calling -`send_data(data, phase=u"offer")` on one side will deliver that data to -`get_data(phase=u"offer")` on the other. - -It is a UsageError to call `send_data()` or `get_data()` twice with the same -phase name. The relay server may limit the number of phases that may be -exchanged, however it will always allow at least two. - ## Verifier -You can call `w.get_verifier()` before `send_data()/get_data()`: this will -perform the first half of the PAKE negotiation, then return a verifier object -(bytes) which can be converted into a printable representation and manually -compared. When the users are convinced that `get_verifier()` from both sides -are the same, call `send_data()/get_data()` to continue the transfer. If you -call `send_data()/get_data()` before `get_verifier()`, it will perform the -complete transfer without pausing. +For extra protection against guessing attacks, Wormhole can provide a +"Verifier". This is a moderate-length series of bytes (a SHA256 hash) that is +derived from the supposedly-shared session key. If desired, both sides can +display this value, and the humans can manually compare them before allowing +the rest of the protocol to proceed. If they do not match, then the two +programs are not talking to each other (they may both be talking to a +man-in-the-middle attacker), and the protocol should be abandoned. + +To retrieve the verifier, you call `w.get_verifier()` before any calls to +`send()/get()`. Turn this into hex or Base64 to print it, or render it as +ASCII-art, etc. Once the users are convinced that `get_verifier()` from both +sides are the same, call `send()/get()` to continue the protocol. If you call +`send()/get()` before `get_verifier()`, it will perform the complete protocol +without pausing. The Twisted form of `get_verifier()` returns a Deferred that fires with the verifier bytes. ## Generating the Invitation Code -In most situations, the "sending" or "initiating" side will call -`i.get_code()` to generate the invitation code. This returns a string in the -form `NNN-code-words`. The numeric "NNN" prefix is the "channel id", and is a +In most situations, the "sending" or "initiating" side will call `get_code()` +to generate the invitation code. This returns a string in the form +`NNN-code-words`. The numeric "NNN" prefix is the "channel id", and is a short integer allocated by talking to the rendezvous server. The rest is a randomly-generated selection from the PGP wordlist, providing a default of 16 bits of entropy. The initiating program should display this code to the user, who should transcribe it to the receiving user, who gives it to the Receiver -object by calling `r.set_code()`. The receiving program can also use +object by calling `set_code()`. The receiving program can also use `input_code_with_completion()` to use a readline-based input function: this offers tab completion of allocated channel-ids and known codewords. Alternatively, the human users can agree upon an invitation code themselves, -and provide it to both programs later (with `i.set_code()` and -`r.set_code()`). They should choose a channel-id that is unlikely to already -be in use (3 or more digits are recommended), append a hyphen, and then -include randomly-selected words or characters. Dice, coin flips, shuffled -cards, or repeated sampling of a high-resolution stopwatch are all useful -techniques. +and provide it to both programs later (both sides call `set_code()`). They +should choose a channel-id that is unlikely to already be in use (3 or more +digits are recommended), append a hyphen, and then include randomly-selected +words or characters. Dice, coin flips, shuffled cards, or repeated sampling +of a high-resolution stopwatch are all useful techniques. Note that the code is a human-readable string (the python "unicode" type in python2, "str" in python3). @@ -192,8 +200,8 @@ invitation codes are scoped to the app-id. Note that the app-id must be unicode, not bytes, so on python2 use `u"appid"`. Distinct app-ids reduce the size of the connection-id numbers. If fewer than -ten initiators are active for a given app-id, the connection-id will only -need to contain a single digit, even if some other app-id is currently using +ten Wormholes are active for a given app-id, the connection-id will only need +to contain a single digit, even if some other app-id is currently using thousands of concurrent sessions. ## Rendezvous Relays @@ -245,16 +253,14 @@ You may not be able to hold the Wormhole object in memory for the whole sync process: maybe you allow it to wait for several days, but the program will be restarted during that time. To support this, you can persist the state of the object by calling `data = w.serialize()`, which will return a printable -bytestring (the JSON-encoding of a small dictionary). To restore, use the -`from_serialized(data)` classmethod (e.g. `w = -Wormhole.from_serialized(data)`). +bytestring (the JSON-encoding of a small dictionary). To restore, use `w = +wormhole_from_serialized(data, reactor)`. There is exactly one point at which you can serialize the wormhole: *after* establishing the invitation code, but before waiting for `get_verifier()` or -`get_data()`, or calling `send_data()`. If you are creating a new invitation -code, the correct time is during the callback fired by `get_code()`. If you -are accepting a pre-generated code, the time is just after calling -`set_code()`. +`get()`, or calling `send()`. If you are creating a new invitation code, the +correct time is during the callback fired by `get_code()`. If you are +accepting a pre-generated code, the time is just after calling `set_code()`. To properly checkpoint the process, you should store the first message (returned by `start()`) next to the serialized wormhole instance, so you can @@ -278,9 +284,3 @@ in python3): * transit connection hints (e.g. "host:port") * application identifier * derived-key "purpose" string: `w.derive_key(PURPOSE)` - -## Detailed Example - -```python - -``` diff --git a/events.dot b/events.dot new file mode 100644 index 0000000..f6aa1c3 --- /dev/null +++ b/events.dot @@ -0,0 +1,93 @@ +digraph { + api_get_code [label="get_code" shape="hexagon" color="red"] + api_input_code [label="input_code" shape="hexagon" color="red"] + api_set_code [label="set_code" shape="hexagon" color="red"] + send [label="API\nsend" shape="hexagon" color="red"] + get [label="API\nget" shape="hexagon" color="red"] + close [label="API\nclose" shape="hexagon" color="red"] + + event_connected [label="connected" shape="box"] + event_learned_code [label="learned\ncode" shape="box"] + event_learned_nameplate [label="learned\nnameplate" shape="box"] + event_received_mailbox [label="received\nmailbox" shape="box"] + event_opened_mailbox [label="opened\nmailbox" shape="box"] + event_built_msg1 [label="built\nmsg1" shape="box"] + event_mailbox_used [label="mailbox\nused" shape="box"] + event_learned_PAKE [label="learned\nmsg2" shape="box"] + event_established_key [label="established\nkey" shape="box"] + event_computed_verifier [label="computed\nverifier" shape="box"] + event_received_confirm [label="received\nconfirm" shape="box"] + event_received_message [label="received\nmessage" shape="box"] + event_received_released [label="ack\nreleased" shape="box"] + event_received_closed [label="ack\nclosed" shape="box"] + + event_connected -> api_get_code + event_connected -> api_input_code + api_get_code -> event_learned_code + api_input_code -> event_learned_code + api_set_code -> event_learned_code + + + maybe_build_msg1 [label="build\nmsg1"] + maybe_claim_nameplate [label="claim\nnameplate"] + maybe_send_pake [label="send\npake"] + maybe_send_phase_messages [label="send\nphase\nmessages"] + + event_connected -> maybe_claim_nameplate + event_connected -> maybe_send_pake + + event_built_msg1 -> maybe_send_pake + + event_learned_code -> maybe_build_msg1 + event_learned_code -> event_learned_nameplate + + maybe_build_msg1 -> event_built_msg1 + event_learned_nameplate -> maybe_claim_nameplate + maybe_claim_nameplate -> event_received_mailbox [style="dashed"] + + event_received_mailbox -> event_opened_mailbox + maybe_claim_nameplate -> event_learned_PAKE [style="dashed"] + maybe_claim_nameplate -> event_received_confirm [style="dashed"] + + event_opened_mailbox -> event_learned_PAKE [style="dashed"] + event_learned_PAKE -> event_mailbox_used [style="dashed"] + event_learned_PAKE -> event_received_confirm [style="dashed"] + event_received_confirm -> event_received_message [style="dashed"] + + send -> maybe_send_phase_messages + release_nameplate [label="release\nnameplate"] + event_mailbox_used -> release_nameplate + event_opened_mailbox -> maybe_send_pake + event_opened_mailbox -> maybe_send_phase_messages + + event_learned_PAKE -> event_established_key + event_established_key -> event_computed_verifier + event_established_key -> maybe_send_phase_messages + + check_verifier [label="check\nverifier"] + event_computed_verifier -> check_verifier + event_received_confirm -> check_verifier + + check_verifier -> error + event_received_message -> error + event_received_message -> get + event_established_key -> get + + close -> close_mailbox + close -> release_nameplate + error [label="signal\nerror"] + error -> close_mailbox + error -> release_nameplate + + release_nameplate -> event_received_released [style="dashed"] + close_mailbox [label="close\nmailbox"] + close_mailbox -> event_received_closed [style="dashed"] + + maybe_close_websocket [label="close\nwebsocket"] + event_received_released -> maybe_close_websocket + event_received_closed -> maybe_close_websocket + maybe_close_websocket -> event_websocket_closed [style="dashed"] + event_websocket_closed [label="websocket\nclosed"] + + +} diff --git a/setup.py b/setup.py index 0e4e4e3..d1c239a 100644 --- a/setup.py +++ b/setup.py @@ -14,7 +14,6 @@ setup(name="magic-wormhole", url="https://github.com/warner/magic-wormhole", package_dir={"": "src"}, packages=["wormhole", - "wormhole.blocking", "wormhole.cli", "wormhole.server", "wormhole.test", @@ -25,11 +24,11 @@ setup(name="magic-wormhole", ["wormhole = wormhole.cli.runner:entry", "wormhole-server = wormhole.server.runner:entry", ]}, - install_requires=["spake2==0.3", "pynacl", "requests", "argparse", - "six", "twisted >= 16.1.0", "hkdf", "tqdm", - "autobahn[twisted]", "pytrie", - # autobahn seems to have a bug, and one plugin throws - # errors unless pytrie is installed + install_requires=["spake2==0.7", "pynacl", "argparse", + "six", + "twisted==16.1.1", # since autobahn pins it + "autobahn[twisted]", + "hkdf", "tqdm", ], extras_require={"tor": ["txtorcon", "ipaddr"]}, test_suite="wormhole.test", diff --git a/src/wormhole/blocking/__init__.py b/src/wormhole/blocking/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/wormhole/blocking/eventsource.py b/src/wormhole/blocking/eventsource.py deleted file mode 100644 index fd7a4a0..0000000 --- a/src/wormhole/blocking/eventsource.py +++ /dev/null @@ -1,49 +0,0 @@ -from __future__ import print_function, unicode_literals -import requests - -class EventSourceFollower: - def __init__(self, url, timeout): - self._resp = requests.get(url, - headers={"accept": "text/event-stream"}, - stream=True, - timeout=timeout) - self._resp.raise_for_status() - self._lines_iter = self._resp.iter_lines(chunk_size=1, - decode_unicode=True) - - def close(self): - self._resp.close() - - def iter_events(self): - # I think Request.iter_lines and .iter_content use chunk_size= in a - # funny way, and nothing happens until at least that much data has - # arrived. So unless we set chunk_size=1, we won't hear about lines - # for a long time. I'd prefer that chunk_size behaved like - # read(size), and gave you 1<=x<=size bytes in response. - eventtype = "message" - current_lines = [] - for line in self._lines_iter: - assert isinstance(line, type(u"")), type(line) - if not line: - # blank line ends the field: deliver event, reset for next - yield (eventtype, "\n".join(current_lines)) - eventtype = "message" - current_lines[:] = [] - continue - if ":" in line: - fieldname, data = line.split(":", 1) - if data.startswith(" "): - data = data[1:] - else: - fieldname = line - data = "" - if fieldname == "event": - eventtype = data - elif fieldname == "data": - current_lines.append(data) - elif fieldname in ("id", "retry"): - # documented but unhandled - pass - else: - #log.msg("weird fieldname", fieldname, data) - pass diff --git a/src/wormhole/blocking/transcribe.py b/src/wormhole/blocking/transcribe.py deleted file mode 100644 index d2a9920..0000000 --- a/src/wormhole/blocking/transcribe.py +++ /dev/null @@ -1,413 +0,0 @@ -from __future__ import print_function -import os, sys, time, re, requests, json, unicodedata -from six.moves.urllib_parse import urlencode -from binascii import hexlify, unhexlify -from spake2 import SPAKE2_Symmetric -from nacl.secret import SecretBox -from nacl.exceptions import CryptoError -from nacl import utils -from .eventsource import EventSourceFollower -from .. import __version__ -from .. import codes -from ..errors import ServerError, Timeout, WrongPasswordError, UsageError -from ..timing import DebugTiming -from hkdf import Hkdf -from ..channel_monitor import monitor - -def HKDF(skm, outlen, salt=None, CTXinfo=b""): - return Hkdf(salt, skm).expand(CTXinfo, outlen) - -SECOND = 1 -MINUTE = 60*SECOND - -CONFMSG_NONCE_LENGTH = 128//8 -CONFMSG_MAC_LENGTH = 256//8 -def make_confmsg(confkey, nonce): - return nonce+HKDF(confkey, CONFMSG_MAC_LENGTH, nonce) - -def to_bytes(u): - return unicodedata.normalize("NFC", u).encode("utf-8") - -class Channel: - def __init__(self, relay_url, appid, channelid, side, handle_welcome, - wait, timeout, timing): - self._relay_url = relay_url - self._appid = appid - self._channelid = channelid - self._side = side - self._handle_welcome = handle_welcome - self._messages = set() # (phase,body) , body is bytes - self._sent_messages = set() # (phase,body) - self._started = time.time() - self._wait = wait - self._timeout = timeout - self._timing = timing - - def _add_inbound_messages(self, messages): - for msg in messages: - phase = msg["phase"] - body = unhexlify(msg["body"].encode("ascii")) - self._messages.add( (phase, body) ) - - def _find_inbound_message(self, phases): - their_messages = self._messages - self._sent_messages - for phase in phases: - for (their_phase,body) in their_messages: - if their_phase == phase: - return (phase, body) - return None - - def send(self, phase, msg): - # TODO: retry on failure, with exponential backoff. We're guarding - # against the rendezvous server being temporarily offline. - if not isinstance(phase, type(u"")): raise TypeError(type(phase)) - if not isinstance(msg, type(b"")): raise TypeError(type(msg)) - self._sent_messages.add( (phase,msg) ) - payload = {"appid": self._appid, - "channelid": self._channelid, - "side": self._side, - "phase": phase, - "body": hexlify(msg).decode("ascii")} - data = json.dumps(payload).encode("utf-8") - with self._timing.add("send %s" % phase): - r = requests.post(self._relay_url+"add", data=data, - timeout=self._timeout) - r.raise_for_status() - resp = r.json() - if "welcome" in resp: - self._handle_welcome(resp["welcome"]) - self._add_inbound_messages(resp["messages"]) - - def get_first_of(self, phases): - if not isinstance(phases, (list, set)): raise TypeError(type(phases)) - for phase in phases: - if not isinstance(phase, type(u"")): raise TypeError(type(phase)) - - # For now, server errors cause the client to fail. TODO: don't. This - # will require changing the client to re-post messages when the - # server comes back up. - - # fire with a bytestring of the first message for any 'phase' that - # wasn't one of our own messages. It will either come from - # previously-received messages, or from an EventSource that we attach - # to the corresponding URL - with self._timing.add("get %s" % "/".join(sorted(phases))): - phase_and_body = self._find_inbound_message(phases) - while phase_and_body is None: - remaining = self._started + self._timeout - time.time() - if remaining < 0: - raise Timeout - queryargs = urlencode([("appid", self._appid), - ("channelid", self._channelid)]) - f = EventSourceFollower(self._relay_url+"watch?%s" % queryargs, - remaining) - # we loop here until the connection is lost, or we see the - # message we want - for (eventtype, line) in f.iter_events(): - if eventtype == "welcome": - self._handle_welcome(json.loads(line)) - if eventtype == "message": - data = json.loads(line) - self._add_inbound_messages([data]) - phase_and_body = self._find_inbound_message(phases) - if phase_and_body: - f.close() - break - if not phase_and_body: - time.sleep(self._wait) - return phase_and_body - - def get(self, phase): - (got_phase, body) = self.get_first_of([phase]) - assert got_phase == phase - return body - - def deallocate(self, mood=None): - # only try once, no retries - data = json.dumps({"appid": self._appid, - "channelid": self._channelid, - "side": self._side, - "mood": mood}).encode("utf-8") - try: - # ignore POST failure, don't call r.raise_for_status(), set a - # short timeout and ignore failures - with self._timing.add("close"): - r = requests.post(self._relay_url+"deallocate", data=data, - timeout=5) - r.json() - except requests.exceptions.RequestException: - pass - -class ChannelManager: - def __init__(self, relay_url, appid, side, handle_welcome, timing=None, - wait=0.5*SECOND, timeout=3*MINUTE): - self._relay_url = relay_url - self._appid = appid - self._side = side - self._handle_welcome = handle_welcome - self._timing = timing or DebugTiming() - self._wait = wait - self._timeout = timeout - - def list_channels(self): - queryargs = urlencode([("appid", self._appid)]) - with self._timing.add("list"): - r = requests.get(self._relay_url+"list?%s" % queryargs, - timeout=self._timeout) - r.raise_for_status() - data = r.json() - if "welcome" in data: - self._handle_welcome(data["welcome"]) - channelids = data["channelids"] - return channelids - - def allocate(self): - data = json.dumps({"appid": self._appid, - "side": self._side}).encode("utf-8") - with self._timing.add("allocate"): - r = requests.post(self._relay_url+"allocate", data=data, - timeout=self._timeout) - r.raise_for_status() - data = r.json() - if "welcome" in data: - self._handle_welcome(data["welcome"]) - channelid = data["channelid"] - return channelid - - def connect(self, channelid): - return Channel(self._relay_url, self._appid, channelid, self._side, - self._handle_welcome, self._wait, self._timeout, - self._timing) - -def close_on_error(f): # method decorator - # Clients report certain errors as "moods", so the server can make a - # rough count failed connections (due to mismatched passwords, attacks, - # or timeouts). We don't report precondition failures, as those are the - # responsibility/fault of the local application code. We count - # non-precondition errors in case they represent server-side problems. - def _f(self, *args, **kwargs): - try: - return f(self, *args, **kwargs) - except Timeout: - self.close(u"lonely") - raise - except WrongPasswordError: - self.close(u"scary") - raise - except (TypeError, UsageError): - # preconditions don't warrant _close_with_error() - raise - except: - self.close(u"errory") - raise - return _f - -class Wormhole: - motd_displayed = False - version_warning_displayed = False - _send_confirm = True - - def __init__(self, appid, relay_url, wait=0.5*SECOND, timeout=3*MINUTE, - timing=None): - if not isinstance(appid, type(u"")): raise TypeError(type(appid)) - if not isinstance(relay_url, type(u"")): - raise TypeError(type(relay_url)) - if not relay_url.endswith(u"/"): raise UsageError - self._appid = appid - self._relay_url = relay_url - self._wait = wait - self._timeout = timeout - self._timing = timing or DebugTiming() - side = hexlify(os.urandom(5)).decode("ascii") - self._channel_manager = ChannelManager(relay_url, appid, side, - self.handle_welcome, - self._timing, - self._wait, self._timeout) - self._channel = None - self.code = None - self.key = None - self.verifier = None - self._sent_data = set() # phases - self._got_data = set() - self._got_confirmation = False - self._closed = False - self._timing_started = self._timing.add("wormhole") - - def __enter__(self): - return self - def __exit__(self, exc_type, exc_val, exc_tb): - self.close() - return False - - def handle_welcome(self, welcome): - if ("motd" in welcome and - not self.motd_displayed): - motd_lines = welcome["motd"].splitlines() - motd_formatted = "\n ".join(motd_lines) - print("Server (at %s) says:\n %s" % (self._relay_url, motd_formatted), - file=sys.stderr) - self.motd_displayed = True - - # Only warn if we're running a release version (e.g. 0.0.6, not - # 0.0.6-DISTANCE-gHASH). Only warn once. - if ("-" not in __version__ and - not self.version_warning_displayed and - welcome["current_version"] != __version__): - print("Warning: errors may occur unless both sides are running the same version", file=sys.stderr) - print("Server claims %s is current, but ours is %s" - % (welcome["current_version"], __version__), file=sys.stderr) - self.version_warning_displayed = True - - if "error" in welcome: - raise ServerError(welcome["error"], self._relay_url) - - def get_code(self, code_length=2): - if self.code is not None: raise UsageError - channelid = self._channel_manager.allocate() - code = codes.make_code(channelid, code_length) - assert isinstance(code, type(u"")), type(code) - self._set_code_and_channelid(code) - self._start() - return code - - def input_code(self, prompt="Enter wormhole code: ", code_length=2): - lister = self._channel_manager.list_channels - # fetch the list of channels ahead of time, to give us a chance to - # discover the welcome message (and warn the user about an obsolete - # client) - initial_channelids = lister() - with self._timing.add("input code", waiting="user"): - code = codes.input_code_with_completion(prompt, - initial_channelids, lister, - code_length) - return code - - def set_code(self, code): # used for human-made pre-generated codes - if not isinstance(code, type(u"")): raise TypeError(type(code)) - if self.code is not None: raise UsageError - self._set_code_and_channelid(code) - self._start() - - def _set_code_and_channelid(self, code): - if self.code is not None: raise UsageError - self._timing.add("code established") - mo = re.search(r'^(\d+)-', code) - if not mo: - raise ValueError("code (%s) must start with NN-" % code) - self.code = code - channelid = int(mo.group(1)) - self._channel = self._channel_manager.connect(channelid) - monitor.add(self._channel) - - def _start(self): - # allocate the rest now too, so it can be serialized - self.sp = SPAKE2_Symmetric(to_bytes(self.code), - idSymmetric=to_bytes(self._appid)) - self.msg1 = self.sp.start() - - def derive_key(self, purpose, length=SecretBox.KEY_SIZE): - if not isinstance(purpose, type(u"")): raise TypeError(type(purpose)) - return HKDF(self.key, length, CTXinfo=to_bytes(purpose)) - - def _encrypt_data(self, key, data): - assert isinstance(key, type(b"")), type(key) - assert isinstance(data, type(b"")), type(data) - assert len(key) == SecretBox.KEY_SIZE, len(key) - box = SecretBox(key) - nonce = utils.random(SecretBox.NONCE_SIZE) - return box.encrypt(data, nonce) - - def _decrypt_data(self, key, encrypted): - assert isinstance(key, type(b"")), type(key) - assert isinstance(encrypted, type(b"")), type(encrypted) - assert len(key) == SecretBox.KEY_SIZE, len(key) - box = SecretBox(key) - data = box.decrypt(encrypted) - return data - - - def _get_key(self): - if not self.key: - self._channel.send(u"pake", self.msg1) - pake_msg = self._channel.get(u"pake") - - self.key = self.sp.finish(pake_msg) - self.verifier = self.derive_key(u"wormhole:verifier") - self._timing.add("key established") - - if not self._send_confirm: - return - confkey = self.derive_key(u"wormhole:confirmation") - nonce = os.urandom(CONFMSG_NONCE_LENGTH) - confmsg = make_confmsg(confkey, nonce) - self._channel.send(u"_confirm", confmsg) - - @close_on_error - def get_verifier(self): - if self._closed: raise UsageError - if self.code is None: raise UsageError - if self._channel is None: raise UsageError - self._get_key() - return self.verifier - - @close_on_error - def send_data(self, outbound_data, phase=u"data"): - if not isinstance(outbound_data, type(b"")): - raise TypeError(type(outbound_data)) - if not isinstance(phase, type(u"")): raise TypeError(type(phase)) - if self._closed: raise UsageError - if phase in self._sent_data: raise UsageError # only call this once - if phase.startswith(u"_"): raise UsageError # reserved for internals - if self.code is None: raise UsageError - if self._channel is None: raise UsageError - with self._timing.add("API send data", phase=phase): - # Without predefined roles, we can't derive predictably unique - # keys for each side, so we use the same key for both. We use - # random nonces to keep the messages distinct, and the Channel - # automatically ignores reflections. - self._sent_data.add(phase) - self._get_key() - data_key = self.derive_key(u"wormhole:phase:%s" % phase) - outbound_encrypted = self._encrypt_data(data_key, outbound_data) - self._channel.send(phase, outbound_encrypted) - - @close_on_error - def get_data(self, phase=u"data"): - if not isinstance(phase, type(u"")): raise TypeError(type(phase)) - if phase in self._got_data: raise UsageError # only call this once - if phase.startswith(u"_"): raise UsageError # reserved for internals - if self._closed: raise UsageError - if self.code is None: raise UsageError - if self._channel is None: raise UsageError - with self._timing.add("API get data", phase=phase): - self._got_data.add(phase) - self._get_key() - phases = [] - if not self._got_confirmation: - phases.append(u"_confirm") - phases.append(phase) - (got_phase, body) = self._channel.get_first_of(phases) - if got_phase == u"_confirm": - confkey = self.derive_key(u"wormhole:confirmation") - nonce = body[:CONFMSG_NONCE_LENGTH] - if body != make_confmsg(confkey, nonce): - raise WrongPasswordError - self._got_confirmation = True - (got_phase, body) = self._channel.get_first_of([phase]) - assert got_phase == phase - try: - data_key = self.derive_key(u"wormhole:phase:%s" % phase) - inbound_data = self._decrypt_data(data_key, body) - return inbound_data - except CryptoError: - raise WrongPasswordError - - def close(self, mood=u"happy"): - if not isinstance(mood, (type(None), type(u""))): - raise TypeError(type(mood)) - self._closed = True - if self._channel: - self._timing_started.finish(mood=mood) - c, self._channel = self._channel, None - monitor.close(c) - c.deallocate(mood) diff --git a/src/wormhole/cli/cmd_receive.py b/src/wormhole/cli/cmd_receive.py index 0de21cd..2e0447a 100644 --- a/src/wormhole/cli/cmd_receive.py +++ b/src/wormhole/cli/cmd_receive.py @@ -3,7 +3,7 @@ import os, sys, json, binascii, six, tempfile, zipfile from tqdm import tqdm from twisted.internet import reactor from twisted.internet.defer import inlineCallbacks, returnValue -from ..twisted.transcribe import Wormhole +from ..wormhole import wormhole from ..twisted.transit import TransitReceiver from ..errors import TransferError @@ -45,9 +45,8 @@ class TwistedReceiver: # can lazy-provide an endpoint, and overlap the startup process # with the user handing off the wormhole code yield tor_manager.start() - w = Wormhole(APPID, self.args.relay_url, tor_manager, - timing=self.args.timing, - reactor=self._reactor) + w = wormhole(APPID, self.args.relay_url, self._reactor, + tor_manager, timing=self.args.timing) # I wanted to do this instead: # # try: @@ -65,12 +64,12 @@ class TwistedReceiver: @inlineCallbacks def _go(self, w, tor_manager): yield self.handle_code(w) - verifier = yield w.get_verifier() + verifier = yield w.verify() self.show_verifier(verifier) them_d = yield self.get_data(w) try: if "message" in them_d: - yield self.handle_text(them_d, w) + self.handle_text(them_d, w) returnValue(None) if "file" in them_d: f = self.handle_file(them_d) @@ -90,7 +89,7 @@ class TwistedReceiver: raise RespondError("unknown offer type") except RespondError as r: data = json.dumps({"error": r.response}).encode("utf-8") - yield w.send_data(data) + w.send(data) raise TransferError(r.response) returnValue(None) @@ -100,10 +99,11 @@ class TwistedReceiver: if self.args.zeromode: assert not code code = u"0-" - if not code: - code = yield w.input_code("Enter receive wormhole code: ", - self.args.code_length) - yield w.set_code(code) + if code: + w.set_code(code) + else: + yield w.input_code("Enter receive wormhole code: ", + self.args.code_length) def show_verifier(self, verifier): verifier_hex = binascii.hexlify(verifier).decode("ascii") @@ -113,18 +113,17 @@ class TwistedReceiver: @inlineCallbacks def get_data(self, w): # this may raise WrongPasswordError - them_bytes = yield w.get_data() + them_bytes = yield w.get() them_d = json.loads(them_bytes.decode("utf-8")) if "error" in them_d: raise TransferError(them_d["error"]) returnValue(them_d) - @inlineCallbacks def handle_text(self, them_d, w): # we're receiving a text message self.msg(them_d["message"]) data = json.dumps({"message_ack": "ok"}).encode("utf-8") - yield w.send_data(data, wait=True) + w.send(data) def handle_file(self, them_d): file_data = them_d["file"] @@ -183,12 +182,13 @@ class TwistedReceiver: @inlineCallbacks def establish_transit(self, w, them_d, tor_manager): - transit_key = w.derive_key(APPID+u"/transit-key") transit_receiver = TransitReceiver(self.args.transit_helper, no_listen=self.args.no_listen, tor_manager=tor_manager, reactor=self._reactor, timing=self.args.timing) + transit_key = w.derive_key(APPID+u"/transit-key", + transit_receiver.TRANSIT_KEY_LENGTH) transit_receiver.set_transit_key(transit_key) direct_hints = yield transit_receiver.get_direct_hints() relay_hints = yield transit_receiver.get_relay_hints() @@ -199,7 +199,7 @@ class TwistedReceiver: "relay_connection_hints": relay_hints, }, }).encode("utf-8") - yield w.send_data(data) + w.send(data) # now receive the rest of the owl tdata = them_d["transit"] diff --git a/src/wormhole/cli/cmd_send.py b/src/wormhole/cli/cmd_send.py index 9904c58..dba5ee0 100644 --- a/src/wormhole/cli/cmd_send.py +++ b/src/wormhole/cli/cmd_send.py @@ -5,7 +5,7 @@ from twisted.protocols import basic from twisted.internet import reactor from twisted.internet.defer import inlineCallbacks, returnValue from ..errors import TransferError -from ..twisted.transcribe import Wormhole +from ..wormhole import wormhole from ..twisted.transit import TransitSender APPID = u"lothar.com/wormhole/text-or-file-xfer" @@ -49,8 +49,8 @@ def send(args, reactor=reactor): # user handing off the wormhole code yield tor_manager.start() - w = Wormhole(APPID, args.relay_url, tor_manager, timing=args.timing, - reactor=reactor) + w = wormhole(APPID, args.relay_url, reactor, tor_manager, + timing=args.timing) d = _send(reactor, w, args, phase1, fd_to_send, tor_manager) d.addBoth(w.close) @@ -83,7 +83,7 @@ def _send(reactor, w, args, phase1, fd_to_send, tor_manager): # get the verifier, because that also lets us derive the transit key, # which we want to set before revealing the connection hints to the far # side, so we'll be ready for them when they connect - verifier_bytes = yield w.get_verifier() + verifier_bytes = yield w.verify() verifier = binascii.hexlify(verifier_bytes).decode("ascii") if args.verify: @@ -94,17 +94,18 @@ def _send(reactor, w, args, phase1, fd_to_send, tor_manager): if ok.lower() == "no": err = "sender rejected verification check, abandoned transfer" reject_data = json.dumps({"error": err}).encode("utf-8") - yield w.send_data(reject_data) + w.send(reject_data) raise TransferError(err) if fd_to_send is not None: - transit_key = w.derive_key(APPID+"/transit-key") + transit_key = w.derive_key(APPID+"/transit-key", + transit_sender.TRANSIT_KEY_LENGTH) transit_sender.set_transit_key(transit_key) my_phase1_bytes = json.dumps(phase1).encode("utf-8") - yield w.send_data(my_phase1_bytes) + w.send(my_phase1_bytes) # this may raise WrongPasswordError - them_phase1_bytes = yield w.get_data() + them_phase1_bytes = yield w.get() them_phase1 = json.loads(them_phase1_bytes.decode("utf-8")) diff --git a/src/wormhole/cli/public_relay.py b/src/wormhole/cli/public_relay.py index de7b339..7d82c66 100644 --- a/src/wormhole/cli/public_relay.py +++ b/src/wormhole/cli/public_relay.py @@ -1,5 +1,5 @@ # This is a relay I run on a personal server. If it gets too expensive to # run, I'll shut it down. -RENDEZVOUS_RELAY = u"http://wormhole-relay.petmail.org:3000/wormhole-relay/" -TRANSIT_RELAY = u"tcp:wormhole-transit-relay.petmail.org:3001" +RENDEZVOUS_RELAY = u"ws://wormhole-relay.petmail.org:4000/" +TRANSIT_RELAY = u"tcp:wormhole-transit-relay.petmail.org:4001" diff --git a/src/wormhole/codes.py b/src/wormhole/codes.py index fd08694..44f2665 100644 --- a/src/wormhole/codes.py +++ b/src/wormhole/codes.py @@ -4,6 +4,7 @@ from .wordlist import (byte_to_even_word, byte_to_odd_word, even_words_lowercase, odd_words_lowercase) def make_code(channel_id, code_length): + assert isinstance(channel_id, type(u"")), type(channel_id) words = [] for i in range(code_length): # we start with an "odd word" @@ -11,7 +12,7 @@ def make_code(channel_id, code_length): words.append(byte_to_odd_word[os.urandom(1)].lower()) else: words.append(byte_to_even_word[os.urandom(1)].lower()) - return u"%d-%s" % (channel_id, u"-".join(words)) + return u"%s-%s" % (channel_id, u"-".join(words)) def extract_channel_id(code): channel_id = int(code.split("-")[0]) diff --git a/src/wormhole/errors.py b/src/wormhole/errors.py index 4d91270..141523d 100644 --- a/src/wormhole/errors.py +++ b/src/wormhole/errors.py @@ -20,6 +20,10 @@ def handle_server_error(func): class Timeout(Exception): pass +class WelcomeError(Exception): + """The server told us to signal an error, probably because our version is + too old to possibly work.""" + class WrongPasswordError(Exception): """ Key confirmation failed. Either you or your correspondent typed the code @@ -37,5 +41,8 @@ class ReflectionAttack(Exception): class UsageError(Exception): """The programmer did something wrong.""" +class WormholeClosedError(UsageError): + """API calls may not be made after close() is called.""" + class TransferError(Exception): """Something bad happened and the transfer failed.""" diff --git a/src/wormhole/server/cli_args.py b/src/wormhole/server/cli_args.py index d68d8db..3f3b02b 100644 --- a/src/wormhole/server/cli_args.py +++ b/src/wormhole/server/cli_args.py @@ -18,9 +18,9 @@ s = parser.add_subparsers(title="subcommands", dest="subcommand") # CLI: run-server sp_start = s.add_parser("start", description="Start a relay server", usage="wormhole server start [opts] [TWISTD-ARGS..]") -sp_start.add_argument("--rendezvous", default="tcp:3000", metavar="tcp:PORT", +sp_start.add_argument("--rendezvous", default="tcp:4000", metavar="tcp:PORT", help="endpoint specification for the rendezvous port") -sp_start.add_argument("--transit", default="tcp:3001", metavar="tcp:PORT", +sp_start.add_argument("--transit", default="tcp:4001", metavar="tcp:PORT", help="endpoint specification for the transit-relay port") sp_start.add_argument("--advertise-version", metavar="VERSION", help="version to recommend to clients") diff --git a/src/wormhole/server/database.py b/src/wormhole/server/database.py index d8edee0..163c8cd 100644 --- a/src/wormhole/server/database.py +++ b/src/wormhole/server/database.py @@ -22,7 +22,7 @@ def get_db(dbfile, stderr=sys.stderr): raise DBError("Unable to create/open db file %s: %s" % (dbfile, e)) db.row_factory = sqlite3.Row - VERSION = 1 + VERSION = 2 if must_create: schema = get_schema(VERSION) db.executescript(schema) diff --git a/src/wormhole/server/db-schemas/v1.sql b/src/wormhole/server/db-schemas/v1.sql deleted file mode 100644 index 7d7115a..0000000 --- a/src/wormhole/server/db-schemas/v1.sql +++ /dev/null @@ -1,43 +0,0 @@ - --- note: anything which isn't an boolean, integer, or human-readable unicode --- string, (i.e. binary strings) will be stored as hex - -CREATE TABLE `version` -( - `version` INTEGER -- contains one row, set to 1 -); - -CREATE TABLE `messages` -( - `appid` VARCHAR, - `channelid` INTEGER, - `side` VARCHAR, - `phase` VARCHAR, -- not numeric, more of a PAKE-phase indicator string - -- phase="_allocate" and "_deallocate" are used internally - `body` VARCHAR, - `server_rx` INTEGER, - `msgid` VARCHAR -); -CREATE INDEX `messages_idx` ON `messages` (`appid`, `channelid`); - -CREATE TABLE `usage` -( - `type` VARCHAR, -- "rendezvous" or "transit" - `started` INTEGER, -- seconds since epoch, rounded to one day - `result` VARCHAR, -- happy, scary, lonely, errory, pruney - -- rendezvous moods: - -- "happy": both sides close with mood=happy - -- "scary": any side closes with mood=scary (bad MAC, probably wrong pw) - -- "lonely": any side closes with mood=lonely (no response from 2nd side) - -- "errory": any side closes with mood=errory (other errors) - -- "pruney": channels which get pruned for inactivity - -- "crowded": three or more sides were involved - -- transit moods: - -- "errory": this side have the wrong handshake - -- "lonely": good handshake, but the other side never showed up - -- "happy": both sides gave correct handshake - `total_bytes` INTEGER, -- for transit, total bytes relayed (both directions) - `total_time` INTEGER, -- seconds from start to closed, or None - `waiting_time` INTEGER -- seconds from start to 2nd side appearing, or None -); -CREATE INDEX `usage_idx` ON `usage` (`started`); diff --git a/src/wormhole/server/db-schemas/v2.sql b/src/wormhole/server/db-schemas/v2.sql new file mode 100644 index 0000000..b436897 --- /dev/null +++ b/src/wormhole/server/db-schemas/v2.sql @@ -0,0 +1,103 @@ + +-- note: anything which isn't an boolean, integer, or human-readable unicode +-- string, (i.e. binary strings) will be stored as hex + +CREATE TABLE `version` +( + `version` INTEGER -- contains one row, set to 2 +); + + +-- Wormhole codes use a "nameplate": a short identifier which is only used to +-- reference a specific (long-named) mailbox. The codes only use numeric +-- nameplates, but the protocol and server allow can use arbitrary strings. +CREATE TABLE `nameplates` +( + `app_id` VARCHAR, + `id` VARCHAR, + `mailbox_id` VARCHAR, -- really a foreign key + `side1` VARCHAR, -- side name, or NULL + `side2` VARCHAR, -- side name, or NULL + `crowded` BOOLEAN, -- at some point, three or more sides were involved + `updated` INTEGER, -- time of last activity, used for pruning + -- timing data + `started` INTEGER, -- time when nameplace was opened + `second` INTEGER -- time when second side opened +); +CREATE INDEX `nameplates_idx` ON `nameplates` (`app_id`, `id`); +CREATE INDEX `nameplates_updated_idx` ON `nameplates` (`app_id`, `updated`); +CREATE INDEX `nameplates_mailbox_idx` ON `nameplates` (`app_id`, `mailbox_id`); + +-- Clients exchange messages through a "mailbox", which has a long (randomly +-- unique) identifier and a queue of messages. +CREATE TABLE `mailboxes` +( + `app_id` VARCHAR, + `id` VARCHAR, + `side1` VARCHAR, -- side name, or NULL + `side2` VARCHAR, -- side name, or NULL + `crowded` BOOLEAN, -- at some point, three or more sides were involved + `first_mood` VARCHAR, + -- timing data for the mailbox itself + `started` INTEGER, -- time when opened + `second` INTEGER -- time when second side opened +); +CREATE INDEX `mailboxes_idx` ON `mailboxes` (`app_id`, `id`); + +CREATE TABLE `messages` +( + `app_id` VARCHAR, + `mailbox_id` VARCHAR, + `side` VARCHAR, + `phase` VARCHAR, -- numeric or string + `body` VARCHAR, + `server_rx` INTEGER, + `msg_id` VARCHAR +); +CREATE INDEX `messages_idx` ON `messages` (`app_id`, `mailbox_id`); + +CREATE TABLE `nameplate_usage` +( + `app_id` VARCHAR, + `started` INTEGER, -- seconds since epoch, rounded to "blur time" + `waiting_time` INTEGER, -- seconds from start to 2nd side appearing, or None + `total_time` INTEGER, -- seconds from open to last close/prune + `result` VARCHAR -- happy, lonely, pruney, crowded + -- nameplate moods: + -- "happy": two sides open and close + -- "lonely": one side opens and closes (no response from 2nd side) + -- "pruney": channels which get pruned for inactivity + -- "crowded": three or more sides were involved +); +CREATE INDEX `nameplate_usage_idx` ON `nameplate_usage` (`app_id`, `started`); + +CREATE TABLE `mailbox_usage` +( + `app_id` VARCHAR, + `started` INTEGER, -- seconds since epoch, rounded to "blur time" + `total_time` INTEGER, -- seconds from open to last close + `waiting_time` INTEGER, -- seconds from start to 2nd side appearing, or None + `result` VARCHAR -- happy, scary, lonely, errory, pruney + -- rendezvous moods: + -- "happy": both sides close with mood=happy + -- "scary": any side closes with mood=scary (bad MAC, probably wrong pw) + -- "lonely": any side closes with mood=lonely (no response from 2nd side) + -- "errory": any side closes with mood=errory (other errors) + -- "pruney": channels which get pruned for inactivity + -- "crowded": three or more sides were involved +); +CREATE INDEX `mailbox_usage_idx` ON `mailbox_usage` (`app_id`, `started`); + +CREATE TABLE `transit_usage` +( + `started` INTEGER, -- seconds since epoch, rounded to "blur time" + `total_time` INTEGER, -- seconds from open to last close + `waiting_time` INTEGER, -- seconds from start to 2nd side appearing, or None + `total_bytes` INTEGER, -- total bytes relayed (both directions) + `result` VARCHAR -- happy, scary, lonely, errory, pruney + -- transit moods: + -- "errory": one side gave the wrong handshake + -- "lonely": good handshake, but the other side never showed up + -- "happy": both sides gave correct handshake +); +CREATE INDEX `transit_usage_idx` ON `transit_usage` (`started`); diff --git a/src/wormhole/server/rendezvous.py b/src/wormhole/server/rendezvous.py index 122f780..d439fbb 100644 --- a/src/wormhole/server/rendezvous.py +++ b/src/wormhole/server/rendezvous.py @@ -1,5 +1,6 @@ from __future__ import print_function -import time, random +import os, time, random, base64 +from collections import namedtuple from twisted.python import log from twisted.application import service, internet @@ -12,92 +13,209 @@ MB = 1000*1000 CHANNEL_EXPIRATION_TIME = 3*DAY EXPIRATION_CHECK_PERIOD = 2*HOUR -ALLOCATE = u"_allocate" -DEALLOCATE = u"_deallocate" +def get_sides(row): + return set([s for s in [row["side1"], row["side2"]] if s]) +def make_sides(sides): + return list(sides) + [None] * (2 - len(sides)) +def generate_mailbox_id(): + return base64.b32encode(os.urandom(8)).lower().strip(b"=").decode("ascii") -class Channel: - def __init__(self, app, db, welcome, blur_usage, log_requests, - appid, channelid): + +SideResult = namedtuple("SideResult", ["changed", "empty", "side1", "side2"]) +Unchanged = SideResult(changed=False, empty=False, side1=None, side2=None) +class CrowdedError(Exception): + pass + +def add_side(row, new_side): + old_sides = [s for s in [row["side1"], row["side2"]] if s] + assert old_sides + if new_side in old_sides: + return Unchanged + if len(old_sides) == 2: + raise CrowdedError("too many sides for this thing") + return SideResult(changed=True, empty=False, + side1=old_sides[0], side2=new_side) + +def remove_side(row, side): + old_sides = [s for s in [row["side1"], row["side2"]] if s] + if side not in old_sides: + return Unchanged + remaining_sides = old_sides[:] + remaining_sides.remove(side) + if remaining_sides: + return SideResult(changed=True, empty=False, side1=remaining_sides[0], + side2=None) + return SideResult(changed=True, empty=True, side1=None, side2=None) + +Usage = namedtuple("Usage", ["started", "waiting_time", "total_time", "result"]) +TransitUsage = namedtuple("TransitUsage", + ["started", "waiting_time", "total_time", + "total_bytes", "result"]) + +SidedMessage = namedtuple("SidedMessage", ["side", "phase", "body", + "server_rx", "msg_id"]) + +class Mailbox: + def __init__(self, app, db, blur_usage, log_requests, app_id, mailbox_id): self._app = app self._db = db self._blur_usage = blur_usage self._log_requests = log_requests - self._appid = appid - self._channelid = channelid - self._listeners = set() # instances with .send_rendezvous_event (that - # takes a JSONable object) and - # .stop_rendezvous_watcher() + self._app_id = app_id + self._mailbox_id = mailbox_id + self._listeners = {} # handle -> (send_f, stop_f) + # "handle" is a hashable object, for deregistration + # send_f() takes a JSONable object, stop_f() has no args - def get_channelid(self): - return self._channelid + def open(self, side, when): + # requires caller to db.commit() + assert isinstance(side, type(u"")), type(side) + db = self._db + row = db.execute("SELECT * FROM `mailboxes`" + " WHERE `app_id`=? AND `id`=?", + (self._app_id, self._mailbox_id)).fetchone() + try: + sr = add_side(row, side) + except CrowdedError: + db.execute("UPDATE `mailboxes` SET `crowded`=?" + " WHERE `app_id`=? AND `id`=?", + (True, self._app_id, self._mailbox_id)) + db.commit() + raise + if sr.changed: + db.execute("UPDATE `mailboxes` SET" + " `side1`=?, `side2`=?, `second`=?" + " WHERE `app_id`=? AND `id`=?", + (sr.side1, sr.side2, when, + self._app_id, self._mailbox_id)) def get_messages(self): messages = [] db = self._db for row in db.execute("SELECT * FROM `messages`" - " WHERE `appid`=? AND `channelid`=?" + " WHERE `app_id`=? AND `mailbox_id`=?" " ORDER BY `server_rx` ASC", - (self._appid, self._channelid)).fetchall(): - if row["phase"] in (u"_allocate", u"_deallocate"): - continue - messages.append({"phase": row["phase"], "body": row["body"], - "server_rx": row["server_rx"], "id": row["msgid"]}) + (self._app_id, self._mailbox_id)).fetchall(): + sm = SidedMessage(side=row["side"], phase=row["phase"], + body=row["body"], server_rx=row["server_rx"], + msg_id=row["msg_id"]) + messages.append(sm) return messages - def add_listener(self, ep): - self._listeners.add(ep) + def add_listener(self, handle, send_f, stop_f): + self._listeners[handle] = (send_f, stop_f) return self.get_messages() - def remove_listener(self, ep): - self._listeners.discard(ep) + def remove_listener(self, handle): + self._listeners.pop(handle) - def broadcast_message(self, phase, body, server_rx, msgid): - for ep in self._listeners: - ep.send_rendezvous_event({"phase": phase, "body": body, - "server_rx": server_rx, "id": msgid}) + def broadcast_message(self, sm): + for (send_f, stop_f) in self._listeners.values(): + send_f(sm) - def _add_message(self, side, phase, body, server_rx, msgid): + def _add_message(self, sm): + self._db.execute("INSERT INTO `messages`" + " (`app_id`, `mailbox_id`, `side`, `phase`, `body`," + " `server_rx`, `msg_id`)" + " VALUES (?,?,?,?,?, ?,?)", + (self._app_id, self._mailbox_id, sm.side, + sm.phase, sm.body, sm.server_rx, sm.msg_id)) + self._db.commit() + + def add_message(self, sm): + assert isinstance(sm, SidedMessage) + self._add_message(sm) + self.broadcast_message(sm) + + def close(self, side, mood, when): + assert isinstance(side, type(u"")), type(side) db = self._db - db.execute("INSERT INTO `messages`" - " (`appid`, `channelid`, `side`, `phase`, `body`," - " `server_rx`, `msgid`)" - " VALUES (?,?,?,?,?, ?,?)", - (self._appid, self._channelid, side, phase, body, - server_rx, msgid)) - db.commit() + row = db.execute("SELECT * FROM `mailboxes`" + " WHERE `app_id`=? AND `id`=?", + (self._app_id, self._mailbox_id)).fetchone() + if not row: + return + sr = remove_side(row, side) + if sr.empty: + rows = db.execute("SELECT DISTINCT(`side`) FROM `messages`" + " WHERE `app_id`=? AND `mailbox_id`=?", + (self._app_id, self._mailbox_id)).fetchall() + num_sides = len(rows) + self._summarize_and_store(row, num_sides, mood, when, pruned=False) + self._delete() + db.commit() + elif sr.changed: + db.execute("UPDATE `mailboxes`" + " SET `side1`=?, `side2`=?, `first_mood`=?" + " WHERE `app_id`=? AND `id`=?", + (sr.side1, sr.side2, mood, + self._app_id, self._mailbox_id)) + db.commit() - def allocate(self, side): - self._add_message(side, ALLOCATE, None, time.time(), None) + def _delete(self): + # requires caller to db.commit() + self._db.execute("DELETE FROM `mailboxes`" + " WHERE `app_id`=? AND `id`=?", + (self._app_id, self._mailbox_id)) + self._db.execute("DELETE FROM `messages`" + " WHERE `app_id`=? AND `mailbox_id`=?", + (self._app_id, self._mailbox_id)) - def add_message(self, side, phase, body, server_rx, msgid): - self._add_message(side, phase, body, server_rx, msgid) - self.broadcast_message(phase, body, server_rx, msgid) - return self.get_messages() # for rendezvous_web.py POST /add + # Shut down any listeners, just in case they're still lingering + # around. + for (send_f, stop_f) in self._listeners.values(): + stop_f() - def deallocate(self, side, mood): - self._add_message(side, DEALLOCATE, mood, time.time(), None) - db = self._db - seen = set([row["side"] for row in - db.execute("SELECT `side` FROM `messages`" - " WHERE `appid`=? AND `channelid`=?", - (self._appid, self._channelid))]) - freed = set([row["side"] for row in - db.execute("SELECT `side` FROM `messages`" - " WHERE `appid`=? AND `channelid`=?" - " AND `phase`=?", - (self._appid, self._channelid, DEALLOCATE))]) - if seen - freed: - return False - self.delete_and_summarize() - return True + self._app.free_mailbox(self._mailbox_id) + + def _summarize_and_store(self, row, num_sides, second_mood, delete_time, + pruned): + u = self._summarize(row, num_sides, second_mood, delete_time, pruned) + self._db.execute("INSERT INTO `mailbox_usage`" + " (`app_id`, " + " `started`, `total_time`, `waiting_time`, `result`)" + " VALUES (?, ?,?,?,?)", + (self._app_id, + u.started, u.total_time, u.waiting_time, u.result)) + + def _summarize(self, row, num_sides, second_mood, delete_time, pruned): + started = row["started"] + if self._blur_usage: + started = self._blur_usage * (started // self._blur_usage) + waiting_time = None + if row["second"]: + waiting_time = row["second"] - row["started"] + total_time = delete_time - row["started"] + + if num_sides == 0: + result = u"quiet" + elif num_sides == 1: + result = u"lonely" + else: + result = u"happy" + + moods = set([row["first_mood"], second_mood]) + if u"lonely" in moods: + result = u"lonely" + if u"errory" in moods: + result = u"errory" + if u"scary" in moods: + result = u"scary" + if pruned: + result = u"pruney" + if row["crowded"]: + result = u"crowded" + + return Usage(started=started, waiting_time=waiting_time, + total_time=total_time, result=result) def is_idle(self): if self._listeners: return False c = self._db.execute("SELECT `server_rx` FROM `messages`" - " WHERE `appid`=? AND `channelid`=?" + " WHERE `app_id`=? AND `mailbox_id`=?" " ORDER BY `server_rx` DESC LIMIT 1", - (self._appid, self._channelid)) + (self._app_id, self._mailbox_id)) rows = c.fetchall() if not rows: return True @@ -106,172 +224,224 @@ class Channel: return True return False - def _store_summary(self, summary): - (started, result, total_time, waiting_time) = summary - if self._blur_usage: - started = self._blur_usage * (started // self._blur_usage) - self._db.execute("INSERT INTO `usage`" - " (`type`, `started`, `result`," - " `total_time`, `waiting_time`)" - " VALUES (?,?,?, ?,?)", - (u"rendezvous", started, result, - total_time, waiting_time)) - self._db.commit() - - def _summarize(self, messages, delete_time): - all_sides = set([m["side"] for m in messages]) - if len(all_sides) == 0: - log.msg("_summarize was given zero messages") # shouldn't happen - return - - started = min([m["server_rx"] for m in messages]) - # 'total_time' is how long the channel was occupied. That ends now, - # both for channels that got pruned for inactivity, and for channels - # that got pruned because of two DEALLOCATE messages - total_time = delete_time - started - - if len(all_sides) == 1: - return (started, "lonely", total_time, None) - if len(all_sides) > 2: - # TODO: it'll be useful to have more detail here - return (started, "crowded", total_time, None) - - # exactly two sides were involved - A_side = sorted(messages, key=lambda m: m["server_rx"])[0]["side"] - B_side = list(all_sides - set([A_side]))[0] - - # How long did the first side wait until the second side showed up? - first_A = min([m["server_rx"] for m in messages if m["side"] == A_side]) - first_B = min([m["server_rx"] for m in messages if m["side"] == B_side]) - waiting_time = first_B - first_A - - # now, were all sides closed? If not, this is "pruney" - A_deallocs = [m for m in messages - if m["phase"] == DEALLOCATE and m["side"] == A_side] - B_deallocs = [m for m in messages - if m["phase"] == DEALLOCATE and m["side"] == B_side] - if not A_deallocs or not B_deallocs: - return (started, "pruney", total_time, None) - - # ok, both sides closed. figure out the mood - A_mood = A_deallocs[0]["body"] # maybe None - B_mood = B_deallocs[0]["body"] # maybe None - mood = "quiet" - if A_mood == u"happy" and B_mood == u"happy": - mood = "happy" - if A_mood == u"lonely" or B_mood == u"lonely": - mood = "lonely" - if A_mood == u"errory" or B_mood == u"errory": - mood = "errory" - if A_mood == u"scary" or B_mood == u"scary": - mood = "scary" - return (started, mood, total_time, waiting_time) - - def delete_and_summarize(self): - db = self._db - c = self._db.execute("SELECT * FROM `messages`" - " WHERE `appid`=? AND `channelid`=?" - " ORDER BY `server_rx`", - (self._appid, self._channelid)) - messages = c.fetchall() - summary = self._summarize(messages, time.time()) - self._store_summary(summary) - db.execute("DELETE FROM `messages`" - " WHERE `appid`=? AND `channelid`=?", - (self._appid, self._channelid)) - db.commit() - - # Shut down any listeners, just in case they're still lingering - # around. - for ep in self._listeners: - ep.stop_rendezvous_watcher() - - self._app.free_channel(self._channelid) - def _shutdown(self): # used at test shutdown to accelerate client disconnects - for ep in self._listeners: - ep.stop_rendezvous_watcher() + for (send_f, stop_f) in self._listeners.values(): + stop_f() class AppNamespace: - def __init__(self, db, welcome, blur_usage, log_requests, appid): + def __init__(self, db, welcome, blur_usage, log_requests, app_id): self._db = db self._welcome = welcome self._blur_usage = blur_usage self._log_requests = log_requests - self._appid = appid - self._channels = {} + self._app_id = app_id + self._mailboxes = {} - def get_allocated(self): + def get_nameplate_ids(self): db = self._db - c = db.execute("SELECT DISTINCT `channelid` FROM `messages`" - " WHERE `appid`=?", (self._appid,)) - return set([row["channelid"] for row in c.fetchall()]) + # TODO: filter this to numeric ids? + c = db.execute("SELECT DISTINCT `id` FROM `nameplates`" + " WHERE `app_id`=?", (self._app_id,)) + return set([row["id"] for row in c.fetchall()]) - def find_available_channelid(self): - allocated = self.get_allocated() + def _find_available_nameplate_id(self): + claimed = self.get_nameplate_ids() for size in range(1,4): # stick to 1-999 for now available = set() - for cid in range(10**(size-1), 10**size): - if cid not in allocated: - available.add(cid) + for id_int in range(10**(size-1), 10**size): + id = u"%d" % id_int + if id not in claimed: + available.add(id) if available: return random.choice(list(available)) - # ouch, 999 currently allocated. Try random ones for a while. + # ouch, 999 currently claimed. Try random ones for a while. for tries in range(1000): - cid = random.randrange(1000, 1000*1000) - if cid not in allocated: - return cid - raise ValueError("unable to find a free channel-id") + id_int = random.randrange(1000, 1000*1000) + id = u"%d" % id_int + if id not in claimed: + return id + raise ValueError("unable to find a free nameplate-id") - def allocate_channel(self, channelid, side): - channel = self.get_channel(channelid) - channel.allocate(side) - return channel + def allocate_nameplate(self, side, when): + nameplate_id = self._find_available_nameplate_id() + mailbox_id = self.claim_nameplate(nameplate_id, side, when) + del mailbox_id # ignored, they'll learn it from claim() + return nameplate_id - def get_channel(self, channelid): - assert isinstance(channelid, int) - if not channelid in self._channels: + def claim_nameplate(self, nameplate_id, side, when): + # when we're done: + # * there will be one row for the nameplate + # * side1 or side2 will be populated + # * started or second will be populated + # * a mailbox id will be created, but not a mailbox row + # (ids are randomly unique, so we can defer creation until 'open') + assert isinstance(nameplate_id, type(u"")), type(nameplate_id) + assert isinstance(side, type(u"")), type(side) + db = self._db + row = db.execute("SELECT * FROM `nameplates`" + " WHERE `app_id`=? AND `id`=?", + (self._app_id, nameplate_id)).fetchone() + if row: + mailbox_id = row["mailbox_id"] + try: + sr = add_side(row, side) + except CrowdedError: + db.execute("UPDATE `nameplates` SET `crowded`=?" + " WHERE `app_id`=? AND `id`=?", + (True, self._app_id, nameplate_id)) + db.commit() + raise + if sr.changed: + db.execute("UPDATE `nameplates` SET" + " `side1`=?, `side2`=?, `updated`=?, `second`=?" + " WHERE `app_id`=? AND `id`=?", + (sr.side1, sr.side2, when, when, + self._app_id, nameplate_id)) + else: if self._log_requests: - log.msg("spawning #%d for appid %s" % (channelid, self._appid)) - self._channels[channelid] = Channel(self, self._db, self._welcome, - self._blur_usage, - self._log_requests, - self._appid, channelid) - return self._channels[channelid] + log.msg("creating nameplate#%s for app_id %s" % + (nameplate_id, self._app_id)) + mailbox_id = generate_mailbox_id() + db.execute("INSERT INTO `nameplates`" + " (`app_id`, `id`, `mailbox_id`, `side1`, `crowded`," + " `updated`, `started`)" + " VALUES(?,?,?,?,?, ?,?)", + (self._app_id, nameplate_id, mailbox_id, side, False, + when, when)) + db.commit() + return mailbox_id - def free_channel(self, channelid): - # called from Channel.delete_and_summarize(), which deletes any + def release_nameplate(self, nameplate_id, side, when): + # when we're done: + # * in the nameplate row, side1 or side2 will be removed + # * if the nameplate is now unused: + # * mailbox.nameplate_closed will be populated + # * the nameplate row will be removed + assert isinstance(nameplate_id, type(u"")), type(nameplate_id) + assert isinstance(side, type(u"")), type(side) + db = self._db + row = db.execute("SELECT * FROM `nameplates`" + " WHERE `app_id`=? AND `id`=?", + (self._app_id, nameplate_id)).fetchone() + if not row: + return + sr = remove_side(row, side) + if sr.empty: + db.execute("DELETE FROM `nameplates`" + " WHERE `app_id`=? AND `id`=?", + (self._app_id, nameplate_id)) + self._summarize_nameplate_and_store(row, when, pruned=False) + db.commit() + elif sr.changed: + db.execute("UPDATE `nameplates`" + " SET `side1`=?, `side2`=?, `updated`=?" + " WHERE `app_id`=? AND `id`=?", + (sr.side1, sr.side2, when, + self._app_id, nameplate_id)) + db.commit() + + def _summarize_nameplate_and_store(self, row, delete_time, pruned): + # requires caller to db.commit() + u = self._summarize_nameplate_usage(row, delete_time, pruned) + self._db.execute("INSERT INTO `nameplate_usage`" + " (`app_id`," + " `started`, `total_time`, `waiting_time`, `result`)" + " VALUES (?, ?,?,?,?)", + (self._app_id, + u.started, u.total_time, u.waiting_time, u.result)) + + def _summarize_nameplate_usage(self, row, delete_time, pruned): + started = row["started"] + if self._blur_usage: + started = self._blur_usage * (started // self._blur_usage) + waiting_time = None + if row["second"]: + waiting_time = row["second"] - row["started"] + total_time = delete_time - row["started"] + result = u"lonely" + if row["second"]: + result = u"happy" + if pruned: + result = u"pruney" + if row["crowded"]: + result = u"crowded" + return Usage(started=started, waiting_time=waiting_time, + total_time=total_time, result=result) + + def _prune_nameplate(self, row, delete_time): + # requires caller to db.commit() + db = self._db + db.execute("DELETE FROM `nameplates` WHERE `app_id`=? AND `id`=?", + (self._app_id, row["id"])) + self._summarize_nameplate_and_store(row, delete_time, pruned=True) + # TODO: make a Nameplate object, keep track of when there's a + # websocket that's watching it, don't prune a nameplate that someone + # is watching, even if they started watching a long time ago + + def prune_nameplates(self, old): + db = self._db + for row in db.execute("SELECT * FROM `nameplates`" + " WHERE `updated` < ?", + (old,)).fetchall(): + self._prune_nameplate(row) + count = db.execute("SELECT COUNT(*) FROM `nameplates`").fetchone()[0] + return count + + def open_mailbox(self, mailbox_id, side, when): + assert isinstance(mailbox_id, type(u"")), type(mailbox_id) + db = self._db + if not mailbox_id in self._mailboxes: + if self._log_requests: + log.msg("spawning #%s for app_id %s" % (mailbox_id, + self._app_id)) + db.execute("INSERT INTO `mailboxes`" + " (`app_id`, `id`, `side1`, `crowded`, `started`)" + " VALUES(?,?,?,?,?)", + (self._app_id, mailbox_id, side, False, when)) + db.commit() # XXX + # mailbox.open() does a SELECT to find the old sides + self._mailboxes[mailbox_id] = Mailbox(self, self._db, + self._blur_usage, + self._log_requests, + self._app_id, mailbox_id) + mailbox = self._mailboxes[mailbox_id] + mailbox.open(side, when) + db.commit() + return mailbox + + def free_mailbox(self, mailbox_id): + # called from Mailbox.delete_and_summarize(), which deletes any # messages - if channelid in self._channels: - self._channels.pop(channelid) - if self._log_requests: - log.msg("freed+killed #%d, now have %d DB channels, %d live" % - (channelid, len(self.get_allocated()), len(self._channels))) + if mailbox_id in self._mailboxes: + self._mailboxes.pop(mailbox_id) + #if self._log_requests: + # log.msg("freed+killed #%s, now have %d DB mailboxes, %d live" % + # (mailbox_id, len(self.get_claimed()), len(self._mailboxes))) - def prune_old_channels(self): + def prune_mailboxes(self, old): # For now, pruning is logged even if log_requests is False, to debug # the pruning process, and since pruning is triggered by a timer - # instead of by user action. It does reveal which channels were + # instead of by user action. It does reveal which mailboxes were # present when the pruning process began, though, so in the log run # it should do less logging. log.msg(" channel prune begins") # a channel is deleted when there are no listeners and there have # been no messages added in CHANNEL_EXPIRATION_TIME seconds - channels = set(self.get_allocated()) # these have messages - channels.update(self._channels) # these might have listeners - for channelid in channels: - log.msg(" channel prune checking %d" % channelid) - channel = self.get_channel(channelid) + mailboxes = set(self.get_claimed()) # these have messages + mailboxes.update(self._mailboxes) # these might have listeners + for mailbox_id in mailboxes: + log.msg(" channel prune checking %d" % mailbox_id) + channel = self.get_channel(mailbox_id) if channel.is_idle(): - log.msg(" channel prune expiring %d" % channelid) + log.msg(" channel prune expiring %d" % mailbox_id) channel.delete_and_summarize() # calls self.free_channel - log.msg(" channel prune done, %r left" % (self._channels.keys(),)) - return bool(self._channels) + log.msg(" channel prune done, %r left" % (self._mailboxes.keys(),)) + return bool(self._mailboxes) def _shutdown(self): - for channel in self._channels.values(): + for channel in self._mailboxes.values(): channel._shutdown() class Rendezvous(service.MultiService): @@ -279,7 +449,7 @@ class Rendezvous(service.MultiService): service.MultiService.__init__(self) self._db = db self._welcome = welcome - self._blur_usage = blur_usage + self._blur_usage = None log_requests = blur_usage is None self._log_requests = log_requests self._apps = {} @@ -291,28 +461,31 @@ class Rendezvous(service.MultiService): def get_log_requests(self): return self._log_requests - def get_app(self, appid): - assert isinstance(appid, type(u"")) - if not appid in self._apps: + def get_app(self, app_id): + assert isinstance(app_id, type(u"")) + if not app_id in self._apps: if self._log_requests: - log.msg("spawning appid %s" % (appid,)) - self._apps[appid] = AppNamespace(self._db, self._welcome, + log.msg("spawning app_id %s" % (app_id,)) + self._apps[app_id] = AppNamespace(self._db, self._welcome, self._blur_usage, - self._log_requests, appid) - return self._apps[appid] + self._log_requests, app_id) + return self._apps[app_id] - def prune(self): - # As with AppNamespace.prune_old_channels, we log for now. + def prune(self, old=None): + # As with AppNamespace.prune_old_mailboxes, we log for now. log.msg("beginning app prune") - c = self._db.execute("SELECT DISTINCT `appid` FROM `messages`") - apps = set([row["appid"] for row in c.fetchall()]) # these have messages + if old is None: + old = time.time() - CHANNEL_EXPIRATION_TIME + c = self._db.execute("SELECT DISTINCT `app_id` FROM `messages`") + apps = set([row["app_id"] for row in c.fetchall()]) # these have messages apps.update(self._apps) # these might have listeners - for appid in apps: - log.msg(" app prune checking %r" % (appid,)) - still_active = self.get_app(appid).prune_old_channels() + for app_id in apps: + log.msg(" app prune checking %r" % (app_id,)) + app = self.get_app(app_id) + still_active = app.prune_nameplates(old) + app.prune_mailboxes(old) if not still_active: - log.msg("prune pops app %r" % (appid,)) - self._apps.pop(appid) + log.msg("prune pops app %r" % (app_id,)) + self._apps.pop(app_id) log.msg("app prune ends, %d remaining apps" % len(self._apps)) def stopService(self): diff --git a/src/wormhole/server/rendezvous_web.py b/src/wormhole/server/rendezvous_web.py deleted file mode 100644 index c89654e..0000000 --- a/src/wormhole/server/rendezvous_web.py +++ /dev/null @@ -1,223 +0,0 @@ -import json, time -from twisted.web import server, resource -from twisted.python import log - -def json_response(request, data): - request.setHeader(b"content-type", b"application/json; charset=utf-8") - return (json.dumps(data)+"\n").encode("utf-8") - -class EventsProtocol: - def __init__(self, request): - self.request = request - - def sendComment(self, comment): - # this is ignored by clients, but can keep the connection open in the - # face of firewall/NAT timeouts. It also helps unit tests, since - # apparently twisted.web.client.Agent doesn't consider the connection - # to be established until it sees the first byte of the reponse body. - self.request.write(b": " + comment + b"\n\n") - - def sendEvent(self, data, name=None, id=None, retry=None): - if name: - self.request.write(b"event: " + name.encode("utf-8") + b"\n") - # e.g. if name=foo, then the client web page should do: - # (new EventSource(url)).addEventListener("foo", handlerfunc) - # Note that this basically defaults to "message". - if id: - self.request.write(b"id: " + id.encode("utf-8") + b"\n") - if retry: - self.request.write(b"retry: " + retry + b"\n") # milliseconds - for line in data.splitlines(): - self.request.write(b"data: " + line.encode("utf-8") + b"\n") - self.request.write(b"\n") - - def stop(self): - self.request.finish() - - def send_rendezvous_event(self, data): - data = data.copy() - data["sent"] = time.time() - self.sendEvent(json.dumps(data)) - def stop_rendezvous_watcher(self): - self.stop() - -# note: no versions of IE (including the current IE11) support EventSource - -# relay URLs are as follows: (MESSAGES=[{phase:,body:}..]) -# ("-" indicates a deprecated URL) -# GET /list?appid= -> {channelids: [INT..]} -# POST /allocate {appid:,side:} -> {channelid: INT} -# these return all messages (base64) for appid=/channelid= : -# POST /add {appid:,channelid:,side:,phase:,body:} -> {messages: MESSAGES} -# GET /get?appid=&channelid= (no-eventsource) -> {messages: MESSAGES} -#- GET /get?appid=&channelid= (eventsource) -> {phase:, body:}.. -# GET /watch?appid=&channelid= (eventsource) -> {phase:, body:}.. -# POST /deallocate {appid:,channelid:,side:} -> {status: waiting | deleted} -# all JSON responses include a "welcome:{..}" key - -class RelayResource(resource.Resource): - def __init__(self, rendezvous): - resource.Resource.__init__(self) - self._rendezvous = rendezvous - self._welcome = rendezvous.get_welcome() - -class ChannelLister(RelayResource): - def render_GET(self, request): - if b"appid" not in request.args: - e = NeedToUpgradeErrorResource(self._welcome) - return e.get_message() - appid = request.args[b"appid"][0].decode("utf-8") - #print("LIST", appid) - app = self._rendezvous.get_app(appid) - allocated = app.get_allocated() - data = {"welcome": self._welcome, "channelids": sorted(allocated), - "sent": time.time()} - return json_response(request, data) - -class Allocator(RelayResource): - def render_POST(self, request): - content = request.content.read() - data = json.loads(content.decode("utf-8")) - appid = data["appid"] - side = data["side"] - if not isinstance(side, type(u"")): - raise TypeError("side must be string, not '%s'" % type(side)) - #print("ALLOCATE", appid, side) - app = self._rendezvous.get_app(appid) - channelid = app.find_available_channelid() - app.allocate_channel(channelid, side) - if self._rendezvous.get_log_requests(): - log.msg("allocated #%d, now have %d DB channels" % - (channelid, len(app.get_allocated()))) - response = {"welcome": self._welcome, "channelid": channelid, - "sent": time.time()} - return json_response(request, response) - - def getChild(self, path, req): - # wormhole-0.4.0 "send" started with "POST /allocate/SIDE". - # wormhole-0.5.0 changed that to "POST /allocate". We catch the old - # URL here to deliver a nicer error message (with upgrade - # instructions) than an ugly 404. - return NeedToUpgradeErrorResource(self._welcome) - -class NeedToUpgradeErrorResource(resource.Resource): - def __init__(self, welcome): - resource.Resource.__init__(self) - w = welcome.copy() - w["error"] = "Sorry, you must upgrade your client to use this server." - message = {"welcome": w} - self._message = (json.dumps(message)+"\n").encode("utf-8") - def get_message(self): - return self._message - def render_POST(self, request): - return self._message - def render_GET(self, request): - return self._message - def getChild(self, path, req): - return self - -class Adder(RelayResource): - def render_POST(self, request): - #content = json.load(request.content, encoding="utf-8") - content = request.content.read() - data = json.loads(content.decode("utf-8")) - appid = data["appid"] - channelid = int(data["channelid"]) - side = data["side"] - phase = data["phase"] - if not isinstance(phase, type(u"")): - raise TypeError("phase must be string, not %s" % type(phase)) - body = data["body"] - #print("ADD", appid, channelid, side, phase, body) - - app = self._rendezvous.get_app(appid) - channel = app.get_channel(channelid) - messages = channel.add_message(side, phase, body, time.time(), None) - response = {"welcome": self._welcome, "messages": messages, - "sent": time.time()} - return json_response(request, response) - -class GetterOrWatcher(RelayResource): - def render_GET(self, request): - appid = request.args[b"appid"][0].decode("utf-8") - channelid = int(request.args[b"channelid"][0]) - #print("GET", appid, channelid) - app = self._rendezvous.get_app(appid) - channel = app.get_channel(channelid) - - if b"text/event-stream" not in (request.getHeader(b"accept") or b""): - messages = channel.get_messages() - response = {"welcome": self._welcome, "messages": messages, - "sent": time.time()} - return json_response(request, response) - - request.setHeader(b"content-type", b"text/event-stream; charset=utf-8") - ep = EventsProtocol(request) - ep.sendEvent(json.dumps(self._welcome), name="welcome") - old_events = channel.add_listener(ep) - request.notifyFinish().addErrback(lambda f: - channel.remove_listener(ep)) - for old_event in old_events: - ep.send_rendezvous_event(old_event) - return server.NOT_DONE_YET - -class Watcher(RelayResource): - def render_GET(self, request): - appid = request.args[b"appid"][0].decode("utf-8") - channelid = int(request.args[b"channelid"][0]) - app = self._rendezvous.get_app(appid) - channel = app.get_channel(channelid) - if b"text/event-stream" not in (request.getHeader(b"accept") or b""): - raise TypeError("/watch is for EventSource only") - - request.setHeader(b"content-type", b"text/event-stream; charset=utf-8") - ep = EventsProtocol(request) - ep.sendEvent(json.dumps(self._welcome), name="welcome") - old_events = channel.add_listener(ep) - request.notifyFinish().addErrback(lambda f: - channel.remove_listener(ep)) - for old_event in old_events: - ep.send_rendezvous_event(old_event) - return server.NOT_DONE_YET - -class Deallocator(RelayResource): - def render_POST(self, request): - content = request.content.read() - data = json.loads(content.decode("utf-8")) - appid = data["appid"] - channelid = int(data["channelid"]) - side = data["side"] - if not isinstance(side, type(u"")): - raise TypeError("side must be string, not '%s'" % type(side)) - mood = data.get("mood") - #print("DEALLOCATE", appid, channelid, side) - - app = self._rendezvous.get_app(appid) - channel = app.get_channel(channelid) - deleted = channel.deallocate(side, mood) - response = {"status": "waiting", "sent": time.time()} - if deleted: - response = {"status": "deleted", "sent": time.time()} - return json_response(request, response) - - -class WebRendezvous(resource.Resource): - def __init__(self, rendezvous): - resource.Resource.__init__(self) - self._rendezvous = rendezvous - self.putChild(b"list", ChannelLister(rendezvous)) - self.putChild(b"allocate", Allocator(rendezvous)) - self.putChild(b"add", Adder(rendezvous)) - self.putChild(b"get", GetterOrWatcher(rendezvous)) - self.putChild(b"watch", Watcher(rendezvous)) - self.putChild(b"deallocate", Deallocator(rendezvous)) - - def getChild(self, path, req): - # 0.4.0 used "POST /CID/SIDE/post/MSGNUM" - # 0.5.0 replaced it with "POST /add (json body)" - # give a nicer error message to old clients - if (len(req.postpath) >= 2 - and req.postpath[1] in (b"post", b"poll", b"deallocate")): - welcome = self._rendezvous.get_welcome() - return NeedToUpgradeErrorResource(welcome) - return resource.NoResource("No such child resource.") diff --git a/src/wormhole/server/rendezvous_websocket.py b/src/wormhole/server/rendezvous_websocket.py index 93abf2c..2a19ab2 100644 --- a/src/wormhole/server/rendezvous_websocket.py +++ b/src/wormhole/server/rendezvous_websocket.py @@ -2,23 +2,48 @@ import json, time from twisted.internet import reactor from twisted.python import log from autobahn.twisted import websocket +from .rendezvous import CrowdedError, SidedMessage -# Each WebSocket connection is bound to one "appid", one "side", and one -# "channelid". The connection's appid and side are set by the "bind" message -# (which must be the first message on the connection). The channelid is set -# by either a "allocate" message (where the server picks the channelid), or -# by a "claim" message (where the client picks it). All three values must be -# set before any other message (watch, add, deallocate) can be sent. +# The WebSocket allows the client to send "commands" to the server, and the +# server to send "responses" to the client. Note that commands and responses +# are not necessarily one-to-one. All commands provoke an "ack" response +# (with a copy of the original message) for timing, testing, and +# synchronization purposes. All commands and responses are JSON-encoded. -# All websocket messages are JSON-encoded. The client can send us "inbound" -# messages (marked as "->" below), which may (or may not) provoke immediate -# (or delayed) "outbound" messages (marked as "<-"). There is no guaranteed -# correlation between requests and responses. In this list, "A -> B" means -# that some time after A is received, at least one message of type B will be -# sent out. +# Each WebSocket connection is bound to one "appid" and one "side", which are +# set by the "bind" command (which must be the first command on the +# connection), and must be set before any other command will be accepted. -# All outbound messages include a "sent" key, which is a float (seconds since -# epoch) with the server clock just before the outbound message was written +# Each connection can be bound to a single "mailbox" (a two-sided +# store-and-forward queue, identified by the "mailbox id": a long, randomly +# unique string identifier) by using the "open" command. This protects the +# mailbox from idle closure, enables the "add" command (to put new messages +# in the queue), and triggers delivery of past and future messages via the +# "message" response. The "close" command removes the binding (but note that +# it does not enable the subsequent binding of a second mailbox). When the +# last side closes a mailbox, its contents are deleted. + +# Additionally, the connection can be bound a single "nameplate", which is +# short identifier that makes up the first component of a wormhole code. Each +# nameplate points to a single long-id "mailbox". The "allocate" message +# determines the shortest available numeric nameplate, reserves it, and +# returns the nameplate id. "list" returns a list of all numeric nameplates +# which currently have only one side active (i.e. they are waiting for a +# partner). The "claim" message reserves an arbitrary nameplate id (perhaps +# the receiver of a wormhole connection typed in a code they got from the +# sender, or perhaps the two sides agreed upon a code offline and are both +# typing it in), and the "release" message releases it. When every side that +# has claimed the nameplate has also released it, the nameplate is +# deallocated (but they will probably keep the underlying mailbox open). + +# Inbound (client to server) commands are marked as "->" below. Unrecognized +# inbound keys will be ignored. Outbound (server to client) responses use +# "<-". There is no guaranteed correlation between requests and responses. In +# this list, "A -> B" means that some time after A is received, at least one +# message of type B will be sent out (probably). + +# All responses include a "server_tx" key, which is a float (seconds since +# epoch) with the server clock just before the outbound response was written # to the socket. # connection -> welcome @@ -27,16 +52,24 @@ from autobahn.twisted import websocket # motd: all clients display message, then continue normally # error: all clients display mesage, then terminate with error # -> {type: "bind", appid:, side:} -# -> {type: "list"} -> channelids -# <- {type: "channelids", channelids: [int..]} -# -> {type: "allocate"} -> allocated -# <- {type: "allocated", channelid: int} -# -> {type: "claim", channelid: int} -# -> {type: "watch"} -> message # sends old messages and more in future -# <- {type: "message", message: {phase:, body:}} # body is hex -# -> {type: "add", phase: str, body: hex} # may send echo -# -> {type: "deallocate", mood: str} -> deallocated -# <- {type: "deallocated", status: waiting|deleted} +# +# -> {type: "list"} -> nameplates +# <- {type: "nameplates", nameplates: [{id: str,..},..]} +# -> {type: "allocate"} -> nameplate, mailbox +# <- {type: "allocated", nameplate: str} +# -> {type: "claim", nameplate: str} -> mailbox +# <- {type: "claimed", mailbox: str} +# -> {type: "release"} +# <- {type: "released"} +# +# -> {type: "open", mailbox: str} -> message +# sends old messages now, and subscribes to deliver future messages +# <- {type: "message", side:, phase:, body:, msg_id:}} # body is hex +# -> {type: "add", phase: str, body: hex} # will send echo in a "message" +# +# -> {type: "close", mood: str} -> closed +# <- {type: "closed"} +# # <- {type: "error", error: str, orig: {}} # in response to malformed msgs # for tests that need to know when a message has been processed: @@ -52,8 +85,9 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): websocket.WebSocketServerProtocol.__init__(self) self._app = None self._side = None - self._channel = None - self._watching = False + self._did_allocate = False # only one allocate() per websocket + self._nameplate_id = None + self._mailbox = None def onConnect(self, request): rv = self.factory.rendezvous @@ -71,10 +105,7 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): try: if "type" not in msg: raise Error("missing 'type'") - if "id" in msg: - # Only ack clients modern enough to include [id]. Older ones - # won't recognize the message, then they'll abort. - self.send("ack", id=msg["id"]) + self.send("ack", id=msg.get("id")) mtype = msg["type"] if mtype == "ping": @@ -83,33 +114,27 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): return self.handle_bind(msg) if not self._app: - raise Error("Must bind first") + raise Error("must bind first") if mtype == "list": return self.handle_list() if mtype == "allocate": - return self.handle_allocate() + return self.handle_allocate(server_rx) if mtype == "claim": - return self.handle_claim(msg) + return self.handle_claim(msg, server_rx) + if mtype == "release": + return self.handle_release(server_rx) - if not self._channel: - raise Error("Must set channel first") - if mtype == "watch": - return self.handle_watch(self._channel, msg) + if mtype == "open": + return self.handle_open(msg, server_rx) if mtype == "add": - return self.handle_add(self._channel, msg, server_rx) - if mtype == "deallocate": - return self.handle_deallocate(self._channel, msg) + return self.handle_add(msg, server_rx) + if mtype == "close": + return self.handle_close(msg, server_rx) - raise Error("Unknown type") + raise Error("unknown type") except Error as e: self.send("error", error=e._explain, orig=msg) - def send_rendezvous_event(self, event): - self.send("message", message=event) - - def stop_rendezvous_watcher(self): - self._reactor.callLater(0, self.transport.loseConnection) - def handle_ping(self, msg): if "ping" not in msg: raise Error("ping requires 'ping'") @@ -125,46 +150,79 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): self._app = self.factory.rendezvous.get_app(msg["appid"]) self._side = msg["side"] + def handle_list(self): - channelids = sorted(self._app.get_allocated()) - self.send("channelids", channelids=channelids) + nameplate_ids = sorted(self._app.get_nameplate_ids()) + # provide room to add nameplate attributes later (like which wordlist + # is used for each, maybe how many words) + nameplates = [{"id": nid} for nid in nameplate_ids] + self.send("nameplates", nameplates=nameplates) - def handle_allocate(self): - if self._channel: - raise Error("Already bound to a channelid") - channelid = self._app.find_available_channelid() - self._channel = self._app.allocate_channel(channelid, self._side) - self.send("allocated", channelid=channelid) + def handle_allocate(self, server_rx): + if self._did_allocate: + raise Error("you already allocated one, don't be greedy") + nameplate_id = self._app.allocate_nameplate(self._side, server_rx) + assert isinstance(nameplate_id, type(u"")) + self._did_allocate = True + self.send("allocated", nameplate=nameplate_id) - def handle_claim(self, msg): - if "channelid" not in msg: - raise Error("claim requires 'channelid'") - # we allow allocate+claim as long as they match - if self._channel is not None: - old_cid = self._channel.get_channelid() - if msg["channelid"] != old_cid: - raise Error("Already bound to channelid %d" % old_cid) - self._channel = self._app.allocate_channel(msg["channelid"], self._side) + def handle_claim(self, msg, server_rx): + if "nameplate" not in msg: + raise Error("claim requires 'nameplate'") + nameplate_id = msg["nameplate"] + assert isinstance(nameplate_id, type(u"")), type(nameplate_id) + self._nameplate_id = nameplate_id + try: + mailbox_id = self._app.claim_nameplate(nameplate_id, self._side, + server_rx) + except CrowdedError: + raise Error("crowded") + self.send("claimed", mailbox=mailbox_id) - def handle_watch(self, channel, msg): - if self._watching: - raise Error("already watching") - self._watching = True - for old_message in channel.add_listener(self): - self.send_rendezvous_event(old_message) + def handle_release(self, server_rx): + if not self._nameplate_id: + raise Error("must claim a nameplate before releasing it") + self._app.release_nameplate(self._nameplate_id, self._side, server_rx) + self._nameplate_id = None + self.send("released") - def handle_add(self, channel, msg, server_rx): + + def handle_open(self, msg, server_rx): + if self._mailbox: + raise Error("you already have a mailbox open") + if "mailbox" not in msg: + raise Error("open requires 'mailbox'") + mailbox_id = msg["mailbox"] + assert isinstance(mailbox_id, type(u"")) + self._mailbox = self._app.open_mailbox(mailbox_id, self._side, + server_rx) + def _send(sm): + self.send("message", side=sm.side, phase=sm.phase, + body=sm.body, server_rx=sm.server_rx, id=sm.msg_id) + def _stop(): + pass + for old_sm in self._mailbox.add_listener(self, _send, _stop): + _send(old_sm) + + def handle_add(self, msg, server_rx): + if not self._mailbox: + raise Error("must open mailbox before adding") if "phase" not in msg: raise Error("missing 'phase'") if "body" not in msg: raise Error("missing 'body'") msgid = msg.get("id") # optional - channel.add_message(self._side, msg["phase"], msg["body"], - server_rx, msgid) + sm = SidedMessage(side=self._side, phase=msg["phase"], + body=msg["body"], server_rx=server_rx, + msg_id=msgid) + self._mailbox.add_message(sm) - def handle_deallocate(self, channel, msg): - deleted = channel.deallocate(self._side, msg.get("mood")) - self.send("deallocated", status="deleted" if deleted else "waiting") + def handle_close(self, msg, server_rx): + if not self._mailbox: + raise Error("must open mailbox before closing") + self._mailbox.close(self._side, msg.get("mood"), server_rx) + self._mailbox = None + self.send("closed") def send(self, mtype, **kwargs): kwargs["type"] = mtype diff --git a/src/wormhole/server/server.py b/src/wormhole/server/server.py index e182e1d..0e65ee4 100644 --- a/src/wormhole/server/server.py +++ b/src/wormhole/server/server.py @@ -8,7 +8,6 @@ from .endpoint_service import ServerEndpointService from .. import __version__ from .database import get_db from .rendezvous import Rendezvous -from .rendezvous_web import WebRendezvous from .rendezvous_websocket import WebSocketRendezvousFactory from .transit_server import Transit @@ -49,12 +48,8 @@ class RelayServer(service.MultiService): rendezvous = Rendezvous(db, welcome, blur_usage) rendezvous.setServiceParent(self) # for the pruning timer - root = Root() - wr = WebRendezvous(rendezvous) - root.putChild(b"wormhole-relay", wr) - wsrf = WebSocketRendezvousFactory(None, rendezvous) - wr.putChild(b"ws", WebSocketResource(wsrf)) + root = WebSocketResource(wsrf) site = PrivacyEnhancedSite(root) if blur_usage: @@ -75,7 +70,6 @@ class RelayServer(service.MultiService): self._db = db self._rendezvous = rendezvous self._root = root - self._rendezvous_web = wr self._rendezvous_web_service = rendezvous_web_service self._rendezvous_websocket = wsrf if transit_port: diff --git a/src/wormhole/server/transit_server.py b/src/wormhole/server/transit_server.py index 2412c89..7c55b44 100644 --- a/src/wormhole/server/transit_server.py +++ b/src/wormhole/server/transit_server.py @@ -186,12 +186,12 @@ class Transit(protocol.ServerFactory, service.MultiService): if self._blur_usage: started = self._blur_usage * (started // self._blur_usage) total_bytes = blur_size(total_bytes) - self._db.execute("INSERT INTO `usage`" - " (`type`, `started`, `result`, `total_bytes`," - " `total_time`, `waiting_time`)" - " VALUES (?,?,?,?, ?,?)", - (u"transit", started, result, total_bytes, - total_time, waiting_time)) + self._db.execute("INSERT INTO `transit_usage`" + " (`started`, `total_time`, `waiting_time`," + " `total_bytes`, `result`)" + " VALUES (?,?,?, ?,?)", + (started, total_time, waiting_time, + total_bytes, result)) self._db.commit() def transitFinished(self, p, token, description): diff --git a/src/wormhole/test/common.py b/src/wormhole/test/common.py index 382cbfe..48c0685 100644 --- a/src/wormhole/test/common.py +++ b/src/wormhole/test/common.py @@ -17,8 +17,7 @@ class ServerBase: s.setServiceParent(self.sp) self._rendezvous = s._rendezvous self._transit_server = s._transit - self.relayurl = u"http://127.0.0.1:%d/wormhole-relay/" % relayport - self.rdv_ws_url = self.relayurl.replace("http:", "ws:") + "ws" + self.relayurl = u"ws://127.0.0.1:%d/" % relayport self.rdv_ws_port = relayport # ws://127.0.0.1:%d/wormhole-relay/ws self.transit = u"tcp:127.0.0.1:%d" % transitport diff --git a/src/wormhole/test/test_blocking.py b/src/wormhole/test/test_blocking.py deleted file mode 100644 index 7ba8222..0000000 --- a/src/wormhole/test/test_blocking.py +++ /dev/null @@ -1,446 +0,0 @@ -from __future__ import print_function -import json -from twisted.trial import unittest -from twisted.internet.defer import gatherResults, succeed -from twisted.internet.threads import deferToThread -from ..blocking.transcribe import (Wormhole, UsageError, ChannelManager, - WrongPasswordError) -from ..blocking.eventsource import EventSourceFollower -from .common import ServerBase - -APPID = u"appid" - -class Channel(ServerBase, unittest.TestCase): - def ignore(self, welcome): - pass - - def test_allocate(self): - cm = ChannelManager(self.relayurl, APPID, u"side", self.ignore) - d = deferToThread(cm.list_channels) - def _got_channels(channels): - self.failUnlessEqual(channels, []) - d.addCallback(_got_channels) - d.addCallback(lambda _: deferToThread(cm.allocate)) - def _allocated(channelid): - self.failUnlessEqual(type(channelid), int) - self._channelid = channelid - d.addCallback(_allocated) - d.addCallback(lambda _: deferToThread(cm.connect, self._channelid)) - def _connected(c): - self._channel = c - d.addCallback(_connected) - d.addCallback(lambda _: deferToThread(self._channel.deallocate, - u"happy")) - return d - - def test_messages(self): - cm1 = ChannelManager(self.relayurl, APPID, u"side1", self.ignore) - cm2 = ChannelManager(self.relayurl, APPID, u"side2", self.ignore) - c1 = cm1.connect(1) - c2 = cm2.connect(1) - - d = succeed(None) - d.addCallback(lambda _: deferToThread(c1.send, u"phase1", b"msg1")) - d.addCallback(lambda _: deferToThread(c2.get, u"phase1")) - d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg1")) - d.addCallback(lambda _: deferToThread(c2.send, u"phase1", b"msg2")) - d.addCallback(lambda _: deferToThread(c1.get, u"phase1")) - d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg2")) - # it's legal to fetch a phase multiple times, should be idempotent - d.addCallback(lambda _: deferToThread(c1.get, u"phase1")) - d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg2")) - # deallocating one side is not enough to destroy the channel - d.addCallback(lambda _: deferToThread(c2.deallocate)) - def _not_yet(_): - self._rendezvous.prune() - self.failUnlessEqual(len(self._rendezvous._apps), 1) - d.addCallback(_not_yet) - # but deallocating both will make the messages go away - d.addCallback(lambda _: deferToThread(c1.deallocate, u"sad")) - def _gone(_): - self._rendezvous.prune() - self.failUnlessEqual(len(self._rendezvous._apps), 0) - d.addCallback(_gone) - - return d - - def test_get_multiple_phases(self): - cm1 = ChannelManager(self.relayurl, APPID, u"side1", self.ignore) - cm2 = ChannelManager(self.relayurl, APPID, u"side2", self.ignore) - c1 = cm1.connect(1) - c2 = cm2.connect(1) - - self.failUnlessRaises(TypeError, c2.get_first_of, u"phase1") - self.failUnlessRaises(TypeError, c2.get_first_of, [u"phase1", 7]) - - d = succeed(None) - d.addCallback(lambda _: deferToThread(c1.send, u"phase1", b"msg1")) - - d.addCallback(lambda _: deferToThread(c2.get_first_of, [u"phase1", - u"phase2"])) - d.addCallback(lambda phase_and_body: - self.failUnlessEqual(phase_and_body, - (u"phase1", b"msg1"))) - d.addCallback(lambda _: deferToThread(c2.get_first_of, [u"phase2", - u"phase1"])) - d.addCallback(lambda phase_and_body: - self.failUnlessEqual(phase_and_body, - (u"phase1", b"msg1"))) - - d.addCallback(lambda _: deferToThread(c1.send, u"phase2", b"msg2")) - d.addCallback(lambda _: deferToThread(c2.get, u"phase2")) - - # if both are present, it should prefer the first one we asked for - d.addCallback(lambda _: deferToThread(c2.get_first_of, [u"phase1", - u"phase2"])) - d.addCallback(lambda phase_and_body: - self.failUnlessEqual(phase_and_body, - (u"phase1", b"msg1"))) - d.addCallback(lambda _: deferToThread(c2.get_first_of, [u"phase2", - u"phase1"])) - d.addCallback(lambda phase_and_body: - self.failUnlessEqual(phase_and_body, - (u"phase2", b"msg2"))) - - return d - - def test_appid_independence(self): - APPID_A = u"appid_A" - APPID_B = u"appid_B" - cm1a = ChannelManager(self.relayurl, APPID_A, u"side1", self.ignore) - cm2a = ChannelManager(self.relayurl, APPID_A, u"side2", self.ignore) - c1a = cm1a.connect(1) - c2a = cm2a.connect(1) - cm1b = ChannelManager(self.relayurl, APPID_B, u"side1", self.ignore) - cm2b = ChannelManager(self.relayurl, APPID_B, u"side2", self.ignore) - c1b = cm1b.connect(1) - c2b = cm2b.connect(1) - - d = succeed(None) - d.addCallback(lambda _: deferToThread(c1a.send, u"phase1", b"msg1a")) - d.addCallback(lambda _: deferToThread(c1b.send, u"phase1", b"msg1b")) - d.addCallback(lambda _: deferToThread(c2a.get, u"phase1")) - d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg1a")) - d.addCallback(lambda _: deferToThread(c2b.get, u"phase1")) - d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg1b")) - return d - -class _DoBothMixin: - def doBoth(self, call1, call2): - f1 = call1[0] - f1args = call1[1:] - f2 = call2[0] - f2args = call2[1:] - return gatherResults([deferToThread(f1, *f1args), - deferToThread(f2, *f2args)], True) - -class Blocking(_DoBothMixin, ServerBase, unittest.TestCase): - # we need Twisted to run the server, but we run the sender and receiver - # with deferToThread() - - def test_basic(self): - w1 = Wormhole(APPID, self.relayurl) - w2 = Wormhole(APPID, self.relayurl) - d = deferToThread(w1.get_code) - def _got_code(code): - w2.set_code(code) - return self.doBoth([w1.send_data, b"data1"], - [w2.send_data, b"data2"]) - d.addCallback(_got_code) - def _sent(res): - return self.doBoth([w1.get_data], [w2.get_data]) - d.addCallback(_sent) - def _done(dl): - (dataX, dataY) = dl - self.assertEqual(dataX, b"data2") - self.assertEqual(dataY, b"data1") - return self.doBoth([w1.close], [w2.close]) - d.addCallback(_done) - return d - - def test_same_message(self): - # the two sides use random nonces for their messages, so it's ok for - # both to try and send the same body: they'll result in distinct - # encrypted messages - w1 = Wormhole(APPID, self.relayurl) - w2 = Wormhole(APPID, self.relayurl) - d = deferToThread(w1.get_code) - def _got_code(code): - w2.set_code(code) - return self.doBoth([w1.send_data, b"data"], - [w2.send_data, b"data"]) - d.addCallback(_got_code) - def _sent(res): - return self.doBoth([w1.get_data], [w2.get_data]) - d.addCallback(_sent) - def _done(dl): - (dataX, dataY) = dl - self.assertEqual(dataX, b"data") - self.assertEqual(dataY, b"data") - return self.doBoth([w1.close], [w2.close]) - d.addCallback(_done) - return d - - def test_interleaved(self): - w1 = Wormhole(APPID, self.relayurl) - w2 = Wormhole(APPID, self.relayurl) - d = deferToThread(w1.get_code) - def _got_code(code): - w2.set_code(code) - return self.doBoth([w1.send_data, b"data1"], - [w2.get_data]) - d.addCallback(_got_code) - def _sent(res): - (_, dataY) = res - self.assertEqual(dataY, b"data1") - return self.doBoth([w1.get_data], [w2.send_data, b"data2"]) - d.addCallback(_sent) - def _done(dl): - (dataX, _) = dl - self.assertEqual(dataX, b"data2") - return self.doBoth([w1.close], [w2.close]) - d.addCallback(_done) - return d - - def test_fixed_code(self): - w1 = Wormhole(APPID, self.relayurl) - w2 = Wormhole(APPID, self.relayurl) - w1.set_code(u"123-purple-elephant") - w2.set_code(u"123-purple-elephant") - d = self.doBoth([w1.send_data, b"data1"], [w2.send_data, b"data2"]) - def _sent(res): - return self.doBoth([w1.get_data], [w2.get_data]) - d.addCallback(_sent) - def _done(dl): - (dataX, dataY) = dl - self.assertEqual(dataX, b"data2") - self.assertEqual(dataY, b"data1") - return self.doBoth([w1.close], [w2.close]) - d.addCallback(_done) - return d - - def test_phases(self): - w1 = Wormhole(APPID, self.relayurl) - w2 = Wormhole(APPID, self.relayurl) - w1.set_code(u"123-purple-elephant") - w2.set_code(u"123-purple-elephant") - d = self.doBoth([w1.send_data, b"data1", u"p1"], - [w2.send_data, b"data2", u"p1"]) - d.addCallback(lambda _: - self.doBoth([w1.send_data, b"data3", u"p2"], - [w2.send_data, b"data4", u"p2"])) - d.addCallback(lambda _: - self.doBoth([w1.get_data, u"p2"], - [w2.get_data, u"p1"])) - def _got_1(dl): - (dataX, dataY) = dl - self.assertEqual(dataX, b"data4") - self.assertEqual(dataY, b"data1") - return self.doBoth([w1.get_data, u"p1"], - [w2.get_data, u"p2"]) - d.addCallback(_got_1) - def _got_2(dl): - (dataX, dataY) = dl - self.assertEqual(dataX, b"data2") - self.assertEqual(dataY, b"data3") - return self.doBoth([w1.close], [w2.close]) - d.addCallback(_got_2) - return d - - def test_wrong_password(self): - w1 = Wormhole(APPID, self.relayurl) - w2 = Wormhole(APPID, self.relayurl) - - # make sure we can detect WrongPasswordError even if one side only - # does get_data() and not send_data(), like "wormhole receive" does - d = deferToThread(w1.get_code) - d.addCallback(lambda code: w2.set_code(code+"not")) - - # w2 can't throw WrongPasswordError until it sees a CONFIRM message, - # and w1 won't send CONFIRM until it sees a PAKE message, which w2 - # won't send until we call get_data. So we need both sides to be - # running at the same time for this test. - def _w1_sends(): - w1.send_data(b"data1") - def _w2_gets(): - self.assertRaises(WrongPasswordError, w2.get_data) - d.addCallback(lambda _: self.doBoth([_w1_sends], [_w2_gets])) - - # and now w1 should have enough information to throw too - d.addCallback(lambda _: deferToThread(self.assertRaises, - WrongPasswordError, w1.get_data)) - def _done(_): - # both sides are closed automatically upon error, but it's still - # legal to call .close(), and should be idempotent - return self.doBoth([w1.close], [w2.close]) - d.addCallback(_done) - return d - - def test_no_confirm(self): - # newer versions (which check confirmations) should will work with - # older versions (that don't send confirmations) - w1 = Wormhole(APPID, self.relayurl) - w1._send_confirm = False - w2 = Wormhole(APPID, self.relayurl) - - d = deferToThread(w1.get_code) - d.addCallback(lambda code: w2.set_code(code)) - d.addCallback(lambda _: self.doBoth([w1.send_data, b"data1"], - [w2.get_data])) - d.addCallback(lambda dl: self.assertEqual(dl[1], b"data1")) - d.addCallback(lambda _: self.doBoth([w1.get_data], - [w2.send_data, b"data2"])) - d.addCallback(lambda dl: self.assertEqual(dl[0], b"data2")) - d.addCallback(lambda _: self.doBoth([w1.close], [w2.close])) - return d - - def test_verifier(self): - w1 = Wormhole(APPID, self.relayurl) - w2 = Wormhole(APPID, self.relayurl) - d = deferToThread(w1.get_code) - def _got_code(code): - w2.set_code(code) - return self.doBoth([w1.get_verifier], [w2.get_verifier]) - d.addCallback(_got_code) - def _check_verifier(res): - v1, v2 = res - self.failUnlessEqual(type(v1), type(b"")) - self.failUnlessEqual(v1, v2) - return self.doBoth([w1.send_data, b"data1"], - [w2.send_data, b"data2"]) - d.addCallback(_check_verifier) - def _sent(res): - return self.doBoth([w1.get_data], [w2.get_data]) - d.addCallback(_sent) - def _done(dl): - (dataX, dataY) = dl - self.assertEqual(dataX, b"data2") - self.assertEqual(dataY, b"data1") - return self.doBoth([w1.close], [w2.close]) - d.addCallback(_done) - return d - - def test_verifier_mismatch(self): - w1 = Wormhole(APPID, self.relayurl) - w2 = Wormhole(APPID, self.relayurl) - d = deferToThread(w1.get_code) - def _got_code(code): - w2.set_code(code+"not") - return self.doBoth([w1.get_verifier], [w2.get_verifier]) - d.addCallback(_got_code) - def _check_verifier(res): - v1, v2 = res - self.failUnlessEqual(type(v1), type(b"")) - self.failIfEqual(v1, v2) - return self.doBoth([w1.close], [w2.close]) - d.addCallback(_check_verifier) - return d - - def test_errors(self): - w1 = Wormhole(APPID, self.relayurl) - self.assertRaises(UsageError, w1.get_verifier) - self.assertRaises(UsageError, w1.get_data) - self.assertRaises(UsageError, w1.send_data, b"data") - w1.set_code(u"123-purple-elephant") - self.assertRaises(UsageError, w1.set_code, u"123-nope") - self.assertRaises(UsageError, w1.get_code) - w2 = Wormhole(APPID, self.relayurl) - d = deferToThread(w2.get_code) - def _done(code): - self.assertRaises(UsageError, w2.get_code) - return self.doBoth([w1.close], [w2.close]) - d.addCallback(_done) - return d - - def test_repeat_phases(self): - w1 = Wormhole(APPID, self.relayurl) - w1.set_code(u"123-purple-elephant") - w2 = Wormhole(APPID, self.relayurl) - w2.set_code(u"123-purple-elephant") - # we must let them establish a key before we can send data - d = self.doBoth([w1.get_verifier], [w2.get_verifier]) - d.addCallback(lambda _: - deferToThread(w1.send_data, b"data1", phase=u"1")) - def _sent(res): - # underscore-prefixed phases are reserved - self.assertRaises(UsageError, w1.send_data, b"data1", phase=u"_1") - self.assertRaises(UsageError, w1.get_data, phase=u"_1") - # you can't send twice to the same phase - self.assertRaises(UsageError, w1.send_data, b"data1", phase=u"1") - # but you can send to a different one - return deferToThread(w1.send_data, b"data2", phase=u"2") - d.addCallback(_sent) - d.addCallback(lambda _: deferToThread(w2.get_data, phase=u"1")) - def _got1(res): - self.failUnlessEqual(res, b"data1") - # and you can't read twice from the same phase - self.assertRaises(UsageError, w2.get_data, phase=u"1") - # but you can read from a different one - return deferToThread(w2.get_data, phase=u"2") - d.addCallback(_got1) - def _got2(res): - self.failUnlessEqual(res, b"data2") - return self.doBoth([w1.close], [w2.close]) - d.addCallback(_got2) - return d - - def test_serialize(self): - w1 = Wormhole(APPID, self.relayurl) - self.assertRaises(UsageError, w1.serialize) # too early - w2 = Wormhole(APPID, self.relayurl) - d = deferToThread(w1.get_code) - def _got_code(code): - self.assertRaises(UsageError, w2.serialize) # too early - w2.set_code(code) - w2.serialize() # ok - s = w1.serialize() - self.assertEqual(type(s), type("")) - unpacked = json.loads(s) # this is supposed to be JSON - self.assertEqual(type(unpacked), dict) - self.new_w1 = Wormhole.from_serialized(s) - return self.doBoth([self.new_w1.send_data, b"data1"], - [w2.send_data, b"data2"]) - d.addCallback(_got_code) - def _sent(res): - return self.doBoth(self.new_w1.get_data(), w2.get_data()) - d.addCallback(_sent) - def _done(dl): - (dataX, dataY) = dl - self.assertEqual(dataX, b"data2") - self.assertEqual(dataY, b"data1") - self.assertRaises(UsageError, w2.serialize) # too late - return self.doBoth([w1.close], [w2.close]) - d.addCallback(_done) - return d - test_serialize.skip = "not yet implemented for the blocking flavor" - -data1 = u"""\ -event: welcome -data: one and a -data: two -data:. - -data: three - -: this line is ignored -event: e2 -: this line is ignored too -i am a dataless field name -data: four - -""" - -class NoNetworkESF(EventSourceFollower): - def __init__(self, text): - self._lines_iter = iter(text.splitlines()) - -class EventSourceClient(unittest.TestCase): - def test_parser(self): - events = [] - f = NoNetworkESF(data1) - events = list(f.iter_events()) - self.failUnlessEqual(events, - [(u"welcome", u"one and a\ntwo\n."), - (u"message", u"three"), - (u"e2", u"four"), - ]) diff --git a/src/wormhole/test/test_interop.py b/src/wormhole/test/test_interop.py deleted file mode 100644 index 23daa9c..0000000 --- a/src/wormhole/test/test_interop.py +++ /dev/null @@ -1,57 +0,0 @@ -from __future__ import print_function -from twisted.trial import unittest -from twisted.internet.defer import gatherResults -from twisted.internet.threads import deferToThread -from ..twisted.transcribe import Wormhole as twisted_Wormhole -from ..blocking.transcribe import Wormhole as blocking_Wormhole -from .common import ServerBase - -# make sure the two implementations (Twisted-style and blocking-style) can -# interoperate - -APPID = u"appid" - -class Basic(ServerBase, unittest.TestCase): - - def doBoth(self, call1, d2): - f1 = call1[0] - f1args = call1[1:] - return gatherResults([deferToThread(f1, *f1args), d2], True) - - def test_twisted_to_blocking(self): - tw = twisted_Wormhole(APPID, self.relayurl) - bw = blocking_Wormhole(APPID, self.relayurl) - d = tw.get_code() - def _got_code(code): - bw.set_code(code) - return self.doBoth([bw.send_data, b"data2"], tw.send_data(b"data1")) - d.addCallback(_got_code) - def _sent(res): - return self.doBoth([bw.get_data], tw.get_data()) - d.addCallback(_sent) - def _done(dl): - (dataX, dataY) = dl - self.assertEqual(dataX, b"data1") - self.assertEqual(dataY, b"data2") - return self.doBoth([bw.close], tw.close()) - d.addCallback(_done) - return d - - def test_blocking_to_twisted(self): - bw = blocking_Wormhole(APPID, self.relayurl) - tw = twisted_Wormhole(APPID, self.relayurl) - d = deferToThread(bw.get_code) - def _got_code(code): - tw.set_code(code) - return self.doBoth([bw.send_data, b"data1"], tw.send_data(b"data2")) - d.addCallback(_got_code) - def _sent(res): - return self.doBoth([bw.get_data], tw.get_data()) - d.addCallback(_sent) - def _done(dl): - (dataX, dataY) = dl - self.assertEqual(dataX, b"data2") - self.assertEqual(dataY, b"data1") - return self.doBoth([bw.close], tw.close()) - d.addCallback(_done) - return d diff --git a/src/wormhole/test/test_scripts.py b/src/wormhole/test/test_scripts.py index b39f2c5..71fd33a 100644 --- a/src/wormhole/test/test_scripts.py +++ b/src/wormhole/test/test_scripts.py @@ -453,7 +453,7 @@ class Cleanup(ServerBase, unittest.TestCase): yield send_d yield receive_d - cids = self._rendezvous.get_app(cmd_send.APPID).get_allocated() + cids = self._rendezvous.get_app(cmd_send.APPID).get_nameplate_ids() self.assertEqual(len(cids), 0) @inlineCallbacks @@ -482,6 +482,7 @@ class Cleanup(ServerBase, unittest.TestCase): yield self.assertFailure(send_d, WrongPasswordError) yield self.assertFailure(receive_d, WrongPasswordError) - cids = self._rendezvous.get_app(cmd_send.APPID).get_allocated() + cids = self._rendezvous.get_app(cmd_send.APPID).get_nameplate_ids() self.assertEqual(len(cids), 0) + self.flushLoggedErrors(WrongPasswordError) diff --git a/src/wormhole/test/test_server.py b/src/wormhole/test/test_server.py index d9010f1..437e8bd 100644 --- a/src/wormhole/test/test_server.py +++ b/src/wormhole/test/test_server.py @@ -1,59 +1,286 @@ from __future__ import print_function import json, itertools from binascii import hexlify -import requests -from six.moves.urllib_parse import urlencode from twisted.trial import unittest from twisted.internet import protocol, reactor, defer from twisted.internet.defer import inlineCallbacks, returnValue -from twisted.internet.threads import deferToThread from twisted.internet.endpoints import clientFromString, connectProtocol -from twisted.web.client import getPage, Agent, readBody from autobahn.twisted import websocket from .. import __version__ from .common import ServerBase from ..server import rendezvous, transit_server -from ..twisted.eventsource import EventSource +from ..server.rendezvous import Usage, SidedMessage -class Reachable(ServerBase, unittest.TestCase): +class Server(ServerBase, unittest.TestCase): + def test_apps(self): + app1 = self._rendezvous.get_app(u"appid1") + self.assertIdentical(app1, self._rendezvous.get_app(u"appid1")) + app2 = self._rendezvous.get_app(u"appid2") + self.assertNotIdentical(app1, app2) - def test_getPage(self): - # client.getPage requires bytes URL, returns bytes - url = self.relayurl.replace("wormhole-relay/", "").encode("ascii") - d = getPage(url) - def _got(res): - self.failUnlessEqual(res, b"Wormhole Relay\n") - d.addCallback(_got) - return d + def test_nameplate_allocation(self): + app = self._rendezvous.get_app(u"appid") + nids = set() + # this takes a second, and claims all the short-numbered nameplates + def add(): + nameplate_id = app.allocate_nameplate(u"side1", 0) + self.assertEqual(type(nameplate_id), type(u"")) + nid = int(nameplate_id) + nids.add(nid) + for i in range(9): add() + self.assertNotIn(0, nids) + self.assertEqual(set(range(1,10)), nids) - def test_agent(self): - url = self.relayurl.replace("wormhole-relay/", "").encode("ascii") - agent = Agent(reactor) - d = agent.request(b"GET", url) - def _check(resp): - self.failUnlessEqual(resp.code, 200) - return readBody(resp) - d.addCallback(_check) - def _got(res): - self.failUnlessEqual(res, b"Wormhole Relay\n") - d.addCallback(_got) - return d + for i in range(100-10): add() + self.assertEqual(len(nids), 99) + self.assertEqual(set(range(1,100)), nids) - def test_requests(self): - # requests requires bytes URL, returns unicode - url = self.relayurl.replace("wormhole-relay/", "") - def _get(url): - r = requests.get(url) - r.raise_for_status() - return r.text - d = deferToThread(_get, url) - def _got(res): - self.failUnlessEqual(res, "Wormhole Relay\n") - d.addCallback(_got) - return d + for i in range(1000-100): add() + self.assertEqual(len(nids), 999) + self.assertEqual(set(range(1,1000)), nids) + + add() + self.assertEqual(len(nids), 1000) + biggest = max(nids) + self.assert_(1000 <= biggest < 1000000, biggest) + + def _nameplate(self, app, nameplate_id): + return app._db.execute("SELECT * FROM `nameplates`" + " WHERE `app_id`='appid' AND `id`=?", + (nameplate_id,)).fetchone() + + def test_nameplate(self): + app = self._rendezvous.get_app(u"appid") + nameplate_id = app.allocate_nameplate(u"side1", 0) + self.assertEqual(type(nameplate_id), type(u"")) + nid = int(nameplate_id) + self.assert_(0 < nid < 10, nid) + self.assertEqual(app.get_nameplate_ids(), set([nameplate_id])) + # allocate also does a claim + row = self._nameplate(app, nameplate_id) + self.assertEqual(row["side1"], u"side1") + self.assertEqual(row["side2"], None) + self.assertEqual(row["crowded"], False) + self.assertEqual(row["started"], 0) + self.assertEqual(row["second"], None) + + mailbox_id = app.claim_nameplate(nameplate_id, u"side1", 1) + self.assertEqual(type(mailbox_id), type(u"")) + # duplicate claims by the same side are combined + row = self._nameplate(app, nameplate_id) + self.assertEqual(row["side1"], u"side1") + self.assertEqual(row["side2"], None) + + mailbox_id2 = app.claim_nameplate(nameplate_id, u"side1", 2) + self.assertEqual(mailbox_id, mailbox_id2) + row = self._nameplate(app, nameplate_id) + self.assertEqual(row["side1"], u"side1") + self.assertEqual(row["side2"], None) + self.assertEqual(row["started"], 0) + self.assertEqual(row["second"], None) + + # claim by the second side is new + mailbox_id3 = app.claim_nameplate(nameplate_id, u"side2", 3) + self.assertEqual(mailbox_id, mailbox_id3) + row = self._nameplate(app, nameplate_id) + self.assertEqual(row["side1"], u"side1") + self.assertEqual(row["side2"], u"side2") + self.assertEqual(row["crowded"], False) + self.assertEqual(row["started"], 0) + self.assertEqual(row["second"], 3) + + # a third claim marks the nameplate as "crowded", but leaves the two + # existing claims alone + self.assertRaises(rendezvous.CrowdedError, + app.claim_nameplate, nameplate_id, u"side3", 0) + row = self._nameplate(app, nameplate_id) + self.assertEqual(row["side1"], u"side1") + self.assertEqual(row["side2"], u"side2") + self.assertEqual(row["crowded"], True) + + # releasing a non-existent nameplate is ignored + app.release_nameplate(nameplate_id+u"not", u"side4", 0) + + # releasing a side that never claimed the nameplate is ignored + app.release_nameplate(nameplate_id, u"side4", 0) + row = self._nameplate(app, nameplate_id) + self.assertEqual(row["side1"], u"side1") + self.assertEqual(row["side2"], u"side2") + + # releasing one side leaves the second claim + app.release_nameplate(nameplate_id, u"side1", 5) + row = self._nameplate(app, nameplate_id) + self.assertEqual(row["side1"], u"side2") + self.assertEqual(row["side2"], None) + + # releasing one side multiple times is ignored + app.release_nameplate(nameplate_id, u"side1", 5) + row = self._nameplate(app, nameplate_id) + self.assertEqual(row["side1"], u"side2") + self.assertEqual(row["side2"], None) + + # releasing the second side frees the nameplate, and adds usage + app.release_nameplate(nameplate_id, u"side2", 6) + row = self._nameplate(app, nameplate_id) + self.assertEqual(row, None) + usage = app._db.execute("SELECT * FROM `nameplate_usage`").fetchone() + self.assertEqual(usage["app_id"], u"appid") + self.assertEqual(usage["started"], 0) + self.assertEqual(usage["waiting_time"], 3) + self.assertEqual(usage["total_time"], 6) + self.assertEqual(usage["result"], u"crowded") + + + def _mailbox(self, app, mailbox_id): + return app._db.execute("SELECT * FROM `mailboxes`" + " WHERE `app_id`='appid' AND `id`=?", + (mailbox_id,)).fetchone() + + def test_mailbox(self): + app = self._rendezvous.get_app(u"appid") + mailbox_id = u"mid" + m1 = app.open_mailbox(mailbox_id, u"side1", 0) + + row = self._mailbox(app, mailbox_id) + self.assertEqual(row["side1"], u"side1") + self.assertEqual(row["side2"], None) + self.assertEqual(row["crowded"], False) + self.assertEqual(row["started"], 0) + self.assertEqual(row["second"], None) + + # opening the same mailbox twice, by the same side, gets the same + # object + self.assertIdentical(m1, app.open_mailbox(mailbox_id, u"side1", 1)) + row = self._mailbox(app, mailbox_id) + self.assertEqual(row["side1"], u"side1") + self.assertEqual(row["side2"], None) + self.assertEqual(row["crowded"], False) + self.assertEqual(row["started"], 0) + self.assertEqual(row["second"], None) + + # opening a second side gets the same object, and adds a new claim + self.assertIdentical(m1, app.open_mailbox(mailbox_id, u"side2", 2)) + row = self._mailbox(app, mailbox_id) + self.assertEqual(row["side1"], u"side1") + self.assertEqual(row["side2"], u"side2") + self.assertEqual(row["crowded"], False) + self.assertEqual(row["started"], 0) + self.assertEqual(row["second"], 2) + + # a third open marks it as crowded + self.assertRaises(rendezvous.CrowdedError, + app.open_mailbox, mailbox_id, u"side3", 3) + row = self._mailbox(app, mailbox_id) + self.assertEqual(row["side1"], u"side1") + self.assertEqual(row["side2"], u"side2") + self.assertEqual(row["crowded"], True) + self.assertEqual(row["started"], 0) + self.assertEqual(row["second"], 2) + + # closing a side that never claimed the mailbox is ignored + m1.close(u"side4", u"mood", 4) + row = self._mailbox(app, mailbox_id) + self.assertEqual(row["side1"], u"side1") + self.assertEqual(row["side2"], u"side2") + self.assertEqual(row["crowded"], True) + self.assertEqual(row["started"], 0) + self.assertEqual(row["second"], 2) + + # closing one side leaves the second claim + m1.close(u"side1", u"mood", 5) + row = self._mailbox(app, mailbox_id) + self.assertEqual(row["side1"], u"side2") + self.assertEqual(row["side2"], None) + self.assertEqual(row["crowded"], True) + self.assertEqual(row["started"], 0) + self.assertEqual(row["second"], 2) + + # closing one side multiple is ignored + m1.close(u"side1", u"mood", 6) + row = self._mailbox(app, mailbox_id) + self.assertEqual(row["side1"], u"side2") + self.assertEqual(row["side2"], None) + self.assertEqual(row["crowded"], True) + self.assertEqual(row["started"], 0) + self.assertEqual(row["second"], 2) + + l1 = []; stop1 = []; stop1_f = lambda: stop1.append(True) + m1.add_listener("handle1", l1.append, stop1_f) + + # closing the second side frees the mailbox, and adds usage + m1.close(u"side2", u"mood", 7) + self.assertEqual(stop1, [True]) + + row = self._mailbox(app, mailbox_id) + self.assertEqual(row, None) + usage = app._db.execute("SELECT * FROM `mailbox_usage`").fetchone() + self.assertEqual(usage["app_id"], u"appid") + self.assertEqual(usage["started"], 0) + self.assertEqual(usage["waiting_time"], 2) + self.assertEqual(usage["total_time"], 7) + self.assertEqual(usage["result"], u"crowded") + + def _messages(self, app): + c = app._db.execute("SELECT * FROM `messages`" + " WHERE `app_id`='appid' AND `mailbox_id`='mid'") + return c.fetchall() + + def test_messages(self): + app = self._rendezvous.get_app(u"appid") + mailbox_id = u"mid" + m1 = app.open_mailbox(mailbox_id, u"side1", 0) + m1.add_message(SidedMessage(side=u"side1", phase=u"phase", + body=u"body", server_rx=1, + msg_id=u"msgid")) + msgs = self._messages(app) + self.assertEqual(len(msgs), 1) + self.assertEqual(msgs[0]["body"], u"body") + + l1 = []; stop1 = []; stop1_f = lambda: stop1.append(True) + l2 = []; stop2 = []; stop2_f = lambda: stop2.append(True) + old = m1.add_listener("handle1", l1.append, stop1_f) + self.assertEqual(len(old), 1) + self.assertEqual(old[0].side, u"side1") + self.assertEqual(old[0].body, u"body") + + m1.add_message(SidedMessage(side=u"side1", phase=u"phase2", + body=u"body2", server_rx=1, + msg_id=u"msgid")) + self.assertEqual(len(l1), 1) + self.assertEqual(l1[0].body, u"body2") + old = m1.add_listener("handle2", l2.append, stop2_f) + self.assertEqual(len(old), 2) + + m1.add_message(SidedMessage(side=u"side1", phase=u"phase3", + body=u"body3", server_rx=1, + msg_id=u"msgid")) + self.assertEqual(len(l1), 2) + self.assertEqual(l1[-1].body, u"body3") + self.assertEqual(len(l2), 1) + self.assertEqual(l2[-1].body, u"body3") + + m1.remove_listener("handle1") + + m1.add_message(SidedMessage(side=u"side1", phase=u"phase4", + body=u"body4", server_rx=1, + msg_id=u"msgid")) + self.assertEqual(len(l1), 2) + self.assertEqual(l1[-1].body, u"body3") + self.assertEqual(len(l2), 2) + self.assertEqual(l2[-1].body, u"body4") + + m1._shutdown() + self.assertEqual(stop1, []) + self.assertEqual(stop2, [True]) + + # message adds are not idempotent: clients filter duplicates + m1.add_message(SidedMessage(side=u"side1", phase=u"phase", + body=u"body", server_rx=1, + msg_id=u"msgid")) + msgs = self._messages(app) + self.assertEqual(len(msgs), 5) + self.assertEqual(msgs[-1]["body"], u"body") -def unjson(data): - return json.loads(data.decode("utf-8")) def strip_message(msg): m2 = msg.copy() @@ -64,323 +291,6 @@ def strip_message(msg): def strip_messages(messages): return [strip_message(m) for m in messages] -class WebAPI(ServerBase, unittest.TestCase): - def build_url(self, path, appid, channelid): - url = self.relayurl+path - queryargs = [] - if appid: - queryargs.append(("appid", appid)) - if channelid: - queryargs.append(("channelid", channelid)) - if queryargs: - url += "?" + urlencode(queryargs) - return url - - def get(self, path, appid=None, channelid=None): - url = self.build_url(path, appid, channelid) - d = getPage(url.encode("ascii")) - d.addCallback(unjson) - return d - - def post(self, path, data): - url = self.relayurl+path - d = getPage(url.encode("ascii"), method=b"POST", - postdata=json.dumps(data).encode("utf-8")) - d.addCallback(unjson) - return d - - def check_welcome(self, data): - self.failUnlessIn("welcome", data) - self.failUnlessEqual(data["welcome"], {"current_version": __version__}) - - def test_allocate_1(self): - d = self.get("list", "app1") - def _check_list_1(data): - self.check_welcome(data) - self.failUnlessEqual(data["channelids"], []) - d.addCallback(_check_list_1) - - d.addCallback(lambda _: self.post("allocate", {"appid": "app1", - "side": "abc"})) - def _allocated(data): - data.pop("sent", None) - self.failUnlessEqual(set(data.keys()), - set(["welcome", "channelid"])) - self.failUnlessIsInstance(data["channelid"], int) - self.cid = data["channelid"] - d.addCallback(_allocated) - - d.addCallback(lambda _: self.get("list", "app1")) - def _check_list_2(data): - self.failUnlessEqual(data["channelids"], [self.cid]) - d.addCallback(_check_list_2) - - d.addCallback(lambda _: self.post("deallocate", - {"appid": "app1", - "channelid": str(self.cid), - "side": "abc"})) - def _check_deallocate(res): - self.failUnlessEqual(res["status"], "deleted") - d.addCallback(_check_deallocate) - - d.addCallback(lambda _: self.get("list", "app1")) - def _check_list_3(data): - self.failUnlessEqual(data["channelids"], []) - d.addCallback(_check_list_3) - - return d - - def test_allocate_2(self): - d = self.post("allocate", {"appid": "app1", "side": "abc"}) - def _allocated(data): - self.cid = data["channelid"] - d.addCallback(_allocated) - - # second caller increases the number of known sides to 2 - d.addCallback(lambda _: self.post("add", - {"appid": "app1", - "channelid": str(self.cid), - "side": "def", - "phase": "1", - "body": ""})) - - d.addCallback(lambda _: self.get("list", "app1")) - d.addCallback(lambda data: - self.failUnlessEqual(data["channelids"], [self.cid])) - - d.addCallback(lambda _: self.post("deallocate", - {"appid": "app1", - "channelid": str(self.cid), - "side": "abc"})) - d.addCallback(lambda res: - self.failUnlessEqual(res["status"], "waiting")) - - d.addCallback(lambda _: self.post("deallocate", - {"appid": "app1", - "channelid": str(self.cid), - "side": "NOT"})) - d.addCallback(lambda res: - self.failUnlessEqual(res["status"], "waiting")) - - d.addCallback(lambda _: self.post("deallocate", - {"appid": "app1", - "channelid": str(self.cid), - "side": "def"})) - d.addCallback(lambda res: - self.failUnlessEqual(res["status"], "deleted")) - - d.addCallback(lambda _: self.get("list", "app1")) - d.addCallback(lambda data: - self.failUnlessEqual(data["channelids"], [])) - - return d - - UPGRADE_ERROR = "Sorry, you must upgrade your client to use this server." - def test_old_allocate(self): - # 0.4.0 used "POST /allocate/SIDE". - # 0.5.0 replaced it with "POST /allocate". - # test that an old client gets a useful error message, not a 404. - d = self.post("allocate/abc", {}) - def _check(data): - self.failUnlessEqual(data["welcome"]["error"], self.UPGRADE_ERROR) - d.addCallback(_check) - return d - - def test_old_list(self): - # 0.4.0 used "GET /list". - # 0.5.0 replaced it with "GET /list?appid=" - d = self.get("list", {}) # no appid - def _check(data): - self.failUnlessEqual(data["welcome"]["error"], self.UPGRADE_ERROR) - d.addCallback(_check) - return d - - def test_old_post(self): - # 0.4.0 used "POST /CID/SIDE/post/MSGNUM" - # 0.5.0 replaced it with "POST /add (json body)" - d = self.post("1/abc/post/pake", {}) - def _check(data): - self.failUnlessEqual(data["welcome"]["error"], self.UPGRADE_ERROR) - d.addCallback(_check) - return d - - def add_message(self, message, side="abc", phase="1"): - return self.post("add", - {"appid": "app1", - "channelid": str(self.cid), - "side": side, - "phase": phase, - "body": message}) - - def parse_messages(self, messages): - out = set() - for m in messages: - self.failUnlessEqual(sorted(m.keys()), sorted(["phase", "body"])) - self.failUnlessIsInstance(m["phase"], type(u"")) - self.failUnlessIsInstance(m["body"], type(u"")) - out.add( (m["phase"], m["body"]) ) - return out - - def check_messages(self, one, two): - # Comparing lists-of-dicts is non-trivial in python3 because we can - # neither sort them (dicts are uncomparable), nor turn them into sets - # (dicts are unhashable). This is close enough. - self.failUnlessEqual(len(one), len(two), (one,two)) - for d in one: - self.failUnlessIn(d, two) - - def test_message(self): - # exercise POST /add - d = self.post("allocate", {"appid": "app1", "side": "abc"}) - def _allocated(data): - self.cid = data["channelid"] - d.addCallback(_allocated) - - d.addCallback(lambda _: self.add_message("msg1A")) - def _check1(data): - self.check_welcome(data) - self.failUnlessEqual(strip_messages(data["messages"]), - [{"phase": "1", "body": "msg1A"}]) - d.addCallback(_check1) - d.addCallback(lambda _: self.get("get", "app1", str(self.cid))) - d.addCallback(_check1) - d.addCallback(lambda _: self.add_message("msg1B", side="def")) - def _check2(data): - self.check_welcome(data) - self.failUnlessEqual(self.parse_messages(strip_messages(data["messages"])), - set([("1", "msg1A"), - ("1", "msg1B")])) - d.addCallback(_check2) - d.addCallback(lambda _: self.get("get", "app1", str(self.cid))) - d.addCallback(_check2) - - # adding a duplicate message is not an error, is ignored by clients - d.addCallback(lambda _: self.add_message("msg1B", side="def")) - def _check3(data): - self.check_welcome(data) - self.failUnlessEqual(self.parse_messages(strip_messages(data["messages"])), - set([("1", "msg1A"), - ("1", "msg1B")])) - d.addCallback(_check3) - d.addCallback(lambda _: self.get("get", "app1", str(self.cid))) - d.addCallback(_check3) - - d.addCallback(lambda _: self.add_message("msg2A", side="abc", - phase="2")) - def _check4(data): - self.check_welcome(data) - self.failUnlessEqual(self.parse_messages(strip_messages(data["messages"])), - set([("1", "msg1A"), - ("1", "msg1B"), - ("2", "msg2A"), - ])) - d.addCallback(_check4) - d.addCallback(lambda _: self.get("get", "app1", str(self.cid))) - d.addCallback(_check4) - - return d - - def test_watch_message(self): - # exercise GET /get (the EventSource version) - # this API is scheduled to be removed after 0.6.0 - return self._do_watch("get") - - def test_watch(self): - # exercise GET /watch (the EventSource version) - return self._do_watch("watch") - - def _do_watch(self, endpoint_name): - d = self.post("allocate", {"appid": "app1", "side": "abc"}) - def _allocated(data): - self.cid = data["channelid"] - url = self.build_url(endpoint_name, "app1", self.cid) - self.o = OneEventAtATime(url, parser=json.loads) - return self.o.wait_for_connection() - d.addCallback(_allocated) - d.addCallback(lambda _: self.o.wait_for_next_event()) - def _check_welcome(ev): - eventtype, data = ev - self.failUnlessEqual(eventtype, "welcome") - self.failUnlessEqual(data, {"current_version": __version__}) - d.addCallback(_check_welcome) - d.addCallback(lambda _: self.add_message("msg1A")) - d.addCallback(lambda _: self.o.wait_for_next_event()) - def _check_msg1(ev): - eventtype, data = ev - self.failUnlessEqual(eventtype, "message") - data.pop("sent", None) - self.failUnlessEqual(strip_message(data), - {"phase": "1", "body": "msg1A"}) - d.addCallback(_check_msg1) - - d.addCallback(lambda _: self.add_message("msg1B")) - d.addCallback(lambda _: self.add_message("msg2A", phase="2")) - d.addCallback(lambda _: self.o.wait_for_next_event()) - def _check_msg2(ev): - eventtype, data = ev - self.failUnlessEqual(eventtype, "message") - data.pop("sent", None) - self.failUnlessEqual(strip_message(data), - {"phase": "1", "body": "msg1B"}) - d.addCallback(_check_msg2) - d.addCallback(lambda _: self.o.wait_for_next_event()) - def _check_msg3(ev): - eventtype, data = ev - self.failUnlessEqual(eventtype, "message") - data.pop("sent", None) - self.failUnlessEqual(strip_message(data), - {"phase": "2", "body": "msg2A"}) - d.addCallback(_check_msg3) - - d.addCallback(lambda _: self.o.close()) - d.addCallback(lambda _: self.o.wait_for_disconnection()) - return d - -class OneEventAtATime: - def __init__(self, url, parser=lambda e: e): - self.parser = parser - self.d = None - self._connected = False - self.connected_d = defer.Deferred() - self.disconnected_d = defer.Deferred() - self.events = [] - self.es = EventSource(url, self.handler, when_connected=self.connected) - d = self.es.start() - d.addBoth(self.disconnected) - - def close(self): - self.es.cancel() - - def wait_for_next_event(self): - assert not self.d - if self.events: - event = self.events.pop(0) - return defer.succeed(event) - self.d = defer.Deferred() - return self.d - - def handler(self, eventtype, data): - event = (eventtype, self.parser(data)) - if self.d: - assert not self.events - d,self.d = self.d,None - d.callback(event) - return - self.events.append(event) - - def wait_for_connection(self): - return self.connected_d - def connected(self): - self._connected = True - self.connected_d.callback(None) - - def wait_for_disconnection(self): - return self.disconnected_d - def disconnected(self, why): - if not self._connected: - self.connected_d.errback(why) - self.disconnected_d.callback((why,)) - class WSClient(websocket.WebSocketClientProtocol): def __init__(self): websocket.WebSocketClientProtocol.__init__(self) @@ -425,6 +335,10 @@ class WSClient(websocket.WebSocketClientProtocol): payload = json.dumps(kwargs).encode("utf-8") self.sendMessage(payload, False) + def send_notype(self, **kwargs): + payload = json.dumps(kwargs).encode("utf-8") + self.sendMessage(payload, False) + @inlineCallbacks def sync(self): ping = next(self.ping_counter) @@ -520,7 +434,7 @@ class WebSocketAPI(ServerBase, unittest.TestCase): @inlineCallbacks def make_client(self): - f = WSFactory(self.rdv_ws_url) + f = WSFactory(self.relayurl) f.d = defer.Deferred() reactor.connectTCP("127.0.0.1", self.rdv_ws_port, f) c = yield f.d @@ -539,289 +453,332 @@ class WebSocketAPI(ServerBase, unittest.TestCase): self.assertEqual(self._rendezvous._apps, {}) @inlineCallbacks - def test_allocate_1(self): + def test_bind(self): c1 = yield self.make_client() - msg = yield c1.next_non_ack() - self.check_welcome(msg) + yield c1.next_non_ack() + + c1.send(u"bind", appid=u"appid") # missing side= + err = yield c1.next_non_ack() + self.assertEqual(err[u"type"], u"error") + self.assertEqual(err[u"error"], u"bind requires 'side'") + + c1.send(u"bind", side=u"side") # missing appid= + err = yield c1.next_non_ack() + self.assertEqual(err[u"type"], u"error") + self.assertEqual(err[u"error"], u"bind requires 'appid'") + c1.send(u"bind", appid=u"appid", side=u"side") yield c1.sync() self.assertEqual(list(self._rendezvous._apps.keys()), [u"appid"]) + + c1.send(u"bind", appid=u"appid", side=u"side") # duplicate + err = yield c1.next_non_ack() + self.assertEqual(err[u"type"], u"error") + self.assertEqual(err[u"error"], u"already bound") + + c1.send_notype(other="misc") # missing 'type' + err = yield c1.next_non_ack() + self.assertEqual(err[u"type"], u"error") + self.assertEqual(err[u"error"], u"missing 'type'") + + c1.send("___unknown") # unknown type + err = yield c1.next_non_ack() + self.assertEqual(err[u"type"], u"error") + self.assertEqual(err[u"error"], u"unknown type") + + c1.send("ping") # missing 'ping' + err = yield c1.next_non_ack() + self.assertEqual(err[u"type"], u"error") + self.assertEqual(err[u"error"], u"ping requires 'ping'") + + @inlineCallbacks + def test_list(self): + c1 = yield self.make_client() + yield c1.next_non_ack() + + c1.send(u"list") # too early, must bind first + err = yield c1.next_non_ack() + self.assertEqual(err[u"type"], u"error") + self.assertEqual(err[u"error"], u"must bind first") + + c1.send(u"bind", appid=u"appid", side=u"side") + c1.send(u"list") + m = yield c1.next_non_ack() + self.assertEqual(m[u"type"], u"nameplates") + self.assertEqual(m[u"nameplates"], []) + app = self._rendezvous.get_app(u"appid") - self.assertEqual(app.get_allocated(), set()) - c1.send(u"list") - msg = yield c1.next_non_ack() - self.assertEqual(msg["type"], u"channelids") - self.assertEqual(msg["channelids"], []) - - c1.send(u"allocate") - msg = yield c1.next_non_ack() - self.assertEqual(msg["type"], u"allocated") - cid = msg["channelid"] - self.failUnlessIsInstance(cid, int) - self.assertEqual(app.get_allocated(), set([cid])) - channel = app.get_channel(cid) - self.assertEqual(channel.get_messages(), []) + nameplate_id1 = app.allocate_nameplate(u"side", 0) + app.claim_nameplate(u"np2", u"side", 0) c1.send(u"list") - msg = yield c1.next_non_ack() - self.assertEqual(msg["type"], u"channelids") - self.assertEqual(msg["channelids"], [cid]) - - c1.send(u"deallocate") - msg = yield c1.next_non_ack() - self.assertEqual(msg["type"], u"deallocated") - self.assertEqual(msg["status"], u"deleted") - self.assertEqual(app.get_allocated(), set()) - - c1.send(u"list") - msg = yield c1.next_non_ack() - self.assertEqual(msg["type"], u"channelids") - self.assertEqual(msg["channelids"], []) + m = yield c1.next_non_ack() + self.assertEqual(m[u"type"], u"nameplates") + nids = set() + for n in m[u"nameplates"]: + self.assertEqual(type(n), dict) + self.assertEqual(list(n.keys()), [u"id"]) + nids.add(n[u"id"]) + self.assertEqual(nids, set([nameplate_id1, u"np2"])) @inlineCallbacks - def test_allocate_2(self): + def test_allocate(self): c1 = yield self.make_client() - msg = yield c1.next_non_ack() - self.check_welcome(msg) + yield c1.next_non_ack() + + c1.send(u"allocate") # too early, must bind first + err = yield c1.next_non_ack() + self.assertEqual(err[u"type"], u"error") + self.assertEqual(err[u"error"], u"must bind first") + c1.send(u"bind", appid=u"appid", side=u"side") - yield c1.sync() app = self._rendezvous.get_app(u"appid") - self.assertEqual(app.get_allocated(), set()) c1.send(u"allocate") - msg = yield c1.next_non_ack() - self.assertEqual(msg["type"], u"allocated") - cid = msg["channelid"] - self.failUnlessIsInstance(cid, int) - self.assertEqual(app.get_allocated(), set([cid])) - channel = app.get_channel(cid) - self.assertEqual(channel.get_messages(), []) + m = yield c1.next_non_ack() + self.assertEqual(m[u"type"], u"allocated") + nameplate_id = m[u"nameplate"] - # second caller increases the number of known sides to 2 - c2 = yield self.make_client() - msg = yield c2.next_non_ack() - self.check_welcome(msg) - c2.send(u"bind", appid=u"appid", side=u"side-2") - c2.send(u"claim", channelid=cid) - c2.send(u"add", phase="1", body="") - yield c2.sync() + nids = app.get_nameplate_ids() + self.assertEqual(len(nids), 1) + self.assertEqual(nameplate_id, list(nids)[0]) - self.assertEqual(app.get_allocated(), set([cid])) - self.assertEqual(strip_messages(channel.get_messages()), - [{"phase": "1", "body": ""}]) - - c1.send(u"list") - msg = yield c1.next_non_ack() - self.assertEqual(msg["type"], u"channelids") - self.assertEqual(msg["channelids"], [cid]) - - c2.send(u"list") - msg = yield c2.next_non_ack() - self.assertEqual(msg["type"], u"channelids") - self.assertEqual(msg["channelids"], [cid]) - - c1.send(u"deallocate") - msg = yield c1.next_non_ack() - self.assertEqual(msg["type"], u"deallocated") - self.assertEqual(msg["status"], u"waiting") - - c2.send(u"deallocate") - msg = yield c2.next_non_ack() - self.assertEqual(msg["type"], u"deallocated") - self.assertEqual(msg["status"], u"deleted") - - c2.send(u"list") - msg = yield c2.next_non_ack() - self.assertEqual(msg["type"], u"channelids") - self.assertEqual(msg["channelids"], []) - - @inlineCallbacks - def test_allocate_and_claim(self): - c1 = yield self.make_client() - msg = yield c1.next_non_ack() - self.check_welcome(msg) - c1.send(u"bind", appid=u"appid", side=u"side") c1.send(u"allocate") - msg = yield c1.next_non_ack() - self.assertEqual(msg["type"], u"allocated") - cid = msg["channelid"] - c1.send(u"claim", channelid=cid) + err = yield c1.next_non_ack() + self.assertEqual(err[u"type"], u"error") + self.assertEqual(err[u"error"], + u"you already allocated one, don't be greedy") + + c1.send(u"claim", nameplate=nameplate_id) # allocate+claim is ok yield c1.sync() - # there should no error - self.assertEqual(c1.errors, []) + row = app._db.execute("SELECT * FROM `nameplates`" + " WHERE `app_id`='appid' AND `id`=?", + (nameplate_id,)).fetchone() + self.assertEqual(row["side1"], u"side") + self.assertEqual(row["side2"], None) @inlineCallbacks - def test_allocate_and_claim_different(self): + def test_claim(self): c1 = yield self.make_client() - msg = yield c1.next_non_ack() - self.check_welcome(msg) + yield c1.next_non_ack() c1.send(u"bind", appid=u"appid", side=u"side") - c1.send(u"allocate") - msg = yield c1.next_non_ack() - self.assertEqual(msg["type"], u"allocated") - cid = msg["channelid"] - c1.send(u"claim", channelid=cid+1) - yield c1.sync() - # that should signal an error - self.assertEqual(len(c1.errors), 1, c1.errors) - msg = c1.errors[0] - self.assertEqual(msg["type"], "error") - self.assertEqual(msg["error"], "Already bound to channelid %d" % cid) - self.assertEqual(msg["orig"], {"type": "claim", "channelid": cid+1}) - - @inlineCallbacks - def test_message(self): - c1 = yield self.make_client() - msg = yield c1.next_non_ack() - self.check_welcome(msg) - c1.send(u"bind", appid=u"appid", side=u"side") - c1.send(u"allocate") - msg = yield c1.next_non_ack() - self.assertEqual(msg["type"], u"allocated") - cid = msg["channelid"] app = self._rendezvous.get_app(u"appid") - channel = app.get_channel(cid) - self.assertEqual(channel.get_messages(), []) - c1.send(u"watch") - yield c1.sync() - self.assertEqual(len(channel._listeners), 1) - c1.strip_acks() - self.assertEqual(c1.events, []) + c1.send(u"claim") # missing nameplate= + err = yield c1.next_non_ack() + self.assertEqual(err[u"type"], u"error") + self.assertEqual(err[u"error"], u"claim requires 'nameplate'") - c1.send(u"add", phase="1", body="msg1A") - yield c1.sync() - c1.strip_acks() - self.assertEqual(strip_messages(channel.get_messages()), - [{"phase": "1", "body": "msg1A"}]) - self.assertEqual(len(c1.events), 1) # echo should be sent right away - msg = yield c1.next_non_ack() - self.assertEqual(msg["type"], "message") - self.assertEqual(strip_message(msg["message"]), - {"phase": "1", "body": "msg1A"}) - self.assertIn("server_tx", msg) - self.assertIsInstance(msg["server_tx"], float) + c1.send(u"claim", nameplate=u"np1") + m = yield c1.next_non_ack() + self.assertEqual(m[u"type"], u"claimed") + mailbox_id = m[u"mailbox"] + self.assertEqual(type(mailbox_id), type(u"")) - c1.send(u"add", phase="1", body="msg1B") - c1.send(u"add", phase="2", body="msg2A") + nids = app.get_nameplate_ids() + self.assertEqual(len(nids), 1) + self.assertEqual(u"np1", list(nids)[0]) - msg = yield c1.next_non_ack() - self.assertEqual(msg["type"], "message") - self.assertEqual(strip_message(msg["message"]), - {"phase": "1", "body": "msg1B"}) + # claiming a nameplate will assign a random mailbox id, but won't + # create the mailbox itself + mailboxes = app._db.execute("SELECT * FROM `mailboxes`" + " WHERE `app_id`='appid'").fetchall() + self.assertEqual(len(mailboxes), 0) - msg = yield c1.next_non_ack() - self.assertEqual(msg["type"], "message") - self.assertEqual(strip_message(msg["message"]), - {"phase": "2", "body": "msg2A"}) + @inlineCallbacks + def test_claim_crowded(self): + c1 = yield self.make_client() + yield c1.next_non_ack() + c1.send(u"bind", appid=u"appid", side=u"side") + app = self._rendezvous.get_app(u"appid") - self.assertEqual(strip_messages(channel.get_messages()), [ - {"phase": "1", "body": "msg1A"}, - {"phase": "1", "body": "msg1B"}, - {"phase": "2", "body": "msg2A"}, - ]) + app.claim_nameplate(u"np1", u"side1", 0) + app.claim_nameplate(u"np1", u"side2", 0) - # second client should see everything - c2 = yield self.make_client() - msg = yield c2.next_non_ack() - self.check_welcome(msg) - c2.send(u"bind", appid=u"appid", side=u"side") - c2.send(u"claim", channelid=cid) - # 'watch' triggers delivery of old messages, in temporal order - c2.send(u"watch") + # the third claim will signal crowding + c1.send(u"claim", nameplate=u"np1") + err = yield c1.next_non_ack() + self.assertEqual(err[u"type"], u"error") + self.assertEqual(err[u"error"], u"crowded") - msg = yield c2.next_non_ack() - self.assertEqual(msg["type"], "message") - self.assertEqual(strip_message(msg["message"]), - {"phase": "1", "body": "msg1A"}) + @inlineCallbacks + def test_release(self): + c1 = yield self.make_client() + yield c1.next_non_ack() + c1.send(u"bind", appid=u"appid", side=u"side") + app = self._rendezvous.get_app(u"appid") - msg = yield c2.next_non_ack() - self.assertEqual(msg["type"], "message") - self.assertEqual(strip_message(msg["message"]), - {"phase": "1", "body": "msg1B"}) + app.claim_nameplate(u"np1", u"side2", 0) - msg = yield c2.next_non_ack() - self.assertEqual(msg["type"], "message") - self.assertEqual(strip_message(msg["message"]), - {"phase": "2", "body": "msg2A"}) + c1.send(u"release") # didn't do claim first + err = yield c1.next_non_ack() + self.assertEqual(err[u"type"], u"error") + self.assertEqual(err[u"error"], + u"must claim a nameplate before releasing it") - # adding a duplicate is not an error, and clients will ignore it - c1.send(u"add", phase="2", body="msg2A") + c1.send(u"claim", nameplate=u"np1") + yield c1.next_non_ack() - # the duplicate message *does* get stored, and delivered - msg = yield c2.next_non_ack() - self.assertEqual(msg["type"], "message") - self.assertEqual(strip_message(msg["message"]), - {"phase": "2", "body": "msg2A"}) + c1.send(u"release") + m = yield c1.next_non_ack() + self.assertEqual(m[u"type"], u"released") + + row = app._db.execute("SELECT * FROM `nameplates`" + " WHERE `app_id`='appid' AND `id`='np1'").fetchone() + self.assertEqual(row["side1"], u"side2") + self.assertEqual(row["side2"], None) + + c1.send(u"release") # no longer claimed + err = yield c1.next_non_ack() + self.assertEqual(err[u"type"], u"error") + self.assertEqual(err[u"error"], + u"must claim a nameplate before releasing it") + + @inlineCallbacks + def test_open(self): + c1 = yield self.make_client() + yield c1.next_non_ack() + c1.send(u"bind", appid=u"appid", side=u"side") + app = self._rendezvous.get_app(u"appid") + + c1.send(u"open") # missing mailbox= + err = yield c1.next_non_ack() + self.assertEqual(err[u"type"], u"error") + self.assertEqual(err[u"error"], u"open requires 'mailbox'") + + mb1 = app.open_mailbox(u"mb1", u"side2", 0) + mb1.add_message(SidedMessage(side=u"side2", phase=u"phase", + body=u"body", server_rx=0, + msg_id=u"msgid")) + + c1.send(u"open", mailbox=u"mb1") + m = yield c1.next_non_ack() + self.assertEqual(m[u"type"], u"message") + self.assertEqual(m[u"body"], u"body") + + mb1.add_message(SidedMessage(side=u"side2", phase=u"phase2", + body=u"body2", server_rx=0, + msg_id=u"msgid")) + m = yield c1.next_non_ack() + self.assertEqual(m[u"type"], u"message") + self.assertEqual(m[u"body"], u"body2") + + c1.send(u"open", mailbox=u"mb1") + err = yield c1.next_non_ack() + self.assertEqual(err[u"type"], u"error") + self.assertEqual(err[u"error"], u"you already have a mailbox open") + + @inlineCallbacks + def test_add(self): + c1 = yield self.make_client() + yield c1.next_non_ack() + c1.send(u"bind", appid=u"appid", side=u"side") + app = self._rendezvous.get_app(u"appid") + mb1 = app.open_mailbox(u"mb1", u"side2", 0) + l1 = []; stop1 = []; stop1_f = lambda: stop1.append(True) + mb1.add_listener("handle1", l1.append, stop1_f) + + c1.send(u"add") # didn't open first + err = yield c1.next_non_ack() + self.assertEqual(err[u"type"], u"error") + self.assertEqual(err[u"error"], u"must open mailbox before adding") + + c1.send(u"open", mailbox=u"mb1") + + c1.send(u"add", body=u"body") # missing phase= + err = yield c1.next_non_ack() + self.assertEqual(err[u"type"], u"error") + self.assertEqual(err[u"error"], u"missing 'phase'") + + c1.send(u"add", phase=u"phase") # missing body= + err = yield c1.next_non_ack() + self.assertEqual(err[u"type"], u"error") + self.assertEqual(err[u"error"], u"missing 'body'") + + c1.send(u"add", phase=u"phase", body=u"body") + m = yield c1.next_non_ack() # echoed back + self.assertEqual(m[u"type"], u"message") + self.assertEqual(m[u"body"], u"body") + + self.assertEqual(len(l1), 1) + self.assertEqual(l1[0].body, u"body") + + @inlineCallbacks + def test_close(self): + c1 = yield self.make_client() + yield c1.next_non_ack() + c1.send(u"bind", appid=u"appid", side=u"side") + + c1.send(u"close", mood=u"mood") # must open first + err = yield c1.next_non_ack() + self.assertEqual(err[u"type"], u"error") + self.assertEqual(err[u"error"], u"must open mailbox before closing") + + c1.send(u"open", mailbox=u"mb1") + c1.send(u"close", mood=u"mood") + m = yield c1.next_non_ack() + self.assertEqual(m[u"type"], u"closed") + + c1.send(u"close", mood=u"mood") # already closed + err = yield c1.next_non_ack() + self.assertEqual(err[u"type"], u"error") + self.assertEqual(err[u"error"], u"must open mailbox before closing") class Summary(unittest.TestCase): - def test_summarize(self): - c = rendezvous.Channel(None, None, None, None, False, None, None) - A = rendezvous.ALLOCATE - D = rendezvous.DEALLOCATE + def test_mailbox(self): + c = rendezvous.Mailbox(None, None, None, False, None, None) + # starts at time 1, maybe gets second open at time 3, closes at 5 + base_row = {u"started": 1, u"second": None, + u"first_mood": None, u"crowded": False} + def summ(num_sides, second_mood=None, pruned=False, **kwargs): + row = base_row.copy() + row.update(kwargs) + return c._summarize(row, num_sides, second_mood, 5, pruned) - messages = [{"server_rx": 1, "side": "a", "phase": A}] - self.failUnlessEqual(c._summarize(messages, 2), - (1, "lonely", 1, None)) + self.assertEqual(summ(1), Usage(1, None, 4, u"lonely")) + self.assertEqual(summ(1, u"lonely"), Usage(1, None, 4, u"lonely")) + self.assertEqual(summ(1, u"errory"), Usage(1, None, 4, u"errory")) + self.assertEqual(summ(1, crowded=True), Usage(1, None, 4, u"crowded")) - messages = [{"server_rx": 1, "side": "a", "phase": A}, - {"server_rx": 2, "side": "a", "phase": D, "body": "lonely"}, - ] - self.failUnlessEqual(c._summarize(messages, 3), - (1, "lonely", 2, None)) + self.assertEqual(summ(2, first_mood=u"happy", + second=3, second_mood=u"happy"), + Usage(1, 2, 4, u"happy")) - messages = [{"server_rx": 1, "side": "a", "phase": A}, - {"server_rx": 2, "side": "b", "phase": A}, - {"server_rx": 3, "side": "c", "phase": A}, - ] - self.failUnlessEqual(c._summarize(messages, 4), - (1, "crowded", 3, None)) + self.assertEqual(summ(2, first_mood=u"errory", + second=3, second_mood=u"happy"), + Usage(1, 2, 4, u"errory")) - base = [{"server_rx": 1, "side": "a", "phase": A}, - {"server_rx": 2, "side": "a", "phase": "pake", "body": "msg1"}, - {"server_rx": 10, "side": "b", "phase": "pake", "body": "msg2"}, - {"server_rx": 11, "side": "b", "phase": "data", "body": "msg3"}, - {"server_rx": 20, "side": "a", "phase": "data", "body": "msg4"}, - ] - def make_moods(A_mood, B_mood): - return base + [ - {"server_rx": 21, "side": "a", "phase": D, "body": A_mood}, - {"server_rx": 30, "side": "b", "phase": D, "body": B_mood}, - ] + self.assertEqual(summ(2, first_mood=u"happy", + second=3, second_mood=u"errory"), + Usage(1, 2, 4, u"errory")) - self.failUnlessEqual(c._summarize(make_moods("happy", "happy"), 41), - (1, "happy", 40, 9)) + self.assertEqual(summ(2, first_mood=u"scary", + second=3, second_mood=u"happy"), + Usage(1, 2, 4, u"scary")) - self.failUnlessEqual(c._summarize(make_moods("scary", "happy"), 41), - (1, "scary", 40, 9)) - self.failUnlessEqual(c._summarize(make_moods("happy", "scary"), 41), - (1, "scary", 40, 9)) + self.assertEqual(summ(2, first_mood=u"scary", + second=3, second_mood=u"errory"), + Usage(1, 2, 4, u"scary")) - self.failUnlessEqual(c._summarize(make_moods("lonely", "happy"), 41), - (1, "lonely", 40, 9)) - self.failUnlessEqual(c._summarize(make_moods("happy", "lonely"), 41), - (1, "lonely", 40, 9)) + self.assertEqual(summ(2, first_mood=u"happy", second=3, pruned=True), + Usage(1, 2, 4, u"pruney")) - self.failUnlessEqual(c._summarize(make_moods("errory", "happy"), 41), - (1, "errory", 40, 9)) - self.failUnlessEqual(c._summarize(make_moods("happy", "errory"), 41), - (1, "errory", 40, 9)) + def test_nameplate(self): + a = rendezvous.AppNamespace(None, None, None, False, None) + # starts at time 1, maybe gets second open at time 3, closes at 5 + base_row = {u"started": 1, u"second": None, u"crowded": False} + def summ(num_sides, pruned=False, **kwargs): + row = base_row.copy() + row.update(kwargs) + return a._summarize_nameplate_usage(row, 5, pruned) - # scary trumps other moods - self.failUnlessEqual(c._summarize(make_moods("scary", "lonely"), 41), - (1, "scary", 40, 9)) - self.failUnlessEqual(c._summarize(make_moods("scary", "errory"), 41), - (1, "scary", 40, 9)) + self.assertEqual(summ(1), Usage(1, None, 4, u"lonely")) + self.assertEqual(summ(1, crowded=True), Usage(1, None, 4, u"crowded")) - # older clients don't send a mood - self.failUnlessEqual(c._summarize(make_moods(None, None), 41), - (1, "quiet", 40, 9)) - self.failUnlessEqual(c._summarize(make_moods(None, "happy"), 41), - (1, "quiet", 40, 9)) - self.failUnlessEqual(c._summarize(make_moods(None, "happy"), 41), - (1, "quiet", 40, 9)) - self.failUnlessEqual(c._summarize(make_moods(None, "scary"), 41), - (1, "scary", 40, 9)) + self.assertEqual(summ(2, second=3), Usage(1, 2, 4, u"happy")) + + self.assertEqual(summ(2, second=3, pruned=True), + Usage(1, 2, 4, u"pruney")) class Accumulator(protocol.Protocol): def __init__(self): diff --git a/src/wormhole/test/test_twisted.py b/src/wormhole/test/test_twisted.py deleted file mode 100644 index 1c186da..0000000 --- a/src/wormhole/test/test_twisted.py +++ /dev/null @@ -1,243 +0,0 @@ -from __future__ import print_function -import json -from twisted.trial import unittest -from twisted.internet.defer import gatherResults, inlineCallbacks -from ..twisted.transcribe import Wormhole, UsageError, WrongPasswordError -from .common import ServerBase - -APPID = u"appid" - -class Basic(ServerBase, unittest.TestCase): - - def doBoth(self, d1, d2): - return gatherResults([d1, d2], True) - - @inlineCallbacks - def test_basic(self): - w1 = Wormhole(APPID, self.relayurl) - w2 = Wormhole(APPID, self.relayurl) - code = yield w1.get_code() - w2.set_code(code) - yield self.doBoth(w1.send_data(b"data1"), w2.send_data(b"data2")) - dl = yield self.doBoth(w1.get_data(), w2.get_data()) - (dataX, dataY) = dl - self.assertEqual(dataX, b"data2") - self.assertEqual(dataY, b"data1") - yield self.doBoth(w1.close(), w2.close()) - - @inlineCallbacks - def test_same_message(self): - # the two sides use random nonces for their messages, so it's ok for - # both to try and send the same body: they'll result in distinct - # encrypted messages - w1 = Wormhole(APPID, self.relayurl) - w2 = Wormhole(APPID, self.relayurl) - code = yield w1.get_code() - w2.set_code(code) - yield self.doBoth(w1.send_data(b"data"), w2.send_data(b"data")) - dl = yield self.doBoth(w1.get_data(), w2.get_data()) - (dataX, dataY) = dl - self.assertEqual(dataX, b"data") - self.assertEqual(dataY, b"data") - yield self.doBoth(w1.close(), w2.close()) - - @inlineCallbacks - def test_interleaved(self): - w1 = Wormhole(APPID, self.relayurl) - w2 = Wormhole(APPID, self.relayurl) - code = yield w1.get_code() - w2.set_code(code) - res = yield self.doBoth(w1.send_data(b"data1"), w2.get_data()) - (_, dataY) = res - self.assertEqual(dataY, b"data1") - dl = yield self.doBoth(w1.get_data(), w2.send_data(b"data2")) - (dataX, _) = dl - self.assertEqual(dataX, b"data2") - yield self.doBoth(w1.close(), w2.close()) - - @inlineCallbacks - def test_fixed_code(self): - w1 = Wormhole(APPID, self.relayurl) - w2 = Wormhole(APPID, self.relayurl) - w1.set_code(u"123-purple-elephant") - w2.set_code(u"123-purple-elephant") - yield self.doBoth(w1.send_data(b"data1"), w2.send_data(b"data2")) - dl = yield self.doBoth(w1.get_data(), w2.get_data()) - (dataX, dataY) = dl - self.assertEqual(dataX, b"data2") - self.assertEqual(dataY, b"data1") - yield self.doBoth(w1.close(), w2.close()) - - - @inlineCallbacks - def test_phases(self): - w1 = Wormhole(APPID, self.relayurl) - w2 = Wormhole(APPID, self.relayurl) - w1.set_code(u"123-purple-elephant") - w2.set_code(u"123-purple-elephant") - yield self.doBoth(w1.send_data(b"data1", u"p1"), - w2.send_data(b"data2", u"p1")) - yield self.doBoth(w1.send_data(b"data3", u"p2"), - w2.send_data(b"data4", u"p2")) - dl = yield self.doBoth(w1.get_data(u"p2"), - w2.get_data(u"p1")) - (dataX, dataY) = dl - self.assertEqual(dataX, b"data4") - self.assertEqual(dataY, b"data1") - dl = yield self.doBoth(w1.get_data(u"p1"), - w2.get_data(u"p2")) - (dataX, dataY) = dl - self.assertEqual(dataX, b"data2") - self.assertEqual(dataY, b"data3") - yield self.doBoth(w1.close(), w2.close()) - - @inlineCallbacks - def test_wrong_password(self): - w1 = Wormhole(APPID, self.relayurl) - w2 = Wormhole(APPID, self.relayurl) - code = yield w1.get_code() - w2.set_code(code+"not") - - # w2 can't throw WrongPasswordError until it sees a CONFIRM message, - # and w1 won't send CONFIRM until it sees a PAKE message, which w2 - # won't send until we call get_data. So we need both sides to be - # running at the same time for this test. - d1 = w1.send_data(b"data1") - # at this point, w1 should be waiting for w2.PAKE - - yield self.assertFailure(w2.get_data(), WrongPasswordError) - # * w2 will send w2.PAKE, wait for (and get) w1.PAKE, compute a key, - # send w2.CONFIRM, then wait for w1.DATA. - # * w1 will get w2.PAKE, compute a key, send w1.CONFIRM. - # * w1 might also get w2.CONFIRM, and may notice the error before it - # sends w1.CONFIRM, in which case the wait=True will signal an - # error inside _get_master_key() (inside send_data), and d1 will - # errback. - # * but w1 might not see w2.CONFIRM yet, in which case it won't - # errback until we do w1.get_data() - # * w2 gets w1.CONFIRM, notices the error, records it. - # * w2 (waiting for w1.DATA) wakes up, sees the error, throws - # * meanwhile w1 finishes sending its data. w2.CONFIRM may or may not - # have arrived by then - try: - yield d1 - except WrongPasswordError: - pass - - # When we ask w1 to get_data(), one of two things might happen: - # * if w2.CONFIRM arrived already, it will have recorded the error. - # When w1.get_data() sleeps (waiting for w2.DATA), we'll notice the - # error before sleeping, and throw WrongPasswordError - # * if w2.CONFIRM hasn't arrived yet, we'll sleep. When w2.CONFIRM - # arrives, we notice and record the error, and wake up, and throw - - # Note that we didn't do w2.send_data(), so we're hoping that w1 will - # have enough information to detect the error before it sleeps - # (waiting for w2.DATA). Checking for the error both before sleeping - # and after waking up makes this happen. - - # so now w1 should have enough information to throw too - yield self.assertFailure(w1.get_data(), WrongPasswordError) - - # both sides are closed automatically upon error, but it's still - # legal to call .close(), and should be idempotent - yield self.doBoth(w1.close(), w2.close()) - - @inlineCallbacks - def test_no_confirm(self): - # newer versions (which check confirmations) should will work with - # older versions (that don't send confirmations) - w1 = Wormhole(APPID, self.relayurl) - w1._send_confirm = False - w2 = Wormhole(APPID, self.relayurl) - - code = yield w1.get_code() - w2.set_code(code) - dl = yield self.doBoth(w1.send_data(b"data1"), w2.get_data()) - self.assertEqual(dl[1], b"data1") - dl = yield self.doBoth(w1.get_data(), w2.send_data(b"data2")) - self.assertEqual(dl[0], b"data2") - yield self.doBoth(w1.close(), w2.close()) - - @inlineCallbacks - def test_verifier(self): - w1 = Wormhole(APPID, self.relayurl) - w2 = Wormhole(APPID, self.relayurl) - code = yield w1.get_code() - w2.set_code(code) - res = yield self.doBoth(w1.get_verifier(), w2.get_verifier()) - v1, v2 = res - self.failUnlessEqual(type(v1), type(b"")) - self.failUnlessEqual(v1, v2) - yield self.doBoth(w1.send_data(b"data1"), w2.send_data(b"data2")) - dl = yield self.doBoth(w1.get_data(), w2.get_data()) - (dataX, dataY) = dl - self.assertEqual(dataX, b"data2") - self.assertEqual(dataY, b"data1") - yield self.doBoth(w1.close(), w2.close()) - - @inlineCallbacks - def test_errors(self): - w1 = Wormhole(APPID, self.relayurl) - yield self.assertFailure(w1.get_verifier(), UsageError) - yield self.assertFailure(w1.send_data(b"data"), UsageError) - yield self.assertFailure(w1.get_data(), UsageError) - w1.set_code(u"123-purple-elephant") - yield self.assertRaises(UsageError, w1.set_code, u"123-nope") - yield self.assertFailure(w1.get_code(), UsageError) - w2 = Wormhole(APPID, self.relayurl) - yield w2.get_code() - yield self.assertFailure(w2.get_code(), UsageError) - yield self.doBoth(w1.close(), w2.close()) - - @inlineCallbacks - def test_repeat_phases(self): - w1 = Wormhole(APPID, self.relayurl) - w1.set_code(u"123-purple-elephant") - w2 = Wormhole(APPID, self.relayurl) - w2.set_code(u"123-purple-elephant") - # we must let them establish a key before we can send data - yield self.doBoth(w1.get_verifier(), w2.get_verifier()) - yield w1.send_data(b"data1", phase=u"1") - # underscore-prefixed phases are reserved - yield self.assertFailure(w1.send_data(b"data1", phase=u"_1"), - UsageError) - yield self.assertFailure(w1.get_data(phase=u"_1"), UsageError) - # you can't send twice to the same phase - yield self.assertFailure(w1.send_data(b"data1", phase=u"1"), - UsageError) - # but you can send to a different one - yield w1.send_data(b"data2", phase=u"2") - res = yield w2.get_data(phase=u"1") - self.failUnlessEqual(res, b"data1") - # and you can't read twice from the same phase - yield self.assertFailure(w2.get_data(phase=u"1"), UsageError) - # but you can read from a different one - res = yield w2.get_data(phase=u"2") - self.failUnlessEqual(res, b"data2") - yield self.doBoth(w1.close(), w2.close()) - - @inlineCallbacks - def test_serialize(self): - w1 = Wormhole(APPID, self.relayurl) - self.assertRaises(UsageError, w1.serialize) # too early - w2 = Wormhole(APPID, self.relayurl) - code = yield w1.get_code() - self.assertRaises(UsageError, w2.serialize) # too early - w2.set_code(code) - w2.serialize() # ok - s = w1.serialize() - self.assertEqual(type(s), type("")) - unpacked = json.loads(s) # this is supposed to be JSON - self.assertEqual(type(unpacked), dict) - - self.new_w1 = Wormhole.from_serialized(s) - yield self.doBoth(self.new_w1.send_data(b"data1"), - w2.send_data(b"data2")) - dl = yield self.doBoth(self.new_w1.get_data(), w2.get_data()) - (dataX, dataY) = dl - self.assertEqual((dataX, dataY), (b"data2", b"data1")) - self.assertRaises(UsageError, w2.serialize) # too late - yield gatherResults([w1.close(), w2.close(), self.new_w1.close()], - True) - diff --git a/src/wormhole/test/test_wormhole.py b/src/wormhole/test/test_wormhole.py new file mode 100644 index 0000000..8948535 --- /dev/null +++ b/src/wormhole/test/test_wormhole.py @@ -0,0 +1,789 @@ +from __future__ import print_function +import os, json, re, gc +from binascii import hexlify, unhexlify +import mock +from twisted.trial import unittest +from twisted.internet import reactor +from twisted.internet.defer import Deferred, gatherResults, inlineCallbacks +from .common import ServerBase +from .. import wormhole +from ..errors import WrongPasswordError, WelcomeError, UsageError +from spake2 import SPAKE2_Symmetric +from ..timing import DebugTiming +from nacl.secret import SecretBox + +APPID = u"appid" + +class MockWebSocket: + def __init__(self): + self._payloads = [] + def sendMessage(self, payload, is_binary): + assert not is_binary + self._payloads.append(payload) + + def outbound(self): + out = [] + while self._payloads: + p = self._payloads.pop(0) + out.append(json.loads(p.decode("utf-8"))) + return out + +def response(w, **kwargs): + payload = json.dumps(kwargs).encode("utf-8") + w._ws_dispatch_response(payload) + +class Welcome(unittest.TestCase): + def test_tolerate_no_current_version(self): + w = wormhole._WelcomeHandler(u"relay_url", u"current_version", None) + w.handle_welcome({}) + + def test_print_motd(self): + w = wormhole._WelcomeHandler(u"relay_url", u"current_version", None) + with mock.patch("sys.stderr") as stderr: + w.handle_welcome({u"motd": u"message of\nthe day"}) + self.assertEqual(stderr.method_calls, + [mock.call.write(u"Server (at relay_url) says:\n" + " message of\n the day"), + mock.call.write(u"\n")]) + # motd is only displayed once + with mock.patch("sys.stderr") as stderr2: + w.handle_welcome({u"motd": u"second message"}) + self.assertEqual(stderr2.method_calls, []) + + def test_current_version(self): + w = wormhole._WelcomeHandler(u"relay_url", u"2.0", None) + with mock.patch("sys.stderr") as stderr: + w.handle_welcome({u"current_version": u"2.0"}) + self.assertEqual(stderr.method_calls, []) + + with mock.patch("sys.stderr") as stderr: + w.handle_welcome({u"current_version": u"3.0"}) + exp1 = (u"Warning: errors may occur unless both sides are" + " running the same version") + exp2 = (u"Server claims 3.0 is current, but ours is 2.0") + self.assertEqual(stderr.method_calls, + [mock.call.write(exp1), + mock.call.write(u"\n"), + mock.call.write(exp2), + mock.call.write(u"\n"), + ]) + + # warning is only displayed once + with mock.patch("sys.stderr") as stderr: + w.handle_welcome({u"current_version": u"3.0"}) + self.assertEqual(stderr.method_calls, []) + + def test_non_release_version(self): + w = wormhole._WelcomeHandler(u"relay_url", u"2.0-dirty", None) + with mock.patch("sys.stderr") as stderr: + w.handle_welcome({u"current_version": u"3.0"}) + self.assertEqual(stderr.method_calls, []) + + def test_signal_error(self): + se = mock.Mock() + w = wormhole._WelcomeHandler(u"relay_url", u"2.0", se) + w.handle_welcome({}) + self.assertEqual(se.mock_calls, []) + + w.handle_welcome({u"error": u"oops"}) + self.assertEqual(len(se.mock_calls), 1) + self.assertEqual(len(se.mock_calls[0][1]), 1) # posargs + we = se.mock_calls[0][1][0] + self.assertIsInstance(we, WelcomeError) + self.assertEqual(we.args, (u"oops",)) + # alas WelcomeError instances don't compare against each other + #self.assertEqual(se.mock_calls, [mock.call(WelcomeError(u"oops"))]) + +class InputCode(unittest.TestCase): + def test_list(self): + send_command = mock.Mock() + ic = wormhole._InputCode(None, u"prompt", 2, send_command, + DebugTiming()) + d = ic._list() + self.assertNoResult(d) + self.assertEqual(send_command.mock_calls, [mock.call(u"list")]) + ic._response_handle_nameplates({u"type": u"nameplates", + u"nameplates": [{u"id": u"123"}]}) + res = self.successResultOf(d) + self.assertEqual(res, [u"123"]) + +class GetCode(unittest.TestCase): + def test_get(self): + send_command = mock.Mock() + gc = wormhole._GetCode(2, send_command, DebugTiming()) + d = gc.go() + self.assertNoResult(d) + self.assertEqual(send_command.mock_calls, [mock.call(u"allocate")]) + # TODO: nameplate attributes get added and checked here + gc._response_handle_allocated({u"type": u"allocated", + u"nameplate": u"123"}) + code = self.successResultOf(d) + self.assertIsInstance(code, type(u"")) + self.assert_(code.startswith(u"123-")) + pieces = code.split(u"-") + self.assertEqual(len(pieces), 3) # nameplate plus two words + self.assert_(re.search(r'^\d+-\w+-\w+$', code), code) + +class Basic(unittest.TestCase): + def tearDown(self): + # flush out any errorful Deferreds left dangling in cycles + gc.collect() + + def check_out(self, out, **kwargs): + # Assert that each kwarg is present in the 'out' dict. Ignore other + # keys ('msgid' in particular) + for key, value in kwargs.items(): + self.assertIn(key, out) + self.assertEqual(out[key], value, (out, key, value)) + + def check_outbound(self, ws, types): + out = ws.outbound() + self.assertEqual(len(out), len(types), (out, types)) + for i,t in enumerate(types): + self.assertEqual(out[i][u"type"], t, (i,t,out)) + return out + + def make_pake(self, code, side, msg1): + sp2 = SPAKE2_Symmetric(wormhole.to_bytes(code), + idSymmetric=wormhole.to_bytes(APPID)) + msg2 = sp2.start() + msg2_hex = hexlify(msg2).decode("ascii") + key = sp2.finish(msg1) + return key, msg2_hex + + def test_create(self): + wormhole._Wormhole(APPID, u"relay_url", reactor, None, None) + + def test_basic(self): + # We don't call w._start(), so this doesn't create a WebSocket + # connection. We provide a mock connection instead. If we wanted to + # exercise _connect, we'd mock out WSFactory. + # w._connect = lambda self: None + # w._event_connected(mock_ws) + # w._event_ws_opened() + # w._ws_dispatch_response(payload) + + timing = DebugTiming() + with mock.patch("wormhole.wormhole._WelcomeHandler") as wh_c: + w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing) + wh = wh_c.return_value + self.assertEqual(w._ws_url, u"relay_url") + self.assertTrue(w._flag_need_nameplate) + self.assertTrue(w._flag_need_to_build_msg1) + self.assertTrue(w._flag_need_to_send_PAKE) + + v = w.verify() + + w._drop_connection = mock.Mock() + ws = MockWebSocket() + w._event_connected(ws) + out = ws.outbound() + self.assertEqual(len(out), 0) + + w._event_ws_opened(None) + out = ws.outbound() + self.assertEqual(len(out), 1) + self.check_out(out[0], type=u"bind", appid=APPID, side=w._side) + self.assertIn(u"id", out[0]) + + # WelcomeHandler should get called upon 'welcome' response. Its full + # behavior is exercised in 'Welcome' above. + WELCOME = {u"foo": u"bar"} + response(w, type="welcome", welcome=WELCOME) + self.assertEqual(wh.mock_calls, [mock.call.handle_welcome(WELCOME)]) + + # because we're connected, setting the code also claims the mailbox + CODE = u"123-foo-bar" + w.set_code(CODE) + self.assertFalse(w._flag_need_to_build_msg1) + out = ws.outbound() + self.assertEqual(len(out), 1) + self.check_out(out[0], type=u"claim", nameplate=u"123") + + # the server reveals the linked mailbox + response(w, type=u"claimed", mailbox=u"mb456") + + # that triggers event_learned_mailbox, which should send open() and + # PAKE + self.assertEqual(w._mailbox_state, wormhole.OPEN) + out = ws.outbound() + self.assertEqual(len(out), 2) + self.check_out(out[0], type=u"open", mailbox=u"mb456") + self.check_out(out[1], type=u"add", phase=u"pake") + self.assertNoResult(v) + + # server echoes back all "add" messages + response(w, type=u"message", phase=u"pake", body=out[1][u"body"], + side=w._side) + self.assertNoResult(v) + + # next we build the simulated peer's PAKE operation + side2 = w._side + u"other" + msg1 = unhexlify(out[1][u"body"].encode("ascii")) + key, msg2_hex = self.make_pake(CODE, side2, msg1) + response(w, type=u"message", phase=u"pake", body=msg2_hex, side=side2) + + # hearing the peer's PAKE (msg2) makes us release the nameplate, send + # the confirmation message, delivered the verifier, and sends any + # queued phase messages + self.assertFalse(w._flag_need_to_see_mailbox_used) + self.assertEqual(w._key, key) + out = ws.outbound() + self.assertEqual(len(out), 2, out) + self.check_out(out[0], type=u"release") + self.check_out(out[1], type=u"add", phase=u"confirm") + verifier = self.successResultOf(v) + self.assertEqual(verifier, + w.derive_key(u"wormhole:verifier", SecretBox.KEY_SIZE)) + + # hearing a valid confirmation message doesn't throw an error + confkey = w.derive_key(u"wormhole:confirmation", SecretBox.KEY_SIZE) + nonce = os.urandom(wormhole.CONFMSG_NONCE_LENGTH) + confirm2 = wormhole.make_confmsg(confkey, nonce) + confirm2_hex = hexlify(confirm2).decode("ascii") + response(w, type=u"message", phase=u"confirm", body=confirm2_hex, + side=side2) + + # an outbound message can now be sent immediately + w.send(b"phase0-outbound") + out = ws.outbound() + self.assertEqual(len(out), 1) + self.check_out(out[0], type=u"add", phase=u"0") + # decrypt+check the outbound message + p0_outbound = unhexlify(out[0][u"body"].encode("ascii")) + msgkey0 = w._derive_phase_key(w._side, u"0") + p0_plaintext = w._decrypt_data(msgkey0, p0_outbound) + self.assertEqual(p0_plaintext, b"phase0-outbound") + + # get() waits for the inbound message to arrive + md = w.get() + self.assertNoResult(md) + self.assertIn(u"0", w._receive_waiters) + self.assertNotIn(u"0", w._received_messages) + msgkey1 = w._derive_phase_key(side2, u"0") + p0_inbound = w._encrypt_data(msgkey1, b"phase0-inbound") + p0_inbound_hex = hexlify(p0_inbound).decode("ascii") + response(w, type=u"message", phase=u"0", body=p0_inbound_hex, + side=side2) + p0_in = self.successResultOf(md) + self.assertEqual(p0_in, b"phase0-inbound") + self.assertNotIn(u"0", w._receive_waiters) + self.assertIn(u"0", w._received_messages) + + # receiving an inbound message will queue it until get() is called + msgkey2 = w._derive_phase_key(side2, u"1") + p1_inbound = w._encrypt_data(msgkey2, b"phase1-inbound") + p1_inbound_hex = hexlify(p1_inbound).decode("ascii") + response(w, type=u"message", phase=u"1", body=p1_inbound_hex, + side=side2) + self.assertIn(u"1", w._received_messages) + self.assertNotIn(u"1", w._receive_waiters) + p1_in = self.successResultOf(w.get()) + self.assertEqual(p1_in, b"phase1-inbound") + self.assertIn(u"1", w._received_messages) + self.assertNotIn(u"1", w._receive_waiters) + + d = w.close() + self.assertNoResult(d) + out = ws.outbound() + self.assertEqual(len(out), 1) + self.check_out(out[0], type=u"close", mood=u"happy") + self.assertEqual(w._drop_connection.mock_calls, []) + + response(w, type=u"released") + self.assertEqual(w._drop_connection.mock_calls, []) + response(w, type=u"closed") + self.assertEqual(w._drop_connection.mock_calls, [mock.call()]) + w._ws_closed(True, None, None) + self.assertEqual(self.successResultOf(d), None) + + def test_close_wait_0(self): + # Close before the connection is established. The connection still + # gets established, but it is then torn down before sending anything. + timing = DebugTiming() + w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing) + w._drop_connection = mock.Mock() + + d = w.close() + self.assertNoResult(d) + + ws = MockWebSocket() + w._event_connected(ws) + w._event_ws_opened(None) + self.assertEqual(w._drop_connection.mock_calls, [mock.call()]) + self.assertNoResult(d) + + w._ws_closed(True, None, None) + self.successResultOf(d) + + def test_close_wait_1(self): + # close before even claiming the nameplate + timing = DebugTiming() + w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing) + w._drop_connection = mock.Mock() + ws = MockWebSocket() + w._event_connected(ws) + w._event_ws_opened(None) + + d = w.close() + self.check_outbound(ws, [u"bind"]) + self.assertNoResult(d) + self.assertEqual(w._drop_connection.mock_calls, [mock.call()]) + self.assertNoResult(d) + + w._ws_closed(True, None, None) + self.successResultOf(d) + + def test_close_wait_2(self): + # Close after claiming the nameplate, but before opening the mailbox. + # The 'claimed' response arrives before we close. + timing = DebugTiming() + w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing) + w._drop_connection = mock.Mock() + ws = MockWebSocket() + w._event_connected(ws) + w._event_ws_opened(None) + CODE = u"123-foo-bar" + w.set_code(CODE) + self.check_outbound(ws, [u"bind", u"claim"]) + + response(w, type=u"claimed", mailbox=u"mb123") + + d = w.close() + self.check_outbound(ws, [u"open", u"add", u"release", u"close"]) + self.assertNoResult(d) + self.assertEqual(w._drop_connection.mock_calls, []) + + response(w, type=u"released") + self.assertNoResult(d) + self.assertEqual(w._drop_connection.mock_calls, []) + + response(w, type=u"closed") + self.assertEqual(w._drop_connection.mock_calls, [mock.call()]) + self.assertNoResult(d) + + w._ws_closed(True, None, None) + self.successResultOf(d) + + def test_close_wait_3(self): + # close after claiming the nameplate, but before opening the mailbox + # The 'claimed' response arrives after we start to close. + timing = DebugTiming() + w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing) + w._drop_connection = mock.Mock() + ws = MockWebSocket() + w._event_connected(ws) + w._event_ws_opened(None) + CODE = u"123-foo-bar" + w.set_code(CODE) + self.check_outbound(ws, [u"bind", u"claim"]) + + d = w.close() + response(w, type=u"claimed", mailbox=u"mb123") + self.check_outbound(ws, [u"release"]) + self.assertNoResult(d) + self.assertEqual(w._drop_connection.mock_calls, []) + + response(w, type=u"released") + self.assertEqual(w._drop_connection.mock_calls, [mock.call()]) + self.assertNoResult(d) + + w._ws_closed(True, None, None) + self.successResultOf(d) + + def test_close_wait_4(self): + # close after both claiming the nameplate and opening the mailbox + timing = DebugTiming() + w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing) + w._drop_connection = mock.Mock() + ws = MockWebSocket() + w._event_connected(ws) + w._event_ws_opened(None) + CODE = u"123-foo-bar" + w.set_code(CODE) + response(w, type=u"claimed", mailbox=u"mb456") + self.check_outbound(ws, [u"bind", u"claim", u"open", u"add"]) + + d = w.close() + self.check_outbound(ws, [u"release", u"close"]) + self.assertNoResult(d) + self.assertEqual(w._drop_connection.mock_calls, []) + + response(w, type=u"released") + self.assertNoResult(d) + self.assertEqual(w._drop_connection.mock_calls, []) + + response(w, type=u"closed") + self.assertNoResult(d) + self.assertEqual(w._drop_connection.mock_calls, [mock.call()]) + + w._ws_closed(True, None, None) + self.successResultOf(d) + + def test_close_wait_5(self): + # close after claiming the nameplate, opening the mailbox, then + # releasing the nameplate + timing = DebugTiming() + w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing) + w._drop_connection = mock.Mock() + ws = MockWebSocket() + w._event_connected(ws) + w._event_ws_opened(None) + CODE = u"123-foo-bar" + w.set_code(CODE) + response(w, type=u"claimed", mailbox=u"mb456") + + w._key = b"" + msgkey = w._derive_phase_key(u"side2", u"misc") + p1_inbound = w._encrypt_data(msgkey, b"") + p1_inbound_hex = hexlify(p1_inbound).decode("ascii") + response(w, type=u"message", phase=u"misc", side=u"side2", + body=p1_inbound_hex) + self.check_outbound(ws, [u"bind", u"claim", u"open", u"add", + u"release"]) + + d = w.close() + self.check_outbound(ws, [u"close"]) + self.assertNoResult(d) + self.assertEqual(w._drop_connection.mock_calls, []) + + response(w, type=u"released") + self.assertNoResult(d) + self.assertEqual(w._drop_connection.mock_calls, []) + + response(w, type=u"closed") + self.assertNoResult(d) + self.assertEqual(w._drop_connection.mock_calls, [mock.call()]) + + w._ws_closed(True, None, None) + self.successResultOf(d) + + def test_close_errbacks(self): + # make sure the Deferreds returned by verify() and get() are properly + # errbacked upon close + pass + + def test_get_code_mock(self): + timing = DebugTiming() + w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing) + ws = MockWebSocket() # TODO: mock w._ws_send_command instead + w._event_connected(ws) + w._event_ws_opened(None) + self.check_outbound(ws, [u"bind"]) + + gc_c = mock.Mock() + gc = gc_c.return_value = mock.Mock() + gc_d = gc.go.return_value = Deferred() + with mock.patch("wormhole.wormhole._GetCode", gc_c): + d = w.get_code() + self.assertNoResult(d) + + gc_d.callback(u"123-foo-bar") + code = self.successResultOf(d) + self.assertEqual(code, u"123-foo-bar") + + def test_get_code_real(self): + timing = DebugTiming() + w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing) + ws = MockWebSocket() + w._event_connected(ws) + w._event_ws_opened(None) + self.check_outbound(ws, [u"bind"]) + + d = w.get_code() + + out = ws.outbound() + self.assertEqual(len(out), 1) + self.check_out(out[0], type=u"allocate") + # TODO: nameplate attributes go here + self.assertNoResult(d) + + response(w, type=u"allocated", nameplate=u"123") + code = self.successResultOf(d) + self.assertIsInstance(code, type(u"")) + self.assert_(code.startswith(u"123-")) + pieces = code.split(u"-") + self.assertEqual(len(pieces), 3) # nameplate plus two words + self.assert_(re.search(r'^\d+-\w+-\w+$', code), code) + + def test_verifier(self): + # make sure verify() can be called both before and after the verifier + # is computed + pass + + def test_api_errors(self): + # doing things you're not supposed to do + pass + + def test_welcome_error(self): + # A welcome message could arrive at any time, with an [error] key + # that should make us halt. In practice, though, this gets sent as + # soon as the connection is established, which limits the possible + # states in which we might see it. + + timing = DebugTiming() + w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing) + w._drop_connection = mock.Mock() + ws = MockWebSocket() + w._event_connected(ws) + w._event_ws_opened(None) + self.check_outbound(ws, [u"bind"]) + + d1 = w.get() + d2 = w.verify() + d3 = w.get_code() + # TODO (tricky): test w.input_code + + self.assertNoResult(d1) + self.assertNoResult(d2) + self.assertNoResult(d3) + + w._signal_error(WelcomeError(u"you are not actually welcome"), u"pouty") + self.failureResultOf(d1, WelcomeError) + self.failureResultOf(d2, WelcomeError) + self.failureResultOf(d3, WelcomeError) + + # once the error is signalled, all API calls should fail + self.assertRaises(WelcomeError, w.send, u"foo") + self.assertRaises(WelcomeError, + w.derive_key, u"foo", SecretBox.KEY_SIZE) + self.failureResultOf(w.get(), WelcomeError) + self.failureResultOf(w.verify(), WelcomeError) + + def test_confirm_error(self): + # we should only receive the "confirm" message after we receive the + # PAKE message, by which point we should know the key. If the + # confirmation message doesn't decrypt, we signal an error. + timing = DebugTiming() + w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing) + w._drop_connection = mock.Mock() + ws = MockWebSocket() + w._event_connected(ws) + w._event_ws_opened(None) + w.set_code(u"123-foo-bar") + response(w, type=u"claimed", mailbox=u"mb456") + + d1 = w.get() + d2 = w.verify() + self.assertNoResult(d1) + self.assertNoResult(d2) + + out = ws.outbound() + # [u"bind", u"claim", u"open", u"add"] + self.assertEqual(len(out), 4) + self.assertEqual(out[3][u"type"], u"add") + + sp2 = SPAKE2_Symmetric(b"", idSymmetric=wormhole.to_bytes(APPID)) + msg2 = sp2.start() + msg2_hex = hexlify(msg2).decode("ascii") + response(w, type=u"message", phase=u"pake", body=msg2_hex, side=u"s2") + self.assertNoResult(d1) + self.successResultOf(d2) # early verify is unaffected + # TODO: change verify() to wait for "confirm" + + # sending a random confirm message will cause a confirmation error + confkey = w.derive_key(u"WRONG", SecretBox.KEY_SIZE) + nonce = os.urandom(wormhole.CONFMSG_NONCE_LENGTH) + badconfirm = wormhole.make_confmsg(confkey, nonce) + badconfirm_hex = hexlify(badconfirm).decode("ascii") + response(w, type=u"message", phase=u"confirm", body=badconfirm_hex, + side=u"s2") + + self.failureResultOf(d1, WrongPasswordError) + + # once the error is signalled, all API calls should fail + self.assertRaises(WrongPasswordError, w.send, u"foo") + self.assertRaises(WrongPasswordError, + w.derive_key, u"foo", SecretBox.KEY_SIZE) + self.failureResultOf(w.get(), WrongPasswordError) + self.failureResultOf(w.verify(), WrongPasswordError) + + +# event orderings to exercise: +# +# * normal sender: set_code, send_phase1, connected, claimed, learn_msg2, +# learn_phase1 +# * normal receiver (argv[2]=code): set_code, connected, learn_msg1, +# learn_phase1, send_phase1, +# * normal receiver (readline): connected, input_code +# * +# * set_code, then connected +# * connected, receive_pake, send_phase, set_code + +class Wormholes(ServerBase, unittest.TestCase): + # integration test, with a real server + + def doBoth(self, d1, d2): + return gatherResults([d1, d2], True) + + @inlineCallbacks + def test_basic(self): + w1 = wormhole.wormhole(APPID, self.relayurl, reactor) + w2 = wormhole.wormhole(APPID, self.relayurl, reactor) + code = yield w1.get_code() + w2.set_code(code) + w1.send(b"data1") + w2.send(b"data2") + dataX = yield w1.get() + dataY = yield w2.get() + self.assertEqual(dataX, b"data2") + self.assertEqual(dataY, b"data1") + yield w1.close() + yield w2.close() + + @inlineCallbacks + def test_same_message(self): + # the two sides use random nonces for their messages, so it's ok for + # both to try and send the same body: they'll result in distinct + # encrypted messages + w1 = wormhole.wormhole(APPID, self.relayurl, reactor) + w2 = wormhole.wormhole(APPID, self.relayurl, reactor) + code = yield w1.get_code() + w2.set_code(code) + w1.send(b"data") + w2.send(b"data") + dataX = yield w1.get() + dataY = yield w2.get() + self.assertEqual(dataX, b"data") + self.assertEqual(dataY, b"data") + yield w1.close() + yield w2.close() + + @inlineCallbacks + def test_interleaved(self): + w1 = wormhole.wormhole(APPID, self.relayurl, reactor) + w2 = wormhole.wormhole(APPID, self.relayurl, reactor) + code = yield w1.get_code() + w2.set_code(code) + w1.send(b"data1") + dataY = yield w2.get() + self.assertEqual(dataY, b"data1") + d = w1.get() + w2.send(b"data2") + dataX = yield d + self.assertEqual(dataX, b"data2") + yield w1.close() + yield w2.close() + + @inlineCallbacks + def test_unidirectional(self): + w1 = wormhole.wormhole(APPID, self.relayurl, reactor) + w2 = wormhole.wormhole(APPID, self.relayurl, reactor) + code = yield w1.get_code() + w2.set_code(code) + w1.send(b"data1") + dataY = yield w2.get() + self.assertEqual(dataY, b"data1") + yield w1.close() + yield w2.close() + + @inlineCallbacks + def test_early(self): + w1 = wormhole.wormhole(APPID, self.relayurl, reactor) + w1.send(b"data1") + w2 = wormhole.wormhole(APPID, self.relayurl, reactor) + d = w2.get() + w1.set_code(u"123-abc-def") + w2.set_code(u"123-abc-def") + dataY = yield d + self.assertEqual(dataY, b"data1") + yield w1.close() + yield w2.close() + + @inlineCallbacks + def test_fixed_code(self): + w1 = wormhole.wormhole(APPID, self.relayurl, reactor) + w2 = wormhole.wormhole(APPID, self.relayurl, reactor) + w1.set_code(u"123-purple-elephant") + w2.set_code(u"123-purple-elephant") + w1.send(b"data1"), w2.send(b"data2") + dl = yield self.doBoth(w1.get(), w2.get()) + (dataX, dataY) = dl + self.assertEqual(dataX, b"data2") + self.assertEqual(dataY, b"data1") + yield w1.close() + yield w2.close() + + + @inlineCallbacks + def test_multiple_messages(self): + w1 = wormhole.wormhole(APPID, self.relayurl, reactor) + w2 = wormhole.wormhole(APPID, self.relayurl, reactor) + w1.set_code(u"123-purple-elephant") + w2.set_code(u"123-purple-elephant") + w1.send(b"data1"), w2.send(b"data2") + w1.send(b"data3"), w2.send(b"data4") + dl = yield self.doBoth(w1.get(), w2.get()) + (dataX, dataY) = dl + self.assertEqual(dataX, b"data2") + self.assertEqual(dataY, b"data1") + dl = yield self.doBoth(w1.get(), w2.get()) + (dataX, dataY) = dl + self.assertEqual(dataX, b"data4") + self.assertEqual(dataY, b"data3") + yield w1.close() + yield w2.close() + + @inlineCallbacks + def test_wrong_password(self): + w1 = wormhole.wormhole(APPID, self.relayurl, reactor) + w2 = wormhole.wormhole(APPID, self.relayurl, reactor) + code = yield w1.get_code() + w2.set_code(code+"not") + # That's enough to allow both sides to discover the mismatch, but + # only after the confirmation message gets through. API calls that + # don't wait will appear to work until the mismatched confirmation + # message arrives. + w1.send(b"should still work") + w2.send(b"should still work") + + # API calls that wait (i.e. get) will errback + yield self.assertFailure(w2.get(), WrongPasswordError) + yield self.assertFailure(w1.get(), WrongPasswordError) + + yield w1.close() + yield w2.close() + self.flushLoggedErrors(WrongPasswordError) + + @inlineCallbacks + def test_verifier(self): + w1 = wormhole.wormhole(APPID, self.relayurl, reactor) + w2 = wormhole.wormhole(APPID, self.relayurl, reactor) + code = yield w1.get_code() + w2.set_code(code) + v1 = yield w1.verify() + v2 = yield w2.verify() + self.failUnlessEqual(type(v1), type(b"")) + self.failUnlessEqual(v1, v2) + w1.send(b"data1") + w2.send(b"data2") + dataX = yield w1.get() + dataY = yield w2.get() + self.assertEqual(dataX, b"data2") + self.assertEqual(dataY, b"data1") + yield w1.close() + yield w2.close() + +class Errors(ServerBase, unittest.TestCase): + @inlineCallbacks + def test_codes_1(self): + w = wormhole.wormhole(APPID, self.relayurl, reactor) + # definitely too early + self.assertRaises(UsageError, w.derive_key, u"purpose", 12) + + w.set_code(u"123-purple-elephant") + # code can only be set once + self.assertRaises(UsageError, w.set_code, u"123-nope") + yield self.assertFailure(w.get_code(), UsageError) + yield self.assertFailure(w.input_code(), UsageError) + yield w.close() + + @inlineCallbacks + def test_codes_2(self): + w = wormhole.wormhole(APPID, self.relayurl, reactor) + yield w.get_code() + self.assertRaises(UsageError, w.set_code, u"123-nope") + yield self.assertFailure(w.get_code(), UsageError) + yield self.assertFailure(w.input_code(), UsageError) + yield w.close() + diff --git a/src/wormhole/twisted/eventsource.py b/src/wormhole/twisted/eventsource.py deleted file mode 100644 index 19272ca..0000000 --- a/src/wormhole/twisted/eventsource.py +++ /dev/null @@ -1,238 +0,0 @@ -#import sys -from twisted.python import log, failure -from twisted.internet import reactor, defer, protocol -from twisted.application import service -from twisted.protocols import basic -from twisted.web.client import Agent, ResponseDone -from twisted.web.http_headers import Headers -from cgi import parse_header -from .eventual import eventually - -#if sys.version_info[0] == 2: -# to_unicode = unicode -#else: -# to_unicode = str - -class EventSourceParser(basic.LineOnlyReceiver): - # http://www.w3.org/TR/eventsource/ - delimiter = b"\n" - - def __init__(self, handler): - self.current_field = None - self.current_lines = [] - self.handler = handler - self.done_deferred = defer.Deferred() - self.eventtype = u"message" - self.encoding = "utf-8" - - def set_encoding(self, encoding): - self.encoding = encoding - - def connectionLost(self, why): - if why.check(ResponseDone): - why = None - self.done_deferred.callback(why) - - def dataReceived(self, data): - # exceptions here aren't being logged properly, and tests will hang - # rather than halt. I suspect twisted.web._newclient's - # HTTP11ClientProtocol.dataReceived(), which catches everything and - # responds with self._giveUp() but doesn't log.err. - try: - basic.LineOnlyReceiver.dataReceived(self, data) - except: - log.err() - raise - - def lineReceived(self, line): - #line = to_unicode(line, self.encoding) - line = line.decode(self.encoding) - if not line: - # blank line ends the field: deliver event, reset for next - self.eventReceived(self.eventtype, "\n".join(self.current_lines)) - self.eventtype = u"message" - self.current_lines[:] = [] - return - if u":" in line: - fieldname, data = line.split(u":", 1) - if data.startswith(u" "): - data = data[1:] - else: - fieldname = line - data = u"" - if fieldname == u"event": - self.eventtype = data - elif fieldname == u"data": - self.current_lines.append(data) - elif fieldname in (u"id", u"retry"): - # documented but unhandled - pass - else: - log.msg("weird fieldname", fieldname, data) - - def eventReceived(self, eventtype, data): - self.handler(eventtype, data) - -class EventSourceError(Exception): - pass - -# es = EventSource(url, handler) -# d = es.start() -# es.cancel() - -class EventSource: # TODO: service.Service - def __init__(self, url, handler, when_connected=None, agent=None): - assert isinstance(url, type(u"")) - self.url = url - self.handler = handler - self.when_connected = when_connected - self.started = False - self.cancelled = False - self.proto = EventSourceParser(self.handler) - if not agent: - agent = Agent(reactor) - self.agent = agent - - def start(self): - assert not self.started, "single-use" - self.started = True - assert self.url - d = self.agent.request(b"GET", self.url.encode("utf-8"), - Headers({b"accept": [b"text/event-stream"]})) - d.addCallback(self._connected) - return d - - def _connected(self, resp): - if resp.code != 200: - raise EventSourceError("%d: %s" % (resp.code, resp.phrase)) - if self.when_connected: - self.when_connected() - default_ct = "text/event-stream; charset=utf-8" - ct_headers = resp.headers.getRawHeaders("content-type", [default_ct]) - ct, ct_params = parse_header(ct_headers[0]) - assert ct == "text/event-stream", ct - self.proto.set_encoding(ct_params.get("charset", "utf-8")) - resp.deliverBody(self.proto) - if self.cancelled: - self.kill_connection() - return self.proto.done_deferred - - def cancel(self): - self.cancelled = True - if not self.proto.transport: - # _connected hasn't been called yet, but that self.cancelled - # should take care of it when the connection is established - def kill(data): - # this should kill it as soon as any data is delivered - raise ValueError("dead") - self.proto.dataReceived = kill # just in case - return - self.kill_connection() - - def kill_connection(self): - if (hasattr(self.proto.transport, "_producer") - and self.proto.transport._producer): - # This is gross and fragile. We need a clean way to stop the - # client connection. p.transport is a - # twisted.web._newclient.TransportProxyProducer , and its - # ._producer is the tcp.Port. - self.proto.transport._producer.loseConnection() - else: - log.err("get_events: unable to stop connection") - # oh well - #err = EventSourceError("unable to cancel") - try: - self.proto.done_deferred.callback(None) - except defer.AlreadyCalledError: - pass - - -class Connector: - # behave enough like an IConnector to appease ReconnectingClientFactory - def __init__(self, res): - self.res = res - def connect(self): - self.res._maybeStart() - def stopConnecting(self): - self.res._stop_eventsource() - -class ReconnectingEventSource(service.MultiService, - protocol.ReconnectingClientFactory): - def __init__(self, url, handler, agent=None): - service.MultiService.__init__(self) - # we don't use any of the basic Factory/ClientFactory methods of - # this, just the ReconnectingClientFactory.retry, stopTrying, and - # resetDelay methods. - - self.url = url - self.handler = handler - self.agent = agent - # IService provides self.running, toggled by {start,stop}Service. - # self.active is toggled by {,de}activate. If both .running and - # .active are True, then we want to have an outstanding EventSource - # and will start one if necessary. If either is False, then we don't - # want one to be outstanding, and will initiate shutdown. - self.active = False - self.connector = Connector(self) - self.es = None # set we have an outstanding EventSource - self.when_stopped = [] # list of Deferreds - - def isStopped(self): - return not self.es - - def startService(self): - service.MultiService.startService(self) # sets self.running - self._maybeStart() - - def stopService(self): - # clears self.running - d = defer.maybeDeferred(service.MultiService.stopService, self) - d.addCallback(self._maybeStop) - return d - - def activate(self): - assert not self.active - self.active = True - self._maybeStart() - - def deactivate(self): - assert self.active # XXX - self.active = False - return self._maybeStop() - - def _maybeStart(self): - if not (self.active and self.running): - return - self.continueTrying = True - self.es = EventSource(self.url, self.handler, self.resetDelay, - agent=self.agent) - d = self.es.start() - d.addBoth(self._stopped) - - def _stopped(self, res): - self.es = None - # we might have stopped because of a connection error, or because of - # an intentional shutdown. - if self.active and self.running: - # we still want to be connected, so schedule a reconnection - if isinstance(res, failure.Failure): - log.err(res) - self.retry() # will eventually call _maybeStart - return - # intentional shutdown - self.stopTrying() - for d in self.when_stopped: - eventually(d.callback, None) - self.when_stopped = [] - - def _stop_eventsource(self): - if self.es: - eventually(self.es.cancel) - - def _maybeStop(self, _=None): - self.stopTrying() # cancels timer, calls _stop_eventsource() - if not self.es: - return defer.succeed(None) - d = defer.Deferred() - self.when_stopped.append(d) - return d diff --git a/src/wormhole/twisted/transcribe.py b/src/wormhole/twisted/transcribe.py deleted file mode 100644 index 199fa50..0000000 --- a/src/wormhole/twisted/transcribe.py +++ /dev/null @@ -1,560 +0,0 @@ -from __future__ import print_function -import os, sys, json, re, unicodedata -from six.moves.urllib_parse import urlparse -from binascii import hexlify, unhexlify -from twisted.internet import reactor, defer, endpoints, error -from twisted.internet.threads import deferToThread, blockingCallFromThread -from twisted.internet.defer import inlineCallbacks, returnValue -from twisted.python import log -from autobahn.twisted import websocket -from nacl.secret import SecretBox -from nacl.exceptions import CryptoError -from nacl import utils -from spake2 import SPAKE2_Symmetric -from .. import __version__ -from .. import codes -from ..errors import ServerError, Timeout, WrongPasswordError, UsageError -from ..timing import DebugTiming -from hkdf import Hkdf - -def HKDF(skm, outlen, salt=None, CTXinfo=b""): - return Hkdf(salt, skm).expand(CTXinfo, outlen) - -CONFMSG_NONCE_LENGTH = 128//8 -CONFMSG_MAC_LENGTH = 256//8 -def make_confmsg(confkey, nonce): - return nonce+HKDF(confkey, CONFMSG_MAC_LENGTH, nonce) - -def to_bytes(u): - return unicodedata.normalize("NFC", u).encode("utf-8") - -class WSClient(websocket.WebSocketClientProtocol): - def onOpen(self): - self.wormhole_open = True - self.factory.d.callback(self) - - def onMessage(self, payload, isBinary): - assert not isBinary - self.wormhole._ws_dispatch_msg(payload) - - def onClose(self, wasClean, code, reason): - if self.wormhole_open: - self.wormhole._ws_closed(wasClean, code, reason) - else: - # we closed before establishing a connection (onConnect) or - # finishing WebSocket negotiation (onOpen): errback - self.factory.d.errback(error.ConnectError(reason)) - -class WSFactory(websocket.WebSocketClientFactory): - protocol = WSClient - def buildProtocol(self, addr): - proto = websocket.WebSocketClientFactory.buildProtocol(self, addr) - proto.wormhole = self.wormhole - proto.wormhole_open = False - return proto - -class Wormhole: - motd_displayed = False - version_warning_displayed = False - _send_confirm = True - - def __init__(self, appid, relay_url, tor_manager=None, timing=None, - reactor=reactor): - if not isinstance(appid, type(u"")): raise TypeError(type(appid)) - if not isinstance(relay_url, type(u"")): - raise TypeError(type(relay_url)) - if not relay_url.endswith(u"/"): raise UsageError - self._appid = appid - self._relay_url = relay_url - self._ws_url = relay_url.replace("http:", "ws:") + "ws" - self._tor_manager = tor_manager - self._timing = timing or DebugTiming() - self._reactor = reactor - self._side = hexlify(os.urandom(5)).decode("ascii") - self._code = None - self._channelid = None - self._key = None - self._started_get_code = False - self._sent_messages = set() # (phase, body_bytes) - self._delivered_messages = set() # (phase, body_bytes) - self._received_messages = {} # phase -> body_bytes - self._sent_phases = set() # phases, to prohibit double-send - self._got_phases = set() # phases, to prohibit double-read - self._sleepers = [] - self._confirmation_failed = False - self._closed = False - self._deallocated_status = None - self._timing_started = self._timing.add("wormhole") - self._ws = None - self._ws_t = None # timing Event - self._ws_channel_claimed = False - self._error = None - - def _make_endpoint(self, hostname, port): - if self._tor_manager: - return self._tor_manager.get_endpoint_for(hostname, port) - # note: HostnameEndpoints have a default 30s timeout - return endpoints.HostnameEndpoint(self._reactor, hostname, port) - - @inlineCallbacks - def _get_websocket(self): - if not self._ws: - # TODO: if we lose the connection, make a new one - #from twisted.python import log - #log.startLogging(sys.stderr) - assert self._side - assert not self._ws_channel_claimed - p = urlparse(self._ws_url) - f = WSFactory(self._ws_url) - f.wormhole = self - f.d = defer.Deferred() - # TODO: if hostname="localhost", I get three factories starting - # and stopping (maybe 127.0.0.1, ::1, and something else?), and - # an error in the factory is masked. - ep = self._make_endpoint(p.hostname, p.port or 80) - # .connect errbacks if the TCP connection fails - self._ws = yield ep.connect(f) - self._ws_t = self._timing.add("websocket") - # f.d is errbacked if WebSocket negotiation fails - yield f.d # WebSocket drops data sent before onOpen() fires - self._ws_send(u"bind", appid=self._appid, side=self._side) - # the socket is connected, and bound, but no channel has been claimed - returnValue(self._ws) - - @inlineCallbacks - def _ws_send(self, mtype, **kwargs): - ws = yield self._get_websocket() - # msgid is used by misc/dump-timing.py to correlate our sends with - # their receives, and vice versa. They are also correlated with the - # ACKs we get back from the server (which we otherwise ignore). There - # are so few messages, 16 bits is enough to be mostly-unique. - kwargs["id"] = hexlify(os.urandom(2)).decode("ascii") - kwargs["type"] = mtype - payload = json.dumps(kwargs).encode("utf-8") - self._timing.add("ws_send", _side=self._side, **kwargs) - ws.sendMessage(payload, False) - - def _ws_dispatch_msg(self, payload): - msg = json.loads(payload.decode("utf-8")) - self._timing.add("ws_receive", _side=self._side, message=msg) - mtype = msg["type"] - meth = getattr(self, "_ws_handle_"+mtype, None) - if not meth: - # make tests fail, but real application will ignore it - log.err(ValueError("Unknown inbound message type %r" % (msg,))) - return - return meth(msg) - - def _ws_handle_ack(self, msg): - pass - - def _ws_handle_welcome(self, msg): - welcome = msg["welcome"] - if ("motd" in welcome and - not self.motd_displayed): - motd_lines = welcome["motd"].splitlines() - motd_formatted = "\n ".join(motd_lines) - print("Server (at %s) says:\n %s" % - (self._ws_url, motd_formatted), file=sys.stderr) - self.motd_displayed = True - - # Only warn if we're running a release version (e.g. 0.0.6, not - # 0.0.6-DISTANCE-gHASH). Only warn once. - if ("-" not in __version__ and - not self.version_warning_displayed and - welcome["current_version"] != __version__): - print("Warning: errors may occur unless both sides are running the same version", file=sys.stderr) - print("Server claims %s is current, but ours is %s" - % (welcome["current_version"], __version__), file=sys.stderr) - self.version_warning_displayed = True - - if "error" in welcome: - return self._signal_error(welcome["error"]) - - @inlineCallbacks - def _sleep(self, wake_on_error=True): - if wake_on_error and self._error: - # don't sleep if the bed's already on fire, unless we're waiting - # for the fire department to respond, in which case sure, keep on - # sleeping - raise self._error - d = defer.Deferred() - self._sleepers.append(d) - yield d - if wake_on_error and self._error: - raise self._error - - def _wakeup(self): - sleepers = self._sleepers - self._sleepers = [] - for d in sleepers: - d.callback(None) - # NOTE: callers should avoid reentrancy themselves. An - # eventual-send would be safer here, but it makes synchronizing - # unit tests annoying. - - def _signal_error(self, error): - assert isinstance(error, Exception) - self._error = error - self._wakeup() - - def _ws_handle_error(self, msg): - err = ServerError("%s: %s" % (msg["error"], msg["orig"]), - self._ws_url) - return self._signal_error(err) - - @inlineCallbacks - def _claim_channel_and_watch(self): - assert self._channelid is not None - yield self._get_websocket() - if not self._ws_channel_claimed: - yield self._ws_send(u"claim", channelid=self._channelid) - self._ws_channel_claimed = True - yield self._ws_send(u"watch") - - # entry point 1: generate a new code - @inlineCallbacks - def get_code(self, code_length=2): # rename to allocate_code()? create_? - if self._code is not None: raise UsageError - if self._started_get_code: raise UsageError - self._started_get_code = True - with self._timing.add("API get_code"): - with self._timing.add("allocate"): - yield self._ws_send(u"allocate") - while self._channelid is None: - yield self._sleep() - code = codes.make_code(self._channelid, code_length) - assert isinstance(code, type(u"")), type(code) - self._set_code(code) - self._start() - returnValue(code) - - def _ws_handle_allocated(self, msg): - if self._channelid is not None: - return self._signal_error("got duplicate channelid") - self._channelid = msg["channelid"] - self._wakeup() - - def _start(self): - # allocate the rest now too, so it can be serialized - with self._timing.add("pake1", waiting="crypto"): - self._sp = SPAKE2_Symmetric(to_bytes(self._code), - idSymmetric=to_bytes(self._appid)) - self._msg1 = self._sp.start() - - # entry point 2a: interactively type in a code, with completion - @inlineCallbacks - def input_code(self, prompt="Enter wormhole code: ", code_length=2): - def _lister(): - return blockingCallFromThread(self._reactor, self._list_channels) - # fetch the list of channels ahead of time, to give us a chance to - # discover the welcome message (and warn the user about an obsolete - # client) - # - # TODO: send the request early, show the prompt right away, hide the - # latency in the user's indecision and slow typing. If we're lucky - # the answer will come back before they hit TAB. - with self._timing.add("API input_code"): - initial_channelids = yield self._list_channels() - with self._timing.add("input code", waiting="user"): - t = self._reactor.addSystemEventTrigger("before", "shutdown", - self._warn_readline) - code = yield deferToThread(codes.input_code_with_completion, - prompt, - initial_channelids, _lister, - code_length) - self._reactor.removeSystemEventTrigger(t) - returnValue(code) # application will give this to set_code() - - def _warn_readline(self): - # When our process receives a SIGINT, Twisted's SIGINT handler will - # stop the reactor and wait for all threads to terminate before the - # process exits. However, if we were waiting for - # input_code_with_completion() when SIGINT happened, the readline - # thread will be blocked waiting for something on stdin. Trick the - # user into satisfying the blocking read so we can exit. - print("\nCommand interrupted: please press Return to quit", - file=sys.stderr) - - # Other potential approaches to this problem: - # * hard-terminate our process with os._exit(1), but make sure the - # tty gets reset to a normal mode ("cooked"?) first, so that the - # next shell command the user types is echoed correctly - # * track down the thread (t.p.threadable.getThreadID from inside the - # thread), get a cffi binding to pthread_kill, deliver SIGINT to it - # * allocate a pty pair (pty.openpty), replace sys.stdin with the - # slave, build a pty bridge that copies bytes (and other PTY - # things) from the real stdin to the master, then close the slave - # at shutdown, so readline sees EOF - # * write tab-completion and basic editing (TTY raw mode, - # backspace-is-erase) without readline, probably with curses or - # twisted.conch.insults - # * write a separate program to get codes (maybe just "wormhole - # --internal-get-code"), run it as a subprocess, let it inherit - # stdin/stdout, send it SIGINT when we receive SIGINT ourselves. It - # needs an RPC mechanism (over some extra file descriptors) to ask - # us to fetch the current channelid list. - # - # Note that hard-terminating our process with os.kill(os.getpid(), - # signal.SIGKILL), or SIGTERM, doesn't seem to work: the thread - # doesn't see the signal, and we must still wait for stdin to make - # readline finish. - - @inlineCallbacks - def _list_channels(self): - with self._timing.add("list"): - self._latest_channelids = None - yield self._ws_send(u"list") - while self._latest_channelids is None: - yield self._sleep() - returnValue(self._latest_channelids) - - def _ws_handle_channelids(self, msg): - self._latest_channelids = msg["channelids"] - self._wakeup() - - # entry point 2b: paste in a fully-formed code - def set_code(self, code): - if not isinstance(code, type(u"")): raise TypeError(type(code)) - if self._code is not None: raise UsageError - mo = re.search(r'^(\d+)-', code) - if not mo: - raise ValueError("code (%s) must start with NN-" % code) - with self._timing.add("API set_code"): - self._channelid = int(mo.group(1)) - self._set_code(code) - self._start() - - def _set_code(self, code): - if self._code is not None: raise UsageError - self._timing.add("code established") - self._code = code - - def serialize(self): - # I can only be serialized after get_code/set_code and before - # get_verifier/get_data - if self._code is None: raise UsageError - if self._key is not None: raise UsageError - if self._sent_phases: raise UsageError - if self._got_phases: raise UsageError - data = { - "appid": self._appid, - "relay_url": self._relay_url, - "code": self._code, - "channelid": self._channelid, - "side": self._side, - "spake2": json.loads(self._sp.serialize().decode("ascii")), - "msg1": hexlify(self._msg1).decode("ascii"), - } - return json.dumps(data) - - # entry point 3: resume a previously-serialized session - @classmethod - def from_serialized(klass, data): - d = json.loads(data) - self = klass(d["appid"], d["relay_url"]) - self._side = d["side"] - self._channelid = d["channelid"] - self._set_code(d["code"]) - sp_data = json.dumps(d["spake2"]).encode("ascii") - self._sp = SPAKE2_Symmetric.from_serialized(sp_data) - self._msg1 = unhexlify(d["msg1"].encode("ascii")) - return self - - @inlineCallbacks - def get_verifier(self): - if self._closed: raise UsageError - if self._code is None: raise UsageError - with self._timing.add("API get_verifier"): - yield self._get_master_key() - # If the caller cares about the verifier, then they'll probably - # also willing to wait a moment to see the _confirm message. Each - # side sends this as soon as it sees the other's PAKE message. So - # the sender should see this hot on the heels of the inbound PAKE - # message (a moment after _get_master_key() returns). The - # receiver will see this a round-trip after they send their PAKE - # (because the sender is using wait=True inside _get_master_key, - # below: otherwise the sender might go do some blocking call). - yield self._msg_get(u"_confirm") - returnValue(self._verifier) - - @inlineCallbacks - def _get_master_key(self): - # TODO: prevent multiple invocation - if not self._key: - yield self._claim_channel_and_watch() - yield self._msg_send(u"pake", self._msg1) - pake_msg = yield self._msg_get(u"pake") - - with self._timing.add("pake2", waiting="crypto"): - self._key = self._sp.finish(pake_msg) - self._verifier = self.derive_key(u"wormhole:verifier") - self._timing.add("key established") - - if self._send_confirm: - # both sides send different (random) confirmation messages - confkey = self.derive_key(u"wormhole:confirmation") - nonce = os.urandom(CONFMSG_NONCE_LENGTH) - confmsg = make_confmsg(confkey, nonce) - yield self._msg_send(u"_confirm", confmsg, wait=True) - - @inlineCallbacks - def _msg_send(self, phase, body, wait=False): - self._sent_messages.add( (phase, body) ) - # TODO: retry on failure, with exponential backoff. We're guarding - # against the rendezvous server being temporarily offline. - t = self._timing.add("add", phase=phase, wait=wait) - yield self._ws_send(u"add", phase=phase, - body=hexlify(body).decode("ascii")) - if wait: - while (phase, body) not in self._delivered_messages: - yield self._sleep() - t.finish() - - def _ws_handle_message(self, msg): - m = msg["message"] - phase = m["phase"] - body = unhexlify(m["body"].encode("ascii")) - if (phase, body) in self._sent_messages: - self._delivered_messages.add( (phase, body) ) # ack by server - self._wakeup() - return # ignore echoes of our outbound messages - if phase in self._received_messages: - # a channel collision would cause this - err = ServerError("got duplicate phase %s" % phase, self._ws_url) - return self._signal_error(err) - self._received_messages[phase] = body - if phase == u"_confirm": - # TODO: we might not have a master key yet, if the caller wasn't - # waiting in _get_master_key() when a back-to-back pake+_confirm - # message pair arrived. - confkey = self.derive_key(u"wormhole:confirmation") - nonce = body[:CONFMSG_NONCE_LENGTH] - if body != make_confmsg(confkey, nonce): - # this makes all API calls fail - return self._signal_error(WrongPasswordError()) - # now notify anyone waiting on it - self._wakeup() - - @inlineCallbacks - def _msg_get(self, phase): - with self._timing.add("get", phase=phase): - while phase not in self._received_messages: - yield self._sleep() # we can wait a long time here - # that will throw an error if something goes wrong - msg = self._received_messages[phase] - returnValue(msg) - - def derive_key(self, purpose, length=SecretBox.KEY_SIZE): - if not isinstance(purpose, type(u"")): raise TypeError(type(purpose)) - if self._key is None: - # call after get_verifier() or get_data() - raise UsageError - return HKDF(self._key, length, CTXinfo=to_bytes(purpose)) - - def _encrypt_data(self, key, data): - assert isinstance(key, type(b"")), type(key) - assert isinstance(data, type(b"")), type(data) - assert len(key) == SecretBox.KEY_SIZE, len(key) - box = SecretBox(key) - nonce = utils.random(SecretBox.NONCE_SIZE) - return box.encrypt(data, nonce) - - def _decrypt_data(self, key, encrypted): - assert isinstance(key, type(b"")), type(key) - assert isinstance(encrypted, type(b"")), type(encrypted) - assert len(key) == SecretBox.KEY_SIZE, len(key) - box = SecretBox(key) - data = box.decrypt(encrypted) - return data - - @inlineCallbacks - def send_data(self, outbound_data, phase=u"data", wait=False): - if not isinstance(outbound_data, type(b"")): - raise TypeError(type(outbound_data)) - if not isinstance(phase, type(u"")): raise TypeError(type(phase)) - if self._closed: raise UsageError - if self._code is None: - raise UsageError("You must set_code() before send_data()") - if phase.startswith(u"_"): raise UsageError # reserved for internals - if phase in self._sent_phases: raise UsageError # only call this once - self._sent_phases.add(phase) - with self._timing.add("API send_data", phase=phase, wait=wait): - # Without predefined roles, we can't derive predictably unique - # keys for each side, so we use the same key for both. We use - # random nonces to keep the messages distinct, and we - # automatically ignore reflections. - yield self._get_master_key() - data_key = self.derive_key(u"wormhole:phase:%s" % phase) - outbound_encrypted = self._encrypt_data(data_key, outbound_data) - yield self._msg_send(phase, outbound_encrypted, wait) - - @inlineCallbacks - def get_data(self, phase=u"data"): - if not isinstance(phase, type(u"")): raise TypeError(type(phase)) - if self._closed: raise UsageError - if self._code is None: raise UsageError - if phase.startswith(u"_"): raise UsageError # reserved for internals - if phase in self._got_phases: raise UsageError # only call this once - self._got_phases.add(phase) - with self._timing.add("API get_data", phase=phase): - yield self._get_master_key() - body = yield self._msg_get(phase) # we can wait a long time here - try: - data_key = self.derive_key(u"wormhole:phase:%s" % phase) - inbound_data = self._decrypt_data(data_key, body) - returnValue(inbound_data) - except CryptoError: - raise WrongPasswordError - - def _ws_closed(self, wasClean, code, reason): - self._ws = None - self._ws_t.finish() - # TODO: schedule reconnect, unless we're done - - @inlineCallbacks - def close(self, f=None, mood=None): - """Do d.addBoth(w.close) at the end of your chain.""" - if self._closed: - returnValue(None) - self._closed = True - if not self._ws: - returnValue(None) - - if mood is None: - mood = u"happy" - if f: - if f.check(Timeout): - mood = u"lonely" - elif f.check(WrongPasswordError): - mood = u"scary" - elif f.check(TypeError, UsageError): - # preconditions don't warrant reporting mood - pass - else: - mood = u"errory" # other errors do - if not isinstance(mood, (type(None), type(u""))): - raise TypeError(type(mood)) - - with self._timing.add("API close"): - yield self._deallocate(mood) - # TODO: mark WebSocket as don't-reconnect - self._ws.transport.loseConnection() # probably flushes - del self._ws - self._ws_t.finish() - self._timing_started.finish(mood=mood) - returnValue(f) - - @inlineCallbacks - def _deallocate(self, mood): - with self._timing.add("deallocate"): - yield self._ws_send(u"deallocate", mood=mood) - while self._deallocated_status is None: - yield self._sleep(wake_on_error=False) - # TODO: set a timeout, don't wait forever for an ack - # TODO: if the connection is lost, let it go - returnValue(self._deallocated_status) - - def _ws_handle_deallocated(self, msg): - self._deallocated_status = msg["status"] - self._wakeup() diff --git a/src/wormhole/twisted/transit.py b/src/wormhole/twisted/transit.py index d0eef5e..cbc8724 100644 --- a/src/wormhole/twisted/transit.py +++ b/src/wormhole/twisted/transit.py @@ -548,6 +548,7 @@ def there_can_be_only_one(contenders): class Common: RELAY_DELAY = 2.0 + TRANSIT_KEY_LENGTH = SecretBox.KEY_SIZE def __init__(self, transit_relay, no_listen=False, tor_manager=None, reactor=reactor, timing=None): diff --git a/src/wormhole/wormhole.py b/src/wormhole/wormhole.py new file mode 100644 index 0000000..d84af3d --- /dev/null +++ b/src/wormhole/wormhole.py @@ -0,0 +1,845 @@ +from __future__ import print_function, absolute_import +import os, sys, json, re, unicodedata +from six.moves.urllib_parse import urlparse +from binascii import hexlify, unhexlify +from twisted.internet import defer, endpoints, error +from twisted.internet.threads import deferToThread, blockingCallFromThread +from twisted.internet.defer import inlineCallbacks, returnValue +from twisted.python import log +from autobahn.twisted import websocket +from nacl.secret import SecretBox +from nacl.exceptions import CryptoError +from nacl import utils +from spake2 import SPAKE2_Symmetric +from hashlib import sha256 +from . import __version__ +from . import codes +#from .errors import ServerError, Timeout +from .errors import (WrongPasswordError, UsageError, WelcomeError, + WormholeClosedError) +from .timing import DebugTiming +from hkdf import Hkdf + +def HKDF(skm, outlen, salt=None, CTXinfo=b""): + return Hkdf(salt, skm).expand(CTXinfo, outlen) + +CONFMSG_NONCE_LENGTH = 128//8 +CONFMSG_MAC_LENGTH = 256//8 +def make_confmsg(confkey, nonce): + return nonce+HKDF(confkey, CONFMSG_MAC_LENGTH, nonce) + +def to_bytes(u): + return unicodedata.normalize("NFC", u).encode("utf-8") + +# We send the following messages through the relay server to the far side (by +# sending "add" commands to the server, and getting "message" responses): +# +# phase=setup: +# * unauthenticated version strings (but why?) +# * early warmup for connection hints ("I can do tor, spin up HS") +# * wordlist l10n identifier +# phase=pake: just the SPAKE2 'start' message (binary) +# phase=confirm: key verification (HKDF(key, nonce)+nonce) +# phase=1,2,3,..: application messages + +class WSClient(websocket.WebSocketClientProtocol): + def onOpen(self): + self.wormhole_open = True + self.factory.d.callback(self) + + def onMessage(self, payload, isBinary): + assert not isBinary + self.wormhole._ws_dispatch_response(payload) + + def onClose(self, wasClean, code, reason): + if self.wormhole_open: + self.wormhole._ws_closed(wasClean, code, reason) + else: + # we closed before establishing a connection (onConnect) or + # finishing WebSocket negotiation (onOpen): errback + self.factory.d.errback(error.ConnectError(reason)) + +class WSFactory(websocket.WebSocketClientFactory): + protocol = WSClient + def buildProtocol(self, addr): + proto = websocket.WebSocketClientFactory.buildProtocol(self, addr) + proto.wormhole = self.wormhole + proto.wormhole_open = False + return proto + + +class _GetCode: + def __init__(self, code_length, send_command, timing): + self._code_length = code_length + self._send_command = send_command + self._timing = timing + self._allocated_d = defer.Deferred() + + @inlineCallbacks + def go(self): + with self._timing.add("allocate"): + self._send_command(u"allocate") + nameplate_id = yield self._allocated_d + code = codes.make_code(nameplate_id, self._code_length) + assert isinstance(code, type(u"")), type(code) + returnValue(code) + + def _response_handle_allocated(self, msg): + nid = msg["nameplate"] + assert isinstance(nid, type(u"")), type(nid) + self._allocated_d.callback(nid) + +class _InputCode: + def __init__(self, reactor, prompt, code_length, send_command, timing): + self._reactor = reactor + self._prompt = prompt + self._code_length = code_length + self._send_command = send_command + self._timing = timing + + @inlineCallbacks + def _list(self): + self._lister_d = defer.Deferred() + self._send_command(u"list") + nameplates = yield self._lister_d + self._lister_d = None + returnValue(nameplates) + + def _list_blocking(self): + return blockingCallFromThread(self._reactor, self._list) + + @inlineCallbacks + def go(self): + # fetch the list of nameplates ahead of time, to give us a chance to + # discover the welcome message (and warn the user about an obsolete + # client) + # + # TODO: send the request early, show the prompt right away, hide the + # latency in the user's indecision and slow typing. If we're lucky + # the answer will come back before they hit TAB. + + initial_nameplate_ids = yield self._list() + with self._timing.add("input code", waiting="user"): + t = self._reactor.addSystemEventTrigger("before", "shutdown", + self._warn_readline) + code = yield deferToThread(codes.input_code_with_completion, + self._prompt, + initial_nameplate_ids, + self._list_blocking, + self._code_length) + self._reactor.removeSystemEventTrigger(t) + returnValue(code) + + def _response_handle_nameplates(self, msg): + nameplates = msg["nameplates"] + assert isinstance(nameplates, list), type(nameplates) + nids = [] + for n in nameplates: + assert isinstance(n, dict), type(n) + nameplate_id = n[u"id"] + assert isinstance(nameplate_id, type(u"")), type(nameplate_id) + nids.append(nameplate_id) + self._lister_d.callback(nids) + + def _warn_readline(self): + # When our process receives a SIGINT, Twisted's SIGINT handler will + # stop the reactor and wait for all threads to terminate before the + # process exits. However, if we were waiting for + # input_code_with_completion() when SIGINT happened, the readline + # thread will be blocked waiting for something on stdin. Trick the + # user into satisfying the blocking read so we can exit. + print("\nCommand interrupted: please press Return to quit", + file=sys.stderr) + + # Other potential approaches to this problem: + # * hard-terminate our process with os._exit(1), but make sure the + # tty gets reset to a normal mode ("cooked"?) first, so that the + # next shell command the user types is echoed correctly + # * track down the thread (t.p.threadable.getThreadID from inside the + # thread), get a cffi binding to pthread_kill, deliver SIGINT to it + # * allocate a pty pair (pty.openpty), replace sys.stdin with the + # slave, build a pty bridge that copies bytes (and other PTY + # things) from the real stdin to the master, then close the slave + # at shutdown, so readline sees EOF + # * write tab-completion and basic editing (TTY raw mode, + # backspace-is-erase) without readline, probably with curses or + # twisted.conch.insults + # * write a separate program to get codes (maybe just "wormhole + # --internal-get-code"), run it as a subprocess, let it inherit + # stdin/stdout, send it SIGINT when we receive SIGINT ourselves. It + # needs an RPC mechanism (over some extra file descriptors) to ask + # us to fetch the current nameplate_id list. + # + # Note that hard-terminating our process with os.kill(os.getpid(), + # signal.SIGKILL), or SIGTERM, doesn't seem to work: the thread + # doesn't see the signal, and we must still wait for stdin to make + # readline finish. + +class _WelcomeHandler: + def __init__(self, url, current_version, signal_error): + self._ws_url = url + self._version_warning_displayed = False + self._motd_displayed = False + self._current_version = current_version + self._signal_error = signal_error + + def handle_welcome(self, welcome): + if ("motd" in welcome and + not self._motd_displayed): + motd_lines = welcome["motd"].splitlines() + motd_formatted = "\n ".join(motd_lines) + print("Server (at %s) says:\n %s" % + (self._ws_url, motd_formatted), file=sys.stderr) + self._motd_displayed = True + + # Only warn if we're running a release version (e.g. 0.0.6, not + # 0.0.6-DISTANCE-gHASH). Only warn once. + if ("current_version" in welcome + and "-" not in self._current_version + and not self._version_warning_displayed + and welcome["current_version"] != self._current_version): + print("Warning: errors may occur unless both sides are running the same version", file=sys.stderr) + print("Server claims %s is current, but ours is %s" + % (welcome["current_version"], self._current_version), + file=sys.stderr) + self._version_warning_displayed = True + + if "error" in welcome: + return self._signal_error(WelcomeError(welcome["error"])) + +# states for nameplates, mailboxes, and the websocket connection +(CLOSED, OPENING, OPEN, CLOSING) = ("closed", "opening", "open", "closing") + + +class _Wormhole: + def __init__(self, appid, relay_url, reactor, tor_manager, timing): + self._appid = appid + self._ws_url = relay_url + self._reactor = reactor + self._tor_manager = tor_manager + self._timing = timing + + self._welcomer = _WelcomeHandler(self._ws_url, __version__, + self._signal_error) + self._side = hexlify(os.urandom(5)).decode("ascii") + self._connection_state = CLOSED + self._connection_waiters = [] + self._started_get_code = False + self._get_code = None + self._started_input_code = False + self._code = None + self._nameplate_id = None + self._nameplate_state = CLOSED + self._mailbox_id = None + self._mailbox_state = CLOSED + self._flag_need_nameplate = True + self._flag_need_to_see_mailbox_used = True + self._flag_need_to_build_msg1 = True + self._flag_need_to_send_PAKE = True + self._key = None + self._close_called = False # the close() API has been called + self._closing = False # we've started shutdown + self._disconnect_waiter = defer.Deferred() + self._error = None + + self._get_verifier_called = False + self._verifier = None + self._verifier_waiter = None + + self._next_send_phase = 0 + # send() queues plaintext here, waiting for a connection and the key + self._plaintext_to_send = [] # (phase, plaintext) + self._sent_phases = set() # to detect double-send + + self._next_receive_phase = 0 + self._receive_waiters = {} # phase -> Deferred + self._received_messages = {} # phase -> plaintext + + # API METHODS for applications to call + + # You must use at least one of these entry points, to establish the + # wormhole code. Other APIs will stall or be queued until we have one. + + # entry point 1: generate a new code. returns a Deferred + def get_code(self, code_length=2): # XX rename to allocate_code()? create_? + return self._API_get_code(code_length) + + # entry point 2: interactively type in a code, with completion. returns + # Deferred + def input_code(self, prompt="Enter wormhole code: ", code_length=2): + return self._API_input_code(prompt, code_length) + + # entry point 3: paste in a fully-formed code. No return value. + def set_code(self, code): + self._API_set_code(code) + + # todo: restore-saved-state entry points + + def verify(self): + """Returns a Deferred that fires when we've heard back from the other + side, and have confirmed that they used the right wormhole code. When + successful, the Deferred fires with a "verifier" (a bytestring) which + can be compared out-of-band before making additional API calls. If + they used the wrong wormhole code, the Deferred errbacks with + WrongPasswordError. + """ + return self._API_verify() + + def send(self, outbound_data): + return self._API_send(outbound_data) + + def get(self): + return self._API_get() + + def derive_key(self, purpose, length): + """Derive a new key from the established wormhole channel for some + other purpose. This is a deterministic randomized function of the + session key and the 'purpose' string (unicode/py3-string). This + cannot be called until verify() or get() has fired. + """ + return self._API_derive_key(purpose, length) + + def close(self, res=None): + """Collapse the wormhole, freeing up server resources and flushing + all pending messages. Returns a Deferred that fires when everything + is done. It fires with any argument close() was given, to enable use + as a d.addBoth() handler: + + w = wormhole(...) + d = w.get() + .. + d.addBoth(w.close) + return d + + Another reasonable approach is to use inlineCallbacks: + + @inlineCallbacks + def pair(self, code): + w = wormhole(...) + try: + them = yield w.get() + finally: + yield w.close() + """ + return self._API_close(res) + + # INTERNAL METHODS beyond here + + def _start(self): + d = self._connect() # causes stuff to happen + d.addErrback(log.err) + return d # fires when connection is established, if you care + + + + def _make_endpoint(self, hostname, port): + if self._tor_manager: + return self._tor_manager.get_endpoint_for(hostname, port) + # note: HostnameEndpoints have a default 30s timeout + return endpoints.HostnameEndpoint(self._reactor, hostname, port) + + def _connect(self): + # TODO: if we lose the connection, make a new one, re-establish the + # state + assert self._side + self._connection_state = OPENING + p = urlparse(self._ws_url) + f = WSFactory(self._ws_url) + f.wormhole = self + f.d = defer.Deferred() + # TODO: if hostname="localhost", I get three factories starting + # and stopping (maybe 127.0.0.1, ::1, and something else?), and + # an error in the factory is masked. + ep = self._make_endpoint(p.hostname, p.port or 80) + # .connect errbacks if the TCP connection fails + d = ep.connect(f) + d.addCallback(self._event_connected) + # f.d is errbacked if WebSocket negotiation fails, and the WebSocket + # drops any data sent before onOpen() fires, so we must wait for it + d.addCallback(lambda _: f.d) + d.addCallback(self._event_ws_opened) + return d + + def _event_connected(self, ws): + self._ws = ws + self._ws_t = self._timing.add("websocket") + + def _event_ws_opened(self, _): + self._connection_state = OPEN + if self._closing: + return self._maybe_finished_closing() + self._ws_send_command(u"bind", appid=self._appid, side=self._side) + self._maybe_claim_nameplate() + self._maybe_send_pake() + waiters, self._connection_waiters = self._connection_waiters, [] + for d in waiters: + d.callback(None) + + def _when_connected(self): + if self._connection_state == OPEN: + return defer.succeed(None) + d = defer.Deferred() + self._connection_waiters.append(d) + return d + + def _ws_send_command(self, mtype, **kwargs): + # msgid is used by misc/dump-timing.py to correlate our sends with + # their receives, and vice versa. They are also correlated with the + # ACKs we get back from the server (which we otherwise ignore). There + # are so few messages, 16 bits is enough to be mostly-unique. + if self.DEBUG: print("SEND", mtype) + kwargs["id"] = hexlify(os.urandom(2)).decode("ascii") + kwargs["type"] = mtype + payload = json.dumps(kwargs).encode("utf-8") + self._timing.add("ws_send", _side=self._side, **kwargs) + self._ws.sendMessage(payload, False) + + DEBUG=False + def _ws_dispatch_response(self, payload): + msg = json.loads(payload.decode("utf-8")) + if self.DEBUG and msg["type"]!="ack": print("DIS", msg["type"], msg) + self._timing.add("ws_receive", _side=self._side, message=msg) + mtype = msg["type"] + meth = getattr(self, "_response_handle_"+mtype, None) + if not meth: + # make tests fail, but real application will ignore it + log.err(ValueError("Unknown inbound message type %r" % (msg,))) + return + return meth(msg) + + def _response_handle_ack(self, msg): + pass + + def _response_handle_welcome(self, msg): + self._welcomer.handle_welcome(msg["welcome"]) + + # entry point 1: generate a new code + @inlineCallbacks + def _API_get_code(self, code_length): + if self._code is not None: raise UsageError + if self._started_get_code: raise UsageError + self._started_get_code = True + with self._timing.add("API get_code"): + yield self._when_connected() + gc = _GetCode(code_length, self._ws_send_command, self._timing) + self._get_code = gc + self._response_handle_allocated = gc._response_handle_allocated + # TODO: signal_error + code = yield gc.go() + self._get_code = None + self._nameplate_state = OPEN + self._event_learned_code(code) + returnValue(code) + + # entry point 2: interactively type in a code, with completion + @inlineCallbacks + def _API_input_code(self, prompt, code_length): + if self._code is not None: raise UsageError + if self._started_input_code: raise UsageError + self._started_input_code = True + with self._timing.add("API input_code"): + yield self._when_connected() + ic = _InputCode(self._reactor, prompt, code_length, + self._ws_send_command, self._timing) + self._response_handle_nameplates = ic._response_handle_nameplates + # TODO: signal_error + code = yield ic.go() + self._event_learned_code(code) + returnValue(None) + + # entry point 3: paste in a fully-formed code + def _API_set_code(self, code): + self._timing.add("API set_code") + if not isinstance(code, type(u"")): raise TypeError(type(code)) + if self._code is not None: raise UsageError + self._event_learned_code(code) + + # TODO: entry point 4: restore pre-contact saved state (we haven't heard + # from the peer yet, so we still need the nameplate) + + # TODO: entry point 5: restore post-contact saved state (so we don't need + # or use the nameplate, only the mailbox) + def _restore_post_contact_state(self, state): + # ... + self._flag_need_nameplate = False + #self._mailbox_id = X(state) + self._event_learned_mailbox() + + def _event_learned_code(self, code): + self._timing.add("code established") + self._code = code + mo = re.search(r'^(\d+)-', code) + if not mo: + raise ValueError("code (%s) must start with NN-" % code) + nid = mo.group(1) + assert isinstance(nid, type(u"")), type(nid) + self._nameplate_id = nid + # fire more events + self._maybe_build_msg1() + self._event_learned_nameplate() + + def _maybe_build_msg1(self): + if not (self._code and self._flag_need_to_build_msg1): + return + with self._timing.add("pake1", waiting="crypto"): + self._sp = SPAKE2_Symmetric(to_bytes(self._code), + idSymmetric=to_bytes(self._appid)) + self._msg1 = self._sp.start() + self._flag_need_to_build_msg1 = False + self._event_built_msg1() + + def _event_built_msg1(self): + self._maybe_send_pake() + + # every _maybe_X starts with a set of conditions + # for each such condition Y, every _event_Y must call _maybe_X + + def _event_learned_nameplate(self): + self._maybe_claim_nameplate() + + def _maybe_claim_nameplate(self): + if not (self._nameplate_id and self._connection_state == OPEN): + return + self._ws_send_command(u"claim", nameplate=self._nameplate_id) + self._nameplate_state = OPEN + + def _response_handle_claimed(self, msg): + mailbox_id = msg["mailbox"] + assert isinstance(mailbox_id, type(u"")), type(mailbox_id) + self._mailbox_id = mailbox_id + self._event_learned_mailbox() + + def _event_learned_mailbox(self): + if not self._mailbox_id: raise UsageError + assert self._mailbox_state == CLOSED, self._mailbox_state + if self._closing: + return + self._ws_send_command(u"open", mailbox=self._mailbox_id) + self._mailbox_state = OPEN + # causes old messages to be sent now, and subscribes to new messages + self._maybe_send_pake() + self._maybe_send_phase_messages() + + def _maybe_send_pake(self): + # TODO: deal with reentrant call + if not (self._connection_state == OPEN + and self._mailbox_state == OPEN + and self._flag_need_to_send_PAKE): + return + self._msg_send(u"pake", self._msg1) + self._flag_need_to_send_PAKE = False + + def _event_received_pake(self, pake_msg): + with self._timing.add("pake2", waiting="crypto"): + self._key = self._sp.finish(pake_msg) + self._event_established_key() + + def _derive_confirmation_key(self): + return self._derive_key(b"wormhole:confirmation") + + def _event_established_key(self): + self._timing.add("key established") + + # both sides send different (random) confirmation messages + confkey = self._derive_confirmation_key() + nonce = os.urandom(CONFMSG_NONCE_LENGTH) + confmsg = make_confmsg(confkey, nonce) + self._msg_send(u"confirm", confmsg) + + verifier = self._derive_key(b"wormhole:verifier") + self._event_computed_verifier(verifier) + + self._maybe_send_phase_messages() + + def _API_verify(self): + # TODO: rename "verify()", make it stall until confirm received. If + # you want to discover WrongPasswordError before doing send(), call + # verify() first. If you also want to deny a successful MitM (and + # have some other way to check a long verifier), use the return value + # of verify(). + if self._error: return defer.fail(self._error) + if self._get_verifier_called: raise UsageError + self._get_verifier_called = True + if self._verifier: + return defer.succeed(self._verifier) + # TODO: maybe have this wait on _event_received_confirm too + self._verifier_waiter = defer.Deferred() + return self._verifier_waiter + + def _event_computed_verifier(self, verifier): + self._verifier = verifier + if self._verifier_waiter: + self._verifier_waiter.callback(verifier) + + def _event_received_confirm(self, body): + # TODO: we might not have a master key yet, if the caller wasn't + # waiting in _get_master_key() when a back-to-back pake+_confirm + # message pair arrived. + confkey = self._derive_confirmation_key() + nonce = body[:CONFMSG_NONCE_LENGTH] + if body != make_confmsg(confkey, nonce): + # this makes all API calls fail + if self.DEBUG: print("CONFIRM FAILED") + return self._signal_error(WrongPasswordError(), u"scary") + + + def _API_send(self, outbound_data): + if self._error: raise self._error + if not isinstance(outbound_data, type(b"")): + raise TypeError(type(outbound_data)) + phase = self._next_send_phase + self._next_send_phase += 1 + self._plaintext_to_send.append( (phase, outbound_data) ) + with self._timing.add("API send", phase=phase): + self._maybe_send_phase_messages() + + def _derive_phase_key(self, side, phase): + assert isinstance(side, type(u"")), type(side) + assert isinstance(phase, type(u"")), type(phase) + side_bytes = side.encode("ascii") + phase_bytes = phase.encode("ascii") + purpose = (b"wormhole:phase:" + + sha256(side_bytes).digest() + + sha256(phase_bytes).digest()) + return self._derive_key(purpose) + + def _maybe_send_phase_messages(self): + # TODO: deal with reentrant call + if not (self._connection_state == OPEN + and self._mailbox_state == OPEN + and self._key): + return + plaintexts = self._plaintext_to_send + self._plaintext_to_send = [] + for pm in plaintexts: + (phase_int, plaintext) = pm + assert isinstance(phase_int, int), type(phase_int) + phase = u"%d" % phase_int + data_key = self._derive_phase_key(self._side, phase) + encrypted = self._encrypt_data(data_key, plaintext) + self._msg_send(phase, encrypted) + + def _encrypt_data(self, key, data): + # Without predefined roles, we can't derive predictably unique keys + # for each side, so we use the same key for both. We use random + # nonces to keep the messages distinct, and we automatically ignore + # reflections. + # TODO: HKDF(side, nonce, key) ?? include 'side' to prevent + # reflections, since we no longer compare messages + assert isinstance(key, type(b"")), type(key) + assert isinstance(data, type(b"")), type(data) + assert len(key) == SecretBox.KEY_SIZE, len(key) + box = SecretBox(key) + nonce = utils.random(SecretBox.NONCE_SIZE) + return box.encrypt(data, nonce) + + def _msg_send(self, phase, body): + if phase in self._sent_phases: raise UsageError + assert self._mailbox_state == OPEN, self._mailbox_state + self._sent_phases.add(phase) + # TODO: retry on failure, with exponential backoff. We're guarding + # against the rendezvous server being temporarily offline. + self._timing.add("add", phase=phase) + self._ws_send_command(u"add", phase=phase, + body=hexlify(body).decode("ascii")) + + + def _event_mailbox_used(self): + if self.DEBUG: print("_event_mailbox_used") + if self._flag_need_to_see_mailbox_used: + self._maybe_release_nameplate() + self._flag_need_to_see_mailbox_used = False + + def _API_derive_key(self, purpose, length): + if self._error: raise self._error + if not isinstance(purpose, type(u"")): raise TypeError(type(purpose)) + return self._derive_key(to_bytes(purpose), length) + + def _derive_key(self, purpose, length=SecretBox.KEY_SIZE): + if not isinstance(purpose, type(b"")): raise TypeError(type(purpose)) + if self._key is None: + raise UsageError # call derive_key after get_verifier() or get() + return HKDF(self._key, length, CTXinfo=purpose) + + def _response_handle_message(self, msg): + side = msg["side"] + phase = msg["phase"] + assert isinstance(phase, type(u"")), type(phase) + body = unhexlify(msg["body"].encode("ascii")) + if side == self._side: + return + self._event_received_peer_message(side, phase, body) + + def _event_received_peer_message(self, side, phase, body): + # any message in the mailbox means we no longer need the nameplate + self._event_mailbox_used() + #if phase in self._received_messages: + # # a nameplate collision would cause this + # err = ServerError("got duplicate phase %s" % phase, self._ws_url) + # return self._signal_error(err) + #self._received_messages[phase] = body + if phase == u"pake": + self._event_received_pake(body) + return + if phase == u"confirm": + self._event_received_confirm(body) + return + + # It's a phase message, aimed at the application above us. Decrypt + # and deliver upstairs, notifying anyone waiting on it + try: + data_key = self._derive_phase_key(side, phase) + plaintext = self._decrypt_data(data_key, body) + except CryptoError: + e = WrongPasswordError() + self._signal_error(e, u"scary") # flunk all other API calls + # make tests fail, if they aren't explicitly catching it + if self.DEBUG: print("CryptoError in msg received") + log.err(e) + if self.DEBUG: print(" did log.err", e) + return # ignore this message + self._received_messages[phase] = plaintext + if phase in self._receive_waiters: + d = self._receive_waiters.pop(phase) + d.callback(plaintext) + + def _decrypt_data(self, key, encrypted): + assert isinstance(key, type(b"")), type(key) + assert isinstance(encrypted, type(b"")), type(encrypted) + assert len(key) == SecretBox.KEY_SIZE, len(key) + box = SecretBox(key) + data = box.decrypt(encrypted) + return data + + def _API_get(self): + if self._error: return defer.fail(self._error) + phase = u"%d" % self._next_receive_phase + self._next_receive_phase += 1 + with self._timing.add("API get", phase=phase): + if phase in self._received_messages: + return defer.succeed(self._received_messages[phase]) + d = self._receive_waiters[phase] = defer.Deferred() + return d + + def _signal_error(self, error, mood): + if self.DEBUG: print("_signal_error", error, mood) + if self._error: + return + self._maybe_close(error, mood) + if self.DEBUG: print("_signal_error done") + + @inlineCallbacks + def _API_close(self, res, mood=u"happy"): + if self.DEBUG: print("close") + if self._close_called: raise UsageError + self._close_called = True + self._maybe_close(WormholeClosedError(), mood) + if self.DEBUG: print("waiting for disconnect") + yield self._disconnect_waiter + returnValue(res) + + def _maybe_close(self, error, mood): + if self._closing: + return + + # ordering constraints: + # * must wait for nameplate/mailbox acks before closing the websocket + # * must mark APIs for failure before errbacking Deferreds + # * since we give up control + # * must mark self._closing before errbacking Deferreds + # * since caller may call close() when we give up control + # * and close() will reenter _maybe_close + + self._error = error # causes new API calls to fail + + # since we're about to give up control by errbacking any API + # Deferreds, set self._closing, to make sure that a new call to + # close() isn't going to confuse anything + self._closing = True + + # now errback all API deferreds except close(): get_code, + # input_code, verify, get + for d in self._connection_waiters: # input_code, get_code (early) + if self.DEBUG: print("EB cw") + d.errback(error) + if self._get_code: # get_code (late) + if self.DEBUG: print("EB gc") + self._get_code._allocated_d.errback(error) + if self._verifier_waiter and not self._verifier_waiter.called: + if self.DEBUG: print("EB VW") + self._verifier_waiter.errback(error) + for d in self._receive_waiters.values(): + if self.DEBUG: print("EB RW") + d.errback(error) + # Release nameplate and close mailbox, if either was claimed/open. + # Since _closing is True when both ACKs come back, the handlers will + # close the websocket. When *that* finishes, _disconnect_waiter() + # will fire. + self._maybe_release_nameplate() + self._maybe_close_mailbox(mood) + # In the off chance we got closed before we even claimed the + # nameplate, give _maybe_finished_closing a chance to run now. + self._maybe_finished_closing() + + def _maybe_release_nameplate(self): + if self.DEBUG: print("_maybe_release_nameplate", self._nameplate_state) + if self._nameplate_state == OPEN: + if self.DEBUG: print(" sending release") + self._ws_send_command(u"release") + self._nameplate_state = CLOSING + + def _response_handle_released(self, msg): + self._nameplate_state = CLOSED + self._maybe_finished_closing() + + def _maybe_close_mailbox(self, mood): + if self.DEBUG: print("_maybe_close_mailbox", self._mailbox_state) + if self._mailbox_state == OPEN: + if self.DEBUG: print(" sending close") + self._ws_send_command(u"close", mood=mood) + self._mailbox_state = CLOSING + + def _response_handle_closed(self, msg): + self._mailbox_state = CLOSED + self._maybe_finished_closing() + + def _maybe_finished_closing(self): + if self.DEBUG: print("_maybe_finished_closing", self._closing, self._nameplate_state, self._mailbox_state, self._connection_state) + if not self._closing: + return + if (self._nameplate_state == CLOSED + and self._mailbox_state == CLOSED + and self._connection_state == OPEN): + self._connection_state = CLOSING + self._drop_connection() + + def _drop_connection(self): + # separate method so it can be overridden by tests + self._ws.transport.loseConnection() # probably flushes output + # calls _ws_closed() when done + + def _ws_closed(self, wasClean, code, reason): + # For now (until we add reconnection), losing the websocket means + # losing everything. Make all API callers fail. Help someone waiting + # in close() to finish + self._connection_state = CLOSED + self._disconnect_waiter.callback(None) + self._maybe_finished_closing() + + # what needs to happen when _ws_closed() happens unexpectedly + # * errback all API deferreds + # * maybe: cause new API calls to fail + # * obviously can't release nameplate or close mailbox + # * can't re-close websocket + # * close(wait=True) callers should fire right away + +def wormhole(appid, relay_url, reactor, tor_manager=None, timing=None): + timing = timing or DebugTiming() + w = _Wormhole(appid, relay_url, reactor, tor_manager, timing) + w._start() + return w + +def wormhole_from_serialized(data, reactor, timing=None): + timing = timing or DebugTiming() + w = _Wormhole.from_serialized(data, reactor, timing) + return w diff --git a/tox.ini b/tox.ini index 7ee4127..326afa5 100644 --- a/tox.ini +++ b/tox.ini @@ -19,6 +19,7 @@ skip_missing_interpreters = True [testenv] deps = pyflakes >= 1.2.3 + mock {env:EXTRA_DEPENDENCY:} commands = pyflakes setup.py src