diff --git a/src/wormhole_transit_relay/transit_server.py b/src/wormhole_transit_relay/transit_server.py index 7fc1384..3a2b168 100644 --- a/src/wormhole_transit_relay/transit_server.py +++ b/src/wormhole_transit_relay/transit_server.py @@ -26,7 +26,11 @@ def blur_size(size): class TransitConnection(LineReceiver): delimiter = b'\n' - + # maximum length of a line we will accept before the handshake is complete. + # This must be >= to the longest possible handshake message. + + MAX_LENGTH = 1024 + def __init__(self): self._got_token = False self._got_side = False @@ -51,14 +55,17 @@ class TransitConnection(LineReceiver): self.transport.setTcpKeepAlive(True) def lineReceived(self, line): - old = self._check_old_handshake(line) - if old is not None: - token = old + # old: "please relay {64}\n" + old = re.fullmatch(br"please relay (\w{64})", line) + if old: + token = old.group(1) return self._got_handshake(token, None) - new = self._check_new_handshake(line) - if new is not None: - token, side = new + # new: "please relay {64} for side {16}\n" + new = re.fullmatch(br"please relay (\w{64}) for side (\w{16})", line) + if new: + token = new.group(1) + side = new.group(2) return self._got_handshake(token, side) self.sendLine(b"bad handshake") @@ -76,27 +83,6 @@ class TransitConnection(LineReceiver): if self._log_requests: log.msg("transit impatience failure") - def _check_old_handshake(self, buf): - # old: "please relay {64}\n" - # return token if buf contains an old-style handshake - # return None if buf does not contain one - mo = re.search(br"^please relay (\w{64})$", buf, re.M) - if mo: - token = mo.group(1) - return token - return None - - def _check_new_handshake(self, buf): - # new: "please relay {64} for side {16}\n" - # return (token, side) if but contains a new-style handshake - # return None if buf does not contain one - mo = re.search(br"^please relay (\w{64}) for side (\w{16})$", buf, re.M) - if mo: - token = mo.group(1) - side = mo.group(2) - return (token, side) - return None - def _got_handshake(self, token, side): self._got_token = token self._got_side = side