diff --git a/src/wormhole_transit_relay/transit_server.py b/src/wormhole_transit_relay/transit_server.py index 4075979..c3ca635 100644 --- a/src/wormhole_transit_relay/transit_server.py +++ b/src/wormhole_transit_relay/transit_server.py @@ -3,6 +3,7 @@ import re, time, json from collections import defaultdict from twisted.python import log from twisted.internet import protocol +from twisted.protocols.basic import LineReceiver from .database import get_db SECONDS = 1.0 @@ -23,11 +24,16 @@ def blur_size(size): return round_to(size, 1e6) return round_to(size, 100e6) -class TransitConnection(protocol.Protocol): +class TransitConnection(LineReceiver): + delimiter = b'\n' + # maximum length of a line we will accept before the handshake is complete. + # This must be >= to the longest possible handshake message. + + MAX_LENGTH = 1024 + def __init__(self): self._got_token = False self._got_side = False - self._token_buffer = b"" self._sent_ok = False self._mood = "empty" self._buddy = None @@ -48,14 +54,33 @@ class TransitConnection(protocol.Protocol): self._log_requests = self.factory._log_requests self.transport.setTcpKeepAlive(True) - def dataReceived(self, data): + def lineReceived(self, line): + # old: "please relay {64}\n" + old = re.search(br"^please relay (\w{64})$", line) + if old: + token = old.group(1) + return self._got_handshake(token, None) + + # new: "please relay {64} for side {16}\n" + new = re.search(br"^please relay (\w{64}) for side (\w{16})$", line) + if new: + token = new.group(1) + side = new.group(2) + return self._got_handshake(token, side) + + self.sendLine(b"bad handshake") + if self._log_requests: + log.msg("transit handshake failure") + return self.disconnect_error() + + def rawDataReceived(self, data): + # We are an IPushProducer to our buddy's IConsumer, so they'll + # throttle us (by calling pauseProducing()) when their outbound + # buffer is full (e.g. when their downstream pipe is full). In + # practice, this buffers about 10MB per connection, after which + # point the sender will only transmit data as fast as the + # receiver can handle it. if self._sent_ok: - # We are an IPushProducer to our buddy's IConsumer, so they'll - # throttle us (by calling pauseProducing()) when their outbound - # buffer is full (e.g. when their downstream pipe is full). In - # practice, this buffers about 10MB per connection, after which - # point the sender will only transmit data as fast as the - # receiver can handle it. if not self._buddy: # Our buddy disconnected (we're "jilted"), so we hung up too, # but our incoming data hasn't stopped yet (it will in a @@ -68,87 +93,23 @@ class TransitConnection(protocol.Protocol): self._buddy.transport.write(data) return - if self._got_token: # but not yet sent_ok - self.transport.write(b"impatient\n") - if self._log_requests: - log.msg("transit impatience failure") - return self.disconnect_error() # impatience yields failure - - # else this should be (part of) the token - self._token_buffer += data - buf = self._token_buffer - - # old: "please relay {64}\n" - # new: "please relay {64} for side {16}\n" - (old, handshake_len, token) = self._check_old_handshake(buf) - assert old in ("yes", "waiting", "no") - if old == "yes": - # remember they aren't supposed to send anything past their - # handshake until we've said go - if len(buf) > handshake_len: - self.transport.write(b"impatient\n") - if self._log_requests: - log.msg("transit impatience failure") - return self.disconnect_error() # impatience yields failure - return self._got_handshake(token, None) - (new, handshake_len, token, side) = self._check_new_handshake(buf) - assert new in ("yes", "waiting", "no") - if new == "yes": - if len(buf) > handshake_len: - self.transport.write(b"impatient\n") - if self._log_requests: - log.msg("transit impatience failure") - return self.disconnect_error() # impatience yields failure - return self._got_handshake(token, side) - if (old == "no" and new == "no"): - self.transport.write(b"bad handshake\n") - if self._log_requests: - log.msg("transit handshake failure") - return self.disconnect_error() # incorrectness yields failure - # else we'll keep waiting - - def _check_old_handshake(self, buf): - # old: "please relay {64}\n" - # return ("yes", handshake, token) if buf contains an old-style handshake - # return ("waiting", None, None) if it might eventually contain one - # return ("no", None, None) if it could never contain one - wanted = len("please relay \n")+32*2 - if len(buf) < wanted-1 and b"\n" in buf: - return ("no", None, None) - if len(buf) < wanted: - return ("waiting", None, None) - - mo = re.search(br"^please relay (\w{64})\n", buf, re.M) - if mo: - token = mo.group(1) - return ("yes", wanted, token) - return ("no", None, None) - - def _check_new_handshake(self, buf): - # new: "please relay {64} for side {16}\n" - wanted = len("please relay for side \n")+32*2+8*2 - if len(buf) < wanted-1 and b"\n" in buf: - return ("no", None, None, None) - if len(buf) < wanted: - return ("waiting", None, None, None) - - mo = re.search(br"^please relay (\w{64}) for side (\w{16})\n", buf, re.M) - if mo: - token = mo.group(1) - side = mo.group(2) - return ("yes", wanted, token, side) - return ("no", None, None, None) + # handshake is complete but not yet sent_ok + self.sendLine(b"impatient") + if self._log_requests: + log.msg("transit impatience failure") + return self.disconnect_error() # impatience yields failure def _got_handshake(self, token, side): self._got_token = token self._got_side = side self._mood = "lonely" # until buddy connects + self.setRawMode() self.factory.connection_got_token(token, side, self) def buddy_connected(self, them): self._buddy = them self._mood = "happy" - self.transport.write(b"ok\n") + self.sendLine(b"ok") self._sent_ok = True # Connect the two as a producer/consumer pair. We use streaming=True, # so this expects the IPushProducer interface, and uses