txwormhole: use websockets, not HTTP

This should speed up the protocol, since we don't have to wait for
acks (HTTP responses) unless we really want to. It also makes it easier
to have multiple messages in flight at once. The protocol is still
compatible with the old HTTP version (which is still used by the
blocking flavor), but requires an updated Rendezvous server that speaks
websockets.

set_code() no longer touches the network: it just stores the code and
channelid for later. We hold off doing 'claim' and 'watch' until we need
messages, triggered by get_verifier() or get_data() or send_data().

We check for error before sleeping, not just after waking. This makes it
possible to detect a WrongPasswordError in get_data() even if the other
side hasn't done a corresponding send_data(), as long as the other side
finished PAKE (and thus sent a CONFIRM message). The unit test was doing
just this, and was hanging.
This commit is contained in:
Brian Warner 2016-04-20 12:19:09 -07:00
parent a57845eb0b
commit 34f4c284b0
2 changed files with 334 additions and 497 deletions

View File

@ -1,25 +1,20 @@
from __future__ import print_function from __future__ import print_function
import os, sys, json, re, unicodedata import os, sys, json, re, unicodedata
from six.moves.urllib_parse import urlencode from six.moves.urllib_parse import urlparse
from binascii import hexlify, unhexlify from binascii import hexlify, unhexlify
from zope.interface import implementer from twisted.internet import reactor, defer, endpoints, error
from twisted.internet import reactor, defer
from twisted.internet.threads import deferToThread, blockingCallFromThread from twisted.internet.threads import deferToThread, blockingCallFromThread
from twisted.internet.defer import inlineCallbacks, returnValue from twisted.internet.defer import inlineCallbacks, returnValue
from twisted.web import client as web_client from autobahn.twisted import websocket
from twisted.web import error as web_error
from twisted.web.iweb import IBodyProducer
from nacl.secret import SecretBox from nacl.secret import SecretBox
from nacl.exceptions import CryptoError from nacl.exceptions import CryptoError
from nacl import utils from nacl import utils
from spake2 import SPAKE2_Symmetric from spake2 import SPAKE2_Symmetric
from .eventsource import ReconnectingEventSource
from wormhole import __version__ from wormhole import __version__
from wormhole import codes from wormhole import codes
from wormhole.errors import ServerError, Timeout, WrongPasswordError, UsageError from wormhole.errors import ServerError, Timeout, WrongPasswordError, UsageError
from wormhole.timing import DebugTiming from wormhole.timing import DebugTiming
from hkdf import Hkdf from hkdf import Hkdf
from wormhole.channel_monitor import monitor
def HKDF(skm, outlen, salt=None, CTXinfo=b""): def HKDF(skm, outlen, salt=None, CTXinfo=b""):
return Hkdf(salt, skm).expand(CTXinfo, outlen) return Hkdf(salt, skm).expand(CTXinfo, outlen)
@ -32,207 +27,6 @@ def make_confmsg(confkey, nonce):
def to_bytes(u): def to_bytes(u):
return unicodedata.normalize("NFC", u).encode("utf-8") return unicodedata.normalize("NFC", u).encode("utf-8")
@implementer(IBodyProducer)
class DataProducer:
def __init__(self, data):
self.data = data
self.length = len(data)
def startProducing(self, consumer):
consumer.write(self.data)
return defer.succeed(None)
def stopProducing(self):
pass
def pauseProducing(self):
pass
def resumeProducing(self):
pass
def post_json(agent, url, request_body):
# POST a JSON body to a URL, parsing the response as JSON
data = json.dumps(request_body).encode("utf-8")
d = agent.request(b"POST", url.encode("utf-8"),
bodyProducer=DataProducer(data))
def _check_error(resp):
if resp.code != 200:
raise web_error.Error(resp.code, resp.phrase)
return resp
d.addCallback(_check_error)
d.addCallback(web_client.readBody)
d.addCallback(lambda data: json.loads(data.decode("utf-8")))
return d
def get_json(agent, url):
# GET from a URL, parsing the response as JSON
d = agent.request(b"GET", url.encode("utf-8"))
def _check_error(resp):
if resp.code != 200:
raise web_error.Error(resp.code, resp.phrase)
return resp
d.addCallback(_check_error)
d.addCallback(web_client.readBody)
d.addCallback(lambda data: json.loads(data.decode("utf-8")))
return d
class Channel:
def __init__(self, relay_url, appid, channelid, side, handle_welcome,
agent, timing):
self._relay_url = relay_url
self._appid = appid
self._channelid = channelid
self._side = side
self._handle_welcome = handle_welcome
self._agent = agent
self._timing = timing
self._messages = set() # (phase,body) , body is bytes
self._sent_messages = set() # (phase,body)
def _add_inbound_messages(self, messages):
for msg in messages:
phase = msg["phase"]
body = unhexlify(msg["body"].encode("ascii"))
self._messages.add( (phase, body) )
def _find_inbound_message(self, phases):
their_messages = self._messages - self._sent_messages
for phase in phases:
for (their_phase,body) in their_messages:
if their_phase == phase:
return (phase, body)
return None
def send(self, phase, msg):
# TODO: retry on failure, with exponential backoff. We're guarding
# against the rendezvous server being temporarily offline.
if not isinstance(phase, type(u"")): raise TypeError(type(phase))
if not isinstance(msg, type(b"")): raise TypeError(type(msg))
self._sent_messages.add( (phase,msg) )
assert isinstance(self._side, type(u"")), type(self._side)
payload = {"appid": self._appid,
"channelid": self._channelid,
"side": self._side,
"phase": phase,
"body": hexlify(msg).decode("ascii")}
_sent = self._timing.add_event("send %s" % phase)
d = post_json(self._agent, self._relay_url+"add", payload)
def _maybe_handle_welcome(resp):
self._timing.finish_event(_sent, resp.get("sent"))
if "welcome" in resp:
self._handle_welcome(resp["welcome"])
return resp
d.addCallback(_maybe_handle_welcome)
d.addCallback(lambda resp: self._add_inbound_messages(resp["messages"]))
return d
def get_first_of(self, phases):
if not isinstance(phases, (list, set)): raise TypeError(type(phases))
for phase in phases:
if not isinstance(phase, type(u"")): raise TypeError(type(phase))
# fire with a bytestring of the first message for any 'phase' that
# wasn't one of our own messages. It will either come from
# previously-received messages, or from an EventSource that we attach
# to the corresponding URL
_sent = self._timing.add_event("get %s" % "/".join(sorted(phases)))
phase_and_body = self._find_inbound_message(phases)
if phase_and_body is not None:
self._timing.finish_event(_sent)
return defer.succeed(phase_and_body)
d = defer.Deferred()
msgs = []
def _handle(name, line):
if name == "welcome":
self._handle_welcome(json.loads(line))
if name == "message":
data = json.loads(line)
self._add_inbound_messages([data])
phase_and_body = self._find_inbound_message(phases)
if phase_and_body is not None and not msgs:
msgs.append(phase_and_body)
self._timing.finish_event(_sent, data.get("sent"))
d.callback(None)
queryargs = urlencode([("appid", self._appid),
("channelid", self._channelid)])
es = ReconnectingEventSource(self._relay_url+"watch?%s" % queryargs,
_handle, self._agent)
es.startService() # TODO: .setServiceParent(self)
es.activate()
d.addCallback(lambda _: es.deactivate())
d.addCallback(lambda _: es.stopService())
d.addCallback(lambda _: msgs[0])
return d
@inlineCallbacks
def get(self, phase):
res = yield self.get_first_of([phase])
(got_phase, body) = res
assert got_phase == phase
returnValue(body)
def deallocate(self, mood=None):
# only try once, no retries
_sent = self._timing.add_event("close")
d = post_json(self._agent, self._relay_url+"deallocate",
{"appid": self._appid,
"channelid": self._channelid,
"side": self._side,
"mood": mood})
def _done(resp):
self._timing.finish_event(_sent, resp.get("sent"))
d.addCallback(_done)
d.addBoth(lambda _: None) # ignore POST failure
return d
class ChannelManager:
def __init__(self, relay, appid, side, handle_welcome, tor_manager=None,
timing=None, reactor=reactor):
assert isinstance(relay, type(u""))
self._relay = relay
self._appid = appid
self._side = side
self._handle_welcome = handle_welcome
self._pool = web_client.HTTPConnectionPool(reactor, True) # persistent
if tor_manager:
print("ChannelManager using tor")
epf = tor_manager.get_web_agent_endpoint_factory()
agent = web_client.Agent.usingEndpointFactory(reactor, epf,
pool=self._pool)
else:
agent = web_client.Agent(reactor, pool=self._pool)
self._agent = agent
self._timing = timing or DebugTiming()
@inlineCallbacks
def allocate(self):
url = self._relay + "allocate"
_sent = self._timing.add_event("allocate")
data = yield post_json(self._agent, url, {"appid": self._appid,
"side": self._side})
if "welcome" in data:
self._handle_welcome(data["welcome"])
self._timing.finish_event(_sent, data.get("sent"))
returnValue(data["channelid"])
@inlineCallbacks
def list_channels(self):
queryargs = urlencode([("appid", self._appid)])
url = self._relay + u"list?%s" % queryargs
_sent = self._timing.add_event("list")
r = yield get_json(self._agent, url)
self._timing.finish_event(_sent, r.get("sent"))
returnValue(r["channelids"])
def connect(self, channelid):
return Channel(self._relay, self._appid, channelid, self._side,
self._handle_welcome, self._agent, self._timing)
@inlineCallbacks
def shutdown(self):
_sent = self._timing.add_event("pool shutdown")
yield self._pool.closeCachedConnections()
self._timing.finish_event(_sent)
def close_on_error(meth): # method decorator def close_on_error(meth): # method decorator
# Clients report certain errors as "moods", so the server can make a # Clients report certain errors as "moods", so the server can make a
# rough count failed connections (due to mismatched passwords, attacks, # rough count failed connections (due to mismatched passwords, attacks,
@ -257,6 +51,31 @@ def close_on_error(meth): # method decorator
return d return d
return _wrapper return _wrapper
class WSClient(websocket.WebSocketClientProtocol):
def onOpen(self):
self.wormhole_open = True
self.factory.d.callback(self)
def onMessage(self, payload, isBinary):
assert not isBinary
self.wormhole._ws_dispatch_msg(payload)
def onClose(self, wasClean, code, reason):
if self.wormhole_open:
self.wormhole._ws_closed(wasClean, code, reason)
else:
# we closed before establishing a connection (onConnect) or
# finishing WebSocket negotiation (onOpen): errback
self.factory.d.errback(error.ConnectError(reason))
class WSFactory(websocket.WebSocketClientFactory):
protocol = WSClient
def buildProtocol(self, addr):
proto = websocket.WebSocketClientFactory.buildProtocol(self, addr)
proto.wormhole = self.wormhole
proto.wormhole_open = False
return proto
class Wormhole: class Wormhole:
motd_displayed = False motd_displayed = False
version_warning_displayed = False version_warning_displayed = False
@ -270,35 +89,81 @@ class Wormhole:
if not relay_url.endswith(u"/"): raise UsageError if not relay_url.endswith(u"/"): raise UsageError
self._appid = appid self._appid = appid
self._relay_url = relay_url self._relay_url = relay_url
self._ws_url = relay_url.replace("http:", "ws:") + "ws"
self._tor_manager = tor_manager self._tor_manager = tor_manager
self._timing = timing or DebugTiming() self._timing = timing or DebugTiming()
self._reactor = reactor self._reactor = reactor
self._set_side(hexlify(os.urandom(5)).decode("ascii")) self._side = hexlify(os.urandom(5)).decode("ascii")
self.code = None self._code = None
self.key = None self._channelid = None
self._key = None
self._started_get_code = False self._started_get_code = False
self._sent_data = set() # phases self._sent_messages = set() # (phase, body_bytes)
self._got_data = set() self._delivered_messages = set() # (phase, body_bytes)
self._got_confirmation = False self._received_messages = {} # phase -> body_bytes
self._sent_phases = set() # phases, to prohibit double-send
self._got_phases = set() # phases, to prohibit double-read
self._sleepers = []
self._confirmation_failed = False
self._closed = False self._closed = False
self._deallocated_status = None
self._timing_started = self._timing.add_event("wormhole") self._timing_started = self._timing.add_event("wormhole")
self._ws = None
self._ws_channel_claimed = False
self._error = None
def _set_side(self, side): def _make_endpoint(self, hostname, port):
self._side = side if self._tor_manager:
self._channel_manager = ChannelManager(self._relay_url, self._appid, return self._tor_manager.endpointForURI()
self._side, self.handle_welcome, return endpoints.HostnameEndpoint(self._reactor, hostname, port) # 30s
self._tor_manager,
self._timing,
reactor=self._reactor)
self._channel = None
def handle_welcome(self, welcome): @inlineCallbacks
def _get_websocket(self):
if not self._ws:
# TODO: if we lose the connection, make a new one
#from twisted.python import log
#log.startLogging(sys.stderr)
assert self._side
assert not self._ws_channel_claimed
p = urlparse(self._ws_url)
f = WSFactory(self._ws_url)
f.wormhole = self
f.d = defer.Deferred()
# TODO: if hostname="localhost", I get three factories starting
# and stopping (maybe 127.0.0.1, ::1, and something else?), and
# an error in the factory is masked.
ep = self._make_endpoint(p.hostname, p.port or 80)
# .connect errbacks if the TCP connection fails
self._ws = yield ep.connect(f)
# f.d is errbacked if WebSocket negotiation fails
yield f.d # WebSocket drops data sent before onOpen() fires
self._ws_send(u"bind", appid=self._appid, side=self._side)
# the socket is connected, and bound, but no channel has been claimed
returnValue(self._ws)
@inlineCallbacks
def _ws_send(self, mtype, **kwargs):
ws = yield self._get_websocket()
kwargs["type"] = mtype
payload = json.dumps(kwargs).encode("utf-8")
ws.sendMessage(payload, False)
def _ws_dispatch_msg(self, payload):
msg = json.loads(payload.decode("utf-8"))
mtype = msg["type"]
meth = getattr(self, "_ws_handle_"+mtype, None)
if not meth:
raise ValueError("Unknown inbound message type %r" % (msg,))
return meth(msg)
def _ws_handle_welcome(self, msg):
welcome = msg["welcome"]
if ("motd" in welcome and if ("motd" in welcome and
not self.motd_displayed): not self.motd_displayed):
motd_lines = welcome["motd"].splitlines() motd_lines = welcome["motd"].splitlines()
motd_formatted = "\n ".join(motd_lines) motd_formatted = "\n ".join(motd_lines)
print("Server (at %s) says:\n %s" % print("Server (at %s) says:\n %s" %
(self._relay_url, motd_formatted), file=sys.stderr) (self._ws_url, motd_formatted), file=sys.stderr)
self.motd_displayed = True self.motd_displayed = True
# Only warn if we're running a release version (e.g. 0.0.6, not # Only warn if we're running a release version (e.g. 0.0.6, not
@ -312,94 +177,232 @@ class Wormhole:
self.version_warning_displayed = True self.version_warning_displayed = True
if "error" in welcome: if "error" in welcome:
raise ServerError(welcome["error"], self._relay_url) return self._signal_error(welcome["error"])
@inlineCallbacks @inlineCallbacks
def get_code(self, code_length=2): def _sleep(self):
if self.code is not None: raise UsageError if self._error: # don't sleep if the bed's already on fire
raise self._error
d = defer.Deferred()
self._sleepers.append(d)
yield d
if self._error:
raise self._error
def _wakeup(self):
sleepers = self._sleepers
self._sleepers = []
for d in sleepers:
d.callback(None)
# NOTE: callers should avoid reentrancy themselves. An
# eventual-send would be safer here, but it makes synchronizing
# unit tests annoying.
def _signal_error(self, error):
assert isinstance(error, Exception)
self._error = error
self._wakeup()
def _ws_handle_error(self, msg):
err = ServerError("%s: %s" % (msg["error"], msg["orig"]),
self._ws_url)
return self._signal_error(err)
@inlineCallbacks
def _claim_channel_and_watch(self):
assert self._channelid is not None
yield self._get_websocket()
if not self._ws_channel_claimed:
yield self._ws_send(u"claim", channelid=self._channelid)
self._ws_channel_claimed = True
yield self._ws_send(u"watch")
# entry point 1: generate a new code
@inlineCallbacks
def get_code(self, code_length=2): # rename to allocate_code()? create_?
if self._code is not None: raise UsageError
if self._started_get_code: raise UsageError if self._started_get_code: raise UsageError
self._started_get_code = True self._started_get_code = True
channelid = yield self._channel_manager.allocate() _sent = self._timing.add_event("allocate")
code = codes.make_code(channelid, code_length) yield self._ws_send(u"allocate")
while self._channelid is None:
yield self._sleep()
self._timing.finish_event(_sent)
code = codes.make_code(self._channelid, code_length)
assert isinstance(code, type(u"")), type(code) assert isinstance(code, type(u"")), type(code)
self._set_code_and_channelid(code) self._set_code(code)
self._start() self._start()
returnValue(code) returnValue(code)
def _ws_handle_allocated(self, msg):
if self._channelid is not None:
return self._signal_error("got duplicate channelid")
self._channelid = msg["channelid"]
self._wakeup()
def _start(self):
# allocate the rest now too, so it can be serialized
self._sp = SPAKE2_Symmetric(to_bytes(self._code),
idSymmetric=to_bytes(self._appid))
self._msg1 = self._sp.start()
# entry point 2a: interactively type in a code, with completion
@inlineCallbacks @inlineCallbacks
def input_code(self, prompt="Enter wormhole code: ", code_length=2): def input_code(self, prompt="Enter wormhole code: ", code_length=2):
def _lister(): def _lister():
return blockingCallFromThread(self._reactor, return blockingCallFromThread(self._reactor, self._list_channels)
self._channel_manager.list_channels)
# fetch the list of channels ahead of time, to give us a chance to # fetch the list of channels ahead of time, to give us a chance to
# discover the welcome message (and warn the user about an obsolete # discover the welcome message (and warn the user about an obsolete
# client) # client)
initial_channelids = yield self._channel_manager.list_channels() #
# TODO: send the request early, show the prompt right away, hide the
# latency in the user's indecision and slow typing. If we're lucky
# the answer will come back before they hit TAB.
initial_channelids = yield self._list_channels()
_start = self._timing.add_event("input code", waiting="user") _start = self._timing.add_event("input code", waiting="user")
code = yield deferToThread(codes.input_code_with_completion, code = yield deferToThread(codes.input_code_with_completion,
prompt, prompt,
initial_channelids, _lister, initial_channelids, _lister,
code_length) code_length)
self._timing.finish_event(_start) self._timing.finish_event(_start)
returnValue(code) returnValue(code) # application will give this to set_code()
@inlineCallbacks
def _list_channels(self):
_sent = self._timing.add_event("list")
self._latest_channelids = None
yield self._ws_send(u"list")
while self._latest_channelids is None:
yield self._sleep()
self._timing.finish_event(_sent)
returnValue(self._latest_channelids)
def _ws_handle_channelids(self, msg):
self._latest_channelids = msg["channelids"]
self._wakeup()
# entry point 2b: paste in a fully-formed code
def set_code(self, code): def set_code(self, code):
if not isinstance(code, type(u"")): raise TypeError(type(code)) if not isinstance(code, type(u"")): raise TypeError(type(code))
if self.code is not None: raise UsageError if self._code is not None: raise UsageError
self._set_code_and_channelid(code)
self._start()
def _set_code_and_channelid(self, code):
if self.code is not None: raise UsageError
self._timing.add_event("code established")
mo = re.search(r'^(\d+)-', code) mo = re.search(r'^(\d+)-', code)
if not mo: if not mo:
raise ValueError("code (%s) must start with NN-" % code) raise ValueError("code (%s) must start with NN-" % code)
self.code = code self._channelid = int(mo.group(1))
channelid = int(mo.group(1)) self._set_code(code)
self._channel = self._channel_manager.connect(channelid) self._start()
monitor.add(self._channel)
def _start(self): def _set_code(self, code):
# allocate the rest now too, so it can be serialized if self._code is not None: raise UsageError
self.sp = SPAKE2_Symmetric(to_bytes(self.code), self._timing.add_event("code established")
idSymmetric=to_bytes(self._appid)) self._code = code
self.msg1 = self.sp.start()
def serialize(self): def serialize(self):
# I can only be serialized after get_code/set_code and before # I can only be serialized after get_code/set_code and before
# get_verifier/get_data # get_verifier/get_data
if self.code is None: raise UsageError if self._code is None: raise UsageError
if self.key is not None: raise UsageError if self._key is not None: raise UsageError
if self._sent_data: raise UsageError if self._sent_phases: raise UsageError
if self._got_data: raise UsageError if self._got_phases: raise UsageError
data = { data = {
"appid": self._appid, "appid": self._appid,
"relay_url": self._relay_url, "relay_url": self._relay_url,
"code": self.code, "code": self._code,
"channelid": self._channelid,
"side": self._side, "side": self._side,
"spake2": json.loads(self.sp.serialize().decode("ascii")), "spake2": json.loads(self._sp.serialize().decode("ascii")),
"msg1": hexlify(self.msg1).decode("ascii"), "msg1": hexlify(self._msg1).decode("ascii"),
} }
return json.dumps(data) return json.dumps(data)
# entry point 3: resume a previously-serialized session
@classmethod @classmethod
def from_serialized(klass, data): def from_serialized(klass, data):
d = json.loads(data) d = json.loads(data)
self = klass(d["appid"], d["relay_url"]) self = klass(d["appid"], d["relay_url"])
self._set_side(d["side"]) self._side = d["side"]
self._set_code_and_channelid(d["code"]) self._channelid = d["channelid"]
self._set_code(d["code"])
sp_data = json.dumps(d["spake2"]).encode("ascii") sp_data = json.dumps(d["spake2"]).encode("ascii")
self.sp = SPAKE2_Symmetric.from_serialized(sp_data) self._sp = SPAKE2_Symmetric.from_serialized(sp_data)
self.msg1 = unhexlify(d["msg1"].encode("ascii")) self._msg1 = unhexlify(d["msg1"].encode("ascii"))
return self return self
@close_on_error
@inlineCallbacks
def get_verifier(self):
if self._closed: raise UsageError
if self._code is None: raise UsageError
yield self._get_master_key()
returnValue(self._verifier)
@inlineCallbacks
def _get_master_key(self):
# TODO: prevent multiple invocation
if not self._key:
yield self._claim_channel_and_watch()
yield self._msg_send(u"pake", self._msg1)
pake_msg = yield self._msg_get(u"pake")
self._key = self._sp.finish(pake_msg)
self._verifier = self.derive_key(u"wormhole:verifier")
self._timing.add_event("key established")
if self._send_confirm:
# both sides send different (random) confirmation messages
confkey = self.derive_key(u"wormhole:confirmation")
nonce = os.urandom(CONFMSG_NONCE_LENGTH)
confmsg = make_confmsg(confkey, nonce)
yield self._msg_send(u"_confirm", confmsg)
@inlineCallbacks
def _msg_send(self, phase, body, wait=False):
self._sent_messages.add( (phase, body) )
# TODO: retry on failure, with exponential backoff. We're guarding
# against the rendezvous server being temporarily offline.
yield self._ws_send(u"add", phase=phase,
body=hexlify(body).decode("ascii"))
if wait:
while (phase, body) not in self._delivered_messages:
yield self._sleep()
def _ws_handle_message(self, msg):
m = msg["message"]
phase = m["phase"]
body = unhexlify(m["body"].encode("ascii"))
if (phase, body) in self._sent_messages:
self._delivered_messages.add( (phase, body) ) # ack by server
self._wakeup()
return # ignore echoes of our outbound messages
if phase in self._received_messages:
# a channel collision would cause this
err = ServerError("got duplicate phase %s" % phase, self._ws_url)
return self._signal_error(err)
self._received_messages[phase] = body
if phase == u"_confirm":
confkey = self.derive_key(u"wormhole:confirmation")
nonce = body[:CONFMSG_NONCE_LENGTH]
if body != make_confmsg(confkey, nonce):
# this makes all API calls fail
return self._signal_error(WrongPasswordError())
# now notify anyone waiting on it
self._wakeup()
@inlineCallbacks
def _msg_get(self, phase):
_start = self._timing.add_event("get(%s)" % phase)
while phase not in self._received_messages:
yield self._sleep() # we can wait a long time here
# that will throw an error if something goes wrong
self._timing.finish_event(_start)
returnValue(self._received_messages[phase])
def derive_key(self, purpose, length=SecretBox.KEY_SIZE): def derive_key(self, purpose, length=SecretBox.KEY_SIZE):
if not isinstance(purpose, type(u"")): raise TypeError(type(purpose)) if not isinstance(purpose, type(u"")): raise TypeError(type(purpose))
if self.key is None: if self._key is None:
# call after get_verifier() or get_data() # call after get_verifier() or get_data()
raise UsageError raise UsageError
return HKDF(self.key, length, CTXinfo=to_bytes(purpose)) return HKDF(self._key, length, CTXinfo=to_bytes(purpose))
def _encrypt_data(self, key, data): def _encrypt_data(self, key, data):
assert isinstance(key, type(b"")), type(key) assert isinstance(key, type(b"")), type(key)
@ -417,35 +420,6 @@ class Wormhole:
data = box.decrypt(encrypted) data = box.decrypt(encrypted)
return data return data
@inlineCallbacks
def _get_key(self):
# TODO: prevent multiple invocation
if self.key:
returnValue(self.key)
yield self._channel.send(u"pake", self.msg1)
pake_msg = yield self._channel.get(u"pake")
key = self.sp.finish(pake_msg)
self.key = key
self.verifier = self.derive_key(u"wormhole:verifier")
self._timing.add_event("key established")
if not self._send_confirm:
returnValue(key)
confkey = self.derive_key(u"wormhole:confirmation")
nonce = os.urandom(CONFMSG_NONCE_LENGTH)
confmsg = make_confmsg(confkey, nonce)
yield self._channel.send(u"_confirm", confmsg)
returnValue(key)
@close_on_error
@inlineCallbacks
def get_verifier(self):
if self._closed: raise UsageError
if self.code is None: raise UsageError
yield self._get_key()
returnValue(self.verifier)
@close_on_error @close_on_error
@inlineCallbacks @inlineCallbacks
def send_data(self, outbound_data, phase=u"data", wait=False): def send_data(self, outbound_data, phase=u"data", wait=False):
@ -453,52 +427,35 @@ class Wormhole:
raise TypeError(type(outbound_data)) raise TypeError(type(outbound_data))
if not isinstance(phase, type(u"")): raise TypeError(type(phase)) if not isinstance(phase, type(u"")): raise TypeError(type(phase))
if self._closed: raise UsageError if self._closed: raise UsageError
if phase in self._sent_data: raise UsageError # only call this once if self._code is None:
raise UsageError("You must set_code() before send_data()")
if phase.startswith(u"_"): raise UsageError # reserved for internals if phase.startswith(u"_"): raise UsageError # reserved for internals
if self.code is None: raise UsageError if phase in self._sent_phases: raise UsageError # only call this once
if self._channel is None: raise UsageError self._sent_phases.add(phase)
_sent = self._timing.add_event("API send data", phase=phase) _sent = self._timing.add_event("API send data", phase=phase, wait=wait)
# Without predefined roles, we can't derive predictably unique keys # Without predefined roles, we can't derive predictably unique keys
# for each side, so we use the same key for both. We use random # for each side, so we use the same key for both. We use random
# nonces to keep the messages distinct, and the Channel automatically # nonces to keep the messages distinct, and we automatically ignore
# ignores reflections. # reflections.
self._sent_data.add(phase) yield self._get_master_key()
yield self._get_key()
data_key = self.derive_key(u"wormhole:phase:%s" % phase) data_key = self.derive_key(u"wormhole:phase:%s" % phase)
outbound_encrypted = self._encrypt_data(data_key, outbound_data) outbound_encrypted = self._encrypt_data(data_key, outbound_data)
yield self._channel.send(phase, outbound_encrypted) yield self._msg_send(phase, outbound_encrypted, wait)
# Since that always waits for the server to ack the POST, we always
# behave as if wait=True.
self._timing.finish_event(_sent) self._timing.finish_event(_sent)
@close_on_error @close_on_error
@inlineCallbacks @inlineCallbacks
def get_data(self, phase=u"data"): def get_data(self, phase=u"data"):
if not isinstance(phase, type(u"")): raise TypeError(type(phase)) if not isinstance(phase, type(u"")): raise TypeError(type(phase))
if phase in self._got_data: raise UsageError # only call this once
if phase.startswith(u"_"): raise UsageError # reserved for internals
if self._closed: raise UsageError if self._closed: raise UsageError
if self.code is None: raise UsageError if self._code is None: raise UsageError
if self._channel is None: raise UsageError if phase.startswith(u"_"): raise UsageError # reserved for internals
if phase in self._got_phases: raise UsageError # only call this once
self._got_phases.add(phase)
_sent = self._timing.add_event("API get data", phase=phase) _sent = self._timing.add_event("API get data", phase=phase)
self._got_data.add(phase) yield self._get_master_key()
yield self._get_key() body = yield self._msg_get(phase) # we can wait a long time here
phases = []
if not self._got_confirmation:
phases.append(u"_confirm")
phases.append(phase)
phase_and_body = yield self._channel.get_first_of(phases)
(got_phase, body) = phase_and_body
if got_phase == u"_confirm":
confkey = self.derive_key(u"wormhole:confirmation")
nonce = body[:CONFMSG_NONCE_LENGTH]
if body != make_confmsg(confkey, nonce):
raise WrongPasswordError
self._got_confirmation = True
phase_and_body = yield self._channel.get_first_of([phase])
(got_phase, body) = phase_and_body
self._timing.finish_event(_sent) self._timing.finish_event(_sent)
assert got_phase == phase
try: try:
data_key = self.derive_key(u"wormhole:phase:%s" % phase) data_key = self.derive_key(u"wormhole:phase:%s" % phase)
inbound_data = self._decrypt_data(data_key, body) inbound_data = self._decrypt_data(data_key, body)
@ -506,16 +463,36 @@ class Wormhole:
except CryptoError: except CryptoError:
raise WrongPasswordError raise WrongPasswordError
def _ws_closed(self, wasClean, code, reason):
self._ws = None
# TODO: schedule reconnect, unless we're done
@inlineCallbacks @inlineCallbacks
def close(self, res=None, mood=u"happy"): def close(self, res=None, mood=u"happy"):
if not isinstance(mood, (type(None), type(u""))): if not isinstance(mood, (type(None), type(u""))):
raise TypeError(type(mood)) raise TypeError(type(mood))
if self._closed:
returnValue(None)
self._closed = True self._closed = True
if not self._channel: if not self._ws:
returnValue(None) returnValue(None)
self._timing.finish_event(self._timing_started, mood=mood) self._timing.finish_event(self._timing_started, mood=mood)
c, self._channel = self._channel, None yield self._deallocate(mood)
monitor.close(c) # TODO: mark WebSocket as don't-reconnect
yield c.deallocate(mood) self._ws.transport.loseConnection() # probably flushes
yield self._channel_manager.shutdown() del self._ws
@inlineCallbacks
def _deallocate(self, mood=None):
_sent = self._timing.add_event("close")
yield self._ws_send(u"deallocate", mood=mood)
while self._deallocated_status is None:
yield self._sleep()
self._timing.finish_event(_sent)
# TODO: set a timeout, don't wait forever for an ack
# TODO: if the connection is lost, let it go
returnValue(self._deallocated_status)
def _ws_handle_deallocated(self, msg):
self._deallocated_status = msg["status"]
self._wakeup()

