diff --git a/MANIFEST.in b/MANIFEST.in index e6cd1b9..c42f3c6 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -11,5 +11,3 @@ include misc/munin/wormhole_errors include misc/munin/wormhole_event_rate include misc/munin/wormhole_events include misc/munin/wormhole_events_alltime -include misc/munin/wormhole_transit -include misc/munin/wormhole_transit_alltime diff --git a/docs/welcome.md b/docs/welcome.md index abedf7c..65d74da 100644 --- a/docs/welcome.md +++ b/docs/welcome.md @@ -184,6 +184,8 @@ addresses of each client with the other (inside the encrypted message), and both clients first attempt to connect directly. If this fails, they fall back to using the transit relay. As before, the host/port of a public server is baked into the library, and should be sufficient to handle moderate traffic. +Code for the Transit Relay is provided a separate package named +`magic-wormhole-transit-relay`. The protocol includes provisions to deliver notices and error messages to clients: if either relay must be shut down, these channels will be used to diff --git a/misc/munin/wormhole_active b/misc/munin/wormhole_active index a9d1db7..132ce86 100755 --- a/misc/munin/wormhole_active +++ b/misc/munin/wormhole_active @@ -22,12 +22,6 @@ mailboxes.type GAUGE messages.label Messages messages.draw LINE1 messages.type GAUGE -transit_waiting.label Transit Waiting -transit_waiting.draw LINE1 -transit_waiting.type GAUGE -transit_connected.label Transit Connected -transit_connected.draw LINE1 -transit_connected.type GAUGE """ if len(sys.argv) > 1 and sys.argv[1] == "config": @@ -45,6 +39,3 @@ ra = data["rendezvous"]["active"] print "nameplates.value", ra["nameplates_total"] print "mailboxes.value", ra["mailboxes_total"] print "messages.value", ra["messages_total"] -ta = data["transit"]["active"] -print "transit_waiting.value", ta["waiting"] -print "transit_connected.value", ta["connected"] diff --git a/misc/munin/wormhole_errors b/misc/munin/wormhole_errors index 5ad312f..d3f4a86 100755 --- a/misc/munin/wormhole_errors +++ b/misc/munin/wormhole_errors @@ -22,9 +22,6 @@ mailboxes.type GAUGE mailboxes_scary.label Mailboxes (scary) mailboxes_scary.draw LINE1 mailboxes_scary.type GAUGE -transit.label Transit -transit.draw LINE1 -transit.type GAUGE """ if len(sys.argv) > 1 and sys.argv[1] == "config": @@ -44,5 +41,3 @@ print "nameplates.value", (r["nameplates_total"] print "mailboxes.value", (r["mailboxes_total"] - r["mailbox_moods"].get("happy", 0)) print "mailboxes_scary.value", r["mailbox_moods"].get("scary", 0) -t = data["transit"]["since_reboot"] -print "transit.value", (t["total"] - t["moods"].get("happy", 0)) diff --git a/misc/munin/wormhole_transit b/misc/munin/wormhole_transit deleted file mode 100755 index e7ba14f..0000000 --- a/misc/munin/wormhole_transit +++ /dev/null @@ -1,33 +0,0 @@ -#! /usr/bin/env python - -""" -Use the following in /etc/munin/plugin-conf.d/wormhole : - -[wormhole_*] -env.serverdir /path/to/your/wormhole/server -""" - -import os, sys, time, json - -CONFIG = """\ -graph_title Magic-Wormhole Transit Usage (since reboot) -graph_vlabel Bytes Since Reboot -graph_category network -bytes.label Transit Bytes -bytes.draw LINE1 -bytes.type GAUGE -""" - -if len(sys.argv) > 1 and sys.argv[1] == "config": - print CONFIG.rstrip() - sys.exit(0) - -serverdir = os.environ["serverdir"] -fn = os.path.join(serverdir, "stats.json") -with open(fn) as f: - data = json.load(f) -if time.time() > data["valid_until"]: - sys.exit(1) # expired - -t = data["transit"]["since_reboot"] -print "bytes.value", t["bytes"] diff --git a/misc/munin/wormhole_transit_alltime b/misc/munin/wormhole_transit_alltime deleted file mode 100644 index 459116b..0000000 --- a/misc/munin/wormhole_transit_alltime +++ /dev/null @@ -1,33 +0,0 @@ -#! /usr/bin/env python - -""" -Use the following in /etc/munin/plugin-conf.d/wormhole : - -[wormhole_*] -env.serverdir /path/to/your/wormhole/server -""" - -import os, sys, time, json - -CONFIG = """\ -graph_title Magic-Wormhole Transit Usage (all time) -graph_vlabel Bytes Since DB Creation -graph_category network -bytes.label Transit Bytes -bytes.draw LINE1 -bytes.type GAUGE -""" - -if len(sys.argv) > 1 and sys.argv[1] == "config": - print CONFIG.rstrip() - sys.exit(0) - -serverdir = os.environ["serverdir"] -fn = os.path.join(serverdir, "stats.json") -with open(fn) as f: - data = json.load(f) -if time.time() > data["valid_until"]: - sys.exit(1) # expired - -t = data["transit"]["all_time"] -print "bytes.value", t["bytes"] diff --git a/setup.py b/setup.py index 3ae39bb..8cc051f 100644 --- a/setup.py +++ b/setup.py @@ -41,7 +41,8 @@ setup(name="magic-wormhole", ], extras_require={ ':sys_platform=="win32"': ["pypiwin32"], - "dev": ["mock", "tox", "pyflakes"], + "dev": ["mock", "tox", "pyflakes", + "magic-wormhole-transit-relay==0.1.0"], }, test_suite="wormhole.test", cmdclass=commands, diff --git a/src/wormhole/server/cli.py b/src/wormhole/server/cli.py index 9596dbd..e205413 100644 --- a/src/wormhole/server/cli.py +++ b/src/wormhole/server/cli.py @@ -41,10 +41,6 @@ LaunchArgs = _compose( "--rendezvous", default="tcp:4000", metavar="tcp:PORT", help="endpoint specification for the rendezvous port", ), - click.option( - "--transit", default="tcp:4001", metavar="tcp:PORT", - help="endpoint specification for the transit-relay port", - ), click.option( "--advertise-version", metavar="VERSION", help="version to recommend to clients", diff --git a/src/wormhole/server/cmd_server.py b/src/wormhole/server/cmd_server.py index 2bbd0a3..f73295e 100644 --- a/src/wormhole/server/cmd_server.py +++ b/src/wormhole/server/cmd_server.py @@ -15,7 +15,6 @@ class MyPlugin(object): from .server import RelayServer return RelayServer( str(self.args.rendezvous), - str(self.args.transit), self.args.advertise_version, self.args.relay_database_path, self.args.blur_usage, diff --git a/src/wormhole/server/server.py b/src/wormhole/server/server.py index c913b1c..94b2b98 100644 --- a/src/wormhole/server/server.py +++ b/src/wormhole/server/server.py @@ -16,7 +16,6 @@ from autobahn.twisted.resource import WebSocketResource from .database import get_db from .rendezvous import Rendezvous from .rendezvous_websocket import WebSocketRendezvousFactory -from .transit_server import Transit SECONDS = 1.0 MINUTE = 60*SECONDS @@ -38,7 +37,7 @@ class PrivacyEnhancedSite(server.Site): class RelayServer(service.MultiService): - def __init__(self, rendezvous_web_port, transit_port, + def __init__(self, rendezvous_web_port, advertise_version, db_url=":memory:", blur_usage=None, signal_error=None, stats_file=None, allow_list=True, websocket_protocol_options=()): @@ -82,13 +81,6 @@ class RelayServer(service.MultiService): rendezvous_web_service = internet.StreamServerEndpointService(r, site) rendezvous_web_service.setServiceParent(self) - if transit_port: - transit = Transit(db, blur_usage) - transit.setServiceParent(self) # for the timer - t = endpoints.serverFromString(reactor, transit_port) - transit_service = internet.StreamServerEndpointService(t, transit) - transit_service.setServiceParent(self) - self._stats_file = stats_file if self._stats_file and os.path.exists(self._stats_file): os.unlink(self._stats_file) @@ -103,10 +95,6 @@ class RelayServer(service.MultiService): self._root = root self._rendezvous_web_service = rendezvous_web_service self._rendezvous_websocket = wsrf - self._transit = None - if transit_port: - self._transit = transit - self._transit_service = transit_service def increase_rlimits(self): if getrlimit is None: @@ -141,10 +129,10 @@ class RelayServer(service.MultiService): service.MultiService.startService(self) self.increase_rlimits() log.msg("websocket listening on /wormhole-relay/ws") - log.msg("Wormhole relay server (Rendezvous and Transit) running") + log.msg("Wormhole relay server (Rendezvous) running") if self._blur_usage: log.msg("blurring access times to %d seconds" % self._blur_usage) - log.msg("not logging HTTP requests or Transit connections") + log.msg("not logging HTTP requests") else: log.msg("not blurring access times") if not self._allow_list: @@ -167,8 +155,6 @@ class RelayServer(service.MultiService): start = time.time() data["rendezvous"] = self._rendezvous.get_stats() - if self._transit: - data["transit"] = self._transit.get_stats() log.msg("get_stats took:", time.time() - start) with open(tmpfn, "wb") as f: diff --git a/src/wormhole/server/transit_server.py b/src/wormhole/server/transit_server.py deleted file mode 100644 index 92265a5..0000000 --- a/src/wormhole/server/transit_server.py +++ /dev/null @@ -1,328 +0,0 @@ -from __future__ import print_function, unicode_literals -import re, time, collections -from twisted.python import log -from twisted.internet import protocol -from twisted.application import service - -SECONDS = 1.0 -MINUTE = 60*SECONDS -HOUR = 60*MINUTE -DAY = 24*HOUR -MB = 1000*1000 - -def round_to(size, coarseness): - return int(coarseness*(1+int((size-1)/coarseness))) - -def blur_size(size): - if size == 0: - return 0 - if size < 1e6: - return round_to(size, 10e3) - if size < 1e9: - return round_to(size, 1e6) - return round_to(size, 100e6) - -class TransitConnection(protocol.Protocol): - def __init__(self): - self._got_token = False - self._got_side = False - self._token_buffer = b"" - self._sent_ok = False - self._buddy = None - self._had_buddy = False - self._total_sent = 0 - - def describeToken(self): - d = "-" - if self._got_token: - d = self._got_token[:16].decode("ascii") - if self._got_side: - d += "-" + self._got_side.decode("ascii") - else: - d += "-" - return d - - def connectionMade(self): - self._started = time.time() - self._log_requests = self.factory._log_requests - - def dataReceived(self, data): - if self._sent_ok: - # We are an IPushProducer to our buddy's IConsumer, so they'll - # throttle us (by calling pauseProducing()) when their outbound - # buffer is full (e.g. when their downstream pipe is full). In - # practice, this buffers about 10MB per connection, after which - # point the sender will only transmit data as fast as the - # receiver can handle it. - self._total_sent += len(data) - self._buddy.transport.write(data) - return - - if self._got_token: # but not yet sent_ok - self.transport.write(b"impatient\n") - if self._log_requests: - log.msg("transit impatience failure") - return self.disconnect() # impatience yields failure - - # else this should be (part of) the token - self._token_buffer += data - buf = self._token_buffer - - # old: "please relay {64}\n" - # new: "please relay {64} for side {16}\n" - (old, handshake_len, token) = self._check_old_handshake(buf) - assert old in ("yes", "waiting", "no") - if old == "yes": - # remember they aren't supposed to send anything past their - # handshake until we've said go - if len(buf) > handshake_len: - self.transport.write(b"impatient\n") - if self._log_requests: - log.msg("transit impatience failure") - return self.disconnect() # impatience yields failure - return self._got_handshake(token, None) - (new, handshake_len, token, side) = self._check_new_handshake(buf) - assert new in ("yes", "waiting", "no") - if new == "yes": - if len(buf) > handshake_len: - self.transport.write(b"impatient\n") - if self._log_requests: - log.msg("transit impatience failure") - return self.disconnect() # impatience yields failure - return self._got_handshake(token, side) - if (old == "no" and new == "no"): - self.transport.write(b"bad handshake\n") - if self._log_requests: - log.msg("transit handshake failure") - return self.disconnect() # incorrectness yields failure - # else we'll keep waiting - - def _check_old_handshake(self, buf): - # old: "please relay {64}\n" - # return ("yes", handshake, token) if buf contains an old-style handshake - # return ("waiting", None, None) if it might eventually contain one - # return ("no", None, None) if it could never contain one - wanted = len("please relay \n")+32*2 - if len(buf) < wanted-1 and b"\n" in buf: - return ("no", None, None) - if len(buf) < wanted: - return ("waiting", None, None) - - mo = re.search(br"^please relay (\w{64})\n", buf, re.M) - if mo: - token = mo.group(1) - return ("yes", wanted, token) - return ("no", None, None) - - def _check_new_handshake(self, buf): - # new: "please relay {64} for side {16}\n" - wanted = len("please relay for side \n")+32*2+8*2 - if len(buf) < wanted-1 and b"\n" in buf: - return ("no", None, None, None) - if len(buf) < wanted: - return ("waiting", None, None, None) - - mo = re.search(br"^please relay (\w{64}) for side (\w{16})\n", buf, re.M) - if mo: - token = mo.group(1) - side = mo.group(2) - return ("yes", wanted, token, side) - return ("no", None, None, None) - - def _got_handshake(self, token, side): - self._got_token = token - self._got_side = side - self.factory.connection_got_token(token, side, self) - - def buddy_connected(self, them): - self._buddy = them - self._had_buddy = True - self.transport.write(b"ok\n") - self._sent_ok = True - # Connect the two as a producer/consumer pair. We use streaming=True, - # so this expects the IPushProducer interface, and uses - # pauseProducing() to throttle, and resumeProducing() to unthrottle. - self._buddy.transport.registerProducer(self.transport, True) - # The Transit object calls buddy_connected() on both protocols, so - # there will be two producer/consumer pairs. - - def buddy_disconnected(self): - if self._log_requests: - log.msg("buddy_disconnected %s" % self.describeToken()) - self._buddy = None - self.transport.loseConnection() - - def connectionLost(self, reason): - if self._buddy: - self._buddy.buddy_disconnected() - self.factory.transitFinished(self, self._got_token, self._got_side, - self.describeToken()) - - # Record usage. There are four cases: - # * 1: we connected, never had a buddy - # * 2: we connected first, we disconnect before the buddy - # * 3: we connected first, buddy disconnects first - # * 4: buddy connected first, we disconnect before buddy - # * 5: buddy connected first, buddy disconnects first - - # whoever disconnects first gets to write the usage record (1,2,4) - - finished = time.time() - if not self._had_buddy: # 1 - total_time = finished - self._started - self.factory.recordUsage(self._started, "lonely", 0, - total_time, None) - if self._had_buddy and self._buddy: # 2,4 - total_bytes = self._total_sent + self._buddy._total_sent - starts = [self._started, self._buddy._started] - total_time = finished - min(starts) - waiting_time = max(starts) - min(starts) - self.factory.recordUsage(self._started, "happy", total_bytes, - total_time, waiting_time) - - def disconnect(self): - self.transport.loseConnection() - self.factory.transitFailed(self) - finished = time.time() - total_time = finished - self._started - self.factory.recordUsage(self._started, "errory", 0, - total_time, None) - -class Transit(protocol.ServerFactory, service.MultiService): - # I manage pairs of simultaneous connections to a secondary TCP port, - # both forwarded to the other. Clients must begin each connection with - # "please relay TOKEN for SIDE\n" (or a legacy form without the "for - # SIDE"). Two connections match if they use the same TOKEN and have - # different SIDEs (the redundant connections are dropped when a match is - # made). Legacy connections match any with the same TOKEN, ignoring SIDE - # (so two legacy connections will match each other). - - # I will send "ok\n" when the matching connection is established, or - # disconnect if no matching connection is made within MAX_WAIT_TIME - # seconds. I will disconnect if you send data before the "ok\n". All data - # you get after the "ok\n" will be from the other side. You will not - # receive "ok\n" until the other side has also connected and submitted a - # matching token (and differing SIDE). - - # In addition, the connections will be dropped after MAXLENGTH bytes have - # been sent by either side, or MAXTIME seconds have elapsed after the - # matching connections were established. A future API will reveal these - # limits to clients instead of causing mysterious spontaneous failures. - - # These relay connections are not half-closeable (unlike full TCP - # connections, applications will not receive any data after half-closing - # their outgoing side). Applications must negotiate shutdown with their - # peer and not close the connection until all data has finished - # transferring in both directions. Applications which only need to send - # data in one direction can use close() as usual. - - MAX_WAIT_TIME = 30*SECONDS - MAXLENGTH = 10*MB - MAXTIME = 60*SECONDS - protocol = TransitConnection - - def __init__(self, db, blur_usage): - service.MultiService.__init__(self) - self._db = db - self._blur_usage = blur_usage - self._log_requests = blur_usage is None - self._pending_requests = {} # token -> set((side, TransitConnection)) - self._active_connections = set() # TransitConnection - self._counts = collections.defaultdict(int) - self._count_bytes = 0 - - def connection_got_token(self, token, new_side, new_tc): - if token not in self._pending_requests: - self._pending_requests[token] = set() - potentials = self._pending_requests[token] - for old in potentials: - (old_side, old_tc) = old - if ((old_side is None) - or (new_side is None) - or (old_side != new_side)): - # we found a match - if self._log_requests: - log.msg("transit relay 2: %s" % new_tc.describeToken()) - - # drop and stop tracking the rest - potentials.remove(old) - for (_, leftover_tc) in potentials: - leftover_tc.disconnect() # TODO: not "errory"? - self._pending_requests.pop(token) - - # glue the two ends together - self._active_connections.add(new_tc) - self._active_connections.add(old_tc) - new_tc.buddy_connected(old_tc) - old_tc.buddy_connected(new_tc) - return - if self._log_requests: - log.msg("transit relay 1: %s" % new_tc.describeToken()) - potentials.add((new_side, new_tc)) - # TODO: timer - - def recordUsage(self, started, result, total_bytes, - total_time, waiting_time): - if self._log_requests: - log.msg("Transit.recordUsage (%dB)" % total_bytes) - if self._blur_usage: - started = self._blur_usage * (started // self._blur_usage) - total_bytes = blur_size(total_bytes) - self._db.execute("INSERT INTO `transit_usage`" - " (`started`, `total_time`, `waiting_time`," - " `total_bytes`, `result`)" - " VALUES (?,?,?, ?,?)", - (started, total_time, waiting_time, - total_bytes, result)) - self._db.commit() - self._counts[result] += 1 - self._count_bytes += total_bytes - - def transitFinished(self, tc, token, side, description): - if token in self._pending_requests: - side_tc = (side, tc) - if side_tc in self._pending_requests[token]: - self._pending_requests[token].remove(side_tc) - if not self._pending_requests[token]: # set is now empty - del self._pending_requests[token] - if self._log_requests: - log.msg("transitFinished %s" % (description,)) - self._active_connections.discard(tc) - - def transitFailed(self, p): - if self._log_requests: - log.msg("transitFailed %r" % p) - pass - - def get_stats(self): - stats = {} - def q(query, values=()): - row = self._db.execute(query, values).fetchone() - return list(row.values())[0] - - # current status: expected to be zero most of the time - c = stats["active"] = {} - c["connected"] = len(self._active_connections) / 2 - c["waiting"] = len(self._pending_requests) - - # usage since last reboot - rb = stats["since_reboot"] = {} - rb["bytes"] = self._count_bytes - rb["total"] = sum(self._counts.values(), 0) - rbm = rb["moods"] = {} - for result, count in self._counts.items(): - rbm[result] = count - - # historical usage (all-time) - u = stats["all_time"] = {} - u["total"] = q("SELECT COUNT() FROM `transit_usage`") - u["bytes"] = q("SELECT SUM(`total_bytes`) FROM `transit_usage`") or 0 - um = u["moods"] = {} - um["happy"] = q("SELECT COUNT() FROM `transit_usage`" - " WHERE `result`='happy'") - um["lonely"] = q("SELECT COUNT() FROM `transit_usage`" - " WHERE `result`='lonely'") - um["errory"] = q("SELECT COUNT() FROM `transit_usage`" - " WHERE `result`='errory'") - - return stats diff --git a/src/wormhole/test/common.py b/src/wormhole/test/common.py index a0761f2..c1c4fe9 100644 --- a/src/wormhole/test/common.py +++ b/src/wormhole/test/common.py @@ -1,12 +1,13 @@ # no unicode_literals untill twisted update -from twisted.application import service -from twisted.internet import defer, task, reactor +from twisted.application import service, internet +from twisted.internet import defer, task, reactor, endpoints from twisted.python import log from click.testing import CliRunner import mock from ..cli import cli from ..transit import allocate_tcp_port from ..server.server import RelayServer +from wormhole_transit_relay.transit_server import Transit class ServerBase: def setUp(self): @@ -16,20 +17,25 @@ class ServerBase: self.sp = service.MultiService() self.sp.startService() self.relayport = allocate_tcp_port() - self.transitport = allocate_tcp_port() # need to talk to twisted team about only using unicode in # endpoints.serverFromString s = RelayServer("tcp:%d:interface=127.0.0.1" % self.relayport, - "tcp:%s:interface=127.0.0.1" % self.transitport, advertise_version=advertise_version, signal_error=error) s.setServiceParent(self.sp) self._relay_server = s self._rendezvous = s._rendezvous - self._transit_server = s._transit self.relayurl = u"ws://127.0.0.1:%d/v1" % self.relayport self.rdv_ws_port = self.relayport # ws://127.0.0.1:%d/wormhole-relay/ws + + self.transitport = allocate_tcp_port() + ep = endpoints.serverFromString(reactor, + "tcp:%d:interface=127.0.0.1" % + self.transitport) + self._transit_server = f = Transit(blur_usage=None, log_file=None, + usage_db=None) + internet.StreamServerEndpointService(ep, f).setServiceParent(self.sp) self.transit = u"tcp:127.0.0.1:%d" % self.transitport def tearDown(self): diff --git a/src/wormhole/test/test_server.py b/src/wormhole/test/test_server.py index 8f55919..5482594 100644 --- a/src/wormhole/test/test_server.py +++ b/src/wormhole/test/test_server.py @@ -13,13 +13,11 @@ from ..server.database import get_db def easy_relay( rendezvous_web_port=str("tcp:0"), - transit_port=str("tcp:0"), advertise_version=None, **kwargs ): return server.RelayServer( rendezvous_web_port, - transit_port, advertise_version, **kwargs ) @@ -33,7 +31,7 @@ class RLimits(unittest.TestCase): # is easier than just passing "tcp:0" ep = endpoints.TCP4ServerEndpoint(None, 0) with patch_s("endpoints.serverFromString", return_value=ep): - s = server.RelayServer("fake", None, None) + s = server.RelayServer("fake", None) fakelog = [] def checklog(*expected): self.assertEqual(fakelog, list(expected)) @@ -1403,7 +1401,6 @@ class DumpStats(unittest.TestCase): self.assertEqual(data["created"], now) self.assertEqual(data["valid_until"], now+validity) self.assertEqual(data["rendezvous"]["all_time"]["mailboxes_total"], 0) - self.assertEqual(data["transit"]["all_time"]["total"], 0) class Startup(unittest.TestCase): diff --git a/src/wormhole/test/test_transit.py b/src/wormhole/test/test_transit.py index a710dce..c93d9d3 100644 --- a/src/wormhole/test/test_transit.py +++ b/src/wormhole/test/test_transit.py @@ -10,9 +10,9 @@ from twisted.internet import defer, task, endpoints, protocol, address, error from twisted.internet.defer import gatherResults, inlineCallbacks from twisted.python import log, failure from twisted.test import proto_helpers +from wormhole_transit_relay import transit_server from ..errors import InternalError from .. import transit -from ..server import transit_server from .common import ServerBase from nacl.secret import SecretBox from nacl.exceptions import CryptoError diff --git a/src/wormhole/test/test_transit_server.py b/src/wormhole/test/test_transit_server.py deleted file mode 100644 index acdca61..0000000 --- a/src/wormhole/test/test_transit_server.py +++ /dev/null @@ -1,306 +0,0 @@ -from __future__ import print_function, unicode_literals -from binascii import hexlify -from twisted.trial import unittest -from twisted.internet import protocol, reactor, defer -from twisted.internet.endpoints import clientFromString, connectProtocol -from twisted.web import client -from .common import ServerBase -from ..server import transit_server - -class Accumulator(protocol.Protocol): - def __init__(self): - self.data = b"" - self.count = 0 - self._wait = None - self._disconnect = defer.Deferred() - def waitForBytes(self, more): - assert self._wait is None - self.count = more - self._wait = defer.Deferred() - self._check_done() - return self._wait - def dataReceived(self, data): - self.data = self.data + data - self._check_done() - def _check_done(self): - if self._wait and len(self.data) >= self.count: - d = self._wait - self._wait = None - d.callback(self) - def connectionLost(self, why): - if self._wait: - self._wait.errback(RuntimeError("closed")) - self._disconnect.callback(None) - -class Transit(ServerBase, unittest.TestCase): - def test_blur_size(self): - blur = transit_server.blur_size - self.failUnlessEqual(blur(0), 0) - self.failUnlessEqual(blur(1), 10e3) - self.failUnlessEqual(blur(10e3), 10e3) - self.failUnlessEqual(blur(10e3+1), 20e3) - self.failUnlessEqual(blur(15e3), 20e3) - self.failUnlessEqual(blur(20e3), 20e3) - self.failUnlessEqual(blur(1e6), 1e6) - self.failUnlessEqual(blur(1e6+1), 2e6) - self.failUnlessEqual(blur(1.5e6), 2e6) - self.failUnlessEqual(blur(2e6), 2e6) - self.failUnlessEqual(blur(900e6), 900e6) - self.failUnlessEqual(blur(1000e6), 1000e6) - self.failUnlessEqual(blur(1050e6), 1100e6) - self.failUnlessEqual(blur(1100e6), 1100e6) - self.failUnlessEqual(blur(1150e6), 1200e6) - - @defer.inlineCallbacks - def test_web_request(self): - resp = yield client.getPage('http://127.0.0.1:{}/'.format(self.relayport).encode('ascii')) - self.assertEqual('Wormhole Relay'.encode('ascii'), resp.strip()) - - @defer.inlineCallbacks - def test_register(self): - ep = clientFromString(reactor, self.transit) - a1 = yield connectProtocol(ep, Accumulator()) - - token1 = b"\x00"*32 - side1 = b"\x01"*8 - a1.transport.write(b"please relay " + hexlify(token1) + - b" for side " + hexlify(side1) + b"\n") - - # let that arrive - while self.count() == 0: - yield self.wait() - self.assertEqual(self.count(), 1) - - a1.transport.loseConnection() - - # let that get removed - while self.count() > 0: - yield self.wait() - self.assertEqual(self.count(), 0) - - # the token should be removed too - self.assertEqual(len(self._transit_server._pending_requests), 0) - - @defer.inlineCallbacks - def test_both_unsided(self): - ep = clientFromString(reactor, self.transit) - a1 = yield connectProtocol(ep, Accumulator()) - a2 = yield connectProtocol(ep, Accumulator()) - - token1 = b"\x00"*32 - a1.transport.write(b"please relay " + hexlify(token1) + b"\n") - a2.transport.write(b"please relay " + hexlify(token1) + b"\n") - - # a correct handshake yields an ack, after which we can send - exp = b"ok\n" - yield a1.waitForBytes(len(exp)) - self.assertEqual(a1.data, exp) - s1 = b"data1" - a1.transport.write(s1) - - exp = b"ok\n" - yield a2.waitForBytes(len(exp)) - self.assertEqual(a2.data, exp) - - # all data they sent after the handshake should be given to us - exp = b"ok\n"+s1 - yield a2.waitForBytes(len(exp)) - self.assertEqual(a2.data, exp) - - a1.transport.loseConnection() - a2.transport.loseConnection() - - @defer.inlineCallbacks - def test_sided_unsided(self): - ep = clientFromString(reactor, self.transit) - a1 = yield connectProtocol(ep, Accumulator()) - a2 = yield connectProtocol(ep, Accumulator()) - - token1 = b"\x00"*32 - side1 = b"\x01"*8 - a1.transport.write(b"please relay " + hexlify(token1) + - b" for side " + hexlify(side1) + b"\n") - a2.transport.write(b"please relay " + hexlify(token1) + b"\n") - - # a correct handshake yields an ack, after which we can send - exp = b"ok\n" - yield a1.waitForBytes(len(exp)) - self.assertEqual(a1.data, exp) - s1 = b"data1" - a1.transport.write(s1) - - exp = b"ok\n" - yield a2.waitForBytes(len(exp)) - self.assertEqual(a2.data, exp) - - # all data they sent after the handshake should be given to us - exp = b"ok\n"+s1 - yield a2.waitForBytes(len(exp)) - self.assertEqual(a2.data, exp) - - a1.transport.loseConnection() - a2.transport.loseConnection() - - @defer.inlineCallbacks - def test_unsided_sided(self): - ep = clientFromString(reactor, self.transit) - a1 = yield connectProtocol(ep, Accumulator()) - a2 = yield connectProtocol(ep, Accumulator()) - - token1 = b"\x00"*32 - side1 = b"\x01"*8 - a1.transport.write(b"please relay " + hexlify(token1) + b"\n") - a2.transport.write(b"please relay " + hexlify(token1) + - b" for side " + hexlify(side1) + b"\n") - - # a correct handshake yields an ack, after which we can send - exp = b"ok\n" - yield a1.waitForBytes(len(exp)) - self.assertEqual(a1.data, exp) - s1 = b"data1" - a1.transport.write(s1) - - exp = b"ok\n" - yield a2.waitForBytes(len(exp)) - self.assertEqual(a2.data, exp) - - # all data they sent after the handshake should be given to us - exp = b"ok\n"+s1 - yield a2.waitForBytes(len(exp)) - self.assertEqual(a2.data, exp) - - a1.transport.loseConnection() - a2.transport.loseConnection() - - @defer.inlineCallbacks - def test_both_sided(self): - ep = clientFromString(reactor, self.transit) - a1 = yield connectProtocol(ep, Accumulator()) - a2 = yield connectProtocol(ep, Accumulator()) - - token1 = b"\x00"*32 - side1 = b"\x01"*8 - side2 = b"\x02"*8 - a1.transport.write(b"please relay " + hexlify(token1) + - b" for side " + hexlify(side1) + b"\n") - a2.transport.write(b"please relay " + hexlify(token1) + - b" for side " + hexlify(side2) + b"\n") - - # a correct handshake yields an ack, after which we can send - exp = b"ok\n" - yield a1.waitForBytes(len(exp)) - self.assertEqual(a1.data, exp) - s1 = b"data1" - a1.transport.write(s1) - - exp = b"ok\n" - yield a2.waitForBytes(len(exp)) - self.assertEqual(a2.data, exp) - - # all data they sent after the handshake should be given to us - exp = b"ok\n"+s1 - yield a2.waitForBytes(len(exp)) - self.assertEqual(a2.data, exp) - - a1.transport.loseConnection() - a2.transport.loseConnection() - - def count(self): - return sum([len(potentials) - for potentials - in self._transit_server._pending_requests.values()]) - def wait(self): - d = defer.Deferred() - reactor.callLater(0.001, d.callback, None) - return d - - @defer.inlineCallbacks - def test_ignore_same_side(self): - ep = clientFromString(reactor, self.transit) - a1 = yield connectProtocol(ep, Accumulator()) - a2 = yield connectProtocol(ep, Accumulator()) - - token1 = b"\x00"*32 - side1 = b"\x01"*8 - a1.transport.write(b"please relay " + hexlify(token1) + - b" for side " + hexlify(side1) + b"\n") - # let that arrive - while self.count() == 0: - yield self.wait() - a2.transport.write(b"please relay " + hexlify(token1) + - b" for side " + hexlify(side1) + b"\n") - # let that arrive - while self.count() == 1: - yield self.wait() - self.assertEqual(self.count(), 2) # same-side connections don't match - - a1.transport.loseConnection() - a2.transport.loseConnection() - - @defer.inlineCallbacks - def test_bad_handshake(self): - ep = clientFromString(reactor, self.transit) - a1 = yield connectProtocol(ep, Accumulator()) - - token1 = b"\x00"*32 - # the server waits for the exact number of bytes in the expected - # handshake message. to trigger "bad handshake", we must match. - a1.transport.write(b"please DELAY " + hexlify(token1) + b"\n") - - exp = b"bad handshake\n" - yield a1.waitForBytes(len(exp)) - self.assertEqual(a1.data, exp) - - a1.transport.loseConnection() - - @defer.inlineCallbacks - def test_binary_handshake(self): - ep = clientFromString(reactor, self.transit) - a1 = yield connectProtocol(ep, Accumulator()) - - binary_bad_handshake = b"\x00\x01\xe0\x0f\n\xff" - # the embedded \n makes the server trigger early, before the full - # expected handshake length has arrived. A non-wormhole client - # writing non-ascii junk to the transit port used to trigger a - # UnicodeDecodeError when it tried to coerce the incoming handshake - # to unicode, due to the ("\n" in buf) check. This was fixed to use - # (b"\n" in buf). This exercises the old failure. - a1.transport.write(binary_bad_handshake) - - exp = b"bad handshake\n" - yield a1.waitForBytes(len(exp)) - self.assertEqual(a1.data, exp) - - a1.transport.loseConnection() - - @defer.inlineCallbacks - def test_impatience_old(self): - ep = clientFromString(reactor, self.transit) - a1 = yield connectProtocol(ep, Accumulator()) - - token1 = b"\x00"*32 - # sending too many bytes is impatience. - a1.transport.write(b"please relay " + hexlify(token1) + b"\nNOWNOWNOW") - - exp = b"impatient\n" - yield a1.waitForBytes(len(exp)) - self.assertEqual(a1.data, exp) - - a1.transport.loseConnection() - - @defer.inlineCallbacks - def test_impatience_new(self): - ep = clientFromString(reactor, self.transit) - a1 = yield connectProtocol(ep, Accumulator()) - - token1 = b"\x00"*32 - side1 = b"\x01"*8 - # sending too many bytes is impatience. - a1.transport.write(b"please relay " + hexlify(token1) + - b" for side " + hexlify(side1) + b"\nNOWNOWNOW") - - exp = b"impatient\n" - yield a1.waitForBytes(len(exp)) - self.assertEqual(a1.data, exp) - - a1.transport.loseConnection()