From 7a2840058677ac461e6f7f8c65b08925c916e105 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Sat, 3 Oct 2015 17:46:11 -0700 Subject: [PATCH] split transcribe.py into two layers: comms and crypto --- src/wormhole/blocking/transcribe.py | 215 ++++++++++++++------------- src/wormhole/twisted/transcribe.py | 222 +++++++++++++++------------- 2 files changed, 237 insertions(+), 200 deletions(-) diff --git a/src/wormhole/blocking/transcribe.py b/src/wormhole/blocking/transcribe.py index 8fd84d3..e580253 100644 --- a/src/wormhole/blocking/transcribe.py +++ b/src/wormhole/blocking/transcribe.py @@ -24,6 +24,106 @@ MINUTE = 60*SECOND # POST /CID/deallocate {side: SIDE} -> {status: waiting | deleted} # all JSON responses include a "welcome:{..}" key +class Channel: + def __init__(self, relay, channel_id, side, handle_welcome): + self._channel_url = "%s%d" % (relay, channel_id) + self._side = side + self._handle_welcome = handle_welcome + self._messages = set() # (phase,body) , body is bytes + self._sent_messages = set() # (phase,body) + self._started = time.time() + self._wait = 0.5*SECOND + self._timeout = 3*MINUTE + + def _add_inbound_messages(self, messages): + for msg in messages: + phase = msg["phase"] + body = unhexlify(msg["body"].encode("ascii")) + self._messages.add( (phase, body) ) + + def _find_inbound_message(self, phase): + for (their_phase,body) in self._messages - self._sent_messages: + if their_phase == phase: + return body + return None + + def send(self, phase, msg): + # TODO: retry on failure, with exponential backoff. We're guarding + # against the rendezvous server being temporarily offline. + if not isinstance(phase, type(u"")): raise UsageError(type(phase)) + if not isinstance(msg, type(b"")): raise UsageError(type(msg)) + self._sent_messages.add( (phase,msg) ) + payload = {"side": self._side, + "phase": phase, + "body": hexlify(msg).decode("ascii")} + data = json.dumps(payload).encode("utf-8") + r = requests.post(self._channel_url, data=data) + r.raise_for_status() + resp = r.json() + self._add_inbound_messages(resp["messages"]) + + def get(self, phase): + if not isinstance(phase, type(u"")): raise UsageError(type(phase)) + # For now, server errors cause the client to fail. TODO: don't. This + # will require changing the client to re-post messages when the + # server comes back up. + + # fire with a bytestring of the first message for 'phase' that wasn't + # one of ours. It will either come from previously-received messages, + # or from an EventSource that we attach to the corresponding URL + body = self._find_inbound_message(phase) + while body is None: + remaining = self._started + self._timeout - time.time() + if remaining < 0: + return Timeout + f = EventSourceFollower(self._channel_url, remaining) + # we loop here until the connection is lost, or we see the + # message we want + for (eventtype, data) in f.iter_events(): + if eventtype == "welcome": + self._handle_welcome(json.loads(data)) + if eventtype == "message": + self._add_inbound_messages([json.loads(data)]) + body = self._find_inbound_message(phase) + if body: + f.close() + break + if not body: + time.sleep(self._wait) + return body + + def deallocate(self): + # only try once, no retries + data = json.dumps({"side": self._side}).encode("utf-8") + requests.post(self._channel_url+"/deallocate", data=data) + # ignore POST failure, don't call r.raise_for_status() + +class ChannelManager: + def __init__(self, relay, side, handle_welcome): + self._relay = relay + self._side = side + self._handle_welcome = handle_welcome + + def list_channels(self): + r = requests.get(self._relay + "list") + r.raise_for_status() + channel_ids = r.json()["channel-ids"] + return channel_ids + + def allocate(self): + data = json.dumps({"side": self._side}).encode("utf-8") + r = requests.post(self._relay + "allocate", data=data) + r.raise_for_status() + data = r.json() + if "welcome" in data: + self._handle_welcome(data["welcome"]) + channel_id = data["channel-id"] + return channel_id + + def connect(self, channel_id): + return Channel(self._relay, channel_id, self._side, + self._handle_welcome) + class Wormhole: motd_displayed = False version_warning_displayed = False @@ -33,16 +133,12 @@ class Wormhole: self.appid = appid self.relay = relay if not self.relay.endswith("/"): raise UsageError - self.started = time.time() - self.wait = 0.5*SECOND - self.timeout = 3*MINUTE - self.side = None + side = hexlify(os.urandom(5)).decode("ascii") + self._channel_manager = ChannelManager(relay, side, + self.handle_welcome) self.code = None self.key = None self.verifier = None - self._channel_url = None - self._messages = set() # (phase,body) , body is bytes - self._sent_messages = set() # (phase,body) def handle_welcome(self, welcome): if ("motd" in welcome and @@ -66,43 +162,25 @@ class Wormhole: if "error" in welcome: raise ServerError(welcome["error"], self.relay) - def _allocate_channel(self): - data = json.dumps({"side": self.side}).encode("utf-8") - r = requests.post(self.relay + "allocate", data=data) - r.raise_for_status() - data = r.json() - if "welcome" in data: - self.handle_welcome(data["welcome"]) - channel_id = data["channel-id"] - return channel_id - def get_code(self, code_length=2): if self.code is not None: raise UsageError - self.side = hexlify(os.urandom(5)).decode("ascii") - channel_id = self._allocate_channel() # allocate channel + channel_id = self._channel_manager.allocate() code = codes.make_code(channel_id, code_length) assert isinstance(code, str), type(code) self._set_code_and_channel_id(code) self._start() return code - def list_channels(self): - r = requests.get(self.relay + "list") - r.raise_for_status() - channel_ids = r.json()["channel-ids"] - return channel_ids - def input_code(self, prompt="Enter wormhole code: ", code_length=2): - code = codes.input_code_with_completion(prompt, self.list_channels, + lister = self._channel_manager.list_channels + code = codes.input_code_with_completion(prompt, lister, code_length) return code def set_code(self, code): # used for human-made pre-generated codes if not isinstance(code, str): raise UsageError if self.code is not None: raise UsageError - if self.side is not None: raise UsageError self._set_code_and_channel_id(code) - self.side = hexlify(os.urandom(5)).decode("ascii") self._start() def _set_code_and_channel_id(self, code): @@ -110,9 +188,9 @@ class Wormhole: mo = re.search(r'^(\d+)-', code) if not mo: raise ValueError("code (%s) must start with NN-" % code) - self.channel_id = int(mo.group(1)) - self._channel_url = "%s%d" % (self.relay, self.channel_id) self.code = code + channel_id = int(mo.group(1)) + self.channel = self._channel_manager.connect(channel_id) def _start(self): # allocate the rest now too, so it can be serialized @@ -120,63 +198,6 @@ class Wormhole: idSymmetric=self.appid) self.msg1 = self.sp.start() - def _add_inbound_messages(self, messages): - for msg in messages: - phase = msg["phase"] - body = unhexlify(msg["body"].encode("ascii")) - self._messages.add( (phase, body) ) - - def _find_inbound_message(self, phase): - for (their_phase,body) in self._messages - self._sent_messages: - if their_phase == phase: - return body - return None - - def _send_message(self, phase, msg): - # TODO: retry on failure, with exponential backoff. We're guarding - # against the rendezvous server being temporarily offline. - if not isinstance(phase, type(u"")): raise UsageError(type(phase)) - if not isinstance(msg, type(b"")): raise UsageError(type(msg)) - self._sent_messages.add( (phase,msg) ) - payload = {"side": self.side, - "phase": phase, - "body": hexlify(msg).decode("ascii")} - data = json.dumps(payload).encode("utf-8") - r = requests.post(self._channel_url, data=data) - r.raise_for_status() - resp = r.json() - self._add_inbound_messages(resp["messages"]) - - def _get_message(self, phase): - if not isinstance(phase, type(u"")): raise UsageError(type(phase)) - # For now, server errors cause the client to fail. TODO: don't. This - # will require changing the client to re-post messages when the - # server comes back up. - - # fire with a bytestring of the first message for 'phase' that wasn't - # one of ours. It will either come from previously-received messages, - # or from an EventSource that we attach to the corresponding URL - body = self._find_inbound_message(phase) - while body is None: - remaining = self.started + self.timeout - time.time() - if remaining < 0: - return Timeout - f = EventSourceFollower(self._channel_url, remaining) - # we loop here until the connection is lost, or we see the - # message we want - for (eventtype, data) in f.iter_events(): - if eventtype == "welcome": - self.handle_welcome(json.loads(data)) - if eventtype == "message": - self._add_inbound_messages([json.loads(data)]) - body = self._find_inbound_message(phase) - if body: - f.close() - break - if not body: - time.sleep(self.wait) - return body - def derive_key(self, purpose, length=SecretBox.KEY_SIZE): if not isinstance(purpose, type(b"")): raise UsageError return HKDF(self.key, length, CTXinfo=purpose) @@ -200,14 +221,14 @@ class Wormhole: def _get_key(self): if not self.key: - self._send_message(u"pake", self.msg1) - pake_msg = self._get_message(u"pake") + self.channel.send(u"pake", self.msg1) + pake_msg = self.channel.get(u"pake") self.key = self.sp.finish(pake_msg) self.verifier = self.derive_key(self.appid+b":Verifier") def get_verifier(self): if self.code is None: raise UsageError - if self.channel_id is None: raise UsageError + if self.channel is None: raise UsageError self._get_key() return self.verifier @@ -215,12 +236,12 @@ class Wormhole: # only call this once if not isinstance(outbound_data, type(b"")): raise UsageError if self.code is None: raise UsageError - if self.channel_id is None: raise UsageError + if self.channel is None: raise UsageError try: self._get_key() return self._get_data2(outbound_data) finally: - self._deallocate() + self.channel.deallocate() def _get_data2(self, outbound_data): # Without predefined roles, we can't derive predictably unique keys @@ -229,9 +250,9 @@ class Wormhole: data_key = self.derive_key(b"data-key") outbound_encrypted = self._encrypt_data(data_key, outbound_data) - self._send_message(u"data", outbound_encrypted) + self.channel.send(u"data", outbound_encrypted) - inbound_encrypted = self._get_message(u"data") + inbound_encrypted = self.channel.get(u"data") # _find_inbound_message() ignores any inbound message that matches # something we previously sent out, so we don't need to explicitly # check for reflection. A reflection attack will just not progress. @@ -240,9 +261,3 @@ class Wormhole: return inbound_data except CryptoError: raise WrongPasswordError - - def _deallocate(self): - # only try once, no retries - data = json.dumps({"side": self.side}).encode("utf-8") - requests.post(self._channel_url+"/deallocate", data=data) - # ignore POST failure, don't call r.raise_for_status() diff --git a/src/wormhole/twisted/transcribe.py b/src/wormhole/twisted/transcribe.py index a681560..eaa479f 100644 --- a/src/wormhole/twisted/transcribe.py +++ b/src/wormhole/twisted/transcribe.py @@ -32,6 +32,112 @@ class DataProducer: pass +def post_json(agent, url, request_body): + # POST a JSON body to a URL, parsing the response as JSON + data = json.dumps(request_body).encode("utf-8") + d = agent.request("POST", url, bodyProducer=DataProducer(data)) + def _check_error(resp): + if resp.code != 200: + raise web_error.Error(resp.code, resp.phrase) + return resp + d.addCallback(_check_error) + d.addCallback(web_client.readBody) + d.addCallback(lambda data: json.loads(data)) + return d + +class Channel: + def __init__(self, relay, channel_id, side, handle_welcome, + agent): + self._channel_url = "%s%d" % (relay, channel_id) + self._side = side + self._handle_welcome = handle_welcome + self._agent = agent + self._messages = set() # (phase,body) , body is bytes + self._sent_messages = set() # (phase,body) + + def _add_inbound_messages(self, messages): + for msg in messages: + phase = msg["phase"] + body = unhexlify(msg["body"].encode("ascii")) + self._messages.add( (phase, body) ) + + def _find_inbound_message(self, phase): + for (their_phase,body) in self._messages - self._sent_messages: + if their_phase == phase: + return body + return None + + def send(self, phase, msg): + # TODO: retry on failure, with exponential backoff. We're guarding + # against the rendezvous server being temporarily offline. + if not isinstance(phase, type(u"")): raise UsageError(type(phase)) + if not isinstance(msg, type(b"")): raise UsageError(type(msg)) + self._sent_messages.add( (phase,msg) ) + payload = {"side": self._side, + "phase": phase, + "body": hexlify(msg).decode("ascii")} + d = post_json(self._agent, self._channel_url, payload) + d.addCallback(lambda resp: self._add_inbound_messages(resp["messages"])) + return d + + def get(self, phase): + # fire with a bytestring of the first message for 'phase' that wasn't + # one of ours. It will either come from previously-received messages, + # or from an EventSource that we attach to the corresponding URL + body = self._find_inbound_message(phase) + if body is not None: + return defer.succeed(body) + d = defer.Deferred() + msgs = [] + def _handle(name, data): + if name == "welcome": + self._handle_welcome(json.loads(data)) + if name == "message": + self._add_inbound_messages([json.loads(data)]) + body = self._find_inbound_message(phase) + if body is not None and not msgs: + msgs.append(body) + d.callback(None) + # TODO: use agent=self._agent + es = ReconnectingEventSource(self._channel_url, _handle) + es.startService() # TODO: .setServiceParent(self) + es.activate() + d.addCallback(lambda _: es.deactivate()) + d.addCallback(lambda _: es.stopService()) + d.addCallback(lambda _: msgs[0]) + return d + + def deallocate(self, res): + # only try once, no retries + d = post_json(self._agent, self._channel_url+"/deallocate", + {"side": self._side}) + d.addBoth(lambda _: res) # ignore POST failure, pass-through result + return d + +class ChannelManager: + def __init__(self, relay, side, handle_welcome): + self._relay = relay + self._side = side + self._handle_welcome = handle_welcome + self._agent = web_client.Agent(reactor) + + def allocate(self): + url = self._relay + "allocate" + d = post_json(self._agent, url, {"side": self._side}) + def _got_channel(data): + if "welcome" in data: + self._handle_welcome(data["welcome"]) + return data["channel-id"] + d.addCallback(_got_channel) + return d + + def list_channels(self): + raise NotImplementedError + + def connect(self, channel_id): + return Channel(self._relay, channel_id, self._side, + self._handle_welcome, self._agent) + class Wormhole: motd_displayed = False version_warning_displayed = False @@ -40,14 +146,15 @@ class Wormhole: if not isinstance(appid, type(b"")): raise UsageError self.appid = appid self.relay = relay - self.agent = web_client.Agent(reactor) - self.side = None + self._set_side(hexlify(os.urandom(5)).decode("ascii")) self.code = None self.key = None self._started_get_code = False - self._channel_url = None - self._messages = set() # (phase,body) , body is bytes - self._sent_messages = set() # (phase,body) + + def _set_side(self, side): + self._side = side + self._channel_manager = ChannelManager(self.relay, self._side, + self.handle_welcome) def handle_welcome(self, welcome): if ("motd" in welcome and @@ -71,35 +178,11 @@ class Wormhole: if "error" in welcome: raise ServerError(welcome["error"], self.relay) - def _post_json(self, url, post_json): - # POST a JSON body to a URL, parsing the response as JSON - data = json.dumps(post_json).encode("utf-8") - d = self.agent.request("POST", url, bodyProducer=DataProducer(data)) - def _check_error(resp): - if resp.code != 200: - raise web_error.Error(resp.code, resp.phrase) - return resp - d.addCallback(_check_error) - d.addCallback(web_client.readBody) - d.addCallback(lambda data: json.loads(data)) - return d - - def _allocate_channel(self): - url = self.relay + "allocate" - d = self._post_json(url, {"side": self.side}) - def _got_channel(data): - if "welcome" in data: - self.handle_welcome(data["welcome"]) - return data["channel-id"] - d.addCallback(_got_channel) - return d - def get_code(self, code_length=2): if self.code is not None: raise UsageError if self._started_get_code: raise UsageError self._started_get_code = True - self.side = hexlify(os.urandom(5)) - d = self._allocate_channel() + d = self._channel_manager.allocate() def _got_channel_id(channel_id): code = codes.make_code(channel_id, code_length) assert isinstance(code, str), type(code) @@ -112,9 +195,7 @@ class Wormhole: def set_code(self, code): if not isinstance(code, str): raise UsageError if self.code is not None: raise UsageError - if self.side is not None: raise UsageError self._set_code_and_channel_id(code) - self.side = hexlify(os.urandom(5)) self._start() def _set_code_and_channel_id(self, code): @@ -122,9 +203,9 @@ class Wormhole: mo = re.search(r'^(\d+)-', code) if not mo: raise ValueError("code (%s) must start with NN-" % code) - self.channel_id = int(mo.group(1)) - self._channel_url = "%s%d" % (self.relay, self.channel_id) self.code = code + channel_id = int(mo.group(1)) + self.channel = self._channel_manager.connect(channel_id) def _start(self): # allocate the rest now too, so it can be serialized @@ -141,7 +222,7 @@ class Wormhole: "appid": self.appid, "relay": self.relay, "code": self.code, - "side": self.side, + "side": self._side, "spake2": json.loads(self.sp.serialize()), "msg1": self.msg1.encode("hex"), } @@ -151,64 +232,12 @@ class Wormhole: def from_serialized(klass, data): d = json.loads(data) self = klass(d["appid"].encode("ascii"), d["relay"].encode("ascii")) + self._set_side(d["side"].encode("ascii")) self._set_code_and_channel_id(d["code"].encode("ascii")) - self.side = d["side"].encode("ascii") self.sp = SPAKE2_Symmetric.from_serialized(json.dumps(d["spake2"])) self.msg1 = d["msg1"].decode("hex") return self - def _add_inbound_messages(self, messages): - for msg in messages: - phase = msg["phase"] - body = unhexlify(msg["body"].encode("ascii")) - self._messages.add( (phase, body) ) - - def _find_inbound_message(self, phase): - for (their_phase,body) in self._messages - self._sent_messages: - if their_phase == phase: - return body - return None - - def _send_message(self, phase, msg): - # TODO: retry on failure, with exponential backoff. We're guarding - # against the rendezvous server being temporarily offline. - if not isinstance(phase, type(u"")): raise UsageError(type(phase)) - if not isinstance(msg, type(b"")): raise UsageError(type(msg)) - self._sent_messages.add( (phase,msg) ) - payload = {"side": self.side, - "phase": phase, - "body": hexlify(msg).decode("ascii")} - d = self._post_json(self._channel_url, payload) - d.addCallback(lambda resp: self._add_inbound_messages(resp["messages"])) - return d - - def _get_message(self, phase): - # fire with a bytestring of the first message for 'phase' that wasn't - # one of ours. It will either come from previously-received messages, - # or from an EventSource that we attach to the corresponding URL - body = self._find_inbound_message(phase) - if body is not None: - return defer.succeed(body) - d = defer.Deferred() - msgs = [] - def _handle(name, data): - if name == "welcome": - self.handle_welcome(json.loads(data)) - if name == "message": - self._add_inbound_messages([json.loads(data)]) - body = self._find_inbound_message(phase) - if body is not None and not msgs: - msgs.append(body) - d.callback(None) - # TODO: use agent=self.agent - es = ReconnectingEventSource(self._channel_url, _handle) - es.startService() # TODO: .setServiceParent(self) - es.activate() - d.addCallback(lambda _: es.deactivate()) - d.addCallback(lambda _: es.stopService()) - d.addCallback(lambda _: msgs[0]) - return d - def derive_key(self, purpose, length=SecretBox.KEY_SIZE): if self.key is None: # call after get_verifier() or get_data() @@ -237,8 +266,8 @@ class Wormhole: # TODO: prevent multiple invocation if self.key: return defer.succeed(self.key) - d = self._send_message(u"pake", self.msg1) - d.addCallback(lambda _: self._get_message(u"pake")) + d = self.channel.send(u"pake", self.msg1) + d.addCallback(lambda _: self.channel.get(u"pake")) def _got_pake(pake_msg): key = self.sp.finish(pake_msg) self.key = key @@ -259,7 +288,7 @@ class Wormhole: if self.code is None: raise UsageError d = self._get_key() d.addCallback(self._get_data2, outbound_data) - d.addBoth(self._deallocate) + d.addBoth(self.channel.deallocate) return d def _get_data2(self, key, outbound_data): @@ -269,9 +298,9 @@ class Wormhole: data_key = self.derive_key(b"data-key") outbound_encrypted = self._encrypt_data(data_key, outbound_data) - d = self._send_message(u"data", outbound_encrypted) + d = self.channel.send(u"data", outbound_encrypted) - d.addCallback(lambda _: self._get_message(u"data")) + d.addCallback(lambda _: self.channel.get(u"data")) def _got_data(inbound_encrypted): #if inbound_encrypted == outbound_encrypted: # raise ReflectionAttack @@ -282,10 +311,3 @@ class Wormhole: raise WrongPasswordError d.addCallback(_got_data) return d - - def _deallocate(self, res): - # only try once, no retries - d = self._post_json(self._channel_url+"/deallocate", - {"side": self.side}) - d.addBoth(lambda _: res) # ignore POST failure, pass-through result - return d