Merge branch 'dilate-xfer'

Add an integration test which exercises a full w.dilate connection and the
control endpoint.

Still untested:

* reconnecting after the initial TCP connection is lost
* resending data that wasn't acked before the connection was lost
* (re)sending data that was submitted while no connection was available
* the connect- and listen- endpoints
This commit is contained in:
Brian Warner 2019-02-10 18:07:03 -08:00
commit a5e011f786
21 changed files with 524 additions and 141 deletions

View File

@ -81,8 +81,8 @@ class Boss(object):
self._A.wire(self._RC, self._C)
self._I.wire(self._C, self._L)
self._C.wire(self, self._A, self._N, self._K, self._I)
self._T.wire(self, self._RC, self._N, self._M)
self._D.wire(self._S)
self._T.wire(self, self._RC, self._N, self._M, self._D)
self._D.wire(self._S, self._T)
def _init_other_state(self):
self._did_start_code = False

View File

@ -4,6 +4,12 @@ except ImportError:
class NoiseInvalidMessage(Exception):
pass
try:
from noise.exceptions import NoiseHandshakeError
except ImportError:
class NoiseHandshakeError(Exception):
pass
try:
from noise.connection import NoiseConnection
except ImportError:

View File

@ -11,8 +11,8 @@ from twisted.internet.interfaces import ITransport
from .._interfaces import IDilationConnector
from ..observer import OneShotObserver
from .encode import to_be4, from_be4
from .roles import FOLLOWER
from ._noise import NoiseInvalidMessage
from .roles import LEADER, FOLLOWER
from ._noise import NoiseInvalidMessage, NoiseHandshakeError
# InboundFraming is given data and returns Frames (Noise wire-side
# bytestrings). It handles the relay handshake and the prologue. The Frames it
@ -56,6 +56,23 @@ def first(l):
class Disconnect(Exception):
pass
# all connections look like:
# (step 1: only for outbound connections)
# 1: if we're connecting to a transit relay:
# * send "sided relay handshake": "please relay TOKEN for side SIDE\n"
# * the relay will send "ok\n" if/when our peer connects
# * a non-relay will probably send junk
# * wait for "ok\n", hang up if we get anything different
# (all subsequent steps are for both inbound and outbound connections)
# 2: send PROLOGUE_LEADER/FOLLOWER: "Magic-Wormhole Dilation Handshale v1 (l/f)\n\n"
# 3: wait for the opposite PROLOGUE string, else hang up
# (everything past this point is a Frame, with be4 length prefix. Frames are
# either noise handshake or an encrypted message)
# 4: if LEADER, send noise handshake string. if FOLLOWER, wait for it
# 5: if FOLLOWER, send noise response string. if LEADER, wait for it
# 6: ...
RelayOK = namedtuple("RelayOk", [])
Prologue = namedtuple("Prologue", [])
@ -193,7 +210,7 @@ class _Framer(object):
def add_and_parse(self, data):
# we can't make this an @m.input because we can't change the state
# from within an input. Instead, let the state choose the parser to
# use, and use the parsed token drive a state transition.
# use, then use the parsed token to drive a state transition.
self._buffer += data
while True:
# it'd be nice to use an iterator here, but since self.parse()
@ -233,7 +250,7 @@ Ping = namedtuple("Ping", ["ping_id"]) # ping_id is arbitrary 4-byte value
Pong = namedtuple("Pong", ["ping_id"])
Open = namedtuple("Open", ["seqnum", "scid"]) # seqnum is integer
Data = namedtuple("Data", ["seqnum", "scid", "data"])
Close = namedtuple("Close", ["seqnum", "scid"]) # scid is integer
Close = namedtuple("Close", ["seqnum", "scid"]) # scid is arbitrary 4-byte value
Ack = namedtuple("Ack", ["resp_seqnum"]) # resp_seqnum is integer
Records = (KCM, Ping, Pong, Open, Data, Close, Ack)
Handshake_or_Records = (Handshake,) + Records
@ -258,16 +275,16 @@ def parse_record(plaintext):
ping_id = plaintext[1:5]
return Pong(ping_id)
if msgtype == T_OPEN:
scid = from_be4(plaintext[1:5])
scid = plaintext[1:5]
seqnum = from_be4(plaintext[5:9])
return Open(seqnum, scid)
if msgtype == T_DATA:
scid = from_be4(plaintext[1:5])
scid = plaintext[1:5]
seqnum = from_be4(plaintext[5:9])
data = plaintext[9:]
return Data(seqnum, scid, data)
if msgtype == T_CLOSE:
scid = from_be4(plaintext[1:5])
scid = plaintext[1:5]
seqnum = from_be4(plaintext[5:9])
return Close(seqnum, scid)
if msgtype == T_ACK:
@ -285,28 +302,36 @@ def encode_record(r):
if isinstance(r, Pong):
return b"\x02" + r.ping_id
if isinstance(r, Open):
assert isinstance(r.scid, six.integer_types)
assert isinstance(r.scid, bytes)
assert len(r.scid) == 4
assert isinstance(r.seqnum, six.integer_types)
return b"\x03" + to_be4(r.scid) + to_be4(r.seqnum)
return b"\x03" + r.scid + to_be4(r.seqnum)
if isinstance(r, Data):
assert isinstance(r.scid, six.integer_types)
assert isinstance(r.scid, bytes)
assert len(r.scid) == 4
assert isinstance(r.seqnum, six.integer_types)
return b"\x04" + to_be4(r.scid) + to_be4(r.seqnum) + r.data
return b"\x04" + r.scid + to_be4(r.seqnum) + r.data
if isinstance(r, Close):
assert isinstance(r.scid, six.integer_types)
assert isinstance(r.scid, bytes)
assert len(r.scid) == 4
assert isinstance(r.seqnum, six.integer_types)
return b"\x05" + to_be4(r.scid) + to_be4(r.seqnum)
return b"\x05" + r.scid + to_be4(r.seqnum)
if isinstance(r, Ack):
assert isinstance(r.resp_seqnum, six.integer_types)
return b"\x06" + to_be4(r.resp_seqnum)
raise TypeError(r)
def _is_role(_record, _attr, value):
if value not in [LEADER, FOLLOWER]:
raise ValueError("role must be LEADER or FOLLOWER")
@attrs
@implementer(IRecord)
class _Record(object):
_framer = attrib(validator=provides(IFramer))
_noise = attrib()
_role = attrib(default="unspecified", validator=_is_role) # for debugging
n = MethodicalMachine()
# TODO: set_trace
@ -321,17 +346,37 @@ class _Record(object):
# states: want_prologue, want_handshake, want_record
@n.state(initial=True)
def want_prologue(self):
def no_role_set(self):
pass # pragma: no cover
@n.state()
def want_handshake(self):
def want_prologue_leader(self):
pass # pragma: no cover
@n.state()
def want_prologue_follower(self):
pass # pragma: no cover
@n.state()
def want_handshake_leader(self):
pass # pragma: no cover
@n.state()
def want_handshake_follower(self):
pass # pragma: no cover
@n.state()
def want_message(self):
pass # pragma: no cover
@n.input()
def set_role_leader(self):
pass
@n.input()
def set_role_follower(self):
pass
@n.input()
def got_prologue(self):
pass
@ -340,9 +385,20 @@ class _Record(object):
def got_frame(self, frame):
pass
@n.output()
def ignore_and_send_handshake(self, frame):
self._send_handshake()
@n.output()
def send_handshake(self):
handshake = self._noise.write_message() # generate the ephemeral key
self._send_handshake()
def _send_handshake(self):
try:
handshake = self._noise.write_message() # generate the ephemeral key
except NoiseHandshakeError as e:
log.err(e, "noise error during handshake")
raise
self._framer.send_frame(handshake)
@n.output()
@ -367,10 +423,19 @@ class _Record(object):
raise Disconnect()
return parse_record(message)
want_prologue.upon(got_prologue, outputs=[send_handshake],
enter=want_handshake)
want_handshake.upon(got_frame, outputs=[process_handshake],
collector=first, enter=want_message)
no_role_set.upon(set_role_leader, outputs=[], enter=want_prologue_leader)
want_prologue_leader.upon(got_prologue, outputs=[send_handshake],
enter=want_handshake_leader)
want_handshake_leader.upon(got_frame, outputs=[process_handshake],
collector=first, enter=want_message)
no_role_set.upon(set_role_follower, outputs=[], enter=want_prologue_follower)
want_prologue_follower.upon(got_prologue, outputs=[],
enter=want_handshake_follower)
want_handshake_follower.upon(got_frame, outputs=[process_handshake,
ignore_and_send_handshake],
collector=first, enter=want_message)
want_message.upon(got_frame, outputs=[decrypt_message],
collector=first, enter=want_message)
@ -393,7 +458,7 @@ class _Record(object):
self._framer.send_frame(frame)
@attrs
@attrs(cmp=False)
class DilatedConnectionProtocol(Protocol, object):
"""I manage an L2 connection.
@ -408,12 +473,13 @@ class DilatedConnectionProtocol(Protocol, object):
At any given time, there is at most one active L2 connection.
"""
_eventual_queue = attrib()
_eventual_queue = attrib(repr=False)
_role = attrib()
_connector = attrib(validator=provides(IDilationConnector))
_noise = attrib()
_outbound_prologue = attrib(validator=instance_of(bytes))
_inbound_prologue = attrib(validator=instance_of(bytes))
_description = attrib()
_connector = attrib(validator=provides(IDilationConnector), repr=False)
_noise = attrib(repr=False)
_outbound_prologue = attrib(validator=instance_of(bytes), repr=False)
_inbound_prologue = attrib(validator=instance_of(bytes), repr=False)
_use_relay = False
_relay_handshake = None
@ -457,6 +523,8 @@ class DilatedConnectionProtocol(Protocol, object):
@m.output()
def set_manager(self, manager):
self._manager = manager
self.when_disconnected().addCallback(lambda c:
manager.connector_connection_lost())
@m.output()
def can_send_records(self, manager):
@ -493,12 +561,20 @@ class DilatedConnectionProtocol(Protocol, object):
# IProtocol methods
def connectionMade(self):
framer = _Framer(self.transport,
self._outbound_prologue, self._inbound_prologue)
if self._use_relay:
framer.use_relay(self._relay_handshake)
self._record = _Record(framer, self._noise)
self._record.connectionMade()
try:
framer = _Framer(self.transport,
self._outbound_prologue, self._inbound_prologue)
if self._use_relay:
framer.use_relay(self._relay_handshake)
self._record = _Record(framer, self._noise, self._role)
if self._role is LEADER:
self._record.set_role_leader()
else:
self._record.set_role_follower()
self._record.connectionMade()
except:
log.err()
raise
def dataReceived(self, data):
try:

