fix some flake8 complaints

This commit is contained in:
Brian Warner 2018-06-30 16:19:41 -07:00
parent 39666f3fed
commit 05900bd08b
8 changed files with 339 additions and 195 deletions

View File

@ -67,7 +67,8 @@ class Boss(object):
self._I = Input(self._timing) self._I = Input(self._timing)
self._C = Code(self._timing) self._C = Code(self._timing)
self._T = Terminator() self._T = Terminator()
self._D = Dilator(self._reactor, self._eventual_queue, self._cooperator) self._D = Dilator(self._reactor, self._eventual_queue,
self._cooperator)
self._N.wire(self._M, self._I, self._RC, self._T) self._N.wire(self._M, self._I, self._RC, self._T)
self._M.wire(self._N, self._RC, self._O, self._T) self._M.wire(self._N, self._RC, self._O, self._T)

View File

@ -39,20 +39,28 @@ from .roles import FOLLOWER
# states). For the specific question of sending plaintext frames, Noise will # states). For the specific question of sending plaintext frames, Noise will
# refuse us unless it's ready anyways, so the question is probably moot. # refuse us unless it's ready anyways, so the question is probably moot.
class IFramer(Interface): class IFramer(Interface):
pass pass
class IRecord(Interface): class IRecord(Interface):
pass pass
def first(l): def first(l):
return l[0] return l[0]
class Disconnect(Exception): class Disconnect(Exception):
pass pass
RelayOK = namedtuple("RelayOk", []) RelayOK = namedtuple("RelayOk", [])
Prologue = namedtuple("Prologue", []) Prologue = namedtuple("Prologue", [])
Frame = namedtuple("Frame", ["frame"]) Frame = namedtuple("Frame", ["frame"])
@attrs @attrs
@implementer(IFramer) @implementer(IFramer)
class _Framer(object): class _Framer(object):
@ -73,19 +81,25 @@ class _Framer(object):
@m.state() @m.state()
def want_relay(self): pass # pragma: no cover def want_relay(self): pass # pragma: no cover
@m.state(initial=True) @m.state(initial=True)
def want_prologue(self): pass # pragma: no cover def want_prologue(self): pass # pragma: no cover
@m.state() @m.state()
def want_frame(self): pass # pragma: no cover def want_frame(self): pass # pragma: no cover
@m.input() @m.input()
def use_relay(self, relay_handshake): pass def use_relay(self, relay_handshake): pass
@m.input() @m.input()
def connectionMade(self): pass def connectionMade(self): pass
@m.input() @m.input()
def parse(self): pass def parse(self): pass
@m.input() @m.input()
def got_relay_ok(self): pass def got_relay_ok(self): pass
@m.input() @m.input()
def got_prologue(self): pass def got_prologue(self): pass
@ -93,6 +107,7 @@ class _Framer(object):
def store_relay_handshake(self, relay_handshake): def store_relay_handshake(self, relay_handshake):
self._outbound_relay_handshake = relay_handshake self._outbound_relay_handshake = relay_handshake
self._expected_relay_handshake = b"ok\n" # TODO: make this configurable self._expected_relay_handshake = b"ok\n" # TODO: make this configurable
@m.output() @m.output()
def send_relay_handshake(self): def send_relay_handshake(self):
self._transport.write(self._outbound_relay_handshake) self._transport.write(self._outbound_relay_handshake)
@ -144,7 +159,6 @@ class _Framer(object):
want_frame.upon(parse, outputs=[parse_frame], enter=want_frame, want_frame.upon(parse, outputs=[parse_frame], enter=want_frame,
collector=first) collector=first)
def _get_expected(self, name, expected): def _get_expected(self, name, expected):
lb = len(self._buffer) lb = len(self._buffer)
le = len(expected) le = len(expected)
@ -202,6 +216,7 @@ class _Framer(object):
# from peer. Sent immediately by Follower, after Selection by Leader. # from peer. Sent immediately by Follower, after Selection by Leader.
# Record: namedtuple of KCM/Open/Data/Close/Ack/Ping/Pong # Record: namedtuple of KCM/Open/Data/Close/Ack/Ping/Pong
Handshake = namedtuple("Handshake", []) Handshake = namedtuple("Handshake", [])
# decrypted frames: produces KCM, Ping, Pong, Open, Data, Close, Ack # decrypted frames: produces KCM, Ping, Pong, Open, Data, Close, Ack
KCM = namedtuple("KCM", []) KCM = namedtuple("KCM", [])
@ -222,6 +237,7 @@ T_DATA = b"\x04"
T_CLOSE = b"\x05" T_CLOSE = b"\x05"
T_ACK = b"\x06" T_ACK = b"\x06"
def parse_record(plaintext): def parse_record(plaintext):
msgtype = plaintext[0:1] msgtype = plaintext[0:1]
if msgtype == T_KCM: if msgtype == T_KCM:
@ -251,6 +267,7 @@ def parse_record(plaintext):
log.err("received unknown message type: {}".format(plaintext)) log.err("received unknown message type: {}".format(plaintext))
raise ValueError() raise ValueError()
def encode_record(r): def encode_record(r):
if isinstance(r, KCM): if isinstance(r, KCM):
return b"\x00" return b"\x00"
@ -275,6 +292,7 @@ def encode_record(r):
return b"\x06" + to_be4(r.resp_seqnum) return b"\x06" + to_be4(r.resp_seqnum)
raise TypeError(r) raise TypeError(r)
@attrs @attrs
@implementer(IRecord) @implementer(IRecord)
class _Record(object): class _Record(object):
@ -295,14 +313,17 @@ class _Record(object):
@n.state(initial=True) @n.state(initial=True)
def want_prologue(self): pass # pragma: no cover def want_prologue(self): pass # pragma: no cover
@n.state() @n.state()
def want_handshake(self): pass # pragma: no cover def want_handshake(self): pass # pragma: no cover
@n.state() @n.state()
def want_message(self): pass # pragma: no cover def want_message(self): pass # pragma: no cover
@n.input() @n.input()
def got_prologue(self): def got_prologue(self):
pass pass
@n.input() @n.input()
def got_frame(self, frame): def got_frame(self, frame):
pass pass
@ -397,17 +418,21 @@ class DilatedConnectionProtocol(Protocol, object):
@m.state(initial=True) @m.state(initial=True)
def unselected(self): pass # pragma: no cover def unselected(self): pass # pragma: no cover
@m.state() @m.state()
def selecting(self): pass # pragma: no cover def selecting(self): pass # pragma: no cover
@m.state() @m.state()
def selected(self): pass # pragma: no cover def selected(self): pass # pragma: no cover
@m.input() @m.input()
def got_kcm(self): def got_kcm(self):
pass pass
@m.input() @m.input()
def select(self, manager): def select(self, manager):
pass # fires set_manager() pass # fires set_manager()
@m.input() @m.input()
def got_record(self, record): def got_record(self, record):
pass pass

