all tests pass

This commit is contained in:
meejah 2021-04-02 15:50:37 -06:00
parent b9c2bbc524
commit 5e21a3c35a
2 changed files with 153 additions and 96 deletions

View File

@ -43,12 +43,25 @@ class ServerBase:
""" """
Speak the transit client protocol used by the tests over TCP Speak the transit client protocol used by the tests over TCP
""" """
received = b""
connected = False
def connectionMade(self):
self.connected = True
def connectionLost(self, reason):
self.connected = False
def send(self, data): def send(self, data):
self.transport.write(data) self.transport.write(data)
def disconnect(self): def disconnect(self):
self.transport.loseConnection() self.transport.loseConnection()
def dataReceived(self, data):
self.received = self.received + data
client_factory = ClientFactory() client_factory = ClientFactory()
client_factory.protocol = TransitClientProtocolTcp client_factory.protocol = TransitClientProtocolTcp
client_protocol = client_factory.buildProtocol(('127.0.0.1', 31337)) client_protocol = client_factory.buildProtocol(('127.0.0.1', 31337))

View File

@ -57,23 +57,27 @@ class _Transit:
p2 = self.new_protocol() p2 = self.new_protocol()
token1 = b"\x00"*32 token1 = b"\x00"*32
p1.dataReceived(handshake(token1, side=None)) p1.send(handshake(token1, side=None))
p2.dataReceived(handshake(token1, side=None)) self.flush()
p2.send(handshake(token1, side=None))
self.flush()
self.flush()
# a correct handshake yields an ack, after which we can send # a correct handshake yields an ack, after which we can send
exp = b"ok\n" exp = b"ok\n"
self.assertEqual(p1.transport.value(), exp) self.assertEqual(p1.received, exp)
self.assertEqual(p2.transport.value(), exp) self.assertEqual(p2.received, exp)
p1.transport.clear() p1.received = b""
p2.transport.clear() p2.received = b""
s1 = b"data1" s1 = b"data1"
p1.dataReceived(s1) p1.send(s1)
self.assertEqual(p2.transport.value(), s1) self.flush()
self.assertEqual(p2.received, s1)
p1.transport.loseConnection() p1.disconnect()
p2.transport.loseConnection() p2.disconnect()
def test_sided_unsided(self): def test_sided_unsided(self):
p1 = self.new_protocol() p1 = self.new_protocol()
@ -81,24 +85,28 @@ class _Transit:
token1 = b"\x00"*32 token1 = b"\x00"*32
side1 = b"\x01"*8 side1 = b"\x01"*8
p1.dataReceived(handshake(token1, side=side1)) p1.send(handshake(token1, side=side1))
p2.dataReceived(handshake(token1, side=None)) self.flush()
p2.send(handshake(token1, side=None))
self.flush()
self.flush()
# a correct handshake yields an ack, after which we can send # a correct handshake yields an ack, after which we can send
exp = b"ok\n" exp = b"ok\n"
self.assertEqual(p1.transport.value(), exp) self.assertEqual(p1.received, exp)
self.assertEqual(p2.transport.value(), exp) self.assertEqual(p2.received, exp)
p1.transport.clear() p1.received = b""
p2.transport.clear() p2.received = b""
# all data they sent after the handshake should be given to us # all data they sent after the handshake should be given to us
s1 = b"data1" s1 = b"data1"
p1.dataReceived(s1) p1.send(s1)
self.assertEqual(p2.transport.value(), s1) self.flush()
self.assertEqual(p2.received, s1)
p1.transport.loseConnection() p1.disconnect()
p2.transport.loseConnection() p2.disconnect()
def test_unsided_sided(self): def test_unsided_sided(self):
p1 = self.new_protocol() p1 = self.new_protocol()
@ -106,24 +114,27 @@ class _Transit:
token1 = b"\x00"*32 token1 = b"\x00"*32
side1 = b"\x01"*8 side1 = b"\x01"*8
p1.dataReceived(handshake(token1, side=None)) p1.send(handshake(token1, side=None))
p2.dataReceived(handshake(token1, side=side1)) p2.send(handshake(token1, side=side1))
self.flush()
self.flush()
# a correct handshake yields an ack, after which we can send # a correct handshake yields an ack, after which we can send
exp = b"ok\n" exp = b"ok\n"
self.assertEqual(p1.transport.value(), exp) self.assertEqual(p1.received, exp)
self.assertEqual(p2.transport.value(), exp) self.assertEqual(p2.received, exp)
p1.transport.clear() p1.received = b""
p2.transport.clear() p2.received = b""
# all data they sent after the handshake should be given to us # all data they sent after the handshake should be given to us
s1 = b"data1" s1 = b"data1"
p1.dataReceived(s1) p1.send(s1)
self.assertEqual(p2.transport.value(), s1) self.flush()
self.assertEqual(p2.received, s1)
p1.transport.loseConnection() p1.disconnect()
p2.transport.loseConnection() p2.disconnect()
def test_both_sided(self): def test_both_sided(self):
p1 = self.new_protocol() p1 = self.new_protocol()
@ -132,24 +143,28 @@ class _Transit:
token1 = b"\x00"*32 token1 = b"\x00"*32
side1 = b"\x01"*8 side1 = b"\x01"*8
side2 = b"\x02"*8 side2 = b"\x02"*8
p1.dataReceived(handshake(token1, side=side1)) p1.send(handshake(token1, side=side1))
p2.dataReceived(handshake(token1, side=side2)) self.flush()
p2.send(handshake(token1, side=side2))
self.flush()
self.flush()
# a correct handshake yields an ack, after which we can send # a correct handshake yields an ack, after which we can send
exp = b"ok\n" exp = b"ok\n"
self.assertEqual(p1.transport.value(), exp) self.assertEqual(p1.received, exp)
self.assertEqual(p2.transport.value(), exp) self.assertEqual(p2.received, exp)
p1.transport.clear() p1.received = b""
p2.transport.clear() p2.received = b""
# all data they sent after the handshake should be given to us # all data they sent after the handshake should be given to us
s1 = b"data1" s1 = b"data1"
p1.dataReceived(s1) p1.send(s1)
self.assertEqual(p2.transport.value(), s1) self.flush()
self.assertEqual(p2.received, s1)
p1.transport.loseConnection() p1.disconnect()
p2.transport.loseConnection() p2.disconnect()
def test_ignore_same_side(self): def test_ignore_same_side(self):
p1 = self.new_protocol() p1 = self.new_protocol()
@ -159,41 +174,47 @@ class _Transit:
token1 = b"\x00"*32 token1 = b"\x00"*32
side1 = b"\x01"*8 side1 = b"\x01"*8
p1.dataReceived(handshake(token1, side=side1)) p1.send(handshake(token1, side=side1))
self.flush()
self.assertEqual(self.count(), 1) self.assertEqual(self.count(), 1)
p2.dataReceived(handshake(token1, side=side1)) p2.send(handshake(token1, side=side1))
self.flush()
self.assertEqual(self.count(), 2) # same-side connections don't match self.assertEqual(self.count(), 2) # same-side connections don't match
# when the second side arrives, the spare first connection should be # when the second side arrives, the spare first connection should be
# closed # closed
side2 = b"\x02"*8 side2 = b"\x02"*8
p3.dataReceived(handshake(token1, side=side2)) p3.send(handshake(token1, side=side2))
self.flush()
self.flush()
self.assertEqual(self.count(), 0) self.assertEqual(self.count(), 0)
self.assertEqual(len(self._transit_server._pending_requests), 0) self.assertEqual(len(self._transit_server._pending_requests), 0)
self.assertEqual(len(self._transit_server._active_connections), 2) self.assertEqual(len(self._transit_server._active_connections), 2)
# That will trigger a disconnect on exactly one of (p1 or p2). # That will trigger a disconnect on exactly one of (p1 or p2).
# The other connection should still be connected # The other connection should still be connected
self.assertEqual(sum([int(t.transport.connected) for t in [p1, p2]]), 1) self.assertEqual(sum([int(t.connected) for t in [p1, p2]]), 1)
p1.transport.loseConnection() p1.disconnect()
p2.transport.loseConnection() p2.disconnect()
p3.transport.loseConnection() p3.disconnect()
def test_bad_handshake_old(self): def test_bad_handshake_old(self):
p1 = self.new_protocol() p1 = self.new_protocol()
token1 = b"\x00"*32 token1 = b"\x00"*32
p1.dataReceived(b"please DELAY " + hexlify(token1) + b"\n") p1.send(b"please DELAY " + hexlify(token1) + b"\n")
self.flush()
exp = b"bad handshake\n" exp = b"bad handshake\n"
self.assertEqual(p1.transport.value(), exp) self.assertEqual(p1.received, exp)
p1.transport.loseConnection() p1.disconnect()
def test_bad_handshake_old_slow(self): def test_bad_handshake_old_slow(self):
p1 = self.new_protocol() p1 = self.new_protocol()
p1.dataReceived(b"please DELAY ") p1.send(b"please DELAY ")
self.flush()
# As in test_impatience_new_slow, the current state machine has code # As in test_impatience_new_slow, the current state machine has code
# that can only be reached if we insert a stall here, so dataReceived # that can only be reached if we insert a stall here, so dataReceived
# gets called twice. Hopefully we can delete this test once # gets called twice. Hopefully we can delete this test once
@ -202,12 +223,13 @@ class _Transit:
token1 = b"\x00"*32 token1 = b"\x00"*32
# the server waits for the exact number of bytes in the expected # the server waits for the exact number of bytes in the expected
# handshake message. to trigger "bad handshake", we must match. # handshake message. to trigger "bad handshake", we must match.
p1.dataReceived(hexlify(token1) + b"\n") p1.send(hexlify(token1) + b"\n")
self.flush()
exp = b"bad handshake\n" exp = b"bad handshake\n"
self.assertEqual(p1.transport.value(), exp) self.assertEqual(p1.received, exp)
p1.transport.loseConnection() p1.disconnect()
def test_bad_handshake_new(self): def test_bad_handshake_new(self):
p1 = self.new_protocol() p1 = self.new_protocol()
@ -216,13 +238,14 @@ class _Transit:
side1 = b"\x01"*8 side1 = b"\x01"*8
# the server waits for the exact number of bytes in the expected # the server waits for the exact number of bytes in the expected
# handshake message. to trigger "bad handshake", we must match. # handshake message. to trigger "bad handshake", we must match.
p1.dataReceived(b"please DELAY " + hexlify(token1) + p1.send(b"please DELAY " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\n") b" for side " + hexlify(side1) + b"\n")
self.flush()
exp = b"bad handshake\n" exp = b"bad handshake\n"
self.assertEqual(p1.transport.value(), exp) self.assertEqual(p1.received, exp)
p1.transport.loseConnection() p1.disconnect()
def test_binary_handshake(self): def test_binary_handshake(self):
p1 = self.new_protocol() p1 = self.new_protocol()
@ -234,24 +257,26 @@ class _Transit:
# UnicodeDecodeError when it tried to coerce the incoming handshake # UnicodeDecodeError when it tried to coerce the incoming handshake
# to unicode, due to the ("\n" in buf) check. This was fixed to use # to unicode, due to the ("\n" in buf) check. This was fixed to use
# (b"\n" in buf). This exercises the old failure. # (b"\n" in buf). This exercises the old failure.
p1.dataReceived(binary_bad_handshake) p1.send(binary_bad_handshake)
self.flush()
exp = b"bad handshake\n" exp = b"bad handshake\n"
self.assertEqual(p1.transport.value(), exp) self.assertEqual(p1.received, exp)
p1.transport.loseConnection() p1.disconnect()
def test_impatience_old(self): def test_impatience_old(self):
p1 = self.new_protocol() p1 = self.new_protocol()
token1 = b"\x00"*32 token1 = b"\x00"*32
# sending too many bytes is impatience. # sending too many bytes is impatience.
p1.dataReceived(b"please relay " + hexlify(token1) + b"\nNOWNOWNOW") p1.send(b"please relay " + hexlify(token1) + b"\nNOWNOWNOW")
self.flush()
exp = b"impatient\n" exp = b"impatient\n"
self.assertEqual(p1.transport.value(), exp) self.assertEqual(p1.received, exp)
p1.transport.loseConnection() p1.disconnect()
def test_impatience_new(self): def test_impatience_new(self):
p1 = self.new_protocol() p1 = self.new_protocol()
@ -259,13 +284,14 @@ class _Transit:
token1 = b"\x00"*32 token1 = b"\x00"*32
side1 = b"\x01"*8 side1 = b"\x01"*8
# sending too many bytes is impatience. # sending too many bytes is impatience.
p1.dataReceived(b"please relay " + hexlify(token1) + p1.send(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\nNOWNOWNOW") b" for side " + hexlify(side1) + b"\nNOWNOWNOW")
self.flush()
exp = b"impatient\n" exp = b"impatient\n"
self.assertEqual(p1.transport.value(), exp) self.assertEqual(p1.received, exp)
p1.transport.loseConnection() p1.disconnect()
def test_impatience_new_slow(self): def test_impatience_new_slow(self):
p1 = self.new_protocol() p1 = self.new_protocol()
@ -281,27 +307,29 @@ class _Transit:
token1 = b"\x00"*32 token1 = b"\x00"*32
side1 = b"\x01"*8 side1 = b"\x01"*8
# sending too many bytes is impatience. # sending too many bytes is impatience.
p1.dataReceived(b"please relay " + hexlify(token1) + p1.send(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\n") b" for side " + hexlify(side1) + b"\n")
self.flush()
p1.send(b"NOWNOWNOW")
p1.dataReceived(b"NOWNOWNOW") self.flush()
exp = b"impatient\n" exp = b"impatient\n"
self.assertEqual(p1.transport.value(), exp) self.assertEqual(p1.received, exp)
p1.transport.loseConnection() p1.disconnect()
def test_short_handshake(self): def test_short_handshake(self):
p1 = self.new_protocol() p1 = self.new_protocol()
# hang up before sending a complete handshake # hang up before sending a complete handshake
p1.dataReceived(b"short") p1.send(b"short")
p1.transport.loseConnection() self.flush()
p1.disconnect()
def test_empty_handshake(self): def test_empty_handshake(self):
p1 = self.new_protocol() p1 = self.new_protocol()
# hang up before sending anything # hang up before sending anything
p1.transport.loseConnection() p1.disconnect()
class TransitWithLogs(_Transit, ServerBase, unittest.TestCase): class TransitWithLogs(_Transit, ServerBase, unittest.TestCase):
log_requests = True log_requests = True
@ -321,7 +349,8 @@ class Usage(ServerBase, unittest.TestCase):
def test_empty(self): def test_empty(self):
p1 = self.new_protocol() p1 = self.new_protocol()
# hang up before sending anything # hang up before sending anything
p1.transport.loseConnection() p1.disconnect()
self.flush()
# that will log the "empty" usage event # that will log the "empty" usage event
self.assertEqual(len(self._usage), 1, self._usage) self.assertEqual(len(self._usage), 1, self._usage)
@ -331,8 +360,9 @@ class Usage(ServerBase, unittest.TestCase):
def test_short(self): def test_short(self):
p1 = self.new_protocol() p1 = self.new_protocol()
# hang up before sending a complete handshake # hang up before sending a complete handshake
p1.transport.write(b"short") p1.send(b"short")
p1.transport.loseConnection() p1.disconnect()
self.flush()
# that will log the "empty" usage event # that will log the "empty" usage event
self.assertEqual(len(self._usage), 1, self._usage) self.assertEqual(len(self._usage), 1, self._usage)
@ -342,9 +372,10 @@ class Usage(ServerBase, unittest.TestCase):
def test_errory(self): def test_errory(self):
p1 = self.new_protocol() p1 = self.new_protocol()
p1.dataReceived(b"this is a very bad handshake\n") p1.send(b"this is a very bad handshake\n")
self.flush()
# that will log the "errory" usage event, then drop the connection # that will log the "errory" usage event, then drop the connection
p1.transport.loseConnection() p1.disconnect()
self.assertEqual(len(self._usage), 1, self._usage) self.assertEqual(len(self._usage), 1, self._usage)
(started, result, total_bytes, total_time, waiting_time) = self._usage[0] (started, result, total_bytes, total_time, waiting_time) = self._usage[0]
self.assertEqual(result, "errory", self._usage) self.assertEqual(result, "errory", self._usage)
@ -354,9 +385,11 @@ class Usage(ServerBase, unittest.TestCase):
token1 = b"\x00"*32 token1 = b"\x00"*32
side1 = b"\x01"*8 side1 = b"\x01"*8
p1.dataReceived(handshake(token1, side=side1)) p1.send(handshake(token1, side=side1))
self.flush()
# now we disconnect before the peer connects # now we disconnect before the peer connects
p1.transport.loseConnection() p1.disconnect()
self.flush()
self.assertEqual(len(self._usage), 1, self._usage) self.assertEqual(len(self._usage), 1, self._usage)
(started, result, total_bytes, total_time, waiting_time) = self._usage[0] (started, result, total_bytes, total_time, waiting_time) = self._usage[0]
@ -370,15 +403,20 @@ class Usage(ServerBase, unittest.TestCase):
token1 = b"\x00"*32 token1 = b"\x00"*32
side1 = b"\x01"*8 side1 = b"\x01"*8
side2 = b"\x02"*8 side2 = b"\x02"*8
p1.dataReceived(handshake(token1, side=side1)) p1.send(handshake(token1, side=side1))
p2.dataReceived(handshake(token1, side=side2)) self.flush()
p2.send(handshake(token1, side=side2))
self.flush()
self.assertEqual(self._usage, []) # no events yet self.assertEqual(self._usage, []) # no events yet
p1.dataReceived(b"\x00" * 13) p1.send(b"\x00" * 13)
p2.dataReceived(b"\xff" * 7) self.flush()
p2.send(b"\xff" * 7)
self.flush()
p1.transport.loseConnection() p1.disconnect()
self.flush()
self.assertEqual(len(self._usage), 1, self._usage) self.assertEqual(len(self._usage), 1, self._usage)
(started, result, total_bytes, total_time, waiting_time) = self._usage[0] (started, result, total_bytes, total_time, waiting_time) = self._usage[0]
@ -395,28 +433,34 @@ class Usage(ServerBase, unittest.TestCase):
token1 = b"\x00"*32 token1 = b"\x00"*32
side1 = b"\x01"*8 side1 = b"\x01"*8
side2 = b"\x02"*8 side2 = b"\x02"*8
p1a.dataReceived(handshake(token1, side=side1)) p1a.send(handshake(token1, side=side1))
p1b.dataReceived(handshake(token1, side=side1)) self.flush()
p1b.send(handshake(token1, side=side1))
self.flush()
# connect and disconnect a third client (for side1) to exercise the # connect and disconnect a third client (for side1) to exercise the
# code that removes a pending connection without removing the entire # code that removes a pending connection without removing the entire
# token # token
p1c.dataReceived(handshake(token1, side=side1)) p1c.send(handshake(token1, side=side1))
p1c.transport.loseConnection() p1c.disconnect()
self.flush()
self.assertEqual(len(self._usage), 1, self._usage) self.assertEqual(len(self._usage), 1, self._usage)
(started, result, total_bytes, total_time, waiting_time) = self._usage[0] (started, result, total_bytes, total_time, waiting_time) = self._usage[0]
self.assertEqual(result, "lonely", self._usage) self.assertEqual(result, "lonely", self._usage)
p2.dataReceived(handshake(token1, side=side2)) p2.send(handshake(token1, side=side2))
self.flush()
self.flush()
self.assertEqual(len(self._transit_server._pending_requests), 0) self.assertEqual(len(self._transit_server._pending_requests), 0)
self.assertEqual(len(self._usage), 2, self._usage) self.assertEqual(len(self._usage), 2, self._usage)
(started, result, total_bytes, total_time, waiting_time) = self._usage[1] (started, result, total_bytes, total_time, waiting_time) = self._usage[1]
self.assertEqual(result, "redundant", self._usage) self.assertEqual(result, "redundant", self._usage)
# one of the these is unecessary, but probably harmless # one of the these is unecessary, but probably harmless
p1a.transport.loseConnection() p1a.disconnect()
p1b.transport.loseConnection() p1b.disconnect()
self.flush()
self.assertEqual(len(self._usage), 3, self._usage) self.assertEqual(len(self._usage), 3, self._usage)
(started, result, total_bytes, total_time, waiting_time) = self._usage[2] (started, result, total_bytes, total_time, waiting_time) = self._usage[2]
self.assertEqual(result, "happy", self._usage) self.assertEqual(result, "happy", self._usage)