View File

@ -9,6 +9,7 @@ from twisted.internet.task import deferLater
from twisted.internet.defer import DeferredList
from twisted.internet.endpoints import serverFromString
from twisted.internet.protocol import ClientFactory, ServerFactory
from twisted.internet.address import HostnameAddress, IPv4Address, IPv6Address
from twisted.python import log
from .. import ipaddrs # TODO: move into _dilation/
from .._interfaces import IDilationConnector, IDilationManager
@ -39,9 +40,36 @@ NOISEPROTO = b"Noise_NNpsk0_25519_ChaChaPoly_BLAKE2s"
def build_noise():
return NoiseConnection.from_name(NOISEPROTO)
@attrs
@attrs(cmp=False)
@implementer(IDilationConnector)
class Connector(object):
"""I manage a single generation of connection.
The Manager creates one of me at a time, whenever it wants a connection
(which is always, once w.dilate() has been called and we know the remote
end can dilate, and is expressed by the Manager calling my .start()
method). I am discarded when my established connection is lost (and if we
still want to be connected, a new generation is started and a new
Connector is created). I am also discarded if we stop wanting to be
connected (which the Manager expresses by calling my .stop() method).
I manage the race between multiple connections for a specific generation
of the dilated connection.
I send connection hints when my InboundConnectionFactory yields addresses
(self.listener_ready), and I initiate outbond connections (with
OutboundConnectionFactory) as I receive connection hints from my peer
(self.got_hints). Both factories use my build_protocol() method to create
connection.DilatedConnectionProtocol instances. I track these protocol
instances until one finishes negotiation and wins the race. I then shut
down the others, remember the winner as self._winning_connection, and
deliver the winner to manager.connector_connection_made(c).
When an active connection is lost, we call manager.connector_connection_lost,
allowing the manager to decide whether it wants to start a new generation
or not.
"""
_dilation_key = attrib(validator=instance_of(type(b"")))
_transit_relay_location = attrib(validator=optional(instance_of(type(u""))))
_manager = attrib(validator=provides(IDilationManager))
@ -83,7 +111,7 @@ class Connector(object):
{"type": "relay-v1"},
]
def build_protocol(self, addr):
def build_protocol(self, addr, description):
# encryption: let's use Noise NNpsk0 (or maybe NNpsk2). That uses
# ephemeral keys plus a pre-shared symmetric key (the Transit key), a
# different one for each potential connection.
@ -98,6 +126,7 @@ class Connector(object):
outbound_prologue = PROLOGUE_FOLLOWER
inbound_prologue = PROLOGUE_LEADER
p = DilatedConnectionProtocol(self._eventual_queue, self._role,
description,
self, noise,
outbound_prologue, inbound_prologue)
return p
@ -181,10 +210,13 @@ class Connector(object):
self.stop_pending_connections()
c.select(self._manager) # subsequent frames go directly to the manager
# c.select also wires up when_disconnected() to fire
# manager.connector_connection_lost(). TODO: rename this, since the
# Connector is no longer the one calling it
if self._role is LEADER:
# TODO: this should live in Connection
c.send_record(KCM()) # leader sends KCM now
self._manager.use_connection(c) # manager sends frames to Connection
self._manager.connector_connection_made(c) # manager sends frames to Connection
@m.output()
def stop_everything(self):
@ -199,11 +231,12 @@ class Connector(object):
return d # synchronization for tests
def stop_pending_connectors(self):
return DeferredList([d.cancel() for d in self._pending_connectors])
for d in self._pending_connectors:
d.cancel()
def stop_pending_connections(self):
d = self._pending_connections.when_next_empty()
[c.loseConnection() for c in self._pending_connections]
[c.disconnect() for c in self._pending_connections]
return d
def break_cycles(self):
@ -337,7 +370,7 @@ class Connector(object):
if is_relay:
relay_handshake = build_sided_relay_handshake(self._dilation_key,
self._side)
f = OutboundConnectionFactory(self, relay_handshake)
f = OutboundConnectionFactory(self, relay_handshake, description)
d = ep.connect(f)
# fires with protocol, or ConnectError
@ -368,20 +401,28 @@ class Connector(object):
class OutboundConnectionFactory(ClientFactory, object):
_connector = attrib(validator=provides(IDilationConnector))
_relay_handshake = attrib(validator=optional(instance_of(bytes)))
_description = attrib()
def buildProtocol(self, addr):
p = self._connector.build_protocol(addr)
p = self._connector.build_protocol(addr, self._description)
p.factory = self
if self._relay_handshake is not None:
p.use_relay(self._relay_handshake)
return p
def describe_inbound(addr):
if isinstance(addr, HostnameAddress):
return "<-tcp:%s:%d" % (addr.hostname, addr.port)
elif isinstance(addr, (IPv4Address, IPv6Address)):
return "<-tcp:%s:%d" % (addr.host, addr.port)
return "<-%r" % addr
@attrs
class InboundConnectionFactory(ServerFactory, object):
_connector = attrib(validator=provides(IDilationConnector))
def buildProtocol(self, addr):
p = self._connector.build_protocol(addr)
description = describe_inbound(addr)
p = self._connector.build_protocol(addr, description)
p.factory = self
return p

