From 0e647074597977fc4da34116decc8b3431c9c8be Mon Sep 17 00:00:00 2001 From: meejah Date: Mon, 1 Feb 2021 16:55:15 -0700 Subject: [PATCH] count totals in state-machine --- src/wormhole_transit_relay/server_state.py | 7 ++++++- src/wormhole_transit_relay/transit_server.py | 10 +++------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/wormhole_transit_relay/server_state.py b/src/wormhole_transit_relay/server_state.py index cf95d1d..ffa0819 100644 --- a/src/wormhole_transit_relay/server_state.py +++ b/src/wormhole_transit_relay/server_state.py @@ -161,6 +161,7 @@ class TransitServerState(object): _side = None _first = None _mood = "empty" + _total_sent = 0 def __init__(self, pending_requests): self._pending_requests = pending_requests @@ -274,6 +275,10 @@ class TransitServerState(object): def _send_impatient(self): self._client.send(b"impatient\n") + @_machine.output() + def _count_bytes(self, data): + self._total_sent += len(data) + @_machine.output() def _send(self, data): self._client.send(data) @@ -404,7 +409,7 @@ class TransitServerState(object): wait_relay.upon( got_bytes, enter=done, - outputs=[_mood_errory, _disconnect], + outputs=[_count_bytes, _mood_errory, _disconnect], ) wait_relay.upon( connection_lost, diff --git a/src/wormhole_transit_relay/transit_server.py b/src/wormhole_transit_relay/transit_server.py index 9de3b1e..abd6406 100644 --- a/src/wormhole_transit_relay/transit_server.py +++ b/src/wormhole_transit_relay/transit_server.py @@ -42,9 +42,6 @@ class TransitConnection(LineReceiver): MAX_LENGTH = 1024 - def __init__(self): - self._total_sent = 0 - def send(self, data): """ ITransitClient API @@ -104,10 +101,10 @@ class TransitConnection(LineReceiver): side = new.group(2) return self._got_handshake(token, side) - # state-machine calls us via ITransitClient interface to do - # bad handshake etc. + # we should have been switched to "raw data" mode on the first + # line received (after which rawDataReceived() is called for + # all bytes) so getting here means a bad handshake. return self._state.bad_token() - #return self._state.got_bytes(line) def rawDataReceived(self, data): # We are an IPushProducer to our buddy's IConsumer, so they'll @@ -117,7 +114,6 @@ class TransitConnection(LineReceiver): # point the sender will only transmit data as fast as the # receiver can handle it. self._state.got_bytes(data) - self._total_sent += len(data) def _got_handshake(self, token, side): self._state.please_relay_for_side(token, side)