fix some flake8 complaints

This commit is contained in:
Brian Warner 2018-06-30 16:19:41 -07:00
parent 39666f3fed
commit 05900bd08b
8 changed files with 339 additions and 195 deletions

View File

@ -67,7 +67,8 @@ class Boss(object):
self._I = Input(self._timing) self._I = Input(self._timing)
self._C = Code(self._timing) self._C = Code(self._timing)
self._T = Terminator() self._T = Terminator()
self._D = Dilator(self._reactor, self._eventual_queue, self._cooperator) self._D = Dilator(self._reactor, self._eventual_queue,
self._cooperator)
self._N.wire(self._M, self._I, self._RC, self._T) self._N.wire(self._M, self._I, self._RC, self._T)
self._M.wire(self._N, self._RC, self._O, self._T) self._M.wire(self._N, self._RC, self._O, self._T)
@ -90,7 +91,7 @@ class Boss(object):
self._rx_phases = {} # phase -> plaintext self._rx_phases = {} # phase -> plaintext
self._next_rx_dilate_seqnum = 0 self._next_rx_dilate_seqnum = 0
self._rx_dilate_seqnums = {} # seqnum -> plaintext self._rx_dilate_seqnums = {} # seqnum -> plaintext
self._result = "empty" self._result = "empty"
@ -205,7 +206,7 @@ class Boss(object):
self._C.set_code(code) self._C.set_code(code)
def dilate(self): def dilate(self):
return self._D.dilate() # fires with endpoints return self._D.dilate() # fires with endpoints
@m.input() @m.input()
def send(self, plaintext): def send(self, plaintext):

View File