View File

@ -1,5 +1,6 @@
from __future__ import print_function, unicode_literals from __future__ import print_function, unicode_literals
import sys, re import sys
import re
from collections import defaultdict, namedtuple from collections import defaultdict, namedtuple
from binascii import hexlify from binascii import hexlify
import six import six
@ -30,7 +31,8 @@ from .roles import LEADER
# * expect to see the receiver/sender handshake bytes from the other side # * expect to see the receiver/sender handshake bytes from the other side
# * the sender writes "go\n", the receiver waits for "go\n" # * the sender writes "go\n", the receiver waits for "go\n"
# * the rest of the connection contains transit data # * the rest of the connection contains transit data
DirectTCPV1Hint = namedtuple("DirectTCPV1Hint", ["hostname", "port", "priority"]) DirectTCPV1Hint = namedtuple(
"DirectTCPV1Hint", ["hostname", "port", "priority"])
TorTCPV1Hint = namedtuple("TorTCPV1Hint", ["hostname", "port", "priority"]) TorTCPV1Hint = namedtuple("TorTCPV1Hint", ["hostname", "port", "priority"])
# RelayV1Hint contains a tuple of DirectTCPV1Hint and TorTCPV1Hint hints (we # RelayV1Hint contains a tuple of DirectTCPV1Hint and TorTCPV1Hint hints (we
# use a tuple rather than a list so they'll be hashable into a set). For each # use a tuple rather than a list so they'll be hashable into a set). For each
@ -38,6 +40,7 @@ TorTCPV1Hint = namedtuple("TorTCPV1Hint", ["hostname", "port", "priority"])
# rest of the V1 protocol. Only one hint per relay is useful. # rest of the V1 protocol. Only one hint per relay is useful.
RelayV1Hint = namedtuple("RelayV1Hint", ["hints"]) RelayV1Hint = namedtuple("RelayV1Hint", ["hints"])
def describe_hint_obj(hint, relay, tor): def describe_hint_obj(hint, relay, tor):
prefix = "tor->" if tor else "->" prefix = "tor->" if tor else "->"
if relay: if relay:
@ -49,6 +52,7 @@ def describe_hint_obj(hint, relay, tor):
else: else:
return prefix + str(hint) return prefix + str(hint)
def parse_hint_argv(hint, stderr=sys.stderr): def parse_hint_argv(hint, stderr=sys.stderr):
assert isinstance(hint, type("")) assert isinstance(hint, type(""))
# return tuple or None for an unparseable hint # return tuple or None for an unparseable hint
@ -59,7 +63,8 @@ def parse_hint_argv(hint, stderr=sys.stderr):
return None return None
hint_type = mo.group(1) hint_type = mo.group(1)
if hint_type != "tcp": if hint_type != "tcp":
print("unknown hint type '%s' in '%s'" % (hint_type, hint), file=stderr) print("unknown hint type '%s' in '%s'" % (hint_type, hint),
file=stderr)
return None return None
hint_value = mo.group(2) hint_value = mo.group(2)
pieces = hint_value.split(":") pieces = hint_value.split(":")
@ -84,17 +89,18 @@ def parse_hint_argv(hint, stderr=sys.stderr):
return None return None
return DirectTCPV1Hint(hint_host, hint_port, priority) return DirectTCPV1Hint(hint_host, hint_port, priority)
def parse_tcp_v1_hint(hint): # hint_struct -> hint_obj def parse_tcp_v1_hint(hint): # hint_struct -> hint_obj
hint_type = hint.get("type", "") hint_type = hint.get("type", "")
if hint_type not in ["direct-tcp-v1", "tor-tcp-v1"]: if hint_type not in ["direct-tcp-v1", "tor-tcp-v1"]:
log.msg("unknown hint type: %r" % (hint,)) log.msg("unknown hint type: %r" % (hint,))
return None return None
if not("hostname" in hint if not("hostname" in hint and
and isinstance(hint["hostname"], type(""))): isinstance(hint["hostname"], type(""))):
log.msg("invalid hostname in hint: %r" % (hint,)) log.msg("invalid hostname in hint: %r" % (hint,))
return None return None
if not("port" in hint if not("port" in hint and
and isinstance(hint["port"], six.integer_types)): isinstance(hint["port"], six.integer_types)):
log.msg("invalid port in hint: %r" % (hint,)) log.msg("invalid port in hint: %r" % (hint,))
return None return None
priority = hint.get("priority", 0.0) priority = hint.get("priority", 0.0)
@ -103,6 +109,7 @@ def parse_tcp_v1_hint(hint): # hint_struct -> hint_obj
else: else:
return TorTCPV1Hint(hint["hostname"], hint["port"], priority) return TorTCPV1Hint(hint["hostname"], hint["port"], priority)
def parse_hint(hint_struct): def parse_hint(hint_struct):
hint_type = hint_struct.get("type", "") hint_type = hint_struct.get("type", "")
if hint_type == "relay-v1": if hint_type == "relay-v1":
@ -112,6 +119,7 @@ def parse_hint(hint_struct):
return RelayV1Hint(rhints) return RelayV1Hint(rhints)
return parse_tcp_v1_hint(hint_struct) return parse_tcp_v1_hint(hint_struct)
def encode_hint(h): def encode_hint(h):
if isinstance(h, DirectTCPV1Hint): if isinstance(h, DirectTCPV1Hint):
return {"type": "direct-tcp-v1", return {"type": "direct-tcp-v1",
@ -135,19 +143,24 @@ def encode_hint(h):
} }
raise ValueError("unknown hint type", h) raise ValueError("unknown hint type", h)
def HKDF(skm, outlen, salt=None, CTXinfo=b""): def HKDF(skm, outlen, salt=None, CTXinfo=b""):
return Hkdf(salt, skm).expand(CTXinfo, outlen) return Hkdf(salt, skm).expand(CTXinfo, outlen)
def build_sided_relay_handshake(key, side): def build_sided_relay_handshake(key, side):
assert isinstance(side, type(u"")) assert isinstance(side, type(u""))
assert len(side) == 8 * 2 assert len(side) == 8 * 2
token = HKDF(key, 32, CTXinfo=b"transit_relay_token") token = HKDF(key, 32, CTXinfo=b"transit_relay_token")
return b"please relay "+hexlify(token)+b" for side "+side.encode("ascii")+b"\n" return (b"please relay " + hexlify(token) +
b" for side " + side.encode("ascii") + b"\n")
PROLOGUE_LEADER = b"Magic-Wormhole Dilation Handshake v1 Leader\n\n" PROLOGUE_LEADER = b"Magic-Wormhole Dilation Handshake v1 Leader\n\n"
PROLOGUE_FOLLOWER = b"Magic-Wormhole Dilation Handshake v1 Follower\n\n" PROLOGUE_FOLLOWER = b"Magic-Wormhole Dilation Handshake v1 Follower\n\n"
NOISEPROTO = "Noise_NNpsk0_25519_ChaChaPoly_BLAKE2s" NOISEPROTO = "Noise_NNpsk0_25519_ChaChaPoly_BLAKE2s"
@attrs @attrs
@implementer(IDilationConnector) @implementer(IDilationConnector)
class Connector(object): class Connector(object):
@ -178,7 +191,8 @@ class Connector(object):
self._transit_relays = [] self._transit_relays = []
self._listeners = set() # IListeningPorts that can be stopped self._listeners = set() # IListeningPorts that can be stopped
self._pending_connectors = set() # Deferreds that can be cancelled self._pending_connectors = set() # Deferreds that can be cancelled
self._pending_connections = EmptyableSet(_eventual_queue=self._eventual_queue) # Protocols to be stopped self._pending_connections = EmptyableSet(
_eventual_queue=self._eventual_queue) # Protocols to be stopped
self._contenders = set() # viable connections self._contenders = set() # viable connections
self._winning_connection = None self._winning_connection = None
self._timing = self._timing or DebugTiming() self._timing = self._timing or DebugTiming()
@ -212,26 +226,41 @@ class Connector(object):
return p return p
@m.state(initial=True) @m.state(initial=True)
def connecting(self): pass # pragma: no cover def connecting(self):
pass # pragma: no cover
@m.state() @m.state()
def connected(self): pass # pragma: no cover def connected(self):
pass # pragma: no cover
@m.state(terminal=True) @m.state(terminal=True)
def stopped(self): pass # pragma: no cover def stopped(self):
pass # pragma: no cover
# TODO: unify the tense of these method-name verbs # TODO: unify the tense of these method-name verbs
@m.input() @m.input()
def listener_ready(self, hint_objs): pass def listener_ready(self, hint_objs):
pass
@m.input() @m.input()
def add_relay(self, hint_objs): pass def add_relay(self, hint_objs):
pass
@m.input() @m.input()
def got_hints(self, hint_objs): pass def got_hints(self, hint_objs):
pass
@m.input() @m.input()
def add_candidate(self, c): # called by DilatedConnectionProtocol def add_candidate(self, c): # called by DilatedConnectionProtocol
pass pass
@m.input() @m.input()
def accept(self, c): pass def accept(self, c):
pass
@m.input() @m.input()
def stop(self): pass def stop(self):
pass
@m.output() @m.output()
def use_hints(self, hint_objs): def use_hints(self, hint_objs):
@ -306,7 +335,8 @@ class Connector(object):
publish_hints]) publish_hints])
connecting.upon(got_hints, enter=connecting, outputs=[use_hints]) connecting.upon(got_hints, enter=connecting, outputs=[use_hints])
connecting.upon(add_candidate, enter=connecting, outputs=[consider]) connecting.upon(add_candidate, enter=connecting, outputs=[consider])
connecting.upon(accept, enter=connected, outputs=[select_and_stop_remaining]) connecting.upon(accept, enter=connected, outputs=[
select_and_stop_remaining])
connecting.upon(stop, enter=stopped, outputs=[stop_everything]) connecting.upon(stop, enter=stopped, outputs=[stop_everything])
# once connected, we ignore everything except stop # once connected, we ignore everything except stop
@ -317,9 +347,9 @@ class Connector(object):
connected.upon(accept, enter=connected, outputs=[]) connected.upon(accept, enter=connected, outputs=[])
connected.upon(stop, enter=stopped, outputs=[stop_everything]) connected.upon(stop, enter=stopped, outputs=[stop_everything])
# from Manager: start, got_hints, stop # from Manager: start, got_hints, stop
# maybe add_candidate, accept # maybe add_candidate, accept
def start(self): def start(self):
self._start_listener() self._start_listener()
if self._transit_relays: if self._transit_relays:
@ -341,6 +371,7 @@ class Connector(object):
ep = serverFromString(self._reactor, "tcp:0") ep = serverFromString(self._reactor, "tcp:0")
f = InboundConnectionFactory(self) f = InboundConnectionFactory(self)
d = ep.listen(f) d = ep.listen(f)
def _listening(lp): def _listening(lp):
# lp is an IListeningPort # lp is an IListeningPort
self._listeners.add(lp) # for shutdown and tests self._listeners.add(lp) # for shutdown and tests
@ -422,6 +453,7 @@ class Connector(object):
f = OutboundConnectionFactory(self, relay_handshake) f = OutboundConnectionFactory(self, relay_handshake)
d = ep.connect(f) d = ep.connect(f)
# fires with protocol, or ConnectError # fires with protocol, or ConnectError
def _connected(p): def _connected(p):
self._pending_connections.add(p) self._pending_connections.add(p)
# c might not be in _pending_connections, if it turned out to be a # c might not be in _pending_connections, if it turned out to be a
@ -444,7 +476,6 @@ class Connector(object):
return HostnameEndpoint(self._reactor, hint.hostname, hint.port) return HostnameEndpoint(self._reactor, hint.hostname, hint.port)
return None return None
# Connection selection. All instances of DilatedConnectionProtocol which # Connection selection. All instances of DilatedConnectionProtocol which
# look viable get passed into our add_contender() method. # look viable get passed into our add_contender() method.
@ -459,6 +490,7 @@ class Connector(object):
# our Connection protocols call: add_candidate # our Connection protocols call: add_candidate
@attrs @attrs
class OutboundConnectionFactory(ClientFactory, object): class OutboundConnectionFactory(ClientFactory, object):
_connector = attrib(validator=provides(IDilationConnector)) _connector = attrib(validator=provides(IDilationConnector))
@ -471,6 +503,7 @@ class OutboundConnectionFactory(ClientFactory, object):
p.use_relay(self._relay_handshake) p.use_relay(self._relay_handshake)
return p return p
@attrs @attrs
class InboundConnectionFactory(ServerFactory, object): class InboundConnectionFactory(ServerFactory, object):
_connector = attrib(validator=provides(IDilationConnector)) _connector = attrib(validator=provides(IDilationConnector))

