(wip) refactor to use Automat state-machine
This commit is contained in:
parent
555c23d4fe
commit
0e11f1b8f1
|
@ -1,3 +1,4 @@
|
|||
from collections import defaultdict
|
||||
|
||||
import automat
|
||||
from zope.interface import (
|
||||
|
@ -12,7 +13,7 @@ class ITransitClient(Interface):
|
|||
Send some byets to the client
|
||||
"""
|
||||
|
||||
def disconnect(reason):
|
||||
def disconnect():
|
||||
"""
|
||||
Disconnect the client transport
|
||||
"""
|
||||
|
@ -37,7 +38,7 @@ class TestClient(object):
|
|||
def send_to_partner(self, data):
|
||||
print("{} GOT:{}".format(id(self), repr(data)))
|
||||
if self._partner:
|
||||
self._partner.send(data)
|
||||
self._partner._client.send(data)
|
||||
|
||||
def send(self, data):
|
||||
print("{} SEND:{}".format(id(self), repr(data)))
|
||||
|
@ -56,18 +57,94 @@ class TestClient(object):
|
|||
print("disconnect_partner: {}".format(id(self._partner)))
|
||||
|
||||
|
||||
class ActiveConnections(object):
|
||||
"""
|
||||
Tracks active connections. A connection is 'active' when both
|
||||
sides have shown up and they are glued together.
|
||||
"""
|
||||
def __init__(self):
|
||||
self._connections = set()
|
||||
|
||||
def register(self, side0, side1):
|
||||
"""
|
||||
A connection has become active so register both its sides
|
||||
|
||||
:param TransitConnection side0: one side of the connection
|
||||
:param TransitConnection side1: one side of the connection
|
||||
"""
|
||||
self._connections.add(side0)
|
||||
self._connections.add(side1)
|
||||
|
||||
def unregister(self, side):
|
||||
"""
|
||||
One side of a connection has become inactive.
|
||||
|
||||
:param TransitConnection side: an inactive side of a connection
|
||||
"""
|
||||
self._connections.discard(side)
|
||||
|
||||
|
||||
class PendingRequests(object):
|
||||
"""
|
||||
Tracks the tokens we have received from client connections and
|
||||
maps them to their partner connections
|
||||
maps them to their partner connections for requests that haven't
|
||||
yet been 'glued together' (that is, one side hasn't yet shown up).
|
||||
"""
|
||||
|
||||
def register_token(self, *args):
|
||||
def __init__(self, active_connections):
|
||||
self._requests = defaultdict(set) # token -> set((side, TransitConnection))
|
||||
self._active = active_connections
|
||||
|
||||
def unregister(self, token, side, tc):
|
||||
if token in self._requests:
|
||||
self._requests[token].discard((side, tc))
|
||||
self._active.unregister(tc)
|
||||
|
||||
def register_token(self, token, new_side, new_tc):
|
||||
"""
|
||||
A client has connected and successfully offered a token (and
|
||||
optional 'side' token). If this is the first one for this
|
||||
token, we merely remember it. If it is the second side for
|
||||
this token we connect them together.
|
||||
|
||||
:returns bool: True if we are the first side to register this
|
||||
token
|
||||
"""
|
||||
potentials = self._requests[token]
|
||||
for old in potentials:
|
||||
(old_side, old_tc) = old
|
||||
if ((old_side is None)
|
||||
or (new_side is None)
|
||||
or (old_side != new_side)):
|
||||
# we found a match
|
||||
# FIXME: debug-log this
|
||||
# print("transit relay 2: %s" % new_tc.get_token())
|
||||
|
||||
# drop and stop tracking the rest
|
||||
potentials.remove(old)
|
||||
for (_, leftover_tc) in potentials.copy():
|
||||
# Don't record this as errory. It's just a spare connection
|
||||
# from the same side as a connection that got used. This
|
||||
# can happen if the connection hint contains multiple
|
||||
# addresses (we don't currently support those, but it'd
|
||||
# probably be useful in the future).
|
||||
leftover_tc.disconnect_redundant()
|
||||
self._requests.pop(token, None)
|
||||
|
||||
# glue the two ends together
|
||||
self._active.register(new_tc, old_tc)
|
||||
new_tc.got_partner(old_tc)
|
||||
old_tc.got_partner(new_tc)
|
||||
return False
|
||||
|
||||
# FIXME: debug-log this
|
||||
# print("transit relay 1: %s" % new_tc.get_token())
|
||||
potentials.add((new_side, new_tc))
|
||||
return True
|
||||
# TODO: timer
|
||||
|
||||
|
||||
class TransitServer(object):
|
||||
class TransitServerState(object):
|
||||
"""
|
||||
Encapsulates the state-machine of the server side of a transit
|
||||
relay connection.
|
||||
|
@ -79,6 +156,36 @@ class TransitServer(object):
|
|||
|
||||
_machine = automat.MethodicalMachine()
|
||||
_client = None
|
||||
_buddy = None
|
||||
_token = None
|
||||
_side = None
|
||||
_first = None
|
||||
_mood = "empty"
|
||||
|
||||
def __init__(self, pending_requests):
|
||||
self._pending_requests = pending_requests
|
||||
|
||||
def get_token(self):
|
||||
"""
|
||||
:returns str: a string describing our token. This will be "-" if
|
||||
we have no token yet, or "{16 chars}-<unsided>" if we have
|
||||
just a token or "{16 chars}-{16 chars}" if we have a token and
|
||||
a side.
|
||||
"""
|
||||
d = "-"
|
||||
if self._token is not None:
|
||||
d = self._token[:16].decode("ascii")
|
||||
if self._side is not None:
|
||||
d += "-" + self._side.decode("ascii")
|
||||
else:
|
||||
d += "-<unsided>"
|
||||
return d
|
||||
|
||||
def get_mood(self):
|
||||
"""
|
||||
:returns str: description of the current 'mood' of the connection
|
||||
"""
|
||||
return self._mood
|
||||
|
||||
@_machine.input()
|
||||
def connection_made(self, client):
|
||||
|
@ -93,16 +200,22 @@ class TransitServer(object):
|
|||
|
||||
@_machine.input()
|
||||
def please_relay(self, token):
|
||||
pass
|
||||
"""
|
||||
A 'please relay X' message has been received (the original version
|
||||
of the protocol).
|
||||
"""
|
||||
|
||||
@_machine.input()
|
||||
def please_relay_for_side(self, token, side):
|
||||
pass
|
||||
"""
|
||||
A 'please relay X for side Y' message has been received (the
|
||||
second version of the protocol).
|
||||
"""
|
||||
|
||||
@_machine.input()
|
||||
def bad_token(self):
|
||||
"""
|
||||
A bad token / relay line was received
|
||||
A bad token / relay line was received (e.g. couldn't be parsed)
|
||||
"""
|
||||
|
||||
@_machine.input()
|
||||
|
@ -113,11 +226,15 @@ class TransitServer(object):
|
|||
|
||||
@_machine.input()
|
||||
def connection_lost(self):
|
||||
pass
|
||||
"""
|
||||
Our transport has failed.
|
||||
"""
|
||||
|
||||
@_machine.input()
|
||||
def partner_connection_lost(self):
|
||||
pass
|
||||
"""
|
||||
Our partner's transport has failed.
|
||||
"""
|
||||
|
||||
@_machine.input()
|
||||
def got_bytes(self, data):
|
||||
|
@ -142,14 +259,20 @@ class TransitServer(object):
|
|||
"""
|
||||
remove us from the thing that remembers tokens and sides
|
||||
"""
|
||||
return self._pending_requests.unregister(self._token, self._side, self)
|
||||
|
||||
@_machine.output()
|
||||
def _send_bad(self):
|
||||
self._client.send("bad handshake\n")
|
||||
self._mood = "errory"
|
||||
self._client.send(b"bad handshake\n")
|
||||
|
||||
@_machine.output()
|
||||
def _send_ok(self):
|
||||
self._client.send("ok\n")
|
||||
self._client.send(b"ok\n")
|
||||
|
||||
@_machine.output()
|
||||
def _send_impatient(self):
|
||||
self._client.send(b"impatient\n")
|
||||
|
||||
@_machine.output()
|
||||
def _send(self, data):
|
||||
|
@ -157,10 +280,11 @@ class TransitServer(object):
|
|||
|
||||
@_machine.output()
|
||||
def _send_to_partner(self, data):
|
||||
self._client.send_to_partner(data)
|
||||
self._buddy._client.send(data)
|
||||
|
||||
@_machine.output()
|
||||
def _connect_partner(self, client):
|
||||
self._buddy = client
|
||||
self._client.connect_partner(client)
|
||||
|
||||
@_machine.output()
|
||||
|
@ -171,12 +295,60 @@ class TransitServer(object):
|
|||
def _disconnect_partner(self):
|
||||
self._client.disconnect_partner()
|
||||
|
||||
# some outputs to record the "mood" ..
|
||||
@_machine.output()
|
||||
def _mood_happy(self):
|
||||
self._mood = "happy"
|
||||
|
||||
@_machine.output()
|
||||
def _mood_lonely(self):
|
||||
self._mood = "lonely"
|
||||
|
||||
@_machine.output()
|
||||
def _mood_impatient(self):
|
||||
self._mood = "impatient"
|
||||
|
||||
@_machine.output()
|
||||
def _mood_errory(self):
|
||||
self._mood = "errory"
|
||||
|
||||
@_machine.output()
|
||||
def _mood_happy_if_first(self):
|
||||
"""
|
||||
We disconnected first so we're only happy if we also connected
|
||||
first.
|
||||
"""
|
||||
if self._first:
|
||||
self._mood = "happy"
|
||||
else:
|
||||
self._mood = "jilted"
|
||||
|
||||
@_machine.output()
|
||||
def _mood_happy_if_second(self):
|
||||
"""
|
||||
We disconnected second so we're only happy if we also connected
|
||||
second.
|
||||
"""
|
||||
if self._first:
|
||||
self._mood = "jilted"
|
||||
else:
|
||||
self._mood = "happy"
|
||||
|
||||
def _real_register_token_for_side(self, token, side):
|
||||
"""
|
||||
basically, _got_handshake() + connection_got_token() from "real"
|
||||
code ...and if this is the "second" side, hook them up and
|
||||
pass .got_partner() input to both
|
||||
A client has connected and sent a valid version 1 or version 2
|
||||
handshake. If the former, `side` will be None.
|
||||
|
||||
In either case, we remember the tokens and register
|
||||
ourselves. This might result in 'got_partner' notifications to
|
||||
two state-machines if this is the second side for a given token.
|
||||
|
||||
:param bytes token: the token
|
||||
:param bytes side: The side token (or None)
|
||||
"""
|
||||
self._token = token
|
||||
self._side = side
|
||||
self._first = self._pending_requests.register_token(token, side, self)
|
||||
|
||||
@_machine.state(initial=True)
|
||||
def listening(self):
|
||||
|
@ -217,17 +389,22 @@ class TransitServer(object):
|
|||
wait_relay.upon(
|
||||
please_relay,
|
||||
enter=wait_partner,
|
||||
outputs=[_register_token],
|
||||
outputs=[_mood_lonely, _register_token],
|
||||
)
|
||||
wait_relay.upon(
|
||||
please_relay_for_side,
|
||||
enter=wait_partner,
|
||||
outputs=[_register_token_for_side],
|
||||
outputs=[_mood_lonely, _register_token_for_side],
|
||||
)
|
||||
wait_relay.upon(
|
||||
bad_token,
|
||||
enter=done,
|
||||
outputs=[_send_bad, _disconnect],
|
||||
outputs=[_mood_errory, _send_bad, _disconnect],
|
||||
)
|
||||
wait_relay.upon(
|
||||
got_bytes,
|
||||
enter=done,
|
||||
outputs=[_mood_errory, _disconnect],
|
||||
)
|
||||
wait_relay.upon(
|
||||
connection_lost,
|
||||
|
@ -238,12 +415,17 @@ class TransitServer(object):
|
|||
wait_partner.upon(
|
||||
got_partner,
|
||||
enter=relaying,
|
||||
outputs=[_send_ok, _connect_partner],
|
||||
outputs=[_mood_happy, _send_ok, _connect_partner],
|
||||
)
|
||||
wait_partner.upon(
|
||||
connection_lost,
|
||||
enter=done,
|
||||
outputs=[_unregister],
|
||||
outputs=[_mood_lonely, _unregister],
|
||||
)
|
||||
wait_partner.upon(
|
||||
got_bytes,
|
||||
enter=done,
|
||||
outputs=[_mood_impatient, _send_impatient, _disconnect, _unregister],
|
||||
)
|
||||
|
||||
relaying.upon(
|
||||
|
@ -254,12 +436,23 @@ class TransitServer(object):
|
|||
relaying.upon(
|
||||
connection_lost,
|
||||
enter=done,
|
||||
outputs=[_disconnect_partner, _unregister],
|
||||
outputs=[_mood_happy_if_first, _disconnect_partner, _unregister],
|
||||
)
|
||||
relaying.upon(
|
||||
partner_connection_lost,
|
||||
enter=done,
|
||||
outputs=[_disconnect, _unregister],
|
||||
outputs=[_mood_happy_if_second, _disconnect, _unregister],
|
||||
)
|
||||
|
||||
done.upon(
|
||||
connection_lost,
|
||||
enter=done,
|
||||
outputs=[],
|
||||
)
|
||||
done.upon(
|
||||
partner_connection_lost,
|
||||
enter=done,
|
||||
outputs=[],
|
||||
)
|
||||
|
||||
|
||||
|
@ -272,9 +465,12 @@ class TransitServer(object):
|
|||
# - ...
|
||||
|
||||
if __name__ == "__main__":
|
||||
server0 = TransitServer()
|
||||
active = ActiveConnections()
|
||||
pending = PendingRequests(active)
|
||||
|
||||
server0 = TransitServerState(pending)
|
||||
client0 = TestClient()
|
||||
server1 = TransitServer()
|
||||
server1 = TransitServerState(pending)
|
||||
client1 = TestClient()
|
||||
server0.connection_made(client0)
|
||||
server0.please_relay(b"bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb")
|
||||
|
@ -282,12 +478,14 @@ if __name__ == "__main__":
|
|||
# this would be an error, because our partner hasn't shown up yet
|
||||
# print(server0.got_bytes(b"asdf"))
|
||||
|
||||
print("about to relay client1")
|
||||
server1.connection_made(client1)
|
||||
server1.please_relay(b"bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb")
|
||||
print("done")
|
||||
|
||||
# XXX the PendingRequests stuff should do this, going "by hand" for now
|
||||
server0.got_partner(client1)
|
||||
server1.got_partner(client0)
|
||||
# server0.got_partner(client1)
|
||||
# server1.got_partner(client0)
|
||||
|
||||
# should be connected now
|
||||
server0.got_bytes(b"asdf")
|
||||
|
|
|
@ -24,6 +24,17 @@ def blur_size(size):
|
|||
return round_to(size, 1e6)
|
||||
return round_to(size, 100e6)
|
||||
|
||||
|
||||
from wormhole_transit_relay.server_state import (
|
||||
TransitServerState,
|
||||
PendingRequests,
|
||||
ActiveConnections,
|
||||
ITransitClient,
|
||||
)
|
||||
from zope.interface import implementer
|
||||
|
||||
|
||||
@implementer(ITransitClient)
|
||||
class TransitConnection(LineReceiver):
|
||||
delimiter = b'\n'
|
||||
# maximum length of a line we will accept before the handshake is complete.
|
||||
|
@ -32,13 +43,33 @@ class TransitConnection(LineReceiver):
|
|||
MAX_LENGTH = 1024
|
||||
|
||||
def __init__(self):
|
||||
self._got_token = False
|
||||
self._got_side = False
|
||||
self._sent_ok = False
|
||||
self._mood = "empty"
|
||||
self._buddy = None
|
||||
self._total_sent = 0
|
||||
|
||||
def send(self, data):
|
||||
"""
|
||||
ITransitClient API
|
||||
"""
|
||||
self.transport.write(data)
|
||||
|
||||
def disconnect(self):
|
||||
"""
|
||||
ITransitClient API
|
||||
"""
|
||||
self.transport.loseConnection()
|
||||
|
||||
def connect_partner(self, other):
|
||||
"""
|
||||
ITransitClient API
|
||||
"""
|
||||
self._buddy = other
|
||||
|
||||
def disconnect_partner(self):
|
||||
"""
|
||||
ITransitClient API
|
||||
"""
|
||||
self._buddy._client.transport.loseConnection()
|
||||
self._buddy = None
|
||||
|
||||
def describeToken(self):
|
||||
d = "-"
|
||||
if self._got_token:
|
||||
|
@ -50,6 +81,8 @@ class TransitConnection(LineReceiver):
|
|||
return d
|
||||
|
||||
def connectionMade(self):
|
||||
self._state = TransitServerState(self.factory.pending_requests)
|
||||
self._state.connection_made(self)
|
||||
self._started = time.time()
|
||||
self._log_requests = self.factory._log_requests
|
||||
try:
|
||||
|
@ -71,10 +104,10 @@ class TransitConnection(LineReceiver):
|
|||
side = new.group(2)
|
||||
return self._got_handshake(token, side)
|
||||
|
||||
self.sendLine(b"bad handshake")
|
||||
if self._log_requests:
|
||||
log.msg("transit handshake failure")
|
||||
return self.disconnect_error()
|
||||
# state-machine calls us via ITransitClient interface to do
|
||||
# bad handshake etc.
|
||||
return self._state.bad_token()
|
||||
#return self._state.got_bytes(line)
|
||||
|
||||
def rawDataReceived(self, data):
|
||||
# We are an IPushProducer to our buddy's IConsumer, so they'll
|
||||
|
@ -83,33 +116,15 @@ class TransitConnection(LineReceiver):
|
|||
# practice, this buffers about 10MB per connection, after which
|
||||
# point the sender will only transmit data as fast as the
|
||||
# receiver can handle it.
|
||||
if self._sent_ok:
|
||||
# if self._buddy is None then our buddy disconnected
|
||||
# (we're "jilted"), so we hung up too, but our incoming
|
||||
# data hasn't stopped yet (it will in a moment, after our
|
||||
# disconnect makes a roundtrip through the kernel). This
|
||||
# probably means the file receiver hung up, and this
|
||||
# connection is the file sender. In may-2020 this happened
|
||||
# 11 times in 40 days.
|
||||
if self._buddy:
|
||||
self._state.got_bytes(data)
|
||||
self._total_sent += len(data)
|
||||
self._buddy.transport.write(data)
|
||||
return
|
||||
|
||||
# handshake is complete but not yet sent_ok
|
||||
self.sendLine(b"impatient")
|
||||
if self._log_requests:
|
||||
log.msg("transit impatience failure")
|
||||
return self.disconnect_error() # impatience yields failure
|
||||
|
||||
def _got_handshake(self, token, side):
|
||||
self._got_token = token
|
||||
self._got_side = side
|
||||
self._mood = "lonely" # until buddy connects
|
||||
self._state.please_relay_for_side(token, side)
|
||||
# self._mood = "lonely" # until buddy connects
|
||||
self.setRawMode()
|
||||
self.factory.connection_got_token(token, side, self)
|
||||
|
||||
def buddy_connected(self, them):
|
||||
def __buddy_connected(self, them):
|
||||
self._buddy = them
|
||||
self._mood = "happy"
|
||||
self.sendLine(b"ok")
|
||||
|
@ -121,7 +136,7 @@ class TransitConnection(LineReceiver):
|
|||
# The Transit object calls buddy_connected() on both protocols, so
|
||||
# there will be two producer/consumer pairs.
|
||||
|
||||
def buddy_disconnected(self):
|
||||
def __buddy_disconnected(self):
|
||||
if self._log_requests:
|
||||
log.msg("buddy_disconnected %s" % self.describeToken())
|
||||
self._buddy = None
|
||||
|
@ -145,7 +160,11 @@ class TransitConnection(LineReceiver):
|
|||
def connectionLost(self, reason):
|
||||
finished = time.time()
|
||||
total_time = finished - self._started
|
||||
self._state.connection_lost()
|
||||
|
||||
# XXX FIXME record usage
|
||||
|
||||
if False:
|
||||
# Record usage. There are eight cases:
|
||||
# * n0: we haven't gotten a full handshake yet (empty)
|
||||
# * n1: the handshake failed, not a real client (errory)
|
||||
|
@ -193,8 +212,10 @@ class TransitConnection(LineReceiver):
|
|||
|
||||
if self._buddy:
|
||||
self._buddy.buddy_disconnected()
|
||||
self.factory.transitFinished(self, self._got_token, self._got_side,
|
||||
self.describeToken())
|
||||
# self.factory.transitFinished(self, self._got_token, self._got_side,
|
||||
# self.describeToken())
|
||||
|
||||
|
||||
|
||||
class Transit(protocol.ServerFactory):
|
||||
# I manage pairs of simultaneous connections to a secondary TCP port,
|
||||
|
@ -230,6 +251,8 @@ class Transit(protocol.ServerFactory):
|
|||
protocol = TransitConnection
|
||||
|
||||
def __init__(self, blur_usage, log_file, usage_db):
|
||||
self.active_connections = ActiveConnections()
|
||||
self.pending_requests = PendingRequests(self.active_connections)
|
||||
self._blur_usage = blur_usage
|
||||
self._log_requests = blur_usage is None
|
||||
if self._blur_usage:
|
||||
|
@ -247,39 +270,6 @@ class Transit(protocol.ServerFactory):
|
|||
self._pending_requests = defaultdict(set) # token -> set((side, TransitConnection))
|
||||
self._active_connections = set() # TransitConnection
|
||||
|
||||
def connection_got_token(self, token, new_side, new_tc):
|
||||
potentials = self._pending_requests[token]
|
||||
for old in potentials:
|
||||
(old_side, old_tc) = old
|
||||
if ((old_side is None)
|
||||
or (new_side is None)
|
||||
or (old_side != new_side)):
|
||||
# we found a match
|
||||
if self._debug_log:
|
||||
log.msg("transit relay 2: %s" % new_tc.describeToken())
|
||||
|
||||
# drop and stop tracking the rest
|
||||
potentials.remove(old)
|
||||
for (_, leftover_tc) in potentials.copy():
|
||||
# Don't record this as errory. It's just a spare connection
|
||||
# from the same side as a connection that got used. This
|
||||
# can happen if the connection hint contains multiple
|
||||
# addresses (we don't currently support those, but it'd
|
||||
# probably be useful in the future).
|
||||
leftover_tc.disconnect_redundant()
|
||||
self._pending_requests.pop(token, None)
|
||||
|
||||
# glue the two ends together
|
||||
self._active_connections.add(new_tc)
|
||||
self._active_connections.add(old_tc)
|
||||
new_tc.buddy_connected(old_tc)
|
||||
old_tc.buddy_connected(new_tc)
|
||||
return
|
||||
if self._debug_log:
|
||||
log.msg("transit relay 1: %s" % new_tc.describeToken())
|
||||
potentials.add((new_side, new_tc))
|
||||
# TODO: timer
|
||||
|
||||
def transitFinished(self, tc, token, side, description):
|
||||
if token in self._pending_requests:
|
||||
side_tc = (side, tc)
|
||||
|
|
Loading…
Reference in New Issue
Block a user