View File

@ -1,136 +1,12 @@
from __future__ import print_function from __future__ import print_function
import json import json
from twisted.trial import unittest from twisted.trial import unittest
from twisted.internet.defer import gatherResults, succeed, inlineCallbacks from twisted.internet.defer import gatherResults, inlineCallbacks
from txwormhole.transcribe import (Wormhole, UsageError, ChannelManager, from txwormhole.transcribe import Wormhole, UsageError, WrongPasswordError
WrongPasswordError)
from txwormhole.eventsource import EventSourceParser
from .common import ServerBase from .common import ServerBase
APPID = u"appid" APPID = u"appid"
class Channel(ServerBase, unittest.TestCase):
def ignore(self, welcome):
pass
def test_allocate(self):
cm = ChannelManager(self.relayurl, APPID, u"side", self.ignore)
d = cm.list_channels()
def _got_channels(channels):
self.failUnlessEqual(channels, [])
d.addCallback(_got_channels)
d.addCallback(lambda _: cm.allocate())
def _allocated(channelid):
self.failUnlessEqual(type(channelid), int)
self._channelid = channelid
d.addCallback(_allocated)
d.addCallback(lambda _: cm.connect(self._channelid))
def _connected(c):
self._channel = c
d.addCallback(_connected)
d.addCallback(lambda _: self._channel.deallocate(u"happy"))
d.addCallback(lambda _: cm.shutdown())
return d
def test_messages(self):
cm1 = ChannelManager(self.relayurl, APPID, u"side1", self.ignore)
cm2 = ChannelManager(self.relayurl, APPID, u"side2", self.ignore)
c1 = cm1.connect(1)
c2 = cm2.connect(1)
d = succeed(None)
d.addCallback(lambda _: c1.send(u"phase1", b"msg1"))
d.addCallback(lambda _: c2.get(u"phase1"))
d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg1"))
d.addCallback(lambda _: c2.send(u"phase1", b"msg2"))
d.addCallback(lambda _: c1.get(u"phase1"))
d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg2"))
# it's legal to fetch a phase multiple times, should be idempotent
d.addCallback(lambda _: c1.get(u"phase1"))
d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg2"))
# deallocating one side is not enough to destroy the channel
d.addCallback(lambda _: c2.deallocate())
def _not_yet(_):
self._rendezvous.prune()
self.failUnlessEqual(len(self._rendezvous._apps), 1)
d.addCallback(_not_yet)
# but deallocating both will make the messages go away
d.addCallback(lambda _: c1.deallocate(u"sad"))
def _gone(_):
self._rendezvous.prune()
self.failUnlessEqual(len(self._rendezvous._apps), 0)
d.addCallback(_gone)
d.addCallback(lambda _: cm1.shutdown())
d.addCallback(lambda _: cm2.shutdown())
return d
def test_get_multiple_phases(self):
cm1 = ChannelManager(self.relayurl, APPID, u"side1", self.ignore)
cm2 = ChannelManager(self.relayurl, APPID, u"side2", self.ignore)
c1 = cm1.connect(1)
c2 = cm2.connect(1)
self.failUnlessRaises(TypeError, c2.get_first_of, u"phase1")
self.failUnlessRaises(TypeError, c2.get_first_of, [u"phase1", 7])
d = succeed(None)
d.addCallback(lambda _: c1.send(u"phase1", b"msg1"))
d.addCallback(lambda _: c2.get_first_of([u"phase1", u"phase2"]))
d.addCallback(lambda phase_and_body:
self.failUnlessEqual(phase_and_body,
(u"phase1", b"msg1")))
d.addCallback(lambda _: c2.get_first_of([u"phase2", u"phase1"]))
d.addCallback(lambda phase_and_body:
self.failUnlessEqual(phase_and_body,
(u"phase1", b"msg1")))
d.addCallback(lambda _: c1.send(u"phase2", b"msg2"))
d.addCallback(lambda _: c2.get(u"phase2"))
# if both are present, it should prefer the first one we asked for
d.addCallback(lambda _: c2.get_first_of([u"phase1", u"phase2"]))
d.addCallback(lambda phase_and_body:
self.failUnlessEqual(phase_and_body,
(u"phase1", b"msg1")))
d.addCallback(lambda _: c2.get_first_of([u"phase2", u"phase1"]))
d.addCallback(lambda phase_and_body:
self.failUnlessEqual(phase_and_body,
(u"phase2", b"msg2")))
d.addCallback(lambda _: cm1.shutdown())
d.addCallback(lambda _: cm2.shutdown())
return d
def test_appid_independence(self):
APPID_A = u"appid_A"
APPID_B = u"appid_B"
cm1a = ChannelManager(self.relayurl, APPID_A, u"side1", self.ignore)
cm2a = ChannelManager(self.relayurl, APPID_A, u"side2", self.ignore)
c1a = cm1a.connect(1)
c2a = cm2a.connect(1)
cm1b = ChannelManager(self.relayurl, APPID_B, u"side1", self.ignore)
cm2b = ChannelManager(self.relayurl, APPID_B, u"side2", self.ignore)
c1b = cm1b.connect(1)
c2b = cm2b.connect(1)
d = succeed(None)
d.addCallback(lambda _: c1a.send(u"phase1", b"msg1a"))
d.addCallback(lambda _: c1b.send(u"phase1", b"msg1b"))
d.addCallback(lambda _: c2a.get(u"phase1"))
d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg1a"))
d.addCallback(lambda _: c2b.get(u"phase1"))
d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg1b"))
d.addCallback(lambda _: cm1a.shutdown())
d.addCallback(lambda _: cm2a.shutdown())
d.addCallback(lambda _: cm1b.shutdown())
d.addCallback(lambda _: cm2b.shutdown())
return d
class Basic(ServerBase, unittest.TestCase): class Basic(ServerBase, unittest.TestCase):
def doBoth(self, d1, d2): def doBoth(self, d1, d2):
@ -226,10 +102,32 @@ class Basic(ServerBase, unittest.TestCase):
# and w1 won't send CONFIRM until it sees a PAKE message, which w2 # and w1 won't send CONFIRM until it sees a PAKE message, which w2
# won't send until we call get_data. So we need both sides to be # won't send until we call get_data. So we need both sides to be
# running at the same time for this test. # running at the same time for this test.
yield self.doBoth(w1.send_data(b"data1"), d1 = w1.send_data(b"data1")
self.assertFailure(w2.get_data(), WrongPasswordError)) # at this point, w1 should be waiting for w2.PAKE
# and now w1 should have enough information to throw too yield self.assertFailure(w2.get_data(), WrongPasswordError)
# * w2 will send w2.PAKE, wait for (and get) w1.PAKE, compute a key,
# send w2.CONFIRM, then wait for w1.DATA.
# * w1 will get w2.PAKE, compute a key, send w1.CONFIRM.
# * w2 gets w1.CONFIRM, notices the error, records it.
# * w2 (waiting for w1.DATA) wakes up, sees the error, throws
# * meanwhile w1 finishes sending its data. w2.CONFIRM may or may not
# have arrived by then
yield d1
# When we ask w1 to get_data(), one of two things might happen:
# * if w2.CONFIRM arrived already, it will have recorded the error.
# When w1.get_data() sleeps (waiting for w2.DATA), we'll notice the
# error before sleeping, and throw WrongPasswordError
# * if w2.CONFIRM hasn't arrived yet, we'll sleep. When w2.CONFIRM
# arrives, we notice and record the error, and wake up, and throw
# Note that we didn't do w2.send_data(), so we're hoping that w1 will
# have enough information to detect the error before it sleeps
# (waiting for w2.DATA). Checking for the error both before sleeping
# and after waking up makes this happen.
# so now w1 should have enough information to throw too
yield self.assertFailure(w1.get_data(), WrongPasswordError) yield self.assertFailure(w1.get_data(), WrongPasswordError)
# both sides are closed automatically upon error, but it's still # both sides are closed automatically upon error, but it's still
@ -349,41 +247,3 @@ class Basic(ServerBase, unittest.TestCase):
yield gatherResults([w1.close(), w2.close(), self.new_w1.close()], yield gatherResults([w1.close(), w2.close(), self.new_w1.close()],
True) True)
data1 = b"""\
event: welcome
data: one and a
data: two
data:.
data: three
: this line is ignored
event: e2
: this line is ignored too
i am a dataless field name
data: four
"""
class FakeTransport:
disconnecting = False
class EventSourceClient(unittest.TestCase):
def test_parser(self):
events = []
p = EventSourceParser(lambda t,d: events.append((t,d)))
p.transport = FakeTransport()
p.dataReceived(data1)
self.failUnlessEqual(events,
[(u"welcome", u"one and a\ntwo\n."),
(u"message", u"three"),
(u"e2", u"four"),
])
# new py3 support in 15.5.0: web.client.Agent, w.c.downloadPage, twistd
# However trying 'wormhole server start' with py3/twisted-15.5.0 throws an
# error in t.i._twistd_unix.UnixApplicationRunner.postApplication, it calls
# os.write with str, not bytes. This file does not cover that test (testing
# daemonization is hard).