View File

@ -4,10 +4,13 @@ import struct
assert len(struct.pack("<L", 0)) == 4 assert len(struct.pack("<L", 0)) == 4
assert len(struct.pack("<Q", 0)) == 8 assert len(struct.pack("<Q", 0)) == 8
def to_be4(value): def to_be4(value):
if not 0 <= value < 2**32: if not 0 <= value < 2**32:
raise ValueError raise ValueError
return struct.pack(">L", value) return struct.pack(">L", value)
def from_be4(b): def from_be4(b):
if not isinstance(b, bytes): if not isinstance(b, bytes):
raise TypeError(repr(b)) raise TypeError(repr(b))

View File

@ -6,13 +6,19 @@ from twisted.python import log
from .._interfaces import IDilationManager, IInbound from .._interfaces import IDilationManager, IInbound
from .subchannel import (SubChannel, _SubchannelAddress) from .subchannel import (SubChannel, _SubchannelAddress)
class DuplicateOpenError(Exception): class DuplicateOpenError(Exception):
pass pass
class DataForMissingSubchannelError(Exception): class DataForMissingSubchannelError(Exception):
pass pass
class CloseForMissingSubchannelError(Exception): class CloseForMissingSubchannelError(Exception):
pass pass
@attrs @attrs
@implementer(IInbound) @implementer(IInbound)
class Inbound(object): class Inbound(object):
@ -37,7 +43,6 @@ class Inbound(object):
def set_subchannel_zero(self, scid0, sc0): def set_subchannel_zero(self, scid0, sc0):
self._open_subchannels[scid0] = sc0 self._open_subchannels[scid0] = sc0
def use_connection(self, c): def use_connection(self, c):
self._connection = c self._connection = c
# We can pause the connection's reads when we receive too much data. If # We can pause the connection's reads when we receive too much data. If
@ -61,7 +66,8 @@ class Inbound(object):
def handle_open(self, scid): def handle_open(self, scid):
if scid in self._open_subchannels: if scid in self._open_subchannels:
log.err(DuplicateOpenError("received duplicate OPEN for {}".format(scid))) log.err(DuplicateOpenError(
"received duplicate OPEN for {}".format(scid)))
return return
peer_addr = _SubchannelAddress(scid) peer_addr = _SubchannelAddress(scid)
sc = SubChannel(scid, self._manager, self._host_addr, peer_addr) sc = SubChannel(scid, self._manager, self._host_addr, peer_addr)
@ -71,14 +77,16 @@ class Inbound(object):
def handle_data(self, scid, data): def handle_data(self, scid, data):
sc = self._open_subchannels.get(scid) sc = self._open_subchannels.get(scid)
if sc is None: if sc is None:
log.err(DataForMissingSubchannelError("received DATA for non-existent subchannel {}".format(scid))) log.err(DataForMissingSubchannelError(
"received DATA for non-existent subchannel {}".format(scid)))
return return
sc.remote_data(data) sc.remote_data(data)
def handle_close(self, scid): def handle_close(self, scid):
sc = self._open_subchannels.get(scid) sc = self._open_subchannels.get(scid)
if sc is None: if sc is None:
log.err(CloseForMissingSubchannelError("received CLOSE for non-existent subchannel {}".format(scid))) log.err(CloseForMissingSubchannelError(
"received CLOSE for non-existent subchannel {}".format(scid)))
return return
sc.remote_close() sc.remote_close()
@ -90,7 +98,6 @@ class Inbound(object):
def stop_using_connection(self): def stop_using_connection(self):
self._connection = None self._connection = None
# from our Subchannel, or rather from the Protocol above it and sent # from our Subchannel, or rather from the Protocol above it and sent
# through the subchannel # through the subchannel

