Merge commit 'dilate-half-close'
Add half-close to dilation API. Fix several bugs exposed by application-level testing. refs #344 but doesn't close it yet: I want to call `p.connectionLost()` after both directions have been closed down
This commit is contained in:
commit
ab5fe65c3b
|
@ -205,8 +205,8 @@ class Boss(object):
|
|||
self._did_start_code = True
|
||||
self._C.set_code(code)
|
||||
|
||||
def dilate(self, no_listen=False):
|
||||
return self._D.dilate(no_listen=no_listen) # fires with endpoints
|
||||
def dilate(self, transit_relay_location=None, no_listen=False):
|
||||
return self._D.dilate(transit_relay_location, no_listen=no_listen) # fires with endpoints
|
||||
|
||||
@m.input()
|
||||
def send(self, plaintext):
|
||||
|
|
|
@ -10,7 +10,7 @@ from twisted.internet.defer import DeferredList, CancelledError
|
|||
from twisted.internet.endpoints import serverFromString
|
||||
from twisted.internet.protocol import ClientFactory, ServerFactory
|
||||
from twisted.internet.address import HostnameAddress, IPv4Address, IPv6Address
|
||||
from twisted.internet.error import ConnectingCancelledError
|
||||
from twisted.internet.error import ConnectingCancelledError, ConnectionRefusedError, DNSLookupError
|
||||
from twisted.python import log
|
||||
from .. import ipaddrs # TODO: move into _dilation/
|
||||
from .._interfaces import IDilationConnector, IDilationManager
|
||||
|
@ -28,7 +28,9 @@ from ._noise import NoiseConnection
|
|||
|
||||
def build_sided_relay_handshake(key, side):
|
||||
assert isinstance(side, type(u""))
|
||||
assert len(side) == 8 * 2
|
||||
# magic-wormhole-transit-relay expects a specific layout for the
|
||||
# handshake message: "please relay {64} for side {16}\n"
|
||||
assert len(side) == 8 * 2, side
|
||||
token = HKDF(key, 32, CTXinfo=b"transit_relay_token")
|
||||
return (b"please relay " + hexlify(token) +
|
||||
b" for side " + side.encode("ascii") + b"\n")
|
||||
|
@ -310,7 +312,13 @@ class Connector(object):
|
|||
d = deferLater(self._reactor, delay,
|
||||
self._connect, ep, desc, is_relay)
|
||||
d.addErrback(lambda f: f.trap(ConnectingCancelledError,
|
||||
CancelledError))
|
||||
ConnectionRefusedError,
|
||||
CancelledError,
|
||||
))
|
||||
# TODO: HostnameEndpoint.connect catches CancelledError and replaces
|
||||
# it with DNSLookupError. Remove this workaround when
|
||||
# https://twistedmatrix.com/trac/ticket/9696 is fixed.
|
||||
d.addErrback(lambda f: f.trap(DNSLookupError))
|
||||
d.addErrback(log.err)
|
||||
self._pending_connectors.add(d)
|
||||
|
||||
|
|
|
@ -65,7 +65,7 @@ class EndpointRecord(Sequence):
|
|||
return (self.control, self.connect, self.listen)[n]
|
||||
|
||||
def make_side():
|
||||
return bytes_to_hexstr(os.urandom(6))
|
||||
return bytes_to_hexstr(os.urandom(8))
|
||||
|
||||
|
||||
# new scheme:
|
||||
|
@ -552,6 +552,7 @@ class Manager(object):
|
|||
ABANDONING.upon(rx_HINTS, enter=ABANDONING, outputs=[]) # shouldn't happen
|
||||
STOPPING.upon(rx_HINTS, enter=STOPPING, outputs=[])
|
||||
|
||||
WAITING.upon(stop, enter=STOPPED, outputs=[notify_stopped])
|
||||
WANTING.upon(stop, enter=STOPPED, outputs=[notify_stopped])
|
||||
CONNECTING.upon(stop, enter=STOPPED, outputs=[stop_connecting, notify_stopped])
|
||||
CONNECTED.upon(stop, enter=STOPPING, outputs=[abandon_connection])
|
||||
|
|
|
@ -6,6 +6,7 @@ from zope.interface import implementer
|
|||
from twisted.internet.defer import inlineCallbacks, returnValue
|
||||
from twisted.internet.interfaces import (ITransport, IProducer, IConsumer,
|
||||
IAddress, IListeningPort,
|
||||
IHalfCloseableProtocol,
|
||||
IStreamClientEndpoint,
|
||||
IStreamServerEndpoint)
|
||||
from twisted.internet.error import ConnectionDone
|
||||
|
@ -55,6 +56,11 @@ class SingleUseEndpointError(Exception):
|
|||
class AlreadyClosedError(Exception):
|
||||
pass
|
||||
|
||||
class NormalCloseUsedOnHalfCloseable(Exception):
|
||||
pass
|
||||
class HalfCloseUsedOnNonHalfCloseable(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@implementer(IAddress)
|
||||
class _WormholeAddress(object):
|
||||
|
@ -87,11 +93,29 @@ class SubChannel(object):
|
|||
# self._pending_outbound = {}
|
||||
# self._processed = set()
|
||||
self._protocol = None
|
||||
self._pending_dataReceived = []
|
||||
self._pending_connectionLost = (False, None)
|
||||
self._pending_remote_data = []
|
||||
self._pending_remote_close = False
|
||||
|
||||
@m.state(initial=True)
|
||||
def open(self):
|
||||
def unconnected(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
# once we get the IProtocol, it's either a IHalfCloseableProtocol, or it
|
||||
# can only be fully closed
|
||||
@m.state()
|
||||
def open_half(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.state()
|
||||
def read_closed():
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.state()
|
||||
def write_closed():
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.state()
|
||||
def open_full(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.state()
|
||||
|
@ -102,6 +126,14 @@ class SubChannel(object):
|
|||
def closed():
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.input()
|
||||
def connect_protocol_half(self):
|
||||
pass
|
||||
|
||||
@m.input()
|
||||
def connect_protocol_full(self):
|
||||
pass
|
||||
|
||||
@m.input()
|
||||
def remote_data(self, data):
|
||||
pass
|
||||
|
@ -118,6 +150,14 @@ class SubChannel(object):
|
|||
def local_close(self):
|
||||
pass
|
||||
|
||||
@m.output()
|
||||
def queue_remote_data(self, data):
|
||||
self._pending_remote_data.append(data)
|
||||
|
||||
@m.output()
|
||||
def queue_remote_close(self):
|
||||
self._pending_remote_close = True
|
||||
|
||||
@m.output()
|
||||
def send_data(self, data):
|
||||
self._manager.send_data(self._scid, data)
|
||||
|
@ -128,17 +168,24 @@ class SubChannel(object):
|
|||
|
||||
@m.output()
|
||||
def signal_dataReceived(self, data):
|
||||
if self._protocol:
|
||||
self._protocol.dataReceived(data)
|
||||
else:
|
||||
self._pending_dataReceived.append(data)
|
||||
assert self._protocol
|
||||
self._protocol.dataReceived(data)
|
||||
|
||||
@m.output()
|
||||
def signal_readConnectionLost(self):
|
||||
IHalfCloseableProtocol(self._protocol).readConnectionLost()
|
||||
|
||||
@m.output()
|
||||
def signal_writeConnectionLost(self):
|
||||
IHalfCloseableProtocol(self._protocol).writeConnectionLost()
|
||||
|
||||
@m.output()
|
||||
def signal_connectionLost(self):
|
||||
if self._protocol:
|
||||
self._protocol.connectionLost(ConnectionDone())
|
||||
else:
|
||||
self._pending_connectionLost = (True, ConnectionDone())
|
||||
assert self._protocol
|
||||
self._protocol.connectionLost(ConnectionDone())
|
||||
|
||||
@m.output()
|
||||
def close_subchannel(self):
|
||||
self._manager.subchannel_closed(self._scid, self)
|
||||
# we're deleted momentarily
|
||||
|
||||
|
@ -151,14 +198,44 @@ class SubChannel(object):
|
|||
raise AlreadyClosedError(
|
||||
"loseConnection not allowed on closed subchannel")
|
||||
|
||||
# primary transitions
|
||||
open.upon(remote_data, enter=open, outputs=[signal_dataReceived])
|
||||
open.upon(local_data, enter=open, outputs=[send_data])
|
||||
open.upon(remote_close, enter=closed, outputs=[send_close, signal_connectionLost])
|
||||
open.upon(local_close, enter=closing, outputs=[send_close])
|
||||
closing.upon(remote_data, enter=closing, outputs=[signal_dataReceived])
|
||||
closing.upon(remote_close, enter=closed, outputs=[signal_connectionLost])
|
||||
# stuff that arrives before we have a protocol connected
|
||||
unconnected.upon(remote_data, enter=unconnected, outputs=[queue_remote_data])
|
||||
unconnected.upon(remote_close, enter=unconnected, outputs=[queue_remote_close])
|
||||
|
||||
# IHalfCloseableProtocol flow
|
||||
unconnected.upon(connect_protocol_half, enter=open_half, outputs=[])
|
||||
open_half.upon(remote_data, enter=open_half, outputs=[signal_dataReceived])
|
||||
open_half.upon(local_data, enter=open_half, outputs=[send_data])
|
||||
# remote closes first
|
||||
open_half.upon(remote_close, enter=read_closed, outputs=[signal_readConnectionLost])
|
||||
read_closed.upon(local_data, enter=read_closed, outputs=[send_data])
|
||||
read_closed.upon(local_close, enter=closed, outputs=[send_close,
|
||||
close_subchannel,
|
||||
# TODO: eventual-signal this?
|
||||
signal_writeConnectionLost,
|
||||
])
|
||||
# local closes first
|
||||
open_half.upon(local_close, enter=write_closed, outputs=[signal_writeConnectionLost,
|
||||
send_close])
|
||||
write_closed.upon(local_data, enter=write_closed, outputs=[error_closed_write])
|
||||
write_closed.upon(remote_data, enter=write_closed, outputs=[signal_dataReceived])
|
||||
write_closed.upon(remote_close, enter=closed, outputs=[close_subchannel,
|
||||
signal_readConnectionLost,
|
||||
])
|
||||
# error cases
|
||||
write_closed.upon(local_close, enter=write_closed, outputs=[error_closed_close])
|
||||
|
||||
# fully-closeable-only flow
|
||||
unconnected.upon(connect_protocol_full, enter=open_full, outputs=[])
|
||||
open_full.upon(remote_data, enter=open_full, outputs=[signal_dataReceived])
|
||||
open_full.upon(local_data, enter=open_full, outputs=[send_data])
|
||||
open_full.upon(remote_close, enter=closed, outputs=[send_close,
|
||||
close_subchannel,
|
||||
signal_connectionLost])
|
||||
open_full.upon(local_close, enter=closing, outputs=[send_close])
|
||||
closing.upon(remote_data, enter=closing, outputs=[signal_dataReceived])
|
||||
closing.upon(remote_close, enter=closed, outputs=[close_subchannel,
|
||||
signal_connectionLost])
|
||||
# error cases
|
||||
# we won't ever see an OPEN, since L4 will log+ignore those for us
|
||||
closing.upon(local_data, enter=closing, outputs=[error_closed_write])
|
||||
|
@ -170,15 +247,19 @@ class SubChannel(object):
|
|||
def _set_protocol(self, protocol):
|
||||
assert not self._protocol
|
||||
self._protocol = protocol
|
||||
if IHalfCloseableProtocol.providedBy(protocol):
|
||||
self.connect_protocol_half()
|
||||
else:
|
||||
# move from UNCONNECTED to OPEN
|
||||
self.connect_protocol_full();
|
||||
|
||||
def _deliver_queued_data(self):
|
||||
if self._pending_dataReceived:
|
||||
for data in self._pending_dataReceived:
|
||||
self._protocol.dataReceived(data)
|
||||
self._pending_dataReceived = []
|
||||
cl, what = self._pending_connectionLost
|
||||
if cl:
|
||||
self._protocol.connectionLost(what)
|
||||
for data in self._pending_remote_data:
|
||||
self.remote_data(data)
|
||||
del self._pending_remote_data
|
||||
if self._pending_remote_close:
|
||||
self.remote_close()
|
||||
del self._pending_remote_close
|
||||
|
||||
# ITransport
|
||||
def write(self, data):
|
||||
|
@ -189,7 +270,18 @@ class SubChannel(object):
|
|||
def writeSequence(self, iovec):
|
||||
self.write(b"".join(iovec))
|
||||
|
||||
def loseWriteConnection(self):
|
||||
if not IHalfCloseableProtocol.providedBy(self._protocol):
|
||||
# this is a clear error
|
||||
raise HalfCloseUsedOnNonHalfCloseable()
|
||||
self.local_close();
|
||||
|
||||
def loseConnection(self):
|
||||
# TODO: what happens if an IHalfCloseableProtocol calls normal
|
||||
# loseConnection()? I think we need to close the read side too.
|
||||
if IHalfCloseableProtocol.providedBy(self._protocol):
|
||||
# I don't know is correct, so avoid this for now
|
||||
raise NormalCloseUsedOnHalfCloseable()
|
||||
self.local_close()
|
||||
|
||||
def getHost(self):
|
||||
|
|
|
@ -216,7 +216,7 @@ class TestManager(unittest.TestCase):
|
|||
def test_make_side(self):
|
||||
side = make_side()
|
||||
self.assertEqual(type(side), type(u""))
|
||||
self.assertEqual(len(side), 2 * 6)
|
||||
self.assertEqual(len(side), 2 * 8)
|
||||
|
||||
def test_create(self):
|
||||
m, h = make_manager()
|
||||
|
|
|
@ -1,21 +1,25 @@
|
|||
from __future__ import print_function, unicode_literals
|
||||
import mock
|
||||
from zope.interface import directlyProvides
|
||||
from twisted.trial import unittest
|
||||
from twisted.internet.interfaces import ITransport
|
||||
from twisted.internet.interfaces import ITransport, IHalfCloseableProtocol
|
||||
from twisted.internet.error import ConnectionDone
|
||||
from ..._dilation.subchannel import (Once, SubChannel,
|
||||
_WormholeAddress, _SubchannelAddress,
|
||||
AlreadyClosedError)
|
||||
AlreadyClosedError,
|
||||
NormalCloseUsedOnHalfCloseable)
|
||||
from .common import mock_manager
|
||||
|
||||
|
||||
def make_sc(set_protocol=True):
|
||||
def make_sc(set_protocol=True, half_closeable=False):
|
||||
scid = 4
|
||||
hostaddr = _WormholeAddress()
|
||||
peeraddr = _SubchannelAddress(scid)
|
||||
m = mock_manager()
|
||||
sc = SubChannel(scid, m, hostaddr, peeraddr)
|
||||
p = mock.Mock()
|
||||
if half_closeable:
|
||||
directlyProvides(p, IHalfCloseableProtocol)
|
||||
if set_protocol:
|
||||
sc._set_protocol(p)
|
||||
return sc, m, scid, hostaddr, peeraddr, p
|
||||
|
@ -109,11 +113,13 @@ class SubChannelAPI(unittest.TestCase):
|
|||
|
||||
def test_data_before_open(self):
|
||||
sc, m, scid, hostaddr, peeraddr, p = make_sc(set_protocol=False)
|
||||
sc.remote_data(b"data")
|
||||
sc.remote_data(b"data1")
|
||||
sc.remote_data(b"data2")
|
||||
self.assertEqual(p.mock_calls, [])
|
||||
sc._set_protocol(p)
|
||||
sc._deliver_queued_data()
|
||||
self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"data")])
|
||||
self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"data1"),
|
||||
mock.call.dataReceived(b"data2")])
|
||||
p.mock_calls[:] = []
|
||||
sc.remote_data(b"more")
|
||||
self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"more")])
|
||||
|
@ -145,3 +151,86 @@ class SubChannelAPI(unittest.TestCase):
|
|||
# TODO: more, once this is implemented
|
||||
sc.registerProducer(None, True)
|
||||
sc.unregisterProducer()
|
||||
|
||||
class HalfCloseable(unittest.TestCase):
|
||||
|
||||
def test_create(self):
|
||||
sc, m, scid, hostaddr, peeraddr, p = make_sc(half_closeable=True)
|
||||
self.assert_(ITransport.providedBy(sc))
|
||||
self.assertEqual(m.mock_calls, [])
|
||||
self.assertIdentical(sc.getHost(), hostaddr)
|
||||
self.assertIdentical(sc.getPeer(), peeraddr)
|
||||
|
||||
def test_local_close(self):
|
||||
sc, m, scid, hostaddr, peeraddr, p = make_sc(half_closeable=True)
|
||||
|
||||
sc.write(b"data")
|
||||
self.assertEqual(m.mock_calls, [mock.call.send_data(scid, b"data")])
|
||||
m.mock_calls[:] = []
|
||||
sc.writeSequence([b"more", b"data"])
|
||||
self.assertEqual(m.mock_calls, [mock.call.send_data(scid, b"moredata")])
|
||||
m.mock_calls[:] = []
|
||||
|
||||
sc.remote_data(b"inbound1")
|
||||
self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"inbound1")])
|
||||
p.mock_calls[:] = []
|
||||
|
||||
with self.assertRaises(NormalCloseUsedOnHalfCloseable) as e:
|
||||
sc.loseConnection() # TODO: maybe this shouldn't be an error
|
||||
|
||||
# after a local close, we can't write anymore, but we can still
|
||||
# receive data
|
||||
sc.loseWriteConnection() # TODO or loseConnection?
|
||||
self.assertEqual(m.mock_calls, [mock.call.send_close(scid)])
|
||||
m.mock_calls[:] = []
|
||||
self.assertEqual(p.mock_calls, [mock.call.writeConnectionLost()])
|
||||
p.mock_calls[:] = []
|
||||
|
||||
with self.assertRaises(AlreadyClosedError) as e:
|
||||
sc.write(b"data")
|
||||
self.assertEqual(str(e.exception),
|
||||
"write not allowed on closed subchannel")
|
||||
|
||||
with self.assertRaises(AlreadyClosedError) as e:
|
||||
sc.loseWriteConnection()
|
||||
self.assertEqual(str(e.exception),
|
||||
"loseConnection not allowed on closed subchannel")
|
||||
|
||||
with self.assertRaises(NormalCloseUsedOnHalfCloseable) as e:
|
||||
sc.loseConnection() # TODO: maybe expect AlreadyClosedError
|
||||
|
||||
sc.remote_data(b"inbound2")
|
||||
self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"inbound2")])
|
||||
p.mock_calls[:] = []
|
||||
|
||||
# the remote end will finally shut down the connection
|
||||
sc.remote_close()
|
||||
self.assertEqual(m.mock_calls, [mock.call.subchannel_closed(scid, sc)])
|
||||
self.assertEqual(p.mock_calls, [mock.call.readConnectionLost()])
|
||||
|
||||
def test_remote_close(self):
|
||||
sc, m, scid, hostaddr, peeraddr, p = make_sc(half_closeable=True)
|
||||
|
||||
sc.write(b"data")
|
||||
self.assertEqual(m.mock_calls, [mock.call.send_data(scid, b"data")])
|
||||
m.mock_calls[:] = []
|
||||
|
||||
sc.remote_data(b"inbound1")
|
||||
self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"inbound1")])
|
||||
p.mock_calls[:] = []
|
||||
|
||||
# after a remote close, we can still write data
|
||||
sc.remote_close()
|
||||
self.assertEqual(m.mock_calls, [])
|
||||
self.assertEqual(p.mock_calls, [mock.call.readConnectionLost()])
|
||||
p.mock_calls[:] = []
|
||||
|
||||
sc.write(b"out2")
|
||||
self.assertEqual(m.mock_calls, [mock.call.send_data(scid, b"out2")])
|
||||
m.mock_calls[:] = []
|
||||
|
||||
# and a local close will shutdown the connection
|
||||
sc.loseWriteConnection()
|
||||
self.assertEqual(m.mock_calls, [mock.call.send_close(scid),
|
||||
mock.call.subchannel_closed(scid, sc)])
|
||||
self.assertEqual(p.mock_calls, [mock.call.writeConnectionLost()])
|
||||
|
|
|
@ -193,10 +193,10 @@ class _DeferredWormhole(object):
|
|||
raise NoKeyError()
|
||||
return derive_key(self._key, to_bytes(purpose), length)
|
||||
|
||||
def dilate(self, no_listen=False):
|
||||
def dilate(self, transit_relay_location=None, no_listen=False):
|
||||
if not self._enable_dilate:
|
||||
raise NotImplementedError
|
||||
return self._boss.dilate(no_listen) # fires with (endpoints)
|
||||
return self._boss.dilate(transit_relay_location, no_listen) # fires with (endpoints)
|
||||
|
||||
def close(self):
|
||||
# fails with WormholeError unless we established a connection
|
||||
|
|
Loading…
Reference in New Issue
Block a user