from __future__ import print_function, unicode_literals from binascii import hexlify from twisted.trial import unittest from twisted.internet import protocol, reactor, defer from twisted.internet.endpoints import clientFromString, connectProtocol from twisted.web import client from .common import ServerBase from ..server import transit_server class Accumulator(protocol.Protocol): def __init__(self): self.data = b"" self.count = 0 self._wait = None self._disconnect = defer.Deferred() def waitForBytes(self, more): assert self._wait is None self.count = more self._wait = defer.Deferred() self._check_done() return self._wait def dataReceived(self, data): self.data = self.data + data self._check_done() def _check_done(self): if self._wait and len(self.data) >= self.count: d = self._wait self._wait = None d.callback(self) def connectionLost(self, why): if self._wait: self._wait.errback(RuntimeError("closed")) self._disconnect.callback(None) class Transit(ServerBase, unittest.TestCase): def test_blur_size(self): blur = transit_server.blur_size self.failUnlessEqual(blur(0), 0) self.failUnlessEqual(blur(1), 10e3) self.failUnlessEqual(blur(10e3), 10e3) self.failUnlessEqual(blur(10e3+1), 20e3) self.failUnlessEqual(blur(15e3), 20e3) self.failUnlessEqual(blur(20e3), 20e3) self.failUnlessEqual(blur(1e6), 1e6) self.failUnlessEqual(blur(1e6+1), 2e6) self.failUnlessEqual(blur(1.5e6), 2e6) self.failUnlessEqual(blur(2e6), 2e6) self.failUnlessEqual(blur(900e6), 900e6) self.failUnlessEqual(blur(1000e6), 1000e6) self.failUnlessEqual(blur(1050e6), 1100e6) self.failUnlessEqual(blur(1100e6), 1100e6) self.failUnlessEqual(blur(1150e6), 1200e6) @defer.inlineCallbacks def test_web_request(self): resp = yield client.getPage('http://127.0.0.1:{}/'.format(self.relayport).encode('ascii')) self.assertEqual('Wormhole Relay'.encode('ascii'), resp.strip()) @defer.inlineCallbacks def test_register(self): ep = clientFromString(reactor, self.transit) a1 = yield connectProtocol(ep, Accumulator()) token1 = b"\x00"*32 side1 = b"\x01"*8 a1.transport.write(b"please relay " + hexlify(token1) + b" for side " + hexlify(side1) + b"\n") # let that arrive while self.count() == 0: yield self.wait() self.assertEqual(self.count(), 1) a1.transport.loseConnection() # let that get removed while self.count() > 0: yield self.wait() self.assertEqual(self.count(), 0) # the token should be removed too self.assertEqual(len(self._transit_server._pending_requests), 0) @defer.inlineCallbacks def test_both_unsided(self): ep = clientFromString(reactor, self.transit) a1 = yield connectProtocol(ep, Accumulator()) a2 = yield connectProtocol(ep, Accumulator()) token1 = b"\x00"*32 a1.transport.write(b"please relay " + hexlify(token1) + b"\n") a2.transport.write(b"please relay " + hexlify(token1) + b"\n") # a correct handshake yields an ack, after which we can send exp = b"ok\n" yield a1.waitForBytes(len(exp)) self.assertEqual(a1.data, exp) s1 = b"data1" a1.transport.write(s1) exp = b"ok\n" yield a2.waitForBytes(len(exp)) self.assertEqual(a2.data, exp) # all data they sent after the handshake should be given to us exp = b"ok\n"+s1 yield a2.waitForBytes(len(exp)) self.assertEqual(a2.data, exp) a1.transport.loseConnection() a2.transport.loseConnection() @defer.inlineCallbacks def test_sided_unsided(self): ep = clientFromString(reactor, self.transit) a1 = yield connectProtocol(ep, Accumulator()) a2 = yield connectProtocol(ep, Accumulator()) token1 = b"\x00"*32 side1 = b"\x01"*8 a1.transport.write(b"please relay " + hexlify(token1) + b" for side " + hexlify(side1) + b"\n") a2.transport.write(b"please relay " + hexlify(token1) + b"\n") # a correct handshake yields an ack, after which we can send exp = b"ok\n" yield a1.waitForBytes(len(exp)) self.assertEqual(a1.data, exp) s1 = b"data1" a1.transport.write(s1) exp = b"ok\n" yield a2.waitForBytes(len(exp)) self.assertEqual(a2.data, exp) # all data they sent after the handshake should be given to us exp = b"ok\n"+s1 yield a2.waitForBytes(len(exp)) self.assertEqual(a2.data, exp) a1.transport.loseConnection() a2.transport.loseConnection() @defer.inlineCallbacks def test_unsided_sided(self): ep = clientFromString(reactor, self.transit) a1 = yield connectProtocol(ep, Accumulator()) a2 = yield connectProtocol(ep, Accumulator()) token1 = b"\x00"*32 side1 = b"\x01"*8 a1.transport.write(b"please relay " + hexlify(token1) + b"\n") a2.transport.write(b"please relay " + hexlify(token1) + b" for side " + hexlify(side1) + b"\n") # a correct handshake yields an ack, after which we can send exp = b"ok\n" yield a1.waitForBytes(len(exp)) self.assertEqual(a1.data, exp) s1 = b"data1" a1.transport.write(s1) exp = b"ok\n" yield a2.waitForBytes(len(exp)) self.assertEqual(a2.data, exp) # all data they sent after the handshake should be given to us exp = b"ok\n"+s1 yield a2.waitForBytes(len(exp)) self.assertEqual(a2.data, exp) a1.transport.loseConnection() a2.transport.loseConnection() @defer.inlineCallbacks def test_both_sided(self): ep = clientFromString(reactor, self.transit) a1 = yield connectProtocol(ep, Accumulator()) a2 = yield connectProtocol(ep, Accumulator()) token1 = b"\x00"*32 side1 = b"\x01"*8 side2 = b"\x02"*8 a1.transport.write(b"please relay " + hexlify(token1) + b" for side " + hexlify(side1) + b"\n") a2.transport.write(b"please relay " + hexlify(token1) + b" for side " + hexlify(side2) + b"\n") # a correct handshake yields an ack, after which we can send exp = b"ok\n" yield a1.waitForBytes(len(exp)) self.assertEqual(a1.data, exp) s1 = b"data1" a1.transport.write(s1) exp = b"ok\n" yield a2.waitForBytes(len(exp)) self.assertEqual(a2.data, exp) # all data they sent after the handshake should be given to us exp = b"ok\n"+s1 yield a2.waitForBytes(len(exp)) self.assertEqual(a2.data, exp) a1.transport.loseConnection() a2.transport.loseConnection() def count(self): return sum([len(potentials) for potentials in self._transit_server._pending_requests.values()]) def wait(self): d = defer.Deferred() reactor.callLater(0.001, d.callback, None) return d @defer.inlineCallbacks def test_ignore_same_side(self): ep = clientFromString(reactor, self.transit) a1 = yield connectProtocol(ep, Accumulator()) a2 = yield connectProtocol(ep, Accumulator()) token1 = b"\x00"*32 side1 = b"\x01"*8 a1.transport.write(b"please relay " + hexlify(token1) + b" for side " + hexlify(side1) + b"\n") # let that arrive while self.count() == 0: yield self.wait() a2.transport.write(b"please relay " + hexlify(token1) + b" for side " + hexlify(side1) + b"\n") # let that arrive while self.count() == 1: yield self.wait() self.assertEqual(self.count(), 2) # same-side connections don't match a1.transport.loseConnection() a2.transport.loseConnection() @defer.inlineCallbacks def test_bad_handshake(self): ep = clientFromString(reactor, self.transit) a1 = yield connectProtocol(ep, Accumulator()) token1 = b"\x00"*32 # the server waits for the exact number of bytes in the expected # handshake message. to trigger "bad handshake", we must match. a1.transport.write(b"please DELAY " + hexlify(token1) + b"\n") exp = b"bad handshake\n" yield a1.waitForBytes(len(exp)) self.assertEqual(a1.data, exp) a1.transport.loseConnection() @defer.inlineCallbacks def test_binary_handshake(self): ep = clientFromString(reactor, self.transit) a1 = yield connectProtocol(ep, Accumulator()) binary_bad_handshake = b"\x00\x01\xe0\x0f\n\xff" # the embedded \n makes the server trigger early, before the full # expected handshake length has arrived. A non-wormhole client # writing non-ascii junk to the transit port used to trigger a # UnicodeDecodeError when it tried to coerce the incoming handshake # to unicode, due to the ("\n" in buf) check. This was fixed to use # (b"\n" in buf). This exercises the old failure. a1.transport.write(binary_bad_handshake) exp = b"bad handshake\n" yield a1.waitForBytes(len(exp)) self.assertEqual(a1.data, exp) a1.transport.loseConnection() @defer.inlineCallbacks def test_impatience_old(self): ep = clientFromString(reactor, self.transit) a1 = yield connectProtocol(ep, Accumulator()) token1 = b"\x00"*32 # sending too many bytes is impatience. a1.transport.write(b"please relay " + hexlify(token1) + b"\nNOWNOWNOW") exp = b"impatient\n" yield a1.waitForBytes(len(exp)) self.assertEqual(a1.data, exp) a1.transport.loseConnection() @defer.inlineCallbacks def test_impatience_new(self): ep = clientFromString(reactor, self.transit) a1 = yield connectProtocol(ep, Accumulator()) token1 = b"\x00"*32 side1 = b"\x01"*8 # sending too many bytes is impatience. a1.transport.write(b"please relay " + hexlify(token1) + b" for side " + hexlify(side1) + b"\nNOWNOWNOW") exp = b"impatient\n" yield a1.waitForBytes(len(exp)) self.assertEqual(a1.data, exp) a1.transport.loseConnection()