txwormhole: use websockets, not HTTP
This should speed up the protocol, since we don't have to wait for acks (HTTP responses) unless we really want to. It also makes it easier to have multiple messages in flight at once. The protocol is still compatible with the old HTTP version (which is still used by the blocking flavor), but requires an updated Rendezvous server that speaks websockets. set_code() no longer touches the network: it just stores the code and channelid for later. We hold off doing 'claim' and 'watch' until we need messages, triggered by get_verifier() or get_data() or send_data(). We check for error before sleeping, not just after waking. This makes it possible to detect a WrongPasswordError in get_data() even if the other side hasn't done a corresponding send_data(), as long as the other side finished PAKE (and thus sent a CONFIRM message). The unit test was doing just this, and was hanging.
This commit is contained in:
parent
a57845eb0b
commit
34f4c284b0
|
@ -1,25 +1,20 @@
|
|||
from __future__ import print_function
|
||||
import os, sys, json, re, unicodedata
|
||||
from six.moves.urllib_parse import urlencode
|
||||
from six.moves.urllib_parse import urlparse
|
||||
from binascii import hexlify, unhexlify
|
||||
from zope.interface import implementer
|
||||
from twisted.internet import reactor, defer
|
||||
from twisted.internet import reactor, defer, endpoints, error
|
||||
from twisted.internet.threads import deferToThread, blockingCallFromThread
|
||||
from twisted.internet.defer import inlineCallbacks, returnValue
|
||||
from twisted.web import client as web_client
|
||||
from twisted.web import error as web_error
|
||||
from twisted.web.iweb import IBodyProducer
|
||||
from autobahn.twisted import websocket
|
||||
from nacl.secret import SecretBox
|
||||
from nacl.exceptions import CryptoError
|
||||
from nacl import utils
|
||||
from spake2 import SPAKE2_Symmetric
|
||||
from .eventsource import ReconnectingEventSource
|
||||
from wormhole import __version__
|
||||
from wormhole import codes
|
||||
from wormhole.errors import ServerError, Timeout, WrongPasswordError, UsageError
|
||||
from wormhole.timing import DebugTiming
|
||||
from hkdf import Hkdf
|
||||
from wormhole.channel_monitor import monitor
|
||||
|
||||
def HKDF(skm, outlen, salt=None, CTXinfo=b""):
|
||||
return Hkdf(salt, skm).expand(CTXinfo, outlen)
|
||||
|
@ -32,207 +27,6 @@ def make_confmsg(confkey, nonce):
|
|||
def to_bytes(u):
|
||||
return unicodedata.normalize("NFC", u).encode("utf-8")
|
||||
|
||||
@implementer(IBodyProducer)
|
||||
class DataProducer:
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
self.length = len(data)
|
||||
def startProducing(self, consumer):
|
||||
consumer.write(self.data)
|
||||
return defer.succeed(None)
|
||||
def stopProducing(self):
|
||||
pass
|
||||
def pauseProducing(self):
|
||||
pass
|
||||
def resumeProducing(self):
|
||||
pass
|
||||
|
||||
|
||||
def post_json(agent, url, request_body):
|
||||
# POST a JSON body to a URL, parsing the response as JSON
|
||||
data = json.dumps(request_body).encode("utf-8")
|
||||
d = agent.request(b"POST", url.encode("utf-8"),
|
||||
bodyProducer=DataProducer(data))
|
||||
def _check_error(resp):
|
||||
if resp.code != 200:
|
||||
raise web_error.Error(resp.code, resp.phrase)
|
||||
return resp
|
||||
d.addCallback(_check_error)
|
||||
d.addCallback(web_client.readBody)
|
||||
d.addCallback(lambda data: json.loads(data.decode("utf-8")))
|
||||
return d
|
||||
|
||||
def get_json(agent, url):
|
||||
# GET from a URL, parsing the response as JSON
|
||||
d = agent.request(b"GET", url.encode("utf-8"))
|
||||
def _check_error(resp):
|
||||
if resp.code != 200:
|
||||
raise web_error.Error(resp.code, resp.phrase)
|
||||
return resp
|
||||
d.addCallback(_check_error)
|
||||
d.addCallback(web_client.readBody)
|
||||
d.addCallback(lambda data: json.loads(data.decode("utf-8")))
|
||||
return d
|
||||
|
||||
class Channel:
|
||||
def __init__(self, relay_url, appid, channelid, side, handle_welcome,
|
||||
agent, timing):
|
||||
self._relay_url = relay_url
|
||||
self._appid = appid
|
||||
self._channelid = channelid
|
||||
self._side = side
|
||||
self._handle_welcome = handle_welcome
|
||||
self._agent = agent
|
||||
self._timing = timing
|
||||
self._messages = set() # (phase,body) , body is bytes
|
||||
self._sent_messages = set() # (phase,body)
|
||||
|
||||
def _add_inbound_messages(self, messages):
|
||||
for msg in messages:
|
||||
phase = msg["phase"]
|
||||
body = unhexlify(msg["body"].encode("ascii"))
|
||||
self._messages.add( (phase, body) )
|
||||
|
||||
def _find_inbound_message(self, phases):
|
||||
their_messages = self._messages - self._sent_messages
|
||||
for phase in phases:
|
||||
for (their_phase,body) in their_messages:
|
||||
if their_phase == phase:
|
||||
return (phase, body)
|
||||
return None
|
||||
|
||||
def send(self, phase, msg):
|
||||
# TODO: retry on failure, with exponential backoff. We're guarding
|
||||
# against the rendezvous server being temporarily offline.
|
||||
if not isinstance(phase, type(u"")): raise TypeError(type(phase))
|
||||
if not isinstance(msg, type(b"")): raise TypeError(type(msg))
|
||||
self._sent_messages.add( (phase,msg) )
|
||||
assert isinstance(self._side, type(u"")), type(self._side)
|
||||
payload = {"appid": self._appid,
|
||||
"channelid": self._channelid,
|
||||
"side": self._side,
|
||||
"phase": phase,
|
||||
"body": hexlify(msg).decode("ascii")}
|
||||
_sent = self._timing.add_event("send %s" % phase)
|
||||
d = post_json(self._agent, self._relay_url+"add", payload)
|
||||
def _maybe_handle_welcome(resp):
|
||||
self._timing.finish_event(_sent, resp.get("sent"))
|
||||
if "welcome" in resp:
|
||||
self._handle_welcome(resp["welcome"])
|
||||
return resp
|
||||
d.addCallback(_maybe_handle_welcome)
|
||||
d.addCallback(lambda resp: self._add_inbound_messages(resp["messages"]))
|
||||
return d
|
||||
|
||||
def get_first_of(self, phases):
|
||||
if not isinstance(phases, (list, set)): raise TypeError(type(phases))
|
||||
for phase in phases:
|
||||
if not isinstance(phase, type(u"")): raise TypeError(type(phase))
|
||||
|
||||
# fire with a bytestring of the first message for any 'phase' that
|
||||
# wasn't one of our own messages. It will either come from
|
||||
# previously-received messages, or from an EventSource that we attach
|
||||
# to the corresponding URL
|
||||
_sent = self._timing.add_event("get %s" % "/".join(sorted(phases)))
|
||||
|
||||
phase_and_body = self._find_inbound_message(phases)
|
||||
if phase_and_body is not None:
|
||||
self._timing.finish_event(_sent)
|
||||
return defer.succeed(phase_and_body)
|
||||
d = defer.Deferred()
|
||||
msgs = []
|
||||
def _handle(name, line):
|
||||
if name == "welcome":
|
||||
self._handle_welcome(json.loads(line))
|
||||
if name == "message":
|
||||
data = json.loads(line)
|
||||
self._add_inbound_messages([data])
|
||||
phase_and_body = self._find_inbound_message(phases)
|
||||
if phase_and_body is not None and not msgs:
|
||||
msgs.append(phase_and_body)
|
||||
self._timing.finish_event(_sent, data.get("sent"))
|
||||
d.callback(None)
|
||||
queryargs = urlencode([("appid", self._appid),
|
||||
("channelid", self._channelid)])
|
||||
es = ReconnectingEventSource(self._relay_url+"watch?%s" % queryargs,
|
||||
_handle, self._agent)
|
||||
es.startService() # TODO: .setServiceParent(self)
|
||||
es.activate()
|
||||
d.addCallback(lambda _: es.deactivate())
|
||||
d.addCallback(lambda _: es.stopService())
|
||||
d.addCallback(lambda _: msgs[0])
|
||||
return d
|
||||
|
||||
@inlineCallbacks
|
||||
def get(self, phase):
|
||||
res = yield self.get_first_of([phase])
|
||||
(got_phase, body) = res
|
||||
assert got_phase == phase
|
||||
returnValue(body)
|
||||
|
||||
def deallocate(self, mood=None):
|
||||
# only try once, no retries
|
||||
_sent = self._timing.add_event("close")
|
||||
d = post_json(self._agent, self._relay_url+"deallocate",
|
||||
{"appid": self._appid,
|
||||
"channelid": self._channelid,
|
||||
"side": self._side,
|
||||
"mood": mood})
|
||||
def _done(resp):
|
||||
self._timing.finish_event(_sent, resp.get("sent"))
|
||||
d.addCallback(_done)
|
||||
d.addBoth(lambda _: None) # ignore POST failure
|
||||
return d
|
||||
|
||||
class ChannelManager:
|
||||
def __init__(self, relay, appid, side, handle_welcome, tor_manager=None,
|
||||
timing=None, reactor=reactor):
|
||||
assert isinstance(relay, type(u""))
|
||||
self._relay = relay
|
||||
self._appid = appid
|
||||
self._side = side
|
||||
self._handle_welcome = handle_welcome
|
||||
self._pool = web_client.HTTPConnectionPool(reactor, True) # persistent
|
||||
if tor_manager:
|
||||
print("ChannelManager using tor")
|
||||
epf = tor_manager.get_web_agent_endpoint_factory()
|
||||
agent = web_client.Agent.usingEndpointFactory(reactor, epf,
|
||||
pool=self._pool)
|
||||
else:
|
||||
agent = web_client.Agent(reactor, pool=self._pool)
|
||||
self._agent = agent
|
||||
self._timing = timing or DebugTiming()
|
||||
|
||||
@inlineCallbacks
|
||||
def allocate(self):
|
||||
url = self._relay + "allocate"
|
||||
_sent = self._timing.add_event("allocate")
|
||||
data = yield post_json(self._agent, url, {"appid": self._appid,
|
||||
"side": self._side})
|
||||
if "welcome" in data:
|
||||
self._handle_welcome(data["welcome"])
|
||||
self._timing.finish_event(_sent, data.get("sent"))
|
||||
returnValue(data["channelid"])
|
||||
|
||||
@inlineCallbacks
|
||||
def list_channels(self):
|
||||
queryargs = urlencode([("appid", self._appid)])
|
||||
url = self._relay + u"list?%s" % queryargs
|
||||
_sent = self._timing.add_event("list")
|
||||
r = yield get_json(self._agent, url)
|
||||
self._timing.finish_event(_sent, r.get("sent"))
|
||||
returnValue(r["channelids"])
|
||||
|
||||
def connect(self, channelid):
|
||||
return Channel(self._relay, self._appid, channelid, self._side,
|
||||
self._handle_welcome, self._agent, self._timing)
|
||||
|
||||
@inlineCallbacks
|
||||
def shutdown(self):
|
||||
_sent = self._timing.add_event("pool shutdown")
|
||||
yield self._pool.closeCachedConnections()
|
||||
self._timing.finish_event(_sent)
|
||||
|
||||
def close_on_error(meth): # method decorator
|
||||
# Clients report certain errors as "moods", so the server can make a
|
||||
# rough count failed connections (due to mismatched passwords, attacks,
|
||||
|
@ -257,6 +51,31 @@ def close_on_error(meth): # method decorator
|
|||
return d
|
||||
return _wrapper
|
||||
|
||||
class WSClient(websocket.WebSocketClientProtocol):
|
||||
def onOpen(self):
|
||||
self.wormhole_open = True
|
||||
self.factory.d.callback(self)
|
||||
|
||||
def onMessage(self, payload, isBinary):
|
||||
assert not isBinary
|
||||
self.wormhole._ws_dispatch_msg(payload)
|
||||
|
||||
def onClose(self, wasClean, code, reason):
|
||||
if self.wormhole_open:
|
||||
self.wormhole._ws_closed(wasClean, code, reason)
|
||||
else:
|
||||
# we closed before establishing a connection (onConnect) or
|
||||
# finishing WebSocket negotiation (onOpen): errback
|
||||
self.factory.d.errback(error.ConnectError(reason))
|
||||
|
||||
class WSFactory(websocket.WebSocketClientFactory):
|
||||
protocol = WSClient
|
||||
def buildProtocol(self, addr):
|
||||
proto = websocket.WebSocketClientFactory.buildProtocol(self, addr)
|
||||
proto.wormhole = self.wormhole
|
||||
proto.wormhole_open = False
|
||||
return proto
|
||||
|
||||
class Wormhole:
|
||||
motd_displayed = False
|
||||
version_warning_displayed = False
|
||||
|
@ -270,35 +89,81 @@ class Wormhole:
|
|||
if not relay_url.endswith(u"/"): raise UsageError
|
||||
self._appid = appid
|
||||
self._relay_url = relay_url
|
||||
self._ws_url = relay_url.replace("http:", "ws:") + "ws"
|
||||
self._tor_manager = tor_manager
|
||||
self._timing = timing or DebugTiming()
|
||||
self._reactor = reactor
|
||||
self._set_side(hexlify(os.urandom(5)).decode("ascii"))
|
||||
self.code = None
|
||||
self.key = None
|
||||
self._side = hexlify(os.urandom(5)).decode("ascii")
|
||||
self._code = None
|
||||
self._channelid = None
|
||||
self._key = None
|
||||
self._started_get_code = False
|
||||
self._sent_data = set() # phases
|
||||
self._got_data = set()
|
||||
self._got_confirmation = False
|
||||
self._sent_messages = set() # (phase, body_bytes)
|
||||
self._delivered_messages = set() # (phase, body_bytes)
|
||||
self._received_messages = {} # phase -> body_bytes
|
||||
self._sent_phases = set() # phases, to prohibit double-send
|
||||
self._got_phases = set() # phases, to prohibit double-read
|
||||
self._sleepers = []
|
||||
self._confirmation_failed = False
|
||||
self._closed = False
|
||||
self._deallocated_status = None
|
||||
self._timing_started = self._timing.add_event("wormhole")
|
||||
self._ws = None
|
||||
self._ws_channel_claimed = False
|
||||
self._error = None
|
||||
|
||||
def _set_side(self, side):
|
||||
self._side = side
|
||||
self._channel_manager = ChannelManager(self._relay_url, self._appid,
|
||||
self._side, self.handle_welcome,
|
||||
self._tor_manager,
|
||||
self._timing,
|
||||
reactor=self._reactor)
|
||||
self._channel = None
|
||||
def _make_endpoint(self, hostname, port):
|
||||
if self._tor_manager:
|
||||
return self._tor_manager.endpointForURI()
|
||||
return endpoints.HostnameEndpoint(self._reactor, hostname, port) # 30s
|
||||
|
||||
def handle_welcome(self, welcome):
|
||||
@inlineCallbacks
|
||||
def _get_websocket(self):
|
||||
if not self._ws:
|
||||
# TODO: if we lose the connection, make a new one
|
||||
#from twisted.python import log
|
||||
#log.startLogging(sys.stderr)
|
||||
assert self._side
|
||||
assert not self._ws_channel_claimed
|
||||
p = urlparse(self._ws_url)
|
||||
f = WSFactory(self._ws_url)
|
||||
f.wormhole = self
|
||||
f.d = defer.Deferred()
|
||||
# TODO: if hostname="localhost", I get three factories starting
|
||||
# and stopping (maybe 127.0.0.1, ::1, and something else?), and
|
||||
# an error in the factory is masked.
|
||||
ep = self._make_endpoint(p.hostname, p.port or 80)
|
||||
# .connect errbacks if the TCP connection fails
|
||||
self._ws = yield ep.connect(f)
|
||||
# f.d is errbacked if WebSocket negotiation fails
|
||||
yield f.d # WebSocket drops data sent before onOpen() fires
|
||||
self._ws_send(u"bind", appid=self._appid, side=self._side)
|
||||
# the socket is connected, and bound, but no channel has been claimed
|
||||
returnValue(self._ws)
|
||||
|
||||
@inlineCallbacks
|
||||
def _ws_send(self, mtype, **kwargs):
|
||||
ws = yield self._get_websocket()
|
||||
kwargs["type"] = mtype
|
||||
payload = json.dumps(kwargs).encode("utf-8")
|
||||
ws.sendMessage(payload, False)
|
||||
|
||||
def _ws_dispatch_msg(self, payload):
|
||||
msg = json.loads(payload.decode("utf-8"))
|
||||
mtype = msg["type"]
|
||||
meth = getattr(self, "_ws_handle_"+mtype, None)
|
||||
if not meth:
|
||||
raise ValueError("Unknown inbound message type %r" % (msg,))
|
||||
return meth(msg)
|
||||
|
||||
def _ws_handle_welcome(self, msg):
|
||||
welcome = msg["welcome"]
|
||||
if ("motd" in welcome and
|
||||
not self.motd_displayed):
|
||||
motd_lines = welcome["motd"].splitlines()
|
||||
motd_formatted = "\n ".join(motd_lines)
|
||||
print("Server (at %s) says:\n %s" %
|
||||
(self._relay_url, motd_formatted), file=sys.stderr)
|
||||
(self._ws_url, motd_formatted), file=sys.stderr)
|
||||
self.motd_displayed = True
|
||||
|
||||
# Only warn if we're running a release version (e.g. 0.0.6, not
|
||||
|
@ -312,94 +177,232 @@ class Wormhole:
|
|||
self.version_warning_displayed = True
|
||||
|
||||
if "error" in welcome:
|
||||
raise ServerError(welcome["error"], self._relay_url)
|
||||
return self._signal_error(welcome["error"])
|
||||
|
||||
@inlineCallbacks
|
||||
def get_code(self, code_length=2):
|
||||
if self.code is not None: raise UsageError
|
||||
def _sleep(self):
|
||||
if self._error: # don't sleep if the bed's already on fire
|
||||
raise self._error
|
||||
d = defer.Deferred()
|
||||
self._sleepers.append(d)
|
||||
yield d
|
||||
if self._error:
|
||||
raise self._error
|
||||
|
||||
def _wakeup(self):
|
||||
sleepers = self._sleepers
|
||||
self._sleepers = []
|
||||
for d in sleepers:
|
||||
d.callback(None)
|
||||
# NOTE: callers should avoid reentrancy themselves. An
|
||||
# eventual-send would be safer here, but it makes synchronizing
|
||||
# unit tests annoying.
|
||||
|
||||
def _signal_error(self, error):
|
||||
assert isinstance(error, Exception)
|
||||
self._error = error
|
||||
self._wakeup()
|
||||
|
||||
def _ws_handle_error(self, msg):
|
||||
err = ServerError("%s: %s" % (msg["error"], msg["orig"]),
|
||||
self._ws_url)
|
||||
return self._signal_error(err)
|
||||
|
||||
@inlineCallbacks
|
||||
def _claim_channel_and_watch(self):
|
||||
assert self._channelid is not None
|
||||
yield self._get_websocket()
|
||||
if not self._ws_channel_claimed:
|
||||
yield self._ws_send(u"claim", channelid=self._channelid)
|
||||
self._ws_channel_claimed = True
|
||||
yield self._ws_send(u"watch")
|
||||
|
||||
# entry point 1: generate a new code
|
||||
@inlineCallbacks
|
||||
def get_code(self, code_length=2): # rename to allocate_code()? create_?
|
||||
if self._code is not None: raise UsageError
|
||||
if self._started_get_code: raise UsageError
|
||||
self._started_get_code = True
|
||||
channelid = yield self._channel_manager.allocate()
|
||||
code = codes.make_code(channelid, code_length)
|
||||
_sent = self._timing.add_event("allocate")
|
||||
yield self._ws_send(u"allocate")
|
||||
while self._channelid is None:
|
||||
yield self._sleep()
|
||||
self._timing.finish_event(_sent)
|
||||
code = codes.make_code(self._channelid, code_length)
|
||||
assert isinstance(code, type(u"")), type(code)
|
||||
self._set_code_and_channelid(code)
|
||||
self._set_code(code)
|
||||
self._start()
|
||||
returnValue(code)
|
||||
|
||||
def _ws_handle_allocated(self, msg):
|
||||
if self._channelid is not None:
|
||||
return self._signal_error("got duplicate channelid")
|
||||
self._channelid = msg["channelid"]
|
||||
self._wakeup()
|
||||
|
||||
def _start(self):
|
||||
# allocate the rest now too, so it can be serialized
|
||||
self._sp = SPAKE2_Symmetric(to_bytes(self._code),
|
||||
idSymmetric=to_bytes(self._appid))
|
||||
self._msg1 = self._sp.start()
|
||||
|
||||
# entry point 2a: interactively type in a code, with completion
|
||||
@inlineCallbacks
|
||||
def input_code(self, prompt="Enter wormhole code: ", code_length=2):
|
||||
def _lister():
|
||||
return blockingCallFromThread(self._reactor,
|
||||
self._channel_manager.list_channels)
|
||||
return blockingCallFromThread(self._reactor, self._list_channels)
|
||||
# fetch the list of channels ahead of time, to give us a chance to
|
||||
# discover the welcome message (and warn the user about an obsolete
|
||||
# client)
|
||||
initial_channelids = yield self._channel_manager.list_channels()
|
||||
#
|
||||
# TODO: send the request early, show the prompt right away, hide the
|
||||
# latency in the user's indecision and slow typing. If we're lucky
|
||||
# the answer will come back before they hit TAB.
|
||||
initial_channelids = yield self._list_channels()
|
||||
_start = self._timing.add_event("input code", waiting="user")
|
||||
code = yield deferToThread(codes.input_code_with_completion,
|
||||
prompt,
|
||||
initial_channelids, _lister,
|
||||
code_length)
|
||||
self._timing.finish_event(_start)
|
||||
returnValue(code)
|
||||
returnValue(code) # application will give this to set_code()
|
||||
|
||||
@inlineCallbacks
|
||||
def _list_channels(self):
|
||||
_sent = self._timing.add_event("list")
|
||||
self._latest_channelids = None
|
||||
yield self._ws_send(u"list")
|
||||
while self._latest_channelids is None:
|
||||
yield self._sleep()
|
||||
self._timing.finish_event(_sent)
|
||||
returnValue(self._latest_channelids)
|
||||
|
||||
def _ws_handle_channelids(self, msg):
|
||||
self._latest_channelids = msg["channelids"]
|
||||
self._wakeup()
|
||||
|
||||
# entry point 2b: paste in a fully-formed code
|
||||
def set_code(self, code):
|
||||
if not isinstance(code, type(u"")): raise TypeError(type(code))
|
||||
if self.code is not None: raise UsageError
|
||||
self._set_code_and_channelid(code)
|
||||
self._start()
|
||||
|
||||
def _set_code_and_channelid(self, code):
|
||||
if self.code is not None: raise UsageError
|
||||
self._timing.add_event("code established")
|
||||
if self._code is not None: raise UsageError
|
||||
mo = re.search(r'^(\d+)-', code)
|
||||
if not mo:
|
||||
raise ValueError("code (%s) must start with NN-" % code)
|
||||
self.code = code
|
||||
channelid = int(mo.group(1))
|
||||
self._channel = self._channel_manager.connect(channelid)
|
||||
monitor.add(self._channel)
|
||||
self._channelid = int(mo.group(1))
|
||||
self._set_code(code)
|
||||
self._start()
|
||||
|
||||
def _start(self):
|
||||
# allocate the rest now too, so it can be serialized
|
||||
self.sp = SPAKE2_Symmetric(to_bytes(self.code),
|
||||
idSymmetric=to_bytes(self._appid))
|
||||
self.msg1 = self.sp.start()
|
||||
def _set_code(self, code):
|
||||
if self._code is not None: raise UsageError
|
||||
self._timing.add_event("code established")
|
||||
self._code = code
|
||||
|
||||
def serialize(self):
|
||||
# I can only be serialized after get_code/set_code and before
|
||||
# get_verifier/get_data
|
||||
if self.code is None: raise UsageError
|
||||
if self.key is not None: raise UsageError
|
||||
if self._sent_data: raise UsageError
|
||||
if self._got_data: raise UsageError
|
||||
if self._code is None: raise UsageError
|
||||
if self._key is not None: raise UsageError
|
||||
if self._sent_phases: raise UsageError
|
||||
if self._got_phases: raise UsageError
|
||||
data = {
|
||||
"appid": self._appid,
|
||||
"relay_url": self._relay_url,
|
||||
"code": self.code,
|
||||
"code": self._code,
|
||||
"channelid": self._channelid,
|
||||
"side": self._side,
|
||||
"spake2": json.loads(self.sp.serialize().decode("ascii")),
|
||||
"msg1": hexlify(self.msg1).decode("ascii"),
|
||||
"spake2": json.loads(self._sp.serialize().decode("ascii")),
|
||||
"msg1": hexlify(self._msg1).decode("ascii"),
|
||||
}
|
||||
return json.dumps(data)
|
||||
|
||||
# entry point 3: resume a previously-serialized session
|
||||
@classmethod
|
||||
def from_serialized(klass, data):
|
||||
d = json.loads(data)
|
||||
self = klass(d["appid"], d["relay_url"])
|
||||
self._set_side(d["side"])
|
||||
self._set_code_and_channelid(d["code"])
|
||||
self._side = d["side"]
|
||||
self._channelid = d["channelid"]
|
||||
self._set_code(d["code"])
|
||||
sp_data = json.dumps(d["spake2"]).encode("ascii")
|
||||
self.sp = SPAKE2_Symmetric.from_serialized(sp_data)
|
||||
self.msg1 = unhexlify(d["msg1"].encode("ascii"))
|
||||
self._sp = SPAKE2_Symmetric.from_serialized(sp_data)
|
||||
self._msg1 = unhexlify(d["msg1"].encode("ascii"))
|
||||
return self
|
||||
|
||||
@close_on_error
|
||||
@inlineCallbacks
|
||||
def get_verifier(self):
|
||||
if self._closed: raise UsageError
|
||||
if self._code is None: raise UsageError
|
||||
yield self._get_master_key()
|
||||
returnValue(self._verifier)
|
||||
|
||||
@inlineCallbacks
|
||||
def _get_master_key(self):
|
||||
# TODO: prevent multiple invocation
|
||||
if not self._key:
|
||||
yield self._claim_channel_and_watch()
|
||||
yield self._msg_send(u"pake", self._msg1)
|
||||
pake_msg = yield self._msg_get(u"pake")
|
||||
|
||||
self._key = self._sp.finish(pake_msg)
|
||||
self._verifier = self.derive_key(u"wormhole:verifier")
|
||||
self._timing.add_event("key established")
|
||||
|
||||
if self._send_confirm:
|
||||
# both sides send different (random) confirmation messages
|
||||
confkey = self.derive_key(u"wormhole:confirmation")
|
||||
nonce = os.urandom(CONFMSG_NONCE_LENGTH)
|
||||
confmsg = make_confmsg(confkey, nonce)
|
||||
yield self._msg_send(u"_confirm", confmsg)
|
||||
|
||||
@inlineCallbacks
|
||||
def _msg_send(self, phase, body, wait=False):
|
||||
self._sent_messages.add( (phase, body) )
|
||||
# TODO: retry on failure, with exponential backoff. We're guarding
|
||||
# against the rendezvous server being temporarily offline.
|
||||
yield self._ws_send(u"add", phase=phase,
|
||||
body=hexlify(body).decode("ascii"))
|
||||
if wait:
|
||||
while (phase, body) not in self._delivered_messages:
|
||||
yield self._sleep()
|
||||
|
||||
def _ws_handle_message(self, msg):
|
||||
m = msg["message"]
|
||||
phase = m["phase"]
|
||||
body = unhexlify(m["body"].encode("ascii"))
|
||||
if (phase, body) in self._sent_messages:
|
||||
self._delivered_messages.add( (phase, body) ) # ack by server
|
||||
self._wakeup()
|
||||
return # ignore echoes of our outbound messages
|
||||
if phase in self._received_messages:
|
||||
# a channel collision would cause this
|
||||
err = ServerError("got duplicate phase %s" % phase, self._ws_url)
|
||||
return self._signal_error(err)
|
||||
self._received_messages[phase] = body
|
||||
if phase == u"_confirm":
|
||||
confkey = self.derive_key(u"wormhole:confirmation")
|
||||
nonce = body[:CONFMSG_NONCE_LENGTH]
|
||||
if body != make_confmsg(confkey, nonce):
|
||||
# this makes all API calls fail
|
||||
return self._signal_error(WrongPasswordError())
|
||||
# now notify anyone waiting on it
|
||||
self._wakeup()
|
||||
|
||||
@inlineCallbacks
|
||||
def _msg_get(self, phase):
|
||||
_start = self._timing.add_event("get(%s)" % phase)
|
||||
while phase not in self._received_messages:
|
||||
yield self._sleep() # we can wait a long time here
|
||||
# that will throw an error if something goes wrong
|
||||
self._timing.finish_event(_start)
|
||||
returnValue(self._received_messages[phase])
|
||||
|
||||
def derive_key(self, purpose, length=SecretBox.KEY_SIZE):
|
||||
if not isinstance(purpose, type(u"")): raise TypeError(type(purpose))
|
||||
if self.key is None:
|
||||
if self._key is None:
|
||||
# call after get_verifier() or get_data()
|
||||
raise UsageError
|
||||
return HKDF(self.key, length, CTXinfo=to_bytes(purpose))
|
||||
return HKDF(self._key, length, CTXinfo=to_bytes(purpose))
|
||||
|
||||
def _encrypt_data(self, key, data):
|
||||
assert isinstance(key, type(b"")), type(key)
|
||||
|
@ -417,35 +420,6 @@ class Wormhole:
|
|||
data = box.decrypt(encrypted)
|
||||
return data
|
||||
|
||||
@inlineCallbacks
|
||||
def _get_key(self):
|
||||
# TODO: prevent multiple invocation
|
||||
if self.key:
|
||||
returnValue(self.key)
|
||||
yield self._channel.send(u"pake", self.msg1)
|
||||
pake_msg = yield self._channel.get(u"pake")
|
||||
|
||||
key = self.sp.finish(pake_msg)
|
||||
self.key = key
|
||||
self.verifier = self.derive_key(u"wormhole:verifier")
|
||||
self._timing.add_event("key established")
|
||||
|
||||
if not self._send_confirm:
|
||||
returnValue(key)
|
||||
confkey = self.derive_key(u"wormhole:confirmation")
|
||||
nonce = os.urandom(CONFMSG_NONCE_LENGTH)
|
||||
confmsg = make_confmsg(confkey, nonce)
|
||||
yield self._channel.send(u"_confirm", confmsg)
|
||||
returnValue(key)
|
||||
|
||||
@close_on_error
|
||||
@inlineCallbacks
|
||||
def get_verifier(self):
|
||||
if self._closed: raise UsageError
|
||||
if self.code is None: raise UsageError
|
||||
yield self._get_key()
|
||||
returnValue(self.verifier)
|
||||
|
||||
@close_on_error
|
||||
@inlineCallbacks
|
||||
def send_data(self, outbound_data, phase=u"data", wait=False):
|
||||
|
@ -453,52 +427,35 @@ class Wormhole:
|
|||
raise TypeError(type(outbound_data))
|
||||
if not isinstance(phase, type(u"")): raise TypeError(type(phase))
|
||||
if self._closed: raise UsageError
|
||||
if phase in self._sent_data: raise UsageError # only call this once
|
||||
if self._code is None:
|
||||
raise UsageError("You must set_code() before send_data()")
|
||||
if phase.startswith(u"_"): raise UsageError # reserved for internals
|
||||
if self.code is None: raise UsageError
|
||||
if self._channel is None: raise UsageError
|
||||
_sent = self._timing.add_event("API send data", phase=phase)
|
||||
if phase in self._sent_phases: raise UsageError # only call this once
|
||||
self._sent_phases.add(phase)
|
||||
_sent = self._timing.add_event("API send data", phase=phase, wait=wait)
|
||||
# Without predefined roles, we can't derive predictably unique keys
|
||||
# for each side, so we use the same key for both. We use random
|
||||
# nonces to keep the messages distinct, and the Channel automatically
|
||||
# ignores reflections.
|
||||
self._sent_data.add(phase)
|
||||
yield self._get_key()
|
||||
# nonces to keep the messages distinct, and we automatically ignore
|
||||
# reflections.
|
||||
yield self._get_master_key()
|
||||
data_key = self.derive_key(u"wormhole:phase:%s" % phase)
|
||||
outbound_encrypted = self._encrypt_data(data_key, outbound_data)
|
||||
yield self._channel.send(phase, outbound_encrypted)
|
||||
# Since that always waits for the server to ack the POST, we always
|
||||
# behave as if wait=True.
|
||||
yield self._msg_send(phase, outbound_encrypted, wait)
|
||||
self._timing.finish_event(_sent)
|
||||
|
||||
@close_on_error
|
||||
@inlineCallbacks
|
||||
def get_data(self, phase=u"data"):
|
||||
if not isinstance(phase, type(u"")): raise TypeError(type(phase))
|
||||
if phase in self._got_data: raise UsageError # only call this once
|
||||
if phase.startswith(u"_"): raise UsageError # reserved for internals
|
||||
if self._closed: raise UsageError
|
||||
if self.code is None: raise UsageError
|
||||
if self._channel is None: raise UsageError
|
||||
if self._code is None: raise UsageError
|
||||
if phase.startswith(u"_"): raise UsageError # reserved for internals
|
||||
if phase in self._got_phases: raise UsageError # only call this once
|
||||
self._got_phases.add(phase)
|
||||
_sent = self._timing.add_event("API get data", phase=phase)
|
||||
self._got_data.add(phase)
|
||||
yield self._get_key()
|
||||
phases = []
|
||||
if not self._got_confirmation:
|
||||
phases.append(u"_confirm")
|
||||
phases.append(phase)
|
||||
phase_and_body = yield self._channel.get_first_of(phases)
|
||||
(got_phase, body) = phase_and_body
|
||||
if got_phase == u"_confirm":
|
||||
confkey = self.derive_key(u"wormhole:confirmation")
|
||||
nonce = body[:CONFMSG_NONCE_LENGTH]
|
||||
if body != make_confmsg(confkey, nonce):
|
||||
raise WrongPasswordError
|
||||
self._got_confirmation = True
|
||||
phase_and_body = yield self._channel.get_first_of([phase])
|
||||
(got_phase, body) = phase_and_body
|
||||
yield self._get_master_key()
|
||||
body = yield self._msg_get(phase) # we can wait a long time here
|
||||
self._timing.finish_event(_sent)
|
||||
assert got_phase == phase
|
||||
try:
|
||||
data_key = self.derive_key(u"wormhole:phase:%s" % phase)
|
||||
inbound_data = self._decrypt_data(data_key, body)
|
||||
|
@ -506,16 +463,36 @@ class Wormhole:
|
|||
except CryptoError:
|
||||
raise WrongPasswordError
|
||||
|
||||
def _ws_closed(self, wasClean, code, reason):
|
||||
self._ws = None
|
||||
# TODO: schedule reconnect, unless we're done
|
||||
|
||||
@inlineCallbacks
|
||||
def close(self, res=None, mood=u"happy"):
|
||||
if not isinstance(mood, (type(None), type(u""))):
|
||||
raise TypeError(type(mood))
|
||||
if self._closed:
|
||||
returnValue(None)
|
||||
self._closed = True
|
||||
if not self._channel:
|
||||
if not self._ws:
|
||||
returnValue(None)
|
||||
self._timing.finish_event(self._timing_started, mood=mood)
|
||||
c, self._channel = self._channel, None
|
||||
monitor.close(c)
|
||||
yield c.deallocate(mood)
|
||||
yield self._channel_manager.shutdown()
|
||||
yield self._deallocate(mood)
|
||||
# TODO: mark WebSocket as don't-reconnect
|
||||
self._ws.transport.loseConnection() # probably flushes
|
||||
del self._ws
|
||||
|
||||
@inlineCallbacks
|
||||
def _deallocate(self, mood=None):
|
||||
_sent = self._timing.add_event("close")
|
||||
yield self._ws_send(u"deallocate", mood=mood)
|
||||
while self._deallocated_status is None:
|
||||
yield self._sleep()
|
||||
self._timing.finish_event(_sent)
|
||||
# TODO: set a timeout, don't wait forever for an ack
|
||||
# TODO: if the connection is lost, let it go
|
||||
returnValue(self._deallocated_status)
|
||||
|
||||
def _ws_handle_deallocated(self, msg):
|
||||
self._deallocated_status = msg["status"]
|
||||
self._wakeup()
|
||||
|
|
|
@ -1,136 +1,12 @@
|
|||
from __future__ import print_function
|
||||
import json
|
||||
from twisted.trial import unittest
|
||||
from twisted.internet.defer import gatherResults, succeed, inlineCallbacks
|
||||
from txwormhole.transcribe import (Wormhole, UsageError, ChannelManager,
|
||||
WrongPasswordError)
|
||||
from txwormhole.eventsource import EventSourceParser
|
||||
from twisted.internet.defer import gatherResults, inlineCallbacks
|
||||
from txwormhole.transcribe import Wormhole, UsageError, WrongPasswordError
|
||||
from .common import ServerBase
|
||||
|
||||
APPID = u"appid"
|
||||
|
||||
class Channel(ServerBase, unittest.TestCase):
|
||||
def ignore(self, welcome):
|
||||
pass
|
||||
|
||||
def test_allocate(self):
|
||||
cm = ChannelManager(self.relayurl, APPID, u"side", self.ignore)
|
||||
d = cm.list_channels()
|
||||
def _got_channels(channels):
|
||||
self.failUnlessEqual(channels, [])
|
||||
d.addCallback(_got_channels)
|
||||
d.addCallback(lambda _: cm.allocate())
|
||||
def _allocated(channelid):
|
||||
self.failUnlessEqual(type(channelid), int)
|
||||
self._channelid = channelid
|
||||
d.addCallback(_allocated)
|
||||
d.addCallback(lambda _: cm.connect(self._channelid))
|
||||
def _connected(c):
|
||||
self._channel = c
|
||||
d.addCallback(_connected)
|
||||
d.addCallback(lambda _: self._channel.deallocate(u"happy"))
|
||||
d.addCallback(lambda _: cm.shutdown())
|
||||
return d
|
||||
|
||||
def test_messages(self):
|
||||
cm1 = ChannelManager(self.relayurl, APPID, u"side1", self.ignore)
|
||||
cm2 = ChannelManager(self.relayurl, APPID, u"side2", self.ignore)
|
||||
c1 = cm1.connect(1)
|
||||
c2 = cm2.connect(1)
|
||||
|
||||
d = succeed(None)
|
||||
d.addCallback(lambda _: c1.send(u"phase1", b"msg1"))
|
||||
d.addCallback(lambda _: c2.get(u"phase1"))
|
||||
d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg1"))
|
||||
d.addCallback(lambda _: c2.send(u"phase1", b"msg2"))
|
||||
d.addCallback(lambda _: c1.get(u"phase1"))
|
||||
d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg2"))
|
||||
# it's legal to fetch a phase multiple times, should be idempotent
|
||||
d.addCallback(lambda _: c1.get(u"phase1"))
|
||||
d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg2"))
|
||||
# deallocating one side is not enough to destroy the channel
|
||||
d.addCallback(lambda _: c2.deallocate())
|
||||
def _not_yet(_):
|
||||
self._rendezvous.prune()
|
||||
self.failUnlessEqual(len(self._rendezvous._apps), 1)
|
||||
d.addCallback(_not_yet)
|
||||
# but deallocating both will make the messages go away
|
||||
d.addCallback(lambda _: c1.deallocate(u"sad"))
|
||||
def _gone(_):
|
||||
self._rendezvous.prune()
|
||||
self.failUnlessEqual(len(self._rendezvous._apps), 0)
|
||||
d.addCallback(_gone)
|
||||
|
||||
d.addCallback(lambda _: cm1.shutdown())
|
||||
d.addCallback(lambda _: cm2.shutdown())
|
||||
|
||||
return d
|
||||
|
||||
def test_get_multiple_phases(self):
|
||||
cm1 = ChannelManager(self.relayurl, APPID, u"side1", self.ignore)
|
||||
cm2 = ChannelManager(self.relayurl, APPID, u"side2", self.ignore)
|
||||
c1 = cm1.connect(1)
|
||||
c2 = cm2.connect(1)
|
||||
|
||||
self.failUnlessRaises(TypeError, c2.get_first_of, u"phase1")
|
||||
self.failUnlessRaises(TypeError, c2.get_first_of, [u"phase1", 7])
|
||||
|
||||
d = succeed(None)
|
||||
d.addCallback(lambda _: c1.send(u"phase1", b"msg1"))
|
||||
|
||||
d.addCallback(lambda _: c2.get_first_of([u"phase1", u"phase2"]))
|
||||
d.addCallback(lambda phase_and_body:
|
||||
self.failUnlessEqual(phase_and_body,
|
||||
(u"phase1", b"msg1")))
|
||||
d.addCallback(lambda _: c2.get_first_of([u"phase2", u"phase1"]))
|
||||
d.addCallback(lambda phase_and_body:
|
||||
self.failUnlessEqual(phase_and_body,
|
||||
(u"phase1", b"msg1")))
|
||||
|
||||
d.addCallback(lambda _: c1.send(u"phase2", b"msg2"))
|
||||
d.addCallback(lambda _: c2.get(u"phase2"))
|
||||
|
||||
# if both are present, it should prefer the first one we asked for
|
||||
d.addCallback(lambda _: c2.get_first_of([u"phase1", u"phase2"]))
|
||||
d.addCallback(lambda phase_and_body:
|
||||
self.failUnlessEqual(phase_and_body,
|
||||
(u"phase1", b"msg1")))
|
||||
d.addCallback(lambda _: c2.get_first_of([u"phase2", u"phase1"]))
|
||||
d.addCallback(lambda phase_and_body:
|
||||
self.failUnlessEqual(phase_and_body,
|
||||
(u"phase2", b"msg2")))
|
||||
|
||||
d.addCallback(lambda _: cm1.shutdown())
|
||||
d.addCallback(lambda _: cm2.shutdown())
|
||||
|
||||
return d
|
||||
|
||||
def test_appid_independence(self):
|
||||
APPID_A = u"appid_A"
|
||||
APPID_B = u"appid_B"
|
||||
cm1a = ChannelManager(self.relayurl, APPID_A, u"side1", self.ignore)
|
||||
cm2a = ChannelManager(self.relayurl, APPID_A, u"side2", self.ignore)
|
||||
c1a = cm1a.connect(1)
|
||||
c2a = cm2a.connect(1)
|
||||
cm1b = ChannelManager(self.relayurl, APPID_B, u"side1", self.ignore)
|
||||
cm2b = ChannelManager(self.relayurl, APPID_B, u"side2", self.ignore)
|
||||
c1b = cm1b.connect(1)
|
||||
c2b = cm2b.connect(1)
|
||||
|
||||
d = succeed(None)
|
||||
d.addCallback(lambda _: c1a.send(u"phase1", b"msg1a"))
|
||||
d.addCallback(lambda _: c1b.send(u"phase1", b"msg1b"))
|
||||
d.addCallback(lambda _: c2a.get(u"phase1"))
|
||||
d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg1a"))
|
||||
d.addCallback(lambda _: c2b.get(u"phase1"))
|
||||
d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg1b"))
|
||||
|
||||
d.addCallback(lambda _: cm1a.shutdown())
|
||||
d.addCallback(lambda _: cm2a.shutdown())
|
||||
d.addCallback(lambda _: cm1b.shutdown())
|
||||
d.addCallback(lambda _: cm2b.shutdown())
|
||||
return d
|
||||
|
||||
class Basic(ServerBase, unittest.TestCase):
|
||||
|
||||
def doBoth(self, d1, d2):
|
||||
|
@ -226,10 +102,32 @@ class Basic(ServerBase, unittest.TestCase):
|
|||
# and w1 won't send CONFIRM until it sees a PAKE message, which w2
|
||||
# won't send until we call get_data. So we need both sides to be
|
||||
# running at the same time for this test.
|
||||
yield self.doBoth(w1.send_data(b"data1"),
|
||||
self.assertFailure(w2.get_data(), WrongPasswordError))
|
||||
d1 = w1.send_data(b"data1")
|
||||
# at this point, w1 should be waiting for w2.PAKE
|
||||
|
||||
# and now w1 should have enough information to throw too
|
||||
yield self.assertFailure(w2.get_data(), WrongPasswordError)
|
||||
# * w2 will send w2.PAKE, wait for (and get) w1.PAKE, compute a key,
|
||||
# send w2.CONFIRM, then wait for w1.DATA.
|
||||
# * w1 will get w2.PAKE, compute a key, send w1.CONFIRM.
|
||||
# * w2 gets w1.CONFIRM, notices the error, records it.
|
||||
# * w2 (waiting for w1.DATA) wakes up, sees the error, throws
|
||||
# * meanwhile w1 finishes sending its data. w2.CONFIRM may or may not
|
||||
# have arrived by then
|
||||
yield d1
|
||||
|
||||
# When we ask w1 to get_data(), one of two things might happen:
|
||||
# * if w2.CONFIRM arrived already, it will have recorded the error.
|
||||
# When w1.get_data() sleeps (waiting for w2.DATA), we'll notice the
|
||||
# error before sleeping, and throw WrongPasswordError
|
||||
# * if w2.CONFIRM hasn't arrived yet, we'll sleep. When w2.CONFIRM
|
||||
# arrives, we notice and record the error, and wake up, and throw
|
||||
|
||||
# Note that we didn't do w2.send_data(), so we're hoping that w1 will
|
||||
# have enough information to detect the error before it sleeps
|
||||
# (waiting for w2.DATA). Checking for the error both before sleeping
|
||||
# and after waking up makes this happen.
|
||||
|
||||
# so now w1 should have enough information to throw too
|
||||
yield self.assertFailure(w1.get_data(), WrongPasswordError)
|
||||
|
||||
# both sides are closed automatically upon error, but it's still
|
||||
|
@ -349,41 +247,3 @@ class Basic(ServerBase, unittest.TestCase):
|
|||
yield gatherResults([w1.close(), w2.close(), self.new_w1.close()],
|
||||
True)
|
||||
|
||||
|
||||
data1 = b"""\
|
||||
event: welcome
|
||||
data: one and a
|
||||
data: two
|
||||
data:.
|
||||
|
||||
data: three
|
||||
|
||||
: this line is ignored
|
||||
event: e2
|
||||
: this line is ignored too
|
||||
i am a dataless field name
|
||||
data: four
|
||||
|
||||
"""
|
||||
|
||||
class FakeTransport:
|
||||
disconnecting = False
|
||||
|
||||
class EventSourceClient(unittest.TestCase):
|
||||
def test_parser(self):
|
||||
events = []
|
||||
p = EventSourceParser(lambda t,d: events.append((t,d)))
|
||||
p.transport = FakeTransport()
|
||||
p.dataReceived(data1)
|
||||
self.failUnlessEqual(events,
|
||||
[(u"welcome", u"one and a\ntwo\n."),
|
||||
(u"message", u"three"),
|
||||
(u"e2", u"four"),
|
||||
])
|
||||
|
||||
# new py3 support in 15.5.0: web.client.Agent, w.c.downloadPage, twistd
|
||||
|
||||
# However trying 'wormhole server start' with py3/twisted-15.5.0 throws an
|
||||
# error in t.i._twistd_unix.UnixApplicationRunner.postApplication, it calls
|
||||
# os.write with str, not bytes. This file does not cover that test (testing
|
||||
# daemonization is hard).
|
||||
|
|
Loading…
Reference in New Issue
Block a user