(wip) refactor to use Automat state-machine

This commit is contained in:
meejah 2021-01-25 17:59:14 -07:00
parent 555c23d4fe
commit 0e11f1b8f1
2 changed files with 327 additions and 139 deletions

View File

@ -1,3 +1,4 @@
from collections import defaultdict
import automat import automat
from zope.interface import ( from zope.interface import (
@ -12,7 +13,7 @@ class ITransitClient(Interface):
Send some byets to the client Send some byets to the client
""" """
def disconnect(reason): def disconnect():
""" """
Disconnect the client transport Disconnect the client transport
""" """
@ -37,7 +38,7 @@ class TestClient(object):
def send_to_partner(self, data): def send_to_partner(self, data):
print("{} GOT:{}".format(id(self), repr(data))) print("{} GOT:{}".format(id(self), repr(data)))
if self._partner: if self._partner:
self._partner.send(data) self._partner._client.send(data)
def send(self, data): def send(self, data):
print("{} SEND:{}".format(id(self), repr(data))) print("{} SEND:{}".format(id(self), repr(data)))
@ -56,18 +57,94 @@ class TestClient(object):
print("disconnect_partner: {}".format(id(self._partner))) print("disconnect_partner: {}".format(id(self._partner)))
class ActiveConnections(object):
"""
Tracks active connections. A connection is 'active' when both
sides have shown up and they are glued together.
"""
def __init__(self):
self._connections = set()
def register(self, side0, side1):
"""
A connection has become active so register both its sides
:param TransitConnection side0: one side of the connection
:param TransitConnection side1: one side of the connection
"""
self._connections.add(side0)
self._connections.add(side1)
def unregister(self, side):
"""
One side of a connection has become inactive.
:param TransitConnection side: an inactive side of a connection
"""
self._connections.discard(side)
class PendingRequests(object): class PendingRequests(object):
""" """
Tracks the tokens we have received from client connections and Tracks the tokens we have received from client connections and
maps them to their partner connections maps them to their partner connections for requests that haven't
yet been 'glued together' (that is, one side hasn't yet shown up).
""" """
def register_token(self, *args): def __init__(self, active_connections):
self._requests = defaultdict(set) # token -> set((side, TransitConnection))
self._active = active_connections
def unregister(self, token, side, tc):
if token in self._requests:
self._requests[token].discard((side, tc))
self._active.unregister(tc)
def register_token(self, token, new_side, new_tc):
""" """
A client has connected and successfully offered a token (and
optional 'side' token). If this is the first one for this
token, we merely remember it. If it is the second side for
this token we connect them together.
:returns bool: True if we are the first side to register this
token
""" """
potentials = self._requests[token]
for old in potentials:
(old_side, old_tc) = old
if ((old_side is None)
or (new_side is None)
or (old_side != new_side)):
# we found a match
# FIXME: debug-log this
# print("transit relay 2: %s" % new_tc.get_token())
# drop and stop tracking the rest
potentials.remove(old)
for (_, leftover_tc) in potentials.copy():
# Don't record this as errory. It's just a spare connection
# from the same side as a connection that got used. This
# can happen if the connection hint contains multiple
# addresses (we don't currently support those, but it'd
# probably be useful in the future).
leftover_tc.disconnect_redundant()
self._requests.pop(token, None)
# glue the two ends together
self._active.register(new_tc, old_tc)
new_tc.got_partner(old_tc)
old_tc.got_partner(new_tc)
return False
# FIXME: debug-log this
# print("transit relay 1: %s" % new_tc.get_token())
potentials.add((new_side, new_tc))
return True
# TODO: timer
class TransitServer(object): class TransitServerState(object):
""" """
Encapsulates the state-machine of the server side of a transit Encapsulates the state-machine of the server side of a transit
relay connection. relay connection.
@ -79,6 +156,36 @@ class TransitServer(object):
_machine = automat.MethodicalMachine() _machine = automat.MethodicalMachine()
_client = None _client = None
_buddy = None
_token = None
_side = None
_first = None
_mood = "empty"
def __init__(self, pending_requests):
self._pending_requests = pending_requests
def get_token(self):
"""
:returns str: a string describing our token. This will be "-" if
we have no token yet, or "{16 chars}-<unsided>" if we have
just a token or "{16 chars}-{16 chars}" if we have a token and
a side.
"""
d = "-"
if self._token is not None:
d = self._token[:16].decode("ascii")
if self._side is not None:
d += "-" + self._side.decode("ascii")
else:
d += "-<unsided>"
return d
def get_mood(self):
"""
:returns str: description of the current 'mood' of the connection
"""
return self._mood
@_machine.input() @_machine.input()
def connection_made(self, client): def connection_made(self, client):
@ -93,16 +200,22 @@ class TransitServer(object):
@_machine.input() @_machine.input()
def please_relay(self, token): def please_relay(self, token):
pass """
A 'please relay X' message has been received (the original version
of the protocol).
"""
@_machine.input() @_machine.input()
def please_relay_for_side(self, token, side): def please_relay_for_side(self, token, side):
pass """
A 'please relay X for side Y' message has been received (the
second version of the protocol).
"""
@_machine.input() @_machine.input()
def bad_token(self): def bad_token(self):
""" """
A bad token / relay line was received A bad token / relay line was received (e.g. couldn't be parsed)
""" """
@_machine.input() @_machine.input()
@ -113,11 +226,15 @@ class TransitServer(object):
@_machine.input() @_machine.input()
def connection_lost(self): def connection_lost(self):
pass """
Our transport has failed.
"""
@_machine.input() @_machine.input()
def partner_connection_lost(self): def partner_connection_lost(self):
pass """
Our partner's transport has failed.
"""
@_machine.input() @_machine.input()
def got_bytes(self, data): def got_bytes(self, data):
@ -142,14 +259,20 @@ class TransitServer(object):
""" """
remove us from the thing that remembers tokens and sides remove us from the thing that remembers tokens and sides
""" """
return self._pending_requests.unregister(self._token, self._side, self)
@_machine.output() @_machine.output()
def _send_bad(self): def _send_bad(self):
self._client.send("bad handshake\n") self._mood = "errory"
self._client.send(b"bad handshake\n")
@_machine.output() @_machine.output()
def _send_ok(self): def _send_ok(self):
self._client.send("ok\n") self._client.send(b"ok\n")
@_machine.output()
def _send_impatient(self):
self._client.send(b"impatient\n")
@_machine.output() @_machine.output()
def _send(self, data): def _send(self, data):
@ -157,10 +280,11 @@ class TransitServer(object):
@_machine.output() @_machine.output()
def _send_to_partner(self, data): def _send_to_partner(self, data):
self._client.send_to_partner(data) self._buddy._client.send(data)
@_machine.output() @_machine.output()
def _connect_partner(self, client): def _connect_partner(self, client):
self._buddy = client
self._client.connect_partner(client) self._client.connect_partner(client)
@_machine.output() @_machine.output()
@ -171,12 +295,60 @@ class TransitServer(object):
def _disconnect_partner(self): def _disconnect_partner(self):
self._client.disconnect_partner() self._client.disconnect_partner()
# some outputs to record the "mood" ..
@_machine.output()
def _mood_happy(self):
self._mood = "happy"
@_machine.output()
def _mood_lonely(self):
self._mood = "lonely"
@_machine.output()
def _mood_impatient(self):
self._mood = "impatient"
@_machine.output()
def _mood_errory(self):
self._mood = "errory"
@_machine.output()
def _mood_happy_if_first(self):
"""
We disconnected first so we're only happy if we also connected
first.
"""
if self._first:
self._mood = "happy"
else:
self._mood = "jilted"
@_machine.output()
def _mood_happy_if_second(self):
"""
We disconnected second so we're only happy if we also connected
second.
"""
if self._first:
self._mood = "jilted"
else:
self._mood = "happy"
def _real_register_token_for_side(self, token, side): def _real_register_token_for_side(self, token, side):
""" """
basically, _got_handshake() + connection_got_token() from "real" A client has connected and sent a valid version 1 or version 2
code ...and if this is the "second" side, hook them up and handshake. If the former, `side` will be None.
pass .got_partner() input to both
In either case, we remember the tokens and register
ourselves. This might result in 'got_partner' notifications to
two state-machines if this is the second side for a given token.
:param bytes token: the token
:param bytes side: The side token (or None)
""" """
self._token = token
self._side = side
self._first = self._pending_requests.register_token(token, side, self)
@_machine.state(initial=True) @_machine.state(initial=True)
def listening(self): def listening(self):
@ -217,17 +389,22 @@ class TransitServer(object):
wait_relay.upon( wait_relay.upon(
please_relay, please_relay,
enter=wait_partner, enter=wait_partner,
outputs=[_register_token], outputs=[_mood_lonely, _register_token],
) )
wait_relay.upon( wait_relay.upon(
please_relay_for_side, please_relay_for_side,
enter=wait_partner, enter=wait_partner,
outputs=[_register_token_for_side], outputs=[_mood_lonely, _register_token_for_side],
) )
wait_relay.upon( wait_relay.upon(
bad_token, bad_token,
enter=done, enter=done,
outputs=[_send_bad, _disconnect], outputs=[_mood_errory, _send_bad, _disconnect],
)
wait_relay.upon(
got_bytes,
enter=done,
outputs=[_mood_errory, _disconnect],
) )
wait_relay.upon( wait_relay.upon(
connection_lost, connection_lost,
@ -238,12 +415,17 @@ class TransitServer(object):
wait_partner.upon( wait_partner.upon(
got_partner, got_partner,
enter=relaying, enter=relaying,
outputs=[_send_ok, _connect_partner], outputs=[_mood_happy, _send_ok, _connect_partner],
) )
wait_partner.upon( wait_partner.upon(
connection_lost, connection_lost,
enter=done, enter=done,
outputs=[_unregister], outputs=[_mood_lonely, _unregister],
)
wait_partner.upon(
got_bytes,
enter=done,
outputs=[_mood_impatient, _send_impatient, _disconnect, _unregister],
) )
relaying.upon( relaying.upon(
@ -254,12 +436,23 @@ class TransitServer(object):
relaying.upon( relaying.upon(
connection_lost, connection_lost,
enter=done, enter=done,
outputs=[_disconnect_partner, _unregister], outputs=[_mood_happy_if_first, _disconnect_partner, _unregister],
) )
relaying.upon( relaying.upon(
partner_connection_lost, partner_connection_lost,
enter=done, enter=done,
outputs=[_disconnect, _unregister], outputs=[_mood_happy_if_second, _disconnect, _unregister],
)
done.upon(
connection_lost,
enter=done,
outputs=[],
)
done.upon(
partner_connection_lost,
enter=done,
outputs=[],
) )
@ -272,9 +465,12 @@ class TransitServer(object):
# - ... # - ...
if __name__ == "__main__": if __name__ == "__main__":
server0 = TransitServer() active = ActiveConnections()
pending = PendingRequests(active)
server0 = TransitServerState(pending)
client0 = TestClient() client0 = TestClient()
server1 = TransitServer() server1 = TransitServerState(pending)
client1 = TestClient() client1 = TestClient()
server0.connection_made(client0) server0.connection_made(client0)
server0.please_relay(b"bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb") server0.please_relay(b"bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb")
@ -282,12 +478,14 @@ if __name__ == "__main__":
# this would be an error, because our partner hasn't shown up yet # this would be an error, because our partner hasn't shown up yet
# print(server0.got_bytes(b"asdf")) # print(server0.got_bytes(b"asdf"))
print("about to relay client1")
server1.connection_made(client1) server1.connection_made(client1)
server1.please_relay(b"bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb") server1.please_relay(b"bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb")
print("done")
# XXX the PendingRequests stuff should do this, going "by hand" for now # XXX the PendingRequests stuff should do this, going "by hand" for now
server0.got_partner(client1) # server0.got_partner(client1)
server1.got_partner(client0) # server1.got_partner(client0)
# should be connected now # should be connected now
server0.got_bytes(b"asdf") server0.got_bytes(b"asdf")

View File

@ -24,6 +24,17 @@ def blur_size(size):
return round_to(size, 1e6) return round_to(size, 1e6)
return round_to(size, 100e6) return round_to(size, 100e6)
from wormhole_transit_relay.server_state import (
TransitServerState,
PendingRequests,
ActiveConnections,
ITransitClient,
)
from zope.interface import implementer
@implementer(ITransitClient)
class TransitConnection(LineReceiver): class TransitConnection(LineReceiver):
delimiter = b'\n' delimiter = b'\n'
# maximum length of a line we will accept before the handshake is complete. # maximum length of a line we will accept before the handshake is complete.
@ -32,13 +43,33 @@ class TransitConnection(LineReceiver):
MAX_LENGTH = 1024 MAX_LENGTH = 1024
def __init__(self): def __init__(self):
self._got_token = False
self._got_side = False
self._sent_ok = False
self._mood = "empty"
self._buddy = None
self._total_sent = 0 self._total_sent = 0
def send(self, data):
"""
ITransitClient API
"""
self.transport.write(data)
def disconnect(self):
"""
ITransitClient API
"""
self.transport.loseConnection()
def connect_partner(self, other):
"""
ITransitClient API
"""
self._buddy = other
def disconnect_partner(self):
"""
ITransitClient API
"""
self._buddy._client.transport.loseConnection()
self._buddy = None
def describeToken(self): def describeToken(self):
d = "-" d = "-"
if self._got_token: if self._got_token:
@ -50,6 +81,8 @@ class TransitConnection(LineReceiver):
return d return d
def connectionMade(self): def connectionMade(self):
self._state = TransitServerState(self.factory.pending_requests)
self._state.connection_made(self)
self._started = time.time() self._started = time.time()
self._log_requests = self.factory._log_requests self._log_requests = self.factory._log_requests
try: try:
@ -71,10 +104,10 @@ class TransitConnection(LineReceiver):
side = new.group(2) side = new.group(2)
return self._got_handshake(token, side) return self._got_handshake(token, side)
self.sendLine(b"bad handshake") # state-machine calls us via ITransitClient interface to do
if self._log_requests: # bad handshake etc.
log.msg("transit handshake failure") return self._state.bad_token()
return self.disconnect_error() #return self._state.got_bytes(line)
def rawDataReceived(self, data): def rawDataReceived(self, data):
# We are an IPushProducer to our buddy's IConsumer, so they'll # We are an IPushProducer to our buddy's IConsumer, so they'll
@ -83,33 +116,15 @@ class TransitConnection(LineReceiver):
# practice, this buffers about 10MB per connection, after which # practice, this buffers about 10MB per connection, after which
# point the sender will only transmit data as fast as the # point the sender will only transmit data as fast as the
# receiver can handle it. # receiver can handle it.
if self._sent_ok: self._state.got_bytes(data)
# if self._buddy is None then our buddy disconnected self._total_sent += len(data)
# (we're "jilted"), so we hung up too, but our incoming
# data hasn't stopped yet (it will in a moment, after our
# disconnect makes a roundtrip through the kernel). This
# probably means the file receiver hung up, and this
# connection is the file sender. In may-2020 this happened
# 11 times in 40 days.
if self._buddy:
self._total_sent += len(data)
self._buddy.transport.write(data)
return
# handshake is complete but not yet sent_ok
self.sendLine(b"impatient")
if self._log_requests:
log.msg("transit impatience failure")
return self.disconnect_error() # impatience yields failure
def _got_handshake(self, token, side): def _got_handshake(self, token, side):
self._got_token = token self._state.please_relay_for_side(token, side)
self._got_side = side # self._mood = "lonely" # until buddy connects
self._mood = "lonely" # until buddy connects
self.setRawMode() self.setRawMode()
self.factory.connection_got_token(token, side, self)
def buddy_connected(self, them): def __buddy_connected(self, them):
self._buddy = them self._buddy = them
self._mood = "happy" self._mood = "happy"
self.sendLine(b"ok") self.sendLine(b"ok")
@ -121,7 +136,7 @@ class TransitConnection(LineReceiver):
# The Transit object calls buddy_connected() on both protocols, so # The Transit object calls buddy_connected() on both protocols, so
# there will be two producer/consumer pairs. # there will be two producer/consumer pairs.
def buddy_disconnected(self): def __buddy_disconnected(self):
if self._log_requests: if self._log_requests:
log.msg("buddy_disconnected %s" % self.describeToken()) log.msg("buddy_disconnected %s" % self.describeToken())
self._buddy = None self._buddy = None
@ -145,56 +160,62 @@ class TransitConnection(LineReceiver):
def connectionLost(self, reason): def connectionLost(self, reason):
finished = time.time() finished = time.time()
total_time = finished - self._started total_time = finished - self._started
self._state.connection_lost()
# Record usage. There are eight cases: # XXX FIXME record usage
# * n0: we haven't gotten a full handshake yet (empty)
# * n1: the handshake failed, not a real client (errory)
# * n2: real client disconnected before any buddy appeared (lonely)
# * n3: real client closed as redundant after buddy appears (redundant)
# * n4: real client connected first, buddy closes first (jilted)
# * n5: real client connected first, buddy close last (happy)
# * n6: real client connected last, buddy closes first (jilted)
# * n7: real client connected last, buddy closes last (happy)
# * non-connected clients (0,1,2,3) always write a usage record if False:
# * for connected clients, whoever disconnects first gets to write the # Record usage. There are eight cases:
# usage record (5, 7). The last disconnect doesn't write a record. # * n0: we haven't gotten a full handshake yet (empty)
# * n1: the handshake failed, not a real client (errory)
# * n2: real client disconnected before any buddy appeared (lonely)
# * n3: real client closed as redundant after buddy appears (redundant)
# * n4: real client connected first, buddy closes first (jilted)
# * n5: real client connected first, buddy close last (happy)
# * n6: real client connected last, buddy closes first (jilted)
# * n7: real client connected last, buddy closes last (happy)
# * non-connected clients (0,1,2,3) always write a usage record
# * for connected clients, whoever disconnects first gets to write the
# usage record (5, 7). The last disconnect doesn't write a record.
if self._mood == "empty": # 0
assert not self._buddy
self.factory.recordUsage(self._started, "empty", 0,
total_time, None)
elif self._mood == "errory": # 1
assert not self._buddy
self.factory.recordUsage(self._started, "errory", 0,
total_time, None)
elif self._mood == "redundant": # 3
assert not self._buddy
self.factory.recordUsage(self._started, "redundant", 0,
total_time, None)
elif self._mood == "jilted": # 4 or 6
# we were connected, but our buddy hung up on us. They record the
# usage event, we do not
pass
elif self._mood == "lonely": # 2
assert not self._buddy
self.factory.recordUsage(self._started, "lonely", 0,
total_time, None)
else: # 5 or 7
# we were connected, we hung up first. We record the event.
assert self._mood == "happy", self._mood
assert self._buddy
starts = [self._started, self._buddy._started]
total_time = finished - min(starts)
waiting_time = max(starts) - min(starts)
total_bytes = self._total_sent + self._buddy._total_sent
self.factory.recordUsage(self._started, "happy", total_bytes,
total_time, waiting_time)
if self._buddy:
self._buddy.buddy_disconnected()
# self.factory.transitFinished(self, self._got_token, self._got_side,
# self.describeToken())
if self._mood == "empty": # 0
assert not self._buddy
self.factory.recordUsage(self._started, "empty", 0,
total_time, None)
elif self._mood == "errory": # 1
assert not self._buddy
self.factory.recordUsage(self._started, "errory", 0,
total_time, None)
elif self._mood == "redundant": # 3
assert not self._buddy
self.factory.recordUsage(self._started, "redundant", 0,
total_time, None)
elif self._mood == "jilted": # 4 or 6
# we were connected, but our buddy hung up on us. They record the
# usage event, we do not
pass
elif self._mood == "lonely": # 2
assert not self._buddy
self.factory.recordUsage(self._started, "lonely", 0,
total_time, None)
else: # 5 or 7
# we were connected, we hung up first. We record the event.
assert self._mood == "happy", self._mood
assert self._buddy
starts = [self._started, self._buddy._started]
total_time = finished - min(starts)
waiting_time = max(starts) - min(starts)
total_bytes = self._total_sent + self._buddy._total_sent
self.factory.recordUsage(self._started, "happy", total_bytes,
total_time, waiting_time)
if self._buddy:
self._buddy.buddy_disconnected()
self.factory.transitFinished(self, self._got_token, self._got_side,
self.describeToken())
class Transit(protocol.ServerFactory): class Transit(protocol.ServerFactory):
# I manage pairs of simultaneous connections to a secondary TCP port, # I manage pairs of simultaneous connections to a secondary TCP port,
@ -230,6 +251,8 @@ class Transit(protocol.ServerFactory):
protocol = TransitConnection protocol = TransitConnection
def __init__(self, blur_usage, log_file, usage_db): def __init__(self, blur_usage, log_file, usage_db):
self.active_connections = ActiveConnections()
self.pending_requests = PendingRequests(self.active_connections)
self._blur_usage = blur_usage self._blur_usage = blur_usage
self._log_requests = blur_usage is None self._log_requests = blur_usage is None
if self._blur_usage: if self._blur_usage:
@ -247,39 +270,6 @@ class Transit(protocol.ServerFactory):
self._pending_requests = defaultdict(set) # token -> set((side, TransitConnection)) self._pending_requests = defaultdict(set) # token -> set((side, TransitConnection))
self._active_connections = set() # TransitConnection self._active_connections = set() # TransitConnection
def connection_got_token(self, token, new_side, new_tc):
potentials = self._pending_requests[token]
for old in potentials:
(old_side, old_tc) = old
if ((old_side is None)
or (new_side is None)
or (old_side != new_side)):
# we found a match
if self._debug_log:
log.msg("transit relay 2: %s" % new_tc.describeToken())
# drop and stop tracking the rest
potentials.remove(old)
for (_, leftover_tc) in potentials.copy():
# Don't record this as errory. It's just a spare connection
# from the same side as a connection that got used. This
# can happen if the connection hint contains multiple
# addresses (we don't currently support those, but it'd
# probably be useful in the future).
leftover_tc.disconnect_redundant()
self._pending_requests.pop(token, None)
# glue the two ends together
self._active_connections.add(new_tc)
self._active_connections.add(old_tc)
new_tc.buddy_connected(old_tc)
old_tc.buddy_connected(new_tc)
return
if self._debug_log:
log.msg("transit relay 1: %s" % new_tc.describeToken())
potentials.add((new_side, new_tc))
# TODO: timer
def transitFinished(self, tc, token, side, description): def transitFinished(self, tc, token, side, description):
if token in self._pending_requests: if token in self._pending_requests:
side_tc = (side, tc) side_tc = (side, tc)