diff --git a/src/txwormhole/transcribe.py b/src/txwormhole/transcribe.py index 5654ba6..4750acd 100644 --- a/src/txwormhole/transcribe.py +++ b/src/txwormhole/transcribe.py @@ -1,25 +1,20 @@ from __future__ import print_function import os, sys, json, re, unicodedata -from six.moves.urllib_parse import urlencode +from six.moves.urllib_parse import urlparse from binascii import hexlify, unhexlify -from zope.interface import implementer -from twisted.internet import reactor, defer +from twisted.internet import reactor, defer, endpoints, error from twisted.internet.threads import deferToThread, blockingCallFromThread from twisted.internet.defer import inlineCallbacks, returnValue -from twisted.web import client as web_client -from twisted.web import error as web_error -from twisted.web.iweb import IBodyProducer +from autobahn.twisted import websocket from nacl.secret import SecretBox from nacl.exceptions import CryptoError from nacl import utils from spake2 import SPAKE2_Symmetric -from .eventsource import ReconnectingEventSource from wormhole import __version__ from wormhole import codes from wormhole.errors import ServerError, Timeout, WrongPasswordError, UsageError from wormhole.timing import DebugTiming from hkdf import Hkdf -from wormhole.channel_monitor import monitor def HKDF(skm, outlen, salt=None, CTXinfo=b""): return Hkdf(salt, skm).expand(CTXinfo, outlen) @@ -32,207 +27,6 @@ def make_confmsg(confkey, nonce): def to_bytes(u): return unicodedata.normalize("NFC", u).encode("utf-8") -@implementer(IBodyProducer) -class DataProducer: - def __init__(self, data): - self.data = data - self.length = len(data) - def startProducing(self, consumer): - consumer.write(self.data) - return defer.succeed(None) - def stopProducing(self): - pass - def pauseProducing(self): - pass - def resumeProducing(self): - pass - - -def post_json(agent, url, request_body): - # POST a JSON body to a URL, parsing the response as JSON - data = json.dumps(request_body).encode("utf-8") - d = agent.request(b"POST", url.encode("utf-8"), - bodyProducer=DataProducer(data)) - def _check_error(resp): - if resp.code != 200: - raise web_error.Error(resp.code, resp.phrase) - return resp - d.addCallback(_check_error) - d.addCallback(web_client.readBody) - d.addCallback(lambda data: json.loads(data.decode("utf-8"))) - return d - -def get_json(agent, url): - # GET from a URL, parsing the response as JSON - d = agent.request(b"GET", url.encode("utf-8")) - def _check_error(resp): - if resp.code != 200: - raise web_error.Error(resp.code, resp.phrase) - return resp - d.addCallback(_check_error) - d.addCallback(web_client.readBody) - d.addCallback(lambda data: json.loads(data.decode("utf-8"))) - return d - -class Channel: - def __init__(self, relay_url, appid, channelid, side, handle_welcome, - agent, timing): - self._relay_url = relay_url - self._appid = appid - self._channelid = channelid - self._side = side - self._handle_welcome = handle_welcome - self._agent = agent - self._timing = timing - self._messages = set() # (phase,body) , body is bytes - self._sent_messages = set() # (phase,body) - - def _add_inbound_messages(self, messages): - for msg in messages: - phase = msg["phase"] - body = unhexlify(msg["body"].encode("ascii")) - self._messages.add( (phase, body) ) - - def _find_inbound_message(self, phases): - their_messages = self._messages - self._sent_messages - for phase in phases: - for (their_phase,body) in their_messages: - if their_phase == phase: - return (phase, body) - return None - - def send(self, phase, msg): - # TODO: retry on failure, with exponential backoff. We're guarding - # against the rendezvous server being temporarily offline. - if not isinstance(phase, type(u"")): raise TypeError(type(phase)) - if not isinstance(msg, type(b"")): raise TypeError(type(msg)) - self._sent_messages.add( (phase,msg) ) - assert isinstance(self._side, type(u"")), type(self._side) - payload = {"appid": self._appid, - "channelid": self._channelid, - "side": self._side, - "phase": phase, - "body": hexlify(msg).decode("ascii")} - _sent = self._timing.add_event("send %s" % phase) - d = post_json(self._agent, self._relay_url+"add", payload) - def _maybe_handle_welcome(resp): - self._timing.finish_event(_sent, resp.get("sent")) - if "welcome" in resp: - self._handle_welcome(resp["welcome"]) - return resp - d.addCallback(_maybe_handle_welcome) - d.addCallback(lambda resp: self._add_inbound_messages(resp["messages"])) - return d - - def get_first_of(self, phases): - if not isinstance(phases, (list, set)): raise TypeError(type(phases)) - for phase in phases: - if not isinstance(phase, type(u"")): raise TypeError(type(phase)) - - # fire with a bytestring of the first message for any 'phase' that - # wasn't one of our own messages. It will either come from - # previously-received messages, or from an EventSource that we attach - # to the corresponding URL - _sent = self._timing.add_event("get %s" % "/".join(sorted(phases))) - - phase_and_body = self._find_inbound_message(phases) - if phase_and_body is not None: - self._timing.finish_event(_sent) - return defer.succeed(phase_and_body) - d = defer.Deferred() - msgs = [] - def _handle(name, line): - if name == "welcome": - self._handle_welcome(json.loads(line)) - if name == "message": - data = json.loads(line) - self._add_inbound_messages([data]) - phase_and_body = self._find_inbound_message(phases) - if phase_and_body is not None and not msgs: - msgs.append(phase_and_body) - self._timing.finish_event(_sent, data.get("sent")) - d.callback(None) - queryargs = urlencode([("appid", self._appid), - ("channelid", self._channelid)]) - es = ReconnectingEventSource(self._relay_url+"watch?%s" % queryargs, - _handle, self._agent) - es.startService() # TODO: .setServiceParent(self) - es.activate() - d.addCallback(lambda _: es.deactivate()) - d.addCallback(lambda _: es.stopService()) - d.addCallback(lambda _: msgs[0]) - return d - - @inlineCallbacks - def get(self, phase): - res = yield self.get_first_of([phase]) - (got_phase, body) = res - assert got_phase == phase - returnValue(body) - - def deallocate(self, mood=None): - # only try once, no retries - _sent = self._timing.add_event("close") - d = post_json(self._agent, self._relay_url+"deallocate", - {"appid": self._appid, - "channelid": self._channelid, - "side": self._side, - "mood": mood}) - def _done(resp): - self._timing.finish_event(_sent, resp.get("sent")) - d.addCallback(_done) - d.addBoth(lambda _: None) # ignore POST failure - return d - -class ChannelManager: - def __init__(self, relay, appid, side, handle_welcome, tor_manager=None, - timing=None, reactor=reactor): - assert isinstance(relay, type(u"")) - self._relay = relay - self._appid = appid - self._side = side - self._handle_welcome = handle_welcome - self._pool = web_client.HTTPConnectionPool(reactor, True) # persistent - if tor_manager: - print("ChannelManager using tor") - epf = tor_manager.get_web_agent_endpoint_factory() - agent = web_client.Agent.usingEndpointFactory(reactor, epf, - pool=self._pool) - else: - agent = web_client.Agent(reactor, pool=self._pool) - self._agent = agent - self._timing = timing or DebugTiming() - - @inlineCallbacks - def allocate(self): - url = self._relay + "allocate" - _sent = self._timing.add_event("allocate") - data = yield post_json(self._agent, url, {"appid": self._appid, - "side": self._side}) - if "welcome" in data: - self._handle_welcome(data["welcome"]) - self._timing.finish_event(_sent, data.get("sent")) - returnValue(data["channelid"]) - - @inlineCallbacks - def list_channels(self): - queryargs = urlencode([("appid", self._appid)]) - url = self._relay + u"list?%s" % queryargs - _sent = self._timing.add_event("list") - r = yield get_json(self._agent, url) - self._timing.finish_event(_sent, r.get("sent")) - returnValue(r["channelids"]) - - def connect(self, channelid): - return Channel(self._relay, self._appid, channelid, self._side, - self._handle_welcome, self._agent, self._timing) - - @inlineCallbacks - def shutdown(self): - _sent = self._timing.add_event("pool shutdown") - yield self._pool.closeCachedConnections() - self._timing.finish_event(_sent) - def close_on_error(meth): # method decorator # Clients report certain errors as "moods", so the server can make a # rough count failed connections (due to mismatched passwords, attacks, @@ -257,6 +51,31 @@ def close_on_error(meth): # method decorator return d return _wrapper +class WSClient(websocket.WebSocketClientProtocol): + def onOpen(self): + self.wormhole_open = True + self.factory.d.callback(self) + + def onMessage(self, payload, isBinary): + assert not isBinary + self.wormhole._ws_dispatch_msg(payload) + + def onClose(self, wasClean, code, reason): + if self.wormhole_open: + self.wormhole._ws_closed(wasClean, code, reason) + else: + # we closed before establishing a connection (onConnect) or + # finishing WebSocket negotiation (onOpen): errback + self.factory.d.errback(error.ConnectError(reason)) + +class WSFactory(websocket.WebSocketClientFactory): + protocol = WSClient + def buildProtocol(self, addr): + proto = websocket.WebSocketClientFactory.buildProtocol(self, addr) + proto.wormhole = self.wormhole + proto.wormhole_open = False + return proto + class Wormhole: motd_displayed = False version_warning_displayed = False @@ -270,35 +89,81 @@ class Wormhole: if not relay_url.endswith(u"/"): raise UsageError self._appid = appid self._relay_url = relay_url + self._ws_url = relay_url.replace("http:", "ws:") + "ws" self._tor_manager = tor_manager self._timing = timing or DebugTiming() self._reactor = reactor - self._set_side(hexlify(os.urandom(5)).decode("ascii")) - self.code = None - self.key = None + self._side = hexlify(os.urandom(5)).decode("ascii") + self._code = None + self._channelid = None + self._key = None self._started_get_code = False - self._sent_data = set() # phases - self._got_data = set() - self._got_confirmation = False + self._sent_messages = set() # (phase, body_bytes) + self._delivered_messages = set() # (phase, body_bytes) + self._received_messages = {} # phase -> body_bytes + self._sent_phases = set() # phases, to prohibit double-send + self._got_phases = set() # phases, to prohibit double-read + self._sleepers = [] + self._confirmation_failed = False self._closed = False + self._deallocated_status = None self._timing_started = self._timing.add_event("wormhole") + self._ws = None + self._ws_channel_claimed = False + self._error = None - def _set_side(self, side): - self._side = side - self._channel_manager = ChannelManager(self._relay_url, self._appid, - self._side, self.handle_welcome, - self._tor_manager, - self._timing, - reactor=self._reactor) - self._channel = None + def _make_endpoint(self, hostname, port): + if self._tor_manager: + return self._tor_manager.endpointForURI() + return endpoints.HostnameEndpoint(self._reactor, hostname, port) # 30s - def handle_welcome(self, welcome): + @inlineCallbacks + def _get_websocket(self): + if not self._ws: + # TODO: if we lose the connection, make a new one + #from twisted.python import log + #log.startLogging(sys.stderr) + assert self._side + assert not self._ws_channel_claimed + p = urlparse(self._ws_url) + f = WSFactory(self._ws_url) + f.wormhole = self + f.d = defer.Deferred() + # TODO: if hostname="localhost", I get three factories starting + # and stopping (maybe 127.0.0.1, ::1, and something else?), and + # an error in the factory is masked. + ep = self._make_endpoint(p.hostname, p.port or 80) + # .connect errbacks if the TCP connection fails + self._ws = yield ep.connect(f) + # f.d is errbacked if WebSocket negotiation fails + yield f.d # WebSocket drops data sent before onOpen() fires + self._ws_send(u"bind", appid=self._appid, side=self._side) + # the socket is connected, and bound, but no channel has been claimed + returnValue(self._ws) + + @inlineCallbacks + def _ws_send(self, mtype, **kwargs): + ws = yield self._get_websocket() + kwargs["type"] = mtype + payload = json.dumps(kwargs).encode("utf-8") + ws.sendMessage(payload, False) + + def _ws_dispatch_msg(self, payload): + msg = json.loads(payload.decode("utf-8")) + mtype = msg["type"] + meth = getattr(self, "_ws_handle_"+mtype, None) + if not meth: + raise ValueError("Unknown inbound message type %r" % (msg,)) + return meth(msg) + + def _ws_handle_welcome(self, msg): + welcome = msg["welcome"] if ("motd" in welcome and not self.motd_displayed): motd_lines = welcome["motd"].splitlines() motd_formatted = "\n ".join(motd_lines) print("Server (at %s) says:\n %s" % - (self._relay_url, motd_formatted), file=sys.stderr) + (self._ws_url, motd_formatted), file=sys.stderr) self.motd_displayed = True # Only warn if we're running a release version (e.g. 0.0.6, not @@ -312,94 +177,232 @@ class Wormhole: self.version_warning_displayed = True if "error" in welcome: - raise ServerError(welcome["error"], self._relay_url) + return self._signal_error(welcome["error"]) @inlineCallbacks - def get_code(self, code_length=2): - if self.code is not None: raise UsageError + def _sleep(self): + if self._error: # don't sleep if the bed's already on fire + raise self._error + d = defer.Deferred() + self._sleepers.append(d) + yield d + if self._error: + raise self._error + + def _wakeup(self): + sleepers = self._sleepers + self._sleepers = [] + for d in sleepers: + d.callback(None) + # NOTE: callers should avoid reentrancy themselves. An + # eventual-send would be safer here, but it makes synchronizing + # unit tests annoying. + + def _signal_error(self, error): + assert isinstance(error, Exception) + self._error = error + self._wakeup() + + def _ws_handle_error(self, msg): + err = ServerError("%s: %s" % (msg["error"], msg["orig"]), + self._ws_url) + return self._signal_error(err) + + @inlineCallbacks + def _claim_channel_and_watch(self): + assert self._channelid is not None + yield self._get_websocket() + if not self._ws_channel_claimed: + yield self._ws_send(u"claim", channelid=self._channelid) + self._ws_channel_claimed = True + yield self._ws_send(u"watch") + + # entry point 1: generate a new code + @inlineCallbacks + def get_code(self, code_length=2): # rename to allocate_code()? create_? + if self._code is not None: raise UsageError if self._started_get_code: raise UsageError self._started_get_code = True - channelid = yield self._channel_manager.allocate() - code = codes.make_code(channelid, code_length) + _sent = self._timing.add_event("allocate") + yield self._ws_send(u"allocate") + while self._channelid is None: + yield self._sleep() + self._timing.finish_event(_sent) + code = codes.make_code(self._channelid, code_length) assert isinstance(code, type(u"")), type(code) - self._set_code_and_channelid(code) + self._set_code(code) self._start() returnValue(code) + def _ws_handle_allocated(self, msg): + if self._channelid is not None: + return self._signal_error("got duplicate channelid") + self._channelid = msg["channelid"] + self._wakeup() + + def _start(self): + # allocate the rest now too, so it can be serialized + self._sp = SPAKE2_Symmetric(to_bytes(self._code), + idSymmetric=to_bytes(self._appid)) + self._msg1 = self._sp.start() + + # entry point 2a: interactively type in a code, with completion @inlineCallbacks def input_code(self, prompt="Enter wormhole code: ", code_length=2): def _lister(): - return blockingCallFromThread(self._reactor, - self._channel_manager.list_channels) + return blockingCallFromThread(self._reactor, self._list_channels) # fetch the list of channels ahead of time, to give us a chance to # discover the welcome message (and warn the user about an obsolete # client) - initial_channelids = yield self._channel_manager.list_channels() + # + # TODO: send the request early, show the prompt right away, hide the + # latency in the user's indecision and slow typing. If we're lucky + # the answer will come back before they hit TAB. + initial_channelids = yield self._list_channels() _start = self._timing.add_event("input code", waiting="user") code = yield deferToThread(codes.input_code_with_completion, prompt, initial_channelids, _lister, code_length) self._timing.finish_event(_start) - returnValue(code) + returnValue(code) # application will give this to set_code() + @inlineCallbacks + def _list_channels(self): + _sent = self._timing.add_event("list") + self._latest_channelids = None + yield self._ws_send(u"list") + while self._latest_channelids is None: + yield self._sleep() + self._timing.finish_event(_sent) + returnValue(self._latest_channelids) + + def _ws_handle_channelids(self, msg): + self._latest_channelids = msg["channelids"] + self._wakeup() + + # entry point 2b: paste in a fully-formed code def set_code(self, code): if not isinstance(code, type(u"")): raise TypeError(type(code)) - if self.code is not None: raise UsageError - self._set_code_and_channelid(code) - self._start() - - def _set_code_and_channelid(self, code): - if self.code is not None: raise UsageError - self._timing.add_event("code established") + if self._code is not None: raise UsageError mo = re.search(r'^(\d+)-', code) if not mo: raise ValueError("code (%s) must start with NN-" % code) - self.code = code - channelid = int(mo.group(1)) - self._channel = self._channel_manager.connect(channelid) - monitor.add(self._channel) + self._channelid = int(mo.group(1)) + self._set_code(code) + self._start() - def _start(self): - # allocate the rest now too, so it can be serialized - self.sp = SPAKE2_Symmetric(to_bytes(self.code), - idSymmetric=to_bytes(self._appid)) - self.msg1 = self.sp.start() + def _set_code(self, code): + if self._code is not None: raise UsageError + self._timing.add_event("code established") + self._code = code def serialize(self): # I can only be serialized after get_code/set_code and before # get_verifier/get_data - if self.code is None: raise UsageError - if self.key is not None: raise UsageError - if self._sent_data: raise UsageError - if self._got_data: raise UsageError + if self._code is None: raise UsageError + if self._key is not None: raise UsageError + if self._sent_phases: raise UsageError + if self._got_phases: raise UsageError data = { "appid": self._appid, "relay_url": self._relay_url, - "code": self.code, + "code": self._code, + "channelid": self._channelid, "side": self._side, - "spake2": json.loads(self.sp.serialize().decode("ascii")), - "msg1": hexlify(self.msg1).decode("ascii"), + "spake2": json.loads(self._sp.serialize().decode("ascii")), + "msg1": hexlify(self._msg1).decode("ascii"), } return json.dumps(data) + # entry point 3: resume a previously-serialized session @classmethod def from_serialized(klass, data): d = json.loads(data) self = klass(d["appid"], d["relay_url"]) - self._set_side(d["side"]) - self._set_code_and_channelid(d["code"]) + self._side = d["side"] + self._channelid = d["channelid"] + self._set_code(d["code"]) sp_data = json.dumps(d["spake2"]).encode("ascii") - self.sp = SPAKE2_Symmetric.from_serialized(sp_data) - self.msg1 = unhexlify(d["msg1"].encode("ascii")) + self._sp = SPAKE2_Symmetric.from_serialized(sp_data) + self._msg1 = unhexlify(d["msg1"].encode("ascii")) return self + @close_on_error + @inlineCallbacks + def get_verifier(self): + if self._closed: raise UsageError + if self._code is None: raise UsageError + yield self._get_master_key() + returnValue(self._verifier) + + @inlineCallbacks + def _get_master_key(self): + # TODO: prevent multiple invocation + if not self._key: + yield self._claim_channel_and_watch() + yield self._msg_send(u"pake", self._msg1) + pake_msg = yield self._msg_get(u"pake") + + self._key = self._sp.finish(pake_msg) + self._verifier = self.derive_key(u"wormhole:verifier") + self._timing.add_event("key established") + + if self._send_confirm: + # both sides send different (random) confirmation messages + confkey = self.derive_key(u"wormhole:confirmation") + nonce = os.urandom(CONFMSG_NONCE_LENGTH) + confmsg = make_confmsg(confkey, nonce) + yield self._msg_send(u"_confirm", confmsg) + + @inlineCallbacks + def _msg_send(self, phase, body, wait=False): + self._sent_messages.add( (phase, body) ) + # TODO: retry on failure, with exponential backoff. We're guarding + # against the rendezvous server being temporarily offline. + yield self._ws_send(u"add", phase=phase, + body=hexlify(body).decode("ascii")) + if wait: + while (phase, body) not in self._delivered_messages: + yield self._sleep() + + def _ws_handle_message(self, msg): + m = msg["message"] + phase = m["phase"] + body = unhexlify(m["body"].encode("ascii")) + if (phase, body) in self._sent_messages: + self._delivered_messages.add( (phase, body) ) # ack by server + self._wakeup() + return # ignore echoes of our outbound messages + if phase in self._received_messages: + # a channel collision would cause this + err = ServerError("got duplicate phase %s" % phase, self._ws_url) + return self._signal_error(err) + self._received_messages[phase] = body + if phase == u"_confirm": + confkey = self.derive_key(u"wormhole:confirmation") + nonce = body[:CONFMSG_NONCE_LENGTH] + if body != make_confmsg(confkey, nonce): + # this makes all API calls fail + return self._signal_error(WrongPasswordError()) + # now notify anyone waiting on it + self._wakeup() + + @inlineCallbacks + def _msg_get(self, phase): + _start = self._timing.add_event("get(%s)" % phase) + while phase not in self._received_messages: + yield self._sleep() # we can wait a long time here + # that will throw an error if something goes wrong + self._timing.finish_event(_start) + returnValue(self._received_messages[phase]) + def derive_key(self, purpose, length=SecretBox.KEY_SIZE): if not isinstance(purpose, type(u"")): raise TypeError(type(purpose)) - if self.key is None: + if self._key is None: # call after get_verifier() or get_data() raise UsageError - return HKDF(self.key, length, CTXinfo=to_bytes(purpose)) + return HKDF(self._key, length, CTXinfo=to_bytes(purpose)) def _encrypt_data(self, key, data): assert isinstance(key, type(b"")), type(key) @@ -417,35 +420,6 @@ class Wormhole: data = box.decrypt(encrypted) return data - @inlineCallbacks - def _get_key(self): - # TODO: prevent multiple invocation - if self.key: - returnValue(self.key) - yield self._channel.send(u"pake", self.msg1) - pake_msg = yield self._channel.get(u"pake") - - key = self.sp.finish(pake_msg) - self.key = key - self.verifier = self.derive_key(u"wormhole:verifier") - self._timing.add_event("key established") - - if not self._send_confirm: - returnValue(key) - confkey = self.derive_key(u"wormhole:confirmation") - nonce = os.urandom(CONFMSG_NONCE_LENGTH) - confmsg = make_confmsg(confkey, nonce) - yield self._channel.send(u"_confirm", confmsg) - returnValue(key) - - @close_on_error - @inlineCallbacks - def get_verifier(self): - if self._closed: raise UsageError - if self.code is None: raise UsageError - yield self._get_key() - returnValue(self.verifier) - @close_on_error @inlineCallbacks def send_data(self, outbound_data, phase=u"data", wait=False): @@ -453,52 +427,35 @@ class Wormhole: raise TypeError(type(outbound_data)) if not isinstance(phase, type(u"")): raise TypeError(type(phase)) if self._closed: raise UsageError - if phase in self._sent_data: raise UsageError # only call this once + if self._code is None: + raise UsageError("You must set_code() before send_data()") if phase.startswith(u"_"): raise UsageError # reserved for internals - if self.code is None: raise UsageError - if self._channel is None: raise UsageError - _sent = self._timing.add_event("API send data", phase=phase) + if phase in self._sent_phases: raise UsageError # only call this once + self._sent_phases.add(phase) + _sent = self._timing.add_event("API send data", phase=phase, wait=wait) # Without predefined roles, we can't derive predictably unique keys # for each side, so we use the same key for both. We use random - # nonces to keep the messages distinct, and the Channel automatically - # ignores reflections. - self._sent_data.add(phase) - yield self._get_key() + # nonces to keep the messages distinct, and we automatically ignore + # reflections. + yield self._get_master_key() data_key = self.derive_key(u"wormhole:phase:%s" % phase) outbound_encrypted = self._encrypt_data(data_key, outbound_data) - yield self._channel.send(phase, outbound_encrypted) - # Since that always waits for the server to ack the POST, we always - # behave as if wait=True. + yield self._msg_send(phase, outbound_encrypted, wait) self._timing.finish_event(_sent) @close_on_error @inlineCallbacks def get_data(self, phase=u"data"): if not isinstance(phase, type(u"")): raise TypeError(type(phase)) - if phase in self._got_data: raise UsageError # only call this once - if phase.startswith(u"_"): raise UsageError # reserved for internals if self._closed: raise UsageError - if self.code is None: raise UsageError - if self._channel is None: raise UsageError + if self._code is None: raise UsageError + if phase.startswith(u"_"): raise UsageError # reserved for internals + if phase in self._got_phases: raise UsageError # only call this once + self._got_phases.add(phase) _sent = self._timing.add_event("API get data", phase=phase) - self._got_data.add(phase) - yield self._get_key() - phases = [] - if not self._got_confirmation: - phases.append(u"_confirm") - phases.append(phase) - phase_and_body = yield self._channel.get_first_of(phases) - (got_phase, body) = phase_and_body - if got_phase == u"_confirm": - confkey = self.derive_key(u"wormhole:confirmation") - nonce = body[:CONFMSG_NONCE_LENGTH] - if body != make_confmsg(confkey, nonce): - raise WrongPasswordError - self._got_confirmation = True - phase_and_body = yield self._channel.get_first_of([phase]) - (got_phase, body) = phase_and_body + yield self._get_master_key() + body = yield self._msg_get(phase) # we can wait a long time here self._timing.finish_event(_sent) - assert got_phase == phase try: data_key = self.derive_key(u"wormhole:phase:%s" % phase) inbound_data = self._decrypt_data(data_key, body) @@ -506,16 +463,36 @@ class Wormhole: except CryptoError: raise WrongPasswordError + def _ws_closed(self, wasClean, code, reason): + self._ws = None + # TODO: schedule reconnect, unless we're done + @inlineCallbacks def close(self, res=None, mood=u"happy"): if not isinstance(mood, (type(None), type(u""))): raise TypeError(type(mood)) + if self._closed: + returnValue(None) self._closed = True - if not self._channel: + if not self._ws: returnValue(None) self._timing.finish_event(self._timing_started, mood=mood) - c, self._channel = self._channel, None - monitor.close(c) - yield c.deallocate(mood) - yield self._channel_manager.shutdown() + yield self._deallocate(mood) + # TODO: mark WebSocket as don't-reconnect + self._ws.transport.loseConnection() # probably flushes + del self._ws + @inlineCallbacks + def _deallocate(self, mood=None): + _sent = self._timing.add_event("close") + yield self._ws_send(u"deallocate", mood=mood) + while self._deallocated_status is None: + yield self._sleep() + self._timing.finish_event(_sent) + # TODO: set a timeout, don't wait forever for an ack + # TODO: if the connection is lost, let it go + returnValue(self._deallocated_status) + + def _ws_handle_deallocated(self, msg): + self._deallocated_status = msg["status"] + self._wakeup() diff --git a/tests/test_twisted.py b/tests/test_twisted.py index be766d0..a1e5f34 100644 --- a/tests/test_twisted.py +++ b/tests/test_twisted.py @@ -1,136 +1,12 @@ from __future__ import print_function import json from twisted.trial import unittest -from twisted.internet.defer import gatherResults, succeed, inlineCallbacks -from txwormhole.transcribe import (Wormhole, UsageError, ChannelManager, - WrongPasswordError) -from txwormhole.eventsource import EventSourceParser +from twisted.internet.defer import gatherResults, inlineCallbacks +from txwormhole.transcribe import Wormhole, UsageError, WrongPasswordError from .common import ServerBase APPID = u"appid" -class Channel(ServerBase, unittest.TestCase): - def ignore(self, welcome): - pass - - def test_allocate(self): - cm = ChannelManager(self.relayurl, APPID, u"side", self.ignore) - d = cm.list_channels() - def _got_channels(channels): - self.failUnlessEqual(channels, []) - d.addCallback(_got_channels) - d.addCallback(lambda _: cm.allocate()) - def _allocated(channelid): - self.failUnlessEqual(type(channelid), int) - self._channelid = channelid - d.addCallback(_allocated) - d.addCallback(lambda _: cm.connect(self._channelid)) - def _connected(c): - self._channel = c - d.addCallback(_connected) - d.addCallback(lambda _: self._channel.deallocate(u"happy")) - d.addCallback(lambda _: cm.shutdown()) - return d - - def test_messages(self): - cm1 = ChannelManager(self.relayurl, APPID, u"side1", self.ignore) - cm2 = ChannelManager(self.relayurl, APPID, u"side2", self.ignore) - c1 = cm1.connect(1) - c2 = cm2.connect(1) - - d = succeed(None) - d.addCallback(lambda _: c1.send(u"phase1", b"msg1")) - d.addCallback(lambda _: c2.get(u"phase1")) - d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg1")) - d.addCallback(lambda _: c2.send(u"phase1", b"msg2")) - d.addCallback(lambda _: c1.get(u"phase1")) - d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg2")) - # it's legal to fetch a phase multiple times, should be idempotent - d.addCallback(lambda _: c1.get(u"phase1")) - d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg2")) - # deallocating one side is not enough to destroy the channel - d.addCallback(lambda _: c2.deallocate()) - def _not_yet(_): - self._rendezvous.prune() - self.failUnlessEqual(len(self._rendezvous._apps), 1) - d.addCallback(_not_yet) - # but deallocating both will make the messages go away - d.addCallback(lambda _: c1.deallocate(u"sad")) - def _gone(_): - self._rendezvous.prune() - self.failUnlessEqual(len(self._rendezvous._apps), 0) - d.addCallback(_gone) - - d.addCallback(lambda _: cm1.shutdown()) - d.addCallback(lambda _: cm2.shutdown()) - - return d - - def test_get_multiple_phases(self): - cm1 = ChannelManager(self.relayurl, APPID, u"side1", self.ignore) - cm2 = ChannelManager(self.relayurl, APPID, u"side2", self.ignore) - c1 = cm1.connect(1) - c2 = cm2.connect(1) - - self.failUnlessRaises(TypeError, c2.get_first_of, u"phase1") - self.failUnlessRaises(TypeError, c2.get_first_of, [u"phase1", 7]) - - d = succeed(None) - d.addCallback(lambda _: c1.send(u"phase1", b"msg1")) - - d.addCallback(lambda _: c2.get_first_of([u"phase1", u"phase2"])) - d.addCallback(lambda phase_and_body: - self.failUnlessEqual(phase_and_body, - (u"phase1", b"msg1"))) - d.addCallback(lambda _: c2.get_first_of([u"phase2", u"phase1"])) - d.addCallback(lambda phase_and_body: - self.failUnlessEqual(phase_and_body, - (u"phase1", b"msg1"))) - - d.addCallback(lambda _: c1.send(u"phase2", b"msg2")) - d.addCallback(lambda _: c2.get(u"phase2")) - - # if both are present, it should prefer the first one we asked for - d.addCallback(lambda _: c2.get_first_of([u"phase1", u"phase2"])) - d.addCallback(lambda phase_and_body: - self.failUnlessEqual(phase_and_body, - (u"phase1", b"msg1"))) - d.addCallback(lambda _: c2.get_first_of([u"phase2", u"phase1"])) - d.addCallback(lambda phase_and_body: - self.failUnlessEqual(phase_and_body, - (u"phase2", b"msg2"))) - - d.addCallback(lambda _: cm1.shutdown()) - d.addCallback(lambda _: cm2.shutdown()) - - return d - - def test_appid_independence(self): - APPID_A = u"appid_A" - APPID_B = u"appid_B" - cm1a = ChannelManager(self.relayurl, APPID_A, u"side1", self.ignore) - cm2a = ChannelManager(self.relayurl, APPID_A, u"side2", self.ignore) - c1a = cm1a.connect(1) - c2a = cm2a.connect(1) - cm1b = ChannelManager(self.relayurl, APPID_B, u"side1", self.ignore) - cm2b = ChannelManager(self.relayurl, APPID_B, u"side2", self.ignore) - c1b = cm1b.connect(1) - c2b = cm2b.connect(1) - - d = succeed(None) - d.addCallback(lambda _: c1a.send(u"phase1", b"msg1a")) - d.addCallback(lambda _: c1b.send(u"phase1", b"msg1b")) - d.addCallback(lambda _: c2a.get(u"phase1")) - d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg1a")) - d.addCallback(lambda _: c2b.get(u"phase1")) - d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg1b")) - - d.addCallback(lambda _: cm1a.shutdown()) - d.addCallback(lambda _: cm2a.shutdown()) - d.addCallback(lambda _: cm1b.shutdown()) - d.addCallback(lambda _: cm2b.shutdown()) - return d - class Basic(ServerBase, unittest.TestCase): def doBoth(self, d1, d2): @@ -226,10 +102,32 @@ class Basic(ServerBase, unittest.TestCase): # and w1 won't send CONFIRM until it sees a PAKE message, which w2 # won't send until we call get_data. So we need both sides to be # running at the same time for this test. - yield self.doBoth(w1.send_data(b"data1"), - self.assertFailure(w2.get_data(), WrongPasswordError)) + d1 = w1.send_data(b"data1") + # at this point, w1 should be waiting for w2.PAKE - # and now w1 should have enough information to throw too + yield self.assertFailure(w2.get_data(), WrongPasswordError) + # * w2 will send w2.PAKE, wait for (and get) w1.PAKE, compute a key, + # send w2.CONFIRM, then wait for w1.DATA. + # * w1 will get w2.PAKE, compute a key, send w1.CONFIRM. + # * w2 gets w1.CONFIRM, notices the error, records it. + # * w2 (waiting for w1.DATA) wakes up, sees the error, throws + # * meanwhile w1 finishes sending its data. w2.CONFIRM may or may not + # have arrived by then + yield d1 + + # When we ask w1 to get_data(), one of two things might happen: + # * if w2.CONFIRM arrived already, it will have recorded the error. + # When w1.get_data() sleeps (waiting for w2.DATA), we'll notice the + # error before sleeping, and throw WrongPasswordError + # * if w2.CONFIRM hasn't arrived yet, we'll sleep. When w2.CONFIRM + # arrives, we notice and record the error, and wake up, and throw + + # Note that we didn't do w2.send_data(), so we're hoping that w1 will + # have enough information to detect the error before it sleeps + # (waiting for w2.DATA). Checking for the error both before sleeping + # and after waking up makes this happen. + + # so now w1 should have enough information to throw too yield self.assertFailure(w1.get_data(), WrongPasswordError) # both sides are closed automatically upon error, but it's still @@ -349,41 +247,3 @@ class Basic(ServerBase, unittest.TestCase): yield gatherResults([w1.close(), w2.close(), self.new_w1.close()], True) - -data1 = b"""\ -event: welcome -data: one and a -data: two -data:. - -data: three - -: this line is ignored -event: e2 -: this line is ignored too -i am a dataless field name -data: four - -""" - -class FakeTransport: - disconnecting = False - -class EventSourceClient(unittest.TestCase): - def test_parser(self): - events = [] - p = EventSourceParser(lambda t,d: events.append((t,d))) - p.transport = FakeTransport() - p.dataReceived(data1) - self.failUnlessEqual(events, - [(u"welcome", u"one and a\ntwo\n."), - (u"message", u"three"), - (u"e2", u"four"), - ]) - -# new py3 support in 15.5.0: web.client.Agent, w.c.downloadPage, twistd - -# However trying 'wormhole server start' with py3/twisted-15.5.0 throws an -# error in t.i._twistd_unix.UnixApplicationRunner.postApplication, it calls -# os.write with str, not bytes. This file does not cover that test (testing -# daemonization is hard).