fix some flake8 complaints

This commit is contained in:
Brian Warner 2018-06-30 16:19:41 -07:00
parent 39666f3fed
commit 05900bd08b
8 changed files with 339 additions and 195 deletions

View File

@ -67,7 +67,8 @@ class Boss(object):
self._I = Input(self._timing)
self._C = Code(self._timing)
self._T = Terminator()
self._D = Dilator(self._reactor, self._eventual_queue, self._cooperator)
self._D = Dilator(self._reactor, self._eventual_queue,
self._cooperator)
self._N.wire(self._M, self._I, self._RC, self._T)
self._M.wire(self._N, self._RC, self._O, self._T)

View File

@ -39,20 +39,28 @@ from .roles import FOLLOWER
# states). For the specific question of sending plaintext frames, Noise will
# refuse us unless it's ready anyways, so the question is probably moot.
class IFramer(Interface):
pass
class IRecord(Interface):
pass
def first(l):
return l[0]
class Disconnect(Exception):
pass
RelayOK = namedtuple("RelayOk", [])
Prologue = namedtuple("Prologue", [])
Frame = namedtuple("Frame", ["frame"])
@attrs
@implementer(IFramer)
class _Framer(object):
@ -73,19 +81,25 @@ class _Framer(object):
@m.state()
def want_relay(self): pass # pragma: no cover
@m.state(initial=True)
def want_prologue(self): pass # pragma: no cover
@m.state()
def want_frame(self): pass # pragma: no cover
@m.input()
def use_relay(self, relay_handshake): pass
@m.input()
def connectionMade(self): pass
@m.input()
def parse(self): pass
@m.input()
def got_relay_ok(self): pass
@m.input()
def got_prologue(self): pass
@ -93,6 +107,7 @@ class _Framer(object):
def store_relay_handshake(self, relay_handshake):
self._outbound_relay_handshake = relay_handshake
self._expected_relay_handshake = b"ok\n" # TODO: make this configurable
@m.output()
def send_relay_handshake(self):
self._transport.write(self._outbound_relay_handshake)
@ -120,10 +135,10 @@ class _Framer(object):
if len(self._buffer) < 4:
return None
frame_length = from_be4(self._buffer[0:4])
if len(self._buffer) < 4+frame_length:
if len(self._buffer) < 4 + frame_length:
return None
frame = self._buffer[4:4+frame_length]
self._buffer = self._buffer[4+frame_length:] # TODO: avoid copy
frame = self._buffer[4:4 + frame_length]
self._buffer = self._buffer[4 + frame_length:] # TODO: avoid copy
return Frame(frame=frame)
want_prologue.upon(use_relay, outputs=[store_relay_handshake],
@ -144,7 +159,6 @@ class _Framer(object):
want_frame.upon(parse, outputs=[parse_frame], enter=want_frame,
collector=first)
def _get_expected(self, name, expected):
lb = len(self._buffer)
le = len(expected)
@ -202,6 +216,7 @@ class _Framer(object):
# from peer. Sent immediately by Follower, after Selection by Leader.
# Record: namedtuple of KCM/Open/Data/Close/Ack/Ping/Pong
Handshake = namedtuple("Handshake", [])
# decrypted frames: produces KCM, Ping, Pong, Open, Data, Close, Ack
KCM = namedtuple("KCM", [])
@ -222,6 +237,7 @@ T_DATA = b"\x04"
T_CLOSE = b"\x05"
T_ACK = b"\x06"
def parse_record(plaintext):
msgtype = plaintext[0:1]
if msgtype == T_KCM:
@ -251,6 +267,7 @@ def parse_record(plaintext):
log.err("received unknown message type: {}".format(plaintext))
raise ValueError()
def encode_record(r):
if isinstance(r, KCM):
return b"\x00"
@ -275,6 +292,7 @@ def encode_record(r):
return b"\x06" + to_be4(r.resp_seqnum)
raise TypeError(r)
@attrs
@implementer(IRecord)
class _Record(object):
@ -295,14 +313,17 @@ class _Record(object):
@n.state(initial=True)
def want_prologue(self): pass # pragma: no cover
@n.state()
def want_handshake(self): pass # pragma: no cover
@n.state()
def want_message(self): pass # pragma: no cover
@n.input()
def got_prologue(self):
pass
@n.input()
def got_frame(self, frame):
pass
@ -397,17 +418,21 @@ class DilatedConnectionProtocol(Protocol, object):
@m.state(initial=True)
def unselected(self): pass # pragma: no cover
@m.state()
def selecting(self): pass # pragma: no cover
@m.state()
def selected(self): pass # pragma: no cover
@m.input()
def got_kcm(self):
pass
@m.input()
def select(self, manager):
pass # fires set_manager()
@m.input()
def got_record(self, record):
pass

View File

@ -1,5 +1,6 @@
from __future__ import print_function, unicode_literals
import sys, re
import sys
import re
from collections import defaultdict, namedtuple
from binascii import hexlify
import six
@ -30,7 +31,8 @@ from .roles import LEADER
# * expect to see the receiver/sender handshake bytes from the other side
# * the sender writes "go\n", the receiver waits for "go\n"
# * the rest of the connection contains transit data
DirectTCPV1Hint = namedtuple("DirectTCPV1Hint", ["hostname", "port", "priority"])
DirectTCPV1Hint = namedtuple(
"DirectTCPV1Hint", ["hostname", "port", "priority"])
TorTCPV1Hint = namedtuple("TorTCPV1Hint", ["hostname", "port", "priority"])
# RelayV1Hint contains a tuple of DirectTCPV1Hint and TorTCPV1Hint hints (we
# use a tuple rather than a list so they'll be hashable into a set). For each
@ -38,6 +40,7 @@ TorTCPV1Hint = namedtuple("TorTCPV1Hint", ["hostname", "port", "priority"])
# rest of the V1 protocol. Only one hint per relay is useful.
RelayV1Hint = namedtuple("RelayV1Hint", ["hints"])
def describe_hint_obj(hint, relay, tor):
prefix = "tor->" if tor else "->"
if relay:
@ -45,9 +48,10 @@ def describe_hint_obj(hint, relay, tor):
if isinstance(hint, DirectTCPV1Hint):
return prefix + "tcp:%s:%d" % (hint.hostname, hint.port)
elif isinstance(hint, TorTCPV1Hint):
return prefix+"tor:%s:%d" % (hint.hostname, hint.port)
return prefix + "tor:%s:%d" % (hint.hostname, hint.port)
else:
return prefix+str(hint)
return prefix + str(hint)
def parse_hint_argv(hint, stderr=sys.stderr):
assert isinstance(hint, type(""))
@ -59,7 +63,8 @@ def parse_hint_argv(hint, stderr=sys.stderr):
return None
hint_type = mo.group(1)
if hint_type != "tcp":
print("unknown hint type '%s' in '%s'" % (hint_type, hint), file=stderr)
print("unknown hint type '%s' in '%s'" % (hint_type, hint),
file=stderr)
return None
hint_value = mo.group(2)
pieces = hint_value.split(":")
@ -84,17 +89,18 @@ def parse_hint_argv(hint, stderr=sys.stderr):
return None
return DirectTCPV1Hint(hint_host, hint_port, priority)
def parse_tcp_v1_hint(hint): # hint_struct -> hint_obj
hint_type = hint.get("type", "")
if hint_type not in ["direct-tcp-v1", "tor-tcp-v1"]:
log.msg("unknown hint type: %r" % (hint,))
return None
if not("hostname" in hint
and isinstance(hint["hostname"], type(""))):
if not("hostname" in hint and
isinstance(hint["hostname"], type(""))):
log.msg("invalid hostname in hint: %r" % (hint,))
return None
if not("port" in hint
and isinstance(hint["port"], six.integer_types)):
if not("port" in hint and
isinstance(hint["port"], six.integer_types)):
log.msg("invalid port in hint: %r" % (hint,))
return None
priority = hint.get("priority", 0.0)
@ -103,6 +109,7 @@ def parse_tcp_v1_hint(hint): # hint_struct -> hint_obj
else:
return TorTCPV1Hint(hint["hostname"], hint["port"], priority)
def parse_hint(hint_struct):
hint_type = hint_struct.get("type", "")
if hint_type == "relay-v1":
@ -112,6 +119,7 @@ def parse_hint(hint_struct):
return RelayV1Hint(rhints)
return parse_tcp_v1_hint(hint_struct)
def encode_hint(h):
if isinstance(h, DirectTCPV1Hint):
return {"type": "direct-tcp-v1",
@ -135,19 +143,24 @@ def encode_hint(h):
}
raise ValueError("unknown hint type", h)
def HKDF(skm, outlen, salt=None, CTXinfo=b""):
return Hkdf(salt, skm).expand(CTXinfo, outlen)
def build_sided_relay_handshake(key, side):
assert isinstance(side, type(u""))
assert len(side) == 8*2
assert len(side) == 8 * 2
token = HKDF(key, 32, CTXinfo=b"transit_relay_token")
return b"please relay "+hexlify(token)+b" for side "+side.encode("ascii")+b"\n"
return (b"please relay " + hexlify(token) +
b" for side " + side.encode("ascii") + b"\n")
PROLOGUE_LEADER = b"Magic-Wormhole Dilation Handshake v1 Leader\n\n"
PROLOGUE_FOLLOWER = b"Magic-Wormhole Dilation Handshake v1 Follower\n\n"
NOISEPROTO = "Noise_NNpsk0_25519_ChaChaPoly_BLAKE2s"
@attrs
@implementer(IDilationConnector)
class Connector(object):
@ -178,7 +191,8 @@ class Connector(object):
self._transit_relays = []
self._listeners = set() # IListeningPorts that can be stopped
self._pending_connectors = set() # Deferreds that can be cancelled
self._pending_connections = EmptyableSet(_eventual_queue=self._eventual_queue) # Protocols to be stopped
self._pending_connections = EmptyableSet(
_eventual_queue=self._eventual_queue) # Protocols to be stopped
self._contenders = set() # viable connections
self._winning_connection = None
self._timing = self._timing or DebugTiming()
@ -212,26 +226,41 @@ class Connector(object):
return p
@m.state(initial=True)
def connecting(self): pass # pragma: no cover
def connecting(self):
pass # pragma: no cover
@m.state()
def connected(self): pass # pragma: no cover
def connected(self):
pass # pragma: no cover
@m.state(terminal=True)
def stopped(self): pass # pragma: no cover
def stopped(self):
pass # pragma: no cover
# TODO: unify the tense of these method-name verbs
@m.input()
def listener_ready(self, hint_objs): pass
def listener_ready(self, hint_objs):
pass
@m.input()
def add_relay(self, hint_objs): pass
def add_relay(self, hint_objs):
pass
@m.input()
def got_hints(self, hint_objs): pass
def got_hints(self, hint_objs):
pass
@m.input()
def add_candidate(self, c): # called by DilatedConnectionProtocol
pass
@m.input()
def accept(self, c): pass
def accept(self, c):
pass
@m.input()
def stop(self): pass
def stop(self):
pass
@m.output()
def use_hints(self, hint_objs):
@ -306,7 +335,8 @@ class Connector(object):
publish_hints])
connecting.upon(got_hints, enter=connecting, outputs=[use_hints])
connecting.upon(add_candidate, enter=connecting, outputs=[consider])
connecting.upon(accept, enter=connected, outputs=[select_and_stop_remaining])
connecting.upon(accept, enter=connected, outputs=[
select_and_stop_remaining])
connecting.upon(stop, enter=stopped, outputs=[stop_everything])
# once connected, we ignore everything except stop
@ -317,9 +347,9 @@ class Connector(object):
connected.upon(accept, enter=connected, outputs=[])
connected.upon(stop, enter=stopped, outputs=[stop_everything])
# from Manager: start, got_hints, stop
# maybe add_candidate, accept
def start(self):
self._start_listener()
if self._transit_relays:
@ -341,6 +371,7 @@ class Connector(object):
ep = serverFromString(self._reactor, "tcp:0")
f = InboundConnectionFactory(self)
d = ep.listen(f)
def _listening(lp):
# lp is an IListeningPort
self._listeners.add(lp) # for shutdown and tests
@ -378,7 +409,7 @@ class Connector(object):
# one still running. But if we bail on that, we might consider
# putting an inter-direct-hint delay here to influence the
# process.
#delay += 1.0
# delay += 1.0
if delay > 0.0:
# Start trying the relays a few seconds after we start to try the
# direct hints. The idea is to prefer direct connections, but not
@ -408,7 +439,7 @@ class Connector(object):
self._connect, ep, desc, is_relay=True)
self._pending_connectors.add(d)
# TODO:
#if not contenders:
# if not contenders:
# raise TransitError("No contenders for connection")
# TODO: add 2*TIMEOUT deadline for first generation, don't wait forever for
@ -422,6 +453,7 @@ class Connector(object):
f = OutboundConnectionFactory(self, relay_handshake)
d = ep.connect(f)
# fires with protocol, or ConnectError
def _connected(p):
self._pending_connections.add(p)
# c might not be in _pending_connections, if it turned out to be a
@ -444,7 +476,6 @@ class Connector(object):
return HostnameEndpoint(self._reactor, hint.hostname, hint.port)
return None
# Connection selection. All instances of DilatedConnectionProtocol which
# look viable get passed into our add_contender() method.
@ -459,6 +490,7 @@ class Connector(object):
# our Connection protocols call: add_candidate
@attrs
class OutboundConnectionFactory(ClientFactory, object):
_connector = attrib(validator=provides(IDilationConnector))
@ -471,6 +503,7 @@ class OutboundConnectionFactory(ClientFactory, object):
p.use_relay(self._relay_handshake)
return p
@attrs
class InboundConnectionFactory(ServerFactory, object):
_connector = attrib(validator=provides(IDilationConnector))