@ -39,20 +39,28 @@ from .roles import FOLLOWER
# states). For the specific question of sending plaintext frames, Noise will # states). For the specific question of sending plaintext frames, Noise will
# refuse us unless it's ready anyways, so the question is probably moot. # refuse us unless it's ready anyways, so the question is probably moot.
class IFramer(Interface): class IFramer(Interface):
pass pass
class IRecord(Interface): class IRecord(Interface):
pass pass
def first(l): def first(l):
return l[0] return l[0]
class Disconnect(Exception): class Disconnect(Exception):
pass pass
RelayOK = namedtuple("RelayOk", []) RelayOK = namedtuple("RelayOk", [])
Prologue = namedtuple("Prologue", []) Prologue = namedtuple("Prologue", [])
Frame = namedtuple("Frame", ["frame"]) Frame = namedtuple("Frame", ["frame"])
@attrs @attrs
@implementer(IFramer) @implementer(IFramer)
class _Framer(object): class _Framer(object):
@ -69,30 +77,37 @@ class _Framer(object):
# out (shared): transport.write (relay handshake, prologue) # out (shared): transport.write (relay handshake, prologue)
# states: want_relay, want_prologue, want_frame # states: want_relay, want_prologue, want_frame
m = MethodicalMachine() m = MethodicalMachine()
set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover
@m.state() @m.state()
def want_relay(self): pass # pragma: no cover def want_relay(self): pass # pragma: no cover
@m.state(initial=True) @m.state(initial=True)
def want_prologue(self): pass # pragma: no cover def want_prologue(self): pass # pragma: no cover
@m.state() @m.state()
def want_frame(self): pass # pragma: no cover def want_frame(self): pass # pragma: no cover
@m.input() @m.input()
def use_relay(self, relay_handshake): pass def use_relay(self, relay_handshake): pass
@m.input() @m.input()
def connectionMade(self): pass def connectionMade(self): pass
@m.input() @m.input()
def parse(self): pass def parse(self): pass
@m.input() @m.input()
def got_relay_ok(self): pass def got_relay_ok(self): pass
@m.input() @m.input()
def got_prologue(self): pass def got_prologue(self): pass
@m.output() @m.output()
def store_relay_handshake(self, relay_handshake): def store_relay_handshake(self, relay_handshake):
self._outbound_relay_handshake = relay_handshake self._outbound_relay_handshake = relay_handshake
self._expected_relay_handshake = b"ok\n" # TODO: make this configurable self._expected_relay_handshake = b"ok\n" # TODO: make this configurable
@m.output() @m.output()
def send_relay_handshake(self): def send_relay_handshake(self):
self._transport.write(self._outbound_relay_handshake) self._transport.write(self._outbound_relay_handshake)
@ -113,17 +128,17 @@ class _Framer(object):
@m.output() @m.output()
def can_send_frames(self): def can_send_frames(self):
self._can_send_frames = True # for assertion in send_frame() self._can_send_frames = True # for assertion in send_frame()
@m.output() @m.output()
def parse_frame(self): def parse_frame(self):
if len(self._buffer) < 4: if len(self._buffer) < 4:
return None return None
frame_length = from_be4(self._buffer[0:4]) frame_length = from_be4(self._buffer[0:4])
if len(self._buffer) < 4+frame_length: if len(self._buffer) < 4 + frame_length:
return None return None
frame = self._buffer[4:4+frame_length] frame = self._buffer[4:4 + frame_length]
self._buffer = self._buffer[4+frame_length:] # TODO: avoid copy self._buffer = self._buffer[4 + frame_length:] # TODO: avoid copy
return Frame(frame=frame) return Frame(frame=frame)
want_prologue.upon(use_relay, outputs=[store_relay_handshake], want_prologue.upon(use_relay, outputs=[store_relay_handshake],
@ -144,7 +159,6 @@ class _Framer(object):
want_frame.upon(parse, outputs=[parse_frame], enter=want_frame, want_frame.upon(parse, outputs=[parse_frame], enter=want_frame,
collector=first) collector=first)
def _get_expected(self, name, expected): def _get_expected(self, name, expected):
lb = len(self._buffer) lb = len(self._buffer)
le = len(expected) le = len(expected)
@ -161,7 +175,7 @@ class _Framer(object):
if (b"\n" in self._buffer or lb >= le): if (b"\n" in self._buffer or lb >= le):
log.msg("bad {}: {}".format(name, self._buffer[:le])) log.msg("bad {}: {}".format(name, self._buffer[:le]))
raise Disconnect() raise Disconnect()
return False # wait a bit longer return False # wait a bit longer
# good so far, just waiting for the rest # good so far, just waiting for the rest
return False return False
@ -181,7 +195,7 @@ class _Framer(object):
self.got_relay_ok() self.got_relay_ok()
elif isinstance(token, Prologue): elif isinstance(token, Prologue):
self.got_prologue() self.got_prologue()
yield token # triggers send_handshake yield token # triggers send_handshake
elif isinstance(token, Frame): elif isinstance(token, Frame):
yield token yield token
else: else:
@ -202,15 +216,16 @@ class _Framer(object):
# from peer. Sent immediately by Follower, after Selection by Leader. # from peer. Sent immediately by Follower, after Selection by Leader.
# Record: namedtuple of KCM/Open/Data/Close/Ack/Ping/Pong # Record: namedtuple of KCM/Open/Data/Close/Ack/Ping/Pong
Handshake = namedtuple("Handshake", []) Handshake = namedtuple("Handshake", [])
# decrypted frames: produces KCM, Ping, Pong, Open, Data, Close, Ack # decrypted frames: produces KCM, Ping, Pong, Open, Data, Close, Ack
KCM = namedtuple("KCM", []) KCM = namedtuple("KCM", [])
Ping = namedtuple("Ping", ["ping_id"]) # ping_id is arbitrary 4-byte value Ping = namedtuple("Ping", ["ping_id"]) # ping_id is arbitrary 4-byte value
Pong = namedtuple("Pong", ["ping_id"]) Pong = namedtuple("Pong", ["ping_id"])
Open = namedtuple("Open", ["seqnum", "scid"]) # seqnum is integer Open = namedtuple("Open", ["seqnum", "scid"]) # seqnum is integer
Data = namedtuple("Data", ["seqnum", "scid", "data"]) Data = namedtuple("Data", ["seqnum", "scid", "data"])
Close = namedtuple("Close", ["seqnum", "scid"]) # scid is integer Close = namedtuple("Close", ["seqnum", "scid"]) # scid is integer
Ack = namedtuple("Ack", ["resp_seqnum"]) # resp_seqnum is integer Ack = namedtuple("Ack", ["resp_seqnum"]) # resp_seqnum is integer
Records = (KCM, Ping, Pong, Open, Data, Close, Ack) Records = (KCM, Ping, Pong, Open, Data, Close, Ack)
Handshake_or_Records = (Handshake,) + Records Handshake_or_Records = (Handshake,) + Records
@ -222,6 +237,7 @@ T_DATA = b"\x04"
T_CLOSE = b"\x05" T_CLOSE = b"\x05"
T_ACK = b"\x06" T_ACK = b"\x06"
def parse_record(plaintext): def parse_record(plaintext):
msgtype = plaintext[0:1] msgtype = plaintext[0:1]
if msgtype == T_KCM: if msgtype == T_KCM:
@ -251,6 +267,7 @@ def parse_record(plaintext):
log.err("received unknown message type: {}".format(plaintext)) log.err("received unknown message type: {}".format(plaintext))
raise ValueError() raise ValueError()
def encode_record(r): def encode_record(r):
if isinstance(r, KCM): if isinstance(r, KCM):
return b"\x00" return b"\x00"
@ -275,6 +292,7 @@ def encode_record(r):
return b"\x06" + to_be4(r.resp_seqnum) return b"\x06" + to_be4(r.resp_seqnum)
raise TypeError(r) raise TypeError(r)
@attrs @attrs
@implementer(IRecord) @implementer(IRecord)
class _Record(object): class _Record(object):
@ -294,22 +312,25 @@ class _Record(object):
# states: want_prologue, want_handshake, want_record # states: want_prologue, want_handshake, want_record
@n.state(initial=True) @n.state(initial=True)
def want_prologue(self): pass # pragma: no cover def want_prologue(self): pass # pragma: no cover
@n.state() @n.state()
def want_handshake(self): pass # pragma: no cover def want_handshake(self): pass # pragma: no cover
@n.state() @n.state()
def want_message(self): pass # pragma: no cover def want_message(self): pass # pragma: no cover
@n.input() @n.input()
def got_prologue(self): def got_prologue(self):
pass pass
@n.input() @n.input()
def got_frame(self, frame): def got_frame(self, frame):
pass pass
@n.output() @n.output()
def send_handshake(self): def send_handshake(self):
handshake = self._noise.write_message() # generate the ephemeral key handshake = self._noise.write_message() # generate the ephemeral key
self._framer.send_frame(handshake) self._framer.send_frame(handshake)
@n.output() @n.output()
@ -351,10 +372,10 @@ class _Record(object):
def add_and_unframe(self, data): def add_and_unframe(self, data):
for token in self._framer.add_and_parse(data): for token in self._framer.add_and_parse(data):
if isinstance(token, Prologue): if isinstance(token, Prologue):
self.got_prologue() # triggers send_handshake self.got_prologue() # triggers send_handshake
else: else:
assert isinstance(token, Frame) assert isinstance(token, Frame)
yield self.got_frame(token.frame) # Handshake or a Record type yield self.got_frame(token.frame) # Handshake or a Record type
def send_record(self, r): def send_record(self, r):
message = encode_record(r) message = encode_record(r)
@ -388,26 +409,30 @@ class DilatedConnectionProtocol(Protocol, object):
_relay_handshake = None _relay_handshake = None
m = MethodicalMachine() m = MethodicalMachine()
set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover
def __attrs_post_init__(self): def __attrs_post_init__(self):
self._manager = None # set if/when we are selected self._manager = None # set if/when we are selected
self._disconnected = OneShotObserver(self._eventual_queue) self._disconnected = OneShotObserver(self._eventual_queue)
self._can_send_records = False self._can_send_records = False
@m.state(initial=True) @m.state(initial=True)
def unselected(self): pass # pragma: no cover def unselected(self): pass # pragma: no cover
@m.state() @m.state()
def selecting(self): pass # pragma: no cover def selecting(self): pass # pragma: no cover
@m.state() @m.state()
def selected(self): pass # pragma: no cover def selected(self): pass # pragma: no cover
@m.input() @m.input()
def got_kcm(self): def got_kcm(self):
pass pass
@m.input() @m.input()
def select(self, manager): def select(self, manager):
pass # fires set_manager() pass # fires set_manager()
@m.input() @m.input()
def got_record(self, record): def got_record(self, record):
pass pass
@ -472,9 +497,9 @@ class DilatedConnectionProtocol(Protocol, object):
elif isinstance(token, KCM): elif isinstance(token, KCM):
# if we're the leader, add this connection as a candiate. # if we're the leader, add this connection as a candiate.
# if we're the follower, accept this connection. # if we're the follower, accept this connection.
self.got_kcm() # connector.add_candidate() self.got_kcm() # connector.add_candidate()
else: else:
self.got_record(token) # manager.got_record() self.got_record(token) # manager.got_record()
except Disconnect: except Disconnect:
self.transport.loseConnection() self.transport.loseConnection()

View File

@ -1,5 +1,6 @@
from __future__ import print_function, unicode_literals from __future__ import print_function, unicode_literals
import sys, re import sys
import re
from collections import defaultdict, namedtuple from collections import defaultdict, namedtuple
from binascii import hexlify from binascii import hexlify
import six import six
@ -13,7 +14,7 @@ from twisted.internet.endpoints import HostnameEndpoint, serverFromString
from twisted.internet.protocol import ClientFactory, ServerFactory from twisted.internet.protocol import ClientFactory, ServerFactory
from twisted.python import log from twisted.python import log
from hkdf import Hkdf from hkdf import Hkdf
from .. import ipaddrs # TODO: move into _dilation/ from .. import ipaddrs # TODO: move into _dilation/
from .._interfaces import IDilationConnector, IDilationManager from .._interfaces import IDilationConnector, IDilationManager
from ..timing import DebugTiming from ..timing import DebugTiming
from ..observer import EmptyableSet from ..observer import EmptyableSet
@ -30,7 +31,8 @@ from .roles import LEADER
# * expect to see the receiver/sender handshake bytes from the other side # * expect to see the receiver/sender handshake bytes from the other side
# * the sender writes "go\n", the receiver waits for "go\n" # * the sender writes "go\n", the receiver waits for "go\n"
# * the rest of the connection contains transit data # * the rest of the connection contains transit data
DirectTCPV1Hint = namedtuple("DirectTCPV1Hint", ["hostname", "port", "priority"]) DirectTCPV1Hint = namedtuple(
"DirectTCPV1Hint", ["hostname", "port", "priority"])
TorTCPV1Hint = namedtuple("TorTCPV1Hint", ["hostname", "port", "priority"]) TorTCPV1Hint = namedtuple("TorTCPV1Hint", ["hostname", "port", "priority"])
# RelayV1Hint contains a tuple of DirectTCPV1Hint and TorTCPV1Hint hints (we # RelayV1Hint contains a tuple of DirectTCPV1Hint and TorTCPV1Hint hints (we
# use a tuple rather than a list so they'll be hashable into a set). For each # use a tuple rather than a list so they'll be hashable into a set). For each
@ -38,6 +40,7 @@ TorTCPV1Hint = namedtuple("TorTCPV1Hint", ["hostname", "port", "priority"])
# rest of the V1 protocol. Only one hint per relay is useful. # rest of the V1 protocol. Only one hint per relay is useful.
RelayV1Hint = namedtuple("RelayV1Hint", ["hints"]) RelayV1Hint = namedtuple("RelayV1Hint", ["hints"])
def describe_hint_obj(hint, relay, tor): def describe_hint_obj(hint, relay, tor):
prefix = "tor->" if tor else "->" prefix = "tor->" if tor else "->"
if relay: if relay:
@ -45,9 +48,10 @@ def describe_hint_obj(hint, relay, tor):
if isinstance(hint, DirectTCPV1Hint): if isinstance(hint, DirectTCPV1Hint):
return prefix + "tcp:%s:%d" % (hint.hostname, hint.port) return prefix + "tcp:%s:%d" % (hint.hostname, hint.port)
elif isinstance(hint, TorTCPV1Hint): elif isinstance(hint, TorTCPV1Hint):
return prefix+"tor:%s:%d" % (hint.hostname, hint.port) return prefix + "tor:%s:%d" % (hint.hostname, hint.port)
else: else:
return prefix+str(hint) return prefix + str(hint)
def parse_hint_argv(hint, stderr=sys.stderr): def parse_hint_argv(hint, stderr=sys.stderr):
assert isinstance(hint, type("")) assert isinstance(hint, type(""))
@ -59,7 +63,8 @@ def parse_hint_argv(hint, stderr=sys.stderr):
return None return None
hint_type = mo.group(1) hint_type = mo.group(1)
if hint_type != "tcp": if hint_type != "tcp":
print("unknown hint type '%s' in '%s'" % (hint_type, hint), file=stderr) print("unknown hint type '%s' in '%s'" % (hint_type, hint),
file=stderr)
return None return None
hint_value = mo.group(2) hint_value = mo.group(2)
pieces = hint_value.split(":") pieces = hint_value.split(":")
@ -84,17 +89,18 @@ def parse_hint_argv(hint, stderr=sys.stderr):
return None return None
return DirectTCPV1Hint(hint_host, hint_port, priority) return DirectTCPV1Hint(hint_host, hint_port, priority)
def parse_tcp_v1_hint(hint): # hint_struct -> hint_obj
def parse_tcp_v1_hint(hint): # hint_struct -> hint_obj
hint_type = hint.get("type", "") hint_type = hint.get("type", "")
if hint_type not in ["direct-tcp-v1", "tor-tcp-v1"]: if hint_type not in ["direct-tcp-v1", "tor-tcp-v1"]:
log.msg("unknown hint type: %r" % (hint,)) log.msg("unknown hint type: %r" % (hint,))
return None return None
if not("hostname" in hint if not("hostname" in hint and
and isinstance(hint["hostname"], type(""))): isinstance(hint["hostname"], type(""))):
log.msg("invalid hostname in hint: %r" % (hint,)) log.msg("invalid hostname in hint: %r" % (hint,))
return None return None
if not("port" in hint if not("port" in hint and
and isinstance(hint["port"], six.integer_types)): isinstance(hint["port"], six.integer_types)):
log.msg("invalid port in hint: %r" % (hint,)) log.msg("invalid port in hint: %r" % (hint,))
return None return None
priority = hint.get("priority", 0.0) priority = hint.get("priority", 0.0)
@ -103,51 +109,58 @@ def parse_tcp_v1_hint(hint): # hint_struct -> hint_obj
else: else:
return TorTCPV1Hint(hint["hostname"], hint["port"], priority) return TorTCPV1Hint(hint["hostname"], hint["port"], priority)
def parse_hint(hint_struct): def parse_hint(hint_struct):
hint_type = hint_struct.get("type", "") hint_type = hint_struct.get("type", "")
if hint_type == "relay-v1": if hint_type == "relay-v1":
# the struct can include multiple ways to reach the same relay # the struct can include multiple ways to reach the same relay
rhints = filter(lambda h: h, # drop None (unrecognized) rhints = filter(lambda h: h, # drop None (unrecognized)
[parse_tcp_v1_hint(rh) for rh in hint_struct["hints"]]) [parse_tcp_v1_hint(rh) for rh in hint_struct["hints"]])
return RelayV1Hint(rhints) return RelayV1Hint(rhints)
return parse_tcp_v1_hint(hint_struct) return parse_tcp_v1_hint(hint_struct)
def encode_hint(h): def encode_hint(h):
if isinstance(h, DirectTCPV1Hint): if isinstance(h, DirectTCPV1Hint):
return {"type": "direct-tcp-v1", return {"type": "direct-tcp-v1",
"priority": h.priority, "priority": h.priority,
"hostname": h.hostname, "hostname": h.hostname,
"port": h.port, # integer "port": h.port, # integer
} }
elif isinstance(h, RelayV1Hint): elif isinstance(h, RelayV1Hint):
rhint = {"type": "relay-v1", "hints": []} rhint = {"type": "relay-v1", "hints": []}
for rh in h.hints: for rh in h.hints:
rhint["hints"].append({"type": "direct-tcp-v1", rhint["hints"].append({"type": "direct-tcp-v1",
"priority": rh.priority, "priority": rh.priority,
"hostname": rh.hostname, "hostname": rh.hostname,
"port": rh.port}) "port": rh.port})
return rhint return rhint
elif isinstance(h, TorTCPV1Hint): elif isinstance(h, TorTCPV1Hint):
return {"type": "tor-tcp-v1", return {"type": "tor-tcp-v1",
"priority": h.priority, "priority": h.priority,
"hostname": h.hostname, "hostname": h.hostname,
"port": h.port, # integer "port": h.port, # integer
} }
raise ValueError("unknown hint type", h) raise ValueError("unknown hint type", h)
def HKDF(skm, outlen, salt=None, CTXinfo=b""): def HKDF(skm, outlen, salt=None, CTXinfo=b""):
return Hkdf(salt, skm).expand(CTXinfo, outlen) return Hkdf(salt, skm).expand(CTXinfo, outlen)
def build_sided_relay_handshake(key, side): def build_sided_relay_handshake(key, side):
assert isinstance(side, type(u"")) assert isinstance(side, type(u""))
assert len(side) == 8*2 assert len(side) == 8 * 2
token = HKDF(key, 32, CTXinfo=b"transit_relay_token") token = HKDF(key, 32, CTXinfo=b"transit_relay_token")
return b"please relay "+hexlify(token)+b" for side "+side.encode("ascii")+b"\n" return (b"please relay " + hexlify(token) +
b" for side " + side.encode("ascii") + b"\n")
PROLOGUE_LEADER = b"Magic-Wormhole Dilation Handshake v1 Leader\n\n"
PROLOGUE_LEADER = b"Magic-Wormhole Dilation Handshake v1 Leader\n\n"
PROLOGUE_FOLLOWER = b"Magic-Wormhole Dilation Handshake v1 Follower\n\n" PROLOGUE_FOLLOWER = b"Magic-Wormhole Dilation Handshake v1 Follower\n\n"
NOISEPROTO = "Noise_NNpsk0_25519_ChaChaPoly_BLAKE2s" NOISEPROTO = "Noise_NNpsk0_25519_ChaChaPoly_BLAKE2s"
@attrs @attrs
@implementer(IDilationConnector) @implementer(IDilationConnector)
class Connector(object): class Connector(object):
@ -176,10 +189,11 @@ class Connector(object):
self._transit_relays = [relay] self._transit_relays = [relay]
else: else:
self._transit_relays = [] self._transit_relays = []
self._listeners = set() # IListeningPorts that can be stopped self._listeners = set() # IListeningPorts that can be stopped
self._pending_connectors = set() # Deferreds that can be cancelled self._pending_connectors = set() # Deferreds that can be cancelled
self._pending_connections = EmptyableSet(_eventual_queue=self._eventual_queue) # Protocols to be stopped self._pending_connections = EmptyableSet(
self._contenders = set() # viable connections _eventual_queue=self._eventual_queue) # Protocols to be stopped
self._contenders = set() # viable connections
self._winning_connection = None self._winning_connection = None
self._timing = self._timing or DebugTiming() self._timing = self._timing or DebugTiming()
self._timing.add("transit") self._timing.add("transit")
@ -212,26 +226,41 @@ class Connector(object):
return p return p
@m.state(initial=True) @m.state(initial=True)
def connecting(self): pass # pragma: no cover def connecting(self):
pass # pragma: no cover
@m.state() @m.state()
def connected(self): pass # pragma: no cover def connected(self):
pass # pragma: no cover
@m.state(terminal=True) @m.state(terminal=True)
def stopped(self): pass # pragma: no cover def stopped(self):
pass # pragma: no cover
# TODO: unify the tense of these method-name verbs # TODO: unify the tense of these method-name verbs
@m.input() @m.input()
def listener_ready(self, hint_objs): pass def listener_ready(self, hint_objs):
@m.input()
def add_relay(self, hint_objs): pass
@m.input()
def got_hints(self, hint_objs): pass
@m.input()
def add_candidate(self, c): # called by DilatedConnectionProtocol
pass pass
@m.input() @m.input()
def accept(self, c): pass def add_relay(self, hint_objs):
pass
@m.input() @m.input()
def stop(self): pass def got_hints(self, hint_objs):
pass
@m.input()
def add_candidate(self, c): # called by DilatedConnectionProtocol
pass
@m.input()
def accept(self, c):
pass
@m.input()
def stop(self):
pass
@m.output() @m.output()
def use_hints(self, hint_objs): def use_hints(self, hint_objs):
@ -255,19 +284,19 @@ class Connector(object):
@m.output() @m.output()
def select_and_stop_remaining(self, c): def select_and_stop_remaining(self, c):
self._winning_connection = c self._winning_connection = c
self._contenders.clear() # we no longer care who else came close self._contenders.clear() # we no longer care who else came close
# remove this winner from the losers, so we don't shut it down # remove this winner from the losers, so we don't shut it down
self._pending_connections.discard(c) self._pending_connections.discard(c)
# shut down losing connections # shut down losing connections
self.stop_listeners() # TODO: maybe keep it open? NAT/p2p assist self.stop_listeners() # TODO: maybe keep it open? NAT/p2p assist
self.stop_pending_connectors() self.stop_pending_connectors()
self.stop_pending_connections() self.stop_pending_connections()
c.select(self._manager) # subsequent frames go directly to the manager c.select(self._manager) # subsequent frames go directly to the manager
if self._role is LEADER: if self._role is LEADER:
# TODO: this should live in Connection # TODO: this should live in Connection
c.send_record(KCM()) # leader sends KCM now c.send_record(KCM()) # leader sends KCM now
self._manager.use_connection(c) # manager sends frames to Connection self._manager.use_connection(c) # manager sends frames to Connection
@m.output() @m.output()
def stop_everything(self): def stop_everything(self):
@ -279,7 +308,7 @@ class Connector(object):
def stop_listeners(self): def stop_listeners(self):
d = DeferredList([l.stopListening() for l in self._listeners]) d = DeferredList([l.stopListening() for l in self._listeners])
self._listeners.clear() self._listeners.clear()
return d # synchronization for tests return d # synchronization for tests
def stop_pending_connectors(self): def stop_pending_connectors(self):
return DeferredList([d.cancel() for d in self._pending_connectors]) return DeferredList([d.cancel() for d in self._pending_connectors])
@ -306,7 +335,8 @@ class Connector(object):
publish_hints]) publish_hints])
connecting.upon(got_hints, enter=connecting, outputs=[use_hints]) connecting.upon(got_hints, enter=connecting, outputs=[use_hints])
connecting.upon(add_candidate, enter=connecting, outputs=[consider]) connecting.upon(add_candidate, enter=connecting, outputs=[consider])
connecting.upon(accept, enter=connected, outputs=[select_and_stop_remaining]) connecting.upon(accept, enter=connected, outputs=[
select_and_stop_remaining])
connecting.upon(stop, enter=stopped, outputs=[stop_everything]) connecting.upon(stop, enter=stopped, outputs=[stop_everything])
# once connected, we ignore everything except stop # once connected, we ignore everything except stop
@ -317,9 +347,9 @@ class Connector(object):
connected.upon(accept, enter=connected, outputs=[]) connected.upon(accept, enter=connected, outputs=[])
connected.upon(stop, enter=stopped, outputs=[stop_everything]) connected.upon(stop, enter=stopped, outputs=[stop_everything])
# from Manager: start, got_hints, stop # from Manager: start, got_hints, stop
# maybe add_candidate, accept # maybe add_candidate, accept
def start(self): def start(self):
self._start_listener() self._start_listener()
if self._transit_relays: if self._transit_relays:
@ -341,9 +371,10 @@ class Connector(object):
ep = serverFromString(self._reactor, "tcp:0") ep = serverFromString(self._reactor, "tcp:0")
f = InboundConnectionFactory(self) f = InboundConnectionFactory(self)
d = ep.listen(f) d = ep.listen(f)
def _listening(lp): def _listening(lp):
# lp is an IListeningPort # lp is an IListeningPort
self._listeners.add(lp) # for shutdown and tests self._listeners.add(lp) # for shutdown and tests
portnum = lp.getHost().port portnum = lp.getHost().port
direct_hints = [DirectTCPV1Hint(six.u(addr), portnum, 0.0) direct_hints = [DirectTCPV1Hint(six.u(addr), portnum, 0.0)
for addr in addresses] for addr in addresses]
@ -378,7 +409,7 @@ class Connector(object):
# one still running. But if we bail on that, we might consider # one still running. But if we bail on that, we might consider
# putting an inter-direct-hint delay here to influence the # putting an inter-direct-hint delay here to influence the
# process. # process.
#delay += 1.0 # delay += 1.0
if delay > 0.0: if delay > 0.0:
# Start trying the relays a few seconds after we start to try the # Start trying the relays a few seconds after we start to try the
# direct hints. The idea is to prefer direct connections, but not # direct hints. The idea is to prefer direct connections, but not
@ -408,7 +439,7 @@ class Connector(object):
self._connect, ep, desc, is_relay=True) self._connect, ep, desc, is_relay=True)
self._pending_connectors.add(d) self._pending_connectors.add(d)
# TODO: # TODO:
#if not contenders: # if not contenders:
# raise TransitError("No contenders for connection") # raise TransitError("No contenders for connection")
# TODO: add 2*TIMEOUT deadline for first generation, don't wait forever for # TODO: add 2*TIMEOUT deadline for first generation, don't wait forever for
@ -422,6 +453,7 @@ class Connector(object):
f = OutboundConnectionFactory(self, relay_handshake) f = OutboundConnectionFactory(self, relay_handshake)
d = ep.connect(f) d = ep.connect(f)
# fires with protocol, or ConnectError # fires with protocol, or ConnectError
def _connected(p): def _connected(p):
self._pending_connections.add(p) self._pending_connections.add(p)
# c might not be in _pending_connections, if it turned out to be a # c might not be in _pending_connections, if it turned out to be a
@ -444,7 +476,6 @@ class Connector(object):
return HostnameEndpoint(self._reactor, hint.hostname, hint.port) return HostnameEndpoint(self._reactor, hint.hostname, hint.port)
return None return None
# Connection selection. All instances of DilatedConnectionProtocol which # Connection selection. All instances of DilatedConnectionProtocol which
# look viable get passed into our add_contender() method. # look viable get passed into our add_contender() method.
@ -459,6 +490,7 @@ class Connector(object):
# our Connection protocols call: add_candidate # our Connection protocols call: add_candidate
@attrs @attrs
class OutboundConnectionFactory(ClientFactory, object): class OutboundConnectionFactory(ClientFactory, object):
_connector = attrib(validator=provides(IDilationConnector)) _connector = attrib(validator=provides(IDilationConnector))
@ -471,6 +503,7 @@ class OutboundConnectionFactory(ClientFactory, object):
p.use_relay(self._relay_handshake) p.use_relay(self._relay_handshake)
return p return p
@attrs @attrs
class InboundConnectionFactory(ServerFactory, object): class InboundConnectionFactory(ServerFactory, object):
_connector = attrib(validator=provides(IDilationConnector)) _connector = attrib(validator=provides(IDilationConnector))

