Merge branch 'websocket'

This commit is contained in:
Brian Warner 2016-04-18 18:10:31 -07:00
commit f606207163
5 changed files with 503 additions and 3 deletions

View File

@ -24,7 +24,11 @@ setup(name="magic-wormhole",
"wormhole-server = wormhole_server.runner:entry",
]},
install_requires=["spake2==0.3", "pynacl", "requests", "argparse",
"six", "twisted >= 16.1.0", "hkdf"],
"six", "twisted >= 16.1.0", "hkdf",
"autobahn[twisted]", "pytrie",
# autobahn seems to have a bug, and one plugin throws
# errors unless pytrie is installed
],
extras_require={"tor": ["txtorcon", "ipaddr"]},
test_suite="wormhole.test",
cmdclass=commands,

View File

@ -0,0 +1,169 @@
import json, time
from twisted.internet import reactor
from twisted.python import log
from autobahn.twisted import websocket
# Each WebSocket connection is bound to one "appid", one "side", and one
# "channelid". The connection's appid and side are set by the "bind" message
# (which must be the first message on the connection). The channelid is set
# by either a "allocate" message (where the server picks the channelid), or
# by a "claim" message (where the client picks it). All three values must be
# set before any other message (watch, add, deallocate) can be sent.
# All websocket messages are JSON-encoded. The client can send us "inbound"
# messages (marked as "->" below), which may (or may not) provoke immediate
# (or delayed) "outbound" messages (marked as "<-"). There is no guaranteed
# correlation between requests and responses. In this list, "A -> B" means
# that some time after A is received, at least one message of type B will be
# sent out.
# All outbound messages include a "sent" key, which is a float (seconds since
# epoch) with the server clock just before the outbound message was written
# to the socket.
# connection -> welcome
# <- {type: "welcome", welcome: {}} # .welcome keys are all optional:
# current_version: out-of-date clients display a warning
# motd: all clients display message, then continue normally
# error: all clients display mesage, then terminate with error
# -> {type: "bind", appid:, side:}
# -> {type: "list"} -> all-channelids
# <- {type: "all-channelids", channelids: [int..]}
# -> {type: "allocate"} -> allocated
# <- {type: "allocated", channelid: int}
# -> {type: "claim", channelid: int}
# -> {type: "watch"} -> message # sends old messages and more in future
# <- {type: "message", message: {phase:, body:}} # body is base64
# -> {type: "add", phase: str, body: base64} # may send echo
# -> {type: "deallocate", mood: str} -> deallocated
# <- {type: "deallocated", status: waiting|deleted}
# <- {type: "error", error: str, orig: {}} # in response to malformed msgs
# for tests that need to know when a message has been processed:
# -> {type: "ping", ping: int} -> pong (does not require bind/claim)
# <- {type: "pong", pong: int}
class Error(Exception):
def __init__(self, explain, orig_msg):
self._explain = explain
class WebSocketRendezvous(websocket.WebSocketServerProtocol):
def __init__(self):
websocket.WebSocketServerProtocol.__init__(self)
self._app = None
self._side = None
self._channel = None
self._watching = False
def onConnect(self, request):
rv = self.factory.rendezvous
if rv.get_log_requests():
log.msg("ws client connecting: %s" % (request.peer,))
self._reactor = self.factory.reactor
def onOpen(self):
rv = self.factory.rendezvous
self.send("welcome", welcome=rv.get_welcome())
def onMessage(self, payload, isBinary):
msg = json.loads(payload.decode("utf-8"))
try:
if "type" not in msg:
raise Error("missing 'type'")
mtype = msg["type"]
if mtype == "ping":
return self.handle_ping(msg)
if mtype == "bind":
return self.handle_bind(msg)
if not self._app:
raise Error("Must bind first")
if mtype == "list":
return self.handle_list()
if mtype == "allocate":
return self.handle_allocate()
if mtype == "claim":
return self.handle_claim(msg)
if not self._channel:
raise Error("Must set channel first")
meth = getattr(self, "handle_"+mtype, None)
if not meth:
raise Error("Unknown type")
return meth(self._channel, msg)
except Error as e:
self.send("error", error=e._explain, orig=msg)
def send_rendezvous_event(self, event):
self.send("message", message=event)
def stop_rendezvous_watcher(self):
self._reactor.callLater(0, self.transport.loseConnection)
def handle_ping(self, msg):
if "ping" not in msg:
raise Error("ping requires 'ping'")
self.send("pong", pong=msg["ping"])
def handle_bind(self, msg):
if self._app or self._side:
raise Error("already bound")
if "appid" not in msg:
raise Error("bind requires 'appid'")
if "side" not in msg:
raise Error("bind requires 'side'")
self._app = self.factory.rendezvous.get_app(msg["appid"])
self._side = msg["side"]
def handle_list(self):
channelids = sorted(self._app.get_allocated())
self.send("all-channelids", channelids=channelids)
def handle_allocate(self):
if self._channel:
raise Error("Already bound to a channelid")
channelid = self._app.find_available_channelid()
self._channel = self._app.allocate_channel(channelid, self._side)
self.send("allocated", channelid=channelid)
def handle_claim(self, msg):
if self._channel:
raise Error("Already bound to a channelid")
if "channelid" not in msg:
raise Error("claim requires 'channelid'")
self._channel = self._app.allocate_channel(msg["channelid"], self._side)
def handle_watch(self, channel, msg):
if self._watching:
raise Error("already watching")
self._watching = True
for old_message in channel.add_listener(self):
self.send_rendezvous_event(old_message)
def handle_add(self, channel, msg):
if "phase" not in msg:
raise Error("missing 'phase'")
if "body" not in msg:
raise Error("missing 'body'")
channel.add_message(self._side, msg["phase"], msg["body"])
def handle_deallocate(self, channel, msg):
deleted = channel.deallocate(self._side, msg.get("mood"))
self.send("deallocated", status="deleted" if deleted else "waiting")
def send(self, mtype, **kwargs):
kwargs["type"] = mtype
kwargs["sent"] = time.time()
payload = json.dumps(kwargs).encode("utf-8")
self.sendMessage(payload, False)
def onClose(self, wasClean, code, reason):
pass
class WebSocketRendezvousFactory(websocket.WebSocketServerFactory):
protocol = WebSocketRendezvous
def __init__(self, url, rendezvous):
websocket.WebSocketServerFactory.__init__(self, url)
self.rendezvous = rendezvous
self.reactor = reactor # for tests to control

