diff --git a/src/wormhole_server/rendezvous_websocket.py b/src/wormhole_server/rendezvous_websocket.py new file mode 100644 index 0000000..b640a54 --- /dev/null +++ b/src/wormhole_server/rendezvous_websocket.py @@ -0,0 +1,169 @@ +import json, time +from twisted.internet import reactor +from twisted.python import log +from autobahn.twisted import websocket + +# Each WebSocket connection is bound to one "appid", one "side", and one +# "channelid". The connection's appid and side are set by the "bind" message +# (which must be the first message on the connection). The channelid is set +# by either a "allocate" message (where the server picks the channelid), or +# by a "claim" message (where the client picks it). All three values must be +# set before any other message (watch, add, deallocate) can be sent. + +# All websocket messages are JSON-encoded. The client can send us "inbound" +# messages (marked as "->" below), which may (or may not) provoke immediate +# (or delayed) "outbound" messages (marked as "<-"). There is no guaranteed +# correlation between requests and responses. In this list, "A -> B" means +# that some time after A is received, at least one message of type B will be +# sent out. + +# All outbound messages include a "sent" key, which is a float (seconds since +# epoch) with the server clock just before the outbound message was written +# to the socket. + +# connection -> welcome +# <- {type: "welcome", welcome: {}} # .welcome keys are all optional: +# current_version: out-of-date clients display a warning +# motd: all clients display message, then continue normally +# error: all clients display mesage, then terminate with error +# -> {type: "bind", appid:, side:} +# -> {type: "list"} -> all-channelids +# <- {type: "all-channelids", channelids: [int..]} +# -> {type: "allocate"} -> allocated +# <- {type: "allocated", channelid: int} +# -> {type: "claim", channelid: int} +# -> {type: "watch"} -> message # sends old messages and more in future +# <- {type: "message", message: {phase:, body:}} # body is base64 +# -> {type: "add", phase: str, body: base64} # may send echo +# -> {type: "deallocate", mood: str} -> deallocated +# <- {type: "deallocated", status: waiting|deleted} +# <- {type: "error", error: str, orig: {}} # in response to malformed msgs + +# for tests that need to know when a message has been processed: +# -> {type: "ping", ping: int} -> pong (does not require bind/claim) +# <- {type: "pong", pong: int} + +class Error(Exception): + def __init__(self, explain, orig_msg): + self._explain = explain + +class WebSocketRendezvous(websocket.WebSocketServerProtocol): + def __init__(self): + websocket.WebSocketServerProtocol.__init__(self) + self._app = None + self._side = None + self._channel = None + self._watching = False + + def onConnect(self, request): + rv = self.factory.rendezvous + if rv.get_log_requests(): + log.msg("ws client connecting: %s" % (request.peer,)) + self._reactor = self.factory.reactor + + def onOpen(self): + rv = self.factory.rendezvous + self.send("welcome", welcome=rv.get_welcome()) + + def onMessage(self, payload, isBinary): + msg = json.loads(payload.decode("utf-8")) + try: + if "type" not in msg: + raise Error("missing 'type'") + mtype = msg["type"] + if mtype == "ping": + return self.handle_ping(msg) + if mtype == "bind": + return self.handle_bind(msg) + + if not self._app: + raise Error("Must bind first") + if mtype == "list": + return self.handle_list() + if mtype == "allocate": + return self.handle_allocate() + if mtype == "claim": + return self.handle_claim(msg) + + if not self._channel: + raise Error("Must set channel first") + meth = getattr(self, "handle_"+mtype, None) + if not meth: + raise Error("Unknown type") + return meth(self._channel, msg) + except Error as e: + self.send("error", error=e._explain, orig=msg) + + def send_rendezvous_event(self, event): + self.send("message", message=event) + + def stop_rendezvous_watcher(self): + self._reactor.callLater(0, self.transport.loseConnection) + + def handle_ping(self, msg): + if "ping" not in msg: + raise Error("ping requires 'ping'") + self.send("pong", pong=msg["ping"]) + + def handle_bind(self, msg): + if self._app or self._side: + raise Error("already bound") + if "appid" not in msg: + raise Error("bind requires 'appid'") + if "side" not in msg: + raise Error("bind requires 'side'") + self._app = self.factory.rendezvous.get_app(msg["appid"]) + self._side = msg["side"] + + def handle_list(self): + channelids = sorted(self._app.get_allocated()) + self.send("all-channelids", channelids=channelids) + + def handle_allocate(self): + if self._channel: + raise Error("Already bound to a channelid") + channelid = self._app.find_available_channelid() + self._channel = self._app.allocate_channel(channelid, self._side) + self.send("allocated", channelid=channelid) + + def handle_claim(self, msg): + if self._channel: + raise Error("Already bound to a channelid") + if "channelid" not in msg: + raise Error("claim requires 'channelid'") + self._channel = self._app.allocate_channel(msg["channelid"], self._side) + + def handle_watch(self, channel, msg): + if self._watching: + raise Error("already watching") + self._watching = True + for old_message in channel.add_listener(self): + self.send_rendezvous_event(old_message) + + def handle_add(self, channel, msg): + if "phase" not in msg: + raise Error("missing 'phase'") + if "body" not in msg: + raise Error("missing 'body'") + channel.add_message(self._side, msg["phase"], msg["body"]) + + def handle_deallocate(self, channel, msg): + deleted = channel.deallocate(self._side, msg.get("mood")) + self.send("deallocated", status="deleted" if deleted else "waiting") + + def send(self, mtype, **kwargs): + kwargs["type"] = mtype + kwargs["sent"] = time.time() + payload = json.dumps(kwargs).encode("utf-8") + self.sendMessage(payload, False) + + def onClose(self, wasClean, code, reason): + pass + + +class WebSocketRendezvousFactory(websocket.WebSocketServerFactory): + protocol = WebSocketRendezvous + def __init__(self, url, rendezvous): + websocket.WebSocketServerFactory.__init__(self, url) + self.rendezvous = rendezvous + self.reactor = reactor # for tests to control diff --git a/src/wormhole_server/server.py b/src/wormhole_server/server.py index 90971b9..a8912c9 100644 --- a/src/wormhole_server/server.py +++ b/src/wormhole_server/server.py @@ -3,11 +3,13 @@ from twisted.python import log from twisted.internet import reactor, endpoints from twisted.application import service from twisted.web import server, static, resource +from autobahn.twisted.resource import WebSocketResource from .endpoint_service import ServerEndpointService from wormhole import __version__ from .database import get_db from .rendezvous import Rendezvous from .rendezvous_web import WebRendezvous +from .rendezvous_websocket import WebSocketRendezvousFactory from .transit_server import Transit class Root(resource.Resource): @@ -48,6 +50,9 @@ class RelayServer(service.MultiService): wr = WebRendezvous(rendezvous) root.putChild(b"wormhole-relay", wr) + wsrf = WebSocketRendezvousFactory(None, rendezvous) + wr.putChild(b"ws", WebSocketResource(wsrf)) + site = PrivacyEnhancedSite(root) if blur_usage: site.logRequests = False @@ -69,6 +74,7 @@ class RelayServer(service.MultiService): self._root = root self._rendezvous_web = wr self._rendezvous_web_service = rendezvous_web_service + self._rendezvous_websocket = wsrf if transit_port: self._transit = transit self._transit_service = transit_service diff --git a/tests/common.py b/tests/common.py index 80b2da2..96f1548 100644 --- a/tests/common.py +++ b/tests/common.py @@ -18,6 +18,9 @@ class ServerBase: self._rendezvous = s._rendezvous self._transit_server = s._transit self.relayurl = u"http://127.0.0.1:%d/wormhole-relay/" % relayport + self.rdv_ws_url = self.relayurl.replace("http:", "ws:") + "ws" + self.rdv_ws_port = relayport + # ws://127.0.0.1:%d/wormhole-relay/ws self.transit = u"tcp:127.0.0.1:%d" % transitport def tearDown(self): diff --git a/tests/test_server.py b/tests/test_server.py index 723e8c7..16a3869 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,13 +1,15 @@ from __future__ import print_function -import json -import requests +import json, itertools from binascii import hexlify +import requests from six.moves.urllib_parse import urlencode from twisted.trial import unittest from twisted.internet import protocol, reactor, defer +from twisted.internet.defer import inlineCallbacks, returnValue from twisted.internet.threads import deferToThread from twisted.internet.endpoints import clientFromString, connectProtocol from twisted.web.client import getPage, Agent, readBody +from autobahn.twisted import websocket from wormhole import __version__ from .common import ServerBase from wormhole_server import rendezvous, transit_server @@ -367,6 +369,322 @@ class OneEventAtATime: self.connected_d.errback(why) self.disconnected_d.callback((why,)) +class WSClient(websocket.WebSocketClientProtocol): + def __init__(self): + websocket.WebSocketClientProtocol.__init__(self) + self.events = [] + self.d = None + self.ping_counter = itertools.count(0) + def onOpen(self): + self.factory.d.callback(self) + def onMessage(self, payload, isBinary): + assert not isBinary + event = json.loads(payload.decode("utf-8")) + if self.d: + assert not self.events + d,self.d = self.d,None + d.callback(event) + return + self.events.append(event) + + def next_event(self): + assert not self.d + if self.events: + event = self.events.pop(0) + return defer.succeed(event) + self.d = defer.Deferred() + return self.d + + def send(self, mtype, **kwargs): + kwargs["type"] = mtype + payload = json.dumps(kwargs).encode("utf-8") + self.sendMessage(payload, False) + + @inlineCallbacks + def sync(self): + ping = next(self.ping_counter) + self.send("ping", ping=ping) + # queue all messages until the pong, then put them back + old_events = [] + while True: + ev = yield self.next_event() + if ev["type"] == "pong" and ev["pong"] == ping: + self.events = old_events + self.events + returnValue(None) + old_events.append(ev) + +class WSFactory(websocket.WebSocketClientFactory): + protocol = WSClient + +class WSClientSync(unittest.TestCase): + # make sure my 'sync' method actually works + + @inlineCallbacks + def test_sync(self): + sent = [] + c = WSClient() + def _send(mtype, **kwargs): + sent.append( (mtype, kwargs) ) + c.send = _send + def add(mtype, **kwargs): + kwargs["type"] = mtype + c.onMessage(json.dumps(kwargs).encode("utf-8"), False) + # no queued messages + sunc = [] + d = c.sync() + d.addBoth(sunc.append) + self.assertEqual(sent, [("ping", {"ping": 0})]) + self.assertEqual(sunc, []) + add("pong", pong=0) + yield d + self.assertEqual(c.events, []) + + # one,two,ping,pong + add("one") + add("two", two=2) + sunc = [] + d = c.sync() + d.addBoth(sunc.append) + add("pong", pong=1) + yield d + m = yield c.next_event() + self.assertEqual(m["type"], "one") + m = yield c.next_event() + self.assertEqual(m["type"], "two") + self.assertEqual(c.events, []) + + # one,ping,two,pong + add("one") + sunc = [] + d = c.sync() + d.addBoth(sunc.append) + add("two", two=2) + add("pong", pong=2) + yield d + m = yield c.next_event() + self.assertEqual(m["type"], "one") + m = yield c.next_event() + self.assertEqual(m["type"], "two") + self.assertEqual(c.events, []) + + # ping,one,two,pong + sunc = [] + d = c.sync() + d.addBoth(sunc.append) + add("one") + add("two", two=2) + add("pong", pong=3) + yield d + m = yield c.next_event() + self.assertEqual(m["type"], "one") + m = yield c.next_event() + self.assertEqual(m["type"], "two") + self.assertEqual(c.events, []) + + + +class WebSocketAPI(ServerBase, unittest.TestCase): + def setUp(self): + self._clients = [] + return ServerBase.setUp(self) + + def tearDown(self): + for c in self._clients: + c.transport.loseConnection() + return ServerBase.tearDown(self) + + @inlineCallbacks + def make_client(self): + f = WSFactory(self.rdv_ws_url) + f.d = defer.Deferred() + reactor.connectTCP("127.0.0.1", self.rdv_ws_port, f) + c = yield f.d + self._clients.append(c) + returnValue(c) + + def check_welcome(self, data): + self.failUnlessIn("welcome", data) + self.failUnlessEqual(data["welcome"], {"current_version": __version__}) + + @inlineCallbacks + def test_welcome(self): + c1 = yield self.make_client() + msg = yield c1.next_event() + self.check_welcome(msg) + self.assertEqual(self._rendezvous._apps, {}) + + @inlineCallbacks + def test_allocate_1(self): + c1 = yield self.make_client() + msg = yield c1.next_event() + self.check_welcome(msg) + c1.send(u"bind", appid=u"appid", side=u"side") + yield c1.sync() + self.assertEqual(list(self._rendezvous._apps.keys()), [u"appid"]) + app = self._rendezvous.get_app(u"appid") + self.assertEqual(app.get_allocated(), set()) + c1.send(u"list") + msg = yield c1.next_event() + self.assertEqual(msg["type"], u"all-channelids") + self.assertEqual(msg["channelids"], []) + + c1.send(u"allocate") + msg = yield c1.next_event() + self.assertEqual(msg["type"], u"allocated") + cid = msg["channelid"] + self.failUnlessIsInstance(cid, int) + self.assertEqual(app.get_allocated(), set([cid])) + channel = app.get_channel(cid) + self.assertEqual(channel.get_messages(), []) + + c1.send(u"list") + msg = yield c1.next_event() + self.assertEqual(msg["type"], u"all-channelids") + self.assertEqual(msg["channelids"], [cid]) + + c1.send(u"deallocate") + msg = yield c1.next_event() + self.assertEqual(msg["type"], u"deallocated") + self.assertEqual(msg["status"], u"deleted") + self.assertEqual(app.get_allocated(), set()) + + c1.send(u"list") + msg = yield c1.next_event() + self.assertEqual(msg["type"], u"all-channelids") + self.assertEqual(msg["channelids"], []) + + @inlineCallbacks + def test_allocate_2(self): + c1 = yield self.make_client() + msg = yield c1.next_event() + self.check_welcome(msg) + c1.send(u"bind", appid=u"appid", side=u"side") + yield c1.sync() + app = self._rendezvous.get_app(u"appid") + self.assertEqual(app.get_allocated(), set()) + c1.send(u"allocate") + msg = yield c1.next_event() + self.assertEqual(msg["type"], u"allocated") + cid = msg["channelid"] + self.failUnlessIsInstance(cid, int) + self.assertEqual(app.get_allocated(), set([cid])) + channel = app.get_channel(cid) + self.assertEqual(channel.get_messages(), []) + + # second caller increases the number of known sides to 2 + c2 = yield self.make_client() + msg = yield c2.next_event() + self.check_welcome(msg) + c2.send(u"bind", appid=u"appid", side=u"side-2") + c2.send(u"claim", channelid=cid) + c2.send(u"add", phase="1", body="") + yield c2.sync() + + self.assertEqual(app.get_allocated(), set([cid])) + self.assertEqual(channel.get_messages(), [{"phase": "1", "body": ""}]) + + c1.send(u"list") + msg = yield c1.next_event() + self.assertEqual(msg["type"], u"all-channelids") + self.assertEqual(msg["channelids"], [cid]) + + c2.send(u"list") + msg = yield c2.next_event() + self.assertEqual(msg["type"], u"all-channelids") + self.assertEqual(msg["channelids"], [cid]) + + c1.send(u"deallocate") + msg = yield c1.next_event() + self.assertEqual(msg["type"], u"deallocated") + self.assertEqual(msg["status"], u"waiting") + + c2.send(u"deallocate") + msg = yield c2.next_event() + self.assertEqual(msg["type"], u"deallocated") + self.assertEqual(msg["status"], u"deleted") + + c2.send(u"list") + msg = yield c2.next_event() + self.assertEqual(msg["type"], u"all-channelids") + self.assertEqual(msg["channelids"], []) + + @inlineCallbacks + def test_message(self): + c1 = yield self.make_client() + msg = yield c1.next_event() + self.check_welcome(msg) + c1.send(u"bind", appid=u"appid", side=u"side") + c1.send(u"allocate") + msg = yield c1.next_event() + self.assertEqual(msg["type"], u"allocated") + cid = msg["channelid"] + app = self._rendezvous.get_app(u"appid") + channel = app.get_channel(cid) + self.assertEqual(channel.get_messages(), []) + + c1.send(u"watch") + yield c1.sync() + self.assertEqual(len(channel._listeners), 1) + self.assertEqual(c1.events, []) + + c1.send(u"add", phase="1", body="msg1A") + yield c1.sync() + self.assertEqual(channel.get_messages(), + [{"phase": "1", "body": "msg1A"}]) + self.assertEqual(len(c1.events), 1) # echo should be sent right away + msg = yield c1.next_event() + self.assertEqual(msg["type"], "message") + self.assertEqual(msg["message"], {"phase": "1", "body": "msg1A"}) + self.assertIn("sent", msg) + self.assertIsInstance(msg["sent"], float) + + c1.send(u"add", phase="1", body="msg1B") + c1.send(u"add", phase="2", body="msg2A") + + msg = yield c1.next_event() + self.assertEqual(msg["type"], "message") + self.assertEqual(msg["message"], {"phase": "1", "body": "msg1B"}) + + msg = yield c1.next_event() + self.assertEqual(msg["type"], "message") + self.assertEqual(msg["message"], {"phase": "2", "body": "msg2A"}) + + self.assertEqual(channel.get_messages(), [ + {"phase": "1", "body": "msg1A"}, + {"phase": "1", "body": "msg1B"}, + {"phase": "2", "body": "msg2A"}, + ]) + + # second client should see everything + c2 = yield self.make_client() + msg = yield c2.next_event() + self.check_welcome(msg) + c2.send(u"bind", appid=u"appid", side=u"side") + c2.send(u"claim", channelid=cid) + # 'watch' triggers delivery of old messages, in temporal order + c2.send(u"watch") + + msg = yield c2.next_event() + self.assertEqual(msg["type"], "message") + self.assertEqual(msg["message"], {"phase": "1", "body": "msg1A"}) + + msg = yield c2.next_event() + self.assertEqual(msg["type"], "message") + self.assertEqual(msg["message"], {"phase": "1", "body": "msg1B"}) + + msg = yield c2.next_event() + self.assertEqual(msg["type"], "message") + self.assertEqual(msg["message"], {"phase": "2", "body": "msg2A"}) + + # adding a duplicate is not an error, and clients will ignore it + c1.send(u"add", phase="2", body="msg2A") + + # the duplicate message *does* get stored, and delivered + msg = yield c2.next_event() + self.assertEqual(msg["type"], "message") + self.assertEqual(msg["message"], {"phase": "2", "body": "msg2A"}) + + class Summary(unittest.TestCase): def test_summarize(self): c = rendezvous.Channel(None, None, None, None, False, None, None)