diff --git a/src/wormhole/test/test_server.py b/src/wormhole/test/test_server.py index 63dd643..09bea1e 100644 --- a/src/wormhole/test/test_server.py +++ b/src/wormhole/test/test_server.py @@ -1,10 +1,12 @@ from __future__ import print_function import json import requests +from binascii import hexlify from six.moves.urllib_parse import urlencode from twisted.trial import unittest -from twisted.internet import reactor, defer +from twisted.internet import protocol, reactor, defer from twisted.internet.threads import deferToThread +from twisted.internet.endpoints import clientFromString, connectProtocol from twisted.web.client import getPage, Agent, readBody from .. import __version__ from .common import ServerBase @@ -434,7 +436,30 @@ class Summary(unittest.TestCase): self.failUnlessEqual(c._summarize(make_moods(None, "scary"), 41), (1, "scary", 40, 9)) -class Transit(unittest.TestCase): +class Accumulator(protocol.Protocol): + def __init__(self): + self.data = b"" + self.count = 0 + self._wait = None + def waitForBytes(self, more): + assert self._wait is None + self.count = more + self._wait = defer.Deferred() + self._check_done() + return self._wait + def dataReceived(self, data): + self.data = self.data + data + self._check_done() + def _check_done(self): + if self._wait and len(self.data) >= self.count: + d = self._wait + self._wait = None + d.callback(self) + def connectionLost(self, why): + if self._wait: + self._wait.errback(RuntimeError("closed")) + +class Transit(ServerBase, unittest.TestCase): def test_blur_size(self): blur = transit_server.blur_size self.failUnlessEqual(blur(0), 0) @@ -453,3 +478,62 @@ class Transit(unittest.TestCase): self.failUnlessEqual(blur(1100e6), 1100e6) self.failUnlessEqual(blur(1150e6), 1200e6) + @defer.inlineCallbacks + def test_basic(self): + ep = clientFromString(reactor, self.transit) + a1 = yield connectProtocol(ep, Accumulator()) + a2 = yield connectProtocol(ep, Accumulator()) + + token1 = b"\x00"*32 + a1.transport.write(b"please relay " + hexlify(token1) + b"\n") + a2.transport.write(b"please relay " + hexlify(token1) + b"\n") + + # a correct handshake yields an ack, after which we can send + exp = b"ok\n" + yield a1.waitForBytes(len(exp)) + self.assertEqual(a1.data, exp) + s1 = b"data1" + a1.transport.write(s1) + + exp = b"ok\n" + yield a2.waitForBytes(len(exp)) + self.assertEqual(a2.data, exp) + + # all data they sent after the handshake should be given to us + exp = b"ok\n"+s1 + yield a2.waitForBytes(len(exp)) + self.assertEqual(a2.data, exp) + + a1.transport.loseConnection() + a2.transport.loseConnection() + + @defer.inlineCallbacks + def test_bad_handshake(self): + ep = clientFromString(reactor, self.transit) + a1 = yield connectProtocol(ep, Accumulator()) + + token1 = b"\x00"*32 + # the server waits for the exact number of bytes in the expected + # handshake message. to trigger "bad handshake", we must match. + a1.transport.write(b"please DELAY " + hexlify(token1) + b"\n") + + exp = b"bad handshake\n" + yield a1.waitForBytes(len(exp)) + self.assertEqual(a1.data, exp) + + a1.transport.loseConnection() + + @defer.inlineCallbacks + def test_impatience(self): + ep = clientFromString(reactor, self.transit) + a1 = yield connectProtocol(ep, Accumulator()) + + token1 = b"\x00"*32 + # sending too many bytes is impatience. + a1.transport.write(b"please RELAY NOWNOW " + hexlify(token1) + b"\n") + + exp = b"impatient\n" + yield a1.waitForBytes(len(exp)) + self.assertEqual(a1.data, exp) + + a1.transport.loseConnection() diff --git a/src/wormhole/test/test_transit.py b/src/wormhole/test/test_transit.py deleted file mode 100644 index 0ebb750..0000000 --- a/src/wormhole/test/test_transit.py +++ /dev/null @@ -1,90 +0,0 @@ -from __future__ import print_function -from binascii import hexlify -from twisted.trial import unittest -from twisted.internet import protocol, defer, reactor -from twisted.internet.endpoints import clientFromString, connectProtocol -from .common import ServerBase - -class Accumulator(protocol.Protocol): - def __init__(self): - self.data = b"" - self.count = 0 - self._wait = None - def waitForBytes(self, more): - assert self._wait is None - self.count = more - self._wait = defer.Deferred() - self._check_done() - return self._wait - def dataReceived(self, data): - self.data = self.data + data - self._check_done() - def _check_done(self): - if self._wait and len(self.data) >= self.count: - d = self._wait - self._wait = None - d.callback(self) - def connectionLost(self, why): - if self._wait: - self._wait.errback(RuntimeError("closed")) - -class Transit(ServerBase, unittest.TestCase): - @defer.inlineCallbacks - def test_basic(self): - ep = clientFromString(reactor, self.transit) - a1 = yield connectProtocol(ep, Accumulator()) - a2 = yield connectProtocol(ep, Accumulator()) - - token1 = b"\x00"*32 - a1.transport.write(b"please relay " + hexlify(token1) + b"\n") - a2.transport.write(b"please relay " + hexlify(token1) + b"\n") - - # a correct handshake yields an ack, after which we can send - exp = b"ok\n" - yield a1.waitForBytes(len(exp)) - self.assertEqual(a1.data, exp) - s1 = b"data1" - a1.transport.write(s1) - - exp = b"ok\n" - yield a2.waitForBytes(len(exp)) - self.assertEqual(a2.data, exp) - - # all data they sent after the handshake should be given to us - exp = b"ok\n"+s1 - yield a2.waitForBytes(len(exp)) - self.assertEqual(a2.data, exp) - - a1.transport.loseConnection() - a2.transport.loseConnection() - - @defer.inlineCallbacks - def test_bad_handshake(self): - ep = clientFromString(reactor, self.transit) - a1 = yield connectProtocol(ep, Accumulator()) - - token1 = b"\x00"*32 - # the server waits for the exact number of bytes in the expected - # handshake message. to trigger "bad handshake", we must match. - a1.transport.write(b"please DELAY " + hexlify(token1) + b"\n") - - exp = b"bad handshake\n" - yield a1.waitForBytes(len(exp)) - self.assertEqual(a1.data, exp) - - a1.transport.loseConnection() - - @defer.inlineCallbacks - def test_impatience(self): - ep = clientFromString(reactor, self.transit) - a1 = yield connectProtocol(ep, Accumulator()) - - token1 = b"\x00"*32 - # sending too many bytes is impatience. - a1.transport.write(b"please RELAY NOWNOW " + hexlify(token1) + b"\n") - - exp = b"impatient\n" - yield a1.waitForBytes(len(exp)) - self.assertEqual(a1.data, exp) - - a1.transport.loseConnection()