View File

@ -20,13 +20,19 @@ from .connection import KCM, Ping, Pong, Open, Data, Close, Ack
from .inbound import Inbound from .inbound import Inbound
from .outbound import Outbound from .outbound import Outbound
class OldPeerCannotDilateError(Exception): class OldPeerCannotDilateError(Exception):
pass pass
class UnknownDilationMessageType(Exception): class UnknownDilationMessageType(Exception):
pass pass
class ReceivedHintsTooEarly(Exception): class ReceivedHintsTooEarly(Exception):
pass pass
@attrs @attrs
@implementer(IDilationManager) @implementer(IDilationManager)
class _ManagerBase(object): class _ManagerBase(object):
@ -64,13 +70,13 @@ class _ManagerBase(object):
def set_listener_endpoint(self, listener_endpoint): def set_listener_endpoint(self, listener_endpoint):
self._inbound.set_listener_endpoint(listener_endpoint) self._inbound.set_listener_endpoint(listener_endpoint)
def set_subchannel_zero(self, scid0, sc0): def set_subchannel_zero(self, scid0, sc0):
self._inbound.set_subchannel_zero(scid0, sc0) self._inbound.set_subchannel_zero(scid0, sc0)
def when_first_connected(self): def when_first_connected(self):
return self._first_connected.when_fired() return self._first_connected.when_fired()
def send_dilation_phase(self, **fields): def send_dilation_phase(self, **fields):
dilation_phase = self._next_dilation_phase dilation_phase = self._next_dilation_phase
self._next_dilation_phase += 1 self._next_dilation_phase += 1
@ -79,25 +85,30 @@ class _ManagerBase(object):
def send_hints(self, hints): # from Connector def send_hints(self, hints): # from Connector
self.send_dilation_phase(type="connection-hints", hints=hints) self.send_dilation_phase(type="connection-hints", hints=hints)
# forward inbound-ish things to _Inbound # forward inbound-ish things to _Inbound
def subchannel_pauseProducing(self, sc): def subchannel_pauseProducing(self, sc):
self._inbound.subchannel_pauseProducing(sc) self._inbound.subchannel_pauseProducing(sc)
def subchannel_resumeProducing(self, sc): def subchannel_resumeProducing(self, sc):
self._inbound.subchannel_resumeProducing(sc) self._inbound.subchannel_resumeProducing(sc)
def subchannel_stopProducing(self, sc): def subchannel_stopProducing(self, sc):
self._inbound.subchannel_stopProducing(sc) self._inbound.subchannel_stopProducing(sc)
# forward outbound-ish things to _Outbound # forward outbound-ish things to _Outbound
def subchannel_registerProducer(self, sc, producer, streaming): def subchannel_registerProducer(self, sc, producer, streaming):
self._outbound.subchannel_registerProducer(sc, producer, streaming) self._outbound.subchannel_registerProducer(sc, producer, streaming)
def subchannel_unregisterProducer(self, sc): def subchannel_unregisterProducer(self, sc):
self._outbound.subchannel_unregisterProducer(sc) self._outbound.subchannel_unregisterProducer(sc)
def send_open(self, scid): def send_open(self, scid):
self._queue_and_send(Open, scid) self._queue_and_send(Open, scid)
def send_data(self, scid, data): def send_data(self, scid, data):
self._queue_and_send(Data, scid, data) self._queue_and_send(Data, scid, data)
def send_close(self, scid): def send_close(self, scid):
self._queue_and_send(Close, scid) self._queue_and_send(Close, scid)
@ -116,7 +127,6 @@ class _ManagerBase(object):
self._inbound.subchannel_closed(scid, sc) self._inbound.subchannel_closed(scid, sc)
self._outbound.subchannel_closed(scid, sc) self._outbound.subchannel_closed(scid, sc)
def _start_connecting(self, role): def _start_connecting(self, role):
assert self._my_role is not None assert self._my_role is not None
self._connector = Connector(self._transit_key, self._connector = Connector(self._transit_key,
@ -140,6 +150,7 @@ class _ManagerBase(object):
self._made_first_connection = True self._made_first_connection = True
self._first_connected.fire(None) self._first_connected.fire(None)
pass pass
def connector_connection_lost(self): def connector_connection_lost(self):
self._stop_using_connection() self._stop_using_connection()
if self.role is LEADER: if self.role is LEADER:
@ -147,7 +158,6 @@ class _ManagerBase(object):
else: else:
self.connection_lost_follower() self.connection_lost_follower()
def _stop_using_connection(self): def _stop_using_connection(self):
# the connection is already lost by this point # the connection is already lost by this point
self._connection = None self._connection = None
@ -190,7 +200,6 @@ class _ManagerBase(object):
def send_ack(self, resp_seqnum): def send_ack(self, resp_seqnum):
self._outbound.send_if_connected(Ack(resp_seqnum)) self._outbound.send_if_connected(Ack(resp_seqnum))
def handle_ping(self, ping_id): def handle_ping(self, ping_id):
self.send_pong(ping_id) self.send_pong(ping_id)
@ -200,7 +209,7 @@ class _ManagerBase(object):
# subchannel maintenance # subchannel maintenance
def allocate_subchannel_id(self): def allocate_subchannel_id(self):
raise NotImplemented # subclass knows if we're leader or follower raise NotImplementedError # subclass knows if we're leader or follower
# new scheme: # new scheme:
# * both sides send PLEASE as soon as they have an unverified key and # * both sides send PLEASE as soon as they have an unverified key and
@ -236,58 +245,91 @@ class _ManagerBase(object):
# * if follower calls w.dilate() but not leader, follower waits forever # * if follower calls w.dilate() but not leader, follower waits forever
# in "want", leader waits forever in "wanted" # in "want", leader waits forever in "wanted"
class ManagerShared(_ManagerBase): class ManagerShared(_ManagerBase):
m = MethodicalMachine() m = MethodicalMachine()
set_trace = getattr(m, "_setTrace", lambda self, f: None) set_trace = getattr(m, "_setTrace", lambda self, f: None)
@m.state(initial=True) @m.state(initial=True)
def IDLE(self): pass # pragma: no cover def IDLE(self):
pass # pragma: no cover
@m.state() @m.state()
def WANTING(self): pass # pragma: no cover def WANTING(self):
pass # pragma: no cover
@m.state() @m.state()
def WANTED(self): pass # pragma: no cover def WANTED(self):
pass # pragma: no cover
@m.state() @m.state()
def CONNECTING(self): pass # pragma: no cover def CONNECTING(self):
pass # pragma: no cover
@m.state() @m.state()
def CONNECTED(self): pass # pragma: no cover def CONNECTED(self):
pass # pragma: no cover
@m.state() @m.state()
def FLUSHING(self): pass # pragma: no cover def FLUSHING(self):
pass # pragma: no cover
@m.state() @m.state()
def ABANDONING(self): pass # pragma: no cover def ABANDONING(self):
pass # pragma: no cover
@m.state() @m.state()
def LONELY(self): pass # pragme: no cover def LONELY(self):
pass # pragme: no cover
@m.state() @m.state()
def STOPPING(self): pass # pragma: no cover def STOPPING(self):
pass # pragma: no cover
@m.state(terminal=True) @m.state(terminal=True)
def STOPPED(self): pass # pragma: no cover def STOPPED(self):
pass # pragma: no cover
@m.input() @m.input()
def start(self): pass # pragma: no cover def start(self):
pass # pragma: no cover
@m.input() @m.input()
def rx_PLEASE(self, message): pass # pragma: no cover def rx_PLEASE(self, message):
pass # pragma: no cover
@m.input() # only sent by Follower @m.input() # only sent by Follower
def rx_HINTS(self, hint_message): pass # pragma: no cover def rx_HINTS(self, hint_message):
pass # pragma: no cover
@m.input() # only Leader sends RECONNECT, so only Follower receives it @m.input() # only Leader sends RECONNECT, so only Follower receives it
def rx_RECONNECT(self): pass # pragma: no cover def rx_RECONNECT(self):
pass # pragma: no cover
@m.input() # only Follower sends RECONNECTING, so only Leader receives it @m.input() # only Follower sends RECONNECTING, so only Leader receives it
def rx_RECONNECTING(self): pass # pragma: no cover def rx_RECONNECTING(self):
pass # pragma: no cover
# Connector gives us connection_made() # Connector gives us connection_made()
@m.input() @m.input()
def connection_made(self, c): pass # pragma: no cover def connection_made(self, c):
pass # pragma: no cover
# our connection_lost() fires connection_lost_leader or # our connection_lost() fires connection_lost_leader or
# connection_lost_follower depending upon our role. If either side sees a # connection_lost_follower depending upon our role. If either side sees a
# problem with the connection (timeouts, bad authentication) then they # problem with the connection (timeouts, bad authentication) then they
# just drop it and let connection_lost() handle the cleanup. # just drop it and let connection_lost() handle the cleanup.
@m.input() @m.input()
def connection_lost_leader(self): pass # pragma: no cover def connection_lost_leader(self):
@m.input() pass # pragma: no cover
def connection_lost_follower(self): pass
@m.input() @m.input()
def stop(self): pass # pragma: no cover def connection_lost_follower(self):
pass
@m.input()
def stop(self):
pass # pragma: no cover
@m.output() @m.output()
def stash_side(self, message): def stash_side(self, message):
@ -302,6 +344,7 @@ class ManagerShared(_ManagerBase):
@m.output() @m.output()
def start_connecting(self): def start_connecting(self):
self._start_connecting() # TODO: merge self._start_connecting() # TODO: merge
@m.output() @m.output()
def ignore_message_start_connecting(self, message): def ignore_message_start_connecting(self, message):
self.start_connecting() self.start_connecting()
@ -309,6 +352,7 @@ class ManagerShared(_ManagerBase):
@m.output() @m.output()
def send_reconnect(self): def send_reconnect(self):
self.send_dilation_phase(type="reconnect") # TODO: generation number? self.send_dilation_phase(type="reconnect") # TODO: generation number?
@m.output() @m.output()
def send_reconnecting(self): def send_reconnecting(self):
self.send_dilation_phase(type="reconnecting") # TODO: generation? self.send_dilation_phase(type="reconnecting") # TODO: generation?
@ -319,20 +363,22 @@ class ManagerShared(_ManagerBase):
[parse_hint(hs) for hs in hint_message["hints"]]) [parse_hint(hs) for hs in hint_message["hints"]])
hint_objs = list(hint_objs) hint_objs = list(hint_objs)
self._connector.got_hints(hint_objs) self._connector.got_hints(hint_objs)
@m.output() @m.output()
def stop_connecting(self): def stop_connecting(self):
self._connector.stop() self._connector.stop()
@m.output() @m.output()
def abandon_connection(self): def abandon_connection(self):
# we think we're still connected, but the Leader disagrees. Or we've # we think we're still connected, but the Leader disagrees. Or we've
# been told to shut down. # been told to shut down.
self._connection.disconnect() # let connection_lost do cleanup self._connection.disconnect() # let connection_lost do cleanup
# we don't start CONNECTING until a local start() plus rx_PLEASE # we don't start CONNECTING until a local start() plus rx_PLEASE
IDLE.upon(rx_PLEASE, enter=WANTED, outputs=[stash_side]) IDLE.upon(rx_PLEASE, enter=WANTED, outputs=[stash_side])
IDLE.upon(start, enter=WANTING, outputs=[send_please]) IDLE.upon(start, enter=WANTING, outputs=[send_please])
WANTED.upon(start, enter=CONNECTING, outputs=[send_please, start_connecting]) WANTED.upon(start, enter=CONNECTING, outputs=[
send_please, start_connecting])
WANTING.upon(rx_PLEASE, enter=CONNECTING, WANTING.upon(rx_PLEASE, enter=CONNECTING,
outputs=[stash_side, outputs=[stash_side,
ignore_message_start_connecting]) ignore_message_start_connecting])
@ -342,7 +388,8 @@ class ManagerShared(_ManagerBase):
# Leader # Leader
CONNECTED.upon(connection_lost_leader, enter=FLUSHING, CONNECTED.upon(connection_lost_leader, enter=FLUSHING,
outputs=[send_reconnect]) outputs=[send_reconnect])
FLUSHING.upon(rx_RECONNECTING, enter=CONNECTING, outputs=[start_connecting]) FLUSHING.upon(rx_RECONNECTING, enter=CONNECTING,
outputs=[start_connecting])
# Follower # Follower
# if we notice a lost connection, just wait for the Leader to notice too # if we notice a lost connection, just wait for the Leader to notice too
@ -362,7 +409,6 @@ class ManagerShared(_ManagerBase):
send_reconnecting, send_reconnecting,
start_connecting]) start_connecting])
# rx_HINTS never changes state, they're just accepted or ignored # rx_HINTS never changes state, they're just accepted or ignored
IDLE.upon(rx_HINTS, enter=IDLE, outputs=[]) # too early IDLE.upon(rx_HINTS, enter=IDLE, outputs=[]) # too early
WANTED.upon(rx_HINTS, enter=WANTED, outputs=[]) # too early WANTED.upon(rx_HINTS, enter=WANTED, outputs=[]) # too early
@ -385,7 +431,6 @@ class ManagerShared(_ManagerBase):
STOPPING.upon(connection_lost_leader, enter=STOPPED, outputs=[]) STOPPING.upon(connection_lost_leader, enter=STOPPED, outputs=[])
STOPPING.upon(connection_lost_follower, enter=STOPPED, outputs=[]) STOPPING.upon(connection_lost_follower, enter=STOPPED, outputs=[])
def allocate_subchannel_id(self): def allocate_subchannel_id(self):
# scid 0 is reserved for the control channel. the leader uses odd # scid 0 is reserved for the control channel. the leader uses odd
# numbers starting with 1 # numbers starting with 1
@ -393,6 +438,7 @@ class ManagerShared(_ManagerBase):
self._next_outbound_seqnum += 2 self._next_outbound_seqnum += 2
return to_be4(scid_num) return to_be4(scid_num)
@attrs @attrs
@implementer(IDilator) @implementer(IDilator)
class Dilator(object): class Dilator(object):
@ -462,7 +508,8 @@ class Dilator(object):
control_ep._subchannel_zero_opened(sc0) control_ep._subchannel_zero_opened(sc0)
self._manager.set_subchannel_zero(scid0, sc0) self._manager.set_subchannel_zero(scid0, sc0)
connect_ep = SubchannelConnectorEndpoint(self._manager, self._host_addr) connect_ep = SubchannelConnectorEndpoint(
self._manager, self._host_addr)
listen_ep = SubchannelListenerEndpoint(self._manager, self._host_addr) listen_ep = SubchannelListenerEndpoint(self._manager, self._host_addr)
self._manager.set_listener_endpoint(listen_ep) self._manager.set_listener_endpoint(listen_ep)