View File

@ -60,9 +60,9 @@ class Inbound(object):
return True
return False
def update_ack_watermark(self, r):
def update_ack_watermark(self, seqnum):
self._highest_inbound_acked = max(self._highest_inbound_acked,
r.seqnum)
seqnum)
def handle_open(self, scid):
if scid in self._open_subchannels:

View File

@ -7,7 +7,7 @@ from automat import MethodicalMachine
from zope.interface import implementer
from twisted.internet.defer import Deferred, inlineCallbacks, returnValue
from twisted.python import log
from .._interfaces import IDilator, IDilationManager, ISend
from .._interfaces import IDilator, IDilationManager, ISend, ITerminator
from ..util import dict_to_bytes, bytes_to_dict, bytes_to_hexstr
from ..observer import OneShotObserver
from .._key import derive_key
@ -87,17 +87,17 @@ def make_side():
# * if follower calls w.dilate() but not leader, follower waits forever
# in "want", leader waits forever in "wanted"
@attrs
@attrs(cmp=False)
@implementer(IDilationManager)
class Manager(object):
_S = attrib(validator=provides(ISend))
_S = attrib(validator=provides(ISend), repr=False)
_my_side = attrib(validator=instance_of(type(u"")))
_transit_key = attrib(validator=instance_of(bytes))
_transit_key = attrib(validator=instance_of(bytes), repr=False)
_transit_relay_location = attrib(validator=optional(instance_of(str)))
_reactor = attrib()
_eventual_queue = attrib()
_cooperator = attrib()
_no_listen = False # TODO
_reactor = attrib(repr=False)
_eventual_queue = attrib(repr=False)
_cooperator = attrib(repr=False)
_no_listen = attrib(default=False)
_tor = None # TODO
_timing = None # TODO
_next_subchannel_id = None # initialized in choose_role
@ -113,6 +113,7 @@ class Manager(object):
self._connection = None
self._made_first_connection = False
self._first_connected = OneShotObserver(self._eventual_queue)
self._stopped = OneShotObserver(self._eventual_queue)
self._host_addr = _WormholeAddress()
self._next_dilation_phase = 0
@ -133,6 +134,9 @@ class Manager(object):
def when_first_connected(self):
return self._first_connected.when_fired()
def when_stopped(self):
return self._stopped.when_fired()
def send_dilation_phase(self, **fields):
dilation_phase = self._next_dilation_phase
self._next_dilation_phase += 1
@ -160,12 +164,15 @@ class Manager(object):
self._outbound.subchannel_unregisterProducer(sc)
def send_open(self, scid):
assert isinstance(scid, bytes)
self._queue_and_send(Open, scid)
def send_data(self, scid, data):
assert isinstance(scid, bytes)
self._queue_and_send(Data, scid, data)
def send_close(self, scid):
assert isinstance(scid, bytes)
self._queue_and_send(Close, scid)
def _queue_and_send(self, record_type, *args):
@ -401,6 +408,10 @@ class Manager(object):
# been told to shut down.
self._connection.disconnect() # let connection_lost do cleanup
@m.output()
def notify_stopped(self):
self._stopped.fire(None)
# we start CONNECTING when we get rx_PLEASE
WANTING.upon(rx_PLEASE, enter=CONNECTING,
outputs=[choose_role, start_connecting_ignore_message])
@ -440,14 +451,14 @@ class Manager(object):
ABANDONING.upon(rx_HINTS, enter=ABANDONING, outputs=[]) # shouldn't happen
STOPPING.upon(rx_HINTS, enter=STOPPING, outputs=[])
WANTING.upon(stop, enter=STOPPED, outputs=[])
CONNECTING.upon(stop, enter=STOPPED, outputs=[stop_connecting])
WANTING.upon(stop, enter=STOPPED, outputs=[notify_stopped])
CONNECTING.upon(stop, enter=STOPPED, outputs=[stop_connecting, notify_stopped])
CONNECTED.upon(stop, enter=STOPPING, outputs=[abandon_connection])
ABANDONING.upon(stop, enter=STOPPING, outputs=[])
FLUSHING.upon(stop, enter=STOPPED, outputs=[])
LONELY.upon(stop, enter=STOPPED, outputs=[])
STOPPING.upon(connection_lost_leader, enter=STOPPED, outputs=[])
STOPPING.upon(connection_lost_follower, enter=STOPPED, outputs=[])
FLUSHING.upon(stop, enter=STOPPED, outputs=[notify_stopped])
LONELY.upon(stop, enter=STOPPED, outputs=[notify_stopped])
STOPPING.upon(connection_lost_leader, enter=STOPPED, outputs=[notify_stopped])
STOPPING.upon(connection_lost_follower, enter=STOPPED, outputs=[notify_stopped])
@attrs
@ -466,6 +477,7 @@ class Dilator(object):
_reactor = attrib()
_eventual_queue = attrib()
_cooperator = attrib()
_no_listen = attrib(default=False)
def __attrs_post_init__(self):
self._got_versions_d = Deferred()
@ -474,8 +486,9 @@ class Dilator(object):
self._pending_inbound_dilate_messages = deque()
self._manager = None
def wire(self, sender):
def wire(self, sender, terminator):
self._S = ISend(sender)
self._T = ITerminator(terminator)
# this is the primary entry point, called when w.dilate() is invoked
def dilate(self, transit_relay_location=None):
@ -509,7 +522,7 @@ class Dilator(object):
self._transit_key,
self._transit_relay_location,
self._reactor, self._eventual_queue,
self._cooperator)
self._cooperator, no_listen=self._no_listen)
self._manager.start()
while self._pending_inbound_dilate_messages:
@ -519,7 +532,7 @@ class Dilator(object):
yield self._manager.when_first_connected()
# we can open subchannels as soon as we get our first connection
scid0 = b"\x00\x00\x00\x00"
scid0 = to_be4(0)
self._host_addr = _WormholeAddress() # TODO: share with Manager
peer_addr0 = _SubchannelAddress(scid0)
control_ep = ControlEndpoint(peer_addr0)
@ -535,6 +548,19 @@ class Dilator(object):
endpoints = (control_ep, connect_ep, listen_ep)
returnValue(endpoints)
# Called by Terminator after everything else (mailbox, nameplate, server
# connection) has shut down. Expects to fire T.stoppedD() when Dilator is
# stopped too.
def stop(self):
if not self._started:
self._T.stoppedD()
return
if self._started:
self._manager.stop()
# TODO: avoid Deferreds for control flow, hard to serialize
self._manager.when_stopped().addCallback(lambda _: self._T.stoppedD())
# TODO: tolerate multiple calls
# from Boss
def got_key(self, key):