View File

@ -4,10 +4,13 @@ import struct
assert len(struct.pack("<L", 0)) == 4 assert len(struct.pack("<L", 0)) == 4
assert len(struct.pack("<Q", 0)) == 8 assert len(struct.pack("<Q", 0)) == 8
def to_be4(value): def to_be4(value):
if not 0 <= value < 2**32: if not 0 <= value < 2**32:
raise ValueError raise ValueError
return struct.pack(">L", value) return struct.pack(">L", value)
def from_be4(b): def from_be4(b):
if not isinstance(b, bytes): if not isinstance(b, bytes):
raise TypeError(repr(b)) raise TypeError(repr(b))

View File

@ -6,13 +6,19 @@ from twisted.python import log
from .._interfaces import IDilationManager, IInbound from .._interfaces import IDilationManager, IInbound
from .subchannel import (SubChannel, _SubchannelAddress) from .subchannel import (SubChannel, _SubchannelAddress)
class DuplicateOpenError(Exception): class DuplicateOpenError(Exception):
pass pass
class DataForMissingSubchannelError(Exception): class DataForMissingSubchannelError(Exception):
pass pass
class CloseForMissingSubchannelError(Exception): class CloseForMissingSubchannelError(Exception):
pass pass
@attrs @attrs
@implementer(IInbound) @implementer(IInbound)
class Inbound(object): class Inbound(object):
@ -24,8 +30,8 @@ class Inbound(object):
def __attrs_post_init__(self): def __attrs_post_init__(self):
# we route inbound Data records to Subchannels .dataReceived # we route inbound Data records to Subchannels .dataReceived
self._open_subchannels = {} # scid -> Subchannel self._open_subchannels = {} # scid -> Subchannel
self._paused_subchannels = set() # Subchannels that have paused us self._paused_subchannels = set() # Subchannels that have paused us
# the set is non-empty, we pause the transport # the set is non-empty, we pause the transport
self._highest_inbound_acked = -1 self._highest_inbound_acked = -1
self._connection = None self._connection = None
@ -37,7 +43,6 @@ class Inbound(object):
def set_subchannel_zero(self, scid0, sc0): def set_subchannel_zero(self, scid0, sc0):
self._open_subchannels[scid0] = sc0 self._open_subchannels[scid0] = sc0
def use_connection(self, c): def use_connection(self, c):
self._connection = c self._connection = c
# We can pause the connection's reads when we receive too much data. If # We can pause the connection's reads when we receive too much data. If
@ -61,7 +66,8 @@ class Inbound(object):
def handle_open(self, scid): def handle_open(self, scid):
if scid in self._open_subchannels: if scid in self._open_subchannels:
log.err(DuplicateOpenError("received duplicate OPEN for {}".format(scid))) log.err(DuplicateOpenError(
"received duplicate OPEN for {}".format(scid)))
return return
peer_addr = _SubchannelAddress(scid) peer_addr = _SubchannelAddress(scid)
sc = SubChannel(scid, self._manager, self._host_addr, peer_addr) sc = SubChannel(scid, self._manager, self._host_addr, peer_addr)
@ -71,14 +77,16 @@ class Inbound(object):
def handle_data(self, scid, data): def handle_data(self, scid, data):
sc = self._open_subchannels.get(scid) sc = self._open_subchannels.get(scid)
if sc is None: if sc is None:
log.err(DataForMissingSubchannelError("received DATA for non-existent subchannel {}".format(scid))) log.err(DataForMissingSubchannelError(
"received DATA for non-existent subchannel {}".format(scid)))
return return
sc.remote_data(data) sc.remote_data(data)
def handle_close(self, scid): def handle_close(self, scid):
sc = self._open_subchannels.get(scid) sc = self._open_subchannels.get(scid)
if sc is None: if sc is None:
log.err(CloseForMissingSubchannelError("received CLOSE for non-existent subchannel {}".format(scid))) log.err(CloseForMissingSubchannelError(
"received CLOSE for non-existent subchannel {}".format(scid)))
return return
sc.remote_close() sc.remote_close()
@ -90,7 +98,6 @@ class Inbound(object):
def stop_using_connection(self): def stop_using_connection(self):
self._connection = None self._connection = None
# from our Subchannel, or rather from the Protocol above it and sent # from our Subchannel, or rather from the Protocol above it and sent
# through the subchannel # through the subchannel