View File

@ -290,9 +290,9 @@ class Outbound(object):
# Inbound is responsible for tracking the high watermark and deciding # Inbound is responsible for tracking the high watermark and deciding
# whether to ignore inbound messages or not # whether to ignore inbound messages or not
# IProducer: the active connection calls these because we used # IProducer: the active connection calls these because we used
# c.registerProducer to ask for them # c.registerProducer to ask for them
def pauseProducing(self): def pauseProducing(self):
if self._paused: if self._paused:
return # someone is confused and called us twice return # someone is confused and called us twice
@ -351,14 +351,14 @@ class PullToPush(object):
while True: while True:
try: try:
self._producer.resumeProducing() self._producer.resumeProducing()
except: except Exception:
log.err(None, "%s failed, producing will be stopped:" % log.err(None, "%s failed, producing will be stopped:" %
(safe_str(self._producer),)) (safe_str(self._producer),))
try: try:
self._unregister() self._unregister()
# The consumer should now call stopStreaming() on us, # The consumer should now call stopStreaming() on us,
# thus stopping the streaming. # thus stopping the streaming.
except: except Exception:
# Since the consumer blew up, we may not have had # Since the consumer blew up, we may not have had
# stopStreaming() called, so we just stop on our own: # stopStreaming() called, so we just stop on our own:
log.err(None, "%s failed to unregister producer:" % log.err(None, "%s failed to unregister producer:" %
@ -378,15 +378,12 @@ class PullToPush(object):
self._finished = True self._finished = True
self._coopTask.stop() self._coopTask.stop()
def pauseProducing(self): def pauseProducing(self):
self._coopTask.pause() self._coopTask.pause()
def resumeProducing(self): def resumeProducing(self):
self._coopTask.resume() self._coopTask.resume()
def stopProducing(self): def stopProducing(self):
self.stopStreaming() self.stopStreaming()
self._producer.stopProducing() self._producer.stopProducing()

View File

@ -1,7 +1,8 @@
from attr import attrs, attrib from attr import attrs, attrib
from attr.validators import instance_of, provides from attr.validators import instance_of, provides
from zope.interface import implementer from zope.interface import implementer
from twisted.internet.defer import Deferred, inlineCallbacks, returnValue, succeed from twisted.internet.defer import (Deferred, inlineCallbacks, returnValue,
succeed)
from twisted.internet.interfaces import (ITransport, IProducer, IConsumer, from twisted.internet.interfaces import (ITransport, IProducer, IConsumer,
IAddress, IListeningPort, IAddress, IListeningPort,
IStreamClientEndpoint, IStreamClientEndpoint,
@ -10,9 +11,11 @@ from twisted.internet.error import ConnectionDone
from automat import MethodicalMachine from automat import MethodicalMachine
from .._interfaces import ISubChannel, IDilationManager from .._interfaces import ISubChannel, IDilationManager
@attrs @attrs
class Once(object): class Once(object):
_errtype = attrib() _errtype = attrib()
def __attrs_post_init__(self): def __attrs_post_init__(self):
self._called = False self._called = False
@ -21,6 +24,7 @@ class Once(object):
raise self._errtype() raise self._errtype()
self._called = True self._called = True
class SingleUseEndpointError(Exception): class SingleUseEndpointError(Exception):
pass pass
@ -38,13 +42,16 @@ class SingleUseEndpointError(Exception):
# (CLOSING) rx CLOSE: deliver .connectionLost(), -> (CLOSED) # (CLOSING) rx CLOSE: deliver .connectionLost(), -> (CLOSED)
# object is deleted upon transition to (CLOSED) # object is deleted upon transition to (CLOSED)
class AlreadyClosedError(Exception): class AlreadyClosedError(Exception):
pass pass
@implementer(IAddress) @implementer(IAddress)
class _WormholeAddress(object): class _WormholeAddress(object):
pass pass
@implementer(IAddress) @implementer(IAddress)
@attrs @attrs
class _SubchannelAddress(object): class _SubchannelAddress(object):
@ -63,7 +70,8 @@ class SubChannel(object):
_peer_addr = attrib(validator=instance_of(_SubchannelAddress)) _peer_addr = attrib(validator=instance_of(_SubchannelAddress))
m = MethodicalMachine() m = MethodicalMachine()
set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover set_trace = getattr(m, "_setTrace", lambda self,
f: None) # pragma: no cover
def __attrs_post_init__(self): def __attrs_post_init__(self):
# self._mailbox = None # self._mailbox = None
@ -74,24 +82,32 @@ class SubChannel(object):
self._pending_connectionLost = (False, None) self._pending_connectionLost = (False, None)
@m.state(initial=True) @m.state(initial=True)
def open(self): pass # pragma: no cover def open(self):
pass # pragma: no cover
@m.state() @m.state()
def closing(): pass # pragma: no cover def closing():
pass # pragma: no cover
@m.state() @m.state()
def closed(): pass # pragma: no cover def closed():
pass # pragma: no cover
@m.input() @m.input()
def remote_data(self, data): pass def remote_data(self, data):
@m.input() pass
def remote_close(self): pass
@m.input() @m.input()
def local_data(self, data): pass def remote_close(self):
@m.input() pass
def local_close(self): pass
@m.input()
def local_data(self, data):
pass
@m.input()
def local_close(self):
pass
@m.output() @m.output()
def send_data(self, data): def send_data(self, data):
@ -120,9 +136,11 @@ class SubChannel(object):
@m.output() @m.output()
def error_closed_write(self, data): def error_closed_write(self, data):
raise AlreadyClosedError("write not allowed on closed subchannel") raise AlreadyClosedError("write not allowed on closed subchannel")
@m.output() @m.output()
def error_closed_close(self): def error_closed_close(self):
raise AlreadyClosedError("loseConnection not allowed on closed subchannel") raise AlreadyClosedError(
"loseConnection not allowed on closed subchannel")
# primary transitions # primary transitions
open.upon(remote_data, enter=open, outputs=[signal_dataReceived]) open.upon(remote_data, enter=open, outputs=[signal_dataReceived])
@ -155,13 +173,17 @@ class SubChannel(object):
def write(self, data): def write(self, data):
assert isinstance(data, type(b"")) assert isinstance(data, type(b""))
self.local_data(data) self.local_data(data)
def writeSequence(self, iovec): def writeSequence(self, iovec):
self.write(b"".join(iovec)) self.write(b"".join(iovec))
def loseConnection(self): def loseConnection(self):
self.local_close() self.local_close()
def getHost(self): def getHost(self):
# we define "host addr" as the overall wormhole # we define "host addr" as the overall wormhole
return self._host_addr return self._host_addr
def getPeer(self): def getPeer(self):
# and "peer addr" as the subchannel within that wormhole # and "peer addr" as the subchannel within that wormhole
return self._peer_addr return self._peer_addr
@ -169,14 +191,17 @@ class SubChannel(object):
# IProducer: throttle inbound data (wormhole "up" to local app's Protocol) # IProducer: throttle inbound data (wormhole "up" to local app's Protocol)
def stopProducing(self): def stopProducing(self):
self._manager.subchannel_stopProducing(self) self._manager.subchannel_stopProducing(self)
def pauseProducing(self): def pauseProducing(self):
self._manager.subchannel_pauseProducing(self) self._manager.subchannel_pauseProducing(self)
def resumeProducing(self): def resumeProducing(self):
self._manager.subchannel_resumeProducing(self) self._manager.subchannel_resumeProducing(self)
# IConsumer: allow the wormhole to throttle outbound data (app->wormhole) # IConsumer: allow the wormhole to throttle outbound data (app->wormhole)
def registerProducer(self, producer, streaming): def registerProducer(self, producer, streaming):
self._manager.subchannel_registerProducer(self, producer, streaming) self._manager.subchannel_registerProducer(self, producer, streaming)
def unregisterProducer(self): def unregisterProducer(self):
self._manager.subchannel_unregisterProducer(self) self._manager.subchannel_unregisterProducer(self)
@ -184,6 +209,7 @@ class SubChannel(object):
@implementer(IStreamClientEndpoint) @implementer(IStreamClientEndpoint)
class ControlEndpoint(object): class ControlEndpoint(object):
_used = False _used = False
def __init__(self, peer_addr): def __init__(self, peer_addr):
self._subchannel_zero = Deferred() self._subchannel_zero = Deferred()
self._peer_addr = peer_addr self._peer_addr = peer_addr
@ -204,6 +230,7 @@ class ControlEndpoint(object):
p.makeConnection(t) # set p.transport = t and call connectionMade() p.makeConnection(t) # set p.transport = t and call connectionMade()
returnValue(p) returnValue(p)
@implementer(IStreamClientEndpoint) @implementer(IStreamClientEndpoint)
@attrs @attrs
class SubchannelConnectorEndpoint(object): class SubchannelConnectorEndpoint(object):
@ -223,6 +250,7 @@ class SubchannelConnectorEndpoint(object):
p.makeConnection(t) # set p.transport = t and call connectionMade() p.makeConnection(t) # set p.transport = t and call connectionMade()
return succeed(p) return succeed(p)
@implementer(IStreamServerEndpoint) @implementer(IStreamServerEndpoint)
@attrs @attrs
class SubchannelListenerEndpoint(object): class SubchannelListenerEndpoint(object):
@ -255,6 +283,7 @@ class SubchannelListenerEndpoint(object):
lp = SubchannelListeningPort(self._host_addr) lp = SubchannelListeningPort(self._host_addr)
return succeed(lp) return succeed(lp)
@implementer(IListeningPort) @implementer(IListeningPort)
@attrs @attrs
class SubchannelListeningPort(object): class SubchannelListeningPort(object):
@ -262,8 +291,10 @@ class SubchannelListeningPort(object):
def startListening(self): def startListening(self):
pass pass
def stopListening(self): def stopListening(self):
# TODO # TODO
pass pass
def getHost(self): def getHost(self):
return self._host_addr return self._host_addr