From b99adecdde8bd41fc7813c544ac093c0a618c96b Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Mon, 18 Apr 2016 17:19:49 -0700 Subject: [PATCH 1/2] Depend upon autobahn, for upcoming websocket support. Use 'autobahn[twisted]' just to be sure (plain 'autobahn' worked fine for py27, but maybe it's needed for py35 or something). Autobahn is failing to do some conditional import and accidentally depends upon pytrie (for some encrypted WAMP thing) when we didn't ask for it (https://github.com/crossbario/autobahn-python/issues/604). This commit also adds a manual dependency on pytrie (which is pretty small) until the upstream bug is fixed. --- setup.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 2c413a6..73dd899 100644 --- a/setup.py +++ b/setup.py @@ -24,7 +24,11 @@ setup(name="magic-wormhole", "wormhole-server = wormhole_server.runner:entry", ]}, install_requires=["spake2==0.3", "pynacl", "requests", "argparse", - "six", "twisted >= 16.1.0", "hkdf"], + "six", "twisted >= 16.1.0", "hkdf", + "autobahn[twisted]", "pytrie", + # autobahn seems to have a bug, and one plugin throws + # errors unless pytrie is installed + ], extras_require={"tor": ["txtorcon", "ipaddr"]}, test_suite="wormhole.test", cmdclass=commands, From aec0615d66ce97857788675de75a348f18aa581c Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Mon, 18 Apr 2016 17:57:27 -0700 Subject: [PATCH 2/2] Add WebSocket-based rendezvous protocol frontend The websocket lives on a Resource of the main rendezvous web site, and the websocket URL is derived from the main "relay_url", so there's no extra port to allocate, and no extra service to shut down. --- src/wormhole_server/rendezvous_websocket.py | 169 ++++++++++ src/wormhole_server/server.py | 6 + tests/common.py | 3 + tests/test_server.py | 322 +++++++++++++++++++- 4 files changed, 498 insertions(+), 2 deletions(-) create mode 100644 src/wormhole_server/rendezvous_websocket.py diff --git a/src/wormhole_server/rendezvous_websocket.py b/src/wormhole_server/rendezvous_websocket.py new file mode 100644 index 0000000..b640a54 --- /dev/null +++ b/src/wormhole_server/rendezvous_websocket.py @@ -0,0 +1,169 @@ +import json, time +from twisted.internet import reactor +from twisted.python import log +from autobahn.twisted import websocket + +# Each WebSocket connection is bound to one "appid", one "side", and one +# "channelid". The connection's appid and side are set by the "bind" message +# (which must be the first message on the connection). The channelid is set +# by either a "allocate" message (where the server picks the channelid), or +# by a "claim" message (where the client picks it). All three values must be +# set before any other message (watch, add, deallocate) can be sent. + +# All websocket messages are JSON-encoded. The client can send us "inbound" +# messages (marked as "->" below), which may (or may not) provoke immediate +# (or delayed) "outbound" messages (marked as "<-"). There is no guaranteed +# correlation between requests and responses. In this list, "A -> B" means +# that some time after A is received, at least one message of type B will be +# sent out. + +# All outbound messages include a "sent" key, which is a float (seconds since +# epoch) with the server clock just before the outbound message was written +# to the socket. + +# connection -> welcome +# <- {type: "welcome", welcome: {}} # .welcome keys are all optional: +# current_version: out-of-date clients display a warning +# motd: all clients display message, then continue normally +# error: all clients display mesage, then terminate with error +# -> {type: "bind", appid:, side:} +# -> {type: "list"} -> all-channelids +# <- {type: "all-channelids", channelids: [int..]} +# -> {type: "allocate"} -> allocated +# <- {type: "allocated", channelid: int} +# -> {type: "claim", channelid: int} +# -> {type: "watch"} -> message # sends old messages and more in future +# <- {type: "message", message: {phase:, body:}} # body is base64 +# -> {type: "add", phase: str, body: base64} # may send echo +# -> {type: "deallocate", mood: str} -> deallocated +# <- {type: "deallocated", status: waiting|deleted} +# <- {type: "error", error: str, orig: {}} # in response to malformed msgs + +# for tests that need to know when a message has been processed: +# -> {type: "ping", ping: int} -> pong (does not require bind/claim) +# <- {type: "pong", pong: int} + +class Error(Exception): + def __init__(self, explain, orig_msg): + self._explain = explain + +class WebSocketRendezvous(websocket.WebSocketServerProtocol): + def __init__(self): + websocket.WebSocketServerProtocol.__init__(self) + self._app = None + self._side = None + self._channel = None + self._watching = False + + def onConnect(self, request): + rv = self.factory.rendezvous + if rv.get_log_requests(): + log.msg("ws client connecting: %s" % (request.peer,)) + self._reactor = self.factory.reactor + + def onOpen(self): + rv = self.factory.rendezvous + self.send("welcome", welcome=rv.get_welcome()) + + def onMessage(self, payload, isBinary): + msg = json.loads(payload.decode("utf-8")) + try: + if "type" not in msg: + raise Error("missing 'type'") + mtype = msg["type"] + if mtype == "ping": + return self.handle_ping(msg) + if mtype == "bind": + return self.handle_bind(msg) + + if not self._app: + raise Error("Must bind first") + if mtype == "list": + return self.handle_list() + if mtype == "allocate": + return self.handle_allocate() + if mtype == "claim": + return self.handle_claim(msg) + + if not self._channel: + raise Error("Must set channel first") + meth = getattr(self, "handle_"+mtype, None) + if not meth: + raise Error("Unknown type") + return meth(self._channel, msg) + except Error as e: + self.send("error", error=e._explain, orig=msg) + + def send_rendezvous_event(self, event): + self.send("message", message=event) + + def stop_rendezvous_watcher(self): + self._reactor.callLater(0, self.transport.loseConnection) + + def handle_ping(self, msg): + if "ping" not in msg: + raise Error("ping requires 'ping'") + self.send("pong", pong=msg["ping"]) + + def handle_bind(self, msg): + if self._app or self._side: + raise Error("already bound") + if "appid" not in msg: + raise Error("bind requires 'appid'") + if "side" not in msg: + raise Error("bind requires 'side'") + self._app = self.factory.rendezvous.get_app(msg["appid"]) + self._side = msg["side"] + + def handle_list(self): + channelids = sorted(self._app.get_allocated()) + self.send("all-channelids", channelids=channelids) + + def handle_allocate(self): + if self._channel: + raise Error("Already bound to a channelid") + channelid = self._app.find_available_channelid() + self._channel = self._app.allocate_channel(channelid, self._side) + self.send("allocated", channelid=channelid) + + def handle_claim(self, msg): + if self._channel: + raise Error("Already bound to a channelid") + if "channelid" not in msg: + raise Error("claim requires 'channelid'") + self._channel = self._app.allocate_channel(msg["channelid"], self._side) + + def handle_watch(self, channel, msg): + if self._watching: + raise Error("already watching") + self._watching = True + for old_message in channel.add_listener(self): + self.send_rendezvous_event(old_message) + + def handle_add(self, channel, msg): + if "phase" not in msg: + raise Error("missing 'phase'") + if "body" not in msg: + raise Error("missing 'body'") + channel.add_message(self._side, msg["phase"], msg["body"]) + + def handle_deallocate(self, channel, msg): + deleted = channel.deallocate(self._side, msg.get("mood")) + self.send("deallocated", status="deleted" if deleted else "waiting") + + def send(self, mtype, **kwargs): + kwargs["type"] = mtype + kwargs["sent"] = time.time() + payload = json.dumps(kwargs).encode("utf-8") + self.sendMessage(payload, False) + + def onClose(self, wasClean, code, reason): + pass + + +class WebSocketRendezvousFactory(websocket.WebSocketServerFactory): + protocol = WebSocketRendezvous + def __init__(self, url, rendezvous): + websocket.WebSocketServerFactory.__init__(self, url) + self.rendezvous = rendezvous + self.reactor = reactor # for tests to control diff --git a/src/wormhole_server/server.py b/src/wormhole_server/server.py index 90971b9..a8912c9 100644 --- a/src/wormhole_server/server.py +++ b/src/wormhole_server/server.py @@ -3,11 +3,13 @@ from twisted.python import log from twisted.internet import reactor, endpoints from twisted.application import service from twisted.web import server, static, resource +from autobahn.twisted.resource import WebSocketResource from .endpoint_service import ServerEndpointService from wormhole import __version__ from .database import get_db from .rendezvous import Rendezvous from .rendezvous_web import WebRendezvous +from .rendezvous_websocket import WebSocketRendezvousFactory from .transit_server import Transit class Root(resource.Resource): @@ -48,6 +50,9 @@ class RelayServer(service.MultiService): wr = WebRendezvous(rendezvous) root.putChild(b"wormhole-relay", wr) + wsrf = WebSocketRendezvousFactory(None, rendezvous) + wr.putChild(b"ws", WebSocketResource(wsrf)) + site = PrivacyEnhancedSite(root) if blur_usage: site.logRequests = False @@ -69,6 +74,7 @@ class RelayServer(service.MultiService): self._root = root self._rendezvous_web = wr self._rendezvous_web_service = rendezvous_web_service + self._rendezvous_websocket = wsrf if transit_port: self._transit = transit self._transit_service = transit_service diff --git a/tests/common.py b/tests/common.py index 80b2da2..96f1548 100644 --- a/tests/common.py +++ b/tests/common.py @@ -18,6 +18,9 @@ class ServerBase: self._rendezvous = s._rendezvous self._transit_server = s._transit self.relayurl = u"http://127.0.0.1:%d/wormhole-relay/" % relayport + self.rdv_ws_url = self.relayurl.replace("http:", "ws:") + "ws" + self.rdv_ws_port = relayport + # ws://127.0.0.1:%d/wormhole-relay/ws self.transit = u"tcp:127.0.0.1:%d" % transitport def tearDown(self): diff --git a/tests/test_server.py b/tests/test_server.py index 723e8c7..16a3869 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,13 +1,15 @@ from __future__ import print_function -import json -import requests +import json, itertools from binascii import hexlify +import requests from six.moves.urllib_parse import urlencode from twisted.trial import unittest from twisted.internet import protocol, reactor, defer +from twisted.internet.defer import inlineCallbacks, returnValue from twisted.internet.threads import deferToThread from twisted.internet.endpoints import clientFromString, connectProtocol from twisted.web.client import getPage, Agent, readBody +from autobahn.twisted import websocket from wormhole import __version__ from .common import ServerBase from wormhole_server import rendezvous, transit_server @@ -367,6 +369,322 @@ class OneEventAtATime: self.connected_d.errback(why) self.disconnected_d.callback((why,)) +class WSClient(websocket.WebSocketClientProtocol): + def __init__(self): + websocket.WebSocketClientProtocol.__init__(self) + self.events = [] + self.d = None + self.ping_counter = itertools.count(0) + def onOpen(self): + self.factory.d.callback(self) + def onMessage(self, payload, isBinary): + assert not isBinary + event = json.loads(payload.decode("utf-8")) + if self.d: + assert not self.events + d,self.d = self.d,None + d.callback(event) + return + self.events.append(event) + + def next_event(self): + assert not self.d + if self.events: + event = self.events.pop(0) + return defer.succeed(event) + self.d = defer.Deferred() + return self.d + + def send(self, mtype, **kwargs): + kwargs["type"] = mtype + payload = json.dumps(kwargs).encode("utf-8") + self.sendMessage(payload, False) + + @inlineCallbacks + def sync(self): + ping = next(self.ping_counter) + self.send("ping", ping=ping) + # queue all messages until the pong, then put them back + old_events = [] + while True: + ev = yield self.next_event() + if ev["type"] == "pong" and ev["pong"] == ping: + self.events = old_events + self.events + returnValue(None) + old_events.append(ev) + +class WSFactory(websocket.WebSocketClientFactory): + protocol = WSClient + +class WSClientSync(unittest.TestCase): + # make sure my 'sync' method actually works + + @inlineCallbacks + def test_sync(self): + sent = [] + c = WSClient() + def _send(mtype, **kwargs): + sent.append( (mtype, kwargs) ) + c.send = _send + def add(mtype, **kwargs): + kwargs["type"] = mtype + c.onMessage(json.dumps(kwargs).encode("utf-8"), False) + # no queued messages + sunc = [] + d = c.sync() + d.addBoth(sunc.append) + self.assertEqual(sent, [("ping", {"ping": 0})]) + self.assertEqual(sunc, []) + add("pong", pong=0) + yield d + self.assertEqual(c.events, []) + + # one,two,ping,pong + add("one") + add("two", two=2) + sunc = [] + d = c.sync() + d.addBoth(sunc.append) + add("pong", pong=1) + yield d + m = yield c.next_event() + self.assertEqual(m["type"], "one") + m = yield c.next_event() + self.assertEqual(m["type"], "two") + self.assertEqual(c.events, []) + + # one,ping,two,pong + add("one") + sunc = [] + d = c.sync() + d.addBoth(sunc.append) + add("two", two=2) + add("pong", pong=2) + yield d + m = yield c.next_event() + self.assertEqual(m["type"], "one") + m = yield c.next_event() + self.assertEqual(m["type"], "two") + self.assertEqual(c.events, []) + + # ping,one,two,pong + sunc = [] + d = c.sync() + d.addBoth(sunc.append) + add("one") + add("two", two=2) + add("pong", pong=3) + yield d + m = yield c.next_event() + self.assertEqual(m["type"], "one") + m = yield c.next_event() + self.assertEqual(m["type"], "two") + self.assertEqual(c.events, []) + + + +class WebSocketAPI(ServerBase, unittest.TestCase): + def setUp(self): + self._clients = [] + return ServerBase.setUp(self) + + def tearDown(self): + for c in self._clients: + c.transport.loseConnection() + return ServerBase.tearDown(self) + + @inlineCallbacks + def make_client(self): + f = WSFactory(self.rdv_ws_url) + f.d = defer.Deferred() + reactor.connectTCP("127.0.0.1", self.rdv_ws_port, f) + c = yield f.d + self._clients.append(c) + returnValue(c) + + def check_welcome(self, data): + self.failUnlessIn("welcome", data) + self.failUnlessEqual(data["welcome"], {"current_version": __version__}) + + @inlineCallbacks + def test_welcome(self): + c1 = yield self.make_client() + msg = yield c1.next_event() + self.check_welcome(msg) + self.assertEqual(self._rendezvous._apps, {}) + + @inlineCallbacks + def test_allocate_1(self): + c1 = yield self.make_client() + msg = yield c1.next_event() + self.check_welcome(msg) + c1.send(u"bind", appid=u"appid", side=u"side") + yield c1.sync() + self.assertEqual(list(self._rendezvous._apps.keys()), [u"appid"]) + app = self._rendezvous.get_app(u"appid") + self.assertEqual(app.get_allocated(), set()) + c1.send(u"list") + msg = yield c1.next_event() + self.assertEqual(msg["type"], u"all-channelids") + self.assertEqual(msg["channelids"], []) + + c1.send(u"allocate") + msg = yield c1.next_event() + self.assertEqual(msg["type"], u"allocated") + cid = msg["channelid"] + self.failUnlessIsInstance(cid, int) + self.assertEqual(app.get_allocated(), set([cid])) + channel = app.get_channel(cid) + self.assertEqual(channel.get_messages(), []) + + c1.send(u"list") + msg = yield c1.next_event() + self.assertEqual(msg["type"], u"all-channelids") + self.assertEqual(msg["channelids"], [cid]) + + c1.send(u"deallocate") + msg = yield c1.next_event() + self.assertEqual(msg["type"], u"deallocated") + self.assertEqual(msg["status"], u"deleted") + self.assertEqual(app.get_allocated(), set()) + + c1.send(u"list") + msg = yield c1.next_event() + self.assertEqual(msg["type"], u"all-channelids") + self.assertEqual(msg["channelids"], []) + + @inlineCallbacks + def test_allocate_2(self): + c1 = yield self.make_client() + msg = yield c1.next_event() + self.check_welcome(msg) + c1.send(u"bind", appid=u"appid", side=u"side") + yield c1.sync() + app = self._rendezvous.get_app(u"appid") + self.assertEqual(app.get_allocated(), set()) + c1.send(u"allocate") + msg = yield c1.next_event() + self.assertEqual(msg["type"], u"allocated") + cid = msg["channelid"] + self.failUnlessIsInstance(cid, int) + self.assertEqual(app.get_allocated(), set([cid])) + channel = app.get_channel(cid) + self.assertEqual(channel.get_messages(), []) + + # second caller increases the number of known sides to 2 + c2 = yield self.make_client() + msg = yield c2.next_event() + self.check_welcome(msg) + c2.send(u"bind", appid=u"appid", side=u"side-2") + c2.send(u"claim", channelid=cid) + c2.send(u"add", phase="1", body="") + yield c2.sync() + + self.assertEqual(app.get_allocated(), set([cid])) + self.assertEqual(channel.get_messages(), [{"phase": "1", "body": ""}]) + + c1.send(u"list") + msg = yield c1.next_event() + self.assertEqual(msg["type"], u"all-channelids") + self.assertEqual(msg["channelids"], [cid]) + + c2.send(u"list") + msg = yield c2.next_event() + self.assertEqual(msg["type"], u"all-channelids") + self.assertEqual(msg["channelids"], [cid]) + + c1.send(u"deallocate") + msg = yield c1.next_event() + self.assertEqual(msg["type"], u"deallocated") + self.assertEqual(msg["status"], u"waiting") + + c2.send(u"deallocate") + msg = yield c2.next_event() + self.assertEqual(msg["type"], u"deallocated") + self.assertEqual(msg["status"], u"deleted") + + c2.send(u"list") + msg = yield c2.next_event() + self.assertEqual(msg["type"], u"all-channelids") + self.assertEqual(msg["channelids"], []) + + @inlineCallbacks + def test_message(self): + c1 = yield self.make_client() + msg = yield c1.next_event() + self.check_welcome(msg) + c1.send(u"bind", appid=u"appid", side=u"side") + c1.send(u"allocate") + msg = yield c1.next_event() + self.assertEqual(msg["type"], u"allocated") + cid = msg["channelid"] + app = self._rendezvous.get_app(u"appid") + channel = app.get_channel(cid) + self.assertEqual(channel.get_messages(), []) + + c1.send(u"watch") + yield c1.sync() + self.assertEqual(len(channel._listeners), 1) + self.assertEqual(c1.events, []) + + c1.send(u"add", phase="1", body="msg1A") + yield c1.sync() + self.assertEqual(channel.get_messages(), + [{"phase": "1", "body": "msg1A"}]) + self.assertEqual(len(c1.events), 1) # echo should be sent right away + msg = yield c1.next_event() + self.assertEqual(msg["type"], "message") + self.assertEqual(msg["message"], {"phase": "1", "body": "msg1A"}) + self.assertIn("sent", msg) + self.assertIsInstance(msg["sent"], float) + + c1.send(u"add", phase="1", body="msg1B") + c1.send(u"add", phase="2", body="msg2A") + + msg = yield c1.next_event() + self.assertEqual(msg["type"], "message") + self.assertEqual(msg["message"], {"phase": "1", "body": "msg1B"}) + + msg = yield c1.next_event() + self.assertEqual(msg["type"], "message") + self.assertEqual(msg["message"], {"phase": "2", "body": "msg2A"}) + + self.assertEqual(channel.get_messages(), [ + {"phase": "1", "body": "msg1A"}, + {"phase": "1", "body": "msg1B"}, + {"phase": "2", "body": "msg2A"}, + ]) + + # second client should see everything + c2 = yield self.make_client() + msg = yield c2.next_event() + self.check_welcome(msg) + c2.send(u"bind", appid=u"appid", side=u"side") + c2.send(u"claim", channelid=cid) + # 'watch' triggers delivery of old messages, in temporal order + c2.send(u"watch") + + msg = yield c2.next_event() + self.assertEqual(msg["type"], "message") + self.assertEqual(msg["message"], {"phase": "1", "body": "msg1A"}) + + msg = yield c2.next_event() + self.assertEqual(msg["type"], "message") + self.assertEqual(msg["message"], {"phase": "1", "body": "msg1B"}) + + msg = yield c2.next_event() + self.assertEqual(msg["type"], "message") + self.assertEqual(msg["message"], {"phase": "2", "body": "msg2A"}) + + # adding a duplicate is not an error, and clients will ignore it + c1.send(u"add", phase="2", body="msg2A") + + # the duplicate message *does* get stored, and delivered + msg = yield c2.next_event() + self.assertEqual(msg["type"], "message") + self.assertEqual(msg["message"], {"phase": "2", "body": "msg2A"}) + + class Summary(unittest.TestCase): def test_summarize(self): c = rendezvous.Channel(None, None, None, None, False, None, None)