diff --git a/src/wormhole_transit_relay/server_state.py b/src/wormhole_transit_relay/server_state.py index 6b2c6a5..cf95d1d 100644 --- a/src/wormhole_transit_relay/server_state.py +++ b/src/wormhole_transit_relay/server_state.py @@ -1,3 +1,4 @@ +from collections import defaultdict import automat from zope.interface import ( @@ -12,7 +13,7 @@ class ITransitClient(Interface): Send some byets to the client """ - def disconnect(reason): + def disconnect(): """ Disconnect the client transport """ @@ -37,7 +38,7 @@ class TestClient(object): def send_to_partner(self, data): print("{} GOT:{}".format(id(self), repr(data))) if self._partner: - self._partner.send(data) + self._partner._client.send(data) def send(self, data): print("{} SEND:{}".format(id(self), repr(data))) @@ -56,18 +57,94 @@ class TestClient(object): print("disconnect_partner: {}".format(id(self._partner))) +class ActiveConnections(object): + """ + Tracks active connections. A connection is 'active' when both + sides have shown up and they are glued together. + """ + def __init__(self): + self._connections = set() + + def register(self, side0, side1): + """ + A connection has become active so register both its sides + + :param TransitConnection side0: one side of the connection + :param TransitConnection side1: one side of the connection + """ + self._connections.add(side0) + self._connections.add(side1) + + def unregister(self, side): + """ + One side of a connection has become inactive. + + :param TransitConnection side: an inactive side of a connection + """ + self._connections.discard(side) + + class PendingRequests(object): """ Tracks the tokens we have received from client connections and - maps them to their partner connections + maps them to their partner connections for requests that haven't + yet been 'glued together' (that is, one side hasn't yet shown up). """ - def register_token(self, *args): + def __init__(self, active_connections): + self._requests = defaultdict(set) # token -> set((side, TransitConnection)) + self._active = active_connections + + def unregister(self, token, side, tc): + if token in self._requests: + self._requests[token].discard((side, tc)) + self._active.unregister(tc) + + def register_token(self, token, new_side, new_tc): """ + A client has connected and successfully offered a token (and + optional 'side' token). If this is the first one for this + token, we merely remember it. If it is the second side for + this token we connect them together. + + :returns bool: True if we are the first side to register this + token """ + potentials = self._requests[token] + for old in potentials: + (old_side, old_tc) = old + if ((old_side is None) + or (new_side is None) + or (old_side != new_side)): + # we found a match + # FIXME: debug-log this + # print("transit relay 2: %s" % new_tc.get_token()) + + # drop and stop tracking the rest + potentials.remove(old) + for (_, leftover_tc) in potentials.copy(): + # Don't record this as errory. It's just a spare connection + # from the same side as a connection that got used. This + # can happen if the connection hint contains multiple + # addresses (we don't currently support those, but it'd + # probably be useful in the future). + leftover_tc.disconnect_redundant() + self._requests.pop(token, None) + + # glue the two ends together + self._active.register(new_tc, old_tc) + new_tc.got_partner(old_tc) + old_tc.got_partner(new_tc) + return False + + # FIXME: debug-log this + # print("transit relay 1: %s" % new_tc.get_token()) + potentials.add((new_side, new_tc)) + return True + # TODO: timer -class TransitServer(object): +class TransitServerState(object): """ Encapsulates the state-machine of the server side of a transit relay connection. @@ -79,6 +156,36 @@ class TransitServer(object): _machine = automat.MethodicalMachine() _client = None + _buddy = None + _token = None + _side = None + _first = None + _mood = "empty" + + def __init__(self, pending_requests): + self._pending_requests = pending_requests + + def get_token(self): + """ + :returns str: a string describing our token. This will be "-" if + we have no token yet, or "{16 chars}-" if we have + just a token or "{16 chars}-{16 chars}" if we have a token and + a side. + """ + d = "-" + if self._token is not None: + d = self._token[:16].decode("ascii") + if self._side is not None: + d += "-" + self._side.decode("ascii") + else: + d += "-" + return d + + def get_mood(self): + """ + :returns str: description of the current 'mood' of the connection + """ + return self._mood @_machine.input() def connection_made(self, client): @@ -93,16 +200,22 @@ class TransitServer(object): @_machine.input() def please_relay(self, token): - pass + """ + A 'please relay X' message has been received (the original version + of the protocol). + """ @_machine.input() def please_relay_for_side(self, token, side): - pass + """ + A 'please relay X for side Y' message has been received (the + second version of the protocol). + """ @_machine.input() def bad_token(self): """ - A bad token / relay line was received + A bad token / relay line was received (e.g. couldn't be parsed) """ @_machine.input() @@ -113,11 +226,15 @@ class TransitServer(object): @_machine.input() def connection_lost(self): - pass + """ + Our transport has failed. + """ @_machine.input() def partner_connection_lost(self): - pass + """ + Our partner's transport has failed. + """ @_machine.input() def got_bytes(self, data): @@ -142,14 +259,20 @@ class TransitServer(object): """ remove us from the thing that remembers tokens and sides """ + return self._pending_requests.unregister(self._token, self._side, self) @_machine.output() def _send_bad(self): - self._client.send("bad handshake\n") + self._mood = "errory" + self._client.send(b"bad handshake\n") @_machine.output() def _send_ok(self): - self._client.send("ok\n") + self._client.send(b"ok\n") + + @_machine.output() + def _send_impatient(self): + self._client.send(b"impatient\n") @_machine.output() def _send(self, data): @@ -157,10 +280,11 @@ class TransitServer(object): @_machine.output() def _send_to_partner(self, data): - self._client.send_to_partner(data) + self._buddy._client.send(data) @_machine.output() def _connect_partner(self, client): + self._buddy = client self._client.connect_partner(client) @_machine.output() @@ -171,12 +295,60 @@ class TransitServer(object): def _disconnect_partner(self): self._client.disconnect_partner() + # some outputs to record the "mood" .. + @_machine.output() + def _mood_happy(self): + self._mood = "happy" + + @_machine.output() + def _mood_lonely(self): + self._mood = "lonely" + + @_machine.output() + def _mood_impatient(self): + self._mood = "impatient" + + @_machine.output() + def _mood_errory(self): + self._mood = "errory" + + @_machine.output() + def _mood_happy_if_first(self): + """ + We disconnected first so we're only happy if we also connected + first. + """ + if self._first: + self._mood = "happy" + else: + self._mood = "jilted" + + @_machine.output() + def _mood_happy_if_second(self): + """ + We disconnected second so we're only happy if we also connected + second. + """ + if self._first: + self._mood = "jilted" + else: + self._mood = "happy" + def _real_register_token_for_side(self, token, side): """ - basically, _got_handshake() + connection_got_token() from "real" - code ...and if this is the "second" side, hook them up and - pass .got_partner() input to both + A client has connected and sent a valid version 1 or version 2 + handshake. If the former, `side` will be None. + + In either case, we remember the tokens and register + ourselves. This might result in 'got_partner' notifications to + two state-machines if this is the second side for a given token. + + :param bytes token: the token + :param bytes side: The side token (or None) """ + self._token = token + self._side = side + self._first = self._pending_requests.register_token(token, side, self) @_machine.state(initial=True) def listening(self): @@ -217,17 +389,22 @@ class TransitServer(object): wait_relay.upon( please_relay, enter=wait_partner, - outputs=[_register_token], + outputs=[_mood_lonely, _register_token], ) wait_relay.upon( please_relay_for_side, enter=wait_partner, - outputs=[_register_token_for_side], + outputs=[_mood_lonely, _register_token_for_side], ) wait_relay.upon( bad_token, enter=done, - outputs=[_send_bad, _disconnect], + outputs=[_mood_errory, _send_bad, _disconnect], + ) + wait_relay.upon( + got_bytes, + enter=done, + outputs=[_mood_errory, _disconnect], ) wait_relay.upon( connection_lost, @@ -238,12 +415,17 @@ class TransitServer(object): wait_partner.upon( got_partner, enter=relaying, - outputs=[_send_ok, _connect_partner], + outputs=[_mood_happy, _send_ok, _connect_partner], ) wait_partner.upon( connection_lost, enter=done, - outputs=[_unregister], + outputs=[_mood_lonely, _unregister], + ) + wait_partner.upon( + got_bytes, + enter=done, + outputs=[_mood_impatient, _send_impatient, _disconnect, _unregister], ) relaying.upon( @@ -254,12 +436,23 @@ class TransitServer(object): relaying.upon( connection_lost, enter=done, - outputs=[_disconnect_partner, _unregister], + outputs=[_mood_happy_if_first, _disconnect_partner, _unregister], ) relaying.upon( partner_connection_lost, enter=done, - outputs=[_disconnect, _unregister], + outputs=[_mood_happy_if_second, _disconnect, _unregister], + ) + + done.upon( + connection_lost, + enter=done, + outputs=[], + ) + done.upon( + partner_connection_lost, + enter=done, + outputs=[], ) @@ -272,9 +465,12 @@ class TransitServer(object): # - ... if __name__ == "__main__": - server0 = TransitServer() + active = ActiveConnections() + pending = PendingRequests(active) + + server0 = TransitServerState(pending) client0 = TestClient() - server1 = TransitServer() + server1 = TransitServerState(pending) client1 = TestClient() server0.connection_made(client0) server0.please_relay(b"bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb") @@ -282,12 +478,14 @@ if __name__ == "__main__": # this would be an error, because our partner hasn't shown up yet # print(server0.got_bytes(b"asdf")) + print("about to relay client1") server1.connection_made(client1) server1.please_relay(b"bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb") + print("done") # XXX the PendingRequests stuff should do this, going "by hand" for now - server0.got_partner(client1) - server1.got_partner(client0) +# server0.got_partner(client1) +# server1.got_partner(client0) # should be connected now server0.got_bytes(b"asdf") diff --git a/src/wormhole_transit_relay/transit_server.py b/src/wormhole_transit_relay/transit_server.py index 426f50f..9de3b1e 100644 --- a/src/wormhole_transit_relay/transit_server.py +++ b/src/wormhole_transit_relay/transit_server.py @@ -24,6 +24,17 @@ def blur_size(size): return round_to(size, 1e6) return round_to(size, 100e6) + +from wormhole_transit_relay.server_state import ( + TransitServerState, + PendingRequests, + ActiveConnections, + ITransitClient, +) +from zope.interface import implementer + + +@implementer(ITransitClient) class TransitConnection(LineReceiver): delimiter = b'\n' # maximum length of a line we will accept before the handshake is complete. @@ -32,13 +43,33 @@ class TransitConnection(LineReceiver): MAX_LENGTH = 1024 def __init__(self): - self._got_token = False - self._got_side = False - self._sent_ok = False - self._mood = "empty" - self._buddy = None self._total_sent = 0 + def send(self, data): + """ + ITransitClient API + """ + self.transport.write(data) + + def disconnect(self): + """ + ITransitClient API + """ + self.transport.loseConnection() + + def connect_partner(self, other): + """ + ITransitClient API + """ + self._buddy = other + + def disconnect_partner(self): + """ + ITransitClient API + """ + self._buddy._client.transport.loseConnection() + self._buddy = None + def describeToken(self): d = "-" if self._got_token: @@ -50,6 +81,8 @@ class TransitConnection(LineReceiver): return d def connectionMade(self): + self._state = TransitServerState(self.factory.pending_requests) + self._state.connection_made(self) self._started = time.time() self._log_requests = self.factory._log_requests try: @@ -71,10 +104,10 @@ class TransitConnection(LineReceiver): side = new.group(2) return self._got_handshake(token, side) - self.sendLine(b"bad handshake") - if self._log_requests: - log.msg("transit handshake failure") - return self.disconnect_error() + # state-machine calls us via ITransitClient interface to do + # bad handshake etc. + return self._state.bad_token() + #return self._state.got_bytes(line) def rawDataReceived(self, data): # We are an IPushProducer to our buddy's IConsumer, so they'll @@ -83,33 +116,15 @@ class TransitConnection(LineReceiver): # practice, this buffers about 10MB per connection, after which # point the sender will only transmit data as fast as the # receiver can handle it. - if self._sent_ok: - # if self._buddy is None then our buddy disconnected - # (we're "jilted"), so we hung up too, but our incoming - # data hasn't stopped yet (it will in a moment, after our - # disconnect makes a roundtrip through the kernel). This - # probably means the file receiver hung up, and this - # connection is the file sender. In may-2020 this happened - # 11 times in 40 days. - if self._buddy: - self._total_sent += len(data) - self._buddy.transport.write(data) - return - - # handshake is complete but not yet sent_ok - self.sendLine(b"impatient") - if self._log_requests: - log.msg("transit impatience failure") - return self.disconnect_error() # impatience yields failure + self._state.got_bytes(data) + self._total_sent += len(data) def _got_handshake(self, token, side): - self._got_token = token - self._got_side = side - self._mood = "lonely" # until buddy connects + self._state.please_relay_for_side(token, side) + # self._mood = "lonely" # until buddy connects self.setRawMode() - self.factory.connection_got_token(token, side, self) - def buddy_connected(self, them): + def __buddy_connected(self, them): self._buddy = them self._mood = "happy" self.sendLine(b"ok") @@ -121,7 +136,7 @@ class TransitConnection(LineReceiver): # The Transit object calls buddy_connected() on both protocols, so # there will be two producer/consumer pairs. - def buddy_disconnected(self): + def __buddy_disconnected(self): if self._log_requests: log.msg("buddy_disconnected %s" % self.describeToken()) self._buddy = None @@ -145,56 +160,62 @@ class TransitConnection(LineReceiver): def connectionLost(self, reason): finished = time.time() total_time = finished - self._started + self._state.connection_lost() - # Record usage. There are eight cases: - # * n0: we haven't gotten a full handshake yet (empty) - # * n1: the handshake failed, not a real client (errory) - # * n2: real client disconnected before any buddy appeared (lonely) - # * n3: real client closed as redundant after buddy appears (redundant) - # * n4: real client connected first, buddy closes first (jilted) - # * n5: real client connected first, buddy close last (happy) - # * n6: real client connected last, buddy closes first (jilted) - # * n7: real client connected last, buddy closes last (happy) + # XXX FIXME record usage - # * non-connected clients (0,1,2,3) always write a usage record - # * for connected clients, whoever disconnects first gets to write the - # usage record (5, 7). The last disconnect doesn't write a record. + if False: + # Record usage. There are eight cases: + # * n0: we haven't gotten a full handshake yet (empty) + # * n1: the handshake failed, not a real client (errory) + # * n2: real client disconnected before any buddy appeared (lonely) + # * n3: real client closed as redundant after buddy appears (redundant) + # * n4: real client connected first, buddy closes first (jilted) + # * n5: real client connected first, buddy close last (happy) + # * n6: real client connected last, buddy closes first (jilted) + # * n7: real client connected last, buddy closes last (happy) + + # * non-connected clients (0,1,2,3) always write a usage record + # * for connected clients, whoever disconnects first gets to write the + # usage record (5, 7). The last disconnect doesn't write a record. + + if self._mood == "empty": # 0 + assert not self._buddy + self.factory.recordUsage(self._started, "empty", 0, + total_time, None) + elif self._mood == "errory": # 1 + assert not self._buddy + self.factory.recordUsage(self._started, "errory", 0, + total_time, None) + elif self._mood == "redundant": # 3 + assert not self._buddy + self.factory.recordUsage(self._started, "redundant", 0, + total_time, None) + elif self._mood == "jilted": # 4 or 6 + # we were connected, but our buddy hung up on us. They record the + # usage event, we do not + pass + elif self._mood == "lonely": # 2 + assert not self._buddy + self.factory.recordUsage(self._started, "lonely", 0, + total_time, None) + else: # 5 or 7 + # we were connected, we hung up first. We record the event. + assert self._mood == "happy", self._mood + assert self._buddy + starts = [self._started, self._buddy._started] + total_time = finished - min(starts) + waiting_time = max(starts) - min(starts) + total_bytes = self._total_sent + self._buddy._total_sent + self.factory.recordUsage(self._started, "happy", total_bytes, + total_time, waiting_time) + + if self._buddy: + self._buddy.buddy_disconnected() + # self.factory.transitFinished(self, self._got_token, self._got_side, + # self.describeToken()) - if self._mood == "empty": # 0 - assert not self._buddy - self.factory.recordUsage(self._started, "empty", 0, - total_time, None) - elif self._mood == "errory": # 1 - assert not self._buddy - self.factory.recordUsage(self._started, "errory", 0, - total_time, None) - elif self._mood == "redundant": # 3 - assert not self._buddy - self.factory.recordUsage(self._started, "redundant", 0, - total_time, None) - elif self._mood == "jilted": # 4 or 6 - # we were connected, but our buddy hung up on us. They record the - # usage event, we do not - pass - elif self._mood == "lonely": # 2 - assert not self._buddy - self.factory.recordUsage(self._started, "lonely", 0, - total_time, None) - else: # 5 or 7 - # we were connected, we hung up first. We record the event. - assert self._mood == "happy", self._mood - assert self._buddy - starts = [self._started, self._buddy._started] - total_time = finished - min(starts) - waiting_time = max(starts) - min(starts) - total_bytes = self._total_sent + self._buddy._total_sent - self.factory.recordUsage(self._started, "happy", total_bytes, - total_time, waiting_time) - if self._buddy: - self._buddy.buddy_disconnected() - self.factory.transitFinished(self, self._got_token, self._got_side, - self.describeToken()) class Transit(protocol.ServerFactory): # I manage pairs of simultaneous connections to a secondary TCP port, @@ -230,6 +251,8 @@ class Transit(protocol.ServerFactory): protocol = TransitConnection def __init__(self, blur_usage, log_file, usage_db): + self.active_connections = ActiveConnections() + self.pending_requests = PendingRequests(self.active_connections) self._blur_usage = blur_usage self._log_requests = blur_usage is None if self._blur_usage: @@ -247,39 +270,6 @@ class Transit(protocol.ServerFactory): self._pending_requests = defaultdict(set) # token -> set((side, TransitConnection)) self._active_connections = set() # TransitConnection - def connection_got_token(self, token, new_side, new_tc): - potentials = self._pending_requests[token] - for old in potentials: - (old_side, old_tc) = old - if ((old_side is None) - or (new_side is None) - or (old_side != new_side)): - # we found a match - if self._debug_log: - log.msg("transit relay 2: %s" % new_tc.describeToken()) - - # drop and stop tracking the rest - potentials.remove(old) - for (_, leftover_tc) in potentials.copy(): - # Don't record this as errory. It's just a spare connection - # from the same side as a connection that got used. This - # can happen if the connection hint contains multiple - # addresses (we don't currently support those, but it'd - # probably be useful in the future). - leftover_tc.disconnect_redundant() - self._pending_requests.pop(token, None) - - # glue the two ends together - self._active_connections.add(new_tc) - self._active_connections.add(old_tc) - new_tc.buddy_connected(old_tc) - old_tc.buddy_connected(new_tc) - return - if self._debug_log: - log.msg("transit relay 1: %s" % new_tc.describeToken()) - potentials.add((new_side, new_tc)) - # TODO: timer - def transitFinished(self, tc, token, side, description): if token in self._pending_requests: side_tc = (side, tc)