diff --git a/src/wormhole/_boss.py b/src/wormhole/_boss.py index cc6f4fa..d3cb5ff 100644 --- a/src/wormhole/_boss.py +++ b/src/wormhole/_boss.py @@ -67,7 +67,8 @@ class Boss(object): self._I = Input(self._timing) self._C = Code(self._timing) self._T = Terminator() - self._D = Dilator(self._reactor, self._eventual_queue, self._cooperator) + self._D = Dilator(self._reactor, self._eventual_queue, + self._cooperator) self._N.wire(self._M, self._I, self._RC, self._T) self._M.wire(self._N, self._RC, self._O, self._T) @@ -90,7 +91,7 @@ class Boss(object): self._rx_phases = {} # phase -> plaintext self._next_rx_dilate_seqnum = 0 - self._rx_dilate_seqnums = {} # seqnum -> plaintext + self._rx_dilate_seqnums = {} # seqnum -> plaintext self._result = "empty" @@ -205,7 +206,7 @@ class Boss(object): self._C.set_code(code) def dilate(self): - return self._D.dilate() # fires with endpoints + return self._D.dilate() # fires with endpoints @m.input() def send(self, plaintext): diff --git a/src/wormhole/_dilation/connection.py b/src/wormhole/_dilation/connection.py index f0c8e35..4b9ced7 100644 --- a/src/wormhole/_dilation/connection.py +++ b/src/wormhole/_dilation/connection.py @@ -39,20 +39,28 @@ from .roles import FOLLOWER # states). For the specific question of sending plaintext frames, Noise will # refuse us unless it's ready anyways, so the question is probably moot. + class IFramer(Interface): pass + + class IRecord(Interface): pass + def first(l): return l[0] + class Disconnect(Exception): pass + + RelayOK = namedtuple("RelayOk", []) Prologue = namedtuple("Prologue", []) Frame = namedtuple("Frame", ["frame"]) + @attrs @implementer(IFramer) class _Framer(object): @@ -69,30 +77,37 @@ class _Framer(object): # out (shared): transport.write (relay handshake, prologue) # states: want_relay, want_prologue, want_frame m = MethodicalMachine() - set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover + set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover @m.state() - def want_relay(self): pass # pragma: no cover + def want_relay(self): pass # pragma: no cover + @m.state(initial=True) - def want_prologue(self): pass # pragma: no cover + def want_prologue(self): pass # pragma: no cover + @m.state() - def want_frame(self): pass # pragma: no cover + def want_frame(self): pass # pragma: no cover @m.input() def use_relay(self, relay_handshake): pass + @m.input() def connectionMade(self): pass + @m.input() def parse(self): pass + @m.input() def got_relay_ok(self): pass + @m.input() def got_prologue(self): pass @m.output() def store_relay_handshake(self, relay_handshake): self._outbound_relay_handshake = relay_handshake - self._expected_relay_handshake = b"ok\n" # TODO: make this configurable + self._expected_relay_handshake = b"ok\n" # TODO: make this configurable + @m.output() def send_relay_handshake(self): self._transport.write(self._outbound_relay_handshake) @@ -113,17 +128,17 @@ class _Framer(object): @m.output() def can_send_frames(self): - self._can_send_frames = True # for assertion in send_frame() + self._can_send_frames = True # for assertion in send_frame() @m.output() def parse_frame(self): if len(self._buffer) < 4: return None frame_length = from_be4(self._buffer[0:4]) - if len(self._buffer) < 4+frame_length: + if len(self._buffer) < 4 + frame_length: return None - frame = self._buffer[4:4+frame_length] - self._buffer = self._buffer[4+frame_length:] # TODO: avoid copy + frame = self._buffer[4:4 + frame_length] + self._buffer = self._buffer[4 + frame_length:] # TODO: avoid copy return Frame(frame=frame) want_prologue.upon(use_relay, outputs=[store_relay_handshake], @@ -144,7 +159,6 @@ class _Framer(object): want_frame.upon(parse, outputs=[parse_frame], enter=want_frame, collector=first) - def _get_expected(self, name, expected): lb = len(self._buffer) le = len(expected) @@ -161,7 +175,7 @@ class _Framer(object): if (b"\n" in self._buffer or lb >= le): log.msg("bad {}: {}".format(name, self._buffer[:le])) raise Disconnect() - return False # wait a bit longer + return False # wait a bit longer # good so far, just waiting for the rest return False @@ -181,7 +195,7 @@ class _Framer(object): self.got_relay_ok() elif isinstance(token, Prologue): self.got_prologue() - yield token # triggers send_handshake + yield token # triggers send_handshake elif isinstance(token, Frame): yield token else: @@ -202,15 +216,16 @@ class _Framer(object): # from peer. Sent immediately by Follower, after Selection by Leader. # Record: namedtuple of KCM/Open/Data/Close/Ack/Ping/Pong + Handshake = namedtuple("Handshake", []) # decrypted frames: produces KCM, Ping, Pong, Open, Data, Close, Ack KCM = namedtuple("KCM", []) -Ping = namedtuple("Ping", ["ping_id"]) # ping_id is arbitrary 4-byte value +Ping = namedtuple("Ping", ["ping_id"]) # ping_id is arbitrary 4-byte value Pong = namedtuple("Pong", ["ping_id"]) -Open = namedtuple("Open", ["seqnum", "scid"]) # seqnum is integer +Open = namedtuple("Open", ["seqnum", "scid"]) # seqnum is integer Data = namedtuple("Data", ["seqnum", "scid", "data"]) -Close = namedtuple("Close", ["seqnum", "scid"]) # scid is integer -Ack = namedtuple("Ack", ["resp_seqnum"]) # resp_seqnum is integer +Close = namedtuple("Close", ["seqnum", "scid"]) # scid is integer +Ack = namedtuple("Ack", ["resp_seqnum"]) # resp_seqnum is integer Records = (KCM, Ping, Pong, Open, Data, Close, Ack) Handshake_or_Records = (Handshake,) + Records @@ -222,6 +237,7 @@ T_DATA = b"\x04" T_CLOSE = b"\x05" T_ACK = b"\x06" + def parse_record(plaintext): msgtype = plaintext[0:1] if msgtype == T_KCM: @@ -251,6 +267,7 @@ def parse_record(plaintext): log.err("received unknown message type: {}".format(plaintext)) raise ValueError() + def encode_record(r): if isinstance(r, KCM): return b"\x00" @@ -275,6 +292,7 @@ def encode_record(r): return b"\x06" + to_be4(r.resp_seqnum) raise TypeError(r) + @attrs @implementer(IRecord) class _Record(object): @@ -294,22 +312,25 @@ class _Record(object): # states: want_prologue, want_handshake, want_record @n.state(initial=True) - def want_prologue(self): pass # pragma: no cover + def want_prologue(self): pass # pragma: no cover + @n.state() - def want_handshake(self): pass # pragma: no cover + def want_handshake(self): pass # pragma: no cover + @n.state() - def want_message(self): pass # pragma: no cover + def want_message(self): pass # pragma: no cover @n.input() def got_prologue(self): pass + @n.input() def got_frame(self, frame): pass @n.output() def send_handshake(self): - handshake = self._noise.write_message() # generate the ephemeral key + handshake = self._noise.write_message() # generate the ephemeral key self._framer.send_frame(handshake) @n.output() @@ -351,10 +372,10 @@ class _Record(object): def add_and_unframe(self, data): for token in self._framer.add_and_parse(data): if isinstance(token, Prologue): - self.got_prologue() # triggers send_handshake + self.got_prologue() # triggers send_handshake else: assert isinstance(token, Frame) - yield self.got_frame(token.frame) # Handshake or a Record type + yield self.got_frame(token.frame) # Handshake or a Record type def send_record(self, r): message = encode_record(r) @@ -388,26 +409,30 @@ class DilatedConnectionProtocol(Protocol, object): _relay_handshake = None m = MethodicalMachine() - set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover + set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover def __attrs_post_init__(self): - self._manager = None # set if/when we are selected + self._manager = None # set if/when we are selected self._disconnected = OneShotObserver(self._eventual_queue) self._can_send_records = False @m.state(initial=True) - def unselected(self): pass # pragma: no cover + def unselected(self): pass # pragma: no cover + @m.state() - def selecting(self): pass # pragma: no cover + def selecting(self): pass # pragma: no cover + @m.state() - def selected(self): pass # pragma: no cover + def selected(self): pass # pragma: no cover @m.input() def got_kcm(self): pass + @m.input() def select(self, manager): - pass # fires set_manager() + pass # fires set_manager() + @m.input() def got_record(self, record): pass @@ -472,9 +497,9 @@ class DilatedConnectionProtocol(Protocol, object): elif isinstance(token, KCM): # if we're the leader, add this connection as a candiate. # if we're the follower, accept this connection. - self.got_kcm() # connector.add_candidate() + self.got_kcm() # connector.add_candidate() else: - self.got_record(token) # manager.got_record() + self.got_record(token) # manager.got_record() except Disconnect: self.transport.loseConnection() diff --git a/src/wormhole/_dilation/connector.py b/src/wormhole/_dilation/connector.py index 86f2a72..f530039 100644 --- a/src/wormhole/_dilation/connector.py +++ b/src/wormhole/_dilation/connector.py @@ -1,5 +1,6 @@ from __future__ import print_function, unicode_literals -import sys, re +import sys +import re from collections import defaultdict, namedtuple from binascii import hexlify import six @@ -13,7 +14,7 @@ from twisted.internet.endpoints import HostnameEndpoint, serverFromString from twisted.internet.protocol import ClientFactory, ServerFactory from twisted.python import log from hkdf import Hkdf -from .. import ipaddrs # TODO: move into _dilation/ +from .. import ipaddrs # TODO: move into _dilation/ from .._interfaces import IDilationConnector, IDilationManager from ..timing import DebugTiming from ..observer import EmptyableSet @@ -30,7 +31,8 @@ from .roles import LEADER # * expect to see the receiver/sender handshake bytes from the other side # * the sender writes "go\n", the receiver waits for "go\n" # * the rest of the connection contains transit data -DirectTCPV1Hint = namedtuple("DirectTCPV1Hint", ["hostname", "port", "priority"]) +DirectTCPV1Hint = namedtuple( + "DirectTCPV1Hint", ["hostname", "port", "priority"]) TorTCPV1Hint = namedtuple("TorTCPV1Hint", ["hostname", "port", "priority"]) # RelayV1Hint contains a tuple of DirectTCPV1Hint and TorTCPV1Hint hints (we # use a tuple rather than a list so they'll be hashable into a set). For each @@ -38,6 +40,7 @@ TorTCPV1Hint = namedtuple("TorTCPV1Hint", ["hostname", "port", "priority"]) # rest of the V1 protocol. Only one hint per relay is useful. RelayV1Hint = namedtuple("RelayV1Hint", ["hints"]) + def describe_hint_obj(hint, relay, tor): prefix = "tor->" if tor else "->" if relay: @@ -45,9 +48,10 @@ def describe_hint_obj(hint, relay, tor): if isinstance(hint, DirectTCPV1Hint): return prefix + "tcp:%s:%d" % (hint.hostname, hint.port) elif isinstance(hint, TorTCPV1Hint): - return prefix+"tor:%s:%d" % (hint.hostname, hint.port) + return prefix + "tor:%s:%d" % (hint.hostname, hint.port) else: - return prefix+str(hint) + return prefix + str(hint) + def parse_hint_argv(hint, stderr=sys.stderr): assert isinstance(hint, type("")) @@ -59,7 +63,8 @@ def parse_hint_argv(hint, stderr=sys.stderr): return None hint_type = mo.group(1) if hint_type != "tcp": - print("unknown hint type '%s' in '%s'" % (hint_type, hint), file=stderr) + print("unknown hint type '%s' in '%s'" % (hint_type, hint), + file=stderr) return None hint_value = mo.group(2) pieces = hint_value.split(":") @@ -84,17 +89,18 @@ def parse_hint_argv(hint, stderr=sys.stderr): return None return DirectTCPV1Hint(hint_host, hint_port, priority) -def parse_tcp_v1_hint(hint): # hint_struct -> hint_obj + +def parse_tcp_v1_hint(hint): # hint_struct -> hint_obj hint_type = hint.get("type", "") if hint_type not in ["direct-tcp-v1", "tor-tcp-v1"]: log.msg("unknown hint type: %r" % (hint,)) return None - if not("hostname" in hint - and isinstance(hint["hostname"], type(""))): + if not("hostname" in hint and + isinstance(hint["hostname"], type(""))): log.msg("invalid hostname in hint: %r" % (hint,)) return None - if not("port" in hint - and isinstance(hint["port"], six.integer_types)): + if not("port" in hint and + isinstance(hint["port"], six.integer_types)): log.msg("invalid port in hint: %r" % (hint,)) return None priority = hint.get("priority", 0.0) @@ -103,51 +109,58 @@ def parse_tcp_v1_hint(hint): # hint_struct -> hint_obj else: return TorTCPV1Hint(hint["hostname"], hint["port"], priority) + def parse_hint(hint_struct): hint_type = hint_struct.get("type", "") if hint_type == "relay-v1": # the struct can include multiple ways to reach the same relay - rhints = filter(lambda h: h, # drop None (unrecognized) + rhints = filter(lambda h: h, # drop None (unrecognized) [parse_tcp_v1_hint(rh) for rh in hint_struct["hints"]]) return RelayV1Hint(rhints) return parse_tcp_v1_hint(hint_struct) + def encode_hint(h): if isinstance(h, DirectTCPV1Hint): return {"type": "direct-tcp-v1", "priority": h.priority, "hostname": h.hostname, - "port": h.port, # integer + "port": h.port, # integer } elif isinstance(h, RelayV1Hint): rhint = {"type": "relay-v1", "hints": []} for rh in h.hints: rhint["hints"].append({"type": "direct-tcp-v1", - "priority": rh.priority, - "hostname": rh.hostname, - "port": rh.port}) + "priority": rh.priority, + "hostname": rh.hostname, + "port": rh.port}) return rhint elif isinstance(h, TorTCPV1Hint): return {"type": "tor-tcp-v1", "priority": h.priority, "hostname": h.hostname, - "port": h.port, # integer + "port": h.port, # integer } raise ValueError("unknown hint type", h) + def HKDF(skm, outlen, salt=None, CTXinfo=b""): return Hkdf(salt, skm).expand(CTXinfo, outlen) + def build_sided_relay_handshake(key, side): assert isinstance(side, type(u"")) - assert len(side) == 8*2 + assert len(side) == 8 * 2 token = HKDF(key, 32, CTXinfo=b"transit_relay_token") - return b"please relay "+hexlify(token)+b" for side "+side.encode("ascii")+b"\n" + return (b"please relay " + hexlify(token) + + b" for side " + side.encode("ascii") + b"\n") -PROLOGUE_LEADER = b"Magic-Wormhole Dilation Handshake v1 Leader\n\n" + +PROLOGUE_LEADER = b"Magic-Wormhole Dilation Handshake v1 Leader\n\n" PROLOGUE_FOLLOWER = b"Magic-Wormhole Dilation Handshake v1 Follower\n\n" NOISEPROTO = "Noise_NNpsk0_25519_ChaChaPoly_BLAKE2s" + @attrs @implementer(IDilationConnector) class Connector(object): @@ -176,10 +189,11 @@ class Connector(object): self._transit_relays = [relay] else: self._transit_relays = [] - self._listeners = set() # IListeningPorts that can be stopped - self._pending_connectors = set() # Deferreds that can be cancelled - self._pending_connections = EmptyableSet(_eventual_queue=self._eventual_queue) # Protocols to be stopped - self._contenders = set() # viable connections + self._listeners = set() # IListeningPorts that can be stopped + self._pending_connectors = set() # Deferreds that can be cancelled + self._pending_connections = EmptyableSet( + _eventual_queue=self._eventual_queue) # Protocols to be stopped + self._contenders = set() # viable connections self._winning_connection = None self._timing = self._timing or DebugTiming() self._timing.add("transit") @@ -212,26 +226,41 @@ class Connector(object): return p @m.state(initial=True) - def connecting(self): pass # pragma: no cover + def connecting(self): + pass # pragma: no cover + @m.state() - def connected(self): pass # pragma: no cover + def connected(self): + pass # pragma: no cover + @m.state(terminal=True) - def stopped(self): pass # pragma: no cover + def stopped(self): + pass # pragma: no cover # TODO: unify the tense of these method-name verbs @m.input() - def listener_ready(self, hint_objs): pass - @m.input() - def add_relay(self, hint_objs): pass - @m.input() - def got_hints(self, hint_objs): pass - @m.input() - def add_candidate(self, c): # called by DilatedConnectionProtocol + def listener_ready(self, hint_objs): pass + @m.input() - def accept(self, c): pass + def add_relay(self, hint_objs): + pass + @m.input() - def stop(self): pass + def got_hints(self, hint_objs): + pass + + @m.input() + def add_candidate(self, c): # called by DilatedConnectionProtocol + pass + + @m.input() + def accept(self, c): + pass + + @m.input() + def stop(self): + pass @m.output() def use_hints(self, hint_objs): @@ -255,19 +284,19 @@ class Connector(object): @m.output() def select_and_stop_remaining(self, c): self._winning_connection = c - self._contenders.clear() # we no longer care who else came close + self._contenders.clear() # we no longer care who else came close # remove this winner from the losers, so we don't shut it down self._pending_connections.discard(c) # shut down losing connections - self.stop_listeners() # TODO: maybe keep it open? NAT/p2p assist + self.stop_listeners() # TODO: maybe keep it open? NAT/p2p assist self.stop_pending_connectors() self.stop_pending_connections() - c.select(self._manager) # subsequent frames go directly to the manager + c.select(self._manager) # subsequent frames go directly to the manager if self._role is LEADER: # TODO: this should live in Connection - c.send_record(KCM()) # leader sends KCM now - self._manager.use_connection(c) # manager sends frames to Connection + c.send_record(KCM()) # leader sends KCM now + self._manager.use_connection(c) # manager sends frames to Connection @m.output() def stop_everything(self): @@ -279,7 +308,7 @@ class Connector(object): def stop_listeners(self): d = DeferredList([l.stopListening() for l in self._listeners]) self._listeners.clear() - return d # synchronization for tests + return d # synchronization for tests def stop_pending_connectors(self): return DeferredList([d.cancel() for d in self._pending_connectors]) @@ -306,7 +335,8 @@ class Connector(object): publish_hints]) connecting.upon(got_hints, enter=connecting, outputs=[use_hints]) connecting.upon(add_candidate, enter=connecting, outputs=[consider]) - connecting.upon(accept, enter=connected, outputs=[select_and_stop_remaining]) + connecting.upon(accept, enter=connected, outputs=[ + select_and_stop_remaining]) connecting.upon(stop, enter=stopped, outputs=[stop_everything]) # once connected, we ignore everything except stop @@ -317,9 +347,9 @@ class Connector(object): connected.upon(accept, enter=connected, outputs=[]) connected.upon(stop, enter=stopped, outputs=[stop_everything]) - # from Manager: start, got_hints, stop # maybe add_candidate, accept + def start(self): self._start_listener() if self._transit_relays: @@ -341,9 +371,10 @@ class Connector(object): ep = serverFromString(self._reactor, "tcp:0") f = InboundConnectionFactory(self) d = ep.listen(f) + def _listening(lp): # lp is an IListeningPort - self._listeners.add(lp) # for shutdown and tests + self._listeners.add(lp) # for shutdown and tests portnum = lp.getHost().port direct_hints = [DirectTCPV1Hint(six.u(addr), portnum, 0.0) for addr in addresses] @@ -378,7 +409,7 @@ class Connector(object): # one still running. But if we bail on that, we might consider # putting an inter-direct-hint delay here to influence the # process. - #delay += 1.0 + # delay += 1.0 if delay > 0.0: # Start trying the relays a few seconds after we start to try the # direct hints. The idea is to prefer direct connections, but not @@ -408,7 +439,7 @@ class Connector(object): self._connect, ep, desc, is_relay=True) self._pending_connectors.add(d) # TODO: - #if not contenders: + # if not contenders: # raise TransitError("No contenders for connection") # TODO: add 2*TIMEOUT deadline for first generation, don't wait forever for @@ -422,6 +453,7 @@ class Connector(object): f = OutboundConnectionFactory(self, relay_handshake) d = ep.connect(f) # fires with protocol, or ConnectError + def _connected(p): self._pending_connections.add(p) # c might not be in _pending_connections, if it turned out to be a @@ -444,7 +476,6 @@ class Connector(object): return HostnameEndpoint(self._reactor, hint.hostname, hint.port) return None - # Connection selection. All instances of DilatedConnectionProtocol which # look viable get passed into our add_contender() method. @@ -459,6 +490,7 @@ class Connector(object): # our Connection protocols call: add_candidate + @attrs class OutboundConnectionFactory(ClientFactory, object): _connector = attrib(validator=provides(IDilationConnector)) @@ -471,6 +503,7 @@ class OutboundConnectionFactory(ClientFactory, object): p.use_relay(self._relay_handshake) return p + @attrs class InboundConnectionFactory(ServerFactory, object): _connector = attrib(validator=provides(IDilationConnector)) diff --git a/src/wormhole/_dilation/encode.py b/src/wormhole/_dilation/encode.py index eb1b0b2..80e9902 100644 --- a/src/wormhole/_dilation/encode.py +++ b/src/wormhole/_dilation/encode.py @@ -4,10 +4,13 @@ import struct assert len(struct.pack("L", value) + + def from_be4(b): if not isinstance(b, bytes): raise TypeError(repr(b)) diff --git a/src/wormhole/_dilation/inbound.py b/src/wormhole/_dilation/inbound.py index 9235681..2f6ffaf 100644 --- a/src/wormhole/_dilation/inbound.py +++ b/src/wormhole/_dilation/inbound.py @@ -6,13 +6,19 @@ from twisted.python import log from .._interfaces import IDilationManager, IInbound from .subchannel import (SubChannel, _SubchannelAddress) + class DuplicateOpenError(Exception): pass + + class DataForMissingSubchannelError(Exception): pass + + class CloseForMissingSubchannelError(Exception): pass + @attrs @implementer(IInbound) class Inbound(object): @@ -24,8 +30,8 @@ class Inbound(object): def __attrs_post_init__(self): # we route inbound Data records to Subchannels .dataReceived - self._open_subchannels = {} # scid -> Subchannel - self._paused_subchannels = set() # Subchannels that have paused us + self._open_subchannels = {} # scid -> Subchannel + self._paused_subchannels = set() # Subchannels that have paused us # the set is non-empty, we pause the transport self._highest_inbound_acked = -1 self._connection = None @@ -37,7 +43,6 @@ class Inbound(object): def set_subchannel_zero(self, scid0, sc0): self._open_subchannels[scid0] = sc0 - def use_connection(self, c): self._connection = c # We can pause the connection's reads when we receive too much data. If @@ -61,7 +66,8 @@ class Inbound(object): def handle_open(self, scid): if scid in self._open_subchannels: - log.err(DuplicateOpenError("received duplicate OPEN for {}".format(scid))) + log.err(DuplicateOpenError( + "received duplicate OPEN for {}".format(scid))) return peer_addr = _SubchannelAddress(scid) sc = SubChannel(scid, self._manager, self._host_addr, peer_addr) @@ -71,14 +77,16 @@ class Inbound(object): def handle_data(self, scid, data): sc = self._open_subchannels.get(scid) if sc is None: - log.err(DataForMissingSubchannelError("received DATA for non-existent subchannel {}".format(scid))) + log.err(DataForMissingSubchannelError( + "received DATA for non-existent subchannel {}".format(scid))) return sc.remote_data(data) def handle_close(self, scid): sc = self._open_subchannels.get(scid) if sc is None: - log.err(CloseForMissingSubchannelError("received CLOSE for non-existent subchannel {}".format(scid))) + log.err(CloseForMissingSubchannelError( + "received CLOSE for non-existent subchannel {}".format(scid))) return sc.remote_close() @@ -90,7 +98,6 @@ class Inbound(object): def stop_using_connection(self): self._connection = None - # from our Subchannel, or rather from the Protocol above it and sent # through the subchannel diff --git a/src/wormhole/_dilation/manager.py b/src/wormhole/_dilation/manager.py index 860665a..b16d198 100644 --- a/src/wormhole/_dilation/manager.py +++ b/src/wormhole/_dilation/manager.py @@ -20,13 +20,19 @@ from .connection import KCM, Ping, Pong, Open, Data, Close, Ack from .inbound import Inbound from .outbound import Outbound + class OldPeerCannotDilateError(Exception): pass + + class UnknownDilationMessageType(Exception): pass + + class ReceivedHintsTooEarly(Exception): pass + @attrs @implementer(IDilationManager) class _ManagerBase(object): @@ -37,14 +43,14 @@ class _ManagerBase(object): _reactor = attrib() _eventual_queue = attrib() _cooperator = attrib() - _no_listen = False # TODO - _tor = None # TODO - _timing = None # TODO + _no_listen = False # TODO + _tor = None # TODO + _timing = None # TODO def __attrs_post_init__(self): self._got_versions_d = Deferred() - self._my_role = None # determined upon rx_PLEASE + self._my_role = None # determined upon rx_PLEASE self._connection = None self._made_first_connection = False @@ -53,51 +59,56 @@ class _ManagerBase(object): self._next_dilation_phase = 0 - self._next_subchannel_id = 0 # increments by 2 + self._next_subchannel_id = 0 # increments by 2 # I kept getting confused about which methods were for inbound data # (and thus flow-control methods go "out") and which were for # outbound data (with flow-control going "in"), so I split them up # into separate pieces. self._inbound = Inbound(self, self._host_addr) - self._outbound = Outbound(self, self._cooperator) # from us to peer + self._outbound = Outbound(self, self._cooperator) # from us to peer def set_listener_endpoint(self, listener_endpoint): self._inbound.set_listener_endpoint(listener_endpoint) + def set_subchannel_zero(self, scid0, sc0): self._inbound.set_subchannel_zero(scid0, sc0) def when_first_connected(self): return self._first_connected.when_fired() - def send_dilation_phase(self, **fields): dilation_phase = self._next_dilation_phase self._next_dilation_phase += 1 self._S.send("dilate-%d" % dilation_phase, dict_to_bytes(fields)) - def send_hints(self, hints): # from Connector + def send_hints(self, hints): # from Connector self.send_dilation_phase(type="connection-hints", hints=hints) - # forward inbound-ish things to _Inbound + def subchannel_pauseProducing(self, sc): self._inbound.subchannel_pauseProducing(sc) + def subchannel_resumeProducing(self, sc): self._inbound.subchannel_resumeProducing(sc) + def subchannel_stopProducing(self, sc): self._inbound.subchannel_stopProducing(sc) # forward outbound-ish things to _Outbound def subchannel_registerProducer(self, sc, producer, streaming): self._outbound.subchannel_registerProducer(sc, producer, streaming) + def subchannel_unregisterProducer(self, sc): self._outbound.subchannel_unregisterProducer(sc) def send_open(self, scid): self._queue_and_send(Open, scid) + def send_data(self, scid, data): self._queue_and_send(Data, scid, data) + def send_close(self, scid): self._queue_and_send(Close, scid) @@ -106,7 +117,7 @@ class _ManagerBase(object): # Outbound owns the send_record() pipe, so that it can stall new # writes after a new connection is made until after all queued # messages are written (to preserve ordering). - self._outbound.queue_and_send_record(r) # may trigger pauseProducing + self._outbound.queue_and_send_record(r) # may trigger pauseProducing def subchannel_closed(self, scid, sc): # let everyone clean up. This happens just after we delivered @@ -116,7 +127,6 @@ class _ManagerBase(object): self._inbound.subchannel_closed(scid, sc) self._outbound.subchannel_closed(scid, sc) - def _start_connecting(self, role): assert self._my_role is not None self._connector = Connector(self._transit_key, @@ -125,41 +135,41 @@ class _ManagerBase(object): self._reactor, self._eventual_queue, self._no_listen, self._tor, self._timing, - self._side, # needed for relay handshake + self._side, # needed for relay handshake self._my_role) self._connector.start() # our Connector calls these def connector_connection_made(self, c): - self.connection_made() # state machine update + self.connection_made() # state machine update self._connection = c self._inbound.use_connection(c) - self._outbound.use_connection(c) # does c.registerProducer + self._outbound.use_connection(c) # does c.registerProducer if not self._made_first_connection: self._made_first_connection = True self._first_connected.fire(None) pass + def connector_connection_lost(self): self._stop_using_connection() if self.role is LEADER: - self.connection_lost_leader() # state machine + self.connection_lost_leader() # state machine else: self.connection_lost_follower() - def _stop_using_connection(self): # the connection is already lost by this point self._connection = None self._inbound.stop_using_connection() - self._outbound.stop_using_connection() # does c.unregisterProducer + self._outbound.stop_using_connection() # does c.unregisterProducer # from our active Connection def got_record(self, r): # records with sequence numbers: always ack, ignore old ones if isinstance(r, (Open, Data, Close)): - self.send_ack(r.seqnum) # always ack, even for old ones + self.send_ack(r.seqnum) # always ack, even for old ones if self._inbound.is_record_old(r): return self._inbound.update_ack_watermark(r.seqnum) @@ -167,7 +177,7 @@ class _ManagerBase(object): self._inbound.handle_open(r.scid) elif isinstance(r, Data): self._inbound.handle_data(r.scid, r.data) - else: # isinstance(r, Close) + else: # isinstance(r, Close) self._inbound.handle_close(r.scid) if isinstance(r, KCM): log.err("got unexpected KCM") @@ -176,7 +186,7 @@ class _ManagerBase(object): elif isinstance(r, Pong): self.handle_pong(r.ping_id) elif isinstance(r, Ack): - self._outbound.handle_ack(r.resp_seqnum) # retire queued messages + self._outbound.handle_ack(r.resp_seqnum) # retire queued messages else: log.err("received unknown message type {}".format(r)) @@ -190,7 +200,6 @@ class _ManagerBase(object): def send_ack(self, resp_seqnum): self._outbound.send_if_connected(Ack(resp_seqnum)) - def handle_ping(self, ping_id): self.send_pong(ping_id) @@ -200,7 +209,7 @@ class _ManagerBase(object): # subchannel maintenance def allocate_subchannel_id(self): - raise NotImplemented # subclass knows if we're leader or follower + raise NotImplementedError # subclass knows if we're leader or follower # new scheme: # * both sides send PLEASE as soon as they have an unverified key and @@ -236,58 +245,91 @@ class _ManagerBase(object): # * if follower calls w.dilate() but not leader, follower waits forever # in "want", leader waits forever in "wanted" + class ManagerShared(_ManagerBase): m = MethodicalMachine() set_trace = getattr(m, "_setTrace", lambda self, f: None) @m.state(initial=True) - def IDLE(self): pass # pragma: no cover + def IDLE(self): + pass # pragma: no cover @m.state() - def WANTING(self): pass # pragma: no cover + def WANTING(self): + pass # pragma: no cover + @m.state() - def WANTED(self): pass # pragma: no cover + def WANTED(self): + pass # pragma: no cover + @m.state() - def CONNECTING(self): pass # pragma: no cover + def CONNECTING(self): + pass # pragma: no cover + @m.state() - def CONNECTED(self): pass # pragma: no cover + def CONNECTED(self): + pass # pragma: no cover + @m.state() - def FLUSHING(self): pass # pragma: no cover + def FLUSHING(self): + pass # pragma: no cover + @m.state() - def ABANDONING(self): pass # pragma: no cover + def ABANDONING(self): + pass # pragma: no cover + @m.state() - def LONELY(self): pass # pragme: no cover + def LONELY(self): + pass # pragme: no cover + @m.state() - def STOPPING(self): pass # pragma: no cover + def STOPPING(self): + pass # pragma: no cover + @m.state(terminal=True) - def STOPPED(self): pass # pragma: no cover + def STOPPED(self): + pass # pragma: no cover @m.input() - def start(self): pass # pragma: no cover + def start(self): + pass # pragma: no cover + @m.input() - def rx_PLEASE(self, message): pass # pragma: no cover - @m.input() # only sent by Follower - def rx_HINTS(self, hint_message): pass # pragma: no cover - @m.input() # only Leader sends RECONNECT, so only Follower receives it - def rx_RECONNECT(self): pass # pragma: no cover - @m.input() # only Follower sends RECONNECTING, so only Leader receives it - def rx_RECONNECTING(self): pass # pragma: no cover + def rx_PLEASE(self, message): + pass # pragma: no cover + + @m.input() # only sent by Follower + def rx_HINTS(self, hint_message): + pass # pragma: no cover + + @m.input() # only Leader sends RECONNECT, so only Follower receives it + def rx_RECONNECT(self): + pass # pragma: no cover + + @m.input() # only Follower sends RECONNECTING, so only Leader receives it + def rx_RECONNECTING(self): + pass # pragma: no cover # Connector gives us connection_made() @m.input() - def connection_made(self, c): pass # pragma: no cover + def connection_made(self, c): + pass # pragma: no cover # our connection_lost() fires connection_lost_leader or # connection_lost_follower depending upon our role. If either side sees a # problem with the connection (timeouts, bad authentication) then they # just drop it and let connection_lost() handle the cleanup. @m.input() - def connection_lost_leader(self): pass # pragma: no cover - @m.input() - def connection_lost_follower(self): pass + def connection_lost_leader(self): + pass # pragma: no cover @m.input() - def stop(self): pass # pragma: no cover + def connection_lost_follower(self): + pass + + @m.input() + def stop(self): + pass # pragma: no cover @m.output() def stash_side(self, message): @@ -301,38 +343,42 @@ class ManagerShared(_ManagerBase): @m.output() def start_connecting(self): - self._start_connecting() # TODO: merge + self._start_connecting() # TODO: merge + @m.output() def ignore_message_start_connecting(self, message): self.start_connecting() @m.output() def send_reconnect(self): - self.send_dilation_phase(type="reconnect") # TODO: generation number? + self.send_dilation_phase(type="reconnect") # TODO: generation number? + @m.output() def send_reconnecting(self): - self.send_dilation_phase(type="reconnecting") # TODO: generation? + self.send_dilation_phase(type="reconnecting") # TODO: generation? @m.output() def use_hints(self, hint_message): - hint_objs = filter(lambda h: h, # ignore None, unrecognizable + hint_objs = filter(lambda h: h, # ignore None, unrecognizable [parse_hint(hs) for hs in hint_message["hints"]]) hint_objs = list(hint_objs) self._connector.got_hints(hint_objs) + @m.output() def stop_connecting(self): self._connector.stop() + @m.output() def abandon_connection(self): # we think we're still connected, but the Leader disagrees. Or we've # been told to shut down. - self._connection.disconnect() # let connection_lost do cleanup - + self._connection.disconnect() # let connection_lost do cleanup # we don't start CONNECTING until a local start() plus rx_PLEASE IDLE.upon(rx_PLEASE, enter=WANTED, outputs=[stash_side]) IDLE.upon(start, enter=WANTING, outputs=[send_please]) - WANTED.upon(start, enter=CONNECTING, outputs=[send_please, start_connecting]) + WANTED.upon(start, enter=CONNECTING, outputs=[ + send_please, start_connecting]) WANTING.upon(rx_PLEASE, enter=CONNECTING, outputs=[stash_side, ignore_message_start_connecting]) @@ -342,7 +388,8 @@ class ManagerShared(_ManagerBase): # Leader CONNECTED.upon(connection_lost_leader, enter=FLUSHING, outputs=[send_reconnect]) - FLUSHING.upon(rx_RECONNECTING, enter=CONNECTING, outputs=[start_connecting]) + FLUSHING.upon(rx_RECONNECTING, enter=CONNECTING, + outputs=[start_connecting]) # Follower # if we notice a lost connection, just wait for the Leader to notice too @@ -350,7 +397,7 @@ class ManagerShared(_ManagerBase): LONELY.upon(rx_RECONNECT, enter=CONNECTING, outputs=[start_connecting]) # but if they notice it first, abandon our (seemingly functional) # connection, then tell them that we're ready to try again - CONNECTED.upon(rx_RECONNECT, enter=ABANDONING, # they noticed loss + CONNECTED.upon(rx_RECONNECT, enter=ABANDONING, # they noticed loss outputs=[abandon_connection]) ABANDONING.upon(connection_lost_follower, enter=CONNECTING, outputs=[send_reconnecting, start_connecting]) @@ -362,16 +409,15 @@ class ManagerShared(_ManagerBase): send_reconnecting, start_connecting]) - # rx_HINTS never changes state, they're just accepted or ignored - IDLE.upon(rx_HINTS, enter=IDLE, outputs=[]) # too early - WANTED.upon(rx_HINTS, enter=WANTED, outputs=[]) # too early - WANTING.upon(rx_HINTS, enter=WANTING, outputs=[]) # too early + IDLE.upon(rx_HINTS, enter=IDLE, outputs=[]) # too early + WANTED.upon(rx_HINTS, enter=WANTED, outputs=[]) # too early + WANTING.upon(rx_HINTS, enter=WANTING, outputs=[]) # too early CONNECTING.upon(rx_HINTS, enter=CONNECTING, outputs=[use_hints]) - CONNECTED.upon(rx_HINTS, enter=CONNECTED, outputs=[]) # too late, ignore - FLUSHING.upon(rx_HINTS, enter=FLUSHING, outputs=[]) # stale, ignore - LONELY.upon(rx_HINTS, enter=LONELY, outputs=[]) # stale, ignore - ABANDONING.upon(rx_HINTS, enter=ABANDONING, outputs=[]) # shouldn't happen + CONNECTED.upon(rx_HINTS, enter=CONNECTED, outputs=[]) # too late, ignore + FLUSHING.upon(rx_HINTS, enter=FLUSHING, outputs=[]) # stale, ignore + LONELY.upon(rx_HINTS, enter=LONELY, outputs=[]) # stale, ignore + ABANDONING.upon(rx_HINTS, enter=ABANDONING, outputs=[]) # shouldn't happen STOPPING.upon(rx_HINTS, enter=STOPPING, outputs=[]) IDLE.upon(stop, enter=STOPPED, outputs=[]) @@ -385,7 +431,6 @@ class ManagerShared(_ManagerBase): STOPPING.upon(connection_lost_leader, enter=STOPPED, outputs=[]) STOPPING.upon(connection_lost_follower, enter=STOPPED, outputs=[]) - def allocate_subchannel_id(self): # scid 0 is reserved for the control channel. the leader uses odd # numbers starting with 1 @@ -393,6 +438,7 @@ class ManagerShared(_ManagerBase): self._next_outbound_seqnum += 2 return to_be4(scid_num) + @attrs @implementer(IDilator) class Dilator(object): @@ -436,10 +482,10 @@ class Dilator(object): dilation_version = yield self._got_versions_d - if not dilation_version: # 1 or None + if not dilation_version: # 1 or None raise OldPeerCannotDilateError() - my_dilation_side = TODO # random + my_dilation_side = TODO # random self._manager = Manager(self._S, my_dilation_side, self._transit_key, self._transit_relay_location, @@ -455,14 +501,15 @@ class Dilator(object): yield self._manager.when_first_connected() # we can open subchannels as soon as we get our first connection scid0 = b"\x00\x00\x00\x00" - self._host_addr = _WormholeAddress() # TODO: share with Manager + self._host_addr = _WormholeAddress() # TODO: share with Manager peer_addr0 = _SubchannelAddress(scid0) control_ep = ControlEndpoint(peer_addr0) sc0 = SubChannel(scid0, self._manager, self._host_addr, peer_addr0) control_ep._subchannel_zero_opened(sc0) self._manager.set_subchannel_zero(scid0, sc0) - connect_ep = SubchannelConnectorEndpoint(self._manager, self._host_addr) + connect_ep = SubchannelConnectorEndpoint( + self._manager, self._host_addr) listen_ep = SubchannelListenerEndpoint(self._manager, self._host_addr) self._manager.set_listener_endpoint(listen_ep) @@ -476,7 +523,7 @@ class Dilator(object): # TODO: verify this happens before got_wormhole_versions, or add a gate # to tolerate either ordering purpose = b"dilation-v1" - LENGTH =32 # TODO: whatever Noise wants, I guess + LENGTH = 32 # TODO: whatever Noise wants, I guess self._transit_key = derive_key(key, purpose, LENGTH) def got_wormhole_versions(self, our_side, their_side, @@ -504,9 +551,9 @@ class Dilator(object): message = bytes_to_dict(plaintext) type = message["type"] if type == "please": - self._manager.rx_PLEASE() # message) + self._manager.rx_PLEASE() # message) elif type == "dilate": - self._manager.rx_DILATE() #message) + self._manager.rx_DILATE() # message) elif type == "connection-hints": self._manager.rx_HINTS(message) else: diff --git a/src/wormhole/_dilation/outbound.py b/src/wormhole/_dilation/outbound.py index 6538ffe..96fbd3d 100644 --- a/src/wormhole/_dilation/outbound.py +++ b/src/wormhole/_dilation/outbound.py @@ -168,9 +168,9 @@ class Outbound(object): self._queued_unsent = deque() # outbound flow control: the Connection throttles our writes - self._subchannel_producers = {} # Subchannel -> IProducer - self._paused = True # our Connection called our pauseProducing - self._all_producers = deque() # rotates, left-is-next + self._subchannel_producers = {} # Subchannel -> IProducer + self._paused = True # our Connection called our pauseProducing + self._all_producers = deque() # rotates, left-is-next self._paused_producers = set() self._unpaused_producers = set() self._check_invariants() @@ -186,7 +186,7 @@ class Outbound(object): seqnum = self._next_outbound_seqnum self._next_outbound_seqnum += 1 r = record_type(seqnum, *args) - assert hasattr(r, "seqnum"), r # only Open/Data/Close + assert hasattr(r, "seqnum"), r # only Open/Data/Close return r def queue_and_send_record(self, r): @@ -203,7 +203,7 @@ class Outbound(object): self._connection.send_record(r) def send_if_connected(self, r): - assert isinstance(r, (KCM, Ping, Pong, Ack)), r # nothing with seqnum + assert isinstance(r, (KCM, Ping, Pong, Ack)), r # nothing with seqnum if self._connection: self._connection.send_record(r) @@ -235,7 +235,7 @@ class Outbound(object): if self._paused: # IPushProducers need to be paused immediately, before they # speak - producer.pauseProducing() # you wake up sleeping + producer.pauseProducing() # you wake up sleeping else: # our PullToPush adapter must be started, but if we're paused then # we tell it to pause before it gets a chance to write anything @@ -265,7 +265,7 @@ class Outbound(object): assert not self._queued_unsent self._queued_unsent.extend(self._outbound_queue) # the connection can tell us to pause when we send too much data - c.registerProducer(self, True) # IPushProducer: pause+resume + c.registerProducer(self, True) # IPushProducer: pause+resume # send our queued messages self.resumeProducing() @@ -290,12 +290,12 @@ class Outbound(object): # Inbound is responsible for tracking the high watermark and deciding # whether to ignore inbound messages or not - # IProducer: the active connection calls these because we used # c.registerProducer to ask for them + def pauseProducing(self): if self._paused: - return # someone is confused and called us twice + return # someone is confused and called us twice self._paused = True for p in self._all_producers: if p in self._unpaused_producers: @@ -305,7 +305,7 @@ class Outbound(object): def resumeProducing(self): if not self._paused: - return # someone is confused and called us twice + return # someone is confused and called us twice self._paused = False while not self._paused: @@ -326,7 +326,7 @@ class Outbound(object): return None while True: p = self._all_producers[0] - self._all_producers.rotate(-1) # p moves to the end of the line + self._all_producers.rotate(-1) # p moves to the end of the line # the only unpaused Producers are at the end of the list assert p in self._paused_producers return p @@ -343,7 +343,7 @@ class Outbound(object): @attrs(cmp=False) class PullToPush(object): _producer = attrib(validator=provides(IPullProducer)) - _unregister = attrib(validator=lambda _a,_b,v: callable(v)) + _unregister = attrib(validator=lambda _a, _b, v: callable(v)) _cooperator = attrib() _finished = False @@ -351,14 +351,14 @@ class PullToPush(object): while True: try: self._producer.resumeProducing() - except: + except Exception: log.err(None, "%s failed, producing will be stopped:" % (safe_str(self._producer),)) try: self._unregister() # The consumer should now call stopStreaming() on us, # thus stopping the streaming. - except: + except Exception: # Since the consumer blew up, we may not have had # stopStreaming() called, so we just stop on our own: log.err(None, "%s failed to unregister producer:" % @@ -370,7 +370,7 @@ class PullToPush(object): def startStreaming(self, paused): self._coopTask = self._cooperator.cooperate(self._pull()) if paused: - self.pauseProducing() # timer is scheduled, but task is removed + self.pauseProducing() # timer is scheduled, but task is removed def stopStreaming(self): if self._finished: @@ -378,15 +378,12 @@ class PullToPush(object): self._finished = True self._coopTask.stop() - def pauseProducing(self): self._coopTask.pause() - def resumeProducing(self): self._coopTask.resume() - def stopProducing(self): self.stopStreaming() self._producer.stopProducing() diff --git a/src/wormhole/_dilation/subchannel.py b/src/wormhole/_dilation/subchannel.py index 94b4a03..abd1939 100644 --- a/src/wormhole/_dilation/subchannel.py +++ b/src/wormhole/_dilation/subchannel.py @@ -1,7 +1,8 @@ from attr import attrs, attrib from attr.validators import instance_of, provides from zope.interface import implementer -from twisted.internet.defer import Deferred, inlineCallbacks, returnValue, succeed +from twisted.internet.defer import (Deferred, inlineCallbacks, returnValue, + succeed) from twisted.internet.interfaces import (ITransport, IProducer, IConsumer, IAddress, IListeningPort, IStreamClientEndpoint, @@ -10,9 +11,11 @@ from twisted.internet.error import ConnectionDone from automat import MethodicalMachine from .._interfaces import ISubChannel, IDilationManager + @attrs class Once(object): _errtype = attrib() + def __attrs_post_init__(self): self._called = False @@ -21,6 +24,7 @@ class Once(object): raise self._errtype() self._called = True + class SingleUseEndpointError(Exception): pass @@ -38,13 +42,16 @@ class SingleUseEndpointError(Exception): # (CLOSING) rx CLOSE: deliver .connectionLost(), -> (CLOSED) # object is deleted upon transition to (CLOSED) + class AlreadyClosedError(Exception): pass + @implementer(IAddress) class _WormholeAddress(object): pass + @implementer(IAddress) @attrs class _SubchannelAddress(object): @@ -63,35 +70,44 @@ class SubChannel(object): _peer_addr = attrib(validator=instance_of(_SubchannelAddress)) m = MethodicalMachine() - set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover + set_trace = getattr(m, "_setTrace", lambda self, + f: None) # pragma: no cover def __attrs_post_init__(self): - #self._mailbox = None - #self._pending_outbound = {} - #self._processed = set() + # self._mailbox = None + # self._pending_outbound = {} + # self._processed = set() self._protocol = None self._pending_dataReceived = [] self._pending_connectionLost = (False, None) @m.state(initial=True) - def open(self): pass # pragma: no cover + def open(self): + pass # pragma: no cover @m.state() - def closing(): pass # pragma: no cover + def closing(): + pass # pragma: no cover @m.state() - def closed(): pass # pragma: no cover + def closed(): + pass # pragma: no cover @m.input() - def remote_data(self, data): pass - @m.input() - def remote_close(self): pass + def remote_data(self, data): + pass @m.input() - def local_data(self, data): pass - @m.input() - def local_close(self): pass + def remote_close(self): + pass + @m.input() + def local_data(self, data): + pass + + @m.input() + def local_close(self): + pass @m.output() def send_data(self, data): @@ -120,9 +136,11 @@ class SubChannel(object): @m.output() def error_closed_write(self, data): raise AlreadyClosedError("write not allowed on closed subchannel") + @m.output() def error_closed_close(self): - raise AlreadyClosedError("loseConnection not allowed on closed subchannel") + raise AlreadyClosedError( + "loseConnection not allowed on closed subchannel") # primary transitions open.upon(remote_data, enter=open, outputs=[signal_dataReceived]) @@ -146,7 +164,7 @@ class SubChannel(object): if self._pending_dataReceived: for data in self._pending_dataReceived: self._protocol.dataReceived(data) - self._pending_dataReceived = [] + self._pending_dataReceived = [] cl, what = self._pending_connectionLost if cl: self._protocol.connectionLost(what) @@ -155,13 +173,17 @@ class SubChannel(object): def write(self, data): assert isinstance(data, type(b"")) self.local_data(data) + def writeSequence(self, iovec): self.write(b"".join(iovec)) + def loseConnection(self): self.local_close() + def getHost(self): # we define "host addr" as the overall wormhole return self._host_addr + def getPeer(self): # and "peer addr" as the subchannel within that wormhole return self._peer_addr @@ -169,14 +191,17 @@ class SubChannel(object): # IProducer: throttle inbound data (wormhole "up" to local app's Protocol) def stopProducing(self): self._manager.subchannel_stopProducing(self) + def pauseProducing(self): self._manager.subchannel_pauseProducing(self) + def resumeProducing(self): self._manager.subchannel_resumeProducing(self) # IConsumer: allow the wormhole to throttle outbound data (app->wormhole) def registerProducer(self, producer, streaming): self._manager.subchannel_registerProducer(self, producer, streaming) + def unregisterProducer(self): self._manager.subchannel_unregisterProducer(self) @@ -184,6 +209,7 @@ class SubChannel(object): @implementer(IStreamClientEndpoint) class ControlEndpoint(object): _used = False + def __init__(self, peer_addr): self._subchannel_zero = Deferred() self._peer_addr = peer_addr @@ -201,9 +227,10 @@ class ControlEndpoint(object): t = yield self._subchannel_zero p = protocolFactory.buildProtocol(self._peer_addr) t._set_protocol(p) - p.makeConnection(t) # set p.transport = t and call connectionMade() + p.makeConnection(t) # set p.transport = t and call connectionMade() returnValue(p) + @implementer(IStreamClientEndpoint) @attrs class SubchannelConnectorEndpoint(object): @@ -220,9 +247,10 @@ class SubchannelConnectorEndpoint(object): t = SubChannel(scid, self._manager, self._host_addr, peer_addr) p = protocolFactory.buildProtocol(peer_addr) t._set_protocol(p) - p.makeConnection(t) # set p.transport = t and call connectionMade() + p.makeConnection(t) # set p.transport = t and call connectionMade() return succeed(p) + @implementer(IStreamServerEndpoint) @attrs class SubchannelListenerEndpoint(object): @@ -238,7 +266,7 @@ class SubchannelListenerEndpoint(object): if self._factory: self._connect(t, peer_addr) else: - self._pending_opens.append( (t, peer_addr) ) + self._pending_opens.append((t, peer_addr)) def _connect(self, t, peer_addr): p = self._factory.buildProtocol(peer_addr) @@ -255,6 +283,7 @@ class SubchannelListenerEndpoint(object): lp = SubchannelListeningPort(self._host_addr) return succeed(lp) + @implementer(IListeningPort) @attrs class SubchannelListeningPort(object): @@ -262,8 +291,10 @@ class SubchannelListeningPort(object): def startListening(self): pass + def stopListening(self): # TODO pass + def getHost(self): return self._host_addr