fix some flake8 complaints
This commit is contained in:
parent
39666f3fed
commit
05900bd08b
|
@ -67,7 +67,8 @@ class Boss(object):
|
|||
self._I = Input(self._timing)
|
||||
self._C = Code(self._timing)
|
||||
self._T = Terminator()
|
||||
self._D = Dilator(self._reactor, self._eventual_queue, self._cooperator)
|
||||
self._D = Dilator(self._reactor, self._eventual_queue,
|
||||
self._cooperator)
|
||||
|
||||
self._N.wire(self._M, self._I, self._RC, self._T)
|
||||
self._M.wire(self._N, self._RC, self._O, self._T)
|
||||
|
@ -90,7 +91,7 @@ class Boss(object):
|
|||
self._rx_phases = {} # phase -> plaintext
|
||||
|
||||
self._next_rx_dilate_seqnum = 0
|
||||
self._rx_dilate_seqnums = {} # seqnum -> plaintext
|
||||
self._rx_dilate_seqnums = {} # seqnum -> plaintext
|
||||
|
||||
self._result = "empty"
|
||||
|
||||
|
@ -205,7 +206,7 @@ class Boss(object):
|
|||
self._C.set_code(code)
|
||||
|
||||
def dilate(self):
|
||||
return self._D.dilate() # fires with endpoints
|
||||
return self._D.dilate() # fires with endpoints
|
||||
|
||||
@m.input()
|
||||
def send(self, plaintext):
|
||||
|
|
|
@ -39,20 +39,28 @@ from .roles import FOLLOWER
|
|||
# states). For the specific question of sending plaintext frames, Noise will
|
||||
# refuse us unless it's ready anyways, so the question is probably moot.
|
||||
|
||||
|
||||
class IFramer(Interface):
|
||||
pass
|
||||
|
||||
|
||||
class IRecord(Interface):
|
||||
pass
|
||||
|
||||
|
||||
def first(l):
|
||||
return l[0]
|
||||
|
||||
|
||||
class Disconnect(Exception):
|
||||
pass
|
||||
|
||||
|
||||
RelayOK = namedtuple("RelayOk", [])
|
||||
Prologue = namedtuple("Prologue", [])
|
||||
Frame = namedtuple("Frame", ["frame"])
|
||||
|
||||
|
||||
@attrs
|
||||
@implementer(IFramer)
|
||||
class _Framer(object):
|
||||
|
@ -69,30 +77,37 @@ class _Framer(object):
|
|||
# out (shared): transport.write (relay handshake, prologue)
|
||||
# states: want_relay, want_prologue, want_frame
|
||||
m = MethodicalMachine()
|
||||
set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover
|
||||
set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover
|
||||
|
||||
@m.state()
|
||||
def want_relay(self): pass # pragma: no cover
|
||||
def want_relay(self): pass # pragma: no cover
|
||||
|
||||
@m.state(initial=True)
|
||||
def want_prologue(self): pass # pragma: no cover
|
||||
def want_prologue(self): pass # pragma: no cover
|
||||
|
||||
@m.state()
|
||||
def want_frame(self): pass # pragma: no cover
|
||||
def want_frame(self): pass # pragma: no cover
|
||||
|
||||
@m.input()
|
||||
def use_relay(self, relay_handshake): pass
|
||||
|
||||
@m.input()
|
||||
def connectionMade(self): pass
|
||||
|
||||
@m.input()
|
||||
def parse(self): pass
|
||||
|
||||
@m.input()
|
||||
def got_relay_ok(self): pass
|
||||
|
||||
@m.input()
|
||||
def got_prologue(self): pass
|
||||
|
||||
@m.output()
|
||||
def store_relay_handshake(self, relay_handshake):
|
||||
self._outbound_relay_handshake = relay_handshake
|
||||
self._expected_relay_handshake = b"ok\n" # TODO: make this configurable
|
||||
self._expected_relay_handshake = b"ok\n" # TODO: make this configurable
|
||||
|
||||
@m.output()
|
||||
def send_relay_handshake(self):
|
||||
self._transport.write(self._outbound_relay_handshake)
|
||||
|
@ -113,17 +128,17 @@ class _Framer(object):
|
|||
|
||||
@m.output()
|
||||
def can_send_frames(self):
|
||||
self._can_send_frames = True # for assertion in send_frame()
|
||||
self._can_send_frames = True # for assertion in send_frame()
|
||||
|
||||
@m.output()
|
||||
def parse_frame(self):
|
||||
if len(self._buffer) < 4:
|
||||
return None
|
||||
frame_length = from_be4(self._buffer[0:4])
|
||||
if len(self._buffer) < 4+frame_length:
|
||||
if len(self._buffer) < 4 + frame_length:
|
||||
return None
|
||||
frame = self._buffer[4:4+frame_length]
|
||||
self._buffer = self._buffer[4+frame_length:] # TODO: avoid copy
|
||||
frame = self._buffer[4:4 + frame_length]
|
||||
self._buffer = self._buffer[4 + frame_length:] # TODO: avoid copy
|
||||
return Frame(frame=frame)
|
||||
|
||||
want_prologue.upon(use_relay, outputs=[store_relay_handshake],
|
||||
|
@ -144,7 +159,6 @@ class _Framer(object):
|
|||
want_frame.upon(parse, outputs=[parse_frame], enter=want_frame,
|
||||
collector=first)
|
||||
|
||||
|
||||
def _get_expected(self, name, expected):
|
||||
lb = len(self._buffer)
|
||||
le = len(expected)
|
||||
|
@ -161,7 +175,7 @@ class _Framer(object):
|
|||
if (b"\n" in self._buffer or lb >= le):
|
||||
log.msg("bad {}: {}".format(name, self._buffer[:le]))
|
||||
raise Disconnect()
|
||||
return False # wait a bit longer
|
||||
return False # wait a bit longer
|
||||
# good so far, just waiting for the rest
|
||||
return False
|
||||
|
||||
|
@ -181,7 +195,7 @@ class _Framer(object):
|
|||
self.got_relay_ok()
|
||||
elif isinstance(token, Prologue):
|
||||
self.got_prologue()
|
||||
yield token # triggers send_handshake
|
||||
yield token # triggers send_handshake
|
||||
elif isinstance(token, Frame):
|
||||
yield token
|
||||
else:
|
||||
|
@ -202,15 +216,16 @@ class _Framer(object):
|
|||
# from peer. Sent immediately by Follower, after Selection by Leader.
|
||||
# Record: namedtuple of KCM/Open/Data/Close/Ack/Ping/Pong
|
||||
|
||||
|
||||
Handshake = namedtuple("Handshake", [])
|
||||
# decrypted frames: produces KCM, Ping, Pong, Open, Data, Close, Ack
|
||||
KCM = namedtuple("KCM", [])
|
||||
Ping = namedtuple("Ping", ["ping_id"]) # ping_id is arbitrary 4-byte value
|
||||
Ping = namedtuple("Ping", ["ping_id"]) # ping_id is arbitrary 4-byte value
|
||||
Pong = namedtuple("Pong", ["ping_id"])
|
||||
Open = namedtuple("Open", ["seqnum", "scid"]) # seqnum is integer
|
||||
Open = namedtuple("Open", ["seqnum", "scid"]) # seqnum is integer
|
||||
Data = namedtuple("Data", ["seqnum", "scid", "data"])
|
||||
Close = namedtuple("Close", ["seqnum", "scid"]) # scid is integer
|
||||
Ack = namedtuple("Ack", ["resp_seqnum"]) # resp_seqnum is integer
|
||||
Close = namedtuple("Close", ["seqnum", "scid"]) # scid is integer
|
||||
Ack = namedtuple("Ack", ["resp_seqnum"]) # resp_seqnum is integer
|
||||
Records = (KCM, Ping, Pong, Open, Data, Close, Ack)
|
||||
Handshake_or_Records = (Handshake,) + Records
|
||||
|
||||
|
@ -222,6 +237,7 @@ T_DATA = b"\x04"
|
|||
T_CLOSE = b"\x05"
|
||||
T_ACK = b"\x06"
|
||||
|
||||
|
||||
def parse_record(plaintext):
|
||||
msgtype = plaintext[0:1]
|
||||
if msgtype == T_KCM:
|
||||
|
@ -251,6 +267,7 @@ def parse_record(plaintext):
|
|||
log.err("received unknown message type: {}".format(plaintext))
|
||||
raise ValueError()
|
||||
|
||||
|
||||
def encode_record(r):
|
||||
if isinstance(r, KCM):
|
||||
return b"\x00"
|
||||
|
@ -275,6 +292,7 @@ def encode_record(r):
|
|||
return b"\x06" + to_be4(r.resp_seqnum)
|
||||
raise TypeError(r)
|
||||
|
||||
|
||||
@attrs
|
||||
@implementer(IRecord)
|
||||
class _Record(object):
|
||||
|
@ -294,22 +312,25 @@ class _Record(object):
|
|||
# states: want_prologue, want_handshake, want_record
|
||||
|
||||
@n.state(initial=True)
|
||||
def want_prologue(self): pass # pragma: no cover
|
||||
def want_prologue(self): pass # pragma: no cover
|
||||
|
||||
@n.state()
|
||||
def want_handshake(self): pass # pragma: no cover
|
||||
def want_handshake(self): pass # pragma: no cover
|
||||
|
||||
@n.state()
|
||||
def want_message(self): pass # pragma: no cover
|
||||
def want_message(self): pass # pragma: no cover
|
||||
|
||||
@n.input()
|
||||
def got_prologue(self):
|
||||
pass
|
||||
|
||||
@n.input()
|
||||
def got_frame(self, frame):
|
||||
pass
|
||||
|
||||
@n.output()
|
||||
def send_handshake(self):
|
||||
handshake = self._noise.write_message() # generate the ephemeral key
|
||||
handshake = self._noise.write_message() # generate the ephemeral key
|
||||
self._framer.send_frame(handshake)
|
||||
|
||||
@n.output()
|
||||
|
@ -351,10 +372,10 @@ class _Record(object):
|
|||
def add_and_unframe(self, data):
|
||||
for token in self._framer.add_and_parse(data):
|
||||
if isinstance(token, Prologue):
|
||||
self.got_prologue() # triggers send_handshake
|
||||
self.got_prologue() # triggers send_handshake
|
||||
else:
|
||||
assert isinstance(token, Frame)
|
||||
yield self.got_frame(token.frame) # Handshake or a Record type
|
||||
yield self.got_frame(token.frame) # Handshake or a Record type
|
||||
|
||||
def send_record(self, r):
|
||||
message = encode_record(r)
|
||||
|
@ -388,26 +409,30 @@ class DilatedConnectionProtocol(Protocol, object):
|
|||
_relay_handshake = None
|
||||
|
||||
m = MethodicalMachine()
|
||||
set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover
|
||||
set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
self._manager = None # set if/when we are selected
|
||||
self._manager = None # set if/when we are selected
|
||||
self._disconnected = OneShotObserver(self._eventual_queue)
|
||||
self._can_send_records = False
|
||||
|
||||
@m.state(initial=True)
|
||||
def unselected(self): pass # pragma: no cover
|
||||
def unselected(self): pass # pragma: no cover
|
||||
|
||||
@m.state()
|
||||
def selecting(self): pass # pragma: no cover
|
||||
def selecting(self): pass # pragma: no cover
|
||||
|
||||
@m.state()
|
||||
def selected(self): pass # pragma: no cover
|
||||
def selected(self): pass # pragma: no cover
|
||||
|
||||
@m.input()
|
||||
def got_kcm(self):
|
||||
pass
|
||||
|
||||
@m.input()
|
||||
def select(self, manager):
|
||||
pass # fires set_manager()
|
||||
pass # fires set_manager()
|
||||
|
||||
@m.input()
|
||||
def got_record(self, record):
|
||||
pass
|
||||
|
@ -472,9 +497,9 @@ class DilatedConnectionProtocol(Protocol, object):
|
|||
elif isinstance(token, KCM):
|
||||
# if we're the leader, add this connection as a candiate.
|
||||
# if we're the follower, accept this connection.
|
||||
self.got_kcm() # connector.add_candidate()
|
||||
self.got_kcm() # connector.add_candidate()
|
||||
else:
|
||||
self.got_record(token) # manager.got_record()
|
||||
self.got_record(token) # manager.got_record()
|
||||
except Disconnect:
|
||||
self.transport.loseConnection()
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from __future__ import print_function, unicode_literals
|
||||
import sys, re
|
||||
import sys
|
||||
import re
|
||||
from collections import defaultdict, namedtuple
|
||||
from binascii import hexlify
|
||||
import six
|
||||
|
@ -13,7 +14,7 @@ from twisted.internet.endpoints import HostnameEndpoint, serverFromString
|
|||
from twisted.internet.protocol import ClientFactory, ServerFactory
|
||||
from twisted.python import log
|
||||
from hkdf import Hkdf
|
||||
from .. import ipaddrs # TODO: move into _dilation/
|
||||
from .. import ipaddrs # TODO: move into _dilation/
|
||||
from .._interfaces import IDilationConnector, IDilationManager
|
||||
from ..timing import DebugTiming
|
||||
from ..observer import EmptyableSet
|
||||
|
@ -30,7 +31,8 @@ from .roles import LEADER
|
|||
# * expect to see the receiver/sender handshake bytes from the other side
|
||||
# * the sender writes "go\n", the receiver waits for "go\n"
|
||||
# * the rest of the connection contains transit data
|
||||
DirectTCPV1Hint = namedtuple("DirectTCPV1Hint", ["hostname", "port", "priority"])
|
||||
DirectTCPV1Hint = namedtuple(
|
||||
"DirectTCPV1Hint", ["hostname", "port", "priority"])
|
||||
TorTCPV1Hint = namedtuple("TorTCPV1Hint", ["hostname", "port", "priority"])
|
||||
# RelayV1Hint contains a tuple of DirectTCPV1Hint and TorTCPV1Hint hints (we
|
||||
# use a tuple rather than a list so they'll be hashable into a set). For each
|
||||
|
@ -38,6 +40,7 @@ TorTCPV1Hint = namedtuple("TorTCPV1Hint", ["hostname", "port", "priority"])
|
|||
# rest of the V1 protocol. Only one hint per relay is useful.
|
||||
RelayV1Hint = namedtuple("RelayV1Hint", ["hints"])
|
||||
|
||||
|
||||
def describe_hint_obj(hint, relay, tor):
|
||||
prefix = "tor->" if tor else "->"
|
||||
if relay:
|
||||
|
@ -45,9 +48,10 @@ def describe_hint_obj(hint, relay, tor):
|
|||
if isinstance(hint, DirectTCPV1Hint):
|
||||
return prefix + "tcp:%s:%d" % (hint.hostname, hint.port)
|
||||
elif isinstance(hint, TorTCPV1Hint):
|
||||
return prefix+"tor:%s:%d" % (hint.hostname, hint.port)
|
||||
return prefix + "tor:%s:%d" % (hint.hostname, hint.port)
|
||||
else:
|
||||
return prefix+str(hint)
|
||||
return prefix + str(hint)
|
||||
|
||||
|
||||
def parse_hint_argv(hint, stderr=sys.stderr):
|
||||
assert isinstance(hint, type(""))
|
||||
|
@ -59,7 +63,8 @@ def parse_hint_argv(hint, stderr=sys.stderr):
|
|||
return None
|
||||
hint_type = mo.group(1)
|
||||
if hint_type != "tcp":
|
||||
print("unknown hint type '%s' in '%s'" % (hint_type, hint), file=stderr)
|
||||
print("unknown hint type '%s' in '%s'" % (hint_type, hint),
|
||||
file=stderr)
|
||||
return None
|
||||
hint_value = mo.group(2)
|
||||
pieces = hint_value.split(":")
|
||||
|
@ -84,17 +89,18 @@ def parse_hint_argv(hint, stderr=sys.stderr):
|
|||
return None
|
||||
return DirectTCPV1Hint(hint_host, hint_port, priority)
|
||||
|
||||
def parse_tcp_v1_hint(hint): # hint_struct -> hint_obj
|
||||
|
||||
def parse_tcp_v1_hint(hint): # hint_struct -> hint_obj
|
||||
hint_type = hint.get("type", "")
|
||||
if hint_type not in ["direct-tcp-v1", "tor-tcp-v1"]:
|
||||
log.msg("unknown hint type: %r" % (hint,))
|
||||
return None
|
||||
if not("hostname" in hint
|
||||
and isinstance(hint["hostname"], type(""))):
|
||||
if not("hostname" in hint and
|
||||
isinstance(hint["hostname"], type(""))):
|
||||
log.msg("invalid hostname in hint: %r" % (hint,))
|
||||
return None
|
||||
if not("port" in hint
|
||||
and isinstance(hint["port"], six.integer_types)):
|
||||
if not("port" in hint and
|
||||
isinstance(hint["port"], six.integer_types)):
|
||||
log.msg("invalid port in hint: %r" % (hint,))
|
||||
return None
|
||||
priority = hint.get("priority", 0.0)
|
||||
|
@ -103,51 +109,58 @@ def parse_tcp_v1_hint(hint): # hint_struct -> hint_obj
|
|||
else:
|
||||
return TorTCPV1Hint(hint["hostname"], hint["port"], priority)
|
||||
|
||||
|
||||
def parse_hint(hint_struct):
|
||||
hint_type = hint_struct.get("type", "")
|
||||
if hint_type == "relay-v1":
|
||||
# the struct can include multiple ways to reach the same relay
|
||||
rhints = filter(lambda h: h, # drop None (unrecognized)
|
||||
rhints = filter(lambda h: h, # drop None (unrecognized)
|
||||
[parse_tcp_v1_hint(rh) for rh in hint_struct["hints"]])
|
||||
return RelayV1Hint(rhints)
|
||||
return parse_tcp_v1_hint(hint_struct)
|
||||
|
||||
|
||||
def encode_hint(h):
|
||||
if isinstance(h, DirectTCPV1Hint):
|
||||
return {"type": "direct-tcp-v1",
|
||||
"priority": h.priority,
|
||||
"hostname": h.hostname,
|
||||
"port": h.port, # integer
|
||||
"port": h.port, # integer
|
||||
}
|
||||
elif isinstance(h, RelayV1Hint):
|
||||
rhint = {"type": "relay-v1", "hints": []}
|
||||
for rh in h.hints:
|
||||
rhint["hints"].append({"type": "direct-tcp-v1",
|
||||
"priority": rh.priority,
|
||||
"hostname": rh.hostname,
|
||||
"port": rh.port})
|
||||
"priority": rh.priority,
|
||||
"hostname": rh.hostname,
|
||||
"port": rh.port})
|
||||
return rhint
|
||||
elif isinstance(h, TorTCPV1Hint):
|
||||
return {"type": "tor-tcp-v1",
|
||||
"priority": h.priority,
|
||||
"hostname": h.hostname,
|
||||
"port": h.port, # integer
|
||||
"port": h.port, # integer
|
||||
}
|
||||
raise ValueError("unknown hint type", h)
|
||||
|
||||
|
||||
def HKDF(skm, outlen, salt=None, CTXinfo=b""):
|
||||
return Hkdf(salt, skm).expand(CTXinfo, outlen)
|
||||
|
||||
|
||||
def build_sided_relay_handshake(key, side):
|
||||
assert isinstance(side, type(u""))
|
||||
assert len(side) == 8*2
|
||||
assert len(side) == 8 * 2
|
||||
token = HKDF(key, 32, CTXinfo=b"transit_relay_token")
|
||||
return b"please relay "+hexlify(token)+b" for side "+side.encode("ascii")+b"\n"
|
||||
return (b"please relay " + hexlify(token) +
|
||||
b" for side " + side.encode("ascii") + b"\n")
|
||||
|
||||
PROLOGUE_LEADER = b"Magic-Wormhole Dilation Handshake v1 Leader\n\n"
|
||||
|
||||
PROLOGUE_LEADER = b"Magic-Wormhole Dilation Handshake v1 Leader\n\n"
|
||||
PROLOGUE_FOLLOWER = b"Magic-Wormhole Dilation Handshake v1 Follower\n\n"
|
||||
NOISEPROTO = "Noise_NNpsk0_25519_ChaChaPoly_BLAKE2s"
|
||||
|
||||
|
||||
@attrs
|
||||
@implementer(IDilationConnector)
|
||||
class Connector(object):
|
||||
|
@ -176,10 +189,11 @@ class Connector(object):
|
|||
self._transit_relays = [relay]
|
||||
else:
|
||||
self._transit_relays = []
|
||||
self._listeners = set() # IListeningPorts that can be stopped
|
||||
self._pending_connectors = set() # Deferreds that can be cancelled
|
||||
self._pending_connections = EmptyableSet(_eventual_queue=self._eventual_queue) # Protocols to be stopped
|
||||
self._contenders = set() # viable connections
|
||||
self._listeners = set() # IListeningPorts that can be stopped
|
||||
self._pending_connectors = set() # Deferreds that can be cancelled
|
||||
self._pending_connections = EmptyableSet(
|
||||
_eventual_queue=self._eventual_queue) # Protocols to be stopped
|
||||
self._contenders = set() # viable connections
|
||||
self._winning_connection = None
|
||||
self._timing = self._timing or DebugTiming()
|
||||
self._timing.add("transit")
|
||||
|
@ -212,26 +226,41 @@ class Connector(object):
|
|||
return p
|
||||
|
||||
@m.state(initial=True)
|
||||
def connecting(self): pass # pragma: no cover
|
||||
def connecting(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.state()
|
||||
def connected(self): pass # pragma: no cover
|
||||
def connected(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.state(terminal=True)
|
||||
def stopped(self): pass # pragma: no cover
|
||||
def stopped(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
# TODO: unify the tense of these method-name verbs
|
||||
@m.input()
|
||||
def listener_ready(self, hint_objs): pass
|
||||
@m.input()
|
||||
def add_relay(self, hint_objs): pass
|
||||
@m.input()
|
||||
def got_hints(self, hint_objs): pass
|
||||
@m.input()
|
||||
def add_candidate(self, c): # called by DilatedConnectionProtocol
|
||||
def listener_ready(self, hint_objs):
|
||||
pass
|
||||
|
||||
@m.input()
|
||||
def accept(self, c): pass
|
||||
def add_relay(self, hint_objs):
|
||||
pass
|
||||
|
||||
@m.input()
|
||||
def stop(self): pass
|
||||
def got_hints(self, hint_objs):
|
||||
pass
|
||||
|
||||
@m.input()
|
||||
def add_candidate(self, c): # called by DilatedConnectionProtocol
|
||||
pass
|
||||
|
||||
@m.input()
|
||||
def accept(self, c):
|
||||
pass
|
||||
|
||||
@m.input()
|
||||
def stop(self):
|
||||
pass
|
||||
|
||||
@m.output()
|
||||
def use_hints(self, hint_objs):
|
||||
|
@ -255,19 +284,19 @@ class Connector(object):
|
|||
@m.output()
|
||||
def select_and_stop_remaining(self, c):
|
||||
self._winning_connection = c
|
||||
self._contenders.clear() # we no longer care who else came close
|
||||
self._contenders.clear() # we no longer care who else came close
|
||||
# remove this winner from the losers, so we don't shut it down
|
||||
self._pending_connections.discard(c)
|
||||
# shut down losing connections
|
||||
self.stop_listeners() # TODO: maybe keep it open? NAT/p2p assist
|
||||
self.stop_listeners() # TODO: maybe keep it open? NAT/p2p assist
|
||||
self.stop_pending_connectors()
|
||||
self.stop_pending_connections()
|
||||
|
||||
c.select(self._manager) # subsequent frames go directly to the manager
|
||||
c.select(self._manager) # subsequent frames go directly to the manager
|
||||
if self._role is LEADER:
|
||||
# TODO: this should live in Connection
|
||||
c.send_record(KCM()) # leader sends KCM now
|
||||
self._manager.use_connection(c) # manager sends frames to Connection
|
||||
c.send_record(KCM()) # leader sends KCM now
|
||||
self._manager.use_connection(c) # manager sends frames to Connection
|
||||
|
||||
@m.output()
|
||||
def stop_everything(self):
|
||||
|
@ -279,7 +308,7 @@ class Connector(object):
|
|||
def stop_listeners(self):
|
||||
d = DeferredList([l.stopListening() for l in self._listeners])
|
||||
self._listeners.clear()
|
||||
return d # synchronization for tests
|
||||
return d # synchronization for tests
|
||||
|
||||
def stop_pending_connectors(self):
|
||||
return DeferredList([d.cancel() for d in self._pending_connectors])
|
||||
|
@ -306,7 +335,8 @@ class Connector(object):
|
|||
publish_hints])
|
||||
connecting.upon(got_hints, enter=connecting, outputs=[use_hints])
|
||||
connecting.upon(add_candidate, enter=connecting, outputs=[consider])
|
||||
connecting.upon(accept, enter=connected, outputs=[select_and_stop_remaining])
|
||||
connecting.upon(accept, enter=connected, outputs=[
|
||||
select_and_stop_remaining])
|
||||
connecting.upon(stop, enter=stopped, outputs=[stop_everything])
|
||||
|
||||
# once connected, we ignore everything except stop
|
||||
|
@ -317,9 +347,9 @@ class Connector(object):
|
|||
connected.upon(accept, enter=connected, outputs=[])
|
||||
connected.upon(stop, enter=stopped, outputs=[stop_everything])
|
||||
|
||||
|
||||
# from Manager: start, got_hints, stop
|
||||
# maybe add_candidate, accept
|
||||
|
||||
def start(self):
|
||||
self._start_listener()
|
||||
if self._transit_relays:
|
||||
|
@ -341,9 +371,10 @@ class Connector(object):
|
|||
ep = serverFromString(self._reactor, "tcp:0")
|
||||
f = InboundConnectionFactory(self)
|
||||
d = ep.listen(f)
|
||||
|
||||
def _listening(lp):
|
||||
# lp is an IListeningPort
|
||||
self._listeners.add(lp) # for shutdown and tests
|
||||
self._listeners.add(lp) # for shutdown and tests
|
||||
portnum = lp.getHost().port
|
||||
direct_hints = [DirectTCPV1Hint(six.u(addr), portnum, 0.0)
|
||||
for addr in addresses]
|
||||
|
@ -378,7 +409,7 @@ class Connector(object):
|
|||
# one still running. But if we bail on that, we might consider
|
||||
# putting an inter-direct-hint delay here to influence the
|
||||
# process.
|
||||
#delay += 1.0
|
||||
# delay += 1.0
|
||||
if delay > 0.0:
|
||||
# Start trying the relays a few seconds after we start to try the
|
||||
# direct hints. The idea is to prefer direct connections, but not
|
||||
|
@ -408,7 +439,7 @@ class Connector(object):
|
|||
self._connect, ep, desc, is_relay=True)
|
||||
self._pending_connectors.add(d)
|
||||
# TODO:
|
||||
#if not contenders:
|
||||
# if not contenders:
|
||||
# raise TransitError("No contenders for connection")
|
||||
|
||||
# TODO: add 2*TIMEOUT deadline for first generation, don't wait forever for
|
||||
|
@ -422,6 +453,7 @@ class Connector(object):
|
|||
f = OutboundConnectionFactory(self, relay_handshake)
|
||||
d = ep.connect(f)
|
||||
# fires with protocol, or ConnectError
|
||||
|
||||
def _connected(p):
|
||||
self._pending_connections.add(p)
|
||||
# c might not be in _pending_connections, if it turned out to be a
|
||||
|
@ -444,7 +476,6 @@ class Connector(object):
|
|||
return HostnameEndpoint(self._reactor, hint.hostname, hint.port)
|
||||
return None
|
||||
|
||||
|
||||
# Connection selection. All instances of DilatedConnectionProtocol which
|
||||
# look viable get passed into our add_contender() method.
|
||||
|
||||
|
@ -459,6 +490,7 @@ class Connector(object):
|
|||
|
||||
# our Connection protocols call: add_candidate
|
||||
|
||||
|
||||
@attrs
|
||||
class OutboundConnectionFactory(ClientFactory, object):
|
||||
_connector = attrib(validator=provides(IDilationConnector))
|
||||
|
@ -471,6 +503,7 @@ class OutboundConnectionFactory(ClientFactory, object):
|
|||
p.use_relay(self._relay_handshake)
|
||||
return p
|
||||
|
||||
|
||||
@attrs
|
||||
class InboundConnectionFactory(ServerFactory, object):
|
||||
_connector = attrib(validator=provides(IDilationConnector))
|
||||
|
|
|
@ -4,10 +4,13 @@ import struct
|
|||
assert len(struct.pack("<L", 0)) == 4
|
||||
assert len(struct.pack("<Q", 0)) == 8
|
||||
|
||||
|
||||
def to_be4(value):
|
||||
if not 0 <= value < 2**32:
|
||||
raise ValueError
|
||||
return struct.pack(">L", value)
|
||||
|
||||
|
||||
def from_be4(b):
|
||||
if not isinstance(b, bytes):
|
||||
raise TypeError(repr(b))
|
||||
|
|
|
@ -6,13 +6,19 @@ from twisted.python import log
|
|||
from .._interfaces import IDilationManager, IInbound
|
||||
from .subchannel import (SubChannel, _SubchannelAddress)
|
||||
|
||||
|
||||
class DuplicateOpenError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DataForMissingSubchannelError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class CloseForMissingSubchannelError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@attrs
|
||||
@implementer(IInbound)
|
||||
class Inbound(object):
|
||||
|
@ -24,8 +30,8 @@ class Inbound(object):
|
|||
|
||||
def __attrs_post_init__(self):
|
||||
# we route inbound Data records to Subchannels .dataReceived
|
||||
self._open_subchannels = {} # scid -> Subchannel
|
||||
self._paused_subchannels = set() # Subchannels that have paused us
|
||||
self._open_subchannels = {} # scid -> Subchannel
|
||||
self._paused_subchannels = set() # Subchannels that have paused us
|
||||
# the set is non-empty, we pause the transport
|
||||
self._highest_inbound_acked = -1
|
||||
self._connection = None
|
||||
|
@ -37,7 +43,6 @@ class Inbound(object):
|
|||
def set_subchannel_zero(self, scid0, sc0):
|
||||
self._open_subchannels[scid0] = sc0
|
||||
|
||||
|
||||
def use_connection(self, c):
|
||||
self._connection = c
|
||||
# We can pause the connection's reads when we receive too much data. If
|
||||
|
@ -61,7 +66,8 @@ class Inbound(object):
|
|||
|
||||
def handle_open(self, scid):
|
||||
if scid in self._open_subchannels:
|
||||
log.err(DuplicateOpenError("received duplicate OPEN for {}".format(scid)))
|
||||
log.err(DuplicateOpenError(
|
||||
"received duplicate OPEN for {}".format(scid)))
|
||||
return
|
||||
peer_addr = _SubchannelAddress(scid)
|
||||
sc = SubChannel(scid, self._manager, self._host_addr, peer_addr)
|
||||
|
@ -71,14 +77,16 @@ class Inbound(object):
|
|||
def handle_data(self, scid, data):
|
||||
sc = self._open_subchannels.get(scid)
|
||||
if sc is None:
|
||||
log.err(DataForMissingSubchannelError("received DATA for non-existent subchannel {}".format(scid)))
|
||||
log.err(DataForMissingSubchannelError(
|
||||
"received DATA for non-existent subchannel {}".format(scid)))
|
||||
return
|
||||
sc.remote_data(data)
|
||||
|
||||
def handle_close(self, scid):
|
||||
sc = self._open_subchannels.get(scid)
|
||||
if sc is None:
|
||||
log.err(CloseForMissingSubchannelError("received CLOSE for non-existent subchannel {}".format(scid)))
|
||||
log.err(CloseForMissingSubchannelError(
|
||||
"received CLOSE for non-existent subchannel {}".format(scid)))
|
||||
return
|
||||
sc.remote_close()
|
||||
|
||||
|
@ -90,7 +98,6 @@ class Inbound(object):
|
|||
def stop_using_connection(self):
|
||||
self._connection = None
|
||||
|
||||
|
||||
# from our Subchannel, or rather from the Protocol above it and sent
|
||||
# through the subchannel
|
||||
|
||||
|
|
|
@ -20,13 +20,19 @@ from .connection import KCM, Ping, Pong, Open, Data, Close, Ack
|
|||
from .inbound import Inbound
|
||||
from .outbound import Outbound
|
||||
|
||||
|
||||
class OldPeerCannotDilateError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class UnknownDilationMessageType(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ReceivedHintsTooEarly(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@attrs
|
||||
@implementer(IDilationManager)
|
||||
class _ManagerBase(object):
|
||||
|
@ -37,14 +43,14 @@ class _ManagerBase(object):
|
|||
_reactor = attrib()
|
||||
_eventual_queue = attrib()
|
||||
_cooperator = attrib()
|
||||
_no_listen = False # TODO
|
||||
_tor = None # TODO
|
||||
_timing = None # TODO
|
||||
_no_listen = False # TODO
|
||||
_tor = None # TODO
|
||||
_timing = None # TODO
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
self._got_versions_d = Deferred()
|
||||
|
||||
self._my_role = None # determined upon rx_PLEASE
|
||||
self._my_role = None # determined upon rx_PLEASE
|
||||
|
||||
self._connection = None
|
||||
self._made_first_connection = False
|
||||
|
@ -53,51 +59,56 @@ class _ManagerBase(object):
|
|||
|
||||
self._next_dilation_phase = 0
|
||||
|
||||
self._next_subchannel_id = 0 # increments by 2
|
||||
self._next_subchannel_id = 0 # increments by 2
|
||||
|
||||
# I kept getting confused about which methods were for inbound data
|
||||
# (and thus flow-control methods go "out") and which were for
|
||||
# outbound data (with flow-control going "in"), so I split them up
|
||||
# into separate pieces.
|
||||
self._inbound = Inbound(self, self._host_addr)
|
||||
self._outbound = Outbound(self, self._cooperator) # from us to peer
|
||||
self._outbound = Outbound(self, self._cooperator) # from us to peer
|
||||
|
||||
def set_listener_endpoint(self, listener_endpoint):
|
||||
self._inbound.set_listener_endpoint(listener_endpoint)
|
||||
|
||||
def set_subchannel_zero(self, scid0, sc0):
|
||||
self._inbound.set_subchannel_zero(scid0, sc0)
|
||||
|
||||
def when_first_connected(self):
|
||||
return self._first_connected.when_fired()
|
||||
|
||||
|
||||
def send_dilation_phase(self, **fields):
|
||||
dilation_phase = self._next_dilation_phase
|
||||
self._next_dilation_phase += 1
|
||||
self._S.send("dilate-%d" % dilation_phase, dict_to_bytes(fields))
|
||||
|
||||
def send_hints(self, hints): # from Connector
|
||||
def send_hints(self, hints): # from Connector
|
||||
self.send_dilation_phase(type="connection-hints", hints=hints)
|
||||
|
||||
|
||||
# forward inbound-ish things to _Inbound
|
||||
|
||||
def subchannel_pauseProducing(self, sc):
|
||||
self._inbound.subchannel_pauseProducing(sc)
|
||||
|
||||
def subchannel_resumeProducing(self, sc):
|
||||
self._inbound.subchannel_resumeProducing(sc)
|
||||
|
||||
def subchannel_stopProducing(self, sc):
|
||||
self._inbound.subchannel_stopProducing(sc)
|
||||
|
||||
# forward outbound-ish things to _Outbound
|
||||
def subchannel_registerProducer(self, sc, producer, streaming):
|
||||
self._outbound.subchannel_registerProducer(sc, producer, streaming)
|
||||
|
||||
def subchannel_unregisterProducer(self, sc):
|
||||
self._outbound.subchannel_unregisterProducer(sc)
|
||||
|
||||
def send_open(self, scid):
|
||||
self._queue_and_send(Open, scid)
|
||||
|
||||
def send_data(self, scid, data):
|
||||
self._queue_and_send(Data, scid, data)
|
||||
|
||||
def send_close(self, scid):
|
||||
self._queue_and_send(Close, scid)
|
||||
|
||||
|
@ -106,7 +117,7 @@ class _ManagerBase(object):
|
|||
# Outbound owns the send_record() pipe, so that it can stall new
|
||||
# writes after a new connection is made until after all queued
|
||||
# messages are written (to preserve ordering).
|
||||
self._outbound.queue_and_send_record(r) # may trigger pauseProducing
|
||||
self._outbound.queue_and_send_record(r) # may trigger pauseProducing
|
||||
|
||||
def subchannel_closed(self, scid, sc):
|
||||
# let everyone clean up. This happens just after we delivered
|
||||
|
@ -116,7 +127,6 @@ class _ManagerBase(object):
|
|||
self._inbound.subchannel_closed(scid, sc)
|
||||
self._outbound.subchannel_closed(scid, sc)
|
||||
|
||||
|
||||
def _start_connecting(self, role):
|
||||
assert self._my_role is not None
|
||||
self._connector = Connector(self._transit_key,
|
||||
|
@ -125,41 +135,41 @@ class _ManagerBase(object):
|
|||
self._reactor, self._eventual_queue,
|
||||
self._no_listen, self._tor,
|
||||
self._timing,
|
||||
self._side, # needed for relay handshake
|
||||
self._side, # needed for relay handshake
|
||||
self._my_role)
|
||||
self._connector.start()
|
||||
|
||||
# our Connector calls these
|
||||
|
||||
def connector_connection_made(self, c):
|
||||
self.connection_made() # state machine update
|
||||
self.connection_made() # state machine update
|
||||
self._connection = c
|
||||
self._inbound.use_connection(c)
|
||||
self._outbound.use_connection(c) # does c.registerProducer
|
||||
self._outbound.use_connection(c) # does c.registerProducer
|
||||
if not self._made_first_connection:
|
||||
self._made_first_connection = True
|
||||
self._first_connected.fire(None)
|
||||
pass
|
||||
|
||||
def connector_connection_lost(self):
|
||||
self._stop_using_connection()
|
||||
if self.role is LEADER:
|
||||
self.connection_lost_leader() # state machine
|
||||
self.connection_lost_leader() # state machine
|
||||
else:
|
||||
self.connection_lost_follower()
|
||||
|
||||
|
||||
def _stop_using_connection(self):
|
||||
# the connection is already lost by this point
|
||||
self._connection = None
|
||||
self._inbound.stop_using_connection()
|
||||
self._outbound.stop_using_connection() # does c.unregisterProducer
|
||||
self._outbound.stop_using_connection() # does c.unregisterProducer
|
||||
|
||||
# from our active Connection
|
||||
|
||||
def got_record(self, r):
|
||||
# records with sequence numbers: always ack, ignore old ones
|
||||
if isinstance(r, (Open, Data, Close)):
|
||||
self.send_ack(r.seqnum) # always ack, even for old ones
|
||||
self.send_ack(r.seqnum) # always ack, even for old ones
|
||||
if self._inbound.is_record_old(r):
|
||||
return
|
||||
self._inbound.update_ack_watermark(r.seqnum)
|
||||
|
@ -167,7 +177,7 @@ class _ManagerBase(object):
|
|||
self._inbound.handle_open(r.scid)
|
||||
elif isinstance(r, Data):
|
||||
self._inbound.handle_data(r.scid, r.data)
|
||||
else: # isinstance(r, Close)
|
||||
else: # isinstance(r, Close)
|
||||
self._inbound.handle_close(r.scid)
|
||||
if isinstance(r, KCM):
|
||||
log.err("got unexpected KCM")
|
||||
|
@ -176,7 +186,7 @@ class _ManagerBase(object):
|
|||
elif isinstance(r, Pong):
|
||||
self.handle_pong(r.ping_id)
|
||||
elif isinstance(r, Ack):
|
||||
self._outbound.handle_ack(r.resp_seqnum) # retire queued messages
|
||||
self._outbound.handle_ack(r.resp_seqnum) # retire queued messages
|
||||
else:
|
||||
log.err("received unknown message type {}".format(r))
|
||||
|
||||
|
@ -190,7 +200,6 @@ class _ManagerBase(object):
|
|||
def send_ack(self, resp_seqnum):
|
||||
self._outbound.send_if_connected(Ack(resp_seqnum))
|
||||
|
||||
|
||||
def handle_ping(self, ping_id):
|
||||
self.send_pong(ping_id)
|
||||
|
||||
|
@ -200,7 +209,7 @@ class _ManagerBase(object):
|
|||
|
||||
# subchannel maintenance
|
||||
def allocate_subchannel_id(self):
|
||||
raise NotImplemented # subclass knows if we're leader or follower
|
||||
raise NotImplementedError # subclass knows if we're leader or follower
|
||||
|
||||
# new scheme:
|
||||
# * both sides send PLEASE as soon as they have an unverified key and
|
||||
|
@ -236,58 +245,91 @@ class _ManagerBase(object):
|
|||
# * if follower calls w.dilate() but not leader, follower waits forever
|
||||
# in "want", leader waits forever in "wanted"
|
||||
|
||||
|
||||
class ManagerShared(_ManagerBase):
|
||||
m = MethodicalMachine()
|
||||
set_trace = getattr(m, "_setTrace", lambda self, f: None)
|
||||
|
||||
@m.state(initial=True)
|
||||
def IDLE(self): pass # pragma: no cover
|
||||
def IDLE(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.state()
|
||||
def WANTING(self): pass # pragma: no cover
|
||||
def WANTING(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.state()
|
||||
def WANTED(self): pass # pragma: no cover
|
||||
def WANTED(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.state()
|
||||
def CONNECTING(self): pass # pragma: no cover
|
||||
def CONNECTING(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.state()
|
||||
def CONNECTED(self): pass # pragma: no cover
|
||||
def CONNECTED(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.state()
|
||||
def FLUSHING(self): pass # pragma: no cover
|
||||
def FLUSHING(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.state()
|
||||
def ABANDONING(self): pass # pragma: no cover
|
||||
def ABANDONING(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.state()
|
||||
def LONELY(self): pass # pragme: no cover
|
||||
def LONELY(self):
|
||||
pass # pragme: no cover
|
||||
|
||||
@m.state()
|
||||
def STOPPING(self): pass # pragma: no cover
|
||||
def STOPPING(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.state(terminal=True)
|
||||
def STOPPED(self): pass # pragma: no cover
|
||||
def STOPPED(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.input()
|
||||
def start(self): pass # pragma: no cover
|
||||
def start(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.input()
|
||||
def rx_PLEASE(self, message): pass # pragma: no cover
|
||||
@m.input() # only sent by Follower
|
||||
def rx_HINTS(self, hint_message): pass # pragma: no cover
|
||||
@m.input() # only Leader sends RECONNECT, so only Follower receives it
|
||||
def rx_RECONNECT(self): pass # pragma: no cover
|
||||
@m.input() # only Follower sends RECONNECTING, so only Leader receives it
|
||||
def rx_RECONNECTING(self): pass # pragma: no cover
|
||||
def rx_PLEASE(self, message):
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.input() # only sent by Follower
|
||||
def rx_HINTS(self, hint_message):
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.input() # only Leader sends RECONNECT, so only Follower receives it
|
||||
def rx_RECONNECT(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.input() # only Follower sends RECONNECTING, so only Leader receives it
|
||||
def rx_RECONNECTING(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
# Connector gives us connection_made()
|
||||
@m.input()
|
||||
def connection_made(self, c): pass # pragma: no cover
|
||||
def connection_made(self, c):
|
||||
pass # pragma: no cover
|
||||
|
||||
# our connection_lost() fires connection_lost_leader or
|
||||
# connection_lost_follower depending upon our role. If either side sees a
|
||||
# problem with the connection (timeouts, bad authentication) then they
|
||||
# just drop it and let connection_lost() handle the cleanup.
|
||||
@m.input()
|
||||
def connection_lost_leader(self): pass # pragma: no cover
|
||||
@m.input()
|
||||
def connection_lost_follower(self): pass
|
||||
def connection_lost_leader(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.input()
|
||||
def stop(self): pass # pragma: no cover
|
||||
def connection_lost_follower(self):
|
||||
pass
|
||||
|
||||
@m.input()
|
||||
def stop(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.output()
|
||||
def stash_side(self, message):
|
||||
|
@ -301,38 +343,42 @@ class ManagerShared(_ManagerBase):
|
|||
|
||||
@m.output()
|
||||
def start_connecting(self):
|
||||
self._start_connecting() # TODO: merge
|
||||
self._start_connecting() # TODO: merge
|
||||
|
||||
@m.output()
|
||||
def ignore_message_start_connecting(self, message):
|
||||
self.start_connecting()
|
||||
|
||||
@m.output()
|
||||
def send_reconnect(self):
|
||||
self.send_dilation_phase(type="reconnect") # TODO: generation number?
|
||||
self.send_dilation_phase(type="reconnect") # TODO: generation number?
|
||||
|
||||
@m.output()
|
||||
def send_reconnecting(self):
|
||||
self.send_dilation_phase(type="reconnecting") # TODO: generation?
|
||||
self.send_dilation_phase(type="reconnecting") # TODO: generation?
|
||||
|
||||
@m.output()
|
||||
def use_hints(self, hint_message):
|
||||
hint_objs = filter(lambda h: h, # ignore None, unrecognizable
|
||||
hint_objs = filter(lambda h: h, # ignore None, unrecognizable
|
||||
[parse_hint(hs) for hs in hint_message["hints"]])
|
||||
hint_objs = list(hint_objs)
|
||||
self._connector.got_hints(hint_objs)
|
||||
|
||||
@m.output()
|
||||
def stop_connecting(self):
|
||||
self._connector.stop()
|
||||
|
||||
@m.output()
|
||||
def abandon_connection(self):
|
||||
# we think we're still connected, but the Leader disagrees. Or we've
|
||||
# been told to shut down.
|
||||
self._connection.disconnect() # let connection_lost do cleanup
|
||||
|
||||
self._connection.disconnect() # let connection_lost do cleanup
|
||||
|
||||
# we don't start CONNECTING until a local start() plus rx_PLEASE
|
||||
IDLE.upon(rx_PLEASE, enter=WANTED, outputs=[stash_side])
|
||||
IDLE.upon(start, enter=WANTING, outputs=[send_please])
|
||||
WANTED.upon(start, enter=CONNECTING, outputs=[send_please, start_connecting])
|
||||
WANTED.upon(start, enter=CONNECTING, outputs=[
|
||||
send_please, start_connecting])
|
||||
WANTING.upon(rx_PLEASE, enter=CONNECTING,
|
||||
outputs=[stash_side,
|
||||
ignore_message_start_connecting])
|
||||
|
@ -342,7 +388,8 @@ class ManagerShared(_ManagerBase):
|
|||
# Leader
|
||||
CONNECTED.upon(connection_lost_leader, enter=FLUSHING,
|
||||
outputs=[send_reconnect])
|
||||
FLUSHING.upon(rx_RECONNECTING, enter=CONNECTING, outputs=[start_connecting])
|
||||
FLUSHING.upon(rx_RECONNECTING, enter=CONNECTING,
|
||||
outputs=[start_connecting])
|
||||
|
||||
# Follower
|
||||
# if we notice a lost connection, just wait for the Leader to notice too
|
||||
|
@ -350,7 +397,7 @@ class ManagerShared(_ManagerBase):
|
|||
LONELY.upon(rx_RECONNECT, enter=CONNECTING, outputs=[start_connecting])
|
||||
# but if they notice it first, abandon our (seemingly functional)
|
||||
# connection, then tell them that we're ready to try again
|
||||
CONNECTED.upon(rx_RECONNECT, enter=ABANDONING, # they noticed loss
|
||||
CONNECTED.upon(rx_RECONNECT, enter=ABANDONING, # they noticed loss
|
||||
outputs=[abandon_connection])
|
||||
ABANDONING.upon(connection_lost_follower, enter=CONNECTING,
|
||||
outputs=[send_reconnecting, start_connecting])
|
||||
|
@ -362,16 +409,15 @@ class ManagerShared(_ManagerBase):
|
|||
send_reconnecting,
|
||||
start_connecting])
|
||||
|
||||
|
||||
# rx_HINTS never changes state, they're just accepted or ignored
|
||||
IDLE.upon(rx_HINTS, enter=IDLE, outputs=[]) # too early
|
||||
WANTED.upon(rx_HINTS, enter=WANTED, outputs=[]) # too early
|
||||
WANTING.upon(rx_HINTS, enter=WANTING, outputs=[]) # too early
|
||||
IDLE.upon(rx_HINTS, enter=IDLE, outputs=[]) # too early
|
||||
WANTED.upon(rx_HINTS, enter=WANTED, outputs=[]) # too early
|
||||
WANTING.upon(rx_HINTS, enter=WANTING, outputs=[]) # too early
|
||||
CONNECTING.upon(rx_HINTS, enter=CONNECTING, outputs=[use_hints])
|
||||
CONNECTED.upon(rx_HINTS, enter=CONNECTED, outputs=[]) # too late, ignore
|
||||
FLUSHING.upon(rx_HINTS, enter=FLUSHING, outputs=[]) # stale, ignore
|
||||
LONELY.upon(rx_HINTS, enter=LONELY, outputs=[]) # stale, ignore
|
||||
ABANDONING.upon(rx_HINTS, enter=ABANDONING, outputs=[]) # shouldn't happen
|
||||
CONNECTED.upon(rx_HINTS, enter=CONNECTED, outputs=[]) # too late, ignore
|
||||
FLUSHING.upon(rx_HINTS, enter=FLUSHING, outputs=[]) # stale, ignore
|
||||
LONELY.upon(rx_HINTS, enter=LONELY, outputs=[]) # stale, ignore
|
||||
ABANDONING.upon(rx_HINTS, enter=ABANDONING, outputs=[]) # shouldn't happen
|
||||
STOPPING.upon(rx_HINTS, enter=STOPPING, outputs=[])
|
||||
|
||||
IDLE.upon(stop, enter=STOPPED, outputs=[])
|
||||
|
@ -385,7 +431,6 @@ class ManagerShared(_ManagerBase):
|
|||
STOPPING.upon(connection_lost_leader, enter=STOPPED, outputs=[])
|
||||
STOPPING.upon(connection_lost_follower, enter=STOPPED, outputs=[])
|
||||
|
||||
|
||||
def allocate_subchannel_id(self):
|
||||
# scid 0 is reserved for the control channel. the leader uses odd
|
||||
# numbers starting with 1
|
||||
|
@ -393,6 +438,7 @@ class ManagerShared(_ManagerBase):
|
|||
self._next_outbound_seqnum += 2
|
||||
return to_be4(scid_num)
|
||||
|
||||
|
||||
@attrs
|
||||
@implementer(IDilator)
|
||||
class Dilator(object):
|
||||
|
@ -436,10 +482,10 @@ class Dilator(object):
|
|||
|
||||
dilation_version = yield self._got_versions_d
|
||||
|
||||
if not dilation_version: # 1 or None
|
||||
if not dilation_version: # 1 or None
|
||||
raise OldPeerCannotDilateError()
|
||||
|
||||
my_dilation_side = TODO # random
|
||||
my_dilation_side = TODO # random
|
||||
self._manager = Manager(self._S, my_dilation_side,
|
||||
self._transit_key,
|
||||
self._transit_relay_location,
|
||||
|
@ -455,14 +501,15 @@ class Dilator(object):
|
|||
yield self._manager.when_first_connected()
|
||||
# we can open subchannels as soon as we get our first connection
|
||||
scid0 = b"\x00\x00\x00\x00"
|
||||
self._host_addr = _WormholeAddress() # TODO: share with Manager
|
||||
self._host_addr = _WormholeAddress() # TODO: share with Manager
|
||||
peer_addr0 = _SubchannelAddress(scid0)
|
||||
control_ep = ControlEndpoint(peer_addr0)
|
||||
sc0 = SubChannel(scid0, self._manager, self._host_addr, peer_addr0)
|
||||
control_ep._subchannel_zero_opened(sc0)
|
||||
self._manager.set_subchannel_zero(scid0, sc0)
|
||||
|
||||
connect_ep = SubchannelConnectorEndpoint(self._manager, self._host_addr)
|
||||
connect_ep = SubchannelConnectorEndpoint(
|
||||
self._manager, self._host_addr)
|
||||
|
||||
listen_ep = SubchannelListenerEndpoint(self._manager, self._host_addr)
|
||||
self._manager.set_listener_endpoint(listen_ep)
|
||||
|
@ -476,7 +523,7 @@ class Dilator(object):
|
|||
# TODO: verify this happens before got_wormhole_versions, or add a gate
|
||||
# to tolerate either ordering
|
||||
purpose = b"dilation-v1"
|
||||
LENGTH =32 # TODO: whatever Noise wants, I guess
|
||||
LENGTH = 32 # TODO: whatever Noise wants, I guess
|
||||
self._transit_key = derive_key(key, purpose, LENGTH)
|
||||
|
||||
def got_wormhole_versions(self, our_side, their_side,
|
||||
|
@ -504,9 +551,9 @@ class Dilator(object):
|
|||
message = bytes_to_dict(plaintext)
|
||||
type = message["type"]
|
||||
if type == "please":
|
||||
self._manager.rx_PLEASE() # message)
|
||||
self._manager.rx_PLEASE() # message)
|
||||
elif type == "dilate":
|
||||
self._manager.rx_DILATE() #message)
|
||||
self._manager.rx_DILATE() # message)
|
||||
elif type == "connection-hints":
|
||||
self._manager.rx_HINTS(message)
|
||||
else:
|
||||
|
|
|
@ -168,9 +168,9 @@ class Outbound(object):
|
|||
self._queued_unsent = deque()
|
||||
|
||||
# outbound flow control: the Connection throttles our writes
|
||||
self._subchannel_producers = {} # Subchannel -> IProducer
|
||||
self._paused = True # our Connection called our pauseProducing
|
||||
self._all_producers = deque() # rotates, left-is-next
|
||||
self._subchannel_producers = {} # Subchannel -> IProducer
|
||||
self._paused = True # our Connection called our pauseProducing
|
||||
self._all_producers = deque() # rotates, left-is-next
|
||||
self._paused_producers = set()
|
||||
self._unpaused_producers = set()
|
||||
self._check_invariants()
|
||||
|
@ -186,7 +186,7 @@ class Outbound(object):
|
|||
seqnum = self._next_outbound_seqnum
|
||||
self._next_outbound_seqnum += 1
|
||||
r = record_type(seqnum, *args)
|
||||
assert hasattr(r, "seqnum"), r # only Open/Data/Close
|
||||
assert hasattr(r, "seqnum"), r # only Open/Data/Close
|
||||
return r
|
||||
|
||||
def queue_and_send_record(self, r):
|
||||
|
@ -203,7 +203,7 @@ class Outbound(object):
|
|||
self._connection.send_record(r)
|
||||
|
||||
def send_if_connected(self, r):
|
||||
assert isinstance(r, (KCM, Ping, Pong, Ack)), r # nothing with seqnum
|
||||
assert isinstance(r, (KCM, Ping, Pong, Ack)), r # nothing with seqnum
|
||||
if self._connection:
|
||||
self._connection.send_record(r)
|
||||
|
||||
|
@ -235,7 +235,7 @@ class Outbound(object):
|
|||
if self._paused:
|
||||
# IPushProducers need to be paused immediately, before they
|
||||
# speak
|
||||
producer.pauseProducing() # you wake up sleeping
|
||||
producer.pauseProducing() # you wake up sleeping
|
||||
else:
|
||||
# our PullToPush adapter must be started, but if we're paused then
|
||||
# we tell it to pause before it gets a chance to write anything
|
||||
|
@ -265,7 +265,7 @@ class Outbound(object):
|
|||
assert not self._queued_unsent
|
||||
self._queued_unsent.extend(self._outbound_queue)
|
||||
# the connection can tell us to pause when we send too much data
|
||||
c.registerProducer(self, True) # IPushProducer: pause+resume
|
||||
c.registerProducer(self, True) # IPushProducer: pause+resume
|
||||
# send our queued messages
|
||||
self.resumeProducing()
|
||||
|
||||
|
@ -290,12 +290,12 @@ class Outbound(object):
|
|||
# Inbound is responsible for tracking the high watermark and deciding
|
||||
# whether to ignore inbound messages or not
|
||||
|
||||
|
||||
# IProducer: the active connection calls these because we used
|
||||
# c.registerProducer to ask for them
|
||||
|
||||
def pauseProducing(self):
|
||||
if self._paused:
|
||||
return # someone is confused and called us twice
|
||||
return # someone is confused and called us twice
|
||||
self._paused = True
|
||||
for p in self._all_producers:
|
||||
if p in self._unpaused_producers:
|
||||
|
@ -305,7 +305,7 @@ class Outbound(object):
|
|||
|
||||
def resumeProducing(self):
|
||||
if not self._paused:
|
||||
return # someone is confused and called us twice
|
||||
return # someone is confused and called us twice
|
||||
self._paused = False
|
||||
|
||||
while not self._paused:
|
||||
|
@ -326,7 +326,7 @@ class Outbound(object):
|
|||
return None
|
||||
while True:
|
||||
p = self._all_producers[0]
|
||||
self._all_producers.rotate(-1) # p moves to the end of the line
|
||||
self._all_producers.rotate(-1) # p moves to the end of the line
|
||||
# the only unpaused Producers are at the end of the list
|
||||
assert p in self._paused_producers
|
||||
return p
|
||||
|
@ -343,7 +343,7 @@ class Outbound(object):
|
|||
@attrs(cmp=False)
|
||||
class PullToPush(object):
|
||||
_producer = attrib(validator=provides(IPullProducer))
|
||||
_unregister = attrib(validator=lambda _a,_b,v: callable(v))
|
||||
_unregister = attrib(validator=lambda _a, _b, v: callable(v))
|
||||
_cooperator = attrib()
|
||||
_finished = False
|
||||
|
||||
|
@ -351,14 +351,14 @@ class PullToPush(object):
|
|||
while True:
|
||||
try:
|
||||
self._producer.resumeProducing()
|
||||
except:
|
||||
except Exception:
|
||||
log.err(None, "%s failed, producing will be stopped:" %
|
||||
(safe_str(self._producer),))
|
||||
try:
|
||||
self._unregister()
|
||||
# The consumer should now call stopStreaming() on us,
|
||||
# thus stopping the streaming.
|
||||
except:
|
||||
except Exception:
|
||||
# Since the consumer blew up, we may not have had
|
||||
# stopStreaming() called, so we just stop on our own:
|
||||
log.err(None, "%s failed to unregister producer:" %
|
||||
|
@ -370,7 +370,7 @@ class PullToPush(object):
|
|||
def startStreaming(self, paused):
|
||||
self._coopTask = self._cooperator.cooperate(self._pull())
|
||||
if paused:
|
||||
self.pauseProducing() # timer is scheduled, but task is removed
|
||||
self.pauseProducing() # timer is scheduled, but task is removed
|
||||
|
||||
def stopStreaming(self):
|
||||
if self._finished:
|
||||
|
@ -378,15 +378,12 @@ class PullToPush(object):
|
|||
self._finished = True
|
||||
self._coopTask.stop()
|
||||
|
||||
|
||||
def pauseProducing(self):
|
||||
self._coopTask.pause()
|
||||
|
||||
|
||||
def resumeProducing(self):
|
||||
self._coopTask.resume()
|
||||
|
||||
|
||||
def stopProducing(self):
|
||||
self.stopStreaming()
|
||||
self._producer.stopProducing()
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
from attr import attrs, attrib
|
||||
from attr.validators import instance_of, provides
|
||||
from zope.interface import implementer
|
||||
from twisted.internet.defer import Deferred, inlineCallbacks, returnValue, succeed
|
||||
from twisted.internet.defer import (Deferred, inlineCallbacks, returnValue,
|
||||
succeed)
|
||||
from twisted.internet.interfaces import (ITransport, IProducer, IConsumer,
|
||||
IAddress, IListeningPort,
|
||||
IStreamClientEndpoint,
|
||||
|
@ -10,9 +11,11 @@ from twisted.internet.error import ConnectionDone
|
|||
from automat import MethodicalMachine
|
||||
from .._interfaces import ISubChannel, IDilationManager
|
||||
|
||||
|
||||
@attrs
|
||||
class Once(object):
|
||||
_errtype = attrib()
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
self._called = False
|
||||
|
||||
|
@ -21,6 +24,7 @@ class Once(object):
|
|||
raise self._errtype()
|
||||
self._called = True
|
||||
|
||||
|
||||
class SingleUseEndpointError(Exception):
|
||||
pass
|
||||
|
||||
|
@ -38,13 +42,16 @@ class SingleUseEndpointError(Exception):
|
|||
# (CLOSING) rx CLOSE: deliver .connectionLost(), -> (CLOSED)
|
||||
# object is deleted upon transition to (CLOSED)
|
||||
|
||||
|
||||
class AlreadyClosedError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@implementer(IAddress)
|
||||
class _WormholeAddress(object):
|
||||
pass
|
||||
|
||||
|
||||
@implementer(IAddress)
|
||||
@attrs
|
||||
class _SubchannelAddress(object):
|
||||
|
@ -63,35 +70,44 @@ class SubChannel(object):
|
|||
_peer_addr = attrib(validator=instance_of(_SubchannelAddress))
|
||||
|
||||
m = MethodicalMachine()
|
||||
set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover
|
||||
set_trace = getattr(m, "_setTrace", lambda self,
|
||||
f: None) # pragma: no cover
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
#self._mailbox = None
|
||||
#self._pending_outbound = {}
|
||||
#self._processed = set()
|
||||
# self._mailbox = None
|
||||
# self._pending_outbound = {}
|
||||
# self._processed = set()
|
||||
self._protocol = None
|
||||
self._pending_dataReceived = []
|
||||
self._pending_connectionLost = (False, None)
|
||||
|
||||
@m.state(initial=True)
|
||||
def open(self): pass # pragma: no cover
|
||||
def open(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.state()
|
||||
def closing(): pass # pragma: no cover
|
||||
def closing():
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.state()
|
||||
def closed(): pass # pragma: no cover
|
||||
def closed():
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.input()
|
||||
def remote_data(self, data): pass
|
||||
@m.input()
|
||||
def remote_close(self): pass
|
||||
def remote_data(self, data):
|
||||
pass
|
||||
|
||||
@m.input()
|
||||
def local_data(self, data): pass
|
||||
@m.input()
|
||||
def local_close(self): pass
|
||||
def remote_close(self):
|
||||
pass
|
||||
|
||||
@m.input()
|
||||
def local_data(self, data):
|
||||
pass
|
||||
|
||||
@m.input()
|
||||
def local_close(self):
|
||||
pass
|
||||
|
||||
@m.output()
|
||||
def send_data(self, data):
|
||||
|
@ -120,9 +136,11 @@ class SubChannel(object):
|
|||
@m.output()
|
||||
def error_closed_write(self, data):
|
||||
raise AlreadyClosedError("write not allowed on closed subchannel")
|
||||
|
||||
@m.output()
|
||||
def error_closed_close(self):
|
||||
raise AlreadyClosedError("loseConnection not allowed on closed subchannel")
|
||||
raise AlreadyClosedError(
|
||||
"loseConnection not allowed on closed subchannel")
|
||||
|
||||
# primary transitions
|
||||
open.upon(remote_data, enter=open, outputs=[signal_dataReceived])
|
||||
|
@ -146,7 +164,7 @@ class SubChannel(object):
|
|||
if self._pending_dataReceived:
|
||||
for data in self._pending_dataReceived:
|
||||
self._protocol.dataReceived(data)
|
||||
self._pending_dataReceived = []
|
||||
self._pending_dataReceived = []
|
||||
cl, what = self._pending_connectionLost
|
||||
if cl:
|
||||
self._protocol.connectionLost(what)
|
||||
|
@ -155,13 +173,17 @@ class SubChannel(object):
|
|||
def write(self, data):
|
||||
assert isinstance(data, type(b""))
|
||||
self.local_data(data)
|
||||
|
||||
def writeSequence(self, iovec):
|
||||
self.write(b"".join(iovec))
|
||||
|
||||
def loseConnection(self):
|
||||
self.local_close()
|
||||
|
||||
def getHost(self):
|
||||
# we define "host addr" as the overall wormhole
|
||||
return self._host_addr
|
||||
|
||||
def getPeer(self):
|
||||
# and "peer addr" as the subchannel within that wormhole
|
||||
return self._peer_addr
|
||||
|
@ -169,14 +191,17 @@ class SubChannel(object):
|
|||
# IProducer: throttle inbound data (wormhole "up" to local app's Protocol)
|
||||
def stopProducing(self):
|
||||
self._manager.subchannel_stopProducing(self)
|
||||
|
||||
def pauseProducing(self):
|
||||
self._manager.subchannel_pauseProducing(self)
|
||||
|
||||
def resumeProducing(self):
|
||||
self._manager.subchannel_resumeProducing(self)
|
||||
|
||||
# IConsumer: allow the wormhole to throttle outbound data (app->wormhole)
|
||||
def registerProducer(self, producer, streaming):
|
||||
self._manager.subchannel_registerProducer(self, producer, streaming)
|
||||
|
||||
def unregisterProducer(self):
|
||||
self._manager.subchannel_unregisterProducer(self)
|
||||
|
||||
|
@ -184,6 +209,7 @@ class SubChannel(object):
|
|||
@implementer(IStreamClientEndpoint)
|
||||
class ControlEndpoint(object):
|
||||
_used = False
|
||||
|
||||
def __init__(self, peer_addr):
|
||||
self._subchannel_zero = Deferred()
|
||||
self._peer_addr = peer_addr
|
||||
|
@ -201,9 +227,10 @@ class ControlEndpoint(object):
|
|||
t = yield self._subchannel_zero
|
||||
p = protocolFactory.buildProtocol(self._peer_addr)
|
||||
t._set_protocol(p)
|
||||
p.makeConnection(t) # set p.transport = t and call connectionMade()
|
||||
p.makeConnection(t) # set p.transport = t and call connectionMade()
|
||||
returnValue(p)
|
||||
|
||||
|
||||
@implementer(IStreamClientEndpoint)
|
||||
@attrs
|
||||
class SubchannelConnectorEndpoint(object):
|
||||
|
@ -220,9 +247,10 @@ class SubchannelConnectorEndpoint(object):
|
|||
t = SubChannel(scid, self._manager, self._host_addr, peer_addr)
|
||||
p = protocolFactory.buildProtocol(peer_addr)
|
||||
t._set_protocol(p)
|
||||
p.makeConnection(t) # set p.transport = t and call connectionMade()
|
||||
p.makeConnection(t) # set p.transport = t and call connectionMade()
|
||||
return succeed(p)
|
||||
|
||||
|
||||
@implementer(IStreamServerEndpoint)
|
||||
@attrs
|
||||
class SubchannelListenerEndpoint(object):
|
||||
|
@ -238,7 +266,7 @@ class SubchannelListenerEndpoint(object):
|
|||
if self._factory:
|
||||
self._connect(t, peer_addr)
|
||||
else:
|
||||
self._pending_opens.append( (t, peer_addr) )
|
||||
self._pending_opens.append((t, peer_addr))
|
||||
|
||||
def _connect(self, t, peer_addr):
|
||||
p = self._factory.buildProtocol(peer_addr)
|
||||
|
@ -255,6 +283,7 @@ class SubchannelListenerEndpoint(object):
|
|||
lp = SubchannelListeningPort(self._host_addr)
|
||||
return succeed(lp)
|
||||
|
||||
|
||||
@implementer(IListeningPort)
|
||||
@attrs
|
||||
class SubchannelListeningPort(object):
|
||||
|
@ -262,8 +291,10 @@ class SubchannelListeningPort(object):
|
|||
|
||||
def startListening(self):
|
||||
pass
|
||||
|
||||
def stopListening(self):
|
||||
# TODO
|
||||
pass
|
||||
|
||||
def getHost(self):
|
||||
return self._host_addr
|
||||
|
|
Loading…
Reference in New Issue
Block a user