View File

@ -4,10 +4,13 @@ import struct
assert len(struct.pack("<L", 0)) == 4
assert len(struct.pack("<Q", 0)) == 8
def to_be4(value):
if not 0 <= value < 2**32:
raise ValueError
return struct.pack(">L", value)
def from_be4(b):
if not isinstance(b, bytes):
raise TypeError(repr(b))

View File

@ -6,13 +6,19 @@ from twisted.python import log
from .._interfaces import IDilationManager, IInbound
from .subchannel import (SubChannel, _SubchannelAddress)
class DuplicateOpenError(Exception):
pass
class DataForMissingSubchannelError(Exception):
pass
class CloseForMissingSubchannelError(Exception):
pass
@attrs
@implementer(IInbound)
class Inbound(object):
@ -37,7 +43,6 @@ class Inbound(object):
def set_subchannel_zero(self, scid0, sc0):
self._open_subchannels[scid0] = sc0
def use_connection(self, c):
self._connection = c
# We can pause the connection's reads when we receive too much data. If
@ -61,7 +66,8 @@ class Inbound(object):
def handle_open(self, scid):
if scid in self._open_subchannels:
log.err(DuplicateOpenError("received duplicate OPEN for {}".format(scid)))
log.err(DuplicateOpenError(
"received duplicate OPEN for {}".format(scid)))
return
peer_addr = _SubchannelAddress(scid)
sc = SubChannel(scid, self._manager, self._host_addr, peer_addr)
@ -71,14 +77,16 @@ class Inbound(object):
def handle_data(self, scid, data):
sc = self._open_subchannels.get(scid)
if sc is None:
log.err(DataForMissingSubchannelError("received DATA for non-existent subchannel {}".format(scid)))
log.err(DataForMissingSubchannelError(
"received DATA for non-existent subchannel {}".format(scid)))
return
sc.remote_data(data)
def handle_close(self, scid):
sc = self._open_subchannels.get(scid)
if sc is None:
log.err(CloseForMissingSubchannelError("received CLOSE for non-existent subchannel {}".format(scid)))
log.err(CloseForMissingSubchannelError(
"received CLOSE for non-existent subchannel {}".format(scid)))
return
sc.remote_close()
@ -90,7 +98,6 @@ class Inbound(object):
def stop_using_connection(self):
self._connection = None
# from our Subchannel, or rather from the Protocol above it and sent
# through the subchannel