View File

@ -154,7 +154,7 @@ from .connection import KCM, Ping, Pong, Ack
@attrs
@implementer(IOutbound)
@implementer(IOutbound, IPushProducer)
class Outbound(object):
# Manage outbound data: subchannel writes to us, we write to transport
_manager = attrib(validator=provides(IDilationManager))
@ -265,12 +265,12 @@ class Outbound(object):
assert not self._queued_unsent
self._queued_unsent.extend(self._outbound_queue)
# the connection can tell us to pause when we send too much data
c.registerProducer(self, True) # IPushProducer: pause+resume
c.transport.registerProducer(self, True) # IPushProducer: pause+resume
# send our queued messages
self.resumeProducing()
def stop_using_connection(self):
self._connection.unregisterProducer()
self._connection.transport.unregisterProducer()
self._connection = None
self._queued_unsent.clear()
self.pauseProducing()
@ -290,8 +290,8 @@ class Outbound(object):
# Inbound is responsible for tracking the high watermark and deciding
# whether to ignore inbound messages or not
# IProducer: the active connection calls these because we used
# c.registerProducer to ask for them
# IPushProducer: the active connection calls these because we used
# c.transport.registerProducer to ask for them
def pauseProducing(self):
if self._paused:

View File

@ -1 +1,7 @@
LEADER, FOLLOWER = object(), object()
class _Role(object):
def __init__(self, which):
self._which = which
def __repr__(self):
return "Role(%s)" % self._which
LEADER, FOLLOWER = _Role("LEADER"), _Role("FOLLOWER")

View File

@ -55,7 +55,7 @@ class _WormholeAddress(object):
@implementer(IAddress)
@attrs
class _SubchannelAddress(object):
_scid = attrib()
_scid = attrib(validator=instance_of(bytes))
@attrs

View File

@ -246,7 +246,7 @@ class RendezvousConnector(object):
# internal
def _stopped(self, res):
self._T.stopped()
self._T.stoppedRC()
def _tx(self, mtype, **kwargs):
assert self._ws

View File

@ -15,15 +15,17 @@ class Terminator(object):
def __init__(self):
self._mood = None
def wire(self, boss, rendezvous_connector, nameplate, mailbox):
def wire(self, boss, rendezvous_connector, nameplate, mailbox, dilator):
self._B = _interfaces.IBoss(boss)
self._RC = _interfaces.IRendezvousConnector(rendezvous_connector)
self._N = _interfaces.INameplate(nameplate)
self._M = _interfaces.IMailbox(mailbox)
self._D = _interfaces.IDilator(dilator)
# 4*2-1 main states:
# (nm, m, n, 0): nameplate and/or mailbox is active
# 2*2-1+1 main states:
# (nm, m, n, d): nameplate and/or mailbox is active
# (o, ""): open (not-yet-closing), or trying to close
# after closing the mailbox-server connection, we stop Dilation
# S0 is special: we don't hang out in it
# TODO: rename o to 0, "" to 1. "S1" is special/terminal
@ -64,7 +66,11 @@ class Terminator(object):
# def S0(self): pass # unused
@m.state()
def S_stopping(self):
def S_stoppingRC(self):
pass # pragma: no cover
@m.state()
def S_stoppingD(self):
pass # pragma: no cover
@m.state()
@ -88,7 +94,11 @@ class Terminator(object):
# from RendezvousConnector
@m.input()
def stopped(self):
def stoppedRC(self):
pass
@m.input()
def stoppedD(self):
pass
@m.output()
@ -107,6 +117,10 @@ class Terminator(object):
def RC_stop(self):
self._RC.stop()
@m.output()
def stop_dilator(self):
self._D.stop()
@m.output()
def B_closed(self):
self._B.closed()
@ -115,20 +129,19 @@ class Terminator(object):
Snmo.upon(close, enter=Snm, outputs=[close_nameplate, close_mailbox])
Snmo.upon(nameplate_done, enter=Smo, outputs=[])
Sno.upon(close, enter=Sn, outputs=[close_nameplate, close_mailbox])
Sno.upon(close, enter=Sn, outputs=[close_nameplate])
Sno.upon(nameplate_done, enter=S0o, outputs=[])
Smo.upon(close, enter=Sm, outputs=[close_nameplate, close_mailbox])
Smo.upon(close, enter=Sm, outputs=[close_mailbox])
Smo.upon(mailbox_done, enter=S0o, outputs=[])
Snm.upon(mailbox_done, enter=Sn, outputs=[])
Snm.upon(nameplate_done, enter=Sm, outputs=[])
Sn.upon(nameplate_done, enter=S_stopping, outputs=[RC_stop])
S0o.upon(
close,
enter=S_stopping,
outputs=[close_nameplate, close_mailbox, ignore_mood_and_RC_stop])
Sm.upon(mailbox_done, enter=S_stopping, outputs=[RC_stop])
Sn.upon(nameplate_done, enter=S_stoppingRC, outputs=[RC_stop])
Sm.upon(mailbox_done, enter=S_stoppingRC, outputs=[RC_stop])
S0o.upon(close, enter=S_stoppingRC, outputs=[ignore_mood_and_RC_stop])
S_stopping.upon(stopped, enter=S_stopped, outputs=[B_closed])
S_stoppingRC.upon(stoppedRC, enter=S_stoppingD, outputs=[stop_dilator])
S_stoppingD.upon(stoppedD, enter=S_stopped, outputs=[B_closed])

View File

