diff --git a/src/wormhole/servers/transit_server.py b/src/wormhole/servers/transit_server.py index 6600752..122ac5e 100644 --- a/src/wormhole/servers/transit_server.py +++ b/src/wormhole/servers/transit_server.py @@ -1,5 +1,6 @@ from __future__ import print_function import re +from twisted.python import log from twisted.internet import protocol from twisted.application import service @@ -24,27 +25,27 @@ class TransitConnection(protocol.Protocol): self.buddy.transport.write(data) return if self.got_token: # but not yet sent_ok - self.transport.write("impatient\n") - print("transit impatience failure") + self.transport.write(b"impatient\n") + log.msg("transit impatience failure") return self.disconnect() # impatience yields failure # else this should be (part of) the token self.token_buffer += data buf = self.token_buffer wanted = len("please relay \n")+32*2 if len(buf) < wanted-1 and "\n" in buf: - self.transport.write("bad handshake\n") - print("transit handshake early failure") + self.transport.write(b"bad handshake\n") + log.msg("transit handshake early failure") return self.disconnect() if len(buf) < wanted: return if len(buf) > wanted: - self.transport.write("impatient\n") - print("transit impatience failure") + self.transport.write(b"impatient\n") + log.msg("transit impatience failure") return self.disconnect() # impatience yields failure - mo = re.search(r"^please relay (\w{64})\n", buf, re.M) + mo = re.search(br"^please relay (\w{64})\n", buf, re.M) if not mo: - self.transport.write("bad handshake\n") - print("transit handshake failure") + self.transport.write(b"bad handshake\n") + log.msg("transit handshake failure") return self.disconnect() # incorrectness yields failure token = mo.group(1) @@ -58,12 +59,12 @@ class TransitConnection(protocol.Protocol): # TODO: connect as producer/consumer def buddy_disconnected(self): - print("buddy_disconnected %r" % self) + log.msg("buddy_disconnected %r" % self) self.buddy = None self.transport.loseConnection() def connectionLost(self, reason): - print("connectionLost %r %s" % (self, reason)) + log.msg("connectionLost %r %s" % (self, reason)) if self.buddy: self.buddy.buddy_disconnected() self.factory.transitFinished(self, self.total_sent) @@ -106,7 +107,7 @@ class Transit(protocol.ServerFactory, service.MultiService): def connection_got_token(self, token, p): if token in self.pending_requests: - print("transit relay 2: %r" % token) + log.msg("transit relay 2: %r" % token) buddy = self.pending_requests.pop(token) self.active_connections.add(p) self.active_connections.add(buddy) @@ -114,10 +115,10 @@ class Transit(protocol.ServerFactory, service.MultiService): buddy.buddy_connected(p) else: self.pending_requests[token] = p - print("transit relay 1: %r" % token) + log.msg("transit relay 1: %r" % token) # TODO: timer def transitFinished(self, p, total_sent): - print("transitFinished (%dB) %r" % (total_sent, p)) + log.msg("transitFinished (%dB) %r" % (total_sent, p)) for token,tc in self.pending_requests.items(): if tc is p: del self.pending_requests[token] @@ -125,5 +126,5 @@ class Transit(protocol.ServerFactory, service.MultiService): self.active_connections.discard(p) def transitFailed(self, p): - print("transitFailed %r" % p) + log.msg("transitFailed %r" % p) pass diff --git a/src/wormhole/test/test_transit.py b/src/wormhole/test/test_transit.py new file mode 100644 index 0000000..576abad --- /dev/null +++ b/src/wormhole/test/test_transit.py @@ -0,0 +1,90 @@ +from __future__ import print_function +from binascii import hexlify +from twisted.trial import unittest +from twisted.internet import protocol, defer, reactor +from twisted.internet.endpoints import clientFromString, connectProtocol +from .common import ServerBase + +class Accumulator(protocol.Protocol): + def __init__(self): + self.data = b"" + self.count = 0 + self._wait = None + def waitForBytes(self, more): + assert self._wait is None + self.count = more + self._wait = defer.Deferred() + self._check_done() + return self._wait + def dataReceived(self, data): + self.data = self.data + data + self._check_done() + def _check_done(self): + if self._wait and len(self.data) >= self.count: + d = self._wait + self._wait = None + d.callback(self) + def connectionLost(self, why): + if self._wait: + self._wait.errback("closed") + +class Transit(ServerBase, unittest.TestCase): + @defer.inlineCallbacks + def test_basic(self): + ep = clientFromString(reactor, self.transit) + a1 = yield connectProtocol(ep, Accumulator()) + a2 = yield connectProtocol(ep, Accumulator()) + + token1 = b"\x00"*32 + a1.transport.write(b"please relay %s\n" % hexlify(token1)) + a2.transport.write(b"please relay %s\n" % hexlify(token1)) + + # a correct handshake yields an ack, after which we can send + exp = b"ok\n" + yield a1.waitForBytes(len(exp)) + self.assertEqual(a1.data, exp) + s1 = b"data1" + a1.transport.write(s1) + + exp = b"ok\n" + yield a2.waitForBytes(len(exp)) + self.assertEqual(a2.data, exp) + + # all data they sent after the handshake should be given to us + exp = b"ok\n"+s1 + yield a2.waitForBytes(len(exp)) + self.assertEqual(a2.data, exp) + + a1.transport.loseConnection() + a2.transport.loseConnection() + + @defer.inlineCallbacks + def test_bad_handshake(self): + ep = clientFromString(reactor, self.transit) + a1 = yield connectProtocol(ep, Accumulator()) + + token1 = b"\x00"*32 + # the server waits for the exact number of bytes in the expected + # handshake message. to trigger "bad handshake", we must match. + a1.transport.write(b"please DELAY %s\n" % hexlify(token1)) + + exp = b"bad handshake\n" + yield a1.waitForBytes(len(exp)) + self.assertEqual(a1.data, exp) + + a1.transport.loseConnection() + + @defer.inlineCallbacks + def test_impatience(self): + ep = clientFromString(reactor, self.transit) + a1 = yield connectProtocol(ep, Accumulator()) + + token1 = b"\x00"*32 + # sending too many bytes is impatience. + a1.transport.write(b"please RELAY NOWNOW %s\n" % hexlify(token1)) + + exp = b"impatient\n" + yield a1.waitForBytes(len(exp)) + self.assertEqual(a1.data, exp) + + a1.transport.loseConnection()