diff --git a/setup.py b/setup.py index 3ae39bb..bd49900 100644 --- a/setup.py +++ b/setup.py @@ -41,7 +41,7 @@ setup(name="magic-wormhole", ], extras_require={ ':sys_platform=="win32"': ["pypiwin32"], - "dev": ["mock", "tox", "pyflakes"], + "dev": ["mock", "tox", "pyflakes", "magic-wormhole-transit-relay"], }, test_suite="wormhole.test", cmdclass=commands, diff --git a/src/wormhole/server/server.py b/src/wormhole/server/server.py index c913b1c..b19d7f0 100644 --- a/src/wormhole/server/server.py +++ b/src/wormhole/server/server.py @@ -16,7 +16,6 @@ from autobahn.twisted.resource import WebSocketResource from .database import get_db from .rendezvous import Rendezvous from .rendezvous_websocket import WebSocketRendezvousFactory -from .transit_server import Transit SECONDS = 1.0 MINUTE = 60*SECONDS @@ -82,13 +81,6 @@ class RelayServer(service.MultiService): rendezvous_web_service = internet.StreamServerEndpointService(r, site) rendezvous_web_service.setServiceParent(self) - if transit_port: - transit = Transit(db, blur_usage) - transit.setServiceParent(self) # for the timer - t = endpoints.serverFromString(reactor, transit_port) - transit_service = internet.StreamServerEndpointService(t, transit) - transit_service.setServiceParent(self) - self._stats_file = stats_file if self._stats_file and os.path.exists(self._stats_file): os.unlink(self._stats_file) @@ -103,10 +95,6 @@ class RelayServer(service.MultiService): self._root = root self._rendezvous_web_service = rendezvous_web_service self._rendezvous_websocket = wsrf - self._transit = None - if transit_port: - self._transit = transit - self._transit_service = transit_service def increase_rlimits(self): if getrlimit is None: @@ -141,10 +129,10 @@ class RelayServer(service.MultiService): service.MultiService.startService(self) self.increase_rlimits() log.msg("websocket listening on /wormhole-relay/ws") - log.msg("Wormhole relay server (Rendezvous and Transit) running") + log.msg("Wormhole relay server (Rendezvous) running") if self._blur_usage: log.msg("blurring access times to %d seconds" % self._blur_usage) - log.msg("not logging HTTP requests or Transit connections") + log.msg("not logging HTTP requests") else: log.msg("not blurring access times") if not self._allow_list: @@ -167,8 +155,6 @@ class RelayServer(service.MultiService): start = time.time() data["rendezvous"] = self._rendezvous.get_stats() - if self._transit: - data["transit"] = self._transit.get_stats() log.msg("get_stats took:", time.time() - start) with open(tmpfn, "wb") as f: diff --git a/src/wormhole/server/transit_server.py b/src/wormhole/server/transit_server.py deleted file mode 100644 index 92265a5..0000000 --- a/src/wormhole/server/transit_server.py +++ /dev/null @@ -1,328 +0,0 @@ -from __future__ import print_function, unicode_literals -import re, time, collections -from twisted.python import log -from twisted.internet import protocol -from twisted.application import service - -SECONDS = 1.0 -MINUTE = 60*SECONDS -HOUR = 60*MINUTE -DAY = 24*HOUR -MB = 1000*1000 - -def round_to(size, coarseness): - return int(coarseness*(1+int((size-1)/coarseness))) - -def blur_size(size): - if size == 0: - return 0 - if size < 1e6: - return round_to(size, 10e3) - if size < 1e9: - return round_to(size, 1e6) - return round_to(size, 100e6) - -class TransitConnection(protocol.Protocol): - def __init__(self): - self._got_token = False - self._got_side = False - self._token_buffer = b"" - self._sent_ok = False - self._buddy = None - self._had_buddy = False - self._total_sent = 0 - - def describeToken(self): - d = "-" - if self._got_token: - d = self._got_token[:16].decode("ascii") - if self._got_side: - d += "-" + self._got_side.decode("ascii") - else: - d += "-" - return d - - def connectionMade(self): - self._started = time.time() - self._log_requests = self.factory._log_requests - - def dataReceived(self, data): - if self._sent_ok: - # We are an IPushProducer to our buddy's IConsumer, so they'll - # throttle us (by calling pauseProducing()) when their outbound - # buffer is full (e.g. when their downstream pipe is full). In - # practice, this buffers about 10MB per connection, after which - # point the sender will only transmit data as fast as the - # receiver can handle it. - self._total_sent += len(data) - self._buddy.transport.write(data) - return - - if self._got_token: # but not yet sent_ok - self.transport.write(b"impatient\n") - if self._log_requests: - log.msg("transit impatience failure") - return self.disconnect() # impatience yields failure - - # else this should be (part of) the token - self._token_buffer += data - buf = self._token_buffer - - # old: "please relay {64}\n" - # new: "please relay {64} for side {16}\n" - (old, handshake_len, token) = self._check_old_handshake(buf) - assert old in ("yes", "waiting", "no") - if old == "yes": - # remember they aren't supposed to send anything past their - # handshake until we've said go - if len(buf) > handshake_len: - self.transport.write(b"impatient\n") - if self._log_requests: - log.msg("transit impatience failure") - return self.disconnect() # impatience yields failure - return self._got_handshake(token, None) - (new, handshake_len, token, side) = self._check_new_handshake(buf) - assert new in ("yes", "waiting", "no") - if new == "yes": - if len(buf) > handshake_len: - self.transport.write(b"impatient\n") - if self._log_requests: - log.msg("transit impatience failure") - return self.disconnect() # impatience yields failure - return self._got_handshake(token, side) - if (old == "no" and new == "no"): - self.transport.write(b"bad handshake\n") - if self._log_requests: - log.msg("transit handshake failure") - return self.disconnect() # incorrectness yields failure - # else we'll keep waiting - - def _check_old_handshake(self, buf): - # old: "please relay {64}\n" - # return ("yes", handshake, token) if buf contains an old-style handshake - # return ("waiting", None, None) if it might eventually contain one - # return ("no", None, None) if it could never contain one - wanted = len("please relay \n")+32*2 - if len(buf) < wanted-1 and b"\n" in buf: - return ("no", None, None) - if len(buf) < wanted: - return ("waiting", None, None) - - mo = re.search(br"^please relay (\w{64})\n", buf, re.M) - if mo: - token = mo.group(1) - return ("yes", wanted, token) - return ("no", None, None) - - def _check_new_handshake(self, buf): - # new: "please relay {64} for side {16}\n" - wanted = len("please relay for side \n")+32*2+8*2 - if len(buf) < wanted-1 and b"\n" in buf: - return ("no", None, None, None) - if len(buf) < wanted: - return ("waiting", None, None, None) - - mo = re.search(br"^please relay (\w{64}) for side (\w{16})\n", buf, re.M) - if mo: - token = mo.group(1) - side = mo.group(2) - return ("yes", wanted, token, side) - return ("no", None, None, None) - - def _got_handshake(self, token, side): - self._got_token = token - self._got_side = side - self.factory.connection_got_token(token, side, self) - - def buddy_connected(self, them): - self._buddy = them - self._had_buddy = True - self.transport.write(b"ok\n") - self._sent_ok = True - # Connect the two as a producer/consumer pair. We use streaming=True, - # so this expects the IPushProducer interface, and uses - # pauseProducing() to throttle, and resumeProducing() to unthrottle. - self._buddy.transport.registerProducer(self.transport, True) - # The Transit object calls buddy_connected() on both protocols, so - # there will be two producer/consumer pairs. - - def buddy_disconnected(self): - if self._log_requests: - log.msg("buddy_disconnected %s" % self.describeToken()) - self._buddy = None - self.transport.loseConnection() - - def connectionLost(self, reason): - if self._buddy: - self._buddy.buddy_disconnected() - self.factory.transitFinished(self, self._got_token, self._got_side, - self.describeToken()) - - # Record usage. There are four cases: - # * 1: we connected, never had a buddy - # * 2: we connected first, we disconnect before the buddy - # * 3: we connected first, buddy disconnects first - # * 4: buddy connected first, we disconnect before buddy - # * 5: buddy connected first, buddy disconnects first - - # whoever disconnects first gets to write the usage record (1,2,4) - - finished = time.time() - if not self._had_buddy: # 1 - total_time = finished - self._started - self.factory.recordUsage(self._started, "lonely", 0, - total_time, None) - if self._had_buddy and self._buddy: # 2,4 - total_bytes = self._total_sent + self._buddy._total_sent - starts = [self._started, self._buddy._started] - total_time = finished - min(starts) - waiting_time = max(starts) - min(starts) - self.factory.recordUsage(self._started, "happy", total_bytes, - total_time, waiting_time) - - def disconnect(self): - self.transport.loseConnection() - self.factory.transitFailed(self) - finished = time.time() - total_time = finished - self._started - self.factory.recordUsage(self._started, "errory", 0, - total_time, None) - -class Transit(protocol.ServerFactory, service.MultiService): - # I manage pairs of simultaneous connections to a secondary TCP port, - # both forwarded to the other. Clients must begin each connection with - # "please relay TOKEN for SIDE\n" (or a legacy form without the "for - # SIDE"). Two connections match if they use the same TOKEN and have - # different SIDEs (the redundant connections are dropped when a match is - # made). Legacy connections match any with the same TOKEN, ignoring SIDE - # (so two legacy connections will match each other). - - # I will send "ok\n" when the matching connection is established, or - # disconnect if no matching connection is made within MAX_WAIT_TIME - # seconds. I will disconnect if you send data before the "ok\n". All data - # you get after the "ok\n" will be from the other side. You will not - # receive "ok\n" until the other side has also connected and submitted a - # matching token (and differing SIDE). - - # In addition, the connections will be dropped after MAXLENGTH bytes have - # been sent by either side, or MAXTIME seconds have elapsed after the - # matching connections were established. A future API will reveal these - # limits to clients instead of causing mysterious spontaneous failures. - - # These relay connections are not half-closeable (unlike full TCP - # connections, applications will not receive any data after half-closing - # their outgoing side). Applications must negotiate shutdown with their - # peer and not close the connection until all data has finished - # transferring in both directions. Applications which only need to send - # data in one direction can use close() as usual. - - MAX_WAIT_TIME = 30*SECONDS - MAXLENGTH = 10*MB - MAXTIME = 60*SECONDS - protocol = TransitConnection - - def __init__(self, db, blur_usage): - service.MultiService.__init__(self) - self._db = db - self._blur_usage = blur_usage - self._log_requests = blur_usage is None - self._pending_requests = {} # token -> set((side, TransitConnection)) - self._active_connections = set() # TransitConnection - self._counts = collections.defaultdict(int) - self._count_bytes = 0 - - def connection_got_token(self, token, new_side, new_tc): - if token not in self._pending_requests: - self._pending_requests[token] = set() - potentials = self._pending_requests[token] - for old in potentials: - (old_side, old_tc) = old - if ((old_side is None) - or (new_side is None) - or (old_side != new_side)): - # we found a match - if self._log_requests: - log.msg("transit relay 2: %s" % new_tc.describeToken()) - - # drop and stop tracking the rest - potentials.remove(old) - for (_, leftover_tc) in potentials: - leftover_tc.disconnect() # TODO: not "errory"? - self._pending_requests.pop(token) - - # glue the two ends together - self._active_connections.add(new_tc) - self._active_connections.add(old_tc) - new_tc.buddy_connected(old_tc) - old_tc.buddy_connected(new_tc) - return - if self._log_requests: - log.msg("transit relay 1: %s" % new_tc.describeToken()) - potentials.add((new_side, new_tc)) - # TODO: timer - - def recordUsage(self, started, result, total_bytes, - total_time, waiting_time): - if self._log_requests: - log.msg("Transit.recordUsage (%dB)" % total_bytes) - if self._blur_usage: - started = self._blur_usage * (started // self._blur_usage) - total_bytes = blur_size(total_bytes) - self._db.execute("INSERT INTO `transit_usage`" - " (`started`, `total_time`, `waiting_time`," - " `total_bytes`, `result`)" - " VALUES (?,?,?, ?,?)", - (started, total_time, waiting_time, - total_bytes, result)) - self._db.commit() - self._counts[result] += 1 - self._count_bytes += total_bytes - - def transitFinished(self, tc, token, side, description): - if token in self._pending_requests: - side_tc = (side, tc) - if side_tc in self._pending_requests[token]: - self._pending_requests[token].remove(side_tc) - if not self._pending_requests[token]: # set is now empty - del self._pending_requests[token] - if self._log_requests: - log.msg("transitFinished %s" % (description,)) - self._active_connections.discard(tc) - - def transitFailed(self, p): - if self._log_requests: - log.msg("transitFailed %r" % p) - pass - - def get_stats(self): - stats = {} - def q(query, values=()): - row = self._db.execute(query, values).fetchone() - return list(row.values())[0] - - # current status: expected to be zero most of the time - c = stats["active"] = {} - c["connected"] = len(self._active_connections) / 2 - c["waiting"] = len(self._pending_requests) - - # usage since last reboot - rb = stats["since_reboot"] = {} - rb["bytes"] = self._count_bytes - rb["total"] = sum(self._counts.values(), 0) - rbm = rb["moods"] = {} - for result, count in self._counts.items(): - rbm[result] = count - - # historical usage (all-time) - u = stats["all_time"] = {} - u["total"] = q("SELECT COUNT() FROM `transit_usage`") - u["bytes"] = q("SELECT SUM(`total_bytes`) FROM `transit_usage`") or 0 - um = u["moods"] = {} - um["happy"] = q("SELECT COUNT() FROM `transit_usage`" - " WHERE `result`='happy'") - um["lonely"] = q("SELECT COUNT() FROM `transit_usage`" - " WHERE `result`='lonely'") - um["errory"] = q("SELECT COUNT() FROM `transit_usage`" - " WHERE `result`='errory'") - - return stats diff --git a/src/wormhole/test/common.py b/src/wormhole/test/common.py index a0761f2..e227e2d 100644 --- a/src/wormhole/test/common.py +++ b/src/wormhole/test/common.py @@ -1,12 +1,13 @@ # no unicode_literals untill twisted update -from twisted.application import service -from twisted.internet import defer, task, reactor +from twisted.application import service, internet +from twisted.internet import defer, task, reactor, endpoints from twisted.python import log from click.testing import CliRunner import mock from ..cli import cli from ..transit import allocate_tcp_port from ..server.server import RelayServer +from wormhole_transit_relay.transit_server import Transit class ServerBase: def setUp(self): @@ -16,20 +17,26 @@ class ServerBase: self.sp = service.MultiService() self.sp.startService() self.relayport = allocate_tcp_port() - self.transitport = allocate_tcp_port() # need to talk to twisted team about only using unicode in # endpoints.serverFromString s = RelayServer("tcp:%d:interface=127.0.0.1" % self.relayport, - "tcp:%s:interface=127.0.0.1" % self.transitport, + "XXXremovemytransit", advertise_version=advertise_version, signal_error=error) s.setServiceParent(self.sp) self._relay_server = s self._rendezvous = s._rendezvous - self._transit_server = s._transit self.relayurl = u"ws://127.0.0.1:%d/v1" % self.relayport self.rdv_ws_port = self.relayport # ws://127.0.0.1:%d/wormhole-relay/ws + + self.transitport = allocate_tcp_port() + ep = endpoints.serverFromString(reactor, + "tcp:%d:interface=127.0.0.1" % + self.transitport) + self._transit_server = f = Transit(blur_usage=None, log_file=None, + usage_db=None) + internet.StreamServerEndpointService(ep, f).setServiceParent(self.sp) self.transit = u"tcp:127.0.0.1:%d" % self.transitport def tearDown(self): diff --git a/src/wormhole/test/test_server.py b/src/wormhole/test/test_server.py index 8f55919..f9aac5e 100644 --- a/src/wormhole/test/test_server.py +++ b/src/wormhole/test/test_server.py @@ -1403,7 +1403,6 @@ class DumpStats(unittest.TestCase): self.assertEqual(data["created"], now) self.assertEqual(data["valid_until"], now+validity) self.assertEqual(data["rendezvous"]["all_time"]["mailboxes_total"], 0) - self.assertEqual(data["transit"]["all_time"]["total"], 0) class Startup(unittest.TestCase): diff --git a/src/wormhole/test/test_transit.py b/src/wormhole/test/test_transit.py index a710dce..c93d9d3 100644 --- a/src/wormhole/test/test_transit.py +++ b/src/wormhole/test/test_transit.py @@ -10,9 +10,9 @@ from twisted.internet import defer, task, endpoints, protocol, address, error from twisted.internet.defer import gatherResults, inlineCallbacks from twisted.python import log, failure from twisted.test import proto_helpers +from wormhole_transit_relay import transit_server from ..errors import InternalError from .. import transit -from ..server import transit_server from .common import ServerBase from nacl.secret import SecretBox from nacl.exceptions import CryptoError diff --git a/src/wormhole/test/test_transit_server.py b/src/wormhole/test/test_transit_server.py deleted file mode 100644 index acdca61..0000000 --- a/src/wormhole/test/test_transit_server.py +++ /dev/null @@ -1,306 +0,0 @@ -from __future__ import print_function, unicode_literals -from binascii import hexlify -from twisted.trial import unittest -from twisted.internet import protocol, reactor, defer -from twisted.internet.endpoints import clientFromString, connectProtocol -from twisted.web import client -from .common import ServerBase -from ..server import transit_server - -class Accumulator(protocol.Protocol): - def __init__(self): - self.data = b"" - self.count = 0 - self._wait = None - self._disconnect = defer.Deferred() - def waitForBytes(self, more): - assert self._wait is None - self.count = more - self._wait = defer.Deferred() - self._check_done() - return self._wait - def dataReceived(self, data): - self.data = self.data + data - self._check_done() - def _check_done(self): - if self._wait and len(self.data) >= self.count: - d = self._wait - self._wait = None - d.callback(self) - def connectionLost(self, why): - if self._wait: - self._wait.errback(RuntimeError("closed")) - self._disconnect.callback(None) - -class Transit(ServerBase, unittest.TestCase): - def test_blur_size(self): - blur = transit_server.blur_size - self.failUnlessEqual(blur(0), 0) - self.failUnlessEqual(blur(1), 10e3) - self.failUnlessEqual(blur(10e3), 10e3) - self.failUnlessEqual(blur(10e3+1), 20e3) - self.failUnlessEqual(blur(15e3), 20e3) - self.failUnlessEqual(blur(20e3), 20e3) - self.failUnlessEqual(blur(1e6), 1e6) - self.failUnlessEqual(blur(1e6+1), 2e6) - self.failUnlessEqual(blur(1.5e6), 2e6) - self.failUnlessEqual(blur(2e6), 2e6) - self.failUnlessEqual(blur(900e6), 900e6) - self.failUnlessEqual(blur(1000e6), 1000e6) - self.failUnlessEqual(blur(1050e6), 1100e6) - self.failUnlessEqual(blur(1100e6), 1100e6) - self.failUnlessEqual(blur(1150e6), 1200e6) - - @defer.inlineCallbacks - def test_web_request(self): - resp = yield client.getPage('http://127.0.0.1:{}/'.format(self.relayport).encode('ascii')) - self.assertEqual('Wormhole Relay'.encode('ascii'), resp.strip()) - - @defer.inlineCallbacks - def test_register(self): - ep = clientFromString(reactor, self.transit) - a1 = yield connectProtocol(ep, Accumulator()) - - token1 = b"\x00"*32 - side1 = b"\x01"*8 - a1.transport.write(b"please relay " + hexlify(token1) + - b" for side " + hexlify(side1) + b"\n") - - # let that arrive - while self.count() == 0: - yield self.wait() - self.assertEqual(self.count(), 1) - - a1.transport.loseConnection() - - # let that get removed - while self.count() > 0: - yield self.wait() - self.assertEqual(self.count(), 0) - - # the token should be removed too - self.assertEqual(len(self._transit_server._pending_requests), 0) - - @defer.inlineCallbacks - def test_both_unsided(self): - ep = clientFromString(reactor, self.transit) - a1 = yield connectProtocol(ep, Accumulator()) - a2 = yield connectProtocol(ep, Accumulator()) - - token1 = b"\x00"*32 - a1.transport.write(b"please relay " + hexlify(token1) + b"\n") - a2.transport.write(b"please relay " + hexlify(token1) + b"\n") - - # a correct handshake yields an ack, after which we can send - exp = b"ok\n" - yield a1.waitForBytes(len(exp)) - self.assertEqual(a1.data, exp) - s1 = b"data1" - a1.transport.write(s1) - - exp = b"ok\n" - yield a2.waitForBytes(len(exp)) - self.assertEqual(a2.data, exp) - - # all data they sent after the handshake should be given to us - exp = b"ok\n"+s1 - yield a2.waitForBytes(len(exp)) - self.assertEqual(a2.data, exp) - - a1.transport.loseConnection() - a2.transport.loseConnection() - - @defer.inlineCallbacks - def test_sided_unsided(self): - ep = clientFromString(reactor, self.transit) - a1 = yield connectProtocol(ep, Accumulator()) - a2 = yield connectProtocol(ep, Accumulator()) - - token1 = b"\x00"*32 - side1 = b"\x01"*8 - a1.transport.write(b"please relay " + hexlify(token1) + - b" for side " + hexlify(side1) + b"\n") - a2.transport.write(b"please relay " + hexlify(token1) + b"\n") - - # a correct handshake yields an ack, after which we can send - exp = b"ok\n" - yield a1.waitForBytes(len(exp)) - self.assertEqual(a1.data, exp) - s1 = b"data1" - a1.transport.write(s1) - - exp = b"ok\n" - yield a2.waitForBytes(len(exp)) - self.assertEqual(a2.data, exp) - - # all data they sent after the handshake should be given to us - exp = b"ok\n"+s1 - yield a2.waitForBytes(len(exp)) - self.assertEqual(a2.data, exp) - - a1.transport.loseConnection() - a2.transport.loseConnection() - - @defer.inlineCallbacks - def test_unsided_sided(self): - ep = clientFromString(reactor, self.transit) - a1 = yield connectProtocol(ep, Accumulator()) - a2 = yield connectProtocol(ep, Accumulator()) - - token1 = b"\x00"*32 - side1 = b"\x01"*8 - a1.transport.write(b"please relay " + hexlify(token1) + b"\n") - a2.transport.write(b"please relay " + hexlify(token1) + - b" for side " + hexlify(side1) + b"\n") - - # a correct handshake yields an ack, after which we can send - exp = b"ok\n" - yield a1.waitForBytes(len(exp)) - self.assertEqual(a1.data, exp) - s1 = b"data1" - a1.transport.write(s1) - - exp = b"ok\n" - yield a2.waitForBytes(len(exp)) - self.assertEqual(a2.data, exp) - - # all data they sent after the handshake should be given to us - exp = b"ok\n"+s1 - yield a2.waitForBytes(len(exp)) - self.assertEqual(a2.data, exp) - - a1.transport.loseConnection() - a2.transport.loseConnection() - - @defer.inlineCallbacks - def test_both_sided(self): - ep = clientFromString(reactor, self.transit) - a1 = yield connectProtocol(ep, Accumulator()) - a2 = yield connectProtocol(ep, Accumulator()) - - token1 = b"\x00"*32 - side1 = b"\x01"*8 - side2 = b"\x02"*8 - a1.transport.write(b"please relay " + hexlify(token1) + - b" for side " + hexlify(side1) + b"\n") - a2.transport.write(b"please relay " + hexlify(token1) + - b" for side " + hexlify(side2) + b"\n") - - # a correct handshake yields an ack, after which we can send - exp = b"ok\n" - yield a1.waitForBytes(len(exp)) - self.assertEqual(a1.data, exp) - s1 = b"data1" - a1.transport.write(s1) - - exp = b"ok\n" - yield a2.waitForBytes(len(exp)) - self.assertEqual(a2.data, exp) - - # all data they sent after the handshake should be given to us - exp = b"ok\n"+s1 - yield a2.waitForBytes(len(exp)) - self.assertEqual(a2.data, exp) - - a1.transport.loseConnection() - a2.transport.loseConnection() - - def count(self): - return sum([len(potentials) - for potentials - in self._transit_server._pending_requests.values()]) - def wait(self): - d = defer.Deferred() - reactor.callLater(0.001, d.callback, None) - return d - - @defer.inlineCallbacks - def test_ignore_same_side(self): - ep = clientFromString(reactor, self.transit) - a1 = yield connectProtocol(ep, Accumulator()) - a2 = yield connectProtocol(ep, Accumulator()) - - token1 = b"\x00"*32 - side1 = b"\x01"*8 - a1.transport.write(b"please relay " + hexlify(token1) + - b" for side " + hexlify(side1) + b"\n") - # let that arrive - while self.count() == 0: - yield self.wait() - a2.transport.write(b"please relay " + hexlify(token1) + - b" for side " + hexlify(side1) + b"\n") - # let that arrive - while self.count() == 1: - yield self.wait() - self.assertEqual(self.count(), 2) # same-side connections don't match - - a1.transport.loseConnection() - a2.transport.loseConnection() - - @defer.inlineCallbacks - def test_bad_handshake(self): - ep = clientFromString(reactor, self.transit) - a1 = yield connectProtocol(ep, Accumulator()) - - token1 = b"\x00"*32 - # the server waits for the exact number of bytes in the expected - # handshake message. to trigger "bad handshake", we must match. - a1.transport.write(b"please DELAY " + hexlify(token1) + b"\n") - - exp = b"bad handshake\n" - yield a1.waitForBytes(len(exp)) - self.assertEqual(a1.data, exp) - - a1.transport.loseConnection() - - @defer.inlineCallbacks - def test_binary_handshake(self): - ep = clientFromString(reactor, self.transit) - a1 = yield connectProtocol(ep, Accumulator()) - - binary_bad_handshake = b"\x00\x01\xe0\x0f\n\xff" - # the embedded \n makes the server trigger early, before the full - # expected handshake length has arrived. A non-wormhole client - # writing non-ascii junk to the transit port used to trigger a - # UnicodeDecodeError when it tried to coerce the incoming handshake - # to unicode, due to the ("\n" in buf) check. This was fixed to use - # (b"\n" in buf). This exercises the old failure. - a1.transport.write(binary_bad_handshake) - - exp = b"bad handshake\n" - yield a1.waitForBytes(len(exp)) - self.assertEqual(a1.data, exp) - - a1.transport.loseConnection() - - @defer.inlineCallbacks - def test_impatience_old(self): - ep = clientFromString(reactor, self.transit) - a1 = yield connectProtocol(ep, Accumulator()) - - token1 = b"\x00"*32 - # sending too many bytes is impatience. - a1.transport.write(b"please relay " + hexlify(token1) + b"\nNOWNOWNOW") - - exp = b"impatient\n" - yield a1.waitForBytes(len(exp)) - self.assertEqual(a1.data, exp) - - a1.transport.loseConnection() - - @defer.inlineCallbacks - def test_impatience_new(self): - ep = clientFromString(reactor, self.transit) - a1 = yield connectProtocol(ep, Accumulator()) - - token1 = b"\x00"*32 - side1 = b"\x01"*8 - # sending too many bytes is impatience. - a1.transport.write(b"please relay " + hexlify(token1) + - b" for side " + hexlify(side1) + b"\nNOWNOWNOW") - - exp = b"impatient\n" - yield a1.waitForBytes(len(exp)) - self.assertEqual(a1.data, exp) - - a1.transport.loseConnection()