use line receiver to simplify handshake logic

This commit is contained in:
Joe Harrison 2020-03-07 02:14:04 +00:00
parent c6445321d7
commit 4fdd89cb35

View File

@ -3,6 +3,7 @@ import re, time, json
from collections import defaultdict from collections import defaultdict
from twisted.python import log from twisted.python import log
from twisted.internet import protocol from twisted.internet import protocol
from twisted.protocols.basic import LineReceiver
from .database import get_db from .database import get_db
SECONDS = 1.0 SECONDS = 1.0
@ -23,11 +24,12 @@ def blur_size(size):
return round_to(size, 1e6) return round_to(size, 1e6)
return round_to(size, 100e6) return round_to(size, 100e6)
class TransitConnection(protocol.Protocol): class TransitConnection(LineReceiver):
delimiter = b'\n'
def __init__(self): def __init__(self):
self._got_token = False self._got_token = False
self._got_side = False self._got_side = False
self._token_buffer = b""
self._sent_ok = False self._sent_ok = False
self._mood = None self._mood = None
self._buddy = None self._buddy = None
@ -48,99 +50,64 @@ class TransitConnection(protocol.Protocol):
self._log_requests = self.factory._log_requests self._log_requests = self.factory._log_requests
self.transport.setTcpKeepAlive(True) self.transport.setTcpKeepAlive(True)
def dataReceived(self, data): def lineReceived(self, line):
old = self._check_old_handshake(line)
if old is not None:
token = old
return self._got_handshake(token, None)
new = self._check_new_handshake(line)
if new is not None:
token, side = new
return self._got_handshake(token, side)
self.sendLine(b"bad handshake")
if self._log_requests:
log.msg("transit handshake failure")
return self.disconnect_error()
def rawDataReceived(self, data):
if self._sent_ok: if self._sent_ok:
# We are an IPushProducer to our buddy's IConsumer, so they'll
# throttle us (by calling pauseProducing()) when their outbound
# buffer is full (e.g. when their downstream pipe is full). In
# practice, this buffers about 10MB per connection, after which
# point the sender will only transmit data as fast as the
# receiver can handle it.
self._total_sent += len(data) self._total_sent += len(data)
self._buddy.transport.write(data) self._buddy.transport.write(data)
return return
if self._got_token: # but not yet sent_ok self.sendLine(b"impatient")
self.transport.write(b"impatient\n") if self._log_requests:
if self._log_requests: log.msg("transit impatience failure")
log.msg("transit impatience failure")
return self.disconnect_error() # impatience yields failure
# else this should be (part of) the token
self._token_buffer += data
buf = self._token_buffer
# old: "please relay {64}\n"
# new: "please relay {64} for side {16}\n"
(old, handshake_len, token) = self._check_old_handshake(buf)
assert old in ("yes", "waiting", "no")
if old == "yes":
# remember they aren't supposed to send anything past their
# handshake until we've said go
if len(buf) > handshake_len:
self.transport.write(b"impatient\n")
if self._log_requests:
log.msg("transit impatience failure")
return self.disconnect_error() # impatience yields failure
return self._got_handshake(token, None)
(new, handshake_len, token, side) = self._check_new_handshake(buf)
assert new in ("yes", "waiting", "no")
if new == "yes":
if len(buf) > handshake_len:
self.transport.write(b"impatient\n")
if self._log_requests:
log.msg("transit impatience failure")
return self.disconnect_error() # impatience yields failure
return self._got_handshake(token, side)
if (old == "no" and new == "no"):
self.transport.write(b"bad handshake\n")
if self._log_requests:
log.msg("transit handshake failure")
return self.disconnect_error() # incorrectness yields failure
# else we'll keep waiting
def _check_old_handshake(self, buf): def _check_old_handshake(self, buf):
# old: "please relay {64}\n" # old: "please relay {64}\n"
# return ("yes", handshake, token) if buf contains an old-style handshake # return token if buf contains an old-style handshake
# return ("waiting", None, None) if it might eventually contain one # return None if buf does not contain one
# return ("no", None, None) if it could never contain one mo = re.search(br"^please relay (\w{64})$", buf, re.M)
wanted = len("please relay \n")+32*2
if len(buf) < wanted-1 and b"\n" in buf:
return ("no", None, None)
if len(buf) < wanted:
return ("waiting", None, None)
mo = re.search(br"^please relay (\w{64})\n", buf, re.M)
if mo: if mo:
token = mo.group(1) token = mo.group(1)
return ("yes", wanted, token) return token
return ("no", None, None) return None
def _check_new_handshake(self, buf): def _check_new_handshake(self, buf):
# new: "please relay {64} for side {16}\n" # new: "please relay {64} for side {16}\n"
wanted = len("please relay for side \n")+32*2+8*2 # return (token, side) if but contains a new-style handshake
if len(buf) < wanted-1 and b"\n" in buf: # return None if buf does not contain one
return ("no", None, None, None) mo = re.search(br"^please relay (\w{64}) for side (\w{16})$", buf, re.M)
if len(buf) < wanted:
return ("waiting", None, None, None)
mo = re.search(br"^please relay (\w{64}) for side (\w{16})\n", buf, re.M)
if mo: if mo:
token = mo.group(1) token = mo.group(1)
side = mo.group(2) side = mo.group(2)
return ("yes", wanted, token, side) return (token, side)
return ("no", None, None, None) return None
def _got_handshake(self, token, side): def _got_handshake(self, token, side):
self._got_token = token self._got_token = token
self._got_side = side self._got_side = side
self._mood = "lonely" # until buddy connects self._mood = "lonely" # until buddy connects
self.setRawMode()
self.factory.connection_got_token(token, side, self) self.factory.connection_got_token(token, side, self)
def buddy_connected(self, them): def buddy_connected(self, them):
self._buddy = them self._buddy = them
self._mood = "happy" self._mood = "happy"
self.transport.write(b"ok\n") self.sendLine(b"ok")
self._sent_ok = True self._sent_ok = True
# Connect the two as a producer/consumer pair. We use streaming=True, # Connect the two as a producer/consumer pair. We use streaming=True,
# so this expects the IPushProducer interface, and uses # so this expects the IPushProducer interface, and uses