@ -0,0 +1,92 @@
import re
import mock
from twisted.internet import reactor
from twisted.trial import unittest
from twisted.internet.task import Cooperator
from twisted.internet.defer import Deferred, inlineCallbacks
from zope.interface import implementer
from ... import _interfaces
from ...eventual import EventualQueue
from ..._interfaces import ITerminator
from ..._dilation import manager
from ..._dilation._noise import NoiseConnection
@implementer(_interfaces.ISend)
class MySend(object):
def __init__(self, side):
self.rx_phase = 0
self.side = side
def send(self, phase, plaintext):
#print("SEND[%s]" % self.side, phase, plaintext)
self.peer.got(phase, plaintext)
def got(self, phase, plaintext):
d_mo = re.search(r'^dilate-(\d+)$', phase)
p = int(d_mo.group(1))
assert p == self.rx_phase
self.rx_phase += 1
self.dilator.received_dilate(plaintext)
@implementer(ITerminator)
class FakeTerminator(object):
def __init__(self):
self.d = Deferred()
def stoppedD(self):
self.d.callback(None)
class Connect(unittest.TestCase):
@inlineCallbacks
def test1(self):
if not NoiseConnection:
raise unittest.SkipTest("noiseprotocol unavailable")
#print()
send_left = MySend("left")
send_right = MySend("right")
send_left.peer = send_right
send_right.peer = send_left
key = b"\x00"*32
eq = EventualQueue(reactor)
cooperator = Cooperator(scheduler=eq.eventually)
t_left = FakeTerminator()
t_right = FakeTerminator()
d_left = manager.Dilator(reactor, eq, cooperator, no_listen=True)
d_left.wire(send_left, t_left)
d_left.got_key(key)
d_left.got_wormhole_versions({"can-dilate": ["1"]})
send_left.dilator = d_left
d_right = manager.Dilator(reactor, eq, cooperator)
d_right.wire(send_right, t_right)
d_right.got_key(key)
d_right.got_wormhole_versions({"can-dilate": ["1"]})
send_right.dilator = d_right
with mock.patch("wormhole._dilation.connector.ipaddrs.find_addresses",
return_value=["127.0.0.1"]):
eps_left_d = d_left.dilate()
eps_right_d = d_right.dilate()
eps_left = yield eps_left_d
eps_right = yield eps_right_d
#print("left connected", eps_left)
#print("right connected", eps_right)
control_ep_left, connect_ep_left, listen_ep_left = eps_left
control_ep_right, connect_ep_right, listen_ep_right = eps_right
#control_ep_left.connect(
# we normally shut down with w.close(), which calls Dilator.stop(),
# which calls Terminator.stoppedD(), which (after everything else is
# done) calls Boss.stopped
d_left.stop()
d_right.stop()
yield t_left.d
yield t_right.d

View File

