diff --git a/src/wormhole_transit_relay/transit_server.py b/src/wormhole_transit_relay/transit_server.py index fd287fc..7fc1384 100644 --- a/src/wormhole_transit_relay/transit_server.py +++ b/src/wormhole_transit_relay/transit_server.py @@ -3,6 +3,7 @@ import re, time, json from collections import defaultdict from twisted.python import log from twisted.internet import protocol +from twisted.protocols.basic import LineReceiver from .database import get_db SECONDS = 1.0 @@ -23,11 +24,12 @@ def blur_size(size): return round_to(size, 1e6) return round_to(size, 100e6) -class TransitConnection(protocol.Protocol): +class TransitConnection(LineReceiver): + delimiter = b'\n' + def __init__(self): self._got_token = False self._got_side = False - self._token_buffer = b"" self._sent_ok = False self._mood = None self._buddy = None @@ -48,99 +50,64 @@ class TransitConnection(protocol.Protocol): self._log_requests = self.factory._log_requests self.transport.setTcpKeepAlive(True) - def dataReceived(self, data): + def lineReceived(self, line): + old = self._check_old_handshake(line) + if old is not None: + token = old + return self._got_handshake(token, None) + + new = self._check_new_handshake(line) + if new is not None: + token, side = new + return self._got_handshake(token, side) + + self.sendLine(b"bad handshake") + if self._log_requests: + log.msg("transit handshake failure") + return self.disconnect_error() + + def rawDataReceived(self, data): if self._sent_ok: - # We are an IPushProducer to our buddy's IConsumer, so they'll - # throttle us (by calling pauseProducing()) when their outbound - # buffer is full (e.g. when their downstream pipe is full). In - # practice, this buffers about 10MB per connection, after which - # point the sender will only transmit data as fast as the - # receiver can handle it. self._total_sent += len(data) self._buddy.transport.write(data) return - if self._got_token: # but not yet sent_ok - self.transport.write(b"impatient\n") - if self._log_requests: - log.msg("transit impatience failure") - return self.disconnect_error() # impatience yields failure - - # else this should be (part of) the token - self._token_buffer += data - buf = self._token_buffer - - # old: "please relay {64}\n" - # new: "please relay {64} for side {16}\n" - (old, handshake_len, token) = self._check_old_handshake(buf) - assert old in ("yes", "waiting", "no") - if old == "yes": - # remember they aren't supposed to send anything past their - # handshake until we've said go - if len(buf) > handshake_len: - self.transport.write(b"impatient\n") - if self._log_requests: - log.msg("transit impatience failure") - return self.disconnect_error() # impatience yields failure - return self._got_handshake(token, None) - (new, handshake_len, token, side) = self._check_new_handshake(buf) - assert new in ("yes", "waiting", "no") - if new == "yes": - if len(buf) > handshake_len: - self.transport.write(b"impatient\n") - if self._log_requests: - log.msg("transit impatience failure") - return self.disconnect_error() # impatience yields failure - return self._got_handshake(token, side) - if (old == "no" and new == "no"): - self.transport.write(b"bad handshake\n") - if self._log_requests: - log.msg("transit handshake failure") - return self.disconnect_error() # incorrectness yields failure - # else we'll keep waiting + self.sendLine(b"impatient") + if self._log_requests: + log.msg("transit impatience failure") def _check_old_handshake(self, buf): # old: "please relay {64}\n" - # return ("yes", handshake, token) if buf contains an old-style handshake - # return ("waiting", None, None) if it might eventually contain one - # return ("no", None, None) if it could never contain one - wanted = len("please relay \n")+32*2 - if len(buf) < wanted-1 and b"\n" in buf: - return ("no", None, None) - if len(buf) < wanted: - return ("waiting", None, None) - - mo = re.search(br"^please relay (\w{64})\n", buf, re.M) + # return token if buf contains an old-style handshake + # return None if buf does not contain one + mo = re.search(br"^please relay (\w{64})$", buf, re.M) if mo: token = mo.group(1) - return ("yes", wanted, token) - return ("no", None, None) + return token + return None def _check_new_handshake(self, buf): # new: "please relay {64} for side {16}\n" - wanted = len("please relay for side \n")+32*2+8*2 - if len(buf) < wanted-1 and b"\n" in buf: - return ("no", None, None, None) - if len(buf) < wanted: - return ("waiting", None, None, None) - - mo = re.search(br"^please relay (\w{64}) for side (\w{16})\n", buf, re.M) + # return (token, side) if but contains a new-style handshake + # return None if buf does not contain one + mo = re.search(br"^please relay (\w{64}) for side (\w{16})$", buf, re.M) if mo: token = mo.group(1) side = mo.group(2) - return ("yes", wanted, token, side) - return ("no", None, None, None) - + return (token, side) + return None + def _got_handshake(self, token, side): self._got_token = token self._got_side = side self._mood = "lonely" # until buddy connects + self.setRawMode() self.factory.connection_got_token(token, side, self) def buddy_connected(self, them): self._buddy = them self._mood = "happy" - self.transport.write(b"ok\n") + self.sendLine(b"ok") self._sent_ok = True # Connect the two as a producer/consumer pair. We use streaming=True, # so this expects the IPushProducer interface, and uses