Merge branch 'dilate-xfer'
Add an integration test which exercises a full w.dilate connection and the control endpoint. Still untested: * reconnecting after the initial TCP connection is lost * resending data that wasn't acked before the connection was lost * (re)sending data that was submitted while no connection was available * the connect- and listen- endpoints
This commit is contained in:
commit
a5e011f786
|
@ -81,8 +81,8 @@ class Boss(object):
|
|||
self._A.wire(self._RC, self._C)
|
||||
self._I.wire(self._C, self._L)
|
||||
self._C.wire(self, self._A, self._N, self._K, self._I)
|
||||
self._T.wire(self, self._RC, self._N, self._M)
|
||||
self._D.wire(self._S)
|
||||
self._T.wire(self, self._RC, self._N, self._M, self._D)
|
||||
self._D.wire(self._S, self._T)
|
||||
|
||||
def _init_other_state(self):
|
||||
self._did_start_code = False
|
||||
|
|
|
@ -4,6 +4,12 @@ except ImportError:
|
|||
class NoiseInvalidMessage(Exception):
|
||||
pass
|
||||
|
||||
try:
|
||||
from noise.exceptions import NoiseHandshakeError
|
||||
except ImportError:
|
||||
class NoiseHandshakeError(Exception):
|
||||
pass
|
||||
|
||||
try:
|
||||
from noise.connection import NoiseConnection
|
||||
except ImportError:
|
||||
|
|
|
@ -11,8 +11,8 @@ from twisted.internet.interfaces import ITransport
|
|||
from .._interfaces import IDilationConnector
|
||||
from ..observer import OneShotObserver
|
||||
from .encode import to_be4, from_be4
|
||||
from .roles import FOLLOWER
|
||||
from ._noise import NoiseInvalidMessage
|
||||
from .roles import LEADER, FOLLOWER
|
||||
from ._noise import NoiseInvalidMessage, NoiseHandshakeError
|
||||
|
||||
# InboundFraming is given data and returns Frames (Noise wire-side
|
||||
# bytestrings). It handles the relay handshake and the prologue. The Frames it
|
||||
|
@ -56,6 +56,23 @@ def first(l):
|
|||
class Disconnect(Exception):
|
||||
pass
|
||||
|
||||
# all connections look like:
|
||||
# (step 1: only for outbound connections)
|
||||
# 1: if we're connecting to a transit relay:
|
||||
# * send "sided relay handshake": "please relay TOKEN for side SIDE\n"
|
||||
# * the relay will send "ok\n" if/when our peer connects
|
||||
# * a non-relay will probably send junk
|
||||
# * wait for "ok\n", hang up if we get anything different
|
||||
# (all subsequent steps are for both inbound and outbound connections)
|
||||
# 2: send PROLOGUE_LEADER/FOLLOWER: "Magic-Wormhole Dilation Handshale v1 (l/f)\n\n"
|
||||
# 3: wait for the opposite PROLOGUE string, else hang up
|
||||
# (everything past this point is a Frame, with be4 length prefix. Frames are
|
||||
# either noise handshake or an encrypted message)
|
||||
# 4: if LEADER, send noise handshake string. if FOLLOWER, wait for it
|
||||
# 5: if FOLLOWER, send noise response string. if LEADER, wait for it
|
||||
# 6: ...
|
||||
|
||||
|
||||
|
||||
RelayOK = namedtuple("RelayOk", [])
|
||||
Prologue = namedtuple("Prologue", [])
|
||||
|
@ -193,7 +210,7 @@ class _Framer(object):
|
|||
def add_and_parse(self, data):
|
||||
# we can't make this an @m.input because we can't change the state
|
||||
# from within an input. Instead, let the state choose the parser to
|
||||
# use, and use the parsed token drive a state transition.
|
||||
# use, then use the parsed token to drive a state transition.
|
||||
self._buffer += data
|
||||
while True:
|
||||
# it'd be nice to use an iterator here, but since self.parse()
|
||||
|
@ -233,7 +250,7 @@ Ping = namedtuple("Ping", ["ping_id"]) # ping_id is arbitrary 4-byte value
|
|||
Pong = namedtuple("Pong", ["ping_id"])
|
||||
Open = namedtuple("Open", ["seqnum", "scid"]) # seqnum is integer
|
||||
Data = namedtuple("Data", ["seqnum", "scid", "data"])
|
||||
Close = namedtuple("Close", ["seqnum", "scid"]) # scid is integer
|
||||
Close = namedtuple("Close", ["seqnum", "scid"]) # scid is arbitrary 4-byte value
|
||||
Ack = namedtuple("Ack", ["resp_seqnum"]) # resp_seqnum is integer
|
||||
Records = (KCM, Ping, Pong, Open, Data, Close, Ack)
|
||||
Handshake_or_Records = (Handshake,) + Records
|
||||
|
@ -258,16 +275,16 @@ def parse_record(plaintext):
|
|||
ping_id = plaintext[1:5]
|
||||
return Pong(ping_id)
|
||||
if msgtype == T_OPEN:
|
||||
scid = from_be4(plaintext[1:5])
|
||||
scid = plaintext[1:5]
|
||||
seqnum = from_be4(plaintext[5:9])
|
||||
return Open(seqnum, scid)
|
||||
if msgtype == T_DATA:
|
||||
scid = from_be4(plaintext[1:5])
|
||||
scid = plaintext[1:5]
|
||||
seqnum = from_be4(plaintext[5:9])
|
||||
data = plaintext[9:]
|
||||
return Data(seqnum, scid, data)
|
||||
if msgtype == T_CLOSE:
|
||||
scid = from_be4(plaintext[1:5])
|
||||
scid = plaintext[1:5]
|
||||
seqnum = from_be4(plaintext[5:9])
|
||||
return Close(seqnum, scid)
|
||||
if msgtype == T_ACK:
|
||||
|
@ -285,28 +302,36 @@ def encode_record(r):
|
|||
if isinstance(r, Pong):
|
||||
return b"\x02" + r.ping_id
|
||||
if isinstance(r, Open):
|
||||
assert isinstance(r.scid, six.integer_types)
|
||||
assert isinstance(r.scid, bytes)
|
||||
assert len(r.scid) == 4
|
||||
assert isinstance(r.seqnum, six.integer_types)
|
||||
return b"\x03" + to_be4(r.scid) + to_be4(r.seqnum)
|
||||
return b"\x03" + r.scid + to_be4(r.seqnum)
|
||||
if isinstance(r, Data):
|
||||
assert isinstance(r.scid, six.integer_types)
|
||||
assert isinstance(r.scid, bytes)
|
||||
assert len(r.scid) == 4
|
||||
assert isinstance(r.seqnum, six.integer_types)
|
||||
return b"\x04" + to_be4(r.scid) + to_be4(r.seqnum) + r.data
|
||||
return b"\x04" + r.scid + to_be4(r.seqnum) + r.data
|
||||
if isinstance(r, Close):
|
||||
assert isinstance(r.scid, six.integer_types)
|
||||
assert isinstance(r.scid, bytes)
|
||||
assert len(r.scid) == 4
|
||||
assert isinstance(r.seqnum, six.integer_types)
|
||||
return b"\x05" + to_be4(r.scid) + to_be4(r.seqnum)
|
||||
return b"\x05" + r.scid + to_be4(r.seqnum)
|
||||
if isinstance(r, Ack):
|
||||
assert isinstance(r.resp_seqnum, six.integer_types)
|
||||
return b"\x06" + to_be4(r.resp_seqnum)
|
||||
raise TypeError(r)
|
||||
|
||||
|
||||
def _is_role(_record, _attr, value):
|
||||
if value not in [LEADER, FOLLOWER]:
|
||||
raise ValueError("role must be LEADER or FOLLOWER")
|
||||
|
||||
@attrs
|
||||
@implementer(IRecord)
|
||||
class _Record(object):
|
||||
_framer = attrib(validator=provides(IFramer))
|
||||
_noise = attrib()
|
||||
_role = attrib(default="unspecified", validator=_is_role) # for debugging
|
||||
|
||||
n = MethodicalMachine()
|
||||
# TODO: set_trace
|
||||
|
@ -321,17 +346,37 @@ class _Record(object):
|
|||
# states: want_prologue, want_handshake, want_record
|
||||
|
||||
@n.state(initial=True)
|
||||
def want_prologue(self):
|
||||
def no_role_set(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@n.state()
|
||||
def want_handshake(self):
|
||||
def want_prologue_leader(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@n.state()
|
||||
def want_prologue_follower(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@n.state()
|
||||
def want_handshake_leader(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@n.state()
|
||||
def want_handshake_follower(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@n.state()
|
||||
def want_message(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@n.input()
|
||||
def set_role_leader(self):
|
||||
pass
|
||||
|
||||
@n.input()
|
||||
def set_role_follower(self):
|
||||
pass
|
||||
|
||||
@n.input()
|
||||
def got_prologue(self):
|
||||
pass
|
||||
|
@ -340,9 +385,20 @@ class _Record(object):
|
|||
def got_frame(self, frame):
|
||||
pass
|
||||
|
||||
@n.output()
|
||||
def ignore_and_send_handshake(self, frame):
|
||||
self._send_handshake()
|
||||
|
||||
@n.output()
|
||||
def send_handshake(self):
|
||||
handshake = self._noise.write_message() # generate the ephemeral key
|
||||
self._send_handshake()
|
||||
|
||||
def _send_handshake(self):
|
||||
try:
|
||||
handshake = self._noise.write_message() # generate the ephemeral key
|
||||
except NoiseHandshakeError as e:
|
||||
log.err(e, "noise error during handshake")
|
||||
raise
|
||||
self._framer.send_frame(handshake)
|
||||
|
||||
@n.output()
|
||||
|
@ -367,10 +423,19 @@ class _Record(object):
|
|||
raise Disconnect()
|
||||
return parse_record(message)
|
||||
|
||||
want_prologue.upon(got_prologue, outputs=[send_handshake],
|
||||
enter=want_handshake)
|
||||
want_handshake.upon(got_frame, outputs=[process_handshake],
|
||||
collector=first, enter=want_message)
|
||||
no_role_set.upon(set_role_leader, outputs=[], enter=want_prologue_leader)
|
||||
want_prologue_leader.upon(got_prologue, outputs=[send_handshake],
|
||||
enter=want_handshake_leader)
|
||||
want_handshake_leader.upon(got_frame, outputs=[process_handshake],
|
||||
collector=first, enter=want_message)
|
||||
|
||||
no_role_set.upon(set_role_follower, outputs=[], enter=want_prologue_follower)
|
||||
want_prologue_follower.upon(got_prologue, outputs=[],
|
||||
enter=want_handshake_follower)
|
||||
want_handshake_follower.upon(got_frame, outputs=[process_handshake,
|
||||
ignore_and_send_handshake],
|
||||
collector=first, enter=want_message)
|
||||
|
||||
want_message.upon(got_frame, outputs=[decrypt_message],
|
||||
collector=first, enter=want_message)
|
||||
|
||||
|
@ -393,7 +458,7 @@ class _Record(object):
|
|||
self._framer.send_frame(frame)
|
||||
|
||||
|
||||
@attrs
|
||||
@attrs(cmp=False)
|
||||
class DilatedConnectionProtocol(Protocol, object):
|
||||
"""I manage an L2 connection.
|
||||
|
||||
|
@ -408,12 +473,13 @@ class DilatedConnectionProtocol(Protocol, object):
|
|||
At any given time, there is at most one active L2 connection.
|
||||
"""
|
||||
|
||||
_eventual_queue = attrib()
|
||||
_eventual_queue = attrib(repr=False)
|
||||
_role = attrib()
|
||||
_connector = attrib(validator=provides(IDilationConnector))
|
||||
_noise = attrib()
|
||||
_outbound_prologue = attrib(validator=instance_of(bytes))
|
||||
_inbound_prologue = attrib(validator=instance_of(bytes))
|
||||
_description = attrib()
|
||||
_connector = attrib(validator=provides(IDilationConnector), repr=False)
|
||||
_noise = attrib(repr=False)
|
||||
_outbound_prologue = attrib(validator=instance_of(bytes), repr=False)
|
||||
_inbound_prologue = attrib(validator=instance_of(bytes), repr=False)
|
||||
|
||||
_use_relay = False
|
||||
_relay_handshake = None
|
||||
|
@ -457,6 +523,8 @@ class DilatedConnectionProtocol(Protocol, object):
|
|||
@m.output()
|
||||
def set_manager(self, manager):
|
||||
self._manager = manager
|
||||
self.when_disconnected().addCallback(lambda c:
|
||||
manager.connector_connection_lost())
|
||||
|
||||
@m.output()
|
||||
def can_send_records(self, manager):
|
||||
|
@ -493,12 +561,20 @@ class DilatedConnectionProtocol(Protocol, object):
|
|||
# IProtocol methods
|
||||
|
||||
def connectionMade(self):
|
||||
framer = _Framer(self.transport,
|
||||
self._outbound_prologue, self._inbound_prologue)
|
||||
if self._use_relay:
|
||||
framer.use_relay(self._relay_handshake)
|
||||
self._record = _Record(framer, self._noise)
|
||||
self._record.connectionMade()
|
||||
try:
|
||||
framer = _Framer(self.transport,
|
||||
self._outbound_prologue, self._inbound_prologue)
|
||||
if self._use_relay:
|
||||
framer.use_relay(self._relay_handshake)
|
||||
self._record = _Record(framer, self._noise, self._role)
|
||||
if self._role is LEADER:
|
||||
self._record.set_role_leader()
|
||||
else:
|
||||
self._record.set_role_follower()
|
||||
self._record.connectionMade()
|
||||
except:
|
||||
log.err()
|
||||
raise
|
||||
|
||||
def dataReceived(self, data):
|
||||
try:
|
||||
|
|
|
@ -9,6 +9,7 @@ from twisted.internet.task import deferLater
|
|||
from twisted.internet.defer import DeferredList
|
||||
from twisted.internet.endpoints import serverFromString
|
||||
from twisted.internet.protocol import ClientFactory, ServerFactory
|
||||
from twisted.internet.address import HostnameAddress, IPv4Address, IPv6Address
|
||||
from twisted.python import log
|
||||
from .. import ipaddrs # TODO: move into _dilation/
|
||||
from .._interfaces import IDilationConnector, IDilationManager
|
||||
|
@ -39,9 +40,36 @@ NOISEPROTO = b"Noise_NNpsk0_25519_ChaChaPoly_BLAKE2s"
|
|||
def build_noise():
|
||||
return NoiseConnection.from_name(NOISEPROTO)
|
||||
|
||||
@attrs
|
||||
@attrs(cmp=False)
|
||||
@implementer(IDilationConnector)
|
||||
class Connector(object):
|
||||
"""I manage a single generation of connection.
|
||||
|
||||
The Manager creates one of me at a time, whenever it wants a connection
|
||||
(which is always, once w.dilate() has been called and we know the remote
|
||||
end can dilate, and is expressed by the Manager calling my .start()
|
||||
method). I am discarded when my established connection is lost (and if we
|
||||
still want to be connected, a new generation is started and a new
|
||||
Connector is created). I am also discarded if we stop wanting to be
|
||||
connected (which the Manager expresses by calling my .stop() method).
|
||||
|
||||
I manage the race between multiple connections for a specific generation
|
||||
of the dilated connection.
|
||||
|
||||
I send connection hints when my InboundConnectionFactory yields addresses
|
||||
(self.listener_ready), and I initiate outbond connections (with
|
||||
OutboundConnectionFactory) as I receive connection hints from my peer
|
||||
(self.got_hints). Both factories use my build_protocol() method to create
|
||||
connection.DilatedConnectionProtocol instances. I track these protocol
|
||||
instances until one finishes negotiation and wins the race. I then shut
|
||||
down the others, remember the winner as self._winning_connection, and
|
||||
deliver the winner to manager.connector_connection_made(c).
|
||||
|
||||
When an active connection is lost, we call manager.connector_connection_lost,
|
||||
allowing the manager to decide whether it wants to start a new generation
|
||||
or not.
|
||||
"""
|
||||
|
||||
_dilation_key = attrib(validator=instance_of(type(b"")))
|
||||
_transit_relay_location = attrib(validator=optional(instance_of(type(u""))))
|
||||
_manager = attrib(validator=provides(IDilationManager))
|
||||
|
@ -83,7 +111,7 @@ class Connector(object):
|
|||
{"type": "relay-v1"},
|
||||
]
|
||||
|
||||
def build_protocol(self, addr):
|
||||
def build_protocol(self, addr, description):
|
||||
# encryption: let's use Noise NNpsk0 (or maybe NNpsk2). That uses
|
||||
# ephemeral keys plus a pre-shared symmetric key (the Transit key), a
|
||||
# different one for each potential connection.
|
||||
|
@ -98,6 +126,7 @@ class Connector(object):
|
|||
outbound_prologue = PROLOGUE_FOLLOWER
|
||||
inbound_prologue = PROLOGUE_LEADER
|
||||
p = DilatedConnectionProtocol(self._eventual_queue, self._role,
|
||||
description,
|
||||
self, noise,
|
||||
outbound_prologue, inbound_prologue)
|
||||
return p
|
||||
|
@ -181,10 +210,13 @@ class Connector(object):
|
|||
self.stop_pending_connections()
|
||||
|
||||
c.select(self._manager) # subsequent frames go directly to the manager
|
||||
# c.select also wires up when_disconnected() to fire
|
||||
# manager.connector_connection_lost(). TODO: rename this, since the
|
||||
# Connector is no longer the one calling it
|
||||
if self._role is LEADER:
|
||||
# TODO: this should live in Connection
|
||||
c.send_record(KCM()) # leader sends KCM now
|
||||
self._manager.use_connection(c) # manager sends frames to Connection
|
||||
self._manager.connector_connection_made(c) # manager sends frames to Connection
|
||||
|
||||
@m.output()
|
||||
def stop_everything(self):
|
||||
|
@ -199,11 +231,12 @@ class Connector(object):
|
|||
return d # synchronization for tests
|
||||
|
||||
def stop_pending_connectors(self):
|
||||
return DeferredList([d.cancel() for d in self._pending_connectors])
|
||||
for d in self._pending_connectors:
|
||||
d.cancel()
|
||||
|
||||
def stop_pending_connections(self):
|
||||
d = self._pending_connections.when_next_empty()
|
||||
[c.loseConnection() for c in self._pending_connections]
|
||||
[c.disconnect() for c in self._pending_connections]
|
||||
return d
|
||||
|
||||
def break_cycles(self):
|
||||
|
@ -337,7 +370,7 @@ class Connector(object):
|
|||
if is_relay:
|
||||
relay_handshake = build_sided_relay_handshake(self._dilation_key,
|
||||
self._side)
|
||||
f = OutboundConnectionFactory(self, relay_handshake)
|
||||
f = OutboundConnectionFactory(self, relay_handshake, description)
|
||||
d = ep.connect(f)
|
||||
# fires with protocol, or ConnectError
|
||||
|
||||
|
@ -368,20 +401,28 @@ class Connector(object):
|
|||
class OutboundConnectionFactory(ClientFactory, object):
|
||||
_connector = attrib(validator=provides(IDilationConnector))
|
||||
_relay_handshake = attrib(validator=optional(instance_of(bytes)))
|
||||
_description = attrib()
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
p = self._connector.build_protocol(addr)
|
||||
p = self._connector.build_protocol(addr, self._description)
|
||||
p.factory = self
|
||||
if self._relay_handshake is not None:
|
||||
p.use_relay(self._relay_handshake)
|
||||
return p
|
||||
|
||||
def describe_inbound(addr):
|
||||
if isinstance(addr, HostnameAddress):
|
||||
return "<-tcp:%s:%d" % (addr.hostname, addr.port)
|
||||
elif isinstance(addr, (IPv4Address, IPv6Address)):
|
||||
return "<-tcp:%s:%d" % (addr.host, addr.port)
|
||||
return "<-%r" % addr
|
||||
|
||||
@attrs
|
||||
class InboundConnectionFactory(ServerFactory, object):
|
||||
_connector = attrib(validator=provides(IDilationConnector))
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
p = self._connector.build_protocol(addr)
|
||||
description = describe_inbound(addr)
|
||||
p = self._connector.build_protocol(addr, description)
|
||||
p.factory = self
|
||||
return p
|
||||
|
|
|
@ -60,9 +60,9 @@ class Inbound(object):
|
|||
return True
|
||||
return False
|
||||
|
||||
def update_ack_watermark(self, r):
|
||||
def update_ack_watermark(self, seqnum):
|
||||
self._highest_inbound_acked = max(self._highest_inbound_acked,
|
||||
r.seqnum)
|
||||
seqnum)
|
||||
|
||||
def handle_open(self, scid):
|
||||
if scid in self._open_subchannels:
|
||||
|
|
|
@ -7,7 +7,7 @@ from automat import MethodicalMachine
|
|||
from zope.interface import implementer
|
||||
from twisted.internet.defer import Deferred, inlineCallbacks, returnValue
|
||||
from twisted.python import log
|
||||
from .._interfaces import IDilator, IDilationManager, ISend
|
||||
from .._interfaces import IDilator, IDilationManager, ISend, ITerminator
|
||||
from ..util import dict_to_bytes, bytes_to_dict, bytes_to_hexstr
|
||||
from ..observer import OneShotObserver
|
||||
from .._key import derive_key
|
||||
|
@ -87,17 +87,17 @@ def make_side():
|
|||
# * if follower calls w.dilate() but not leader, follower waits forever
|
||||
# in "want", leader waits forever in "wanted"
|
||||
|
||||
@attrs
|
||||
@attrs(cmp=False)
|
||||
@implementer(IDilationManager)
|
||||
class Manager(object):
|
||||
_S = attrib(validator=provides(ISend))
|
||||
_S = attrib(validator=provides(ISend), repr=False)
|
||||
_my_side = attrib(validator=instance_of(type(u"")))
|
||||
_transit_key = attrib(validator=instance_of(bytes))
|
||||
_transit_key = attrib(validator=instance_of(bytes), repr=False)
|
||||
_transit_relay_location = attrib(validator=optional(instance_of(str)))
|
||||
_reactor = attrib()
|
||||
_eventual_queue = attrib()
|
||||
_cooperator = attrib()
|
||||
_no_listen = False # TODO
|
||||
_reactor = attrib(repr=False)
|
||||
_eventual_queue = attrib(repr=False)
|
||||
_cooperator = attrib(repr=False)
|
||||
_no_listen = attrib(default=False)
|
||||
_tor = None # TODO
|
||||
_timing = None # TODO
|
||||
_next_subchannel_id = None # initialized in choose_role
|
||||
|
@ -113,6 +113,7 @@ class Manager(object):
|
|||
self._connection = None
|
||||
self._made_first_connection = False
|
||||
self._first_connected = OneShotObserver(self._eventual_queue)
|
||||
self._stopped = OneShotObserver(self._eventual_queue)
|
||||
self._host_addr = _WormholeAddress()
|
||||
|
||||
self._next_dilation_phase = 0
|
||||
|
@ -133,6 +134,9 @@ class Manager(object):
|
|||
def when_first_connected(self):
|
||||
return self._first_connected.when_fired()
|
||||
|
||||
def when_stopped(self):
|
||||
return self._stopped.when_fired()
|
||||
|
||||
def send_dilation_phase(self, **fields):
|
||||
dilation_phase = self._next_dilation_phase
|
||||
self._next_dilation_phase += 1
|
||||
|
@ -160,12 +164,15 @@ class Manager(object):
|
|||
self._outbound.subchannel_unregisterProducer(sc)
|
||||
|
||||
def send_open(self, scid):
|
||||
assert isinstance(scid, bytes)
|
||||
self._queue_and_send(Open, scid)
|
||||
|
||||
def send_data(self, scid, data):
|
||||
assert isinstance(scid, bytes)
|
||||
self._queue_and_send(Data, scid, data)
|
||||
|
||||
def send_close(self, scid):
|
||||
assert isinstance(scid, bytes)
|
||||
self._queue_and_send(Close, scid)
|
||||
|
||||
def _queue_and_send(self, record_type, *args):
|
||||
|
@ -401,6 +408,10 @@ class Manager(object):
|
|||
# been told to shut down.
|
||||
self._connection.disconnect() # let connection_lost do cleanup
|
||||
|
||||
@m.output()
|
||||
def notify_stopped(self):
|
||||
self._stopped.fire(None)
|
||||
|
||||
# we start CONNECTING when we get rx_PLEASE
|
||||
WANTING.upon(rx_PLEASE, enter=CONNECTING,
|
||||
outputs=[choose_role, start_connecting_ignore_message])
|
||||
|
@ -440,14 +451,14 @@ class Manager(object):
|
|||
ABANDONING.upon(rx_HINTS, enter=ABANDONING, outputs=[]) # shouldn't happen
|
||||
STOPPING.upon(rx_HINTS, enter=STOPPING, outputs=[])
|
||||
|
||||
WANTING.upon(stop, enter=STOPPED, outputs=[])
|
||||
CONNECTING.upon(stop, enter=STOPPED, outputs=[stop_connecting])
|
||||
WANTING.upon(stop, enter=STOPPED, outputs=[notify_stopped])
|
||||
CONNECTING.upon(stop, enter=STOPPED, outputs=[stop_connecting, notify_stopped])
|
||||
CONNECTED.upon(stop, enter=STOPPING, outputs=[abandon_connection])
|
||||
ABANDONING.upon(stop, enter=STOPPING, outputs=[])
|
||||
FLUSHING.upon(stop, enter=STOPPED, outputs=[])
|
||||
LONELY.upon(stop, enter=STOPPED, outputs=[])
|
||||
STOPPING.upon(connection_lost_leader, enter=STOPPED, outputs=[])
|
||||
STOPPING.upon(connection_lost_follower, enter=STOPPED, outputs=[])
|
||||
FLUSHING.upon(stop, enter=STOPPED, outputs=[notify_stopped])
|
||||
LONELY.upon(stop, enter=STOPPED, outputs=[notify_stopped])
|
||||
STOPPING.upon(connection_lost_leader, enter=STOPPED, outputs=[notify_stopped])
|
||||
STOPPING.upon(connection_lost_follower, enter=STOPPED, outputs=[notify_stopped])
|
||||
|
||||
|
||||
@attrs
|
||||
|
@ -466,6 +477,7 @@ class Dilator(object):
|
|||
_reactor = attrib()
|
||||
_eventual_queue = attrib()
|
||||
_cooperator = attrib()
|
||||
_no_listen = attrib(default=False)
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
self._got_versions_d = Deferred()
|
||||
|
@ -474,8 +486,9 @@ class Dilator(object):
|
|||
self._pending_inbound_dilate_messages = deque()
|
||||
self._manager = None
|
||||
|
||||
def wire(self, sender):
|
||||
def wire(self, sender, terminator):
|
||||
self._S = ISend(sender)
|
||||
self._T = ITerminator(terminator)
|
||||
|
||||
# this is the primary entry point, called when w.dilate() is invoked
|
||||
def dilate(self, transit_relay_location=None):
|
||||
|
@ -509,7 +522,7 @@ class Dilator(object):
|
|||
self._transit_key,
|
||||
self._transit_relay_location,
|
||||
self._reactor, self._eventual_queue,
|
||||
self._cooperator)
|
||||
self._cooperator, no_listen=self._no_listen)
|
||||
self._manager.start()
|
||||
|
||||
while self._pending_inbound_dilate_messages:
|
||||
|
@ -519,7 +532,7 @@ class Dilator(object):
|
|||
yield self._manager.when_first_connected()
|
||||
|
||||
# we can open subchannels as soon as we get our first connection
|
||||
scid0 = b"\x00\x00\x00\x00"
|
||||
scid0 = to_be4(0)
|
||||
self._host_addr = _WormholeAddress() # TODO: share with Manager
|
||||
peer_addr0 = _SubchannelAddress(scid0)
|
||||
control_ep = ControlEndpoint(peer_addr0)
|
||||
|
@ -535,6 +548,19 @@ class Dilator(object):
|
|||
endpoints = (control_ep, connect_ep, listen_ep)
|
||||
returnValue(endpoints)
|
||||
|
||||
# Called by Terminator after everything else (mailbox, nameplate, server
|
||||
# connection) has shut down. Expects to fire T.stoppedD() when Dilator is
|
||||
# stopped too.
|
||||
def stop(self):
|
||||
if not self._started:
|
||||
self._T.stoppedD()
|
||||
return
|
||||
if self._started:
|
||||
self._manager.stop()
|
||||
# TODO: avoid Deferreds for control flow, hard to serialize
|
||||
self._manager.when_stopped().addCallback(lambda _: self._T.stoppedD())
|
||||
# TODO: tolerate multiple calls
|
||||
|
||||
# from Boss
|
||||
|
||||
def got_key(self, key):
|
||||
|
|
|
@ -154,7 +154,7 @@ from .connection import KCM, Ping, Pong, Ack
|
|||
|
||||
|
||||
@attrs
|
||||
@implementer(IOutbound)
|
||||
@implementer(IOutbound, IPushProducer)
|
||||
class Outbound(object):
|
||||
# Manage outbound data: subchannel writes to us, we write to transport
|
||||
_manager = attrib(validator=provides(IDilationManager))
|
||||
|
@ -265,12 +265,12 @@ class Outbound(object):
|
|||
assert not self._queued_unsent
|
||||
self._queued_unsent.extend(self._outbound_queue)
|
||||
# the connection can tell us to pause when we send too much data
|
||||
c.registerProducer(self, True) # IPushProducer: pause+resume
|
||||
c.transport.registerProducer(self, True) # IPushProducer: pause+resume
|
||||
# send our queued messages
|
||||
self.resumeProducing()
|
||||
|
||||
def stop_using_connection(self):
|
||||
self._connection.unregisterProducer()
|
||||
self._connection.transport.unregisterProducer()
|
||||
self._connection = None
|
||||
self._queued_unsent.clear()
|
||||
self.pauseProducing()
|
||||
|
@ -290,8 +290,8 @@ class Outbound(object):
|
|||
# Inbound is responsible for tracking the high watermark and deciding
|
||||
# whether to ignore inbound messages or not
|
||||
|
||||
# IProducer: the active connection calls these because we used
|
||||
# c.registerProducer to ask for them
|
||||
# IPushProducer: the active connection calls these because we used
|
||||
# c.transport.registerProducer to ask for them
|
||||
|
||||
def pauseProducing(self):
|
||||
if self._paused:
|
||||
|
|
|
@ -1 +1,7 @@
|
|||
LEADER, FOLLOWER = object(), object()
|
||||
class _Role(object):
|
||||
def __init__(self, which):
|
||||
self._which = which
|
||||
def __repr__(self):
|
||||
return "Role(%s)" % self._which
|
||||
|
||||
LEADER, FOLLOWER = _Role("LEADER"), _Role("FOLLOWER")
|
||||
|
|
|
@ -55,7 +55,7 @@ class _WormholeAddress(object):
|
|||
@implementer(IAddress)
|
||||
@attrs
|
||||
class _SubchannelAddress(object):
|
||||
_scid = attrib()
|
||||
_scid = attrib(validator=instance_of(bytes))
|
||||
|
||||
|
||||
@attrs
|
||||
|
|
|
@ -246,7 +246,7 @@ class RendezvousConnector(object):
|
|||
|
||||
# internal
|
||||
def _stopped(self, res):
|
||||
self._T.stopped()
|
||||
self._T.stoppedRC()
|
||||
|
||||
def _tx(self, mtype, **kwargs):
|
||||
assert self._ws
|
||||
|
|
|
@ -15,15 +15,17 @@ class Terminator(object):
|
|||
def __init__(self):
|
||||
self._mood = None
|
||||
|
||||
def wire(self, boss, rendezvous_connector, nameplate, mailbox):
|
||||
def wire(self, boss, rendezvous_connector, nameplate, mailbox, dilator):
|
||||
self._B = _interfaces.IBoss(boss)
|
||||
self._RC = _interfaces.IRendezvousConnector(rendezvous_connector)
|
||||
self._N = _interfaces.INameplate(nameplate)
|
||||
self._M = _interfaces.IMailbox(mailbox)
|
||||
self._D = _interfaces.IDilator(dilator)
|
||||
|
||||
# 4*2-1 main states:
|
||||
# (nm, m, n, 0): nameplate and/or mailbox is active
|
||||
# 2*2-1+1 main states:
|
||||
# (nm, m, n, d): nameplate and/or mailbox is active
|
||||
# (o, ""): open (not-yet-closing), or trying to close
|
||||
# after closing the mailbox-server connection, we stop Dilation
|
||||
# S0 is special: we don't hang out in it
|
||||
|
||||
# TODO: rename o to 0, "" to 1. "S1" is special/terminal
|
||||
|
@ -64,7 +66,11 @@ class Terminator(object):
|
|||
# def S0(self): pass # unused
|
||||
|
||||
@m.state()
|
||||
def S_stopping(self):
|
||||
def S_stoppingRC(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.state()
|
||||
def S_stoppingD(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.state()
|
||||
|
@ -88,7 +94,11 @@ class Terminator(object):
|
|||
|
||||
# from RendezvousConnector
|
||||
@m.input()
|
||||
def stopped(self):
|
||||
def stoppedRC(self):
|
||||
pass
|
||||
|
||||
@m.input()
|
||||
def stoppedD(self):
|
||||
pass
|
||||
|
||||
@m.output()
|
||||
|
@ -107,6 +117,10 @@ class Terminator(object):
|
|||
def RC_stop(self):
|
||||
self._RC.stop()
|
||||
|
||||
@m.output()
|
||||
def stop_dilator(self):
|
||||
self._D.stop()
|
||||
|
||||
@m.output()
|
||||
def B_closed(self):
|
||||
self._B.closed()
|
||||
|
@ -115,20 +129,19 @@ class Terminator(object):
|
|||
Snmo.upon(close, enter=Snm, outputs=[close_nameplate, close_mailbox])
|
||||
Snmo.upon(nameplate_done, enter=Smo, outputs=[])
|
||||
|
||||
Sno.upon(close, enter=Sn, outputs=[close_nameplate, close_mailbox])
|
||||
Sno.upon(close, enter=Sn, outputs=[close_nameplate])
|
||||
Sno.upon(nameplate_done, enter=S0o, outputs=[])
|
||||
|
||||
Smo.upon(close, enter=Sm, outputs=[close_nameplate, close_mailbox])
|
||||
Smo.upon(close, enter=Sm, outputs=[close_mailbox])
|
||||
Smo.upon(mailbox_done, enter=S0o, outputs=[])
|
||||
|
||||
Snm.upon(mailbox_done, enter=Sn, outputs=[])
|
||||
Snm.upon(nameplate_done, enter=Sm, outputs=[])
|
||||
|
||||
Sn.upon(nameplate_done, enter=S_stopping, outputs=[RC_stop])
|
||||
S0o.upon(
|
||||
close,
|
||||
enter=S_stopping,
|
||||
outputs=[close_nameplate, close_mailbox, ignore_mood_and_RC_stop])
|
||||
Sm.upon(mailbox_done, enter=S_stopping, outputs=[RC_stop])
|
||||
Sn.upon(nameplate_done, enter=S_stoppingRC, outputs=[RC_stop])
|
||||
Sm.upon(mailbox_done, enter=S_stoppingRC, outputs=[RC_stop])
|
||||
S0o.upon(close, enter=S_stoppingRC, outputs=[ignore_mood_and_RC_stop])
|
||||
|
||||
S_stopping.upon(stopped, enter=S_stopped, outputs=[B_closed])
|
||||
S_stoppingRC.upon(stoppedRC, enter=S_stoppingD, outputs=[stop_dilator])
|
||||
|
||||
S_stoppingD.upon(stoppedD, enter=S_stopped, outputs=[B_closed])
|
||||
|
|
92
src/wormhole/test/dilate/test_connect.py
Normal file
92
src/wormhole/test/dilate/test_connect.py
Normal file
|
@ -0,0 +1,92 @@
|
|||
import re
|
||||
import mock
|
||||
from twisted.internet import reactor
|
||||
from twisted.trial import unittest
|
||||
from twisted.internet.task import Cooperator
|
||||
from twisted.internet.defer import Deferred, inlineCallbacks
|
||||
from zope.interface import implementer
|
||||
|
||||
from ... import _interfaces
|
||||
from ...eventual import EventualQueue
|
||||
from ..._interfaces import ITerminator
|
||||
from ..._dilation import manager
|
||||
from ..._dilation._noise import NoiseConnection
|
||||
|
||||
|
||||
@implementer(_interfaces.ISend)
|
||||
class MySend(object):
|
||||
def __init__(self, side):
|
||||
self.rx_phase = 0
|
||||
self.side = side
|
||||
def send(self, phase, plaintext):
|
||||
#print("SEND[%s]" % self.side, phase, plaintext)
|
||||
self.peer.got(phase, plaintext)
|
||||
def got(self, phase, plaintext):
|
||||
d_mo = re.search(r'^dilate-(\d+)$', phase)
|
||||
p = int(d_mo.group(1))
|
||||
assert p == self.rx_phase
|
||||
self.rx_phase += 1
|
||||
self.dilator.received_dilate(plaintext)
|
||||
|
||||
@implementer(ITerminator)
|
||||
class FakeTerminator(object):
|
||||
def __init__(self):
|
||||
self.d = Deferred()
|
||||
def stoppedD(self):
|
||||
self.d.callback(None)
|
||||
|
||||
class Connect(unittest.TestCase):
|
||||
@inlineCallbacks
|
||||
def test1(self):
|
||||
if not NoiseConnection:
|
||||
raise unittest.SkipTest("noiseprotocol unavailable")
|
||||
#print()
|
||||
send_left = MySend("left")
|
||||
send_right = MySend("right")
|
||||
send_left.peer = send_right
|
||||
send_right.peer = send_left
|
||||
key = b"\x00"*32
|
||||
eq = EventualQueue(reactor)
|
||||
cooperator = Cooperator(scheduler=eq.eventually)
|
||||
|
||||
t_left = FakeTerminator()
|
||||
t_right = FakeTerminator()
|
||||
|
||||
d_left = manager.Dilator(reactor, eq, cooperator, no_listen=True)
|
||||
d_left.wire(send_left, t_left)
|
||||
d_left.got_key(key)
|
||||
d_left.got_wormhole_versions({"can-dilate": ["1"]})
|
||||
send_left.dilator = d_left
|
||||
|
||||
d_right = manager.Dilator(reactor, eq, cooperator)
|
||||
d_right.wire(send_right, t_right)
|
||||
d_right.got_key(key)
|
||||
d_right.got_wormhole_versions({"can-dilate": ["1"]})
|
||||
send_right.dilator = d_right
|
||||
|
||||
with mock.patch("wormhole._dilation.connector.ipaddrs.find_addresses",
|
||||
return_value=["127.0.0.1"]):
|
||||
eps_left_d = d_left.dilate()
|
||||
eps_right_d = d_right.dilate()
|
||||
|
||||
eps_left = yield eps_left_d
|
||||
eps_right = yield eps_right_d
|
||||
|
||||
#print("left connected", eps_left)
|
||||
#print("right connected", eps_right)
|
||||
|
||||
control_ep_left, connect_ep_left, listen_ep_left = eps_left
|
||||
control_ep_right, connect_ep_right, listen_ep_right = eps_right
|
||||
|
||||
#control_ep_left.connect(
|
||||
|
||||
# we normally shut down with w.close(), which calls Dilator.stop(),
|
||||
# which calls Terminator.stoppedD(), which (after everything else is
|
||||
# done) calls Boss.stopped
|
||||
d_left.stop()
|
||||
d_right.stop()
|
||||
|
||||
yield t_left.d
|
||||
yield t_right.d
|
||||
|
||||
|
|
@ -9,6 +9,7 @@ from ..._interfaces import IDilationConnector
|
|||
from ..._dilation.roles import LEADER, FOLLOWER
|
||||
from ..._dilation.connection import (DilatedConnectionProtocol, encode_record,
|
||||
KCM, Open, Ack)
|
||||
from ..._dilation.encode import to_be4
|
||||
from .common import clear_mock_calls
|
||||
|
||||
|
||||
|
@ -19,7 +20,7 @@ def make_con(role, use_relay=False):
|
|||
alsoProvides(connector, IDilationConnector)
|
||||
n = mock.Mock() # pretends to be a Noise object
|
||||
n.write_message = mock.Mock(side_effect=[b"handshake"])
|
||||
c = DilatedConnectionProtocol(eq, role, connector, n,
|
||||
c = DilatedConnectionProtocol(eq, role, "desc", connector, n,
|
||||
b"outbound_prologue\n", b"inbound_prologue\n")
|
||||
if use_relay:
|
||||
c.use_relay(b"relay_handshake\n")
|
||||
|
@ -29,6 +30,10 @@ def make_con(role, use_relay=False):
|
|||
|
||||
|
||||
class Connection(unittest.TestCase):
|
||||
def test_hashable(self):
|
||||
c, n, connector, t, eq = make_con(LEADER)
|
||||
hash(c)
|
||||
|
||||
def test_bad_prologue(self):
|
||||
c, n, connector, t, eq = make_con(LEADER)
|
||||
c.makeConnection(t)
|
||||
|
@ -52,7 +57,7 @@ class Connection(unittest.TestCase):
|
|||
def _test_no_relay(self, role):
|
||||
c, n, connector, t, eq = make_con(role)
|
||||
t_kcm = KCM()
|
||||
t_open = Open(seqnum=1, scid=0x11223344)
|
||||
t_open = Open(seqnum=1, scid=to_be4(0x11223344))
|
||||
t_ack = Ack(resp_seqnum=2)
|
||||
n.decrypt = mock.Mock(side_effect=[
|
||||
encode_record(t_kcm),
|
||||
|
@ -69,10 +74,20 @@ class Connection(unittest.TestCase):
|
|||
clear_mock_calls(n, connector, t, m)
|
||||
|
||||
c.dataReceived(b"inbound_prologue\n")
|
||||
self.assertEqual(n.mock_calls, [mock.call.write_message()])
|
||||
self.assertEqual(connector.mock_calls, [])
|
||||
|
||||
exp_handshake = b"\x00\x00\x00\x09handshake"
|
||||
self.assertEqual(t.mock_calls, [mock.call.write(exp_handshake)])
|
||||
if role is LEADER:
|
||||
# the LEADER sends the Noise handshake message immediately upon
|
||||
# receipt of the prologue
|
||||
self.assertEqual(n.mock_calls, [mock.call.write_message()])
|
||||
self.assertEqual(t.mock_calls, [mock.call.write(exp_handshake)])
|
||||
else:
|
||||
# however the FOLLOWER waits until receiving the leader's
|
||||
# handshake before sending their own
|
||||
self.assertEqual(n.mock_calls, [])
|
||||
self.assertEqual(t.mock_calls, [])
|
||||
self.assertEqual(connector.mock_calls, [])
|
||||
|
||||
clear_mock_calls(n, connector, t, m)
|
||||
|
||||
c.dataReceived(b"\x00\x00\x00\x0Ahandshake2")
|
||||
|
@ -84,13 +99,16 @@ class Connection(unittest.TestCase):
|
|||
self.assertEqual(t.mock_calls, [])
|
||||
self.assertEqual(c._manager, None)
|
||||
else:
|
||||
# we're the follower, so we encrypt and send the KCM immediately
|
||||
# we're the follower, so we send our Noise handshake, then
|
||||
# encrypt and send the KCM immediately
|
||||
self.assertEqual(n.mock_calls, [
|
||||
mock.call.read_message(b"handshake2"),
|
||||
mock.call.write_message(),
|
||||
mock.call.encrypt(encode_record(t_kcm)),
|
||||
])
|
||||
self.assertEqual(connector.mock_calls, [])
|
||||
self.assertEqual(t.mock_calls, [
|
||||
mock.call.write(exp_handshake),
|
||||
mock.call.write(exp_kcm)])
|
||||
self.assertEqual(c._manager, None)
|
||||
clear_mock_calls(n, connector, t, m)
|
||||
|
|
|
@ -5,6 +5,7 @@ from zope.interface import alsoProvides
|
|||
from twisted.trial import unittest
|
||||
from twisted.internet.task import Clock
|
||||
from twisted.internet.defer import Deferred
|
||||
from twisted.internet.address import IPv4Address
|
||||
from ...eventual import EventualQueue
|
||||
from ..._interfaces import IDilationManager, IDilationConnector
|
||||
from ..._hints import DirectTCPV1Hint, RelayV1Hint, TorTCPV1Hint
|
||||
|
@ -34,11 +35,11 @@ class Outbound(unittest.TestCase):
|
|||
p0 = mock.Mock()
|
||||
c.build_protocol = mock.Mock(return_value=p0)
|
||||
relay_handshake = None
|
||||
f = OutboundConnectionFactory(c, relay_handshake)
|
||||
f = OutboundConnectionFactory(c, relay_handshake, "desc")
|
||||
addr = object()
|
||||
p = f.buildProtocol(addr)
|
||||
self.assertIdentical(p, p0)
|
||||
self.assertEqual(c.mock_calls, [mock.call.build_protocol(addr)])
|
||||
self.assertEqual(c.mock_calls, [mock.call.build_protocol(addr, "desc")])
|
||||
self.assertEqual(p.mock_calls, [])
|
||||
self.assertIdentical(p.factory, f)
|
||||
|
||||
|
@ -48,11 +49,11 @@ class Outbound(unittest.TestCase):
|
|||
p0 = mock.Mock()
|
||||
c.build_protocol = mock.Mock(return_value=p0)
|
||||
relay_handshake = b"relay handshake"
|
||||
f = OutboundConnectionFactory(c, relay_handshake)
|
||||
f = OutboundConnectionFactory(c, relay_handshake, "desc")
|
||||
addr = object()
|
||||
p = f.buildProtocol(addr)
|
||||
self.assertIdentical(p, p0)
|
||||
self.assertEqual(c.mock_calls, [mock.call.build_protocol(addr)])
|
||||
self.assertEqual(c.mock_calls, [mock.call.build_protocol(addr, "desc")])
|
||||
self.assertEqual(p.mock_calls, [mock.call.use_relay(relay_handshake)])
|
||||
self.assertIdentical(p.factory, f)
|
||||
|
||||
|
@ -63,10 +64,10 @@ class Inbound(unittest.TestCase):
|
|||
p0 = mock.Mock()
|
||||
c.build_protocol = mock.Mock(return_value=p0)
|
||||
f = InboundConnectionFactory(c)
|
||||
addr = object()
|
||||
addr = IPv4Address("TCP", "1.2.3.4", 55)
|
||||
p = f.buildProtocol(addr)
|
||||
self.assertIdentical(p, p0)
|
||||
self.assertEqual(c.mock_calls, [mock.call.build_protocol(addr)])
|
||||
self.assertEqual(c.mock_calls, [mock.call.build_protocol(addr, "<-tcp:1.2.3.4:55")])
|
||||
self.assertIdentical(p.factory, f)
|
||||
|
||||
def make_connector(listen=True, tor=False, relay=None, role=roles.LEADER):
|
||||
|
@ -115,13 +116,13 @@ class TestConnector(unittest.TestCase):
|
|||
return_value=n0) as bn:
|
||||
with mock.patch("wormhole._dilation.connector.DilatedConnectionProtocol",
|
||||
return_value=p0) as dcp:
|
||||
p = c.build_protocol(addr)
|
||||
p = c.build_protocol(addr, "desc")
|
||||
self.assertEqual(bn.mock_calls, [mock.call()])
|
||||
self.assertEqual(n0.mock_calls, [mock.call.set_psks(h.dilation_key),
|
||||
mock.call.set_as_initiator()])
|
||||
self.assertIdentical(p, p0)
|
||||
self.assertEqual(dcp.mock_calls,
|
||||
[mock.call(h.eq, h.role, c, n0,
|
||||
[mock.call(h.eq, h.role, "desc", c, n0,
|
||||
PROLOGUE_LEADER, PROLOGUE_FOLLOWER)])
|
||||
|
||||
def test_build_protocol_follower(self):
|
||||
|
@ -133,13 +134,13 @@ class TestConnector(unittest.TestCase):
|
|||
return_value=n0) as bn:
|
||||
with mock.patch("wormhole._dilation.connector.DilatedConnectionProtocol",
|
||||
return_value=p0) as dcp:
|
||||
p = c.build_protocol(addr)
|
||||
p = c.build_protocol(addr, "desc")
|
||||
self.assertEqual(bn.mock_calls, [mock.call()])
|
||||
self.assertEqual(n0.mock_calls, [mock.call.set_psks(h.dilation_key),
|
||||
mock.call.set_as_responder()])
|
||||
self.assertIdentical(p, p0)
|
||||
self.assertEqual(dcp.mock_calls,
|
||||
[mock.call(h.eq, h.role, c, n0,
|
||||
[mock.call(h.eq, h.role, "desc", c, n0,
|
||||
PROLOGUE_FOLLOWER, PROLOGUE_LEADER)])
|
||||
|
||||
def test_start_stop(self):
|
||||
|
@ -244,7 +245,7 @@ class TestConnector(unittest.TestCase):
|
|||
with mock.patch("wormhole._dilation.connector.OutboundConnectionFactory",
|
||||
return_value=f) as ocf:
|
||||
h.clock.advance(1.0)
|
||||
self.assertEqual(ocf.mock_calls, [mock.call(c, None)])
|
||||
self.assertEqual(ocf.mock_calls, [mock.call(c, None, "->tcp:foo:55")])
|
||||
self.assertEqual(ep.connect.mock_calls, [mock.call(f)])
|
||||
p = mock.Mock()
|
||||
d.callback(p)
|
||||
|
@ -269,7 +270,7 @@ class TestConnector(unittest.TestCase):
|
|||
return_value=f) as ocf:
|
||||
h.clock.advance(1.0)
|
||||
handshake = build_sided_relay_handshake(h.dilation_key, h.side)
|
||||
self.assertEqual(ocf.mock_calls, [mock.call(c, handshake)])
|
||||
self.assertEqual(ocf.mock_calls, [mock.call(c, handshake, "->relay:tcp:foo:55")])
|
||||
|
||||
def test_listen_but_tor(self):
|
||||
c, h = make_connector(listen=True, tor=True, role=roles.LEADER)
|
||||
|
@ -388,7 +389,7 @@ class Race(unittest.TestCase):
|
|||
c.add_candidate(p1)
|
||||
self.assertEqual(h.manager.mock_calls, [])
|
||||
h.eq.flush_sync()
|
||||
self.assertEqual(h.manager.mock_calls, [mock.call.use_connection(p1)])
|
||||
self.assertEqual(h.manager.mock_calls, [mock.call.connector_connection_made(p1)])
|
||||
self.assertEqual(p1.mock_calls,
|
||||
[mock.call.select(h.manager),
|
||||
mock.call.send_record(KCM())])
|
||||
|
@ -409,7 +410,7 @@ class Race(unittest.TestCase):
|
|||
c.add_candidate(p1)
|
||||
self.assertEqual(h.manager.mock_calls, [])
|
||||
h.eq.flush_sync()
|
||||
self.assertEqual(h.manager.mock_calls, [mock.call.use_connection(p1)])
|
||||
self.assertEqual(h.manager.mock_calls, [mock.call.connector_connection_made(p1)])
|
||||
# just like LEADER, but follower doesn't send KCM now (it sent one
|
||||
# earlier, to tell the leader that this connection looks viable)
|
||||
self.assertEqual(p1.mock_calls,
|
||||
|
@ -432,7 +433,7 @@ class Race(unittest.TestCase):
|
|||
c.add_candidate(p1)
|
||||
self.assertEqual(h.manager.mock_calls, [])
|
||||
h.eq.flush_sync()
|
||||
self.assertEqual(h.manager.mock_calls, [mock.call.use_connection(p1)])
|
||||
self.assertEqual(h.manager.mock_calls, [mock.call.connector_connection_made(p1)])
|
||||
clear_mock_calls(h.manager)
|
||||
self.assertEqual(p1.mock_calls,
|
||||
[mock.call.select(h.manager),
|
||||
|
@ -454,10 +455,9 @@ class Race(unittest.TestCase):
|
|||
c.add_candidate(p1)
|
||||
self.assertEqual(h.manager.mock_calls, [])
|
||||
h.eq.flush_sync()
|
||||
self.assertEqual(h.manager.mock_calls, [mock.call.use_connection(p1)])
|
||||
self.assertEqual(p1.mock_calls,
|
||||
[mock.call.select(h.manager),
|
||||
mock.call.send_record(KCM())])
|
||||
self.assertEqual(h.manager.mock_calls, [mock.call.connector_connection_made(p1)])
|
||||
|
||||
c.stop()
|
||||
|
||||
|
|
77
src/wormhole/test/dilate/test_full.py
Normal file
77
src/wormhole/test/dilate/test_full.py
Normal file
|
@ -0,0 +1,77 @@
|
|||
from __future__ import print_function, absolute_import, unicode_literals
|
||||
import wormhole
|
||||
from twisted.internet import reactor
|
||||
from twisted.internet.defer import Deferred, inlineCallbacks, gatherResults
|
||||
from twisted.internet.protocol import Protocol, Factory
|
||||
from twisted.trial import unittest
|
||||
|
||||
from ..common import ServerBase
|
||||
from ...eventual import EventualQueue
|
||||
from ..._dilation._noise import NoiseConnection
|
||||
|
||||
APPID = u"lothar.com/dilate-test"
|
||||
|
||||
def doBoth(d1, d2):
|
||||
return gatherResults([d1, d2], True)
|
||||
|
||||
class L(Protocol):
|
||||
def connectionMade(self):
|
||||
print("got connection")
|
||||
self.transport.write(b"hello\n")
|
||||
def dataReceived(self, data):
|
||||
print("dataReceived: {}".format(data))
|
||||
self.factory.d.callback(data)
|
||||
def connectionLost(self, why):
|
||||
print("connectionLost")
|
||||
|
||||
|
||||
class Full(ServerBase, unittest.TestCase):
|
||||
@inlineCallbacks
|
||||
def setUp(self):
|
||||
if not NoiseConnection:
|
||||
raise unittest.SkipTest("noiseprotocol unavailable")
|
||||
# test_welcome wants to see [current_cli_version]
|
||||
yield self._setup_relay(None)
|
||||
|
||||
@inlineCallbacks
|
||||
def test_full(self):
|
||||
eq = EventualQueue(reactor)
|
||||
w1 = wormhole.create(APPID, self.relayurl, reactor, _enable_dilate=True)
|
||||
w2 = wormhole.create(APPID, self.relayurl, reactor, _enable_dilate=True)
|
||||
w1.allocate_code()
|
||||
code = yield w1.get_code()
|
||||
print("code is: {}".format(code))
|
||||
w2.set_code(code)
|
||||
yield doBoth(w1.get_verifier(), w2.get_verifier())
|
||||
print("connected")
|
||||
|
||||
eps1_d = w1.dilate()
|
||||
eps2_d = w2.dilate()
|
||||
(eps1, eps2) = yield doBoth(eps1_d, eps2_d)
|
||||
(control_ep1, connect_ep1, listen_ep1) = eps1
|
||||
(control_ep2, connect_ep2, listen_ep2) = eps2
|
||||
print("w.dilate ready")
|
||||
|
||||
f1 = Factory()
|
||||
f1.protocol = L
|
||||
f1.d = Deferred()
|
||||
f1.d.addCallback(lambda data: eq.fire_eventually(data))
|
||||
d1 = control_ep1.connect(f1)
|
||||
|
||||
f2 = Factory()
|
||||
f2.protocol = L
|
||||
f2.d = Deferred()
|
||||
f2.d.addCallback(lambda data: eq.fire_eventually(data))
|
||||
d2 = control_ep2.connect(f2)
|
||||
yield d1
|
||||
yield d2
|
||||
print("control endpoints connected")
|
||||
data1 = yield f1.d
|
||||
data2 = yield f2.d
|
||||
self.assertEqual(data1, b"hello\n")
|
||||
self.assertEqual(data2, b"hello\n")
|
||||
|
||||
yield w1.close()
|
||||
yield w2.close()
|
||||
|
||||
test_full.timeout = 30
|
|
@ -27,12 +27,12 @@ class InboundTest(unittest.TestCase):
|
|||
self.assertFalse(i.is_record_old(r2))
|
||||
self.assertFalse(i.is_record_old(r3))
|
||||
|
||||
i.update_ack_watermark(r1)
|
||||
i.update_ack_watermark(r1.seqnum)
|
||||
self.assertTrue(i.is_record_old(r1))
|
||||
self.assertFalse(i.is_record_old(r2))
|
||||
self.assertFalse(i.is_record_old(r3))
|
||||
|
||||
i.update_ack_watermark(r2)
|
||||
i.update_ack_watermark(r2.seqnum)
|
||||
self.assertTrue(i.is_record_old(r1))
|
||||
self.assertTrue(i.is_record_old(r2))
|
||||
self.assertFalse(i.is_record_old(r3))
|
||||
|
|
|
@ -5,7 +5,7 @@ from twisted.internet.defer import Deferred
|
|||
from twisted.internet.task import Clock, Cooperator
|
||||
import mock
|
||||
from ...eventual import EventualQueue
|
||||
from ..._interfaces import ISend, IDilationManager
|
||||
from ..._interfaces import ISend, IDilationManager, ITerminator
|
||||
from ...util import dict_to_bytes
|
||||
from ..._dilation import roles
|
||||
from ..._dilation.encode import to_be4
|
||||
|
@ -32,7 +32,9 @@ def make_dilator():
|
|||
send = mock.Mock()
|
||||
alsoProvides(send, ISend)
|
||||
dil = Dilator(reactor, eq, coop)
|
||||
dil.wire(send)
|
||||
terminator = mock.Mock()
|
||||
alsoProvides(terminator, ITerminator)
|
||||
dil.wire(send, terminator)
|
||||
return dil, send, reactor, eq, clock, coop
|
||||
|
||||
|
||||
|
@ -64,7 +66,7 @@ class TestDilator(unittest.TestCase):
|
|||
dil.got_wormhole_versions({"can-dilate": ["1"]})
|
||||
# that should create the Manager
|
||||
self.assertEqual(ml.mock_calls, [mock.call(send, "us", transit_key,
|
||||
None, reactor, eq, coop)])
|
||||
None, reactor, eq, coop, no_listen=False)])
|
||||
# and tell it to start, and get wait-for-it-to-connect Deferred
|
||||
self.assertEqual(m.mock_calls, [mock.call.start(),
|
||||
mock.call.when_first_connected(),
|
||||
|
@ -180,7 +182,7 @@ class TestDilator(unittest.TestCase):
|
|||
return_value="us"):
|
||||
dil.got_wormhole_versions({"can-dilate": ["1"]})
|
||||
self.assertEqual(ml.mock_calls, [mock.call(send, "us", b"key",
|
||||
None, reactor, eq, coop)])
|
||||
None, reactor, eq, coop, no_listen=False)])
|
||||
self.assertEqual(m.mock_calls, [mock.call.start(),
|
||||
mock.call.rx_PLEASE(pleasemsg),
|
||||
mock.call.rx_HINTS(hintmsg),
|
||||
|
@ -198,7 +200,7 @@ class TestDilator(unittest.TestCase):
|
|||
return_value="us"):
|
||||
dil.got_wormhole_versions({"can-dilate": ["1"]})
|
||||
self.assertEqual(ml.mock_calls, [mock.call(send, "us", b"key",
|
||||
relay, reactor, eq, coop),
|
||||
relay, reactor, eq, coop, no_listen=False),
|
||||
mock.call().start(),
|
||||
mock.call().when_first_connected()])
|
||||
|
||||
|
|
|
@ -105,7 +105,7 @@ class OutboundTest(unittest.TestCase):
|
|||
|
||||
# as soon as the connection is established, everything is sent
|
||||
o.use_connection(c)
|
||||
self.assertEqual(c.mock_calls, [mock.call.registerProducer(o, True),
|
||||
self.assertEqual(c.mock_calls, [mock.call.transport.registerProducer(o, True),
|
||||
mock.call.send_record(r1),
|
||||
mock.call.send_record(r2)])
|
||||
self.assertEqual(list(o._outbound_queue), [r1, r2])
|
||||
|
@ -131,7 +131,7 @@ class OutboundTest(unittest.TestCase):
|
|||
# after each write. So only r1 should have been sent before getting
|
||||
# paused
|
||||
o.use_connection(c)
|
||||
self.assertEqual(c.mock_calls, [mock.call.registerProducer(o, True),
|
||||
self.assertEqual(c.mock_calls, [mock.call.transport.registerProducer(o, True),
|
||||
mock.call.send_record(r1)])
|
||||
self.assertEqual(list(o._outbound_queue), [r1, r2])
|
||||
self.assertEqual(list(o._queued_unsent), [r2])
|
||||
|
@ -172,7 +172,7 @@ class OutboundTest(unittest.TestCase):
|
|||
self.assertEqual(list(o._queued_unsent), [])
|
||||
|
||||
o.use_connection(c)
|
||||
self.assertEqual(c.mock_calls, [mock.call.registerProducer(o, True),
|
||||
self.assertEqual(c.mock_calls, [mock.call.transport.registerProducer(o, True),
|
||||
mock.call.send_record(r1)])
|
||||
self.assertEqual(list(o._outbound_queue), [r1, r2])
|
||||
self.assertEqual(list(o._queued_unsent), [r2])
|
||||
|
@ -191,7 +191,7 @@ class OutboundTest(unittest.TestCase):
|
|||
def test_pause(self):
|
||||
o, m, c = make_outbound()
|
||||
o.use_connection(c)
|
||||
self.assertEqual(c.mock_calls, [mock.call.registerProducer(o, True)])
|
||||
self.assertEqual(c.mock_calls, [mock.call.transport.registerProducer(o, True)])
|
||||
self.assertEqual(list(o._outbound_queue), [])
|
||||
self.assertEqual(list(o._queued_unsent), [])
|
||||
clear_mock_calls(c)
|
||||
|
@ -519,7 +519,7 @@ class OutboundTest(unittest.TestCase):
|
|||
|
||||
o.use_connection(c)
|
||||
o.send_if_connected(KCM())
|
||||
self.assertEqual(c.mock_calls, [mock.call.registerProducer(o, True),
|
||||
self.assertEqual(c.mock_calls, [mock.call.transport.registerProducer(o, True),
|
||||
mock.call.send_record(KCM())])
|
||||
|
||||
def test_tolerate_duplicate_pause_resume(self):
|
||||
|
|
|
@ -13,11 +13,11 @@ class Parse(unittest.TestCase):
|
|||
self.assertEqual(parse_record(b"\x02\x55\x44\x33\x22"),
|
||||
Pong(ping_id=b"\x55\x44\x33\x22"))
|
||||
self.assertEqual(parse_record(b"\x03\x00\x00\x02\x01\x00\x00\x01\x00"),
|
||||
Open(scid=513, seqnum=256))
|
||||
Open(scid=b"\x00\x00\x02\x01", seqnum=256))
|
||||
self.assertEqual(parse_record(b"\x04\x00\x00\x02\x02\x00\x00\x01\x01dataaa"),
|
||||
Data(scid=514, seqnum=257, data=b"dataaa"))
|
||||
Data(scid=b"\x00\x00\x02\x02", seqnum=257, data=b"dataaa"))
|
||||
self.assertEqual(parse_record(b"\x05\x00\x00\x02\x03\x00\x00\x01\x02"),
|
||||
Close(scid=515, seqnum=258))
|
||||
Close(scid=b"\x00\x00\x02\x03", seqnum=258))
|
||||
self.assertEqual(parse_record(b"\x06\x00\x00\x01\x03"),
|
||||
Ack(resp_seqnum=259))
|
||||
with mock.patch("wormhole._dilation.connection.log.err") as le:
|
||||
|
@ -31,11 +31,11 @@ class Parse(unittest.TestCase):
|
|||
self.assertEqual(encode_record(KCM()), b"\x00")
|
||||
self.assertEqual(encode_record(Ping(ping_id=b"ping")), b"\x01ping")
|
||||
self.assertEqual(encode_record(Pong(ping_id=b"pong")), b"\x02pong")
|
||||
self.assertEqual(encode_record(Open(scid=65536, seqnum=16)),
|
||||
self.assertEqual(encode_record(Open(scid=b"\x00\x01\x00\x00", seqnum=16)),
|
||||
b"\x03\x00\x01\x00\x00\x00\x00\x00\x10")
|
||||
self.assertEqual(encode_record(Data(scid=65537, seqnum=17, data=b"dataaa")),
|
||||
self.assertEqual(encode_record(Data(scid=b"\x00\x01\x00\x01", seqnum=17, data=b"dataaa")),
|
||||
b"\x04\x00\x01\x00\x01\x00\x00\x00\x11dataaa")
|
||||
self.assertEqual(encode_record(Close(scid=65538, seqnum=18)),
|
||||
self.assertEqual(encode_record(Close(scid=b"\x00\x01\x00\x02", seqnum=18)),
|
||||
b"\x05\x00\x01\x00\x02\x00\x00\x00\x12")
|
||||
self.assertEqual(encode_record(Ack(resp_seqnum=19)),
|
||||
b"\x06\x00\x00\x00\x13")
|
||||
|
|
|
@ -6,13 +6,15 @@ from ..._dilation._noise import NoiseInvalidMessage
|
|||
from ..._dilation.connection import (IFramer, Frame, Prologue,
|
||||
_Record, Handshake,
|
||||
Disconnect, Ping)
|
||||
from ..._dilation.roles import LEADER
|
||||
|
||||
|
||||
def make_record():
|
||||
f = mock.Mock()
|
||||
alsoProvides(f, IFramer)
|
||||
n = mock.Mock() # pretends to be a Noise object
|
||||
r = _Record(f, n)
|
||||
r = _Record(f, n, LEADER)
|
||||
r.set_role_leader()
|
||||
return r, f, n
|
||||
|
||||
|
||||
|
@ -30,7 +32,8 @@ class Record(unittest.TestCase):
|
|||
n.write_message = mock.Mock(return_value=b"tx-handshake")
|
||||
p1, p2 = object(), object()
|
||||
n.decrypt = mock.Mock(side_effect=[p1, p2])
|
||||
r = _Record(f, n)
|
||||
r = _Record(f, n, LEADER)
|
||||
r.set_role_leader()
|
||||
self.assertEqual(f.mock_calls, [])
|
||||
r.connectionMade()
|
||||
self.assertEqual(f.mock_calls, [mock.call.connectionMade()])
|
||||
|
@ -79,7 +82,8 @@ class Record(unittest.TestCase):
|
|||
n.write_message = mock.Mock(return_value=b"tx-handshake")
|
||||
nvm = NoiseInvalidMessage()
|
||||
n.read_message = mock.Mock(side_effect=nvm)
|
||||
r = _Record(f, n)
|
||||
r = _Record(f, n, LEADER)
|
||||
r.set_role_leader()
|
||||
self.assertEqual(f.mock_calls, [])
|
||||
r.connectionMade()
|
||||
self.assertEqual(f.mock_calls, [mock.call.connectionMade()])
|
||||
|
@ -103,7 +107,8 @@ class Record(unittest.TestCase):
|
|||
n.write_message = mock.Mock(return_value=b"tx-handshake")
|
||||
nvm = NoiseInvalidMessage()
|
||||
n.decrypt = mock.Mock(side_effect=nvm)
|
||||
r = _Record(f, n)
|
||||
r = _Record(f, n, LEADER)
|
||||
r.set_role_leader()
|
||||
self.assertEqual(f.mock_calls, [])
|
||||
r.connectionMade()
|
||||
self.assertEqual(f.mock_calls, [mock.call.connectionMade()])
|
||||
|
@ -124,7 +129,8 @@ class Record(unittest.TestCase):
|
|||
f1 = object()
|
||||
n.encrypt = mock.Mock(return_value=f1)
|
||||
r1 = Ping(b"pingid")
|
||||
r = _Record(f, n)
|
||||
r = _Record(f, n, LEADER)
|
||||
r.set_role_leader()
|
||||
self.assertEqual(f.mock_calls, [])
|
||||
m1 = object()
|
||||
with mock.patch("wormhole._dilation.connection.encode_record",
|
||||
|
|
|
@ -1220,7 +1220,8 @@ class Terminator(unittest.TestCase):
|
|||
rc = Dummy("rc", events, IRendezvousConnector, "stop")
|
||||
n = Dummy("n", events, INameplate, "close")
|
||||
m = Dummy("m", events, IMailbox, "close")
|
||||
t.wire(b, rc, n, m)
|
||||
d = Dummy("d", events, IDilator, "stop")
|
||||
t.wire(b, rc, n, m, d)
|
||||
return t, b, rc, n, m, events
|
||||
|
||||
# there are three events, and we need to test all orderings of them
|
||||
|
@ -1229,45 +1230,64 @@ class Terminator(unittest.TestCase):
|
|||
input_events = {
|
||||
"mailbox": lambda: t.mailbox_done(),
|
||||
"nameplate": lambda: t.nameplate_done(),
|
||||
"close": lambda: t.close("happy"),
|
||||
"rc": lambda: t.close("happy"),
|
||||
}
|
||||
close_events = [
|
||||
("n.close", ),
|
||||
("m.close", "happy"),
|
||||
]
|
||||
|
||||
if ev1 == "mailbox":
|
||||
close_events.remove(("m.close", "happy"))
|
||||
elif ev1 == "nameplate":
|
||||
close_events.remove(("n.close",))
|
||||
|
||||
input_events[ev1]()
|
||||
expected = []
|
||||
if ev1 == "close":
|
||||
if ev1 == "rc":
|
||||
expected.extend(close_events)
|
||||
self.assertEqual(events, expected)
|
||||
events[:] = []
|
||||
|
||||
if ev2 == "mailbox":
|
||||
close_events.remove(("m.close", "happy"))
|
||||
elif ev2 == "nameplate":
|
||||
close_events.remove(("n.close",))
|
||||
|
||||
input_events[ev2]()
|
||||
expected = []
|
||||
if ev2 == "close":
|
||||
if ev2 == "rc":
|
||||
expected.extend(close_events)
|
||||
self.assertEqual(events, expected)
|
||||
events[:] = []
|
||||
|
||||
if ev3 == "mailbox":
|
||||
close_events.remove(("m.close", "happy"))
|
||||
elif ev3 == "nameplate":
|
||||
close_events.remove(("n.close",))
|
||||
|
||||
input_events[ev3]()
|
||||
expected = []
|
||||
if ev3 == "close":
|
||||
if ev3 == "rc":
|
||||
expected.extend(close_events)
|
||||
expected.append(("rc.stop", ))
|
||||
self.assertEqual(events, expected)
|
||||
events[:] = []
|
||||
|
||||
t.stopped()
|
||||
t.stoppedRC()
|
||||
self.assertEqual(events, [("d.stop", )])
|
||||
events[:] = []
|
||||
|
||||
t.stoppedD()
|
||||
self.assertEqual(events, [("b.closed", )])
|
||||
|
||||
def test_terminate(self):
|
||||
self._do_test("mailbox", "nameplate", "close")
|
||||
self._do_test("mailbox", "close", "nameplate")
|
||||
self._do_test("nameplate", "mailbox", "close")
|
||||
self._do_test("nameplate", "close", "mailbox")
|
||||
self._do_test("close", "nameplate", "mailbox")
|
||||
self._do_test("close", "mailbox", "nameplate")
|
||||
self._do_test("mailbox", "nameplate", "rc")
|
||||
self._do_test("mailbox", "rc", "nameplate")
|
||||
self._do_test("nameplate", "mailbox", "rc")
|
||||
self._do_test("nameplate", "rc", "mailbox")
|
||||
self._do_test("rc", "nameplate", "mailbox")
|
||||
self._do_test("rc", "mailbox", "nameplate")
|
||||
|
||||
# TODO: test moods
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user