merge test_transit.py into test_server.py

This commit is contained in:
Brian Warner 2016-04-15 16:14:25 -07:00
parent 68b22bec97
commit 8c67a98259
2 changed files with 86 additions and 92 deletions

View File

@ -1,10 +1,12 @@
from __future__ import print_function
import json
import requests
from binascii import hexlify
from six.moves.urllib_parse import urlencode
from twisted.trial import unittest
from twisted.internet import reactor, defer
from twisted.internet import protocol, reactor, defer
from twisted.internet.threads import deferToThread
from twisted.internet.endpoints import clientFromString, connectProtocol
from twisted.web.client import getPage, Agent, readBody
from .. import __version__
from .common import ServerBase
@ -434,7 +436,30 @@ class Summary(unittest.TestCase):
self.failUnlessEqual(c._summarize(make_moods(None, "scary"), 41),
(1, "scary", 40, 9))
class Transit(unittest.TestCase):
class Accumulator(protocol.Protocol):
def __init__(self):
self.data = b""
self.count = 0
self._wait = None
def waitForBytes(self, more):
assert self._wait is None
self.count = more
self._wait = defer.Deferred()
self._check_done()
return self._wait
def dataReceived(self, data):
self.data = self.data + data
self._check_done()
def _check_done(self):
if self._wait and len(self.data) >= self.count:
d = self._wait
self._wait = None
d.callback(self)
def connectionLost(self, why):
if self._wait:
self._wait.errback(RuntimeError("closed"))
class Transit(ServerBase, unittest.TestCase):
def test_blur_size(self):
blur = transit_server.blur_size
self.failUnlessEqual(blur(0), 0)
@ -453,3 +478,62 @@ class Transit(unittest.TestCase):
self.failUnlessEqual(blur(1100e6), 1100e6)
self.failUnlessEqual(blur(1150e6), 1200e6)
@defer.inlineCallbacks
def test_basic(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
a2 = yield connectProtocol(ep, Accumulator())
token1 = b"\x00"*32
a1.transport.write(b"please relay " + hexlify(token1) + b"\n")
a2.transport.write(b"please relay " + hexlify(token1) + b"\n")
# a correct handshake yields an ack, after which we can send
exp = b"ok\n"
yield a1.waitForBytes(len(exp))
self.assertEqual(a1.data, exp)
s1 = b"data1"
a1.transport.write(s1)
exp = b"ok\n"
yield a2.waitForBytes(len(exp))
self.assertEqual(a2.data, exp)
# all data they sent after the handshake should be given to us
exp = b"ok\n"+s1
yield a2.waitForBytes(len(exp))
self.assertEqual(a2.data, exp)
a1.transport.loseConnection()
a2.transport.loseConnection()
@defer.inlineCallbacks
def test_bad_handshake(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
token1 = b"\x00"*32
# the server waits for the exact number of bytes in the expected
# handshake message. to trigger "bad handshake", we must match.
a1.transport.write(b"please DELAY " + hexlify(token1) + b"\n")
exp = b"bad handshake\n"
yield a1.waitForBytes(len(exp))
self.assertEqual(a1.data, exp)
a1.transport.loseConnection()
@defer.inlineCallbacks
def test_impatience(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
token1 = b"\x00"*32
# sending too many bytes is impatience.
a1.transport.write(b"please RELAY NOWNOW " + hexlify(token1) + b"\n")
exp = b"impatient\n"
yield a1.waitForBytes(len(exp))
self.assertEqual(a1.data, exp)
a1.transport.loseConnection()

View File

@ -1,90 +0,0 @@
from __future__ import print_function
from binascii import hexlify
from twisted.trial import unittest
from twisted.internet import protocol, defer, reactor
from twisted.internet.endpoints import clientFromString, connectProtocol
from .common import ServerBase
class Accumulator(protocol.Protocol):
def __init__(self):
self.data = b""
self.count = 0
self._wait = None
def waitForBytes(self, more):
assert self._wait is None
self.count = more
self._wait = defer.Deferred()
self._check_done()
return self._wait
def dataReceived(self, data):
self.data = self.data + data
self._check_done()
def _check_done(self):
if self._wait and len(self.data) >= self.count:
d = self._wait
self._wait = None
d.callback(self)
def connectionLost(self, why):
if self._wait:
self._wait.errback(RuntimeError("closed"))
class Transit(ServerBase, unittest.TestCase):
@defer.inlineCallbacks
def test_basic(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
a2 = yield connectProtocol(ep, Accumulator())
token1 = b"\x00"*32
a1.transport.write(b"please relay " + hexlify(token1) + b"\n")
a2.transport.write(b"please relay " + hexlify(token1) + b"\n")
# a correct handshake yields an ack, after which we can send
exp = b"ok\n"
yield a1.waitForBytes(len(exp))
self.assertEqual(a1.data, exp)
s1 = b"data1"
a1.transport.write(s1)
exp = b"ok\n"
yield a2.waitForBytes(len(exp))
self.assertEqual(a2.data, exp)
# all data they sent after the handshake should be given to us
exp = b"ok\n"+s1
yield a2.waitForBytes(len(exp))
self.assertEqual(a2.data, exp)
a1.transport.loseConnection()
a2.transport.loseConnection()
@defer.inlineCallbacks
def test_bad_handshake(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
token1 = b"\x00"*32
# the server waits for the exact number of bytes in the expected
# handshake message. to trigger "bad handshake", we must match.
a1.transport.write(b"please DELAY " + hexlify(token1) + b"\n")
exp = b"bad handshake\n"
yield a1.waitForBytes(len(exp))
self.assertEqual(a1.data, exp)
a1.transport.loseConnection()
@defer.inlineCallbacks
def test_impatience(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
token1 = b"\x00"*32
# sending too many bytes is impatience.
a1.transport.write(b"please RELAY NOWNOW " + hexlify(token1) + b"\n")
exp = b"impatient\n"
yield a1.waitForBytes(len(exp))
self.assertEqual(a1.data, exp)
a1.transport.loseConnection()