View File

@ -3,11 +3,13 @@ from twisted.python import log
from twisted.internet import reactor, endpoints
from twisted.application import service
from twisted.web import server, static, resource
from autobahn.twisted.resource import WebSocketResource
from .endpoint_service import ServerEndpointService
from wormhole import __version__
from .database import get_db
from .rendezvous import Rendezvous
from .rendezvous_web import WebRendezvous
from .rendezvous_websocket import WebSocketRendezvousFactory
from .transit_server import Transit
class Root(resource.Resource):
@ -48,6 +50,9 @@ class RelayServer(service.MultiService):
wr = WebRendezvous(rendezvous)
root.putChild(b"wormhole-relay", wr)
wsrf = WebSocketRendezvousFactory(None, rendezvous)
wr.putChild(b"ws", WebSocketResource(wsrf))
site = PrivacyEnhancedSite(root)
if blur_usage:
site.logRequests = False
@ -69,6 +74,7 @@ class RelayServer(service.MultiService):
self._root = root
self._rendezvous_web = wr
self._rendezvous_web_service = rendezvous_web_service
self._rendezvous_websocket = wsrf
if transit_port:
self._transit = transit
self._transit_service = transit_service

View File

@ -18,6 +18,9 @@ class ServerBase:
self._rendezvous = s._rendezvous
self._transit_server = s._transit
self.relayurl = u"http://127.0.0.1:%d/wormhole-relay/" % relayport
self.rdv_ws_url = self.relayurl.replace("http:", "ws:") + "ws"
self.rdv_ws_port = relayport
# ws://127.0.0.1:%d/wormhole-relay/ws
self.transit = u"tcp:127.0.0.1:%d" % transitport
def tearDown(self):

View File