View File

@ -20,13 +20,19 @@ from .connection import KCM, Ping, Pong, Open, Data, Close, Ack
from .inbound import Inbound
from .outbound import Outbound
class OldPeerCannotDilateError(Exception):
pass
class UnknownDilationMessageType(Exception):
pass
class ReceivedHintsTooEarly(Exception):
pass
@attrs
@implementer(IDilationManager)
class _ManagerBase(object):
@ -64,13 +70,13 @@ class _ManagerBase(object):
def set_listener_endpoint(self, listener_endpoint):
self._inbound.set_listener_endpoint(listener_endpoint)
def set_subchannel_zero(self, scid0, sc0):
self._inbound.set_subchannel_zero(scid0, sc0)
def when_first_connected(self):
return self._first_connected.when_fired()
def send_dilation_phase(self, **fields):
dilation_phase = self._next_dilation_phase
self._next_dilation_phase += 1
@ -79,25 +85,30 @@ class _ManagerBase(object):
def send_hints(self, hints): # from Connector
self.send_dilation_phase(type="connection-hints", hints=hints)
# forward inbound-ish things to _Inbound
def subchannel_pauseProducing(self, sc):
self._inbound.subchannel_pauseProducing(sc)
def subchannel_resumeProducing(self, sc):
self._inbound.subchannel_resumeProducing(sc)
def subchannel_stopProducing(self, sc):
self._inbound.subchannel_stopProducing(sc)
# forward outbound-ish things to _Outbound
def subchannel_registerProducer(self, sc, producer, streaming):
self._outbound.subchannel_registerProducer(sc, producer, streaming)
def subchannel_unregisterProducer(self, sc):
self._outbound.subchannel_unregisterProducer(sc)
def send_open(self, scid):
self._queue_and_send(Open, scid)
def send_data(self, scid, data):
self._queue_and_send(Data, scid, data)
def send_close(self, scid):
self._queue_and_send(Close, scid)
@ -116,7 +127,6 @@ class _ManagerBase(object):
self._inbound.subchannel_closed(scid, sc)
self._outbound.subchannel_closed(scid, sc)
def _start_connecting(self, role):
assert self._my_role is not None
self._connector = Connector(self._transit_key,
@ -140,6 +150,7 @@ class _ManagerBase(object):
self._made_first_connection = True
self._first_connected.fire(None)
pass
def connector_connection_lost(self):
self._stop_using_connection()
if self.role is LEADER:
@ -147,7 +158,6 @@ class _ManagerBase(object):
else:
self.connection_lost_follower()
def _stop_using_connection(self):
# the connection is already lost by this point
self._connection = None
@ -190,7 +200,6 @@ class _ManagerBase(object):
def send_ack(self, resp_seqnum):
self._outbound.send_if_connected(Ack(resp_seqnum))
def handle_ping(self, ping_id):
self.send_pong(ping_id)
@ -200,7 +209,7 @@ class _ManagerBase(object):
# subchannel maintenance
def allocate_subchannel_id(self):
raise NotImplemented # subclass knows if we're leader or follower
raise NotImplementedError # subclass knows if we're leader or follower
# new scheme:
# * both sides send PLEASE as soon as they have an unverified key and
@ -236,58 +245,91 @@ class _ManagerBase(object):
# * if follower calls w.dilate() but not leader, follower waits forever
# in "want", leader waits forever in "wanted"
class ManagerShared(_ManagerBase):
m = MethodicalMachine()
set_trace = getattr(m, "_setTrace", lambda self, f: None)
@m.state(initial=True)
def IDLE(self): pass # pragma: no cover
def IDLE(self):
pass # pragma: no cover
@m.state()
def WANTING(self): pass # pragma: no cover
def WANTING(self):
pass # pragma: no cover
@m.state()
def WANTED(self): pass # pragma: no cover
def WANTED(self):
pass # pragma: no cover
@m.state()
def CONNECTING(self): pass # pragma: no cover
def CONNECTING(self):
pass # pragma: no cover
@m.state()
def CONNECTED(self): pass # pragma: no cover
def CONNECTED(self):
pass # pragma: no cover
@m.state()
def FLUSHING(self): pass # pragma: no cover
def FLUSHING(self):
pass # pragma: no cover
@m.state()
def ABANDONING(self): pass # pragma: no cover
def ABANDONING(self):
pass # pragma: no cover
@m.state()
def LONELY(self): pass # pragme: no cover
def LONELY(self):
pass # pragme: no cover
@m.state()
def STOPPING(self): pass # pragma: no cover
def STOPPING(self):
pass # pragma: no cover
@m.state(terminal=True)
def STOPPED(self): pass # pragma: no cover
def STOPPED(self):
pass # pragma: no cover
@m.input()
def start(self): pass # pragma: no cover
def start(self):
pass # pragma: no cover
@m.input()
def rx_PLEASE(self, message): pass # pragma: no cover
def rx_PLEASE(self, message):
pass # pragma: no cover
@m.input() # only sent by Follower
def rx_HINTS(self, hint_message): pass # pragma: no cover
def rx_HINTS(self, hint_message):
pass # pragma: no cover
@m.input() # only Leader sends RECONNECT, so only Follower receives it
def rx_RECONNECT(self): pass # pragma: no cover
def rx_RECONNECT(self):
pass # pragma: no cover
@m.input() # only Follower sends RECONNECTING, so only Leader receives it
def rx_RECONNECTING(self): pass # pragma: no cover
def rx_RECONNECTING(self):
pass # pragma: no cover
# Connector gives us connection_made()
@m.input()
def connection_made(self, c): pass # pragma: no cover
def connection_made(self, c):
pass # pragma: no cover
# our connection_lost() fires connection_lost_leader or
# connection_lost_follower depending upon our role. If either side sees a
# problem with the connection (timeouts, bad authentication) then they
# just drop it and let connection_lost() handle the cleanup.
@m.input()
def connection_lost_leader(self): pass # pragma: no cover
@m.input()
def connection_lost_follower(self): pass
def connection_lost_leader(self):
pass # pragma: no cover
@m.input()
def stop(self): pass # pragma: no cover
def connection_lost_follower(self):
pass
@m.input()
def stop(self):
pass # pragma: no cover
@m.output()
def stash_side(self, message):
@ -302,6 +344,7 @@ class ManagerShared(_ManagerBase):
@m.output()
def start_connecting(self):
self._start_connecting() # TODO: merge
@m.output()
def ignore_message_start_connecting(self, message):
self.start_connecting()
@ -309,6 +352,7 @@ class ManagerShared(_ManagerBase):
@m.output()
def send_reconnect(self):
self.send_dilation_phase(type="reconnect") # TODO: generation number?
@m.output()
def send_reconnecting(self):
self.send_dilation_phase(type="reconnecting") # TODO: generation?
@ -319,20 +363,22 @@ class ManagerShared(_ManagerBase):
[parse_hint(hs) for hs in hint_message["hints"]])
hint_objs = list(hint_objs)
self._connector.got_hints(hint_objs)
@m.output()
def stop_connecting(self):
self._connector.stop()
@m.output()
def abandon_connection(self):
# we think we're still connected, but the Leader disagrees. Or we've
# been told to shut down.
self._connection.disconnect() # let connection_lost do cleanup
# we don't start CONNECTING until a local start() plus rx_PLEASE
IDLE.upon(rx_PLEASE, enter=WANTED, outputs=[stash_side])
IDLE.upon(start, enter=WANTING, outputs=[send_please])
WANTED.upon(start, enter=CONNECTING, outputs=[send_please, start_connecting])
WANTED.upon(start, enter=CONNECTING, outputs=[
send_please, start_connecting])
WANTING.upon(rx_PLEASE, enter=CONNECTING,
outputs=[stash_side,
ignore_message_start_connecting])
@ -342,7 +388,8 @@ class ManagerShared(_ManagerBase):
# Leader
CONNECTED.upon(connection_lost_leader, enter=FLUSHING,
outputs=[send_reconnect])
FLUSHING.upon(rx_RECONNECTING, enter=CONNECTING, outputs=[start_connecting])
FLUSHING.upon(rx_RECONNECTING, enter=CONNECTING,
outputs=[start_connecting])
# Follower
# if we notice a lost connection, just wait for the Leader to notice too
@ -362,7 +409,6 @@ class ManagerShared(_ManagerBase):
send_reconnecting,
start_connecting])
# rx_HINTS never changes state, they're just accepted or ignored
IDLE.upon(rx_HINTS, enter=IDLE, outputs=[]) # too early
WANTED.upon(rx_HINTS, enter=WANTED, outputs=[]) # too early
@ -385,7 +431,6 @@ class ManagerShared(_ManagerBase):
STOPPING.upon(connection_lost_leader, enter=STOPPED, outputs=[])
STOPPING.upon(connection_lost_follower, enter=STOPPED, outputs=[])
def allocate_subchannel_id(self):
# scid 0 is reserved for the control channel. the leader uses odd
# numbers starting with 1
@ -393,6 +438,7 @@ class ManagerShared(_ManagerBase):
self._next_outbound_seqnum += 2
return to_be4(scid_num)
@attrs
@implementer(IDilator)
class Dilator(object):
@ -462,7 +508,8 @@ class Dilator(object):
control_ep._subchannel_zero_opened(sc0)
self._manager.set_subchannel_zero(scid0, sc0)
connect_ep = SubchannelConnectorEndpoint(self._manager, self._host_addr)
connect_ep = SubchannelConnectorEndpoint(
self._manager, self._host_addr)
listen_ep = SubchannelListenerEndpoint(self._manager, self._host_addr)
self._manager.set_listener_endpoint(listen_ep)
@ -476,7 +523,7 @@ class Dilator(object):
# TODO: verify this happens before got_wormhole_versions, or add a gate
# to tolerate either ordering
purpose = b"dilation-v1"
LENGTH =32 # TODO: whatever Noise wants, I guess
LENGTH = 32 # TODO: whatever Noise wants, I guess
self._transit_key = derive_key(key, purpose, LENGTH)
def got_wormhole_versions(self, our_side, their_side,
@ -506,7 +553,7 @@ class Dilator(object):
if type == "please":
self._manager.rx_PLEASE() # message)
elif type == "dilate":
self._manager.rx_DILATE() #message)
self._manager.rx_DILATE() # message)
elif type == "connection-hints":
self._manager.rx_HINTS(message)
else:

View File

@ -290,9 +290,9 @@ class Outbound(object):
# Inbound is responsible for tracking the high watermark and deciding
# whether to ignore inbound messages or not
# IProducer: the active connection calls these because we used
# c.registerProducer to ask for them
def pauseProducing(self):
if self._paused:
return # someone is confused and called us twice
@ -343,7 +343,7 @@ class Outbound(object):
@attrs(cmp=False)
class PullToPush(object):
_producer = attrib(validator=provides(IPullProducer))
_unregister = attrib(validator=lambda _a,_b,v: callable(v))
_unregister = attrib(validator=lambda _a, _b, v: callable(v))
_cooperator = attrib()
_finished = False
@ -351,14 +351,14 @@ class PullToPush(object):
while True:
try:
self._producer.resumeProducing()
except:
except Exception:
log.err(None, "%s failed, producing will be stopped:" %
(safe_str(self._producer),))
try:
self._unregister()
# The consumer should now call stopStreaming() on us,
# thus stopping the streaming.
except:
except Exception:
# Since the consumer blew up, we may not have had
# stopStreaming() called, so we just stop on our own:
log.err(None, "%s failed to unregister producer:" %
@ -378,15 +378,12 @@ class PullToPush(object):
self._finished = True
self._coopTask.stop()
def pauseProducing(self):
self._coopTask.pause()
def resumeProducing(self):
self._coopTask.resume()
def stopProducing(self):
self.stopStreaming()
self._producer.stopProducing()

View File

@ -1,7 +1,8 @@
from attr import attrs, attrib
from attr.validators import instance_of, provides
from zope.interface import implementer
from twisted.internet.defer import Deferred, inlineCallbacks, returnValue, succeed
from twisted.internet.defer import (Deferred, inlineCallbacks, returnValue,
succeed)
from twisted.internet.interfaces import (ITransport, IProducer, IConsumer,
IAddress, IListeningPort,
IStreamClientEndpoint,
@ -10,9 +11,11 @@ from twisted.internet.error import ConnectionDone
from automat import MethodicalMachine
from .._interfaces import ISubChannel, IDilationManager
@attrs
class Once(object):
_errtype = attrib()
def __attrs_post_init__(self):
self._called = False
@ -21,6 +24,7 @@ class Once(object):
raise self._errtype()
self._called = True
class SingleUseEndpointError(Exception):
pass
@ -38,13 +42,16 @@ class SingleUseEndpointError(Exception):
# (CLOSING) rx CLOSE: deliver .connectionLost(), -> (CLOSED)
# object is deleted upon transition to (CLOSED)
class AlreadyClosedError(Exception):
pass
@implementer(IAddress)
class _WormholeAddress(object):
pass
@implementer(IAddress)
@attrs
class _SubchannelAddress(object):
@ -63,35 +70,44 @@ class SubChannel(object):
_peer_addr = attrib(validator=instance_of(_SubchannelAddress))
m = MethodicalMachine()
set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover
set_trace = getattr(m, "_setTrace", lambda self,
f: None) # pragma: no cover
def __attrs_post_init__(self):
#self._mailbox = None
#self._pending_outbound = {}
#self._processed = set()
# self._mailbox = None
# self._pending_outbound = {}
# self._processed = set()
self._protocol = None
self._pending_dataReceived = []
self._pending_connectionLost = (False, None)
@m.state(initial=True)
def open(self): pass # pragma: no cover
def open(self):
pass # pragma: no cover
@m.state()
def closing(): pass # pragma: no cover
def closing():
pass # pragma: no cover
@m.state()
def closed(): pass # pragma: no cover
def closed():
pass # pragma: no cover
@m.input()
def remote_data(self, data): pass
@m.input()
def remote_close(self): pass
def remote_data(self, data):
pass
@m.input()
def local_data(self, data): pass
@m.input()
def local_close(self): pass
def remote_close(self):
pass
@m.input()
def local_data(self, data):
pass
@m.input()
def local_close(self):
pass
@m.output()
def send_data(self, data):
@ -120,9 +136,11 @@ class SubChannel(object):
@m.output()
def error_closed_write(self, data):
raise AlreadyClosedError("write not allowed on closed subchannel")
@m.output()
def error_closed_close(self):
raise AlreadyClosedError("loseConnection not allowed on closed subchannel")
raise AlreadyClosedError(
"loseConnection not allowed on closed subchannel")
# primary transitions
open.upon(remote_data, enter=open, outputs=[signal_dataReceived])
@ -155,13 +173,17 @@ class SubChannel(object):
def write(self, data):
assert isinstance(data, type(b""))
self.local_data(data)
def writeSequence(self, iovec):
self.write(b"".join(iovec))
def loseConnection(self):
self.local_close()
def getHost(self):
# we define "host addr" as the overall wormhole
return self._host_addr
def getPeer(self):
# and "peer addr" as the subchannel within that wormhole
return self._peer_addr
@ -169,14 +191,17 @@ class SubChannel(object):
# IProducer: throttle inbound data (wormhole "up" to local app's Protocol)
def stopProducing(self):
self._manager.subchannel_stopProducing(self)
def pauseProducing(self):
self._manager.subchannel_pauseProducing(self)
def resumeProducing(self):
self._manager.subchannel_resumeProducing(self)
# IConsumer: allow the wormhole to throttle outbound data (app->wormhole)
def registerProducer(self, producer, streaming):
self._manager.subchannel_registerProducer(self, producer, streaming)
def unregisterProducer(self):
self._manager.subchannel_unregisterProducer(self)
@ -184,6 +209,7 @@ class SubChannel(object):
@implementer(IStreamClientEndpoint)
class ControlEndpoint(object):
_used = False
def __init__(self, peer_addr):
self._subchannel_zero = Deferred()
self._peer_addr = peer_addr
@ -204,6 +230,7 @@ class ControlEndpoint(object):
p.makeConnection(t) # set p.transport = t and call connectionMade()
returnValue(p)
@implementer(IStreamClientEndpoint)
@attrs
class SubchannelConnectorEndpoint(object):
@ -223,6 +250,7 @@ class SubchannelConnectorEndpoint(object):
p.makeConnection(t) # set p.transport = t and call connectionMade()
return succeed(p)
@implementer(IStreamServerEndpoint)
@attrs
class SubchannelListenerEndpoint(object):
@ -238,7 +266,7 @@ class SubchannelListenerEndpoint(object):
if self._factory:
self._connect(t, peer_addr)
else:
self._pending_opens.append( (t, peer_addr) )
self._pending_opens.append((t, peer_addr))
def _connect(self, t, peer_addr):
p = self._factory.buildProtocol(peer_addr)
@ -255,6 +283,7 @@ class SubchannelListenerEndpoint(object):
lp = SubchannelListeningPort(self._host_addr)
return succeed(lp)
@implementer(IListeningPort)
@attrs
class SubchannelListeningPort(object):
@ -262,8 +291,10 @@ class SubchannelListeningPort(object):
def startListening(self):
pass
def stopListening(self):
# TODO
pass
def getHost(self):
return self._host_addr