Merge commit 'dilate-half-close'

Add half-close to dilation API. Fix several bugs exposed by application-level
testing.

refs #344 but doesn't close it yet: I want to call `p.connectionLost()` after
both directions have been closed down
This commit is contained in:
Brian Warner 2019-08-11 22:22:45 -07:00
commit ab5fe65c3b
7 changed files with 229 additions and 39 deletions

View File

@ -205,8 +205,8 @@ class Boss(object):
self._did_start_code = True
self._C.set_code(code)
def dilate(self, no_listen=False):
return self._D.dilate(no_listen=no_listen) # fires with endpoints
def dilate(self, transit_relay_location=None, no_listen=False):
return self._D.dilate(transit_relay_location, no_listen=no_listen) # fires with endpoints
@m.input()
def send(self, plaintext):

View File

@ -10,7 +10,7 @@ from twisted.internet.defer import DeferredList, CancelledError
from twisted.internet.endpoints import serverFromString
from twisted.internet.protocol import ClientFactory, ServerFactory
from twisted.internet.address import HostnameAddress, IPv4Address, IPv6Address
from twisted.internet.error import ConnectingCancelledError
from twisted.internet.error import ConnectingCancelledError, ConnectionRefusedError, DNSLookupError
from twisted.python import log
from .. import ipaddrs # TODO: move into _dilation/
from .._interfaces import IDilationConnector, IDilationManager
@ -28,7 +28,9 @@ from ._noise import NoiseConnection
def build_sided_relay_handshake(key, side):
assert isinstance(side, type(u""))
assert len(side) == 8 * 2
# magic-wormhole-transit-relay expects a specific layout for the
# handshake message: "please relay {64} for side {16}\n"
assert len(side) == 8 * 2, side
token = HKDF(key, 32, CTXinfo=b"transit_relay_token")
return (b"please relay " + hexlify(token) +
b" for side " + side.encode("ascii") + b"\n")
@ -310,7 +312,13 @@ class Connector(object):
d = deferLater(self._reactor, delay,
self._connect, ep, desc, is_relay)
d.addErrback(lambda f: f.trap(ConnectingCancelledError,
CancelledError))
ConnectionRefusedError,
CancelledError,
))
# TODO: HostnameEndpoint.connect catches CancelledError and replaces
# it with DNSLookupError. Remove this workaround when
# https://twistedmatrix.com/trac/ticket/9696 is fixed.
d.addErrback(lambda f: f.trap(DNSLookupError))
d.addErrback(log.err)
self._pending_connectors.add(d)

View File

@ -65,7 +65,7 @@ class EndpointRecord(Sequence):
return (self.control, self.connect, self.listen)[n]
def make_side():
return bytes_to_hexstr(os.urandom(6))
return bytes_to_hexstr(os.urandom(8))
# new scheme:
@ -552,6 +552,7 @@ class Manager(object):
ABANDONING.upon(rx_HINTS, enter=ABANDONING, outputs=[]) # shouldn't happen
STOPPING.upon(rx_HINTS, enter=STOPPING, outputs=[])
WAITING.upon(stop, enter=STOPPED, outputs=[notify_stopped])
WANTING.upon(stop, enter=STOPPED, outputs=[notify_stopped])
CONNECTING.upon(stop, enter=STOPPED, outputs=[stop_connecting, notify_stopped])
CONNECTED.upon(stop, enter=STOPPING, outputs=[abandon_connection])

View File

