diff --git a/src/wormhole/_boss.py b/src/wormhole/_boss.py index ec6c845..a5cb790 100644 --- a/src/wormhole/_boss.py +++ b/src/wormhole/_boss.py @@ -205,8 +205,8 @@ class Boss(object): self._did_start_code = True self._C.set_code(code) - def dilate(self, no_listen=False): - return self._D.dilate(no_listen=no_listen) # fires with endpoints + def dilate(self, transit_relay_location=None, no_listen=False): + return self._D.dilate(transit_relay_location, no_listen=no_listen) # fires with endpoints @m.input() def send(self, plaintext): diff --git a/src/wormhole/_dilation/connector.py b/src/wormhole/_dilation/connector.py index 320ae43..b421022 100644 --- a/src/wormhole/_dilation/connector.py +++ b/src/wormhole/_dilation/connector.py @@ -10,7 +10,7 @@ from twisted.internet.defer import DeferredList, CancelledError from twisted.internet.endpoints import serverFromString from twisted.internet.protocol import ClientFactory, ServerFactory from twisted.internet.address import HostnameAddress, IPv4Address, IPv6Address -from twisted.internet.error import ConnectingCancelledError +from twisted.internet.error import ConnectingCancelledError, ConnectionRefusedError, DNSLookupError from twisted.python import log from .. import ipaddrs # TODO: move into _dilation/ from .._interfaces import IDilationConnector, IDilationManager @@ -28,7 +28,9 @@ from ._noise import NoiseConnection def build_sided_relay_handshake(key, side): assert isinstance(side, type(u"")) - assert len(side) == 8 * 2 + # magic-wormhole-transit-relay expects a specific layout for the + # handshake message: "please relay {64} for side {16}\n" + assert len(side) == 8 * 2, side token = HKDF(key, 32, CTXinfo=b"transit_relay_token") return (b"please relay " + hexlify(token) + b" for side " + side.encode("ascii") + b"\n") @@ -310,7 +312,13 @@ class Connector(object): d = deferLater(self._reactor, delay, self._connect, ep, desc, is_relay) d.addErrback(lambda f: f.trap(ConnectingCancelledError, - CancelledError)) + ConnectionRefusedError, + CancelledError, + )) + # TODO: HostnameEndpoint.connect catches CancelledError and replaces + # it with DNSLookupError. Remove this workaround when + # https://twistedmatrix.com/trac/ticket/9696 is fixed. + d.addErrback(lambda f: f.trap(DNSLookupError)) d.addErrback(log.err) self._pending_connectors.add(d) diff --git a/src/wormhole/_dilation/manager.py b/src/wormhole/_dilation/manager.py index 19430bb..096b459 100644 --- a/src/wormhole/_dilation/manager.py +++ b/src/wormhole/_dilation/manager.py @@ -65,7 +65,7 @@ class EndpointRecord(Sequence): return (self.control, self.connect, self.listen)[n] def make_side(): - return bytes_to_hexstr(os.urandom(6)) + return bytes_to_hexstr(os.urandom(8)) # new scheme: @@ -552,6 +552,7 @@ class Manager(object): ABANDONING.upon(rx_HINTS, enter=ABANDONING, outputs=[]) # shouldn't happen STOPPING.upon(rx_HINTS, enter=STOPPING, outputs=[]) + WAITING.upon(stop, enter=STOPPED, outputs=[notify_stopped]) WANTING.upon(stop, enter=STOPPED, outputs=[notify_stopped]) CONNECTING.upon(stop, enter=STOPPED, outputs=[stop_connecting, notify_stopped]) CONNECTED.upon(stop, enter=STOPPING, outputs=[abandon_connection]) diff --git a/src/wormhole/_dilation/subchannel.py b/src/wormhole/_dilation/subchannel.py index ee10e97..cca0d3c 100644 --- a/src/wormhole/_dilation/subchannel.py +++ b/src/wormhole/_dilation/subchannel.py @@ -6,6 +6,7 @@ from zope.interface import implementer from twisted.internet.defer import inlineCallbacks, returnValue from twisted.internet.interfaces import (ITransport, IProducer, IConsumer, IAddress, IListeningPort, + IHalfCloseableProtocol, IStreamClientEndpoint, IStreamServerEndpoint) from twisted.internet.error import ConnectionDone @@ -55,6 +56,11 @@ class SingleUseEndpointError(Exception): class AlreadyClosedError(Exception): pass +class NormalCloseUsedOnHalfCloseable(Exception): + pass +class HalfCloseUsedOnNonHalfCloseable(Exception): + pass + @implementer(IAddress) class _WormholeAddress(object): @@ -87,11 +93,29 @@ class SubChannel(object): # self._pending_outbound = {} # self._processed = set() self._protocol = None - self._pending_dataReceived = [] - self._pending_connectionLost = (False, None) + self._pending_remote_data = [] + self._pending_remote_close = False @m.state(initial=True) - def open(self): + def unconnected(self): + pass # pragma: no cover + + # once we get the IProtocol, it's either a IHalfCloseableProtocol, or it + # can only be fully closed + @m.state() + def open_half(self): + pass # pragma: no cover + + @m.state() + def read_closed(): + pass # pragma: no cover + + @m.state() + def write_closed(): + pass # pragma: no cover + + @m.state() + def open_full(self): pass # pragma: no cover @m.state() @@ -102,6 +126,14 @@ class SubChannel(object): def closed(): pass # pragma: no cover + @m.input() + def connect_protocol_half(self): + pass + + @m.input() + def connect_protocol_full(self): + pass + @m.input() def remote_data(self, data): pass @@ -118,6 +150,14 @@ class SubChannel(object): def local_close(self): pass + @m.output() + def queue_remote_data(self, data): + self._pending_remote_data.append(data) + + @m.output() + def queue_remote_close(self): + self._pending_remote_close = True + @m.output() def send_data(self, data): self._manager.send_data(self._scid, data) @@ -128,17 +168,24 @@ class SubChannel(object): @m.output() def signal_dataReceived(self, data): - if self._protocol: - self._protocol.dataReceived(data) - else: - self._pending_dataReceived.append(data) + assert self._protocol + self._protocol.dataReceived(data) + + @m.output() + def signal_readConnectionLost(self): + IHalfCloseableProtocol(self._protocol).readConnectionLost() + + @m.output() + def signal_writeConnectionLost(self): + IHalfCloseableProtocol(self._protocol).writeConnectionLost() @m.output() def signal_connectionLost(self): - if self._protocol: - self._protocol.connectionLost(ConnectionDone()) - else: - self._pending_connectionLost = (True, ConnectionDone()) + assert self._protocol + self._protocol.connectionLost(ConnectionDone()) + + @m.output() + def close_subchannel(self): self._manager.subchannel_closed(self._scid, self) # we're deleted momentarily @@ -151,14 +198,44 @@ class SubChannel(object): raise AlreadyClosedError( "loseConnection not allowed on closed subchannel") - # primary transitions - open.upon(remote_data, enter=open, outputs=[signal_dataReceived]) - open.upon(local_data, enter=open, outputs=[send_data]) - open.upon(remote_close, enter=closed, outputs=[send_close, signal_connectionLost]) - open.upon(local_close, enter=closing, outputs=[send_close]) - closing.upon(remote_data, enter=closing, outputs=[signal_dataReceived]) - closing.upon(remote_close, enter=closed, outputs=[signal_connectionLost]) + # stuff that arrives before we have a protocol connected + unconnected.upon(remote_data, enter=unconnected, outputs=[queue_remote_data]) + unconnected.upon(remote_close, enter=unconnected, outputs=[queue_remote_close]) + # IHalfCloseableProtocol flow + unconnected.upon(connect_protocol_half, enter=open_half, outputs=[]) + open_half.upon(remote_data, enter=open_half, outputs=[signal_dataReceived]) + open_half.upon(local_data, enter=open_half, outputs=[send_data]) + # remote closes first + open_half.upon(remote_close, enter=read_closed, outputs=[signal_readConnectionLost]) + read_closed.upon(local_data, enter=read_closed, outputs=[send_data]) + read_closed.upon(local_close, enter=closed, outputs=[send_close, + close_subchannel, + # TODO: eventual-signal this? + signal_writeConnectionLost, + ]) + # local closes first + open_half.upon(local_close, enter=write_closed, outputs=[signal_writeConnectionLost, + send_close]) + write_closed.upon(local_data, enter=write_closed, outputs=[error_closed_write]) + write_closed.upon(remote_data, enter=write_closed, outputs=[signal_dataReceived]) + write_closed.upon(remote_close, enter=closed, outputs=[close_subchannel, + signal_readConnectionLost, + ]) + # error cases + write_closed.upon(local_close, enter=write_closed, outputs=[error_closed_close]) + + # fully-closeable-only flow + unconnected.upon(connect_protocol_full, enter=open_full, outputs=[]) + open_full.upon(remote_data, enter=open_full, outputs=[signal_dataReceived]) + open_full.upon(local_data, enter=open_full, outputs=[send_data]) + open_full.upon(remote_close, enter=closed, outputs=[send_close, + close_subchannel, + signal_connectionLost]) + open_full.upon(local_close, enter=closing, outputs=[send_close]) + closing.upon(remote_data, enter=closing, outputs=[signal_dataReceived]) + closing.upon(remote_close, enter=closed, outputs=[close_subchannel, + signal_connectionLost]) # error cases # we won't ever see an OPEN, since L4 will log+ignore those for us closing.upon(local_data, enter=closing, outputs=[error_closed_write]) @@ -170,15 +247,19 @@ class SubChannel(object): def _set_protocol(self, protocol): assert not self._protocol self._protocol = protocol + if IHalfCloseableProtocol.providedBy(protocol): + self.connect_protocol_half() + else: + # move from UNCONNECTED to OPEN + self.connect_protocol_full(); def _deliver_queued_data(self): - if self._pending_dataReceived: - for data in self._pending_dataReceived: - self._protocol.dataReceived(data) - self._pending_dataReceived = [] - cl, what = self._pending_connectionLost - if cl: - self._protocol.connectionLost(what) + for data in self._pending_remote_data: + self.remote_data(data) + del self._pending_remote_data + if self._pending_remote_close: + self.remote_close() + del self._pending_remote_close # ITransport def write(self, data): @@ -189,7 +270,18 @@ class SubChannel(object): def writeSequence(self, iovec): self.write(b"".join(iovec)) + def loseWriteConnection(self): + if not IHalfCloseableProtocol.providedBy(self._protocol): + # this is a clear error + raise HalfCloseUsedOnNonHalfCloseable() + self.local_close(); + def loseConnection(self): + # TODO: what happens if an IHalfCloseableProtocol calls normal + # loseConnection()? I think we need to close the read side too. + if IHalfCloseableProtocol.providedBy(self._protocol): + # I don't know is correct, so avoid this for now + raise NormalCloseUsedOnHalfCloseable() self.local_close() def getHost(self): diff --git a/src/wormhole/test/dilate/test_manager.py b/src/wormhole/test/dilate/test_manager.py index 3fdd051..d8aa8ab 100644 --- a/src/wormhole/test/dilate/test_manager.py +++ b/src/wormhole/test/dilate/test_manager.py @@ -216,7 +216,7 @@ class TestManager(unittest.TestCase): def test_make_side(self): side = make_side() self.assertEqual(type(side), type(u"")) - self.assertEqual(len(side), 2 * 6) + self.assertEqual(len(side), 2 * 8) def test_create(self): m, h = make_manager() diff --git a/src/wormhole/test/dilate/test_subchannel.py b/src/wormhole/test/dilate/test_subchannel.py index 0705a10..88e6565 100644 --- a/src/wormhole/test/dilate/test_subchannel.py +++ b/src/wormhole/test/dilate/test_subchannel.py @@ -1,21 +1,25 @@ from __future__ import print_function, unicode_literals import mock +from zope.interface import directlyProvides from twisted.trial import unittest -from twisted.internet.interfaces import ITransport +from twisted.internet.interfaces import ITransport, IHalfCloseableProtocol from twisted.internet.error import ConnectionDone from ..._dilation.subchannel import (Once, SubChannel, _WormholeAddress, _SubchannelAddress, - AlreadyClosedError) + AlreadyClosedError, + NormalCloseUsedOnHalfCloseable) from .common import mock_manager -def make_sc(set_protocol=True): +def make_sc(set_protocol=True, half_closeable=False): scid = 4 hostaddr = _WormholeAddress() peeraddr = _SubchannelAddress(scid) m = mock_manager() sc = SubChannel(scid, m, hostaddr, peeraddr) p = mock.Mock() + if half_closeable: + directlyProvides(p, IHalfCloseableProtocol) if set_protocol: sc._set_protocol(p) return sc, m, scid, hostaddr, peeraddr, p @@ -109,11 +113,13 @@ class SubChannelAPI(unittest.TestCase): def test_data_before_open(self): sc, m, scid, hostaddr, peeraddr, p = make_sc(set_protocol=False) - sc.remote_data(b"data") + sc.remote_data(b"data1") + sc.remote_data(b"data2") self.assertEqual(p.mock_calls, []) sc._set_protocol(p) sc._deliver_queued_data() - self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"data")]) + self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"data1"), + mock.call.dataReceived(b"data2")]) p.mock_calls[:] = [] sc.remote_data(b"more") self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"more")]) @@ -145,3 +151,86 @@ class SubChannelAPI(unittest.TestCase): # TODO: more, once this is implemented sc.registerProducer(None, True) sc.unregisterProducer() + +class HalfCloseable(unittest.TestCase): + + def test_create(self): + sc, m, scid, hostaddr, peeraddr, p = make_sc(half_closeable=True) + self.assert_(ITransport.providedBy(sc)) + self.assertEqual(m.mock_calls, []) + self.assertIdentical(sc.getHost(), hostaddr) + self.assertIdentical(sc.getPeer(), peeraddr) + + def test_local_close(self): + sc, m, scid, hostaddr, peeraddr, p = make_sc(half_closeable=True) + + sc.write(b"data") + self.assertEqual(m.mock_calls, [mock.call.send_data(scid, b"data")]) + m.mock_calls[:] = [] + sc.writeSequence([b"more", b"data"]) + self.assertEqual(m.mock_calls, [mock.call.send_data(scid, b"moredata")]) + m.mock_calls[:] = [] + + sc.remote_data(b"inbound1") + self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"inbound1")]) + p.mock_calls[:] = [] + + with self.assertRaises(NormalCloseUsedOnHalfCloseable) as e: + sc.loseConnection() # TODO: maybe this shouldn't be an error + + # after a local close, we can't write anymore, but we can still + # receive data + sc.loseWriteConnection() # TODO or loseConnection? + self.assertEqual(m.mock_calls, [mock.call.send_close(scid)]) + m.mock_calls[:] = [] + self.assertEqual(p.mock_calls, [mock.call.writeConnectionLost()]) + p.mock_calls[:] = [] + + with self.assertRaises(AlreadyClosedError) as e: + sc.write(b"data") + self.assertEqual(str(e.exception), + "write not allowed on closed subchannel") + + with self.assertRaises(AlreadyClosedError) as e: + sc.loseWriteConnection() + self.assertEqual(str(e.exception), + "loseConnection not allowed on closed subchannel") + + with self.assertRaises(NormalCloseUsedOnHalfCloseable) as e: + sc.loseConnection() # TODO: maybe expect AlreadyClosedError + + sc.remote_data(b"inbound2") + self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"inbound2")]) + p.mock_calls[:] = [] + + # the remote end will finally shut down the connection + sc.remote_close() + self.assertEqual(m.mock_calls, [mock.call.subchannel_closed(scid, sc)]) + self.assertEqual(p.mock_calls, [mock.call.readConnectionLost()]) + + def test_remote_close(self): + sc, m, scid, hostaddr, peeraddr, p = make_sc(half_closeable=True) + + sc.write(b"data") + self.assertEqual(m.mock_calls, [mock.call.send_data(scid, b"data")]) + m.mock_calls[:] = [] + + sc.remote_data(b"inbound1") + self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"inbound1")]) + p.mock_calls[:] = [] + + # after a remote close, we can still write data + sc.remote_close() + self.assertEqual(m.mock_calls, []) + self.assertEqual(p.mock_calls, [mock.call.readConnectionLost()]) + p.mock_calls[:] = [] + + sc.write(b"out2") + self.assertEqual(m.mock_calls, [mock.call.send_data(scid, b"out2")]) + m.mock_calls[:] = [] + + # and a local close will shutdown the connection + sc.loseWriteConnection() + self.assertEqual(m.mock_calls, [mock.call.send_close(scid), + mock.call.subchannel_closed(scid, sc)]) + self.assertEqual(p.mock_calls, [mock.call.writeConnectionLost()]) diff --git a/src/wormhole/wormhole.py b/src/wormhole/wormhole.py index 1e5509d..1db12d7 100644 --- a/src/wormhole/wormhole.py +++ b/src/wormhole/wormhole.py @@ -193,10 +193,10 @@ class _DeferredWormhole(object): raise NoKeyError() return derive_key(self._key, to_bytes(purpose), length) - def dilate(self, no_listen=False): + def dilate(self, transit_relay_location=None, no_listen=False): if not self._enable_dilate: raise NotImplementedError - return self._boss.dilate(no_listen) # fires with (endpoints) + return self._boss.dilate(transit_relay_location, no_listen) # fires with (endpoints) def close(self): # fails with WormholeError unless we established a connection