diff --git a/src/wormhole/_boss.py b/src/wormhole/_boss.py index ce650b7..226735b 100644 --- a/src/wormhole/_boss.py +++ b/src/wormhole/_boss.py @@ -81,8 +81,8 @@ class Boss(object): self._A.wire(self._RC, self._C) self._I.wire(self._C, self._L) self._C.wire(self, self._A, self._N, self._K, self._I) - self._T.wire(self, self._RC, self._N, self._M) - self._D.wire(self._S) + self._T.wire(self, self._RC, self._N, self._M, self._D) + self._D.wire(self._S, self._T) def _init_other_state(self): self._did_start_code = False diff --git a/src/wormhole/_dilation/_noise.py b/src/wormhole/_dilation/_noise.py index bb4cf58..1005264 100644 --- a/src/wormhole/_dilation/_noise.py +++ b/src/wormhole/_dilation/_noise.py @@ -4,6 +4,12 @@ except ImportError: class NoiseInvalidMessage(Exception): pass +try: + from noise.exceptions import NoiseHandshakeError +except ImportError: + class NoiseHandshakeError(Exception): + pass + try: from noise.connection import NoiseConnection except ImportError: diff --git a/src/wormhole/_dilation/connection.py b/src/wormhole/_dilation/connection.py index b8f3ec6..8bb82aa 100644 --- a/src/wormhole/_dilation/connection.py +++ b/src/wormhole/_dilation/connection.py @@ -11,8 +11,8 @@ from twisted.internet.interfaces import ITransport from .._interfaces import IDilationConnector from ..observer import OneShotObserver from .encode import to_be4, from_be4 -from .roles import FOLLOWER -from ._noise import NoiseInvalidMessage +from .roles import LEADER, FOLLOWER +from ._noise import NoiseInvalidMessage, NoiseHandshakeError # InboundFraming is given data and returns Frames (Noise wire-side # bytestrings). It handles the relay handshake and the prologue. The Frames it @@ -56,6 +56,23 @@ def first(l): class Disconnect(Exception): pass +# all connections look like: +# (step 1: only for outbound connections) +# 1: if we're connecting to a transit relay: +# * send "sided relay handshake": "please relay TOKEN for side SIDE\n" +# * the relay will send "ok\n" if/when our peer connects +# * a non-relay will probably send junk +# * wait for "ok\n", hang up if we get anything different +# (all subsequent steps are for both inbound and outbound connections) +# 2: send PROLOGUE_LEADER/FOLLOWER: "Magic-Wormhole Dilation Handshale v1 (l/f)\n\n" +# 3: wait for the opposite PROLOGUE string, else hang up +# (everything past this point is a Frame, with be4 length prefix. Frames are +# either noise handshake or an encrypted message) +# 4: if LEADER, send noise handshake string. if FOLLOWER, wait for it +# 5: if FOLLOWER, send noise response string. if LEADER, wait for it +# 6: ... + + RelayOK = namedtuple("RelayOk", []) Prologue = namedtuple("Prologue", []) @@ -193,7 +210,7 @@ class _Framer(object): def add_and_parse(self, data): # we can't make this an @m.input because we can't change the state # from within an input. Instead, let the state choose the parser to - # use, and use the parsed token drive a state transition. + # use, then use the parsed token to drive a state transition. self._buffer += data while True: # it'd be nice to use an iterator here, but since self.parse() @@ -233,7 +250,7 @@ Ping = namedtuple("Ping", ["ping_id"]) # ping_id is arbitrary 4-byte value Pong = namedtuple("Pong", ["ping_id"]) Open = namedtuple("Open", ["seqnum", "scid"]) # seqnum is integer Data = namedtuple("Data", ["seqnum", "scid", "data"]) -Close = namedtuple("Close", ["seqnum", "scid"]) # scid is integer +Close = namedtuple("Close", ["seqnum", "scid"]) # scid is arbitrary 4-byte value Ack = namedtuple("Ack", ["resp_seqnum"]) # resp_seqnum is integer Records = (KCM, Ping, Pong, Open, Data, Close, Ack) Handshake_or_Records = (Handshake,) + Records @@ -258,16 +275,16 @@ def parse_record(plaintext): ping_id = plaintext[1:5] return Pong(ping_id) if msgtype == T_OPEN: - scid = from_be4(plaintext[1:5]) + scid = plaintext[1:5] seqnum = from_be4(plaintext[5:9]) return Open(seqnum, scid) if msgtype == T_DATA: - scid = from_be4(plaintext[1:5]) + scid = plaintext[1:5] seqnum = from_be4(plaintext[5:9]) data = plaintext[9:] return Data(seqnum, scid, data) if msgtype == T_CLOSE: - scid = from_be4(plaintext[1:5]) + scid = plaintext[1:5] seqnum = from_be4(plaintext[5:9]) return Close(seqnum, scid) if msgtype == T_ACK: @@ -285,28 +302,36 @@ def encode_record(r): if isinstance(r, Pong): return b"\x02" + r.ping_id if isinstance(r, Open): - assert isinstance(r.scid, six.integer_types) + assert isinstance(r.scid, bytes) + assert len(r.scid) == 4 assert isinstance(r.seqnum, six.integer_types) - return b"\x03" + to_be4(r.scid) + to_be4(r.seqnum) + return b"\x03" + r.scid + to_be4(r.seqnum) if isinstance(r, Data): - assert isinstance(r.scid, six.integer_types) + assert isinstance(r.scid, bytes) + assert len(r.scid) == 4 assert isinstance(r.seqnum, six.integer_types) - return b"\x04" + to_be4(r.scid) + to_be4(r.seqnum) + r.data + return b"\x04" + r.scid + to_be4(r.seqnum) + r.data if isinstance(r, Close): - assert isinstance(r.scid, six.integer_types) + assert isinstance(r.scid, bytes) + assert len(r.scid) == 4 assert isinstance(r.seqnum, six.integer_types) - return b"\x05" + to_be4(r.scid) + to_be4(r.seqnum) + return b"\x05" + r.scid + to_be4(r.seqnum) if isinstance(r, Ack): assert isinstance(r.resp_seqnum, six.integer_types) return b"\x06" + to_be4(r.resp_seqnum) raise TypeError(r) +def _is_role(_record, _attr, value): + if value not in [LEADER, FOLLOWER]: + raise ValueError("role must be LEADER or FOLLOWER") + @attrs @implementer(IRecord) class _Record(object): _framer = attrib(validator=provides(IFramer)) _noise = attrib() + _role = attrib(default="unspecified", validator=_is_role) # for debugging n = MethodicalMachine() # TODO: set_trace @@ -321,17 +346,37 @@ class _Record(object): # states: want_prologue, want_handshake, want_record @n.state(initial=True) - def want_prologue(self): + def no_role_set(self): pass # pragma: no cover @n.state() - def want_handshake(self): + def want_prologue_leader(self): + pass # pragma: no cover + + @n.state() + def want_prologue_follower(self): + pass # pragma: no cover + + @n.state() + def want_handshake_leader(self): + pass # pragma: no cover + + @n.state() + def want_handshake_follower(self): pass # pragma: no cover @n.state() def want_message(self): pass # pragma: no cover + @n.input() + def set_role_leader(self): + pass + + @n.input() + def set_role_follower(self): + pass + @n.input() def got_prologue(self): pass @@ -340,9 +385,20 @@ class _Record(object): def got_frame(self, frame): pass + @n.output() + def ignore_and_send_handshake(self, frame): + self._send_handshake() + @n.output() def send_handshake(self): - handshake = self._noise.write_message() # generate the ephemeral key + self._send_handshake() + + def _send_handshake(self): + try: + handshake = self._noise.write_message() # generate the ephemeral key + except NoiseHandshakeError as e: + log.err(e, "noise error during handshake") + raise self._framer.send_frame(handshake) @n.output() @@ -367,10 +423,19 @@ class _Record(object): raise Disconnect() return parse_record(message) - want_prologue.upon(got_prologue, outputs=[send_handshake], - enter=want_handshake) - want_handshake.upon(got_frame, outputs=[process_handshake], - collector=first, enter=want_message) + no_role_set.upon(set_role_leader, outputs=[], enter=want_prologue_leader) + want_prologue_leader.upon(got_prologue, outputs=[send_handshake], + enter=want_handshake_leader) + want_handshake_leader.upon(got_frame, outputs=[process_handshake], + collector=first, enter=want_message) + + no_role_set.upon(set_role_follower, outputs=[], enter=want_prologue_follower) + want_prologue_follower.upon(got_prologue, outputs=[], + enter=want_handshake_follower) + want_handshake_follower.upon(got_frame, outputs=[process_handshake, + ignore_and_send_handshake], + collector=first, enter=want_message) + want_message.upon(got_frame, outputs=[decrypt_message], collector=first, enter=want_message) @@ -393,7 +458,7 @@ class _Record(object): self._framer.send_frame(frame) -@attrs +@attrs(cmp=False) class DilatedConnectionProtocol(Protocol, object): """I manage an L2 connection. @@ -408,12 +473,13 @@ class DilatedConnectionProtocol(Protocol, object): At any given time, there is at most one active L2 connection. """ - _eventual_queue = attrib() + _eventual_queue = attrib(repr=False) _role = attrib() - _connector = attrib(validator=provides(IDilationConnector)) - _noise = attrib() - _outbound_prologue = attrib(validator=instance_of(bytes)) - _inbound_prologue = attrib(validator=instance_of(bytes)) + _description = attrib() + _connector = attrib(validator=provides(IDilationConnector), repr=False) + _noise = attrib(repr=False) + _outbound_prologue = attrib(validator=instance_of(bytes), repr=False) + _inbound_prologue = attrib(validator=instance_of(bytes), repr=False) _use_relay = False _relay_handshake = None @@ -457,6 +523,8 @@ class DilatedConnectionProtocol(Protocol, object): @m.output() def set_manager(self, manager): self._manager = manager + self.when_disconnected().addCallback(lambda c: + manager.connector_connection_lost()) @m.output() def can_send_records(self, manager): @@ -493,12 +561,20 @@ class DilatedConnectionProtocol(Protocol, object): # IProtocol methods def connectionMade(self): - framer = _Framer(self.transport, - self._outbound_prologue, self._inbound_prologue) - if self._use_relay: - framer.use_relay(self._relay_handshake) - self._record = _Record(framer, self._noise) - self._record.connectionMade() + try: + framer = _Framer(self.transport, + self._outbound_prologue, self._inbound_prologue) + if self._use_relay: + framer.use_relay(self._relay_handshake) + self._record = _Record(framer, self._noise, self._role) + if self._role is LEADER: + self._record.set_role_leader() + else: + self._record.set_role_follower() + self._record.connectionMade() + except: + log.err() + raise def dataReceived(self, data): try: diff --git a/src/wormhole/_dilation/connector.py b/src/wormhole/_dilation/connector.py index aa5f8e0..638a827 100644 --- a/src/wormhole/_dilation/connector.py +++ b/src/wormhole/_dilation/connector.py @@ -9,6 +9,7 @@ from twisted.internet.task import deferLater from twisted.internet.defer import DeferredList from twisted.internet.endpoints import serverFromString from twisted.internet.protocol import ClientFactory, ServerFactory +from twisted.internet.address import HostnameAddress, IPv4Address, IPv6Address from twisted.python import log from .. import ipaddrs # TODO: move into _dilation/ from .._interfaces import IDilationConnector, IDilationManager @@ -39,9 +40,36 @@ NOISEPROTO = b"Noise_NNpsk0_25519_ChaChaPoly_BLAKE2s" def build_noise(): return NoiseConnection.from_name(NOISEPROTO) -@attrs +@attrs(cmp=False) @implementer(IDilationConnector) class Connector(object): + """I manage a single generation of connection. + + The Manager creates one of me at a time, whenever it wants a connection + (which is always, once w.dilate() has been called and we know the remote + end can dilate, and is expressed by the Manager calling my .start() + method). I am discarded when my established connection is lost (and if we + still want to be connected, a new generation is started and a new + Connector is created). I am also discarded if we stop wanting to be + connected (which the Manager expresses by calling my .stop() method). + + I manage the race between multiple connections for a specific generation + of the dilated connection. + + I send connection hints when my InboundConnectionFactory yields addresses + (self.listener_ready), and I initiate outbond connections (with + OutboundConnectionFactory) as I receive connection hints from my peer + (self.got_hints). Both factories use my build_protocol() method to create + connection.DilatedConnectionProtocol instances. I track these protocol + instances until one finishes negotiation and wins the race. I then shut + down the others, remember the winner as self._winning_connection, and + deliver the winner to manager.connector_connection_made(c). + + When an active connection is lost, we call manager.connector_connection_lost, + allowing the manager to decide whether it wants to start a new generation + or not. + """ + _dilation_key = attrib(validator=instance_of(type(b""))) _transit_relay_location = attrib(validator=optional(instance_of(type(u"")))) _manager = attrib(validator=provides(IDilationManager)) @@ -83,7 +111,7 @@ class Connector(object): {"type": "relay-v1"}, ] - def build_protocol(self, addr): + def build_protocol(self, addr, description): # encryption: let's use Noise NNpsk0 (or maybe NNpsk2). That uses # ephemeral keys plus a pre-shared symmetric key (the Transit key), a # different one for each potential connection. @@ -98,6 +126,7 @@ class Connector(object): outbound_prologue = PROLOGUE_FOLLOWER inbound_prologue = PROLOGUE_LEADER p = DilatedConnectionProtocol(self._eventual_queue, self._role, + description, self, noise, outbound_prologue, inbound_prologue) return p @@ -181,10 +210,13 @@ class Connector(object): self.stop_pending_connections() c.select(self._manager) # subsequent frames go directly to the manager + # c.select also wires up when_disconnected() to fire + # manager.connector_connection_lost(). TODO: rename this, since the + # Connector is no longer the one calling it if self._role is LEADER: # TODO: this should live in Connection c.send_record(KCM()) # leader sends KCM now - self._manager.use_connection(c) # manager sends frames to Connection + self._manager.connector_connection_made(c) # manager sends frames to Connection @m.output() def stop_everything(self): @@ -199,11 +231,12 @@ class Connector(object): return d # synchronization for tests def stop_pending_connectors(self): - return DeferredList([d.cancel() for d in self._pending_connectors]) + for d in self._pending_connectors: + d.cancel() def stop_pending_connections(self): d = self._pending_connections.when_next_empty() - [c.loseConnection() for c in self._pending_connections] + [c.disconnect() for c in self._pending_connections] return d def break_cycles(self): @@ -337,7 +370,7 @@ class Connector(object): if is_relay: relay_handshake = build_sided_relay_handshake(self._dilation_key, self._side) - f = OutboundConnectionFactory(self, relay_handshake) + f = OutboundConnectionFactory(self, relay_handshake, description) d = ep.connect(f) # fires with protocol, or ConnectError @@ -368,20 +401,28 @@ class Connector(object): class OutboundConnectionFactory(ClientFactory, object): _connector = attrib(validator=provides(IDilationConnector)) _relay_handshake = attrib(validator=optional(instance_of(bytes))) + _description = attrib() def buildProtocol(self, addr): - p = self._connector.build_protocol(addr) + p = self._connector.build_protocol(addr, self._description) p.factory = self if self._relay_handshake is not None: p.use_relay(self._relay_handshake) return p +def describe_inbound(addr): + if isinstance(addr, HostnameAddress): + return "<-tcp:%s:%d" % (addr.hostname, addr.port) + elif isinstance(addr, (IPv4Address, IPv6Address)): + return "<-tcp:%s:%d" % (addr.host, addr.port) + return "<-%r" % addr @attrs class InboundConnectionFactory(ServerFactory, object): _connector = attrib(validator=provides(IDilationConnector)) def buildProtocol(self, addr): - p = self._connector.build_protocol(addr) + description = describe_inbound(addr) + p = self._connector.build_protocol(addr, description) p.factory = self return p diff --git a/src/wormhole/_dilation/inbound.py b/src/wormhole/_dilation/inbound.py index 2f6ffaf..7adaade 100644 --- a/src/wormhole/_dilation/inbound.py +++ b/src/wormhole/_dilation/inbound.py @@ -60,9 +60,9 @@ class Inbound(object): return True return False - def update_ack_watermark(self, r): + def update_ack_watermark(self, seqnum): self._highest_inbound_acked = max(self._highest_inbound_acked, - r.seqnum) + seqnum) def handle_open(self, scid): if scid in self._open_subchannels: diff --git a/src/wormhole/_dilation/manager.py b/src/wormhole/_dilation/manager.py index 18d1770..f5724be 100644 --- a/src/wormhole/_dilation/manager.py +++ b/src/wormhole/_dilation/manager.py @@ -7,7 +7,7 @@ from automat import MethodicalMachine from zope.interface import implementer from twisted.internet.defer import Deferred, inlineCallbacks, returnValue from twisted.python import log -from .._interfaces import IDilator, IDilationManager, ISend +from .._interfaces import IDilator, IDilationManager, ISend, ITerminator from ..util import dict_to_bytes, bytes_to_dict, bytes_to_hexstr from ..observer import OneShotObserver from .._key import derive_key @@ -87,17 +87,17 @@ def make_side(): # * if follower calls w.dilate() but not leader, follower waits forever # in "want", leader waits forever in "wanted" -@attrs +@attrs(cmp=False) @implementer(IDilationManager) class Manager(object): - _S = attrib(validator=provides(ISend)) + _S = attrib(validator=provides(ISend), repr=False) _my_side = attrib(validator=instance_of(type(u""))) - _transit_key = attrib(validator=instance_of(bytes)) + _transit_key = attrib(validator=instance_of(bytes), repr=False) _transit_relay_location = attrib(validator=optional(instance_of(str))) - _reactor = attrib() - _eventual_queue = attrib() - _cooperator = attrib() - _no_listen = False # TODO + _reactor = attrib(repr=False) + _eventual_queue = attrib(repr=False) + _cooperator = attrib(repr=False) + _no_listen = attrib(default=False) _tor = None # TODO _timing = None # TODO _next_subchannel_id = None # initialized in choose_role @@ -113,6 +113,7 @@ class Manager(object): self._connection = None self._made_first_connection = False self._first_connected = OneShotObserver(self._eventual_queue) + self._stopped = OneShotObserver(self._eventual_queue) self._host_addr = _WormholeAddress() self._next_dilation_phase = 0 @@ -133,6 +134,9 @@ class Manager(object): def when_first_connected(self): return self._first_connected.when_fired() + def when_stopped(self): + return self._stopped.when_fired() + def send_dilation_phase(self, **fields): dilation_phase = self._next_dilation_phase self._next_dilation_phase += 1 @@ -160,12 +164,15 @@ class Manager(object): self._outbound.subchannel_unregisterProducer(sc) def send_open(self, scid): + assert isinstance(scid, bytes) self._queue_and_send(Open, scid) def send_data(self, scid, data): + assert isinstance(scid, bytes) self._queue_and_send(Data, scid, data) def send_close(self, scid): + assert isinstance(scid, bytes) self._queue_and_send(Close, scid) def _queue_and_send(self, record_type, *args): @@ -401,6 +408,10 @@ class Manager(object): # been told to shut down. self._connection.disconnect() # let connection_lost do cleanup + @m.output() + def notify_stopped(self): + self._stopped.fire(None) + # we start CONNECTING when we get rx_PLEASE WANTING.upon(rx_PLEASE, enter=CONNECTING, outputs=[choose_role, start_connecting_ignore_message]) @@ -440,14 +451,14 @@ class Manager(object): ABANDONING.upon(rx_HINTS, enter=ABANDONING, outputs=[]) # shouldn't happen STOPPING.upon(rx_HINTS, enter=STOPPING, outputs=[]) - WANTING.upon(stop, enter=STOPPED, outputs=[]) - CONNECTING.upon(stop, enter=STOPPED, outputs=[stop_connecting]) + WANTING.upon(stop, enter=STOPPED, outputs=[notify_stopped]) + CONNECTING.upon(stop, enter=STOPPED, outputs=[stop_connecting, notify_stopped]) CONNECTED.upon(stop, enter=STOPPING, outputs=[abandon_connection]) ABANDONING.upon(stop, enter=STOPPING, outputs=[]) - FLUSHING.upon(stop, enter=STOPPED, outputs=[]) - LONELY.upon(stop, enter=STOPPED, outputs=[]) - STOPPING.upon(connection_lost_leader, enter=STOPPED, outputs=[]) - STOPPING.upon(connection_lost_follower, enter=STOPPED, outputs=[]) + FLUSHING.upon(stop, enter=STOPPED, outputs=[notify_stopped]) + LONELY.upon(stop, enter=STOPPED, outputs=[notify_stopped]) + STOPPING.upon(connection_lost_leader, enter=STOPPED, outputs=[notify_stopped]) + STOPPING.upon(connection_lost_follower, enter=STOPPED, outputs=[notify_stopped]) @attrs @@ -466,6 +477,7 @@ class Dilator(object): _reactor = attrib() _eventual_queue = attrib() _cooperator = attrib() + _no_listen = attrib(default=False) def __attrs_post_init__(self): self._got_versions_d = Deferred() @@ -474,8 +486,9 @@ class Dilator(object): self._pending_inbound_dilate_messages = deque() self._manager = None - def wire(self, sender): + def wire(self, sender, terminator): self._S = ISend(sender) + self._T = ITerminator(terminator) # this is the primary entry point, called when w.dilate() is invoked def dilate(self, transit_relay_location=None): @@ -509,7 +522,7 @@ class Dilator(object): self._transit_key, self._transit_relay_location, self._reactor, self._eventual_queue, - self._cooperator) + self._cooperator, no_listen=self._no_listen) self._manager.start() while self._pending_inbound_dilate_messages: @@ -519,7 +532,7 @@ class Dilator(object): yield self._manager.when_first_connected() # we can open subchannels as soon as we get our first connection - scid0 = b"\x00\x00\x00\x00" + scid0 = to_be4(0) self._host_addr = _WormholeAddress() # TODO: share with Manager peer_addr0 = _SubchannelAddress(scid0) control_ep = ControlEndpoint(peer_addr0) @@ -535,6 +548,19 @@ class Dilator(object): endpoints = (control_ep, connect_ep, listen_ep) returnValue(endpoints) + # Called by Terminator after everything else (mailbox, nameplate, server + # connection) has shut down. Expects to fire T.stoppedD() when Dilator is + # stopped too. + def stop(self): + if not self._started: + self._T.stoppedD() + return + if self._started: + self._manager.stop() + # TODO: avoid Deferreds for control flow, hard to serialize + self._manager.when_stopped().addCallback(lambda _: self._T.stoppedD()) + # TODO: tolerate multiple calls + # from Boss def got_key(self, key): diff --git a/src/wormhole/_dilation/outbound.py b/src/wormhole/_dilation/outbound.py index 96fbd3d..96786ca 100644 --- a/src/wormhole/_dilation/outbound.py +++ b/src/wormhole/_dilation/outbound.py @@ -154,7 +154,7 @@ from .connection import KCM, Ping, Pong, Ack @attrs -@implementer(IOutbound) +@implementer(IOutbound, IPushProducer) class Outbound(object): # Manage outbound data: subchannel writes to us, we write to transport _manager = attrib(validator=provides(IDilationManager)) @@ -265,12 +265,12 @@ class Outbound(object): assert not self._queued_unsent self._queued_unsent.extend(self._outbound_queue) # the connection can tell us to pause when we send too much data - c.registerProducer(self, True) # IPushProducer: pause+resume + c.transport.registerProducer(self, True) # IPushProducer: pause+resume # send our queued messages self.resumeProducing() def stop_using_connection(self): - self._connection.unregisterProducer() + self._connection.transport.unregisterProducer() self._connection = None self._queued_unsent.clear() self.pauseProducing() @@ -290,8 +290,8 @@ class Outbound(object): # Inbound is responsible for tracking the high watermark and deciding # whether to ignore inbound messages or not - # IProducer: the active connection calls these because we used - # c.registerProducer to ask for them + # IPushProducer: the active connection calls these because we used + # c.transport.registerProducer to ask for them def pauseProducing(self): if self._paused: diff --git a/src/wormhole/_dilation/roles.py b/src/wormhole/_dilation/roles.py index 8f9adac..001566b 100644 --- a/src/wormhole/_dilation/roles.py +++ b/src/wormhole/_dilation/roles.py @@ -1 +1,7 @@ -LEADER, FOLLOWER = object(), object() +class _Role(object): + def __init__(self, which): + self._which = which + def __repr__(self): + return "Role(%s)" % self._which + +LEADER, FOLLOWER = _Role("LEADER"), _Role("FOLLOWER") diff --git a/src/wormhole/_dilation/subchannel.py b/src/wormhole/_dilation/subchannel.py index abd1939..ddcc856 100644 --- a/src/wormhole/_dilation/subchannel.py +++ b/src/wormhole/_dilation/subchannel.py @@ -55,7 +55,7 @@ class _WormholeAddress(object): @implementer(IAddress) @attrs class _SubchannelAddress(object): - _scid = attrib() + _scid = attrib(validator=instance_of(bytes)) @attrs diff --git a/src/wormhole/_rendezvous.py b/src/wormhole/_rendezvous.py index db02166..b27f318 100644 --- a/src/wormhole/_rendezvous.py +++ b/src/wormhole/_rendezvous.py @@ -246,7 +246,7 @@ class RendezvousConnector(object): # internal def _stopped(self, res): - self._T.stopped() + self._T.stoppedRC() def _tx(self, mtype, **kwargs): assert self._ws diff --git a/src/wormhole/_terminator.py b/src/wormhole/_terminator.py index fe4bdcb..c45f6d7 100644 --- a/src/wormhole/_terminator.py +++ b/src/wormhole/_terminator.py @@ -15,15 +15,17 @@ class Terminator(object): def __init__(self): self._mood = None - def wire(self, boss, rendezvous_connector, nameplate, mailbox): + def wire(self, boss, rendezvous_connector, nameplate, mailbox, dilator): self._B = _interfaces.IBoss(boss) self._RC = _interfaces.IRendezvousConnector(rendezvous_connector) self._N = _interfaces.INameplate(nameplate) self._M = _interfaces.IMailbox(mailbox) + self._D = _interfaces.IDilator(dilator) - # 4*2-1 main states: - # (nm, m, n, 0): nameplate and/or mailbox is active + # 2*2-1+1 main states: + # (nm, m, n, d): nameplate and/or mailbox is active # (o, ""): open (not-yet-closing), or trying to close + # after closing the mailbox-server connection, we stop Dilation # S0 is special: we don't hang out in it # TODO: rename o to 0, "" to 1. "S1" is special/terminal @@ -64,7 +66,11 @@ class Terminator(object): # def S0(self): pass # unused @m.state() - def S_stopping(self): + def S_stoppingRC(self): + pass # pragma: no cover + + @m.state() + def S_stoppingD(self): pass # pragma: no cover @m.state() @@ -88,7 +94,11 @@ class Terminator(object): # from RendezvousConnector @m.input() - def stopped(self): + def stoppedRC(self): + pass + + @m.input() + def stoppedD(self): pass @m.output() @@ -107,6 +117,10 @@ class Terminator(object): def RC_stop(self): self._RC.stop() + @m.output() + def stop_dilator(self): + self._D.stop() + @m.output() def B_closed(self): self._B.closed() @@ -115,20 +129,19 @@ class Terminator(object): Snmo.upon(close, enter=Snm, outputs=[close_nameplate, close_mailbox]) Snmo.upon(nameplate_done, enter=Smo, outputs=[]) - Sno.upon(close, enter=Sn, outputs=[close_nameplate, close_mailbox]) + Sno.upon(close, enter=Sn, outputs=[close_nameplate]) Sno.upon(nameplate_done, enter=S0o, outputs=[]) - Smo.upon(close, enter=Sm, outputs=[close_nameplate, close_mailbox]) + Smo.upon(close, enter=Sm, outputs=[close_mailbox]) Smo.upon(mailbox_done, enter=S0o, outputs=[]) Snm.upon(mailbox_done, enter=Sn, outputs=[]) Snm.upon(nameplate_done, enter=Sm, outputs=[]) - Sn.upon(nameplate_done, enter=S_stopping, outputs=[RC_stop]) - S0o.upon( - close, - enter=S_stopping, - outputs=[close_nameplate, close_mailbox, ignore_mood_and_RC_stop]) - Sm.upon(mailbox_done, enter=S_stopping, outputs=[RC_stop]) + Sn.upon(nameplate_done, enter=S_stoppingRC, outputs=[RC_stop]) + Sm.upon(mailbox_done, enter=S_stoppingRC, outputs=[RC_stop]) + S0o.upon(close, enter=S_stoppingRC, outputs=[ignore_mood_and_RC_stop]) - S_stopping.upon(stopped, enter=S_stopped, outputs=[B_closed]) + S_stoppingRC.upon(stoppedRC, enter=S_stoppingD, outputs=[stop_dilator]) + + S_stoppingD.upon(stoppedD, enter=S_stopped, outputs=[B_closed]) diff --git a/src/wormhole/test/dilate/test_connect.py b/src/wormhole/test/dilate/test_connect.py new file mode 100644 index 0000000..7a60400 --- /dev/null +++ b/src/wormhole/test/dilate/test_connect.py @@ -0,0 +1,92 @@ +import re +import mock +from twisted.internet import reactor +from twisted.trial import unittest +from twisted.internet.task import Cooperator +from twisted.internet.defer import Deferred, inlineCallbacks +from zope.interface import implementer + +from ... import _interfaces +from ...eventual import EventualQueue +from ..._interfaces import ITerminator +from ..._dilation import manager +from ..._dilation._noise import NoiseConnection + + +@implementer(_interfaces.ISend) +class MySend(object): + def __init__(self, side): + self.rx_phase = 0 + self.side = side + def send(self, phase, plaintext): + #print("SEND[%s]" % self.side, phase, plaintext) + self.peer.got(phase, plaintext) + def got(self, phase, plaintext): + d_mo = re.search(r'^dilate-(\d+)$', phase) + p = int(d_mo.group(1)) + assert p == self.rx_phase + self.rx_phase += 1 + self.dilator.received_dilate(plaintext) + +@implementer(ITerminator) +class FakeTerminator(object): + def __init__(self): + self.d = Deferred() + def stoppedD(self): + self.d.callback(None) + +class Connect(unittest.TestCase): + @inlineCallbacks + def test1(self): + if not NoiseConnection: + raise unittest.SkipTest("noiseprotocol unavailable") + #print() + send_left = MySend("left") + send_right = MySend("right") + send_left.peer = send_right + send_right.peer = send_left + key = b"\x00"*32 + eq = EventualQueue(reactor) + cooperator = Cooperator(scheduler=eq.eventually) + + t_left = FakeTerminator() + t_right = FakeTerminator() + + d_left = manager.Dilator(reactor, eq, cooperator, no_listen=True) + d_left.wire(send_left, t_left) + d_left.got_key(key) + d_left.got_wormhole_versions({"can-dilate": ["1"]}) + send_left.dilator = d_left + + d_right = manager.Dilator(reactor, eq, cooperator) + d_right.wire(send_right, t_right) + d_right.got_key(key) + d_right.got_wormhole_versions({"can-dilate": ["1"]}) + send_right.dilator = d_right + + with mock.patch("wormhole._dilation.connector.ipaddrs.find_addresses", + return_value=["127.0.0.1"]): + eps_left_d = d_left.dilate() + eps_right_d = d_right.dilate() + + eps_left = yield eps_left_d + eps_right = yield eps_right_d + + #print("left connected", eps_left) + #print("right connected", eps_right) + + control_ep_left, connect_ep_left, listen_ep_left = eps_left + control_ep_right, connect_ep_right, listen_ep_right = eps_right + + #control_ep_left.connect( + + # we normally shut down with w.close(), which calls Dilator.stop(), + # which calls Terminator.stoppedD(), which (after everything else is + # done) calls Boss.stopped + d_left.stop() + d_right.stop() + + yield t_left.d + yield t_right.d + + diff --git a/src/wormhole/test/dilate/test_connection.py b/src/wormhole/test/dilate/test_connection.py index ee761fd..345f18b 100644 --- a/src/wormhole/test/dilate/test_connection.py +++ b/src/wormhole/test/dilate/test_connection.py @@ -9,6 +9,7 @@ from ..._interfaces import IDilationConnector from ..._dilation.roles import LEADER, FOLLOWER from ..._dilation.connection import (DilatedConnectionProtocol, encode_record, KCM, Open, Ack) +from ..._dilation.encode import to_be4 from .common import clear_mock_calls @@ -19,7 +20,7 @@ def make_con(role, use_relay=False): alsoProvides(connector, IDilationConnector) n = mock.Mock() # pretends to be a Noise object n.write_message = mock.Mock(side_effect=[b"handshake"]) - c = DilatedConnectionProtocol(eq, role, connector, n, + c = DilatedConnectionProtocol(eq, role, "desc", connector, n, b"outbound_prologue\n", b"inbound_prologue\n") if use_relay: c.use_relay(b"relay_handshake\n") @@ -29,6 +30,10 @@ def make_con(role, use_relay=False): class Connection(unittest.TestCase): + def test_hashable(self): + c, n, connector, t, eq = make_con(LEADER) + hash(c) + def test_bad_prologue(self): c, n, connector, t, eq = make_con(LEADER) c.makeConnection(t) @@ -52,7 +57,7 @@ class Connection(unittest.TestCase): def _test_no_relay(self, role): c, n, connector, t, eq = make_con(role) t_kcm = KCM() - t_open = Open(seqnum=1, scid=0x11223344) + t_open = Open(seqnum=1, scid=to_be4(0x11223344)) t_ack = Ack(resp_seqnum=2) n.decrypt = mock.Mock(side_effect=[ encode_record(t_kcm), @@ -69,10 +74,20 @@ class Connection(unittest.TestCase): clear_mock_calls(n, connector, t, m) c.dataReceived(b"inbound_prologue\n") - self.assertEqual(n.mock_calls, [mock.call.write_message()]) - self.assertEqual(connector.mock_calls, []) + exp_handshake = b"\x00\x00\x00\x09handshake" - self.assertEqual(t.mock_calls, [mock.call.write(exp_handshake)]) + if role is LEADER: + # the LEADER sends the Noise handshake message immediately upon + # receipt of the prologue + self.assertEqual(n.mock_calls, [mock.call.write_message()]) + self.assertEqual(t.mock_calls, [mock.call.write(exp_handshake)]) + else: + # however the FOLLOWER waits until receiving the leader's + # handshake before sending their own + self.assertEqual(n.mock_calls, []) + self.assertEqual(t.mock_calls, []) + self.assertEqual(connector.mock_calls, []) + clear_mock_calls(n, connector, t, m) c.dataReceived(b"\x00\x00\x00\x0Ahandshake2") @@ -84,13 +99,16 @@ class Connection(unittest.TestCase): self.assertEqual(t.mock_calls, []) self.assertEqual(c._manager, None) else: - # we're the follower, so we encrypt and send the KCM immediately + # we're the follower, so we send our Noise handshake, then + # encrypt and send the KCM immediately self.assertEqual(n.mock_calls, [ mock.call.read_message(b"handshake2"), + mock.call.write_message(), mock.call.encrypt(encode_record(t_kcm)), ]) self.assertEqual(connector.mock_calls, []) self.assertEqual(t.mock_calls, [ + mock.call.write(exp_handshake), mock.call.write(exp_kcm)]) self.assertEqual(c._manager, None) clear_mock_calls(n, connector, t, m) diff --git a/src/wormhole/test/dilate/test_connector.py b/src/wormhole/test/dilate/test_connector.py index 2bb8809..accef38 100644 --- a/src/wormhole/test/dilate/test_connector.py +++ b/src/wormhole/test/dilate/test_connector.py @@ -5,6 +5,7 @@ from zope.interface import alsoProvides from twisted.trial import unittest from twisted.internet.task import Clock from twisted.internet.defer import Deferred +from twisted.internet.address import IPv4Address from ...eventual import EventualQueue from ..._interfaces import IDilationManager, IDilationConnector from ..._hints import DirectTCPV1Hint, RelayV1Hint, TorTCPV1Hint @@ -34,11 +35,11 @@ class Outbound(unittest.TestCase): p0 = mock.Mock() c.build_protocol = mock.Mock(return_value=p0) relay_handshake = None - f = OutboundConnectionFactory(c, relay_handshake) + f = OutboundConnectionFactory(c, relay_handshake, "desc") addr = object() p = f.buildProtocol(addr) self.assertIdentical(p, p0) - self.assertEqual(c.mock_calls, [mock.call.build_protocol(addr)]) + self.assertEqual(c.mock_calls, [mock.call.build_protocol(addr, "desc")]) self.assertEqual(p.mock_calls, []) self.assertIdentical(p.factory, f) @@ -48,11 +49,11 @@ class Outbound(unittest.TestCase): p0 = mock.Mock() c.build_protocol = mock.Mock(return_value=p0) relay_handshake = b"relay handshake" - f = OutboundConnectionFactory(c, relay_handshake) + f = OutboundConnectionFactory(c, relay_handshake, "desc") addr = object() p = f.buildProtocol(addr) self.assertIdentical(p, p0) - self.assertEqual(c.mock_calls, [mock.call.build_protocol(addr)]) + self.assertEqual(c.mock_calls, [mock.call.build_protocol(addr, "desc")]) self.assertEqual(p.mock_calls, [mock.call.use_relay(relay_handshake)]) self.assertIdentical(p.factory, f) @@ -63,10 +64,10 @@ class Inbound(unittest.TestCase): p0 = mock.Mock() c.build_protocol = mock.Mock(return_value=p0) f = InboundConnectionFactory(c) - addr = object() + addr = IPv4Address("TCP", "1.2.3.4", 55) p = f.buildProtocol(addr) self.assertIdentical(p, p0) - self.assertEqual(c.mock_calls, [mock.call.build_protocol(addr)]) + self.assertEqual(c.mock_calls, [mock.call.build_protocol(addr, "<-tcp:1.2.3.4:55")]) self.assertIdentical(p.factory, f) def make_connector(listen=True, tor=False, relay=None, role=roles.LEADER): @@ -115,13 +116,13 @@ class TestConnector(unittest.TestCase): return_value=n0) as bn: with mock.patch("wormhole._dilation.connector.DilatedConnectionProtocol", return_value=p0) as dcp: - p = c.build_protocol(addr) + p = c.build_protocol(addr, "desc") self.assertEqual(bn.mock_calls, [mock.call()]) self.assertEqual(n0.mock_calls, [mock.call.set_psks(h.dilation_key), mock.call.set_as_initiator()]) self.assertIdentical(p, p0) self.assertEqual(dcp.mock_calls, - [mock.call(h.eq, h.role, c, n0, + [mock.call(h.eq, h.role, "desc", c, n0, PROLOGUE_LEADER, PROLOGUE_FOLLOWER)]) def test_build_protocol_follower(self): @@ -133,13 +134,13 @@ class TestConnector(unittest.TestCase): return_value=n0) as bn: with mock.patch("wormhole._dilation.connector.DilatedConnectionProtocol", return_value=p0) as dcp: - p = c.build_protocol(addr) + p = c.build_protocol(addr, "desc") self.assertEqual(bn.mock_calls, [mock.call()]) self.assertEqual(n0.mock_calls, [mock.call.set_psks(h.dilation_key), mock.call.set_as_responder()]) self.assertIdentical(p, p0) self.assertEqual(dcp.mock_calls, - [mock.call(h.eq, h.role, c, n0, + [mock.call(h.eq, h.role, "desc", c, n0, PROLOGUE_FOLLOWER, PROLOGUE_LEADER)]) def test_start_stop(self): @@ -244,7 +245,7 @@ class TestConnector(unittest.TestCase): with mock.patch("wormhole._dilation.connector.OutboundConnectionFactory", return_value=f) as ocf: h.clock.advance(1.0) - self.assertEqual(ocf.mock_calls, [mock.call(c, None)]) + self.assertEqual(ocf.mock_calls, [mock.call(c, None, "->tcp:foo:55")]) self.assertEqual(ep.connect.mock_calls, [mock.call(f)]) p = mock.Mock() d.callback(p) @@ -269,7 +270,7 @@ class TestConnector(unittest.TestCase): return_value=f) as ocf: h.clock.advance(1.0) handshake = build_sided_relay_handshake(h.dilation_key, h.side) - self.assertEqual(ocf.mock_calls, [mock.call(c, handshake)]) + self.assertEqual(ocf.mock_calls, [mock.call(c, handshake, "->relay:tcp:foo:55")]) def test_listen_but_tor(self): c, h = make_connector(listen=True, tor=True, role=roles.LEADER) @@ -388,7 +389,7 @@ class Race(unittest.TestCase): c.add_candidate(p1) self.assertEqual(h.manager.mock_calls, []) h.eq.flush_sync() - self.assertEqual(h.manager.mock_calls, [mock.call.use_connection(p1)]) + self.assertEqual(h.manager.mock_calls, [mock.call.connector_connection_made(p1)]) self.assertEqual(p1.mock_calls, [mock.call.select(h.manager), mock.call.send_record(KCM())]) @@ -409,7 +410,7 @@ class Race(unittest.TestCase): c.add_candidate(p1) self.assertEqual(h.manager.mock_calls, []) h.eq.flush_sync() - self.assertEqual(h.manager.mock_calls, [mock.call.use_connection(p1)]) + self.assertEqual(h.manager.mock_calls, [mock.call.connector_connection_made(p1)]) # just like LEADER, but follower doesn't send KCM now (it sent one # earlier, to tell the leader that this connection looks viable) self.assertEqual(p1.mock_calls, @@ -432,7 +433,7 @@ class Race(unittest.TestCase): c.add_candidate(p1) self.assertEqual(h.manager.mock_calls, []) h.eq.flush_sync() - self.assertEqual(h.manager.mock_calls, [mock.call.use_connection(p1)]) + self.assertEqual(h.manager.mock_calls, [mock.call.connector_connection_made(p1)]) clear_mock_calls(h.manager) self.assertEqual(p1.mock_calls, [mock.call.select(h.manager), @@ -454,10 +455,9 @@ class Race(unittest.TestCase): c.add_candidate(p1) self.assertEqual(h.manager.mock_calls, []) h.eq.flush_sync() - self.assertEqual(h.manager.mock_calls, [mock.call.use_connection(p1)]) self.assertEqual(p1.mock_calls, [mock.call.select(h.manager), mock.call.send_record(KCM())]) + self.assertEqual(h.manager.mock_calls, [mock.call.connector_connection_made(p1)]) c.stop() - diff --git a/src/wormhole/test/dilate/test_full.py b/src/wormhole/test/dilate/test_full.py new file mode 100644 index 0000000..16d990a --- /dev/null +++ b/src/wormhole/test/dilate/test_full.py @@ -0,0 +1,77 @@ +from __future__ import print_function, absolute_import, unicode_literals +import wormhole +from twisted.internet import reactor +from twisted.internet.defer import Deferred, inlineCallbacks, gatherResults +from twisted.internet.protocol import Protocol, Factory +from twisted.trial import unittest + +from ..common import ServerBase +from ...eventual import EventualQueue +from ..._dilation._noise import NoiseConnection + +APPID = u"lothar.com/dilate-test" + +def doBoth(d1, d2): + return gatherResults([d1, d2], True) + +class L(Protocol): + def connectionMade(self): + print("got connection") + self.transport.write(b"hello\n") + def dataReceived(self, data): + print("dataReceived: {}".format(data)) + self.factory.d.callback(data) + def connectionLost(self, why): + print("connectionLost") + + +class Full(ServerBase, unittest.TestCase): + @inlineCallbacks + def setUp(self): + if not NoiseConnection: + raise unittest.SkipTest("noiseprotocol unavailable") + # test_welcome wants to see [current_cli_version] + yield self._setup_relay(None) + + @inlineCallbacks + def test_full(self): + eq = EventualQueue(reactor) + w1 = wormhole.create(APPID, self.relayurl, reactor, _enable_dilate=True) + w2 = wormhole.create(APPID, self.relayurl, reactor, _enable_dilate=True) + w1.allocate_code() + code = yield w1.get_code() + print("code is: {}".format(code)) + w2.set_code(code) + yield doBoth(w1.get_verifier(), w2.get_verifier()) + print("connected") + + eps1_d = w1.dilate() + eps2_d = w2.dilate() + (eps1, eps2) = yield doBoth(eps1_d, eps2_d) + (control_ep1, connect_ep1, listen_ep1) = eps1 + (control_ep2, connect_ep2, listen_ep2) = eps2 + print("w.dilate ready") + + f1 = Factory() + f1.protocol = L + f1.d = Deferred() + f1.d.addCallback(lambda data: eq.fire_eventually(data)) + d1 = control_ep1.connect(f1) + + f2 = Factory() + f2.protocol = L + f2.d = Deferred() + f2.d.addCallback(lambda data: eq.fire_eventually(data)) + d2 = control_ep2.connect(f2) + yield d1 + yield d2 + print("control endpoints connected") + data1 = yield f1.d + data2 = yield f2.d + self.assertEqual(data1, b"hello\n") + self.assertEqual(data2, b"hello\n") + + yield w1.close() + yield w2.close() + + test_full.timeout = 30 diff --git a/src/wormhole/test/dilate/test_inbound.py b/src/wormhole/test/dilate/test_inbound.py index 392a661..f512575 100644 --- a/src/wormhole/test/dilate/test_inbound.py +++ b/src/wormhole/test/dilate/test_inbound.py @@ -27,12 +27,12 @@ class InboundTest(unittest.TestCase): self.assertFalse(i.is_record_old(r2)) self.assertFalse(i.is_record_old(r3)) - i.update_ack_watermark(r1) + i.update_ack_watermark(r1.seqnum) self.assertTrue(i.is_record_old(r1)) self.assertFalse(i.is_record_old(r2)) self.assertFalse(i.is_record_old(r3)) - i.update_ack_watermark(r2) + i.update_ack_watermark(r2.seqnum) self.assertTrue(i.is_record_old(r1)) self.assertTrue(i.is_record_old(r2)) self.assertFalse(i.is_record_old(r3)) diff --git a/src/wormhole/test/dilate/test_manager.py b/src/wormhole/test/dilate/test_manager.py index e223a1c..b8258f1 100644 --- a/src/wormhole/test/dilate/test_manager.py +++ b/src/wormhole/test/dilate/test_manager.py @@ -5,7 +5,7 @@ from twisted.internet.defer import Deferred from twisted.internet.task import Clock, Cooperator import mock from ...eventual import EventualQueue -from ..._interfaces import ISend, IDilationManager +from ..._interfaces import ISend, IDilationManager, ITerminator from ...util import dict_to_bytes from ..._dilation import roles from ..._dilation.encode import to_be4 @@ -32,7 +32,9 @@ def make_dilator(): send = mock.Mock() alsoProvides(send, ISend) dil = Dilator(reactor, eq, coop) - dil.wire(send) + terminator = mock.Mock() + alsoProvides(terminator, ITerminator) + dil.wire(send, terminator) return dil, send, reactor, eq, clock, coop @@ -64,7 +66,7 @@ class TestDilator(unittest.TestCase): dil.got_wormhole_versions({"can-dilate": ["1"]}) # that should create the Manager self.assertEqual(ml.mock_calls, [mock.call(send, "us", transit_key, - None, reactor, eq, coop)]) + None, reactor, eq, coop, no_listen=False)]) # and tell it to start, and get wait-for-it-to-connect Deferred self.assertEqual(m.mock_calls, [mock.call.start(), mock.call.when_first_connected(), @@ -180,7 +182,7 @@ class TestDilator(unittest.TestCase): return_value="us"): dil.got_wormhole_versions({"can-dilate": ["1"]}) self.assertEqual(ml.mock_calls, [mock.call(send, "us", b"key", - None, reactor, eq, coop)]) + None, reactor, eq, coop, no_listen=False)]) self.assertEqual(m.mock_calls, [mock.call.start(), mock.call.rx_PLEASE(pleasemsg), mock.call.rx_HINTS(hintmsg), @@ -198,7 +200,7 @@ class TestDilator(unittest.TestCase): return_value="us"): dil.got_wormhole_versions({"can-dilate": ["1"]}) self.assertEqual(ml.mock_calls, [mock.call(send, "us", b"key", - relay, reactor, eq, coop), + relay, reactor, eq, coop, no_listen=False), mock.call().start(), mock.call().when_first_connected()]) diff --git a/src/wormhole/test/dilate/test_outbound.py b/src/wormhole/test/dilate/test_outbound.py index 6ba5264..ed43a47 100644 --- a/src/wormhole/test/dilate/test_outbound.py +++ b/src/wormhole/test/dilate/test_outbound.py @@ -105,7 +105,7 @@ class OutboundTest(unittest.TestCase): # as soon as the connection is established, everything is sent o.use_connection(c) - self.assertEqual(c.mock_calls, [mock.call.registerProducer(o, True), + self.assertEqual(c.mock_calls, [mock.call.transport.registerProducer(o, True), mock.call.send_record(r1), mock.call.send_record(r2)]) self.assertEqual(list(o._outbound_queue), [r1, r2]) @@ -131,7 +131,7 @@ class OutboundTest(unittest.TestCase): # after each write. So only r1 should have been sent before getting # paused o.use_connection(c) - self.assertEqual(c.mock_calls, [mock.call.registerProducer(o, True), + self.assertEqual(c.mock_calls, [mock.call.transport.registerProducer(o, True), mock.call.send_record(r1)]) self.assertEqual(list(o._outbound_queue), [r1, r2]) self.assertEqual(list(o._queued_unsent), [r2]) @@ -172,7 +172,7 @@ class OutboundTest(unittest.TestCase): self.assertEqual(list(o._queued_unsent), []) o.use_connection(c) - self.assertEqual(c.mock_calls, [mock.call.registerProducer(o, True), + self.assertEqual(c.mock_calls, [mock.call.transport.registerProducer(o, True), mock.call.send_record(r1)]) self.assertEqual(list(o._outbound_queue), [r1, r2]) self.assertEqual(list(o._queued_unsent), [r2]) @@ -191,7 +191,7 @@ class OutboundTest(unittest.TestCase): def test_pause(self): o, m, c = make_outbound() o.use_connection(c) - self.assertEqual(c.mock_calls, [mock.call.registerProducer(o, True)]) + self.assertEqual(c.mock_calls, [mock.call.transport.registerProducer(o, True)]) self.assertEqual(list(o._outbound_queue), []) self.assertEqual(list(o._queued_unsent), []) clear_mock_calls(c) @@ -519,7 +519,7 @@ class OutboundTest(unittest.TestCase): o.use_connection(c) o.send_if_connected(KCM()) - self.assertEqual(c.mock_calls, [mock.call.registerProducer(o, True), + self.assertEqual(c.mock_calls, [mock.call.transport.registerProducer(o, True), mock.call.send_record(KCM())]) def test_tolerate_duplicate_pause_resume(self): diff --git a/src/wormhole/test/dilate/test_parse.py b/src/wormhole/test/dilate/test_parse.py index f7276a6..f40c661 100644 --- a/src/wormhole/test/dilate/test_parse.py +++ b/src/wormhole/test/dilate/test_parse.py @@ -13,11 +13,11 @@ class Parse(unittest.TestCase): self.assertEqual(parse_record(b"\x02\x55\x44\x33\x22"), Pong(ping_id=b"\x55\x44\x33\x22")) self.assertEqual(parse_record(b"\x03\x00\x00\x02\x01\x00\x00\x01\x00"), - Open(scid=513, seqnum=256)) + Open(scid=b"\x00\x00\x02\x01", seqnum=256)) self.assertEqual(parse_record(b"\x04\x00\x00\x02\x02\x00\x00\x01\x01dataaa"), - Data(scid=514, seqnum=257, data=b"dataaa")) + Data(scid=b"\x00\x00\x02\x02", seqnum=257, data=b"dataaa")) self.assertEqual(parse_record(b"\x05\x00\x00\x02\x03\x00\x00\x01\x02"), - Close(scid=515, seqnum=258)) + Close(scid=b"\x00\x00\x02\x03", seqnum=258)) self.assertEqual(parse_record(b"\x06\x00\x00\x01\x03"), Ack(resp_seqnum=259)) with mock.patch("wormhole._dilation.connection.log.err") as le: @@ -31,11 +31,11 @@ class Parse(unittest.TestCase): self.assertEqual(encode_record(KCM()), b"\x00") self.assertEqual(encode_record(Ping(ping_id=b"ping")), b"\x01ping") self.assertEqual(encode_record(Pong(ping_id=b"pong")), b"\x02pong") - self.assertEqual(encode_record(Open(scid=65536, seqnum=16)), + self.assertEqual(encode_record(Open(scid=b"\x00\x01\x00\x00", seqnum=16)), b"\x03\x00\x01\x00\x00\x00\x00\x00\x10") - self.assertEqual(encode_record(Data(scid=65537, seqnum=17, data=b"dataaa")), + self.assertEqual(encode_record(Data(scid=b"\x00\x01\x00\x01", seqnum=17, data=b"dataaa")), b"\x04\x00\x01\x00\x01\x00\x00\x00\x11dataaa") - self.assertEqual(encode_record(Close(scid=65538, seqnum=18)), + self.assertEqual(encode_record(Close(scid=b"\x00\x01\x00\x02", seqnum=18)), b"\x05\x00\x01\x00\x02\x00\x00\x00\x12") self.assertEqual(encode_record(Ack(resp_seqnum=19)), b"\x06\x00\x00\x00\x13") diff --git a/src/wormhole/test/dilate/test_record.py b/src/wormhole/test/dilate/test_record.py index 63a784c..252a8b0 100644 --- a/src/wormhole/test/dilate/test_record.py +++ b/src/wormhole/test/dilate/test_record.py @@ -6,13 +6,15 @@ from ..._dilation._noise import NoiseInvalidMessage from ..._dilation.connection import (IFramer, Frame, Prologue, _Record, Handshake, Disconnect, Ping) +from ..._dilation.roles import LEADER def make_record(): f = mock.Mock() alsoProvides(f, IFramer) n = mock.Mock() # pretends to be a Noise object - r = _Record(f, n) + r = _Record(f, n, LEADER) + r.set_role_leader() return r, f, n @@ -30,7 +32,8 @@ class Record(unittest.TestCase): n.write_message = mock.Mock(return_value=b"tx-handshake") p1, p2 = object(), object() n.decrypt = mock.Mock(side_effect=[p1, p2]) - r = _Record(f, n) + r = _Record(f, n, LEADER) + r.set_role_leader() self.assertEqual(f.mock_calls, []) r.connectionMade() self.assertEqual(f.mock_calls, [mock.call.connectionMade()]) @@ -79,7 +82,8 @@ class Record(unittest.TestCase): n.write_message = mock.Mock(return_value=b"tx-handshake") nvm = NoiseInvalidMessage() n.read_message = mock.Mock(side_effect=nvm) - r = _Record(f, n) + r = _Record(f, n, LEADER) + r.set_role_leader() self.assertEqual(f.mock_calls, []) r.connectionMade() self.assertEqual(f.mock_calls, [mock.call.connectionMade()]) @@ -103,7 +107,8 @@ class Record(unittest.TestCase): n.write_message = mock.Mock(return_value=b"tx-handshake") nvm = NoiseInvalidMessage() n.decrypt = mock.Mock(side_effect=nvm) - r = _Record(f, n) + r = _Record(f, n, LEADER) + r.set_role_leader() self.assertEqual(f.mock_calls, []) r.connectionMade() self.assertEqual(f.mock_calls, [mock.call.connectionMade()]) @@ -124,7 +129,8 @@ class Record(unittest.TestCase): f1 = object() n.encrypt = mock.Mock(return_value=f1) r1 = Ping(b"pingid") - r = _Record(f, n) + r = _Record(f, n, LEADER) + r.set_role_leader() self.assertEqual(f.mock_calls, []) m1 = object() with mock.patch("wormhole._dilation.connection.encode_record", diff --git a/src/wormhole/test/test_machines.py b/src/wormhole/test/test_machines.py index dff3fb0..9e417d6 100644 --- a/src/wormhole/test/test_machines.py +++ b/src/wormhole/test/test_machines.py @@ -1220,7 +1220,8 @@ class Terminator(unittest.TestCase): rc = Dummy("rc", events, IRendezvousConnector, "stop") n = Dummy("n", events, INameplate, "close") m = Dummy("m", events, IMailbox, "close") - t.wire(b, rc, n, m) + d = Dummy("d", events, IDilator, "stop") + t.wire(b, rc, n, m, d) return t, b, rc, n, m, events # there are three events, and we need to test all orderings of them @@ -1229,45 +1230,64 @@ class Terminator(unittest.TestCase): input_events = { "mailbox": lambda: t.mailbox_done(), "nameplate": lambda: t.nameplate_done(), - "close": lambda: t.close("happy"), + "rc": lambda: t.close("happy"), } close_events = [ ("n.close", ), ("m.close", "happy"), ] + if ev1 == "mailbox": + close_events.remove(("m.close", "happy")) + elif ev1 == "nameplate": + close_events.remove(("n.close",)) + input_events[ev1]() expected = [] - if ev1 == "close": + if ev1 == "rc": expected.extend(close_events) self.assertEqual(events, expected) events[:] = [] + if ev2 == "mailbox": + close_events.remove(("m.close", "happy")) + elif ev2 == "nameplate": + close_events.remove(("n.close",)) + input_events[ev2]() expected = [] - if ev2 == "close": + if ev2 == "rc": expected.extend(close_events) self.assertEqual(events, expected) events[:] = [] + if ev3 == "mailbox": + close_events.remove(("m.close", "happy")) + elif ev3 == "nameplate": + close_events.remove(("n.close",)) + input_events[ev3]() expected = [] - if ev3 == "close": + if ev3 == "rc": expected.extend(close_events) expected.append(("rc.stop", )) self.assertEqual(events, expected) events[:] = [] - t.stopped() + t.stoppedRC() + self.assertEqual(events, [("d.stop", )]) + events[:] = [] + + t.stoppedD() self.assertEqual(events, [("b.closed", )]) def test_terminate(self): - self._do_test("mailbox", "nameplate", "close") - self._do_test("mailbox", "close", "nameplate") - self._do_test("nameplate", "mailbox", "close") - self._do_test("nameplate", "close", "mailbox") - self._do_test("close", "nameplate", "mailbox") - self._do_test("close", "mailbox", "nameplate") + self._do_test("mailbox", "nameplate", "rc") + self._do_test("mailbox", "rc", "nameplate") + self._do_test("nameplate", "mailbox", "rc") + self._do_test("nameplate", "rc", "mailbox") + self._do_test("rc", "nameplate", "mailbox") + self._do_test("rc", "mailbox", "nameplate") # TODO: test moods