@ -9,6 +9,7 @@ from ..._interfaces import IDilationConnector
from ..._dilation.roles import LEADER, FOLLOWER
from ..._dilation.connection import (DilatedConnectionProtocol, encode_record,
KCM, Open, Ack)
from ..._dilation.encode import to_be4
from .common import clear_mock_calls
@ -19,7 +20,7 @@ def make_con(role, use_relay=False):
alsoProvides(connector, IDilationConnector)
n = mock.Mock() # pretends to be a Noise object
n.write_message = mock.Mock(side_effect=[b"handshake"])
c = DilatedConnectionProtocol(eq, role, connector, n,
c = DilatedConnectionProtocol(eq, role, "desc", connector, n,
b"outbound_prologue\n", b"inbound_prologue\n")
if use_relay:
c.use_relay(b"relay_handshake\n")
@ -29,6 +30,10 @@ def make_con(role, use_relay=False):
class Connection(unittest.TestCase):
def test_hashable(self):
c, n, connector, t, eq = make_con(LEADER)
hash(c)
def test_bad_prologue(self):
c, n, connector, t, eq = make_con(LEADER)
c.makeConnection(t)
@ -52,7 +57,7 @@ class Connection(unittest.TestCase):
def _test_no_relay(self, role):
c, n, connector, t, eq = make_con(role)
t_kcm = KCM()
t_open = Open(seqnum=1, scid=0x11223344)
t_open = Open(seqnum=1, scid=to_be4(0x11223344))
t_ack = Ack(resp_seqnum=2)
n.decrypt = mock.Mock(side_effect=[
encode_record(t_kcm),
@ -69,10 +74,20 @@ class Connection(unittest.TestCase):
clear_mock_calls(n, connector, t, m)
c.dataReceived(b"inbound_prologue\n")
self.assertEqual(n.mock_calls, [mock.call.write_message()])
self.assertEqual(connector.mock_calls, [])
exp_handshake = b"\x00\x00\x00\x09handshake"
self.assertEqual(t.mock_calls, [mock.call.write(exp_handshake)])
if role is LEADER:
# the LEADER sends the Noise handshake message immediately upon
# receipt of the prologue
self.assertEqual(n.mock_calls, [mock.call.write_message()])
self.assertEqual(t.mock_calls, [mock.call.write(exp_handshake)])
else:
# however the FOLLOWER waits until receiving the leader's
# handshake before sending their own
self.assertEqual(n.mock_calls, [])
self.assertEqual(t.mock_calls, [])
self.assertEqual(connector.mock_calls, [])
clear_mock_calls(n, connector, t, m)
c.dataReceived(b"\x00\x00\x00\x0Ahandshake2")
@ -84,13 +99,16 @@ class Connection(unittest.TestCase):
self.assertEqual(t.mock_calls, [])
self.assertEqual(c._manager, None)
else:
# we're the follower, so we encrypt and send the KCM immediately
# we're the follower, so we send our Noise handshake, then
# encrypt and send the KCM immediately
self.assertEqual(n.mock_calls, [
mock.call.read_message(b"handshake2"),
mock.call.write_message(),
mock.call.encrypt(encode_record(t_kcm)),
])
self.assertEqual(connector.mock_calls, [])
self.assertEqual(t.mock_calls, [
mock.call.write(exp_handshake),
mock.call.write(exp_kcm)])
self.assertEqual(c._manager, None)
clear_mock_calls(n, connector, t, m)

View File

@ -5,6 +5,7 @@ from zope.interface import alsoProvides
from twisted.trial import unittest
from twisted.internet.task import Clock
from twisted.internet.defer import Deferred
from twisted.internet.address import IPv4Address
from ...eventual import EventualQueue
from ..._interfaces import IDilationManager, IDilationConnector
from ..._hints import DirectTCPV1Hint, RelayV1Hint, TorTCPV1Hint
@ -34,11 +35,11 @@ class Outbound(unittest.TestCase):
p0 = mock.Mock()
c.build_protocol = mock.Mock(return_value=p0)
relay_handshake = None
f = OutboundConnectionFactory(c, relay_handshake)
f = OutboundConnectionFactory(c, relay_handshake, "desc")
addr = object()
p = f.buildProtocol(addr)
self.assertIdentical(p, p0)
self.assertEqual(c.mock_calls, [mock.call.build_protocol(addr)])
self.assertEqual(c.mock_calls, [mock.call.build_protocol(addr, "desc")])
self.assertEqual(p.mock_calls, [])
self.assertIdentical(p.factory, f)
@ -48,11 +49,11 @@ class Outbound(unittest.TestCase):
p0 = mock.Mock()
c.build_protocol = mock.Mock(return_value=p0)
relay_handshake = b"relay handshake"
f = OutboundConnectionFactory(c, relay_handshake)
f = OutboundConnectionFactory(c, relay_handshake, "desc")
addr = object()
p = f.buildProtocol(addr)
self.assertIdentical(p, p0)
self.assertEqual(c.mock_calls, [mock.call.build_protocol(addr)])
self.assertEqual(c.mock_calls, [mock.call.build_protocol(addr, "desc")])
self.assertEqual(p.mock_calls, [mock.call.use_relay(relay_handshake)])
self.assertIdentical(p.factory, f)
@ -63,10 +64,10 @@ class Inbound(unittest.TestCase):
p0 = mock.Mock()
c.build_protocol = mock.Mock(return_value=p0)
f = InboundConnectionFactory(c)
addr = object()
addr = IPv4Address("TCP", "1.2.3.4", 55)
p = f.buildProtocol(addr)
self.assertIdentical(p, p0)
self.assertEqual(c.mock_calls, [mock.call.build_protocol(addr)])
self.assertEqual(c.mock_calls, [mock.call.build_protocol(addr, "<-tcp:1.2.3.4:55")])
self.assertIdentical(p.factory, f)
def make_connector(listen=True, tor=False, relay=None, role=roles.LEADER):
@ -115,13 +116,13 @@ class TestConnector(unittest.TestCase):
return_value=n0) as bn:
with mock.patch("wormhole._dilation.connector.DilatedConnectionProtocol",
return_value=p0) as dcp:
p = c.build_protocol(addr)
p = c.build_protocol(addr, "desc")
self.assertEqual(bn.mock_calls, [mock.call()])
self.assertEqual(n0.mock_calls, [mock.call.set_psks(h.dilation_key),
mock.call.set_as_initiator()])
self.assertIdentical(p, p0)
self.assertEqual(dcp.mock_calls,
[mock.call(h.eq, h.role, c, n0,
[mock.call(h.eq, h.role, "desc", c, n0,
PROLOGUE_LEADER, PROLOGUE_FOLLOWER)])
def test_build_protocol_follower(self):
@ -133,13 +134,13 @@ class TestConnector(unittest.TestCase):
return_value=n0) as bn:
with mock.patch("wormhole._dilation.connector.DilatedConnectionProtocol",
return_value=p0) as dcp:
p = c.build_protocol(addr)
p = c.build_protocol(addr, "desc")
self.assertEqual(bn.mock_calls, [mock.call()])
self.assertEqual(n0.mock_calls, [mock.call.set_psks(h.dilation_key),
mock.call.set_as_responder()])
self.assertIdentical(p, p0)
self.assertEqual(dcp.mock_calls,
[mock.call(h.eq, h.role, c, n0,
[mock.call(h.eq, h.role, "desc", c, n0,
PROLOGUE_FOLLOWER, PROLOGUE_LEADER)])
def test_start_stop(self):
@ -244,7 +245,7 @@ class TestConnector(unittest.TestCase):
with mock.patch("wormhole._dilation.connector.OutboundConnectionFactory",
return_value=f) as ocf:
h.clock.advance(1.0)
self.assertEqual(ocf.mock_calls, [mock.call(c, None)])
self.assertEqual(ocf.mock_calls, [mock.call(c, None, "->tcp:foo:55")])
self.assertEqual(ep.connect.mock_calls, [mock.call(f)])
p = mock.Mock()
d.callback(p)
@ -269,7 +270,7 @@ class TestConnector(unittest.TestCase):
return_value=f) as ocf:
h.clock.advance(1.0)
handshake = build_sided_relay_handshake(h.dilation_key, h.side)
self.assertEqual(ocf.mock_calls, [mock.call(c, handshake)])
self.assertEqual(ocf.mock_calls, [mock.call(c, handshake, "->relay:tcp:foo:55")])
def test_listen_but_tor(self):
c, h = make_connector(listen=True, tor=True, role=roles.LEADER)
@ -388,7 +389,7 @@ class Race(unittest.TestCase):
c.add_candidate(p1)
self.assertEqual(h.manager.mock_calls, [])
h.eq.flush_sync()
self.assertEqual(h.manager.mock_calls, [mock.call.use_connection(p1)])
self.assertEqual(h.manager.mock_calls, [mock.call.connector_connection_made(p1)])
self.assertEqual(p1.mock_calls,
[mock.call.select(h.manager),
mock.call.send_record(KCM())])
@ -409,7 +410,7 @@ class Race(unittest.TestCase):
c.add_candidate(p1)
self.assertEqual(h.manager.mock_calls, [])
h.eq.flush_sync()
self.assertEqual(h.manager.mock_calls, [mock.call.use_connection(p1)])
self.assertEqual(h.manager.mock_calls, [mock.call.connector_connection_made(p1)])
# just like LEADER, but follower doesn't send KCM now (it sent one
# earlier, to tell the leader that this connection looks viable)
self.assertEqual(p1.mock_calls,
@ -432,7 +433,7 @@ class Race(unittest.TestCase):
c.add_candidate(p1)
self.assertEqual(h.manager.mock_calls, [])
h.eq.flush_sync()
self.assertEqual(h.manager.mock_calls, [mock.call.use_connection(p1)])
self.assertEqual(h.manager.mock_calls, [mock.call.connector_connection_made(p1)])
clear_mock_calls(h.manager)
self.assertEqual(p1.mock_calls,
[mock.call.select(h.manager),
@ -454,10 +455,9 @@ class Race(unittest.TestCase):
c.add_candidate(p1)
self.assertEqual(h.manager.mock_calls, [])
h.eq.flush_sync()
self.assertEqual(h.manager.mock_calls, [mock.call.use_connection(p1)])
self.assertEqual(p1.mock_calls,
[mock.call.select(h.manager),
mock.call.send_record(KCM())])
self.assertEqual(h.manager.mock_calls, [mock.call.connector_connection_made(p1)])
c.stop()

View File

@ -0,0 +1,77 @@
from __future__ import print_function, absolute_import, unicode_literals
import wormhole
from twisted.internet import reactor
from twisted.internet.defer import Deferred, inlineCallbacks, gatherResults
from twisted.internet.protocol import Protocol, Factory
from twisted.trial import unittest
from ..common import ServerBase
from ...eventual import EventualQueue
from ..._dilation._noise import NoiseConnection
APPID = u"lothar.com/dilate-test"
def doBoth(d1, d2):
return gatherResults([d1, d2], True)
class L(Protocol):
def connectionMade(self):
print("got connection")
self.transport.write(b"hello\n")
def dataReceived(self, data):
print("dataReceived: {}".format(data))
self.factory.d.callback(data)
def connectionLost(self, why):
print("connectionLost")
class Full(ServerBase, unittest.TestCase):
@inlineCallbacks
def setUp(self):
if not NoiseConnection:
raise unittest.SkipTest("noiseprotocol unavailable")
# test_welcome wants to see [current_cli_version]
yield self._setup_relay(None)
@inlineCallbacks
def test_full(self):
eq = EventualQueue(reactor)
w1 = wormhole.create(APPID, self.relayurl, reactor, _enable_dilate=True)
w2 = wormhole.create(APPID, self.relayurl, reactor, _enable_dilate=True)
w1.allocate_code()
code = yield w1.get_code()
print("code is: {}".format(code))
w2.set_code(code)
yield doBoth(w1.get_verifier(), w2.get_verifier())
print("connected")
eps1_d = w1.dilate()
eps2_d = w2.dilate()
(eps1, eps2) = yield doBoth(eps1_d, eps2_d)
(control_ep1, connect_ep1, listen_ep1) = eps1
(control_ep2, connect_ep2, listen_ep2) = eps2
print("w.dilate ready")
f1 = Factory()
f1.protocol = L
f1.d = Deferred()
f1.d.addCallback(lambda data: eq.fire_eventually(data))
d1 = control_ep1.connect(f1)
f2 = Factory()
f2.protocol = L
f2.d = Deferred()
f2.d.addCallback(lambda data: eq.fire_eventually(data))
d2 = control_ep2.connect(f2)
yield d1
yield d2
print("control endpoints connected")
data1 = yield f1.d
data2 = yield f2.d
self.assertEqual(data1, b"hello\n")
self.assertEqual(data2, b"hello\n")
yield w1.close()
yield w2.close()
test_full.timeout = 30

View File

@ -27,12 +27,12 @@ class InboundTest(unittest.TestCase):
self.assertFalse(i.is_record_old(r2))
self.assertFalse(i.is_record_old(r3))
i.update_ack_watermark(r1)
i.update_ack_watermark(r1.seqnum)
self.assertTrue(i.is_record_old(r1))
self.assertFalse(i.is_record_old(r2))
self.assertFalse(i.is_record_old(r3))
i.update_ack_watermark(r2)
i.update_ack_watermark(r2.seqnum)
self.assertTrue(i.is_record_old(r1))
self.assertTrue(i.is_record_old(r2))
self.assertFalse(i.is_record_old(r3))

View File

@ -5,7 +5,7 @@ from twisted.internet.defer import Deferred
from twisted.internet.task import Clock, Cooperator
import mock
from ...eventual import EventualQueue
from ..._interfaces import ISend, IDilationManager
from ..._interfaces import ISend, IDilationManager, ITerminator
from ...util import dict_to_bytes
from ..._dilation import roles
from ..._dilation.encode import to_be4
@ -32,7 +32,9 @@ def make_dilator():
send = mock.Mock()
alsoProvides(send, ISend)
dil = Dilator(reactor, eq, coop)
dil.wire(send)
terminator = mock.Mock()
alsoProvides(terminator, ITerminator)
dil.wire(send, terminator)
return dil, send, reactor, eq, clock, coop
@ -64,7 +66,7 @@ class TestDilator(unittest.TestCase):
dil.got_wormhole_versions({"can-dilate": ["1"]})
# that should create the Manager
self.assertEqual(ml.mock_calls, [mock.call(send, "us", transit_key,
None, reactor, eq, coop)])
None, reactor, eq, coop, no_listen=False)])
# and tell it to start, and get wait-for-it-to-connect Deferred
self.assertEqual(m.mock_calls, [mock.call.start(),
mock.call.when_first_connected(),
@ -180,7 +182,7 @@ class TestDilator(unittest.TestCase):
return_value="us"):
dil.got_wormhole_versions({"can-dilate": ["1"]})
self.assertEqual(ml.mock_calls, [mock.call(send, "us", b"key",
None, reactor, eq, coop)])
None, reactor, eq, coop, no_listen=False)])
self.assertEqual(m.mock_calls, [mock.call.start(),
mock.call.rx_PLEASE(pleasemsg),
mock.call.rx_HINTS(hintmsg),
@ -198,7 +200,7 @@ class TestDilator(unittest.TestCase):
return_value="us"):
dil.got_wormhole_versions({"can-dilate": ["1"]})
self.assertEqual(ml.mock_calls, [mock.call(send, "us", b"key",
relay, reactor, eq, coop),
relay, reactor, eq, coop, no_listen=False),
mock.call().start(),
mock.call().when_first_connected()])

View File

@ -105,7 +105,7 @@ class OutboundTest(unittest.TestCase):
# as soon as the connection is established, everything is sent
o.use_connection(c)
self.assertEqual(c.mock_calls, [mock.call.registerProducer(o, True),
self.assertEqual(c.mock_calls, [mock.call.transport.registerProducer(o, True),
mock.call.send_record(r1),
mock.call.send_record(r2)])
self.assertEqual(list(o._outbound_queue), [r1, r2])
@ -131,7 +131,7 @@ class OutboundTest(unittest.TestCase):
# after each write. So only r1 should have been sent before getting
# paused
o.use_connection(c)
self.assertEqual(c.mock_calls, [mock.call.registerProducer(o, True),
self.assertEqual(c.mock_calls, [mock.call.transport.registerProducer(o, True),
mock.call.send_record(r1)])
self.assertEqual(list(o._outbound_queue), [r1, r2])
self.assertEqual(list(o._queued_unsent), [r2])
@ -172,7 +172,7 @@ class OutboundTest(unittest.TestCase):
self.assertEqual(list(o._queued_unsent), [])
o.use_connection(c)
self.assertEqual(c.mock_calls, [mock.call.registerProducer(o, True),
self.assertEqual(c.mock_calls, [mock.call.transport.registerProducer(o, True),
mock.call.send_record(r1)])
self.assertEqual(list(o._outbound_queue), [r1, r2])
self.assertEqual(list(o._queued_unsent), [r2])
@ -191,7 +191,7 @@ class OutboundTest(unittest.TestCase):
def test_pause(self):
o, m, c = make_outbound()
o.use_connection(c)
self.assertEqual(c.mock_calls, [mock.call.registerProducer(o, True)])
self.assertEqual(c.mock_calls, [mock.call.transport.registerProducer(o, True)])
self.assertEqual(list(o._outbound_queue), [])
self.assertEqual(list(o._queued_unsent), [])
clear_mock_calls(c)
@ -519,7 +519,7 @@ class OutboundTest(unittest.TestCase):
o.use_connection(c)
o.send_if_connected(KCM())
self.assertEqual(c.mock_calls, [mock.call.registerProducer(o, True),
self.assertEqual(c.mock_calls, [mock.call.transport.registerProducer(o, True),
mock.call.send_record(KCM())])
def test_tolerate_duplicate_pause_resume(self):

View File

@ -13,11 +13,11 @@ class Parse(unittest.TestCase):
self.assertEqual(parse_record(b"\x02\x55\x44\x33\x22"),
Pong(ping_id=b"\x55\x44\x33\x22"))
self.assertEqual(parse_record(b"\x03\x00\x00\x02\x01\x00\x00\x01\x00"),
Open(scid=513, seqnum=256))
Open(scid=b"\x00\x00\x02\x01", seqnum=256))
self.assertEqual(parse_record(b"\x04\x00\x00\x02\x02\x00\x00\x01\x01dataaa"),
Data(scid=514, seqnum=257, data=b"dataaa"))
Data(scid=b"\x00\x00\x02\x02", seqnum=257, data=b"dataaa"))
self.assertEqual(parse_record(b"\x05\x00\x00\x02\x03\x00\x00\x01\x02"),
Close(scid=515, seqnum=258))
Close(scid=b"\x00\x00\x02\x03", seqnum=258))
self.assertEqual(parse_record(b"\x06\x00\x00\x01\x03"),
Ack(resp_seqnum=259))
with mock.patch("wormhole._dilation.connection.log.err") as le:
@ -31,11 +31,11 @@ class Parse(unittest.TestCase):
self.assertEqual(encode_record(KCM()), b"\x00")
self.assertEqual(encode_record(Ping(ping_id=b"ping")), b"\x01ping")
self.assertEqual(encode_record(Pong(ping_id=b"pong")), b"\x02pong")
self.assertEqual(encode_record(Open(scid=65536, seqnum=16)),
self.assertEqual(encode_record(Open(scid=b"\x00\x01\x00\x00", seqnum=16)),
b"\x03\x00\x01\x00\x00\x00\x00\x00\x10")
self.assertEqual(encode_record(Data(scid=65537, seqnum=17, data=b"dataaa")),
self.assertEqual(encode_record(Data(scid=b"\x00\x01\x00\x01", seqnum=17, data=b"dataaa")),
b"\x04\x00\x01\x00\x01\x00\x00\x00\x11dataaa")
self.assertEqual(encode_record(Close(scid=65538, seqnum=18)),
self.assertEqual(encode_record(Close(scid=b"\x00\x01\x00\x02", seqnum=18)),
b"\x05\x00\x01\x00\x02\x00\x00\x00\x12")
self.assertEqual(encode_record(Ack(resp_seqnum=19)),
b"\x06\x00\x00\x00\x13")

View File

@ -6,13 +6,15 @@ from ..._dilation._noise import NoiseInvalidMessage
from ..._dilation.connection import (IFramer, Frame, Prologue,
_Record, Handshake,
Disconnect, Ping)
from ..._dilation.roles import LEADER
def make_record():
f = mock.Mock()
alsoProvides(f, IFramer)
n = mock.Mock() # pretends to be a Noise object
r = _Record(f, n)
r = _Record(f, n, LEADER)
r.set_role_leader()
return r, f, n
@ -30,7 +32,8 @@ class Record(unittest.TestCase):
n.write_message = mock.Mock(return_value=b"tx-handshake")
p1, p2 = object(), object()
n.decrypt = mock.Mock(side_effect=[p1, p2])
r = _Record(f, n)
r = _Record(f, n, LEADER)
r.set_role_leader()
self.assertEqual(f.mock_calls, [])
r.connectionMade()
self.assertEqual(f.mock_calls, [mock.call.connectionMade()])
@ -79,7 +82,8 @@ class Record(unittest.TestCase):
n.write_message = mock.Mock(return_value=b"tx-handshake")
nvm = NoiseInvalidMessage()
n.read_message = mock.Mock(side_effect=nvm)
r = _Record(f, n)
r = _Record(f, n, LEADER)
r.set_role_leader()
self.assertEqual(f.mock_calls, [])
r.connectionMade()
self.assertEqual(f.mock_calls, [mock.call.connectionMade()])
@ -103,7 +107,8 @@ class Record(unittest.TestCase):
n.write_message = mock.Mock(return_value=b"tx-handshake")
nvm = NoiseInvalidMessage()
n.decrypt = mock.Mock(side_effect=nvm)
r = _Record(f, n)
r = _Record(f, n, LEADER)
r.set_role_leader()
self.assertEqual(f.mock_calls, [])
r.connectionMade()
self.assertEqual(f.mock_calls, [mock.call.connectionMade()])
@ -124,7 +129,8 @@ class Record(unittest.TestCase):
f1 = object()
n.encrypt = mock.Mock(return_value=f1)
r1 = Ping(b"pingid")
r = _Record(f, n)
r = _Record(f, n, LEADER)
r.set_role_leader()
self.assertEqual(f.mock_calls, [])
m1 = object()
with mock.patch("wormhole._dilation.connection.encode_record",

View File

@ -1220,7 +1220,8 @@ class Terminator(unittest.TestCase):
rc = Dummy("rc", events, IRendezvousConnector, "stop")
n = Dummy("n", events, INameplate, "close")
m = Dummy("m", events, IMailbox, "close")
t.wire(b, rc, n, m)
d = Dummy("d", events, IDilator, "stop")
t.wire(b, rc, n, m, d)
return t, b, rc, n, m, events
# there are three events, and we need to test all orderings of them
@ -1229,45 +1230,64 @@ class Terminator(unittest.TestCase):
input_events = {
"mailbox": lambda: t.mailbox_done(),
"nameplate": lambda: t.nameplate_done(),
"close": lambda: t.close("happy"),
"rc": lambda: t.close("happy"),
}
close_events = [
("n.close", ),
("m.close", "happy"),
]
if ev1 == "mailbox":
close_events.remove(("m.close", "happy"))
elif ev1 == "nameplate":
close_events.remove(("n.close",))
input_events[ev1]()
expected = []
if ev1 == "close":
if ev1 == "rc":
expected.extend(close_events)
self.assertEqual(events, expected)
events[:] = []
if ev2 == "mailbox":
close_events.remove(("m.close", "happy"))
elif ev2 == "nameplate":
close_events.remove(("n.close",))
input_events[ev2]()
expected = []
if ev2 == "close":
if ev2 == "rc":
expected.extend(close_events)
self.assertEqual(events, expected)
events[:] = []
if ev3 == "mailbox":
close_events.remove(("m.close", "happy"))
elif ev3 == "nameplate":
close_events.remove(("n.close",))
input_events[ev3]()
expected = []
if ev3 == "close":
if ev3 == "rc":
expected.extend(close_events)
expected.append(("rc.stop", ))
self.assertEqual(events, expected)
events[:] = []
t.stopped()
t.stoppedRC()
self.assertEqual(events, [("d.stop", )])
events[:] = []
t.stoppedD()
self.assertEqual(events, [("b.closed", )])
def test_terminate(self):
self._do_test("mailbox", "nameplate", "close")
self._do_test("mailbox", "close", "nameplate")
self._do_test("nameplate", "mailbox", "close")
self._do_test("nameplate", "close", "mailbox")
self._do_test("close", "nameplate", "mailbox")
self._do_test("close", "mailbox", "nameplate")
self._do_test("mailbox", "nameplate", "rc")
self._do_test("mailbox", "rc", "nameplate")
self._do_test("nameplate", "mailbox", "rc")
self._do_test("nameplate", "rc", "mailbox")
self._do_test("rc", "nameplate", "mailbox")
self._do_test("rc", "mailbox", "nameplate")
# TODO: test moods