split transcribe.py into two layers: comms and crypto
This commit is contained in:
parent
617bb03ad5
commit
7a28400586
|
@ -24,6 +24,106 @@ MINUTE = 60*SECOND
|
||||||
# POST /CID/deallocate {side: SIDE} -> {status: waiting | deleted}
|
# POST /CID/deallocate {side: SIDE} -> {status: waiting | deleted}
|
||||||
# all JSON responses include a "welcome:{..}" key
|
# all JSON responses include a "welcome:{..}" key
|
||||||
|
|
||||||
|
class Channel:
|
||||||
|
def __init__(self, relay, channel_id, side, handle_welcome):
|
||||||
|
self._channel_url = "%s%d" % (relay, channel_id)
|
||||||
|
self._side = side
|
||||||
|
self._handle_welcome = handle_welcome
|
||||||
|
self._messages = set() # (phase,body) , body is bytes
|
||||||
|
self._sent_messages = set() # (phase,body)
|
||||||
|
self._started = time.time()
|
||||||
|
self._wait = 0.5*SECOND
|
||||||
|
self._timeout = 3*MINUTE
|
||||||
|
|
||||||
|
def _add_inbound_messages(self, messages):
|
||||||
|
for msg in messages:
|
||||||
|
phase = msg["phase"]
|
||||||
|
body = unhexlify(msg["body"].encode("ascii"))
|
||||||
|
self._messages.add( (phase, body) )
|
||||||
|
|
||||||
|
def _find_inbound_message(self, phase):
|
||||||
|
for (their_phase,body) in self._messages - self._sent_messages:
|
||||||
|
if their_phase == phase:
|
||||||
|
return body
|
||||||
|
return None
|
||||||
|
|
||||||
|
def send(self, phase, msg):
|
||||||
|
# TODO: retry on failure, with exponential backoff. We're guarding
|
||||||
|
# against the rendezvous server being temporarily offline.
|
||||||
|
if not isinstance(phase, type(u"")): raise UsageError(type(phase))
|
||||||
|
if not isinstance(msg, type(b"")): raise UsageError(type(msg))
|
||||||
|
self._sent_messages.add( (phase,msg) )
|
||||||
|
payload = {"side": self._side,
|
||||||
|
"phase": phase,
|
||||||
|
"body": hexlify(msg).decode("ascii")}
|
||||||
|
data = json.dumps(payload).encode("utf-8")
|
||||||
|
r = requests.post(self._channel_url, data=data)
|
||||||
|
r.raise_for_status()
|
||||||
|
resp = r.json()
|
||||||
|
self._add_inbound_messages(resp["messages"])
|
||||||
|
|
||||||
|
def get(self, phase):
|
||||||
|
if not isinstance(phase, type(u"")): raise UsageError(type(phase))
|
||||||
|
# For now, server errors cause the client to fail. TODO: don't. This
|
||||||
|
# will require changing the client to re-post messages when the
|
||||||
|
# server comes back up.
|
||||||
|
|
||||||
|
# fire with a bytestring of the first message for 'phase' that wasn't
|
||||||
|
# one of ours. It will either come from previously-received messages,
|
||||||
|
# or from an EventSource that we attach to the corresponding URL
|
||||||
|
body = self._find_inbound_message(phase)
|
||||||
|
while body is None:
|
||||||
|
remaining = self._started + self._timeout - time.time()
|
||||||
|
if remaining < 0:
|
||||||
|
return Timeout
|
||||||
|
f = EventSourceFollower(self._channel_url, remaining)
|
||||||
|
# we loop here until the connection is lost, or we see the
|
||||||
|
# message we want
|
||||||
|
for (eventtype, data) in f.iter_events():
|
||||||
|
if eventtype == "welcome":
|
||||||
|
self._handle_welcome(json.loads(data))
|
||||||
|
if eventtype == "message":
|
||||||
|
self._add_inbound_messages([json.loads(data)])
|
||||||
|
body = self._find_inbound_message(phase)
|
||||||
|
if body:
|
||||||
|
f.close()
|
||||||
|
break
|
||||||
|
if not body:
|
||||||
|
time.sleep(self._wait)
|
||||||
|
return body
|
||||||
|
|
||||||
|
def deallocate(self):
|
||||||
|
# only try once, no retries
|
||||||
|
data = json.dumps({"side": self._side}).encode("utf-8")
|
||||||
|
requests.post(self._channel_url+"/deallocate", data=data)
|
||||||
|
# ignore POST failure, don't call r.raise_for_status()
|
||||||
|
|
||||||
|
class ChannelManager:
|
||||||
|
def __init__(self, relay, side, handle_welcome):
|
||||||
|
self._relay = relay
|
||||||
|
self._side = side
|
||||||
|
self._handle_welcome = handle_welcome
|
||||||
|
|
||||||
|
def list_channels(self):
|
||||||
|
r = requests.get(self._relay + "list")
|
||||||
|
r.raise_for_status()
|
||||||
|
channel_ids = r.json()["channel-ids"]
|
||||||
|
return channel_ids
|
||||||
|
|
||||||
|
def allocate(self):
|
||||||
|
data = json.dumps({"side": self._side}).encode("utf-8")
|
||||||
|
r = requests.post(self._relay + "allocate", data=data)
|
||||||
|
r.raise_for_status()
|
||||||
|
data = r.json()
|
||||||
|
if "welcome" in data:
|
||||||
|
self._handle_welcome(data["welcome"])
|
||||||
|
channel_id = data["channel-id"]
|
||||||
|
return channel_id
|
||||||
|
|
||||||
|
def connect(self, channel_id):
|
||||||
|
return Channel(self._relay, channel_id, self._side,
|
||||||
|
self._handle_welcome)
|
||||||
|
|
||||||
class Wormhole:
|
class Wormhole:
|
||||||
motd_displayed = False
|
motd_displayed = False
|
||||||
version_warning_displayed = False
|
version_warning_displayed = False
|
||||||
|
@ -33,16 +133,12 @@ class Wormhole:
|
||||||
self.appid = appid
|
self.appid = appid
|
||||||
self.relay = relay
|
self.relay = relay
|
||||||
if not self.relay.endswith("/"): raise UsageError
|
if not self.relay.endswith("/"): raise UsageError
|
||||||
self.started = time.time()
|
side = hexlify(os.urandom(5)).decode("ascii")
|
||||||
self.wait = 0.5*SECOND
|
self._channel_manager = ChannelManager(relay, side,
|
||||||
self.timeout = 3*MINUTE
|
self.handle_welcome)
|
||||||
self.side = None
|
|
||||||
self.code = None
|
self.code = None
|
||||||
self.key = None
|
self.key = None
|
||||||
self.verifier = None
|
self.verifier = None
|
||||||
self._channel_url = None
|
|
||||||
self._messages = set() # (phase,body) , body is bytes
|
|
||||||
self._sent_messages = set() # (phase,body)
|
|
||||||
|
|
||||||
def handle_welcome(self, welcome):
|
def handle_welcome(self, welcome):
|
||||||
if ("motd" in welcome and
|
if ("motd" in welcome and
|
||||||
|
@ -66,43 +162,25 @@ class Wormhole:
|
||||||
if "error" in welcome:
|
if "error" in welcome:
|
||||||
raise ServerError(welcome["error"], self.relay)
|
raise ServerError(welcome["error"], self.relay)
|
||||||
|
|
||||||
def _allocate_channel(self):
|
|
||||||
data = json.dumps({"side": self.side}).encode("utf-8")
|
|
||||||
r = requests.post(self.relay + "allocate", data=data)
|
|
||||||
r.raise_for_status()
|
|
||||||
data = r.json()
|
|
||||||
if "welcome" in data:
|
|
||||||
self.handle_welcome(data["welcome"])
|
|
||||||
channel_id = data["channel-id"]
|
|
||||||
return channel_id
|
|
||||||
|
|
||||||
def get_code(self, code_length=2):
|
def get_code(self, code_length=2):
|
||||||
if self.code is not None: raise UsageError
|
if self.code is not None: raise UsageError
|
||||||
self.side = hexlify(os.urandom(5)).decode("ascii")
|
channel_id = self._channel_manager.allocate()
|
||||||
channel_id = self._allocate_channel() # allocate channel
|
|
||||||
code = codes.make_code(channel_id, code_length)
|
code = codes.make_code(channel_id, code_length)
|
||||||
assert isinstance(code, str), type(code)
|
assert isinstance(code, str), type(code)
|
||||||
self._set_code_and_channel_id(code)
|
self._set_code_and_channel_id(code)
|
||||||
self._start()
|
self._start()
|
||||||
return code
|
return code
|
||||||
|
|
||||||
def list_channels(self):
|
|
||||||
r = requests.get(self.relay + "list")
|
|
||||||
r.raise_for_status()
|
|
||||||
channel_ids = r.json()["channel-ids"]
|
|
||||||
return channel_ids
|
|
||||||
|
|
||||||
def input_code(self, prompt="Enter wormhole code: ", code_length=2):
|
def input_code(self, prompt="Enter wormhole code: ", code_length=2):
|
||||||
code = codes.input_code_with_completion(prompt, self.list_channels,
|
lister = self._channel_manager.list_channels
|
||||||
|
code = codes.input_code_with_completion(prompt, lister,
|
||||||
code_length)
|
code_length)
|
||||||
return code
|
return code
|
||||||
|
|
||||||
def set_code(self, code): # used for human-made pre-generated codes
|
def set_code(self, code): # used for human-made pre-generated codes
|
||||||
if not isinstance(code, str): raise UsageError
|
if not isinstance(code, str): raise UsageError
|
||||||
if self.code is not None: raise UsageError
|
if self.code is not None: raise UsageError
|
||||||
if self.side is not None: raise UsageError
|
|
||||||
self._set_code_and_channel_id(code)
|
self._set_code_and_channel_id(code)
|
||||||
self.side = hexlify(os.urandom(5)).decode("ascii")
|
|
||||||
self._start()
|
self._start()
|
||||||
|
|
||||||
def _set_code_and_channel_id(self, code):
|
def _set_code_and_channel_id(self, code):
|
||||||
|
@ -110,9 +188,9 @@ class Wormhole:
|
||||||
mo = re.search(r'^(\d+)-', code)
|
mo = re.search(r'^(\d+)-', code)
|
||||||
if not mo:
|
if not mo:
|
||||||
raise ValueError("code (%s) must start with NN-" % code)
|
raise ValueError("code (%s) must start with NN-" % code)
|
||||||
self.channel_id = int(mo.group(1))
|
|
||||||
self._channel_url = "%s%d" % (self.relay, self.channel_id)
|
|
||||||
self.code = code
|
self.code = code
|
||||||
|
channel_id = int(mo.group(1))
|
||||||
|
self.channel = self._channel_manager.connect(channel_id)
|
||||||
|
|
||||||
def _start(self):
|
def _start(self):
|
||||||
# allocate the rest now too, so it can be serialized
|
# allocate the rest now too, so it can be serialized
|
||||||
|
@ -120,63 +198,6 @@ class Wormhole:
|
||||||
idSymmetric=self.appid)
|
idSymmetric=self.appid)
|
||||||
self.msg1 = self.sp.start()
|
self.msg1 = self.sp.start()
|
||||||
|
|
||||||
def _add_inbound_messages(self, messages):
|
|
||||||
for msg in messages:
|
|
||||||
phase = msg["phase"]
|
|
||||||
body = unhexlify(msg["body"].encode("ascii"))
|
|
||||||
self._messages.add( (phase, body) )
|
|
||||||
|
|
||||||
def _find_inbound_message(self, phase):
|
|
||||||
for (their_phase,body) in self._messages - self._sent_messages:
|
|
||||||
if their_phase == phase:
|
|
||||||
return body
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _send_message(self, phase, msg):
|
|
||||||
# TODO: retry on failure, with exponential backoff. We're guarding
|
|
||||||
# against the rendezvous server being temporarily offline.
|
|
||||||
if not isinstance(phase, type(u"")): raise UsageError(type(phase))
|
|
||||||
if not isinstance(msg, type(b"")): raise UsageError(type(msg))
|
|
||||||
self._sent_messages.add( (phase,msg) )
|
|
||||||
payload = {"side": self.side,
|
|
||||||
"phase": phase,
|
|
||||||
"body": hexlify(msg).decode("ascii")}
|
|
||||||
data = json.dumps(payload).encode("utf-8")
|
|
||||||
r = requests.post(self._channel_url, data=data)
|
|
||||||
r.raise_for_status()
|
|
||||||
resp = r.json()
|
|
||||||
self._add_inbound_messages(resp["messages"])
|
|
||||||
|
|
||||||
def _get_message(self, phase):
|
|
||||||
if not isinstance(phase, type(u"")): raise UsageError(type(phase))
|
|
||||||
# For now, server errors cause the client to fail. TODO: don't. This
|
|
||||||
# will require changing the client to re-post messages when the
|
|
||||||
# server comes back up.
|
|
||||||
|
|
||||||
# fire with a bytestring of the first message for 'phase' that wasn't
|
|
||||||
# one of ours. It will either come from previously-received messages,
|
|
||||||
# or from an EventSource that we attach to the corresponding URL
|
|
||||||
body = self._find_inbound_message(phase)
|
|
||||||
while body is None:
|
|
||||||
remaining = self.started + self.timeout - time.time()
|
|
||||||
if remaining < 0:
|
|
||||||
return Timeout
|
|
||||||
f = EventSourceFollower(self._channel_url, remaining)
|
|
||||||
# we loop here until the connection is lost, or we see the
|
|
||||||
# message we want
|
|
||||||
for (eventtype, data) in f.iter_events():
|
|
||||||
if eventtype == "welcome":
|
|
||||||
self.handle_welcome(json.loads(data))
|
|
||||||
if eventtype == "message":
|
|
||||||
self._add_inbound_messages([json.loads(data)])
|
|
||||||
body = self._find_inbound_message(phase)
|
|
||||||
if body:
|
|
||||||
f.close()
|
|
||||||
break
|
|
||||||
if not body:
|
|
||||||
time.sleep(self.wait)
|
|
||||||
return body
|
|
||||||
|
|
||||||
def derive_key(self, purpose, length=SecretBox.KEY_SIZE):
|
def derive_key(self, purpose, length=SecretBox.KEY_SIZE):
|
||||||
if not isinstance(purpose, type(b"")): raise UsageError
|
if not isinstance(purpose, type(b"")): raise UsageError
|
||||||
return HKDF(self.key, length, CTXinfo=purpose)
|
return HKDF(self.key, length, CTXinfo=purpose)
|
||||||
|
@ -200,14 +221,14 @@ class Wormhole:
|
||||||
|
|
||||||
def _get_key(self):
|
def _get_key(self):
|
||||||
if not self.key:
|
if not self.key:
|
||||||
self._send_message(u"pake", self.msg1)
|
self.channel.send(u"pake", self.msg1)
|
||||||
pake_msg = self._get_message(u"pake")
|
pake_msg = self.channel.get(u"pake")
|
||||||
self.key = self.sp.finish(pake_msg)
|
self.key = self.sp.finish(pake_msg)
|
||||||
self.verifier = self.derive_key(self.appid+b":Verifier")
|
self.verifier = self.derive_key(self.appid+b":Verifier")
|
||||||
|
|
||||||
def get_verifier(self):
|
def get_verifier(self):
|
||||||
if self.code is None: raise UsageError
|
if self.code is None: raise UsageError
|
||||||
if self.channel_id is None: raise UsageError
|
if self.channel is None: raise UsageError
|
||||||
self._get_key()
|
self._get_key()
|
||||||
return self.verifier
|
return self.verifier
|
||||||
|
|
||||||
|
@ -215,12 +236,12 @@ class Wormhole:
|
||||||
# only call this once
|
# only call this once
|
||||||
if not isinstance(outbound_data, type(b"")): raise UsageError
|
if not isinstance(outbound_data, type(b"")): raise UsageError
|
||||||
if self.code is None: raise UsageError
|
if self.code is None: raise UsageError
|
||||||
if self.channel_id is None: raise UsageError
|
if self.channel is None: raise UsageError
|
||||||
try:
|
try:
|
||||||
self._get_key()
|
self._get_key()
|
||||||
return self._get_data2(outbound_data)
|
return self._get_data2(outbound_data)
|
||||||
finally:
|
finally:
|
||||||
self._deallocate()
|
self.channel.deallocate()
|
||||||
|
|
||||||
def _get_data2(self, outbound_data):
|
def _get_data2(self, outbound_data):
|
||||||
# Without predefined roles, we can't derive predictably unique keys
|
# Without predefined roles, we can't derive predictably unique keys
|
||||||
|
@ -229,9 +250,9 @@ class Wormhole:
|
||||||
data_key = self.derive_key(b"data-key")
|
data_key = self.derive_key(b"data-key")
|
||||||
|
|
||||||
outbound_encrypted = self._encrypt_data(data_key, outbound_data)
|
outbound_encrypted = self._encrypt_data(data_key, outbound_data)
|
||||||
self._send_message(u"data", outbound_encrypted)
|
self.channel.send(u"data", outbound_encrypted)
|
||||||
|
|
||||||
inbound_encrypted = self._get_message(u"data")
|
inbound_encrypted = self.channel.get(u"data")
|
||||||
# _find_inbound_message() ignores any inbound message that matches
|
# _find_inbound_message() ignores any inbound message that matches
|
||||||
# something we previously sent out, so we don't need to explicitly
|
# something we previously sent out, so we don't need to explicitly
|
||||||
# check for reflection. A reflection attack will just not progress.
|
# check for reflection. A reflection attack will just not progress.
|
||||||
|
@ -240,9 +261,3 @@ class Wormhole:
|
||||||
return inbound_data
|
return inbound_data
|
||||||
except CryptoError:
|
except CryptoError:
|
||||||
raise WrongPasswordError
|
raise WrongPasswordError
|
||||||
|
|
||||||
def _deallocate(self):
|
|
||||||
# only try once, no retries
|
|
||||||
data = json.dumps({"side": self.side}).encode("utf-8")
|
|
||||||
requests.post(self._channel_url+"/deallocate", data=data)
|
|
||||||
# ignore POST failure, don't call r.raise_for_status()
|
|
||||||
|
|
|
@ -32,6 +32,112 @@ class DataProducer:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def post_json(agent, url, request_body):
|
||||||
|
# POST a JSON body to a URL, parsing the response as JSON
|
||||||
|
data = json.dumps(request_body).encode("utf-8")
|
||||||
|
d = agent.request("POST", url, bodyProducer=DataProducer(data))
|
||||||
|
def _check_error(resp):
|
||||||
|
if resp.code != 200:
|
||||||
|
raise web_error.Error(resp.code, resp.phrase)
|
||||||
|
return resp
|
||||||
|
d.addCallback(_check_error)
|
||||||
|
d.addCallback(web_client.readBody)
|
||||||
|
d.addCallback(lambda data: json.loads(data))
|
||||||
|
return d
|
||||||
|
|
||||||
|
class Channel:
|
||||||
|
def __init__(self, relay, channel_id, side, handle_welcome,
|
||||||
|
agent):
|
||||||
|
self._channel_url = "%s%d" % (relay, channel_id)
|
||||||
|
self._side = side
|
||||||
|
self._handle_welcome = handle_welcome
|
||||||
|
self._agent = agent
|
||||||
|
self._messages = set() # (phase,body) , body is bytes
|
||||||
|
self._sent_messages = set() # (phase,body)
|
||||||
|
|
||||||
|
def _add_inbound_messages(self, messages):
|
||||||
|
for msg in messages:
|
||||||
|
phase = msg["phase"]
|
||||||
|
body = unhexlify(msg["body"].encode("ascii"))
|
||||||
|
self._messages.add( (phase, body) )
|
||||||
|
|
||||||
|
def _find_inbound_message(self, phase):
|
||||||
|
for (their_phase,body) in self._messages - self._sent_messages:
|
||||||
|
if their_phase == phase:
|
||||||
|
return body
|
||||||
|
return None
|
||||||
|
|
||||||
|
def send(self, phase, msg):
|
||||||
|
# TODO: retry on failure, with exponential backoff. We're guarding
|
||||||
|
# against the rendezvous server being temporarily offline.
|
||||||
|
if not isinstance(phase, type(u"")): raise UsageError(type(phase))
|
||||||
|
if not isinstance(msg, type(b"")): raise UsageError(type(msg))
|
||||||
|
self._sent_messages.add( (phase,msg) )
|
||||||
|
payload = {"side": self._side,
|
||||||
|
"phase": phase,
|
||||||
|
"body": hexlify(msg).decode("ascii")}
|
||||||
|
d = post_json(self._agent, self._channel_url, payload)
|
||||||
|
d.addCallback(lambda resp: self._add_inbound_messages(resp["messages"]))
|
||||||
|
return d
|
||||||
|
|
||||||
|
def get(self, phase):
|
||||||
|
# fire with a bytestring of the first message for 'phase' that wasn't
|
||||||
|
# one of ours. It will either come from previously-received messages,
|
||||||
|
# or from an EventSource that we attach to the corresponding URL
|
||||||
|
body = self._find_inbound_message(phase)
|
||||||
|
if body is not None:
|
||||||
|
return defer.succeed(body)
|
||||||
|
d = defer.Deferred()
|
||||||
|
msgs = []
|
||||||
|
def _handle(name, data):
|
||||||
|
if name == "welcome":
|
||||||
|
self._handle_welcome(json.loads(data))
|
||||||
|
if name == "message":
|
||||||
|
self._add_inbound_messages([json.loads(data)])
|
||||||
|
body = self._find_inbound_message(phase)
|
||||||
|
if body is not None and not msgs:
|
||||||
|
msgs.append(body)
|
||||||
|
d.callback(None)
|
||||||
|
# TODO: use agent=self._agent
|
||||||
|
es = ReconnectingEventSource(self._channel_url, _handle)
|
||||||
|
es.startService() # TODO: .setServiceParent(self)
|
||||||
|
es.activate()
|
||||||
|
d.addCallback(lambda _: es.deactivate())
|
||||||
|
d.addCallback(lambda _: es.stopService())
|
||||||
|
d.addCallback(lambda _: msgs[0])
|
||||||
|
return d
|
||||||
|
|
||||||
|
def deallocate(self, res):
|
||||||
|
# only try once, no retries
|
||||||
|
d = post_json(self._agent, self._channel_url+"/deallocate",
|
||||||
|
{"side": self._side})
|
||||||
|
d.addBoth(lambda _: res) # ignore POST failure, pass-through result
|
||||||
|
return d
|
||||||
|
|
||||||
|
class ChannelManager:
|
||||||
|
def __init__(self, relay, side, handle_welcome):
|
||||||
|
self._relay = relay
|
||||||
|
self._side = side
|
||||||
|
self._handle_welcome = handle_welcome
|
||||||
|
self._agent = web_client.Agent(reactor)
|
||||||
|
|
||||||
|
def allocate(self):
|
||||||
|
url = self._relay + "allocate"
|
||||||
|
d = post_json(self._agent, url, {"side": self._side})
|
||||||
|
def _got_channel(data):
|
||||||
|
if "welcome" in data:
|
||||||
|
self._handle_welcome(data["welcome"])
|
||||||
|
return data["channel-id"]
|
||||||
|
d.addCallback(_got_channel)
|
||||||
|
return d
|
||||||
|
|
||||||
|
def list_channels(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def connect(self, channel_id):
|
||||||
|
return Channel(self._relay, channel_id, self._side,
|
||||||
|
self._handle_welcome, self._agent)
|
||||||
|
|
||||||
class Wormhole:
|
class Wormhole:
|
||||||
motd_displayed = False
|
motd_displayed = False
|
||||||
version_warning_displayed = False
|
version_warning_displayed = False
|
||||||
|
@ -40,14 +146,15 @@ class Wormhole:
|
||||||
if not isinstance(appid, type(b"")): raise UsageError
|
if not isinstance(appid, type(b"")): raise UsageError
|
||||||
self.appid = appid
|
self.appid = appid
|
||||||
self.relay = relay
|
self.relay = relay
|
||||||
self.agent = web_client.Agent(reactor)
|
self._set_side(hexlify(os.urandom(5)).decode("ascii"))
|
||||||
self.side = None
|
|
||||||
self.code = None
|
self.code = None
|
||||||
self.key = None
|
self.key = None
|
||||||
self._started_get_code = False
|
self._started_get_code = False
|
||||||
self._channel_url = None
|
|
||||||
self._messages = set() # (phase,body) , body is bytes
|
def _set_side(self, side):
|
||||||
self._sent_messages = set() # (phase,body)
|
self._side = side
|
||||||
|
self._channel_manager = ChannelManager(self.relay, self._side,
|
||||||
|
self.handle_welcome)
|
||||||
|
|
||||||
def handle_welcome(self, welcome):
|
def handle_welcome(self, welcome):
|
||||||
if ("motd" in welcome and
|
if ("motd" in welcome and
|
||||||
|
@ -71,35 +178,11 @@ class Wormhole:
|
||||||
if "error" in welcome:
|
if "error" in welcome:
|
||||||
raise ServerError(welcome["error"], self.relay)
|
raise ServerError(welcome["error"], self.relay)
|
||||||
|
|
||||||
def _post_json(self, url, post_json):
|
|
||||||
# POST a JSON body to a URL, parsing the response as JSON
|
|
||||||
data = json.dumps(post_json).encode("utf-8")
|
|
||||||
d = self.agent.request("POST", url, bodyProducer=DataProducer(data))
|
|
||||||
def _check_error(resp):
|
|
||||||
if resp.code != 200:
|
|
||||||
raise web_error.Error(resp.code, resp.phrase)
|
|
||||||
return resp
|
|
||||||
d.addCallback(_check_error)
|
|
||||||
d.addCallback(web_client.readBody)
|
|
||||||
d.addCallback(lambda data: json.loads(data))
|
|
||||||
return d
|
|
||||||
|
|
||||||
def _allocate_channel(self):
|
|
||||||
url = self.relay + "allocate"
|
|
||||||
d = self._post_json(url, {"side": self.side})
|
|
||||||
def _got_channel(data):
|
|
||||||
if "welcome" in data:
|
|
||||||
self.handle_welcome(data["welcome"])
|
|
||||||
return data["channel-id"]
|
|
||||||
d.addCallback(_got_channel)
|
|
||||||
return d
|
|
||||||
|
|
||||||
def get_code(self, code_length=2):
|
def get_code(self, code_length=2):
|
||||||
if self.code is not None: raise UsageError
|
if self.code is not None: raise UsageError
|
||||||
if self._started_get_code: raise UsageError
|
if self._started_get_code: raise UsageError
|
||||||
self._started_get_code = True
|
self._started_get_code = True
|
||||||
self.side = hexlify(os.urandom(5))
|
d = self._channel_manager.allocate()
|
||||||
d = self._allocate_channel()
|
|
||||||
def _got_channel_id(channel_id):
|
def _got_channel_id(channel_id):
|
||||||
code = codes.make_code(channel_id, code_length)
|
code = codes.make_code(channel_id, code_length)
|
||||||
assert isinstance(code, str), type(code)
|
assert isinstance(code, str), type(code)
|
||||||
|
@ -112,9 +195,7 @@ class Wormhole:
|
||||||
def set_code(self, code):
|
def set_code(self, code):
|
||||||
if not isinstance(code, str): raise UsageError
|
if not isinstance(code, str): raise UsageError
|
||||||
if self.code is not None: raise UsageError
|
if self.code is not None: raise UsageError
|
||||||
if self.side is not None: raise UsageError
|
|
||||||
self._set_code_and_channel_id(code)
|
self._set_code_and_channel_id(code)
|
||||||
self.side = hexlify(os.urandom(5))
|
|
||||||
self._start()
|
self._start()
|
||||||
|
|
||||||
def _set_code_and_channel_id(self, code):
|
def _set_code_and_channel_id(self, code):
|
||||||
|
@ -122,9 +203,9 @@ class Wormhole:
|
||||||
mo = re.search(r'^(\d+)-', code)
|
mo = re.search(r'^(\d+)-', code)
|
||||||
if not mo:
|
if not mo:
|
||||||
raise ValueError("code (%s) must start with NN-" % code)
|
raise ValueError("code (%s) must start with NN-" % code)
|
||||||
self.channel_id = int(mo.group(1))
|
|
||||||
self._channel_url = "%s%d" % (self.relay, self.channel_id)
|
|
||||||
self.code = code
|
self.code = code
|
||||||
|
channel_id = int(mo.group(1))
|
||||||
|
self.channel = self._channel_manager.connect(channel_id)
|
||||||
|
|
||||||
def _start(self):
|
def _start(self):
|
||||||
# allocate the rest now too, so it can be serialized
|
# allocate the rest now too, so it can be serialized
|
||||||
|
@ -141,7 +222,7 @@ class Wormhole:
|
||||||
"appid": self.appid,
|
"appid": self.appid,
|
||||||
"relay": self.relay,
|
"relay": self.relay,
|
||||||
"code": self.code,
|
"code": self.code,
|
||||||
"side": self.side,
|
"side": self._side,
|
||||||
"spake2": json.loads(self.sp.serialize()),
|
"spake2": json.loads(self.sp.serialize()),
|
||||||
"msg1": self.msg1.encode("hex"),
|
"msg1": self.msg1.encode("hex"),
|
||||||
}
|
}
|
||||||
|
@ -151,64 +232,12 @@ class Wormhole:
|
||||||
def from_serialized(klass, data):
|
def from_serialized(klass, data):
|
||||||
d = json.loads(data)
|
d = json.loads(data)
|
||||||
self = klass(d["appid"].encode("ascii"), d["relay"].encode("ascii"))
|
self = klass(d["appid"].encode("ascii"), d["relay"].encode("ascii"))
|
||||||
|
self._set_side(d["side"].encode("ascii"))
|
||||||
self._set_code_and_channel_id(d["code"].encode("ascii"))
|
self._set_code_and_channel_id(d["code"].encode("ascii"))
|
||||||
self.side = d["side"].encode("ascii")
|
|
||||||
self.sp = SPAKE2_Symmetric.from_serialized(json.dumps(d["spake2"]))
|
self.sp = SPAKE2_Symmetric.from_serialized(json.dumps(d["spake2"]))
|
||||||
self.msg1 = d["msg1"].decode("hex")
|
self.msg1 = d["msg1"].decode("hex")
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def _add_inbound_messages(self, messages):
|
|
||||||
for msg in messages:
|
|
||||||
phase = msg["phase"]
|
|
||||||
body = unhexlify(msg["body"].encode("ascii"))
|
|
||||||
self._messages.add( (phase, body) )
|
|
||||||
|
|
||||||
def _find_inbound_message(self, phase):
|
|
||||||
for (their_phase,body) in self._messages - self._sent_messages:
|
|
||||||
if their_phase == phase:
|
|
||||||
return body
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _send_message(self, phase, msg):
|
|
||||||
# TODO: retry on failure, with exponential backoff. We're guarding
|
|
||||||
# against the rendezvous server being temporarily offline.
|
|
||||||
if not isinstance(phase, type(u"")): raise UsageError(type(phase))
|
|
||||||
if not isinstance(msg, type(b"")): raise UsageError(type(msg))
|
|
||||||
self._sent_messages.add( (phase,msg) )
|
|
||||||
payload = {"side": self.side,
|
|
||||||
"phase": phase,
|
|
||||||
"body": hexlify(msg).decode("ascii")}
|
|
||||||
d = self._post_json(self._channel_url, payload)
|
|
||||||
d.addCallback(lambda resp: self._add_inbound_messages(resp["messages"]))
|
|
||||||
return d
|
|
||||||
|
|
||||||
def _get_message(self, phase):
|
|
||||||
# fire with a bytestring of the first message for 'phase' that wasn't
|
|
||||||
# one of ours. It will either come from previously-received messages,
|
|
||||||
# or from an EventSource that we attach to the corresponding URL
|
|
||||||
body = self._find_inbound_message(phase)
|
|
||||||
if body is not None:
|
|
||||||
return defer.succeed(body)
|
|
||||||
d = defer.Deferred()
|
|
||||||
msgs = []
|
|
||||||
def _handle(name, data):
|
|
||||||
if name == "welcome":
|
|
||||||
self.handle_welcome(json.loads(data))
|
|
||||||
if name == "message":
|
|
||||||
self._add_inbound_messages([json.loads(data)])
|
|
||||||
body = self._find_inbound_message(phase)
|
|
||||||
if body is not None and not msgs:
|
|
||||||
msgs.append(body)
|
|
||||||
d.callback(None)
|
|
||||||
# TODO: use agent=self.agent
|
|
||||||
es = ReconnectingEventSource(self._channel_url, _handle)
|
|
||||||
es.startService() # TODO: .setServiceParent(self)
|
|
||||||
es.activate()
|
|
||||||
d.addCallback(lambda _: es.deactivate())
|
|
||||||
d.addCallback(lambda _: es.stopService())
|
|
||||||
d.addCallback(lambda _: msgs[0])
|
|
||||||
return d
|
|
||||||
|
|
||||||
def derive_key(self, purpose, length=SecretBox.KEY_SIZE):
|
def derive_key(self, purpose, length=SecretBox.KEY_SIZE):
|
||||||
if self.key is None:
|
if self.key is None:
|
||||||
# call after get_verifier() or get_data()
|
# call after get_verifier() or get_data()
|
||||||
|
@ -237,8 +266,8 @@ class Wormhole:
|
||||||
# TODO: prevent multiple invocation
|
# TODO: prevent multiple invocation
|
||||||
if self.key:
|
if self.key:
|
||||||
return defer.succeed(self.key)
|
return defer.succeed(self.key)
|
||||||
d = self._send_message(u"pake", self.msg1)
|
d = self.channel.send(u"pake", self.msg1)
|
||||||
d.addCallback(lambda _: self._get_message(u"pake"))
|
d.addCallback(lambda _: self.channel.get(u"pake"))
|
||||||
def _got_pake(pake_msg):
|
def _got_pake(pake_msg):
|
||||||
key = self.sp.finish(pake_msg)
|
key = self.sp.finish(pake_msg)
|
||||||
self.key = key
|
self.key = key
|
||||||
|
@ -259,7 +288,7 @@ class Wormhole:
|
||||||
if self.code is None: raise UsageError
|
if self.code is None: raise UsageError
|
||||||
d = self._get_key()
|
d = self._get_key()
|
||||||
d.addCallback(self._get_data2, outbound_data)
|
d.addCallback(self._get_data2, outbound_data)
|
||||||
d.addBoth(self._deallocate)
|
d.addBoth(self.channel.deallocate)
|
||||||
return d
|
return d
|
||||||
|
|
||||||
def _get_data2(self, key, outbound_data):
|
def _get_data2(self, key, outbound_data):
|
||||||
|
@ -269,9 +298,9 @@ class Wormhole:
|
||||||
data_key = self.derive_key(b"data-key")
|
data_key = self.derive_key(b"data-key")
|
||||||
|
|
||||||
outbound_encrypted = self._encrypt_data(data_key, outbound_data)
|
outbound_encrypted = self._encrypt_data(data_key, outbound_data)
|
||||||
d = self._send_message(u"data", outbound_encrypted)
|
d = self.channel.send(u"data", outbound_encrypted)
|
||||||
|
|
||||||
d.addCallback(lambda _: self._get_message(u"data"))
|
d.addCallback(lambda _: self.channel.get(u"data"))
|
||||||
def _got_data(inbound_encrypted):
|
def _got_data(inbound_encrypted):
|
||||||
#if inbound_encrypted == outbound_encrypted:
|
#if inbound_encrypted == outbound_encrypted:
|
||||||
# raise ReflectionAttack
|
# raise ReflectionAttack
|
||||||
|
@ -282,10 +311,3 @@ class Wormhole:
|
||||||
raise WrongPasswordError
|
raise WrongPasswordError
|
||||||
d.addCallback(_got_data)
|
d.addCallback(_got_data)
|
||||||
return d
|
return d
|
||||||
|
|
||||||
def _deallocate(self, res):
|
|
||||||
# only try once, no retries
|
|
||||||
d = self._post_json(self._channel_url+"/deallocate",
|
|
||||||
{"side": self.side})
|
|
||||||
d.addBoth(lambda _: res) # ignore POST failure, pass-through result
|
|
||||||
return d
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user