@ -1,13 +1,15 @@
from __future__ import print_function
import json
import requests
import json, itertools
from binascii import hexlify
import requests
from six.moves.urllib_parse import urlencode
from twisted.trial import unittest
from twisted.internet import protocol, reactor, defer
from twisted.internet.defer import inlineCallbacks, returnValue
from twisted.internet.threads import deferToThread
from twisted.internet.endpoints import clientFromString, connectProtocol
from twisted.web.client import getPage, Agent, readBody
from autobahn.twisted import websocket
from wormhole import __version__
from .common import ServerBase
from wormhole_server import rendezvous, transit_server
@ -367,6 +369,322 @@ class OneEventAtATime:
self.connected_d.errback(why)
self.disconnected_d.callback((why,))
class WSClient(websocket.WebSocketClientProtocol):
def __init__(self):
websocket.WebSocketClientProtocol.__init__(self)
self.events = []
self.d = None
self.ping_counter = itertools.count(0)
def onOpen(self):
self.factory.d.callback(self)
def onMessage(self, payload, isBinary):
assert not isBinary
event = json.loads(payload.decode("utf-8"))
if self.d:
assert not self.events
d,self.d = self.d,None
d.callback(event)
return
self.events.append(event)
def next_event(self):
assert not self.d
if self.events:
event = self.events.pop(0)
return defer.succeed(event)
self.d = defer.Deferred()
return self.d
def send(self, mtype, **kwargs):
kwargs["type"] = mtype
payload = json.dumps(kwargs).encode("utf-8")
self.sendMessage(payload, False)
@inlineCallbacks
def sync(self):
ping = next(self.ping_counter)
self.send("ping", ping=ping)
# queue all messages until the pong, then put them back
old_events = []
while True:
ev = yield self.next_event()
if ev["type"] == "pong" and ev["pong"] == ping:
self.events = old_events + self.events
returnValue(None)
old_events.append(ev)
class WSFactory(websocket.WebSocketClientFactory):
protocol = WSClient
class WSClientSync(unittest.TestCase):
# make sure my 'sync' method actually works
@inlineCallbacks
def test_sync(self):
sent = []
c = WSClient()
def _send(mtype, **kwargs):
sent.append( (mtype, kwargs) )
c.send = _send
def add(mtype, **kwargs):
kwargs["type"] = mtype
c.onMessage(json.dumps(kwargs).encode("utf-8"), False)
# no queued messages
sunc = []
d = c.sync()
d.addBoth(sunc.append)
self.assertEqual(sent, [("ping", {"ping": 0})])
self.assertEqual(sunc, [])
add("pong", pong=0)
yield d
self.assertEqual(c.events, [])
# one,two,ping,pong
add("one")
add("two", two=2)
sunc = []
d = c.sync()
d.addBoth(sunc.append)
add("pong", pong=1)
yield d
m = yield c.next_event()
self.assertEqual(m["type"], "one")
m = yield c.next_event()
self.assertEqual(m["type"], "two")
self.assertEqual(c.events, [])
# one,ping,two,pong
add("one")
sunc = []
d = c.sync()
d.addBoth(sunc.append)
add("two", two=2)
add("pong", pong=2)
yield d
m = yield c.next_event()
self.assertEqual(m["type"], "one")
m = yield c.next_event()
self.assertEqual(m["type"], "two")
self.assertEqual(c.events, [])
# ping,one,two,pong
sunc = []
d = c.sync()
d.addBoth(sunc.append)
add("one")
add("two", two=2)
add("pong", pong=3)
yield d
m = yield c.next_event()
self.assertEqual(m["type"], "one")
m = yield c.next_event()
self.assertEqual(m["type"], "two")
self.assertEqual(c.events, [])
class WebSocketAPI(ServerBase, unittest.TestCase):
def setUp(self):
self._clients = []
return ServerBase.setUp(self)
def tearDown(self):
for c in self._clients:
c.transport.loseConnection()
return ServerBase.tearDown(self)
@inlineCallbacks
def make_client(self):
f = WSFactory(self.rdv_ws_url)
f.d = defer.Deferred()
reactor.connectTCP("127.0.0.1", self.rdv_ws_port, f)
c = yield f.d
self._clients.append(c)
returnValue(c)
def check_welcome(self, data):
self.failUnlessIn("welcome", data)
self.failUnlessEqual(data["welcome"], {"current_version": __version__})
@inlineCallbacks
def test_welcome(self):
c1 = yield self.make_client()
msg = yield c1.next_event()
self.check_welcome(msg)
self.assertEqual(self._rendezvous._apps, {})
@inlineCallbacks
def test_allocate_1(self):
c1 = yield self.make_client()
msg = yield c1.next_event()
self.check_welcome(msg)
c1.send(u"bind", appid=u"appid", side=u"side")
yield c1.sync()
self.assertEqual(list(self._rendezvous._apps.keys()), [u"appid"])
app = self._rendezvous.get_app(u"appid")
self.assertEqual(app.get_allocated(), set())
c1.send(u"list")
msg = yield c1.next_event()
self.assertEqual(msg["type"], u"all-channelids")
self.assertEqual(msg["channelids"], [])
c1.send(u"allocate")
msg = yield c1.next_event()
self.assertEqual(msg["type"], u"allocated")
cid = msg["channelid"]
self.failUnlessIsInstance(cid, int)
self.assertEqual(app.get_allocated(), set([cid]))
channel = app.get_channel(cid)
self.assertEqual(channel.get_messages(), [])
c1.send(u"list")
msg = yield c1.next_event()
self.assertEqual(msg["type"], u"all-channelids")
self.assertEqual(msg["channelids"], [cid])
c1.send(u"deallocate")
msg = yield c1.next_event()
self.assertEqual(msg["type"], u"deallocated")
self.assertEqual(msg["status"], u"deleted")
self.assertEqual(app.get_allocated(), set())
c1.send(u"list")
msg = yield c1.next_event()
self.assertEqual(msg["type"], u"all-channelids")
self.assertEqual(msg["channelids"], [])
@inlineCallbacks
def test_allocate_2(self):
c1 = yield self.make_client()
msg = yield c1.next_event()
self.check_welcome(msg)
c1.send(u"bind", appid=u"appid", side=u"side")
yield c1.sync()
app = self._rendezvous.get_app(u"appid")
self.assertEqual(app.get_allocated(), set())
c1.send(u"allocate")
msg = yield c1.next_event()
self.assertEqual(msg["type"], u"allocated")
cid = msg["channelid"]
self.failUnlessIsInstance(cid, int)
self.assertEqual(app.get_allocated(), set([cid]))
channel = app.get_channel(cid)
self.assertEqual(channel.get_messages(), [])
# second caller increases the number of known sides to 2
c2 = yield self.make_client()
msg = yield c2.next_event()
self.check_welcome(msg)
c2.send(u"bind", appid=u"appid", side=u"side-2")
c2.send(u"claim", channelid=cid)
c2.send(u"add", phase="1", body="")
yield c2.sync()
self.assertEqual(app.get_allocated(), set([cid]))
self.assertEqual(channel.get_messages(), [{"phase": "1", "body": ""}])
c1.send(u"list")
msg = yield c1.next_event()
self.assertEqual(msg["type"], u"all-channelids")
self.assertEqual(msg["channelids"], [cid])
c2.send(u"list")
msg = yield c2.next_event()
self.assertEqual(msg["type"], u"all-channelids")
self.assertEqual(msg["channelids"], [cid])
c1.send(u"deallocate")
msg = yield c1.next_event()
self.assertEqual(msg["type"], u"deallocated")
self.assertEqual(msg["status"], u"waiting")
c2.send(u"deallocate")
msg = yield c2.next_event()
self.assertEqual(msg["type"], u"deallocated")
self.assertEqual(msg["status"], u"deleted")
c2.send(u"list")
msg = yield c2.next_event()
self.assertEqual(msg["type"], u"all-channelids")
self.assertEqual(msg["channelids"], [])
@inlineCallbacks
def test_message(self):
c1 = yield self.make_client()
msg = yield c1.next_event()
self.check_welcome(msg)
c1.send(u"bind", appid=u"appid", side=u"side")
c1.send(u"allocate")
msg = yield c1.next_event()
self.assertEqual(msg["type"], u"allocated")
cid = msg["channelid"]
app = self._rendezvous.get_app(u"appid")
channel = app.get_channel(cid)
self.assertEqual(channel.get_messages(), [])
c1.send(u"watch")
yield c1.sync()
self.assertEqual(len(channel._listeners), 1)
self.assertEqual(c1.events, [])
c1.send(u"add", phase="1", body="msg1A")
yield c1.sync()
self.assertEqual(channel.get_messages(),
[{"phase": "1", "body": "msg1A"}])
self.assertEqual(len(c1.events), 1) # echo should be sent right away
msg = yield c1.next_event()
self.assertEqual(msg["type"], "message")
self.assertEqual(msg["message"], {"phase": "1", "body": "msg1A"})
self.assertIn("sent", msg)
self.assertIsInstance(msg["sent"], float)
c1.send(u"add", phase="1", body="msg1B")
c1.send(u"add", phase="2", body="msg2A")
msg = yield c1.next_event()
self.assertEqual(msg["type"], "message")
self.assertEqual(msg["message"], {"phase": "1", "body": "msg1B"})
msg = yield c1.next_event()
self.assertEqual(msg["type"], "message")
self.assertEqual(msg["message"], {"phase": "2", "body": "msg2A"})
self.assertEqual(channel.get_messages(), [
{"phase": "1", "body": "msg1A"},
{"phase": "1", "body": "msg1B"},
{"phase": "2", "body": "msg2A"},
])
# second client should see everything
c2 = yield self.make_client()
msg = yield c2.next_event()
self.check_welcome(msg)
c2.send(u"bind", appid=u"appid", side=u"side")
c2.send(u"claim", channelid=cid)
# 'watch' triggers delivery of old messages, in temporal order
c2.send(u"watch")
msg = yield c2.next_event()
self.assertEqual(msg["type"], "message")
self.assertEqual(msg["message"], {"phase": "1", "body": "msg1A"})
msg = yield c2.next_event()
self.assertEqual(msg["type"], "message")
self.assertEqual(msg["message"], {"phase": "1", "body": "msg1B"})
msg = yield c2.next_event()
self.assertEqual(msg["type"], "message")
self.assertEqual(msg["message"], {"phase": "2", "body": "msg2A"})
# adding a duplicate is not an error, and clients will ignore it
c1.send(u"add", phase="2", body="msg2A")
# the duplicate message *does* get stored, and delivered
msg = yield c2.next_event()
self.assertEqual(msg["type"], "message")
self.assertEqual(msg["message"], {"phase": "2", "body": "msg2A"})
class Summary(unittest.TestCase):
def test_summarize(self):
c = rendezvous.Channel(None, None, None, None, False, None, None)