diff --git a/src/wormhole_transit_relay/server_state.py b/src/wormhole_transit_relay/server_state.py index 86cf9b6..e37f95d 100644 --- a/src/wormhole_transit_relay/server_state.py +++ b/src/wormhole_transit_relay/server_state.py @@ -497,6 +497,7 @@ class TransitServerState(object): @_machine.output() def _count_bytes(self, data): self._total_sent += len(data) + print("COUNT BYTES +{} now {}".format(len(data), self._total_sent)) @_machine.output() def _send(self, data): @@ -522,6 +523,7 @@ class TransitServerState(object): # some outputs to record "usage" information .. @_machine.output() def _record_usage(self): + print("RECORD", self, self._mood, self._total_sent) if self._mood == "jilted": if self._buddy: if self._buddy._mood == "happy": @@ -694,3 +696,7 @@ class TransitServerState(object): enter=done, outputs=[], ) + + + ## XXX tracing + set_trace_function = _machine._setTrace diff --git a/src/wormhole_transit_relay/server_tap.py b/src/wormhole_transit_relay/server_tap.py index fd0c7ff..27e5a21 100644 --- a/src/wormhole_transit_relay/server_tap.py +++ b/src/wormhole_transit_relay/server_tap.py @@ -57,7 +57,7 @@ def makeService(config, reactor=reactor): ws_factory = WebSocketServerFactory("ws://localhost:4002") # FIXME: url ws_factory.protocol = transit_server.WebSocketTransitConnection - ws_factory.websocket_protocols = ["transit_relay"] + ws_factory.websocket_protocols = ["binary"] tcp_factory.transit = transit ws_factory.transit = transit diff --git a/src/wormhole_transit_relay/test/test_transit_server.py b/src/wormhole_transit_relay/test/test_transit_server.py index 943013d..946ca66 100644 --- a/src/wormhole_transit_relay/test/test_transit_server.py +++ b/src/wormhole_transit_relay/test/test_transit_server.py @@ -4,6 +4,7 @@ from binascii import hexlify from twisted.trial import unittest from twisted.test import proto_helpers from twisted.internet.defer import inlineCallbacks +from twisted.internet.task import deferLater from .common import ServerBase from ..server_state import ( MemoryUsageRecorder, @@ -13,6 +14,13 @@ from ..transit_server import ( WebSocketTransitConnection, ) +from ..transit_server import Transit, TransitConnection, WebSocketTransitConnection +from twisted.internet.protocol import ( + ServerFactory, + ClientFactory, +) +from twisted.internet import protocol +from ..server_state import create_usage_tracker from autobahn.twisted.websocket import WebSocketServerFactory @@ -427,6 +435,8 @@ class Usage(ServerBase, unittest.TestCase): def test_one_happy_one_jilted(self): p1 = yield self.new_protocol() p2 = yield self.new_protocol() + print(dir(p1.factory)) + return token1 = b"\x00"*32 side1 = b"\x01"*8 @@ -436,6 +446,7 @@ class Usage(ServerBase, unittest.TestCase): p2.send(handshake(token1, side=side2)) self.flush() + print("shouldn't be events yet") self.assertEqual(self._usage.events, []) # no events yet p1.send(b"\x00" * 13) @@ -540,3 +551,136 @@ class UsageWebSockets(Usage): client_proto = yield agent.open("ws://127.0.0.1:4002/", dict()) return client_proto + +class New(unittest.TestCase): + """ + A completely fresh approach using: + + - no base classes (besides TestCase to match rest) + - twisted.test.iosim.* (IOPump etc) + - no "faking" any interfaces + """ + log_requests = False + + def setUp(self): + self._pumps = [] + self._usage = MemoryUsageRecorder() + self._setup_relay(blur_usage=60.0 if self.log_requests else None) + + def flush(self): + for pump in self._pumps: + pump.flush() + + def _setup_relay(self, blur_usage=None, log_file=None, usage_db=None): + usage = create_usage_tracker( + blur_usage=blur_usage, + log_file=log_file, + usage_db=usage_db, + ) + self._transit = Transit(usage, lambda: 123456789.0) + self._transit._debug_log = self.log_requests + self._transit.usage.add_backend(self._usage) + + def new_protocol(self): + if False: + return self._new_protocol_tcp() + else: + return self._new_protocol_ws() + + def _new_protocol_tcp(self): + server_factory = ServerFactory() + server_factory.protocol = TransitConnection + server_factory.transit = self._transit + server_protocol = server_factory.buildProtocol(('127.0.0.1', 0)) + + class ClientProtocol(protocol.Protocol): + def sendMessage(self, data): + self.transport.write(data) + + def disconnect(self): + self.transport.loseConnection() + + client_factory = ClientFactory() + client_factory.protocol = ClientProtocol + client_protocol = client_factory.buildProtocol(('128.0.0.1', 31337)) + + pump = iosim.connect( + server_protocol, + iosim.makeFakeServer(server_protocol), + client_protocol, + iosim.makeFakeClient(client_protocol), + ) + print("did connectionmade get called yet?") + pump.flush() + self._pumps.append(pump) + return client_protocol + + def _new_protocol_ws(self): + ws_factory = WebSocketServerFactory("ws://localhost:4002") # FIXME: url + ws_factory.protocol = WebSocketTransitConnection + ws_factory.transit = self._transit + ws_factory.websocket_protocols = ["binary"] + ws_protocol = ws_factory.buildProtocol(('127.0.0.1', 0)) + + from autobahn.twisted.websocket import WebSocketClientFactory, WebSocketClientProtocol + client_factory = WebSocketClientFactory() + client_factory.protocol = WebSocketClientProtocol + client_factory.protocols = ["binary"] + client_protocol = client_factory.buildProtocol(('127.0.0.1', 31337)) + client_protocol.disconnect = client_protocol.dropConnection + + pump = iosim.connect( + ws_protocol, + iosim.makeFakeServer(ws_protocol), + client_protocol, + iosim.makeFakeClient(client_protocol), + ) + self._pumps.append(pump) + return client_protocol + + def test_short(self): + p1 = self.new_protocol() + # hang up before sending a complete handshake +# p1.sendMessage(b"short") # <-- only makes sense for TCP + p1.disconnect() + self.flush() + + # that will log the "empty" usage event + self.assertEqual(len(self._usage.events), 1, self._usage) + self.assertEqual("empty", self._usage.events[0]["mood"]) + + def test_one_happy_one_jilted(self): + p1 = self.new_protocol() + p2 = self.new_protocol() + + token1 = b"\x00"*32 + side1 = b"\x01"*8 + side2 = b"\x02"*8 + from twisted.internet import reactor + + print("p1 data") + p1.sendMessage(handshake(token1, side=side1), True) + print("p2 data") + p2.sendMessage(handshake(token1, side=side2), True) + self.flush() + + print("shouldn't be events yet") + self.assertEqual(self._usage.events, []) # no events yet + + print("p1 moar") + for x in range(13): + p1.sendMessage(b"\x00", True) + ##p1.sendMessage(b"\x00" * 13) + self.flush() + print("p2 moar") + p2.sendMessage(b"\xff" * 7, True) + self.flush() + + print("p1 lose") + p1.disconnect() + self.flush() + + self.assertEqual(len(self._usage.events), 1, self._usage) + self.assertEqual(self._usage.events[0]["mood"], "happy", self._usage) + self.assertEqual(self._usage.events[0]["total_bytes"], 20) + self.assertNotIdentical(self._usage.events[0]["waiting_time"], None) diff --git a/src/wormhole_transit_relay/transit_server.py b/src/wormhole_transit_relay/transit_server.py index 0c2510a..b81b5d6 100644 --- a/src/wormhole_transit_relay/transit_server.py +++ b/src/wormhole_transit_relay/transit_server.py @@ -77,6 +77,10 @@ class TransitConnection(LineReceiver): except AttributeError: pass + def tracer(oldstate, theinput, newstate): + print("TRACE: {}: {} --{}--> {}".format(id(self), oldstate, theinput, newstate)) + self._state.set_trace_function(tracer) + def lineReceived(self, line): """ LineReceiver API @@ -213,6 +217,7 @@ class WebSocketTransitConnection(WebSocketServerProtocol): """ ITransitClient API """ + print("send: {}".format(repr(data))) self.sendMessage(data, isBinary=True) def disconnect(self): @@ -244,14 +249,14 @@ class WebSocketTransitConnection(WebSocketServerProtocol): # ideally more like self._reactor.seconds() ... but Twisted # doesn't have a good way to get the reactor for a protocol # (besides "use the global one") - # print("protocols: {}".format(request.protocols)) - return None #"transit_relay" + print("protocols: {}".format(request.protocols)) + return 'binary' def connectionMade(self): """ IProtocol API """ - # print("connectionMade") + print("connectionMade") super(WebSocketTransitConnection, self).connectionMade() self.started_time = time.time() self._first_message = True @@ -260,6 +265,10 @@ class WebSocketTransitConnection(WebSocketServerProtocol): self.factory.transit.usage, ) + def tracer(oldstate, theinput, newstate): + print("WSTRACE: {}: {} --{}--> {}".format(id(self), oldstate, theinput, newstate)) + self._state.set_trace_function(tracer) + def onOpen(self): # print("onOpen") self._state.connection_made(self) @@ -298,6 +307,6 @@ class WebSocketTransitConnection(WebSocketServerProtocol): """ IWebSocketChannel API """ - # print("onClose", wasClean, code, reason) + print("{} onClose: {} {} {}".format(id(self), wasClean, code, reason)) self._state.connection_lost() # XXX "transit finished", etc