From 555c23d4fe3740389a3e7d4995be6dcea7ed8615 Mon Sep 17 00:00:00 2001 From: meejah Date: Tue, 19 Jan 2021 15:44:13 -0700 Subject: [PATCH 01/96] first-cut of state-machine style code --- docs/server-statemachine.dot | 22 ++ src/wormhole_transit_relay/server_state.py | 299 +++++++++++++++++++++ 2 files changed, 321 insertions(+) create mode 100644 docs/server-statemachine.dot create mode 100644 src/wormhole_transit_relay/server_state.py diff --git a/docs/server-statemachine.dot b/docs/server-statemachine.dot new file mode 100644 index 0000000..d3a2215 --- /dev/null +++ b/docs/server-statemachine.dot @@ -0,0 +1,22 @@ +/** +. thinking about state-machine from "hand-drawn" perspective +. will it look the same as an Automat one? +**/ + +digraph { + listening -> wait_relay [label="connection_made"] + + wait_relay -> wait_partner [label="please_relay\nFindPartner"] + wait_relay -> wait_partner [label="please_relay_for_side\nFindPartner"] + wait_relay -> done [label="invalid_token\nSend('bad handshake')\nDisconnect"] + wait_relay -> done [label="connection_lost"] + + wait_partner -> relaying [label="got_partner\nConnectPartner(partner)\nSend('ok')"] + wait_partner -> done [label="got_bytes\nDisconnect"] + wait_partner -> done [label="connection_lost"] + + relaying -> relaying [label="got_bytes\nSend(bytes)"] + relaying -> done [label="partner_connection_lost\nDisconnectMe"] + relaying -> done [label="connection_lost\nDisconnectPartner"] +} + diff --git a/src/wormhole_transit_relay/server_state.py b/src/wormhole_transit_relay/server_state.py new file mode 100644 index 0000000..6b2c6a5 --- /dev/null +++ b/src/wormhole_transit_relay/server_state.py @@ -0,0 +1,299 @@ + +import automat +from zope.interface import ( + Interface, + implementer, +) + + +class ITransitClient(Interface): + def send(data): + """ + Send some byets to the client + """ + + def disconnect(reason): + """ + Disconnect the client transport + """ + + def connect_partner(other): + """ + Hook up to our partner. + :param ITransitClient other: our partner + """ + + def disconnect_partner(): + """ + Disconnect our partner's transport + """ + + +@implementer(ITransitClient) +class TestClient(object): + _partner = None + _data = b"" + + def send_to_partner(self, data): + print("{} GOT:{}".format(id(self), repr(data))) + if self._partner: + self._partner.send(data) + + def send(self, data): + print("{} SEND:{}".format(id(self), repr(data))) + self._data += data + + def disconnect(self): + print("disconnect") + + def connect_partner(self, other): + print("connect_partner: {} <--> {}".format(id(self), id(other))) + assert self._partner is None, "double partner" + self._partner = other + + def disconnect_partner(self): + assert self._partner is not None, "no partner" + print("disconnect_partner: {}".format(id(self._partner))) + + +class PendingRequests(object): + """ + Tracks the tokens we have received from client connections and + maps them to their partner connections + """ + + def register_token(self, *args): + """ + """ + + +class TransitServer(object): + """ + Encapsulates the state-machine of the server side of a transit + relay connection. + + Once the protocol has been told to relay (or to relay for a side) + it starts passing all received bytes to the other side until it + closes. + """ + + _machine = automat.MethodicalMachine() + _client = None + + @_machine.input() + def connection_made(self, client): + """ + A client has connected. May only be called once. + + :param ITransitClient client: our client. + """ + # NB: the "only called once" is enforced by the state-machine; + # this input is only valid for the "listening" state, to which + # we never return. + + @_machine.input() + def please_relay(self, token): + pass + + @_machine.input() + def please_relay_for_side(self, token, side): + pass + + @_machine.input() + def bad_token(self): + """ + A bad token / relay line was received + """ + + @_machine.input() + def got_partner(self, client): + """ + The partner for this relay session has been found + """ + + @_machine.input() + def connection_lost(self): + pass + + @_machine.input() + def partner_connection_lost(self): + pass + + @_machine.input() + def got_bytes(self, data): + """ + Some bytes have arrived (that aren't part of the handshake) + """ + + @_machine.output() + def _remember_client(self, client): + self._client = client + + @_machine.output() + def _register_token(self, token): + return self._real_register_token_for_side(token, None) + + @_machine.output() + def _register_token_for_side(self, token, side): + return self._real_register_token_for_side(token, side) + + @_machine.output() + def _unregister(self): + """ + remove us from the thing that remembers tokens and sides + """ + + @_machine.output() + def _send_bad(self): + self._client.send("bad handshake\n") + + @_machine.output() + def _send_ok(self): + self._client.send("ok\n") + + @_machine.output() + def _send(self, data): + self._client.send(data) + + @_machine.output() + def _send_to_partner(self, data): + self._client.send_to_partner(data) + + @_machine.output() + def _connect_partner(self, client): + self._client.connect_partner(client) + + @_machine.output() + def _disconnect(self): + self._client.disconnect() + + @_machine.output() + def _disconnect_partner(self): + self._client.disconnect_partner() + + def _real_register_token_for_side(self, token, side): + """ + basically, _got_handshake() + connection_got_token() from "real" + code ...and if this is the "second" side, hook them up and + pass .got_partner() input to both + """ + + @_machine.state(initial=True) + def listening(self): + """ + Initial state, awaiting connection. + """ + + @_machine.state() + def wait_relay(self): + """ + Waiting for a 'relay' message + """ + + @_machine.state() + def wait_partner(self): + """ + Waiting for our partner to connect + """ + + @_machine.state() + def relaying(self): + """ + Relaying bytes to our partner + """ + + @_machine.state() + def done(self): + """ + Terminal state + """ + + listening.upon( + connection_made, + enter=wait_relay, + outputs=[_remember_client], + ) + + wait_relay.upon( + please_relay, + enter=wait_partner, + outputs=[_register_token], + ) + wait_relay.upon( + please_relay_for_side, + enter=wait_partner, + outputs=[_register_token_for_side], + ) + wait_relay.upon( + bad_token, + enter=done, + outputs=[_send_bad, _disconnect], + ) + wait_relay.upon( + connection_lost, + enter=done, + outputs=[_disconnect], + ) + + wait_partner.upon( + got_partner, + enter=relaying, + outputs=[_send_ok, _connect_partner], + ) + wait_partner.upon( + connection_lost, + enter=done, + outputs=[_unregister], + ) + + relaying.upon( + got_bytes, + enter=relaying, + outputs=[_send_to_partner], + ) + relaying.upon( + connection_lost, + enter=done, + outputs=[_disconnect_partner, _unregister], + ) + relaying.upon( + partner_connection_lost, + enter=done, + outputs=[_disconnect, _unregister], + ) + + + + +# actions: +# - send("ok") +# - send("bad handshake") +# - disconnect +# - ... + +if __name__ == "__main__": + server0 = TransitServer() + client0 = TestClient() + server1 = TransitServer() + client1 = TestClient() + server0.connection_made(client0) + server0.please_relay(b"bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb") + + # this would be an error, because our partner hasn't shown up yet + # print(server0.got_bytes(b"asdf")) + + server1.connection_made(client1) + server1.please_relay(b"bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb") + + # XXX the PendingRequests stuff should do this, going "by hand" for now + server0.got_partner(client1) + server1.got_partner(client0) + + # should be connected now + server0.got_bytes(b"asdf") + # client1 should receive b"asdf" + + server0.connection_lost() + print("----[ received data on both sides ]----") + print("client0:{}".format(repr(client0._data))) + print("client1:{}".format(repr(client1._data))) From 0e11f1b8f14539529ec7a06904439c8d73b75332 Mon Sep 17 00:00:00 2001 From: meejah Date: Mon, 25 Jan 2021 17:59:14 -0700 Subject: [PATCH 02/96] (wip) refactor to use Automat state-machine --- src/wormhole_transit_relay/server_state.py | 252 +++++++++++++++++-- src/wormhole_transit_relay/transit_server.py | 214 ++++++++-------- 2 files changed, 327 insertions(+), 139 deletions(-) diff --git a/src/wormhole_transit_relay/server_state.py b/src/wormhole_transit_relay/server_state.py index 6b2c6a5..cf95d1d 100644 --- a/src/wormhole_transit_relay/server_state.py +++ b/src/wormhole_transit_relay/server_state.py @@ -1,3 +1,4 @@ +from collections import defaultdict import automat from zope.interface import ( @@ -12,7 +13,7 @@ class ITransitClient(Interface): Send some byets to the client """ - def disconnect(reason): + def disconnect(): """ Disconnect the client transport """ @@ -37,7 +38,7 @@ class TestClient(object): def send_to_partner(self, data): print("{} GOT:{}".format(id(self), repr(data))) if self._partner: - self._partner.send(data) + self._partner._client.send(data) def send(self, data): print("{} SEND:{}".format(id(self), repr(data))) @@ -56,18 +57,94 @@ class TestClient(object): print("disconnect_partner: {}".format(id(self._partner))) +class ActiveConnections(object): + """ + Tracks active connections. A connection is 'active' when both + sides have shown up and they are glued together. + """ + def __init__(self): + self._connections = set() + + def register(self, side0, side1): + """ + A connection has become active so register both its sides + + :param TransitConnection side0: one side of the connection + :param TransitConnection side1: one side of the connection + """ + self._connections.add(side0) + self._connections.add(side1) + + def unregister(self, side): + """ + One side of a connection has become inactive. + + :param TransitConnection side: an inactive side of a connection + """ + self._connections.discard(side) + + class PendingRequests(object): """ Tracks the tokens we have received from client connections and - maps them to their partner connections + maps them to their partner connections for requests that haven't + yet been 'glued together' (that is, one side hasn't yet shown up). """ - def register_token(self, *args): + def __init__(self, active_connections): + self._requests = defaultdict(set) # token -> set((side, TransitConnection)) + self._active = active_connections + + def unregister(self, token, side, tc): + if token in self._requests: + self._requests[token].discard((side, tc)) + self._active.unregister(tc) + + def register_token(self, token, new_side, new_tc): """ + A client has connected and successfully offered a token (and + optional 'side' token). If this is the first one for this + token, we merely remember it. If it is the second side for + this token we connect them together. + + :returns bool: True if we are the first side to register this + token """ + potentials = self._requests[token] + for old in potentials: + (old_side, old_tc) = old + if ((old_side is None) + or (new_side is None) + or (old_side != new_side)): + # we found a match + # FIXME: debug-log this + # print("transit relay 2: %s" % new_tc.get_token()) + + # drop and stop tracking the rest + potentials.remove(old) + for (_, leftover_tc) in potentials.copy(): + # Don't record this as errory. It's just a spare connection + # from the same side as a connection that got used. This + # can happen if the connection hint contains multiple + # addresses (we don't currently support those, but it'd + # probably be useful in the future). + leftover_tc.disconnect_redundant() + self._requests.pop(token, None) + + # glue the two ends together + self._active.register(new_tc, old_tc) + new_tc.got_partner(old_tc) + old_tc.got_partner(new_tc) + return False + + # FIXME: debug-log this + # print("transit relay 1: %s" % new_tc.get_token()) + potentials.add((new_side, new_tc)) + return True + # TODO: timer -class TransitServer(object): +class TransitServerState(object): """ Encapsulates the state-machine of the server side of a transit relay connection. @@ -79,6 +156,36 @@ class TransitServer(object): _machine = automat.MethodicalMachine() _client = None + _buddy = None + _token = None + _side = None + _first = None + _mood = "empty" + + def __init__(self, pending_requests): + self._pending_requests = pending_requests + + def get_token(self): + """ + :returns str: a string describing our token. This will be "-" if + we have no token yet, or "{16 chars}-" if we have + just a token or "{16 chars}-{16 chars}" if we have a token and + a side. + """ + d = "-" + if self._token is not None: + d = self._token[:16].decode("ascii") + if self._side is not None: + d += "-" + self._side.decode("ascii") + else: + d += "-" + return d + + def get_mood(self): + """ + :returns str: description of the current 'mood' of the connection + """ + return self._mood @_machine.input() def connection_made(self, client): @@ -93,16 +200,22 @@ class TransitServer(object): @_machine.input() def please_relay(self, token): - pass + """ + A 'please relay X' message has been received (the original version + of the protocol). + """ @_machine.input() def please_relay_for_side(self, token, side): - pass + """ + A 'please relay X for side Y' message has been received (the + second version of the protocol). + """ @_machine.input() def bad_token(self): """ - A bad token / relay line was received + A bad token / relay line was received (e.g. couldn't be parsed) """ @_machine.input() @@ -113,11 +226,15 @@ class TransitServer(object): @_machine.input() def connection_lost(self): - pass + """ + Our transport has failed. + """ @_machine.input() def partner_connection_lost(self): - pass + """ + Our partner's transport has failed. + """ @_machine.input() def got_bytes(self, data): @@ -142,14 +259,20 @@ class TransitServer(object): """ remove us from the thing that remembers tokens and sides """ + return self._pending_requests.unregister(self._token, self._side, self) @_machine.output() def _send_bad(self): - self._client.send("bad handshake\n") + self._mood = "errory" + self._client.send(b"bad handshake\n") @_machine.output() def _send_ok(self): - self._client.send("ok\n") + self._client.send(b"ok\n") + + @_machine.output() + def _send_impatient(self): + self._client.send(b"impatient\n") @_machine.output() def _send(self, data): @@ -157,10 +280,11 @@ class TransitServer(object): @_machine.output() def _send_to_partner(self, data): - self._client.send_to_partner(data) + self._buddy._client.send(data) @_machine.output() def _connect_partner(self, client): + self._buddy = client self._client.connect_partner(client) @_machine.output() @@ -171,12 +295,60 @@ class TransitServer(object): def _disconnect_partner(self): self._client.disconnect_partner() + # some outputs to record the "mood" .. + @_machine.output() + def _mood_happy(self): + self._mood = "happy" + + @_machine.output() + def _mood_lonely(self): + self._mood = "lonely" + + @_machine.output() + def _mood_impatient(self): + self._mood = "impatient" + + @_machine.output() + def _mood_errory(self): + self._mood = "errory" + + @_machine.output() + def _mood_happy_if_first(self): + """ + We disconnected first so we're only happy if we also connected + first. + """ + if self._first: + self._mood = "happy" + else: + self._mood = "jilted" + + @_machine.output() + def _mood_happy_if_second(self): + """ + We disconnected second so we're only happy if we also connected + second. + """ + if self._first: + self._mood = "jilted" + else: + self._mood = "happy" + def _real_register_token_for_side(self, token, side): """ - basically, _got_handshake() + connection_got_token() from "real" - code ...and if this is the "second" side, hook them up and - pass .got_partner() input to both + A client has connected and sent a valid version 1 or version 2 + handshake. If the former, `side` will be None. + + In either case, we remember the tokens and register + ourselves. This might result in 'got_partner' notifications to + two state-machines if this is the second side for a given token. + + :param bytes token: the token + :param bytes side: The side token (or None) """ + self._token = token + self._side = side + self._first = self._pending_requests.register_token(token, side, self) @_machine.state(initial=True) def listening(self): @@ -217,17 +389,22 @@ class TransitServer(object): wait_relay.upon( please_relay, enter=wait_partner, - outputs=[_register_token], + outputs=[_mood_lonely, _register_token], ) wait_relay.upon( please_relay_for_side, enter=wait_partner, - outputs=[_register_token_for_side], + outputs=[_mood_lonely, _register_token_for_side], ) wait_relay.upon( bad_token, enter=done, - outputs=[_send_bad, _disconnect], + outputs=[_mood_errory, _send_bad, _disconnect], + ) + wait_relay.upon( + got_bytes, + enter=done, + outputs=[_mood_errory, _disconnect], ) wait_relay.upon( connection_lost, @@ -238,12 +415,17 @@ class TransitServer(object): wait_partner.upon( got_partner, enter=relaying, - outputs=[_send_ok, _connect_partner], + outputs=[_mood_happy, _send_ok, _connect_partner], ) wait_partner.upon( connection_lost, enter=done, - outputs=[_unregister], + outputs=[_mood_lonely, _unregister], + ) + wait_partner.upon( + got_bytes, + enter=done, + outputs=[_mood_impatient, _send_impatient, _disconnect, _unregister], ) relaying.upon( @@ -254,12 +436,23 @@ class TransitServer(object): relaying.upon( connection_lost, enter=done, - outputs=[_disconnect_partner, _unregister], + outputs=[_mood_happy_if_first, _disconnect_partner, _unregister], ) relaying.upon( partner_connection_lost, enter=done, - outputs=[_disconnect, _unregister], + outputs=[_mood_happy_if_second, _disconnect, _unregister], + ) + + done.upon( + connection_lost, + enter=done, + outputs=[], + ) + done.upon( + partner_connection_lost, + enter=done, + outputs=[], ) @@ -272,9 +465,12 @@ class TransitServer(object): # - ... if __name__ == "__main__": - server0 = TransitServer() + active = ActiveConnections() + pending = PendingRequests(active) + + server0 = TransitServerState(pending) client0 = TestClient() - server1 = TransitServer() + server1 = TransitServerState(pending) client1 = TestClient() server0.connection_made(client0) server0.please_relay(b"bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb") @@ -282,12 +478,14 @@ if __name__ == "__main__": # this would be an error, because our partner hasn't shown up yet # print(server0.got_bytes(b"asdf")) + print("about to relay client1") server1.connection_made(client1) server1.please_relay(b"bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb") + print("done") # XXX the PendingRequests stuff should do this, going "by hand" for now - server0.got_partner(client1) - server1.got_partner(client0) +# server0.got_partner(client1) +# server1.got_partner(client0) # should be connected now server0.got_bytes(b"asdf") diff --git a/src/wormhole_transit_relay/transit_server.py b/src/wormhole_transit_relay/transit_server.py index 426f50f..9de3b1e 100644 --- a/src/wormhole_transit_relay/transit_server.py +++ b/src/wormhole_transit_relay/transit_server.py @@ -24,6 +24,17 @@ def blur_size(size): return round_to(size, 1e6) return round_to(size, 100e6) + +from wormhole_transit_relay.server_state import ( + TransitServerState, + PendingRequests, + ActiveConnections, + ITransitClient, +) +from zope.interface import implementer + + +@implementer(ITransitClient) class TransitConnection(LineReceiver): delimiter = b'\n' # maximum length of a line we will accept before the handshake is complete. @@ -32,13 +43,33 @@ class TransitConnection(LineReceiver): MAX_LENGTH = 1024 def __init__(self): - self._got_token = False - self._got_side = False - self._sent_ok = False - self._mood = "empty" - self._buddy = None self._total_sent = 0 + def send(self, data): + """ + ITransitClient API + """ + self.transport.write(data) + + def disconnect(self): + """ + ITransitClient API + """ + self.transport.loseConnection() + + def connect_partner(self, other): + """ + ITransitClient API + """ + self._buddy = other + + def disconnect_partner(self): + """ + ITransitClient API + """ + self._buddy._client.transport.loseConnection() + self._buddy = None + def describeToken(self): d = "-" if self._got_token: @@ -50,6 +81,8 @@ class TransitConnection(LineReceiver): return d def connectionMade(self): + self._state = TransitServerState(self.factory.pending_requests) + self._state.connection_made(self) self._started = time.time() self._log_requests = self.factory._log_requests try: @@ -71,10 +104,10 @@ class TransitConnection(LineReceiver): side = new.group(2) return self._got_handshake(token, side) - self.sendLine(b"bad handshake") - if self._log_requests: - log.msg("transit handshake failure") - return self.disconnect_error() + # state-machine calls us via ITransitClient interface to do + # bad handshake etc. + return self._state.bad_token() + #return self._state.got_bytes(line) def rawDataReceived(self, data): # We are an IPushProducer to our buddy's IConsumer, so they'll @@ -83,33 +116,15 @@ class TransitConnection(LineReceiver): # practice, this buffers about 10MB per connection, after which # point the sender will only transmit data as fast as the # receiver can handle it. - if self._sent_ok: - # if self._buddy is None then our buddy disconnected - # (we're "jilted"), so we hung up too, but our incoming - # data hasn't stopped yet (it will in a moment, after our - # disconnect makes a roundtrip through the kernel). This - # probably means the file receiver hung up, and this - # connection is the file sender. In may-2020 this happened - # 11 times in 40 days. - if self._buddy: - self._total_sent += len(data) - self._buddy.transport.write(data) - return - - # handshake is complete but not yet sent_ok - self.sendLine(b"impatient") - if self._log_requests: - log.msg("transit impatience failure") - return self.disconnect_error() # impatience yields failure + self._state.got_bytes(data) + self._total_sent += len(data) def _got_handshake(self, token, side): - self._got_token = token - self._got_side = side - self._mood = "lonely" # until buddy connects + self._state.please_relay_for_side(token, side) + # self._mood = "lonely" # until buddy connects self.setRawMode() - self.factory.connection_got_token(token, side, self) - def buddy_connected(self, them): + def __buddy_connected(self, them): self._buddy = them self._mood = "happy" self.sendLine(b"ok") @@ -121,7 +136,7 @@ class TransitConnection(LineReceiver): # The Transit object calls buddy_connected() on both protocols, so # there will be two producer/consumer pairs. - def buddy_disconnected(self): + def __buddy_disconnected(self): if self._log_requests: log.msg("buddy_disconnected %s" % self.describeToken()) self._buddy = None @@ -145,56 +160,62 @@ class TransitConnection(LineReceiver): def connectionLost(self, reason): finished = time.time() total_time = finished - self._started + self._state.connection_lost() - # Record usage. There are eight cases: - # * n0: we haven't gotten a full handshake yet (empty) - # * n1: the handshake failed, not a real client (errory) - # * n2: real client disconnected before any buddy appeared (lonely) - # * n3: real client closed as redundant after buddy appears (redundant) - # * n4: real client connected first, buddy closes first (jilted) - # * n5: real client connected first, buddy close last (happy) - # * n6: real client connected last, buddy closes first (jilted) - # * n7: real client connected last, buddy closes last (happy) + # XXX FIXME record usage - # * non-connected clients (0,1,2,3) always write a usage record - # * for connected clients, whoever disconnects first gets to write the - # usage record (5, 7). The last disconnect doesn't write a record. + if False: + # Record usage. There are eight cases: + # * n0: we haven't gotten a full handshake yet (empty) + # * n1: the handshake failed, not a real client (errory) + # * n2: real client disconnected before any buddy appeared (lonely) + # * n3: real client closed as redundant after buddy appears (redundant) + # * n4: real client connected first, buddy closes first (jilted) + # * n5: real client connected first, buddy close last (happy) + # * n6: real client connected last, buddy closes first (jilted) + # * n7: real client connected last, buddy closes last (happy) + + # * non-connected clients (0,1,2,3) always write a usage record + # * for connected clients, whoever disconnects first gets to write the + # usage record (5, 7). The last disconnect doesn't write a record. + + if self._mood == "empty": # 0 + assert not self._buddy + self.factory.recordUsage(self._started, "empty", 0, + total_time, None) + elif self._mood == "errory": # 1 + assert not self._buddy + self.factory.recordUsage(self._started, "errory", 0, + total_time, None) + elif self._mood == "redundant": # 3 + assert not self._buddy + self.factory.recordUsage(self._started, "redundant", 0, + total_time, None) + elif self._mood == "jilted": # 4 or 6 + # we were connected, but our buddy hung up on us. They record the + # usage event, we do not + pass + elif self._mood == "lonely": # 2 + assert not self._buddy + self.factory.recordUsage(self._started, "lonely", 0, + total_time, None) + else: # 5 or 7 + # we were connected, we hung up first. We record the event. + assert self._mood == "happy", self._mood + assert self._buddy + starts = [self._started, self._buddy._started] + total_time = finished - min(starts) + waiting_time = max(starts) - min(starts) + total_bytes = self._total_sent + self._buddy._total_sent + self.factory.recordUsage(self._started, "happy", total_bytes, + total_time, waiting_time) + + if self._buddy: + self._buddy.buddy_disconnected() + # self.factory.transitFinished(self, self._got_token, self._got_side, + # self.describeToken()) - if self._mood == "empty": # 0 - assert not self._buddy - self.factory.recordUsage(self._started, "empty", 0, - total_time, None) - elif self._mood == "errory": # 1 - assert not self._buddy - self.factory.recordUsage(self._started, "errory", 0, - total_time, None) - elif self._mood == "redundant": # 3 - assert not self._buddy - self.factory.recordUsage(self._started, "redundant", 0, - total_time, None) - elif self._mood == "jilted": # 4 or 6 - # we were connected, but our buddy hung up on us. They record the - # usage event, we do not - pass - elif self._mood == "lonely": # 2 - assert not self._buddy - self.factory.recordUsage(self._started, "lonely", 0, - total_time, None) - else: # 5 or 7 - # we were connected, we hung up first. We record the event. - assert self._mood == "happy", self._mood - assert self._buddy - starts = [self._started, self._buddy._started] - total_time = finished - min(starts) - waiting_time = max(starts) - min(starts) - total_bytes = self._total_sent + self._buddy._total_sent - self.factory.recordUsage(self._started, "happy", total_bytes, - total_time, waiting_time) - if self._buddy: - self._buddy.buddy_disconnected() - self.factory.transitFinished(self, self._got_token, self._got_side, - self.describeToken()) class Transit(protocol.ServerFactory): # I manage pairs of simultaneous connections to a secondary TCP port, @@ -230,6 +251,8 @@ class Transit(protocol.ServerFactory): protocol = TransitConnection def __init__(self, blur_usage, log_file, usage_db): + self.active_connections = ActiveConnections() + self.pending_requests = PendingRequests(self.active_connections) self._blur_usage = blur_usage self._log_requests = blur_usage is None if self._blur_usage: @@ -247,39 +270,6 @@ class Transit(protocol.ServerFactory): self._pending_requests = defaultdict(set) # token -> set((side, TransitConnection)) self._active_connections = set() # TransitConnection - def connection_got_token(self, token, new_side, new_tc): - potentials = self._pending_requests[token] - for old in potentials: - (old_side, old_tc) = old - if ((old_side is None) - or (new_side is None) - or (old_side != new_side)): - # we found a match - if self._debug_log: - log.msg("transit relay 2: %s" % new_tc.describeToken()) - - # drop and stop tracking the rest - potentials.remove(old) - for (_, leftover_tc) in potentials.copy(): - # Don't record this as errory. It's just a spare connection - # from the same side as a connection that got used. This - # can happen if the connection hint contains multiple - # addresses (we don't currently support those, but it'd - # probably be useful in the future). - leftover_tc.disconnect_redundant() - self._pending_requests.pop(token, None) - - # glue the two ends together - self._active_connections.add(new_tc) - self._active_connections.add(old_tc) - new_tc.buddy_connected(old_tc) - old_tc.buddy_connected(new_tc) - return - if self._debug_log: - log.msg("transit relay 1: %s" % new_tc.describeToken()) - potentials.add((new_side, new_tc)) - # TODO: timer - def transitFinished(self, tc, token, side, description): if token in self._pending_requests: side_tc = (side, tc) From 0e647074597977fc4da34116decc8b3431c9c8be Mon Sep 17 00:00:00 2001 From: meejah Date: Mon, 1 Feb 2021 16:55:15 -0700 Subject: [PATCH 03/96] count totals in state-machine --- src/wormhole_transit_relay/server_state.py | 7 ++++++- src/wormhole_transit_relay/transit_server.py | 10 +++------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/wormhole_transit_relay/server_state.py b/src/wormhole_transit_relay/server_state.py index cf95d1d..ffa0819 100644 --- a/src/wormhole_transit_relay/server_state.py +++ b/src/wormhole_transit_relay/server_state.py @@ -161,6 +161,7 @@ class TransitServerState(object): _side = None _first = None _mood = "empty" + _total_sent = 0 def __init__(self, pending_requests): self._pending_requests = pending_requests @@ -274,6 +275,10 @@ class TransitServerState(object): def _send_impatient(self): self._client.send(b"impatient\n") + @_machine.output() + def _count_bytes(self, data): + self._total_sent += len(data) + @_machine.output() def _send(self, data): self._client.send(data) @@ -404,7 +409,7 @@ class TransitServerState(object): wait_relay.upon( got_bytes, enter=done, - outputs=[_mood_errory, _disconnect], + outputs=[_count_bytes, _mood_errory, _disconnect], ) wait_relay.upon( connection_lost, diff --git a/src/wormhole_transit_relay/transit_server.py b/src/wormhole_transit_relay/transit_server.py index 9de3b1e..abd6406 100644 --- a/src/wormhole_transit_relay/transit_server.py +++ b/src/wormhole_transit_relay/transit_server.py @@ -42,9 +42,6 @@ class TransitConnection(LineReceiver): MAX_LENGTH = 1024 - def __init__(self): - self._total_sent = 0 - def send(self, data): """ ITransitClient API @@ -104,10 +101,10 @@ class TransitConnection(LineReceiver): side = new.group(2) return self._got_handshake(token, side) - # state-machine calls us via ITransitClient interface to do - # bad handshake etc. + # we should have been switched to "raw data" mode on the first + # line received (after which rawDataReceived() is called for + # all bytes) so getting here means a bad handshake. return self._state.bad_token() - #return self._state.got_bytes(line) def rawDataReceived(self, data): # We are an IPushProducer to our buddy's IConsumer, so they'll @@ -117,7 +114,6 @@ class TransitConnection(LineReceiver): # point the sender will only transmit data as fast as the # receiver can handle it. self._state.got_bytes(data) - self._total_sent += len(data) def _got_handshake(self, token, side): self._state.please_relay_for_side(token, side) From b51237d958343a1c3f0d83c05191ae40b70f4502 Mon Sep 17 00:00:00 2001 From: meejah Date: Fri, 12 Feb 2021 00:06:19 -0700 Subject: [PATCH 04/96] start of refactoring usage-recording: pass one test --- src/wormhole_transit_relay/server_state.py | 85 +++++++++++++++++-- .../test/test_transit_server.py | 8 +- src/wormhole_transit_relay/transit_server.py | 15 +++- 3 files changed, 90 insertions(+), 18 deletions(-) diff --git a/src/wormhole_transit_relay/server_state.py b/src/wormhole_transit_relay/server_state.py index ffa0819..21b2464 100644 --- a/src/wormhole_transit_relay/server_state.py +++ b/src/wormhole_transit_relay/server_state.py @@ -1,13 +1,18 @@ +import time from collections import defaultdict import automat from zope.interface import ( Interface, + Attribute, implementer, ) class ITransitClient(Interface): + + started_time = Attribute("timestamp when the connection was established") + def send(data): """ Send some byets to the client @@ -34,6 +39,11 @@ class ITransitClient(Interface): class TestClient(object): _partner = None _data = b"" + _started_time = time.time() + + @property + def started_time(self): + return _started_time def send_to_partner(self, data): print("{} GOT:{}".format(id(self), repr(data))) @@ -57,6 +67,53 @@ class TestClient(object): print("disconnect_partner: {}".format(id(self._partner))) +class UsageRecorder(object): + """ + Tracks usage statistics of connections + """ + + def record(self, started, buddy_started, result, bytes_sent, buddy_bytes): + """ + :param int started: timestamp when our connection started + + :param int buddy_started: None, or the timestamp when our + partner's connection started (will be None if we don't yet + have a partner). + + :param str result: a label for the result of the connection + (one of the "moods"). + + :param int bytes_sent: number of bytes we sent + + :param int buddy_bytes: number of bytes our partner sent + """ + + # ideally self._reactor.seconds() or similar, but .. + finished = time.time() + if buddy_started is not None: + starts = [started, buddy_started] + total_time = finished - min(starts) + waiting_time = max(starts) - min(starts) + total_bytes = bytes_sent + buddy_bytes + else: + total_time = finished - started + waiting_time = None + total_bytes = bytes_sent + # probably want like "backends" here or something? original + # code logs some JSON (maybe) and writes to a database (maybe) + # and tests record in memory. + self.json_record({ + "started": started, + "total_time": total_time, + "waiting_time": waiting_time, + "total_bytes": total_bytes, + "mood": result, + }) + + def json_record(self, data): + pass + + class ActiveConnections(object): """ Tracks active connections. A connection is 'active' when both @@ -163,8 +220,9 @@ class TransitServerState(object): _mood = "empty" _total_sent = 0 - def __init__(self, pending_requests): + def __init__(self, pending_requests, usage_recorder): self._pending_requests = pending_requests + self._usage = usage_recorder def get_token(self): """ @@ -300,6 +358,17 @@ class TransitServerState(object): def _disconnect_partner(self): self._client.disconnect_partner() + # some outputs to record "usage" information .. + @_machine.output() + def _record_usage(self): + self._usage.record( + started=self._client.started_time, + buddy_started=self._buddy._client.started_time if self._buddy is not None else None, + result=self._mood, + bytes_sent=self._total_sent, + buddy_bytes=self._buddy._total_sent if self._buddy is not None else None + ) + # some outputs to record the "mood" .. @_machine.output() def _mood_happy(self): @@ -404,17 +473,17 @@ class TransitServerState(object): wait_relay.upon( bad_token, enter=done, - outputs=[_mood_errory, _send_bad, _disconnect], + outputs=[_mood_errory, _send_bad, _disconnect, _record_usage], ) wait_relay.upon( got_bytes, enter=done, - outputs=[_count_bytes, _mood_errory, _disconnect], + outputs=[_count_bytes, _mood_errory, _disconnect, _record_usage], ) wait_relay.upon( connection_lost, enter=done, - outputs=[_disconnect], + outputs=[_disconnect, _record_usage], ) wait_partner.upon( @@ -425,12 +494,12 @@ class TransitServerState(object): wait_partner.upon( connection_lost, enter=done, - outputs=[_mood_lonely, _unregister], + outputs=[_mood_lonely, _unregister, _record_usage], ) wait_partner.upon( got_bytes, enter=done, - outputs=[_mood_impatient, _send_impatient, _disconnect, _unregister], + outputs=[_mood_impatient, _send_impatient, _disconnect, _unregister, _record_usage], ) relaying.upon( @@ -441,12 +510,12 @@ class TransitServerState(object): relaying.upon( connection_lost, enter=done, - outputs=[_mood_happy_if_first, _disconnect_partner, _unregister], + outputs=[_mood_happy_if_first, _disconnect_partner, _unregister, _record_usage], ) relaying.upon( partner_connection_lost, enter=done, - outputs=[_mood_happy_if_second, _disconnect, _unregister], + outputs=[_mood_happy_if_second, _disconnect, _unregister, _record_usage], ) done.upon( diff --git a/src/wormhole_transit_relay/test/test_transit_server.py b/src/wormhole_transit_relay/test/test_transit_server.py index bca740e..d151b21 100644 --- a/src/wormhole_transit_relay/test/test_transit_server.py +++ b/src/wormhole_transit_relay/test/test_transit_server.py @@ -340,10 +340,7 @@ class Usage(ServerBase, unittest.TestCase): def setUp(self): super(Usage, self).setUp() self._usage = [] - def record(started, result, total_bytes, total_time, waiting_time): - self._usage.append((started, result, total_bytes, - total_time, waiting_time)) - self._transit_server.recordUsage = record + self._transit_server.usage.json_record = self._usage.append def test_empty(self): p1 = self.new_protocol() @@ -365,8 +362,7 @@ class Usage(ServerBase, unittest.TestCase): # that will log the "empty" usage event self.assertEqual(len(self._usage), 1, self._usage) - (started, result, total_bytes, total_time, waiting_time) = self._usage[0] - self.assertEqual(result, "empty", self._usage) + self.assertEqual("empty", self._usage[0]["mood"]) def test_errory(self): p1 = self.new_protocol() diff --git a/src/wormhole_transit_relay/transit_server.py b/src/wormhole_transit_relay/transit_server.py index abd6406..1727610 100644 --- a/src/wormhole_transit_relay/transit_server.py +++ b/src/wormhole_transit_relay/transit_server.py @@ -29,6 +29,7 @@ from wormhole_transit_relay.server_state import ( TransitServerState, PendingRequests, ActiveConnections, + UsageRecorder, ITransitClient, ) from zope.interface import implementer @@ -41,6 +42,7 @@ class TransitConnection(LineReceiver): # This must be >= to the longest possible handshake message. MAX_LENGTH = 1024 + started_time = None def send(self, data): """ @@ -78,9 +80,15 @@ class TransitConnection(LineReceiver): return d def connectionMade(self): - self._state = TransitServerState(self.factory.pending_requests) + # ideally more like self._reactor.seconds() ... but Twisted + # doesn't have a good way to get the reactor for a protocol + # (besides "use the global one") + self.started_time = time.time() + self._state = TransitServerState( + self.factory.pending_requests, + self.factory.usage, + ) self._state.connection_made(self) - self._started = time.time() self._log_requests = self.factory._log_requests try: self.transport.setTcpKeepAlive(True) @@ -154,8 +162,6 @@ class TransitConnection(LineReceiver): self.transport.loseConnection() def connectionLost(self, reason): - finished = time.time() - total_time = finished - self._started self._state.connection_lost() # XXX FIXME record usage @@ -249,6 +255,7 @@ class Transit(protocol.ServerFactory): def __init__(self, blur_usage, log_file, usage_db): self.active_connections = ActiveConnections() self.pending_requests = PendingRequests(self.active_connections) + self.usage = UsageRecorder() self._blur_usage = blur_usage self._log_requests = blur_usage is None if self._blur_usage: From 734ed809c2aa2c7342be917f04d98c8fbddbdc19 Mon Sep 17 00:00:00 2001 From: meejah Date: Fri, 12 Feb 2021 00:35:52 -0700 Subject: [PATCH 05/96] fix more tests --- .../test/test_transit_server.py | 25 +++++++------------ 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/src/wormhole_transit_relay/test/test_transit_server.py b/src/wormhole_transit_relay/test/test_transit_server.py index d151b21..a3b7a90 100644 --- a/src/wormhole_transit_relay/test/test_transit_server.py +++ b/src/wormhole_transit_relay/test/test_transit_server.py @@ -350,8 +350,7 @@ class Usage(ServerBase, unittest.TestCase): # that will log the "empty" usage event self.assertEqual(len(self._usage), 1, self._usage) - (started, result, total_bytes, total_time, waiting_time) = self._usage[0] - self.assertEqual(result, "empty", self._usage) + self.assertEqual(self._usage[0]["mood"], "empty", self._usage) def test_short(self): p1 = self.new_protocol() @@ -372,8 +371,7 @@ class Usage(ServerBase, unittest.TestCase): # that will log the "errory" usage event, then drop the connection p1.disconnect() self.assertEqual(len(self._usage), 1, self._usage) - (started, result, total_bytes, total_time, waiting_time) = self._usage[0] - self.assertEqual(result, "errory", self._usage) + self.assertEqual(self._usage[0]["mood"], "errory", self._usage) def test_lonely(self): p1 = self.new_protocol() @@ -387,9 +385,8 @@ class Usage(ServerBase, unittest.TestCase): self.flush() self.assertEqual(len(self._usage), 1, self._usage) - (started, result, total_bytes, total_time, waiting_time) = self._usage[0] - self.assertEqual(result, "lonely", self._usage) - self.assertIdentical(waiting_time, None) + self.assertEqual(self._usage[0]["mood"], "lonely", self._usage) + self.assertIdentical(self._usage[0]["waiting_time"], None) def test_one_happy_one_jilted(self): p1 = self.new_protocol() @@ -414,9 +411,8 @@ class Usage(ServerBase, unittest.TestCase): self.flush() self.assertEqual(len(self._usage), 1, self._usage) - (started, result, total_bytes, total_time, waiting_time) = self._usage[0] - self.assertEqual(result, "happy", self._usage) - self.assertEqual(total_bytes, 20) + self.assertEqual(self._usage[0]["mood"], "happy", self._usage) + self.assertEqual(self._usage[0]["total_bytes"], 20) self.assertNotIdentical(waiting_time, None) def test_redundant(self): @@ -441,20 +437,17 @@ class Usage(ServerBase, unittest.TestCase): self.flush() self.assertEqual(len(self._usage), 1, self._usage) - (started, result, total_bytes, total_time, waiting_time) = self._usage[0] - self.assertEqual(result, "lonely", self._usage) + self.assertEqual(self._usage[0]["mood"], "lonely") p2.send(handshake(token1, side=side2)) self.flush() self.assertEqual(len(self._transit_server._pending_requests), 0) self.assertEqual(len(self._usage), 2, self._usage) - (started, result, total_bytes, total_time, waiting_time) = self._usage[1] - self.assertEqual(result, "redundant", self._usage) + self.assertEqual(self._usage[1]["mood"], "redundant") # one of the these is unecessary, but probably harmless p1a.disconnect() p1b.disconnect() self.flush() self.assertEqual(len(self._usage), 3, self._usage) - (started, result, total_bytes, total_time, waiting_time) = self._usage[2] - self.assertEqual(result, "happy", self._usage) + self.assertEqual(self._usage[2]["mood"], "happy") From 7b91377e94df570bec76eaa5b61a68a61d76f246 Mon Sep 17 00:00:00 2001 From: meejah Date: Fri, 12 Feb 2021 00:36:15 -0700 Subject: [PATCH 06/96] try to make 'redudant' mood work --- src/wormhole_transit_relay/server_state.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/wormhole_transit_relay/server_state.py b/src/wormhole_transit_relay/server_state.py index 21b2464..31458b2 100644 --- a/src/wormhole_transit_relay/server_state.py +++ b/src/wormhole_transit_relay/server_state.py @@ -185,7 +185,8 @@ class PendingRequests(object): # can happen if the connection hint contains multiple # addresses (we don't currently support those, but it'd # probably be useful in the future). - leftover_tc.disconnect_redundant() + ##leftover_tc.disconnect_redundant() + leftover_tc.partner_connection_lost() self._requests.pop(token, None) # glue the two ends together @@ -378,6 +379,10 @@ class TransitServerState(object): def _mood_lonely(self): self._mood = "lonely" + @_machine.output() + def _mood_redundant(self): + self._mood = "redundant" + @_machine.output() def _mood_impatient(self): self._mood = "impatient" @@ -501,6 +506,11 @@ class TransitServerState(object): enter=done, outputs=[_mood_impatient, _send_impatient, _disconnect, _unregister, _record_usage], ) + wait_partner.upon( + partner_connection_lost, + enter=done, + outputs=[_mood_redundant, _disconnect, _record_usage], + ) relaying.upon( got_bytes, From ff578fccf801d18a71d8665d6d98d892a3d9ba0d Mon Sep 17 00:00:00 2001 From: meejah Date: Fri, 12 Feb 2021 00:50:05 -0700 Subject: [PATCH 07/96] fix more tests (that examine internals) --- .../test/test_transit_server.py | 16 +++++++++------- src/wormhole_transit_relay/transit_server.py | 4 ++-- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/src/wormhole_transit_relay/test/test_transit_server.py b/src/wormhole_transit_relay/test/test_transit_server.py index a3b7a90..a4f2f94 100644 --- a/src/wormhole_transit_relay/test/test_transit_server.py +++ b/src/wormhole_transit_relay/test/test_transit_server.py @@ -13,9 +13,11 @@ def handshake(token, side=None): class _Transit: def count(self): - return sum([len(potentials) - for potentials - in self._transit_server._pending_requests.values()]) + return sum([ + len(potentials) + for potentials + in self._transit_server.pending_requests._requests.values() + ]) def test_blur_size(self): blur = transit_server.blur_size @@ -50,7 +52,7 @@ class _Transit: self.assertEqual(self.count(), 0) # the token should be removed too - self.assertEqual(len(self._transit_server._pending_requests), 0) + self.assertEqual(len(self._transit_server.pending_requests._requests), 0) def test_both_unsided(self): p1 = self.new_protocol() @@ -186,8 +188,8 @@ class _Transit: p3.send(handshake(token1, side=side2)) self.flush() self.assertEqual(self.count(), 0) - self.assertEqual(len(self._transit_server._pending_requests), 0) - self.assertEqual(len(self._transit_server._active_connections), 2) + self.assertEqual(len(self._transit_server.pending_requests._requests), 0) + self.assertEqual(len(self._transit_server.active_connections._connections), 2) # That will trigger a disconnect on exactly one of (p1 or p2). # The other connection should still be connected self.assertEqual(sum([int(t.connected) for t in [p1, p2]]), 1) @@ -441,7 +443,7 @@ class Usage(ServerBase, unittest.TestCase): p2.send(handshake(token1, side=side2)) self.flush() - self.assertEqual(len(self._transit_server._pending_requests), 0) + self.assertEqual(len(self._transit_server.pending_requests._requests), 0) self.assertEqual(len(self._usage), 2, self._usage) self.assertEqual(self._usage[1]["mood"], "redundant") diff --git a/src/wormhole_transit_relay/transit_server.py b/src/wormhole_transit_relay/transit_server.py index 1727610..fba84be 100644 --- a/src/wormhole_transit_relay/transit_server.py +++ b/src/wormhole_transit_relay/transit_server.py @@ -270,8 +270,8 @@ class Transit(protocol.ServerFactory): self._db = get_db(usage_db) self._rebooted = time.time() # we don't track TransitConnections until they submit a token - self._pending_requests = defaultdict(set) # token -> set((side, TransitConnection)) - self._active_connections = set() # TransitConnection +## self._pending_requests = defaultdict(set) # token -> set((side, TransitConnection)) +## self._active_connections = set() # TransitConnection def transitFinished(self, tc, token, side, description): if token in self._pending_requests: From 4669619f7e18d4c49746331c90c4f8586adab86b Mon Sep 17 00:00:00 2001 From: meejah Date: Fri, 12 Feb 2021 01:09:16 -0700 Subject: [PATCH 08/96] skip usage-counting if we're jilted but other side is happy? --- src/wormhole_transit_relay/server_state.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/wormhole_transit_relay/server_state.py b/src/wormhole_transit_relay/server_state.py index 31458b2..7cdcb70 100644 --- a/src/wormhole_transit_relay/server_state.py +++ b/src/wormhole_transit_relay/server_state.py @@ -362,6 +362,10 @@ class TransitServerState(object): # some outputs to record "usage" information .. @_machine.output() def _record_usage(self): + if self._mood == "jilted": + if self._buddy: + if self._buddy._mood == "happy": + return self._usage.record( started=self._client.started_time, buddy_started=self._buddy._client.started_time if self._buddy is not None else None, From 40919b51be424451a8c7714ec3f60690aeb801d5 Mon Sep 17 00:00:00 2001 From: meejah Date: Fri, 12 Feb 2021 01:09:39 -0700 Subject: [PATCH 09/96] count bytes missing --- src/wormhole_transit_relay/server_state.py | 2 +- src/wormhole_transit_relay/test/test_transit_server.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/wormhole_transit_relay/server_state.py b/src/wormhole_transit_relay/server_state.py index 7cdcb70..690b2c5 100644 --- a/src/wormhole_transit_relay/server_state.py +++ b/src/wormhole_transit_relay/server_state.py @@ -519,7 +519,7 @@ class TransitServerState(object): relaying.upon( got_bytes, enter=relaying, - outputs=[_send_to_partner], + outputs=[_count_bytes, _send_to_partner], ) relaying.upon( connection_lost, diff --git a/src/wormhole_transit_relay/test/test_transit_server.py b/src/wormhole_transit_relay/test/test_transit_server.py index a4f2f94..09615c7 100644 --- a/src/wormhole_transit_relay/test/test_transit_server.py +++ b/src/wormhole_transit_relay/test/test_transit_server.py @@ -415,7 +415,7 @@ class Usage(ServerBase, unittest.TestCase): self.assertEqual(len(self._usage), 1, self._usage) self.assertEqual(self._usage[0]["mood"], "happy", self._usage) self.assertEqual(self._usage[0]["total_bytes"], 20) - self.assertNotIdentical(waiting_time, None) + self.assertNotIdentical(self._usage[0]["waiting_time"], None) def test_redundant(self): p1a = self.new_protocol() From 5ed572187bc76e07add827533197d05ee3816abd Mon Sep 17 00:00:00 2001 From: meejah Date: Fri, 12 Feb 2021 01:16:33 -0700 Subject: [PATCH 10/96] unregister completely --- src/wormhole_transit_relay/server_state.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/wormhole_transit_relay/server_state.py b/src/wormhole_transit_relay/server_state.py index 690b2c5..0e7c446 100644 --- a/src/wormhole_transit_relay/server_state.py +++ b/src/wormhole_transit_relay/server_state.py @@ -153,8 +153,14 @@ class PendingRequests(object): self._active = active_connections def unregister(self, token, side, tc): + """ + We no longer care about a particular client (e.g. it has + disconnected). + """ if token in self._requests: self._requests[token].discard((side, tc)) + if not self._requests[token]: + del self._requests[token] self._active.unregister(tc) def register_token(self, token, new_side, new_tc): From 53864f57f0cbb3535d4039abefe560592778e939 Mon Sep 17 00:00:00 2001 From: meejah Date: Fri, 12 Feb 2021 01:32:52 -0700 Subject: [PATCH 11/96] use 'backends' for usage-recording --- src/wormhole_transit_relay/server_state.py | 104 +++++++++++++++++- .../test/test_transit_server.py | 45 ++++---- 2 files changed, 124 insertions(+), 25 deletions(-) diff --git a/src/wormhole_transit_relay/server_state.py b/src/wormhole_transit_relay/server_state.py index 0e7c446..b98615e 100644 --- a/src/wormhole_transit_relay/server_state.py +++ b/src/wormhole_transit_relay/server_state.py @@ -67,11 +67,105 @@ class TestClient(object): print("disconnect_partner: {}".format(id(self._partner))) +class IUsageWriter(Interface): + """ + Records actual usage statistics in some way + """ + + def record_usage(started=None, total_time=None, waiting_time=None, total_bytes=None, mood=None): + """ + :param int started: timestemp when this connection began + + :param float total_time: total seconds this connection lasted + + :param float waiting_time: None or the total seconds one side + waited for the other + + :param int total_bytes: the total bytes sent. In case the + connection was concluded successfully, only one side will + record the total bytes (but count both). + + :param str mood: the 'mood' of the connection + """ + + +@implementer(IUsageWriter) +class MemoryUsageRecorder: + + def __init__(self): + self.events = [] + + def record_usage(self, started=None, total_time=None, waiting_time=None, total_bytes=None, mood=None): + """ + IUsageWriter. + """ + data = { + "started": started, + "total_time": total_time, + "waiting_time": waiting_time, + "total_bytes": total_bytes, + "mood": mood, + } + self.events.append(data) + + +@implementer(IUsageWriter) +class LogFileUsageRecorder: + + def __init__(self, writable_file): + self._file = writable_file + + def record_usage(self, started=None, total_time=None, waiting_time=None, total_bytes=None, mood=None): + """ + IUsageWriter. + """ + data = { + "started": started, + "total_time": total_time, + "waiting_time": waiting_time, + "total_bytes": total_bytes, + "mood": mood, + } + self._file.write(json.dumps(data)) + + + +@implementer(IUsageWriter) +class DatabaseUsageRecorder: + + def __init__(self, _db): + self._db = db + + def record_usage(self, started=None, total_time=None, waiting_time=None, total_bytes=None, mood=None): + """ + IUsageWriter. + """ + + class UsageRecorder(object): """ Tracks usage statistics of connections """ + def __init__(self): + self._backends = set() + + def add_backend(self, backend): + """ + Add a new backend. + + :param IUsageWriter backend: the backend to add + """ + self._backends.add(backend) + + def remove_backend(self, backend): + """ + Remove an existing backend + + :param IUsageWriter backend: the backend to remove + """ + self._backends.remove(backend) + def record(self, started, buddy_started, result, bytes_sent, buddy_bytes): """ :param int started: timestamp when our connection started @@ -102,7 +196,7 @@ class UsageRecorder(object): # probably want like "backends" here or something? original # code logs some JSON (maybe) and writes to a database (maybe) # and tests record in memory. - self.json_record({ + self._notify_backends({ "started": started, "total_time": total_time, "waiting_time": waiting_time, @@ -110,8 +204,12 @@ class UsageRecorder(object): "mood": result, }) - def json_record(self, data): - pass + def _notify_backends(self, data): + """ + Internal helper. Tell every backend we have about a new usage. + """ + for backend in self._backends: + backend.record_usage(**data) class ActiveConnections(object): diff --git a/src/wormhole_transit_relay/test/test_transit_server.py b/src/wormhole_transit_relay/test/test_transit_server.py index 09615c7..22c8ee4 100644 --- a/src/wormhole_transit_relay/test/test_transit_server.py +++ b/src/wormhole_transit_relay/test/test_transit_server.py @@ -3,6 +3,7 @@ from binascii import hexlify from twisted.trial import unittest from .common import ServerBase from .. import transit_server +from ..server_state import MemoryUsageRecorder def handshake(token, side=None): hs = b"please relay " + hexlify(token) @@ -341,8 +342,8 @@ class Usage(ServerBase, unittest.TestCase): def setUp(self): super(Usage, self).setUp() - self._usage = [] - self._transit_server.usage.json_record = self._usage.append + self._usage = MemoryUsageRecorder() + self._transit_server.usage.add_backend(self._usage) def test_empty(self): p1 = self.new_protocol() @@ -351,8 +352,8 @@ class Usage(ServerBase, unittest.TestCase): self.flush() # that will log the "empty" usage event - self.assertEqual(len(self._usage), 1, self._usage) - self.assertEqual(self._usage[0]["mood"], "empty", self._usage) + self.assertEqual(len(self._usage.events), 1, self._usage) + self.assertEqual(self._usage.events[0]["mood"], "empty", self._usage) def test_short(self): p1 = self.new_protocol() @@ -362,8 +363,8 @@ class Usage(ServerBase, unittest.TestCase): self.flush() # that will log the "empty" usage event - self.assertEqual(len(self._usage), 1, self._usage) - self.assertEqual("empty", self._usage[0]["mood"]) + self.assertEqual(len(self._usage.events), 1, self._usage) + self.assertEqual("empty", self._usage.events[0]["mood"]) def test_errory(self): p1 = self.new_protocol() @@ -372,8 +373,8 @@ class Usage(ServerBase, unittest.TestCase): self.flush() # that will log the "errory" usage event, then drop the connection p1.disconnect() - self.assertEqual(len(self._usage), 1, self._usage) - self.assertEqual(self._usage[0]["mood"], "errory", self._usage) + self.assertEqual(len(self._usage.events), 1, self._usage) + self.assertEqual(self._usage.events[0]["mood"], "errory", self._usage) def test_lonely(self): p1 = self.new_protocol() @@ -386,9 +387,9 @@ class Usage(ServerBase, unittest.TestCase): p1.disconnect() self.flush() - self.assertEqual(len(self._usage), 1, self._usage) - self.assertEqual(self._usage[0]["mood"], "lonely", self._usage) - self.assertIdentical(self._usage[0]["waiting_time"], None) + self.assertEqual(len(self._usage.events), 1, self._usage) + self.assertEqual(self._usage.events[0]["mood"], "lonely", self._usage) + self.assertIdentical(self._usage.events[0]["waiting_time"], None) def test_one_happy_one_jilted(self): p1 = self.new_protocol() @@ -402,7 +403,7 @@ class Usage(ServerBase, unittest.TestCase): p2.send(handshake(token1, side=side2)) self.flush() - self.assertEqual(self._usage, []) # no events yet + self.assertEqual(self._usage.events, []) # no events yet p1.send(b"\x00" * 13) self.flush() @@ -412,10 +413,10 @@ class Usage(ServerBase, unittest.TestCase): p1.disconnect() self.flush() - self.assertEqual(len(self._usage), 1, self._usage) - self.assertEqual(self._usage[0]["mood"], "happy", self._usage) - self.assertEqual(self._usage[0]["total_bytes"], 20) - self.assertNotIdentical(self._usage[0]["waiting_time"], None) + self.assertEqual(len(self._usage.events), 1, self._usage) + self.assertEqual(self._usage.events[0]["mood"], "happy", self._usage) + self.assertEqual(self._usage.events[0]["total_bytes"], 20) + self.assertNotIdentical(self._usage.events[0]["waiting_time"], None) def test_redundant(self): p1a = self.new_protocol() @@ -438,18 +439,18 @@ class Usage(ServerBase, unittest.TestCase): p1c.disconnect() self.flush() - self.assertEqual(len(self._usage), 1, self._usage) - self.assertEqual(self._usage[0]["mood"], "lonely") + self.assertEqual(len(self._usage.events), 1, self._usage) + self.assertEqual(self._usage.events[0]["mood"], "lonely") p2.send(handshake(token1, side=side2)) self.flush() self.assertEqual(len(self._transit_server.pending_requests._requests), 0) - self.assertEqual(len(self._usage), 2, self._usage) - self.assertEqual(self._usage[1]["mood"], "redundant") + self.assertEqual(len(self._usage.events), 2, self._usage) + self.assertEqual(self._usage.events[1]["mood"], "redundant") # one of the these is unecessary, but probably harmless p1a.disconnect() p1b.disconnect() self.flush() - self.assertEqual(len(self._usage), 3, self._usage) - self.assertEqual(self._usage[2]["mood"], "happy") + self.assertEqual(len(self._usage.events), 3, self._usage) + self.assertEqual(self._usage.events[2]["mood"], "happy") From b7bcdfdca3d762a1e9bcdb8239badb16bb17bc28 Mon Sep 17 00:00:00 2001 From: meejah Date: Fri, 12 Feb 2021 02:18:36 -0700 Subject: [PATCH 12/96] more stats / recording works --- src/wormhole_transit_relay/server_state.py | 45 +++++++++++++++---- src/wormhole_transit_relay/test/test_stats.py | 40 ++++++++++++----- .../test/test_transit_server.py | 40 ++++++++++------- src/wormhole_transit_relay/transit_server.py | 26 +++++------ 4 files changed, 99 insertions(+), 52 deletions(-) diff --git a/src/wormhole_transit_relay/server_state.py b/src/wormhole_transit_relay/server_state.py index b98615e..442adbf 100644 --- a/src/wormhole_transit_relay/server_state.py +++ b/src/wormhole_transit_relay/server_state.py @@ -1,4 +1,5 @@ import time +import json from collections import defaultdict import automat @@ -126,29 +127,54 @@ class LogFileUsageRecorder: "total_bytes": total_bytes, "mood": mood, } - self._file.write(json.dumps(data)) - + self._file.write(json.dumps(data) + "\n") + self._file.flush() @implementer(IUsageWriter) class DatabaseUsageRecorder: - def __init__(self, _db): + def __init__(self, db): self._db = db def record_usage(self, started=None, total_time=None, waiting_time=None, total_bytes=None, mood=None): """ IUsageWriter. """ + self._db.execute( + "INSERT INTO `usage`" + " (`started`, `total_time`, `waiting_time`," + " `total_bytes`, `result`)" + " VALUES (?,?,?,?,?)", + (started, total_time, waiting_time, total_bytes, mood) + ) + # XXX FIXME see comment in transit_server + #self._update_stats() + self._db.commit() -class UsageRecorder(object): +def round_to(size, coarseness): + return int(coarseness*(1+int((size-1)/coarseness))) + + +def blur_size(size): + if size == 0: + return 0 + if size < 1e6: + return round_to(size, 10e3) + if size < 1e9: + return round_to(size, 1e6) + return round_to(size, 100e6) + + +class UsageTracker(object): """ Tracks usage statistics of connections """ - def __init__(self): + def __init__(self, blur_usage): self._backends = set() + self._blur_usage = blur_usage def add_backend(self, backend): """ @@ -181,7 +207,6 @@ class UsageRecorder(object): :param int buddy_bytes: number of bytes our partner sent """ - # ideally self._reactor.seconds() or similar, but .. finished = time.time() if buddy_started is not None: @@ -193,9 +218,11 @@ class UsageRecorder(object): total_time = finished - started waiting_time = None total_bytes = bytes_sent - # probably want like "backends" here or something? original - # code logs some JSON (maybe) and writes to a database (maybe) - # and tests record in memory. + + if self._blur_usage: + started = self._blur_usage * (started // self._blur_usage) + total_bytes = blur_size(total_bytes) + self._notify_backends({ "started": started, "total_time": total_time, diff --git a/src/wormhole_transit_relay/test/test_stats.py b/src/wormhole_transit_relay/test/test_stats.py index 1f114b1..4c0b036 100644 --- a/src/wormhole_transit_relay/test/test_stats.py +++ b/src/wormhole_transit_relay/test/test_stats.py @@ -22,9 +22,10 @@ class DB(unittest.TestCase): with mock.patch("time.time", return_value=T+0): t = Transit(blur_usage=None, log_file=None, usage_db=usage_db) db = self.open_db(usage_db) + usage = list(t.usage._backends)[0] with mock.patch("time.time", return_value=T+1): - t.recordUsage(started=123, result="happy", total_bytes=100, + usage.record_usage(started=123, mood="happy", total_bytes=100, total_time=10, waiting_time=2) self.assertEqual(db.execute("SELECT * FROM `usage`").fetchall(), [dict(result="happy", started=123, @@ -36,7 +37,7 @@ class DB(unittest.TestCase): waiting=0, connected=0)) with mock.patch("time.time", return_value=T+2): - t.recordUsage(started=150, result="errory", total_bytes=200, + usage.record_usage(started=150, mood="errory", total_bytes=200, total_time=11, waiting_time=3) self.assertEqual(db.execute("SELECT * FROM `usage`").fetchall(), [dict(result="happy", started=123, @@ -58,18 +59,22 @@ class DB(unittest.TestCase): def test_no_db(self): t = Transit(blur_usage=None, log_file=None, usage_db=None) + self.assertEqual(0, len(t.usage._backends)) - t.recordUsage(started=123, result="happy", total_bytes=100, - total_time=10, waiting_time=2) - t.timerUpdateStats() class LogToStdout(unittest.TestCase): def test_log(self): # emit lines of JSON to log_file, if set log_file = io.StringIO() t = Transit(blur_usage=None, log_file=log_file, usage_db=None) - t.recordUsage(started=123, result="happy", total_bytes=100, - total_time=10, waiting_time=2) + with mock.patch("time.time", return_value=133): + t.usage.record( + started=123, + buddy_started=125, + result="happy", + bytes_sent=100, + buddy_bytes=0, + ) self.assertEqual(json.loads(log_file.getvalue()), {"started": 123, "total_time": 10, "waiting_time": 2, "total_bytes": 100, @@ -80,8 +85,16 @@ class LogToStdout(unittest.TestCase): # requested amount, and sizes should be rounded up too log_file = io.StringIO() t = Transit(blur_usage=60, log_file=log_file, usage_db=None) - t.recordUsage(started=123, result="happy", total_bytes=11999, - total_time=10, waiting_time=2) + + with mock.patch("time.time", return_value=123 + 10): + t.usage.record( + started=123, + buddy_started=125, + result="happy", + bytes_sent=11999, + buddy_bytes=8001, + ) + print(log_file.getvalue()) self.assertEqual(json.loads(log_file.getvalue()), {"started": 120, "total_time": 10, "waiting_time": 2, "total_bytes": 20000, @@ -89,5 +102,10 @@ class LogToStdout(unittest.TestCase): def test_do_not_log(self): t = Transit(blur_usage=60, log_file=None, usage_db=None) - t.recordUsage(started=123, result="happy", total_bytes=11999, - total_time=10, waiting_time=2) + t.usage.record( + started=123, + buddy_started=124, + result="happy", + bytes_sent=11999, + buddy_bytes=12, + ) diff --git a/src/wormhole_transit_relay/test/test_transit_server.py b/src/wormhole_transit_relay/test/test_transit_server.py index 22c8ee4..60cc8c9 100644 --- a/src/wormhole_transit_relay/test/test_transit_server.py +++ b/src/wormhole_transit_relay/test/test_transit_server.py @@ -3,7 +3,10 @@ from binascii import hexlify from twisted.trial import unittest from .common import ServerBase from .. import transit_server -from ..server_state import MemoryUsageRecorder +from ..server_state import ( + MemoryUsageRecorder, + blur_size, +) def handshake(token, side=None): hs = b"please relay " + hexlify(token) @@ -21,22 +24,21 @@ class _Transit: ]) def test_blur_size(self): - blur = transit_server.blur_size - self.failUnlessEqual(blur(0), 0) - self.failUnlessEqual(blur(1), 10e3) - self.failUnlessEqual(blur(10e3), 10e3) - self.failUnlessEqual(blur(10e3+1), 20e3) - self.failUnlessEqual(blur(15e3), 20e3) - self.failUnlessEqual(blur(20e3), 20e3) - self.failUnlessEqual(blur(1e6), 1e6) - self.failUnlessEqual(blur(1e6+1), 2e6) - self.failUnlessEqual(blur(1.5e6), 2e6) - self.failUnlessEqual(blur(2e6), 2e6) - self.failUnlessEqual(blur(900e6), 900e6) - self.failUnlessEqual(blur(1000e6), 1000e6) - self.failUnlessEqual(blur(1050e6), 1100e6) - self.failUnlessEqual(blur(1100e6), 1100e6) - self.failUnlessEqual(blur(1150e6), 1200e6) + self.failUnlessEqual(blur_size(0), 0) + self.failUnlessEqual(blur_size(1), 10e3) + self.failUnlessEqual(blur_size(10e3), 10e3) + self.failUnlessEqual(blur_size(10e3+1), 20e3) + self.failUnlessEqual(blur_size(15e3), 20e3) + self.failUnlessEqual(blur_size(20e3), 20e3) + self.failUnlessEqual(blur_size(1e6), 1e6) + self.failUnlessEqual(blur_size(1e6+1), 2e6) + self.failUnlessEqual(blur_size(1.5e6), 2e6) + self.failUnlessEqual(blur_size(2e6), 2e6) + self.failUnlessEqual(blur_size(900e6), 900e6) + self.failUnlessEqual(blur_size(1000e6), 1000e6) + self.failUnlessEqual(blur_size(1050e6), 1100e6) + self.failUnlessEqual(blur_size(1100e6), 1100e6) + self.failUnlessEqual(blur_size(1150e6), 1200e6) def test_register(self): p1 = self.new_protocol() @@ -331,12 +333,15 @@ class _Transit: # hang up before sending anything p1.disconnect() + class TransitWithLogs(_Transit, ServerBase, unittest.TestCase): log_requests = True + class TransitWithoutLogs(_Transit, ServerBase, unittest.TestCase): log_requests = False + class Usage(ServerBase, unittest.TestCase): log_requests = True @@ -344,6 +349,7 @@ class Usage(ServerBase, unittest.TestCase): super(Usage, self).setUp() self._usage = MemoryUsageRecorder() self._transit_server.usage.add_backend(self._usage) +## self._transit_server.usage._blur_usage = None def test_empty(self): p1 = self.new_protocol() diff --git a/src/wormhole_transit_relay/transit_server.py b/src/wormhole_transit_relay/transit_server.py index fba84be..ee71b06 100644 --- a/src/wormhole_transit_relay/transit_server.py +++ b/src/wormhole_transit_relay/transit_server.py @@ -12,24 +12,14 @@ HOUR = 60*MINUTE DAY = 24*HOUR MB = 1000*1000 -def round_to(size, coarseness): - return int(coarseness*(1+int((size-1)/coarseness))) - -def blur_size(size): - if size == 0: - return 0 - if size < 1e6: - return round_to(size, 10e3) - if size < 1e9: - return round_to(size, 1e6) - return round_to(size, 100e6) - from wormhole_transit_relay.server_state import ( TransitServerState, PendingRequests, ActiveConnections, - UsageRecorder, + UsageTracker, + DatabaseUsageRecorder, + LogFileUsageRecorder, ITransitClient, ) from zope.interface import implementer @@ -255,7 +245,7 @@ class Transit(protocol.ServerFactory): def __init__(self, blur_usage, log_file, usage_db): self.active_connections = ActiveConnections() self.pending_requests = PendingRequests(self.active_connections) - self.usage = UsageRecorder() + self.usage = UsageTracker(blur_usage) self._blur_usage = blur_usage self._log_requests = blur_usage is None if self._blur_usage: @@ -264,10 +254,13 @@ class Transit(protocol.ServerFactory): else: log.msg("not blurring access times") self._debug_log = False - self._log_file = log_file +## self._log_file = log_file self._db = None if usage_db: self._db = get_db(usage_db) + self.usage.add_backend(DatabaseUsageRecorder(self._db)) + if log_file: + self.usage.add_backend(LogFileUsageRecorder(log_file)) self._rebooted = time.time() # we don't track TransitConnections until they submit a token ## self._pending_requests = defaultdict(set) # token -> set((side, TransitConnection)) @@ -317,6 +310,9 @@ class Transit(protocol.ServerFactory): " VALUES (?,?,?, ?,?)", (started, total_time, waiting_time, total_bytes, result)) + # XXXX aaaaaAA! okay, so just this one type of usage also + # does some other random stats-stuff; need more + # refactorizing self._update_stats() self._db.commit() From 3ae3bb74434bbfafdb49fa26f5b42a487632550c Mon Sep 17 00:00:00 2001 From: meejah Date: Fri, 12 Feb 2021 16:35:20 -0700 Subject: [PATCH 13/96] cleanup, remove dead code --- src/wormhole_transit_relay/server_state.py | 67 +++++++- src/wormhole_transit_relay/server_tap.py | 22 ++- src/wormhole_transit_relay/test/common.py | 6 +- .../test/test_service.py | 6 +- src/wormhole_transit_relay/test/test_stats.py | 11 +- src/wormhole_transit_relay/transit_server.py | 145 ++++++------------ 6 files changed, 131 insertions(+), 126 deletions(-) diff --git a/src/wormhole_transit_relay/server_state.py b/src/wormhole_transit_relay/server_state.py index 442adbf..1be9146 100644 --- a/src/wormhole_transit_relay/server_state.py +++ b/src/wormhole_transit_relay/server_state.py @@ -8,6 +8,7 @@ from zope.interface import ( Attribute, implementer, ) +from .database import get_db class ITransitClient(Interface): @@ -167,12 +168,41 @@ def blur_size(size): return round_to(size, 100e6) +def create_usage_tracker(blur_usage, log_file, usage_db): + """ + :param int blur_usage: see UsageTracker + + :param log_file: None or a file-like object to write JSON-encoded + lines of usage information to. + + :param usage_db: None or an sqlite3 database connection + + :returns: a new UsageTracker instance configured with backends. + """ + tracker = UsageTracker(blur_usage) + if usage_db: + db = get_db(usage_db) + tracker.add_backend(DatabaseUsageRecorder(db)) + if log_file: + tracker.add_backend(LogFileUsageRecorder(log_file)) + return tracker + + + + class UsageTracker(object): """ Tracks usage statistics of connections """ def __init__(self, blur_usage): + """ + :param int blur_usage: None or the number of seconds to use as a + window around which to blur time statistics (e.g. "60" means times + will be rounded to 1 minute intervals). When blur_usage is + non-zero, sizes will also be rounded into buckets of "one + megabyte", "one gigabyte" or "lots" + """ self._backends = set() self._blur_usage = blur_usage @@ -223,6 +253,9 @@ class UsageTracker(object): started = self._blur_usage * (started // self._blur_usage) total_bytes = blur_size(total_bytes) + # This is "a dict" instead of "kwargs" because we have to make + # it into a dict for the log use-case and in-memory/testing + # use-case anyway so this is less repeats of the names. self._notify_backends({ "started": started, "total_time": total_time, @@ -233,7 +266,7 @@ class UsageTracker(object): def _notify_backends(self, data): """ - Internal helper. Tell every backend we have about a new usage. + Internal helper. Tell every backend we have about a new usage record. """ for backend in self._backends: backend.record_usage(**data) @@ -241,8 +274,11 @@ class UsageTracker(object): class ActiveConnections(object): """ - Tracks active connections. A connection is 'active' when both - sides have shown up and they are glued together. + Tracks active connections. + + A connection is 'active' when both sides have shown up and they + are glued together (and thus could be passing data back and forth + if any is flowing). """ def __init__(self): self._connections = set() @@ -268,12 +304,20 @@ class ActiveConnections(object): class PendingRequests(object): """ - Tracks the tokens we have received from client connections and - maps them to their partner connections for requests that haven't - yet been 'glued together' (that is, one side hasn't yet shown up). + Tracks outstanding (non-"active") requests. + + We register client connections against the tokens we have + received. When the other side shows up we can thus match it to the + correct partner connection. At this point, the connection becomes + "active" is and is thus no longer "pending" and so will no longer + be in this collection. """ def __init__(self, active_connections): + """ + :param active_connections: an instance of ActiveConnections where + connections are put when both sides arrive. + """ self._requests = defaultdict(set) # token -> set((side, TransitConnection)) self._active = active_connections @@ -285,16 +329,23 @@ class PendingRequests(object): if token in self._requests: self._requests[token].discard((side, tc)) if not self._requests[token]: + # no more sides; token is dead del self._requests[token] self._active.unregister(tc) - def register_token(self, token, new_side, new_tc): + def register(self, token, new_side, new_tc): """ A client has connected and successfully offered a token (and optional 'side' token). If this is the first one for this token, we merely remember it. If it is the second side for this token we connect them together. + :param bytes token: the token for this connection. + + :param bytes new_side: None or the side token for this connection + + :param TransitServerState new_tc: the state-machine of the connection + :returns bool: True if we are the first side to register this token """ @@ -562,7 +613,7 @@ class TransitServerState(object): """ self._token = token self._side = side - self._first = self._pending_requests.register_token(token, side, self) + self._first = self._pending_requests.register(token, side, self) @_machine.state(initial=True) def listening(self): diff --git a/src/wormhole_transit_relay/server_tap.py b/src/wormhole_transit_relay/server_tap.py index 8fbfde2..cbf3efa 100644 --- a/src/wormhole_transit_relay/server_tap.py +++ b/src/wormhole_transit_relay/server_tap.py @@ -6,6 +6,7 @@ from twisted.application.internet import (TimerService, StreamServerEndpointService) from twisted.internet import endpoints from . import transit_server +from .server_state import create_usage_tracker from .increase_rlimits import increase_rlimits LONGDESC = """\ @@ -32,13 +33,18 @@ class Options(usage.Options): def makeService(config, reactor=reactor): increase_rlimits() ep = endpoints.serverFromString(reactor, config["port"]) # to listen - log_file = (os.fdopen(int(config["log-fd"]), "w") - if config["log-fd"] is not None - else None) - f = transit_server.Transit(blur_usage=config["blur-usage"], - log_file=log_file, - usage_db=config["usage-db"]) + log_file = ( + os.fdopen(int(config["log-fd"]), "w") + if config["log-fd"] is not None + else None + ) + usage = create_usage_tracker( + blur_usage=config["blur-usage"], + log_file=log_file, + usage_db=config["usage-db"], + ) + factory = transit_server.Transit(usage) parent = MultiService() - StreamServerEndpointService(ep, f).setServiceParent(parent) - TimerService(5*60.0, f.timerUpdateStats).setServiceParent(parent) + StreamServerEndpointService(ep, factory).setServiceParent(parent) +### FIXME TODO TimerService(5*60.0, factory.timerUpdateStats).setServiceParent(parent) return parent diff --git a/src/wormhole_transit_relay/test/common.py b/src/wormhole_transit_relay/test/common.py index 8073ee0..adbecf8 100644 --- a/src/wormhole_transit_relay/test/common.py +++ b/src/wormhole_transit_relay/test/common.py @@ -11,6 +11,8 @@ from zope.interface import ( from ..transit_server import ( Transit, ) +from ..transit_server import Transit +from ..server_state import create_usage_tracker class IRelayTestClient(Interface): @@ -42,6 +44,7 @@ class IRelayTestClient(Interface): Erase any received data to this point. """ + class ServerBase: log_requests = False @@ -62,11 +65,12 @@ class ServerBase: self.flush() def _setup_relay(self, blur_usage=None, log_file=None, usage_db=None): - self._transit_server = Transit( + usage = create_usage_tracker( blur_usage=blur_usage, log_file=log_file, usage_db=usage_db, ) + self._transit_server = Transit(usage) self._transit_server._debug_log = self.log_requests def new_protocol(self): diff --git a/src/wormhole_transit_relay/test/test_service.py b/src/wormhole_transit_relay/test/test_service.py index 003de32..f84b01a 100644 --- a/src/wormhole_transit_relay/test/test_service.py +++ b/src/wormhole_transit_relay/test/test_service.py @@ -11,7 +11,7 @@ class Service(unittest.TestCase): def test_defaults(self): o = server_tap.Options() o.parseOptions([]) - with mock.patch("wormhole_transit_relay.server_tap.transit_server.Transit") as t: + with mock.patch("wormhole_transit_relay.server_tap.create_usage_tracker") as t: s = server_tap.makeService(o) self.assertEqual(t.mock_calls, [mock.call(blur_usage=None, @@ -21,7 +21,7 @@ class Service(unittest.TestCase): def test_blur(self): o = server_tap.Options() o.parseOptions(["--blur-usage=60"]) - with mock.patch("wormhole_transit_relay.server_tap.transit_server.Transit") as t: + with mock.patch("wormhole_transit_relay.server_tap.create_usage_tracker") as t: server_tap.makeService(o) self.assertEqual(t.mock_calls, [mock.call(blur_usage=60, @@ -31,7 +31,7 @@ class Service(unittest.TestCase): o = server_tap.Options() o.parseOptions(["--log-fd=99"]) fd = object() - with mock.patch("wormhole_transit_relay.server_tap.transit_server.Transit") as t: + with mock.patch("wormhole_transit_relay.server_tap.create_usage_tracker") as t: with mock.patch("wormhole_transit_relay.server_tap.os.fdopen", return_value=fd) as f: server_tap.makeService(o) diff --git a/src/wormhole_transit_relay/test/test_stats.py b/src/wormhole_transit_relay/test/test_stats.py index 4c0b036..bce450b 100644 --- a/src/wormhole_transit_relay/test/test_stats.py +++ b/src/wormhole_transit_relay/test/test_stats.py @@ -6,6 +6,7 @@ except ImportError: import mock from twisted.trial import unittest from ..transit_server import Transit +from ..server_state import create_usage_tracker from .. import database class DB(unittest.TestCase): @@ -20,7 +21,7 @@ class DB(unittest.TestCase): os.mkdir(d) usage_db = os.path.join(d, "usage.sqlite") with mock.patch("time.time", return_value=T+0): - t = Transit(blur_usage=None, log_file=None, usage_db=usage_db) + t = Transit(create_usage_tracker(blur_usage=None, log_file=None, usage_db=usage_db)) db = self.open_db(usage_db) usage = list(t.usage._backends)[0] @@ -58,7 +59,7 @@ class DB(unittest.TestCase): waiting=0, connected=0)) def test_no_db(self): - t = Transit(blur_usage=None, log_file=None, usage_db=None) + t = Transit(create_usage_tracker(blur_usage=None, log_file=None, usage_db=None)) self.assertEqual(0, len(t.usage._backends)) @@ -66,7 +67,7 @@ class LogToStdout(unittest.TestCase): def test_log(self): # emit lines of JSON to log_file, if set log_file = io.StringIO() - t = Transit(blur_usage=None, log_file=log_file, usage_db=None) + t = Transit(create_usage_tracker(blur_usage=None, log_file=log_file, usage_db=None)) with mock.patch("time.time", return_value=133): t.usage.record( started=123, @@ -84,7 +85,7 @@ class LogToStdout(unittest.TestCase): # if blurring is enabled, timestamps should be rounded to the # requested amount, and sizes should be rounded up too log_file = io.StringIO() - t = Transit(blur_usage=60, log_file=log_file, usage_db=None) + t = Transit(create_usage_tracker(blur_usage=60, log_file=log_file, usage_db=None)) with mock.patch("time.time", return_value=123 + 10): t.usage.record( @@ -101,7 +102,7 @@ class LogToStdout(unittest.TestCase): "mood": "happy"}) def test_do_not_log(self): - t = Transit(blur_usage=60, log_file=None, usage_db=None) + t = Transit(create_usage_tracker(blur_usage=60, log_file=None, usage_db=None)) t.usage.record( started=123, buddy_started=124, diff --git a/src/wormhole_transit_relay/transit_server.py b/src/wormhole_transit_relay/transit_server.py index ee71b06..a897bd4 100644 --- a/src/wormhole_transit_relay/transit_server.py +++ b/src/wormhole_transit_relay/transit_server.py @@ -4,7 +4,6 @@ from collections import defaultdict from twisted.python import log from twisted.internet import protocol from twisted.protocols.basic import LineReceiver -from .database import get_db SECONDS = 1.0 MINUTE = 60*SECONDS @@ -79,7 +78,7 @@ class TransitConnection(LineReceiver): self.factory.usage, ) self._state.connection_made(self) - self._log_requests = self.factory._log_requests +## self._log_requests = self.factory._log_requests try: self.transport.setTcpKeepAlive(True) except AttributeError: @@ -131,8 +130,8 @@ class TransitConnection(LineReceiver): # there will be two producer/consumer pairs. def __buddy_disconnected(self): - if self._log_requests: - log.msg("buddy_disconnected %s" % self.describeToken()) +## if self._log_requests: +## log.msg("buddy_disconnected %s" % self.describeToken()) self._buddy = None self._mood = "jilted" self.transport.loseConnection() @@ -210,117 +209,61 @@ class TransitConnection(LineReceiver): class Transit(protocol.ServerFactory): - # I manage pairs of simultaneous connections to a secondary TCP port, - # both forwarded to the other. Clients must begin each connection with - # "please relay TOKEN for SIDE\n" (or a legacy form without the "for - # SIDE"). Two connections match if they use the same TOKEN and have - # different SIDEs (the redundant connections are dropped when a match is - # made). Legacy connections match any with the same TOKEN, ignoring SIDE - # (so two legacy connections will match each other). + """ + I manage pairs of simultaneous connections to a secondary TCP port, + both forwarded to the other. Clients must begin each connection with + "please relay TOKEN for SIDE\n" (or a legacy form without the "for + SIDE"). Two connections match if they use the same TOKEN and have + different SIDEs (the redundant connections are dropped when a match is + made). Legacy connections match any with the same TOKEN, ignoring SIDE + (so two legacy connections will match each other). - # I will send "ok\n" when the matching connection is established, or - # disconnect if no matching connection is made within MAX_WAIT_TIME - # seconds. I will disconnect if you send data before the "ok\n". All data - # you get after the "ok\n" will be from the other side. You will not - # receive "ok\n" until the other side has also connected and submitted a - # matching token (and differing SIDE). + I will send "ok\n" when the matching connection is established, or + disconnect if no matching connection is made within MAX_WAIT_TIME + seconds. I will disconnect if you send data before the "ok\n". All data + you get after the "ok\n" will be from the other side. You will not + receive "ok\n" until the other side has also connected and submitted a + matching token (and differing SIDE). - # In addition, the connections will be dropped after MAXLENGTH bytes have - # been sent by either side, or MAXTIME seconds have elapsed after the - # matching connections were established. A future API will reveal these - # limits to clients instead of causing mysterious spontaneous failures. + In addition, the connections will be dropped after MAXLENGTH bytes have + been sent by either side, or MAXTIME seconds have elapsed after the + matching connections were established. A future API will reveal these + limits to clients instead of causing mysterious spontaneous failures. - # These relay connections are not half-closeable (unlike full TCP - # connections, applications will not receive any data after half-closing - # their outgoing side). Applications must negotiate shutdown with their - # peer and not close the connection until all data has finished - # transferring in both directions. Applications which only need to send - # data in one direction can use close() as usual. + These relay connections are not half-closeable (unlike full TCP + connections, applications will not receive any data after half-closing + their outgoing side). Applications must negotiate shutdown with their + peer and not close the connection until all data has finished + transferring in both directions. Applications which only need to send + data in one direction can use close() as usual. + """ + # TODO: unused MAX_WAIT_TIME = 30*SECONDS + # TODO: unused MAXLENGTH = 10*MB + # TODO: unused MAXTIME = 60*SECONDS protocol = TransitConnection - def __init__(self, blur_usage, log_file, usage_db): + def __init__(self, usage): self.active_connections = ActiveConnections() self.pending_requests = PendingRequests(self.active_connections) - self.usage = UsageTracker(blur_usage) - self._blur_usage = blur_usage - self._log_requests = blur_usage is None - if self._blur_usage: - log.msg("blurring access times to %d seconds" % self._blur_usage) - log.msg("not logging Transit connections to Twisted log") - else: - log.msg("not blurring access times") + self.usage = usage + if False: + # these logs-message should be made by the usage-tracker + # .. or in the "tap" setup? + if blur_usage: + log.msg("blurring access times to %d seconds" % self._blur_usage) + log.msg("not logging Transit connections to Twisted log") + else: + log.msg("not blurring access times") self._debug_log = False -## self._log_file = log_file - self._db = None - if usage_db: - self._db = get_db(usage_db) - self.usage.add_backend(DatabaseUsageRecorder(self._db)) - if log_file: - self.usage.add_backend(LogFileUsageRecorder(log_file)) + self._rebooted = time.time() - # we don't track TransitConnections until they submit a token -## self._pending_requests = defaultdict(set) # token -> set((side, TransitConnection)) -## self._active_connections = set() # TransitConnection - - def transitFinished(self, tc, token, side, description): - if token in self._pending_requests: - side_tc = (side, tc) - self._pending_requests[token].discard(side_tc) - if not self._pending_requests[token]: # set is now empty - del self._pending_requests[token] - if self._debug_log: - log.msg("transitFinished %s" % (description,)) - self._active_connections.discard(tc) - # we could update the usage database "current" row immediately, or wait - # until the 5-minute timer updates it. If we update it now, just after - # losing a connection, we should probably also update it just after - # establishing one (at the end of connection_got_token). For now I'm - # going to omit these, but maybe someday we'll turn them both on. The - # consequence is that a manual execution of the munin scripts ("munin - # run wormhole_transit_active") will give the wrong value just after a - # connect/disconnect event. Actual munin graphs should accurately - # report connections that last longer than the 5-minute sampling - # window, which is what we actually care about. - #self.timerUpdateStats() - - def recordUsage(self, started, result, total_bytes, - total_time, waiting_time): - if self._debug_log: - log.msg(format="Transit.recordUsage {bytes}B", bytes=total_bytes) - if self._blur_usage: - started = self._blur_usage * (started // self._blur_usage) - total_bytes = blur_size(total_bytes) - if self._log_file is not None: - data = {"started": started, - "total_time": total_time, - "waiting_time": waiting_time, - "total_bytes": total_bytes, - "mood": result, - } - self._log_file.write(json.dumps(data)+"\n") - self._log_file.flush() - if self._db: - self._db.execute("INSERT INTO `usage`" - " (`started`, `total_time`, `waiting_time`," - " `total_bytes`, `result`)" - " VALUES (?,?,?, ?,?)", - (started, total_time, waiting_time, - total_bytes, result)) - # XXXX aaaaaAA! okay, so just this one type of usage also - # does some other random stats-stuff; need more - # refactorizing - self._update_stats() - self._db.commit() - - def timerUpdateStats(self): - if self._db: - self._update_stats() - self._db.commit() + # XXX TODO self._rebooted and the below could be in a separate + # object? or in the DatabaseUsageRecorder .. but not here def _update_stats(self): # current status: should be zero when idle rebooted = self._rebooted From 9557bbf75a965801da31bce7b9932dd11bafb936 Mon Sep 17 00:00:00 2001 From: meejah Date: Fri, 12 Feb 2021 16:57:16 -0700 Subject: [PATCH 14/96] we never remove backends --- src/wormhole_transit_relay/server_state.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/wormhole_transit_relay/server_state.py b/src/wormhole_transit_relay/server_state.py index 1be9146..e2923e8 100644 --- a/src/wormhole_transit_relay/server_state.py +++ b/src/wormhole_transit_relay/server_state.py @@ -214,14 +214,6 @@ class UsageTracker(object): """ self._backends.add(backend) - def remove_backend(self, backend): - """ - Remove an existing backend - - :param IUsageWriter backend: the backend to remove - """ - self._backends.remove(backend) - def record(self, started, buddy_started, result, bytes_sent, buddy_bytes): """ :param int started: timestamp when our connection started From 83de03c8c6651bb85ec5a6946deb10184e4777b4 Mon Sep 17 00:00:00 2001 From: meejah Date: Fri, 12 Feb 2021 16:57:49 -0700 Subject: [PATCH 15/96] remove old test-code --- src/wormhole_transit_relay/server_state.py | 73 ---------------------- 1 file changed, 73 deletions(-) diff --git a/src/wormhole_transit_relay/server_state.py b/src/wormhole_transit_relay/server_state.py index e2923e8..e5c7bc3 100644 --- a/src/wormhole_transit_relay/server_state.py +++ b/src/wormhole_transit_relay/server_state.py @@ -37,38 +37,6 @@ class ITransitClient(Interface): """ -@implementer(ITransitClient) -class TestClient(object): - _partner = None - _data = b"" - _started_time = time.time() - - @property - def started_time(self): - return _started_time - - def send_to_partner(self, data): - print("{} GOT:{}".format(id(self), repr(data))) - if self._partner: - self._partner._client.send(data) - - def send(self, data): - print("{} SEND:{}".format(id(self), repr(data))) - self._data += data - - def disconnect(self): - print("disconnect") - - def connect_partner(self, other): - print("connect_partner: {} <--> {}".format(id(self), id(other))) - assert self._partner is None, "double partner" - self._partner = other - - def disconnect_partner(self): - assert self._partner is not None, "no partner" - print("disconnect_partner: {}".format(id(self._partner))) - - class IUsageWriter(Interface): """ Records actual usage statistics in some way @@ -716,44 +684,3 @@ class TransitServerState(object): enter=done, outputs=[], ) - - - - -# actions: -# - send("ok") -# - send("bad handshake") -# - disconnect -# - ... - -if __name__ == "__main__": - active = ActiveConnections() - pending = PendingRequests(active) - - server0 = TransitServerState(pending) - client0 = TestClient() - server1 = TransitServerState(pending) - client1 = TestClient() - server0.connection_made(client0) - server0.please_relay(b"bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb") - - # this would be an error, because our partner hasn't shown up yet - # print(server0.got_bytes(b"asdf")) - - print("about to relay client1") - server1.connection_made(client1) - server1.please_relay(b"bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb") - print("done") - - # XXX the PendingRequests stuff should do this, going "by hand" for now -# server0.got_partner(client1) -# server1.got_partner(client0) - - # should be connected now - server0.got_bytes(b"asdf") - # client1 should receive b"asdf" - - server0.connection_lost() - print("----[ received data on both sides ]----") - print("client0:{}".format(repr(client0._data))) - print("client1:{}".format(repr(client1._data))) From 03906ffe0d550b83fedd8c6bb0553510f93dcf4c Mon Sep 17 00:00:00 2001 From: meejah Date: Fri, 12 Feb 2021 17:39:55 -0700 Subject: [PATCH 16/96] pass actual database, not config --- src/wormhole_transit_relay/server_state.py | 5 +---- src/wormhole_transit_relay/server_tap.py | 3 ++- src/wormhole_transit_relay/test/test_stats.py | 9 +++------ 3 files changed, 6 insertions(+), 11 deletions(-) diff --git a/src/wormhole_transit_relay/server_state.py b/src/wormhole_transit_relay/server_state.py index e5c7bc3..9f83736 100644 --- a/src/wormhole_transit_relay/server_state.py +++ b/src/wormhole_transit_relay/server_state.py @@ -149,15 +149,12 @@ def create_usage_tracker(blur_usage, log_file, usage_db): """ tracker = UsageTracker(blur_usage) if usage_db: - db = get_db(usage_db) - tracker.add_backend(DatabaseUsageRecorder(db)) + tracker.add_backend(DatabaseUsageRecorder(usage_db)) if log_file: tracker.add_backend(LogFileUsageRecorder(log_file)) return tracker - - class UsageTracker(object): """ Tracks usage statistics of connections diff --git a/src/wormhole_transit_relay/server_tap.py b/src/wormhole_transit_relay/server_tap.py index cbf3efa..7f89409 100644 --- a/src/wormhole_transit_relay/server_tap.py +++ b/src/wormhole_transit_relay/server_tap.py @@ -38,10 +38,11 @@ def makeService(config, reactor=reactor): if config["log-fd"] is not None else None ) + db = None if config["usage-db"] is None else get_db(config["usage-db"]) usage = create_usage_tracker( blur_usage=config["blur-usage"], log_file=log_file, - usage_db=config["usage-db"], + usage_db=db, ) factory = transit_server.Transit(usage) parent = MultiService() diff --git a/src/wormhole_transit_relay/test/test_stats.py b/src/wormhole_transit_relay/test/test_stats.py index bce450b..6cdfc7b 100644 --- a/src/wormhole_transit_relay/test/test_stats.py +++ b/src/wormhole_transit_relay/test/test_stats.py @@ -10,19 +10,16 @@ from ..server_state import create_usage_tracker from .. import database class DB(unittest.TestCase): - def open_db(self, dbfile): - db = sqlite3.connect(dbfile) - database._initialize_db_connection(db) - return db def test_db(self): T = 1519075308.0 d = self.mktemp() os.mkdir(d) usage_db = os.path.join(d, "usage.sqlite") + db = database.get_db(usage_db) with mock.patch("time.time", return_value=T+0): - t = Transit(create_usage_tracker(blur_usage=None, log_file=None, usage_db=usage_db)) - db = self.open_db(usage_db) + t = Transit(create_usage_tracker(blur_usage=None, log_file=None, usage_db=db)) + self.assertEqual(len(t.usage._backends), 1) usage = list(t.usage._backends)[0] with mock.patch("time.time", return_value=T+1): From 215a0f350b1bae30cef774ec89ee2afdfd2b8f94 Mon Sep 17 00:00:00 2001 From: meejah Date: Fri, 12 Feb 2021 17:47:26 -0700 Subject: [PATCH 17/96] restore 2 missing log-lines --- src/wormhole_transit_relay/server_state.py | 6 ++++++ src/wormhole_transit_relay/transit_server.py | 8 -------- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/wormhole_transit_relay/server_state.py b/src/wormhole_transit_relay/server_state.py index 9f83736..60e3101 100644 --- a/src/wormhole_transit_relay/server_state.py +++ b/src/wormhole_transit_relay/server_state.py @@ -8,6 +8,7 @@ from zope.interface import ( Attribute, implementer, ) +from twisted.python import log from .database import get_db @@ -170,6 +171,11 @@ class UsageTracker(object): """ self._backends = set() self._blur_usage = blur_usage + if blur_usage: + log.msg("blurring access times to %d seconds" % self._blur_usage) +## XXX log.msg("not logging Transit connections to Twisted log") + else: + log.msg("not blurring access times") def add_backend(self, backend): """ diff --git a/src/wormhole_transit_relay/transit_server.py b/src/wormhole_transit_relay/transit_server.py index a897bd4..29e3279 100644 --- a/src/wormhole_transit_relay/transit_server.py +++ b/src/wormhole_transit_relay/transit_server.py @@ -250,14 +250,6 @@ class Transit(protocol.ServerFactory): self.active_connections = ActiveConnections() self.pending_requests = PendingRequests(self.active_connections) self.usage = usage - if False: - # these logs-message should be made by the usage-tracker - # .. or in the "tap" setup? - if blur_usage: - log.msg("blurring access times to %d seconds" % self._blur_usage) - log.msg("not logging Transit connections to Twisted log") - else: - log.msg("not blurring access times") self._debug_log = False self._rebooted = time.time() From 60e70bac3cc75589794ec712c46b1f10ee11d08a Mon Sep 17 00:00:00 2001 From: meejah Date: Fri, 12 Feb 2021 17:47:40 -0700 Subject: [PATCH 18/96] cleanup / dead code --- src/wormhole_transit_relay/test/test_stats.py | 2 +- .../test/test_transit_server.py | 1 - src/wormhole_transit_relay/transit_server.py | 63 ++----------------- 3 files changed, 6 insertions(+), 60 deletions(-) diff --git a/src/wormhole_transit_relay/test/test_stats.py b/src/wormhole_transit_relay/test/test_stats.py index 6cdfc7b..0ce46a5 100644 --- a/src/wormhole_transit_relay/test/test_stats.py +++ b/src/wormhole_transit_relay/test/test_stats.py @@ -90,7 +90,7 @@ class LogToStdout(unittest.TestCase): buddy_started=125, result="happy", bytes_sent=11999, - buddy_bytes=8001, + buddy_bytes=0, ) print(log_file.getvalue()) self.assertEqual(json.loads(log_file.getvalue()), diff --git a/src/wormhole_transit_relay/test/test_transit_server.py b/src/wormhole_transit_relay/test/test_transit_server.py index 60cc8c9..bb0633f 100644 --- a/src/wormhole_transit_relay/test/test_transit_server.py +++ b/src/wormhole_transit_relay/test/test_transit_server.py @@ -2,7 +2,6 @@ from __future__ import print_function, unicode_literals from binascii import hexlify from twisted.trial import unittest from .common import ServerBase -from .. import transit_server from ..server_state import ( MemoryUsageRecorder, blur_size, diff --git a/src/wormhole_transit_relay/transit_server.py b/src/wormhole_transit_relay/transit_server.py index 29e3279..7865c22 100644 --- a/src/wormhole_transit_relay/transit_server.py +++ b/src/wormhole_transit_relay/transit_server.py @@ -1,6 +1,6 @@ from __future__ import print_function, unicode_literals -import re, time, json -from collections import defaultdict +import re +import time from twisted.python import log from twisted.internet import protocol from twisted.protocols.basic import LineReceiver @@ -16,9 +16,6 @@ from wormhole_transit_relay.server_state import ( TransitServerState, PendingRequests, ActiveConnections, - UsageTracker, - DatabaseUsageRecorder, - LogFileUsageRecorder, ITransitClient, ) from zope.interface import implementer @@ -152,59 +149,9 @@ class TransitConnection(LineReceiver): def connectionLost(self, reason): self._state.connection_lost() - - # XXX FIXME record usage - - if False: - # Record usage. There are eight cases: - # * n0: we haven't gotten a full handshake yet (empty) - # * n1: the handshake failed, not a real client (errory) - # * n2: real client disconnected before any buddy appeared (lonely) - # * n3: real client closed as redundant after buddy appears (redundant) - # * n4: real client connected first, buddy closes first (jilted) - # * n5: real client connected first, buddy close last (happy) - # * n6: real client connected last, buddy closes first (jilted) - # * n7: real client connected last, buddy closes last (happy) - - # * non-connected clients (0,1,2,3) always write a usage record - # * for connected clients, whoever disconnects first gets to write the - # usage record (5, 7). The last disconnect doesn't write a record. - - if self._mood == "empty": # 0 - assert not self._buddy - self.factory.recordUsage(self._started, "empty", 0, - total_time, None) - elif self._mood == "errory": # 1 - assert not self._buddy - self.factory.recordUsage(self._started, "errory", 0, - total_time, None) - elif self._mood == "redundant": # 3 - assert not self._buddy - self.factory.recordUsage(self._started, "redundant", 0, - total_time, None) - elif self._mood == "jilted": # 4 or 6 - # we were connected, but our buddy hung up on us. They record the - # usage event, we do not - pass - elif self._mood == "lonely": # 2 - assert not self._buddy - self.factory.recordUsage(self._started, "lonely", 0, - total_time, None) - else: # 5 or 7 - # we were connected, we hung up first. We record the event. - assert self._mood == "happy", self._mood - assert self._buddy - starts = [self._started, self._buddy._started] - total_time = finished - min(starts) - waiting_time = max(starts) - min(starts) - total_bytes = self._total_sent + self._buddy._total_sent - self.factory.recordUsage(self._started, "happy", total_bytes, - total_time, waiting_time) - - if self._buddy: - self._buddy.buddy_disconnected() - # self.factory.transitFinished(self, self._got_token, self._got_side, - # self.describeToken()) +# XXX this probably resulted in a log message we've not refactored yet +# self.factory.transitFinished(self, self._got_token, self._got_side, +# self.describeToken()) From ca555097634cbdb43254a56ff9e5d4115cbd6305 Mon Sep 17 00:00:00 2001 From: meejah Date: Fri, 12 Feb 2021 18:16:30 -0700 Subject: [PATCH 19/96] fix global stats-gathering / recording --- src/wormhole_transit_relay/server_state.py | 26 ++++++++- src/wormhole_transit_relay/server_tap.py | 4 +- src/wormhole_transit_relay/test/common.py | 2 +- src/wormhole_transit_relay/test/test_stats.py | 53 ++++++++++++++----- src/wormhole_transit_relay/transit_server.py | 35 ++++++------ 5 files changed, 81 insertions(+), 39 deletions(-) diff --git a/src/wormhole_transit_relay/server_state.py b/src/wormhole_transit_relay/server_state.py index 60e3101..613eee3 100644 --- a/src/wormhole_transit_relay/server_state.py +++ b/src/wormhole_transit_relay/server_state.py @@ -118,8 +118,10 @@ class DatabaseUsageRecorder: " VALUES (?,?,?,?,?)", (started, total_time, waiting_time, total_bytes, mood) ) - # XXX FIXME see comment in transit_server - #self._update_stats() + # original code did "self._update_stats()" here, thus causing + # "global" stats update on every connection update .. should + # we repeat this behavior, or really only record every + # 60-seconds with the timer? self._db.commit() @@ -227,6 +229,26 @@ class UsageTracker(object): "mood": result, }) + def update_stats(self, rebooted, updated, connected, waiting, + incomplete_bytes): + """ + Update general statistics. + """ + # in original code, this is only recorded in the database + # .. perhaps a better way to do this, but .. + for backend in self._backends: + if isinstance(backend, DatabaseUsageRecorder): + backend._db.execute("DELETE FROM `current`") + backend._db.execute( + "INSERT INTO `current`" + " (`rebooted`, `updated`, `connected`, `waiting`," + " `incomplete_bytes`)" + " VALUES (?, ?, ?, ?, ?)", + (int(rebooted), int(updated), connected, waiting, + incomplete_bytes) + ) + + def _notify_backends(self, data): """ Internal helper. Tell every backend we have about a new usage record. diff --git a/src/wormhole_transit_relay/server_tap.py b/src/wormhole_transit_relay/server_tap.py index 7f89409..704d404 100644 --- a/src/wormhole_transit_relay/server_tap.py +++ b/src/wormhole_transit_relay/server_tap.py @@ -44,8 +44,8 @@ def makeService(config, reactor=reactor): log_file=log_file, usage_db=db, ) - factory = transit_server.Transit(usage) + factory = transit_server.Transit(usage, reactor.seconds) parent = MultiService() StreamServerEndpointService(ep, factory).setServiceParent(parent) -### FIXME TODO TimerService(5*60.0, factory.timerUpdateStats).setServiceParent(parent) + TimerService(5*60.0, factory.update_stats).setServiceParent(parent) return parent diff --git a/src/wormhole_transit_relay/test/common.py b/src/wormhole_transit_relay/test/common.py index adbecf8..d78b844 100644 --- a/src/wormhole_transit_relay/test/common.py +++ b/src/wormhole_transit_relay/test/common.py @@ -70,7 +70,7 @@ class ServerBase: log_file=log_file, usage_db=usage_db, ) - self._transit_server = Transit(usage) + self._transit_server = Transit(usage, lambda: 123456789.0) self._transit_server._debug_log = self.log_requests def new_protocol(self): diff --git a/src/wormhole_transit_relay/test/test_stats.py b/src/wormhole_transit_relay/test/test_stats.py index 0ce46a5..390b524 100644 --- a/src/wormhole_transit_relay/test/test_stats.py +++ b/src/wormhole_transit_relay/test/test_stats.py @@ -12,19 +12,31 @@ from .. import database class DB(unittest.TestCase): def test_db(self): + T = 1519075308.0 + + class Timer: + t = T + def __call__(self): + return self.t + get_time = Timer() + d = self.mktemp() os.mkdir(d) usage_db = os.path.join(d, "usage.sqlite") db = database.get_db(usage_db) - with mock.patch("time.time", return_value=T+0): - t = Transit(create_usage_tracker(blur_usage=None, log_file=None, usage_db=db)) + t = Transit( + create_usage_tracker(blur_usage=None, log_file=None, usage_db=db), + get_time, + ) self.assertEqual(len(t.usage._backends), 1) usage = list(t.usage._backends)[0] - with mock.patch("time.time", return_value=T+1): - usage.record_usage(started=123, mood="happy", total_bytes=100, - total_time=10, waiting_time=2) + get_time.t = T + 1 + usage.record_usage(started=123, mood="happy", total_bytes=100, + total_time=10, waiting_time=2) + t.update_stats() + self.assertEqual(db.execute("SELECT * FROM `usage`").fetchall(), [dict(result="happy", started=123, total_bytes=100, total_time=10, waiting_time=2), @@ -34,9 +46,10 @@ class DB(unittest.TestCase): incomplete_bytes=0, waiting=0, connected=0)) - with mock.patch("time.time", return_value=T+2): - usage.record_usage(started=150, mood="errory", total_bytes=200, - total_time=11, waiting_time=3) + get_time.t = T + 2 + usage.record_usage(started=150, mood="errory", total_bytes=200, + total_time=11, waiting_time=3) + t.update_stats() self.assertEqual(db.execute("SELECT * FROM `usage`").fetchall(), [dict(result="happy", started=123, total_bytes=100, total_time=10, waiting_time=2), @@ -48,15 +61,18 @@ class DB(unittest.TestCase): incomplete_bytes=0, waiting=0, connected=0)) - with mock.patch("time.time", return_value=T+3): - t.timerUpdateStats() + get_time.t = T + 3 + t.update_stats() self.assertEqual(db.execute("SELECT * FROM `current`").fetchone(), dict(rebooted=T+0, updated=T+3, incomplete_bytes=0, waiting=0, connected=0)) def test_no_db(self): - t = Transit(create_usage_tracker(blur_usage=None, log_file=None, usage_db=None)) + t = Transit( + create_usage_tracker(blur_usage=None, log_file=None, usage_db=None), + lambda: 0, + ) self.assertEqual(0, len(t.usage._backends)) @@ -64,7 +80,10 @@ class LogToStdout(unittest.TestCase): def test_log(self): # emit lines of JSON to log_file, if set log_file = io.StringIO() - t = Transit(create_usage_tracker(blur_usage=None, log_file=log_file, usage_db=None)) + t = Transit( + create_usage_tracker(blur_usage=None, log_file=log_file, usage_db=None), + lambda: 0, + ) with mock.patch("time.time", return_value=133): t.usage.record( started=123, @@ -82,7 +101,10 @@ class LogToStdout(unittest.TestCase): # if blurring is enabled, timestamps should be rounded to the # requested amount, and sizes should be rounded up too log_file = io.StringIO() - t = Transit(create_usage_tracker(blur_usage=60, log_file=log_file, usage_db=None)) + t = Transit( + create_usage_tracker(blur_usage=60, log_file=log_file, usage_db=None), + lambda: 0, + ) with mock.patch("time.time", return_value=123 + 10): t.usage.record( @@ -99,7 +121,10 @@ class LogToStdout(unittest.TestCase): "mood": "happy"}) def test_do_not_log(self): - t = Transit(create_usage_tracker(blur_usage=60, log_file=None, usage_db=None)) + t = Transit( + create_usage_tracker(blur_usage=60, log_file=None, usage_db=None), + lambda: 0, + ) t.usage.record( started=123, buddy_started=124, diff --git a/src/wormhole_transit_relay/transit_server.py b/src/wormhole_transit_relay/transit_server.py index 7865c22..640e972 100644 --- a/src/wormhole_transit_relay/transit_server.py +++ b/src/wormhole_transit_relay/transit_server.py @@ -193,33 +193,28 @@ class Transit(protocol.ServerFactory): MAXTIME = 60*SECONDS protocol = TransitConnection - def __init__(self, usage): + def __init__(self, usage, get_timestamp): self.active_connections = ActiveConnections() self.pending_requests = PendingRequests(self.active_connections) self.usage = usage self._debug_log = False + self._timestamp = get_timestamp + self._rebooted = self._timestamp() - self._rebooted = time.time() - - # XXX TODO self._rebooted and the below could be in a separate - # object? or in the DatabaseUsageRecorder .. but not here - def _update_stats(self): - # current status: should be zero when idle - rebooted = self._rebooted - updated = time.time() - connected = len(self._active_connections) / 2 + def update_stats(self): # TODO: when a connection is half-closed, len(active) will be odd. a # moment later (hopefully) the other side will disconnect, but # _update_stats isn't updated until later. - waiting = len(self._pending_requests) + # "waiting" doesn't count multiple parallel connections from the same # side - incomplete_bytes = sum(tc._total_sent - for tc in self._active_connections) - self._db.execute("DELETE FROM `current`") - self._db.execute("INSERT INTO `current`" - " (`rebooted`, `updated`, `connected`, `waiting`," - " `incomplete_bytes`)" - " VALUES (?, ?, ?, ?, ?)", - (rebooted, updated, connected, waiting, - incomplete_bytes)) + self.usage.update_stats( + rebooted=self._rebooted, + updated=self._timestamp(), + connected=len(self.active_connections._connections), + waiting=len(self.pending_requests._requests), + incomplete_bytes=sum( + tc._total_sent + for tc in self.active_connections._connections + ), + ) From 7e58767ac1b8414977c759334d308a942dd00eab Mon Sep 17 00:00:00 2001 From: meejah Date: Fri, 12 Feb 2021 18:17:09 -0700 Subject: [PATCH 20/96] pyflakes --- src/wormhole_transit_relay/server_state.py | 1 - src/wormhole_transit_relay/server_tap.py | 1 + src/wormhole_transit_relay/test/test_stats.py | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/wormhole_transit_relay/server_state.py b/src/wormhole_transit_relay/server_state.py index 613eee3..d89d335 100644 --- a/src/wormhole_transit_relay/server_state.py +++ b/src/wormhole_transit_relay/server_state.py @@ -9,7 +9,6 @@ from zope.interface import ( implementer, ) from twisted.python import log -from .database import get_db class ITransitClient(Interface): diff --git a/src/wormhole_transit_relay/server_tap.py b/src/wormhole_transit_relay/server_tap.py index 704d404..b4028d2 100644 --- a/src/wormhole_transit_relay/server_tap.py +++ b/src/wormhole_transit_relay/server_tap.py @@ -8,6 +8,7 @@ from twisted.internet import endpoints from . import transit_server from .server_state import create_usage_tracker from .increase_rlimits import increase_rlimits +from .database import get_db LONGDESC = """\ This plugin sets up a 'Transit Relay' server for magic-wormhole. This service diff --git a/src/wormhole_transit_relay/test/test_stats.py b/src/wormhole_transit_relay/test/test_stats.py index 390b524..806aafb 100644 --- a/src/wormhole_transit_relay/test/test_stats.py +++ b/src/wormhole_transit_relay/test/test_stats.py @@ -1,5 +1,5 @@ from __future__ import print_function, unicode_literals -import os, io, json, sqlite3 +import os, io, json try: from unittest import mock except ImportError: From 942f2041400bccfd7f5ff7d0e655e58c1a3efae1 Mon Sep 17 00:00:00 2001 From: meejah Date: Fri, 12 Feb 2021 20:05:27 -0700 Subject: [PATCH 21/96] log again --- src/wormhole_transit_relay/transit_server.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/wormhole_transit_relay/transit_server.py b/src/wormhole_transit_relay/transit_server.py index 640e972..18f3b6d 100644 --- a/src/wormhole_transit_relay/transit_server.py +++ b/src/wormhole_transit_relay/transit_server.py @@ -52,6 +52,7 @@ class TransitConnection(LineReceiver): """ ITransitClient API """ + print("buddy_disconnected {}".format(self._buddy.get_token())) self._buddy._client.transport.loseConnection() self._buddy = None From b03801d155c617f6b14cff592cc83b19ecb9f63b Mon Sep 17 00:00:00 2001 From: meejah Date: Fri, 12 Feb 2021 20:52:14 -0700 Subject: [PATCH 22/96] guard --- src/wormhole_transit_relay/transit_server.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/wormhole_transit_relay/transit_server.py b/src/wormhole_transit_relay/transit_server.py index 18f3b6d..4c71616 100644 --- a/src/wormhole_transit_relay/transit_server.py +++ b/src/wormhole_transit_relay/transit_server.py @@ -52,9 +52,10 @@ class TransitConnection(LineReceiver): """ ITransitClient API """ - print("buddy_disconnected {}".format(self._buddy.get_token())) - self._buddy._client.transport.loseConnection() - self._buddy = None + if self._buddy is not None: + # print("buddy_disconnected {}".format(self._buddy.get_token())) + self._buddy._client.transport.loseConnection() + self._buddy = None def describeToken(self): d = "-" From e0f5f556ccb29e5acfd15f8c2f75faf31de8fc24 Mon Sep 17 00:00:00 2001 From: meejah Date: Fri, 12 Feb 2021 21:46:32 -0700 Subject: [PATCH 23/96] does this ever get called? --- src/wormhole_transit_relay/server_state.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/wormhole_transit_relay/server_state.py b/src/wormhole_transit_relay/server_state.py index d89d335..b747965 100644 --- a/src/wormhole_transit_relay/server_state.py +++ b/src/wormhole_transit_relay/server_state.py @@ -692,11 +692,11 @@ class TransitServerState(object): enter=done, outputs=[_mood_happy_if_first, _disconnect_partner, _unregister, _record_usage], ) - relaying.upon( - partner_connection_lost, - enter=done, - outputs=[_mood_happy_if_second, _disconnect, _unregister, _record_usage], - ) +# relaying.upon( +# partner_connection_lost, +# enter=done, +# outputs=[_mood_happy_if_second, _disconnect, _unregister], # no _record_usage; other side will +# ) done.upon( connection_lost, From c09f15d866f8d786857b60e3521a98ac5dfbc7bc Mon Sep 17 00:00:00 2001 From: meejah Date: Fri, 12 Feb 2021 21:59:47 -0700 Subject: [PATCH 24/96] re-instate log message --- src/wormhole_transit_relay/transit_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/wormhole_transit_relay/transit_server.py b/src/wormhole_transit_relay/transit_server.py index 4c71616..7bd7d91 100644 --- a/src/wormhole_transit_relay/transit_server.py +++ b/src/wormhole_transit_relay/transit_server.py @@ -53,7 +53,7 @@ class TransitConnection(LineReceiver): ITransitClient API """ if self._buddy is not None: - # print("buddy_disconnected {}".format(self._buddy.get_token())) + log.msg("buddy_disconnected {}".format(self._buddy.get_token())) self._buddy._client.transport.loseConnection() self._buddy = None From 9cf42c560b5014220a19a2ff9cae1e41792527c2 Mon Sep 17 00:00:00 2001 From: meejah Date: Fri, 12 Feb 2021 22:00:57 -0700 Subject: [PATCH 25/96] not sure we can hit this state at all --- src/wormhole_transit_relay/server_state.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/src/wormhole_transit_relay/server_state.py b/src/wormhole_transit_relay/server_state.py index b747965..949176e 100644 --- a/src/wormhole_transit_relay/server_state.py +++ b/src/wormhole_transit_relay/server_state.py @@ -572,17 +572,6 @@ class TransitServerState(object): else: self._mood = "jilted" - @_machine.output() - def _mood_happy_if_second(self): - """ - We disconnected second so we're only happy if we also connected - second. - """ - if self._first: - self._mood = "jilted" - else: - self._mood = "happy" - def _real_register_token_for_side(self, token, side): """ A client has connected and sent a valid version 1 or version 2 @@ -692,11 +681,6 @@ class TransitServerState(object): enter=done, outputs=[_mood_happy_if_first, _disconnect_partner, _unregister, _record_usage], ) -# relaying.upon( -# partner_connection_lost, -# enter=done, -# outputs=[_mood_happy_if_second, _disconnect, _unregister], # no _record_usage; other side will -# ) done.upon( connection_lost, From 2b2f06d98485929b1ce02d74cbda08c4c3695e4f Mon Sep 17 00:00:00 2001 From: meejah Date: Fri, 12 Feb 2021 22:01:35 -0700 Subject: [PATCH 26/96] unused --- src/wormhole_transit_relay/server_state.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/wormhole_transit_relay/server_state.py b/src/wormhole_transit_relay/server_state.py index 949176e..a773f87 100644 --- a/src/wormhole_transit_relay/server_state.py +++ b/src/wormhole_transit_relay/server_state.py @@ -407,12 +407,6 @@ class TransitServerState(object): d += "-" return d - def get_mood(self): - """ - :returns str: description of the current 'mood' of the connection - """ - return self._mood - @_machine.input() def connection_made(self, client): """ From 34dd36158e18e2df501345a065939d3b3d186d2b Mon Sep 17 00:00:00 2001 From: meejah Date: Fri, 12 Feb 2021 22:11:17 -0700 Subject: [PATCH 27/96] unused --- src/wormhole_transit_relay/test/test_transit_server.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/wormhole_transit_relay/test/test_transit_server.py b/src/wormhole_transit_relay/test/test_transit_server.py index bb0633f..74f51e5 100644 --- a/src/wormhole_transit_relay/test/test_transit_server.py +++ b/src/wormhole_transit_relay/test/test_transit_server.py @@ -348,7 +348,6 @@ class Usage(ServerBase, unittest.TestCase): super(Usage, self).setUp() self._usage = MemoryUsageRecorder() self._transit_server.usage.add_backend(self._usage) -## self._transit_server.usage._blur_usage = None def test_empty(self): p1 = self.new_protocol() From b192b5ca71b437badf7915ca4652b266d60c100e Mon Sep 17 00:00:00 2001 From: meejah Date: Fri, 12 Feb 2021 22:14:46 -0700 Subject: [PATCH 28/96] dead code, correct input --- src/wormhole_transit_relay/transit_server.py | 28 ++++---------------- 1 file changed, 5 insertions(+), 23 deletions(-) diff --git a/src/wormhole_transit_relay/transit_server.py b/src/wormhole_transit_relay/transit_server.py index 7bd7d91..a9cdddd 100644 --- a/src/wormhole_transit_relay/transit_server.py +++ b/src/wormhole_transit_relay/transit_server.py @@ -112,41 +112,23 @@ class TransitConnection(LineReceiver): self._state.got_bytes(data) def _got_handshake(self, token, side): - self._state.please_relay_for_side(token, side) - # self._mood = "lonely" # until buddy connects + if side is not None: + self._state.please_relay_for_side(token, side) + else: + self._state.please_relay(token) self.setRawMode() - def __buddy_connected(self, them): - self._buddy = them - self._mood = "happy" - self.sendLine(b"ok") - self._sent_ok = True - # Connect the two as a producer/consumer pair. We use streaming=True, - # so this expects the IPushProducer interface, and uses - # pauseProducing() to throttle, and resumeProducing() to unthrottle. - self._buddy.transport.registerProducer(self.transport, True) - # The Transit object calls buddy_connected() on both protocols, so - # there will be two producer/consumer pairs. - - def __buddy_disconnected(self): -## if self._log_requests: -## log.msg("buddy_disconnected %s" % self.describeToken()) - self._buddy = None - self._mood = "jilted" - self.transport.loseConnection() - def disconnect_error(self): # we haven't finished the handshake, so there are no tokens tracking # us - self._mood = "errory" self.transport.loseConnection() + # XXX probably should be logged by state? if self.factory._debug_log: log.msg("transitFailed %r" % self) def disconnect_redundant(self): # this is called if a buddy connected and we were found unnecessary. # Any token-tracking cleanup will have been done before we're called. - self._mood = "redundant" self.transport.loseConnection() def connectionLost(self, reason): From c8fbc22120aef8fd7d40a276fb06827623c9e50a Mon Sep 17 00:00:00 2001 From: meejah Date: Fri, 19 Feb 2021 17:17:42 -0700 Subject: [PATCH 29/96] dead code --- src/wormhole_transit_relay/transit_server.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/src/wormhole_transit_relay/transit_server.py b/src/wormhole_transit_relay/transit_server.py index a9cdddd..0b23a66 100644 --- a/src/wormhole_transit_relay/transit_server.py +++ b/src/wormhole_transit_relay/transit_server.py @@ -57,16 +57,6 @@ class TransitConnection(LineReceiver): self._buddy._client.transport.loseConnection() self._buddy = None - def describeToken(self): - d = "-" - if self._got_token: - d = self._got_token[:16].decode("ascii") - if self._got_side: - d += "-" + self._got_side.decode("ascii") - else: - d += "-" - return d - def connectionMade(self): # ideally more like self._reactor.seconds() ... but Twisted # doesn't have a good way to get the reactor for a protocol @@ -136,6 +126,7 @@ class TransitConnection(LineReceiver): # XXX this probably resulted in a log message we've not refactored yet # self.factory.transitFinished(self, self._got_token, self._got_side, # self.describeToken()) +# XXX describeToken -> self._state.get_token() From c2147ee9850701c818c50281cb907e2a60891390 Mon Sep 17 00:00:00 2001 From: meejah Date: Fri, 19 Feb 2021 17:18:08 -0700 Subject: [PATCH 30/96] change from review: inline _got_handshake --- src/wormhole_transit_relay/transit_server.py | 26 ++++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/wormhole_transit_relay/transit_server.py b/src/wormhole_transit_relay/transit_server.py index 0b23a66..d9dabfd 100644 --- a/src/wormhole_transit_relay/transit_server.py +++ b/src/wormhole_transit_relay/transit_server.py @@ -74,25 +74,32 @@ class TransitConnection(LineReceiver): pass def lineReceived(self, line): + """ + LineReceiver API + """ # old: "please relay {64}\n" + token = None old = re.search(br"^please relay (\w{64})$", line) if old: token = old.group(1) - return self._got_handshake(token, None) + self._state.please_relay(token) # new: "please relay {64} for side {16}\n" new = re.search(br"^please relay (\w{64}) for side (\w{16})$", line) if new: token = new.group(1) side = new.group(2) - return self._got_handshake(token, side) + self._state.please_relay_for_side(token, side) - # we should have been switched to "raw data" mode on the first - # line received (after which rawDataReceived() is called for - # all bytes) so getting here means a bad handshake. - return self._state.bad_token() + if token is None: + self._state.bad_token() + else: + self.setRawMode() def rawDataReceived(self, data): + """ + LineReceiver API + """ # We are an IPushProducer to our buddy's IConsumer, so they'll # throttle us (by calling pauseProducing()) when their outbound # buffer is full (e.g. when their downstream pipe is full). In @@ -101,13 +108,6 @@ class TransitConnection(LineReceiver): # receiver can handle it. self._state.got_bytes(data) - def _got_handshake(self, token, side): - if side is not None: - self._state.please_relay_for_side(token, side) - else: - self._state.please_relay(token) - self.setRawMode() - def disconnect_error(self): # we haven't finished the handshake, so there are no tokens tracking # us From 34d039c38cd28d905a418559ae5ca0459b0de4f6 Mon Sep 17 00:00:00 2001 From: meejah Date: Tue, 23 Feb 2021 13:31:55 -0700 Subject: [PATCH 31/96] hack in prelim websocket support --- src/wormhole_transit_relay/server_state.py | 8 ++ src/wormhole_transit_relay/server_tap.py | 3 +- src/wormhole_transit_relay/transit_server.py | 110 ++++++++++++++++++- 3 files changed, 119 insertions(+), 2 deletions(-) diff --git a/src/wormhole_transit_relay/server_state.py b/src/wormhole_transit_relay/server_state.py index a773f87..86cf9b6 100644 --- a/src/wormhole_transit_relay/server_state.py +++ b/src/wormhole_transit_relay/server_state.py @@ -612,11 +612,19 @@ class TransitServerState(object): Terminal state """ + # need a listening.upon(connection_lost) for special websocket + # case where handshake fails? + listening.upon( connection_made, enter=wait_relay, outputs=[_remember_client], ) + listening.upon( + connection_lost, + enter=done, + outputs=[_mood_errory], + ) wait_relay.upon( please_relay, diff --git a/src/wormhole_transit_relay/server_tap.py b/src/wormhole_transit_relay/server_tap.py index b4028d2..4caa5c9 100644 --- a/src/wormhole_transit_relay/server_tap.py +++ b/src/wormhole_transit_relay/server_tap.py @@ -45,7 +45,8 @@ def makeService(config, reactor=reactor): log_file=log_file, usage_db=db, ) - factory = transit_server.Transit(usage, reactor.seconds) + ##factory = transit_server.Transit(usage, reactor.seconds) + factory = transit_server.WebSocketTransit(usage, reactor.seconds) parent = MultiService() StreamServerEndpointService(ep, factory).setServiceParent(parent) TimerService(5*60.0, factory.update_stats).setServiceParent(parent) diff --git a/src/wormhole_transit_relay/transit_server.py b/src/wormhole_transit_relay/transit_server.py index d9dabfd..af6760c 100644 --- a/src/wormhole_transit_relay/transit_server.py +++ b/src/wormhole_transit_relay/transit_server.py @@ -4,6 +4,12 @@ import time from twisted.python import log from twisted.internet import protocol from twisted.protocols.basic import LineReceiver +from autobahn.twisted.websocket import ( + WebSocketServerProtocol, + WebSocketServerFactory, +) + + SECONDS = 1.0 MINUTE = 60*SECONDS @@ -129,8 +135,12 @@ class TransitConnection(LineReceiver): # XXX describeToken -> self._state.get_token() +# XXX multiple-inheritance sucks... +# ("Transit" wants to be "the factory" but the base class is slightly +# different for websocket versus "normal" socket .. so maybe we need +# to make Transit *not* the factory directly?) -class Transit(protocol.ServerFactory): +class Transit(WebSocketServerFactory):#protocol.ServerFactory): """ I manage pairs of simultaneous connections to a secondary TCP port, both forwarded to the other. Clients must begin each connection with @@ -169,6 +179,7 @@ class Transit(protocol.ServerFactory): protocol = TransitConnection def __init__(self, usage, get_timestamp): + super(Transit, self).__init__() self.active_connections = ActiveConnections() self.pending_requests = PendingRequests(self.active_connections) self.usage = usage @@ -193,3 +204,100 @@ class Transit(protocol.ServerFactory): for tc in self.active_connections._connections ), ) + + +@implementer(ITransitClient) +class WebSocketTransitConnection(WebSocketServerProtocol): + started_time = None + + def send(self, data): + """ + ITransitClient API + """ + self.sendMessage(data, isBinary=True) + + def disconnect(self): + """ + ITransitClient API + """ + self.sendClose(1000, None) + + def connect_partner(self, other): + """ + ITransitClient API + """ + self._buddy = other + + def disconnect_partner(self): + """ + ITransitClient API + """ + if self._buddy is not None: + log.msg("buddy_disconnected {}".format(self._buddy.get_token())) + self._buddy._client.transport.loseConnection() + self._buddy = None + + def onConnect(self, request): + """ + IWebSocketChannel API + """ + print("onConnect: {}".format(request)) + # ideally more like self._reactor.seconds() ... but Twisted + # doesn't have a good way to get the reactor for a protocol + # (besides "use the global one") + print("protocols: {}".format(request.protocols)) + return None#"transit_relay" + + def connectionMade(self): + """ + IProtocol API + """ + print("connectionMade") + self.started_time = time.time() + self._first_message = True + self._state = TransitServerState( + self.factory.pending_requests, + self.factory.usage, + ) + return super(WebSocketTransitConnection, self).connectionMade() + + def onOpen(self): + print("onOpen") + self._state.connection_made(self) + + def onMessage(self, payload, isBinary): + """ + We may have a 'handshake' on our hands or we may just have some bytes to relay + """ + print("onMessage isBinary={}: {}".format(isBinary, payload)) + if self._first_message: + self._first_message = False + token = None + old = re.search(br"^please relay (\w{64})$", payload) + if old: + token = old.group(1) + self._state.please_relay(token) + + # new: "please relay {64} for side {16}\n" + new = re.search(br"^please relay (\w{64}) for side (\w{16})$", payload) + if new: + token = new.group(1) + side = new.group(2) + self._state.please_relay_for_side(token, side) + + if token is None: + self._state.bad_token() + else: + self._state.got_bytes(payload) + + def onClose(self, wasClean, code, reason): + """ + IWebSocketChannel API + """ + self._state.connection_lost() + # XXX "transit finished", etc + + +class WebSocketTransit(Transit, WebSocketServerFactory): + protocol = WebSocketTransitConnection + websocket_protocols = ["transit_relay"] From 2b78fbec8fe0528c165eb150feca1429a2d90d77 Mon Sep 17 00:00:00 2001 From: meejah Date: Tue, 23 Feb 2021 13:32:02 -0700 Subject: [PATCH 32/96] test WebSocket client --- ws_client.py | 74 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 74 insertions(+) create mode 100644 ws_client.py diff --git a/ws_client.py b/ws_client.py new file mode 100644 index 0000000..a8dee88 --- /dev/null +++ b/ws_client.py @@ -0,0 +1,74 @@ +from __future__ import print_function + +from twisted.internet import endpoints +from twisted.internet.defer import ( + Deferred, + inlineCallbacks, +) +from twisted.internet.task import react +from twisted.internet.error import ( + ConnectionDone, +) +from twisted.internet.protocol import ( + Protocol, + Factory, +) +from twisted.protocols.basic import LineReceiver +from twisted.application.internet import StreamServerEndpointService + +from autobahn.twisted.websocket import ( + WebSocketClientProtocol, + WebSocketClientFactory, +) +from autobahn.websocket import types + + +class RelayEchoClient(WebSocketClientProtocol): + + def onOpen(self): + self.data = b"" + self.sendMessage(u"please relay {}".format(self.factory.token).encode("ascii"), True) + + def onConnecting(self, details): + return types.ConnectingRequest( + protocols=["transit_relay"], + ) + + def onMessage(self, data, isBinary): + print(">onMessage: {} bytes".format(len(data))) + print(data, isBinary) + if data == b"ok\n": + self.factory.ready.callback(None) + else: + self.data += data + return True + + def onClose(self, wasClean, code, reason): + print(">onClose", wasClean, code, reason) + self.factory.done.callback(reason) + + +@react +@inlineCallbacks +def main(reactor): + #ep = endpoints.clientFromString(reactor, "ws://localhost:4001/") + from twisted.plugins.autobahn_endpoints import AutobahnClientEndpoint + ep = endpoints.clientFromString(reactor, "tcp:localhost:4001") + f = WebSocketClientFactory("ws://127.0.0.1:4001/") + f.protocol = RelayEchoClient + # NB: write our own factory, probably.. + f.token = "a" * 64 + f.done = Deferred() + f.ready = Deferred() + proto = yield ep.connect(f) + # proto_d = ep.connect(f) + # print("proto_d", proto_d) + # proto = yield proto_d + print("proto", proto, f.done) + yield f.ready + print("ready") + import sys + if len(sys.argv) > 2: + proto.sendMessage(b"it's a message", True) + yield proto.sendClose() + yield f.done From 1a461aa461519394abf31de974b5e6cede8f9d24 Mon Sep 17 00:00:00 2001 From: meejah Date: Tue, 23 Feb 2021 13:47:22 -0700 Subject: [PATCH 33/96] haxxor --- src/wormhole_transit_relay/server_tap.py | 25 ++++++++++++---- src/wormhole_transit_relay/transit_server.py | 30 ++++++++------------ ws_client.py | 8 +++--- 3 files changed, 36 insertions(+), 27 deletions(-) diff --git a/src/wormhole_transit_relay/server_tap.py b/src/wormhole_transit_relay/server_tap.py index 4caa5c9..42e986a 100644 --- a/src/wormhole_transit_relay/server_tap.py +++ b/src/wormhole_transit_relay/server_tap.py @@ -5,6 +5,10 @@ from twisted.application.service import MultiService from twisted.application.internet import (TimerService, StreamServerEndpointService) from twisted.internet import endpoints +from twisted.internet import protocol + +from autobahn.twisted.websocket import WebSocketServerFactory + from . import transit_server from .server_state import create_usage_tracker from .increase_rlimits import increase_rlimits @@ -33,7 +37,9 @@ class Options(usage.Options): def makeService(config, reactor=reactor): increase_rlimits() - ep = endpoints.serverFromString(reactor, config["port"]) # to listen + tcp_ep = endpoints.serverFromString(reactor, config["port"]) # to listen + # XXX FIXME proper websocket option + ws_ep = endpoints.serverFromString(reactor, "tcp:4002:interface=localhost") # to listen log_file = ( os.fdopen(int(config["log-fd"]), "w") if config["log-fd"] is not None @@ -45,9 +51,18 @@ def makeService(config, reactor=reactor): log_file=log_file, usage_db=db, ) - ##factory = transit_server.Transit(usage, reactor.seconds) - factory = transit_server.WebSocketTransit(usage, reactor.seconds) + transit = transit_server.Transit(usage, reactor.seconds) + tcp_factory = protocol.ServerFactory() + tcp_factory.protocol = transit_server.TransitConnection + + ws_factory = WebSocketServerFactory("ws://localhost:4002") # FIXME: url + ws_factory.protocol = transit_server.WebSocketTransitConnection + ws_factory.websocket_protocols = ["transit_relay"] + + tcp_factory.transit = transit + ws_factory.transit = transit parent = MultiService() - StreamServerEndpointService(ep, factory).setServiceParent(parent) - TimerService(5*60.0, factory.update_stats).setServiceParent(parent) + StreamServerEndpointService(tcp_ep, tcp_factory).setServiceParent(parent) + StreamServerEndpointService(ws_ep, ws_factory).setServiceParent(parent) + TimerService(5*60.0, transit.update_stats).setServiceParent(parent) return parent diff --git a/src/wormhole_transit_relay/transit_server.py b/src/wormhole_transit_relay/transit_server.py index af6760c..5ca92f1 100644 --- a/src/wormhole_transit_relay/transit_server.py +++ b/src/wormhole_transit_relay/transit_server.py @@ -4,10 +4,8 @@ import time from twisted.python import log from twisted.internet import protocol from twisted.protocols.basic import LineReceiver -from autobahn.twisted.websocket import ( - WebSocketServerProtocol, - WebSocketServerFactory, -) +from autobahn.twisted.websocket import WebSocketServerProtocol + @@ -69,8 +67,8 @@ class TransitConnection(LineReceiver): # (besides "use the global one") self.started_time = time.time() self._state = TransitServerState( - self.factory.pending_requests, - self.factory.usage, + self.factory.transit.pending_requests, + self.factory.transit.usage, ) self._state.connection_made(self) ## self._log_requests = self.factory._log_requests @@ -119,7 +117,7 @@ class TransitConnection(LineReceiver): # us self.transport.loseConnection() # XXX probably should be logged by state? - if self.factory._debug_log: + if self.factory.transit._debug_log: log.msg("transitFailed %r" % self) def disconnect_redundant(self): @@ -140,7 +138,9 @@ class TransitConnection(LineReceiver): # different for websocket versus "normal" socket .. so maybe we need # to make Transit *not* the factory directly?) -class Transit(WebSocketServerFactory):#protocol.ServerFactory): +##WebSocketServerFactory):#protocol.ServerFactory): + +class Transit(object): """ I manage pairs of simultaneous connections to a secondary TCP port, both forwarded to the other. Clients must begin each connection with @@ -176,10 +176,9 @@ class Transit(WebSocketServerFactory):#protocol.ServerFactory): MAXLENGTH = 10*MB # TODO: unused MAXTIME = 60*SECONDS - protocol = TransitConnection +## protocol = TransitConnection def __init__(self, usage, get_timestamp): - super(Transit, self).__init__() self.active_connections = ActiveConnections() self.pending_requests = PendingRequests(self.active_connections) self.usage = usage @@ -253,13 +252,13 @@ class WebSocketTransitConnection(WebSocketServerProtocol): IProtocol API """ print("connectionMade") + super(WebSocketTransitConnection, self).connectionMade() self.started_time = time.time() self._first_message = True self._state = TransitServerState( - self.factory.pending_requests, - self.factory.usage, + self.factory.transit.pending_requests, + self.factory.transit.usage, ) - return super(WebSocketTransitConnection, self).connectionMade() def onOpen(self): print("onOpen") @@ -296,8 +295,3 @@ class WebSocketTransitConnection(WebSocketServerProtocol): """ self._state.connection_lost() # XXX "transit finished", etc - - -class WebSocketTransit(Transit, WebSocketServerFactory): - protocol = WebSocketTransitConnection - websocket_protocols = ["transit_relay"] diff --git a/ws_client.py b/ws_client.py index a8dee88..03158a8 100644 --- a/ws_client.py +++ b/ws_client.py @@ -51,10 +51,8 @@ class RelayEchoClient(WebSocketClientProtocol): @react @inlineCallbacks def main(reactor): - #ep = endpoints.clientFromString(reactor, "ws://localhost:4001/") - from twisted.plugins.autobahn_endpoints import AutobahnClientEndpoint - ep = endpoints.clientFromString(reactor, "tcp:localhost:4001") - f = WebSocketClientFactory("ws://127.0.0.1:4001/") + ep = endpoints.clientFromString(reactor, "tcp:localhost:4002") + f = WebSocketClientFactory("ws://127.0.0.1:4002/") f.protocol = RelayEchoClient # NB: write our own factory, probably.. f.token = "a" * 64 @@ -72,3 +70,5 @@ def main(reactor): proto.sendMessage(b"it's a message", True) yield proto.sendClose() yield f.done + print("relayed {} bytes:".format(len(proto.data))) + print(proto.data.decode("utf8")) From 5df5f86e42e5b0c1069a94bfc1d72cb9ba3901f2 Mon Sep 17 00:00:00 2001 From: meejah Date: Tue, 23 Feb 2021 13:48:43 -0700 Subject: [PATCH 34/96] not just localhost --- src/wormhole_transit_relay/server_tap.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/wormhole_transit_relay/server_tap.py b/src/wormhole_transit_relay/server_tap.py index 42e986a..fd0c7ff 100644 --- a/src/wormhole_transit_relay/server_tap.py +++ b/src/wormhole_transit_relay/server_tap.py @@ -39,7 +39,7 @@ def makeService(config, reactor=reactor): increase_rlimits() tcp_ep = endpoints.serverFromString(reactor, config["port"]) # to listen # XXX FIXME proper websocket option - ws_ep = endpoints.serverFromString(reactor, "tcp:4002:interface=localhost") # to listen + ws_ep = endpoints.serverFromString(reactor, "tcp:4002") # to listen log_file = ( os.fdopen(int(config["log-fd"]), "w") if config["log-fd"] is not None From 21af1f68a378c64a8db9d90c941afba4cb1af7aa Mon Sep 17 00:00:00 2001 From: meejah Date: Thu, 4 Mar 2021 22:37:15 -0700 Subject: [PATCH 35/96] Transit is no longer a factory --- src/wormhole_transit_relay/test/common.py | 3 ++- src/wormhole_transit_relay/test/test_transit_server.py | 10 +++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/wormhole_transit_relay/test/common.py b/src/wormhole_transit_relay/test/common.py index d78b844..5a3bbf6 100644 --- a/src/wormhole_transit_relay/test/common.py +++ b/src/wormhole_transit_relay/test/common.py @@ -10,8 +10,9 @@ from zope.interface import ( ) from ..transit_server import ( Transit, + TransitConnection, ) -from ..transit_server import Transit +from twisted.internet.protocol import ServerFactory from ..server_state import create_usage_tracker diff --git a/src/wormhole_transit_relay/test/test_transit_server.py b/src/wormhole_transit_relay/test/test_transit_server.py index 74f51e5..1854a17 100644 --- a/src/wormhole_transit_relay/test/test_transit_server.py +++ b/src/wormhole_transit_relay/test/test_transit_server.py @@ -19,7 +19,7 @@ class _Transit: return sum([ len(potentials) for potentials - in self._transit_server.pending_requests._requests.values() + in self._transit.pending_requests._requests.values() ]) def test_blur_size(self): @@ -54,7 +54,7 @@ class _Transit: self.assertEqual(self.count(), 0) # the token should be removed too - self.assertEqual(len(self._transit_server.pending_requests._requests), 0) + self.assertEqual(len(self._transit.pending_requests._requests), 0) def test_both_unsided(self): p1 = self.new_protocol() @@ -190,8 +190,8 @@ class _Transit: p3.send(handshake(token1, side=side2)) self.flush() self.assertEqual(self.count(), 0) - self.assertEqual(len(self._transit_server.pending_requests._requests), 0) - self.assertEqual(len(self._transit_server.active_connections._connections), 2) + self.assertEqual(len(self._transit.pending_requests._requests), 0) + self.assertEqual(len(self._transit.active_connections._connections), 2) # That will trigger a disconnect on exactly one of (p1 or p2). # The other connection should still be connected self.assertEqual(sum([int(t.connected) for t in [p1, p2]]), 1) @@ -347,7 +347,7 @@ class Usage(ServerBase, unittest.TestCase): def setUp(self): super(Usage, self).setUp() self._usage = MemoryUsageRecorder() - self._transit_server.usage.add_backend(self._usage) + self._transit.usage.add_backend(self._usage) def test_empty(self): p1 = self.new_protocol() From 521056615013aea1357eb1b55cdc0101a1ce5781 Mon Sep 17 00:00:00 2001 From: meejah Date: Thu, 4 Mar 2021 23:30:59 -0700 Subject: [PATCH 36/96] websocket version of tests, with handshake --- src/wormhole_transit_relay/test/common.py | 1 - .../test/test_transit_server.py | 54 +++++++++++++++++++ src/wormhole_transit_relay/transit_server.py | 5 ++ 3 files changed, 59 insertions(+), 1 deletion(-) diff --git a/src/wormhole_transit_relay/test/common.py b/src/wormhole_transit_relay/test/common.py index 5a3bbf6..8690e91 100644 --- a/src/wormhole_transit_relay/test/common.py +++ b/src/wormhole_transit_relay/test/common.py @@ -72,7 +72,6 @@ class ServerBase: usage_db=usage_db, ) self._transit_server = Transit(usage, lambda: 123456789.0) - self._transit_server._debug_log = self.log_requests def new_protocol(self): """ diff --git a/src/wormhole_transit_relay/test/test_transit_server.py b/src/wormhole_transit_relay/test/test_transit_server.py index 1854a17..9fee87d 100644 --- a/src/wormhole_transit_relay/test/test_transit_server.py +++ b/src/wormhole_transit_relay/test/test_transit_server.py @@ -1,11 +1,19 @@ from __future__ import print_function, unicode_literals +import base64 from binascii import hexlify from twisted.trial import unittest +from twisted.test import proto_helpers from .common import ServerBase from ..server_state import ( MemoryUsageRecorder, blur_size, ) +from ..transit_server import ( + WebSocketTransitConnection, +) + +from autobahn.twisted.websocket import WebSocketServerFactory + def handshake(token, side=None): hs = b"please relay " + hexlify(token) @@ -458,3 +466,49 @@ class Usage(ServerBase, unittest.TestCase): self.flush() self.assertEqual(len(self._usage.events), 3, self._usage) self.assertEqual(self._usage.events[2]["mood"], "happy") + + +class UsageWebSockets(Usage): + """ + All the tests of 'Usage' except with a WebSocket (instead of TCP) + transport. + + This overrides ServerBase.new_protocol to achieve this. It might + be nicer to parametrize these tests in a way that doesn't use + inheritance .. but all the support etc classes are set up that way + already. + """ + + def new_protocol(self): + ws_factory = WebSocketServerFactory("ws://localhost:4002") # FIXME: url + ws_factory.protocol = WebSocketTransitConnection + ws_factory.websocket_protocols = ["transit_relay"] + ws_factory.transit = self._transit + + protocol = ws_factory.buildProtocol(('127.0.0.1', 4002)) + transport = proto_helpers.StringTransportWithDisconnection() + protocol.makeConnection(transport) + transport.protocol = protocol + + class Producer: + pass + protocol.registerProducer(Producer(), False) +## protocol.transport.abortConnection = protocol.transport.loseConnection + + # unlike in the TCP case, we need to drive a WebSocket + # handshake through the server first. + options = {} + self._websocket_key = b"0" * 16 + request = ( + "GET /ws HTTP/1.1\x0d\x0a" + "Host: 127.0.0.1:4002\x0d\x0a" + "Upgrade: WebSocket\x0d\x0a" + "Connection: Upgrade\x0d\x0a" + "Sec-WebSocket-Key: {}\x0d\x0a" + "Sec-WebSocket-Protocol: transit-relay\x0d\x0a" + "Sec-WebSocket-Version: 13\x0d\x0a" + "\x0d\x0a" + ).format(base64.b64encode(self._websocket_key).decode()) + protocol.dataReceived(request.encode("utf8")) + + return protocol diff --git a/src/wormhole_transit_relay/transit_server.py b/src/wormhole_transit_relay/transit_server.py index 5ca92f1..9e5a419 100644 --- a/src/wormhole_transit_relay/transit_server.py +++ b/src/wormhole_transit_relay/transit_server.py @@ -289,6 +289,11 @@ class WebSocketTransitConnection(WebSocketServerProtocol): else: self._state.got_bytes(payload) + def disconnect_redundant(self): + # this is called if a buddy connected and we were found unnecessary. + # Any token-tracking cleanup will have been done before we're called. + self.transport.loseConnection() + def onClose(self, wasClean, code, reason): """ IWebSocketChannel API From 002773d79fe431bed99f99fe1777bd4629f9f04c Mon Sep 17 00:00:00 2001 From: meejah Date: Wed, 31 Mar 2021 19:47:29 -0600 Subject: [PATCH 37/96] WIP: first passing IOPump test --- .../test/test_transit_server.py | 66 +++++++++++-------- src/wormhole_transit_relay/transit_server.py | 1 + 2 files changed, 38 insertions(+), 29 deletions(-) diff --git a/src/wormhole_transit_relay/test/test_transit_server.py b/src/wormhole_transit_relay/test/test_transit_server.py index 9fee87d..6101501 100644 --- a/src/wormhole_transit_relay/test/test_transit_server.py +++ b/src/wormhole_transit_relay/test/test_transit_server.py @@ -3,6 +3,7 @@ import base64 from binascii import hexlify from twisted.trial import unittest from twisted.test import proto_helpers +from twisted.internet.defer import inlineCallbacks from .common import ServerBase from ..server_state import ( MemoryUsageRecorder, @@ -367,8 +368,9 @@ class Usage(ServerBase, unittest.TestCase): self.assertEqual(len(self._usage.events), 1, self._usage) self.assertEqual(self._usage.events[0]["mood"], "empty", self._usage) + @inlineCallbacks def test_short(self): - p1 = self.new_protocol() + p1 = yield self.new_protocol() # hang up before sending a complete handshake p1.send(b"short") p1.disconnect() @@ -451,6 +453,7 @@ class Usage(ServerBase, unittest.TestCase): p1c.disconnect() self.flush() + print(self._usage.events) self.assertEqual(len(self._usage.events), 1, self._usage) self.assertEqual(self._usage.events[0]["mood"], "lonely") @@ -468,6 +471,16 @@ class Usage(ServerBase, unittest.TestCase): self.assertEqual(self._usage.events[2]["mood"], "happy") +from twisted.test import iosim +from twisted.internet.testing import MemoryReactorClock +from twisted.internet.address import IPv4Address +from autobahn.twisted.testing import ( + create_pumper, + create_memory_agent, + MemoryReactorClockResolver, +) + + class UsageWebSockets(Usage): """ All the tests of 'Usage' except with a WebSocket (instead of TCP) @@ -479,36 +492,31 @@ class UsageWebSockets(Usage): already. """ + def setUp(self): + super(UsageWebSockets, self).setUp() + self._pump = create_pumper() + self._reactor = MemoryReactorClockResolver() + return self._pump.start() + + def tearDown(self): + return self._pump.stop() + + @inlineCallbacks def new_protocol(self): - ws_factory = WebSocketServerFactory("ws://localhost:4002") # FIXME: url - ws_factory.protocol = WebSocketTransitConnection - ws_factory.websocket_protocols = ["transit_relay"] - ws_factory.transit = self._transit - protocol = ws_factory.buildProtocol(('127.0.0.1', 4002)) - transport = proto_helpers.StringTransportWithDisconnection() - protocol.makeConnection(transport) - transport.protocol = protocol + class RelayFactory(WebSocketServerFactory): + protocol = WebSocketTransitConnection + websocket_protocols = ["transit_relay"] + transit = self._transit - class Producer: - pass - protocol.registerProducer(Producer(), False) -## protocol.transport.abortConnection = protocol.transport.loseConnection + server_factory = RelayFactory("ws://localhost:4002") - # unlike in the TCP case, we need to drive a WebSocket - # handshake through the server first. - options = {} - self._websocket_key = b"0" * 16 - request = ( - "GET /ws HTTP/1.1\x0d\x0a" - "Host: 127.0.0.1:4002\x0d\x0a" - "Upgrade: WebSocket\x0d\x0a" - "Connection: Upgrade\x0d\x0a" - "Sec-WebSocket-Key: {}\x0d\x0a" - "Sec-WebSocket-Protocol: transit-relay\x0d\x0a" - "Sec-WebSocket-Version: 13\x0d\x0a" - "\x0d\x0a" - ).format(base64.b64encode(self._websocket_key).decode()) - protocol.dataReceived(request.encode("utf8")) + agent = create_memory_agent( + self._reactor, + self._pump, + lambda: server_factory.buildProtocol(IPv4Address("TCP", "127.0.0.1", 31337)), + ) + client_proto = yield agent.open("ws://127.0.0.1:4002/", dict()) + print("PROTO", client_proto) + return client_proto - return protocol diff --git a/src/wormhole_transit_relay/transit_server.py b/src/wormhole_transit_relay/transit_server.py index 9e5a419..133fbd5 100644 --- a/src/wormhole_transit_relay/transit_server.py +++ b/src/wormhole_transit_relay/transit_server.py @@ -298,5 +298,6 @@ class WebSocketTransitConnection(WebSocketServerProtocol): """ IWebSocketChannel API """ + print("onClose", wasClean, code, reason) self._state.connection_lost() # XXX "transit finished", etc From dd1cc7d52041f7582ae6c779849b507abf07c5a9 Mon Sep 17 00:00:00 2001 From: meejah Date: Wed, 31 Mar 2021 20:11:04 -0600 Subject: [PATCH 38/96] upgrade all tests; reactor_turn(); remove debug --- .../test/test_transit_server.py | 84 ++++++++++++------- src/wormhole_transit_relay/transit_server.py | 14 ++-- 2 files changed, 59 insertions(+), 39 deletions(-) diff --git a/src/wormhole_transit_relay/test/test_transit_server.py b/src/wormhole_transit_relay/test/test_transit_server.py index 6101501..943013d 100644 --- a/src/wormhole_transit_relay/test/test_transit_server.py +++ b/src/wormhole_transit_relay/test/test_transit_server.py @@ -48,8 +48,9 @@ class _Transit: self.failUnlessEqual(blur_size(1100e6), 1100e6) self.failUnlessEqual(blur_size(1150e6), 1200e6) + @inlineCallbacks def test_register(self): - p1 = self.new_protocol() + p1 = yield self.new_protocol() token1 = b"\x00"*32 side1 = b"\x01"*8 @@ -65,9 +66,10 @@ class _Transit: # the token should be removed too self.assertEqual(len(self._transit.pending_requests._requests), 0) + @inlineCallbacks def test_both_unsided(self): - p1 = self.new_protocol() - p2 = self.new_protocol() + p1 = yield self.new_protocol() + p2 = yield self.new_protocol() token1 = b"\x00"*32 p1.send(handshake(token1, side=None)) @@ -92,9 +94,10 @@ class _Transit: p2.disconnect() self.flush() + @inlineCallbacks def test_sided_unsided(self): - p1 = self.new_protocol() - p2 = self.new_protocol() + p1 = yield self.new_protocol() + p2 = yield self.new_protocol() token1 = b"\x00"*32 side1 = b"\x01"*8 @@ -121,9 +124,10 @@ class _Transit: p2.disconnect() self.flush() + @inlineCallbacks def test_unsided_sided(self): - p1 = self.new_protocol() - p2 = self.new_protocol() + p1 = yield self.new_protocol() + p2 = yield self.new_protocol() token1 = b"\x00"*32 side1 = b"\x01"*8 @@ -148,9 +152,10 @@ class _Transit: p1.disconnect() p2.disconnect() + @inlineCallbacks def test_both_sided(self): - p1 = self.new_protocol() - p2 = self.new_protocol() + p1 = yield self.new_protocol() + p2 = yield self.new_protocol() token1 = b"\x00"*32 side1 = b"\x01"*8 @@ -177,10 +182,11 @@ class _Transit: p1.disconnect() p2.disconnect() + @inlineCallbacks def test_ignore_same_side(self): - p1 = self.new_protocol() - p2 = self.new_protocol() - p3 = self.new_protocol() + p1 = yield self.new_protocol() + p2 = yield self.new_protocol() + p3 = yield self.new_protocol() token1 = b"\x00"*32 side1 = b"\x01"*8 @@ -209,8 +215,9 @@ class _Transit: p2.disconnect() p3.disconnect() + @inlineCallbacks def test_bad_handshake_old(self): - p1 = self.new_protocol() + p1 = yield self.new_protocol() token1 = b"\x00"*32 p1.send(b"please DELAY " + hexlify(token1) + b"\n") @@ -220,8 +227,9 @@ class _Transit: self.assertEqual(p1.get_received_data(), exp) p1.disconnect() + @inlineCallbacks def test_bad_handshake_old_slow(self): - p1 = self.new_protocol() + p1 = yield self.new_protocol() p1.send(b"please DELAY ") self.flush() @@ -241,8 +249,9 @@ class _Transit: p1.disconnect() + @inlineCallbacks def test_bad_handshake_new(self): - p1 = self.new_protocol() + p1 = yield self.new_protocol() token1 = b"\x00"*32 side1 = b"\x01"*8 @@ -257,8 +266,9 @@ class _Transit: p1.disconnect() + @inlineCallbacks def test_binary_handshake(self): - p1 = self.new_protocol() + p1 = yield self.new_protocol() binary_bad_handshake = b"\x00\x01\xe0\x0f\n\xff" # the embedded \n makes the server trigger early, before the full @@ -275,8 +285,9 @@ class _Transit: p1.disconnect() + @inlineCallbacks def test_impatience_old(self): - p1 = self.new_protocol() + p1 = yield self.new_protocol() token1 = b"\x00"*32 # sending too many bytes is impatience. @@ -288,8 +299,9 @@ class _Transit: p1.disconnect() + @inlineCallbacks def test_impatience_new(self): - p1 = self.new_protocol() + p1 = yield self.new_protocol() token1 = b"\x00"*32 side1 = b"\x01"*8 @@ -303,8 +315,9 @@ class _Transit: p1.disconnect() + @inlineCallbacks def test_impatience_new_slow(self): - p1 = self.new_protocol() + p1 = yield self.new_protocol() # For full coverage, we need dataReceived to see a particular framing # of these two pieces of data, and ITCPTransport doesn't have flush() # (which probably wouldn't work anyways). For now, force a 100ms @@ -329,15 +342,17 @@ class _Transit: p1.disconnect() + @inlineCallbacks def test_short_handshake(self): - p1 = self.new_protocol() + p1 = yield self.new_protocol() # hang up before sending a complete handshake p1.send(b"short") self.flush() p1.disconnect() + @inlineCallbacks def test_empty_handshake(self): - p1 = self.new_protocol() + p1 = yield self.new_protocol() # hang up before sending anything p1.disconnect() @@ -358,8 +373,9 @@ class Usage(ServerBase, unittest.TestCase): self._usage = MemoryUsageRecorder() self._transit.usage.add_backend(self._usage) + @inlineCallbacks def test_empty(self): - p1 = self.new_protocol() + p1 = yield self.new_protocol() # hang up before sending anything p1.disconnect() self.flush() @@ -380,8 +396,9 @@ class Usage(ServerBase, unittest.TestCase): self.assertEqual(len(self._usage.events), 1, self._usage) self.assertEqual("empty", self._usage.events[0]["mood"]) + @inlineCallbacks def test_errory(self): - p1 = self.new_protocol() + p1 = yield self.new_protocol() p1.send(b"this is a very bad handshake\n") self.flush() @@ -390,8 +407,9 @@ class Usage(ServerBase, unittest.TestCase): self.assertEqual(len(self._usage.events), 1, self._usage) self.assertEqual(self._usage.events[0]["mood"], "errory", self._usage) + @inlineCallbacks def test_lonely(self): - p1 = self.new_protocol() + p1 = yield self.new_protocol() token1 = b"\x00"*32 side1 = b"\x01"*8 @@ -405,9 +423,10 @@ class Usage(ServerBase, unittest.TestCase): self.assertEqual(self._usage.events[0]["mood"], "lonely", self._usage) self.assertIdentical(self._usage.events[0]["waiting_time"], None) + @inlineCallbacks def test_one_happy_one_jilted(self): - p1 = self.new_protocol() - p2 = self.new_protocol() + p1 = yield self.new_protocol() + p2 = yield self.new_protocol() token1 = b"\x00"*32 side1 = b"\x01"*8 @@ -432,11 +451,12 @@ class Usage(ServerBase, unittest.TestCase): self.assertEqual(self._usage.events[0]["total_bytes"], 20) self.assertNotIdentical(self._usage.events[0]["waiting_time"], None) + @inlineCallbacks def test_redundant(self): - p1a = self.new_protocol() - p1b = self.new_protocol() - p1c = self.new_protocol() - p2 = self.new_protocol() + p1a = yield self.new_protocol() + p1b = yield self.new_protocol() + p1c = yield self.new_protocol() + p2 = yield self.new_protocol() token1 = b"\x00"*32 side1 = b"\x01"*8 @@ -453,7 +473,8 @@ class Usage(ServerBase, unittest.TestCase): p1c.disconnect() self.flush() - print(self._usage.events) + for x in self._usage.events: + print(x) self.assertEqual(len(self._usage.events), 1, self._usage) self.assertEqual(self._usage.events[0]["mood"], "lonely") @@ -517,6 +538,5 @@ class UsageWebSockets(Usage): lambda: server_factory.buildProtocol(IPv4Address("TCP", "127.0.0.1", 31337)), ) client_proto = yield agent.open("ws://127.0.0.1:4002/", dict()) - print("PROTO", client_proto) return client_proto diff --git a/src/wormhole_transit_relay/transit_server.py b/src/wormhole_transit_relay/transit_server.py index 133fbd5..0c2510a 100644 --- a/src/wormhole_transit_relay/transit_server.py +++ b/src/wormhole_transit_relay/transit_server.py @@ -240,18 +240,18 @@ class WebSocketTransitConnection(WebSocketServerProtocol): """ IWebSocketChannel API """ - print("onConnect: {}".format(request)) + # print("onConnect: {}".format(request)) # ideally more like self._reactor.seconds() ... but Twisted # doesn't have a good way to get the reactor for a protocol # (besides "use the global one") - print("protocols: {}".format(request.protocols)) - return None#"transit_relay" + # print("protocols: {}".format(request.protocols)) + return None #"transit_relay" def connectionMade(self): """ IProtocol API """ - print("connectionMade") + # print("connectionMade") super(WebSocketTransitConnection, self).connectionMade() self.started_time = time.time() self._first_message = True @@ -261,14 +261,14 @@ class WebSocketTransitConnection(WebSocketServerProtocol): ) def onOpen(self): - print("onOpen") + # print("onOpen") self._state.connection_made(self) def onMessage(self, payload, isBinary): """ We may have a 'handshake' on our hands or we may just have some bytes to relay """ - print("onMessage isBinary={}: {}".format(isBinary, payload)) + # print("onMessage isBinary={}: {}".format(isBinary, payload)) if self._first_message: self._first_message = False token = None @@ -298,6 +298,6 @@ class WebSocketTransitConnection(WebSocketServerProtocol): """ IWebSocketChannel API """ - print("onClose", wasClean, code, reason) + # print("onClose", wasClean, code, reason) self._state.connection_lost() # XXX "transit finished", etc From 99c71112b6d33eb66be4ac8564eba7b74630fea3 Mon Sep 17 00:00:00 2001 From: meejah Date: Thu, 1 Apr 2021 19:56:47 -0600 Subject: [PATCH 39/96] a passing thing --- src/wormhole_transit_relay/server_state.py | 6 + src/wormhole_transit_relay/server_tap.py | 2 +- .../test/test_transit_server.py | 144 ++++++++++++++++++ src/wormhole_transit_relay/transit_server.py | 17 ++- 4 files changed, 164 insertions(+), 5 deletions(-) diff --git a/src/wormhole_transit_relay/server_state.py b/src/wormhole_transit_relay/server_state.py index 86cf9b6..e37f95d 100644 --- a/src/wormhole_transit_relay/server_state.py +++ b/src/wormhole_transit_relay/server_state.py @@ -497,6 +497,7 @@ class TransitServerState(object): @_machine.output() def _count_bytes(self, data): self._total_sent += len(data) + print("COUNT BYTES +{} now {}".format(len(data), self._total_sent)) @_machine.output() def _send(self, data): @@ -522,6 +523,7 @@ class TransitServerState(object): # some outputs to record "usage" information .. @_machine.output() def _record_usage(self): + print("RECORD", self, self._mood, self._total_sent) if self._mood == "jilted": if self._buddy: if self._buddy._mood == "happy": @@ -694,3 +696,7 @@ class TransitServerState(object): enter=done, outputs=[], ) + + + ## XXX tracing + set_trace_function = _machine._setTrace diff --git a/src/wormhole_transit_relay/server_tap.py b/src/wormhole_transit_relay/server_tap.py index fd0c7ff..27e5a21 100644 --- a/src/wormhole_transit_relay/server_tap.py +++ b/src/wormhole_transit_relay/server_tap.py @@ -57,7 +57,7 @@ def makeService(config, reactor=reactor): ws_factory = WebSocketServerFactory("ws://localhost:4002") # FIXME: url ws_factory.protocol = transit_server.WebSocketTransitConnection - ws_factory.websocket_protocols = ["transit_relay"] + ws_factory.websocket_protocols = ["binary"] tcp_factory.transit = transit ws_factory.transit = transit diff --git a/src/wormhole_transit_relay/test/test_transit_server.py b/src/wormhole_transit_relay/test/test_transit_server.py index 943013d..946ca66 100644 --- a/src/wormhole_transit_relay/test/test_transit_server.py +++ b/src/wormhole_transit_relay/test/test_transit_server.py @@ -4,6 +4,7 @@ from binascii import hexlify from twisted.trial import unittest from twisted.test import proto_helpers from twisted.internet.defer import inlineCallbacks +from twisted.internet.task import deferLater from .common import ServerBase from ..server_state import ( MemoryUsageRecorder, @@ -13,6 +14,13 @@ from ..transit_server import ( WebSocketTransitConnection, ) +from ..transit_server import Transit, TransitConnection, WebSocketTransitConnection +from twisted.internet.protocol import ( + ServerFactory, + ClientFactory, +) +from twisted.internet import protocol +from ..server_state import create_usage_tracker from autobahn.twisted.websocket import WebSocketServerFactory @@ -427,6 +435,8 @@ class Usage(ServerBase, unittest.TestCase): def test_one_happy_one_jilted(self): p1 = yield self.new_protocol() p2 = yield self.new_protocol() + print(dir(p1.factory)) + return token1 = b"\x00"*32 side1 = b"\x01"*8 @@ -436,6 +446,7 @@ class Usage(ServerBase, unittest.TestCase): p2.send(handshake(token1, side=side2)) self.flush() + print("shouldn't be events yet") self.assertEqual(self._usage.events, []) # no events yet p1.send(b"\x00" * 13) @@ -540,3 +551,136 @@ class UsageWebSockets(Usage): client_proto = yield agent.open("ws://127.0.0.1:4002/", dict()) return client_proto + +class New(unittest.TestCase): + """ + A completely fresh approach using: + + - no base classes (besides TestCase to match rest) + - twisted.test.iosim.* (IOPump etc) + - no "faking" any interfaces + """ + log_requests = False + + def setUp(self): + self._pumps = [] + self._usage = MemoryUsageRecorder() + self._setup_relay(blur_usage=60.0 if self.log_requests else None) + + def flush(self): + for pump in self._pumps: + pump.flush() + + def _setup_relay(self, blur_usage=None, log_file=None, usage_db=None): + usage = create_usage_tracker( + blur_usage=blur_usage, + log_file=log_file, + usage_db=usage_db, + ) + self._transit = Transit(usage, lambda: 123456789.0) + self._transit._debug_log = self.log_requests + self._transit.usage.add_backend(self._usage) + + def new_protocol(self): + if False: + return self._new_protocol_tcp() + else: + return self._new_protocol_ws() + + def _new_protocol_tcp(self): + server_factory = ServerFactory() + server_factory.protocol = TransitConnection + server_factory.transit = self._transit + server_protocol = server_factory.buildProtocol(('127.0.0.1', 0)) + + class ClientProtocol(protocol.Protocol): + def sendMessage(self, data): + self.transport.write(data) + + def disconnect(self): + self.transport.loseConnection() + + client_factory = ClientFactory() + client_factory.protocol = ClientProtocol + client_protocol = client_factory.buildProtocol(('128.0.0.1', 31337)) + + pump = iosim.connect( + server_protocol, + iosim.makeFakeServer(server_protocol), + client_protocol, + iosim.makeFakeClient(client_protocol), + ) + print("did connectionmade get called yet?") + pump.flush() + self._pumps.append(pump) + return client_protocol + + def _new_protocol_ws(self): + ws_factory = WebSocketServerFactory("ws://localhost:4002") # FIXME: url + ws_factory.protocol = WebSocketTransitConnection + ws_factory.transit = self._transit + ws_factory.websocket_protocols = ["binary"] + ws_protocol = ws_factory.buildProtocol(('127.0.0.1', 0)) + + from autobahn.twisted.websocket import WebSocketClientFactory, WebSocketClientProtocol + client_factory = WebSocketClientFactory() + client_factory.protocol = WebSocketClientProtocol + client_factory.protocols = ["binary"] + client_protocol = client_factory.buildProtocol(('127.0.0.1', 31337)) + client_protocol.disconnect = client_protocol.dropConnection + + pump = iosim.connect( + ws_protocol, + iosim.makeFakeServer(ws_protocol), + client_protocol, + iosim.makeFakeClient(client_protocol), + ) + self._pumps.append(pump) + return client_protocol + + def test_short(self): + p1 = self.new_protocol() + # hang up before sending a complete handshake +# p1.sendMessage(b"short") # <-- only makes sense for TCP + p1.disconnect() + self.flush() + + # that will log the "empty" usage event + self.assertEqual(len(self._usage.events), 1, self._usage) + self.assertEqual("empty", self._usage.events[0]["mood"]) + + def test_one_happy_one_jilted(self): + p1 = self.new_protocol() + p2 = self.new_protocol() + + token1 = b"\x00"*32 + side1 = b"\x01"*8 + side2 = b"\x02"*8 + from twisted.internet import reactor + + print("p1 data") + p1.sendMessage(handshake(token1, side=side1), True) + print("p2 data") + p2.sendMessage(handshake(token1, side=side2), True) + self.flush() + + print("shouldn't be events yet") + self.assertEqual(self._usage.events, []) # no events yet + + print("p1 moar") + for x in range(13): + p1.sendMessage(b"\x00", True) + ##p1.sendMessage(b"\x00" * 13) + self.flush() + print("p2 moar") + p2.sendMessage(b"\xff" * 7, True) + self.flush() + + print("p1 lose") + p1.disconnect() + self.flush() + + self.assertEqual(len(self._usage.events), 1, self._usage) + self.assertEqual(self._usage.events[0]["mood"], "happy", self._usage) + self.assertEqual(self._usage.events[0]["total_bytes"], 20) + self.assertNotIdentical(self._usage.events[0]["waiting_time"], None) diff --git a/src/wormhole_transit_relay/transit_server.py b/src/wormhole_transit_relay/transit_server.py index 0c2510a..b81b5d6 100644 --- a/src/wormhole_transit_relay/transit_server.py +++ b/src/wormhole_transit_relay/transit_server.py @@ -77,6 +77,10 @@ class TransitConnection(LineReceiver): except AttributeError: pass + def tracer(oldstate, theinput, newstate): + print("TRACE: {}: {} --{}--> {}".format(id(self), oldstate, theinput, newstate)) + self._state.set_trace_function(tracer) + def lineReceived(self, line): """ LineReceiver API @@ -213,6 +217,7 @@ class WebSocketTransitConnection(WebSocketServerProtocol): """ ITransitClient API """ + print("send: {}".format(repr(data))) self.sendMessage(data, isBinary=True) def disconnect(self): @@ -244,14 +249,14 @@ class WebSocketTransitConnection(WebSocketServerProtocol): # ideally more like self._reactor.seconds() ... but Twisted # doesn't have a good way to get the reactor for a protocol # (besides "use the global one") - # print("protocols: {}".format(request.protocols)) - return None #"transit_relay" + print("protocols: {}".format(request.protocols)) + return 'binary' def connectionMade(self): """ IProtocol API """ - # print("connectionMade") + print("connectionMade") super(WebSocketTransitConnection, self).connectionMade() self.started_time = time.time() self._first_message = True @@ -260,6 +265,10 @@ class WebSocketTransitConnection(WebSocketServerProtocol): self.factory.transit.usage, ) + def tracer(oldstate, theinput, newstate): + print("WSTRACE: {}: {} --{}--> {}".format(id(self), oldstate, theinput, newstate)) + self._state.set_trace_function(tracer) + def onOpen(self): # print("onOpen") self._state.connection_made(self) @@ -298,6 +307,6 @@ class WebSocketTransitConnection(WebSocketServerProtocol): """ IWebSocketChannel API """ - # print("onClose", wasClean, code, reason) + print("{} onClose: {} {} {}".format(id(self), wasClean, code, reason)) self._state.connection_lost() # XXX "transit finished", etc From f18edc89f95047e7a94028bc3b5b7cd0bf675f8d Mon Sep 17 00:00:00 2001 From: meejah Date: Fri, 2 Apr 2021 14:58:31 -0600 Subject: [PATCH 40/96] refine --- .../test/test_transit_server.py | 22 +++++++++++-------- src/wormhole_transit_relay/transit_server.py | 4 ++++ 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/src/wormhole_transit_relay/test/test_transit_server.py b/src/wormhole_transit_relay/test/test_transit_server.py index 946ca66..795b996 100644 --- a/src/wormhole_transit_relay/test/test_transit_server.py +++ b/src/wormhole_transit_relay/test/test_transit_server.py @@ -594,7 +594,7 @@ class New(unittest.TestCase): server_protocol = server_factory.buildProtocol(('127.0.0.1', 0)) class ClientProtocol(protocol.Protocol): - def sendMessage(self, data): + def send(self, data): self.transport.write(data) def disconnect(self): @@ -623,8 +623,13 @@ class New(unittest.TestCase): ws_protocol = ws_factory.buildProtocol(('127.0.0.1', 0)) from autobahn.twisted.websocket import WebSocketClientFactory, WebSocketClientProtocol + + class TransitWebSocketClientProtocol(WebSocketClientProtocol): + def send(self, data): + self.sendMessage(data, True) + client_factory = WebSocketClientFactory() - client_factory.protocol = WebSocketClientProtocol + client_factory.protocol = TransitWebSocketClientProtocol client_factory.protocols = ["binary"] client_protocol = client_factory.buildProtocol(('127.0.0.1', 31337)) client_protocol.disconnect = client_protocol.dropConnection @@ -639,9 +644,10 @@ class New(unittest.TestCase): return client_protocol def test_short(self): + # XXX this test only makes sense for TCP p1 = self.new_protocol() # hang up before sending a complete handshake -# p1.sendMessage(b"short") # <-- only makes sense for TCP + p1.send(b"short") p1.disconnect() self.flush() @@ -659,21 +665,19 @@ class New(unittest.TestCase): from twisted.internet import reactor print("p1 data") - p1.sendMessage(handshake(token1, side=side1), True) + p1.send(handshake(token1, side=side1)) print("p2 data") - p2.sendMessage(handshake(token1, side=side2), True) + p2.send(handshake(token1, side=side2)) self.flush() print("shouldn't be events yet") self.assertEqual(self._usage.events, []) # no events yet print("p1 moar") - for x in range(13): - p1.sendMessage(b"\x00", True) - ##p1.sendMessage(b"\x00" * 13) + p1.send(b"\x00" * 13) self.flush() print("p2 moar") - p2.sendMessage(b"\xff" * 7, True) + p2.send(b"\xff" * 7) self.flush() print("p1 lose") diff --git a/src/wormhole_transit_relay/transit_server.py b/src/wormhole_transit_relay/transit_server.py index b81b5d6..f096bfc 100644 --- a/src/wormhole_transit_relay/transit_server.py +++ b/src/wormhole_transit_relay/transit_server.py @@ -277,6 +277,10 @@ class WebSocketTransitConnection(WebSocketServerProtocol): """ We may have a 'handshake' on our hands or we may just have some bytes to relay """ + if not isBinary: + raise ValueError( + "All messages must be binary" + ) # print("onMessage isBinary={}: {}".format(isBinary, payload)) if self._first_message: self._first_message = False From 816e997b016dcebe988cf4aede21778554307b84 Mon Sep 17 00:00:00 2001 From: meejah Date: Sat, 3 Apr 2021 00:06:14 -0600 Subject: [PATCH 41/96] post-rebase fixups --- src/wormhole_transit_relay/test/common.py | 5 +- .../test/test_transit_server.py | 120 ++++++------------ 2 files changed, 45 insertions(+), 80 deletions(-) diff --git a/src/wormhole_transit_relay/test/common.py b/src/wormhole_transit_relay/test/common.py index 8690e91..a98a338 100644 --- a/src/wormhole_transit_relay/test/common.py +++ b/src/wormhole_transit_relay/test/common.py @@ -78,7 +78,10 @@ class ServerBase: Create a new client protocol connected to the server. :returns: a IRelayTestClient implementation """ - server_protocol = self._transit_server.buildProtocol(('127.0.0.1', 0)) + server_factory = ServerFactory() + server_factory.protocol = TransitConnection + server_factory.transit = self._transit_server + server_protocol = server_factory.buildProtocol(('127.0.0.1', 0)) @implementer(IRelayTestClient) class TransitClientProtocolTcp(Protocol): diff --git a/src/wormhole_transit_relay/test/test_transit_server.py b/src/wormhole_transit_relay/test/test_transit_server.py index 795b996..e75dac2 100644 --- a/src/wormhole_transit_relay/test/test_transit_server.py +++ b/src/wormhole_transit_relay/test/test_transit_server.py @@ -36,7 +36,7 @@ class _Transit: return sum([ len(potentials) for potentials - in self._transit.pending_requests._requests.values() + in self._transit_server.pending_requests._requests.values() ]) def test_blur_size(self): @@ -56,9 +56,8 @@ class _Transit: self.failUnlessEqual(blur_size(1100e6), 1100e6) self.failUnlessEqual(blur_size(1150e6), 1200e6) - @inlineCallbacks def test_register(self): - p1 = yield self.new_protocol() + p1 = self.new_protocol() token1 = b"\x00"*32 side1 = b"\x01"*8 @@ -72,12 +71,11 @@ class _Transit: self.assertEqual(self.count(), 0) # the token should be removed too - self.assertEqual(len(self._transit.pending_requests._requests), 0) + self.assertEqual(len(self._transit_server.pending_requests._requests), 0) - @inlineCallbacks def test_both_unsided(self): - p1 = yield self.new_protocol() - p2 = yield self.new_protocol() + p1 = self.new_protocol() + p2 = self.new_protocol() token1 = b"\x00"*32 p1.send(handshake(token1, side=None)) @@ -102,10 +100,9 @@ class _Transit: p2.disconnect() self.flush() - @inlineCallbacks def test_sided_unsided(self): - p1 = yield self.new_protocol() - p2 = yield self.new_protocol() + p1 = self.new_protocol() + p2 = self.new_protocol() token1 = b"\x00"*32 side1 = b"\x01"*8 @@ -132,10 +129,9 @@ class _Transit: p2.disconnect() self.flush() - @inlineCallbacks def test_unsided_sided(self): - p1 = yield self.new_protocol() - p2 = yield self.new_protocol() + p1 = self.new_protocol() + p2 = self.new_protocol() token1 = b"\x00"*32 side1 = b"\x01"*8 @@ -160,10 +156,9 @@ class _Transit: p1.disconnect() p2.disconnect() - @inlineCallbacks def test_both_sided(self): - p1 = yield self.new_protocol() - p2 = yield self.new_protocol() + p1 = self.new_protocol() + p2 = self.new_protocol() token1 = b"\x00"*32 side1 = b"\x01"*8 @@ -190,11 +185,10 @@ class _Transit: p1.disconnect() p2.disconnect() - @inlineCallbacks def test_ignore_same_side(self): - p1 = yield self.new_protocol() - p2 = yield self.new_protocol() - p3 = yield self.new_protocol() + p1 = self.new_protocol() + p2 = self.new_protocol() + p3 = self.new_protocol() token1 = b"\x00"*32 side1 = b"\x01"*8 @@ -213,8 +207,8 @@ class _Transit: p3.send(handshake(token1, side=side2)) self.flush() self.assertEqual(self.count(), 0) - self.assertEqual(len(self._transit.pending_requests._requests), 0) - self.assertEqual(len(self._transit.active_connections._connections), 2) + self.assertEqual(len(self._transit_server.pending_requests._requests), 0) + self.assertEqual(len(self._transit_server.active_connections._connections), 2) # That will trigger a disconnect on exactly one of (p1 or p2). # The other connection should still be connected self.assertEqual(sum([int(t.connected) for t in [p1, p2]]), 1) @@ -223,9 +217,8 @@ class _Transit: p2.disconnect() p3.disconnect() - @inlineCallbacks def test_bad_handshake_old(self): - p1 = yield self.new_protocol() + p1 = self.new_protocol() token1 = b"\x00"*32 p1.send(b"please DELAY " + hexlify(token1) + b"\n") @@ -235,9 +228,8 @@ class _Transit: self.assertEqual(p1.get_received_data(), exp) p1.disconnect() - @inlineCallbacks def test_bad_handshake_old_slow(self): - p1 = yield self.new_protocol() + p1 = self.new_protocol() p1.send(b"please DELAY ") self.flush() @@ -257,9 +249,8 @@ class _Transit: p1.disconnect() - @inlineCallbacks def test_bad_handshake_new(self): - p1 = yield self.new_protocol() + p1 = self.new_protocol() token1 = b"\x00"*32 side1 = b"\x01"*8 @@ -274,9 +265,8 @@ class _Transit: p1.disconnect() - @inlineCallbacks def test_binary_handshake(self): - p1 = yield self.new_protocol() + p1 = self.new_protocol() binary_bad_handshake = b"\x00\x01\xe0\x0f\n\xff" # the embedded \n makes the server trigger early, before the full @@ -293,9 +283,8 @@ class _Transit: p1.disconnect() - @inlineCallbacks def test_impatience_old(self): - p1 = yield self.new_protocol() + p1 = self.new_protocol() token1 = b"\x00"*32 # sending too many bytes is impatience. @@ -307,9 +296,8 @@ class _Transit: p1.disconnect() - @inlineCallbacks def test_impatience_new(self): - p1 = yield self.new_protocol() + p1 = self.new_protocol() token1 = b"\x00"*32 side1 = b"\x01"*8 @@ -323,9 +311,8 @@ class _Transit: p1.disconnect() - @inlineCallbacks def test_impatience_new_slow(self): - p1 = yield self.new_protocol() + p1 = self.new_protocol() # For full coverage, we need dataReceived to see a particular framing # of these two pieces of data, and ITCPTransport doesn't have flush() # (which probably wouldn't work anyways). For now, force a 100ms @@ -350,17 +337,15 @@ class _Transit: p1.disconnect() - @inlineCallbacks def test_short_handshake(self): - p1 = yield self.new_protocol() + p1 = self.new_protocol() # hang up before sending a complete handshake p1.send(b"short") self.flush() p1.disconnect() - @inlineCallbacks def test_empty_handshake(self): - p1 = yield self.new_protocol() + p1 = self.new_protocol() # hang up before sending anything p1.disconnect() @@ -379,11 +364,10 @@ class Usage(ServerBase, unittest.TestCase): def setUp(self): super(Usage, self).setUp() self._usage = MemoryUsageRecorder() - self._transit.usage.add_backend(self._usage) + self._transit_server.usage.add_backend(self._usage) - @inlineCallbacks def test_empty(self): - p1 = yield self.new_protocol() + p1 = self.new_protocol() # hang up before sending anything p1.disconnect() self.flush() @@ -392,9 +376,8 @@ class Usage(ServerBase, unittest.TestCase): self.assertEqual(len(self._usage.events), 1, self._usage) self.assertEqual(self._usage.events[0]["mood"], "empty", self._usage) - @inlineCallbacks def test_short(self): - p1 = yield self.new_protocol() + p1 = self.new_protocol() # hang up before sending a complete handshake p1.send(b"short") p1.disconnect() @@ -404,9 +387,8 @@ class Usage(ServerBase, unittest.TestCase): self.assertEqual(len(self._usage.events), 1, self._usage) self.assertEqual("empty", self._usage.events[0]["mood"]) - @inlineCallbacks def test_errory(self): - p1 = yield self.new_protocol() + p1 = self.new_protocol() p1.send(b"this is a very bad handshake\n") self.flush() @@ -415,9 +397,8 @@ class Usage(ServerBase, unittest.TestCase): self.assertEqual(len(self._usage.events), 1, self._usage) self.assertEqual(self._usage.events[0]["mood"], "errory", self._usage) - @inlineCallbacks def test_lonely(self): - p1 = yield self.new_protocol() + p1 = self.new_protocol() token1 = b"\x00"*32 side1 = b"\x01"*8 @@ -431,10 +412,9 @@ class Usage(ServerBase, unittest.TestCase): self.assertEqual(self._usage.events[0]["mood"], "lonely", self._usage) self.assertIdentical(self._usage.events[0]["waiting_time"], None) - @inlineCallbacks def test_one_happy_one_jilted(self): - p1 = yield self.new_protocol() - p2 = yield self.new_protocol() + p1 = self.new_protocol() + p2 = self.new_protocol() print(dir(p1.factory)) return @@ -462,12 +442,11 @@ class Usage(ServerBase, unittest.TestCase): self.assertEqual(self._usage.events[0]["total_bytes"], 20) self.assertNotIdentical(self._usage.events[0]["waiting_time"], None) - @inlineCallbacks def test_redundant(self): - p1a = yield self.new_protocol() - p1b = yield self.new_protocol() - p1c = yield self.new_protocol() - p2 = yield self.new_protocol() + p1a = self.new_protocol() + p1b = self.new_protocol() + p1c = self.new_protocol() + p2 = self.new_protocol() token1 = b"\x00"*32 side1 = b"\x01"*8 @@ -533,23 +512,6 @@ class UsageWebSockets(Usage): def tearDown(self): return self._pump.stop() - @inlineCallbacks - def new_protocol(self): - - class RelayFactory(WebSocketServerFactory): - protocol = WebSocketTransitConnection - websocket_protocols = ["transit_relay"] - transit = self._transit - - server_factory = RelayFactory("ws://localhost:4002") - - agent = create_memory_agent( - self._reactor, - self._pump, - lambda: server_factory.buildProtocol(IPv4Address("TCP", "127.0.0.1", 31337)), - ) - client_proto = yield agent.open("ws://127.0.0.1:4002/", dict()) - return client_proto class New(unittest.TestCase): @@ -577,9 +539,9 @@ class New(unittest.TestCase): log_file=log_file, usage_db=usage_db, ) - self._transit = Transit(usage, lambda: 123456789.0) - self._transit._debug_log = self.log_requests - self._transit.usage.add_backend(self._usage) + self._transit_server = Transit(usage, lambda: 123456789.0) + self._transit_server._debug_log = self.log_requests + self._transit_server.usage.add_backend(self._usage) def new_protocol(self): if False: @@ -590,7 +552,7 @@ class New(unittest.TestCase): def _new_protocol_tcp(self): server_factory = ServerFactory() server_factory.protocol = TransitConnection - server_factory.transit = self._transit + server_factory.transit = self._transit_server server_protocol = server_factory.buildProtocol(('127.0.0.1', 0)) class ClientProtocol(protocol.Protocol): @@ -618,7 +580,7 @@ class New(unittest.TestCase): def _new_protocol_ws(self): ws_factory = WebSocketServerFactory("ws://localhost:4002") # FIXME: url ws_factory.protocol = WebSocketTransitConnection - ws_factory.transit = self._transit + ws_factory.transit = self._transit_server ws_factory.websocket_protocols = ["binary"] ws_protocol = ws_factory.buildProtocol(('127.0.0.1', 0)) From a89988af9083b00987c7d66780cd1ff313e0a5d5 Mon Sep 17 00:00:00 2001 From: meejah Date: Mon, 5 Apr 2021 16:03:25 -0600 Subject: [PATCH 42/96] get rid of placeholder/test code; skip test_short for websockets --- .../test/test_transit_server.py | 118 ++---------------- 1 file changed, 8 insertions(+), 110 deletions(-) diff --git a/src/wormhole_transit_relay/test/test_transit_server.py b/src/wormhole_transit_relay/test/test_transit_server.py index e75dac2..f00db6c 100644 --- a/src/wormhole_transit_relay/test/test_transit_server.py +++ b/src/wormhole_transit_relay/test/test_transit_server.py @@ -377,6 +377,7 @@ class Usage(ServerBase, unittest.TestCase): self.assertEqual(self._usage.events[0]["mood"], "empty", self._usage) def test_short(self): + # XXX this test only makes sense for TCP p1 = self.new_protocol() # hang up before sending a complete handshake p1.send(b"short") @@ -512,72 +513,15 @@ class UsageWebSockets(Usage): def tearDown(self): return self._pump.stop() - - -class New(unittest.TestCase): - """ - A completely fresh approach using: - - - no base classes (besides TestCase to match rest) - - twisted.test.iosim.* (IOPump etc) - - no "faking" any interfaces - """ - log_requests = False - - def setUp(self): - self._pumps = [] - self._usage = MemoryUsageRecorder() - self._setup_relay(blur_usage=60.0 if self.log_requests else None) - - def flush(self): - for pump in self._pumps: - pump.flush() - - def _setup_relay(self, blur_usage=None, log_file=None, usage_db=None): - usage = create_usage_tracker( - blur_usage=blur_usage, - log_file=log_file, - usage_db=usage_db, - ) - self._transit_server = Transit(usage, lambda: 123456789.0) - self._transit_server._debug_log = self.log_requests - self._transit_server.usage.add_backend(self._usage) + def test_short(self): + """ + This test essentially just tests the framing of the line-oriented + TCP protocol; it doesnt' make sense for the WebSockets case + because WS handles frameing: you either sent a 'bad handshake' + because it is semantically invalid or no handshake (yet). + """ def new_protocol(self): - if False: - return self._new_protocol_tcp() - else: - return self._new_protocol_ws() - - def _new_protocol_tcp(self): - server_factory = ServerFactory() - server_factory.protocol = TransitConnection - server_factory.transit = self._transit_server - server_protocol = server_factory.buildProtocol(('127.0.0.1', 0)) - - class ClientProtocol(protocol.Protocol): - def send(self, data): - self.transport.write(data) - - def disconnect(self): - self.transport.loseConnection() - - client_factory = ClientFactory() - client_factory.protocol = ClientProtocol - client_protocol = client_factory.buildProtocol(('128.0.0.1', 31337)) - - pump = iosim.connect( - server_protocol, - iosim.makeFakeServer(server_protocol), - client_protocol, - iosim.makeFakeClient(client_protocol), - ) - print("did connectionmade get called yet?") - pump.flush() - self._pumps.append(pump) - return client_protocol - - def _new_protocol_ws(self): ws_factory = WebSocketServerFactory("ws://localhost:4002") # FIXME: url ws_factory.protocol = WebSocketTransitConnection ws_factory.transit = self._transit_server @@ -604,49 +548,3 @@ class New(unittest.TestCase): ) self._pumps.append(pump) return client_protocol - - def test_short(self): - # XXX this test only makes sense for TCP - p1 = self.new_protocol() - # hang up before sending a complete handshake - p1.send(b"short") - p1.disconnect() - self.flush() - - # that will log the "empty" usage event - self.assertEqual(len(self._usage.events), 1, self._usage) - self.assertEqual("empty", self._usage.events[0]["mood"]) - - def test_one_happy_one_jilted(self): - p1 = self.new_protocol() - p2 = self.new_protocol() - - token1 = b"\x00"*32 - side1 = b"\x01"*8 - side2 = b"\x02"*8 - from twisted.internet import reactor - - print("p1 data") - p1.send(handshake(token1, side=side1)) - print("p2 data") - p2.send(handshake(token1, side=side2)) - self.flush() - - print("shouldn't be events yet") - self.assertEqual(self._usage.events, []) # no events yet - - print("p1 moar") - p1.send(b"\x00" * 13) - self.flush() - print("p2 moar") - p2.send(b"\xff" * 7) - self.flush() - - print("p1 lose") - p1.disconnect() - self.flush() - - self.assertEqual(len(self._usage.events), 1, self._usage) - self.assertEqual(self._usage.events[0]["mood"], "happy", self._usage) - self.assertEqual(self._usage.events[0]["total_bytes"], 20) - self.assertNotIdentical(self._usage.events[0]["waiting_time"], None) From b829cae9404271b6f44f7c0a31f5b031990fc268 Mon Sep 17 00:00:00 2001 From: meejah Date: Mon, 5 Apr 2021 21:59:07 -0600 Subject: [PATCH 43/96] obsolete --- docs/server-statemachine.dot | 22 ---------------------- 1 file changed, 22 deletions(-) delete mode 100644 docs/server-statemachine.dot diff --git a/docs/server-statemachine.dot b/docs/server-statemachine.dot deleted file mode 100644 index d3a2215..0000000 --- a/docs/server-statemachine.dot +++ /dev/null @@ -1,22 +0,0 @@ -/** -. thinking about state-machine from "hand-drawn" perspective -. will it look the same as an Automat one? -**/ - -digraph { - listening -> wait_relay [label="connection_made"] - - wait_relay -> wait_partner [label="please_relay\nFindPartner"] - wait_relay -> wait_partner [label="please_relay_for_side\nFindPartner"] - wait_relay -> done [label="invalid_token\nSend('bad handshake')\nDisconnect"] - wait_relay -> done [label="connection_lost"] - - wait_partner -> relaying [label="got_partner\nConnectPartner(partner)\nSend('ok')"] - wait_partner -> done [label="got_bytes\nDisconnect"] - wait_partner -> done [label="connection_lost"] - - relaying -> relaying [label="got_bytes\nSend(bytes)"] - relaying -> done [label="partner_connection_lost\nDisconnectMe"] - relaying -> done [label="connection_lost\nDisconnectPartner"] -} - From 317b5a8dae886a0eec8a4e37e0ff43ff570a9271 Mon Sep 17 00:00:00 2001 From: meejah Date: Tue, 6 Apr 2021 10:34:00 -0600 Subject: [PATCH 44/96] test-client cleanup --- ws_client.py | 46 +++++++++++++++++++++++++++++++--------------- 1 file changed, 31 insertions(+), 15 deletions(-) diff --git a/ws_client.py b/ws_client.py index 03158a8..e5fbbd3 100644 --- a/ws_client.py +++ b/ws_client.py @@ -1,11 +1,12 @@ from __future__ import print_function +import sys from twisted.internet import endpoints from twisted.internet.defer import ( Deferred, inlineCallbacks, ) -from twisted.internet.task import react +from twisted.internet.task import react, deferLater from twisted.internet.error import ( ConnectionDone, ) @@ -26,49 +27,64 @@ from autobahn.websocket import types class RelayEchoClient(WebSocketClientProtocol): def onOpen(self): - self.data = b"" - self.sendMessage(u"please relay {}".format(self.factory.token).encode("ascii"), True) - - def onConnecting(self, details): - return types.ConnectingRequest( - protocols=["transit_relay"], + self._received = b"" + self.sendMessage( + u"please relay {} for side {}".format( + self.factory.token, + self.factory.side, + ).encode("ascii"), + True, ) + # def onConnecting(self, details): + # return types.ConnectingRequest( + # protocols=["binary"], + # ) + def onMessage(self, data, isBinary): print(">onMessage: {} bytes".format(len(data))) print(data, isBinary) if data == b"ok\n": self.factory.ready.callback(None) else: - self.data += data - return True + self._received += data +# return False def onClose(self, wasClean, code, reason): print(">onClose", wasClean, code, reason) self.factory.done.callback(reason) + if not self.factory.ready.called: + self.factory.ready.errback(RuntimeError(reason)) @react @inlineCallbacks def main(reactor): + will_send_message = len(sys.argv) > 1 ep = endpoints.clientFromString(reactor, "tcp:localhost:4002") f = WebSocketClientFactory("ws://127.0.0.1:4002/") + f.reactor = reactor f.protocol = RelayEchoClient +## f.protocols = ["binary"] # NB: write our own factory, probably.. f.token = "a" * 64 + f.side = "0" * 16 if will_send_message else "1" * 16 f.done = Deferred() f.ready = Deferred() proto = yield ep.connect(f) # proto_d = ep.connect(f) # print("proto_d", proto_d) # proto = yield proto_d - print("proto", proto, f.done) + print("proto", proto) yield f.ready print("ready") - import sys - if len(sys.argv) > 2: - proto.sendMessage(b"it's a message", True) + if will_send_message: + for _ in range(5): + print("sending message") + proto.sendMessage(b"it's a message", True) + yield deferLater(reactor, 0.2) yield proto.sendClose() + print("closing") yield f.done - print("relayed {} bytes:".format(len(proto.data))) - print(proto.data.decode("utf8")) + print("relayed {} bytes:".format(len(proto._received))) + print(proto._received.decode("utf8")) From 0aaf00f803695d6acde321a8fbc8f5264a5694bb Mon Sep 17 00:00:00 2001 From: meejah Date: Wed, 7 Apr 2021 16:26:04 -0600 Subject: [PATCH 45/96] get rid of prints --- src/wormhole_transit_relay/server_state.py | 2 -- src/wormhole_transit_relay/transit_server.py | 5 ----- 2 files changed, 7 deletions(-) diff --git a/src/wormhole_transit_relay/server_state.py b/src/wormhole_transit_relay/server_state.py index e37f95d..ae4ac2f 100644 --- a/src/wormhole_transit_relay/server_state.py +++ b/src/wormhole_transit_relay/server_state.py @@ -497,7 +497,6 @@ class TransitServerState(object): @_machine.output() def _count_bytes(self, data): self._total_sent += len(data) - print("COUNT BYTES +{} now {}".format(len(data), self._total_sent)) @_machine.output() def _send(self, data): @@ -523,7 +522,6 @@ class TransitServerState(object): # some outputs to record "usage" information .. @_machine.output() def _record_usage(self): - print("RECORD", self, self._mood, self._total_sent) if self._mood == "jilted": if self._buddy: if self._buddy._mood == "happy": diff --git a/src/wormhole_transit_relay/transit_server.py b/src/wormhole_transit_relay/transit_server.py index f096bfc..8a5bead 100644 --- a/src/wormhole_transit_relay/transit_server.py +++ b/src/wormhole_transit_relay/transit_server.py @@ -217,7 +217,6 @@ class WebSocketTransitConnection(WebSocketServerProtocol): """ ITransitClient API """ - print("send: {}".format(repr(data))) self.sendMessage(data, isBinary=True) def disconnect(self): @@ -256,7 +255,6 @@ class WebSocketTransitConnection(WebSocketServerProtocol): """ IProtocol API """ - print("connectionMade") super(WebSocketTransitConnection, self).connectionMade() self.started_time = time.time() self._first_message = True @@ -281,7 +279,6 @@ class WebSocketTransitConnection(WebSocketServerProtocol): raise ValueError( "All messages must be binary" ) - # print("onMessage isBinary={}: {}".format(isBinary, payload)) if self._first_message: self._first_message = False token = None @@ -311,6 +308,4 @@ class WebSocketTransitConnection(WebSocketServerProtocol): """ IWebSocketChannel API """ - print("{} onClose: {} {} {}".format(id(self), wasClean, code, reason)) self._state.connection_lost() - # XXX "transit finished", etc From 40e14174e7f1391bac7c7450ac9784f08ab62367 Mon Sep 17 00:00:00 2001 From: meejah Date: Wed, 7 Apr 2021 16:26:28 -0600 Subject: [PATCH 46/96] don't need 'binary' subprotocol stuff --- src/wormhole_transit_relay/server_tap.py | 1 - src/wormhole_transit_relay/transit_server.py | 11 ----------- 2 files changed, 12 deletions(-) diff --git a/src/wormhole_transit_relay/server_tap.py b/src/wormhole_transit_relay/server_tap.py index 27e5a21..7223fa4 100644 --- a/src/wormhole_transit_relay/server_tap.py +++ b/src/wormhole_transit_relay/server_tap.py @@ -57,7 +57,6 @@ def makeService(config, reactor=reactor): ws_factory = WebSocketServerFactory("ws://localhost:4002") # FIXME: url ws_factory.protocol = transit_server.WebSocketTransitConnection - ws_factory.websocket_protocols = ["binary"] tcp_factory.transit = transit ws_factory.transit = transit diff --git a/src/wormhole_transit_relay/transit_server.py b/src/wormhole_transit_relay/transit_server.py index 8a5bead..a20c5de 100644 --- a/src/wormhole_transit_relay/transit_server.py +++ b/src/wormhole_transit_relay/transit_server.py @@ -240,17 +240,6 @@ class WebSocketTransitConnection(WebSocketServerProtocol): self._buddy._client.transport.loseConnection() self._buddy = None - def onConnect(self, request): - """ - IWebSocketChannel API - """ - # print("onConnect: {}".format(request)) - # ideally more like self._reactor.seconds() ... but Twisted - # doesn't have a good way to get the reactor for a protocol - # (besides "use the global one") - print("protocols: {}".format(request.protocols)) - return 'binary' - def connectionMade(self): """ IProtocol API From 0bfff5242b375b8b731e8160da86ee869df43d28 Mon Sep 17 00:00:00 2001 From: meejah Date: Wed, 7 Apr 2021 16:28:00 -0600 Subject: [PATCH 47/96] note --- src/wormhole_transit_relay/server_state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/wormhole_transit_relay/server_state.py b/src/wormhole_transit_relay/server_state.py index ae4ac2f..0dfed82 100644 --- a/src/wormhole_transit_relay/server_state.py +++ b/src/wormhole_transit_relay/server_state.py @@ -618,7 +618,7 @@ class TransitServerState(object): listening.upon( connection_made, enter=wait_relay, - outputs=[_remember_client], + outputs=[_remember_client], # XXX need _forget_client ? ) listening.upon( connection_lost, From b095b6919a4d829f282bb1e5674a3743c0a1889c Mon Sep 17 00:00:00 2001 From: meejah Date: Wed, 7 Apr 2021 16:28:19 -0600 Subject: [PATCH 48/96] cleanup --- src/wormhole_transit_relay/transit_server.py | 4 +++- ws_client.py | 13 ++----------- 2 files changed, 5 insertions(+), 12 deletions(-) diff --git a/src/wormhole_transit_relay/transit_server.py b/src/wormhole_transit_relay/transit_server.py index a20c5de..be543d1 100644 --- a/src/wormhole_transit_relay/transit_server.py +++ b/src/wormhole_transit_relay/transit_server.py @@ -58,6 +58,7 @@ class TransitConnection(LineReceiver): """ if self._buddy is not None: log.msg("buddy_disconnected {}".format(self._buddy.get_token())) + # XXX if our buddy is a WebSocket, this isn't the right way? self._buddy._client.transport.loseConnection() self._buddy = None @@ -237,7 +238,8 @@ class WebSocketTransitConnection(WebSocketServerProtocol): """ if self._buddy is not None: log.msg("buddy_disconnected {}".format(self._buddy.get_token())) - self._buddy._client.transport.loseConnection() + # XXX if our buddy is tcp this is wrong + self._buddy._client.disconnect() self._buddy = None def connectionMade(self): diff --git a/ws_client.py b/ws_client.py index e5fbbd3..f6219da 100644 --- a/ws_client.py +++ b/ws_client.py @@ -36,11 +36,6 @@ class RelayEchoClient(WebSocketClientProtocol): True, ) - # def onConnecting(self, details): - # return types.ConnectingRequest( - # protocols=["binary"], - # ) - def onMessage(self, data, isBinary): print(">onMessage: {} bytes".format(len(data))) print(data, isBinary) @@ -48,7 +43,6 @@ class RelayEchoClient(WebSocketClientProtocol): self.factory.ready.callback(None) else: self._received += data -# return False def onClose(self, wasClean, code, reason): print(">onClose", wasClean, code, reason) @@ -65,18 +59,15 @@ def main(reactor): f = WebSocketClientFactory("ws://127.0.0.1:4002/") f.reactor = reactor f.protocol = RelayEchoClient -## f.protocols = ["binary"] - # NB: write our own factory, probably.. f.token = "a" * 64 f.side = "0" * 16 if will_send_message else "1" * 16 f.done = Deferred() f.ready = Deferred() + proto = yield ep.connect(f) - # proto_d = ep.connect(f) - # print("proto_d", proto_d) - # proto = yield proto_d print("proto", proto) yield f.ready + print("ready") if will_send_message: for _ in range(5): From 27d7ea85e89f95bac367e470a487c2569032a115 Mon Sep 17 00:00:00 2001 From: meejah Date: Wed, 7 Apr 2021 16:30:54 -0600 Subject: [PATCH 49/96] error-handling --- src/wormhole_transit_relay/transit_server.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/wormhole_transit_relay/transit_server.py b/src/wormhole_transit_relay/transit_server.py index be543d1..259c61c 100644 --- a/src/wormhole_transit_relay/transit_server.py +++ b/src/wormhole_transit_relay/transit_server.py @@ -288,7 +288,11 @@ class WebSocketTransitConnection(WebSocketServerProtocol): if token is None: self._state.bad_token() else: - self._state.got_bytes(payload) + try: + self._state.got_bytes(payload) + except Exception as e: + log.err("Failed to send to partner: {}".format(e)) + self.sendClose(3000, "send to partner failed") def disconnect_redundant(self): # this is called if a buddy connected and we were found unnecessary. From 9b4e9577b3f7b74b18f139a9977fdd05065ff6cc Mon Sep 17 00:00:00 2001 From: meejah Date: Wed, 7 Apr 2021 16:31:11 -0600 Subject: [PATCH 50/96] whitespace --- src/wormhole_transit_relay/transit_server.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/wormhole_transit_relay/transit_server.py b/src/wormhole_transit_relay/transit_server.py index 259c61c..9a9c31a 100644 --- a/src/wormhole_transit_relay/transit_server.py +++ b/src/wormhole_transit_relay/transit_server.py @@ -7,8 +7,6 @@ from twisted.protocols.basic import LineReceiver from autobahn.twisted.websocket import WebSocketServerProtocol - - SECONDS = 1.0 MINUTE = 60*SECONDS HOUR = 60*MINUTE From e7b7b4cd6b1bee0f5383255ea9ffe2f2ab1571bf Mon Sep 17 00:00:00 2001 From: meejah Date: Wed, 7 Apr 2021 16:31:32 -0600 Subject: [PATCH 51/96] unexpected hangup --- ws_client.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ws_client.py b/ws_client.py index f6219da..b0407bf 100644 --- a/ws_client.py +++ b/ws_client.py @@ -43,6 +43,9 @@ class RelayEchoClient(WebSocketClientProtocol): self.factory.ready.callback(None) else: self._received += data + if False: + # test abrupt hangup from receiving side + self.transport.loseConnection() def onClose(self, wasClean, code, reason): print(">onClose", wasClean, code, reason) From 09e46d3713880b8624853160b8e851b8e11e39a7 Mon Sep 17 00:00:00 2001 From: meejah Date: Wed, 7 Apr 2021 16:37:25 -0600 Subject: [PATCH 52/96] cleanup --- .../test/test_transit_server.py | 5 ---- src/wormhole_transit_relay/transit_server.py | 28 ++++++------------- 2 files changed, 9 insertions(+), 24 deletions(-) diff --git a/src/wormhole_transit_relay/test/test_transit_server.py b/src/wormhole_transit_relay/test/test_transit_server.py index f00db6c..d94577b 100644 --- a/src/wormhole_transit_relay/test/test_transit_server.py +++ b/src/wormhole_transit_relay/test/test_transit_server.py @@ -416,8 +416,6 @@ class Usage(ServerBase, unittest.TestCase): def test_one_happy_one_jilted(self): p1 = self.new_protocol() p2 = self.new_protocol() - print(dir(p1.factory)) - return token1 = b"\x00"*32 side1 = b"\x01"*8 @@ -427,7 +425,6 @@ class Usage(ServerBase, unittest.TestCase): p2.send(handshake(token1, side=side2)) self.flush() - print("shouldn't be events yet") self.assertEqual(self._usage.events, []) # no events yet p1.send(b"\x00" * 13) @@ -464,8 +461,6 @@ class Usage(ServerBase, unittest.TestCase): p1c.disconnect() self.flush() - for x in self._usage.events: - print(x) self.assertEqual(len(self._usage.events), 1, self._usage) self.assertEqual(self._usage.events[0]["mood"], "lonely") diff --git a/src/wormhole_transit_relay/transit_server.py b/src/wormhole_transit_relay/transit_server.py index 9a9c31a..7f9a1ca 100644 --- a/src/wormhole_transit_relay/transit_server.py +++ b/src/wormhole_transit_relay/transit_server.py @@ -56,8 +56,7 @@ class TransitConnection(LineReceiver): """ if self._buddy is not None: log.msg("buddy_disconnected {}".format(self._buddy.get_token())) - # XXX if our buddy is a WebSocket, this isn't the right way? - self._buddy._client.transport.loseConnection() + self._buddy._client.disconnect() self._buddy = None def connectionMade(self): @@ -76,9 +75,10 @@ class TransitConnection(LineReceiver): except AttributeError: pass - def tracer(oldstate, theinput, newstate): - print("TRACE: {}: {} --{}--> {}".format(id(self), oldstate, theinput, newstate)) - self._state.set_trace_function(tracer) + if False: + def tracer(oldstate, theinput, newstate): + print("TRACE: {}: {} --{}--> {}".format(id(self), oldstate, theinput, newstate)) + self._state.set_trace_function(tracer) def lineReceived(self, line): """ @@ -133,16 +133,8 @@ class TransitConnection(LineReceiver): # XXX this probably resulted in a log message we've not refactored yet # self.factory.transitFinished(self, self._got_token, self._got_side, # self.describeToken()) -# XXX describeToken -> self._state.get_token() -# XXX multiple-inheritance sucks... -# ("Transit" wants to be "the factory" but the base class is slightly -# different for websocket versus "normal" socket .. so maybe we need -# to make Transit *not* the factory directly?) - -##WebSocketServerFactory):#protocol.ServerFactory): - class Transit(object): """ I manage pairs of simultaneous connections to a secondary TCP port, @@ -179,7 +171,6 @@ class Transit(object): MAXLENGTH = 10*MB # TODO: unused MAXTIME = 60*SECONDS -## protocol = TransitConnection def __init__(self, usage, get_timestamp): self.active_connections = ActiveConnections() @@ -236,7 +227,6 @@ class WebSocketTransitConnection(WebSocketServerProtocol): """ if self._buddy is not None: log.msg("buddy_disconnected {}".format(self._buddy.get_token())) - # XXX if our buddy is tcp this is wrong self._buddy._client.disconnect() self._buddy = None @@ -252,12 +242,12 @@ class WebSocketTransitConnection(WebSocketServerProtocol): self.factory.transit.usage, ) - def tracer(oldstate, theinput, newstate): - print("WSTRACE: {}: {} --{}--> {}".format(id(self), oldstate, theinput, newstate)) - self._state.set_trace_function(tracer) + if False: + def tracer(oldstate, theinput, newstate): + print("WSTRACE: {}: {} --{}--> {}".format(id(self), oldstate, theinput, newstate)) + self._state.set_trace_function(tracer) def onOpen(self): - # print("onOpen") self._state.connection_made(self) def onMessage(self, payload, isBinary): From d7ebd02f780cf730df5e5c927d86350c8ad071c8 Mon Sep 17 00:00:00 2001 From: meejah Date: Wed, 7 Apr 2021 16:44:01 -0600 Subject: [PATCH 53/96] unused --- src/wormhole_transit_relay/server_state.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/wormhole_transit_relay/server_state.py b/src/wormhole_transit_relay/server_state.py index 0dfed82..31a4040 100644 --- a/src/wormhole_transit_relay/server_state.py +++ b/src/wormhole_transit_relay/server_state.py @@ -498,10 +498,6 @@ class TransitServerState(object): def _count_bytes(self, data): self._total_sent += len(data) - @_machine.output() - def _send(self, data): - self._client.send(data) - @_machine.output() def _send_to_partner(self, data): self._buddy._client.send(data) From 941e4fe18a5e1ea6396cab0971b4daf6b544b44a Mon Sep 17 00:00:00 2001 From: meejah Date: Wed, 7 Apr 2021 16:44:17 -0600 Subject: [PATCH 54/96] clean up imports --- .../test/test_transit_server.py | 33 ++++++------------- src/wormhole_transit_relay/transit_server.py | 1 - 2 files changed, 10 insertions(+), 24 deletions(-) diff --git a/src/wormhole_transit_relay/test/test_transit_server.py b/src/wormhole_transit_relay/test/test_transit_server.py index d94577b..e281d3d 100644 --- a/src/wormhole_transit_relay/test/test_transit_server.py +++ b/src/wormhole_transit_relay/test/test_transit_server.py @@ -1,10 +1,16 @@ from __future__ import print_function, unicode_literals -import base64 from binascii import hexlify from twisted.trial import unittest -from twisted.test import proto_helpers -from twisted.internet.defer import inlineCallbacks -from twisted.internet.task import deferLater +from twisted.test import iosim +from autobahn.twisted.websocket import ( + WebSocketServerFactory, + WebSocketClientFactory, + WebSocketClientProtocol, +) +from autobahn.twisted.testing import ( + create_pumper, + MemoryReactorClockResolver, +) from .common import ServerBase from ..server_state import ( MemoryUsageRecorder, @@ -14,15 +20,6 @@ from ..transit_server import ( WebSocketTransitConnection, ) -from ..transit_server import Transit, TransitConnection, WebSocketTransitConnection -from twisted.internet.protocol import ( - ServerFactory, - ClientFactory, -) -from twisted.internet import protocol -from ..server_state import create_usage_tracker -from autobahn.twisted.websocket import WebSocketServerFactory - def handshake(token, side=None): hs = b"please relay " + hexlify(token) @@ -478,16 +475,6 @@ class Usage(ServerBase, unittest.TestCase): self.assertEqual(self._usage.events[2]["mood"], "happy") -from twisted.test import iosim -from twisted.internet.testing import MemoryReactorClock -from twisted.internet.address import IPv4Address -from autobahn.twisted.testing import ( - create_pumper, - create_memory_agent, - MemoryReactorClockResolver, -) - - class UsageWebSockets(Usage): """ All the tests of 'Usage' except with a WebSocket (instead of TCP) diff --git a/src/wormhole_transit_relay/transit_server.py b/src/wormhole_transit_relay/transit_server.py index 7f9a1ca..bc16b86 100644 --- a/src/wormhole_transit_relay/transit_server.py +++ b/src/wormhole_transit_relay/transit_server.py @@ -2,7 +2,6 @@ from __future__ import print_function, unicode_literals import re import time from twisted.python import log -from twisted.internet import protocol from twisted.protocols.basic import LineReceiver from autobahn.twisted.websocket import WebSocketServerProtocol From eb3bc6b5a88020f6ddee01ea4ee7cd81d59bd549 Mon Sep 17 00:00:00 2001 From: meejah Date: Wed, 7 Apr 2021 16:44:24 -0600 Subject: [PATCH 55/96] cleanup --- src/wormhole_transit_relay/test/test_transit_server.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/wormhole_transit_relay/test/test_transit_server.py b/src/wormhole_transit_relay/test/test_transit_server.py index e281d3d..8835522 100644 --- a/src/wormhole_transit_relay/test/test_transit_server.py +++ b/src/wormhole_transit_relay/test/test_transit_server.py @@ -504,14 +504,11 @@ class UsageWebSockets(Usage): """ def new_protocol(self): - ws_factory = WebSocketServerFactory("ws://localhost:4002") # FIXME: url + ws_factory = WebSocketServerFactory("ws://localhost:4002") ws_factory.protocol = WebSocketTransitConnection ws_factory.transit = self._transit_server - ws_factory.websocket_protocols = ["binary"] ws_protocol = ws_factory.buildProtocol(('127.0.0.1', 0)) - from autobahn.twisted.websocket import WebSocketClientFactory, WebSocketClientProtocol - class TransitWebSocketClientProtocol(WebSocketClientProtocol): def send(self, data): self.sendMessage(data, True) From 4112f718d4ed17ec8fba7f68626037e533a46eee Mon Sep 17 00:00:00 2001 From: meejah Date: Wed, 7 Apr 2021 16:46:12 -0600 Subject: [PATCH 56/96] unused --- src/wormhole_transit_relay/test/test_transit_server.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/wormhole_transit_relay/test/test_transit_server.py b/src/wormhole_transit_relay/test/test_transit_server.py index 8835522..4095f45 100644 --- a/src/wormhole_transit_relay/test/test_transit_server.py +++ b/src/wormhole_transit_relay/test/test_transit_server.py @@ -515,7 +515,6 @@ class UsageWebSockets(Usage): client_factory = WebSocketClientFactory() client_factory.protocol = TransitWebSocketClientProtocol - client_factory.protocols = ["binary"] client_protocol = client_factory.buildProtocol(('127.0.0.1', 31337)) client_protocol.disconnect = client_protocol.dropConnection From b73c76c8df2b6c30a3becbe01bfab0fb1b37b952 Mon Sep 17 00:00:00 2001 From: meejah Date: Wed, 7 Apr 2021 17:03:23 -0600 Subject: [PATCH 57/96] run '_Transit' tests on websockets too --- .../test/test_transit_server.py | 69 ++++++++++++++++++- 1 file changed, 67 insertions(+), 2 deletions(-) diff --git a/src/wormhole_transit_relay/test/test_transit_server.py b/src/wormhole_transit_relay/test/test_transit_server.py index 4095f45..73f75fd 100644 --- a/src/wormhole_transit_relay/test/test_transit_server.py +++ b/src/wormhole_transit_relay/test/test_transit_server.py @@ -196,6 +196,7 @@ class _Transit: p2.send(handshake(token1, side=side1)) self.flush() + self.flush() self.assertEqual(self.count(), 2) # same-side connections don't match # when the second side arrives, the spare first connection should be @@ -285,7 +286,8 @@ class _Transit: token1 = b"\x00"*32 # sending too many bytes is impatience. - p1.send(b"please relay " + hexlify(token1) + b"\nNOWNOWNOW") + p1.send(b"please relay " + hexlify(token1)) + p1.send(b"\nNOWNOWNOW") self.flush() exp = b"impatient\n" @@ -300,7 +302,8 @@ class _Transit: side1 = b"\x01"*8 # sending too many bytes is impatience. p1.send(b"please relay " + hexlify(token1) + - b" for side " + hexlify(side1) + b"\nNOWNOWNOW") + b" for side " + hexlify(side1)) + p1.send(b"\nNOWNOWNOW") self.flush() exp = b"impatient\n" @@ -355,6 +358,58 @@ class TransitWithoutLogs(_Transit, ServerBase, unittest.TestCase): log_requests = False +class TransitWebSockets(_Transit, ServerBase, unittest.TestCase): + + def test_bad_handshake_old_slow(self): + """ + This test only makes sense for TCP + """ + + def new_protocol(self): + ws_factory = WebSocketServerFactory("ws://localhost:4002") + ws_factory.protocol = WebSocketTransitConnection + ws_factory.transit = self._transit_server + ws_protocol = ws_factory.buildProtocol(('127.0.0.1', 0)) + + class TransitWebSocketClientProtocol(WebSocketClientProtocol): + _received = b"" + connected = False + + def connectionMade(self): + self.connected = True + return super(TransitWebSocketClientProtocol, self).connectionMade() + + def connectionLost(self, reason): + self.connected = False + return super(TransitWebSocketClientProtocol, self).connectionLost(reason) + + def send(self, data): + self.sendMessage(data, True) + + def onMessage(self, data, isBinary): + self._received = self._received + data + + def get_received_data(self): + return self._received + + def reset_received_data(self): + self._received = b"" + + client_factory = WebSocketClientFactory() + client_factory.protocol = TransitWebSocketClientProtocol + client_protocol = client_factory.buildProtocol(('127.0.0.1', 31337)) + client_protocol.disconnect = client_protocol.dropConnection + + pump = iosim.connect( + ws_protocol, + iosim.makeFakeServer(ws_protocol), + client_protocol, + iosim.makeFakeClient(client_protocol), + ) + self._pumps.append(pump) + return client_protocol + + class Usage(ServerBase, unittest.TestCase): log_requests = True @@ -503,6 +558,16 @@ class UsageWebSockets(Usage): because it is semantically invalid or no handshake (yet). """ + def test_send_non_binary_message(self): + """ + A non-binary WebSocket message is an error + """ + ws_factory = WebSocketServerFactory("ws://localhost:4002") + ws_factory.protocol = WebSocketTransitConnection + ws_protocol = ws_factory.buildProtocol(('127.0.0.1', 0)) + with self.assertRaises(ValueError): + ws_protocol.onMessage(u"foo", isBinary=False) + def new_protocol(self): ws_factory = WebSocketServerFactory("ws://localhost:4002") ws_factory.protocol = WebSocketTransitConnection From e689bfcf4ffd01739ad58ae5773e0d22ad254247 Mon Sep 17 00:00:00 2001 From: meejah Date: Sat, 10 Apr 2021 18:33:47 -0600 Subject: [PATCH 58/96] remove debug, doesn't make sense anymore --- src/wormhole_transit_relay/transit_server.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/src/wormhole_transit_relay/transit_server.py b/src/wormhole_transit_relay/transit_server.py index bc16b86..15d235d 100644 --- a/src/wormhole_transit_relay/transit_server.py +++ b/src/wormhole_transit_relay/transit_server.py @@ -114,14 +114,6 @@ class TransitConnection(LineReceiver): # receiver can handle it. self._state.got_bytes(data) - def disconnect_error(self): - # we haven't finished the handshake, so there are no tokens tracking - # us - self.transport.loseConnection() - # XXX probably should be logged by state? - if self.factory.transit._debug_log: - log.msg("transitFailed %r" % self) - def disconnect_redundant(self): # this is called if a buddy connected and we were found unnecessary. # Any token-tracking cleanup will have been done before we're called. @@ -175,7 +167,6 @@ class Transit(object): self.active_connections = ActiveConnections() self.pending_requests = PendingRequests(self.active_connections) self.usage = usage - self._debug_log = False self._timestamp = get_timestamp self._rebooted = self._timestamp() From ce7458e604d31ba148e2de5ea5855292234defe4 Mon Sep 17 00:00:00 2001 From: meejah Date: Sat, 10 Apr 2021 18:34:11 -0600 Subject: [PATCH 59/96] test for disconnect / error propagation --- .../test/test_transit_server.py | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/src/wormhole_transit_relay/test/test_transit_server.py b/src/wormhole_transit_relay/test/test_transit_server.py index 73f75fd..fee418a 100644 --- a/src/wormhole_transit_relay/test/test_transit_server.py +++ b/src/wormhole_transit_relay/test/test_transit_server.py @@ -11,6 +11,7 @@ from autobahn.twisted.testing import ( create_pumper, MemoryReactorClockResolver, ) +from autobahn.exception import Disconnected from .common import ServerBase from ..server_state import ( MemoryUsageRecorder, @@ -365,6 +366,34 @@ class TransitWebSockets(_Transit, ServerBase, unittest.TestCase): This test only makes sense for TCP """ + def test_send_closed_partner(self): + """ + Sending data to a closed partner causes an error that propogates + to the sender. + """ + p1 = self.new_protocol() + p2 = self.new_protocol() + + # set up a successful connection + token = b"a" * 32 + p1.send(handshake(token)) + p2.send(handshake(token)) + self.flush() + + # p2 loses connection, then p1 sends a message + p2.transport.loseConnection() + self.flush() + p1.send(b"more message") + self.flush() + + # at this point, p1 learns that p2 is disconnected (because it + # tried to relay "a message" but failed) + + # try to send more (our partner p2 is gone now though so it + # should be an immediate error) + with self.assertRaises(Disconnected): + p1.send(b"more message") + def new_protocol(self): ws_factory = WebSocketServerFactory("ws://localhost:4002") ws_factory.protocol = WebSocketTransitConnection From a057da49cfb78bfe53b9cbf8c8475ab5aeed2c6d Mon Sep 17 00:00:00 2001 From: meejah Date: Sat, 10 Apr 2021 18:37:43 -0600 Subject: [PATCH 60/96] disconnect_redundant is redundant --- src/wormhole_transit_relay/server_state.py | 1 - src/wormhole_transit_relay/transit_server.py | 10 ---------- 2 files changed, 11 deletions(-) diff --git a/src/wormhole_transit_relay/server_state.py b/src/wormhole_transit_relay/server_state.py index 31a4040..d7e7055 100644 --- a/src/wormhole_transit_relay/server_state.py +++ b/src/wormhole_transit_relay/server_state.py @@ -351,7 +351,6 @@ class PendingRequests(object): # can happen if the connection hint contains multiple # addresses (we don't currently support those, but it'd # probably be useful in the future). - ##leftover_tc.disconnect_redundant() leftover_tc.partner_connection_lost() self._requests.pop(token, None) diff --git a/src/wormhole_transit_relay/transit_server.py b/src/wormhole_transit_relay/transit_server.py index 15d235d..d85e383 100644 --- a/src/wormhole_transit_relay/transit_server.py +++ b/src/wormhole_transit_relay/transit_server.py @@ -114,11 +114,6 @@ class TransitConnection(LineReceiver): # receiver can handle it. self._state.got_bytes(data) - def disconnect_redundant(self): - # this is called if a buddy connected and we were found unnecessary. - # Any token-tracking cleanup will have been done before we're called. - self.transport.loseConnection() - def connectionLost(self, reason): self._state.connection_lost() # XXX this probably resulted in a log message we've not refactored yet @@ -272,11 +267,6 @@ class WebSocketTransitConnection(WebSocketServerProtocol): log.err("Failed to send to partner: {}".format(e)) self.sendClose(3000, "send to partner failed") - def disconnect_redundant(self): - # this is called if a buddy connected and we were found unnecessary. - # Any token-tracking cleanup will have been done before we're called. - self.transport.loseConnection() - def onClose(self, wasClean, code, reason): """ IWebSocketChannel API From ef96af2a80497b9f21a81ca2a1735584c9adea2b Mon Sep 17 00:00:00 2001 From: meejah Date: Sat, 10 Apr 2021 18:45:20 -0600 Subject: [PATCH 61/96] websocket tests already use the interface --- src/wormhole_transit_relay/test/test_transit_server.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/wormhole_transit_relay/test/test_transit_server.py b/src/wormhole_transit_relay/test/test_transit_server.py index fee418a..826093d 100644 --- a/src/wormhole_transit_relay/test/test_transit_server.py +++ b/src/wormhole_transit_relay/test/test_transit_server.py @@ -12,7 +12,11 @@ from autobahn.twisted.testing import ( MemoryReactorClockResolver, ) from autobahn.exception import Disconnected -from .common import ServerBase +from zope.interface import implementer +from .common import ( + ServerBase, + IRelayTestClient, +) from ..server_state import ( MemoryUsageRecorder, blur_size, @@ -400,6 +404,7 @@ class TransitWebSockets(_Transit, ServerBase, unittest.TestCase): ws_factory.transit = self._transit_server ws_protocol = ws_factory.buildProtocol(('127.0.0.1', 0)) + @implementer(IRelayTestClient) class TransitWebSocketClientProtocol(WebSocketClientProtocol): _received = b"" connected = False From 8132ea8f917f89700b46c1ab3161645df817a7c4 Mon Sep 17 00:00:00 2001 From: meejah Date: Mon, 12 Apr 2021 08:09:57 -0600 Subject: [PATCH 62/96] more docstrings --- src/wormhole_transit_relay/server_state.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/wormhole_transit_relay/server_state.py b/src/wormhole_transit_relay/server_state.py index d7e7055..e3e1d06 100644 --- a/src/wormhole_transit_relay/server_state.py +++ b/src/wormhole_transit_relay/server_state.py @@ -12,6 +12,10 @@ from twisted.python import log class ITransitClient(Interface): + """ + Represents the client side of a connection to this transit + relay. This is used by TransitServerState instances. + """ started_time = Attribute("timestamp when the connection was established") @@ -61,6 +65,9 @@ class IUsageWriter(Interface): @implementer(IUsageWriter) class MemoryUsageRecorder: + """ + Remebers usage records in memory. + """ def __init__(self): self.events = [] @@ -81,6 +88,10 @@ class MemoryUsageRecorder: @implementer(IUsageWriter) class LogFileUsageRecorder: + """ + Writes usage records to a file. The records are written in JSON, + one record per line. + """ def __init__(self, writable_file): self._file = writable_file @@ -102,6 +113,9 @@ class LogFileUsageRecorder: @implementer(IUsageWriter) class DatabaseUsageRecorder: + """ + Write usage records into a database + """ def __init__(self, db): self._db = db @@ -174,7 +188,6 @@ class UsageTracker(object): self._blur_usage = blur_usage if blur_usage: log.msg("blurring access times to %d seconds" % self._blur_usage) -## XXX log.msg("not logging Transit connections to Twisted log") else: log.msg("not blurring access times") From 5f43e53db17daeddc3139a67eedfc1bf35e8b941 Mon Sep 17 00:00:00 2001 From: meejah Date: Mon, 12 Apr 2021 09:35:55 -0600 Subject: [PATCH 63/96] cleanup --- .../test/test_transit_server.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/wormhole_transit_relay/test/test_transit_server.py b/src/wormhole_transit_relay/test/test_transit_server.py index 826093d..994121e 100644 --- a/src/wormhole_transit_relay/test/test_transit_server.py +++ b/src/wormhole_transit_relay/test/test_transit_server.py @@ -99,7 +99,6 @@ class _Transit: self.assertEqual(p2.get_received_data(), s1) p1.disconnect() - p2.disconnect() self.flush() def test_sided_unsided(self): @@ -128,7 +127,6 @@ class _Transit: self.assertEqual(p2.get_received_data(), s1) p1.disconnect() - p2.disconnect() self.flush() def test_unsided_sided(self): @@ -365,6 +363,9 @@ class TransitWithoutLogs(_Transit, ServerBase, unittest.TestCase): class TransitWebSockets(_Transit, ServerBase, unittest.TestCase): + # XXX note to self, from pairing with Flo: + # - write a WS <--> TCP version of at least one of these tests? + def test_bad_handshake_old_slow(self): """ This test only makes sense for TCP @@ -387,8 +388,6 @@ class TransitWebSockets(_Transit, ServerBase, unittest.TestCase): # p2 loses connection, then p1 sends a message p2.transport.loseConnection() self.flush() - p1.send(b"more message") - self.flush() # at this point, p1 learns that p2 is disconnected (because it # tried to relay "a message" but failed) @@ -417,18 +416,21 @@ class TransitWebSockets(_Transit, ServerBase, unittest.TestCase): self.connected = False return super(TransitWebSocketClientProtocol, self).connectionLost(reason) - def send(self, data): - self.sendMessage(data, True) - def onMessage(self, data, isBinary): self._received = self._received + data + def send(self, data): + self.sendMessage(data, True) + def get_received_data(self): return self._received def reset_received_data(self): self._received = b"" + def disconnect(self): + self.sendClose(1000, True) + client_factory = WebSocketClientFactory() client_factory.protocol = TransitWebSocketClientProtocol client_protocol = client_factory.buildProtocol(('127.0.0.1', 31337)) From 8aeea711eb10875eb2370351e355fc252aa53b42 Mon Sep 17 00:00:00 2001 From: meejah Date: Mon, 12 Apr 2021 09:43:07 -0600 Subject: [PATCH 64/96] cleanup --- src/wormhole_transit_relay/test/test_transit_server.py | 1 + src/wormhole_transit_relay/transit_server.py | 6 +----- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/src/wormhole_transit_relay/test/test_transit_server.py b/src/wormhole_transit_relay/test/test_transit_server.py index 994121e..400c234 100644 --- a/src/wormhole_transit_relay/test/test_transit_server.py +++ b/src/wormhole_transit_relay/test/test_transit_server.py @@ -396,6 +396,7 @@ class TransitWebSockets(_Transit, ServerBase, unittest.TestCase): # should be an immediate error) with self.assertRaises(Disconnected): p1.send(b"more message") + self.flush() def new_protocol(self): ws_factory = WebSocketServerFactory("ws://localhost:4002") diff --git a/src/wormhole_transit_relay/transit_server.py b/src/wormhole_transit_relay/transit_server.py index d85e383..8bbc8d1 100644 --- a/src/wormhole_transit_relay/transit_server.py +++ b/src/wormhole_transit_relay/transit_server.py @@ -261,11 +261,7 @@ class WebSocketTransitConnection(WebSocketServerProtocol): if token is None: self._state.bad_token() else: - try: - self._state.got_bytes(payload) - except Exception as e: - log.err("Failed to send to partner: {}".format(e)) - self.sendClose(3000, "send to partner failed") + self._state.got_bytes(payload) def onClose(self, wasClean, code, reason): """ From 5b7ec9ef4c9f6893fd9599b6a5e9738f3532e59a Mon Sep 17 00:00:00 2001 From: meejah Date: Mon, 12 Apr 2021 09:47:52 -0600 Subject: [PATCH 65/96] move usage-tracking stuff to own module --- src/wormhole_transit_relay/server_state.py | 228 ----------------- src/wormhole_transit_relay/server_tap.py | 2 +- src/wormhole_transit_relay/test/common.py | 2 +- src/wormhole_transit_relay/test/test_stats.py | 2 +- .../test/test_transit_server.py | 2 +- src/wormhole_transit_relay/usage.py | 236 ++++++++++++++++++ 6 files changed, 240 insertions(+), 232 deletions(-) create mode 100644 src/wormhole_transit_relay/usage.py diff --git a/src/wormhole_transit_relay/server_state.py b/src/wormhole_transit_relay/server_state.py index e3e1d06..19e3c3c 100644 --- a/src/wormhole_transit_relay/server_state.py +++ b/src/wormhole_transit_relay/server_state.py @@ -41,234 +41,6 @@ class ITransitClient(Interface): """ -class IUsageWriter(Interface): - """ - Records actual usage statistics in some way - """ - - def record_usage(started=None, total_time=None, waiting_time=None, total_bytes=None, mood=None): - """ - :param int started: timestemp when this connection began - - :param float total_time: total seconds this connection lasted - - :param float waiting_time: None or the total seconds one side - waited for the other - - :param int total_bytes: the total bytes sent. In case the - connection was concluded successfully, only one side will - record the total bytes (but count both). - - :param str mood: the 'mood' of the connection - """ - - -@implementer(IUsageWriter) -class MemoryUsageRecorder: - """ - Remebers usage records in memory. - """ - - def __init__(self): - self.events = [] - - def record_usage(self, started=None, total_time=None, waiting_time=None, total_bytes=None, mood=None): - """ - IUsageWriter. - """ - data = { - "started": started, - "total_time": total_time, - "waiting_time": waiting_time, - "total_bytes": total_bytes, - "mood": mood, - } - self.events.append(data) - - -@implementer(IUsageWriter) -class LogFileUsageRecorder: - """ - Writes usage records to a file. The records are written in JSON, - one record per line. - """ - - def __init__(self, writable_file): - self._file = writable_file - - def record_usage(self, started=None, total_time=None, waiting_time=None, total_bytes=None, mood=None): - """ - IUsageWriter. - """ - data = { - "started": started, - "total_time": total_time, - "waiting_time": waiting_time, - "total_bytes": total_bytes, - "mood": mood, - } - self._file.write(json.dumps(data) + "\n") - self._file.flush() - - -@implementer(IUsageWriter) -class DatabaseUsageRecorder: - """ - Write usage records into a database - """ - - def __init__(self, db): - self._db = db - - def record_usage(self, started=None, total_time=None, waiting_time=None, total_bytes=None, mood=None): - """ - IUsageWriter. - """ - self._db.execute( - "INSERT INTO `usage`" - " (`started`, `total_time`, `waiting_time`," - " `total_bytes`, `result`)" - " VALUES (?,?,?,?,?)", - (started, total_time, waiting_time, total_bytes, mood) - ) - # original code did "self._update_stats()" here, thus causing - # "global" stats update on every connection update .. should - # we repeat this behavior, or really only record every - # 60-seconds with the timer? - self._db.commit() - - -def round_to(size, coarseness): - return int(coarseness*(1+int((size-1)/coarseness))) - - -def blur_size(size): - if size == 0: - return 0 - if size < 1e6: - return round_to(size, 10e3) - if size < 1e9: - return round_to(size, 1e6) - return round_to(size, 100e6) - - -def create_usage_tracker(blur_usage, log_file, usage_db): - """ - :param int blur_usage: see UsageTracker - - :param log_file: None or a file-like object to write JSON-encoded - lines of usage information to. - - :param usage_db: None or an sqlite3 database connection - - :returns: a new UsageTracker instance configured with backends. - """ - tracker = UsageTracker(blur_usage) - if usage_db: - tracker.add_backend(DatabaseUsageRecorder(usage_db)) - if log_file: - tracker.add_backend(LogFileUsageRecorder(log_file)) - return tracker - - -class UsageTracker(object): - """ - Tracks usage statistics of connections - """ - - def __init__(self, blur_usage): - """ - :param int blur_usage: None or the number of seconds to use as a - window around which to blur time statistics (e.g. "60" means times - will be rounded to 1 minute intervals). When blur_usage is - non-zero, sizes will also be rounded into buckets of "one - megabyte", "one gigabyte" or "lots" - """ - self._backends = set() - self._blur_usage = blur_usage - if blur_usage: - log.msg("blurring access times to %d seconds" % self._blur_usage) - else: - log.msg("not blurring access times") - - def add_backend(self, backend): - """ - Add a new backend. - - :param IUsageWriter backend: the backend to add - """ - self._backends.add(backend) - - def record(self, started, buddy_started, result, bytes_sent, buddy_bytes): - """ - :param int started: timestamp when our connection started - - :param int buddy_started: None, or the timestamp when our - partner's connection started (will be None if we don't yet - have a partner). - - :param str result: a label for the result of the connection - (one of the "moods"). - - :param int bytes_sent: number of bytes we sent - - :param int buddy_bytes: number of bytes our partner sent - """ - # ideally self._reactor.seconds() or similar, but .. - finished = time.time() - if buddy_started is not None: - starts = [started, buddy_started] - total_time = finished - min(starts) - waiting_time = max(starts) - min(starts) - total_bytes = bytes_sent + buddy_bytes - else: - total_time = finished - started - waiting_time = None - total_bytes = bytes_sent - - if self._blur_usage: - started = self._blur_usage * (started // self._blur_usage) - total_bytes = blur_size(total_bytes) - - # This is "a dict" instead of "kwargs" because we have to make - # it into a dict for the log use-case and in-memory/testing - # use-case anyway so this is less repeats of the names. - self._notify_backends({ - "started": started, - "total_time": total_time, - "waiting_time": waiting_time, - "total_bytes": total_bytes, - "mood": result, - }) - - def update_stats(self, rebooted, updated, connected, waiting, - incomplete_bytes): - """ - Update general statistics. - """ - # in original code, this is only recorded in the database - # .. perhaps a better way to do this, but .. - for backend in self._backends: - if isinstance(backend, DatabaseUsageRecorder): - backend._db.execute("DELETE FROM `current`") - backend._db.execute( - "INSERT INTO `current`" - " (`rebooted`, `updated`, `connected`, `waiting`," - " `incomplete_bytes`)" - " VALUES (?, ?, ?, ?, ?)", - (int(rebooted), int(updated), connected, waiting, - incomplete_bytes) - ) - - - def _notify_backends(self, data): - """ - Internal helper. Tell every backend we have about a new usage record. - """ - for backend in self._backends: - backend.record_usage(**data) - - class ActiveConnections(object): """ Tracks active connections. diff --git a/src/wormhole_transit_relay/server_tap.py b/src/wormhole_transit_relay/server_tap.py index 7223fa4..3973617 100644 --- a/src/wormhole_transit_relay/server_tap.py +++ b/src/wormhole_transit_relay/server_tap.py @@ -10,7 +10,7 @@ from twisted.internet import protocol from autobahn.twisted.websocket import WebSocketServerFactory from . import transit_server -from .server_state import create_usage_tracker +from .usage import create_usage_tracker from .increase_rlimits import increase_rlimits from .database import get_db diff --git a/src/wormhole_transit_relay/test/common.py b/src/wormhole_transit_relay/test/common.py index a98a338..c502c33 100644 --- a/src/wormhole_transit_relay/test/common.py +++ b/src/wormhole_transit_relay/test/common.py @@ -13,7 +13,7 @@ from ..transit_server import ( TransitConnection, ) from twisted.internet.protocol import ServerFactory -from ..server_state import create_usage_tracker +from ..usage import create_usage_tracker class IRelayTestClient(Interface): diff --git a/src/wormhole_transit_relay/test/test_stats.py b/src/wormhole_transit_relay/test/test_stats.py index 806aafb..5137036 100644 --- a/src/wormhole_transit_relay/test/test_stats.py +++ b/src/wormhole_transit_relay/test/test_stats.py @@ -6,7 +6,7 @@ except ImportError: import mock from twisted.trial import unittest from ..transit_server import Transit -from ..server_state import create_usage_tracker +from ..usage import create_usage_tracker from .. import database class DB(unittest.TestCase): diff --git a/src/wormhole_transit_relay/test/test_transit_server.py b/src/wormhole_transit_relay/test/test_transit_server.py index 400c234..4dbdf31 100644 --- a/src/wormhole_transit_relay/test/test_transit_server.py +++ b/src/wormhole_transit_relay/test/test_transit_server.py @@ -17,7 +17,7 @@ from .common import ( ServerBase, IRelayTestClient, ) -from ..server_state import ( +from ..usage import ( MemoryUsageRecorder, blur_size, ) diff --git a/src/wormhole_transit_relay/usage.py b/src/wormhole_transit_relay/usage.py new file mode 100644 index 0000000..b73fd7d --- /dev/null +++ b/src/wormhole_transit_relay/usage.py @@ -0,0 +1,236 @@ +import time +import json + +from twisted.python import log +from zope.interface import ( + implementer, + Interface, +) + + +def create_usage_tracker(blur_usage, log_file, usage_db): + """ + :param int blur_usage: see UsageTracker + + :param log_file: None or a file-like object to write JSON-encoded + lines of usage information to. + + :param usage_db: None or an sqlite3 database connection + + :returns: a new UsageTracker instance configured with backends. + """ + tracker = UsageTracker(blur_usage) + if usage_db: + tracker.add_backend(DatabaseUsageRecorder(usage_db)) + if log_file: + tracker.add_backend(LogFileUsageRecorder(log_file)) + return tracker + + +class IUsageWriter(Interface): + """ + Records actual usage statistics in some way + """ + + def record_usage(started=None, total_time=None, waiting_time=None, total_bytes=None, mood=None): + """ + :param int started: timestemp when this connection began + + :param float total_time: total seconds this connection lasted + + :param float waiting_time: None or the total seconds one side + waited for the other + + :param int total_bytes: the total bytes sent. In case the + connection was concluded successfully, only one side will + record the total bytes (but count both). + + :param str mood: the 'mood' of the connection + """ + + +@implementer(IUsageWriter) +class MemoryUsageRecorder: + """ + Remebers usage records in memory. + """ + + def __init__(self): + self.events = [] + + def record_usage(self, started=None, total_time=None, waiting_time=None, total_bytes=None, mood=None): + """ + IUsageWriter. + """ + data = { + "started": started, + "total_time": total_time, + "waiting_time": waiting_time, + "total_bytes": total_bytes, + "mood": mood, + } + self.events.append(data) + + +@implementer(IUsageWriter) +class LogFileUsageRecorder: + """ + Writes usage records to a file. The records are written in JSON, + one record per line. + """ + + def __init__(self, writable_file): + self._file = writable_file + + def record_usage(self, started=None, total_time=None, waiting_time=None, total_bytes=None, mood=None): + """ + IUsageWriter. + """ + data = { + "started": started, + "total_time": total_time, + "waiting_time": waiting_time, + "total_bytes": total_bytes, + "mood": mood, + } + self._file.write(json.dumps(data) + "\n") + self._file.flush() + + +@implementer(IUsageWriter) +class DatabaseUsageRecorder: + """ + Write usage records into a database + """ + + def __init__(self, db): + self._db = db + + def record_usage(self, started=None, total_time=None, waiting_time=None, total_bytes=None, mood=None): + """ + IUsageWriter. + """ + self._db.execute( + "INSERT INTO `usage`" + " (`started`, `total_time`, `waiting_time`," + " `total_bytes`, `result`)" + " VALUES (?,?,?,?,?)", + (started, total_time, waiting_time, total_bytes, mood) + ) + # original code did "self._update_stats()" here, thus causing + # "global" stats update on every connection update .. should + # we repeat this behavior, or really only record every + # 60-seconds with the timer? + self._db.commit() + + +class UsageTracker(object): + """ + Tracks usage statistics of connections + """ + + def __init__(self, blur_usage): + """ + :param int blur_usage: None or the number of seconds to use as a + window around which to blur time statistics (e.g. "60" means times + will be rounded to 1 minute intervals). When blur_usage is + non-zero, sizes will also be rounded into buckets of "one + megabyte", "one gigabyte" or "lots" + """ + self._backends = set() + self._blur_usage = blur_usage + if blur_usage: + log.msg("blurring access times to %d seconds" % self._blur_usage) + else: + log.msg("not blurring access times") + + def add_backend(self, backend): + """ + Add a new backend. + + :param IUsageWriter backend: the backend to add + """ + self._backends.add(backend) + + def record(self, started, buddy_started, result, bytes_sent, buddy_bytes): + """ + :param int started: timestamp when our connection started + + :param int buddy_started: None, or the timestamp when our + partner's connection started (will be None if we don't yet + have a partner). + + :param str result: a label for the result of the connection + (one of the "moods"). + + :param int bytes_sent: number of bytes we sent + + :param int buddy_bytes: number of bytes our partner sent + """ + # ideally self._reactor.seconds() or similar, but .. + finished = time.time() + if buddy_started is not None: + starts = [started, buddy_started] + total_time = finished - min(starts) + waiting_time = max(starts) - min(starts) + total_bytes = bytes_sent + buddy_bytes + else: + total_time = finished - started + waiting_time = None + total_bytes = bytes_sent + + if self._blur_usage: + started = self._blur_usage * (started // self._blur_usage) + total_bytes = blur_size(total_bytes) + + # This is "a dict" instead of "kwargs" because we have to make + # it into a dict for the log use-case and in-memory/testing + # use-case anyway so this is less repeats of the names. + self._notify_backends({ + "started": started, + "total_time": total_time, + "waiting_time": waiting_time, + "total_bytes": total_bytes, + "mood": result, + }) + + def update_stats(self, rebooted, updated, connected, waiting, + incomplete_bytes): + """ + Update general statistics. + """ + # in original code, this is only recorded in the database + # .. perhaps a better way to do this, but .. + for backend in self._backends: + if isinstance(backend, DatabaseUsageRecorder): + backend._db.execute("DELETE FROM `current`") + backend._db.execute( + "INSERT INTO `current`" + " (`rebooted`, `updated`, `connected`, `waiting`," + " `incomplete_bytes`)" + " VALUES (?, ?, ?, ?, ?)", + (int(rebooted), int(updated), connected, waiting, + incomplete_bytes) + ) + + + def _notify_backends(self, data): + """ + Internal helper. Tell every backend we have about a new usage record. + """ + for backend in self._backends: + backend.record_usage(**data) + + +def round_to(size, coarseness): + return int(coarseness*(1+int((size-1)/coarseness))) + + +def blur_size(size): + if size == 0: + return 0 + if size < 1e6: + return round_to(size, 10e3) + if size < 1e9: + return round_to(size, 1e6) + return round_to(size, 100e6) From 5a405443b9862bf176316b4cbd32a60cc05f86a3 Mon Sep 17 00:00:00 2001 From: meejah Date: Wed, 14 Apr 2021 15:10:58 -0600 Subject: [PATCH 66/96] more-explicit about which protocol clients use --- src/wormhole_transit_relay/test/common.py | 7 +++++++ .../test/test_transit_server.py | 16 ++++++++++++++-- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/src/wormhole_transit_relay/test/common.py b/src/wormhole_transit_relay/test/common.py index c502c33..cb84de1 100644 --- a/src/wormhole_transit_relay/test/common.py +++ b/src/wormhole_transit_relay/test/common.py @@ -74,6 +74,13 @@ class ServerBase: self._transit_server = Transit(usage, lambda: 123456789.0) def new_protocol(self): + """ + This should be overridden by derived test-case classes to decide + if they want a TCP or WebSockets protocol. + """ + raise NotImplementedError() + + def new_protocol_tcp(self): """ Create a new client protocol connected to the server. :returns: a IRelayTestClient implementation diff --git a/src/wormhole_transit_relay/test/test_transit_server.py b/src/wormhole_transit_relay/test/test_transit_server.py index 4dbdf31..5937c3b 100644 --- a/src/wormhole_transit_relay/test/test_transit_server.py +++ b/src/wormhole_transit_relay/test/test_transit_server.py @@ -34,6 +34,9 @@ def handshake(token, side=None): return hs class _Transit: + def new_protocol(self): + return self.new_protocol_tcp() + def count(self): return sum([ len(potentials) @@ -366,6 +369,9 @@ class TransitWebSockets(_Transit, ServerBase, unittest.TestCase): # XXX note to self, from pairing with Flo: # - write a WS <--> TCP version of at least one of these tests? + def new_protocol(self): + return self.new_protocol_ws() + def test_bad_handshake_old_slow(self): """ This test only makes sense for TCP @@ -398,7 +404,7 @@ class TransitWebSockets(_Transit, ServerBase, unittest.TestCase): p1.send(b"more message") self.flush() - def new_protocol(self): + def new_protocol_ws(self): ws_factory = WebSocketServerFactory("ws://localhost:4002") ws_factory.protocol = WebSocketTransitConnection ws_factory.transit = self._transit_server @@ -455,6 +461,9 @@ class Usage(ServerBase, unittest.TestCase): self._usage = MemoryUsageRecorder() self._transit_server.usage.add_backend(self._usage) + def new_protocol(self): + return self.new_protocol_tcp() + def test_empty(self): p1 = self.new_protocol() # hang up before sending anything @@ -587,6 +596,9 @@ class UsageWebSockets(Usage): def tearDown(self): return self._pump.stop() + def new_protocol(self): + return self.new_protocol_ws() + def test_short(self): """ This test essentially just tests the framing of the line-oriented @@ -605,7 +617,7 @@ class UsageWebSockets(Usage): with self.assertRaises(ValueError): ws_protocol.onMessage(u"foo", isBinary=False) - def new_protocol(self): + def new_protocol_ws(self): ws_factory = WebSocketServerFactory("ws://localhost:4002") ws_factory.protocol = WebSocketTransitConnection ws_factory.transit = self._transit_server From 0db8ed3225256f556eef6f1bd42183c1cccb5b1e Mon Sep 17 00:00:00 2001 From: meejah Date: Wed, 14 Apr 2021 15:11:38 -0600 Subject: [PATCH 67/96] even more explicit --- src/wormhole_transit_relay/test/test_transit_server.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/wormhole_transit_relay/test/test_transit_server.py b/src/wormhole_transit_relay/test/test_transit_server.py index 5937c3b..f05e3e4 100644 --- a/src/wormhole_transit_relay/test/test_transit_server.py +++ b/src/wormhole_transit_relay/test/test_transit_server.py @@ -34,9 +34,6 @@ def handshake(token, side=None): return hs class _Transit: - def new_protocol(self): - return self.new_protocol_tcp() - def count(self): return sum([ len(potentials) @@ -359,10 +356,16 @@ class _Transit: class TransitWithLogs(_Transit, ServerBase, unittest.TestCase): log_requests = True + def new_protocol(self): + return self.new_protocol_tcp() + class TransitWithoutLogs(_Transit, ServerBase, unittest.TestCase): log_requests = False + def new_protocol(self): + return self.new_protocol_tcp() + class TransitWebSockets(_Transit, ServerBase, unittest.TestCase): From 31979460730e5fa3a379067a99a1f9aa198ad986 Mon Sep 17 00:00:00 2001 From: meejah Date: Wed, 14 Apr 2021 15:13:16 -0600 Subject: [PATCH 68/96] websocket<-->TCP test --- .../test/test_transit_server.py | 36 +++++++++++++++++-- 1 file changed, 33 insertions(+), 3 deletions(-) diff --git a/src/wormhole_transit_relay/test/test_transit_server.py b/src/wormhole_transit_relay/test/test_transit_server.py index f05e3e4..ae00ad3 100644 --- a/src/wormhole_transit_relay/test/test_transit_server.py +++ b/src/wormhole_transit_relay/test/test_transit_server.py @@ -369,12 +369,42 @@ class TransitWithoutLogs(_Transit, ServerBase, unittest.TestCase): class TransitWebSockets(_Transit, ServerBase, unittest.TestCase): - # XXX note to self, from pairing with Flo: - # - write a WS <--> TCP version of at least one of these tests? - def new_protocol(self): return self.new_protocol_ws() + def test_websocket_to_tcp(self): + """ + One client is WebSocket and one is TCP + """ + p1 = self.new_protocol_ws() + p2 = self.new_protocol_tcp() + + token1 = b"\x00"*32 + side1 = b"\x01"*8 + side2 = b"\x02"*8 + p1.send(handshake(token1, side=side1)) + self.flush() + p2.send(handshake(token1, side=side2)) + self.flush() + + # a correct handshake yields an ack, after which we can send + exp = b"ok\n" + self.assertEqual(p1.get_received_data(), exp) + self.assertEqual(p2.get_received_data(), exp) + + p1.reset_received_data() + p2.reset_received_data() + + # all data they sent after the handshake should be given to us + s1 = b"data1" + p1.send(s1) + self.flush() + self.assertEqual(p2.get_received_data(), s1) + + p1.disconnect() + p2.disconnect() + self.flush() + def test_bad_handshake_old_slow(self): """ This test only makes sense for TCP From 786cd08350ae9470d7554c8611aac069a54187ca Mon Sep 17 00:00:00 2001 From: meejah Date: Wed, 14 Apr 2021 15:17:48 -0600 Subject: [PATCH 69/96] pyflakes; unused imports --- src/wormhole_transit_relay/server_state.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/wormhole_transit_relay/server_state.py b/src/wormhole_transit_relay/server_state.py index 19e3c3c..a2b9fec 100644 --- a/src/wormhole_transit_relay/server_state.py +++ b/src/wormhole_transit_relay/server_state.py @@ -1,14 +1,10 @@ -import time -import json from collections import defaultdict import automat from zope.interface import ( Interface, Attribute, - implementer, ) -from twisted.python import log class ITransitClient(Interface): From 7b92c3701e54ba7ddd12d303fa1b9340f90eb044 Mon Sep 17 00:00:00 2001 From: meejah Date: Wed, 14 Apr 2021 16:33:21 -0600 Subject: [PATCH 70/96] leave state-machine tracing code (but commented) --- src/wormhole_transit_relay/server_state.py | 5 ++--- src/wormhole_transit_relay/transit_server.py | 16 ++++++++-------- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/src/wormhole_transit_relay/server_state.py b/src/wormhole_transit_relay/server_state.py index a2b9fec..3192480 100644 --- a/src/wormhole_transit_relay/server_state.py +++ b/src/wormhole_transit_relay/server_state.py @@ -471,6 +471,5 @@ class TransitServerState(object): outputs=[], ) - - ## XXX tracing - set_trace_function = _machine._setTrace + # uncomment to turn on state-machine tracing + # set_trace_function = _machine._setTrace diff --git a/src/wormhole_transit_relay/transit_server.py b/src/wormhole_transit_relay/transit_server.py index 8bbc8d1..2be5231 100644 --- a/src/wormhole_transit_relay/transit_server.py +++ b/src/wormhole_transit_relay/transit_server.py @@ -74,10 +74,10 @@ class TransitConnection(LineReceiver): except AttributeError: pass - if False: - def tracer(oldstate, theinput, newstate): - print("TRACE: {}: {} --{}--> {}".format(id(self), oldstate, theinput, newstate)) - self._state.set_trace_function(tracer) + # uncomment to turn on state-machine tracing + # def tracer(oldstate, theinput, newstate): + # print("TRACE: {}: {} --{}--> {}".format(id(self), oldstate, theinput, newstate)) + # self._state.set_trace_function(tracer) def lineReceived(self, line): """ @@ -227,10 +227,10 @@ class WebSocketTransitConnection(WebSocketServerProtocol): self.factory.transit.usage, ) - if False: - def tracer(oldstate, theinput, newstate): - print("WSTRACE: {}: {} --{}--> {}".format(id(self), oldstate, theinput, newstate)) - self._state.set_trace_function(tracer) + # uncomment to turn on state-machine tracing + # def tracer(oldstate, theinput, newstate): + # print("WSTRACE: {}: {} --{}--> {}".format(id(self), oldstate, theinput, newstate)) + # self._state.set_trace_function(tracer) def onOpen(self): self._state.connection_made(self) From e7466a359505962783fc05002cab0c504c898406 Mon Sep 17 00:00:00 2001 From: meejah Date: Wed, 14 Apr 2021 16:46:46 -0600 Subject: [PATCH 71/96] add CLI options for WebSockets support --- src/wormhole_transit_relay/server_tap.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/src/wormhole_transit_relay/server_tap.py b/src/wormhole_transit_relay/server_tap.py index 3973617..a54082b 100644 --- a/src/wormhole_transit_relay/server_tap.py +++ b/src/wormhole_transit_relay/server_tap.py @@ -26,6 +26,8 @@ class Options(usage.Options): optParameters = [ ("port", "p", "tcp:4001:interface=\:\:", "endpoint to listen on"), + ("websocket", "w", None, "endpoint to listen for WebSocket connections"), + ("websocket-url", "u", None, "WebSocket URL (derived from endpoint if not provided)"), ("blur-usage", None, None, "blur timestamps and data sizes in logs"), ("log-fd", None, None, "write JSON usage logs to this file descriptor"), ("usage-db", None, None, "record usage data (SQLite)"), @@ -38,8 +40,11 @@ class Options(usage.Options): def makeService(config, reactor=reactor): increase_rlimits() tcp_ep = endpoints.serverFromString(reactor, config["port"]) # to listen - # XXX FIXME proper websocket option - ws_ep = endpoints.serverFromString(reactor, "tcp:4002") # to listen + ws_ep = ( + endpoints.serverFromString(reactor, config["websocket"]) + if config["websocket"] is not None + else None + ) log_file = ( os.fdopen(int(config["log-fd"]), "w") if config["log-fd"] is not None @@ -55,13 +60,22 @@ def makeService(config, reactor=reactor): tcp_factory = protocol.ServerFactory() tcp_factory.protocol = transit_server.TransitConnection - ws_factory = WebSocketServerFactory("ws://localhost:4002") # FIXME: url - ws_factory.protocol = transit_server.WebSocketTransitConnection + if ws_ep is not None: + ws_url = config["websocket-url"] + if ws_url is None: + # we're using a "private" attribute here but I don't see + # any useful alternative unless we also want to parse + # Twisted endpoint-strings. + ws_url = "ws://localhost:{}/".format(ws_ep._port) + print("Using WebSocket URL '{}'".format(ws_url)) + ws_factory = WebSocketServerFactory(ws_url) + ws_factory.protocol = transit_server.WebSocketTransitConnection tcp_factory.transit = transit ws_factory.transit = transit parent = MultiService() StreamServerEndpointService(tcp_ep, tcp_factory).setServiceParent(parent) - StreamServerEndpointService(ws_ep, ws_factory).setServiceParent(parent) + if ws_ep is not None: + StreamServerEndpointService(ws_ep, ws_factory).setServiceParent(parent) TimerService(5*60.0, transit.update_stats).setServiceParent(parent) return parent From 319145608de0b478ef7c201f4e42e59a7e6e6be7 Mon Sep 17 00:00:00 2001 From: meejah Date: Wed, 14 Apr 2021 16:47:52 -0600 Subject: [PATCH 72/96] better comment --- src/wormhole_transit_relay/test/test_transit_server.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/wormhole_transit_relay/test/test_transit_server.py b/src/wormhole_transit_relay/test/test_transit_server.py index ae00ad3..f812a9a 100644 --- a/src/wormhole_transit_relay/test/test_transit_server.py +++ b/src/wormhole_transit_relay/test/test_transit_server.py @@ -508,7 +508,10 @@ class Usage(ServerBase, unittest.TestCase): self.assertEqual(self._usage.events[0]["mood"], "empty", self._usage) def test_short(self): - # XXX this test only makes sense for TCP + # Note: this test only runs on TCP clients because WebSockets + # already does framing (so it's either "a bad handshake" or + # there's no handshake at all yet .. you can't have a "short" + # one). p1 = self.new_protocol() # hang up before sending a complete handshake p1.send(b"short") From 95a72e6ac9044f50f23907cb38df9212005ce66b Mon Sep 17 00:00:00 2001 From: meejah Date: Wed, 14 Apr 2021 16:53:37 -0600 Subject: [PATCH 73/96] better comment --- src/wormhole_transit_relay/server_state.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/wormhole_transit_relay/server_state.py b/src/wormhole_transit_relay/server_state.py index 3192480..06a26bf 100644 --- a/src/wormhole_transit_relay/server_state.py +++ b/src/wormhole_transit_relay/server_state.py @@ -246,6 +246,10 @@ class TransitServerState(object): def _remember_client(self, client): self._client = client + # note that there is no corresponding "_forget_client" because we + # may still want to access it after it is gone .. for example, to + # get the .started_time for logging purposes + @_machine.output() def _register_token(self, token): return self._real_register_token_for_side(token, None) @@ -388,13 +392,10 @@ class TransitServerState(object): Terminal state """ - # need a listening.upon(connection_lost) for special websocket - # case where handshake fails? - listening.upon( connection_made, enter=wait_relay, - outputs=[_remember_client], # XXX need _forget_client ? + outputs=[_remember_client], ) listening.upon( connection_lost, From 6698cf95d54d0280613c7d33c6437f2e8522ac4e Mon Sep 17 00:00:00 2001 From: meejah Date: Wed, 14 Apr 2021 16:54:41 -0600 Subject: [PATCH 74/96] irrelevant (there was only a debug message logged) --- src/wormhole_transit_relay/transit_server.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/wormhole_transit_relay/transit_server.py b/src/wormhole_transit_relay/transit_server.py index 2be5231..667aa69 100644 --- a/src/wormhole_transit_relay/transit_server.py +++ b/src/wormhole_transit_relay/transit_server.py @@ -116,9 +116,6 @@ class TransitConnection(LineReceiver): def connectionLost(self, reason): self._state.connection_lost() -# XXX this probably resulted in a log message we've not refactored yet -# self.factory.transitFinished(self, self._got_token, self._got_side, -# self.describeToken()) class Transit(object): From d8da1a62d663dee4166e33c9687cced7a525659a Mon Sep 17 00:00:00 2001 From: meejah Date: Wed, 14 Apr 2021 16:55:34 -0600 Subject: [PATCH 75/96] no more debug-log --- src/wormhole_transit_relay/server_state.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/wormhole_transit_relay/server_state.py b/src/wormhole_transit_relay/server_state.py index 06a26bf..3bc112e 100644 --- a/src/wormhole_transit_relay/server_state.py +++ b/src/wormhole_transit_relay/server_state.py @@ -121,8 +121,6 @@ class PendingRequests(object): or (new_side is None) or (old_side != new_side)): # we found a match - # FIXME: debug-log this - # print("transit relay 2: %s" % new_tc.get_token()) # drop and stop tracking the rest potentials.remove(old) @@ -141,8 +139,6 @@ class PendingRequests(object): old_tc.got_partner(new_tc) return False - # FIXME: debug-log this - # print("transit relay 1: %s" % new_tc.get_token()) potentials.add((new_side, new_tc)) return True # TODO: timer From bd06fad7e7d805276f1f90ec6c201c5c6488d7b8 Mon Sep 17 00:00:00 2001 From: meejah Date: Wed, 14 Apr 2021 16:58:05 -0600 Subject: [PATCH 76/96] websocket defaults test --- src/wormhole_transit_relay/server_tap.py | 2 +- src/wormhole_transit_relay/test/test_config.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/wormhole_transit_relay/server_tap.py b/src/wormhole_transit_relay/server_tap.py index a54082b..8198674 100644 --- a/src/wormhole_transit_relay/server_tap.py +++ b/src/wormhole_transit_relay/server_tap.py @@ -70,9 +70,9 @@ def makeService(config, reactor=reactor): print("Using WebSocket URL '{}'".format(ws_url)) ws_factory = WebSocketServerFactory(ws_url) ws_factory.protocol = transit_server.WebSocketTransitConnection + ws_factory.transit = transit tcp_factory.transit = transit - ws_factory.transit = transit parent = MultiService() StreamServerEndpointService(tcp_ep, tcp_factory).setServiceParent(parent) if ws_ep is not None: diff --git a/src/wormhole_transit_relay/test/test_config.py b/src/wormhole_transit_relay/test/test_config.py index 72aa7ec..0b014d2 100644 --- a/src/wormhole_transit_relay/test/test_config.py +++ b/src/wormhole_transit_relay/test/test_config.py @@ -9,12 +9,14 @@ class Config(unittest.TestCase): o = server_tap.Options() o.parseOptions([]) self.assertEqual(o, {"blur-usage": None, "log-fd": None, - "usage-db": None, "port": PORT}) + "usage-db": None, "port": PORT, + "websocket": None, "websocket-url": None}) def test_blur(self): o = server_tap.Options() o.parseOptions(["--blur-usage=60"]) self.assertEqual(o, {"blur-usage": 60, "log-fd": None, - "usage-db": None, "port": PORT}) + "usage-db": None, "port": PORT, + "websocket": None, "websocket-url": None}) def test_string(self): o = server_tap.Options() From 425f040168187d4a2c8068aa3efdabc9d10a74a8 Mon Sep 17 00:00:00 2001 From: meejah Date: Wed, 14 Apr 2021 17:01:20 -0600 Subject: [PATCH 77/96] test for websocket option-parsing --- src/wormhole_transit_relay/test/test_config.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/wormhole_transit_relay/test/test_config.py b/src/wormhole_transit_relay/test/test_config.py index 0b014d2..e942d30 100644 --- a/src/wormhole_transit_relay/test/test_config.py +++ b/src/wormhole_transit_relay/test/test_config.py @@ -18,6 +18,13 @@ class Config(unittest.TestCase): "usage-db": None, "port": PORT, "websocket": None, "websocket-url": None}) + def test_websocket(self): + o = server_tap.Options() + o.parseOptions(["--websocket=tcp:4004"]) + self.assertEqual(o, {"blur-usage": None, "log-fd": None, + "usage-db": None, "port": PORT, + "websocket": "tcp:4004", "websocket-url": None}) + def test_string(self): o = server_tap.Options() s = str(o) From 30688f638c7f66a1ebffa160ac734ad62f30518a Mon Sep 17 00:00:00 2001 From: meejah Date: Wed, 14 Apr 2021 17:07:50 -0600 Subject: [PATCH 78/96] test to ensure we make a websocket service when passing --websocket --- src/wormhole_transit_relay/test/test_service.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/wormhole_transit_relay/test/test_service.py b/src/wormhole_transit_relay/test/test_service.py index f84b01a..3205bb7 100644 --- a/src/wormhole_transit_relay/test/test_service.py +++ b/src/wormhole_transit_relay/test/test_service.py @@ -5,6 +5,7 @@ try: except ImportError: import mock from twisted.application.service import MultiService +from autobahn.twisted.websocket import WebSocketServerFactory from .. import server_tap class Service(unittest.TestCase): @@ -40,3 +41,13 @@ class Service(unittest.TestCase): [mock.call(blur_usage=None, log_file=fd, usage_db=None)]) + def test_websocket(self): + o = server_tap.Options() + o.parseOptions(["--websocket=tcp:4004"]) + services = server_tap.makeService(o) + self.assertTrue( + any( + isinstance(s.factory, WebSocketServerFactory) + for s in services.services + ) + ) From b2ec2981d655083876775de1e11a596e870fedf0 Mon Sep 17 00:00:00 2001 From: meejah Date: Thu, 15 Apr 2021 11:15:16 -0600 Subject: [PATCH 79/96] add autobahn dep, upgrade twisted --- setup.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 92c87c9..6ea23a0 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,8 @@ setup(name="magic-wormhole-transit-relay", ], package_data={"wormhole_transit_relay": ["db-schemas/*.sql"]}, install_requires=[ - "twisted >= 17.5.0", + "twisted >= 21.2.1", + "autobahn >= 21.3.1", ], extras_require={ ':sys_platform=="win32"': ["pypiwin32"], From 141709bf62cca0254c79110563ea528fee1114fa Mon Sep 17 00:00:00 2001 From: meejah Date: Thu, 15 Apr 2021 11:17:16 -0600 Subject: [PATCH 80/96] correct twisted version --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 6ea23a0..7119506 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ setup(name="magic-wormhole-transit-relay", ], package_data={"wormhole_transit_relay": ["db-schemas/*.sql"]}, install_requires=[ - "twisted >= 21.2.1", + "twisted >= 21.2.0", "autobahn >= 21.3.1", ], extras_require={ From 360c7999a8095a50fa1c4c60d3dcfb5d50436d4d Mon Sep 17 00:00:00 2001 From: meejah Date: Thu, 15 Apr 2021 11:36:00 -0600 Subject: [PATCH 81/96] unconstrained autobahn version so we can support other pythons --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 7119506..71ee619 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,7 @@ setup(name="magic-wormhole-transit-relay", package_data={"wormhole_transit_relay": ["db-schemas/*.sql"]}, install_requires=[ "twisted >= 21.2.0", - "autobahn >= 21.3.1", + "autobahn", ], extras_require={ ':sys_platform=="win32"': ["pypiwin32"], From 553158da633f167488b4d4f5f380444b6ed09c06 Mon Sep 17 00:00:00 2001 From: meejah Date: Thu, 15 Apr 2021 11:39:39 -0600 Subject: [PATCH 82/96] drop unsupported pythons for now --- .github/workflows/test.yml | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index d854f89..19d3656 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -11,7 +11,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [2.7, 3.5, 3.6, 3.7, 3.8, 3.9] + python-version: [3.7, 3.8, 3.9] steps: - uses: actions/checkout@v2 diff --git a/setup.py b/setup.py index 71ee619..7119506 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,7 @@ setup(name="magic-wormhole-transit-relay", package_data={"wormhole_transit_relay": ["db-schemas/*.sql"]}, install_requires=[ "twisted >= 21.2.0", - "autobahn", + "autobahn >= 21.3.1", ], extras_require={ ':sys_platform=="win32"': ["pypiwin32"], From f225fded53385741edec242214e5c5125e8f3982 Mon Sep 17 00:00:00 2001 From: meejah Date: Thu, 15 Apr 2021 12:06:09 -0600 Subject: [PATCH 83/96] pyflakes cleanup for test-websocket-client --- ws_client.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/ws_client.py b/ws_client.py index b0407bf..93dacfc 100644 --- a/ws_client.py +++ b/ws_client.py @@ -7,21 +7,11 @@ from twisted.internet.defer import ( inlineCallbacks, ) from twisted.internet.task import react, deferLater -from twisted.internet.error import ( - ConnectionDone, -) -from twisted.internet.protocol import ( - Protocol, - Factory, -) -from twisted.protocols.basic import LineReceiver -from twisted.application.internet import StreamServerEndpointService from autobahn.twisted.websocket import ( WebSocketClientProtocol, WebSocketClientFactory, ) -from autobahn.websocket import types class RelayEchoClient(WebSocketClientProtocol): From 82c175a02e82fac5268e27c0d6f6635809aec40d Mon Sep 17 00:00:00 2001 From: meejah Date: Thu, 15 Apr 2021 12:07:18 -0600 Subject: [PATCH 84/96] add tcp test-client too --- client.py | 49 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 client.py diff --git a/client.py b/client.py new file mode 100644 index 0000000..8998a64 --- /dev/null +++ b/client.py @@ -0,0 +1,49 @@ +from __future__ import print_function + +from twisted.internet import endpoints +from twisted.internet.defer import ( + Deferred, +) +from twisted.internet.task import react +from twisted.internet.error import ( + ConnectionDone, +) +from twisted.internet.protocol import ( + Protocol, + Factory, +) + + +class RelayEchoClient(Protocol): + """ + Speaks the version1 magic wormhole transit relay protocol (as a client) + """ + + def connectionMade(self): + print(">CONNECT") + self.data = b"" + self.transport.write(u"please relay {}\n".format(self.factory.token).encode("ascii")) + + def dataReceived(self, data): + print(">RECV {} bytes".format(len(data))) + print(data.decode("ascii")) + self.data += data + if data == "ok\n": + self.transport.write("ding\n") + + def connectionLost(self, reason): + if isinstance(reason.value, ConnectionDone): + self.factory.done.callback(None) + else: + print(">DISCONNCT: {}".format(reason)) + self.factory.done.callback(reason) + + +@react +def main(reactor): + ep = endpoints.clientFromString(reactor, "tcp:localhost:4001") + f = Factory.forProtocol(RelayEchoClient) + f.token = "a" * 64 + f.done = Deferred() + ep.connect(f) + return f.done From 3cda647883b1f804e9d5bcf55e4d5c1e2b15d328 Mon Sep 17 00:00:00 2001 From: meejah Date: Thu, 15 Apr 2021 12:09:44 -0600 Subject: [PATCH 85/96] document test-clients --- client.py | 7 ++++++- ws_client.py | 10 +++++++++- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/client.py b/client.py index 8998a64..5c7d235 100644 --- a/client.py +++ b/client.py @@ -1,4 +1,9 @@ -from __future__ import print_function +""" +This is a test-client for the transit-relay that uses TCP. It +doesn't send any data, only prints out data that is received. Uses a +fixed token of 64 'a' characters. Always connects on localhost:4001 +""" + from twisted.internet import endpoints from twisted.internet.defer import ( diff --git a/ws_client.py b/ws_client.py index 93dacfc..27e989c 100644 --- a/ws_client.py +++ b/ws_client.py @@ -1,4 +1,12 @@ -from __future__ import print_function +""" +This is a test-client for the transit-relay that uses WebSockets. + +If an additional command-line argument (anything) is added, it will +send 5 messages upon connection. Otherwise, it just prints out what is +received. Uses a fixed token of 64 'a' characters. Always connects on +localhost:4002 +""" + import sys from twisted.internet import endpoints From 088757a7c9f39f30cf5c9c9056ba37697a39f86a Mon Sep 17 00:00:00 2001 From: meejah Date: Thu, 15 Apr 2021 12:16:34 -0600 Subject: [PATCH 86/96] add words about websockets support --- docs/running.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/docs/running.md b/docs/running.md index 5ea0601..5908584 100644 --- a/docs/running.md +++ b/docs/running.md @@ -50,6 +50,15 @@ The relevant arguments are: * ``--usage-db=``: maintains a SQLite database with current and historical usage data * ``--blur-usage=``: round logged timestamps and data sizes +For WebSockets support, two additional arguments: + +* ``--websocket``: the endpoint to listen for websocket connections + on, like ``tcp:4002`` +* ``--websocket-url``: the URL of the WebSocket connection. This may + be different from the listening endpoint because of port-forwarding + and so forth. By default it will be ``ws://localhost:`` if not + provided + When you use ``twist``, the relay runs in the foreground, so it will generally exit as soon as the controlling terminal exits. For persistent environments, you should daemonize the server. From 1a1947d7e92794fea9af0744b5a296ce6e883c0c Mon Sep 17 00:00:00 2001 From: meejah Date: Thu, 15 Apr 2021 12:19:40 -0600 Subject: [PATCH 87/96] irrelevant --- src/wormhole_transit_relay/transit_server.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/wormhole_transit_relay/transit_server.py b/src/wormhole_transit_relay/transit_server.py index b0b35e1..c7baffb 100644 --- a/src/wormhole_transit_relay/transit_server.py +++ b/src/wormhole_transit_relay/transit_server.py @@ -67,7 +67,6 @@ class TransitConnection(LineReceiver): self.factory.transit.usage, ) self._state.connection_made(self) -## self._log_requests = self.factory._log_requests self.transport.setTcpKeepAlive(True) # uncomment to turn on state-machine tracing From aa58b85ace2fb52a0899b2ee24fd6eb933da4e47 Mon Sep 17 00:00:00 2001 From: meejah Date: Thu, 15 Apr 2021 12:28:23 -0600 Subject: [PATCH 88/96] honour log_requests as original code did --- src/wormhole_transit_relay/server_state.py | 5 +++++ src/wormhole_transit_relay/server_tap.py | 2 ++ src/wormhole_transit_relay/test/common.py | 1 + src/wormhole_transit_relay/test/test_transit_server.py | 2 ++ src/wormhole_transit_relay/transit_server.py | 6 ++++-- 5 files changed, 14 insertions(+), 2 deletions(-) diff --git a/src/wormhole_transit_relay/server_state.py b/src/wormhole_transit_relay/server_state.py index 3bc112e..60ea710 100644 --- a/src/wormhole_transit_relay/server_state.py +++ b/src/wormhole_transit_relay/server_state.py @@ -1,6 +1,7 @@ from collections import defaultdict import automat +from twisted.python import log from zope.interface import ( Interface, Attribute, @@ -265,6 +266,8 @@ class TransitServerState(object): def _send_bad(self): self._mood = "errory" self._client.send(b"bad handshake\n") + if self._client.factory.log_requests: + log.msg("transit handshake failure") @_machine.output() def _send_ok(self): @@ -273,6 +276,8 @@ class TransitServerState(object): @_machine.output() def _send_impatient(self): self._client.send(b"impatient\n") + if self._client.factory.log_requests: + log.msg("transit impatience failure") @_machine.output() def _count_bytes(self, data): diff --git a/src/wormhole_transit_relay/server_tap.py b/src/wormhole_transit_relay/server_tap.py index 8198674..0db3ef6 100644 --- a/src/wormhole_transit_relay/server_tap.py +++ b/src/wormhole_transit_relay/server_tap.py @@ -59,6 +59,7 @@ def makeService(config, reactor=reactor): transit = transit_server.Transit(usage, reactor.seconds) tcp_factory = protocol.ServerFactory() tcp_factory.protocol = transit_server.TransitConnection + tcp_factory.log_requests = False if ws_ep is not None: ws_url = config["websocket-url"] @@ -71,6 +72,7 @@ def makeService(config, reactor=reactor): ws_factory = WebSocketServerFactory(ws_url) ws_factory.protocol = transit_server.WebSocketTransitConnection ws_factory.transit = transit + ws_factory.log_requests = False tcp_factory.transit = transit parent = MultiService() diff --git a/src/wormhole_transit_relay/test/common.py b/src/wormhole_transit_relay/test/common.py index cb84de1..4b2469f 100644 --- a/src/wormhole_transit_relay/test/common.py +++ b/src/wormhole_transit_relay/test/common.py @@ -88,6 +88,7 @@ class ServerBase: server_factory = ServerFactory() server_factory.protocol = TransitConnection server_factory.transit = self._transit_server + server_factory.log_requests = self.log_requests server_protocol = server_factory.buildProtocol(('127.0.0.1', 0)) @implementer(IRelayTestClient) diff --git a/src/wormhole_transit_relay/test/test_transit_server.py b/src/wormhole_transit_relay/test/test_transit_server.py index 1333781..c46b936 100644 --- a/src/wormhole_transit_relay/test/test_transit_server.py +++ b/src/wormhole_transit_relay/test/test_transit_server.py @@ -440,6 +440,7 @@ class TransitWebSockets(_Transit, ServerBase, unittest.TestCase): ws_factory = WebSocketServerFactory("ws://localhost:4002") ws_factory.protocol = WebSocketTransitConnection ws_factory.transit = self._transit_server + ws_factory.log_requests = self.log_requests ws_protocol = ws_factory.buildProtocol(('127.0.0.1', 0)) @implementer(IRelayTestClient) @@ -656,6 +657,7 @@ class UsageWebSockets(Usage): ws_factory = WebSocketServerFactory("ws://localhost:4002") ws_factory.protocol = WebSocketTransitConnection ws_factory.transit = self._transit_server + ws_factory.log_requests = self.log_requests ws_protocol = ws_factory.buildProtocol(('127.0.0.1', 0)) class TransitWebSocketClientProtocol(WebSocketClientProtocol): diff --git a/src/wormhole_transit_relay/transit_server.py b/src/wormhole_transit_relay/transit_server.py index c7baffb..35a2853 100644 --- a/src/wormhole_transit_relay/transit_server.py +++ b/src/wormhole_transit_relay/transit_server.py @@ -53,7 +53,8 @@ class TransitConnection(LineReceiver): ITransitClient API """ if self._buddy is not None: - log.msg("buddy_disconnected {}".format(self._buddy.get_token())) + if self.factory.log_requests: + log.msg("buddy_disconnected {}".format(self._buddy.get_token())) self._buddy._client.disconnect() self._buddy = None @@ -203,7 +204,8 @@ class WebSocketTransitConnection(WebSocketServerProtocol): ITransitClient API """ if self._buddy is not None: - log.msg("buddy_disconnected {}".format(self._buddy.get_token())) + if self.factory.log_requests: + log.msg("buddy_disconnected {}".format(self._buddy.get_token())) self._buddy._client.disconnect() self._buddy = None From bfd8312ef06035407b52253a186a879536af5568 Mon Sep 17 00:00:00 2001 From: meejah Date: Thu, 15 Apr 2021 19:19:14 -0600 Subject: [PATCH 89/96] render empty token correctly --- src/wormhole_transit_relay/server_state.py | 9 +++++---- .../test/test_transit_server.py | 16 ++++++++++++++++ 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/src/wormhole_transit_relay/server_state.py b/src/wormhole_transit_relay/server_state.py index 60ea710..f851b91 100644 --- a/src/wormhole_transit_relay/server_state.py +++ b/src/wormhole_transit_relay/server_state.py @@ -178,10 +178,11 @@ class TransitServerState(object): d = "-" if self._token is not None: d = self._token[:16].decode("ascii") - if self._side is not None: - d += "-" + self._side.decode("ascii") - else: - d += "-" + + if self._side is not None: + d += "-" + self._side.decode("ascii") + else: + d += "-" return d @_machine.input() diff --git a/src/wormhole_transit_relay/test/test_transit_server.py b/src/wormhole_transit_relay/test/test_transit_server.py index c46b936..2b55ef7 100644 --- a/src/wormhole_transit_relay/test/test_transit_server.py +++ b/src/wormhole_transit_relay/test/test_transit_server.py @@ -22,6 +22,7 @@ from ..usage import ( ) from ..transit_server import ( WebSocketTransitConnection, + TransitServerState, ) @@ -677,3 +678,18 @@ class UsageWebSockets(Usage): ) self._pumps.append(pump) return client_protocol + + +class State(unittest.TestCase): + """ + Tests related to server_state.TransitServerState + """ + + def setUp(self): + self.state = TransitServerState(None, None) + + def test_empty_token(self): + self.assertEqual( + "-", + self.state.get_token(), + ) From 513e5bed6ec47456675baa3c45b5767fca878948 Mon Sep 17 00:00:00 2001 From: meejah Date: Sun, 18 Apr 2021 21:13:45 -0600 Subject: [PATCH 90/96] combine checks --- src/wormhole_transit_relay/server_state.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/wormhole_transit_relay/server_state.py b/src/wormhole_transit_relay/server_state.py index f851b91..6018252 100644 --- a/src/wormhole_transit_relay/server_state.py +++ b/src/wormhole_transit_relay/server_state.py @@ -305,9 +305,8 @@ class TransitServerState(object): @_machine.output() def _record_usage(self): if self._mood == "jilted": - if self._buddy: - if self._buddy._mood == "happy": - return + if self._buddy and self._buddy._mood == "happy": + return self._usage.record( started=self._client.started_time, buddy_started=self._buddy._client.started_time if self._buddy is not None else None, From ade99eb8b3e13030bfebc992e15400a4ae5fbf44 Mon Sep 17 00:00:00 2001 From: meejah Date: Sun, 18 Apr 2021 21:14:07 -0600 Subject: [PATCH 91/96] docstring, whitespace --- src/wormhole_transit_relay/test/test_service.py | 3 +++ src/wormhole_transit_relay/usage.py | 1 - 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/wormhole_transit_relay/test/test_service.py b/src/wormhole_transit_relay/test/test_service.py index 20e681d..23faed8 100644 --- a/src/wormhole_transit_relay/test/test_service.py +++ b/src/wormhole_transit_relay/test/test_service.py @@ -38,6 +38,9 @@ class Service(unittest.TestCase): log_file=fd, usage_db=None)]) def test_websocket(self): + """ + A websocket factory is created when passing --websocket + """ o = server_tap.Options() o.parseOptions(["--websocket=tcp:4004"]) services = server_tap.makeService(o) diff --git a/src/wormhole_transit_relay/usage.py b/src/wormhole_transit_relay/usage.py index b73fd7d..61bf308 100644 --- a/src/wormhole_transit_relay/usage.py +++ b/src/wormhole_transit_relay/usage.py @@ -213,7 +213,6 @@ class UsageTracker(object): incomplete_bytes) ) - def _notify_backends(self, data): """ Internal helper. Tell every backend we have about a new usage record. From dfd3bdd1a1a62eb86709e847bd3ea132f84962db Mon Sep 17 00:00:00 2001 From: meejah Date: Sun, 18 Apr 2021 21:14:28 -0600 Subject: [PATCH 92/96] test we can pass an explicit websocket URL --- .../test/test_service.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/src/wormhole_transit_relay/test/test_service.py b/src/wormhole_transit_relay/test/test_service.py index 23faed8..9ab30c8 100644 --- a/src/wormhole_transit_relay/test/test_service.py +++ b/src/wormhole_transit_relay/test/test_service.py @@ -50,3 +50,21 @@ class Service(unittest.TestCase): for s in services.services ) ) + + def test_websocket_explicit_url(self): + """ + A websocket factory is created with --websocket and + --websocket-url + """ + o = server_tap.Options() + o.parseOptions([ + "--websocket=tcp:4004", + "--websocket-url=ws://example.com:4004", + ]) + services = server_tap.makeService(o) + self.assertTrue( + any( + isinstance(s.factory, WebSocketServerFactory) + for s in services.services + ) + ) From 0ce08b66cf48fe840387bae1e2e502bbc984b830 Mon Sep 17 00:00:00 2001 From: meejah Date: Sun, 18 Apr 2021 21:25:25 -0600 Subject: [PATCH 93/96] defensive if's -> assert --- src/wormhole_transit_relay/transit_server.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/wormhole_transit_relay/transit_server.py b/src/wormhole_transit_relay/transit_server.py index 35a2853..4b7b0b5 100644 --- a/src/wormhole_transit_relay/transit_server.py +++ b/src/wormhole_transit_relay/transit_server.py @@ -52,11 +52,11 @@ class TransitConnection(LineReceiver): """ ITransitClient API """ - if self._buddy is not None: - if self.factory.log_requests: - log.msg("buddy_disconnected {}".format(self._buddy.get_token())) - self._buddy._client.disconnect() - self._buddy = None + assert self._buddy is not None, "internal error: no buddy" + if self.factory.log_requests: + log.msg("buddy_disconnected {}".format(self._buddy.get_token())) + self._buddy._client.disconnect() + self._buddy = None def connectionMade(self): # ideally more like self._reactor.seconds() ... but Twisted @@ -203,11 +203,11 @@ class WebSocketTransitConnection(WebSocketServerProtocol): """ ITransitClient API """ - if self._buddy is not None: - if self.factory.log_requests: - log.msg("buddy_disconnected {}".format(self._buddy.get_token())) - self._buddy._client.disconnect() - self._buddy = None + assert self._buddy is not None, "internal error: no buddy" + if self.factory.log_requests: + log.msg("buddy_disconnected {}".format(self._buddy.get_token())) + self._buddy._client.disconnect() + self._buddy = None def connectionMade(self): """ From 807dfc1c183337693eb2e7a45cccf401a2b6e2a5 Mon Sep 17 00:00:00 2001 From: meejah Date: Sun, 18 Apr 2021 21:33:01 -0600 Subject: [PATCH 94/96] unify new_protocol_ws, make it a bare helper --- .../test/test_transit_server.py | 140 +++++++++--------- 1 file changed, 66 insertions(+), 74 deletions(-) diff --git a/src/wormhole_transit_relay/test/test_transit_server.py b/src/wormhole_transit_relay/test/test_transit_server.py index 2b55ef7..c5370de 100644 --- a/src/wormhole_transit_relay/test/test_transit_server.py +++ b/src/wormhole_transit_relay/test/test_transit_server.py @@ -367,11 +367,72 @@ class TransitWithoutLogs(_Transit, ServerBase, unittest.TestCase): return self.new_protocol_tcp() +def _new_protocol_ws(transit_server, log_requests): + """ + Internal helper for test-suites that need to provide WebSocket + client/server pairs. + + :returns: a 2-tuple: (iosim.IOPump, protocol) + """ + ws_factory = WebSocketServerFactory("ws://localhost:4002") + ws_factory.protocol = WebSocketTransitConnection + ws_factory.transit = transit_server + ws_factory.log_requests = log_requests + ws_protocol = ws_factory.buildProtocol(('127.0.0.1', 0)) + + @implementer(IRelayTestClient) + class TransitWebSocketClientProtocol(WebSocketClientProtocol): + _received = b"" + connected = False + + def connectionMade(self): + self.connected = True + return super(TransitWebSocketClientProtocol, self).connectionMade() + + def connectionLost(self, reason): + self.connected = False + return super(TransitWebSocketClientProtocol, self).connectionLost(reason) + + def onMessage(self, data, isBinary): + self._received = self._received + data + + def send(self, data): + self.sendMessage(data, True) + + def get_received_data(self): + return self._received + + def reset_received_data(self): + self._received = b"" + + def disconnect(self): + self.sendClose(1000, True) + + client_factory = WebSocketClientFactory() + client_factory.protocol = TransitWebSocketClientProtocol + client_protocol = client_factory.buildProtocol(('127.0.0.1', 31337)) + client_protocol.disconnect = client_protocol.dropConnection + + pump = iosim.connect( + ws_protocol, + iosim.makeFakeServer(ws_protocol), + client_protocol, + iosim.makeFakeClient(client_protocol), + ) + return pump, client_protocol + + + class TransitWebSockets(_Transit, ServerBase, unittest.TestCase): def new_protocol(self): return self.new_protocol_ws() + def new_protocol_ws(self): + pump, proto = _new_protocol_ws(self._transit_server, self.log_requests) + self._pumps.append(pump) + return proto + def test_websocket_to_tcp(self): """ One client is WebSocket and one is TCP @@ -437,55 +498,6 @@ class TransitWebSockets(_Transit, ServerBase, unittest.TestCase): p1.send(b"more message") self.flush() - def new_protocol_ws(self): - ws_factory = WebSocketServerFactory("ws://localhost:4002") - ws_factory.protocol = WebSocketTransitConnection - ws_factory.transit = self._transit_server - ws_factory.log_requests = self.log_requests - ws_protocol = ws_factory.buildProtocol(('127.0.0.1', 0)) - - @implementer(IRelayTestClient) - class TransitWebSocketClientProtocol(WebSocketClientProtocol): - _received = b"" - connected = False - - def connectionMade(self): - self.connected = True - return super(TransitWebSocketClientProtocol, self).connectionMade() - - def connectionLost(self, reason): - self.connected = False - return super(TransitWebSocketClientProtocol, self).connectionLost(reason) - - def onMessage(self, data, isBinary): - self._received = self._received + data - - def send(self, data): - self.sendMessage(data, True) - - def get_received_data(self): - return self._received - - def reset_received_data(self): - self._received = b"" - - def disconnect(self): - self.sendClose(1000, True) - - client_factory = WebSocketClientFactory() - client_factory.protocol = TransitWebSocketClientProtocol - client_protocol = client_factory.buildProtocol(('127.0.0.1', 31337)) - client_protocol.disconnect = client_protocol.dropConnection - - pump = iosim.connect( - ws_protocol, - iosim.makeFakeServer(ws_protocol), - client_protocol, - iosim.makeFakeClient(client_protocol), - ) - self._pumps.append(pump) - return client_protocol - class Usage(ServerBase, unittest.TestCase): log_requests = True @@ -636,6 +648,11 @@ class UsageWebSockets(Usage): def new_protocol(self): return self.new_protocol_ws() + def new_protocol_ws(self): + pump, proto = _new_protocol_ws(self._transit_server, self.log_requests) + self._pumps.append(pump) + return proto + def test_short(self): """ This test essentially just tests the framing of the line-oriented @@ -654,31 +671,6 @@ class UsageWebSockets(Usage): with self.assertRaises(ValueError): ws_protocol.onMessage(u"foo", isBinary=False) - def new_protocol_ws(self): - ws_factory = WebSocketServerFactory("ws://localhost:4002") - ws_factory.protocol = WebSocketTransitConnection - ws_factory.transit = self._transit_server - ws_factory.log_requests = self.log_requests - ws_protocol = ws_factory.buildProtocol(('127.0.0.1', 0)) - - class TransitWebSocketClientProtocol(WebSocketClientProtocol): - def send(self, data): - self.sendMessage(data, True) - - client_factory = WebSocketClientFactory() - client_factory.protocol = TransitWebSocketClientProtocol - client_protocol = client_factory.buildProtocol(('127.0.0.1', 31337)) - client_protocol.disconnect = client_protocol.dropConnection - - pump = iosim.connect( - ws_protocol, - iosim.makeFakeServer(ws_protocol), - client_protocol, - iosim.makeFakeClient(client_protocol), - ) - self._pumps.append(pump) - return client_protocol - class State(unittest.TestCase): """ From 845c55ddec9904e5571813180df2f1a6f9186513 Mon Sep 17 00:00:00 2001 From: meejah Date: Sun, 9 May 2021 23:38:40 -0600 Subject: [PATCH 95/96] test --websocket-url option too --- src/wormhole_transit_relay/test/test_config.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/wormhole_transit_relay/test/test_config.py b/src/wormhole_transit_relay/test/test_config.py index 9d4f762..b2bb7e8 100644 --- a/src/wormhole_transit_relay/test/test_config.py +++ b/src/wormhole_transit_relay/test/test_config.py @@ -24,6 +24,14 @@ class Config(unittest.TestCase): "usage-db": None, "port": PORT, "websocket": "tcp:4004", "websocket-url": None}) + def test_websocket_url(self): + o = server_tap.Options() + o.parseOptions(["--websocket=tcp:4004", "--websocket-url=ws://example.com/"]) + self.assertEqual(o, {"blur-usage": None, "log-fd": None, + "usage-db": None, "port": PORT, + "websocket": "tcp:4004", + "websocket-url": "ws://example.com/"}) + def test_string(self): o = server_tap.Options() s = str(o) From 6bd063a91771b085a9a850dd257afe29de8b07e8 Mon Sep 17 00:00:00 2001 From: meejah Date: Sun, 9 May 2021 23:40:18 -0600 Subject: [PATCH 96/96] clarify --- src/wormhole_transit_relay/usage.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/wormhole_transit_relay/usage.py b/src/wormhole_transit_relay/usage.py index 61bf308..92f8e35 100644 --- a/src/wormhole_transit_relay/usage.py +++ b/src/wormhole_transit_relay/usage.py @@ -178,6 +178,9 @@ class UsageTracker(object): total_time = finished - started waiting_time = None total_bytes = bytes_sent + # note that "bytes_sent" should always be 0 here, but + # we're recording what the state-machine remembered in any + # case if self._blur_usage: started = self._blur_usage * (started // self._blur_usage)