@ -6,6 +6,7 @@ from zope.interface import implementer
from twisted.internet.defer import inlineCallbacks, returnValue
from twisted.internet.interfaces import (ITransport, IProducer, IConsumer,
IAddress, IListeningPort,
IHalfCloseableProtocol,
IStreamClientEndpoint,
IStreamServerEndpoint)
from twisted.internet.error import ConnectionDone
@ -55,6 +56,11 @@ class SingleUseEndpointError(Exception):
class AlreadyClosedError(Exception):
pass
class NormalCloseUsedOnHalfCloseable(Exception):
pass
class HalfCloseUsedOnNonHalfCloseable(Exception):
pass
@implementer(IAddress)
class _WormholeAddress(object):
@ -87,11 +93,29 @@ class SubChannel(object):
# self._pending_outbound = {}
# self._processed = set()
self._protocol = None
self._pending_dataReceived = []
self._pending_connectionLost = (False, None)
self._pending_remote_data = []
self._pending_remote_close = False
@m.state(initial=True)
def open(self):
def unconnected(self):
pass # pragma: no cover
# once we get the IProtocol, it's either a IHalfCloseableProtocol, or it
# can only be fully closed
@m.state()
def open_half(self):
pass # pragma: no cover
@m.state()
def read_closed():
pass # pragma: no cover
@m.state()
def write_closed():
pass # pragma: no cover
@m.state()
def open_full(self):
pass # pragma: no cover
@m.state()
@ -102,6 +126,14 @@ class SubChannel(object):
def closed():
pass # pragma: no cover
@m.input()
def connect_protocol_half(self):
pass
@m.input()
def connect_protocol_full(self):
pass
@m.input()
def remote_data(self, data):
pass
@ -118,6 +150,14 @@ class SubChannel(object):
def local_close(self):
pass
@m.output()
def queue_remote_data(self, data):
self._pending_remote_data.append(data)
@m.output()
def queue_remote_close(self):
self._pending_remote_close = True
@m.output()
def send_data(self, data):
self._manager.send_data(self._scid, data)
@ -128,17 +168,24 @@ class SubChannel(object):
@m.output()
def signal_dataReceived(self, data):
if self._protocol:
self._protocol.dataReceived(data)
else:
self._pending_dataReceived.append(data)
assert self._protocol
self._protocol.dataReceived(data)
@m.output()
def signal_readConnectionLost(self):
IHalfCloseableProtocol(self._protocol).readConnectionLost()
@m.output()
def signal_writeConnectionLost(self):
IHalfCloseableProtocol(self._protocol).writeConnectionLost()
@m.output()
def signal_connectionLost(self):
if self._protocol:
self._protocol.connectionLost(ConnectionDone())
else:
self._pending_connectionLost = (True, ConnectionDone())
assert self._protocol
self._protocol.connectionLost(ConnectionDone())
@m.output()
def close_subchannel(self):
self._manager.subchannel_closed(self._scid, self)
# we're deleted momentarily
@ -151,14 +198,44 @@ class SubChannel(object):
raise AlreadyClosedError(
"loseConnection not allowed on closed subchannel")
# primary transitions
open.upon(remote_data, enter=open, outputs=[signal_dataReceived])
open.upon(local_data, enter=open, outputs=[send_data])
open.upon(remote_close, enter=closed, outputs=[send_close, signal_connectionLost])
open.upon(local_close, enter=closing, outputs=[send_close])
closing.upon(remote_data, enter=closing, outputs=[signal_dataReceived])
closing.upon(remote_close, enter=closed, outputs=[signal_connectionLost])
# stuff that arrives before we have a protocol connected
unconnected.upon(remote_data, enter=unconnected, outputs=[queue_remote_data])
unconnected.upon(remote_close, enter=unconnected, outputs=[queue_remote_close])
# IHalfCloseableProtocol flow
unconnected.upon(connect_protocol_half, enter=open_half, outputs=[])
open_half.upon(remote_data, enter=open_half, outputs=[signal_dataReceived])
open_half.upon(local_data, enter=open_half, outputs=[send_data])
# remote closes first
open_half.upon(remote_close, enter=read_closed, outputs=[signal_readConnectionLost])
read_closed.upon(local_data, enter=read_closed, outputs=[send_data])
read_closed.upon(local_close, enter=closed, outputs=[send_close,
close_subchannel,
# TODO: eventual-signal this?
signal_writeConnectionLost,
])
# local closes first
open_half.upon(local_close, enter=write_closed, outputs=[signal_writeConnectionLost,
send_close])
write_closed.upon(local_data, enter=write_closed, outputs=[error_closed_write])
write_closed.upon(remote_data, enter=write_closed, outputs=[signal_dataReceived])
write_closed.upon(remote_close, enter=closed, outputs=[close_subchannel,
signal_readConnectionLost,
])
# error cases
write_closed.upon(local_close, enter=write_closed, outputs=[error_closed_close])
# fully-closeable-only flow
unconnected.upon(connect_protocol_full, enter=open_full, outputs=[])
open_full.upon(remote_data, enter=open_full, outputs=[signal_dataReceived])
open_full.upon(local_data, enter=open_full, outputs=[send_data])
open_full.upon(remote_close, enter=closed, outputs=[send_close,
close_subchannel,
signal_connectionLost])
open_full.upon(local_close, enter=closing, outputs=[send_close])
closing.upon(remote_data, enter=closing, outputs=[signal_dataReceived])
closing.upon(remote_close, enter=closed, outputs=[close_subchannel,
signal_connectionLost])
# error cases
# we won't ever see an OPEN, since L4 will log+ignore those for us
closing.upon(local_data, enter=closing, outputs=[error_closed_write])
@ -170,15 +247,19 @@ class SubChannel(object):
def _set_protocol(self, protocol):
assert not self._protocol
self._protocol = protocol
if IHalfCloseableProtocol.providedBy(protocol):
self.connect_protocol_half()
else:
# move from UNCONNECTED to OPEN
self.connect_protocol_full();
def _deliver_queued_data(self):
if self._pending_dataReceived:
for data in self._pending_dataReceived:
self._protocol.dataReceived(data)
self._pending_dataReceived = []
cl, what = self._pending_connectionLost
if cl:
self._protocol.connectionLost(what)
for data in self._pending_remote_data:
self.remote_data(data)
del self._pending_remote_data
if self._pending_remote_close:
self.remote_close()
del self._pending_remote_close
# ITransport
def write(self, data):
@ -189,7 +270,18 @@ class SubChannel(object):
def writeSequence(self, iovec):
self.write(b"".join(iovec))
def loseWriteConnection(self):
if not IHalfCloseableProtocol.providedBy(self._protocol):
# this is a clear error
raise HalfCloseUsedOnNonHalfCloseable()
self.local_close();
def loseConnection(self):
# TODO: what happens if an IHalfCloseableProtocol calls normal
# loseConnection()? I think we need to close the read side too.
if IHalfCloseableProtocol.providedBy(self._protocol):
# I don't know is correct, so avoid this for now
raise NormalCloseUsedOnHalfCloseable()
self.local_close()
def getHost(self):

