From 05aa5ca76ee7995c99003178bfbd1acab5a4bc14 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Fri, 20 May 2016 13:51:05 -0700 Subject: [PATCH] WIP Wormhole --- src/wormhole/twisted/transcribe.py | 241 ++++++++++++++++++----------- 1 file changed, 148 insertions(+), 93 deletions(-) diff --git a/src/wormhole/twisted/transcribe.py b/src/wormhole/twisted/transcribe.py index d8f39d6..555499b 100644 --- a/src/wormhole/twisted/transcribe.py +++ b/src/wormhole/twisted/transcribe.py @@ -28,6 +28,17 @@ def make_confmsg(confkey, nonce): def to_bytes(u): return unicodedata.normalize("NFC", u).encode("utf-8") +# We send the following messages through the relay server to the far side (by +# sending "add" commands to the server, and getting "message" responses): +# +# phase=setup: +# * unauthenticated version strings (but why?) +# * early warmup for connection hints ("I can do tor, spin up HS") +# * wordlist l10n identifier +# phase=pake: just the SPAKE2 'start' message (binary) +# phase=confirm: key verification (HKDF(key, nonce)+nonce) +# phase=1,2,3,..: application messages + class WSClient(websocket.WebSocketClientProtocol): def onOpen(self): self.wormhole_open = True @@ -35,7 +46,7 @@ class WSClient(websocket.WebSocketClientProtocol): def onMessage(self, payload, isBinary): assert not isBinary - self.wormhole._ws_dispatch_msg(payload) + self.wormhole._ws_dispatch_response(payload) def onClose(self, wasClean, code, reason): if self.wormhole_open: @@ -70,9 +81,20 @@ class _Wormhole: self._tor_manager = tor_manager self._timing = timing or DebugTiming() self._reactor = reactor + + self._ws_connected = defer.Deferred() # XXX + self._side = hexlify(os.urandom(5)).decode("ascii") self._code = None - self._channelid = None + + self._nameplate_id = None + self._nameplate_claimed = False + self._nameplate_released = False + + self._mailbox_id = None + self._mailbox_opened = False + self._mailbox_closed = False + self._key = None self._started_get_code = False self._next_outbound_phase = 0 @@ -88,7 +110,6 @@ class _Wormhole: self._timing_started = self._timing.add("wormhole") self._ws = None self._ws_t = None # timing Event - self._ws_channel_claimed = False self._error = None def _make_endpoint(self, hostname, port): @@ -100,11 +121,9 @@ class _Wormhole: @inlineCallbacks def _get_websocket(self): if not self._ws: - # TODO: if we lose the connection, make a new one - #from twisted.python import log - #log.startLogging(sys.stderr) + # TODO: if we lose the connection, make a new one, re-establish + # the state assert self._side - assert not self._ws_channel_claimed p = urlparse(self._ws_url) f = WSFactory(self._ws_url) f.wormhole = self @@ -118,12 +137,12 @@ class _Wormhole: self._ws_t = self._timing.add("websocket") # f.d is errbacked if WebSocket negotiation fails yield f.d # WebSocket drops data sent before onOpen() fires - self._ws_send(u"bind", appid=self._appid, side=self._side) - # the socket is connected, and bound, but no channel has been claimed + self._ws_send_command(u"bind", appid=self._appid, side=self._side) + # the socket is connected, and bound, but no nameplate has been claimed returnValue(self._ws) @inlineCallbacks - def _ws_send(self, mtype, **kwargs): + def _ws_send_command(self, mtype, **kwargs): ws = yield self._get_websocket() # msgid is used by misc/dump-timing.py to correlate our sends with # their receives, and vice versa. They are also correlated with the @@ -135,7 +154,7 @@ class _Wormhole: self._timing.add("ws_send", _side=self._side, **kwargs) ws.sendMessage(payload, False) - def _ws_dispatch_msg(self, payload): + def _ws_dispatch_response(self, payload): msg = json.loads(payload.decode("utf-8")) self._timing.add("ws_receive", _side=self._side, message=msg) mtype = msg["type"] @@ -205,36 +224,120 @@ class _Wormhole: return self._signal_error(err) @inlineCallbacks - def _claim_channel_and_watch(self): - assert self._channelid is not None - yield self._get_websocket() - if not self._ws_channel_claimed: - yield self._ws_send(u"claim", channelid=self._channelid) - self._ws_channel_claimed = True - yield self._ws_send(u"watch", channelid=self._channelid) + def _claim_nameplate(self): + if not self._nameplate_id: raise UsageError + if self._nameplate_claimed: raise UsageError + yield self._ws_send_command(u"claim", nameplate=self._nameplate_id) + # provokes "claimed" response + + def _ws_handle_claimed(self, msg): + mailbox_id = msg["mailbox"] + assert isinstance(mailbox_id, type(u"")), type(mailbox_id) + self._mailbox_id = mailbox_id + self._open_mailbox() + + @inlineCallbacks + def _release_nameplate(self): + if not self._nameplate_claimed: raise UsageError + if self._nameplate_released: raise UsageError + yield self._ws_send_command(u"release") + self._nameplate_released = True + + + @inlineCallbacks + def _open_mailbox(self): + if not self._mailbox_id: raise UsageError + if self._mailbox_opened: raise UsageError + yield self._ws_send_command(u"open", mailbox=self._mailbox_id) + self._mailbox_opened = True + # causes old messages to be sent now, and subscribes to new messages + + @inlineCallbacks + def _close_mailbox(self): + if not self._mailbox_id: raise UsageError + if not self._mailbox_opened: raise UsageError + if self._mailbox_closed: raise UsageError + yield self._ws_send_command(u"close") + self._mailbox_closed = True + + + @inlineCallbacks + def _msg_send(self, phase, body, wait=False): + if phase in self._sent_messages: raise UsageError + if not self._mailbox_opened: raise UsageError + if self._mailbox_closed: raise UsageError + self._sent_messages[phase] = body + # TODO: retry on failure, with exponential backoff. We're guarding + # against the rendezvous server being temporarily offline. + t = self._timing.add("add", phase=phase, wait=wait) + yield self._ws_send_command(u"add", phase=phase, + body=hexlify(body).decode("ascii")) + if wait: + while phase not in self._delivered_messages: + yield self._sleep() + t.finish() + + def _ws_handle_message(self, msg): + # any message in the mailbox means we no longer need the nameplate + if not self._nameplate_released: + self._release_nameplate() # XXX returns Deferred + + m = msg["message"] + phase = m["phase"] + body = unhexlify(m["body"].encode("ascii")) + if phase in self._sent_messages and self._sent_messages[phase] == body: + self._delivered_messages.add(phase) # ack by server + self._wakeup() + return # ignore echoes of our outbound messages + if phase in self._received_messages: + # a nameplate collision would cause this + err = ServerError("got duplicate phase %s" % phase, self._ws_url) + return self._signal_error(err) + self._received_messages[phase] = body + if phase == u"confirm": + # TODO: we might not have a master key yet, if the caller wasn't + # waiting in _get_master_key() when a back-to-back pake+_confirm + # message pair arrived. + confkey = self.derive_key(u"wormhole:confirmation") + nonce = body[:CONFMSG_NONCE_LENGTH] + if body != make_confmsg(confkey, nonce): + # this makes all API calls fail + return self._signal_error(WrongPasswordError()) + # now notify anyone waiting on it + self._wakeup() + + @inlineCallbacks + def _msg_get(self, phase): + with self._timing.add("get", phase=phase): + while phase not in self._received_messages: + yield self._sleep() # we can wait a long time here + # that will throw an error if something goes wrong + msg = self._received_messages[phase] + returnValue(msg) # entry point 1: generate a new code @inlineCallbacks - def get_code(self, code_length=2): # rename to allocate_code()? create_? + def get_code(self, code_length=2): # XX rename to allocate_code()? create_? if self._code is not None: raise UsageError if self._started_get_code: raise UsageError self._started_get_code = True with self._timing.add("API get_code"): with self._timing.add("allocate"): - yield self._ws_send(u"allocate") - while self._channelid is None: + yield self._ws_send_command(u"allocate") + while self._nameplate_id is None: yield self._sleep() - code = codes.make_code(self._channelid, code_length) + code = codes.make_code(self._nameplate_id, code_length) assert isinstance(code, type(u"")), type(code) self._set_code(code) self._start() returnValue(code) def _ws_handle_allocated(self, msg): - if self._channelid is not None: - return self._signal_error("got duplicate channelid") - self._channelid = msg["channelid"] - assert isinstance(self._channelid, type(u"")), type(self._channelid) + if self._nameplate_id is not None: + return self._signal_error("got duplicate 'allocated' response") + nid = msg["nameplate"] + assert isinstance(nid, type(u"")), type(nid) + self._nameplate_id = nid self._wakeup() def _start(self): @@ -248,8 +351,8 @@ class _Wormhole: @inlineCallbacks def input_code(self, prompt="Enter wormhole code: ", code_length=2): def _lister(): - return blockingCallFromThread(self._reactor, self._list_channels) - # fetch the list of channels ahead of time, to give us a chance to + return blockingCallFromThread(self._reactor, self._list_nameplates) + # fetch the list of nameplates ahead of time, to give us a chance to # discover the welcome message (and warn the user about an obsolete # client) # @@ -257,13 +360,13 @@ class _Wormhole: # latency in the user's indecision and slow typing. If we're lucky # the answer will come back before they hit TAB. with self._timing.add("API input_code"): - initial_channelids = yield self._list_channels() + initial_nameplate_ids = yield self._list_nameplates() with self._timing.add("input code", waiting="user"): t = self._reactor.addSystemEventTrigger("before", "shutdown", self._warn_readline) code = yield deferToThread(codes.input_code_with_completion, prompt, - initial_channelids, _lister, + initial_nameplate_ids, _lister, code_length) self._reactor.removeSystemEventTrigger(t) returnValue(code) # application will give this to set_code() @@ -295,7 +398,7 @@ class _Wormhole: # --internal-get-code"), run it as a subprocess, let it inherit # stdin/stdout, send it SIGINT when we receive SIGINT ourselves. It # needs an RPC mechanism (over some extra file descriptors) to ask - # us to fetch the current channelid list. + # us to fetch the current nameplate_id list. # # Note that hard-terminating our process with os.kill(os.getpid(), # signal.SIGKILL), or SIGTERM, doesn't seem to work: the thread @@ -303,16 +406,16 @@ class _Wormhole: # readline finish. @inlineCallbacks - def _list_channels(self): + def _list_nameplates(self): with self._timing.add("list"): - self._latest_channelids = None - yield self._ws_send(u"list") - while self._latest_channelids is None: + self._latest_nameplate_ids = None + yield self._ws_send_command(u"list") + while self._latest_nameplate_ids is None: yield self._sleep() - returnValue(self._latest_channelids) + returnValue(self._latest_nameplate_ids) - def _ws_handle_channelids(self, msg): - self._latest_channelids = msg["channelids"] + def _ws_handle_nameplates(self, msg): + self._latest_nameplate_ids = msg["nameplates"] self._wakeup() # entry point 2b: paste in a fully-formed code @@ -323,8 +426,8 @@ class _Wormhole: if not mo: raise ValueError("code (%s) must start with NN-" % code) with self._timing.add("API set_code"): - self._channelid = mo.group(1) - assert isinstance(self._channelid, type(u"")), type(self._channelid) + self._nameplate_id = mo.group(1) + assert isinstance(self._nameplate_id, type(u"")), type(self._nameplate_id) self._set_code(code) self._start() @@ -344,7 +447,7 @@ class _Wormhole: "appid": self._appid, "relay_url": self._relay_url, "code": self._code, - "channelid": self._channelid, + "nameplate_id": self._nameplate_id, "side": self._side, "spake2": json.loads(self._sp.serialize().decode("ascii")), "msg1": hexlify(self._msg1).decode("ascii"), @@ -357,7 +460,7 @@ class _Wormhole: d = json.loads(data) self = klass(d["appid"], d["relay_url"], reactor) self._side = d["side"] - self._channelid = d["channelid"] + self._nameplate_id = d["nameplate_id"] self._set_code(d["code"]) sp_data = json.dumps(d["spake2"]).encode("ascii") self._sp = SPAKE2_Symmetric.from_serialized(sp_data) @@ -385,7 +488,7 @@ class _Wormhole: def _get_master_key(self): # TODO: prevent multiple invocation if not self._key: - yield self._claim_channel_and_watch() + yield self._claim_nameplate_and_watch() yield self._msg_send(u"pake", self._msg1) pake_msg = yield self._msg_get(u"pake") @@ -401,54 +504,6 @@ class _Wormhole: confmsg = make_confmsg(confkey, nonce) yield self._msg_send(u"confirm", confmsg, wait=True) - @inlineCallbacks - def _msg_send(self, phase, body, wait=False): - if phase in self._sent_messages: raise UsageError - self._sent_messages[phase] = body - # TODO: retry on failure, with exponential backoff. We're guarding - # against the rendezvous server being temporarily offline. - t = self._timing.add("add", phase=phase, wait=wait) - yield self._ws_send(u"add", channelid=self._channelid, phase=phase, - body=hexlify(body).decode("ascii")) - if wait: - while phase not in self._delivered_messages: - yield self._sleep() - t.finish() - - def _ws_handle_message(self, msg): - m = msg["message"] - phase = m["phase"] - body = unhexlify(m["body"].encode("ascii")) - if phase in self._sent_messages and self._sent_messages[phase] == body: - self._delivered_messages.add(phase) # ack by server - self._wakeup() - return # ignore echoes of our outbound messages - if phase in self._received_messages: - # a channel collision would cause this - err = ServerError("got duplicate phase %s" % phase, self._ws_url) - return self._signal_error(err) - self._received_messages[phase] = body - if phase == u"confirm": - # TODO: we might not have a master key yet, if the caller wasn't - # waiting in _get_master_key() when a back-to-back pake+_confirm - # message pair arrived. - confkey = self.derive_key(u"wormhole:confirmation") - nonce = body[:CONFMSG_NONCE_LENGTH] - if body != make_confmsg(confkey, nonce): - # this makes all API calls fail - return self._signal_error(WrongPasswordError()) - # now notify anyone waiting on it - self._wakeup() - - @inlineCallbacks - def _msg_get(self, phase): - with self._timing.add("get", phase=phase): - while phase not in self._received_messages: - yield self._sleep() # we can wait a long time here - # that will throw an error if something goes wrong - msg = self._received_messages[phase] - returnValue(msg) - def derive_key(self, purpose, length=SecretBox.KEY_SIZE): if not isinstance(purpose, type(u"")): raise TypeError(type(purpose)) if self._key is None: @@ -548,8 +603,7 @@ class _Wormhole: @inlineCallbacks def _release(self, mood): with self._timing.add("release"): - yield self._ws_send(u"release", channelid=self._channelid, - mood=mood) + yield self._ws_send_command(u"release", mood=mood) while self._released_status is None: yield self._sleep(wake_on_error=False) # TODO: set a timeout, don't wait forever for an ack @@ -562,6 +616,7 @@ class _Wormhole: def wormhole(appid, relay_url, reactor, tor_manager=None, timing=None): w = _Wormhole(appid, relay_url, reactor, tor_manager, timing) + w._start() return w def wormhole_from_serialized(data, reactor):