diff --git a/src/wormhole/test/test_server.py b/src/wormhole/test/test_server.py index cf1a5f7..e0654d8 100644 --- a/src/wormhole/test/test_server.py +++ b/src/wormhole/test/test_server.py @@ -1,17 +1,14 @@ from __future__ import print_function, unicode_literals import os, json, itertools, time -from binascii import hexlify import mock from twisted.trial import unittest from twisted.python import log -from twisted.internet import protocol, reactor, defer +from twisted.internet import reactor, defer from twisted.internet.defer import inlineCallbacks, returnValue -from twisted.internet.endpoints import clientFromString, connectProtocol -from twisted.web import client from autobahn.twisted import websocket from .. import __version__ from .common import ServerBase -from ..server import server, rendezvous, transit_server +from ..server import server, rendezvous from ..server.rendezvous import Usage, SidedMessage from ..server.database import get_db @@ -1066,301 +1063,3 @@ class DumpStats(unittest.TestCase): self.assertEqual(data["rendezvous"]["all_time"]["mailboxes_total"], 0) self.assertEqual(data["transit"]["all_time"]["total"], 0) - -class Accumulator(protocol.Protocol): - def __init__(self): - self.data = b"" - self.count = 0 - self._wait = None - self._disconnect = defer.Deferred() - def waitForBytes(self, more): - assert self._wait is None - self.count = more - self._wait = defer.Deferred() - self._check_done() - return self._wait - def dataReceived(self, data): - self.data = self.data + data - self._check_done() - def _check_done(self): - if self._wait and len(self.data) >= self.count: - d = self._wait - self._wait = None - d.callback(self) - def connectionLost(self, why): - if self._wait: - self._wait.errback(RuntimeError("closed")) - self._disconnect.callback(None) - -class Transit(ServerBase, unittest.TestCase): - def test_blur_size(self): - blur = transit_server.blur_size - self.failUnlessEqual(blur(0), 0) - self.failUnlessEqual(blur(1), 10e3) - self.failUnlessEqual(blur(10e3), 10e3) - self.failUnlessEqual(blur(10e3+1), 20e3) - self.failUnlessEqual(blur(15e3), 20e3) - self.failUnlessEqual(blur(20e3), 20e3) - self.failUnlessEqual(blur(1e6), 1e6) - self.failUnlessEqual(blur(1e6+1), 2e6) - self.failUnlessEqual(blur(1.5e6), 2e6) - self.failUnlessEqual(blur(2e6), 2e6) - self.failUnlessEqual(blur(900e6), 900e6) - self.failUnlessEqual(blur(1000e6), 1000e6) - self.failUnlessEqual(blur(1050e6), 1100e6) - self.failUnlessEqual(blur(1100e6), 1100e6) - self.failUnlessEqual(blur(1150e6), 1200e6) - - @defer.inlineCallbacks - def test_web_request(self): - resp = yield client.getPage('http://127.0.0.1:{}/'.format(self.relayport).encode('ascii')) - self.assertEqual('Wormhole Relay'.encode('ascii'), resp.strip()) - - @defer.inlineCallbacks - def test_register(self): - ep = clientFromString(reactor, self.transit) - a1 = yield connectProtocol(ep, Accumulator()) - - token1 = b"\x00"*32 - side1 = b"\x01"*8 - a1.transport.write(b"please relay " + hexlify(token1) + - b" for side " + hexlify(side1) + b"\n") - - # let that arrive - while self.count() == 0: - yield self.wait() - self.assertEqual(self.count(), 1) - - a1.transport.loseConnection() - - # let that get removed - while self.count() > 0: - yield self.wait() - self.assertEqual(self.count(), 0) - - # the token should be removed too - self.assertEqual(len(self._transit_server._pending_requests), 0) - - @defer.inlineCallbacks - def test_both_unsided(self): - ep = clientFromString(reactor, self.transit) - a1 = yield connectProtocol(ep, Accumulator()) - a2 = yield connectProtocol(ep, Accumulator()) - - token1 = b"\x00"*32 - a1.transport.write(b"please relay " + hexlify(token1) + b"\n") - a2.transport.write(b"please relay " + hexlify(token1) + b"\n") - - # a correct handshake yields an ack, after which we can send - exp = b"ok\n" - yield a1.waitForBytes(len(exp)) - self.assertEqual(a1.data, exp) - s1 = b"data1" - a1.transport.write(s1) - - exp = b"ok\n" - yield a2.waitForBytes(len(exp)) - self.assertEqual(a2.data, exp) - - # all data they sent after the handshake should be given to us - exp = b"ok\n"+s1 - yield a2.waitForBytes(len(exp)) - self.assertEqual(a2.data, exp) - - a1.transport.loseConnection() - a2.transport.loseConnection() - - @defer.inlineCallbacks - def test_sided_unsided(self): - ep = clientFromString(reactor, self.transit) - a1 = yield connectProtocol(ep, Accumulator()) - a2 = yield connectProtocol(ep, Accumulator()) - - token1 = b"\x00"*32 - side1 = b"\x01"*8 - a1.transport.write(b"please relay " + hexlify(token1) + - b" for side " + hexlify(side1) + b"\n") - a2.transport.write(b"please relay " + hexlify(token1) + b"\n") - - # a correct handshake yields an ack, after which we can send - exp = b"ok\n" - yield a1.waitForBytes(len(exp)) - self.assertEqual(a1.data, exp) - s1 = b"data1" - a1.transport.write(s1) - - exp = b"ok\n" - yield a2.waitForBytes(len(exp)) - self.assertEqual(a2.data, exp) - - # all data they sent after the handshake should be given to us - exp = b"ok\n"+s1 - yield a2.waitForBytes(len(exp)) - self.assertEqual(a2.data, exp) - - a1.transport.loseConnection() - a2.transport.loseConnection() - - @defer.inlineCallbacks - def test_unsided_sided(self): - ep = clientFromString(reactor, self.transit) - a1 = yield connectProtocol(ep, Accumulator()) - a2 = yield connectProtocol(ep, Accumulator()) - - token1 = b"\x00"*32 - side1 = b"\x01"*8 - a1.transport.write(b"please relay " + hexlify(token1) + b"\n") - a2.transport.write(b"please relay " + hexlify(token1) + - b" for side " + hexlify(side1) + b"\n") - - # a correct handshake yields an ack, after which we can send - exp = b"ok\n" - yield a1.waitForBytes(len(exp)) - self.assertEqual(a1.data, exp) - s1 = b"data1" - a1.transport.write(s1) - - exp = b"ok\n" - yield a2.waitForBytes(len(exp)) - self.assertEqual(a2.data, exp) - - # all data they sent after the handshake should be given to us - exp = b"ok\n"+s1 - yield a2.waitForBytes(len(exp)) - self.assertEqual(a2.data, exp) - - a1.transport.loseConnection() - a2.transport.loseConnection() - - @defer.inlineCallbacks - def test_both_sided(self): - ep = clientFromString(reactor, self.transit) - a1 = yield connectProtocol(ep, Accumulator()) - a2 = yield connectProtocol(ep, Accumulator()) - - token1 = b"\x00"*32 - side1 = b"\x01"*8 - side2 = b"\x02"*8 - a1.transport.write(b"please relay " + hexlify(token1) + - b" for side " + hexlify(side1) + b"\n") - a2.transport.write(b"please relay " + hexlify(token1) + - b" for side " + hexlify(side2) + b"\n") - - # a correct handshake yields an ack, after which we can send - exp = b"ok\n" - yield a1.waitForBytes(len(exp)) - self.assertEqual(a1.data, exp) - s1 = b"data1" - a1.transport.write(s1) - - exp = b"ok\n" - yield a2.waitForBytes(len(exp)) - self.assertEqual(a2.data, exp) - - # all data they sent after the handshake should be given to us - exp = b"ok\n"+s1 - yield a2.waitForBytes(len(exp)) - self.assertEqual(a2.data, exp) - - a1.transport.loseConnection() - a2.transport.loseConnection() - - def count(self): - return sum([len(potentials) - for potentials - in self._transit_server._pending_requests.values()]) - def wait(self): - d = defer.Deferred() - reactor.callLater(0.001, d.callback, None) - return d - - @defer.inlineCallbacks - def test_ignore_same_side(self): - ep = clientFromString(reactor, self.transit) - a1 = yield connectProtocol(ep, Accumulator()) - a2 = yield connectProtocol(ep, Accumulator()) - - token1 = b"\x00"*32 - side1 = b"\x01"*8 - a1.transport.write(b"please relay " + hexlify(token1) + - b" for side " + hexlify(side1) + b"\n") - # let that arrive - while self.count() == 0: - yield self.wait() - a2.transport.write(b"please relay " + hexlify(token1) + - b" for side " + hexlify(side1) + b"\n") - # let that arrive - while self.count() == 1: - yield self.wait() - self.assertEqual(self.count(), 2) # same-side connections don't match - - a1.transport.loseConnection() - a2.transport.loseConnection() - - @defer.inlineCallbacks - def test_bad_handshake(self): - ep = clientFromString(reactor, self.transit) - a1 = yield connectProtocol(ep, Accumulator()) - - token1 = b"\x00"*32 - # the server waits for the exact number of bytes in the expected - # handshake message. to trigger "bad handshake", we must match. - a1.transport.write(b"please DELAY " + hexlify(token1) + b"\n") - - exp = b"bad handshake\n" - yield a1.waitForBytes(len(exp)) - self.assertEqual(a1.data, exp) - - a1.transport.loseConnection() - - @defer.inlineCallbacks - def test_binary_handshake(self): - ep = clientFromString(reactor, self.transit) - a1 = yield connectProtocol(ep, Accumulator()) - - binary_bad_handshake = b"\x00\x01\xe0\x0f\n\xff" - # the embedded \n makes the server trigger early, before the full - # expected handshake length has arrived. A non-wormhole client - # writing non-ascii junk to the transit port used to trigger a - # UnicodeDecodeError when it tried to coerce the incoming handshake - # to unicode, due to the ("\n" in buf) check. This was fixed to use - # (b"\n" in buf). This exercises the old failure. - a1.transport.write(binary_bad_handshake) - - exp = b"bad handshake\n" - yield a1.waitForBytes(len(exp)) - self.assertEqual(a1.data, exp) - - a1.transport.loseConnection() - - @defer.inlineCallbacks - def test_impatience_old(self): - ep = clientFromString(reactor, self.transit) - a1 = yield connectProtocol(ep, Accumulator()) - - token1 = b"\x00"*32 - # sending too many bytes is impatience. - a1.transport.write(b"please relay " + hexlify(token1) + b"\nNOWNOWNOW") - - exp = b"impatient\n" - yield a1.waitForBytes(len(exp)) - self.assertEqual(a1.data, exp) - - a1.transport.loseConnection() - - @defer.inlineCallbacks - def test_impatience_new(self): - ep = clientFromString(reactor, self.transit) - a1 = yield connectProtocol(ep, Accumulator()) - - token1 = b"\x00"*32 - side1 = b"\x01"*8 - # sending too many bytes is impatience. - a1.transport.write(b"please relay " + hexlify(token1) + - b" for side " + hexlify(side1) + b"\nNOWNOWNOW") - - exp = b"impatient\n" - yield a1.waitForBytes(len(exp)) - self.assertEqual(a1.data, exp) - - a1.transport.loseConnection() diff --git a/src/wormhole/test/test_transit_server.py b/src/wormhole/test/test_transit_server.py new file mode 100644 index 0000000..acdca61 --- /dev/null +++ b/src/wormhole/test/test_transit_server.py @@ -0,0 +1,306 @@ +from __future__ import print_function, unicode_literals +from binascii import hexlify +from twisted.trial import unittest +from twisted.internet import protocol, reactor, defer +from twisted.internet.endpoints import clientFromString, connectProtocol +from twisted.web import client +from .common import ServerBase +from ..server import transit_server + +class Accumulator(protocol.Protocol): + def __init__(self): + self.data = b"" + self.count = 0 + self._wait = None + self._disconnect = defer.Deferred() + def waitForBytes(self, more): + assert self._wait is None + self.count = more + self._wait = defer.Deferred() + self._check_done() + return self._wait + def dataReceived(self, data): + self.data = self.data + data + self._check_done() + def _check_done(self): + if self._wait and len(self.data) >= self.count: + d = self._wait + self._wait = None + d.callback(self) + def connectionLost(self, why): + if self._wait: + self._wait.errback(RuntimeError("closed")) + self._disconnect.callback(None) + +class Transit(ServerBase, unittest.TestCase): + def test_blur_size(self): + blur = transit_server.blur_size + self.failUnlessEqual(blur(0), 0) + self.failUnlessEqual(blur(1), 10e3) + self.failUnlessEqual(blur(10e3), 10e3) + self.failUnlessEqual(blur(10e3+1), 20e3) + self.failUnlessEqual(blur(15e3), 20e3) + self.failUnlessEqual(blur(20e3), 20e3) + self.failUnlessEqual(blur(1e6), 1e6) + self.failUnlessEqual(blur(1e6+1), 2e6) + self.failUnlessEqual(blur(1.5e6), 2e6) + self.failUnlessEqual(blur(2e6), 2e6) + self.failUnlessEqual(blur(900e6), 900e6) + self.failUnlessEqual(blur(1000e6), 1000e6) + self.failUnlessEqual(blur(1050e6), 1100e6) + self.failUnlessEqual(blur(1100e6), 1100e6) + self.failUnlessEqual(blur(1150e6), 1200e6) + + @defer.inlineCallbacks + def test_web_request(self): + resp = yield client.getPage('http://127.0.0.1:{}/'.format(self.relayport).encode('ascii')) + self.assertEqual('Wormhole Relay'.encode('ascii'), resp.strip()) + + @defer.inlineCallbacks + def test_register(self): + ep = clientFromString(reactor, self.transit) + a1 = yield connectProtocol(ep, Accumulator()) + + token1 = b"\x00"*32 + side1 = b"\x01"*8 + a1.transport.write(b"please relay " + hexlify(token1) + + b" for side " + hexlify(side1) + b"\n") + + # let that arrive + while self.count() == 0: + yield self.wait() + self.assertEqual(self.count(), 1) + + a1.transport.loseConnection() + + # let that get removed + while self.count() > 0: + yield self.wait() + self.assertEqual(self.count(), 0) + + # the token should be removed too + self.assertEqual(len(self._transit_server._pending_requests), 0) + + @defer.inlineCallbacks + def test_both_unsided(self): + ep = clientFromString(reactor, self.transit) + a1 = yield connectProtocol(ep, Accumulator()) + a2 = yield connectProtocol(ep, Accumulator()) + + token1 = b"\x00"*32 + a1.transport.write(b"please relay " + hexlify(token1) + b"\n") + a2.transport.write(b"please relay " + hexlify(token1) + b"\n") + + # a correct handshake yields an ack, after which we can send + exp = b"ok\n" + yield a1.waitForBytes(len(exp)) + self.assertEqual(a1.data, exp) + s1 = b"data1" + a1.transport.write(s1) + + exp = b"ok\n" + yield a2.waitForBytes(len(exp)) + self.assertEqual(a2.data, exp) + + # all data they sent after the handshake should be given to us + exp = b"ok\n"+s1 + yield a2.waitForBytes(len(exp)) + self.assertEqual(a2.data, exp) + + a1.transport.loseConnection() + a2.transport.loseConnection() + + @defer.inlineCallbacks + def test_sided_unsided(self): + ep = clientFromString(reactor, self.transit) + a1 = yield connectProtocol(ep, Accumulator()) + a2 = yield connectProtocol(ep, Accumulator()) + + token1 = b"\x00"*32 + side1 = b"\x01"*8 + a1.transport.write(b"please relay " + hexlify(token1) + + b" for side " + hexlify(side1) + b"\n") + a2.transport.write(b"please relay " + hexlify(token1) + b"\n") + + # a correct handshake yields an ack, after which we can send + exp = b"ok\n" + yield a1.waitForBytes(len(exp)) + self.assertEqual(a1.data, exp) + s1 = b"data1" + a1.transport.write(s1) + + exp = b"ok\n" + yield a2.waitForBytes(len(exp)) + self.assertEqual(a2.data, exp) + + # all data they sent after the handshake should be given to us + exp = b"ok\n"+s1 + yield a2.waitForBytes(len(exp)) + self.assertEqual(a2.data, exp) + + a1.transport.loseConnection() + a2.transport.loseConnection() + + @defer.inlineCallbacks + def test_unsided_sided(self): + ep = clientFromString(reactor, self.transit) + a1 = yield connectProtocol(ep, Accumulator()) + a2 = yield connectProtocol(ep, Accumulator()) + + token1 = b"\x00"*32 + side1 = b"\x01"*8 + a1.transport.write(b"please relay " + hexlify(token1) + b"\n") + a2.transport.write(b"please relay " + hexlify(token1) + + b" for side " + hexlify(side1) + b"\n") + + # a correct handshake yields an ack, after which we can send + exp = b"ok\n" + yield a1.waitForBytes(len(exp)) + self.assertEqual(a1.data, exp) + s1 = b"data1" + a1.transport.write(s1) + + exp = b"ok\n" + yield a2.waitForBytes(len(exp)) + self.assertEqual(a2.data, exp) + + # all data they sent after the handshake should be given to us + exp = b"ok\n"+s1 + yield a2.waitForBytes(len(exp)) + self.assertEqual(a2.data, exp) + + a1.transport.loseConnection() + a2.transport.loseConnection() + + @defer.inlineCallbacks + def test_both_sided(self): + ep = clientFromString(reactor, self.transit) + a1 = yield connectProtocol(ep, Accumulator()) + a2 = yield connectProtocol(ep, Accumulator()) + + token1 = b"\x00"*32 + side1 = b"\x01"*8 + side2 = b"\x02"*8 + a1.transport.write(b"please relay " + hexlify(token1) + + b" for side " + hexlify(side1) + b"\n") + a2.transport.write(b"please relay " + hexlify(token1) + + b" for side " + hexlify(side2) + b"\n") + + # a correct handshake yields an ack, after which we can send + exp = b"ok\n" + yield a1.waitForBytes(len(exp)) + self.assertEqual(a1.data, exp) + s1 = b"data1" + a1.transport.write(s1) + + exp = b"ok\n" + yield a2.waitForBytes(len(exp)) + self.assertEqual(a2.data, exp) + + # all data they sent after the handshake should be given to us + exp = b"ok\n"+s1 + yield a2.waitForBytes(len(exp)) + self.assertEqual(a2.data, exp) + + a1.transport.loseConnection() + a2.transport.loseConnection() + + def count(self): + return sum([len(potentials) + for potentials + in self._transit_server._pending_requests.values()]) + def wait(self): + d = defer.Deferred() + reactor.callLater(0.001, d.callback, None) + return d + + @defer.inlineCallbacks + def test_ignore_same_side(self): + ep = clientFromString(reactor, self.transit) + a1 = yield connectProtocol(ep, Accumulator()) + a2 = yield connectProtocol(ep, Accumulator()) + + token1 = b"\x00"*32 + side1 = b"\x01"*8 + a1.transport.write(b"please relay " + hexlify(token1) + + b" for side " + hexlify(side1) + b"\n") + # let that arrive + while self.count() == 0: + yield self.wait() + a2.transport.write(b"please relay " + hexlify(token1) + + b" for side " + hexlify(side1) + b"\n") + # let that arrive + while self.count() == 1: + yield self.wait() + self.assertEqual(self.count(), 2) # same-side connections don't match + + a1.transport.loseConnection() + a2.transport.loseConnection() + + @defer.inlineCallbacks + def test_bad_handshake(self): + ep = clientFromString(reactor, self.transit) + a1 = yield connectProtocol(ep, Accumulator()) + + token1 = b"\x00"*32 + # the server waits for the exact number of bytes in the expected + # handshake message. to trigger "bad handshake", we must match. + a1.transport.write(b"please DELAY " + hexlify(token1) + b"\n") + + exp = b"bad handshake\n" + yield a1.waitForBytes(len(exp)) + self.assertEqual(a1.data, exp) + + a1.transport.loseConnection() + + @defer.inlineCallbacks + def test_binary_handshake(self): + ep = clientFromString(reactor, self.transit) + a1 = yield connectProtocol(ep, Accumulator()) + + binary_bad_handshake = b"\x00\x01\xe0\x0f\n\xff" + # the embedded \n makes the server trigger early, before the full + # expected handshake length has arrived. A non-wormhole client + # writing non-ascii junk to the transit port used to trigger a + # UnicodeDecodeError when it tried to coerce the incoming handshake + # to unicode, due to the ("\n" in buf) check. This was fixed to use + # (b"\n" in buf). This exercises the old failure. + a1.transport.write(binary_bad_handshake) + + exp = b"bad handshake\n" + yield a1.waitForBytes(len(exp)) + self.assertEqual(a1.data, exp) + + a1.transport.loseConnection() + + @defer.inlineCallbacks + def test_impatience_old(self): + ep = clientFromString(reactor, self.transit) + a1 = yield connectProtocol(ep, Accumulator()) + + token1 = b"\x00"*32 + # sending too many bytes is impatience. + a1.transport.write(b"please relay " + hexlify(token1) + b"\nNOWNOWNOW") + + exp = b"impatient\n" + yield a1.waitForBytes(len(exp)) + self.assertEqual(a1.data, exp) + + a1.transport.loseConnection() + + @defer.inlineCallbacks + def test_impatience_new(self): + ep = clientFromString(reactor, self.transit) + a1 = yield connectProtocol(ep, Accumulator()) + + token1 = b"\x00"*32 + side1 = b"\x01"*8 + # sending too many bytes is impatience. + a1.transport.write(b"please relay " + hexlify(token1) + + b" for side " + hexlify(side1) + b"\nNOWNOWNOW") + + exp = b"impatient\n" + yield a1.waitForBytes(len(exp)) + self.assertEqual(a1.data, exp) + + a1.transport.loseConnection()