diff --git a/src/wormhole_transit_relay/test/common.py b/src/wormhole_transit_relay/test/common.py index fb232b5..8073ee0 100644 --- a/src/wormhole_transit_relay/test/common.py +++ b/src/wormhole_transit_relay/test/common.py @@ -55,8 +55,11 @@ class ServerBase: self._setup_relay(blur_usage=blur_usage) def flush(self): + did_work = False for pump in self._pumps: - pump.flush() + did_work = pump.flush() or did_work + if did_work: + self.flush() def _setup_relay(self, blur_usage=None, log_file=None, usage_db=None): self._transit_server = Transit( diff --git a/src/wormhole_transit_relay/test/test_transit_server.py b/src/wormhole_transit_relay/test/test_transit_server.py index e44b099..8fbdef8 100644 --- a/src/wormhole_transit_relay/test/test_transit_server.py +++ b/src/wormhole_transit_relay/test/test_transit_server.py @@ -61,7 +61,6 @@ class _Transit: self.flush() p2.send(handshake(token1, side=None)) self.flush() - self.flush() # a correct handshake yields an ack, after which we can send exp = b"ok\n" @@ -89,7 +88,6 @@ class _Transit: self.flush() p2.send(handshake(token1, side=None)) self.flush() - self.flush() # a correct handshake yields an ack, after which we can send exp = b"ok\n" @@ -117,7 +115,6 @@ class _Transit: p1.send(handshake(token1, side=None)) p2.send(handshake(token1, side=side1)) self.flush() - self.flush() # a correct handshake yields an ack, after which we can send exp = b"ok\n" @@ -147,7 +144,6 @@ class _Transit: self.flush() p2.send(handshake(token1, side=side2)) self.flush() - self.flush() # a correct handshake yields an ack, after which we can send exp = b"ok\n" @@ -187,7 +183,6 @@ class _Transit: side2 = b"\x02"*8 p3.send(handshake(token1, side=side2)) self.flush() - self.flush() self.assertEqual(self.count(), 0) self.assertEqual(len(self._transit_server._pending_requests), 0) self.assertEqual(len(self._transit_server._active_connections), 2) @@ -453,7 +448,6 @@ class Usage(ServerBase, unittest.TestCase): p2.send(handshake(token1, side=side2)) self.flush() - self.flush() self.assertEqual(len(self._transit_server._pending_requests), 0) self.assertEqual(len(self._usage), 2, self._usage) (started, result, total_bytes, total_time, waiting_time) = self._usage[1]