View File

@ -216,7 +216,7 @@ class TestManager(unittest.TestCase):
def test_make_side(self):
side = make_side()
self.assertEqual(type(side), type(u""))
self.assertEqual(len(side), 2 * 6)
self.assertEqual(len(side), 2 * 8)
def test_create(self):
m, h = make_manager()

View File

@ -1,21 +1,25 @@
from __future__ import print_function, unicode_literals
import mock
from zope.interface import directlyProvides
from twisted.trial import unittest
from twisted.internet.interfaces import ITransport
from twisted.internet.interfaces import ITransport, IHalfCloseableProtocol
from twisted.internet.error import ConnectionDone
from ..._dilation.subchannel import (Once, SubChannel,
_WormholeAddress, _SubchannelAddress,
AlreadyClosedError)
AlreadyClosedError,
NormalCloseUsedOnHalfCloseable)
from .common import mock_manager
def make_sc(set_protocol=True):
def make_sc(set_protocol=True, half_closeable=False):
scid = 4
hostaddr = _WormholeAddress()
peeraddr = _SubchannelAddress(scid)
m = mock_manager()
sc = SubChannel(scid, m, hostaddr, peeraddr)
p = mock.Mock()
if half_closeable:
directlyProvides(p, IHalfCloseableProtocol)
if set_protocol:
sc._set_protocol(p)
return sc, m, scid, hostaddr, peeraddr, p
@ -109,11 +113,13 @@ class SubChannelAPI(unittest.TestCase):
def test_data_before_open(self):
sc, m, scid, hostaddr, peeraddr, p = make_sc(set_protocol=False)
sc.remote_data(b"data")
sc.remote_data(b"data1")
sc.remote_data(b"data2")
self.assertEqual(p.mock_calls, [])
sc._set_protocol(p)
sc._deliver_queued_data()
self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"data")])
self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"data1"),
mock.call.dataReceived(b"data2")])
p.mock_calls[:] = []
sc.remote_data(b"more")
self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"more")])
@ -145,3 +151,86 @@ class SubChannelAPI(unittest.TestCase):
# TODO: more, once this is implemented
sc.registerProducer(None, True)
sc.unregisterProducer()
class HalfCloseable(unittest.TestCase):
def test_create(self):
sc, m, scid, hostaddr, peeraddr, p = make_sc(half_closeable=True)
self.assert_(ITransport.providedBy(sc))
self.assertEqual(m.mock_calls, [])
self.assertIdentical(sc.getHost(), hostaddr)
self.assertIdentical(sc.getPeer(), peeraddr)
def test_local_close(self):
sc, m, scid, hostaddr, peeraddr, p = make_sc(half_closeable=True)
sc.write(b"data")
self.assertEqual(m.mock_calls, [mock.call.send_data(scid, b"data")])
m.mock_calls[:] = []
sc.writeSequence([b"more", b"data"])
self.assertEqual(m.mock_calls, [mock.call.send_data(scid, b"moredata")])
m.mock_calls[:] = []
sc.remote_data(b"inbound1")
self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"inbound1")])
p.mock_calls[:] = []
with self.assertRaises(NormalCloseUsedOnHalfCloseable) as e:
sc.loseConnection() # TODO: maybe this shouldn't be an error
# after a local close, we can't write anymore, but we can still
# receive data
sc.loseWriteConnection() # TODO or loseConnection?
self.assertEqual(m.mock_calls, [mock.call.send_close(scid)])
m.mock_calls[:] = []
self.assertEqual(p.mock_calls, [mock.call.writeConnectionLost()])
p.mock_calls[:] = []
with self.assertRaises(AlreadyClosedError) as e:
sc.write(b"data")
self.assertEqual(str(e.exception),
"write not allowed on closed subchannel")
with self.assertRaises(AlreadyClosedError) as e:
sc.loseWriteConnection()
self.assertEqual(str(e.exception),
"loseConnection not allowed on closed subchannel")
with self.assertRaises(NormalCloseUsedOnHalfCloseable) as e:
sc.loseConnection() # TODO: maybe expect AlreadyClosedError
sc.remote_data(b"inbound2")
self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"inbound2")])
p.mock_calls[:] = []
# the remote end will finally shut down the connection
sc.remote_close()
self.assertEqual(m.mock_calls, [mock.call.subchannel_closed(scid, sc)])
self.assertEqual(p.mock_calls, [mock.call.readConnectionLost()])
def test_remote_close(self):
sc, m, scid, hostaddr, peeraddr, p = make_sc(half_closeable=True)
sc.write(b"data")
self.assertEqual(m.mock_calls, [mock.call.send_data(scid, b"data")])
m.mock_calls[:] = []
sc.remote_data(b"inbound1")
self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"inbound1")])
p.mock_calls[:] = []
# after a remote close, we can still write data
sc.remote_close()
self.assertEqual(m.mock_calls, [])
self.assertEqual(p.mock_calls, [mock.call.readConnectionLost()])
p.mock_calls[:] = []
sc.write(b"out2")
self.assertEqual(m.mock_calls, [mock.call.send_data(scid, b"out2")])
m.mock_calls[:] = []
# and a local close will shutdown the connection
sc.loseWriteConnection()
self.assertEqual(m.mock_calls, [mock.call.send_close(scid),
mock.call.subchannel_closed(scid, sc)])
self.assertEqual(p.mock_calls, [mock.call.writeConnectionLost()])

View File

@ -193,10 +193,10 @@ class _DeferredWormhole(object):
raise NoKeyError()
return derive_key(self._key, to_bytes(purpose), length)
def dilate(self, no_listen=False):
def dilate(self, transit_relay_location=None, no_listen=False):
if not self._enable_dilate:
raise NotImplementedError
return self._boss.dilate(no_listen) # fires with (endpoints)
return self._boss.dilate(transit_relay_location, no_listen) # fires with (endpoints)
def close(self):
# fails with WormholeError unless we established a connection