View File

@ -20,13 +20,19 @@ from .connection import KCM, Ping, Pong, Open, Data, Close, Ack
from .inbound import Inbound from .inbound import Inbound
from .outbound import Outbound from .outbound import Outbound
class OldPeerCannotDilateError(Exception): class OldPeerCannotDilateError(Exception):
pass pass
class UnknownDilationMessageType(Exception): class UnknownDilationMessageType(Exception):
pass pass
class ReceivedHintsTooEarly(Exception): class ReceivedHintsTooEarly(Exception):
pass pass
@attrs @attrs
@implementer(IDilationManager) @implementer(IDilationManager)
class _ManagerBase(object): class _ManagerBase(object):
@ -37,14 +43,14 @@ class _ManagerBase(object):
_reactor = attrib() _reactor = attrib()
_eventual_queue = attrib() _eventual_queue = attrib()
_cooperator = attrib() _cooperator = attrib()
_no_listen = False # TODO _no_listen = False # TODO
_tor = None # TODO _tor = None # TODO
_timing = None # TODO _timing = None # TODO
def __attrs_post_init__(self): def __attrs_post_init__(self):
self._got_versions_d = Deferred() self._got_versions_d = Deferred()
self._my_role = None # determined upon rx_PLEASE self._my_role = None # determined upon rx_PLEASE
self._connection = None self._connection = None
self._made_first_connection = False self._made_first_connection = False
@ -53,51 +59,56 @@ class _ManagerBase(object):
self._next_dilation_phase = 0 self._next_dilation_phase = 0
self._next_subchannel_id = 0 # increments by 2 self._next_subchannel_id = 0 # increments by 2
# I kept getting confused about which methods were for inbound data # I kept getting confused about which methods were for inbound data
# (and thus flow-control methods go "out") and which were for # (and thus flow-control methods go "out") and which were for
# outbound data (with flow-control going "in"), so I split them up # outbound data (with flow-control going "in"), so I split them up
# into separate pieces. # into separate pieces.
self._inbound = Inbound(self, self._host_addr) self._inbound = Inbound(self, self._host_addr)
self._outbound = Outbound(self, self._cooperator) # from us to peer self._outbound = Outbound(self, self._cooperator) # from us to peer
def set_listener_endpoint(self, listener_endpoint): def set_listener_endpoint(self, listener_endpoint):
self._inbound.set_listener_endpoint(listener_endpoint) self._inbound.set_listener_endpoint(listener_endpoint)
def set_subchannel_zero(self, scid0, sc0): def set_subchannel_zero(self, scid0, sc0):
self._inbound.set_subchannel_zero(scid0, sc0) self._inbound.set_subchannel_zero(scid0, sc0)
def when_first_connected(self): def when_first_connected(self):
return self._first_connected.when_fired() return self._first_connected.when_fired()
def send_dilation_phase(self, **fields): def send_dilation_phase(self, **fields):
dilation_phase = self._next_dilation_phase dilation_phase = self._next_dilation_phase
self._next_dilation_phase += 1 self._next_dilation_phase += 1
self._S.send("dilate-%d" % dilation_phase, dict_to_bytes(fields)) self._S.send("dilate-%d" % dilation_phase, dict_to_bytes(fields))
def send_hints(self, hints): # from Connector def send_hints(self, hints): # from Connector
self.send_dilation_phase(type="connection-hints", hints=hints) self.send_dilation_phase(type="connection-hints", hints=hints)
# forward inbound-ish things to _Inbound # forward inbound-ish things to _Inbound
def subchannel_pauseProducing(self, sc): def subchannel_pauseProducing(self, sc):
self._inbound.subchannel_pauseProducing(sc) self._inbound.subchannel_pauseProducing(sc)
def subchannel_resumeProducing(self, sc): def subchannel_resumeProducing(self, sc):
self._inbound.subchannel_resumeProducing(sc) self._inbound.subchannel_resumeProducing(sc)
def subchannel_stopProducing(self, sc): def subchannel_stopProducing(self, sc):
self._inbound.subchannel_stopProducing(sc) self._inbound.subchannel_stopProducing(sc)
# forward outbound-ish things to _Outbound # forward outbound-ish things to _Outbound
def subchannel_registerProducer(self, sc, producer, streaming): def subchannel_registerProducer(self, sc, producer, streaming):
self._outbound.subchannel_registerProducer(sc, producer, streaming) self._outbound.subchannel_registerProducer(sc, producer, streaming)
def subchannel_unregisterProducer(self, sc): def subchannel_unregisterProducer(self, sc):
self._outbound.subchannel_unregisterProducer(sc) self._outbound.subchannel_unregisterProducer(sc)
def send_open(self, scid): def send_open(self, scid):
self._queue_and_send(Open, scid) self._queue_and_send(Open, scid)
def send_data(self, scid, data): def send_data(self, scid, data):
self._queue_and_send(Data, scid, data) self._queue_and_send(Data, scid, data)
def send_close(self, scid): def send_close(self, scid):
self._queue_and_send(Close, scid) self._queue_and_send(Close, scid)
@ -106,7 +117,7 @@ class _ManagerBase(object):
# Outbound owns the send_record() pipe, so that it can stall new # Outbound owns the send_record() pipe, so that it can stall new
# writes after a new connection is made until after all queued # writes after a new connection is made until after all queued
# messages are written (to preserve ordering). # messages are written (to preserve ordering).
self._outbound.queue_and_send_record(r) # may trigger pauseProducing self._outbound.queue_and_send_record(r) # may trigger pauseProducing
def subchannel_closed(self, scid, sc): def subchannel_closed(self, scid, sc):
# let everyone clean up. This happens just after we delivered # let everyone clean up. This happens just after we delivered
@ -116,7 +127,6 @@ class _ManagerBase(object):
self._inbound.subchannel_closed(scid, sc) self._inbound.subchannel_closed(scid, sc)
self._outbound.subchannel_closed(scid, sc) self._outbound.subchannel_closed(scid, sc)
def _start_connecting(self, role): def _start_connecting(self, role):
assert self._my_role is not None assert self._my_role is not None
self._connector = Connector(self._transit_key, self._connector = Connector(self._transit_key,
@ -125,41 +135,41 @@ class _ManagerBase(object):
self._reactor, self._eventual_queue, self._reactor, self._eventual_queue,
self._no_listen, self._tor, self._no_listen, self._tor,
self._timing, self._timing,
self._side, # needed for relay handshake self._side, # needed for relay handshake
self._my_role) self._my_role)
self._connector.start() self._connector.start()
# our Connector calls these # our Connector calls these
def connector_connection_made(self, c): def connector_connection_made(self, c):
self.connection_made() # state machine update self.connection_made() # state machine update
self._connection = c self._connection = c
self._inbound.use_connection(c) self._inbound.use_connection(c)
self._outbound.use_connection(c) # does c.registerProducer self._outbound.use_connection(c) # does c.registerProducer
if not self._made_first_connection: if not self._made_first_connection:
self._made_first_connection = True self._made_first_connection = True
self._first_connected.fire(None) self._first_connected.fire(None)
pass pass
def connector_connection_lost(self): def connector_connection_lost(self):
self._stop_using_connection() self._stop_using_connection()
if self.role is LEADER: if self.role is LEADER:
self.connection_lost_leader() # state machine self.connection_lost_leader() # state machine
else: else:
self.connection_lost_follower() self.connection_lost_follower()
def _stop_using_connection(self): def _stop_using_connection(self):
# the connection is already lost by this point # the connection is already lost by this point
self._connection = None self._connection = None
self._inbound.stop_using_connection() self._inbound.stop_using_connection()
self._outbound.stop_using_connection() # does c.unregisterProducer self._outbound.stop_using_connection() # does c.unregisterProducer
# from our active Connection # from our active Connection
def got_record(self, r): def got_record(self, r):
# records with sequence numbers: always ack, ignore old ones # records with sequence numbers: always ack, ignore old ones
if isinstance(r, (Open, Data, Close)): if isinstance(r, (Open, Data, Close)):
self.send_ack(r.seqnum) # always ack, even for old ones self.send_ack(r.seqnum) # always ack, even for old ones
if self._inbound.is_record_old(r): if self._inbound.is_record_old(r):
return return
self._inbound.update_ack_watermark(r.seqnum) self._inbound.update_ack_watermark(r.seqnum)
@ -167,7 +177,7 @@ class _ManagerBase(object):
self._inbound.handle_open(r.scid) self._inbound.handle_open(r.scid)
elif isinstance(r, Data): elif isinstance(r, Data):
self._inbound.handle_data(r.scid, r.data) self._inbound.handle_data(r.scid, r.data)
else: # isinstance(r, Close) else: # isinstance(r, Close)
self._inbound.handle_close(r.scid) self._inbound.handle_close(r.scid)
if isinstance(r, KCM): if isinstance(r, KCM):
log.err("got unexpected KCM") log.err("got unexpected KCM")
@ -176,7 +186,7 @@ class _ManagerBase(object):
elif isinstance(r, Pong): elif isinstance(r, Pong):
self.handle_pong(r.ping_id) self.handle_pong(r.ping_id)
elif isinstance(r, Ack): elif isinstance(r, Ack):
self._outbound.handle_ack(r.resp_seqnum) # retire queued messages self._outbound.handle_ack(r.resp_seqnum) # retire queued messages
else: else:
log.err("received unknown message type {}".format(r)) log.err("received unknown message type {}".format(r))
@ -190,7 +200,6 @@ class _ManagerBase(object):
def send_ack(self, resp_seqnum): def send_ack(self, resp_seqnum):
self._outbound.send_if_connected(Ack(resp_seqnum)) self._outbound.send_if_connected(Ack(resp_seqnum))
def handle_ping(self, ping_id): def handle_ping(self, ping_id):
self.send_pong(ping_id) self.send_pong(ping_id)
@ -200,7 +209,7 @@ class _ManagerBase(object):
# subchannel maintenance # subchannel maintenance
def allocate_subchannel_id(self): def allocate_subchannel_id(self):
raise NotImplemented # subclass knows if we're leader or follower raise NotImplementedError # subclass knows if we're leader or follower
# new scheme: # new scheme:
# * both sides send PLEASE as soon as they have an unverified key and # * both sides send PLEASE as soon as they have an unverified key and
@ -236,58 +245,91 @@ class _ManagerBase(object):
# * if follower calls w.dilate() but not leader, follower waits forever # * if follower calls w.dilate() but not leader, follower waits forever
# in "want", leader waits forever in "wanted" # in "want", leader waits forever in "wanted"
class ManagerShared(_ManagerBase): class ManagerShared(_ManagerBase):
m = MethodicalMachine() m = MethodicalMachine()
set_trace = getattr(m, "_setTrace", lambda self, f: None) set_trace = getattr(m, "_setTrace", lambda self, f: None)
@m.state(initial=True) @m.state(initial=True)
def IDLE(self): pass # pragma: no cover def IDLE(self):
pass # pragma: no cover
@m.state() @m.state()
def WANTING(self): pass # pragma: no cover def WANTING(self):
pass # pragma: no cover
@m.state() @m.state()
def WANTED(self): pass # pragma: no cover def WANTED(self):
pass # pragma: no cover
@m.state() @m.state()
def CONNECTING(self): pass # pragma: no cover def CONNECTING(self):
pass # pragma: no cover
@m.state() @m.state()
def CONNECTED(self): pass # pragma: no cover def CONNECTED(self):
pass # pragma: no cover
@m.state() @m.state()
def FLUSHING(self): pass # pragma: no cover def FLUSHING(self):
pass # pragma: no cover
@m.state() @m.state()
def ABANDONING(self): pass # pragma: no cover def ABANDONING(self):
pass # pragma: no cover
@m.state() @m.state()
def LONELY(self): pass # pragme: no cover def LONELY(self):
pass # pragme: no cover
@m.state() @m.state()
def STOPPING(self): pass # pragma: no cover def STOPPING(self):
pass # pragma: no cover
@m.state(terminal=True) @m.state(terminal=True)
def STOPPED(self): pass # pragma: no cover def STOPPED(self):
pass # pragma: no cover
@m.input() @m.input()
def start(self): pass # pragma: no cover def start(self):
pass # pragma: no cover
@m.input() @m.input()
def rx_PLEASE(self, message): pass # pragma: no cover def rx_PLEASE(self, message):
@m.input() # only sent by Follower pass # pragma: no cover
def rx_HINTS(self, hint_message): pass # pragma: no cover
@m.input() # only Leader sends RECONNECT, so only Follower receives it @m.input() # only sent by Follower
def rx_RECONNECT(self): pass # pragma: no cover def rx_HINTS(self, hint_message):
@m.input() # only Follower sends RECONNECTING, so only Leader receives it pass # pragma: no cover
def rx_RECONNECTING(self): pass # pragma: no cover
@m.input() # only Leader sends RECONNECT, so only Follower receives it
def rx_RECONNECT(self):
pass # pragma: no cover
@m.input() # only Follower sends RECONNECTING, so only Leader receives it
def rx_RECONNECTING(self):
pass # pragma: no cover
# Connector gives us connection_made() # Connector gives us connection_made()
@m.input() @m.input()
def connection_made(self, c): pass # pragma: no cover def connection_made(self, c):
pass # pragma: no cover
# our connection_lost() fires connection_lost_leader or # our connection_lost() fires connection_lost_leader or
# connection_lost_follower depending upon our role. If either side sees a # connection_lost_follower depending upon our role. If either side sees a
# problem with the connection (timeouts, bad authentication) then they # problem with the connection (timeouts, bad authentication) then they
# just drop it and let connection_lost() handle the cleanup. # just drop it and let connection_lost() handle the cleanup.
@m.input() @m.input()
def connection_lost_leader(self): pass # pragma: no cover def connection_lost_leader(self):
@m.input() pass # pragma: no cover
def connection_lost_follower(self): pass
@m.input() @m.input()
def stop(self): pass # pragma: no cover def connection_lost_follower(self):
pass
@m.input()
def stop(self):
pass # pragma: no cover
@m.output() @m.output()
def stash_side(self, message): def stash_side(self, message):
@ -301,38 +343,42 @@ class ManagerShared(_ManagerBase):
@m.output() @m.output()
def start_connecting(self): def start_connecting(self):
self._start_connecting() # TODO: merge self._start_connecting() # TODO: merge
@m.output() @m.output()
def ignore_message_start_connecting(self, message): def ignore_message_start_connecting(self, message):
self.start_connecting() self.start_connecting()
@m.output() @m.output()
def send_reconnect(self): def send_reconnect(self):
self.send_dilation_phase(type="reconnect") # TODO: generation number? self.send_dilation_phase(type="reconnect") # TODO: generation number?
@m.output() @m.output()
def send_reconnecting(self): def send_reconnecting(self):
self.send_dilation_phase(type="reconnecting") # TODO: generation? self.send_dilation_phase(type="reconnecting") # TODO: generation?
@m.output() @m.output()
def use_hints(self, hint_message): def use_hints(self, hint_message):
hint_objs = filter(lambda h: h, # ignore None, unrecognizable hint_objs = filter(lambda h: h, # ignore None, unrecognizable
[parse_hint(hs) for hs in hint_message["hints"]]) [parse_hint(hs) for hs in hint_message["hints"]])
hint_objs = list(hint_objs) hint_objs = list(hint_objs)
self._connector.got_hints(hint_objs) self._connector.got_hints(hint_objs)
@m.output() @m.output()
def stop_connecting(self): def stop_connecting(self):
self._connector.stop() self._connector.stop()
@m.output() @m.output()
def abandon_connection(self): def abandon_connection(self):
# we think we're still connected, but the Leader disagrees. Or we've # we think we're still connected, but the Leader disagrees. Or we've
# been told to shut down. # been told to shut down.
self._connection.disconnect() # let connection_lost do cleanup self._connection.disconnect() # let connection_lost do cleanup
# we don't start CONNECTING until a local start() plus rx_PLEASE # we don't start CONNECTING until a local start() plus rx_PLEASE
IDLE.upon(rx_PLEASE, enter=WANTED, outputs=[stash_side]) IDLE.upon(rx_PLEASE, enter=WANTED, outputs=[stash_side])
IDLE.upon(start, enter=WANTING, outputs=[send_please]) IDLE.upon(start, enter=WANTING, outputs=[send_please])
WANTED.upon(start, enter=CONNECTING, outputs=[send_please, start_connecting]) WANTED.upon(start, enter=CONNECTING, outputs=[
send_please, start_connecting])
WANTING.upon(rx_PLEASE, enter=CONNECTING, WANTING.upon(rx_PLEASE, enter=CONNECTING,
outputs=[stash_side, outputs=[stash_side,
ignore_message_start_connecting]) ignore_message_start_connecting])
@ -342,7 +388,8 @@ class ManagerShared(_ManagerBase):
# Leader # Leader
CONNECTED.upon(connection_lost_leader, enter=FLUSHING, CONNECTED.upon(connection_lost_leader, enter=FLUSHING,
outputs=[send_reconnect]) outputs=[send_reconnect])
FLUSHING.upon(rx_RECONNECTING, enter=CONNECTING, outputs=[start_connecting]) FLUSHING.upon(rx_RECONNECTING, enter=CONNECTING,
outputs=[start_connecting])
# Follower # Follower
# if we notice a lost connection, just wait for the Leader to notice too # if we notice a lost connection, just wait for the Leader to notice too
@ -350,7 +397,7 @@ class ManagerShared(_ManagerBase):
LONELY.upon(rx_RECONNECT, enter=CONNECTING, outputs=[start_connecting]) LONELY.upon(rx_RECONNECT, enter=CONNECTING, outputs=[start_connecting])
# but if they notice it first, abandon our (seemingly functional) # but if they notice it first, abandon our (seemingly functional)
# connection, then tell them that we're ready to try again # connection, then tell them that we're ready to try again
CONNECTED.upon(rx_RECONNECT, enter=ABANDONING, # they noticed loss CONNECTED.upon(rx_RECONNECT, enter=ABANDONING, # they noticed loss
outputs=[abandon_connection]) outputs=[abandon_connection])
ABANDONING.upon(connection_lost_follower, enter=CONNECTING, ABANDONING.upon(connection_lost_follower, enter=CONNECTING,
outputs=[send_reconnecting, start_connecting]) outputs=[send_reconnecting, start_connecting])
@ -362,16 +409,15 @@ class ManagerShared(_ManagerBase):
send_reconnecting, send_reconnecting,
start_connecting]) start_connecting])
# rx_HINTS never changes state, they're just accepted or ignored # rx_HINTS never changes state, they're just accepted or ignored
IDLE.upon(rx_HINTS, enter=IDLE, outputs=[]) # too early IDLE.upon(rx_HINTS, enter=IDLE, outputs=[]) # too early
WANTED.upon(rx_HINTS, enter=WANTED, outputs=[]) # too early WANTED.upon(rx_HINTS, enter=WANTED, outputs=[]) # too early
WANTING.upon(rx_HINTS, enter=WANTING, outputs=[]) # too early WANTING.upon(rx_HINTS, enter=WANTING, outputs=[]) # too early
CONNECTING.upon(rx_HINTS, enter=CONNECTING, outputs=[use_hints]) CONNECTING.upon(rx_HINTS, enter=CONNECTING, outputs=[use_hints])
CONNECTED.upon(rx_HINTS, enter=CONNECTED, outputs=[]) # too late, ignore CONNECTED.upon(rx_HINTS, enter=CONNECTED, outputs=[]) # too late, ignore
FLUSHING.upon(rx_HINTS, enter=FLUSHING, outputs=[]) # stale, ignore FLUSHING.upon(rx_HINTS, enter=FLUSHING, outputs=[]) # stale, ignore
LONELY.upon(rx_HINTS, enter=LONELY, outputs=[]) # stale, ignore LONELY.upon(rx_HINTS, enter=LONELY, outputs=[]) # stale, ignore
ABANDONING.upon(rx_HINTS, enter=ABANDONING, outputs=[]) # shouldn't happen ABANDONING.upon(rx_HINTS, enter=ABANDONING, outputs=[]) # shouldn't happen
STOPPING.upon(rx_HINTS, enter=STOPPING, outputs=[]) STOPPING.upon(rx_HINTS, enter=STOPPING, outputs=[])
IDLE.upon(stop, enter=STOPPED, outputs=[]) IDLE.upon(stop, enter=STOPPED, outputs=[])
@ -385,7 +431,6 @@ class ManagerShared(_ManagerBase):
STOPPING.upon(connection_lost_leader, enter=STOPPED, outputs=[]) STOPPING.upon(connection_lost_leader, enter=STOPPED, outputs=[])
STOPPING.upon(connection_lost_follower, enter=STOPPED, outputs=[]) STOPPING.upon(connection_lost_follower, enter=STOPPED, outputs=[])
def allocate_subchannel_id(self): def allocate_subchannel_id(self):
# scid 0 is reserved for the control channel. the leader uses odd # scid 0 is reserved for the control channel. the leader uses odd
# numbers starting with 1 # numbers starting with 1
@ -393,6 +438,7 @@ class ManagerShared(_ManagerBase):
self._next_outbound_seqnum += 2 self._next_outbound_seqnum += 2
return to_be4(scid_num) return to_be4(scid_num)
@attrs @attrs
@implementer(IDilator) @implementer(IDilator)
class Dilator(object): class Dilator(object):
@ -436,10 +482,10 @@ class Dilator(object):
dilation_version = yield self._got_versions_d dilation_version = yield self._got_versions_d
if not dilation_version: # 1 or None if not dilation_version: # 1 or None
raise OldPeerCannotDilateError() raise OldPeerCannotDilateError()
my_dilation_side = TODO # random my_dilation_side = TODO # random
self._manager = Manager(self._S, my_dilation_side, self._manager = Manager(self._S, my_dilation_side,
self._transit_key, self._transit_key,
self._transit_relay_location, self._transit_relay_location,
@ -455,14 +501,15 @@ class Dilator(object):
yield self._manager.when_first_connected() yield self._manager.when_first_connected()
# we can open subchannels as soon as we get our first connection # we can open subchannels as soon as we get our first connection
scid0 = b"\x00\x00\x00\x00" scid0 = b"\x00\x00\x00\x00"
self._host_addr = _WormholeAddress() # TODO: share with Manager self._host_addr = _WormholeAddress() # TODO: share with Manager
peer_addr0 = _SubchannelAddress(scid0) peer_addr0 = _SubchannelAddress(scid0)
control_ep = ControlEndpoint(peer_addr0) control_ep = ControlEndpoint(peer_addr0)
sc0 = SubChannel(scid0, self._manager, self._host_addr, peer_addr0) sc0 = SubChannel(scid0, self._manager, self._host_addr, peer_addr0)
control_ep._subchannel_zero_opened(sc0) control_ep._subchannel_zero_opened(sc0)
self._manager.set_subchannel_zero(scid0, sc0) self._manager.set_subchannel_zero(scid0, sc0)
connect_ep = SubchannelConnectorEndpoint(self._manager, self._host_addr) connect_ep = SubchannelConnectorEndpoint(
self._manager, self._host_addr)
listen_ep = SubchannelListenerEndpoint(self._manager, self._host_addr) listen_ep = SubchannelListenerEndpoint(self._manager, self._host_addr)
self._manager.set_listener_endpoint(listen_ep) self._manager.set_listener_endpoint(listen_ep)
@ -476,7 +523,7 @@ class Dilator(object):
# TODO: verify this happens before got_wormhole_versions, or add a gate # TODO: verify this happens before got_wormhole_versions, or add a gate
# to tolerate either ordering # to tolerate either ordering
purpose = b"dilation-v1" purpose = b"dilation-v1"
LENGTH =32 # TODO: whatever Noise wants, I guess LENGTH = 32 # TODO: whatever Noise wants, I guess
self._transit_key = derive_key(key, purpose, LENGTH) self._transit_key = derive_key(key, purpose, LENGTH)
def got_wormhole_versions(self, our_side, their_side, def got_wormhole_versions(self, our_side, their_side,
@ -504,9 +551,9 @@ class Dilator(object):
message = bytes_to_dict(plaintext) message = bytes_to_dict(plaintext)
type = message["type"] type = message["type"]
if type == "please": if type == "please":
self._manager.rx_PLEASE() # message) self._manager.rx_PLEASE() # message)
elif type == "dilate": elif type == "dilate":
self._manager.rx_DILATE() #message) self._manager.rx_DILATE() # message)
elif type == "connection-hints": elif type == "connection-hints":
self._manager.rx_HINTS(message) self._manager.rx_HINTS(message)
else: else:

View File

@ -168,9 +168,9 @@ class Outbound(object):
self._queued_unsent = deque() self._queued_unsent = deque()
# outbound flow control: the Connection throttles our writes # outbound flow control: the Connection throttles our writes
self._subchannel_producers = {} # Subchannel -> IProducer self._subchannel_producers = {} # Subchannel -> IProducer
self._paused = True # our Connection called our pauseProducing self._paused = True # our Connection called our pauseProducing
self._all_producers = deque() # rotates, left-is-next self._all_producers = deque() # rotates, left-is-next
self._paused_producers = set() self._paused_producers = set()
self._unpaused_producers = set() self._unpaused_producers = set()
self._check_invariants() self._check_invariants()
@ -186,7 +186,7 @@ class Outbound(object):
seqnum = self._next_outbound_seqnum seqnum = self._next_outbound_seqnum
self._next_outbound_seqnum += 1 self._next_outbound_seqnum += 1
r = record_type(seqnum, *args) r = record_type(seqnum, *args)
assert hasattr(r, "seqnum"), r # only Open/Data/Close assert hasattr(r, "seqnum"), r # only Open/Data/Close
return r return r
def queue_and_send_record(self, r): def queue_and_send_record(self, r):
@ -203,7 +203,7 @@ class Outbound(object):
self._connection.send_record(r) self._connection.send_record(r)
def send_if_connected(self, r): def send_if_connected(self, r):
assert isinstance(r, (KCM, Ping, Pong, Ack)), r # nothing with seqnum assert isinstance(r, (KCM, Ping, Pong, Ack)), r # nothing with seqnum
if self._connection: if self._connection:
self._connection.send_record(r) self._connection.send_record(r)
@ -235,7 +235,7 @@ class Outbound(object):
if self._paused: if self._paused:
# IPushProducers need to be paused immediately, before they # IPushProducers need to be paused immediately, before they
# speak # speak
producer.pauseProducing() # you wake up sleeping producer.pauseProducing() # you wake up sleeping
else: else:
# our PullToPush adapter must be started, but if we're paused then # our PullToPush adapter must be started, but if we're paused then
# we tell it to pause before it gets a chance to write anything # we tell it to pause before it gets a chance to write anything
@ -265,7 +265,7 @@ class Outbound(object):
assert not self._queued_unsent assert not self._queued_unsent
self._queued_unsent.extend(self._outbound_queue) self._queued_unsent.extend(self._outbound_queue)
# the connection can tell us to pause when we send too much data # the connection can tell us to pause when we send too much data
c.registerProducer(self, True) # IPushProducer: pause+resume c.registerProducer(self, True) # IPushProducer: pause+resume
# send our queued messages # send our queued messages
self.resumeProducing() self.resumeProducing()
@ -290,12 +290,12 @@ class Outbound(object):
# Inbound is responsible for tracking the high watermark and deciding # Inbound is responsible for tracking the high watermark and deciding
# whether to ignore inbound messages or not # whether to ignore inbound messages or not
# IProducer: the active connection calls these because we used # IProducer: the active connection calls these because we used
# c.registerProducer to ask for them # c.registerProducer to ask for them
def pauseProducing(self): def pauseProducing(self):
if self._paused: if self._paused:
return # someone is confused and called us twice return # someone is confused and called us twice
self._paused = True self._paused = True
for p in self._all_producers: for p in self._all_producers:
if p in self._unpaused_producers: if p in self._unpaused_producers:
@ -305,7 +305,7 @@ class Outbound(object):
def resumeProducing(self): def resumeProducing(self):
if not self._paused: if not self._paused:
return # someone is confused and called us twice return # someone is confused and called us twice
self._paused = False self._paused = False
while not self._paused: while not self._paused:
@ -326,7 +326,7 @@ class Outbound(object):
return None return None
while True: while True:
p = self._all_producers[0] p = self._all_producers[0]
self._all_producers.rotate(-1) # p moves to the end of the line self._all_producers.rotate(-1) # p moves to the end of the line
# the only unpaused Producers are at the end of the list # the only unpaused Producers are at the end of the list
assert p in self._paused_producers assert p in self._paused_producers
return p return p
@ -343,7 +343,7 @@ class Outbound(object):
@attrs(cmp=False) @attrs(cmp=False)
class PullToPush(object): class PullToPush(object):
_producer = attrib(validator=provides(IPullProducer)) _producer = attrib(validator=provides(IPullProducer))
_unregister = attrib(validator=lambda _a,_b,v: callable(v)) _unregister = attrib(validator=lambda _a, _b, v: callable(v))
_cooperator = attrib() _cooperator = attrib()
_finished = False _finished = False
@ -351,14 +351,14 @@ class PullToPush(object):
while True: while True:
try: try:
self._producer.resumeProducing() self._producer.resumeProducing()
except: except Exception:
log.err(None, "%s failed, producing will be stopped:" % log.err(None, "%s failed, producing will be stopped:" %
(safe_str(self._producer),)) (safe_str(self._producer),))
try: try:
self._unregister() self._unregister()
# The consumer should now call stopStreaming() on us, # The consumer should now call stopStreaming() on us,
# thus stopping the streaming. # thus stopping the streaming.
except: except Exception:
# Since the consumer blew up, we may not have had # Since the consumer blew up, we may not have had
# stopStreaming() called, so we just stop on our own: # stopStreaming() called, so we just stop on our own:
log.err(None, "%s failed to unregister producer:" % log.err(None, "%s failed to unregister producer:" %
@ -370,7 +370,7 @@ class PullToPush(object):
def startStreaming(self, paused): def startStreaming(self, paused):
self._coopTask = self._cooperator.cooperate(self._pull()) self._coopTask = self._cooperator.cooperate(self._pull())
if paused: if paused:
self.pauseProducing() # timer is scheduled, but task is removed self.pauseProducing() # timer is scheduled, but task is removed
def stopStreaming(self): def stopStreaming(self):
if self._finished: if self._finished:
@ -378,15 +378,12 @@ class PullToPush(object):
self._finished = True self._finished = True
self._coopTask.stop() self._coopTask.stop()
def pauseProducing(self): def pauseProducing(self):
self._coopTask.pause() self._coopTask.pause()
def resumeProducing(self): def resumeProducing(self):
self._coopTask.resume() self._coopTask.resume()
def stopProducing(self): def stopProducing(self):
self.stopStreaming() self.stopStreaming()
self._producer.stopProducing() self._producer.stopProducing()

View File

@ -1,7 +1,8 @@
from attr import attrs, attrib from attr import attrs, attrib
from attr.validators import instance_of, provides from attr.validators import instance_of, provides
from zope.interface import implementer from zope.interface import implementer
from twisted.internet.defer import Deferred, inlineCallbacks, returnValue, succeed from twisted.internet.defer import (Deferred, inlineCallbacks, returnValue,
succeed)
from twisted.internet.interfaces import (ITransport, IProducer, IConsumer, from twisted.internet.interfaces import (ITransport, IProducer, IConsumer,
IAddress, IListeningPort, IAddress, IListeningPort,
IStreamClientEndpoint, IStreamClientEndpoint,
@ -10,9 +11,11 @@ from twisted.internet.error import ConnectionDone
from automat import MethodicalMachine from automat import MethodicalMachine
from .._interfaces import ISubChannel, IDilationManager from .._interfaces import ISubChannel, IDilationManager
@attrs @attrs
class Once(object): class Once(object):
_errtype = attrib() _errtype = attrib()
def __attrs_post_init__(self): def __attrs_post_init__(self):
self._called = False self._called = False
@ -21,6 +24,7 @@ class Once(object):
raise self._errtype() raise self._errtype()
self._called = True self._called = True
class SingleUseEndpointError(Exception): class SingleUseEndpointError(Exception):
pass pass
@ -38,13 +42,16 @@ class SingleUseEndpointError(Exception):
# (CLOSING) rx CLOSE: deliver .connectionLost(), -> (CLOSED) # (CLOSING) rx CLOSE: deliver .connectionLost(), -> (CLOSED)
# object is deleted upon transition to (CLOSED) # object is deleted upon transition to (CLOSED)
class AlreadyClosedError(Exception): class AlreadyClosedError(Exception):
pass pass
@implementer(IAddress) @implementer(IAddress)
class _WormholeAddress(object): class _WormholeAddress(object):
pass pass
@implementer(IAddress) @implementer(IAddress)
@attrs @attrs
class _SubchannelAddress(object): class _SubchannelAddress(object):
@ -63,35 +70,44 @@ class SubChannel(object):
_peer_addr = attrib(validator=instance_of(_SubchannelAddress)) _peer_addr = attrib(validator=instance_of(_SubchannelAddress))
m = MethodicalMachine() m = MethodicalMachine()
set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover set_trace = getattr(m, "_setTrace", lambda self,
f: None) # pragma: no cover
def __attrs_post_init__(self): def __attrs_post_init__(self):
#self._mailbox = None # self._mailbox = None
#self._pending_outbound = {} # self._pending_outbound = {}
#self._processed = set() # self._processed = set()
self._protocol = None self._protocol = None
self._pending_dataReceived = [] self._pending_dataReceived = []
self._pending_connectionLost = (False, None) self._pending_connectionLost = (False, None)
@m.state(initial=True) @m.state(initial=True)
def open(self): pass # pragma: no cover def open(self):
pass # pragma: no cover
@m.state() @m.state()
def closing(): pass # pragma: no cover def closing():
pass # pragma: no cover
@m.state() @m.state()
def closed(): pass # pragma: no cover def closed():
pass # pragma: no cover
@m.input() @m.input()
def remote_data(self, data): pass def remote_data(self, data):
@m.input() pass
def remote_close(self): pass
@m.input() @m.input()
def local_data(self, data): pass def remote_close(self):
@m.input() pass
def local_close(self): pass
@m.input()
def local_data(self, data):
pass
@m.input()
def local_close(self):
pass
@m.output() @m.output()
def send_data(self, data): def send_data(self, data):
@ -120,9 +136,11 @@ class SubChannel(object):
@m.output() @m.output()
def error_closed_write(self, data): def error_closed_write(self, data):
raise AlreadyClosedError("write not allowed on closed subchannel") raise AlreadyClosedError("write not allowed on closed subchannel")
@m.output() @m.output()
def error_closed_close(self): def error_closed_close(self):
raise AlreadyClosedError("loseConnection not allowed on closed subchannel") raise AlreadyClosedError(
"loseConnection not allowed on closed subchannel")
# primary transitions # primary transitions
open.upon(remote_data, enter=open, outputs=[signal_dataReceived]) open.upon(remote_data, enter=open, outputs=[signal_dataReceived])
@ -146,7 +164,7 @@ class SubChannel(object):
if self._pending_dataReceived: if self._pending_dataReceived:
for data in self._pending_dataReceived: for data in self._pending_dataReceived:
self._protocol.dataReceived(data) self._protocol.dataReceived(data)
self._pending_dataReceived = [] self._pending_dataReceived = []
cl, what = self._pending_connectionLost cl, what = self._pending_connectionLost
if cl: if cl:
self._protocol.connectionLost(what) self._protocol.connectionLost(what)
@ -155,13 +173,17 @@ class SubChannel(object):
def write(self, data): def write(self, data):
assert isinstance(data, type(b"")) assert isinstance(data, type(b""))
self.local_data(data) self.local_data(data)
def writeSequence(self, iovec): def writeSequence(self, iovec):
self.write(b"".join(iovec)) self.write(b"".join(iovec))
def loseConnection(self): def loseConnection(self):
self.local_close() self.local_close()
def getHost(self): def getHost(self):
# we define "host addr" as the overall wormhole # we define "host addr" as the overall wormhole
return self._host_addr return self._host_addr
def getPeer(self): def getPeer(self):
# and "peer addr" as the subchannel within that wormhole # and "peer addr" as the subchannel within that wormhole
return self._peer_addr return self._peer_addr
@ -169,14 +191,17 @@ class SubChannel(object):
# IProducer: throttle inbound data (wormhole "up" to local app's Protocol) # IProducer: throttle inbound data (wormhole "up" to local app's Protocol)
def stopProducing(self): def stopProducing(self):
self._manager.subchannel_stopProducing(self) self._manager.subchannel_stopProducing(self)
def pauseProducing(self): def pauseProducing(self):
self._manager.subchannel_pauseProducing(self) self._manager.subchannel_pauseProducing(self)
def resumeProducing(self): def resumeProducing(self):
self._manager.subchannel_resumeProducing(self) self._manager.subchannel_resumeProducing(self)
# IConsumer: allow the wormhole to throttle outbound data (app->wormhole) # IConsumer: allow the wormhole to throttle outbound data (app->wormhole)
def registerProducer(self, producer, streaming): def registerProducer(self, producer, streaming):
self._manager.subchannel_registerProducer(self, producer, streaming) self._manager.subchannel_registerProducer(self, producer, streaming)
def unregisterProducer(self): def unregisterProducer(self):
self._manager.subchannel_unregisterProducer(self) self._manager.subchannel_unregisterProducer(self)
@ -184,6 +209,7 @@ class SubChannel(object):
@implementer(IStreamClientEndpoint) @implementer(IStreamClientEndpoint)
class ControlEndpoint(object): class ControlEndpoint(object):
_used = False _used = False
def __init__(self, peer_addr): def __init__(self, peer_addr):
self._subchannel_zero = Deferred() self._subchannel_zero = Deferred()
self._peer_addr = peer_addr self._peer_addr = peer_addr
@ -201,9 +227,10 @@ class ControlEndpoint(object):
t = yield self._subchannel_zero t = yield self._subchannel_zero
p = protocolFactory.buildProtocol(self._peer_addr) p = protocolFactory.buildProtocol(self._peer_addr)
t._set_protocol(p) t._set_protocol(p)
p.makeConnection(t) # set p.transport = t and call connectionMade() p.makeConnection(t) # set p.transport = t and call connectionMade()
returnValue(p) returnValue(p)
@implementer(IStreamClientEndpoint) @implementer(IStreamClientEndpoint)
@attrs @attrs
class SubchannelConnectorEndpoint(object): class SubchannelConnectorEndpoint(object):
@ -220,9 +247,10 @@ class SubchannelConnectorEndpoint(object):
t = SubChannel(scid, self._manager, self._host_addr, peer_addr) t = SubChannel(scid, self._manager, self._host_addr, peer_addr)
p = protocolFactory.buildProtocol(peer_addr) p = protocolFactory.buildProtocol(peer_addr)
t._set_protocol(p) t._set_protocol(p)
p.makeConnection(t) # set p.transport = t and call connectionMade() p.makeConnection(t) # set p.transport = t and call connectionMade()
return succeed(p) return succeed(p)
@implementer(IStreamServerEndpoint) @implementer(IStreamServerEndpoint)
@attrs @attrs
class SubchannelListenerEndpoint(object): class SubchannelListenerEndpoint(object):
@ -238,7 +266,7 @@ class SubchannelListenerEndpoint(object):
if self._factory: if self._factory:
self._connect(t, peer_addr) self._connect(t, peer_addr)
else: else:
self._pending_opens.append( (t, peer_addr) ) self._pending_opens.append((t, peer_addr))
def _connect(self, t, peer_addr): def _connect(self, t, peer_addr):
p = self._factory.buildProtocol(peer_addr) p = self._factory.buildProtocol(peer_addr)
@ -255,6 +283,7 @@ class SubchannelListenerEndpoint(object):
lp = SubchannelListeningPort(self._host_addr) lp = SubchannelListeningPort(self._host_addr)
return succeed(lp) return succeed(lp)
@implementer(IListeningPort) @implementer(IListeningPort)
@attrs @attrs
class SubchannelListeningPort(object): class SubchannelListeningPort(object):
@ -262,8 +291,10 @@ class SubchannelListeningPort(object):
def startListening(self): def startListening(self):
pass pass
def stopListening(self): def stopListening(self):
# TODO # TODO
pass pass
def getHost(self): def getHost(self):
return self._host_addr return self._host_addr