add 'mock', building out test_wormhole
This commit is contained in:
parent
0ee56e12b0
commit
3da52b0a3e
74
src/wormhole/test/test_wormhole.py
Normal file
74
src/wormhole/test/test_wormhole.py
Normal file
|
@ -0,0 +1,74 @@
|
||||||
|
from __future__ import print_function
|
||||||
|
import json
|
||||||
|
import mock
|
||||||
|
from twisted.trial import unittest
|
||||||
|
from twisted.internet import reactor
|
||||||
|
from twisted.internet.defer import gatherResults, inlineCallbacks
|
||||||
|
#from ..twisted.transcribe import (wormhole, wormhole_from_serialized,
|
||||||
|
# UsageError, WrongPasswordError)
|
||||||
|
#from .common import ServerBase
|
||||||
|
from ..wormhole import _Wormhole, _WelcomeHandler
|
||||||
|
from ..timing import DebugTiming
|
||||||
|
|
||||||
|
APPID = u"appid"
|
||||||
|
|
||||||
|
class MockWebSocket:
|
||||||
|
def __init__(self):
|
||||||
|
self._payloads = []
|
||||||
|
def sendMessage(self, payload, is_binary):
|
||||||
|
assert not is_binary
|
||||||
|
self._payloads.append(payload)
|
||||||
|
|
||||||
|
def outbound(self):
|
||||||
|
out = []
|
||||||
|
while self._payloads:
|
||||||
|
p = self._payloads.pop(0)
|
||||||
|
out.append(json.loads(p.decode("utf-8")))
|
||||||
|
return out
|
||||||
|
|
||||||
|
def response(w, **kwargs):
|
||||||
|
payload = json.dumps(kwargs).encode("utf-8")
|
||||||
|
w._ws_dispatch_response(payload)
|
||||||
|
|
||||||
|
class Welcome(unittest.TestCase):
|
||||||
|
def test_no_current_version(self):
|
||||||
|
# WelcomeHandler should tolerate lack of ["current_version"]
|
||||||
|
w = _WelcomeHandler(u"relay_url", u"current_version")
|
||||||
|
w.handle_welcome({})
|
||||||
|
|
||||||
|
|
||||||
|
class Basic(unittest.TestCase):
|
||||||
|
def test_create(self):
|
||||||
|
w = _Wormhole(APPID, u"relay_url", reactor, None, None)
|
||||||
|
|
||||||
|
def test_basic(self):
|
||||||
|
# We don't call w._start(), so this doesn't create a WebSocket
|
||||||
|
# connection. We provide a mock connection instead.
|
||||||
|
timing = DebugTiming()
|
||||||
|
with mock.patch("wormhole.wormhole._WelcomeHandler") as whc:
|
||||||
|
w = _Wormhole(APPID, u"relay_url", reactor, None, timing)
|
||||||
|
wh = whc.return_value
|
||||||
|
#w._welcomer = mock.Mock()
|
||||||
|
# w._connect = lambda self: None
|
||||||
|
# w._event_connected(mock_ws)
|
||||||
|
# w._event_ws_opened()
|
||||||
|
# w._ws_dispatch_response(payload)
|
||||||
|
self.assertEqual(w._ws_url, u"relay_url")
|
||||||
|
ws = MockWebSocket()
|
||||||
|
w._event_connected(ws)
|
||||||
|
out = ws.outbound()
|
||||||
|
self.assertEqual(len(out), 0)
|
||||||
|
|
||||||
|
w._event_ws_opened(None)
|
||||||
|
out = ws.outbound()
|
||||||
|
self.assertEqual(len(out), 1)
|
||||||
|
self.assertEqual(out[0]["type"], u"bind")
|
||||||
|
self.assertEqual(out[0]["appid"], APPID)
|
||||||
|
self.assertEqual(out[0]["side"], w._side)
|
||||||
|
self.assertIn(u"id", out[0])
|
||||||
|
|
||||||
|
# WelcomeHandler should get called upon 'welcome' response
|
||||||
|
WELCOME = {u"foo": u"bar"}
|
||||||
|
response(w, type="welcome", welcome=WELCOME)
|
||||||
|
self.assertEqual(wh.mock_calls, [mock.call.handle_welcome(WELCOME)])
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from __future__ import print_function
|
from __future__ import print_function, absolute_import
|
||||||
import os, sys, json, re, unicodedata
|
import os, sys, json, re, unicodedata
|
||||||
from six.moves.urllib_parse import urlparse
|
from six.moves.urllib_parse import urlparse
|
||||||
from binascii import hexlify, unhexlify
|
from binascii import hexlify, unhexlify
|
||||||
|
@ -11,10 +11,11 @@ from nacl.secret import SecretBox
|
||||||
from nacl.exceptions import CryptoError
|
from nacl.exceptions import CryptoError
|
||||||
from nacl import utils
|
from nacl import utils
|
||||||
from spake2 import SPAKE2_Symmetric
|
from spake2 import SPAKE2_Symmetric
|
||||||
from .. import __version__
|
from . import __version__
|
||||||
from .. import codes
|
from . import codes
|
||||||
from ..errors import ServerError, Timeout, WrongPasswordError, UsageError
|
#from .errors import ServerError, Timeout
|
||||||
from ..timing import DebugTiming
|
from .errors import WrongPasswordError, UsageError
|
||||||
|
#from .timing import DebugTiming
|
||||||
from hkdf import Hkdf
|
from hkdf import Hkdf
|
||||||
|
|
||||||
def HKDF(skm, outlen, salt=None, CTXinfo=b""):
|
def HKDF(skm, outlen, salt=None, CTXinfo=b""):
|
||||||
|
@ -80,7 +81,7 @@ class _GetCode:
|
||||||
assert isinstance(code, type(u"")), type(code)
|
assert isinstance(code, type(u"")), type(code)
|
||||||
returnValue(code)
|
returnValue(code)
|
||||||
|
|
||||||
def _ws_handle_allocated(self, msg):
|
def _response_handle_allocated(self, msg):
|
||||||
nid = msg["nameplate"]
|
nid = msg["nameplate"]
|
||||||
assert isinstance(nid, type(u"")), type(nid)
|
assert isinstance(nid, type(u"")), type(nid)
|
||||||
self._allocated_d.callback(nid)
|
self._allocated_d.callback(nid)
|
||||||
|
@ -125,12 +126,16 @@ class _InputCode:
|
||||||
self._reactor.removeSystemEventTrigger(t)
|
self._reactor.removeSystemEventTrigger(t)
|
||||||
returnValue(code)
|
returnValue(code)
|
||||||
|
|
||||||
def _ws_handle_nameplates(self, msg):
|
def _response_handle_nameplates(self, msg):
|
||||||
nameplates = msg["nameplates"]
|
nameplates = msg["nameplates"]
|
||||||
assert isinstance(nameplates, list), type(nameplates)
|
assert isinstance(nameplates, list), type(nameplates)
|
||||||
for nameplate_id in nameplates:
|
nids = []
|
||||||
|
for n in nameplates:
|
||||||
|
assert isinstance(n, dict), type(n)
|
||||||
|
nameplate_id = n[u"id"]
|
||||||
assert isinstance(nameplate_id, type(u"")), type(nameplate_id)
|
assert isinstance(nameplate_id, type(u"")), type(nameplate_id)
|
||||||
self._lister_d.callback(nameplates)
|
nids.append(nameplate_id)
|
||||||
|
self._lister_d.callback(nids)
|
||||||
|
|
||||||
def _warn_readline(self):
|
def _warn_readline(self):
|
||||||
# When our process receives a SIGINT, Twisted's SIGINT handler will
|
# When our process receives a SIGINT, Twisted's SIGINT handler will
|
||||||
|
@ -166,11 +171,53 @@ class _InputCode:
|
||||||
# doesn't see the signal, and we must still wait for stdin to make
|
# doesn't see the signal, and we must still wait for stdin to make
|
||||||
# readline finish.
|
# readline finish.
|
||||||
|
|
||||||
|
class _WelcomeHandler:
|
||||||
|
def __init__(self, url, current_version):
|
||||||
|
self._ws_url = url
|
||||||
|
self._version_warning_displayed = False
|
||||||
|
self._motd_displayed = False
|
||||||
|
self._current_version = current_version
|
||||||
|
|
||||||
|
def handle_welcome(self, welcome):
|
||||||
|
if ("motd" in welcome and
|
||||||
|
not self._motd_displayed):
|
||||||
|
motd_lines = welcome["motd"].splitlines()
|
||||||
|
motd_formatted = "\n ".join(motd_lines)
|
||||||
|
print("Server (at %s) says:\n %s" %
|
||||||
|
(self._ws_url, motd_formatted), file=sys.stderr)
|
||||||
|
self._motd_displayed = True
|
||||||
|
|
||||||
|
# Only warn if we're running a release version (e.g. 0.0.6, not
|
||||||
|
# 0.0.6-DISTANCE-gHASH). Only warn once.
|
||||||
|
if ("current_version" in welcome
|
||||||
|
and "-" not in self._current_version
|
||||||
|
and not self._version_warning_displayed
|
||||||
|
and welcome["current_version"] != self._current_version):
|
||||||
|
print("Warning: errors may occur unless both sides are running the same version", file=sys.stderr)
|
||||||
|
print("Server claims %s is current, but ours is %s"
|
||||||
|
% (welcome["current_version"], self._current_version),
|
||||||
|
file=sys.stderr)
|
||||||
|
self._version_warning_displayed = True
|
||||||
|
|
||||||
|
if "error" in welcome:
|
||||||
|
return self._signal_error(welcome["error"])
|
||||||
|
|
||||||
|
|
||||||
class _Wormhole:
|
class _Wormhole:
|
||||||
def __init__(self):
|
def __init__(self, appid, relay_url, reactor, tor_manager, timing):
|
||||||
|
self._appid = appid
|
||||||
|
self._ws_url = relay_url
|
||||||
|
self._reactor = reactor
|
||||||
|
self._tor_manager = tor_manager
|
||||||
|
self._timing = timing
|
||||||
|
|
||||||
|
self._welcomer = _WelcomeHandler(self._ws_url, __version__)
|
||||||
|
self._side = hexlify(os.urandom(5)).decode("ascii")
|
||||||
self._connected = None
|
self._connected = None
|
||||||
|
self._nameplate_id = None
|
||||||
|
self._mailbox_id = None
|
||||||
|
self._mailbox_opened = False
|
||||||
|
self._mailbox_closed = False
|
||||||
self._flag_need_mailbox = True
|
self._flag_need_mailbox = True
|
||||||
self._flag_need_to_see_mailbox_used = True
|
self._flag_need_to_see_mailbox_used = True
|
||||||
self._flag_need_to_build_msg1 = True
|
self._flag_need_to_build_msg1 = True
|
||||||
|
@ -179,9 +226,11 @@ class _Wormhole:
|
||||||
self._flag_need_key = True # rename to not self._key
|
self._flag_need_key = True # rename to not self._key
|
||||||
|
|
||||||
self._next_send_phase = 0
|
self._next_send_phase = 0
|
||||||
|
self._plaintext_to_send = [] # (phase, plaintext, deferred)
|
||||||
self._phase_messages_to_send = [] # not yet acked by server
|
self._phase_messages_to_send = [] # not yet acked by server
|
||||||
|
|
||||||
self._next_receive_phase = 0
|
self._next_receive_phase = 0
|
||||||
|
self._receive_waiters = {} # phase -> Deferred
|
||||||
self._phase_messages_received = {} # phase -> message
|
self._phase_messages_received = {} # phase -> message
|
||||||
|
|
||||||
|
|
||||||
|
@ -213,10 +262,11 @@ class _Wormhole:
|
||||||
d.addCallback(self._event_connected)
|
d.addCallback(self._event_connected)
|
||||||
# f.d is errbacked if WebSocket negotiation fails, and the WebSocket
|
# f.d is errbacked if WebSocket negotiation fails, and the WebSocket
|
||||||
# drops any data sent before onOpen() fires, so we must wait for it
|
# drops any data sent before onOpen() fires, so we must wait for it
|
||||||
|
d.addCallback(lambda _: f.d)
|
||||||
d.addCallback(self._event_ws_opened)
|
d.addCallback(self._event_ws_opened)
|
||||||
return d
|
return d
|
||||||
|
|
||||||
def _event_connected(self, ws, f):
|
def _event_connected(self, ws):
|
||||||
self._ws = ws
|
self._ws = ws
|
||||||
self._ws_t = self._timing.add("websocket")
|
self._ws_t = self._timing.add("websocket")
|
||||||
|
|
||||||
|
@ -224,30 +274,35 @@ class _Wormhole:
|
||||||
self._connected = True
|
self._connected = True
|
||||||
self._ws_send_command(u"bind", appid=self._appid, side=self._side)
|
self._ws_send_command(u"bind", appid=self._appid, side=self._side)
|
||||||
self._maybe_get_mailbox()
|
self._maybe_get_mailbox()
|
||||||
|
self._maybe_send_pake()
|
||||||
|
|
||||||
def _ws_handle_welcome(self, msg):
|
def _ws_send_command(self, mtype, **kwargs):
|
||||||
welcome = msg["welcome"]
|
# msgid is used by misc/dump-timing.py to correlate our sends with
|
||||||
if ("motd" in welcome and
|
# their receives, and vice versa. They are also correlated with the
|
||||||
not self.motd_displayed):
|
# ACKs we get back from the server (which we otherwise ignore). There
|
||||||
motd_lines = welcome["motd"].splitlines()
|
# are so few messages, 16 bits is enough to be mostly-unique.
|
||||||
motd_formatted = "\n ".join(motd_lines)
|
kwargs["id"] = hexlify(os.urandom(2)).decode("ascii")
|
||||||
print("Server (at %s) says:\n %s" %
|
kwargs["type"] = mtype
|
||||||
(self._ws_url, motd_formatted), file=sys.stderr)
|
payload = json.dumps(kwargs).encode("utf-8")
|
||||||
self.motd_displayed = True
|
self._timing.add("ws_send", _side=self._side, **kwargs)
|
||||||
|
self._ws.sendMessage(payload, False)
|
||||||
|
|
||||||
# Only warn if we're running a release version (e.g. 0.0.6, not
|
def _ws_dispatch_response(self, payload):
|
||||||
# 0.0.6-DISTANCE-gHASH). Only warn once.
|
msg = json.loads(payload.decode("utf-8"))
|
||||||
if ("-" not in __version__ and
|
self._timing.add("ws_receive", _side=self._side, message=msg)
|
||||||
not self.version_warning_displayed and
|
mtype = msg["type"]
|
||||||
welcome["current_version"] != __version__):
|
meth = getattr(self, "_response_handle_"+mtype, None)
|
||||||
print("Warning: errors may occur unless both sides are running the same version", file=sys.stderr)
|
if not meth:
|
||||||
print("Server claims %s is current, but ours is %s"
|
# make tests fail, but real application will ignore it
|
||||||
% (welcome["current_version"], __version__), file=sys.stderr)
|
log.err(ValueError("Unknown inbound message type %r" % (msg,)))
|
||||||
self.version_warning_displayed = True
|
return
|
||||||
|
return meth(msg)
|
||||||
|
|
||||||
if "error" in welcome:
|
def _response_handle_ack(self, msg):
|
||||||
return self._signal_error(welcome["error"])
|
pass
|
||||||
|
|
||||||
|
def _response_handle_welcome(self, msg):
|
||||||
|
self._welcomer.handle_welcome(msg["welcome"])
|
||||||
|
|
||||||
# entry point 1: generate a new code
|
# entry point 1: generate a new code
|
||||||
@inlineCallbacks
|
@inlineCallbacks
|
||||||
|
@ -257,7 +312,7 @@ class _Wormhole:
|
||||||
self._started_get_code = True
|
self._started_get_code = True
|
||||||
with self._timing.add("API get_code"):
|
with self._timing.add("API get_code"):
|
||||||
gc = _GetCode(code_length, self._ws_send_command)
|
gc = _GetCode(code_length, self._ws_send_command)
|
||||||
self._ws_handle_allocated = gc._ws_handle_allocated
|
self._response_handle_allocated = gc._response_handle_allocated
|
||||||
code = yield gc.go()
|
code = yield gc.go()
|
||||||
self._event_learned_code(code)
|
self._event_learned_code(code)
|
||||||
returnValue(code)
|
returnValue(code)
|
||||||
|
@ -269,9 +324,9 @@ class _Wormhole:
|
||||||
if self._started_input_code: raise UsageError
|
if self._started_input_code: raise UsageError
|
||||||
self._started_input_code = True
|
self._started_input_code = True
|
||||||
with self._timing.add("API input_code"):
|
with self._timing.add("API input_code"):
|
||||||
gc = _InputCode(prompt, code_length, self._ws_send_command)
|
ic = _InputCode(prompt, code_length, self._ws_send_command)
|
||||||
self._ws_handle_nameplates = gc._ws_handle_nameplates
|
self._response_handle_nameplates = ic._response_handle_nameplates
|
||||||
code = yield gc.go()
|
code = yield ic.go()
|
||||||
self._event_learned_code(code)
|
self._event_learned_code(code)
|
||||||
returnValue(None)
|
returnValue(None)
|
||||||
|
|
||||||
|
@ -320,15 +375,12 @@ class _Wormhole:
|
||||||
return
|
return
|
||||||
self._ws_send_command(u"claim", nameplate=self._nameplate_id)
|
self._ws_send_command(u"claim", nameplate=self._nameplate_id)
|
||||||
|
|
||||||
def _ws_handle_claimed(self, msg):
|
def _response_handle_claimed(self, msg):
|
||||||
mailbox_id = msg["mailbox"]
|
mailbox_id = msg["mailbox"]
|
||||||
assert isinstance(mailbox_id, type(u"")), type(mailbox_id)
|
assert isinstance(mailbox_id, type(u"")), type(mailbox_id)
|
||||||
self._mailbox_id = mailbox_id
|
self._mailbox_id = mailbox_id
|
||||||
self._event_learned_mailbox()
|
self._event_learned_mailbox()
|
||||||
|
|
||||||
def _event_welcome(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _event_learned_mailbox(self):
|
def _event_learned_mailbox(self):
|
||||||
self._flag_need_mailbox = False
|
self._flag_need_mailbox = False
|
||||||
if not self._mailbox_id: raise UsageError
|
if not self._mailbox_id: raise UsageError
|
||||||
|
@ -340,7 +392,7 @@ class _Wormhole:
|
||||||
|
|
||||||
def _maybe_send_pake(self):
|
def _maybe_send_pake(self):
|
||||||
# TODO: deal with reentrant call
|
# TODO: deal with reentrant call
|
||||||
if not (self._connected and self._mailbox
|
if not (self._connected and self._mailbox_opened
|
||||||
and self._flag_need_to_send_PAKE):
|
and self._flag_need_to_send_PAKE):
|
||||||
return
|
return
|
||||||
d = self._msg_send(u"pake", self._msg1)
|
d = self._msg_send(u"pake", self._msg1)
|
||||||
|
@ -349,42 +401,11 @@ class _Wormhole:
|
||||||
d.addCallback(_pake_sent)
|
d.addCallback(_pake_sent)
|
||||||
d.addErrback(log.err)
|
d.addErrback(log.err)
|
||||||
|
|
||||||
def _maybe_send_phase_messages(self):
|
|
||||||
# TODO: deal with reentrant call
|
|
||||||
if not (self._connected and self._mailbox and self._key):
|
|
||||||
return
|
|
||||||
for pm in self._phase_messages_to_send:
|
|
||||||
(phase, message) = pm
|
|
||||||
d = self._msg_send(phase, message)
|
|
||||||
def _phase_message_sent(res, pm=pm):
|
|
||||||
try:
|
|
||||||
self._phase_messages_to_send.remove(pm)
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
d.addCallback(_phase_message_sent)
|
|
||||||
d.addErrback(log.err)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _event_received_message(self, msg):
|
|
||||||
pass
|
|
||||||
def _event_mailbox_used(self):
|
|
||||||
if self._flag_need_to_see_mailbox_used:
|
|
||||||
self._ws_send_command(u"release")
|
|
||||||
self._flag_need_to_see_mailbox_used = False
|
|
||||||
|
|
||||||
def _event_learned_PAKE(self, pake_msg):
|
def _event_learned_PAKE(self, pake_msg):
|
||||||
with self._timing.add("pake2", waiting="crypto"):
|
with self._timing.add("pake2", waiting="crypto"):
|
||||||
self._key = self._sp.finish(pake_msg)
|
self._key = self._sp.finish(pake_msg)
|
||||||
self._event_established_key()
|
self._event_established_key()
|
||||||
|
|
||||||
def derive_key(self, purpose, length=SecretBox.KEY_SIZE):
|
|
||||||
if not isinstance(purpose, type(u"")): raise TypeError(type(purpose))
|
|
||||||
if self._key is None:
|
|
||||||
# call after get_verifier() or get()
|
|
||||||
raise UsageError
|
|
||||||
return HKDF(self._key, length, CTXinfo=to_bytes(purpose))
|
|
||||||
|
|
||||||
def _event_established_key(self):
|
def _event_established_key(self):
|
||||||
self._timing.add("key established")
|
self._timing.add("key established")
|
||||||
if self._send_confirm:
|
if self._send_confirm:
|
||||||
|
@ -395,7 +416,8 @@ class _Wormhole:
|
||||||
self._msg_send(u"confirm", confmsg, wait=True)
|
self._msg_send(u"confirm", confmsg, wait=True)
|
||||||
verifier = self.derive_key(u"wormhole:verifier")
|
verifier = self.derive_key(u"wormhole:verifier")
|
||||||
self._event_computed_verifier(verifier)
|
self._event_computed_verifier(verifier)
|
||||||
pass
|
self._maybe_send_phase_messages()
|
||||||
|
|
||||||
def _event_computed_verifier(self, verifier):
|
def _event_computed_verifier(self, verifier):
|
||||||
self._verifier = verifier
|
self._verifier = verifier
|
||||||
d, self._verifier_waiter = self._verifier_waiter, None
|
d, self._verifier_waiter = self._verifier_waiter, None
|
||||||
|
@ -412,13 +434,85 @@ class _Wormhole:
|
||||||
# this makes all API calls fail
|
# this makes all API calls fail
|
||||||
return self._signal_error(WrongPasswordError())
|
return self._signal_error(WrongPasswordError())
|
||||||
|
|
||||||
|
|
||||||
|
@inlineCallbacks
|
||||||
|
def send(self, outbound_data, wait=False):
|
||||||
|
if not isinstance(outbound_data, type(b"")):
|
||||||
|
raise TypeError(type(outbound_data))
|
||||||
|
if self._closed: raise UsageError
|
||||||
|
phase = self._next_send_phase
|
||||||
|
self._next_send_phase += 1
|
||||||
|
d = defer.Deferred()
|
||||||
|
self._plaintext_to_send.append( (phase, outbound_data, d) )
|
||||||
|
with self._timing.add("API send", phase=phase, wait=wait):
|
||||||
|
self._maybe_send_phase_messages()
|
||||||
|
if wait:
|
||||||
|
yield d
|
||||||
|
|
||||||
|
def _maybe_send_phase_messages(self):
|
||||||
|
# TODO: deal with reentrant call
|
||||||
|
if not (self._connected and self._mailbox_opened and self._key):
|
||||||
|
return
|
||||||
|
plaintexts = self._plaintext_to_send
|
||||||
|
self._plaintext_to_send = []
|
||||||
|
for pm in plaintexts:
|
||||||
|
(phase, plaintext, wait_d) = pm
|
||||||
|
data_key = self.derive_key(u"wormhole:phase:%d" % phase)
|
||||||
|
encrypted = self._encrypt_data(data_key, plaintext)
|
||||||
|
d = self._msg_send(phase, encrypted)
|
||||||
|
d.addBoth(wait_d.callback)
|
||||||
|
d.addErrback(log.err)
|
||||||
|
|
||||||
|
def _encrypt_data(self, key, data):
|
||||||
|
# Without predefined roles, we can't derive predictably unique keys
|
||||||
|
# for each side, so we use the same key for both. We use random
|
||||||
|
# nonces to keep the messages distinct, and we automatically ignore
|
||||||
|
# reflections.
|
||||||
|
assert isinstance(key, type(b"")), type(key)
|
||||||
|
assert isinstance(data, type(b"")), type(data)
|
||||||
|
assert len(key) == SecretBox.KEY_SIZE, len(key)
|
||||||
|
box = SecretBox(key)
|
||||||
|
nonce = utils.random(SecretBox.NONCE_SIZE)
|
||||||
|
return box.encrypt(data, nonce)
|
||||||
|
|
||||||
|
@inlineCallbacks
|
||||||
|
def _msg_send(self, phase, body, wait=False):
|
||||||
|
if phase in self._sent_messages: raise UsageError
|
||||||
|
if not self._mailbox_opened: raise UsageError
|
||||||
|
if self._mailbox_closed: raise UsageError
|
||||||
|
self._sent_messages[phase] = body
|
||||||
|
# TODO: retry on failure, with exponential backoff. We're guarding
|
||||||
|
# against the rendezvous server being temporarily offline.
|
||||||
|
t = self._timing.add("add", phase=phase, wait=wait)
|
||||||
|
yield self._ws_send_command(u"add", phase=phase,
|
||||||
|
body=hexlify(body).decode("ascii"))
|
||||||
|
if wait:
|
||||||
|
while phase not in self._delivered_messages:
|
||||||
|
yield self._sleep()
|
||||||
|
t.finish()
|
||||||
|
|
||||||
|
|
||||||
|
def _event_received_message(self, msg):
|
||||||
|
pass
|
||||||
|
def _event_mailbox_used(self):
|
||||||
|
if self._flag_need_to_see_mailbox_used:
|
||||||
|
self._ws_send_command(u"release")
|
||||||
|
self._flag_need_to_see_mailbox_used = False
|
||||||
|
|
||||||
|
def derive_key(self, purpose, length=SecretBox.KEY_SIZE):
|
||||||
|
if not isinstance(purpose, type(u"")): raise TypeError(type(purpose))
|
||||||
|
if self._key is None:
|
||||||
|
# call after get_verifier() or get()
|
||||||
|
raise UsageError
|
||||||
|
return HKDF(self._key, length, CTXinfo=to_bytes(purpose))
|
||||||
|
|
||||||
def _event_received_phase_message(self, phase, message):
|
def _event_received_phase_message(self, phase, message):
|
||||||
self._phase_messages_received[phase] = message
|
self._phase_messages_received[phase] = message
|
||||||
if phase in self._phase_message_waiters:
|
if phase in self._phase_message_waiters:
|
||||||
d = self._phase_message_waiters.pop(phase)
|
d = self._phase_message_waiters.pop(phase)
|
||||||
d.callback(message)
|
d.callback(message)
|
||||||
|
|
||||||
def _ws_handle_message(self, msg):
|
def _response_handle_message(self, msg):
|
||||||
side = msg["side"]
|
side = msg["side"]
|
||||||
phase = msg["phase"]
|
phase = msg["phase"]
|
||||||
body = unhexlify(msg["body"].encode("ascii"))
|
body = unhexlify(msg["body"].encode("ascii"))
|
||||||
|
@ -443,12 +537,35 @@ class _Wormhole:
|
||||||
if phase == u"confirm":
|
if phase == u"confirm":
|
||||||
self._event_received_confirm(body)
|
self._event_received_confirm(body)
|
||||||
# now notify anyone waiting on it
|
# now notify anyone waiting on it
|
||||||
self._wakeup()
|
try:
|
||||||
|
data_key = self.derive_key(u"wormhole:phase:%s" % phase)
|
||||||
|
inbound_data = self._decrypt_data(data_key, body)
|
||||||
|
except CryptoError:
|
||||||
|
raise WrongPasswordError
|
||||||
|
self._phase_messages_received[phase] = inbound_data
|
||||||
|
if phase in self._receive_waiters:
|
||||||
|
d = self._receive_waiters.pop(phase)
|
||||||
|
d.callback(inbound_data)
|
||||||
|
|
||||||
def _event_asked_to_send_phase_message(self, phase, message):
|
def _decrypt_data(self, key, encrypted):
|
||||||
pm = (phase, message)
|
assert isinstance(key, type(b"")), type(key)
|
||||||
self._phase_messages_to_send.append(pm)
|
assert isinstance(encrypted, type(b"")), type(encrypted)
|
||||||
self._maybe_send_phase_messages()
|
assert len(key) == SecretBox.KEY_SIZE, len(key)
|
||||||
|
box = SecretBox(key)
|
||||||
|
data = box.decrypt(encrypted)
|
||||||
|
return data
|
||||||
|
|
||||||
|
@inlineCallbacks
|
||||||
|
def get(self):
|
||||||
|
if self._closed: raise UsageError
|
||||||
|
if self._code is None: raise UsageError
|
||||||
|
phase = self._next_receive_phase
|
||||||
|
self._next_receive_phase += 1
|
||||||
|
with self._timing.add("API get", phase=phase):
|
||||||
|
if phase in self._phase_messages_received:
|
||||||
|
returnValue(self._phase_messages_received[phase])
|
||||||
|
d = self._receive_waiters[phase] = defer.Deferred()
|
||||||
|
yield d
|
||||||
|
|
||||||
def _event_asked_to_close(self):
|
def _event_asked_to_close(self):
|
||||||
pass
|
pass
|
||||||
|
|
Loading…
Reference in New Issue
Block a user