(wip) refactor to use Automat state-machine

This commit is contained in:
meejah 2021-01-25 17:59:14 -07:00
parent 555c23d4fe
commit 0e11f1b8f1
2 changed files with 327 additions and 139 deletions

View File

@ -1,3 +1,4 @@
from collections import defaultdict
import automat
from zope.interface import (
@ -12,7 +13,7 @@ class ITransitClient(Interface):
Send some byets to the client
"""
def disconnect(reason):
def disconnect():
"""
Disconnect the client transport
"""
@ -37,7 +38,7 @@ class TestClient(object):
def send_to_partner(self, data):
print("{} GOT:{}".format(id(self), repr(data)))
if self._partner:
self._partner.send(data)
self._partner._client.send(data)
def send(self, data):
print("{} SEND:{}".format(id(self), repr(data)))
@ -56,18 +57,94 @@ class TestClient(object):
print("disconnect_partner: {}".format(id(self._partner)))
class ActiveConnections(object):
"""
Tracks active connections. A connection is 'active' when both
sides have shown up and they are glued together.
"""
def __init__(self):
self._connections = set()
def register(self, side0, side1):
"""
A connection has become active so register both its sides
:param TransitConnection side0: one side of the connection
:param TransitConnection side1: one side of the connection
"""
self._connections.add(side0)
self._connections.add(side1)
def unregister(self, side):
"""
One side of a connection has become inactive.
:param TransitConnection side: an inactive side of a connection
"""
self._connections.discard(side)
class PendingRequests(object):
"""
Tracks the tokens we have received from client connections and
maps them to their partner connections
maps them to their partner connections for requests that haven't
yet been 'glued together' (that is, one side hasn't yet shown up).
"""
def register_token(self, *args):
def __init__(self, active_connections):
self._requests = defaultdict(set) # token -> set((side, TransitConnection))
self._active = active_connections
def unregister(self, token, side, tc):
if token in self._requests:
self._requests[token].discard((side, tc))
self._active.unregister(tc)
def register_token(self, token, new_side, new_tc):
"""
A client has connected and successfully offered a token (and
optional 'side' token). If this is the first one for this
token, we merely remember it. If it is the second side for
this token we connect them together.
:returns bool: True if we are the first side to register this
token
"""
potentials = self._requests[token]
for old in potentials:
(old_side, old_tc) = old
if ((old_side is None)
or (new_side is None)
or (old_side != new_side)):
# we found a match
# FIXME: debug-log this
# print("transit relay 2: %s" % new_tc.get_token())
# drop and stop tracking the rest
potentials.remove(old)
for (_, leftover_tc) in potentials.copy():
# Don't record this as errory. It's just a spare connection
# from the same side as a connection that got used. This
# can happen if the connection hint contains multiple
# addresses (we don't currently support those, but it'd
# probably be useful in the future).
leftover_tc.disconnect_redundant()
self._requests.pop(token, None)
# glue the two ends together
self._active.register(new_tc, old_tc)
new_tc.got_partner(old_tc)
old_tc.got_partner(new_tc)
return False
# FIXME: debug-log this
# print("transit relay 1: %s" % new_tc.get_token())
potentials.add((new_side, new_tc))
return True
# TODO: timer
class TransitServer(object):
class TransitServerState(object):
"""
Encapsulates the state-machine of the server side of a transit
relay connection.
@ -79,6 +156,36 @@ class TransitServer(object):
_machine = automat.MethodicalMachine()
_client = None
_buddy = None
_token = None
_side = None
_first = None
_mood = "empty"
def __init__(self, pending_requests):
self._pending_requests = pending_requests
def get_token(self):
"""
:returns str: a string describing our token. This will be "-" if
we have no token yet, or "{16 chars}-<unsided>" if we have
just a token or "{16 chars}-{16 chars}" if we have a token and
a side.
"""
d = "-"
if self._token is not None:
d = self._token[:16].decode("ascii")
if self._side is not None:
d += "-" + self._side.decode("ascii")
else:
d += "-<unsided>"
return d
def get_mood(self):
"""
:returns str: description of the current 'mood' of the connection
"""
return self._mood
@_machine.input()
def connection_made(self, client):
@ -93,16 +200,22 @@ class TransitServer(object):
@_machine.input()
def please_relay(self, token):
pass
"""
A 'please relay X' message has been received (the original version
of the protocol).
"""
@_machine.input()
def please_relay_for_side(self, token, side):
pass
"""
A 'please relay X for side Y' message has been received (the
second version of the protocol).
"""
@_machine.input()
def bad_token(self):
"""
A bad token / relay line was received
A bad token / relay line was received (e.g. couldn't be parsed)
"""
@_machine.input()
@ -113,11 +226,15 @@ class TransitServer(object):
@_machine.input()
def connection_lost(self):
pass
"""
Our transport has failed.
"""
@_machine.input()
def partner_connection_lost(self):
pass
"""
Our partner's transport has failed.
"""
@_machine.input()
def got_bytes(self, data):
@ -142,14 +259,20 @@ class TransitServer(object):
"""
remove us from the thing that remembers tokens and sides
"""
return self._pending_requests.unregister(self._token, self._side, self)
@_machine.output()
def _send_bad(self):
self._client.send("bad handshake\n")
self._mood = "errory"
self._client.send(b"bad handshake\n")
@_machine.output()
def _send_ok(self):
self._client.send("ok\n")
self._client.send(b"ok\n")
@_machine.output()
def _send_impatient(self):
self._client.send(b"impatient\n")
@_machine.output()
def _send(self, data):
@ -157,10 +280,11 @@ class TransitServer(object):
@_machine.output()
def _send_to_partner(self, data):
self._client.send_to_partner(data)
self._buddy._client.send(data)
@_machine.output()
def _connect_partner(self, client):
self._buddy = client
self._client.connect_partner(client)
@_machine.output()
@ -171,12 +295,60 @@ class TransitServer(object):
def _disconnect_partner(self):
self._client.disconnect_partner()
# some outputs to record the "mood" ..
@_machine.output()
def _mood_happy(self):
self._mood = "happy"
@_machine.output()
def _mood_lonely(self):
self._mood = "lonely"
@_machine.output()
def _mood_impatient(self):
self._mood = "impatient"
@_machine.output()
def _mood_errory(self):
self._mood = "errory"
@_machine.output()
def _mood_happy_if_first(self):
"""
We disconnected first so we're only happy if we also connected
first.
"""
if self._first:
self._mood = "happy"
else:
self._mood = "jilted"
@_machine.output()
def _mood_happy_if_second(self):
"""
We disconnected second so we're only happy if we also connected
second.
"""
if self._first:
self._mood = "jilted"
else:
self._mood = "happy"
def _real_register_token_for_side(self, token, side):
"""
basically, _got_handshake() + connection_got_token() from "real"
code ...and if this is the "second" side, hook them up and
pass .got_partner() input to both
A client has connected and sent a valid version 1 or version 2
handshake. If the former, `side` will be None.
In either case, we remember the tokens and register
ourselves. This might result in 'got_partner' notifications to
two state-machines if this is the second side for a given token.
:param bytes token: the token
:param bytes side: The side token (or None)
"""
self._token = token
self._side = side
self._first = self._pending_requests.register_token(token, side, self)
@_machine.state(initial=True)
def listening(self):
@ -217,17 +389,22 @@ class TransitServer(object):
wait_relay.upon(
please_relay,
enter=wait_partner,
outputs=[_register_token],
outputs=[_mood_lonely, _register_token],
)
wait_relay.upon(
please_relay_for_side,
enter=wait_partner,
outputs=[_register_token_for_side],
outputs=[_mood_lonely, _register_token_for_side],
)
wait_relay.upon(
bad_token,
enter=done,
outputs=[_send_bad, _disconnect],
outputs=[_mood_errory, _send_bad, _disconnect],
)
wait_relay.upon(
got_bytes,
enter=done,
outputs=[_mood_errory, _disconnect],
)
wait_relay.upon(
connection_lost,
@ -238,12 +415,17 @@ class TransitServer(object):
wait_partner.upon(
got_partner,
enter=relaying,
outputs=[_send_ok, _connect_partner],
outputs=[_mood_happy, _send_ok, _connect_partner],
)
wait_partner.upon(
connection_lost,
enter=done,
outputs=[_unregister],
outputs=[_mood_lonely, _unregister],
)
wait_partner.upon(
got_bytes,
enter=done,
outputs=[_mood_impatient, _send_impatient, _disconnect, _unregister],
)
relaying.upon(
@ -254,12 +436,23 @@ class TransitServer(object):
relaying.upon(
connection_lost,
enter=done,
outputs=[_disconnect_partner, _unregister],
outputs=[_mood_happy_if_first, _disconnect_partner, _unregister],
)
relaying.upon(
partner_connection_lost,
enter=done,
outputs=[_disconnect, _unregister],
outputs=[_mood_happy_if_second, _disconnect, _unregister],
)
done.upon(
connection_lost,
enter=done,
outputs=[],
)
done.upon(
partner_connection_lost,
enter=done,
outputs=[],
)
@ -272,9 +465,12 @@ class TransitServer(object):
# - ...
if __name__ == "__main__":
server0 = TransitServer()
active = ActiveConnections()
pending = PendingRequests(active)
server0 = TransitServerState(pending)
client0 = TestClient()
server1 = TransitServer()
server1 = TransitServerState(pending)
client1 = TestClient()
server0.connection_made(client0)
server0.please_relay(b"bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb")
@ -282,12 +478,14 @@ if __name__ == "__main__":
# this would be an error, because our partner hasn't shown up yet
# print(server0.got_bytes(b"asdf"))
print("about to relay client1")
server1.connection_made(client1)
server1.please_relay(b"bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb")
print("done")
# XXX the PendingRequests stuff should do this, going "by hand" for now
server0.got_partner(client1)
server1.got_partner(client0)
# server0.got_partner(client1)
# server1.got_partner(client0)
# should be connected now
server0.got_bytes(b"asdf")

View File

@ -24,6 +24,17 @@ def blur_size(size):
return round_to(size, 1e6)
return round_to(size, 100e6)
from wormhole_transit_relay.server_state import (
TransitServerState,
PendingRequests,
ActiveConnections,
ITransitClient,
)
from zope.interface import implementer
@implementer(ITransitClient)
class TransitConnection(LineReceiver):
delimiter = b'\n'
# maximum length of a line we will accept before the handshake is complete.
@ -32,13 +43,33 @@ class TransitConnection(LineReceiver):
MAX_LENGTH = 1024
def __init__(self):
self._got_token = False
self._got_side = False
self._sent_ok = False
self._mood = "empty"
self._buddy = None
self._total_sent = 0
def send(self, data):
"""
ITransitClient API
"""
self.transport.write(data)
def disconnect(self):
"""
ITransitClient API
"""
self.transport.loseConnection()
def connect_partner(self, other):
"""
ITransitClient API
"""
self._buddy = other
def disconnect_partner(self):
"""
ITransitClient API
"""
self._buddy._client.transport.loseConnection()
self._buddy = None
def describeToken(self):
d = "-"
if self._got_token:
@ -50,6 +81,8 @@ class TransitConnection(LineReceiver):
return d
def connectionMade(self):
self._state = TransitServerState(self.factory.pending_requests)
self._state.connection_made(self)
self._started = time.time()
self._log_requests = self.factory._log_requests
try:
@ -71,10 +104,10 @@ class TransitConnection(LineReceiver):
side = new.group(2)
return self._got_handshake(token, side)
self.sendLine(b"bad handshake")
if self._log_requests:
log.msg("transit handshake failure")
return self.disconnect_error()
# state-machine calls us via ITransitClient interface to do
# bad handshake etc.
return self._state.bad_token()
#return self._state.got_bytes(line)
def rawDataReceived(self, data):
# We are an IPushProducer to our buddy's IConsumer, so they'll
@ -83,33 +116,15 @@ class TransitConnection(LineReceiver):
# practice, this buffers about 10MB per connection, after which
# point the sender will only transmit data as fast as the
# receiver can handle it.
if self._sent_ok:
# if self._buddy is None then our buddy disconnected
# (we're "jilted"), so we hung up too, but our incoming
# data hasn't stopped yet (it will in a moment, after our
# disconnect makes a roundtrip through the kernel). This
# probably means the file receiver hung up, and this
# connection is the file sender. In may-2020 this happened
# 11 times in 40 days.
if self._buddy:
self._state.got_bytes(data)
self._total_sent += len(data)
self._buddy.transport.write(data)
return
# handshake is complete but not yet sent_ok
self.sendLine(b"impatient")
if self._log_requests:
log.msg("transit impatience failure")
return self.disconnect_error() # impatience yields failure
def _got_handshake(self, token, side):
self._got_token = token
self._got_side = side
self._mood = "lonely" # until buddy connects
self._state.please_relay_for_side(token, side)
# self._mood = "lonely" # until buddy connects
self.setRawMode()
self.factory.connection_got_token(token, side, self)
def buddy_connected(self, them):
def __buddy_connected(self, them):
self._buddy = them
self._mood = "happy"
self.sendLine(b"ok")
@ -121,7 +136,7 @@ class TransitConnection(LineReceiver):
# The Transit object calls buddy_connected() on both protocols, so
# there will be two producer/consumer pairs.
def buddy_disconnected(self):
def __buddy_disconnected(self):
if self._log_requests:
log.msg("buddy_disconnected %s" % self.describeToken())
self._buddy = None
@ -145,7 +160,11 @@ class TransitConnection(LineReceiver):
def connectionLost(self, reason):
finished = time.time()
total_time = finished - self._started
self._state.connection_lost()
# XXX FIXME record usage
if False:
# Record usage. There are eight cases:
# * n0: we haven't gotten a full handshake yet (empty)
# * n1: the handshake failed, not a real client (errory)
@ -193,8 +212,10 @@ class TransitConnection(LineReceiver):
if self._buddy:
self._buddy.buddy_disconnected()
self.factory.transitFinished(self, self._got_token, self._got_side,
self.describeToken())
# self.factory.transitFinished(self, self._got_token, self._got_side,
# self.describeToken())
class Transit(protocol.ServerFactory):
# I manage pairs of simultaneous connections to a secondary TCP port,
@ -230,6 +251,8 @@ class Transit(protocol.ServerFactory):
protocol = TransitConnection
def __init__(self, blur_usage, log_file, usage_db):
self.active_connections = ActiveConnections()
self.pending_requests = PendingRequests(self.active_connections)
self._blur_usage = blur_usage
self._log_requests = blur_usage is None
if self._blur_usage:
@ -247,39 +270,6 @@ class Transit(protocol.ServerFactory):
self._pending_requests = defaultdict(set) # token -> set((side, TransitConnection))
self._active_connections = set() # TransitConnection
def connection_got_token(self, token, new_side, new_tc):
potentials = self._pending_requests[token]
for old in potentials:
(old_side, old_tc) = old
if ((old_side is None)
or (new_side is None)
or (old_side != new_side)):
# we found a match
if self._debug_log:
log.msg("transit relay 2: %s" % new_tc.describeToken())
# drop and stop tracking the rest
potentials.remove(old)
for (_, leftover_tc) in potentials.copy():
# Don't record this as errory. It's just a spare connection
# from the same side as a connection that got used. This
# can happen if the connection hint contains multiple
# addresses (we don't currently support those, but it'd
# probably be useful in the future).
leftover_tc.disconnect_redundant()
self._pending_requests.pop(token, None)
# glue the two ends together
self._active_connections.add(new_tc)
self._active_connections.add(old_tc)
new_tc.buddy_connected(old_tc)
old_tc.buddy_connected(new_tc)
return
if self._debug_log:
log.msg("transit relay 1: %s" % new_tc.describeToken())
potentials.add((new_side, new_tc))
# TODO: timer
def transitFinished(self, tc, token, side, description):
if token in self._pending_requests:
side_tc = (side, tc)