From eb7c04e366b6f54c151aab114f2b40040db861a5 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Tue, 13 Mar 2018 14:20:48 -0700 Subject: [PATCH 01/49] observer.py: add EmptyableSet --- src/wormhole/observer.py | 21 +++++++++++++++++++++ src/wormhole/test/test_observer.py | 27 ++++++++++++++++++++++++++- 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/src/wormhole/observer.py b/src/wormhole/observer.py index 43afa70..34fbf1b 100644 --- a/src/wormhole/observer.py +++ b/src/wormhole/observer.py @@ -70,3 +70,24 @@ class SequenceObserver(object): if self._observers: d = self._observers.pop(0) self._eq.eventually(d.callback, self._results.pop(0)) + + +class EmptyableSet(set): + # manage a set which grows and shrinks over time. Fire a Deferred the first + # time it becomes empty after you start watching for it. + + def __init__(self, *args, **kwargs): + self._eq = kwargs.pop("_eventual_queue") # required + super(EmptyableSet, self).__init__(*args, **kwargs) + self._observer = None + + def when_next_empty(self): + if not self._observer: + self._observer = OneShotObserver(self._eq) + return self._observer.when_fired() + + def discard(self, o): + super(EmptyableSet, self).discard(o) + if self._observer and not self: + self._observer.fire(None) + self._observer = None diff --git a/src/wormhole/test/test_observer.py b/src/wormhole/test/test_observer.py index 8d89ddd..6ca8957 100644 --- a/src/wormhole/test/test_observer.py +++ b/src/wormhole/test/test_observer.py @@ -3,7 +3,7 @@ from twisted.python.failure import Failure from twisted.trial import unittest from ..eventual import EventualQueue -from ..observer import OneShotObserver, SequenceObserver +from ..observer import OneShotObserver, SequenceObserver, EmptyableSet class OneShot(unittest.TestCase): @@ -121,3 +121,28 @@ class Sequence(unittest.TestCase): d2 = o.when_next_event() eq.flush_sync() self.assertIdentical(self.failureResultOf(d2), f) + + +class Empty(unittest.TestCase): + def test_set(self): + eq = EventualQueue(Clock()) + s = EmptyableSet(_eventual_queue=eq) + d1 = s.when_next_empty() + eq.flush_sync() + self.assertNoResult(d1) + s.add(1) + eq.flush_sync() + self.assertNoResult(d1) + s.add(2) + s.discard(1) + d2 = s.when_next_empty() + eq.flush_sync() + self.assertNoResult(d1) + self.assertNoResult(d2) + s.discard(2) + eq.flush_sync() + self.assertEqual(self.successResultOf(d1), None) + self.assertEqual(self.successResultOf(d2), None) + + s.add(3) + s.discard(3) From e260369be1b6548ff97456c43c952357949dde7c Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Tue, 10 Apr 2018 21:33:50 -0700 Subject: [PATCH 02/49] __main__.py: stop breaking automat-visualize --- src/wormhole/__main__.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/wormhole/__main__.py b/src/wormhole/__main__.py index d6c1b97..37e97b7 100644 --- a/src/wormhole/__main__.py +++ b/src/wormhole/__main__.py @@ -1,7 +1,7 @@ from __future__ import absolute_import, print_function, unicode_literals -from .cli import cli - -if __name__ != "__main__": - raise ImportError('this module should not be imported') - -cli.wormhole() +if __name__ == "__main__": + from .cli import cli + cli.wormhole() +else: + # raise ImportError('this module should not be imported') + pass From a693b1fc48783fca938c1dfecdbcfa4657cf43ec Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Sat, 30 Jun 2018 13:56:32 -0700 Subject: [PATCH 03/49] Boss/Receive: add 'side' to got_message this will be used later by Dilation --- src/wormhole/_boss.py | 10 ++++++---- src/wormhole/_receive.py | 14 +++++++------- src/wormhole/test/test_machines.py | 18 +++++++++--------- 3 files changed, 22 insertions(+), 20 deletions(-) diff --git a/src/wormhole/_boss.py b/src/wormhole/_boss.py index 097efe7..686335f 100644 --- a/src/wormhole/_boss.py +++ b/src/wormhole/_boss.py @@ -252,11 +252,12 @@ class Boss(object): def scared(self): pass - def got_message(self, phase, plaintext): + def got_message(self, side, phase, plaintext): + # this is only called for side != ours assert isinstance(phase, type("")), type(phase) assert isinstance(plaintext, type(b"")), type(plaintext) if phase == "version": - self._got_version(plaintext) + self._got_version(side, plaintext) elif re.search(r'^\d+$', phase): self._got_phase(int(phase), plaintext) else: @@ -265,7 +266,7 @@ class Boss(object): log.err(_UnknownPhaseError("received unknown phase '%s'" % phase)) @m.input() - def _got_version(self, plaintext): + def _got_version(self, side, plaintext): pass @m.input() @@ -290,9 +291,10 @@ class Boss(object): self._W.got_code(code) @m.output() - def process_version(self, plaintext): + def process_version(self, side, plaintext): # most of this is wormhole-to-wormhole, ignored for now # in the future, this is how Dilation is signalled + self._their_side = side self._their_versions = bytes_to_dict(plaintext) # but this part is app-to-app app_versions = self._their_versions.get("app_versions", {}) diff --git a/src/wormhole/_receive.py b/src/wormhole/_receive.py index 8e9de4f..832dc44 100644 --- a/src/wormhole/_receive.py +++ b/src/wormhole/_receive.py @@ -53,10 +53,10 @@ class Receive(object): except CryptoError: self.got_message_bad() return - self.got_message_good(phase, plaintext) + self.got_message_good(side, phase, plaintext) @m.input() - def got_message_good(self, phase, plaintext): + def got_message_good(self, side, phase, plaintext): pass @m.input() @@ -73,23 +73,23 @@ class Receive(object): self._key = key @m.output() - def S_got_verified_key(self, phase, plaintext): + def S_got_verified_key(self, side, phase, plaintext): assert self._key self._S.got_verified_key(self._key) @m.output() - def W_happy(self, phase, plaintext): + def W_happy(self, side, phase, plaintext): self._B.happy() @m.output() - def W_got_verifier(self, phase, plaintext): + def W_got_verifier(self, side, phase, plaintext): self._B.got_verifier(derive_key(self._key, b"wormhole:verifier")) @m.output() - def W_got_message(self, phase, plaintext): + def W_got_message(self, side, phase, plaintext): assert isinstance(phase, type("")), type(phase) assert isinstance(plaintext, type(b"")), type(plaintext) - self._B.got_message(phase, plaintext) + self._B.got_message(side, phase, plaintext) @m.output() def W_scared(self): diff --git a/src/wormhole/test/test_machines.py b/src/wormhole/test/test_machines.py index 860b0b8..8a17b76 100644 --- a/src/wormhole/test/test_machines.py +++ b/src/wormhole/test/test_machines.py @@ -167,7 +167,7 @@ class Receive(unittest.TestCase): ("s.got_verified_key", key), ("b.happy", ), ("b.got_verifier", verifier), - ("b.got_message", u"phase1", data1), + ("b.got_message", u"side", u"phase1", data1), ]) phase2_key = derive_phase_key(key, u"side", u"phase2") @@ -178,8 +178,8 @@ class Receive(unittest.TestCase): ("s.got_verified_key", key), ("b.happy", ), ("b.got_verifier", verifier), - ("b.got_message", u"phase1", data1), - ("b.got_message", u"phase2", data2), + ("b.got_message", u"side", u"phase1", data1), + ("b.got_message", u"side", u"phase2", data2), ]) def test_early_bad(self): @@ -217,7 +217,7 @@ class Receive(unittest.TestCase): ("s.got_verified_key", key), ("b.happy", ), ("b.got_verifier", verifier), - ("b.got_message", u"phase1", data1), + ("b.got_message", u"side", u"phase1", data1), ]) phase2_key = derive_phase_key(key, u"side", u"bad") @@ -228,7 +228,7 @@ class Receive(unittest.TestCase): ("s.got_verified_key", key), ("b.happy", ), ("b.got_verifier", verifier), - ("b.got_message", u"phase1", data1), + ("b.got_message", u"side", u"phase1", data1), ("b.scared", ), ]) r.got_message(u"side", u"phase1", good_body) @@ -237,7 +237,7 @@ class Receive(unittest.TestCase): ("s.got_verified_key", key), ("b.happy", ), ("b.got_verifier", verifier), - ("b.got_message", u"phase1", data1), + ("b.got_message", u"side", u"phase1", data1), ("b.scared", ), ]) @@ -1320,8 +1320,8 @@ class Boss(unittest.TestCase): b.got_key(b"key") b.happy() b.got_verifier(b"verifier") - b.got_message("version", b"{}") - b.got_message("0", b"msg1") + b.got_message("side", "version", b"{}") + b.got_message("side", "0", b"msg1") self.assertEqual(events, [ ("w.got_key", b"key"), ("w.got_verifier", b"verifier"), @@ -1477,7 +1477,7 @@ class Boss(unittest.TestCase): b.happy() # phase=version - b.got_message("unknown-phase", b"spooky") + b.got_message("side", "unknown-phase", b"spooky") self.assertEqual(events, []) self.flushLoggedErrors(errors._UnknownPhaseError) From 6cfabba31a8532a7e30ffc475ae95dd2c3bd6e5b Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Sat, 30 Jun 2018 14:15:06 -0700 Subject: [PATCH 04/49] add reactor/cooperator to Wormhole and Boss calls --- src/wormhole/_boss.py | 2 ++ src/wormhole/test/test_machines.py | 5 ++++- src/wormhole/wormhole.py | 9 ++++++--- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/wormhole/_boss.py b/src/wormhole/_boss.py index 686335f..373c0f4 100644 --- a/src/wormhole/_boss.py +++ b/src/wormhole/_boss.py @@ -38,6 +38,8 @@ class Boss(object): _versions = attrib(validator=instance_of(dict)) _client_version = attrib(validator=instance_of(tuple)) _reactor = attrib() + _eventual_queue = attrib() + _cooperator = attrib() _journal = attrib(validator=provides(_interfaces.IJournal)) _tor = attrib(validator=optional(provides(_interfaces.ITorManager))) _timing = attrib(validator=provides(_interfaces.ITiming)) diff --git a/src/wormhole/test/test_machines.py b/src/wormhole/test/test_machines.py index 8a17b76..5b5e9f1 100644 --- a/src/wormhole/test/test_machines.py +++ b/src/wormhole/test/test_machines.py @@ -1286,11 +1286,14 @@ class Boss(unittest.TestCase): "closed") versions = {"app": "version1"} reactor = None + eq = None + cooperator = None journal = ImmediateJournal() tor_manager = None client_version = ("python", __version__) b = MockBoss(wormhole, "side", "url", "appid", versions, - client_version, reactor, journal, tor_manager, + client_version, reactor, eq, cooperator, journal, + tor_manager, timing.DebugTiming()) b._T = Dummy("t", events, ITerminator, "close") b._S = Dummy("s", events, ISend, "send") diff --git a/src/wormhole/wormhole.py b/src/wormhole/wormhole.py index 610069c..c02aa60 100644 --- a/src/wormhole/wormhole.py +++ b/src/wormhole/wormhole.py @@ -5,6 +5,7 @@ import sys from attr import attrib, attrs from twisted.python import failure +from twisted.internet.task import Cooperator from zope.interface import implementer from ._boss import Boss @@ -122,7 +123,8 @@ class _DelegatedWormhole(object): @implementer(IWormhole, IDeferredWormhole) class _DeferredWormhole(object): - def __init__(self, eq): + def __init__(self, reactor, eq): + self._reactor = reactor self._welcome_observer = OneShotObserver(eq) self._code_observer = OneShotObserver(eq) self._key = None @@ -258,10 +260,11 @@ def create( side = bytes_to_hexstr(os.urandom(5)) journal = journal or ImmediateJournal() eq = _eventual_queue or EventualQueue(reactor) + cooperator = Cooperator(scheduler=eq.eventually) if delegate: w = _DelegatedWormhole(delegate) else: - w = _DeferredWormhole(eq) + w = _DeferredWormhole(reactor, eq) wormhole_versions = {} # will be used to indicate Wormhole capabilities wormhole_versions["app_versions"] = versions # app-specific capabilities v = __version__ @@ -269,7 +272,7 @@ def create( v = v.decode("utf-8", errors="replace") client_version = ("python", v) b = Boss(w, side, relay_url, appid, wormhole_versions, client_version, - reactor, journal, tor, timing) + reactor, eq, cooperator, journal, tor, timing) w._set_boss(b) b.start() return w From cd6ae6390f63564c8b8ee87d7df7e571ea43b7ce Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Sat, 30 Jun 2018 15:25:38 -0700 Subject: [PATCH 05/49] _rendezvous: add note to use EventualQueue --- src/wormhole/_rendezvous.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/wormhole/_rendezvous.py b/src/wormhole/_rendezvous.py index 65d6545..db02166 100644 --- a/src/wormhole/_rendezvous.py +++ b/src/wormhole/_rendezvous.py @@ -89,6 +89,7 @@ class RendezvousConnector(object): # if the initial connection fails, signal an error and shut down. do # this in a different reactor turn to avoid some hazards d.addBoth(lambda res: task.deferLater(self._reactor, 0.0, lambda: res)) + # TODO: use EventualQueue d.addErrback(self._initial_connection_failed) self._debug_record_inbound_f = None From 34686a346ab14a6c812f77aea83bbd3010a4f032 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Sun, 1 Jul 2018 14:44:45 -0700 Subject: [PATCH 06/49] add dilation code (this compresses several months of false starts and rearchitecting) --- docs/api.md | 136 +- docs/dilation-protocol.md | 500 +++++ docs/new-protocol.svg | 2000 +++++++++++++++++++ docs/state-machines/machines.dot | 6 +- setup.py | 1 + src/wormhole/_boss.py | 37 +- src/wormhole/_dilation/__init__.py | 0 src/wormhole/_dilation/connection.py | 482 +++++ src/wormhole/_dilation/connector.py | 482 +++++ src/wormhole/_dilation/encode.py | 16 + src/wormhole/_dilation/inbound.py | 127 ++ src/wormhole/_dilation/manager.py | 514 +++++ src/wormhole/_dilation/old-follower.py | 106 + src/wormhole/_dilation/outbound.py | 392 ++++ src/wormhole/_dilation/roles.py | 1 + src/wormhole/_dilation/subchannel.py | 269 +++ src/wormhole/_interfaces.py | 13 + src/wormhole/test/dilate/__init__.py | 0 src/wormhole/test/dilate/common.py | 18 + src/wormhole/test/dilate/test_connection.py | 216 ++ src/wormhole/test/dilate/test_encoding.py | 25 + src/wormhole/test/dilate/test_endpoints.py | 97 + src/wormhole/test/dilate/test_framer.py | 110 + src/wormhole/test/dilate/test_inbound.py | 172 ++ src/wormhole/test/dilate/test_manager.py | 205 ++ src/wormhole/test/dilate/test_outbound.py | 645 ++++++ src/wormhole/test/dilate/test_parse.py | 43 + src/wormhole/test/dilate/test_record.py | 268 +++ src/wormhole/test/dilate/test_subchannel.py | 142 ++ src/wormhole/test/test_machines.py | 5 +- src/wormhole/wormhole.py | 12 +- 31 files changed, 7011 insertions(+), 29 deletions(-) create mode 100644 docs/dilation-protocol.md create mode 100644 docs/new-protocol.svg create mode 100644 src/wormhole/_dilation/__init__.py create mode 100644 src/wormhole/_dilation/connection.py create mode 100644 src/wormhole/_dilation/connector.py create mode 100644 src/wormhole/_dilation/encode.py create mode 100644 src/wormhole/_dilation/inbound.py create mode 100644 src/wormhole/_dilation/manager.py create mode 100644 src/wormhole/_dilation/old-follower.py create mode 100644 src/wormhole/_dilation/outbound.py create mode 100644 src/wormhole/_dilation/roles.py create mode 100644 src/wormhole/_dilation/subchannel.py create mode 100644 src/wormhole/test/dilate/__init__.py create mode 100644 src/wormhole/test/dilate/common.py create mode 100644 src/wormhole/test/dilate/test_connection.py create mode 100644 src/wormhole/test/dilate/test_encoding.py create mode 100644 src/wormhole/test/dilate/test_endpoints.py create mode 100644 src/wormhole/test/dilate/test_framer.py create mode 100644 src/wormhole/test/dilate/test_inbound.py create mode 100644 src/wormhole/test/dilate/test_manager.py create mode 100644 src/wormhole/test/dilate/test_outbound.py create mode 100644 src/wormhole/test/dilate/test_parse.py create mode 100644 src/wormhole/test/dilate/test_record.py create mode 100644 src/wormhole/test/dilate/test_subchannel.py diff --git a/docs/api.md b/docs/api.md index 39a4e43..43cc318 100644 --- a/docs/api.md +++ b/docs/api.md @@ -524,25 +524,31 @@ object twice. ## Dilation -(NOTE: this section is speculative: this code has not yet been written) - -In the longer term, the Wormhole object will incorporate the "Transit" -functionality (see transit.md) directly, removing the need to instantiate a -second object. A Wormhole can be "dilated" into a form that is suitable for -bulk data transfer. +To send bulk data, or anything more than a handful of messages, a Wormhole +can be "dilated" into a form that uses a direct TCP connection between the +two endpoints. All wormholes start out "undilated". In this state, all messages are queued on the Rendezvous Server for the lifetime of the wormhole, and server-imposed number/size/rate limits apply. Calling `w.dilate()` initiates the dilation -process, and success is signalled via either `d=w.when_dilated()` firing, or -`dg.wormhole_dilated()` being called. Once dilated, the Wormhole can be used -as an IConsumer/IProducer, and messages will be sent on a direct connection -(if possible) or through the transit relay (if not). +process, and eventually yields a set of Endpoints. Once dilated, the usual +`.send_message()`/`.get_message()` APIs are disabled (TODO: really?), and +these endpoints can be used to establish multiple (encrypted) "subchannel" +connections to the other side. + +Each subchannel behaves like a regular Twisted `ITransport`, so they can be +glued to the Protocol instance of your choice. They also implement the +IConsumer/IProducer interfaces. + +These subchannels are *durable*: as long as the processes on both sides keep +running, the subchannel will survive the network connection being dropped. +For example, a file transfer can be started from a laptop, then while it is +running, the laptop can be closed, moved to a new wifi network, opened back +up, and the transfer will resume from the new IP address. What's good about a non-dilated wormhole?: * setup is faster: no delay while it tries to make a direct connection -* survives temporary network outages, since messages are queued * works with "journaled mode", allowing progress to be made even when both sides are never online at the same time, by serializing the wormhole @@ -556,21 +562,103 @@ Use non-dilated wormholes when your application only needs to exchange a couple of messages, for example to set up public keys or provision access tokens. Use a dilated wormhole to move files. -Dilated wormholes can provide multiple "channels": these are multiplexed -through the single (encrypted) TCP connection. Each channel is a separate -stream (offering IProducer/IConsumer) +Dilated wormholes can provide multiple "subchannels": these are multiplexed +through the single (encrypted) TCP connection. Each subchannel is a separate +stream (offering IProducer/IConsumer for flow control), and is opened and +closed independently. A special "control channel" is available to both sides +so they can coordinate how they use the subchannels. -To create a channel, call `c = w.create_channel()` on a dilated wormhole. The -"channel ID" can be obtained with `c.get_id()`. This ID will be a short -(unicode) string, which can be sent to the other side via a normal -`w.send()`, or any other means. On the other side, use `c = -w.open_channel(channel_id)` to get a matching channel object. +The `d = w.dilate()` Deferred fires with a triple of Endpoints: -Then use `c.send(data)` and `d=c.when_received()` to exchange data, or wire -them up with `c.registerProducer()`. Note that channels do not close until -the wormhole connection is closed, so they do not have separate `close()` -methods or events. Therefore if you plan to send files through them, you'll -need to inform the recipient ahead of time about how many bytes to expect. +```python +d = w.dilate() +def _dilated(res): + (control_channel_ep, subchannel_client_ep, subchannel_server_ep) = res +d.addCallback(_dilated) +``` + +The `control_channel_ep` endpoint is a client-style endpoint, so both sides +will connect to it with `ep.connect(factory)`. This endpoint is single-use: +calling `.connect()` a second time will fail. The control channel is +symmetric: it doesn't matter which side is the application-level +client/server or initiator/responder, if the application even has such +concepts. The two applications can use the control channel to negotiate who +goes first, if necessary. + +The subchannel endpoints are *not* symmetric: for each subchannel, one side +must listen as a server, and the other must connect as a client. Subchannels +can be established by either side at any time. This supports e.g. +bidirectional file transfer, where either user of a GUI app can drop files +into the "wormhole" whenever they like. + +The `subchannel_client_ep` on one side is used to connect to the other side's +`subchannel_server_ep`, and vice versa. The client endpoint is reusable. The +server endpoint is single-use: `.listen(factory)` may only be called once. + +Applications are under no obligation to use subchannels: for many use cases, +the control channel is enough. + +To use subchannels, once the wormhole is dilated and the endpoints are +available, the listening-side application should attach a listener to the +`subchannel_server_ep` endpoint: + +```python +def _dilated(res): + (control_channel_ep, subchannel_client_ep, subchannel_server_ep) = res + f = Factory(MyListeningProtocol) + subchannel_server_ep.listen(f) +``` + +When the connecting-side application wants to connect to that listening +protocol, it should use `.connect()` with a suitable connecting protocol +factory: + +```python +def _connect(): + f = Factory(MyConnectingProtocol) + subchannel_client_ep.connect(f) +``` + +For a bidirectional file-transfer application, both sides will establish a +listening protocol. Later, if/when the user drops a file on the application +window, that side will initiate a connection, use the resulting subchannel to +transfer the single file, and then close the subchannel. + +```python +def FileSendingProtocol(internet.Protocol): + def __init__(self, metadata, filename): + self.file_metadata = metadata + self.file_name = filename + def connectionMade(self): + self.transport.write(self.file_metadata) + sender = protocols.basic.FileSender() + f = open(self.file_name,"rb") + d = sender.beginFileTransfer(f, self.transport) + d.addBoth(self._done, f) + def _done(res, f): + self.transport.loseConnection() + f.close() +def _send(metadata, filename): + f = protocol.ClientCreator(reactor, + FileSendingProtocol, metadata, filename) + subchannel_client_ep.connect(f) +def FileReceivingProtocol(internet.Protocol): + state = INITIAL + def dataReceived(self, data): + if state == INITIAL: + self.state = DATA + metadata = parse(data) + self.f = open(metadata.filename, "wb") + else: + # local file writes are blocking, so don't bother with IConsumer + self.f.write(data) + def connectionLost(self, reason): + self.f.close() +def _dilated(res): + (control_channel_ep, subchannel_client_ep, subchannel_server_ep) = res + f = Factory(FileReceivingProtocol) + subchannel_server_ep.listen(f) +``` ## Bytes, Strings, Unicode, and Python 3 diff --git a/docs/dilation-protocol.md b/docs/dilation-protocol.md new file mode 100644 index 0000000..4f0dac1 --- /dev/null +++ b/docs/dilation-protocol.md @@ -0,0 +1,500 @@ +# Dilation Internals + +Wormhole dilation involves several moving parts. Both sides exchange messages +through the Mailbox server to coordinate the establishment of a more direct +connection. This connection might flow in either direction, so they trade +"connection hints" to point at potential listening ports. This process might +succeed in making multiple connections at about the same time, so one side +must select the best one to use, and cleanly shut down the others. To make +the dilated connection *durable*, this side must also decide when the +connection has been lost, and then coordinate the construction of a +replacement. Within this connection, a series of queued-and-acked subchannel +messages are used to open/use/close the application-visible subchannels. + +## Leaders and Followers + +Each side of a Wormhole has a randomly-generated "side" string. When the +wormhole is dilated, the side with the lexicographically-higher "side" value +is named the "Leader", and the other side is named the "Follower". The general +wormhole protocol treats both sides identically, but the distinction matters +for the dilation protocol. + +Either side can trigger dilation, but the Follower does so by asking the +Leader to start the process, whereas the Leader just starts the process +unilaterally. The Leader has exclusive control over whether a given +connection is considered established or not: if there are multiple potential +connections to use, the Leader decides which one to use, and the Leader gets +to decide when the connection is no longer viable (and triggers the +establishment of a new one). + +## Connection Layers + +We describe the protocol as a series of layers. Messages sent on one layer +may be encoded or transformed before being delivered on some other layer. + +L1 is the mailbox channel (queued store-and-forward messages that always go +to the mailbox server, and then are forwarded to other clients subscribed to +the same mailbox). Both clients remain connected to the mailbox server until +the Wormhole is closed. They send DILATE-n messages to each other to manage +the dilation process, including records like `please-dilate`, +`start-dilation`, `ok-dilation`, and `connection-hints` + +L2 is the set of competing connection attempts for a given generation of +connection. Each time the Leader decides to establish a new connection, a new +generation number is used. Hopefully these are direct TCP connections between +the two peers, but they may also include connections through the transit +relay. Each connection must go through an encrypted handshake process before +it is considered viable. Viable connections are then submitted to a selection +process (on the Leader side), which chooses exactly one to use, and drops the +others. It may wait an extra few seconds in the hopes of getting a "better" +connection (faster, cheaper, etc), but eventually it will select one. + +L3 is the current selected connection. There is one L3 for each generation. +At all times, the wormhole will have exactly zero or one L3 connection. L3 is +responsible for the selection process, connection monitoring/keepalives, and +serialization/deserialization of the plaintext frames. L3 delivers decoded +frames and connection-establishment events up to L4. + +L4 is the persistent higher-level channel. It is created as soon as the first +L3 connection is selected, and lasts until wormhole is closed entirely. L4 +contains OPEN/DATA/CLOSE/ACK messages: OPEN/DATA/CLOSE have a sequence number +(scoped to the L4 connection and the direction of travel), and the ACK +messages reference those sequence numbers. When a message is given to the L4 +channel for delivery to the remote side, it is always queued, then +transmitted if there is an L3 connection available. This message remains in +the queue until an ACK is received to release it. If a new L3 connection is +made, all queued messages will be re-sent (in seqnum order). + +L5 are subchannels. There is one pre-established subchannel 0 known as the +"control channel", which does not require an OPEN message. All other +subchannels are created by the receipt of an OPEN message with the subchannel +number. DATA frames are delivered to a specific subchannel. When the +subchannel is no longer needed, one side will invoke the ``close()`` API +(``loseConnection()`` in Twisted), which will cause a CLOSE message to be +sent, and the local L5 object will be put into the "closing "state. When the +other side receives the CLOSE, it will send its own CLOSE for the same +subchannel, and fully close its local object (``connectionLost()``). When the +first side receives CLOSE in the "closing" state, it will fully close its +local object too. + +All L5 subchannels will be paused (``pauseProducing()``) when the L3 +connection is paused or lost. They are resumed when the L3 connection is +resumed or reestablished. + +## Initiating Dilation + +Dilation is triggered by calling the `w.dilate()` API. This returns a +Deferred that will fire once the first L3 connection is established. It fires +with a 3-tuple of endpoints that can be used to establish subchannels. + +For dilation to succeed, both sides must call `w.dilate()`, since the +resulting endpoints are the only way to access the subchannels. If the other +side never calls `w.dilate()`, the Deferred will never fire. + +The L1 (mailbox) path is used to deliver dilation requests and connection +hints. The current mailbox protocol uses named "phases" to distinguish +messages (rather than behaving like a regular ordered channel of arbitrary +frames or bytes), and all-number phase names are reserved for application +data (sent via `w.send_message()`). Therefore the dilation control messages +use phases named `DILATE-0`, `DILATE-1`, etc. Each side maintains its own +counter, so one side might be up to e.g. `DILATE-5` while the other has only +gotten as far as `DILATE-2`. This effectively creates a unidirectional stream +of `DILATE-n` messages, each containing one or more dilation record, of +various types described below. Note that all phases beyond the initial +VERSION and PAKE phases are encrypted by the shared session key. + +A future mailbox protocol might provide a simple ordered stream of messages, +with application records and dilation records mixed together. + +Each `DILATE-n` message is a JSON-encoded dictionary with a `type` field that +has a string value. The dictionary will have other keys that depend upon the +type. + +`w.dilate()` triggers a `please-dilate` record with a set of versions that +can be accepted. Both Leader and Follower emit this record, although the +Leader is responsible for version decisions. Versions use strings, rather +than integers, to support experimental protocols, however there is still a +total ordering of preferability. + +``` +{ "type": "please-dilate", + "accepted-versions": ["1"] +} +``` + +The Leader then sends a `start-dilation` message with a `version` field (the +"best" mutually-supported value) and the new "L2 generation" number in the +`generation` field. Generation numbers are integers, monotonically increasing +by 1 each time. + +``` +{ "type": start-dilation, + "version": "1", + "generation": 1, +} +``` + +The Follower responds with a `ok-dilation` message with matching `version` +and `generation` fields. + +The Leader decides when a new dilation connection is necessary, both for the +initial connection and any subsequent reconnects. Therefore the Leader has +the exclusive right to send the `start-dilation` record. It won't send this +until after it has sent its own `please-dilate`, and after it has received +the Follower's `please-dilate`. As a result, local preparations may begin as +soon as `w.dilate()` is called, but L2 connections do not begin until the +Leader declares the start of a new L2 generation with the `start-dilation` +message. + +Generations are non-overlapping. The Leader will drop all connections from +generation 1 before sending the `start-dilation` for generation 2, and will +not initiate any gen-2 connections until it receives the matching +`ok-dilation` from the Follower. The Follower must drop all gen-1 connections +before it sends the `ok-dilation` response (even if it thinks they are still +functioning: if the Leader thought the gen-1 connection still worked, it +wouldn't have started gen-2). Listening sockets can be retained, but any +previous connection made through them must be dropped. This should avoid a +race. + +(TODO: what about a follower->leader connection that was started before +start-dilation is received, and gets established on the Leader side after +start-dilation is sent? the follower will drop it after it receives +start-dilation, but meanwhile the leader may accept it as gen2) + +(probably need to include the generation number in the handshake, or in the +derived key) + +(TODO: reduce the number of round-trip stalls here, I've added too many) + +"Connection hints" are type/address/port records that tell the other side of +likely targets for L2 connections. Both sides will try to determine their +external IP addresses, listen on a TCP port, and advertise `(tcp, +external-IP, port)` as a connection hint. The Transit Relay is also used as a +(lower-priority) hint. These are sent in `connection-hint` records, which can +be sent by the Leader any time after the `start-dilation` record, and by the +Follower after the `ok-dilation` record. Each side will initiate connections +upon receipt of the hints. + +``` +{ "type": "connection-hints", + "hints": [ ... ] +} +``` + +Hints can arrive at any time. One side might immediately send hints that can +be computed quickly, then send additional hints later as they become +available. For example, it might enumerate the local network interfaces and +send hints for all of the LAN addresses first, then send port-forwarding +(UPnP) requests to the local router. When the forwarding is established +(providing an externally-visible IP address and port), it can send additional +hints for that new endpoint. If the other peer happens to be on the same LAN, +the local connection can be established without waiting for the router's +response. + + +### Connection Hint Format + +Each member of the `hints` field describes a potential L2 connection target +endpoint, with an associated priority and a set of hints. + +The priority is a number (positive or negative float), where larger numbers +indicate that the client supplying that hint would prefer to use this +connection over others of lower number. This indicates a sense of cost or +performance. For example, the Transit Relay is lower priority than a direct +TCP connection, because it incurs a bandwidth cost (on the relay operator), +as well as adding latency. + +Each endpoint has a set of hints, because the same target might be reachable +by multiple hints. Once one hint succeeds, there is no point in using the +other hints. + +TODO: think this through some more. What's the example of a single endpoint +reachable by multiple hints? Should each hint have its own priority, or just +each endpoint? + +## L2 protocol + +Upon ``connectionMade()``, both sides send their handshake message. The +Leader sends "Magic-Wormhole Dilation Handshake v1 Leader\n\n". The Follower +sends "Magic-Wormhole Dilation Handshake v1 Follower\n\n". This should +trigger an immediate error for most non-magic-wormhole listeners (e.g. HTTP +servers that were contacted by accident). If the wrong handshake is received, +the connection will be dropped. For debugging purposes, the node might want +to keep looking at data beyond the first incorrect character and log +everything until the first newline. + +Everything beyond that point is a Noise protocol message, which consists of a +4-byte big-endian length field, followed by the indicated number of bytes. +This ises the `NNpsk0` pattern with the Leader as the first party ("-> psk, +e" in the Noise spec), and the Follower as the second ("<- e, ee"). The +pre-shared-key is the "dilation key", which is statically derived from the +master PAKE key using HKDF. Each L2 connection uses the same dilation key, +but different ephemeral keys, so each gets a different session key. + +The Leader sends the first message, which is a psk-encrypted ephemeral key. +The Follower sends the next message, its own psk-encrypted ephemeral key. The +Follower then sends an empty packet as the "key confirmation message", which +will be encrypted by the shared key. + +The Leader sees the KCM and knows the connection is viable. It delivers the +protocol object to the L3 manager, which will decide which connection to +select. When the L2 connection is selected to be the new L3, it will send an +empty KCM of its own, to let the Follower know the connection being selected. +All other L2 connections (either viable or still in handshake) are dropped, +all other connection attempts are cancelled, and all listening sockets are +shut down. + +The Follower will wait for either an empty KCM (at which point the L2 +connection is delivered to the Dilation manager as the new L3), a +disconnection, or an invalid message (which causes the connection to be +dropped). Other connections and/or listening sockets are stopped. + +Internally, the L2Protocol object manages the Noise session itself. It knows +(via a constructor argument) whether it is on the Leader or Follower side, +which affects both the role is plays in the Noise pattern, and the reaction +to receiving the ephemeral key (for which only the Follower sends an empty +KCM message). After that, the L2Protocol notifies the L3 object in three +situations: + +* the Noise session produces a valid decrypted frame (for Leader, this + includes the Follower's KCM, and thus indicates a viable candidate for + connection selection) +* the Noise session reports a failed decryption +* the TCP session is lost + +All notifications include a reference to the L2Protocol object (`self`). The +L3 object uses this reference to either close the connection (for errors or +when the selection process chooses someone else), to send the KCM message +(after selection, only for the Leader), or to send other L4 messages. The L3 +object will retain a reference to the winning L2 object. + +## L3 protocol + +The L3 layer is responsible for connection selection, monitoring/keepalives, +and message (de)serialization. Framing is handled by L2, so the inbound L3 +codepath receives single-message bytestrings, and delivers the same down to +L2 for encryption, framing, and transmission. + +Connection selection takes place exclusively on the Leader side, and includes +the following: + +* receipt of viable L2 connections from below (indicated by the first valid + decrypted frame received for any given connection) +* expiration of a timer +* comparison of TBD quality/desirability/cost metrics of viable connections +* selection of winner +* instructions to losing connections to disconnect +* delivery of KCM message through winning connection +* retain reference to winning connection + +On the Follower side, the L3 manager just waits for the first connection to +receive the Leader's KCM, at which point it is retained and all others are +dropped. + +The L3 manager knows which "generation" of connection is being established. +Each generation uses a different dilation key (?), and is triggered by a new +set of L1 messages. Connections from one generation should not be confused +with those of a different generation. + +Each time a new L3 connection is established, the L4 protocol is notified. It +will will immediately send all the L4 messages waiting in its outbound queue. +The L3 protocol simply wraps these in Noise frames and sends them to the +other side. + +The L3 manager monitors the viability of the current connection, and declares +it as lost when bidirectional traffic cannot be maintained. It uses PING and +PONG messages to detect this. These also serve to keep NAT entries alive, +since many firewalls will stop forwarding packets if they don't observe any +traffic for e.g. 5 minutes. + +Our goals are: + +* don't allow more than 30? seconds to pass without at least *some* data + being sent along each side of the connection +* allow the Leader to detect silent connection loss within 60? seconds +* minimize overhead + +We need both sides to: + +* maintain a 30-second repeating timer +* set a flag each time we write to the connection +* each time the timer fires, if the flag was clear then send a PONG, + otherwise clear the flag + +In addition, the Leader must: + +* run a 60-second repeating timer (ideally somewhat offset from the other) +* set a flag each time we receive data from the connection +* each time the timer fires, if the flag was clear then drop the connection, + otherwise clear the flag + +In the future, we might have L2 links that are less connection-oriented, +which might have a unidirectional failure mode, at which point we'll need to +monitor full roundtrips. To accomplish this, the Leader will send periodic +unconditional PINGs, and the Follower will respond with PONGs. If the +Leader->Follower connection is down, the PINGs won't arrive and no PONGs will +be produced. If the Follower->Leader direction has failed, the PONGs won't +arrive. The delivery of both will be delayed by actual data, so the timeouts +should be adjusted if we see regular data arriving. + +If the connection is dropped before the wormhole is closed (either the other +end explicitly dropped it, we noticed a problem and told TCP to drop it, or +TCP noticed a problem itself), the Leader-side L3 manager will initiate a +reconnection attempt. This uses L1 to send a new DILATE message through the +mailbox server, along with new connection hints. Eventually this will result +in a new L3 connection being established. + +Finally, L3 is responsible for message serialization and deserialization. L2 +performs decryption and delivers plaintext frames to L3. Each frame starts +with a one-byte type indicator. The rest of the message depends upon the +type: + +* 0x00 PING, 4-byte ping-id +* 0x01 PONG, 4-byte ping-id +* 0x02 OPEN, 4-byte subchannel-id, 4-byte seqnum +* 0x03 DATA, 4-byte subchannel-id, 4-byte seqnum, variable-length payload +* 0x04 CLOSE, 4-byte subchannel-id, 4-byte seqnum +* 0x05 ACK, 4-byte response-seqnum + +All seqnums are big-endian, and are provided by the L4 protocol. The other +fields are arbitrary and not interpreted as integers. The subchannel-ids must +be allocated by both sides without collision, but otherwise they are only +used to look up L5 objects for dispatch. The response-seqnum is always copied +from the OPEN/DATA/CLOSE packet being acknowledged. + +L3 consumes the PING and PONG messages. Receiving any PING will provoke a +PONG in response, with a copy of the ping-id field. The 30-second timer will +produce unprovoked PONGs with a ping-id of all zeros. A future viability +protocol will use PINGs to test for roundtrip functionality. + +All other messages (OPEN/DATA/CLOSE/ACK) are deserialized and delivered +"upstairs" to the L4 protocol handler. + +The current L3 connection's `IProducer`/`IConsumer` interface is made +available to the L4 flow-control manager. + +## L4 protocol + +The L4 protocol manages a durable stream of OPEN/DATA/CLOSE/ACK messages. +Since each will be enclosed in a Noise frame before they pass to L3, they do +not need length fields or other framing. + +Each OPEN/DATA/CLOSE has a sequence number, starting at 0, and monotonically +increasing by 1 for each message. Each direction has a separate number space. + +The L4 manager maintains a double-ended queue of unacknowledged outbound +messages. Subchannel activity (opening, closing, sending data) cause messages +to be added to this queue. If an L3 connection is available, these messages +are also sent over that connection, but they remain in the queue in case the +connection is lost and they must be retransmitted on some future replacement +connection. Messages stay in the queue until they can be retired by the +receipt of an ACK with a matching response-sequence-number. This provides +reliable message delivery that survives the L3 connection being replaced. + +ACKs are not acked, nor do they have seqnums of their own. Each inbound side +remembers the highest ACK it has sent, and ignores incoming OPEN/DATA/CLOSE +messages with that sequence number or higher. This ensures in-order +at-most-once processing of OPEN/DATA/CLOSE messages. + +Each inbound OPEN message causes a new L5 subchannel object to be created. +Subsequent DATA/CLOSE messages for the same subchannel-id are delivered to +that object. + +Each time an L3 connection is established, the side will immediately send all +L4 messages waiting in the outbound queue. A future protocol might reduce +this duplication by including the highest received sequence number in the L1 +PLEASE-DILATE message, which would effectively retire queued messages before +initiating the L2 connection process. On any given L3 connection, all +messages are sent in-order. The receipt of an ACK for seqnum `N` allows all +messages with `seqnum <= N` to be retired. + +The L4 layer is also responsible for managing flow control among the L3 +connection and the various L5 subchannels. + +## L5 subchannels + +The L5 layer consists of a collection of "subchannel" objects, a dispatcher, +and the endpoints that provide the Twisted-flavored API. + +Other than the "control channel", all subchannels are created by a client +endpoint connection API. The side that calls this API is named the Initiator, +and the other side is named the Acceptor. Subchannels can be initiated in +either direction, independent of the Leader/Follower distinction. For a +typical file-transfer application, the subchannel would be initiated by the +side seeking to send a file. + +Each subchannel uses a distinct subchannel-id, which is a four-byte +identifier. Both directions share a number space (unlike L4 seqnums), so the +rule is that the Leader side sets the last bit of the last byte to a 0, while +the Follower sets it to a 1. These are not generally treated as integers, +however for the sake of debugging, the implementation generates them with a +simple big-endian-encoded counter (`next(counter)*2` for the Leader, +`next(counter)*2+1` for the Follower). + +When the `client_ep.connect()` API is called, the Initiator allocates a +subchannel-id and sends an OPEN. It can then immediately send DATA messages +with the outbound data (there is no special response to an OPEN, so there is +no need to wait). The Acceptor will trigger their `.connectionMade` handler +upon receipt of the OPEN. + +Subchannels are durable: they do not close until one side calls +`.loseConnection` on the subchannel object (or the enclosing Wormhole is +closed). Either the Initiator or the Acceptor can call `.loseConnection`. +This causes a CLOSE message to be sent (with the subchannel-id). The other +side will send its own CLOSE message in response. Each side will signal the +`.connectionLost()` event upon receipt of a CLOSE. + +There is no equivalent to TCP's "half-closed" state, however if only one side +calls `close()`, then all data written before that call will be delivered +before the other side observes `.connectionLost()`. Any inbound data that was +queued for delivery before the other side sees the CLOSE will still be +delivered to the side that called `close()` before it sees +`.connectionLost()`. Internally, the side which called `.loseConnection` will +remain in a special "closing" state until the CLOSE response arrives, during +which time DATA payloads are still delivered. After calling `close()` (or +receiving CLOSE), any outbound `.write()` calls will trigger an error. + +DATA payloads that arrive for a non-open subchannel are logged and discarded. + +This protocol calls for one OPEN and two CLOSE messages for each subchannel, +with some arbitrary number of DATA messages in between. Subchannel-ids should +not be reused (it would probably work, the protocol hasn't been analyzed +enough to be sure). + +The "control channel" is special. It uses a subchannel-id of all zeros, and +is opened implicitly by both sides as soon as the first L3 connection is +selected. It is routed to a special client-on-both-sides endpoint, rather +than causing the listening endpoint to accept a new connection. This avoids +the need for application-level code to negotiate who should be the one to +open it (the Leader/Follower distinction is private to the Wormhole +internals: applications are not obligated to pick a side). + +OPEN and CLOSE messages for the control channel are logged and discarded. The +control-channel client endpoints can only be used once, and does not close +until the Wormhole itself is closed. + +Each OPEN/DATA/CLOSE message is delivered to the L4 object for queueing, +delivery, and eventual retirement. The L5 layer does not keep track of old +messages. + +### Flow Control + +Subchannels are flow-controlled by pausing their writes when the L3 +connection is paused, and pausing the L3 connection when the subchannel +signals a pause. When the outbound L3 connection is full, *all* subchannels +are paused. Likewise the inbound connection is paused if *any* of the +subchannels asks for a pause. This is much easier to implement and improves +our utilization factor (we can use TCP's window-filling algorithm, instead of +rolling our own), but will block all subchannels even if only one of them +gets full. This shouldn't matter for many applications, but might be +noticeable when combining very different kinds of traffic (e.g. a chat +conversation sharing a wormhole with file-transfer might prefer the IM text +to take priority). + +Each subchannel implements Twisted's `ITransport`, `IProducer`, and +`IConsumer` interfaces. The Endpoint API causes a new `IProtocol` object to +be created (by the caller's factory) and glued to the subchannel object in +the `.transport` property, as is standard in Twisted-based applications. + +All subchannels are also paused when the L3 connection is lost, and are +unpaused when a new replacement connection is selected. diff --git a/docs/new-protocol.svg b/docs/new-protocol.svg new file mode 100644 index 0000000..a69b79a --- /dev/null +++ b/docs/new-protocol.svg @@ -0,0 +1,2000 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + image/svg+xml + + + + + + + connectionMade() + dataReceived() + dataReceived() + connectionLost() + + + empty + + + + open + + + + open + + + + closing + + + + empty + + + + + connect() + + + Open1 + + write() + write() + loseConnection() + connectionLost() + + + open + + + + + + Data1 + + + + Data1 + + + + Close1 + + + + Close1 + + + + + + + + + + + + + + empty + + + + open + + + + open + + + + empty + + + + + + open + + + + + + + + + + + + + + + Open1 + + + + Data1 + + + + Data1 + + + + Close1 + + connection 1 + + + Open1 + + + + Data1 + + 0 + 1 + 2 + 3 + + + ack 0 + + connection 2 + + + Data1 + + + + ack 1 + + + + Data1 + + + + Close1 + + + + ack 2 + + + + ack 3 + + + + Close1 + + + + ack 0' + + 0' + 0 + 1 + 1 + 2 + 3 + 0' + logical + 0 + 1 + 2 + 3 + + diff --git a/docs/state-machines/machines.dot b/docs/state-machines/machines.dot index ecfdf9a..09cc02c 100644 --- a/docs/state-machines/machines.dot +++ b/docs/state-machines/machines.dot @@ -23,12 +23,13 @@ digraph { Terminator [shape="box" color="blue" fontcolor="blue"] InputHelperAPI [shape="oval" label="input\nhelper\nAPI" color="blue" fontcolor="blue"] + Dilator [shape="box" label="Dilator" color="blue" fontcolor="blue"] #Connection -> websocket [color="blue"] #Connection -> Order [color="blue"] Wormhole -> Boss [style="dashed" - label="allocate_code\ninput_code\nset_code\nsend\nclose\n(once)" + label="allocate_code\ninput_code\nset_code\ndilate\nsend\nclose\n(once)" color="red" fontcolor="red"] #Wormhole -> Boss [color="blue"] Boss -> Wormhole [style="dashed" label="got_code\ngot_key\ngot_verifier\ngot_version\nreceived (seq)\nclosed\n(once)"] @@ -112,4 +113,7 @@ digraph { Terminator -> Boss [style="dashed" label="closed\n(once)"] Boss -> Terminator [style="dashed" color="red" fontcolor="red" label="close"] + + Boss -> Dilator [style="dashed" label="dilate\nreceived_dilate\ngot_wormhole_versions"] + Dilator -> Send [style="dashed" label="send(dilate-N)"] } diff --git a/setup.py b/setup.py index a99a2d2..c275e8b 100644 --- a/setup.py +++ b/setup.py @@ -48,6 +48,7 @@ setup(name="magic-wormhole", "click", "humanize", "txtorcon >= 18.0.2", # 18.0.2 fixes py3.4 support + "noiseprotocol", ], extras_require={ ':sys_platform=="win32"': ["pywin32"], diff --git a/src/wormhole/_boss.py b/src/wormhole/_boss.py index 373c0f4..cc6f4fa 100644 --- a/src/wormhole/_boss.py +++ b/src/wormhole/_boss.py @@ -12,6 +12,7 @@ from zope.interface import implementer from . import _interfaces from ._allocator import Allocator from ._code import Code, validate_code +from ._dilation.manager import Dilator from ._input import Input from ._key import Key from ._lister import Lister @@ -66,6 +67,7 @@ class Boss(object): self._I = Input(self._timing) self._C = Code(self._timing) self._T = Terminator() + self._D = Dilator(self._reactor, self._eventual_queue, self._cooperator) self._N.wire(self._M, self._I, self._RC, self._T) self._M.wire(self._N, self._RC, self._O, self._T) @@ -79,6 +81,7 @@ class Boss(object): self._I.wire(self._C, self._L) self._C.wire(self, self._A, self._N, self._K, self._I) self._T.wire(self, self._RC, self._N, self._M) + self._D.wire(self._S) def _init_other_state(self): self._did_start_code = False @@ -86,6 +89,9 @@ class Boss(object): self._next_rx_phase = 0 self._rx_phases = {} # phase -> plaintext + self._next_rx_dilate_seqnum = 0 + self._rx_dilate_seqnums = {} # seqnum -> plaintext + self._result = "empty" # these methods are called from outside @@ -198,6 +204,9 @@ class Boss(object): self._did_start_code = True self._C.set_code(code) + def dilate(self): + return self._D.dilate() # fires with endpoints + @m.input() def send(self, plaintext): pass @@ -258,8 +267,11 @@ class Boss(object): # this is only called for side != ours assert isinstance(phase, type("")), type(phase) assert isinstance(plaintext, type(b"")), type(plaintext) + d_mo = re.search(r'^dilate-(\d+)$', phase) if phase == "version": self._got_version(side, plaintext) + elif d_mo: + self._got_dilate(int(d_mo.group(1)), plaintext) elif re.search(r'^\d+$', phase): self._got_phase(int(phase), plaintext) else: @@ -275,6 +287,10 @@ class Boss(object): def _got_phase(self, phase, plaintext): pass + @m.input() + def _got_dilate(self, seqnum, plaintext): + pass + @m.input() def got_key(self, key): pass @@ -298,6 +314,8 @@ class Boss(object): # in the future, this is how Dilation is signalled self._their_side = side self._their_versions = bytes_to_dict(plaintext) + self._D.got_wormhole_versions(self._side, self._their_side, + self._their_versions) # but this part is app-to-app app_versions = self._their_versions.get("app_versions", {}) self._W.got_versions(app_versions) @@ -339,6 +357,10 @@ class Boss(object): def W_got_key(self, key): self._W.got_key(key) + @m.output() + def D_got_key(self, key): + self._D.got_key(key) + @m.output() def W_got_verifier(self, verifier): self._W.got_verifier(verifier) @@ -352,6 +374,16 @@ class Boss(object): self._W.received(self._rx_phases.pop(self._next_rx_phase)) self._next_rx_phase += 1 + @m.output() + def D_received_dilate(self, seqnum, plaintext): + assert isinstance(seqnum, six.integer_types), type(seqnum) + # strict phase order, no gaps + self._rx_dilate_seqnums[seqnum] = plaintext + while self._next_rx_dilate_seqnum in self._rx_dilate_seqnums: + m = self._rx_dilate_seqnums.pop(self._next_rx_dilate_seqnum) + self._D.received_dilate(m) + self._next_rx_dilate_seqnum += 1 + @m.output() def W_close_with_error(self, err): self._result = err # exception @@ -374,7 +406,7 @@ class Boss(object): S1_lonely.upon(scared, enter=S3_closing, outputs=[close_scared]) S1_lonely.upon(close, enter=S3_closing, outputs=[close_lonely]) S1_lonely.upon(send, enter=S1_lonely, outputs=[S_send]) - S1_lonely.upon(got_key, enter=S1_lonely, outputs=[W_got_key]) + S1_lonely.upon(got_key, enter=S1_lonely, outputs=[W_got_key, D_got_key]) S1_lonely.upon(rx_error, enter=S3_closing, outputs=[close_error]) S1_lonely.upon(error, enter=S4_closed, outputs=[W_close_with_error]) @@ -382,6 +414,7 @@ class Boss(object): S2_happy.upon(got_verifier, enter=S2_happy, outputs=[W_got_verifier]) S2_happy.upon(_got_phase, enter=S2_happy, outputs=[W_received]) S2_happy.upon(_got_version, enter=S2_happy, outputs=[process_version]) + S2_happy.upon(_got_dilate, enter=S2_happy, outputs=[D_received_dilate]) S2_happy.upon(scared, enter=S3_closing, outputs=[close_scared]) S2_happy.upon(close, enter=S3_closing, outputs=[close_happy]) S2_happy.upon(send, enter=S2_happy, outputs=[S_send]) @@ -393,6 +426,7 @@ class Boss(object): S3_closing.upon(got_verifier, enter=S3_closing, outputs=[]) S3_closing.upon(_got_phase, enter=S3_closing, outputs=[]) S3_closing.upon(_got_version, enter=S3_closing, outputs=[]) + S3_closing.upon(_got_dilate, enter=S3_closing, outputs=[]) S3_closing.upon(happy, enter=S3_closing, outputs=[]) S3_closing.upon(scared, enter=S3_closing, outputs=[]) S3_closing.upon(close, enter=S3_closing, outputs=[]) @@ -404,6 +438,7 @@ class Boss(object): S4_closed.upon(got_verifier, enter=S4_closed, outputs=[]) S4_closed.upon(_got_phase, enter=S4_closed, outputs=[]) S4_closed.upon(_got_version, enter=S4_closed, outputs=[]) + S4_closed.upon(_got_dilate, enter=S4_closed, outputs=[]) S4_closed.upon(happy, enter=S4_closed, outputs=[]) S4_closed.upon(scared, enter=S4_closed, outputs=[]) S4_closed.upon(close, enter=S4_closed, outputs=[]) diff --git a/src/wormhole/_dilation/__init__.py b/src/wormhole/_dilation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/wormhole/_dilation/connection.py b/src/wormhole/_dilation/connection.py new file mode 100644 index 0000000..f0c8e35 --- /dev/null +++ b/src/wormhole/_dilation/connection.py @@ -0,0 +1,482 @@ +from __future__ import print_function, unicode_literals +from collections import namedtuple +import six +from attr import attrs, attrib +from attr.validators import instance_of, provides +from automat import MethodicalMachine +from zope.interface import Interface, implementer +from twisted.python import log +from twisted.internet.protocol import Protocol +from twisted.internet.interfaces import ITransport +from .._interfaces import IDilationConnector +from ..observer import OneShotObserver +from .encode import to_be4, from_be4 +from .roles import FOLLOWER + +# InboundFraming is given data and returns Frames (Noise wire-side +# bytestrings). It handles the relay handshake and the prologue. The Frames it +# returns are either the ephemeral key (the Noise "handshake") or ciphertext +# messages. + +# The next object up knows whether it's expecting a Handshake or a message. It +# feeds the first into Noise as a handshake, it feeds the rest into Noise as a +# message (which produces a plaintext stream). It emits tokens that are either +# "i've finished with the handshake (so you can send the KCM if you want)", or +# "here is a decrypted message (which might be the KCM)". + +# the transmit direction goes directly to transport.write, and doesn't touch +# the state machine. we can do this because the way we encode/encrypt/frame +# things doesn't depend upon the receiver state. It would be more safe to e.g. +# prohibit sending ciphertext frames unless we're in the received-handshake +# state, but then we'll be in the middle of an inbound state transition ("we +# just received the handshake, so you can send a KCM now") when we perform an +# operation that depends upon the state (send_plaintext(kcm)), which is not a +# coherent/safe place to touch the state machine. + +# we could set a flag and test it from inside send_plaintext, which kind of +# violates the state machine owning the state (ideally all "if" statements +# would be translated into same-input transitions from different starting +# states). For the specific question of sending plaintext frames, Noise will +# refuse us unless it's ready anyways, so the question is probably moot. + +class IFramer(Interface): + pass +class IRecord(Interface): + pass + +def first(l): + return l[0] + +class Disconnect(Exception): + pass +RelayOK = namedtuple("RelayOk", []) +Prologue = namedtuple("Prologue", []) +Frame = namedtuple("Frame", ["frame"]) + +@attrs +@implementer(IFramer) +class _Framer(object): + _transport = attrib(validator=provides(ITransport)) + _outbound_prologue = attrib(validator=instance_of(bytes)) + _inbound_prologue = attrib(validator=instance_of(bytes)) + _buffer = b"" + _can_send_frames = False + + # in: use_relay + # in: connectionMade, dataReceived + # out: prologue_received, frame_received + # out (shared): transport.loseConnection + # out (shared): transport.write (relay handshake, prologue) + # states: want_relay, want_prologue, want_frame + m = MethodicalMachine() + set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover + + @m.state() + def want_relay(self): pass # pragma: no cover + @m.state(initial=True) + def want_prologue(self): pass # pragma: no cover + @m.state() + def want_frame(self): pass # pragma: no cover + + @m.input() + def use_relay(self, relay_handshake): pass + @m.input() + def connectionMade(self): pass + @m.input() + def parse(self): pass + @m.input() + def got_relay_ok(self): pass + @m.input() + def got_prologue(self): pass + + @m.output() + def store_relay_handshake(self, relay_handshake): + self._outbound_relay_handshake = relay_handshake + self._expected_relay_handshake = b"ok\n" # TODO: make this configurable + @m.output() + def send_relay_handshake(self): + self._transport.write(self._outbound_relay_handshake) + + @m.output() + def send_prologue(self): + self._transport.write(self._outbound_prologue) + + @m.output() + def parse_relay_ok(self): + if self._get_expected("relay_ok", self._expected_relay_handshake): + return RelayOK() + + @m.output() + def parse_prologue(self): + if self._get_expected("prologue", self._inbound_prologue): + return Prologue() + + @m.output() + def can_send_frames(self): + self._can_send_frames = True # for assertion in send_frame() + + @m.output() + def parse_frame(self): + if len(self._buffer) < 4: + return None + frame_length = from_be4(self._buffer[0:4]) + if len(self._buffer) < 4+frame_length: + return None + frame = self._buffer[4:4+frame_length] + self._buffer = self._buffer[4+frame_length:] # TODO: avoid copy + return Frame(frame=frame) + + want_prologue.upon(use_relay, outputs=[store_relay_handshake], + enter=want_relay) + + want_relay.upon(connectionMade, outputs=[send_relay_handshake], + enter=want_relay) + want_relay.upon(parse, outputs=[parse_relay_ok], enter=want_relay, + collector=first) + want_relay.upon(got_relay_ok, outputs=[send_prologue], enter=want_prologue) + + want_prologue.upon(connectionMade, outputs=[send_prologue], + enter=want_prologue) + want_prologue.upon(parse, outputs=[parse_prologue], enter=want_prologue, + collector=first) + want_prologue.upon(got_prologue, outputs=[can_send_frames], enter=want_frame) + + want_frame.upon(parse, outputs=[parse_frame], enter=want_frame, + collector=first) + + + def _get_expected(self, name, expected): + lb = len(self._buffer) + le = len(expected) + if self._buffer.startswith(expected): + # if the buffer starts with the expected string, consume it and + # return True + self._buffer = self._buffer[le:] + return True + if not expected.startswith(self._buffer): + # we're not on track: the data we've received so far does not + # match the expected value, so this can't possibly be right. + # Don't complain until we see the expected length, or a newline, + # so we can capture the weird input in the log for debugging. + if (b"\n" in self._buffer or lb >= le): + log.msg("bad {}: {}".format(name, self._buffer[:le])) + raise Disconnect() + return False # wait a bit longer + # good so far, just waiting for the rest + return False + + # external API is: connectionMade, add_and_parse, and send_frame + + def add_and_parse(self, data): + # we can't make this an @m.input because we can't change the state + # from within an input. Instead, let the state choose the parser to + # use, and use the parsed token drive a state transition. + self._buffer += data + while True: + # it'd be nice to use an iterator here, but since self.parse() + # dispatches to a different parser (depending upon the current + # state), we'd be using multiple iterators + token = self.parse() + if isinstance(token, RelayOK): + self.got_relay_ok() + elif isinstance(token, Prologue): + self.got_prologue() + yield token # triggers send_handshake + elif isinstance(token, Frame): + yield token + else: + break + + def send_frame(self, frame): + assert self._can_send_frames + self._transport.write(to_be4(len(frame)) + frame) + +# RelayOK: Newline-terminated buddy-is-connected response from Relay. +# First data received from relay. +# Prologue: double-newline-terminated this-is-really-wormhole response +# from peer. First data received from peer. +# Frame: Either handshake or encrypted message. Length-prefixed on wire. +# Handshake: the Noise ephemeral key, first framed message +# Message: plaintext: encoded KCM/PING/PONG/OPEN/DATA/CLOSE/ACK +# KCM: Key Confirmation Message (encrypted b"\x00"). First frame +# from peer. Sent immediately by Follower, after Selection by Leader. +# Record: namedtuple of KCM/Open/Data/Close/Ack/Ping/Pong + +Handshake = namedtuple("Handshake", []) +# decrypted frames: produces KCM, Ping, Pong, Open, Data, Close, Ack +KCM = namedtuple("KCM", []) +Ping = namedtuple("Ping", ["ping_id"]) # ping_id is arbitrary 4-byte value +Pong = namedtuple("Pong", ["ping_id"]) +Open = namedtuple("Open", ["seqnum", "scid"]) # seqnum is integer +Data = namedtuple("Data", ["seqnum", "scid", "data"]) +Close = namedtuple("Close", ["seqnum", "scid"]) # scid is integer +Ack = namedtuple("Ack", ["resp_seqnum"]) # resp_seqnum is integer +Records = (KCM, Ping, Pong, Open, Data, Close, Ack) +Handshake_or_Records = (Handshake,) + Records + +T_KCM = b"\x00" +T_PING = b"\x01" +T_PONG = b"\x02" +T_OPEN = b"\x03" +T_DATA = b"\x04" +T_CLOSE = b"\x05" +T_ACK = b"\x06" + +def parse_record(plaintext): + msgtype = plaintext[0:1] + if msgtype == T_KCM: + return KCM() + if msgtype == T_PING: + ping_id = plaintext[1:5] + return Ping(ping_id) + if msgtype == T_PONG: + ping_id = plaintext[1:5] + return Pong(ping_id) + if msgtype == T_OPEN: + scid = from_be4(plaintext[1:5]) + seqnum = from_be4(plaintext[5:9]) + return Open(seqnum, scid) + if msgtype == T_DATA: + scid = from_be4(plaintext[1:5]) + seqnum = from_be4(plaintext[5:9]) + data = plaintext[9:] + return Data(seqnum, scid, data) + if msgtype == T_CLOSE: + scid = from_be4(plaintext[1:5]) + seqnum = from_be4(plaintext[5:9]) + return Close(seqnum, scid) + if msgtype == T_ACK: + resp_seqnum = from_be4(plaintext[1:5]) + return Ack(resp_seqnum) + log.err("received unknown message type: {}".format(plaintext)) + raise ValueError() + +def encode_record(r): + if isinstance(r, KCM): + return b"\x00" + if isinstance(r, Ping): + return b"\x01" + r.ping_id + if isinstance(r, Pong): + return b"\x02" + r.ping_id + if isinstance(r, Open): + assert isinstance(r.scid, six.integer_types) + assert isinstance(r.seqnum, six.integer_types) + return b"\x03" + to_be4(r.scid) + to_be4(r.seqnum) + if isinstance(r, Data): + assert isinstance(r.scid, six.integer_types) + assert isinstance(r.seqnum, six.integer_types) + return b"\x04" + to_be4(r.scid) + to_be4(r.seqnum) + r.data + if isinstance(r, Close): + assert isinstance(r.scid, six.integer_types) + assert isinstance(r.seqnum, six.integer_types) + return b"\x05" + to_be4(r.scid) + to_be4(r.seqnum) + if isinstance(r, Ack): + assert isinstance(r.resp_seqnum, six.integer_types) + return b"\x06" + to_be4(r.resp_seqnum) + raise TypeError(r) + +@attrs +@implementer(IRecord) +class _Record(object): + _framer = attrib(validator=provides(IFramer)) + _noise = attrib() + + n = MethodicalMachine() + # TODO: set_trace + + def __attrs_post_init__(self): + self._noise.start_handshake() + + # in: role= + # in: prologue_received, frame_received + # out: handshake_received, record_received + # out: transport.write (noise handshake, encrypted records) + # states: want_prologue, want_handshake, want_record + + @n.state(initial=True) + def want_prologue(self): pass # pragma: no cover + @n.state() + def want_handshake(self): pass # pragma: no cover + @n.state() + def want_message(self): pass # pragma: no cover + + @n.input() + def got_prologue(self): + pass + @n.input() + def got_frame(self, frame): + pass + + @n.output() + def send_handshake(self): + handshake = self._noise.write_message() # generate the ephemeral key + self._framer.send_frame(handshake) + + @n.output() + def process_handshake(self, frame): + from noise.exceptions import NoiseInvalidMessage + try: + payload = self._noise.read_message(frame) + # Noise can include unencrypted data in the handshake, but we don't + # use it + del payload + except NoiseInvalidMessage as e: + log.err(e, "bad inbound noise handshake") + raise Disconnect() + return Handshake() + + @n.output() + def decrypt_message(self, frame): + from noise.exceptions import NoiseInvalidMessage + try: + message = self._noise.decrypt(frame) + except NoiseInvalidMessage as e: + # if this happens during tests, flunk the test + log.err(e, "bad inbound noise frame") + raise Disconnect() + return parse_record(message) + + want_prologue.upon(got_prologue, outputs=[send_handshake], + enter=want_handshake) + want_handshake.upon(got_frame, outputs=[process_handshake], + collector=first, enter=want_message) + want_message.upon(got_frame, outputs=[decrypt_message], + collector=first, enter=want_message) + + # external API is: connectionMade, dataReceived, send_record + + def connectionMade(self): + self._framer.connectionMade() + + def add_and_unframe(self, data): + for token in self._framer.add_and_parse(data): + if isinstance(token, Prologue): + self.got_prologue() # triggers send_handshake + else: + assert isinstance(token, Frame) + yield self.got_frame(token.frame) # Handshake or a Record type + + def send_record(self, r): + message = encode_record(r) + frame = self._noise.encrypt(message) + self._framer.send_frame(frame) + + +@attrs +class DilatedConnectionProtocol(Protocol, object): + """I manage an L2 connection. + + When a new L2 connection is needed (as determined by the Leader), + both Leader and Follower will initiate many simultaneous connections + (probably TCP, but conceivably others). A subset will actually + connect. A subset of those will successfully pass negotiation by + exchanging handshakes to demonstrate knowledge of the session key. + One of the negotiated connections will be selected by the Leader for + active use, and the others will be dropped. + + At any given time, there is at most one active L2 connection. + """ + + _eventual_queue = attrib() + _role = attrib() + _connector = attrib(validator=provides(IDilationConnector)) + _noise = attrib() + _outbound_prologue = attrib(validator=instance_of(bytes)) + _inbound_prologue = attrib(validator=instance_of(bytes)) + + _use_relay = False + _relay_handshake = None + + m = MethodicalMachine() + set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover + + def __attrs_post_init__(self): + self._manager = None # set if/when we are selected + self._disconnected = OneShotObserver(self._eventual_queue) + self._can_send_records = False + + @m.state(initial=True) + def unselected(self): pass # pragma: no cover + @m.state() + def selecting(self): pass # pragma: no cover + @m.state() + def selected(self): pass # pragma: no cover + + @m.input() + def got_kcm(self): + pass + @m.input() + def select(self, manager): + pass # fires set_manager() + @m.input() + def got_record(self, record): + pass + + @m.output() + def add_candidate(self): + self._connector.add_candidate(self) + + @m.output() + def set_manager(self, manager): + self._manager = manager + + @m.output() + def can_send_records(self, manager): + self._can_send_records = True + + @m.output() + def deliver_record(self, record): + self._manager.got_record(record) + + unselected.upon(got_kcm, outputs=[add_candidate], enter=selecting) + selecting.upon(select, outputs=[set_manager, can_send_records], enter=selected) + selected.upon(got_record, outputs=[deliver_record], enter=selected) + + # called by Connector + + def use_relay(self, relay_handshake): + assert isinstance(relay_handshake, bytes) + self._use_relay = True + self._relay_handshake = relay_handshake + + def when_disconnected(self): + return self._disconnected.when_fired() + + def disconnect(self): + self.transport.loseConnection() + + # select() called by Connector + + # called by Manager + def send_record(self, record): + assert self._can_send_records + self._record.send_record(record) + + # IProtocol methods + + def connectionMade(self): + framer = _Framer(self.transport, + self._outbound_prologue, self._inbound_prologue) + if self._use_relay: + framer.use_relay(self._relay_handshake) + self._record = _Record(framer, self._noise) + self._record.connectionMade() + + def dataReceived(self, data): + try: + for token in self._record.add_and_unframe(data): + assert isinstance(token, Handshake_or_Records) + if isinstance(token, Handshake): + if self._role is FOLLOWER: + self._record.send_record(KCM()) + elif isinstance(token, KCM): + # if we're the leader, add this connection as a candiate. + # if we're the follower, accept this connection. + self.got_kcm() # connector.add_candidate() + else: + self.got_record(token) # manager.got_record() + except Disconnect: + self.transport.loseConnection() + + def connectionLost(self, why=None): + self._disconnected.fire(self) diff --git a/src/wormhole/_dilation/connector.py b/src/wormhole/_dilation/connector.py new file mode 100644 index 0000000..86f2a72 --- /dev/null +++ b/src/wormhole/_dilation/connector.py @@ -0,0 +1,482 @@ +from __future__ import print_function, unicode_literals +import sys, re +from collections import defaultdict, namedtuple +from binascii import hexlify +import six +from attr import attrs, attrib +from attr.validators import instance_of, provides, optional +from automat import MethodicalMachine +from zope.interface import implementer +from twisted.internet.task import deferLater +from twisted.internet.defer import DeferredList +from twisted.internet.endpoints import HostnameEndpoint, serverFromString +from twisted.internet.protocol import ClientFactory, ServerFactory +from twisted.python import log +from hkdf import Hkdf +from .. import ipaddrs # TODO: move into _dilation/ +from .._interfaces import IDilationConnector, IDilationManager +from ..timing import DebugTiming +from ..observer import EmptyableSet +from .connection import DilatedConnectionProtocol, KCM +from .roles import LEADER + + +# These namedtuples are "hint objects". The JSON-serializable dictionaries +# are "hint dicts". + +# DirectTCPV1Hint and TorTCPV1Hint mean the following protocol: +# * make a TCP connection (possibly via Tor) +# * send the sender/receiver handshake bytes first +# * expect to see the receiver/sender handshake bytes from the other side +# * the sender writes "go\n", the receiver waits for "go\n" +# * the rest of the connection contains transit data +DirectTCPV1Hint = namedtuple("DirectTCPV1Hint", ["hostname", "port", "priority"]) +TorTCPV1Hint = namedtuple("TorTCPV1Hint", ["hostname", "port", "priority"]) +# RelayV1Hint contains a tuple of DirectTCPV1Hint and TorTCPV1Hint hints (we +# use a tuple rather than a list so they'll be hashable into a set). For each +# one, make the TCP connection, send the relay handshake, then complete the +# rest of the V1 protocol. Only one hint per relay is useful. +RelayV1Hint = namedtuple("RelayV1Hint", ["hints"]) + +def describe_hint_obj(hint, relay, tor): + prefix = "tor->" if tor else "->" + if relay: + prefix = prefix + "relay:" + if isinstance(hint, DirectTCPV1Hint): + return prefix + "tcp:%s:%d" % (hint.hostname, hint.port) + elif isinstance(hint, TorTCPV1Hint): + return prefix+"tor:%s:%d" % (hint.hostname, hint.port) + else: + return prefix+str(hint) + +def parse_hint_argv(hint, stderr=sys.stderr): + assert isinstance(hint, type("")) + # return tuple or None for an unparseable hint + priority = 0.0 + mo = re.search(r'^([a-zA-Z0-9]+):(.*)$', hint) + if not mo: + print("unparseable hint '%s'" % (hint,), file=stderr) + return None + hint_type = mo.group(1) + if hint_type != "tcp": + print("unknown hint type '%s' in '%s'" % (hint_type, hint), file=stderr) + return None + hint_value = mo.group(2) + pieces = hint_value.split(":") + if len(pieces) < 2: + print("unparseable TCP hint (need more colons) '%s'" % (hint,), + file=stderr) + return None + mo = re.search(r'^(\d+)$', pieces[1]) + if not mo: + print("non-numeric port in TCP hint '%s'" % (hint,), file=stderr) + return None + hint_host = pieces[0] + hint_port = int(pieces[1]) + for more in pieces[2:]: + if more.startswith("priority="): + more_pieces = more.split("=") + try: + priority = float(more_pieces[1]) + except ValueError: + print("non-float priority= in TCP hint '%s'" % (hint,), + file=stderr) + return None + return DirectTCPV1Hint(hint_host, hint_port, priority) + +def parse_tcp_v1_hint(hint): # hint_struct -> hint_obj + hint_type = hint.get("type", "") + if hint_type not in ["direct-tcp-v1", "tor-tcp-v1"]: + log.msg("unknown hint type: %r" % (hint,)) + return None + if not("hostname" in hint + and isinstance(hint["hostname"], type(""))): + log.msg("invalid hostname in hint: %r" % (hint,)) + return None + if not("port" in hint + and isinstance(hint["port"], six.integer_types)): + log.msg("invalid port in hint: %r" % (hint,)) + return None + priority = hint.get("priority", 0.0) + if hint_type == "direct-tcp-v1": + return DirectTCPV1Hint(hint["hostname"], hint["port"], priority) + else: + return TorTCPV1Hint(hint["hostname"], hint["port"], priority) + +def parse_hint(hint_struct): + hint_type = hint_struct.get("type", "") + if hint_type == "relay-v1": + # the struct can include multiple ways to reach the same relay + rhints = filter(lambda h: h, # drop None (unrecognized) + [parse_tcp_v1_hint(rh) for rh in hint_struct["hints"]]) + return RelayV1Hint(rhints) + return parse_tcp_v1_hint(hint_struct) + +def encode_hint(h): + if isinstance(h, DirectTCPV1Hint): + return {"type": "direct-tcp-v1", + "priority": h.priority, + "hostname": h.hostname, + "port": h.port, # integer + } + elif isinstance(h, RelayV1Hint): + rhint = {"type": "relay-v1", "hints": []} + for rh in h.hints: + rhint["hints"].append({"type": "direct-tcp-v1", + "priority": rh.priority, + "hostname": rh.hostname, + "port": rh.port}) + return rhint + elif isinstance(h, TorTCPV1Hint): + return {"type": "tor-tcp-v1", + "priority": h.priority, + "hostname": h.hostname, + "port": h.port, # integer + } + raise ValueError("unknown hint type", h) + +def HKDF(skm, outlen, salt=None, CTXinfo=b""): + return Hkdf(salt, skm).expand(CTXinfo, outlen) + +def build_sided_relay_handshake(key, side): + assert isinstance(side, type(u"")) + assert len(side) == 8*2 + token = HKDF(key, 32, CTXinfo=b"transit_relay_token") + return b"please relay "+hexlify(token)+b" for side "+side.encode("ascii")+b"\n" + +PROLOGUE_LEADER = b"Magic-Wormhole Dilation Handshake v1 Leader\n\n" +PROLOGUE_FOLLOWER = b"Magic-Wormhole Dilation Handshake v1 Follower\n\n" +NOISEPROTO = "Noise_NNpsk0_25519_ChaChaPoly_BLAKE2s" + +@attrs +@implementer(IDilationConnector) +class Connector(object): + _dilation_key = attrib(validator=instance_of(type(b""))) + _transit_relay_location = attrib(validator=optional(instance_of(str))) + _manager = attrib(validator=provides(IDilationManager)) + _reactor = attrib() + _eventual_queue = attrib() + _no_listen = attrib(validator=instance_of(bool)) + _tor = attrib() + _timing = attrib() + _side = attrib(validator=instance_of(type(u""))) + # was self._side = bytes_to_hexstr(os.urandom(8)) # unicode + _role = attrib() + + m = MethodicalMachine() + set_trace = getattr(m, "_setTrace", lambda self, f: None) + + RELAY_DELAY = 2.0 + + def __attrs_post_init__(self): + if self._transit_relay_location: + # TODO: allow multiple hints for a single relay + relay_hint = parse_hint_argv(self._transit_relay_location) + relay = RelayV1Hint(hints=(relay_hint,)) + self._transit_relays = [relay] + else: + self._transit_relays = [] + self._listeners = set() # IListeningPorts that can be stopped + self._pending_connectors = set() # Deferreds that can be cancelled + self._pending_connections = EmptyableSet(_eventual_queue=self._eventual_queue) # Protocols to be stopped + self._contenders = set() # viable connections + self._winning_connection = None + self._timing = self._timing or DebugTiming() + self._timing.add("transit") + + # this describes what our Connector can do, for the initial advertisement + @classmethod + def get_connection_abilities(klass): + return [{"type": "direct-tcp-v1"}, + {"type": "relay-v1"}, + ] + + def build_protocol(self, addr): + # encryption: let's use Noise NNpsk0 (or maybe NNpsk2). That uses + # ephemeral keys plus a pre-shared symmetric key (the Transit key), a + # different one for each potential connection. + from noise.connection import NoiseConnection + noise = NoiseConnection.from_name(NOISEPROTO) + noise.set_psks(self._dilation_key) + if self._role is LEADER: + noise.set_as_initiator() + outbound_prologue = PROLOGUE_LEADER + inbound_prologue = PROLOGUE_FOLLOWER + else: + noise.set_as_responder() + outbound_prologue = PROLOGUE_FOLLOWER + inbound_prologue = PROLOGUE_LEADER + p = DilatedConnectionProtocol(self._eventual_queue, self._role, + self, noise, + outbound_prologue, inbound_prologue) + return p + + @m.state(initial=True) + def connecting(self): pass # pragma: no cover + @m.state() + def connected(self): pass # pragma: no cover + @m.state(terminal=True) + def stopped(self): pass # pragma: no cover + + # TODO: unify the tense of these method-name verbs + @m.input() + def listener_ready(self, hint_objs): pass + @m.input() + def add_relay(self, hint_objs): pass + @m.input() + def got_hints(self, hint_objs): pass + @m.input() + def add_candidate(self, c): # called by DilatedConnectionProtocol + pass + @m.input() + def accept(self, c): pass + @m.input() + def stop(self): pass + + @m.output() + def use_hints(self, hint_objs): + self._use_hints(hint_objs) + + @m.output() + def publish_hints(self, hint_objs): + self._manager.send_hints([encode_hint(h) for h in hint_objs]) + + @m.output() + def consider(self, c): + self._contenders.add(c) + if self._role is LEADER: + # for now, just accept the first one. TODO: be clever. + self._eventual_queue.eventually(self.accept, c) + else: + # the follower always uses the first contender, since that's the + # only one the leader picked + self._eventual_queue.eventually(self.accept, c) + + @m.output() + def select_and_stop_remaining(self, c): + self._winning_connection = c + self._contenders.clear() # we no longer care who else came close + # remove this winner from the losers, so we don't shut it down + self._pending_connections.discard(c) + # shut down losing connections + self.stop_listeners() # TODO: maybe keep it open? NAT/p2p assist + self.stop_pending_connectors() + self.stop_pending_connections() + + c.select(self._manager) # subsequent frames go directly to the manager + if self._role is LEADER: + # TODO: this should live in Connection + c.send_record(KCM()) # leader sends KCM now + self._manager.use_connection(c) # manager sends frames to Connection + + @m.output() + def stop_everything(self): + self.stop_listeners() + self.stop_pending_connectors() + self.stop_pending_connections() + self.break_cycles() + + def stop_listeners(self): + d = DeferredList([l.stopListening() for l in self._listeners]) + self._listeners.clear() + return d # synchronization for tests + + def stop_pending_connectors(self): + return DeferredList([d.cancel() for d in self._pending_connectors]) + + def stop_pending_connections(self): + d = self._pending_connections.when_next_empty() + [c.loseConnection() for c in self._pending_connections] + return d + + def stop_winner(self): + d = self._winner.when_disconnected() + self._winner.disconnect() + return d + + def break_cycles(self): + # help GC by forgetting references to things that reference us + self._listeners.clear() + self._pending_connectors.clear() + self._pending_connections.clear() + self._winner = None + + connecting.upon(listener_ready, enter=connecting, outputs=[publish_hints]) + connecting.upon(add_relay, enter=connecting, outputs=[use_hints, + publish_hints]) + connecting.upon(got_hints, enter=connecting, outputs=[use_hints]) + connecting.upon(add_candidate, enter=connecting, outputs=[consider]) + connecting.upon(accept, enter=connected, outputs=[select_and_stop_remaining]) + connecting.upon(stop, enter=stopped, outputs=[stop_everything]) + + # once connected, we ignore everything except stop + connected.upon(listener_ready, enter=connected, outputs=[]) + connected.upon(add_relay, enter=connected, outputs=[]) + connected.upon(got_hints, enter=connected, outputs=[]) + connected.upon(add_candidate, enter=connected, outputs=[]) + connected.upon(accept, enter=connected, outputs=[]) + connected.upon(stop, enter=stopped, outputs=[stop_everything]) + + + # from Manager: start, got_hints, stop + # maybe add_candidate, accept + def start(self): + self._start_listener() + if self._transit_relays: + self.publish_hints(self._transit_relays) + self._use_hints(self._transit_relays) + + def _start_listener(self): + if self._no_listen or self._tor: + return + addresses = ipaddrs.find_addresses() + non_loopback_addresses = [a for a in addresses if a != "127.0.0.1"] + if non_loopback_addresses: + # some test hosts, including the appveyor VMs, *only* have + # 127.0.0.1, and the tests will hang badly if we remove it. + addresses = non_loopback_addresses + # TODO: listen on a fixed port, if possible, for NAT/p2p benefits, also + # to make firewall configs easier + # TODO: retain listening port between connection generations? + ep = serverFromString(self._reactor, "tcp:0") + f = InboundConnectionFactory(self) + d = ep.listen(f) + def _listening(lp): + # lp is an IListeningPort + self._listeners.add(lp) # for shutdown and tests + portnum = lp.getHost().port + direct_hints = [DirectTCPV1Hint(six.u(addr), portnum, 0.0) + for addr in addresses] + self.listener_ready(direct_hints) + d.addCallback(_listening) + d.addErrback(log.err) + + def _use_hints(self, hints): + # first, pull out all the relays, we'll connect to them later + relays = defaultdict(list) + direct = defaultdict(list) + for h in hints: + if isinstance(h, RelayV1Hint): + relays[h.priority].append(h) + else: + direct[h.priority].append(h) + delay = 0.0 + priorities = sorted(set(direct.keys()), reverse=True) + for p in priorities: + for h in direct[p]: + if isinstance(h, TorTCPV1Hint) and not self._tor: + continue + ep = self._endpoint_from_hint_obj(h) + desc = describe_hint_obj(h, False, self._tor) + d = deferLater(self._reactor, delay, + self._connect, ep, desc, is_relay=False) + self._pending_connectors.add(d) + # Make all direct connections immediately. Later, we'll change + # the add_candidate() function to look at the priority when + # deciding whether to accept a successful connection or not, + # and it can wait for more options if it sees a higher-priority + # one still running. But if we bail on that, we might consider + # putting an inter-direct-hint delay here to influence the + # process. + #delay += 1.0 + if delay > 0.0: + # Start trying the relays a few seconds after we start to try the + # direct hints. The idea is to prefer direct connections, but not + # be afraid of using a relay when we have direct hints that don't + # resolve quickly. Many direct hints will be to unused + # local-network IP addresses, which won't answer, and would take + # the full TCP timeout (30s or more) to fail. If there were no + # direct hints, don't delay at all. + delay += self.RELAY_DELAY + + # prefer direct connections by stalling relay connections by a few + # seconds, unless we're using --no-listen in which case we're probably + # going to have to use the relay + delay = self.RELAY_DELAY if self._no_listen else 0.0 + + # It might be nice to wire this so that a failure in the direct hints + # causes the relay hints to be used right away (fast failover). But + # none of our current use cases would take advantage of that: if we + # have any viable direct hints, then they're either going to succeed + # quickly or hang for a long time. + for p in priorities: + for r in relays[p]: + for h in r.hints: + ep = self._endpoint_from_hint_obj(h) + desc = describe_hint_obj(h, True, self._tor) + d = deferLater(self._reactor, delay, + self._connect, ep, desc, is_relay=True) + self._pending_connectors.add(d) + # TODO: + #if not contenders: + # raise TransitError("No contenders for connection") + + # TODO: add 2*TIMEOUT deadline for first generation, don't wait forever for + # the initial connection + + def _connect(self, h, ep, description, is_relay=False): + relay_handshake = None + if is_relay: + relay_handshake = build_sided_relay_handshake(self._dilation_key, + self._side) + f = OutboundConnectionFactory(self, relay_handshake) + d = ep.connect(f) + # fires with protocol, or ConnectError + def _connected(p): + self._pending_connections.add(p) + # c might not be in _pending_connections, if it turned out to be a + # winner, which is why we use discard() and not remove() + p.when_disconnected().addCallback(self._pending_connections.discard) + d.addCallback(_connected) + return d + + def _endpoint_from_hint_obj(self, hint): + if self._tor: + if isinstance(hint, (DirectTCPV1Hint, TorTCPV1Hint)): + # this Tor object will throw ValueError for non-public IPv4 + # addresses and any IPv6 address + try: + return self._tor.stream_via(hint.hostname, hint.port) + except ValueError: + return None + return None + if isinstance(hint, DirectTCPV1Hint): + return HostnameEndpoint(self._reactor, hint.hostname, hint.port) + return None + + + # Connection selection. All instances of DilatedConnectionProtocol which + # look viable get passed into our add_contender() method. + + # On the Leader side, "viable" means we've seen their KCM frame, which is + # the first Noise-encrypted packet on any given connection, and it has an + # empty body. We gather viable connections until we see one that we like, + # or a timer expires. Then we "select" it, close the others, and tell our + # Manager to use it. + + # On the Follower side, we'll only see a KCM on the one connection selected + # by the Leader, so the first viable connection wins. + + # our Connection protocols call: add_candidate + +@attrs +class OutboundConnectionFactory(ClientFactory, object): + _connector = attrib(validator=provides(IDilationConnector)) + _relay_handshake = attrib(validator=optional(instance_of(bytes))) + + def buildProtocol(self, addr): + p = self._connector.build_protocol(addr) + p.factory = self + if self._relay_handshake is not None: + p.use_relay(self._relay_handshake) + return p + +@attrs +class InboundConnectionFactory(ServerFactory, object): + _connector = attrib(validator=provides(IDilationConnector)) + protocol = DilatedConnectionProtocol + + def buildProtocol(self, addr): + p = self._connector.build_protocol(addr) + p.factory = self + return p diff --git a/src/wormhole/_dilation/encode.py b/src/wormhole/_dilation/encode.py new file mode 100644 index 0000000..eb1b0b2 --- /dev/null +++ b/src/wormhole/_dilation/encode.py @@ -0,0 +1,16 @@ +from __future__ import print_function, unicode_literals +import struct + +assert len(struct.pack("L", value) +def from_be4(b): + if not isinstance(b, bytes): + raise TypeError(repr(b)) + if len(b) != 4: + raise ValueError + return struct.unpack(">L", b)[0] diff --git a/src/wormhole/_dilation/inbound.py b/src/wormhole/_dilation/inbound.py new file mode 100644 index 0000000..9235681 --- /dev/null +++ b/src/wormhole/_dilation/inbound.py @@ -0,0 +1,127 @@ +from __future__ import print_function, unicode_literals +from attr import attrs, attrib +from attr.validators import provides +from zope.interface import implementer +from twisted.python import log +from .._interfaces import IDilationManager, IInbound +from .subchannel import (SubChannel, _SubchannelAddress) + +class DuplicateOpenError(Exception): + pass +class DataForMissingSubchannelError(Exception): + pass +class CloseForMissingSubchannelError(Exception): + pass + +@attrs +@implementer(IInbound) +class Inbound(object): + # Inbound flow control: TCP delivers data to Connection.dataReceived, + # Connection delivers to our handle_data, we deliver to + # SubChannel.remote_data, subchannel delivers to proto.dataReceived + _manager = attrib(validator=provides(IDilationManager)) + _host_addr = attrib() + + def __attrs_post_init__(self): + # we route inbound Data records to Subchannels .dataReceived + self._open_subchannels = {} # scid -> Subchannel + self._paused_subchannels = set() # Subchannels that have paused us + # the set is non-empty, we pause the transport + self._highest_inbound_acked = -1 + self._connection = None + + # from our Manager + def set_listener_endpoint(self, listener_endpoint): + self._listener_endpoint = listener_endpoint + + def set_subchannel_zero(self, scid0, sc0): + self._open_subchannels[scid0] = sc0 + + + def use_connection(self, c): + self._connection = c + # We can pause the connection's reads when we receive too much data. If + # this is a non-initial connection, then we might already have + # subchannels that are paused from before, so we might need to pause + # the new connection before it can send us any data + if self._paused_subchannels: + self._connection.pauseProducing() + + # Inbound is responsible for tracking the high watermark and deciding + # whether to ignore inbound messages or not + + def is_record_old(self, r): + if r.seqnum <= self._highest_inbound_acked: + return True + return False + + def update_ack_watermark(self, r): + self._highest_inbound_acked = max(self._highest_inbound_acked, + r.seqnum) + + def handle_open(self, scid): + if scid in self._open_subchannels: + log.err(DuplicateOpenError("received duplicate OPEN for {}".format(scid))) + return + peer_addr = _SubchannelAddress(scid) + sc = SubChannel(scid, self._manager, self._host_addr, peer_addr) + self._open_subchannels[scid] = sc + self._listener_endpoint._got_open(sc, peer_addr) + + def handle_data(self, scid, data): + sc = self._open_subchannels.get(scid) + if sc is None: + log.err(DataForMissingSubchannelError("received DATA for non-existent subchannel {}".format(scid))) + return + sc.remote_data(data) + + def handle_close(self, scid): + sc = self._open_subchannels.get(scid) + if sc is None: + log.err(CloseForMissingSubchannelError("received CLOSE for non-existent subchannel {}".format(scid))) + return + sc.remote_close() + + def subchannel_closed(self, scid, sc): + # connectionLost has just been signalled + assert self._open_subchannels[scid] is sc + del self._open_subchannels[scid] + + def stop_using_connection(self): + self._connection = None + + + # from our Subchannel, or rather from the Protocol above it and sent + # through the subchannel + + # The subchannel is an IProducer, and application protocols can always + # thell them to pauseProducing if we're delivering inbound data too + # quickly. They don't need to register anything. + + def subchannel_pauseProducing(self, sc): + was_paused = bool(self._paused_subchannels) + self._paused_subchannels.add(sc) + if self._connection and not was_paused: + self._connection.pauseProducing() + + def subchannel_resumeProducing(self, sc): + was_paused = bool(self._paused_subchannels) + self._paused_subchannels.discard(sc) + if self._connection and was_paused and not self._paused_subchannels: + self._connection.resumeProducing() + + def subchannel_stopProducing(self, sc): + # This protocol doesn't want any additional data. If we were a normal + # (single-owner) Transport, we'd call .loseConnection now. But our + # Connection is shared among many subchannels, so instead we just + # stop letting them pause the connection. + was_paused = bool(self._paused_subchannels) + self._paused_subchannels.discard(sc) + if self._connection and was_paused and not self._paused_subchannels: + self._connection.resumeProducing() + + # TODO: we might refactor these pause/resume/stop methods by building a + # context manager that look at the paused/not-paused state first, then + # lets the caller modify self._paused_subchannels, then looks at it a + # second time, and calls c.pauseProducing/c.resumeProducing as + # appropriate. I'm not sure it would be any cleaner, though. diff --git a/src/wormhole/_dilation/manager.py b/src/wormhole/_dilation/manager.py new file mode 100644 index 0000000..860665a --- /dev/null +++ b/src/wormhole/_dilation/manager.py @@ -0,0 +1,514 @@ +from __future__ import print_function, unicode_literals +from collections import deque +from attr import attrs, attrib +from attr.validators import provides, instance_of, optional +from automat import MethodicalMachine +from zope.interface import implementer +from twisted.internet.defer import Deferred, inlineCallbacks, returnValue +from twisted.python import log +from .._interfaces import IDilator, IDilationManager, ISend +from ..util import dict_to_bytes, bytes_to_dict +from ..observer import OneShotObserver +from .._key import derive_key +from .encode import to_be4 +from .subchannel import (SubChannel, _SubchannelAddress, _WormholeAddress, + ControlEndpoint, SubchannelConnectorEndpoint, + SubchannelListenerEndpoint) +from .connector import Connector, parse_hint +from .roles import LEADER, FOLLOWER +from .connection import KCM, Ping, Pong, Open, Data, Close, Ack +from .inbound import Inbound +from .outbound import Outbound + +class OldPeerCannotDilateError(Exception): + pass +class UnknownDilationMessageType(Exception): + pass +class ReceivedHintsTooEarly(Exception): + pass + +@attrs +@implementer(IDilationManager) +class _ManagerBase(object): + _S = attrib(validator=provides(ISend)) + _my_side = attrib(validator=instance_of(type(u""))) + _transit_key = attrib(validator=instance_of(bytes)) + _transit_relay_location = attrib(validator=optional(instance_of(str))) + _reactor = attrib() + _eventual_queue = attrib() + _cooperator = attrib() + _no_listen = False # TODO + _tor = None # TODO + _timing = None # TODO + + def __attrs_post_init__(self): + self._got_versions_d = Deferred() + + self._my_role = None # determined upon rx_PLEASE + + self._connection = None + self._made_first_connection = False + self._first_connected = OneShotObserver(self._eventual_queue) + self._host_addr = _WormholeAddress() + + self._next_dilation_phase = 0 + + self._next_subchannel_id = 0 # increments by 2 + + # I kept getting confused about which methods were for inbound data + # (and thus flow-control methods go "out") and which were for + # outbound data (with flow-control going "in"), so I split them up + # into separate pieces. + self._inbound = Inbound(self, self._host_addr) + self._outbound = Outbound(self, self._cooperator) # from us to peer + + def set_listener_endpoint(self, listener_endpoint): + self._inbound.set_listener_endpoint(listener_endpoint) + def set_subchannel_zero(self, scid0, sc0): + self._inbound.set_subchannel_zero(scid0, sc0) + + def when_first_connected(self): + return self._first_connected.when_fired() + + + def send_dilation_phase(self, **fields): + dilation_phase = self._next_dilation_phase + self._next_dilation_phase += 1 + self._S.send("dilate-%d" % dilation_phase, dict_to_bytes(fields)) + + def send_hints(self, hints): # from Connector + self.send_dilation_phase(type="connection-hints", hints=hints) + + + # forward inbound-ish things to _Inbound + def subchannel_pauseProducing(self, sc): + self._inbound.subchannel_pauseProducing(sc) + def subchannel_resumeProducing(self, sc): + self._inbound.subchannel_resumeProducing(sc) + def subchannel_stopProducing(self, sc): + self._inbound.subchannel_stopProducing(sc) + + # forward outbound-ish things to _Outbound + def subchannel_registerProducer(self, sc, producer, streaming): + self._outbound.subchannel_registerProducer(sc, producer, streaming) + def subchannel_unregisterProducer(self, sc): + self._outbound.subchannel_unregisterProducer(sc) + + def send_open(self, scid): + self._queue_and_send(Open, scid) + def send_data(self, scid, data): + self._queue_and_send(Data, scid, data) + def send_close(self, scid): + self._queue_and_send(Close, scid) + + def _queue_and_send(self, record_type, *args): + r = self._outbound.build_record(record_type, *args) + # Outbound owns the send_record() pipe, so that it can stall new + # writes after a new connection is made until after all queued + # messages are written (to preserve ordering). + self._outbound.queue_and_send_record(r) # may trigger pauseProducing + + def subchannel_closed(self, scid, sc): + # let everyone clean up. This happens just after we delivered + # connectionLost to the Protocol, except for the control channel, + # which might get connectionLost later after they use ep.connect. + # TODO: is this inversion a problem? + self._inbound.subchannel_closed(scid, sc) + self._outbound.subchannel_closed(scid, sc) + + + def _start_connecting(self, role): + assert self._my_role is not None + self._connector = Connector(self._transit_key, + self._transit_relay_location, + self, + self._reactor, self._eventual_queue, + self._no_listen, self._tor, + self._timing, + self._side, # needed for relay handshake + self._my_role) + self._connector.start() + + # our Connector calls these + + def connector_connection_made(self, c): + self.connection_made() # state machine update + self._connection = c + self._inbound.use_connection(c) + self._outbound.use_connection(c) # does c.registerProducer + if not self._made_first_connection: + self._made_first_connection = True + self._first_connected.fire(None) + pass + def connector_connection_lost(self): + self._stop_using_connection() + if self.role is LEADER: + self.connection_lost_leader() # state machine + else: + self.connection_lost_follower() + + + def _stop_using_connection(self): + # the connection is already lost by this point + self._connection = None + self._inbound.stop_using_connection() + self._outbound.stop_using_connection() # does c.unregisterProducer + + # from our active Connection + + def got_record(self, r): + # records with sequence numbers: always ack, ignore old ones + if isinstance(r, (Open, Data, Close)): + self.send_ack(r.seqnum) # always ack, even for old ones + if self._inbound.is_record_old(r): + return + self._inbound.update_ack_watermark(r.seqnum) + if isinstance(r, Open): + self._inbound.handle_open(r.scid) + elif isinstance(r, Data): + self._inbound.handle_data(r.scid, r.data) + else: # isinstance(r, Close) + self._inbound.handle_close(r.scid) + if isinstance(r, KCM): + log.err("got unexpected KCM") + elif isinstance(r, Ping): + self.handle_ping(r.ping_id) + elif isinstance(r, Pong): + self.handle_pong(r.ping_id) + elif isinstance(r, Ack): + self._outbound.handle_ack(r.resp_seqnum) # retire queued messages + else: + log.err("received unknown message type {}".format(r)) + + # pings, pongs, and acks are not queued + def send_ping(self, ping_id): + self._outbound.send_if_connected(Ping(ping_id)) + + def send_pong(self, ping_id): + self._outbound.send_if_connected(Pong(ping_id)) + + def send_ack(self, resp_seqnum): + self._outbound.send_if_connected(Ack(resp_seqnum)) + + + def handle_ping(self, ping_id): + self.send_pong(ping_id) + + def handle_pong(self, ping_id): + # TODO: update is-alive timer + pass + + # subchannel maintenance + def allocate_subchannel_id(self): + raise NotImplemented # subclass knows if we're leader or follower + +# new scheme: +# * both sides send PLEASE as soon as they have an unverified key and +# w.dilate has been called, +# * PLEASE includes a dilation-specific "side" (independent of the "side" +# used by mailbox messages) +# * higher "side" is Leader, lower is Follower +# * PLEASE includes can-dilate list of version integers, requires overlap +# "1" is current +# * dilation starts as soon as we've sent PLEASE and received PLEASE +# (four-state two-variable IDLE/WANTING/WANTED/STARTED diamond FSM) +# * HINTS sent after dilation starts +# * only Leader sends RECONNECT, only Follower sends RECONNECTING. This +# is the only difference between the two sides, and is not enforced +# by the protocol (i.e. if the Follower sends RECONNECT to the Leader, +# the Leader will obey, although TODO how confusing will this get?) +# * upon receiving RECONNECT: drop Connector, start new Connector, send +# RECONNECTING, start sending HINTS +# * upon sending CONNECT: go into FLUSHING state and ignore all HINTS until +# RECONNECTING received. The new Connector can be spun up earlier, and it +# can send HINTS, but it must not be given any HINTS that arrive before +# RECONNECTING (since they're probably stale) + +# * after VERSIONS(KCM) received, we might learn that they other side cannot +# dilate. w.dilate errbacks at this point + +# * maybe signal warning if we stay in a "want" state for too long +# * nobody sends HINTS until they're ready to receive +# * nobody sends HINTS unless they've called w.dilate() and received PLEASE +# * nobody connects to inbound hints unless they've called w.dilate() +# * if leader calls w.dilate() but not follower, leader waits forever in +# "want" (doesn't send anything) +# * if follower calls w.dilate() but not leader, follower waits forever +# in "want", leader waits forever in "wanted" + +class ManagerShared(_ManagerBase): + m = MethodicalMachine() + set_trace = getattr(m, "_setTrace", lambda self, f: None) + + @m.state(initial=True) + def IDLE(self): pass # pragma: no cover + + @m.state() + def WANTING(self): pass # pragma: no cover + @m.state() + def WANTED(self): pass # pragma: no cover + @m.state() + def CONNECTING(self): pass # pragma: no cover + @m.state() + def CONNECTED(self): pass # pragma: no cover + @m.state() + def FLUSHING(self): pass # pragma: no cover + @m.state() + def ABANDONING(self): pass # pragma: no cover + @m.state() + def LONELY(self): pass # pragme: no cover + @m.state() + def STOPPING(self): pass # pragma: no cover + @m.state(terminal=True) + def STOPPED(self): pass # pragma: no cover + + @m.input() + def start(self): pass # pragma: no cover + @m.input() + def rx_PLEASE(self, message): pass # pragma: no cover + @m.input() # only sent by Follower + def rx_HINTS(self, hint_message): pass # pragma: no cover + @m.input() # only Leader sends RECONNECT, so only Follower receives it + def rx_RECONNECT(self): pass # pragma: no cover + @m.input() # only Follower sends RECONNECTING, so only Leader receives it + def rx_RECONNECTING(self): pass # pragma: no cover + + # Connector gives us connection_made() + @m.input() + def connection_made(self, c): pass # pragma: no cover + + # our connection_lost() fires connection_lost_leader or + # connection_lost_follower depending upon our role. If either side sees a + # problem with the connection (timeouts, bad authentication) then they + # just drop it and let connection_lost() handle the cleanup. + @m.input() + def connection_lost_leader(self): pass # pragma: no cover + @m.input() + def connection_lost_follower(self): pass + + @m.input() + def stop(self): pass # pragma: no cover + + @m.output() + def stash_side(self, message): + their_side = message["side"] + self.my_role = LEADER if self._my_side > their_side else FOLLOWER + + # these Outputs behave differently for the Leader vs the Follower + @m.output() + def send_please(self): + self.send_dilation_phase(type="please", side=self._my_side) + + @m.output() + def start_connecting(self): + self._start_connecting() # TODO: merge + @m.output() + def ignore_message_start_connecting(self, message): + self.start_connecting() + + @m.output() + def send_reconnect(self): + self.send_dilation_phase(type="reconnect") # TODO: generation number? + @m.output() + def send_reconnecting(self): + self.send_dilation_phase(type="reconnecting") # TODO: generation? + + @m.output() + def use_hints(self, hint_message): + hint_objs = filter(lambda h: h, # ignore None, unrecognizable + [parse_hint(hs) for hs in hint_message["hints"]]) + hint_objs = list(hint_objs) + self._connector.got_hints(hint_objs) + @m.output() + def stop_connecting(self): + self._connector.stop() + @m.output() + def abandon_connection(self): + # we think we're still connected, but the Leader disagrees. Or we've + # been told to shut down. + self._connection.disconnect() # let connection_lost do cleanup + + + # we don't start CONNECTING until a local start() plus rx_PLEASE + IDLE.upon(rx_PLEASE, enter=WANTED, outputs=[stash_side]) + IDLE.upon(start, enter=WANTING, outputs=[send_please]) + WANTED.upon(start, enter=CONNECTING, outputs=[send_please, start_connecting]) + WANTING.upon(rx_PLEASE, enter=CONNECTING, + outputs=[stash_side, + ignore_message_start_connecting]) + + CONNECTING.upon(connection_made, enter=CONNECTED, outputs=[]) + + # Leader + CONNECTED.upon(connection_lost_leader, enter=FLUSHING, + outputs=[send_reconnect]) + FLUSHING.upon(rx_RECONNECTING, enter=CONNECTING, outputs=[start_connecting]) + + # Follower + # if we notice a lost connection, just wait for the Leader to notice too + CONNECTED.upon(connection_lost_follower, enter=LONELY, outputs=[]) + LONELY.upon(rx_RECONNECT, enter=CONNECTING, outputs=[start_connecting]) + # but if they notice it first, abandon our (seemingly functional) + # connection, then tell them that we're ready to try again + CONNECTED.upon(rx_RECONNECT, enter=ABANDONING, # they noticed loss + outputs=[abandon_connection]) + ABANDONING.upon(connection_lost_follower, enter=CONNECTING, + outputs=[send_reconnecting, start_connecting]) + # and if they notice a problem while we're still connecting, abandon our + # incomplete attempt and try again. in this case we don't have to wait + # for a connection to finish shutdown + CONNECTING.upon(rx_RECONNECT, enter=CONNECTING, + outputs=[stop_connecting, + send_reconnecting, + start_connecting]) + + + # rx_HINTS never changes state, they're just accepted or ignored + IDLE.upon(rx_HINTS, enter=IDLE, outputs=[]) # too early + WANTED.upon(rx_HINTS, enter=WANTED, outputs=[]) # too early + WANTING.upon(rx_HINTS, enter=WANTING, outputs=[]) # too early + CONNECTING.upon(rx_HINTS, enter=CONNECTING, outputs=[use_hints]) + CONNECTED.upon(rx_HINTS, enter=CONNECTED, outputs=[]) # too late, ignore + FLUSHING.upon(rx_HINTS, enter=FLUSHING, outputs=[]) # stale, ignore + LONELY.upon(rx_HINTS, enter=LONELY, outputs=[]) # stale, ignore + ABANDONING.upon(rx_HINTS, enter=ABANDONING, outputs=[]) # shouldn't happen + STOPPING.upon(rx_HINTS, enter=STOPPING, outputs=[]) + + IDLE.upon(stop, enter=STOPPED, outputs=[]) + WANTED.upon(stop, enter=STOPPED, outputs=[]) + WANTING.upon(stop, enter=STOPPED, outputs=[]) + CONNECTING.upon(stop, enter=STOPPED, outputs=[stop_connecting]) + CONNECTED.upon(stop, enter=STOPPING, outputs=[abandon_connection]) + ABANDONING.upon(stop, enter=STOPPING, outputs=[]) + FLUSHING.upon(stop, enter=STOPPED, outputs=[stop_connecting]) + LONELY.upon(stop, enter=STOPPED, outputs=[]) + STOPPING.upon(connection_lost_leader, enter=STOPPED, outputs=[]) + STOPPING.upon(connection_lost_follower, enter=STOPPED, outputs=[]) + + + def allocate_subchannel_id(self): + # scid 0 is reserved for the control channel. the leader uses odd + # numbers starting with 1 + scid_num = self._next_outbound_seqnum + 1 + self._next_outbound_seqnum += 2 + return to_be4(scid_num) + +@attrs +@implementer(IDilator) +class Dilator(object): + """I launch the dilation process. + + I am created with every Wormhole (regardless of whether .dilate() + was called or not), and I handle the initial phase of dilation, + before we know whether we'll be the Leader or the Follower. Once we + hear the other side's VERSION message (which tells us that we have a + connection, they are capable of dilating, and which side we're on), + then we build a DilationManager and hand control to it. + """ + + _reactor = attrib() + _eventual_queue = attrib() + _cooperator = attrib() + + def __attrs_post_init__(self): + self._got_versions_d = Deferred() + self._started = False + self._endpoints = OneShotObserver(self._eventual_queue) + self._pending_inbound_dilate_messages = deque() + self._manager = None + + def wire(self, sender): + self._S = ISend(sender) + + # this is the primary entry point, called when w.dilate() is invoked + def dilate(self, transit_relay_location=None): + self._transit_relay_location = transit_relay_location + if not self._started: + self._started = True + self._start().addBoth(self._endpoints.fire) + return self._endpoints.when_fired() + + @inlineCallbacks + def _start(self): + # first, we wait until we hear the VERSION message, which tells us 1: + # the PAKE key works, so we can talk securely, 2: their side, so we + # know who will lead, and 3: that they can do dilation at all + + dilation_version = yield self._got_versions_d + + if not dilation_version: # 1 or None + raise OldPeerCannotDilateError() + + my_dilation_side = TODO # random + self._manager = Manager(self._S, my_dilation_side, + self._transit_key, + self._transit_relay_location, + self._reactor, self._eventual_queue, + self._cooperator) + self._manager.start() + + while self._pending_inbound_dilate_messages: + plaintext = self._pending_inbound_dilate_messages.popleft() + self.received_dilate(plaintext) + + # we could probably return the endpoints earlier + yield self._manager.when_first_connected() + # we can open subchannels as soon as we get our first connection + scid0 = b"\x00\x00\x00\x00" + self._host_addr = _WormholeAddress() # TODO: share with Manager + peer_addr0 = _SubchannelAddress(scid0) + control_ep = ControlEndpoint(peer_addr0) + sc0 = SubChannel(scid0, self._manager, self._host_addr, peer_addr0) + control_ep._subchannel_zero_opened(sc0) + self._manager.set_subchannel_zero(scid0, sc0) + + connect_ep = SubchannelConnectorEndpoint(self._manager, self._host_addr) + + listen_ep = SubchannelListenerEndpoint(self._manager, self._host_addr) + self._manager.set_listener_endpoint(listen_ep) + + endpoints = (control_ep, connect_ep, listen_ep) + returnValue(endpoints) + + # from Boss + + def got_key(self, key): + # TODO: verify this happens before got_wormhole_versions, or add a gate + # to tolerate either ordering + purpose = b"dilation-v1" + LENGTH =32 # TODO: whatever Noise wants, I guess + self._transit_key = derive_key(key, purpose, LENGTH) + + def got_wormhole_versions(self, our_side, their_side, + their_wormhole_versions): + # TODO: remove our_side, their_side + assert isinstance(our_side, str), str + assert isinstance(their_side, str), str + # this always happens before received_dilate + dilation_version = None + their_dilation_versions = their_wormhole_versions.get("can-dilate", []) + if 1 in their_dilation_versions: + dilation_version = 1 + self._got_versions_d.callback(dilation_version) + + def received_dilate(self, plaintext): + # this receives new in-order DILATE-n payloads, decrypted but not + # de-JSONed. + + # this can appear before our .dilate() method is called, in which case + # we queue them for later + if not self._manager: + self._pending_inbound_dilate_messages.append(plaintext) + return + + message = bytes_to_dict(plaintext) + type = message["type"] + if type == "please": + self._manager.rx_PLEASE() # message) + elif type == "dilate": + self._manager.rx_DILATE() #message) + elif type == "connection-hints": + self._manager.rx_HINTS(message) + else: + log.err(UnknownDilationMessageType(message)) + return diff --git a/src/wormhole/_dilation/old-follower.py b/src/wormhole/_dilation/old-follower.py new file mode 100644 index 0000000..68e3a38 --- /dev/null +++ b/src/wormhole/_dilation/old-follower.py @@ -0,0 +1,106 @@ + +class ManagerFollower(_ManagerBase): + m = MethodicalMachine() + set_trace = getattr(m, "_setTrace", lambda self, f: None) + + @m.state(initial=True) + def IDLE(self): pass # pragma: no cover + + @m.state() + def WANTING(self): pass # pragma: no cover + @m.state() + def CONNECTING(self): pass # pragma: no cover + @m.state() + def CONNECTED(self): pass # pragma: no cover + @m.state(terminal=True) + def STOPPED(self): pass # pragma: no cover + + @m.input() + def start(self): pass # pragma: no cover + @m.input() + def rx_PLEASE(self): pass # pragma: no cover + @m.input() + def rx_DILATE(self): pass # pragma: no cover + @m.input() + def rx_HINTS(self, hint_message): pass # pragma: no cover + + @m.input() + def connection_made(self): pass # pragma: no cover + @m.input() + def connection_lost(self): pass # pragma: no cover + # follower doesn't react to connection_lost, but waits for a new LETS_DILATE + + @m.input() + def stop(self): pass # pragma: no cover + + # these Outputs behave differently for the Leader vs the Follower + @m.output() + def send_please(self): + self.send_dilation_phase(type="please") + + @m.output() + def start_connecting(self): + self._start_connecting(FOLLOWER) + + # these Outputs delegate to the same code in both the Leader and the + # Follower, but they must be replicated here because the Automat instance + # is on the subclass, not the shared superclass + + @m.output() + def use_hints(self, hint_message): + hint_objs = filter(lambda h: h, # ignore None, unrecognizable + [parse_hint(hs) for hs in hint_message["hints"]]) + self._connector.got_hints(hint_objs) + @m.output() + def stop_connecting(self): + self._connector.stop() + @m.output() + def use_connection(self, c): + self._use_connection(c) + @m.output() + def stop_using_connection(self): + self._stop_using_connection() + @m.output() + def signal_error(self): + pass # TODO + @m.output() + def signal_error_hints(self, hint_message): + pass # TODO + + IDLE.upon(rx_HINTS, enter=STOPPED, outputs=[signal_error_hints]) # too early + IDLE.upon(rx_DILATE, enter=STOPPED, outputs=[signal_error]) # too early + # leader shouldn't send us DILATE before receiving our PLEASE + IDLE.upon(stop, enter=STOPPED, outputs=[]) + IDLE.upon(start, enter=WANTING, outputs=[send_please]) + WANTING.upon(rx_DILATE, enter=CONNECTING, outputs=[start_connecting]) + WANTING.upon(stop, enter=STOPPED, outputs=[]) + + CONNECTING.upon(rx_HINTS, enter=CONNECTING, outputs=[use_hints]) + CONNECTING.upon(connection_made, enter=CONNECTED, outputs=[use_connection]) + # shouldn't happen: connection_lost + #CONNECTING.upon(connection_lost, enter=CONNECTING, outputs=[?]) + CONNECTING.upon(rx_DILATE, enter=CONNECTING, outputs=[stop_connecting, + start_connecting]) + # receiving rx_DILATE while we're still working on the last one means the + # leader thought we'd connected, then thought we'd been disconnected, all + # before we heard about that connection + CONNECTING.upon(stop, enter=STOPPED, outputs=[stop_connecting]) + + CONNECTED.upon(connection_lost, enter=WANTING, outputs=[stop_using_connection]) + CONNECTED.upon(rx_DILATE, enter=CONNECTING, outputs=[stop_using_connection, + start_connecting]) + CONNECTED.upon(rx_HINTS, enter=CONNECTED, outputs=[]) # too late, ignore + CONNECTED.upon(stop, enter=STOPPED, outputs=[stop_using_connection]) + # shouldn't happen: connection_made + + # we should never receive PLEASE, we're the follower + IDLE.upon(rx_PLEASE, enter=STOPPED, outputs=[signal_error]) + WANTING.upon(rx_PLEASE, enter=STOPPED, outputs=[signal_error]) + CONNECTING.upon(rx_PLEASE, enter=STOPPED, outputs=[signal_error]) + CONNECTED.upon(rx_PLEASE, enter=STOPPED, outputs=[signal_error]) + + def allocate_subchannel_id(self): + # the follower uses even numbers starting with 2 + scid_num = self._next_outbound_seqnum + 2 + self._next_outbound_seqnum += 2 + return to_be4(scid_num) diff --git a/src/wormhole/_dilation/outbound.py b/src/wormhole/_dilation/outbound.py new file mode 100644 index 0000000..6538ffe --- /dev/null +++ b/src/wormhole/_dilation/outbound.py @@ -0,0 +1,392 @@ +from __future__ import print_function, unicode_literals +from collections import deque +from attr import attrs, attrib +from attr.validators import provides +from zope.interface import implementer +from twisted.internet.interfaces import IPushProducer, IPullProducer +from twisted.python import log +from twisted.python.reflect import safe_str +from .._interfaces import IDilationManager, IOutbound +from .connection import KCM, Ping, Pong, Ack + + +# Outbound flow control: app writes to subchannel, we write to Connection + +# The app can register an IProducer of their choice, to let us throttle their +# outbound data. Not all subchannels will have producers registered, and the +# producer probably won't be the IProtocol instance (it'll be something else +# which feeds data out through the protocol, like a t.p.basic.FileSender). If +# a producerless subchannel writes too much, we won't be able to stop them, +# and we'll keep writing records into the Connection even though it's asked +# us to pause. Likewise, when the connection is down (and we're busily trying +# to reestablish a new one), registered subchannels will be paused, but +# unregistered ones will just dump everything in _outbound_queue, and we'll +# consume memory without bound until they stop. + +# We need several things: +# +# * Add each registered IProducer to a list, whose order remains stable. We +# want fairness under outbound throttling: each time the outbound +# connection opens up (our resumeProducing method is called), we should let +# just one producer have an opportunity to do transport.write, and then we +# should pause them again, and not come back to them until everyone else +# has gotten a turn. So we want an ordered list of producers to track this +# rotation. +# +# * Remove the IProducer if/when the protocol uses unregisterProducer +# +# * Remove any registered IProducer when the associated Subchannel is closed. +# This isn't a problem for normal transports, because usually there's a +# one-to-one mapping from Protocol to Transport, so when the Transport you +# forget the only reference to the Producer anyways. Our situation is +# unusual because we have multiple Subchannels that get merged into the +# same underlying Connection: each Subchannel's Protocol can register a +# producer on the Subchannel (which is an ITransport), but that adds it to +# a set of Producers for the Connection (which is also an ITransport). So +# if the Subchannel is closed, we need to remove its Producer (if any) even +# though the Connection remains open. +# +# * Register ourselves as an IPushProducer with each successive Connection +# object. These connections will come and go, but there will never be more +# than one. When the connection goes away, pause all our producers. When a +# new one is established, write all our queued messages, then unpause our +# producers as we would in resumeProducing. +# +# * Inside our resumeProducing call, we'll cycle through all producers, +# calling their individual resumeProducing methods one at a time. If they +# write so much data that the Connection pauses us again, we'll find out +# because our pauseProducing will be called inside that loop. When that +# happens, we need to stop looping. If we make it through the whole loop +# without being paused, then all subchannel Producers are left unpaused, +# and are free to write whenever they want. During this loop, some +# Producers will be paused, and others will be resumed +# +# * If our pauseProducing is called, all Producers must be paused, and a flag +# should be set to notify the resumeProducing loop to exit +# +# * In between calls to our resumeProducing method, we're in one of two +# states. +# * If we're writing data too fast, then we'll be left in the "paused" +# state, in which all Subchannel producers are paused, and the aggregate +# is paused too (our Connection told us to pauseProducing and hasn't yet +# told us to resumeProducing). In this state, activity is driven by the +# outbound TCP window opening up, which calls resumeProducing and allows +# (probably just) one message to be sent. We receive pauseProducing in +# the middle of their transport.write, so the loop exits early, and the +# only state change is that some other Producer should get to go next +# time. +# * If we're writing too slowly, we'll be left in the "unpaused" state: all +# Subchannel producers are unpaused, and the aggregate is unpaused too +# (resumeProducing is the last thing we've been told). In this satte, +# activity is driven by the Subchannels doing a transport.write, which +# queues some data on the TCP connection (and then might call +# pauseProducing if it's now full). +# +# * We want to guard against: +# +# * application protocol registering a Producer without first unregistering +# the previous one +# +# * application protocols writing data despite being told to pause +# (Subchannels without a registered Producer cannot be throttled, and we +# can't do anything about that, but we must also handle the case where +# they give us a pause switch and then proceed to ignore it) +# +# * our Connection calling resumeProducing or pauseProducing without an +# intervening call of the other kind +# +# * application protocols that don't handle a resumeProducing or +# pauseProducing call without an intervening call of the other kind (i.e. +# we should keep track of the last thing we told them, and not repeat +# ourselves) +# +# * If the Wormhole is closed, all Subchannels should close. This is not our +# responsibility: it lives in (Manager? Inbound?) +# +# * If we're given an IPullProducer, we should keep calling its +# resumeProducing until it runs out of data. We still want fairness, so we +# won't call it a second time until everyone else has had a turn. + + +# There are a couple of different ways to approach this. The one I've +# selected is: +# +# * keep a dict that maps from Subchannel to Producer, which only contains +# entries for Subchannels that have registered a producer. We use this to +# remove Producers when Subchannels are closed +# +# * keep a Deque of Producers. This represents the fair-throttling rotation: +# the left-most item gets the next upcoming turn, and then they'll be moved +# to the end of the queue. +# +# * keep a set of IPushProducers which are paused, a second set of +# IPushProducers which are unpaused, and a third set of IPullProducers +# (which are always left paused) Enforce the invariant that these three +# sets are disjoint, and that their union equals the contents of the deque. +# +# * keep a "paused" flag, which is cleared upon entry to resumeProducing, and +# set upon entry to pauseProducing. The loop inside resumeProducing checks +# this flag after each call to producer.resumeProducing, to sense whether +# they used their turn to write data, and if that write was large enough to +# fill the TCP window. If set, we break out of the loop. If not, we look +# for the next producer to unpause. The loop finishes when all producers +# are unpaused (evidenced by the two sets of paused producers being empty) +# +# * the "paused" flag also determines whether new IPushProducers are added to +# the paused or unpaused set (IPullProducers are always added to the +# pull+paused set). If we have any IPullProducers, we're always in the +# "writing data too fast" state. + +# other approaches that I didn't decide to do at this time (but might use in +# the future): +# +# * use one set instead of two. pros: fewer moving parts. cons: harder to +# spot decoherence bugs like adding a producer to the deque but forgetting +# to add it to one of the +# +# * use zero sets, and keep the paused-vs-unpaused state in the Subchannel as +# a visible boolean flag. This conflates Subchannels with their associated +# Producer (so if we went this way, we should also let them track their own +# Producer). Our resumeProducing loop ends when 'not any(sc.paused for sc +# in self._subchannels_with_producers)'. Pros: fewer subchannel->producer +# mappings lying around to disagree with one another. Cons: exposes a bit +# too much of the Subchannel internals + + +@attrs +@implementer(IOutbound) +class Outbound(object): + # Manage outbound data: subchannel writes to us, we write to transport + _manager = attrib(validator=provides(IDilationManager)) + _cooperator = attrib() + + def __attrs_post_init__(self): + # _outbound_queue holds all messages we've ever sent but not retired + self._outbound_queue = deque() + self._next_outbound_seqnum = 0 + # _queued_unsent are messages to retry with our new connection + self._queued_unsent = deque() + + # outbound flow control: the Connection throttles our writes + self._subchannel_producers = {} # Subchannel -> IProducer + self._paused = True # our Connection called our pauseProducing + self._all_producers = deque() # rotates, left-is-next + self._paused_producers = set() + self._unpaused_producers = set() + self._check_invariants() + + self._connection = None + + def _check_invariants(self): + assert self._unpaused_producers.isdisjoint(self._paused_producers) + assert (self._paused_producers.union(self._unpaused_producers) == + set(self._all_producers)) + + def build_record(self, record_type, *args): + seqnum = self._next_outbound_seqnum + self._next_outbound_seqnum += 1 + r = record_type(seqnum, *args) + assert hasattr(r, "seqnum"), r # only Open/Data/Close + return r + + def queue_and_send_record(self, r): + # we always queue it, to resend on a subsequent connection if + # necessary + self._outbound_queue.append(r) + + if self._connection: + if self._queued_unsent: + # to maintain correct ordering, queue this instead of sending it + self._queued_unsent.append(r) + else: + # we're allowed to send it immediately + self._connection.send_record(r) + + def send_if_connected(self, r): + assert isinstance(r, (KCM, Ping, Pong, Ack)), r # nothing with seqnum + if self._connection: + self._connection.send_record(r) + + # our subchannels call these to register a producer + + def subchannel_registerProducer(self, sc, producer, streaming): + # streaming==True: IPushProducer (pause/resume) + # streaming==False: IPullProducer (just resume) + if sc in self._subchannel_producers: + raise ValueError( + "registering producer %s before previous one (%s) was " + "unregistered" % (producer, + self._subchannel_producers[sc])) + # our underlying Connection uses streaming==True, so to make things + # easier, use an adapter when the Subchannel asks for streaming=False + if not streaming: + def unregister(): + self.subchannel_unregisterProducer(sc) + producer = PullToPush(producer, unregister, self._cooperator) + + self._subchannel_producers[sc] = producer + self._all_producers.append(producer) + if self._paused: + self._paused_producers.add(producer) + else: + self._unpaused_producers.add(producer) + self._check_invariants() + if streaming: + if self._paused: + # IPushProducers need to be paused immediately, before they + # speak + producer.pauseProducing() # you wake up sleeping + else: + # our PullToPush adapter must be started, but if we're paused then + # we tell it to pause before it gets a chance to write anything + producer.startStreaming(self._paused) + + def subchannel_unregisterProducer(self, sc): + # TODO: what if the subchannel closes, so we unregister their + # producer for them, then the application reacts to connectionLost + # with a duplicate unregisterProducer? + p = self._subchannel_producers.pop(sc) + if isinstance(p, PullToPush): + p.stopStreaming() + self._all_producers.remove(p) + self._paused_producers.discard(p) + self._unpaused_producers.discard(p) + self._check_invariants() + + def subchannel_closed(self, sc): + self._check_invariants() + if sc in self._subchannel_producers: + self.subchannel_unregisterProducer(sc) + + # our Manager tells us when we've got a new Connection to work with + + def use_connection(self, c): + self._connection = c + assert not self._queued_unsent + self._queued_unsent.extend(self._outbound_queue) + # the connection can tell us to pause when we send too much data + c.registerProducer(self, True) # IPushProducer: pause+resume + # send our queued messages + self.resumeProducing() + + def stop_using_connection(self): + self._connection.unregisterProducer() + self._connection = None + self._queued_unsent.clear() + self.pauseProducing() + # TODO: I expect this will call pauseProducing twice: the first time + # when we get stopProducing (since we're registere with the + # underlying connection as the producer), and again when the manager + # notices the connectionLost and calls our _stop_using_connection + + def handle_ack(self, resp_seqnum): + # we've received an inbound ack, so retire something + while (self._outbound_queue and + self._outbound_queue[0].seqnum <= resp_seqnum): + self._outbound_queue.popleft() + while (self._queued_unsent and + self._queued_unsent[0].seqnum <= resp_seqnum): + self._queued_unsent.popleft() + # Inbound is responsible for tracking the high watermark and deciding + # whether to ignore inbound messages or not + + + # IProducer: the active connection calls these because we used + # c.registerProducer to ask for them + def pauseProducing(self): + if self._paused: + return # someone is confused and called us twice + self._paused = True + for p in self._all_producers: + if p in self._unpaused_producers: + self._unpaused_producers.remove(p) + self._paused_producers.add(p) + p.pauseProducing() + + def resumeProducing(self): + if not self._paused: + return # someone is confused and called us twice + self._paused = False + + while not self._paused: + if self._queued_unsent: + r = self._queued_unsent.popleft() + self._connection.send_record(r) + continue + p = self._get_next_unpaused_producer() + if not p: + break + self._paused_producers.remove(p) + self._unpaused_producers.add(p) + p.resumeProducing() + + def _get_next_unpaused_producer(self): + self._check_invariants() + if not self._paused_producers: + return None + while True: + p = self._all_producers[0] + self._all_producers.rotate(-1) # p moves to the end of the line + # the only unpaused Producers are at the end of the list + assert p in self._paused_producers + return p + + def stopProducing(self): + # we'll hopefully have a new connection to work with in the future, + # so we don't shut anything down. We do pause everyone, though. + self.pauseProducing() + + +# modelled after twisted.internet._producer_helper._PullToPush , but with a +# configurable Cooperator, a pause-immediately argument to startStreaming() +@implementer(IPushProducer) +@attrs(cmp=False) +class PullToPush(object): + _producer = attrib(validator=provides(IPullProducer)) + _unregister = attrib(validator=lambda _a,_b,v: callable(v)) + _cooperator = attrib() + _finished = False + + def _pull(self): + while True: + try: + self._producer.resumeProducing() + except: + log.err(None, "%s failed, producing will be stopped:" % + (safe_str(self._producer),)) + try: + self._unregister() + # The consumer should now call stopStreaming() on us, + # thus stopping the streaming. + except: + # Since the consumer blew up, we may not have had + # stopStreaming() called, so we just stop on our own: + log.err(None, "%s failed to unregister producer:" % + (safe_str(self._unregister),)) + self._finished = True + return + yield None + + def startStreaming(self, paused): + self._coopTask = self._cooperator.cooperate(self._pull()) + if paused: + self.pauseProducing() # timer is scheduled, but task is removed + + def stopStreaming(self): + if self._finished: + return + self._finished = True + self._coopTask.stop() + + + def pauseProducing(self): + self._coopTask.pause() + + + def resumeProducing(self): + self._coopTask.resume() + + + def stopProducing(self): + self.stopStreaming() + self._producer.stopProducing() diff --git a/src/wormhole/_dilation/roles.py b/src/wormhole/_dilation/roles.py new file mode 100644 index 0000000..8f9adac --- /dev/null +++ b/src/wormhole/_dilation/roles.py @@ -0,0 +1 @@ +LEADER, FOLLOWER = object(), object() diff --git a/src/wormhole/_dilation/subchannel.py b/src/wormhole/_dilation/subchannel.py new file mode 100644 index 0000000..94b4a03 --- /dev/null +++ b/src/wormhole/_dilation/subchannel.py @@ -0,0 +1,269 @@ +from attr import attrs, attrib +from attr.validators import instance_of, provides +from zope.interface import implementer +from twisted.internet.defer import Deferred, inlineCallbacks, returnValue, succeed +from twisted.internet.interfaces import (ITransport, IProducer, IConsumer, + IAddress, IListeningPort, + IStreamClientEndpoint, + IStreamServerEndpoint) +from twisted.internet.error import ConnectionDone +from automat import MethodicalMachine +from .._interfaces import ISubChannel, IDilationManager + +@attrs +class Once(object): + _errtype = attrib() + def __attrs_post_init__(self): + self._called = False + + def __call__(self): + if self._called: + raise self._errtype() + self._called = True + +class SingleUseEndpointError(Exception): + pass + +# created in the (OPEN) state, by either: +# * receipt of an OPEN message +# * or local client_endpoint.connect() +# then transitions are: +# (OPEN) rx DATA: deliver .dataReceived(), -> (OPEN) +# (OPEN) rx CLOSE: deliver .connectionLost(), send CLOSE, -> (CLOSED) +# (OPEN) local .write(): send DATA, -> (OPEN) +# (OPEN) local .loseConnection(): send CLOSE, -> (CLOSING) +# (CLOSING) local .write(): error +# (CLOSING) local .loseConnection(): error +# (CLOSING) rx DATA: deliver .dataReceived(), -> (CLOSING) +# (CLOSING) rx CLOSE: deliver .connectionLost(), -> (CLOSED) +# object is deleted upon transition to (CLOSED) + +class AlreadyClosedError(Exception): + pass + +@implementer(IAddress) +class _WormholeAddress(object): + pass + +@implementer(IAddress) +@attrs +class _SubchannelAddress(object): + _scid = attrib() + + +@attrs +@implementer(ITransport) +@implementer(IProducer) +@implementer(IConsumer) +@implementer(ISubChannel) +class SubChannel(object): + _id = attrib(validator=instance_of(bytes)) + _manager = attrib(validator=provides(IDilationManager)) + _host_addr = attrib(validator=instance_of(_WormholeAddress)) + _peer_addr = attrib(validator=instance_of(_SubchannelAddress)) + + m = MethodicalMachine() + set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover + + def __attrs_post_init__(self): + #self._mailbox = None + #self._pending_outbound = {} + #self._processed = set() + self._protocol = None + self._pending_dataReceived = [] + self._pending_connectionLost = (False, None) + + @m.state(initial=True) + def open(self): pass # pragma: no cover + + @m.state() + def closing(): pass # pragma: no cover + + @m.state() + def closed(): pass # pragma: no cover + + @m.input() + def remote_data(self, data): pass + @m.input() + def remote_close(self): pass + + @m.input() + def local_data(self, data): pass + @m.input() + def local_close(self): pass + + + @m.output() + def send_data(self, data): + self._manager.send_data(self._id, data) + + @m.output() + def send_close(self): + self._manager.send_close(self._id) + + @m.output() + def signal_dataReceived(self, data): + if self._protocol: + self._protocol.dataReceived(data) + else: + self._pending_dataReceived.append(data) + + @m.output() + def signal_connectionLost(self): + if self._protocol: + self._protocol.connectionLost(ConnectionDone()) + else: + self._pending_connectionLost = (True, ConnectionDone()) + self._manager.subchannel_closed(self) + # we're deleted momentarily + + @m.output() + def error_closed_write(self, data): + raise AlreadyClosedError("write not allowed on closed subchannel") + @m.output() + def error_closed_close(self): + raise AlreadyClosedError("loseConnection not allowed on closed subchannel") + + # primary transitions + open.upon(remote_data, enter=open, outputs=[signal_dataReceived]) + open.upon(local_data, enter=open, outputs=[send_data]) + open.upon(remote_close, enter=closed, outputs=[signal_connectionLost]) + open.upon(local_close, enter=closing, outputs=[send_close]) + closing.upon(remote_data, enter=closing, outputs=[signal_dataReceived]) + closing.upon(remote_close, enter=closed, outputs=[signal_connectionLost]) + + # error cases + # we won't ever see an OPEN, since L4 will log+ignore those for us + closing.upon(local_data, enter=closing, outputs=[error_closed_write]) + closing.upon(local_close, enter=closing, outputs=[error_closed_close]) + # the CLOSED state won't ever see messages, since we'll be deleted + + # our endpoints use this + + def _set_protocol(self, protocol): + assert not self._protocol + self._protocol = protocol + if self._pending_dataReceived: + for data in self._pending_dataReceived: + self._protocol.dataReceived(data) + self._pending_dataReceived = [] + cl, what = self._pending_connectionLost + if cl: + self._protocol.connectionLost(what) + + # ITransport + def write(self, data): + assert isinstance(data, type(b"")) + self.local_data(data) + def writeSequence(self, iovec): + self.write(b"".join(iovec)) + def loseConnection(self): + self.local_close() + def getHost(self): + # we define "host addr" as the overall wormhole + return self._host_addr + def getPeer(self): + # and "peer addr" as the subchannel within that wormhole + return self._peer_addr + + # IProducer: throttle inbound data (wormhole "up" to local app's Protocol) + def stopProducing(self): + self._manager.subchannel_stopProducing(self) + def pauseProducing(self): + self._manager.subchannel_pauseProducing(self) + def resumeProducing(self): + self._manager.subchannel_resumeProducing(self) + + # IConsumer: allow the wormhole to throttle outbound data (app->wormhole) + def registerProducer(self, producer, streaming): + self._manager.subchannel_registerProducer(self, producer, streaming) + def unregisterProducer(self): + self._manager.subchannel_unregisterProducer(self) + + +@implementer(IStreamClientEndpoint) +class ControlEndpoint(object): + _used = False + def __init__(self, peer_addr): + self._subchannel_zero = Deferred() + self._peer_addr = peer_addr + self._once = Once(SingleUseEndpointError) + + # from manager + def _subchannel_zero_opened(self, subchannel): + assert ISubChannel.providedBy(subchannel), subchannel + self._subchannel_zero.callback(subchannel) + + @inlineCallbacks + def connect(self, protocolFactory): + # return Deferred that fires with IProtocol or Failure(ConnectError) + self._once() + t = yield self._subchannel_zero + p = protocolFactory.buildProtocol(self._peer_addr) + t._set_protocol(p) + p.makeConnection(t) # set p.transport = t and call connectionMade() + returnValue(p) + +@implementer(IStreamClientEndpoint) +@attrs +class SubchannelConnectorEndpoint(object): + _manager = attrib(validator=provides(IDilationManager)) + _host_addr = attrib(validator=instance_of(_WormholeAddress)) + + def connect(self, protocolFactory): + # return Deferred that fires with IProtocol or Failure(ConnectError) + scid = self._manager.allocate_subchannel_id() + self._manager.send_open(scid) + peer_addr = _SubchannelAddress(scid) + # ? f.doStart() + # ? f.startedConnecting(CONNECTOR) # ?? + t = SubChannel(scid, self._manager, self._host_addr, peer_addr) + p = protocolFactory.buildProtocol(peer_addr) + t._set_protocol(p) + p.makeConnection(t) # set p.transport = t and call connectionMade() + return succeed(p) + +@implementer(IStreamServerEndpoint) +@attrs +class SubchannelListenerEndpoint(object): + _manager = attrib(validator=provides(IDilationManager)) + _host_addr = attrib(validator=provides(IAddress)) + + def __attrs_post_init__(self): + self._factory = None + self._pending_opens = [] + + # from manager + def _got_open(self, t, peer_addr): + if self._factory: + self._connect(t, peer_addr) + else: + self._pending_opens.append( (t, peer_addr) ) + + def _connect(self, t, peer_addr): + p = self._factory.buildProtocol(peer_addr) + t._set_protocol(p) + p.makeConnection(t) + + # IStreamServerEndpoint + + def listen(self, protocolFactory): + self._factory = protocolFactory + for (t, peer_addr) in self._pending_opens: + self._connect(t, peer_addr) + self._pending_opens = [] + lp = SubchannelListeningPort(self._host_addr) + return succeed(lp) + +@implementer(IListeningPort) +@attrs +class SubchannelListeningPort(object): + _host_addr = attrib(validator=provides(IAddress)) + + def startListening(self): + pass + def stopListening(self): + # TODO + pass + def getHost(self): + return self._host_addr diff --git a/src/wormhole/_interfaces.py b/src/wormhole/_interfaces.py index 9f59ff4..52bde35 100644 --- a/src/wormhole/_interfaces.py +++ b/src/wormhole/_interfaces.py @@ -433,3 +433,16 @@ class IInputHelper(Interface): class IJournal(Interface): # TODO: this needs to be public pass + +class IDilator(Interface): + pass +class IDilationManager(Interface): + pass +class IDilationConnector(Interface): + pass +class ISubChannel(Interface): + pass +class IInbound(Interface): + pass +class IOutbound(Interface): + pass diff --git a/src/wormhole/test/dilate/__init__.py b/src/wormhole/test/dilate/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/wormhole/test/dilate/common.py b/src/wormhole/test/dilate/common.py new file mode 100644 index 0000000..2ddacfb --- /dev/null +++ b/src/wormhole/test/dilate/common.py @@ -0,0 +1,18 @@ +from __future__ import print_function, unicode_literals +import mock +from zope.interface import alsoProvides +from ..._interfaces import IDilationManager, IWormhole + +def mock_manager(): + m = mock.Mock() + alsoProvides(m, IDilationManager) + return m + +def mock_wormhole(): + m = mock.Mock() + alsoProvides(m, IWormhole) + return m + +def clear_mock_calls(*args): + for a in args: + a.mock_calls[:] = [] diff --git a/src/wormhole/test/dilate/test_connection.py b/src/wormhole/test/dilate/test_connection.py new file mode 100644 index 0000000..406e42a --- /dev/null +++ b/src/wormhole/test/dilate/test_connection.py @@ -0,0 +1,216 @@ +from __future__ import print_function, unicode_literals +import mock +from zope.interface import alsoProvides +from twisted.trial import unittest +from twisted.internet.task import Clock +from twisted.internet.interfaces import ITransport +from ...eventual import EventualQueue +from ..._interfaces import IDilationConnector +from ..._dilation.roles import LEADER, FOLLOWER +from ..._dilation.connection import (DilatedConnectionProtocol, encode_record, + KCM, Open, Ack) +from .common import clear_mock_calls + +def make_con(role, use_relay=False): + clock = Clock() + eq = EventualQueue(clock) + connector = mock.Mock() + alsoProvides(connector, IDilationConnector) + n = mock.Mock() # pretends to be a Noise object + n.write_message = mock.Mock(side_effect=[b"handshake"]) + c = DilatedConnectionProtocol(eq, role, connector, n, + b"outbound_prologue\n", b"inbound_prologue\n") + if use_relay: + c.use_relay(b"relay_handshake\n") + t = mock.Mock() + alsoProvides(t, ITransport) + return c, n, connector, t, eq + +class Connection(unittest.TestCase): + def test_bad_prologue(self): + c, n, connector, t, eq = make_con(LEADER) + c.makeConnection(t) + d = c.when_disconnected() + self.assertEqual(n.mock_calls, [mock.call.start_handshake()]) + self.assertEqual(connector.mock_calls, []) + self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")]) + clear_mock_calls(n, connector, t) + + c.dataReceived(b"prologue\n") + self.assertEqual(n.mock_calls, []) + self.assertEqual(connector.mock_calls, []) + self.assertEqual(t.mock_calls, [mock.call.loseConnection()]) + + eq.flush_sync() + self.assertNoResult(d) + c.connectionLost(b"why") + eq.flush_sync() + self.assertIdentical(self.successResultOf(d), c) + + def _test_no_relay(self, role): + c, n, connector, t, eq = make_con(role) + t_kcm = KCM() + t_open = Open(seqnum=1, scid=0x11223344) + t_ack = Ack(resp_seqnum=2) + n.decrypt = mock.Mock(side_effect=[ + encode_record(t_kcm), + encode_record(t_open), + ]) + exp_kcm = b"\x00\x00\x00\x03kcm" + n.encrypt = mock.Mock(side_effect=[b"kcm", b"ack1"]) + m = mock.Mock() # Manager + + c.makeConnection(t) + self.assertEqual(n.mock_calls, [mock.call.start_handshake()]) + self.assertEqual(connector.mock_calls, []) + self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")]) + clear_mock_calls(n, connector, t, m) + + c.dataReceived(b"inbound_prologue\n") + self.assertEqual(n.mock_calls, [mock.call.write_message()]) + self.assertEqual(connector.mock_calls, []) + exp_handshake = b"\x00\x00\x00\x09handshake" + self.assertEqual(t.mock_calls, [mock.call.write(exp_handshake)]) + clear_mock_calls(n, connector, t, m) + + c.dataReceived(b"\x00\x00\x00\x0Ahandshake2") + if role is LEADER: + # we're the leader, so we don't send the KCM right away + self.assertEqual(n.mock_calls, [ + mock.call.read_message(b"handshake2")]) + self.assertEqual(connector.mock_calls, []) + self.assertEqual(t.mock_calls, []) + self.assertEqual(c._manager, None) + else: + # we're the follower, so we encrypt and send the KCM immediately + self.assertEqual(n.mock_calls, [ + mock.call.read_message(b"handshake2"), + mock.call.encrypt(encode_record(t_kcm)), + ]) + self.assertEqual(connector.mock_calls, []) + self.assertEqual(t.mock_calls, [ + mock.call.write(exp_kcm)]) + self.assertEqual(c._manager, None) + clear_mock_calls(n, connector, t, m) + + c.dataReceived(b"\x00\x00\x00\x03KCM") + # leader: inbound KCM means we add the candidate + # follower: inbound KCM means we've been selected. + # in both cases we notify Connector.add_candidate(), and the Connector + # decides if/when to call .select() + + self.assertEqual(n.mock_calls, [mock.call.decrypt(b"KCM")]) + self.assertEqual(connector.mock_calls, [mock.call.add_candidate(c)]) + self.assertEqual(t.mock_calls, []) + clear_mock_calls(n, connector, t, m) + + # now pretend this connection wins (either the Leader decides to use + # this one among all the candiates, or we're the Follower and the + # Connector is reacting to add_candidate() by recognizing we're the + # only candidate there is) + c.select(m) + self.assertIdentical(c._manager, m) + if role is LEADER: + # TODO: currently Connector.select_and_stop_remaining() is + # responsible for sending the KCM just before calling c.select() + # iff we're the LEADER, therefore Connection.select won't send + # anything. This should be moved to c.select(). + self.assertEqual(n.mock_calls, []) + self.assertEqual(connector.mock_calls, []) + self.assertEqual(t.mock_calls, []) + self.assertEqual(m.mock_calls, []) + + c.send_record(KCM()) + self.assertEqual(n.mock_calls, [ + mock.call.encrypt(encode_record(t_kcm)), + ]) + self.assertEqual(connector.mock_calls, []) + self.assertEqual(t.mock_calls, [mock.call.write(exp_kcm)]) + self.assertEqual(m.mock_calls, []) + else: + # follower: we already sent the KCM, do nothing + self.assertEqual(n.mock_calls, []) + self.assertEqual(connector.mock_calls, []) + self.assertEqual(t.mock_calls, []) + self.assertEqual(m.mock_calls, []) + clear_mock_calls(n, connector, t, m) + + c.dataReceived(b"\x00\x00\x00\x04msg1") + self.assertEqual(n.mock_calls, [mock.call.decrypt(b"msg1")]) + self.assertEqual(connector.mock_calls, []) + self.assertEqual(t.mock_calls, []) + self.assertEqual(m.mock_calls, [mock.call.got_record(t_open)]) + clear_mock_calls(n, connector, t, m) + + c.send_record(t_ack) + exp_ack = b"\x06\x00\x00\x00\x02" + self.assertEqual(n.mock_calls, [mock.call.encrypt(exp_ack)]) + self.assertEqual(connector.mock_calls, []) + self.assertEqual(t.mock_calls, [mock.call.write(b"\x00\x00\x00\x04ack1")]) + self.assertEqual(m.mock_calls, []) + clear_mock_calls(n, connector, t, m) + + c.disconnect() + self.assertEqual(n.mock_calls, []) + self.assertEqual(connector.mock_calls, []) + self.assertEqual(t.mock_calls, [mock.call.loseConnection()]) + self.assertEqual(m.mock_calls, []) + clear_mock_calls(n, connector, t, m) + + def test_no_relay_leader(self): + return self._test_no_relay(LEADER) + + def test_no_relay_follower(self): + return self._test_no_relay(FOLLOWER) + + + def test_relay(self): + c, n, connector, t, eq = make_con(LEADER, use_relay=True) + + c.makeConnection(t) + self.assertEqual(n.mock_calls, [mock.call.start_handshake()]) + self.assertEqual(connector.mock_calls, []) + self.assertEqual(t.mock_calls, [mock.call.write(b"relay_handshake\n")]) + clear_mock_calls(n, connector, t) + + c.dataReceived(b"ok\n") + self.assertEqual(n.mock_calls, []) + self.assertEqual(connector.mock_calls, []) + self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")]) + clear_mock_calls(n, connector, t) + + c.dataReceived(b"inbound_prologue\n") + self.assertEqual(n.mock_calls, [mock.call.write_message()]) + self.assertEqual(connector.mock_calls, []) + exp_handshake = b"\x00\x00\x00\x09handshake" + self.assertEqual(t.mock_calls, [mock.call.write(exp_handshake)]) + clear_mock_calls(n, connector, t) + + def test_relay_jilted(self): + c, n, connector, t, eq = make_con(LEADER, use_relay=True) + d = c.when_disconnected() + + c.makeConnection(t) + self.assertEqual(n.mock_calls, [mock.call.start_handshake()]) + self.assertEqual(connector.mock_calls, []) + self.assertEqual(t.mock_calls, [mock.call.write(b"relay_handshake\n")]) + clear_mock_calls(n, connector, t) + + c.connectionLost(b"why") + eq.flush_sync() + self.assertIdentical(self.successResultOf(d), c) + + def test_relay_bad_response(self): + c, n, connector, t, eq = make_con(LEADER, use_relay=True) + + c.makeConnection(t) + self.assertEqual(n.mock_calls, [mock.call.start_handshake()]) + self.assertEqual(connector.mock_calls, []) + self.assertEqual(t.mock_calls, [mock.call.write(b"relay_handshake\n")]) + clear_mock_calls(n, connector, t) + + c.dataReceived(b"not ok\n") + self.assertEqual(n.mock_calls, []) + self.assertEqual(connector.mock_calls, []) + self.assertEqual(t.mock_calls, [mock.call.loseConnection()]) + clear_mock_calls(n, connector, t) diff --git a/src/wormhole/test/dilate/test_encoding.py b/src/wormhole/test/dilate/test_encoding.py new file mode 100644 index 0000000..e2c854e --- /dev/null +++ b/src/wormhole/test/dilate/test_encoding.py @@ -0,0 +1,25 @@ +from __future__ import print_function, unicode_literals +from twisted.trial import unittest +from ..._dilation.encode import to_be4, from_be4 + +class Encoding(unittest.TestCase): + + def test_be4(self): + self.assertEqual(to_be4(0), b"\x00\x00\x00\x00") + self.assertEqual(to_be4(1), b"\x00\x00\x00\x01") + self.assertEqual(to_be4(256), b"\x00\x00\x01\x00") + self.assertEqual(to_be4(257), b"\x00\x00\x01\x01") + with self.assertRaises(ValueError): + to_be4(-1) + with self.assertRaises(ValueError): + to_be4(2**32) + + self.assertEqual(from_be4(b"\x00\x00\x00\x00"), 0) + self.assertEqual(from_be4(b"\x00\x00\x00\x01"), 1) + self.assertEqual(from_be4(b"\x00\x00\x01\x00"), 256) + self.assertEqual(from_be4(b"\x00\x00\x01\x01"), 257) + + with self.assertRaises(TypeError): + from_be4(0) + with self.assertRaises(ValueError): + from_be4(b"\x01\x00\x00\x00\x00") diff --git a/src/wormhole/test/dilate/test_endpoints.py b/src/wormhole/test/dilate/test_endpoints.py new file mode 100644 index 0000000..bd8f995 --- /dev/null +++ b/src/wormhole/test/dilate/test_endpoints.py @@ -0,0 +1,97 @@ +from __future__ import print_function, unicode_literals +import mock +from zope.interface import alsoProvides +from twisted.trial import unittest +from ..._interfaces import ISubChannel +from ..._dilation.subchannel import (ControlEndpoint, + SubchannelConnectorEndpoint, + SubchannelListenerEndpoint, + SubchannelListeningPort, + _WormholeAddress, _SubchannelAddress, + SingleUseEndpointError) +from .common import mock_manager + +class Endpoints(unittest.TestCase): + def test_control(self): + scid0 = b"scid0" + peeraddr = _SubchannelAddress(scid0) + ep = ControlEndpoint(peeraddr) + + f = mock.Mock() + p = mock.Mock() + f.buildProtocol = mock.Mock(return_value=p) + d = ep.connect(f) + self.assertNoResult(d) + + t = mock.Mock() + alsoProvides(t, ISubChannel) + ep._subchannel_zero_opened(t) + self.assertIdentical(self.successResultOf(d), p) + self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr)]) + self.assertEqual(t.mock_calls, [mock.call._set_protocol(p)]) + self.assertEqual(p.mock_calls, [mock.call.makeConnection(t)]) + + d = ep.connect(f) + self.failureResultOf(d, SingleUseEndpointError) + + def assert_makeConnection(self, mock_calls): + self.assertEqual(len(mock_calls), 1) + self.assertEqual(mock_calls[0][0], "makeConnection") + self.assertEqual(len(mock_calls[0][1]), 1) + return mock_calls[0][1][0] + + def test_connector(self): + m = mock_manager() + m.allocate_subchannel_id = mock.Mock(return_value=b"scid") + hostaddr = _WormholeAddress() + peeraddr = _SubchannelAddress(b"scid") + ep = SubchannelConnectorEndpoint(m, hostaddr) + + f = mock.Mock() + p = mock.Mock() + t = mock.Mock() + f.buildProtocol = mock.Mock(return_value=p) + with mock.patch("wormhole._dilation.subchannel.SubChannel", + return_value=t) as sc: + d = ep.connect(f) + self.assertIdentical(self.successResultOf(d), p) + self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr)]) + self.assertEqual(sc.mock_calls, [mock.call(b"scid", m, hostaddr, peeraddr)]) + self.assertEqual(t.mock_calls, [mock.call._set_protocol(p)]) + self.assertEqual(p.mock_calls, [mock.call.makeConnection(t)]) + + def test_listener(self): + m = mock_manager() + m.allocate_subchannel_id = mock.Mock(return_value=b"scid") + hostaddr = _WormholeAddress() + ep = SubchannelListenerEndpoint(m, hostaddr) + + f = mock.Mock() + p1 = mock.Mock() + p2 = mock.Mock() + f.buildProtocol = mock.Mock(side_effect=[p1, p2]) + + # OPEN that arrives before we ep.listen() should be queued + + t1 = mock.Mock() + peeraddr1 = _SubchannelAddress(b"peer1") + ep._got_open(t1, peeraddr1) + + d = ep.listen(f) + lp = self.successResultOf(d) + self.assertIsInstance(lp, SubchannelListeningPort) + + self.assertEqual(lp.getHost(), hostaddr) + lp.startListening() + + self.assertEqual(t1.mock_calls, [mock.call._set_protocol(p1)]) + self.assertEqual(p1.mock_calls, [mock.call.makeConnection(t1)]) + + t2 = mock.Mock() + peeraddr2 = _SubchannelAddress(b"peer2") + ep._got_open(t2, peeraddr2) + + self.assertEqual(t2.mock_calls, [mock.call._set_protocol(p2)]) + self.assertEqual(p2.mock_calls, [mock.call.makeConnection(t2)]) + + lp.stopListening() # TODO: should this do more? diff --git a/src/wormhole/test/dilate/test_framer.py b/src/wormhole/test/dilate/test_framer.py new file mode 100644 index 0000000..81d4cf9 --- /dev/null +++ b/src/wormhole/test/dilate/test_framer.py @@ -0,0 +1,110 @@ +from __future__ import print_function, unicode_literals +import mock +from zope.interface import alsoProvides +from twisted.trial import unittest +from twisted.internet.interfaces import ITransport +from ..._dilation.connection import _Framer, Frame, Prologue, Disconnect + +def make_framer(): + t = mock.Mock() + alsoProvides(t, ITransport) + f = _Framer(t, b"outbound_prologue\n", b"inbound_prologue\n") + return f, t + +class Framer(unittest.TestCase): + def test_bad_prologue_length(self): + f, t = make_framer() + self.assertEqual(t.mock_calls, []) + + f.connectionMade() + self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")]) + t.mock_calls[:] = [] + self.assertEqual([], list(f.add_and_parse(b"inbound_"))) # wait for it + self.assertEqual(t.mock_calls, []) + + with mock.patch("wormhole._dilation.connection.log.msg") as m: + with self.assertRaises(Disconnect): + list(f.add_and_parse(b"not the prologue after all")) + self.assertEqual(m.mock_calls, + [mock.call("bad prologue: {}".format( + b"inbound_not the p"))]) + self.assertEqual(t.mock_calls, []) + + def test_bad_prologue_newline(self): + f, t = make_framer() + self.assertEqual(t.mock_calls, []) + + f.connectionMade() + self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")]) + t.mock_calls[:] = [] + self.assertEqual([], list(f.add_and_parse(b"inbound_"))) # wait for it + + self.assertEqual([], list(f.add_and_parse(b"not"))) + with mock.patch("wormhole._dilation.connection.log.msg") as m: + with self.assertRaises(Disconnect): + list(f.add_and_parse(b"\n")) + self.assertEqual(m.mock_calls, + [mock.call("bad prologue: {}".format( + b"inbound_not\n"))]) + self.assertEqual(t.mock_calls, []) + + def test_good_prologue(self): + f, t = make_framer() + self.assertEqual(t.mock_calls, []) + + f.connectionMade() + self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")]) + t.mock_calls[:] = [] + self.assertEqual([Prologue()], + list(f.add_and_parse(b"inbound_prologue\n"))) + self.assertEqual(t.mock_calls, []) + + # now send_frame should work + f.send_frame(b"frame") + self.assertEqual(t.mock_calls, + [mock.call.write(b"\x00\x00\x00\x05frame")]) + + def test_bad_relay(self): + f, t = make_framer() + self.assertEqual(t.mock_calls, []) + f.use_relay(b"relay handshake\n") + + f.connectionMade() + self.assertEqual(t.mock_calls, [mock.call.write(b"relay handshake\n")]) + t.mock_calls[:] = [] + with mock.patch("wormhole._dilation.connection.log.msg") as m: + with self.assertRaises(Disconnect): + list(f.add_and_parse(b"goodbye\n")) + self.assertEqual(m.mock_calls, + [mock.call("bad relay_ok: {}".format(b"goo"))]) + self.assertEqual(t.mock_calls, []) + + def test_good_relay(self): + f, t = make_framer() + self.assertEqual(t.mock_calls, []) + f.use_relay(b"relay handshake\n") + self.assertEqual(t.mock_calls, []) + + f.connectionMade() + self.assertEqual(t.mock_calls, [mock.call.write(b"relay handshake\n")]) + t.mock_calls[:] = [] + + self.assertEqual([], list(f.add_and_parse(b"ok\n"))) + self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")]) + + def test_frame(self): + f, t = make_framer() + self.assertEqual(t.mock_calls, []) + + f.connectionMade() + self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")]) + t.mock_calls[:] = [] + self.assertEqual([Prologue()], + list(f.add_and_parse(b"inbound_prologue\n"))) + self.assertEqual(t.mock_calls, []) + + encoded_frame = b"\x00\x00\x00\x05frame" + self.assertEqual([], list(f.add_and_parse(encoded_frame[:2]))) + self.assertEqual([], list(f.add_and_parse(encoded_frame[2:6]))) + self.assertEqual([Frame(frame=b"frame")], + list(f.add_and_parse(encoded_frame[6:]))) diff --git a/src/wormhole/test/dilate/test_inbound.py b/src/wormhole/test/dilate/test_inbound.py new file mode 100644 index 0000000..d147283 --- /dev/null +++ b/src/wormhole/test/dilate/test_inbound.py @@ -0,0 +1,172 @@ +from __future__ import print_function, unicode_literals +import mock +from zope.interface import alsoProvides +from twisted.trial import unittest +from ..._interfaces import IDilationManager +from ..._dilation.connection import Open, Data, Close +from ..._dilation.inbound import (Inbound, DuplicateOpenError, + DataForMissingSubchannelError, + CloseForMissingSubchannelError) + +def make_inbound(): + m = mock.Mock() + alsoProvides(m, IDilationManager) + host_addr = object() + i = Inbound(m, host_addr) + return i, m, host_addr + +class InboundTest(unittest.TestCase): + def test_seqnum(self): + i, m, host_addr = make_inbound() + r1 = Open(scid=513, seqnum=1) + r2 = Data(scid=513, seqnum=2, data=b"") + r3 = Close(scid=513, seqnum=3) + self.assertFalse(i.is_record_old(r1)) + self.assertFalse(i.is_record_old(r2)) + self.assertFalse(i.is_record_old(r3)) + + i.update_ack_watermark(r1) + self.assertTrue(i.is_record_old(r1)) + self.assertFalse(i.is_record_old(r2)) + self.assertFalse(i.is_record_old(r3)) + + i.update_ack_watermark(r2) + self.assertTrue(i.is_record_old(r1)) + self.assertTrue(i.is_record_old(r2)) + self.assertFalse(i.is_record_old(r3)) + + def test_open_data_close(self): + i, m, host_addr = make_inbound() + scid1 = b"scid" + scid2 = b"scXX" + c = mock.Mock() + lep = mock.Mock() + i.set_listener_endpoint(lep) + i.use_connection(c) + sc1 = mock.Mock() + peer_addr = object() + with mock.patch("wormhole._dilation.inbound.SubChannel", + side_effect=[sc1]) as sc: + with mock.patch("wormhole._dilation.inbound._SubchannelAddress", + side_effect=[peer_addr]) as sca: + i.handle_open(scid1) + self.assertEqual(lep.mock_calls, [mock.call._got_open(sc1, peer_addr)]) + self.assertEqual(sc.mock_calls, [mock.call(scid1, m, host_addr, peer_addr)]) + self.assertEqual(sca.mock_calls, [mock.call(scid1)]) + lep.mock_calls[:] = [] + + # a subsequent duplicate OPEN should be ignored + with mock.patch("wormhole._dilation.inbound.SubChannel", + side_effect=[sc1]) as sc: + with mock.patch("wormhole._dilation.inbound._SubchannelAddress", + side_effect=[peer_addr]) as sca: + i.handle_open(scid1) + self.assertEqual(lep.mock_calls, []) + self.assertEqual(sc.mock_calls, []) + self.assertEqual(sca.mock_calls, []) + self.flushLoggedErrors(DuplicateOpenError) + + i.handle_data(scid1, b"data") + self.assertEqual(sc1.mock_calls, [mock.call.remote_data(b"data")]) + sc1.mock_calls[:] = [] + + i.handle_data(scid2, b"for non-existent subchannel") + self.assertEqual(sc1.mock_calls, []) + self.flushLoggedErrors(DataForMissingSubchannelError) + + i.handle_close(scid1) + self.assertEqual(sc1.mock_calls, [mock.call.remote_close()]) + sc1.mock_calls[:] = [] + + i.handle_close(scid2) + self.assertEqual(sc1.mock_calls, []) + self.flushLoggedErrors(CloseForMissingSubchannelError) + + # after the subchannel is closed, the Manager will notify Inbound + i.subchannel_closed(scid1, sc1) + + i.stop_using_connection() + + def test_control_channel(self): + i, m, host_addr = make_inbound() + lep = mock.Mock() + i.set_listener_endpoint(lep) + + scid0 = b"scid" + sc0 = mock.Mock() + i.set_subchannel_zero(scid0, sc0) + + # OPEN on the control channel identifier should be ignored as a + # duplicate, since the control channel is already registered + sc1 = mock.Mock() + peer_addr = object() + with mock.patch("wormhole._dilation.inbound.SubChannel", + side_effect=[sc1]) as sc: + with mock.patch("wormhole._dilation.inbound._SubchannelAddress", + side_effect=[peer_addr]) as sca: + i.handle_open(scid0) + self.assertEqual(lep.mock_calls, []) + self.assertEqual(sc.mock_calls, []) + self.assertEqual(sca.mock_calls, []) + self.flushLoggedErrors(DuplicateOpenError) + + # and DATA to it should be delivered correctly + i.handle_data(scid0, b"data") + self.assertEqual(sc0.mock_calls, [mock.call.remote_data(b"data")]) + sc0.mock_calls[:] = [] + + def test_pause(self): + i, m, host_addr = make_inbound() + c = mock.Mock() + lep = mock.Mock() + i.set_listener_endpoint(lep) + + # add two subchannels, pause one, then add a connection + scid1 = b"sci1" + scid2 = b"sci2" + sc1 = mock.Mock() + sc2 = mock.Mock() + peer_addr = object() + with mock.patch("wormhole._dilation.inbound.SubChannel", + side_effect=[sc1, sc2]): + with mock.patch("wormhole._dilation.inbound._SubchannelAddress", + return_value=peer_addr): + i.handle_open(scid1) + i.handle_open(scid2) + self.assertEqual(c.mock_calls, []) + + i.subchannel_pauseProducing(sc1) + self.assertEqual(c.mock_calls, []) + i.subchannel_resumeProducing(sc1) + self.assertEqual(c.mock_calls, []) + i.subchannel_pauseProducing(sc1) + self.assertEqual(c.mock_calls, []) + + i.use_connection(c) + self.assertEqual(c.mock_calls, [mock.call.pauseProducing()]) + c.mock_calls[:] = [] + + i.subchannel_resumeProducing(sc1) + self.assertEqual(c.mock_calls, [mock.call.resumeProducing()]) + c.mock_calls[:] = [] + + # consumers aren't really supposed to do this, but tolerate it + i.subchannel_resumeProducing(sc1) + self.assertEqual(c.mock_calls, []) + + i.subchannel_pauseProducing(sc1) + self.assertEqual(c.mock_calls, [mock.call.pauseProducing()]) + c.mock_calls[:] = [] + i.subchannel_pauseProducing(sc2) + self.assertEqual(c.mock_calls, []) # was already paused + + # tolerate duplicate pauseProducing + i.subchannel_pauseProducing(sc2) + self.assertEqual(c.mock_calls, []) + + # stopProducing is treated like a terminal resumeProducing + i.subchannel_stopProducing(sc1) + self.assertEqual(c.mock_calls, []) + i.subchannel_stopProducing(sc2) + self.assertEqual(c.mock_calls, [mock.call.resumeProducing()]) + c.mock_calls[:] = [] diff --git a/src/wormhole/test/dilate/test_manager.py b/src/wormhole/test/dilate/test_manager.py new file mode 100644 index 0000000..625039d --- /dev/null +++ b/src/wormhole/test/dilate/test_manager.py @@ -0,0 +1,205 @@ +from __future__ import print_function, unicode_literals +from zope.interface import alsoProvides +from twisted.trial import unittest +from twisted.internet.defer import Deferred +from twisted.internet.task import Clock, Cooperator +import mock +from ...eventual import EventualQueue +from ..._interfaces import ISend, IDilationManager +from ...util import dict_to_bytes +from ..._dilation.manager import (Dilator, + OldPeerCannotDilateError, + UnknownDilationMessageType) +from ..._dilation.subchannel import _WormholeAddress +from .common import clear_mock_calls + +def make_dilator(): + reactor = object() + clock = Clock() + eq = EventualQueue(clock) + term = mock.Mock(side_effect=lambda: True) # one write per Eventual tick + term_factory = lambda: term + coop = Cooperator(terminationPredicateFactory=term_factory, + scheduler=eq.eventually) + send = mock.Mock() + alsoProvides(send, ISend) + dil = Dilator(reactor, eq, coop) + dil.wire(send) + return dil, send, reactor, eq, clock, coop + +class TestDilator(unittest.TestCase): + def test_leader(self): + dil, send, reactor, eq, clock, coop = make_dilator() + d1 = dil.dilate() + d2 = dil.dilate() + self.assertNoResult(d1) + self.assertNoResult(d2) + + key = b"key" + transit_key = object() + with mock.patch("wormhole._dilation.manager.derive_key", + return_value=transit_key) as dk: + dil.got_key(key) + self.assertEqual(dk.mock_calls, [mock.call(key, b"dilation-v1", 32)]) + self.assertIdentical(dil._transit_key, transit_key) + self.assertNoResult(d1) + self.assertNoResult(d2) + + m = mock.Mock() + alsoProvides(m, IDilationManager) + m.when_first_connected.return_value = wfc_d = Deferred() + # TODO: test missing can-dilate, and no-overlap + with mock.patch("wormhole._dilation.manager.ManagerLeader", + return_value=m) as ml: + dil.got_wormhole_versions("us", "them", {"can-dilate": [1]}) + # that should create the Manager. Because "us" > "them", we're + # the leader + self.assertEqual(ml.mock_calls, [mock.call(send, "us", transit_key, + None, reactor, eq, coop)]) + self.assertEqual(m.mock_calls, [mock.call.start(), + mock.call.when_first_connected(), + ]) + clear_mock_calls(m) + self.assertNoResult(d1) + self.assertNoResult(d2) + + host_addr = _WormholeAddress() + m_wa = mock.patch("wormhole._dilation.manager._WormholeAddress", + return_value=host_addr) + peer_addr = object() + m_sca = mock.patch("wormhole._dilation.manager._SubchannelAddress", + return_value=peer_addr) + ce = mock.Mock() + m_ce = mock.patch("wormhole._dilation.manager.ControlEndpoint", + return_value=ce) + sc = mock.Mock() + m_sc = mock.patch("wormhole._dilation.manager.SubChannel", + return_value=sc) + + lep = object() + m_sle = mock.patch("wormhole._dilation.manager.SubchannelListenerEndpoint", + return_value=lep) + + with m_wa, m_sca, m_ce as m_ce_m, m_sc as m_sc_m, m_sle as m_sle_m: + wfc_d.callback(None) + eq.flush_sync() + scid0 = b"\x00\x00\x00\x00" + self.assertEqual(m_ce_m.mock_calls, [mock.call(peer_addr)]) + self.assertEqual(m_sc_m.mock_calls, + [mock.call(scid0, m, host_addr, peer_addr)]) + self.assertEqual(ce.mock_calls, [mock.call._subchannel_zero_opened(sc)]) + self.assertEqual(m_sle_m.mock_calls, [mock.call(m, host_addr)]) + self.assertEqual(m.mock_calls, + [mock.call.set_subchannel_zero(scid0, sc), + mock.call.set_listener_endpoint(lep), + ]) + clear_mock_calls(m) + + eps = self.successResultOf(d1) + self.assertEqual(eps, self.successResultOf(d2)) + d3 = dil.dilate() + eq.flush_sync() + self.assertEqual(eps, self.successResultOf(d3)) + + self.assertEqual(m.mock_calls, []) + dil.received_dilate(dict_to_bytes(dict(type="please"))) + self.assertEqual(m.mock_calls, [mock.call.rx_PLEASE()]) + clear_mock_calls(m) + + hintmsg = dict(type="connection-hints") + dil.received_dilate(dict_to_bytes(hintmsg)) + self.assertEqual(m.mock_calls, [mock.call.rx_HINTS(hintmsg)]) + clear_mock_calls(m) + + dil.received_dilate(dict_to_bytes(dict(type="dilate"))) + self.assertEqual(m.mock_calls, [mock.call.rx_DILATE()]) + clear_mock_calls(m) + + dil.received_dilate(dict_to_bytes(dict(type="unknown"))) + self.assertEqual(m.mock_calls, []) + self.flushLoggedErrors(UnknownDilationMessageType) + + def test_follower(self): + dil, send, reactor, eq, clock, coop = make_dilator() + d1 = dil.dilate() + self.assertNoResult(d1) + + key = b"key" + transit_key = object() + with mock.patch("wormhole._dilation.manager.derive_key", + return_value=transit_key): + dil.got_key(key) + + m = mock.Mock() + alsoProvides(m, IDilationManager) + m.when_first_connected.return_value = Deferred() + with mock.patch("wormhole._dilation.manager.ManagerFollower", + return_value=m) as mf: + dil.got_wormhole_versions("me", "you", {"can-dilate": [1]}) + # "me" < "you", so we're the follower + self.assertEqual(mf.mock_calls, [mock.call(send, "me", transit_key, + None, reactor, eq, coop)]) + self.assertEqual(m.mock_calls, [mock.call.start(), + mock.call.when_first_connected(), + ]) + + def test_peer_cannot_dilate(self): + dil, send, reactor, eq, clock, coop = make_dilator() + d1 = dil.dilate() + self.assertNoResult(d1) + + dil.got_wormhole_versions("me", "you", {}) # missing "can-dilate" + eq.flush_sync() + f = self.failureResultOf(d1) + f.check(OldPeerCannotDilateError) + + + def test_disjoint_versions(self): + dil, send, reactor, eq, clock, coop = make_dilator() + d1 = dil.dilate() + self.assertNoResult(d1) + + dil.got_wormhole_versions("me", "you", {"can-dilate": [-1]}) + eq.flush_sync() + f = self.failureResultOf(d1) + f.check(OldPeerCannotDilateError) + + + def test_early_dilate_messages(self): + dil, send, reactor, eq, clock, coop = make_dilator() + dil._transit_key = b"key" + d1 = dil.dilate() + self.assertNoResult(d1) + dil.received_dilate(dict_to_bytes(dict(type="please"))) + hintmsg = dict(type="connection-hints") + dil.received_dilate(dict_to_bytes(hintmsg)) + + m = mock.Mock() + alsoProvides(m, IDilationManager) + m.when_first_connected.return_value = Deferred() + + with mock.patch("wormhole._dilation.manager.ManagerLeader", + return_value=m) as ml: + dil.got_wormhole_versions("us", "them", {"can-dilate": [1]}) + self.assertEqual(ml.mock_calls, [mock.call(send, "us", b"key", + None, reactor, eq, coop)]) + self.assertEqual(m.mock_calls, [mock.call.start(), + mock.call.rx_PLEASE(), + mock.call.rx_HINTS(hintmsg), + mock.call.when_first_connected()]) + + + + def test_transit_relay(self): + dil, send, reactor, eq, clock, coop = make_dilator() + dil._transit_key = b"key" + relay = object() + d1 = dil.dilate(transit_relay_location=relay) + self.assertNoResult(d1) + + with mock.patch("wormhole._dilation.manager.ManagerLeader") as ml: + dil.got_wormhole_versions("us", "them", {"can-dilate": [1]}) + self.assertEqual(ml.mock_calls, [mock.call(send, "us", b"key", + relay, reactor, eq, coop), + mock.call().start(), + mock.call().when_first_connected()]) diff --git a/src/wormhole/test/dilate/test_outbound.py b/src/wormhole/test/dilate/test_outbound.py new file mode 100644 index 0000000..fab596a --- /dev/null +++ b/src/wormhole/test/dilate/test_outbound.py @@ -0,0 +1,645 @@ +from __future__ import print_function, unicode_literals +from collections import namedtuple +from itertools import cycle +import mock +from zope.interface import alsoProvides +from twisted.trial import unittest +from twisted.internet.task import Clock, Cooperator +from twisted.internet.interfaces import IPullProducer +from ...eventual import EventualQueue +from ..._interfaces import IDilationManager +from ..._dilation.connection import KCM, Open, Data, Close, Ack +from ..._dilation.outbound import Outbound, PullToPush +from .common import clear_mock_calls + +Pauser = namedtuple("Pauser", ["seqnum"]) +NonPauser = namedtuple("NonPauser", ["seqnum"]) +Stopper = namedtuple("Stopper", ["sc"]) + +def make_outbound(): + m = mock.Mock() + alsoProvides(m, IDilationManager) + clock = Clock() + eq = EventualQueue(clock) + term = mock.Mock(side_effect=lambda: True) # one write per Eventual tick + term_factory = lambda: term + coop = Cooperator(terminationPredicateFactory=term_factory, + scheduler=eq.eventually) + o = Outbound(m, coop) + c = mock.Mock() # Connection + def maybe_pause(r): + if isinstance(r, Pauser): + o.pauseProducing() + elif isinstance(r, Stopper): + o.subchannel_unregisterProducer(r.sc) + c.send_record = mock.Mock(side_effect=maybe_pause) + o._test_eq = eq + o._test_term = term + return o, m, c + +class OutboundTest(unittest.TestCase): + def test_build_record(self): + o, m, c = make_outbound() + scid1 = b"scid" + self.assertEqual(o.build_record(Open, scid1), + Open(seqnum=0, scid=b"scid")) + self.assertEqual(o.build_record(Data, scid1, b"dataaa"), + Data(seqnum=1, scid=b"scid", data=b"dataaa")) + self.assertEqual(o.build_record(Close, scid1), + Close(seqnum=2, scid=b"scid")) + self.assertEqual(o.build_record(Close, scid1), + Close(seqnum=3, scid=b"scid")) + + def test_outbound_queue(self): + o, m, c = make_outbound() + scid1 = b"scid" + r1 = o.build_record(Open, scid1) + r2 = o.build_record(Data, scid1, b"data1") + r3 = o.build_record(Data, scid1, b"data2") + o.queue_and_send_record(r1) + o.queue_and_send_record(r2) + o.queue_and_send_record(r3) + self.assertEqual(list(o._outbound_queue), [r1, r2, r3]) + + # we would never normally receive an ACK without first getting a + # connection + o.handle_ack(r2.seqnum) + self.assertEqual(list(o._outbound_queue), [r3]) + + o.handle_ack(r3.seqnum) + self.assertEqual(list(o._outbound_queue), []) + + o.handle_ack(r3.seqnum) # ignored + self.assertEqual(list(o._outbound_queue), []) + + o.handle_ack(r1.seqnum) # ignored + self.assertEqual(list(o._outbound_queue), []) + + def test_duplicate_registerProducer(self): + o, m, c = make_outbound() + sc1 = object() + p1 = mock.Mock() + o.subchannel_registerProducer(sc1, p1, True) + with self.assertRaises(ValueError) as ar: + o.subchannel_registerProducer(sc1, p1, True) + s = str(ar.exception) + self.assertIn("registering producer", s) + self.assertIn("before previous one", s) + self.assertIn("was unregistered", s) + + def test_connection_send_queued_unpaused(self): + o, m, c = make_outbound() + scid1 = b"scid" + r1 = o.build_record(Open, scid1) + r2 = o.build_record(Data, scid1, b"data1") + r3 = o.build_record(Data, scid1, b"data2") + o.queue_and_send_record(r1) + o.queue_and_send_record(r2) + self.assertEqual(list(o._outbound_queue), [r1, r2]) + self.assertEqual(list(o._queued_unsent), []) + + # as soon as the connection is established, everything is sent + o.use_connection(c) + self.assertEqual(c.mock_calls, [mock.call.registerProducer(o, True), + mock.call.send_record(r1), + mock.call.send_record(r2)]) + self.assertEqual(list(o._outbound_queue), [r1, r2]) + self.assertEqual(list(o._queued_unsent), []) + clear_mock_calls(c) + + o.queue_and_send_record(r3) + self.assertEqual(list(o._outbound_queue), [r1, r2, r3]) + self.assertEqual(list(o._queued_unsent), []) + self.assertEqual(c.mock_calls, [mock.call.send_record(r3)]) + + def test_connection_send_queued_paused(self): + o, m, c = make_outbound() + r1 = Pauser(seqnum=1) + r2 = Pauser(seqnum=2) + r3 = Pauser(seqnum=3) + o.queue_and_send_record(r1) + o.queue_and_send_record(r2) + self.assertEqual(list(o._outbound_queue), [r1, r2]) + self.assertEqual(list(o._queued_unsent), []) + + # pausing=True, so our mock Manager will pause the Outbound producer + # after each write. So only r1 should have been sent before getting + # paused + o.use_connection(c) + self.assertEqual(c.mock_calls, [mock.call.registerProducer(o, True), + mock.call.send_record(r1)]) + self.assertEqual(list(o._outbound_queue), [r1, r2]) + self.assertEqual(list(o._queued_unsent), [r2]) + clear_mock_calls(c) + + # Outbound is responsible for sending all records, so when Manager + # wants to send a new one, and Outbound is still in the middle of + # draining the beginning-of-connection queue, the new message gets + # queued behind the rest (in addition to being queued in + # _outbound_queue until an ACK retires it). + o.queue_and_send_record(r3) + self.assertEqual(list(o._outbound_queue), [r1, r2, r3]) + self.assertEqual(list(o._queued_unsent), [r2, r3]) + self.assertEqual(c.mock_calls, []) + + o.handle_ack(r1.seqnum) + self.assertEqual(list(o._outbound_queue), [r2, r3]) + self.assertEqual(list(o._queued_unsent), [r2, r3]) + self.assertEqual(c.mock_calls, []) + + def test_premptive_ack(self): + # one mode I have in mind is for each side to send an immediate ACK, + # with everything they've ever seen, as the very first message on each + # new connection. The idea is that you might preempt sending stuff from + # the _queued_unsent list if it arrives fast enough (in practice this + # is more likely to be delivered via the DILATE mailbox message, but + # the effects might be vaguely similar, so it seems worth testing + # here). A similar situation would be if each side sends ACKs with the + # highest seqnum they've ever seen, instead of merely ACKing the + # message which was just received. + o, m, c = make_outbound() + r1 = Pauser(seqnum=1) + r2 = Pauser(seqnum=2) + r3 = Pauser(seqnum=3) + o.queue_and_send_record(r1) + o.queue_and_send_record(r2) + self.assertEqual(list(o._outbound_queue), [r1, r2]) + self.assertEqual(list(o._queued_unsent), []) + + o.use_connection(c) + self.assertEqual(c.mock_calls, [mock.call.registerProducer(o, True), + mock.call.send_record(r1)]) + self.assertEqual(list(o._outbound_queue), [r1, r2]) + self.assertEqual(list(o._queued_unsent), [r2]) + clear_mock_calls(c) + + o.queue_and_send_record(r3) + self.assertEqual(list(o._outbound_queue), [r1, r2, r3]) + self.assertEqual(list(o._queued_unsent), [r2, r3]) + self.assertEqual(c.mock_calls, []) + + o.handle_ack(r2.seqnum) + self.assertEqual(list(o._outbound_queue), [r3]) + self.assertEqual(list(o._queued_unsent), [r3]) + self.assertEqual(c.mock_calls, []) + + def test_pause(self): + o, m, c = make_outbound() + o.use_connection(c) + self.assertEqual(c.mock_calls, [mock.call.registerProducer(o, True)]) + self.assertEqual(list(o._outbound_queue), []) + self.assertEqual(list(o._queued_unsent), []) + clear_mock_calls(c) + + sc1, sc2, sc3 = object(), object(), object() + p1, p2, p3 = mock.Mock(name="p1"), mock.Mock(name="p2"), mock.Mock(name="p3") + + # we aren't paused yet, since we haven't sent any data + o.subchannel_registerProducer(sc1, p1, True) + self.assertEqual(p1.mock_calls, []) + + r1 = Pauser(seqnum=1) + o.queue_and_send_record(r1) + # now we should be paused + self.assertTrue(o._paused) + self.assertEqual(c.mock_calls, [mock.call.send_record(r1)]) + self.assertEqual(p1.mock_calls, [mock.call.pauseProducing()]) + clear_mock_calls(p1, c) + + # so an IPushProducer will be paused right away + o.subchannel_registerProducer(sc2, p2, True) + self.assertEqual(p2.mock_calls, [mock.call.pauseProducing()]) + clear_mock_calls(p2) + + o.subchannel_registerProducer(sc3, p3, True) + self.assertEqual(p3.mock_calls, [mock.call.pauseProducing()]) + self.assertEqual(o._paused_producers, set([p1, p2, p3])) + self.assertEqual(list(o._all_producers), [p1, p2, p3]) + clear_mock_calls(p3) + + # one resumeProducing should cause p1 to get a turn, since p2 was added + # after we were paused and p1 was at the "end" of a one-element list. + # If it writes anything, it will get paused again immediately. + r2 = Pauser(seqnum=2) + p1.resumeProducing.side_effect = lambda: c.send_record(r2) + o.resumeProducing() + self.assertEqual(p1.mock_calls, [mock.call.resumeProducing(), + mock.call.pauseProducing(), + ]) + self.assertEqual(p2.mock_calls, []) + self.assertEqual(p3.mock_calls, []) + self.assertEqual(c.mock_calls, [mock.call.send_record(r2)]) + clear_mock_calls(p1, p2, p3, c) + # p2 should now be at the head of the queue + self.assertEqual(list(o._all_producers), [p2, p3, p1]) + + # next turn: p2 has nothing to send, but p3 does. we should see p3 + # called but not p1. The actual sequence of expected calls is: + # p2.resume, p3.resume, pauseProducing, set(p2.pause, p3.pause) + r3 = Pauser(seqnum=3) + p2.resumeProducing.side_effect = lambda: None + p3.resumeProducing.side_effect = lambda: c.send_record(r3) + o.resumeProducing() + self.assertEqual(p1.mock_calls, []) + self.assertEqual(p2.mock_calls, [mock.call.resumeProducing(), + mock.call.pauseProducing(), + ]) + self.assertEqual(p3.mock_calls, [mock.call.resumeProducing(), + mock.call.pauseProducing(), + ]) + self.assertEqual(c.mock_calls, [mock.call.send_record(r3)]) + clear_mock_calls(p1, p2, p3, c) + # p1 should now be at the head of the queue + self.assertEqual(list(o._all_producers), [p1, p2, p3]) + + # next turn: p1 has data to send, but not enough to cause a pause. same + # for p2. p3 causes a pause + r4 = NonPauser(seqnum=4) + r5 = NonPauser(seqnum=5) + r6 = Pauser(seqnum=6) + p1.resumeProducing.side_effect = lambda: c.send_record(r4) + p2.resumeProducing.side_effect = lambda: c.send_record(r5) + p3.resumeProducing.side_effect = lambda: c.send_record(r6) + o.resumeProducing() + self.assertEqual(p1.mock_calls, [mock.call.resumeProducing(), + mock.call.pauseProducing(), + ]) + self.assertEqual(p2.mock_calls, [mock.call.resumeProducing(), + mock.call.pauseProducing(), + ]) + self.assertEqual(p3.mock_calls, [mock.call.resumeProducing(), + mock.call.pauseProducing(), + ]) + self.assertEqual(c.mock_calls, [mock.call.send_record(r4), + mock.call.send_record(r5), + mock.call.send_record(r6), + ]) + clear_mock_calls(p1, p2, p3, c) + # p1 should now be at the head of the queue again + self.assertEqual(list(o._all_producers), [p1, p2, p3]) + + # now we let it catch up. p1 and p2 send non-pausing data, p3 sends + # nothing. + r7 = NonPauser(seqnum=4) + r8 = NonPauser(seqnum=5) + p1.resumeProducing.side_effect = lambda: c.send_record(r7) + p2.resumeProducing.side_effect = lambda: c.send_record(r8) + p3.resumeProducing.side_effect = lambda: None + + o.resumeProducing() + self.assertEqual(p1.mock_calls, [mock.call.resumeProducing(), + ]) + self.assertEqual(p2.mock_calls, [mock.call.resumeProducing(), + ]) + self.assertEqual(p3.mock_calls, [mock.call.resumeProducing(), + ]) + self.assertEqual(c.mock_calls, [mock.call.send_record(r7), + mock.call.send_record(r8), + ]) + clear_mock_calls(p1, p2, p3, c) + # p1 should now be at the head of the queue again + self.assertEqual(list(o._all_producers), [p1, p2, p3]) + self.assertFalse(o._paused) + + # now a producer disconnects itself (spontaneously, not from inside a + # resumeProducing) + o.subchannel_unregisterProducer(sc1) + self.assertEqual(list(o._all_producers), [p2, p3]) + self.assertEqual(p1.mock_calls, []) + self.assertFalse(o._paused) + + # and another disconnects itself when called + p2.resumeProducing.side_effect = lambda: None + p3.resumeProducing.side_effect = lambda: o.subchannel_unregisterProducer(sc3) + o.pauseProducing() + o.resumeProducing() + self.assertEqual(p2.mock_calls, [mock.call.pauseProducing(), + mock.call.resumeProducing()]) + self.assertEqual(p3.mock_calls, [mock.call.pauseProducing(), + mock.call.resumeProducing()]) + clear_mock_calls(p2, p3) + self.assertEqual(list(o._all_producers), [p2]) + self.assertFalse(o._paused) + + def test_subchannel_closed(self): + o, m, c = make_outbound() + + sc1 = mock.Mock() + p1 = mock.Mock(name="p1") + o.subchannel_registerProducer(sc1, p1, True) + self.assertEqual(p1.mock_calls, [mock.call.pauseProducing()]) + clear_mock_calls(p1) + + o.subchannel_closed(sc1) + self.assertEqual(p1.mock_calls, []) + self.assertEqual(list(o._all_producers), []) + + sc2 = mock.Mock() + o.subchannel_closed(sc2) + + def test_disconnect(self): + o, m, c = make_outbound() + o.use_connection(c) + + sc1 = mock.Mock() + p1 = mock.Mock(name="p1") + o.subchannel_registerProducer(sc1, p1, True) + self.assertEqual(p1.mock_calls, []) + o.stop_using_connection() + self.assertEqual(p1.mock_calls, [mock.call.pauseProducing()]) + + def OFF_test_push_pull(self): + # use one IPushProducer and one IPullProducer. They should take turns + o, m, c = make_outbound() + o.use_connection(c) + clear_mock_calls(c) + + sc1, sc2 = object(), object() + p1, p2 = mock.Mock(name="p1"), mock.Mock(name="p2") + r1 = Pauser(seqnum=1) + r2 = NonPauser(seqnum=2) + + # we aren't paused yet, since we haven't sent any data + o.subchannel_registerProducer(sc1, p1, True) # push + o.queue_and_send_record(r1) + # now we're paused + self.assertTrue(o._paused) + self.assertEqual(c.mock_calls, [mock.call.send_record(r1)]) + self.assertEqual(p1.mock_calls, [mock.call.pauseProducing()]) + self.assertEqual(p2.mock_calls, []) + clear_mock_calls(p1, p2, c) + + p1.resumeProducing.side_effect = lambda: c.send_record(r1) + p2.resumeProducing.side_effect = lambda: c.send_record(r2) + o.subchannel_registerProducer(sc2, p2, False) # pull: always ready + + # p1 is still first, since p2 was just added (at the end) + self.assertTrue(o._paused) + self.assertEqual(c.mock_calls, []) + self.assertEqual(p1.mock_calls, []) + self.assertEqual(p2.mock_calls, []) + self.assertEqual(list(o._all_producers), [p1, p2]) + clear_mock_calls(p1, p2, c) + + # resume should send r1, which should pause everything + o.resumeProducing() + self.assertTrue(o._paused) + self.assertEqual(c.mock_calls, [mock.call.send_record(r1), + ]) + self.assertEqual(p1.mock_calls, [mock.call.resumeProducing(), + mock.call.pauseProducing(), + ]) + self.assertEqual(p2.mock_calls, []) + self.assertEqual(list(o._all_producers), [p2, p1]) # now p2 is next + clear_mock_calls(p1, p2, c) + + # next should fire p2, then p1 + o.resumeProducing() + self.assertTrue(o._paused) + self.assertEqual(c.mock_calls, [mock.call.send_record(r2), + mock.call.send_record(r1), + ]) + self.assertEqual(p1.mock_calls, [mock.call.resumeProducing(), + mock.call.pauseProducing(), + ]) + self.assertEqual(p2.mock_calls, [mock.call.resumeProducing(), + ]) + self.assertEqual(list(o._all_producers), [p2, p1]) # p2 still at bat + clear_mock_calls(p1, p2, c) + + def test_pull_producer(self): + # a single pull producer should write until it is paused, rate-limited + # by the cooperator (so we'll see back-to-back resumeProducing calls + # until the Connection is paused, or 10ms have passed, whichever comes + # first, and if it's stopped by the timer, then the next EventualQueue + # turn will start it off again) + + o, m, c = make_outbound() + eq = o._test_eq + o.use_connection(c) + clear_mock_calls(c) + self.assertFalse(o._paused) + + sc1 = mock.Mock() + p1 = mock.Mock(name="p1") + alsoProvides(p1, IPullProducer) + + records = [NonPauser(seqnum=1)] * 10 + records.append(Pauser(seqnum=2)) + records.append(Stopper(sc1)) + it = iter(records) + p1.resumeProducing.side_effect = lambda: c.send_record(next(it)) + o.subchannel_registerProducer(sc1, p1, False) + eq.flush_sync() # fast forward into the glorious (paused) future + + self.assertTrue(o._paused) + self.assertEqual(c.mock_calls, + [mock.call.send_record(r) for r in records[:-1]]) + self.assertEqual(p1.mock_calls, + [mock.call.resumeProducing()]*(len(records)-1)) + clear_mock_calls(c, p1) + + # next resumeProducing should cause it to disconnect + o.resumeProducing() + eq.flush_sync() + self.assertEqual(c.mock_calls, [mock.call.send_record(records[-1])]) + self.assertEqual(p1.mock_calls, [mock.call.resumeProducing()]) + self.assertEqual(len(o._all_producers), 0) + self.assertFalse(o._paused) + + def test_two_pull_producers(self): + # we should alternate between them until paused + p1_records = ([NonPauser(seqnum=i) for i in range(5)] + + [Pauser(seqnum=5)] + + [NonPauser(seqnum=i) for i in range(6, 10)]) + p2_records = ([NonPauser(seqnum=i) for i in range(10, 19)] + + [Pauser(seqnum=19)]) + expected1 = [NonPauser(0), NonPauser(10), + NonPauser(1), NonPauser(11), + NonPauser(2), NonPauser(12), + NonPauser(3), NonPauser(13), + NonPauser(4), NonPauser(14), + Pauser(5)] + expected2 = [ NonPauser(15), + NonPauser(6), NonPauser(16), + NonPauser(7), NonPauser(17), + NonPauser(8), NonPauser(18), + NonPauser(9), Pauser(19), + ] + + o, m, c = make_outbound() + eq = o._test_eq + o.use_connection(c) + clear_mock_calls(c) + self.assertFalse(o._paused) + + sc1 = mock.Mock() + p1 = mock.Mock(name="p1") + alsoProvides(p1, IPullProducer) + it1 = iter(p1_records) + p1.resumeProducing.side_effect = lambda: c.send_record(next(it1)) + o.subchannel_registerProducer(sc1, p1, False) + + sc2 = mock.Mock() + p2 = mock.Mock(name="p2") + alsoProvides(p2, IPullProducer) + it2 = iter(p2_records) + p2.resumeProducing.side_effect = lambda: c.send_record(next(it2)) + o.subchannel_registerProducer(sc2, p2, False) + + eq.flush_sync() # fast forward into the glorious (paused) future + + sends = [mock.call.resumeProducing()] + self.assertTrue(o._paused) + self.assertEqual(c.mock_calls, + [mock.call.send_record(r) for r in expected1]) + self.assertEqual(p1.mock_calls, 6*sends) + self.assertEqual(p2.mock_calls, 5*sends) + clear_mock_calls(c, p1, p2) + + o.resumeProducing() + eq.flush_sync() + self.assertTrue(o._paused) + self.assertEqual(c.mock_calls, + [mock.call.send_record(r) for r in expected2]) + self.assertEqual(p1.mock_calls, 4*sends) + self.assertEqual(p2.mock_calls, 5*sends) + clear_mock_calls(c, p1, p2) + + def test_send_if_connected(self): + o, m, c = make_outbound() + o.send_if_connected(Ack(1)) # not connected yet + + o.use_connection(c) + o.send_if_connected(KCM()) + self.assertEqual(c.mock_calls, [mock.call.registerProducer(o, True), + mock.call.send_record(KCM())]) + + def test_tolerate_duplicate_pause_resume(self): + o, m, c = make_outbound() + self.assertTrue(o._paused) # no connection + o.use_connection(c) + self.assertFalse(o._paused) + o.pauseProducing() + self.assertTrue(o._paused) + o.pauseProducing() + self.assertTrue(o._paused) + o.resumeProducing() + self.assertFalse(o._paused) + o.resumeProducing() + self.assertFalse(o._paused) + + def test_stopProducing(self): + o, m, c = make_outbound() + o.use_connection(c) + self.assertFalse(o._paused) + o.stopProducing() # connection does this before loss + self.assertTrue(o._paused) + o.stop_using_connection() + self.assertTrue(o._paused) + + def test_resume_error(self): + o, m, c = make_outbound() + o.use_connection(c) + sc1 = mock.Mock() + p1 = mock.Mock(name="p1") + alsoProvides(p1, IPullProducer) + p1.resumeProducing.side_effect = PretendResumptionError + o.subchannel_registerProducer(sc1, p1, False) + o._test_eq.flush_sync() + # the error is supposed to automatically unregister the producer + self.assertEqual(list(o._all_producers), []) + self.flushLoggedErrors(PretendResumptionError) + + +def make_pushpull(pauses): + p = mock.Mock() + alsoProvides(p, IPullProducer) + unregister = mock.Mock() + + clock = Clock() + eq = EventualQueue(clock) + term = mock.Mock(side_effect=lambda: True) # one write per Eventual tick + term_factory = lambda: term + coop = Cooperator(terminationPredicateFactory=term_factory, + scheduler=eq.eventually) + pp = PullToPush(p, unregister, coop) + + it = cycle(pauses) + def action(i): + if isinstance(i, Exception): + raise i + elif i: + pp.pauseProducing() + p.resumeProducing.side_effect = lambda: action(next(it)) + return p, unregister, pp, eq + +class PretendResumptionError(Exception): + pass +class PretendUnregisterError(Exception): + pass + +class PushPull(unittest.TestCase): + # test our wrapper utility, which I copied from + # twisted.internet._producer_helpers since it isn't publically exposed + + def test_start_unpaused(self): + p, unr, pp, eq = make_pushpull([True]) # pause on each resumeProducing + # if it starts unpaused, it gets one write before being halted + pp.startStreaming(False) + eq.flush_sync() + self.assertEqual(p.mock_calls, [mock.call.resumeProducing()]*1) + clear_mock_calls(p) + + # now each time we call resumeProducing, we should see one delivered to + # the underlying IPullProducer + pp.resumeProducing() + eq.flush_sync() + self.assertEqual(p.mock_calls, [mock.call.resumeProducing()]*1) + + pp.stopStreaming() + pp.stopStreaming() # should tolerate this + + def test_start_unpaused_two_writes(self): + p, unr, pp, eq = make_pushpull([False, True]) # pause every other time + # it should get two writes, since the first didn't pause + pp.startStreaming(False) + eq.flush_sync() + self.assertEqual(p.mock_calls, [mock.call.resumeProducing()]*2) + + def test_start_paused(self): + p, unr, pp, eq = make_pushpull([True]) # pause on each resumeProducing + pp.startStreaming(True) + eq.flush_sync() + self.assertEqual(p.mock_calls, []) + pp.stopStreaming() + + def test_stop(self): + p, unr, pp, eq = make_pushpull([True]) + pp.startStreaming(True) + pp.stopProducing() + eq.flush_sync() + self.assertEqual(p.mock_calls, [mock.call.stopProducing()]) + + def test_error(self): + p, unr, pp, eq = make_pushpull([PretendResumptionError()]) + unr.side_effect = lambda: pp.stopStreaming() + pp.startStreaming(False) + eq.flush_sync() + self.assertEqual(unr.mock_calls, [mock.call()]) + self.flushLoggedErrors(PretendResumptionError) + + def test_error_during_unregister(self): + p, unr, pp, eq = make_pushpull([PretendResumptionError()]) + unr.side_effect = PretendUnregisterError() + pp.startStreaming(False) + eq.flush_sync() + self.assertEqual(unr.mock_calls, [mock.call()]) + self.flushLoggedErrors(PretendResumptionError, PretendUnregisterError) + + + + + + # TODO: consider making p1/p2/p3 all elements of a shared Mock, maybe I + # could capture the inter-call ordering that way diff --git a/src/wormhole/test/dilate/test_parse.py b/src/wormhole/test/dilate/test_parse.py new file mode 100644 index 0000000..8365e62 --- /dev/null +++ b/src/wormhole/test/dilate/test_parse.py @@ -0,0 +1,43 @@ +from __future__ import print_function, unicode_literals +import mock +from twisted.trial import unittest +from ..._dilation.connection import (parse_record, encode_record, + KCM, Ping, Pong, Open, Data, Close, Ack) + +class Parse(unittest.TestCase): + def test_parse(self): + self.assertEqual(parse_record(b"\x00"), KCM()) + self.assertEqual(parse_record(b"\x01\x55\x44\x33\x22"), + Ping(ping_id=b"\x55\x44\x33\x22")) + self.assertEqual(parse_record(b"\x02\x55\x44\x33\x22"), + Pong(ping_id=b"\x55\x44\x33\x22")) + self.assertEqual(parse_record(b"\x03\x00\x00\x02\x01\x00\x00\x01\x00"), + Open(scid=513, seqnum=256)) + self.assertEqual(parse_record(b"\x04\x00\x00\x02\x02\x00\x00\x01\x01dataaa"), + Data(scid=514, seqnum=257, data=b"dataaa")) + self.assertEqual(parse_record(b"\x05\x00\x00\x02\x03\x00\x00\x01\x02"), + Close(scid=515, seqnum=258)) + self.assertEqual(parse_record(b"\x06\x00\x00\x01\x03"), + Ack(resp_seqnum=259)) + with mock.patch("wormhole._dilation.connection.log.err") as le: + with self.assertRaises(ValueError): + parse_record(b"\x07unknown") + self.assertEqual(le.mock_calls, + [mock.call("received unknown message type: {}".format( + b"\x07unknown"))]) + + def test_encode(self): + self.assertEqual(encode_record(KCM()), b"\x00") + self.assertEqual(encode_record(Ping(ping_id=b"ping")), b"\x01ping") + self.assertEqual(encode_record(Pong(ping_id=b"pong")), b"\x02pong") + self.assertEqual(encode_record(Open(scid=65536, seqnum=16)), + b"\x03\x00\x01\x00\x00\x00\x00\x00\x10") + self.assertEqual(encode_record(Data(scid=65537, seqnum=17, data=b"dataaa")), + b"\x04\x00\x01\x00\x01\x00\x00\x00\x11dataaa") + self.assertEqual(encode_record(Close(scid=65538, seqnum=18)), + b"\x05\x00\x01\x00\x02\x00\x00\x00\x12") + self.assertEqual(encode_record(Ack(resp_seqnum=19)), + b"\x06\x00\x00\x00\x13") + with self.assertRaises(TypeError) as ar: + encode_record("not a record") + self.assertEqual(str(ar.exception), "not a record") diff --git a/src/wormhole/test/dilate/test_record.py b/src/wormhole/test/dilate/test_record.py new file mode 100644 index 0000000..810396c --- /dev/null +++ b/src/wormhole/test/dilate/test_record.py @@ -0,0 +1,268 @@ +from __future__ import print_function, unicode_literals +import mock +from zope.interface import alsoProvides +from twisted.trial import unittest +from noise.exceptions import NoiseInvalidMessage +from ..._dilation.connection import (IFramer, Frame, Prologue, + _Record, Handshake, + Disconnect, Ping) + +def make_record(): + f = mock.Mock() + alsoProvides(f, IFramer) + n = mock.Mock() # pretends to be a Noise object + r = _Record(f, n) + return r, f, n + +class Record(unittest.TestCase): + def test_good2(self): + f = mock.Mock() + alsoProvides(f, IFramer) + f.add_and_parse = mock.Mock(side_effect=[ + [], + [Prologue()], + [Frame(frame=b"rx-handshake")], + [Frame(frame=b"frame1"), Frame(frame=b"frame2")], + ]) + n = mock.Mock() + n.write_message = mock.Mock(return_value=b"tx-handshake") + p1, p2 = object(), object() + n.decrypt = mock.Mock(side_effect=[p1, p2]) + r = _Record(f, n) + self.assertEqual(f.mock_calls, []) + r.connectionMade() + self.assertEqual(f.mock_calls, [mock.call.connectionMade()]) + f.mock_calls[:] = [] + self.assertEqual(n.mock_calls, [mock.call.start_handshake()]) + n.mock_calls[:] = [] + + # Pretend to deliver the prologue in two parts. The text we send in + # doesn't matter: the side_effect= is what causes the prologue to be + # recognized by the second call. + self.assertEqual(list(r.add_and_unframe(b"pro")), []) + self.assertEqual(f.mock_calls, [mock.call.add_and_parse(b"pro")]) + f.mock_calls[:] = [] + self.assertEqual(n.mock_calls, []) + + # recognizing the prologue causes a handshake frame to be sent + self.assertEqual(list(r.add_and_unframe(b"logue")), []) + self.assertEqual(f.mock_calls, [mock.call.add_and_parse(b"logue"), + mock.call.send_frame(b"tx-handshake")]) + f.mock_calls[:] = [] + self.assertEqual(n.mock_calls, [mock.call.write_message()]) + n.mock_calls[:] = [] + + # next add_and_unframe is recognized as the Handshake + self.assertEqual(list(r.add_and_unframe(b"blah")), [Handshake()]) + self.assertEqual(f.mock_calls, [mock.call.add_and_parse(b"blah")]) + f.mock_calls[:] = [] + self.assertEqual(n.mock_calls, [mock.call.read_message(b"rx-handshake")]) + n.mock_calls[:] = [] + + # next is a pair of Records + r1, r2 = object() , object() + with mock.patch("wormhole._dilation.connection.parse_record", + side_effect=[r1,r2]) as pr: + self.assertEqual(list(r.add_and_unframe(b"blah2")), [r1, r2]) + self.assertEqual(n.mock_calls, [mock.call.decrypt(b"frame1"), + mock.call.decrypt(b"frame2")]) + self.assertEqual(pr.mock_calls, [mock.call(p1), mock.call(p2)]) + + def test_bad_handshake(self): + f = mock.Mock() + alsoProvides(f, IFramer) + f.add_and_parse = mock.Mock(return_value=[Prologue(), + Frame(frame=b"rx-handshake")]) + n = mock.Mock() + n.write_message = mock.Mock(return_value=b"tx-handshake") + nvm = NoiseInvalidMessage() + n.read_message = mock.Mock(side_effect=nvm) + r = _Record(f, n) + self.assertEqual(f.mock_calls, []) + r.connectionMade() + self.assertEqual(f.mock_calls, [mock.call.connectionMade()]) + f.mock_calls[:] = [] + self.assertEqual(n.mock_calls, [mock.call.start_handshake()]) + n.mock_calls[:] = [] + + with mock.patch("wormhole._dilation.connection.log.err") as le: + with self.assertRaises(Disconnect): + list(r.add_and_unframe(b"data")) + self.assertEqual(le.mock_calls, + [mock.call(nvm, "bad inbound noise handshake")]) + + def test_bad_message(self): + f = mock.Mock() + alsoProvides(f, IFramer) + f.add_and_parse = mock.Mock(return_value=[Prologue(), + Frame(frame=b"rx-handshake"), + Frame(frame=b"bad-message")]) + n = mock.Mock() + n.write_message = mock.Mock(return_value=b"tx-handshake") + nvm = NoiseInvalidMessage() + n.decrypt = mock.Mock(side_effect=nvm) + r = _Record(f, n) + self.assertEqual(f.mock_calls, []) + r.connectionMade() + self.assertEqual(f.mock_calls, [mock.call.connectionMade()]) + f.mock_calls[:] = [] + self.assertEqual(n.mock_calls, [mock.call.start_handshake()]) + n.mock_calls[:] = [] + + with mock.patch("wormhole._dilation.connection.log.err") as le: + with self.assertRaises(Disconnect): + list(r.add_and_unframe(b"data")) + self.assertEqual(le.mock_calls, + [mock.call(nvm, "bad inbound noise frame")]) + + def test_send_record(self): + f = mock.Mock() + alsoProvides(f, IFramer) + n = mock.Mock() + f1 = object() + n.encrypt = mock.Mock(return_value=f1) + r1 = Ping(b"pingid") + r = _Record(f, n) + self.assertEqual(f.mock_calls, []) + m1 = object() + with mock.patch("wormhole._dilation.connection.encode_record", + return_value=m1) as er: + r.send_record(r1) + self.assertEqual(er.mock_calls, [mock.call(r1)]) + self.assertEqual(n.mock_calls, [mock.call.start_handshake(), + mock.call.encrypt(m1)]) + self.assertEqual(f.mock_calls, [mock.call.send_frame(f1)]) + + def test_good(self): + # Exercise the success path. The Record instance is given each chunk + # of data as it arrives on Protocol.dataReceived, and is supposed to + # return a series of Tokens (maybe none, if the chunk was incomplete, + # or more than one, if the chunk was larger). Internally, it delivers + # the chunks to the Framer for unframing (which returns 0 or more + # frames), manages the Noise decryption object, and parses any + # decrypted messages into tokens (some of which are consumed + # internally, others for delivery upstairs). + # + # in the normal flow, we get: + # + # | | Inbound | NoiseAction | Outbound | ToUpstairs | + # | | - | - | - | - | + # | 1 | | | prologue | | + # | 2 | prologue | | | | + # | 3 | | write_message | handshake | | + # | 4 | handshake | read_message | | Handshake | + # | 5 | | encrypt | KCM | | + # | 6 | KCM | decrypt | | KCM | + # | 7 | msg1 | decrypt | | msg1 | + + # 1: instantiating the Record instance causes the outbound prologue + # to be sent + + # 2+3: receipt of the inbound prologue triggers creation of the + # ephemeral key (the "handshake") by calling noise.write_message() + # and then writes the handshake to the outbound transport + + # 4: when the peer's handshake is received, it is delivered to + # noise.read_message(), which generates the shared key (enabling + # noise.send() and noise.decrypt()). It also delivers the Handshake + # token upstairs, which might (on the Follower) trigger immediate + # transmission of the Key Confirmation Message (KCM) + + # 5: the outbound KCM is framed and fed into noise.encrypt(), then + # sent outbound + + # 6: the peer's KCM is decrypted then delivered upstairs. The + # Follower treats this as a signal that it should use this connection + # (and drop all others). + + # 7: the peer's first message is decrypted, parsed, and delivered + # upstairs. This might be an Open or a Data, depending upon what + # queued messages were left over from the previous connection + + r, f, n = make_record() + outbound_handshake = object() + kcm, msg1 = object(), object() + f_kcm, f_msg1 = object(), object() + n.write_message = mock.Mock(return_value=outbound_handshake) + n.decrypt = mock.Mock(side_effect=[kcm, msg1]) + n.encrypt = mock.Mock(side_effect=[f_kcm, f_msg1]) + f.add_and_parse = mock.Mock(side_effect=[[], # no tokens yet + [Prologue()], + [Frame("f_handshake")], + [Frame("f_kcm"), + Frame("f_msg1")], + ]) + + self.assertEqual(f.mock_calls, []) + self.assertEqual(n.mock_calls, [mock.call.start_handshake()]) + n.mock_calls[:] = [] + + # 1. The Framer is responsible for sending the prologue, so we don't + # have to check that here, we just check that the Framer was told + # about connectionMade properly. + r.connectionMade() + self.assertEqual(f.mock_calls, [mock.call.connectionMade()]) + self.assertEqual(n.mock_calls, []) + f.mock_calls[:] = [] + + # 2 + # we dribble the prologue in over two messages, to make sure we can + # handle a dataReceived that doesn't complete the token + + # remember, add_and_unframe is a generator + self.assertEqual(list(r.add_and_unframe(b"pro")), []) + self.assertEqual(f.mock_calls, [mock.call.add_and_parse(b"pro")]) + self.assertEqual(n.mock_calls, []) + f.mock_calls[:] = [] + + self.assertEqual(list(r.add_and_unframe(b"logue")), []) + # 3: write_message, send outbound handshake + self.assertEqual(f.mock_calls, [mock.call.add_and_parse(b"logue"), + mock.call.send_frame(outbound_handshake), + ]) + self.assertEqual(n.mock_calls, [mock.call.write_message()]) + f.mock_calls[:] = [] + n.mock_calls[:] = [] + + # 4 + # Now deliver the Noise "handshake", the ephemeral public key. This + # is framed, but not a record, so it shouldn't decrypt or parse + # anything, but the handshake is delivered to the Noise object, and + # it does return a Handshake token so we can let the next layer up + # react (by sending the KCM frame if we're a Follower, or not if + # we're the Leader) + + self.assertEqual(list(r.add_and_unframe(b"handshake")), [Handshake()]) + self.assertEqual(f.mock_calls, [mock.call.add_and_parse(b"handshake")]) + self.assertEqual(n.mock_calls, [mock.call.read_message("f_handshake")]) + f.mock_calls[:] = [] + n.mock_calls[:] = [] + + + # 5: at this point we ought to be able to send a messge, the KCM + with mock.patch("wormhole._dilation.connection.encode_record", + side_effect=[b"r-kcm"]) as er: + r.send_record(kcm) + self.assertEqual(er.mock_calls, [mock.call(kcm)]) + self.assertEqual(n.mock_calls, [mock.call.encrypt(b"r-kcm")]) + self.assertEqual(f.mock_calls, [mock.call.send_frame(f_kcm)]) + n.mock_calls[:] = [] + f.mock_calls[:] = [] + + # 6: Now we deliver two messages stacked up: the KCM (Key + # Confirmation Message) and the first real message. Concatenating + # them tests that we can handle more than one token in a single + # chunk. We need to mock parse_record() because everything past the + # handshake is decrypted and parsed. + + with mock.patch("wormhole._dilation.connection.parse_record", + side_effect=[kcm, msg1]) as pr: + self.assertEqual(list(r.add_and_unframe(b"kcm,msg1")), + [kcm, msg1]) + self.assertEqual(f.mock_calls, + [mock.call.add_and_parse(b"kcm,msg1")]) + self.assertEqual(n.mock_calls, [mock.call.decrypt("f_kcm"), + mock.call.decrypt("f_msg1")]) + self.assertEqual(pr.mock_calls, [mock.call(kcm), mock.call(msg1)]) + n.mock_calls[:] = [] + f.mock_calls[:] = [] diff --git a/src/wormhole/test/dilate/test_subchannel.py b/src/wormhole/test/dilate/test_subchannel.py new file mode 100644 index 0000000..69fa001 --- /dev/null +++ b/src/wormhole/test/dilate/test_subchannel.py @@ -0,0 +1,142 @@ +from __future__ import print_function, unicode_literals +import mock +from twisted.trial import unittest +from twisted.internet.interfaces import ITransport +from twisted.internet.error import ConnectionDone +from ..._dilation.subchannel import (Once, SubChannel, + _WormholeAddress, _SubchannelAddress, + AlreadyClosedError) +from .common import mock_manager + +def make_sc(set_protocol=True): + scid = b"scid" + hostaddr = _WormholeAddress() + peeraddr = _SubchannelAddress(scid) + m = mock_manager() + sc = SubChannel(scid, m, hostaddr, peeraddr) + p = mock.Mock() + if set_protocol: + sc._set_protocol(p) + return sc, m, scid, hostaddr, peeraddr, p + +class SubChannelAPI(unittest.TestCase): + def test_once(self): + o = Once(ValueError) + o() + with self.assertRaises(ValueError): + o() + + def test_create(self): + sc, m, scid, hostaddr, peeraddr, p = make_sc() + self.assert_(ITransport.providedBy(sc)) + self.assertEqual(m.mock_calls, []) + self.assertIdentical(sc.getHost(), hostaddr) + self.assertIdentical(sc.getPeer(), peeraddr) + + def test_write(self): + sc, m, scid, hostaddr, peeraddr, p = make_sc() + + sc.write(b"data") + self.assertEqual(m.mock_calls, [mock.call.send_data(scid, b"data")]) + m.mock_calls[:] = [] + sc.writeSequence([b"more", b"data"]) + self.assertEqual(m.mock_calls, [mock.call.send_data(scid, b"moredata")]) + + def test_write_when_closing(self): + sc, m, scid, hostaddr, peeraddr, p = make_sc() + + sc.loseConnection() + self.assertEqual(m.mock_calls, [mock.call.send_close(scid)]) + m.mock_calls[:] = [] + + with self.assertRaises(AlreadyClosedError) as e: + sc.write(b"data") + self.assertEqual(str(e.exception), + "write not allowed on closed subchannel") + + def test_local_close(self): + sc, m, scid, hostaddr, peeraddr, p = make_sc() + + sc.loseConnection() + self.assertEqual(m.mock_calls, [mock.call.send_close(scid)]) + m.mock_calls[:] = [] + + # late arriving data is still delivered + sc.remote_data(b"late") + self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"late")]) + p.mock_calls[:] = [] + + sc.remote_close() + self.assert_connectionDone(p.mock_calls) + + def test_local_double_close(self): + sc, m, scid, hostaddr, peeraddr, p = make_sc() + + sc.loseConnection() + self.assertEqual(m.mock_calls, [mock.call.send_close(scid)]) + m.mock_calls[:] = [] + + with self.assertRaises(AlreadyClosedError) as e: + sc.loseConnection() + self.assertEqual(str(e.exception), + "loseConnection not allowed on closed subchannel") + + def assert_connectionDone(self, mock_calls): + self.assertEqual(len(mock_calls), 1) + self.assertEqual(mock_calls[0][0], "connectionLost") + self.assertEqual(len(mock_calls[0][1]), 1) + self.assertIsInstance(mock_calls[0][1][0], ConnectionDone) + + def test_remote_close(self): + sc, m, scid, hostaddr, peeraddr, p = make_sc() + sc.remote_close() + self.assertEqual(m.mock_calls, [mock.call.subchannel_closed(sc)]) + self.assert_connectionDone(p.mock_calls) + + def test_data(self): + sc, m, scid, hostaddr, peeraddr, p = make_sc() + sc.remote_data(b"data") + self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"data")]) + p.mock_calls[:] = [] + sc.remote_data(b"not") + sc.remote_data(b"coalesced") + self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"not"), + mock.call.dataReceived(b"coalesced"), + ]) + + def test_data_before_open(self): + sc, m, scid, hostaddr, peeraddr, p = make_sc(set_protocol=False) + sc.remote_data(b"data") + self.assertEqual(p.mock_calls, []) + sc._set_protocol(p) + self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"data")]) + p.mock_calls[:] = [] + sc.remote_data(b"more") + self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"more")]) + + def test_close_before_open(self): + sc, m, scid, hostaddr, peeraddr, p = make_sc(set_protocol=False) + sc.remote_close() + self.assertEqual(p.mock_calls, []) + sc._set_protocol(p) + self.assert_connectionDone(p.mock_calls) + + def test_producer(self): + sc, m, scid, hostaddr, peeraddr, p = make_sc() + + sc.pauseProducing() + self.assertEqual(m.mock_calls, [mock.call.subchannel_pauseProducing(sc)]) + m.mock_calls[:] = [] + sc.resumeProducing() + self.assertEqual(m.mock_calls, [mock.call.subchannel_resumeProducing(sc)]) + m.mock_calls[:] = [] + sc.stopProducing() + self.assertEqual(m.mock_calls, [mock.call.subchannel_stopProducing(sc)]) + m.mock_calls[:] = [] + + def test_consumer(self): + sc, m, scid, hostaddr, peeraddr, p = make_sc() + + # TODO: more, once this is implemented + sc.registerProducer(None, True) + sc.unregisterProducer() diff --git a/src/wormhole/test/test_machines.py b/src/wormhole/test/test_machines.py index 5b5e9f1..35854ae 100644 --- a/src/wormhole/test/test_machines.py +++ b/src/wormhole/test/test_machines.py @@ -12,7 +12,7 @@ import mock from .. import (__version__, _allocator, _boss, _code, _input, _key, _lister, _mailbox, _nameplate, _order, _receive, _rendezvous, _send, _terminator, errors, timing) -from .._interfaces import (IAllocator, IBoss, ICode, IInput, IKey, ILister, +from .._interfaces import (IAllocator, IBoss, ICode, IDilator, IInput, IKey, ILister, IMailbox, INameplate, IOrder, IReceive, IRendezvousConnector, ISend, ITerminator, IWordlist) from .._key import derive_key, derive_phase_key, encrypt_data @@ -1300,6 +1300,7 @@ class Boss(unittest.TestCase): b._RC = Dummy("rc", events, IRendezvousConnector, "start") b._C = Dummy("c", events, ICode, "allocate_code", "input_code", "set_code") + b._D = Dummy("d", events, IDilator, "got_wormhole_versions", "got_key") return b, events def test_basic(self): @@ -1327,7 +1328,9 @@ class Boss(unittest.TestCase): b.got_message("side", "0", b"msg1") self.assertEqual(events, [ ("w.got_key", b"key"), + ("d.got_key", b"key"), ("w.got_verifier", b"verifier"), + ("d.got_wormhole_versions", "side", "side", {}), ("w.got_versions", {}), ("w.received", b"msg1"), ]) diff --git a/src/wormhole/wormhole.py b/src/wormhole/wormhole.py index c02aa60..f967c25 100644 --- a/src/wormhole/wormhole.py +++ b/src/wormhole/wormhole.py @@ -9,6 +9,7 @@ from twisted.internet.task import Cooperator from zope.interface import implementer from ._boss import Boss +from ._dilation.connector import Connector from ._interfaces import IDeferredWormhole, IWormhole from ._key import derive_key from .errors import NoKeyError, WormholeClosed @@ -189,6 +190,9 @@ class _DeferredWormhole(object): raise NoKeyError() return derive_key(self._key, to_bytes(purpose), length) + def dilate(self): + return self._boss.dilate() # fires with (endpoints) + def close(self): # fails with WormholeError unless we established a connection # (state=="happy"). Fails with WrongPasswordError (a subclass of @@ -265,8 +269,12 @@ def create( w = _DelegatedWormhole(delegate) else: w = _DeferredWormhole(reactor, eq) - wormhole_versions = {} # will be used to indicate Wormhole capabilities - wormhole_versions["app_versions"] = versions # app-specific capabilities + # this indicates Wormhole capabilities + wormhole_versions = { + "can-dilate": [1], + "dilation-abilities": Connector.get_connection_abilities(), + } + wormhole_versions["app_versions"] = versions # app-specific capabilities v = __version__ if isinstance(v, type(b"")): v = v.decode("utf-8", errors="replace") From 39666f3fed3183ada348dcb7113fb7a9bd3fe4f4 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Sun, 1 Jul 2018 14:44:04 -0700 Subject: [PATCH 07/49] travis: tolerate failures in py2.7, 'noiseprotocol' dep is py3-only --- .travis.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.travis.yml b/.travis.yml index 33c3dad..69344cf 100644 --- a/.travis.yml +++ b/.travis.yml @@ -34,5 +34,8 @@ matrix: dist: xenial - python: nightly allow_failures: + - python: 2.7 - python: 3.3 + # travis doesn't support py3.7 yet + - python: 3.7 - python: nightly From 05900bd08b1ae7e9ef3cec46a71eea5408786b40 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Sat, 30 Jun 2018 16:19:41 -0700 Subject: [PATCH 08/49] fix some flake8 complaints --- src/wormhole/_boss.py | 7 +- src/wormhole/_dilation/connection.py | 85 +++++++----- src/wormhole/_dilation/connector.py | 129 +++++++++++------- src/wormhole/_dilation/encode.py | 3 + src/wormhole/_dilation/inbound.py | 21 ++- src/wormhole/_dilation/manager.py | 187 +++++++++++++++++---------- src/wormhole/_dilation/outbound.py | 33 +++-- src/wormhole/_dilation/subchannel.py | 69 +++++++--- 8 files changed, 339 insertions(+), 195 deletions(-) diff --git a/src/wormhole/_boss.py b/src/wormhole/_boss.py index cc6f4fa..d3cb5ff 100644 --- a/src/wormhole/_boss.py +++ b/src/wormhole/_boss.py @@ -67,7 +67,8 @@ class Boss(object): self._I = Input(self._timing) self._C = Code(self._timing) self._T = Terminator() - self._D = Dilator(self._reactor, self._eventual_queue, self._cooperator) + self._D = Dilator(self._reactor, self._eventual_queue, + self._cooperator) self._N.wire(self._M, self._I, self._RC, self._T) self._M.wire(self._N, self._RC, self._O, self._T) @@ -90,7 +91,7 @@ class Boss(object): self._rx_phases = {} # phase -> plaintext self._next_rx_dilate_seqnum = 0 - self._rx_dilate_seqnums = {} # seqnum -> plaintext + self._rx_dilate_seqnums = {} # seqnum -> plaintext self._result = "empty" @@ -205,7 +206,7 @@ class Boss(object): self._C.set_code(code) def dilate(self): - return self._D.dilate() # fires with endpoints + return self._D.dilate() # fires with endpoints @m.input() def send(self, plaintext): diff --git a/src/wormhole/_dilation/connection.py b/src/wormhole/_dilation/connection.py index f0c8e35..4b9ced7 100644 --- a/src/wormhole/_dilation/connection.py +++ b/src/wormhole/_dilation/connection.py @@ -39,20 +39,28 @@ from .roles import FOLLOWER # states). For the specific question of sending plaintext frames, Noise will # refuse us unless it's ready anyways, so the question is probably moot. + class IFramer(Interface): pass + + class IRecord(Interface): pass + def first(l): return l[0] + class Disconnect(Exception): pass + + RelayOK = namedtuple("RelayOk", []) Prologue = namedtuple("Prologue", []) Frame = namedtuple("Frame", ["frame"]) + @attrs @implementer(IFramer) class _Framer(object): @@ -69,30 +77,37 @@ class _Framer(object): # out (shared): transport.write (relay handshake, prologue) # states: want_relay, want_prologue, want_frame m = MethodicalMachine() - set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover + set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover @m.state() - def want_relay(self): pass # pragma: no cover + def want_relay(self): pass # pragma: no cover + @m.state(initial=True) - def want_prologue(self): pass # pragma: no cover + def want_prologue(self): pass # pragma: no cover + @m.state() - def want_frame(self): pass # pragma: no cover + def want_frame(self): pass # pragma: no cover @m.input() def use_relay(self, relay_handshake): pass + @m.input() def connectionMade(self): pass + @m.input() def parse(self): pass + @m.input() def got_relay_ok(self): pass + @m.input() def got_prologue(self): pass @m.output() def store_relay_handshake(self, relay_handshake): self._outbound_relay_handshake = relay_handshake - self._expected_relay_handshake = b"ok\n" # TODO: make this configurable + self._expected_relay_handshake = b"ok\n" # TODO: make this configurable + @m.output() def send_relay_handshake(self): self._transport.write(self._outbound_relay_handshake) @@ -113,17 +128,17 @@ class _Framer(object): @m.output() def can_send_frames(self): - self._can_send_frames = True # for assertion in send_frame() + self._can_send_frames = True # for assertion in send_frame() @m.output() def parse_frame(self): if len(self._buffer) < 4: return None frame_length = from_be4(self._buffer[0:4]) - if len(self._buffer) < 4+frame_length: + if len(self._buffer) < 4 + frame_length: return None - frame = self._buffer[4:4+frame_length] - self._buffer = self._buffer[4+frame_length:] # TODO: avoid copy + frame = self._buffer[4:4 + frame_length] + self._buffer = self._buffer[4 + frame_length:] # TODO: avoid copy return Frame(frame=frame) want_prologue.upon(use_relay, outputs=[store_relay_handshake], @@ -144,7 +159,6 @@ class _Framer(object): want_frame.upon(parse, outputs=[parse_frame], enter=want_frame, collector=first) - def _get_expected(self, name, expected): lb = len(self._buffer) le = len(expected) @@ -161,7 +175,7 @@ class _Framer(object): if (b"\n" in self._buffer or lb >= le): log.msg("bad {}: {}".format(name, self._buffer[:le])) raise Disconnect() - return False # wait a bit longer + return False # wait a bit longer # good so far, just waiting for the rest return False @@ -181,7 +195,7 @@ class _Framer(object): self.got_relay_ok() elif isinstance(token, Prologue): self.got_prologue() - yield token # triggers send_handshake + yield token # triggers send_handshake elif isinstance(token, Frame): yield token else: @@ -202,15 +216,16 @@ class _Framer(object): # from peer. Sent immediately by Follower, after Selection by Leader. # Record: namedtuple of KCM/Open/Data/Close/Ack/Ping/Pong + Handshake = namedtuple("Handshake", []) # decrypted frames: produces KCM, Ping, Pong, Open, Data, Close, Ack KCM = namedtuple("KCM", []) -Ping = namedtuple("Ping", ["ping_id"]) # ping_id is arbitrary 4-byte value +Ping = namedtuple("Ping", ["ping_id"]) # ping_id is arbitrary 4-byte value Pong = namedtuple("Pong", ["ping_id"]) -Open = namedtuple("Open", ["seqnum", "scid"]) # seqnum is integer +Open = namedtuple("Open", ["seqnum", "scid"]) # seqnum is integer Data = namedtuple("Data", ["seqnum", "scid", "data"]) -Close = namedtuple("Close", ["seqnum", "scid"]) # scid is integer -Ack = namedtuple("Ack", ["resp_seqnum"]) # resp_seqnum is integer +Close = namedtuple("Close", ["seqnum", "scid"]) # scid is integer +Ack = namedtuple("Ack", ["resp_seqnum"]) # resp_seqnum is integer Records = (KCM, Ping, Pong, Open, Data, Close, Ack) Handshake_or_Records = (Handshake,) + Records @@ -222,6 +237,7 @@ T_DATA = b"\x04" T_CLOSE = b"\x05" T_ACK = b"\x06" + def parse_record(plaintext): msgtype = plaintext[0:1] if msgtype == T_KCM: @@ -251,6 +267,7 @@ def parse_record(plaintext): log.err("received unknown message type: {}".format(plaintext)) raise ValueError() + def encode_record(r): if isinstance(r, KCM): return b"\x00" @@ -275,6 +292,7 @@ def encode_record(r): return b"\x06" + to_be4(r.resp_seqnum) raise TypeError(r) + @attrs @implementer(IRecord) class _Record(object): @@ -294,22 +312,25 @@ class _Record(object): # states: want_prologue, want_handshake, want_record @n.state(initial=True) - def want_prologue(self): pass # pragma: no cover + def want_prologue(self): pass # pragma: no cover + @n.state() - def want_handshake(self): pass # pragma: no cover + def want_handshake(self): pass # pragma: no cover + @n.state() - def want_message(self): pass # pragma: no cover + def want_message(self): pass # pragma: no cover @n.input() def got_prologue(self): pass + @n.input() def got_frame(self, frame): pass @n.output() def send_handshake(self): - handshake = self._noise.write_message() # generate the ephemeral key + handshake = self._noise.write_message() # generate the ephemeral key self._framer.send_frame(handshake) @n.output() @@ -351,10 +372,10 @@ class _Record(object): def add_and_unframe(self, data): for token in self._framer.add_and_parse(data): if isinstance(token, Prologue): - self.got_prologue() # triggers send_handshake + self.got_prologue() # triggers send_handshake else: assert isinstance(token, Frame) - yield self.got_frame(token.frame) # Handshake or a Record type + yield self.got_frame(token.frame) # Handshake or a Record type def send_record(self, r): message = encode_record(r) @@ -388,26 +409,30 @@ class DilatedConnectionProtocol(Protocol, object): _relay_handshake = None m = MethodicalMachine() - set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover + set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover def __attrs_post_init__(self): - self._manager = None # set if/when we are selected + self._manager = None # set if/when we are selected self._disconnected = OneShotObserver(self._eventual_queue) self._can_send_records = False @m.state(initial=True) - def unselected(self): pass # pragma: no cover + def unselected(self): pass # pragma: no cover + @m.state() - def selecting(self): pass # pragma: no cover + def selecting(self): pass # pragma: no cover + @m.state() - def selected(self): pass # pragma: no cover + def selected(self): pass # pragma: no cover @m.input() def got_kcm(self): pass + @m.input() def select(self, manager): - pass # fires set_manager() + pass # fires set_manager() + @m.input() def got_record(self, record): pass @@ -472,9 +497,9 @@ class DilatedConnectionProtocol(Protocol, object): elif isinstance(token, KCM): # if we're the leader, add this connection as a candiate. # if we're the follower, accept this connection. - self.got_kcm() # connector.add_candidate() + self.got_kcm() # connector.add_candidate() else: - self.got_record(token) # manager.got_record() + self.got_record(token) # manager.got_record() except Disconnect: self.transport.loseConnection() diff --git a/src/wormhole/_dilation/connector.py b/src/wormhole/_dilation/connector.py index 86f2a72..f530039 100644 --- a/src/wormhole/_dilation/connector.py +++ b/src/wormhole/_dilation/connector.py @@ -1,5 +1,6 @@ from __future__ import print_function, unicode_literals -import sys, re +import sys +import re from collections import defaultdict, namedtuple from binascii import hexlify import six @@ -13,7 +14,7 @@ from twisted.internet.endpoints import HostnameEndpoint, serverFromString from twisted.internet.protocol import ClientFactory, ServerFactory from twisted.python import log from hkdf import Hkdf -from .. import ipaddrs # TODO: move into _dilation/ +from .. import ipaddrs # TODO: move into _dilation/ from .._interfaces import IDilationConnector, IDilationManager from ..timing import DebugTiming from ..observer import EmptyableSet @@ -30,7 +31,8 @@ from .roles import LEADER # * expect to see the receiver/sender handshake bytes from the other side # * the sender writes "go\n", the receiver waits for "go\n" # * the rest of the connection contains transit data -DirectTCPV1Hint = namedtuple("DirectTCPV1Hint", ["hostname", "port", "priority"]) +DirectTCPV1Hint = namedtuple( + "DirectTCPV1Hint", ["hostname", "port", "priority"]) TorTCPV1Hint = namedtuple("TorTCPV1Hint", ["hostname", "port", "priority"]) # RelayV1Hint contains a tuple of DirectTCPV1Hint and TorTCPV1Hint hints (we # use a tuple rather than a list so they'll be hashable into a set). For each @@ -38,6 +40,7 @@ TorTCPV1Hint = namedtuple("TorTCPV1Hint", ["hostname", "port", "priority"]) # rest of the V1 protocol. Only one hint per relay is useful. RelayV1Hint = namedtuple("RelayV1Hint", ["hints"]) + def describe_hint_obj(hint, relay, tor): prefix = "tor->" if tor else "->" if relay: @@ -45,9 +48,10 @@ def describe_hint_obj(hint, relay, tor): if isinstance(hint, DirectTCPV1Hint): return prefix + "tcp:%s:%d" % (hint.hostname, hint.port) elif isinstance(hint, TorTCPV1Hint): - return prefix+"tor:%s:%d" % (hint.hostname, hint.port) + return prefix + "tor:%s:%d" % (hint.hostname, hint.port) else: - return prefix+str(hint) + return prefix + str(hint) + def parse_hint_argv(hint, stderr=sys.stderr): assert isinstance(hint, type("")) @@ -59,7 +63,8 @@ def parse_hint_argv(hint, stderr=sys.stderr): return None hint_type = mo.group(1) if hint_type != "tcp": - print("unknown hint type '%s' in '%s'" % (hint_type, hint), file=stderr) + print("unknown hint type '%s' in '%s'" % (hint_type, hint), + file=stderr) return None hint_value = mo.group(2) pieces = hint_value.split(":") @@ -84,17 +89,18 @@ def parse_hint_argv(hint, stderr=sys.stderr): return None return DirectTCPV1Hint(hint_host, hint_port, priority) -def parse_tcp_v1_hint(hint): # hint_struct -> hint_obj + +def parse_tcp_v1_hint(hint): # hint_struct -> hint_obj hint_type = hint.get("type", "") if hint_type not in ["direct-tcp-v1", "tor-tcp-v1"]: log.msg("unknown hint type: %r" % (hint,)) return None - if not("hostname" in hint - and isinstance(hint["hostname"], type(""))): + if not("hostname" in hint and + isinstance(hint["hostname"], type(""))): log.msg("invalid hostname in hint: %r" % (hint,)) return None - if not("port" in hint - and isinstance(hint["port"], six.integer_types)): + if not("port" in hint and + isinstance(hint["port"], six.integer_types)): log.msg("invalid port in hint: %r" % (hint,)) return None priority = hint.get("priority", 0.0) @@ -103,51 +109,58 @@ def parse_tcp_v1_hint(hint): # hint_struct -> hint_obj else: return TorTCPV1Hint(hint["hostname"], hint["port"], priority) + def parse_hint(hint_struct): hint_type = hint_struct.get("type", "") if hint_type == "relay-v1": # the struct can include multiple ways to reach the same relay - rhints = filter(lambda h: h, # drop None (unrecognized) + rhints = filter(lambda h: h, # drop None (unrecognized) [parse_tcp_v1_hint(rh) for rh in hint_struct["hints"]]) return RelayV1Hint(rhints) return parse_tcp_v1_hint(hint_struct) + def encode_hint(h): if isinstance(h, DirectTCPV1Hint): return {"type": "direct-tcp-v1", "priority": h.priority, "hostname": h.hostname, - "port": h.port, # integer + "port": h.port, # integer } elif isinstance(h, RelayV1Hint): rhint = {"type": "relay-v1", "hints": []} for rh in h.hints: rhint["hints"].append({"type": "direct-tcp-v1", - "priority": rh.priority, - "hostname": rh.hostname, - "port": rh.port}) + "priority": rh.priority, + "hostname": rh.hostname, + "port": rh.port}) return rhint elif isinstance(h, TorTCPV1Hint): return {"type": "tor-tcp-v1", "priority": h.priority, "hostname": h.hostname, - "port": h.port, # integer + "port": h.port, # integer } raise ValueError("unknown hint type", h) + def HKDF(skm, outlen, salt=None, CTXinfo=b""): return Hkdf(salt, skm).expand(CTXinfo, outlen) + def build_sided_relay_handshake(key, side): assert isinstance(side, type(u"")) - assert len(side) == 8*2 + assert len(side) == 8 * 2 token = HKDF(key, 32, CTXinfo=b"transit_relay_token") - return b"please relay "+hexlify(token)+b" for side "+side.encode("ascii")+b"\n" + return (b"please relay " + hexlify(token) + + b" for side " + side.encode("ascii") + b"\n") -PROLOGUE_LEADER = b"Magic-Wormhole Dilation Handshake v1 Leader\n\n" + +PROLOGUE_LEADER = b"Magic-Wormhole Dilation Handshake v1 Leader\n\n" PROLOGUE_FOLLOWER = b"Magic-Wormhole Dilation Handshake v1 Follower\n\n" NOISEPROTO = "Noise_NNpsk0_25519_ChaChaPoly_BLAKE2s" + @attrs @implementer(IDilationConnector) class Connector(object): @@ -176,10 +189,11 @@ class Connector(object): self._transit_relays = [relay] else: self._transit_relays = [] - self._listeners = set() # IListeningPorts that can be stopped - self._pending_connectors = set() # Deferreds that can be cancelled - self._pending_connections = EmptyableSet(_eventual_queue=self._eventual_queue) # Protocols to be stopped - self._contenders = set() # viable connections + self._listeners = set() # IListeningPorts that can be stopped + self._pending_connectors = set() # Deferreds that can be cancelled + self._pending_connections = EmptyableSet( + _eventual_queue=self._eventual_queue) # Protocols to be stopped + self._contenders = set() # viable connections self._winning_connection = None self._timing = self._timing or DebugTiming() self._timing.add("transit") @@ -212,26 +226,41 @@ class Connector(object): return p @m.state(initial=True) - def connecting(self): pass # pragma: no cover + def connecting(self): + pass # pragma: no cover + @m.state() - def connected(self): pass # pragma: no cover + def connected(self): + pass # pragma: no cover + @m.state(terminal=True) - def stopped(self): pass # pragma: no cover + def stopped(self): + pass # pragma: no cover # TODO: unify the tense of these method-name verbs @m.input() - def listener_ready(self, hint_objs): pass - @m.input() - def add_relay(self, hint_objs): pass - @m.input() - def got_hints(self, hint_objs): pass - @m.input() - def add_candidate(self, c): # called by DilatedConnectionProtocol + def listener_ready(self, hint_objs): pass + @m.input() - def accept(self, c): pass + def add_relay(self, hint_objs): + pass + @m.input() - def stop(self): pass + def got_hints(self, hint_objs): + pass + + @m.input() + def add_candidate(self, c): # called by DilatedConnectionProtocol + pass + + @m.input() + def accept(self, c): + pass + + @m.input() + def stop(self): + pass @m.output() def use_hints(self, hint_objs): @@ -255,19 +284,19 @@ class Connector(object): @m.output() def select_and_stop_remaining(self, c): self._winning_connection = c - self._contenders.clear() # we no longer care who else came close + self._contenders.clear() # we no longer care who else came close # remove this winner from the losers, so we don't shut it down self._pending_connections.discard(c) # shut down losing connections - self.stop_listeners() # TODO: maybe keep it open? NAT/p2p assist + self.stop_listeners() # TODO: maybe keep it open? NAT/p2p assist self.stop_pending_connectors() self.stop_pending_connections() - c.select(self._manager) # subsequent frames go directly to the manager + c.select(self._manager) # subsequent frames go directly to the manager if self._role is LEADER: # TODO: this should live in Connection - c.send_record(KCM()) # leader sends KCM now - self._manager.use_connection(c) # manager sends frames to Connection + c.send_record(KCM()) # leader sends KCM now + self._manager.use_connection(c) # manager sends frames to Connection @m.output() def stop_everything(self): @@ -279,7 +308,7 @@ class Connector(object): def stop_listeners(self): d = DeferredList([l.stopListening() for l in self._listeners]) self._listeners.clear() - return d # synchronization for tests + return d # synchronization for tests def stop_pending_connectors(self): return DeferredList([d.cancel() for d in self._pending_connectors]) @@ -306,7 +335,8 @@ class Connector(object): publish_hints]) connecting.upon(got_hints, enter=connecting, outputs=[use_hints]) connecting.upon(add_candidate, enter=connecting, outputs=[consider]) - connecting.upon(accept, enter=connected, outputs=[select_and_stop_remaining]) + connecting.upon(accept, enter=connected, outputs=[ + select_and_stop_remaining]) connecting.upon(stop, enter=stopped, outputs=[stop_everything]) # once connected, we ignore everything except stop @@ -317,9 +347,9 @@ class Connector(object): connected.upon(accept, enter=connected, outputs=[]) connected.upon(stop, enter=stopped, outputs=[stop_everything]) - # from Manager: start, got_hints, stop # maybe add_candidate, accept + def start(self): self._start_listener() if self._transit_relays: @@ -341,9 +371,10 @@ class Connector(object): ep = serverFromString(self._reactor, "tcp:0") f = InboundConnectionFactory(self) d = ep.listen(f) + def _listening(lp): # lp is an IListeningPort - self._listeners.add(lp) # for shutdown and tests + self._listeners.add(lp) # for shutdown and tests portnum = lp.getHost().port direct_hints = [DirectTCPV1Hint(six.u(addr), portnum, 0.0) for addr in addresses] @@ -378,7 +409,7 @@ class Connector(object): # one still running. But if we bail on that, we might consider # putting an inter-direct-hint delay here to influence the # process. - #delay += 1.0 + # delay += 1.0 if delay > 0.0: # Start trying the relays a few seconds after we start to try the # direct hints. The idea is to prefer direct connections, but not @@ -408,7 +439,7 @@ class Connector(object): self._connect, ep, desc, is_relay=True) self._pending_connectors.add(d) # TODO: - #if not contenders: + # if not contenders: # raise TransitError("No contenders for connection") # TODO: add 2*TIMEOUT deadline for first generation, don't wait forever for @@ -422,6 +453,7 @@ class Connector(object): f = OutboundConnectionFactory(self, relay_handshake) d = ep.connect(f) # fires with protocol, or ConnectError + def _connected(p): self._pending_connections.add(p) # c might not be in _pending_connections, if it turned out to be a @@ -444,7 +476,6 @@ class Connector(object): return HostnameEndpoint(self._reactor, hint.hostname, hint.port) return None - # Connection selection. All instances of DilatedConnectionProtocol which # look viable get passed into our add_contender() method. @@ -459,6 +490,7 @@ class Connector(object): # our Connection protocols call: add_candidate + @attrs class OutboundConnectionFactory(ClientFactory, object): _connector = attrib(validator=provides(IDilationConnector)) @@ -471,6 +503,7 @@ class OutboundConnectionFactory(ClientFactory, object): p.use_relay(self._relay_handshake) return p + @attrs class InboundConnectionFactory(ServerFactory, object): _connector = attrib(validator=provides(IDilationConnector)) diff --git a/src/wormhole/_dilation/encode.py b/src/wormhole/_dilation/encode.py index eb1b0b2..80e9902 100644 --- a/src/wormhole/_dilation/encode.py +++ b/src/wormhole/_dilation/encode.py @@ -4,10 +4,13 @@ import struct assert len(struct.pack("L", value) + + def from_be4(b): if not isinstance(b, bytes): raise TypeError(repr(b)) diff --git a/src/wormhole/_dilation/inbound.py b/src/wormhole/_dilation/inbound.py index 9235681..2f6ffaf 100644 --- a/src/wormhole/_dilation/inbound.py +++ b/src/wormhole/_dilation/inbound.py @@ -6,13 +6,19 @@ from twisted.python import log from .._interfaces import IDilationManager, IInbound from .subchannel import (SubChannel, _SubchannelAddress) + class DuplicateOpenError(Exception): pass + + class DataForMissingSubchannelError(Exception): pass + + class CloseForMissingSubchannelError(Exception): pass + @attrs @implementer(IInbound) class Inbound(object): @@ -24,8 +30,8 @@ class Inbound(object): def __attrs_post_init__(self): # we route inbound Data records to Subchannels .dataReceived - self._open_subchannels = {} # scid -> Subchannel - self._paused_subchannels = set() # Subchannels that have paused us + self._open_subchannels = {} # scid -> Subchannel + self._paused_subchannels = set() # Subchannels that have paused us # the set is non-empty, we pause the transport self._highest_inbound_acked = -1 self._connection = None @@ -37,7 +43,6 @@ class Inbound(object): def set_subchannel_zero(self, scid0, sc0): self._open_subchannels[scid0] = sc0 - def use_connection(self, c): self._connection = c # We can pause the connection's reads when we receive too much data. If @@ -61,7 +66,8 @@ class Inbound(object): def handle_open(self, scid): if scid in self._open_subchannels: - log.err(DuplicateOpenError("received duplicate OPEN for {}".format(scid))) + log.err(DuplicateOpenError( + "received duplicate OPEN for {}".format(scid))) return peer_addr = _SubchannelAddress(scid) sc = SubChannel(scid, self._manager, self._host_addr, peer_addr) @@ -71,14 +77,16 @@ class Inbound(object): def handle_data(self, scid, data): sc = self._open_subchannels.get(scid) if sc is None: - log.err(DataForMissingSubchannelError("received DATA for non-existent subchannel {}".format(scid))) + log.err(DataForMissingSubchannelError( + "received DATA for non-existent subchannel {}".format(scid))) return sc.remote_data(data) def handle_close(self, scid): sc = self._open_subchannels.get(scid) if sc is None: - log.err(CloseForMissingSubchannelError("received CLOSE for non-existent subchannel {}".format(scid))) + log.err(CloseForMissingSubchannelError( + "received CLOSE for non-existent subchannel {}".format(scid))) return sc.remote_close() @@ -90,7 +98,6 @@ class Inbound(object): def stop_using_connection(self): self._connection = None - # from our Subchannel, or rather from the Protocol above it and sent # through the subchannel diff --git a/src/wormhole/_dilation/manager.py b/src/wormhole/_dilation/manager.py index 860665a..b16d198 100644 --- a/src/wormhole/_dilation/manager.py +++ b/src/wormhole/_dilation/manager.py @@ -20,13 +20,19 @@ from .connection import KCM, Ping, Pong, Open, Data, Close, Ack from .inbound import Inbound from .outbound import Outbound + class OldPeerCannotDilateError(Exception): pass + + class UnknownDilationMessageType(Exception): pass + + class ReceivedHintsTooEarly(Exception): pass + @attrs @implementer(IDilationManager) class _ManagerBase(object): @@ -37,14 +43,14 @@ class _ManagerBase(object): _reactor = attrib() _eventual_queue = attrib() _cooperator = attrib() - _no_listen = False # TODO - _tor = None # TODO - _timing = None # TODO + _no_listen = False # TODO + _tor = None # TODO + _timing = None # TODO def __attrs_post_init__(self): self._got_versions_d = Deferred() - self._my_role = None # determined upon rx_PLEASE + self._my_role = None # determined upon rx_PLEASE self._connection = None self._made_first_connection = False @@ -53,51 +59,56 @@ class _ManagerBase(object): self._next_dilation_phase = 0 - self._next_subchannel_id = 0 # increments by 2 + self._next_subchannel_id = 0 # increments by 2 # I kept getting confused about which methods were for inbound data # (and thus flow-control methods go "out") and which were for # outbound data (with flow-control going "in"), so I split them up # into separate pieces. self._inbound = Inbound(self, self._host_addr) - self._outbound = Outbound(self, self._cooperator) # from us to peer + self._outbound = Outbound(self, self._cooperator) # from us to peer def set_listener_endpoint(self, listener_endpoint): self._inbound.set_listener_endpoint(listener_endpoint) + def set_subchannel_zero(self, scid0, sc0): self._inbound.set_subchannel_zero(scid0, sc0) def when_first_connected(self): return self._first_connected.when_fired() - def send_dilation_phase(self, **fields): dilation_phase = self._next_dilation_phase self._next_dilation_phase += 1 self._S.send("dilate-%d" % dilation_phase, dict_to_bytes(fields)) - def send_hints(self, hints): # from Connector + def send_hints(self, hints): # from Connector self.send_dilation_phase(type="connection-hints", hints=hints) - # forward inbound-ish things to _Inbound + def subchannel_pauseProducing(self, sc): self._inbound.subchannel_pauseProducing(sc) + def subchannel_resumeProducing(self, sc): self._inbound.subchannel_resumeProducing(sc) + def subchannel_stopProducing(self, sc): self._inbound.subchannel_stopProducing(sc) # forward outbound-ish things to _Outbound def subchannel_registerProducer(self, sc, producer, streaming): self._outbound.subchannel_registerProducer(sc, producer, streaming) + def subchannel_unregisterProducer(self, sc): self._outbound.subchannel_unregisterProducer(sc) def send_open(self, scid): self._queue_and_send(Open, scid) + def send_data(self, scid, data): self._queue_and_send(Data, scid, data) + def send_close(self, scid): self._queue_and_send(Close, scid) @@ -106,7 +117,7 @@ class _ManagerBase(object): # Outbound owns the send_record() pipe, so that it can stall new # writes after a new connection is made until after all queued # messages are written (to preserve ordering). - self._outbound.queue_and_send_record(r) # may trigger pauseProducing + self._outbound.queue_and_send_record(r) # may trigger pauseProducing def subchannel_closed(self, scid, sc): # let everyone clean up. This happens just after we delivered @@ -116,7 +127,6 @@ class _ManagerBase(object): self._inbound.subchannel_closed(scid, sc) self._outbound.subchannel_closed(scid, sc) - def _start_connecting(self, role): assert self._my_role is not None self._connector = Connector(self._transit_key, @@ -125,41 +135,41 @@ class _ManagerBase(object): self._reactor, self._eventual_queue, self._no_listen, self._tor, self._timing, - self._side, # needed for relay handshake + self._side, # needed for relay handshake self._my_role) self._connector.start() # our Connector calls these def connector_connection_made(self, c): - self.connection_made() # state machine update + self.connection_made() # state machine update self._connection = c self._inbound.use_connection(c) - self._outbound.use_connection(c) # does c.registerProducer + self._outbound.use_connection(c) # does c.registerProducer if not self._made_first_connection: self._made_first_connection = True self._first_connected.fire(None) pass + def connector_connection_lost(self): self._stop_using_connection() if self.role is LEADER: - self.connection_lost_leader() # state machine + self.connection_lost_leader() # state machine else: self.connection_lost_follower() - def _stop_using_connection(self): # the connection is already lost by this point self._connection = None self._inbound.stop_using_connection() - self._outbound.stop_using_connection() # does c.unregisterProducer + self._outbound.stop_using_connection() # does c.unregisterProducer # from our active Connection def got_record(self, r): # records with sequence numbers: always ack, ignore old ones if isinstance(r, (Open, Data, Close)): - self.send_ack(r.seqnum) # always ack, even for old ones + self.send_ack(r.seqnum) # always ack, even for old ones if self._inbound.is_record_old(r): return self._inbound.update_ack_watermark(r.seqnum) @@ -167,7 +177,7 @@ class _ManagerBase(object): self._inbound.handle_open(r.scid) elif isinstance(r, Data): self._inbound.handle_data(r.scid, r.data) - else: # isinstance(r, Close) + else: # isinstance(r, Close) self._inbound.handle_close(r.scid) if isinstance(r, KCM): log.err("got unexpected KCM") @@ -176,7 +186,7 @@ class _ManagerBase(object): elif isinstance(r, Pong): self.handle_pong(r.ping_id) elif isinstance(r, Ack): - self._outbound.handle_ack(r.resp_seqnum) # retire queued messages + self._outbound.handle_ack(r.resp_seqnum) # retire queued messages else: log.err("received unknown message type {}".format(r)) @@ -190,7 +200,6 @@ class _ManagerBase(object): def send_ack(self, resp_seqnum): self._outbound.send_if_connected(Ack(resp_seqnum)) - def handle_ping(self, ping_id): self.send_pong(ping_id) @@ -200,7 +209,7 @@ class _ManagerBase(object): # subchannel maintenance def allocate_subchannel_id(self): - raise NotImplemented # subclass knows if we're leader or follower + raise NotImplementedError # subclass knows if we're leader or follower # new scheme: # * both sides send PLEASE as soon as they have an unverified key and @@ -236,58 +245,91 @@ class _ManagerBase(object): # * if follower calls w.dilate() but not leader, follower waits forever # in "want", leader waits forever in "wanted" + class ManagerShared(_ManagerBase): m = MethodicalMachine() set_trace = getattr(m, "_setTrace", lambda self, f: None) @m.state(initial=True) - def IDLE(self): pass # pragma: no cover + def IDLE(self): + pass # pragma: no cover @m.state() - def WANTING(self): pass # pragma: no cover + def WANTING(self): + pass # pragma: no cover + @m.state() - def WANTED(self): pass # pragma: no cover + def WANTED(self): + pass # pragma: no cover + @m.state() - def CONNECTING(self): pass # pragma: no cover + def CONNECTING(self): + pass # pragma: no cover + @m.state() - def CONNECTED(self): pass # pragma: no cover + def CONNECTED(self): + pass # pragma: no cover + @m.state() - def FLUSHING(self): pass # pragma: no cover + def FLUSHING(self): + pass # pragma: no cover + @m.state() - def ABANDONING(self): pass # pragma: no cover + def ABANDONING(self): + pass # pragma: no cover + @m.state() - def LONELY(self): pass # pragme: no cover + def LONELY(self): + pass # pragme: no cover + @m.state() - def STOPPING(self): pass # pragma: no cover + def STOPPING(self): + pass # pragma: no cover + @m.state(terminal=True) - def STOPPED(self): pass # pragma: no cover + def STOPPED(self): + pass # pragma: no cover @m.input() - def start(self): pass # pragma: no cover + def start(self): + pass # pragma: no cover + @m.input() - def rx_PLEASE(self, message): pass # pragma: no cover - @m.input() # only sent by Follower - def rx_HINTS(self, hint_message): pass # pragma: no cover - @m.input() # only Leader sends RECONNECT, so only Follower receives it - def rx_RECONNECT(self): pass # pragma: no cover - @m.input() # only Follower sends RECONNECTING, so only Leader receives it - def rx_RECONNECTING(self): pass # pragma: no cover + def rx_PLEASE(self, message): + pass # pragma: no cover + + @m.input() # only sent by Follower + def rx_HINTS(self, hint_message): + pass # pragma: no cover + + @m.input() # only Leader sends RECONNECT, so only Follower receives it + def rx_RECONNECT(self): + pass # pragma: no cover + + @m.input() # only Follower sends RECONNECTING, so only Leader receives it + def rx_RECONNECTING(self): + pass # pragma: no cover # Connector gives us connection_made() @m.input() - def connection_made(self, c): pass # pragma: no cover + def connection_made(self, c): + pass # pragma: no cover # our connection_lost() fires connection_lost_leader or # connection_lost_follower depending upon our role. If either side sees a # problem with the connection (timeouts, bad authentication) then they # just drop it and let connection_lost() handle the cleanup. @m.input() - def connection_lost_leader(self): pass # pragma: no cover - @m.input() - def connection_lost_follower(self): pass + def connection_lost_leader(self): + pass # pragma: no cover @m.input() - def stop(self): pass # pragma: no cover + def connection_lost_follower(self): + pass + + @m.input() + def stop(self): + pass # pragma: no cover @m.output() def stash_side(self, message): @@ -301,38 +343,42 @@ class ManagerShared(_ManagerBase): @m.output() def start_connecting(self): - self._start_connecting() # TODO: merge + self._start_connecting() # TODO: merge + @m.output() def ignore_message_start_connecting(self, message): self.start_connecting() @m.output() def send_reconnect(self): - self.send_dilation_phase(type="reconnect") # TODO: generation number? + self.send_dilation_phase(type="reconnect") # TODO: generation number? + @m.output() def send_reconnecting(self): - self.send_dilation_phase(type="reconnecting") # TODO: generation? + self.send_dilation_phase(type="reconnecting") # TODO: generation? @m.output() def use_hints(self, hint_message): - hint_objs = filter(lambda h: h, # ignore None, unrecognizable + hint_objs = filter(lambda h: h, # ignore None, unrecognizable [parse_hint(hs) for hs in hint_message["hints"]]) hint_objs = list(hint_objs) self._connector.got_hints(hint_objs) + @m.output() def stop_connecting(self): self._connector.stop() + @m.output() def abandon_connection(self): # we think we're still connected, but the Leader disagrees. Or we've # been told to shut down. - self._connection.disconnect() # let connection_lost do cleanup - + self._connection.disconnect() # let connection_lost do cleanup # we don't start CONNECTING until a local start() plus rx_PLEASE IDLE.upon(rx_PLEASE, enter=WANTED, outputs=[stash_side]) IDLE.upon(start, enter=WANTING, outputs=[send_please]) - WANTED.upon(start, enter=CONNECTING, outputs=[send_please, start_connecting]) + WANTED.upon(start, enter=CONNECTING, outputs=[ + send_please, start_connecting]) WANTING.upon(rx_PLEASE, enter=CONNECTING, outputs=[stash_side, ignore_message_start_connecting]) @@ -342,7 +388,8 @@ class ManagerShared(_ManagerBase): # Leader CONNECTED.upon(connection_lost_leader, enter=FLUSHING, outputs=[send_reconnect]) - FLUSHING.upon(rx_RECONNECTING, enter=CONNECTING, outputs=[start_connecting]) + FLUSHING.upon(rx_RECONNECTING, enter=CONNECTING, + outputs=[start_connecting]) # Follower # if we notice a lost connection, just wait for the Leader to notice too @@ -350,7 +397,7 @@ class ManagerShared(_ManagerBase): LONELY.upon(rx_RECONNECT, enter=CONNECTING, outputs=[start_connecting]) # but if they notice it first, abandon our (seemingly functional) # connection, then tell them that we're ready to try again - CONNECTED.upon(rx_RECONNECT, enter=ABANDONING, # they noticed loss + CONNECTED.upon(rx_RECONNECT, enter=ABANDONING, # they noticed loss outputs=[abandon_connection]) ABANDONING.upon(connection_lost_follower, enter=CONNECTING, outputs=[send_reconnecting, start_connecting]) @@ -362,16 +409,15 @@ class ManagerShared(_ManagerBase): send_reconnecting, start_connecting]) - # rx_HINTS never changes state, they're just accepted or ignored - IDLE.upon(rx_HINTS, enter=IDLE, outputs=[]) # too early - WANTED.upon(rx_HINTS, enter=WANTED, outputs=[]) # too early - WANTING.upon(rx_HINTS, enter=WANTING, outputs=[]) # too early + IDLE.upon(rx_HINTS, enter=IDLE, outputs=[]) # too early + WANTED.upon(rx_HINTS, enter=WANTED, outputs=[]) # too early + WANTING.upon(rx_HINTS, enter=WANTING, outputs=[]) # too early CONNECTING.upon(rx_HINTS, enter=CONNECTING, outputs=[use_hints]) - CONNECTED.upon(rx_HINTS, enter=CONNECTED, outputs=[]) # too late, ignore - FLUSHING.upon(rx_HINTS, enter=FLUSHING, outputs=[]) # stale, ignore - LONELY.upon(rx_HINTS, enter=LONELY, outputs=[]) # stale, ignore - ABANDONING.upon(rx_HINTS, enter=ABANDONING, outputs=[]) # shouldn't happen + CONNECTED.upon(rx_HINTS, enter=CONNECTED, outputs=[]) # too late, ignore + FLUSHING.upon(rx_HINTS, enter=FLUSHING, outputs=[]) # stale, ignore + LONELY.upon(rx_HINTS, enter=LONELY, outputs=[]) # stale, ignore + ABANDONING.upon(rx_HINTS, enter=ABANDONING, outputs=[]) # shouldn't happen STOPPING.upon(rx_HINTS, enter=STOPPING, outputs=[]) IDLE.upon(stop, enter=STOPPED, outputs=[]) @@ -385,7 +431,6 @@ class ManagerShared(_ManagerBase): STOPPING.upon(connection_lost_leader, enter=STOPPED, outputs=[]) STOPPING.upon(connection_lost_follower, enter=STOPPED, outputs=[]) - def allocate_subchannel_id(self): # scid 0 is reserved for the control channel. the leader uses odd # numbers starting with 1 @@ -393,6 +438,7 @@ class ManagerShared(_ManagerBase): self._next_outbound_seqnum += 2 return to_be4(scid_num) + @attrs @implementer(IDilator) class Dilator(object): @@ -436,10 +482,10 @@ class Dilator(object): dilation_version = yield self._got_versions_d - if not dilation_version: # 1 or None + if not dilation_version: # 1 or None raise OldPeerCannotDilateError() - my_dilation_side = TODO # random + my_dilation_side = TODO # random self._manager = Manager(self._S, my_dilation_side, self._transit_key, self._transit_relay_location, @@ -455,14 +501,15 @@ class Dilator(object): yield self._manager.when_first_connected() # we can open subchannels as soon as we get our first connection scid0 = b"\x00\x00\x00\x00" - self._host_addr = _WormholeAddress() # TODO: share with Manager + self._host_addr = _WormholeAddress() # TODO: share with Manager peer_addr0 = _SubchannelAddress(scid0) control_ep = ControlEndpoint(peer_addr0) sc0 = SubChannel(scid0, self._manager, self._host_addr, peer_addr0) control_ep._subchannel_zero_opened(sc0) self._manager.set_subchannel_zero(scid0, sc0) - connect_ep = SubchannelConnectorEndpoint(self._manager, self._host_addr) + connect_ep = SubchannelConnectorEndpoint( + self._manager, self._host_addr) listen_ep = SubchannelListenerEndpoint(self._manager, self._host_addr) self._manager.set_listener_endpoint(listen_ep) @@ -476,7 +523,7 @@ class Dilator(object): # TODO: verify this happens before got_wormhole_versions, or add a gate # to tolerate either ordering purpose = b"dilation-v1" - LENGTH =32 # TODO: whatever Noise wants, I guess + LENGTH = 32 # TODO: whatever Noise wants, I guess self._transit_key = derive_key(key, purpose, LENGTH) def got_wormhole_versions(self, our_side, their_side, @@ -504,9 +551,9 @@ class Dilator(object): message = bytes_to_dict(plaintext) type = message["type"] if type == "please": - self._manager.rx_PLEASE() # message) + self._manager.rx_PLEASE() # message) elif type == "dilate": - self._manager.rx_DILATE() #message) + self._manager.rx_DILATE() # message) elif type == "connection-hints": self._manager.rx_HINTS(message) else: diff --git a/src/wormhole/_dilation/outbound.py b/src/wormhole/_dilation/outbound.py index 6538ffe..96fbd3d 100644 --- a/src/wormhole/_dilation/outbound.py +++ b/src/wormhole/_dilation/outbound.py @@ -168,9 +168,9 @@ class Outbound(object): self._queued_unsent = deque() # outbound flow control: the Connection throttles our writes - self._subchannel_producers = {} # Subchannel -> IProducer - self._paused = True # our Connection called our pauseProducing - self._all_producers = deque() # rotates, left-is-next + self._subchannel_producers = {} # Subchannel -> IProducer + self._paused = True # our Connection called our pauseProducing + self._all_producers = deque() # rotates, left-is-next self._paused_producers = set() self._unpaused_producers = set() self._check_invariants() @@ -186,7 +186,7 @@ class Outbound(object): seqnum = self._next_outbound_seqnum self._next_outbound_seqnum += 1 r = record_type(seqnum, *args) - assert hasattr(r, "seqnum"), r # only Open/Data/Close + assert hasattr(r, "seqnum"), r # only Open/Data/Close return r def queue_and_send_record(self, r): @@ -203,7 +203,7 @@ class Outbound(object): self._connection.send_record(r) def send_if_connected(self, r): - assert isinstance(r, (KCM, Ping, Pong, Ack)), r # nothing with seqnum + assert isinstance(r, (KCM, Ping, Pong, Ack)), r # nothing with seqnum if self._connection: self._connection.send_record(r) @@ -235,7 +235,7 @@ class Outbound(object): if self._paused: # IPushProducers need to be paused immediately, before they # speak - producer.pauseProducing() # you wake up sleeping + producer.pauseProducing() # you wake up sleeping else: # our PullToPush adapter must be started, but if we're paused then # we tell it to pause before it gets a chance to write anything @@ -265,7 +265,7 @@ class Outbound(object): assert not self._queued_unsent self._queued_unsent.extend(self._outbound_queue) # the connection can tell us to pause when we send too much data - c.registerProducer(self, True) # IPushProducer: pause+resume + c.registerProducer(self, True) # IPushProducer: pause+resume # send our queued messages self.resumeProducing() @@ -290,12 +290,12 @@ class Outbound(object): # Inbound is responsible for tracking the high watermark and deciding # whether to ignore inbound messages or not - # IProducer: the active connection calls these because we used # c.registerProducer to ask for them + def pauseProducing(self): if self._paused: - return # someone is confused and called us twice + return # someone is confused and called us twice self._paused = True for p in self._all_producers: if p in self._unpaused_producers: @@ -305,7 +305,7 @@ class Outbound(object): def resumeProducing(self): if not self._paused: - return # someone is confused and called us twice + return # someone is confused and called us twice self._paused = False while not self._paused: @@ -326,7 +326,7 @@ class Outbound(object): return None while True: p = self._all_producers[0] - self._all_producers.rotate(-1) # p moves to the end of the line + self._all_producers.rotate(-1) # p moves to the end of the line # the only unpaused Producers are at the end of the list assert p in self._paused_producers return p @@ -343,7 +343,7 @@ class Outbound(object): @attrs(cmp=False) class PullToPush(object): _producer = attrib(validator=provides(IPullProducer)) - _unregister = attrib(validator=lambda _a,_b,v: callable(v)) + _unregister = attrib(validator=lambda _a, _b, v: callable(v)) _cooperator = attrib() _finished = False @@ -351,14 +351,14 @@ class PullToPush(object): while True: try: self._producer.resumeProducing() - except: + except Exception: log.err(None, "%s failed, producing will be stopped:" % (safe_str(self._producer),)) try: self._unregister() # The consumer should now call stopStreaming() on us, # thus stopping the streaming. - except: + except Exception: # Since the consumer blew up, we may not have had # stopStreaming() called, so we just stop on our own: log.err(None, "%s failed to unregister producer:" % @@ -370,7 +370,7 @@ class PullToPush(object): def startStreaming(self, paused): self._coopTask = self._cooperator.cooperate(self._pull()) if paused: - self.pauseProducing() # timer is scheduled, but task is removed + self.pauseProducing() # timer is scheduled, but task is removed def stopStreaming(self): if self._finished: @@ -378,15 +378,12 @@ class PullToPush(object): self._finished = True self._coopTask.stop() - def pauseProducing(self): self._coopTask.pause() - def resumeProducing(self): self._coopTask.resume() - def stopProducing(self): self.stopStreaming() self._producer.stopProducing() diff --git a/src/wormhole/_dilation/subchannel.py b/src/wormhole/_dilation/subchannel.py index 94b4a03..abd1939 100644 --- a/src/wormhole/_dilation/subchannel.py +++ b/src/wormhole/_dilation/subchannel.py @@ -1,7 +1,8 @@ from attr import attrs, attrib from attr.validators import instance_of, provides from zope.interface import implementer -from twisted.internet.defer import Deferred, inlineCallbacks, returnValue, succeed +from twisted.internet.defer import (Deferred, inlineCallbacks, returnValue, + succeed) from twisted.internet.interfaces import (ITransport, IProducer, IConsumer, IAddress, IListeningPort, IStreamClientEndpoint, @@ -10,9 +11,11 @@ from twisted.internet.error import ConnectionDone from automat import MethodicalMachine from .._interfaces import ISubChannel, IDilationManager + @attrs class Once(object): _errtype = attrib() + def __attrs_post_init__(self): self._called = False @@ -21,6 +24,7 @@ class Once(object): raise self._errtype() self._called = True + class SingleUseEndpointError(Exception): pass @@ -38,13 +42,16 @@ class SingleUseEndpointError(Exception): # (CLOSING) rx CLOSE: deliver .connectionLost(), -> (CLOSED) # object is deleted upon transition to (CLOSED) + class AlreadyClosedError(Exception): pass + @implementer(IAddress) class _WormholeAddress(object): pass + @implementer(IAddress) @attrs class _SubchannelAddress(object): @@ -63,35 +70,44 @@ class SubChannel(object): _peer_addr = attrib(validator=instance_of(_SubchannelAddress)) m = MethodicalMachine() - set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover + set_trace = getattr(m, "_setTrace", lambda self, + f: None) # pragma: no cover def __attrs_post_init__(self): - #self._mailbox = None - #self._pending_outbound = {} - #self._processed = set() + # self._mailbox = None + # self._pending_outbound = {} + # self._processed = set() self._protocol = None self._pending_dataReceived = [] self._pending_connectionLost = (False, None) @m.state(initial=True) - def open(self): pass # pragma: no cover + def open(self): + pass # pragma: no cover @m.state() - def closing(): pass # pragma: no cover + def closing(): + pass # pragma: no cover @m.state() - def closed(): pass # pragma: no cover + def closed(): + pass # pragma: no cover @m.input() - def remote_data(self, data): pass - @m.input() - def remote_close(self): pass + def remote_data(self, data): + pass @m.input() - def local_data(self, data): pass - @m.input() - def local_close(self): pass + def remote_close(self): + pass + @m.input() + def local_data(self, data): + pass + + @m.input() + def local_close(self): + pass @m.output() def send_data(self, data): @@ -120,9 +136,11 @@ class SubChannel(object): @m.output() def error_closed_write(self, data): raise AlreadyClosedError("write not allowed on closed subchannel") + @m.output() def error_closed_close(self): - raise AlreadyClosedError("loseConnection not allowed on closed subchannel") + raise AlreadyClosedError( + "loseConnection not allowed on closed subchannel") # primary transitions open.upon(remote_data, enter=open, outputs=[signal_dataReceived]) @@ -146,7 +164,7 @@ class SubChannel(object): if self._pending_dataReceived: for data in self._pending_dataReceived: self._protocol.dataReceived(data) - self._pending_dataReceived = [] + self._pending_dataReceived = [] cl, what = self._pending_connectionLost if cl: self._protocol.connectionLost(what) @@ -155,13 +173,17 @@ class SubChannel(object): def write(self, data): assert isinstance(data, type(b"")) self.local_data(data) + def writeSequence(self, iovec): self.write(b"".join(iovec)) + def loseConnection(self): self.local_close() + def getHost(self): # we define "host addr" as the overall wormhole return self._host_addr + def getPeer(self): # and "peer addr" as the subchannel within that wormhole return self._peer_addr @@ -169,14 +191,17 @@ class SubChannel(object): # IProducer: throttle inbound data (wormhole "up" to local app's Protocol) def stopProducing(self): self._manager.subchannel_stopProducing(self) + def pauseProducing(self): self._manager.subchannel_pauseProducing(self) + def resumeProducing(self): self._manager.subchannel_resumeProducing(self) # IConsumer: allow the wormhole to throttle outbound data (app->wormhole) def registerProducer(self, producer, streaming): self._manager.subchannel_registerProducer(self, producer, streaming) + def unregisterProducer(self): self._manager.subchannel_unregisterProducer(self) @@ -184,6 +209,7 @@ class SubChannel(object): @implementer(IStreamClientEndpoint) class ControlEndpoint(object): _used = False + def __init__(self, peer_addr): self._subchannel_zero = Deferred() self._peer_addr = peer_addr @@ -201,9 +227,10 @@ class ControlEndpoint(object): t = yield self._subchannel_zero p = protocolFactory.buildProtocol(self._peer_addr) t._set_protocol(p) - p.makeConnection(t) # set p.transport = t and call connectionMade() + p.makeConnection(t) # set p.transport = t and call connectionMade() returnValue(p) + @implementer(IStreamClientEndpoint) @attrs class SubchannelConnectorEndpoint(object): @@ -220,9 +247,10 @@ class SubchannelConnectorEndpoint(object): t = SubChannel(scid, self._manager, self._host_addr, peer_addr) p = protocolFactory.buildProtocol(peer_addr) t._set_protocol(p) - p.makeConnection(t) # set p.transport = t and call connectionMade() + p.makeConnection(t) # set p.transport = t and call connectionMade() return succeed(p) + @implementer(IStreamServerEndpoint) @attrs class SubchannelListenerEndpoint(object): @@ -238,7 +266,7 @@ class SubchannelListenerEndpoint(object): if self._factory: self._connect(t, peer_addr) else: - self._pending_opens.append( (t, peer_addr) ) + self._pending_opens.append((t, peer_addr)) def _connect(self, t, peer_addr): p = self._factory.buildProtocol(peer_addr) @@ -255,6 +283,7 @@ class SubchannelListenerEndpoint(object): lp = SubchannelListeningPort(self._host_addr) return succeed(lp) + @implementer(IListeningPort) @attrs class SubchannelListeningPort(object): @@ -262,8 +291,10 @@ class SubchannelListeningPort(object): def startListening(self): pass + def stopListening(self): # TODO pass + def getHost(self): return self._host_addr From ea35e570a2275203f08c88dae10856ae6b18a0b0 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Sat, 30 Jun 2018 16:19:48 -0700 Subject: [PATCH 09/49] setup.cfg: bump flake8 max-line-length to 84 --- setup.cfg | 3 +++ 1 file changed, 3 insertions(+) diff --git a/setup.cfg b/setup.cfg index aec9318..996e37d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -7,3 +7,6 @@ versionfile_source = src/wormhole/_version.py versionfile_build = wormhole/_version.py tag_prefix = parentdir_prefix = magic-wormhole + +[flake8] +max-line-length = 84 From bf0c93eddc4bac5e8a8b248702241cc00ab6239d Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Sat, 30 Jun 2018 16:23:39 -0700 Subject: [PATCH 10/49] more flake8 fixes --- src/wormhole/_dilation/connection.py | 42 +++++++---- src/wormhole/_dilation/old-follower.py | 68 ++++++++++++------ src/wormhole/test/dilate/common.py | 3 + src/wormhole/test/dilate/test_connection.py | 13 ++-- src/wormhole/test/dilate/test_encoding.py | 5 +- src/wormhole/test/dilate/test_endpoints.py | 3 +- src/wormhole/test/dilate/test_framer.py | 6 +- src/wormhole/test/dilate/test_inbound.py | 4 +- src/wormhole/test/dilate/test_manager.py | 13 ++-- src/wormhole/test/dilate/test_outbound.py | 78 ++++++++++++--------- src/wormhole/test/dilate/test_parse.py | 1 + src/wormhole/test/dilate/test_record.py | 13 ++-- src/wormhole/test/dilate/test_subchannel.py | 2 + 13 files changed, 155 insertions(+), 96 deletions(-) diff --git a/src/wormhole/_dilation/connection.py b/src/wormhole/_dilation/connection.py index 4b9ced7..d142eed 100644 --- a/src/wormhole/_dilation/connection.py +++ b/src/wormhole/_dilation/connection.py @@ -80,28 +80,36 @@ class _Framer(object): set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover @m.state() - def want_relay(self): pass # pragma: no cover + def want_relay(self): + pass # pragma: no cover @m.state(initial=True) - def want_prologue(self): pass # pragma: no cover + def want_prologue(self): + pass # pragma: no cover @m.state() - def want_frame(self): pass # pragma: no cover + def want_frame(self): + pass # pragma: no cover @m.input() - def use_relay(self, relay_handshake): pass + def use_relay(self, relay_handshake): + pass @m.input() - def connectionMade(self): pass + def connectionMade(self): + pass @m.input() - def parse(self): pass + def parse(self): + pass @m.input() - def got_relay_ok(self): pass + def got_relay_ok(self): + pass @m.input() - def got_prologue(self): pass + def got_prologue(self): + pass @m.output() def store_relay_handshake(self, relay_handshake): @@ -312,13 +320,16 @@ class _Record(object): # states: want_prologue, want_handshake, want_record @n.state(initial=True) - def want_prologue(self): pass # pragma: no cover + def want_prologue(self): + pass # pragma: no cover @n.state() - def want_handshake(self): pass # pragma: no cover + def want_handshake(self): + pass # pragma: no cover @n.state() - def want_message(self): pass # pragma: no cover + def want_message(self): + pass # pragma: no cover @n.input() def got_prologue(self): @@ -417,13 +428,16 @@ class DilatedConnectionProtocol(Protocol, object): self._can_send_records = False @m.state(initial=True) - def unselected(self): pass # pragma: no cover + def unselected(self): + pass # pragma: no cover @m.state() - def selecting(self): pass # pragma: no cover + def selecting(self): + pass # pragma: no cover @m.state() - def selected(self): pass # pragma: no cover + def selected(self): + pass # pragma: no cover @m.input() def got_kcm(self): diff --git a/src/wormhole/_dilation/old-follower.py b/src/wormhole/_dilation/old-follower.py index 68e3a38..7be4307 100644 --- a/src/wormhole/_dilation/old-follower.py +++ b/src/wormhole/_dilation/old-follower.py @@ -4,34 +4,53 @@ class ManagerFollower(_ManagerBase): set_trace = getattr(m, "_setTrace", lambda self, f: None) @m.state(initial=True) - def IDLE(self): pass # pragma: no cover + def IDLE(self): + pass # pragma: no cover @m.state() - def WANTING(self): pass # pragma: no cover + def WANTING(self): + pass # pragma: no cover + @m.state() - def CONNECTING(self): pass # pragma: no cover + def CONNECTING(self): + pass # pragma: no cover + @m.state() - def CONNECTED(self): pass # pragma: no cover + def CONNECTED(self): + pass # pragma: no cover + @m.state(terminal=True) - def STOPPED(self): pass # pragma: no cover + def STOPPED(self): + pass # pragma: no cover @m.input() - def start(self): pass # pragma: no cover - @m.input() - def rx_PLEASE(self): pass # pragma: no cover - @m.input() - def rx_DILATE(self): pass # pragma: no cover - @m.input() - def rx_HINTS(self, hint_message): pass # pragma: no cover + def start(self): + pass # pragma: no cover @m.input() - def connection_made(self): pass # pragma: no cover + def rx_PLEASE(self): + pass # pragma: no cover + @m.input() - def connection_lost(self): pass # pragma: no cover + def rx_DILATE(self): + pass # pragma: no cover + + @m.input() + def rx_HINTS(self, hint_message): + pass # pragma: no cover + + @m.input() + def connection_made(self): + pass # pragma: no cover + + @m.input() + def connection_lost(self): + pass # pragma: no cover # follower doesn't react to connection_lost, but waits for a new LETS_DILATE @m.input() - def stop(self): pass # pragma: no cover + def stop(self): + pass # pragma: no cover # these Outputs behave differently for the Leader vs the Follower @m.output() @@ -48,27 +67,32 @@ class ManagerFollower(_ManagerBase): @m.output() def use_hints(self, hint_message): - hint_objs = filter(lambda h: h, # ignore None, unrecognizable + hint_objs = filter(lambda h: h, # ignore None, unrecognizable [parse_hint(hs) for hs in hint_message["hints"]]) self._connector.got_hints(hint_objs) + @m.output() def stop_connecting(self): self._connector.stop() + @m.output() def use_connection(self, c): self._use_connection(c) + @m.output() def stop_using_connection(self): self._stop_using_connection() + @m.output() def signal_error(self): - pass # TODO + pass # TODO + @m.output() def signal_error_hints(self, hint_message): - pass # TODO + pass # TODO - IDLE.upon(rx_HINTS, enter=STOPPED, outputs=[signal_error_hints]) # too early - IDLE.upon(rx_DILATE, enter=STOPPED, outputs=[signal_error]) # too early + IDLE.upon(rx_HINTS, enter=STOPPED, outputs=[signal_error_hints]) # too early + IDLE.upon(rx_DILATE, enter=STOPPED, outputs=[signal_error]) # too early # leader shouldn't send us DILATE before receiving our PLEASE IDLE.upon(stop, enter=STOPPED, outputs=[]) IDLE.upon(start, enter=WANTING, outputs=[send_please]) @@ -78,7 +102,7 @@ class ManagerFollower(_ManagerBase): CONNECTING.upon(rx_HINTS, enter=CONNECTING, outputs=[use_hints]) CONNECTING.upon(connection_made, enter=CONNECTED, outputs=[use_connection]) # shouldn't happen: connection_lost - #CONNECTING.upon(connection_lost, enter=CONNECTING, outputs=[?]) + # CONNECTING.upon(connection_lost, enter=CONNECTING, outputs=[?]) CONNECTING.upon(rx_DILATE, enter=CONNECTING, outputs=[stop_connecting, start_connecting]) # receiving rx_DILATE while we're still working on the last one means the @@ -89,7 +113,7 @@ class ManagerFollower(_ManagerBase): CONNECTED.upon(connection_lost, enter=WANTING, outputs=[stop_using_connection]) CONNECTED.upon(rx_DILATE, enter=CONNECTING, outputs=[stop_using_connection, start_connecting]) - CONNECTED.upon(rx_HINTS, enter=CONNECTED, outputs=[]) # too late, ignore + CONNECTED.upon(rx_HINTS, enter=CONNECTED, outputs=[]) # too late, ignore CONNECTED.upon(stop, enter=STOPPED, outputs=[stop_using_connection]) # shouldn't happen: connection_made diff --git a/src/wormhole/test/dilate/common.py b/src/wormhole/test/dilate/common.py index 2ddacfb..4f398d7 100644 --- a/src/wormhole/test/dilate/common.py +++ b/src/wormhole/test/dilate/common.py @@ -3,16 +3,19 @@ import mock from zope.interface import alsoProvides from ..._interfaces import IDilationManager, IWormhole + def mock_manager(): m = mock.Mock() alsoProvides(m, IDilationManager) return m + def mock_wormhole(): m = mock.Mock() alsoProvides(m, IWormhole) return m + def clear_mock_calls(*args): for a in args: a.mock_calls[:] = [] diff --git a/src/wormhole/test/dilate/test_connection.py b/src/wormhole/test/dilate/test_connection.py index 406e42a..ee761fd 100644 --- a/src/wormhole/test/dilate/test_connection.py +++ b/src/wormhole/test/dilate/test_connection.py @@ -11,12 +11,13 @@ from ..._dilation.connection import (DilatedConnectionProtocol, encode_record, KCM, Open, Ack) from .common import clear_mock_calls + def make_con(role, use_relay=False): clock = Clock() eq = EventualQueue(clock) connector = mock.Mock() alsoProvides(connector, IDilationConnector) - n = mock.Mock() # pretends to be a Noise object + n = mock.Mock() # pretends to be a Noise object n.write_message = mock.Mock(side_effect=[b"handshake"]) c = DilatedConnectionProtocol(eq, role, connector, n, b"outbound_prologue\n", b"inbound_prologue\n") @@ -26,6 +27,7 @@ def make_con(role, use_relay=False): alsoProvides(t, ITransport) return c, n, connector, t, eq + class Connection(unittest.TestCase): def test_bad_prologue(self): c, n, connector, t, eq = make_con(LEADER) @@ -55,10 +57,10 @@ class Connection(unittest.TestCase): n.decrypt = mock.Mock(side_effect=[ encode_record(t_kcm), encode_record(t_open), - ]) + ]) exp_kcm = b"\x00\x00\x00\x03kcm" n.encrypt = mock.Mock(side_effect=[b"kcm", b"ack1"]) - m = mock.Mock() # Manager + m = mock.Mock() # Manager c.makeConnection(t) self.assertEqual(n.mock_calls, [mock.call.start_handshake()]) @@ -86,7 +88,7 @@ class Connection(unittest.TestCase): self.assertEqual(n.mock_calls, [ mock.call.read_message(b"handshake2"), mock.call.encrypt(encode_record(t_kcm)), - ]) + ]) self.assertEqual(connector.mock_calls, []) self.assertEqual(t.mock_calls, [ mock.call.write(exp_kcm)]) @@ -123,7 +125,7 @@ class Connection(unittest.TestCase): c.send_record(KCM()) self.assertEqual(n.mock_calls, [ mock.call.encrypt(encode_record(t_kcm)), - ]) + ]) self.assertEqual(connector.mock_calls, []) self.assertEqual(t.mock_calls, [mock.call.write(exp_kcm)]) self.assertEqual(m.mock_calls, []) @@ -163,7 +165,6 @@ class Connection(unittest.TestCase): def test_no_relay_follower(self): return self._test_no_relay(FOLLOWER) - def test_relay(self): c, n, connector, t, eq = make_con(LEADER, use_relay=True) diff --git a/src/wormhole/test/dilate/test_encoding.py b/src/wormhole/test/dilate/test_encoding.py index e2c854e..6bf2c7a 100644 --- a/src/wormhole/test/dilate/test_encoding.py +++ b/src/wormhole/test/dilate/test_encoding.py @@ -2,11 +2,12 @@ from __future__ import print_function, unicode_literals from twisted.trial import unittest from ..._dilation.encode import to_be4, from_be4 + class Encoding(unittest.TestCase): def test_be4(self): - self.assertEqual(to_be4(0), b"\x00\x00\x00\x00") - self.assertEqual(to_be4(1), b"\x00\x00\x00\x01") + self.assertEqual(to_be4(0), b"\x00\x00\x00\x00") + self.assertEqual(to_be4(1), b"\x00\x00\x00\x01") self.assertEqual(to_be4(256), b"\x00\x00\x01\x00") self.assertEqual(to_be4(257), b"\x00\x00\x01\x01") with self.assertRaises(ValueError): diff --git a/src/wormhole/test/dilate/test_endpoints.py b/src/wormhole/test/dilate/test_endpoints.py index bd8f995..ba07fe0 100644 --- a/src/wormhole/test/dilate/test_endpoints.py +++ b/src/wormhole/test/dilate/test_endpoints.py @@ -11,6 +11,7 @@ from ..._dilation.subchannel import (ControlEndpoint, SingleUseEndpointError) from .common import mock_manager + class Endpoints(unittest.TestCase): def test_control(self): scid0 = b"scid0" @@ -94,4 +95,4 @@ class Endpoints(unittest.TestCase): self.assertEqual(t2.mock_calls, [mock.call._set_protocol(p2)]) self.assertEqual(p2.mock_calls, [mock.call.makeConnection(t2)]) - lp.stopListening() # TODO: should this do more? + lp.stopListening() # TODO: should this do more? diff --git a/src/wormhole/test/dilate/test_framer.py b/src/wormhole/test/dilate/test_framer.py index 81d4cf9..51ac039 100644 --- a/src/wormhole/test/dilate/test_framer.py +++ b/src/wormhole/test/dilate/test_framer.py @@ -5,12 +5,14 @@ from twisted.trial import unittest from twisted.internet.interfaces import ITransport from ..._dilation.connection import _Framer, Frame, Prologue, Disconnect + def make_framer(): t = mock.Mock() alsoProvides(t, ITransport) f = _Framer(t, b"outbound_prologue\n", b"inbound_prologue\n") return f, t + class Framer(unittest.TestCase): def test_bad_prologue_length(self): f, t = make_framer() @@ -19,7 +21,7 @@ class Framer(unittest.TestCase): f.connectionMade() self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")]) t.mock_calls[:] = [] - self.assertEqual([], list(f.add_and_parse(b"inbound_"))) # wait for it + self.assertEqual([], list(f.add_and_parse(b"inbound_"))) # wait for it self.assertEqual(t.mock_calls, []) with mock.patch("wormhole._dilation.connection.log.msg") as m: @@ -37,7 +39,7 @@ class Framer(unittest.TestCase): f.connectionMade() self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")]) t.mock_calls[:] = [] - self.assertEqual([], list(f.add_and_parse(b"inbound_"))) # wait for it + self.assertEqual([], list(f.add_and_parse(b"inbound_"))) # wait for it self.assertEqual([], list(f.add_and_parse(b"not"))) with mock.patch("wormhole._dilation.connection.log.msg") as m: diff --git a/src/wormhole/test/dilate/test_inbound.py b/src/wormhole/test/dilate/test_inbound.py index d147283..392a661 100644 --- a/src/wormhole/test/dilate/test_inbound.py +++ b/src/wormhole/test/dilate/test_inbound.py @@ -8,6 +8,7 @@ from ..._dilation.inbound import (Inbound, DuplicateOpenError, DataForMissingSubchannelError, CloseForMissingSubchannelError) + def make_inbound(): m = mock.Mock() alsoProvides(m, IDilationManager) @@ -15,6 +16,7 @@ def make_inbound(): i = Inbound(m, host_addr) return i, m, host_addr + class InboundTest(unittest.TestCase): def test_seqnum(self): i, m, host_addr = make_inbound() @@ -158,7 +160,7 @@ class InboundTest(unittest.TestCase): self.assertEqual(c.mock_calls, [mock.call.pauseProducing()]) c.mock_calls[:] = [] i.subchannel_pauseProducing(sc2) - self.assertEqual(c.mock_calls, []) # was already paused + self.assertEqual(c.mock_calls, []) # was already paused # tolerate duplicate pauseProducing i.subchannel_pauseProducing(sc2) diff --git a/src/wormhole/test/dilate/test_manager.py b/src/wormhole/test/dilate/test_manager.py index 625039d..6acf25b 100644 --- a/src/wormhole/test/dilate/test_manager.py +++ b/src/wormhole/test/dilate/test_manager.py @@ -13,12 +13,14 @@ from ..._dilation.manager import (Dilator, from ..._dilation.subchannel import _WormholeAddress from .common import clear_mock_calls + def make_dilator(): reactor = object() clock = Clock() eq = EventualQueue(clock) - term = mock.Mock(side_effect=lambda: True) # one write per Eventual tick - term_factory = lambda: term + term = mock.Mock(side_effect=lambda: True) # one write per Eventual tick + + def term_factory(): return term coop = Cooperator(terminationPredicateFactory=term_factory, scheduler=eq.eventually) send = mock.Mock() @@ -27,6 +29,7 @@ def make_dilator(): dil.wire(send) return dil, send, reactor, eq, clock, coop + class TestDilator(unittest.TestCase): def test_leader(self): dil, send, reactor, eq, clock, coop = make_dilator() @@ -148,12 +151,11 @@ class TestDilator(unittest.TestCase): d1 = dil.dilate() self.assertNoResult(d1) - dil.got_wormhole_versions("me", "you", {}) # missing "can-dilate" + dil.got_wormhole_versions("me", "you", {}) # missing "can-dilate" eq.flush_sync() f = self.failureResultOf(d1) f.check(OldPeerCannotDilateError) - def test_disjoint_versions(self): dil, send, reactor, eq, clock, coop = make_dilator() d1 = dil.dilate() @@ -164,7 +166,6 @@ class TestDilator(unittest.TestCase): f = self.failureResultOf(d1) f.check(OldPeerCannotDilateError) - def test_early_dilate_messages(self): dil, send, reactor, eq, clock, coop = make_dilator() dil._transit_key = b"key" @@ -188,8 +189,6 @@ class TestDilator(unittest.TestCase): mock.call.rx_HINTS(hintmsg), mock.call.when_first_connected()]) - - def test_transit_relay(self): dil, send, reactor, eq, clock, coop = make_dilator() dil._transit_key = b"key" diff --git a/src/wormhole/test/dilate/test_outbound.py b/src/wormhole/test/dilate/test_outbound.py index fab596a..db38502 100644 --- a/src/wormhole/test/dilate/test_outbound.py +++ b/src/wormhole/test/dilate/test_outbound.py @@ -16,17 +16,20 @@ Pauser = namedtuple("Pauser", ["seqnum"]) NonPauser = namedtuple("NonPauser", ["seqnum"]) Stopper = namedtuple("Stopper", ["sc"]) + def make_outbound(): m = mock.Mock() alsoProvides(m, IDilationManager) clock = Clock() eq = EventualQueue(clock) - term = mock.Mock(side_effect=lambda: True) # one write per Eventual tick - term_factory = lambda: term + term = mock.Mock(side_effect=lambda: True) # one write per Eventual tick + + def term_factory(): return term coop = Cooperator(terminationPredicateFactory=term_factory, scheduler=eq.eventually) o = Outbound(m, coop) - c = mock.Mock() # Connection + c = mock.Mock() # Connection + def maybe_pause(r): if isinstance(r, Pauser): o.pauseProducing() @@ -37,6 +40,7 @@ def make_outbound(): o._test_term = term return o, m, c + class OutboundTest(unittest.TestCase): def test_build_record(self): o, m, c = make_outbound() @@ -69,10 +73,10 @@ class OutboundTest(unittest.TestCase): o.handle_ack(r3.seqnum) self.assertEqual(list(o._outbound_queue), []) - o.handle_ack(r3.seqnum) # ignored + o.handle_ack(r3.seqnum) # ignored self.assertEqual(list(o._outbound_queue), []) - o.handle_ack(r1.seqnum) # ignored + o.handle_ack(r1.seqnum) # ignored self.assertEqual(list(o._outbound_queue), []) def test_duplicate_registerProducer(self): @@ -192,7 +196,8 @@ class OutboundTest(unittest.TestCase): clear_mock_calls(c) sc1, sc2, sc3 = object(), object(), object() - p1, p2, p3 = mock.Mock(name="p1"), mock.Mock(name="p2"), mock.Mock(name="p3") + p1, p2, p3 = mock.Mock(name="p1"), mock.Mock( + name="p2"), mock.Mock(name="p3") # we aren't paused yet, since we haven't sent any data o.subchannel_registerProducer(sc1, p1, True) @@ -310,7 +315,8 @@ class OutboundTest(unittest.TestCase): # and another disconnects itself when called p2.resumeProducing.side_effect = lambda: None - p3.resumeProducing.side_effect = lambda: o.subchannel_unregisterProducer(sc3) + p3.resumeProducing.side_effect = lambda: o.subchannel_unregisterProducer( + sc3) o.pauseProducing() o.resumeProducing() self.assertEqual(p2.mock_calls, [mock.call.pauseProducing(), @@ -360,7 +366,7 @@ class OutboundTest(unittest.TestCase): r2 = NonPauser(seqnum=2) # we aren't paused yet, since we haven't sent any data - o.subchannel_registerProducer(sc1, p1, True) # push + o.subchannel_registerProducer(sc1, p1, True) # push o.queue_and_send_record(r1) # now we're paused self.assertTrue(o._paused) @@ -371,7 +377,7 @@ class OutboundTest(unittest.TestCase): p1.resumeProducing.side_effect = lambda: c.send_record(r1) p2.resumeProducing.side_effect = lambda: c.send_record(r2) - o.subchannel_registerProducer(sc2, p2, False) # pull: always ready + o.subchannel_registerProducer(sc2, p2, False) # pull: always ready # p1 is still first, since p2 was just added (at the end) self.assertTrue(o._paused) @@ -390,7 +396,7 @@ class OutboundTest(unittest.TestCase): mock.call.pauseProducing(), ]) self.assertEqual(p2.mock_calls, []) - self.assertEqual(list(o._all_producers), [p2, p1]) # now p2 is next + self.assertEqual(list(o._all_producers), [p2, p1]) # now p2 is next clear_mock_calls(p1, p2, c) # next should fire p2, then p1 @@ -404,7 +410,7 @@ class OutboundTest(unittest.TestCase): ]) self.assertEqual(p2.mock_calls, [mock.call.resumeProducing(), ]) - self.assertEqual(list(o._all_producers), [p2, p1]) # p2 still at bat + self.assertEqual(list(o._all_producers), [p2, p1]) # p2 still at bat clear_mock_calls(p1, p2, c) def test_pull_producer(self): @@ -430,13 +436,13 @@ class OutboundTest(unittest.TestCase): it = iter(records) p1.resumeProducing.side_effect = lambda: c.send_record(next(it)) o.subchannel_registerProducer(sc1, p1, False) - eq.flush_sync() # fast forward into the glorious (paused) future + eq.flush_sync() # fast forward into the glorious (paused) future self.assertTrue(o._paused) self.assertEqual(c.mock_calls, [mock.call.send_record(r) for r in records[:-1]]) self.assertEqual(p1.mock_calls, - [mock.call.resumeProducing()]*(len(records)-1)) + [mock.call.resumeProducing()] * (len(records) - 1)) clear_mock_calls(c, p1) # next resumeProducing should cause it to disconnect @@ -460,7 +466,7 @@ class OutboundTest(unittest.TestCase): NonPauser(3), NonPauser(13), NonPauser(4), NonPauser(14), Pauser(5)] - expected2 = [ NonPauser(15), + expected2 = [NonPauser(15), NonPauser(6), NonPauser(16), NonPauser(7), NonPauser(17), NonPauser(8), NonPauser(18), @@ -487,14 +493,14 @@ class OutboundTest(unittest.TestCase): p2.resumeProducing.side_effect = lambda: c.send_record(next(it2)) o.subchannel_registerProducer(sc2, p2, False) - eq.flush_sync() # fast forward into the glorious (paused) future + eq.flush_sync() # fast forward into the glorious (paused) future sends = [mock.call.resumeProducing()] self.assertTrue(o._paused) self.assertEqual(c.mock_calls, [mock.call.send_record(r) for r in expected1]) - self.assertEqual(p1.mock_calls, 6*sends) - self.assertEqual(p2.mock_calls, 5*sends) + self.assertEqual(p1.mock_calls, 6 * sends) + self.assertEqual(p2.mock_calls, 5 * sends) clear_mock_calls(c, p1, p2) o.resumeProducing() @@ -502,13 +508,13 @@ class OutboundTest(unittest.TestCase): self.assertTrue(o._paused) self.assertEqual(c.mock_calls, [mock.call.send_record(r) for r in expected2]) - self.assertEqual(p1.mock_calls, 4*sends) - self.assertEqual(p2.mock_calls, 5*sends) + self.assertEqual(p1.mock_calls, 4 * sends) + self.assertEqual(p2.mock_calls, 5 * sends) clear_mock_calls(c, p1, p2) def test_send_if_connected(self): o, m, c = make_outbound() - o.send_if_connected(Ack(1)) # not connected yet + o.send_if_connected(Ack(1)) # not connected yet o.use_connection(c) o.send_if_connected(KCM()) @@ -517,7 +523,7 @@ class OutboundTest(unittest.TestCase): def test_tolerate_duplicate_pause_resume(self): o, m, c = make_outbound() - self.assertTrue(o._paused) # no connection + self.assertTrue(o._paused) # no connection o.use_connection(c) self.assertFalse(o._paused) o.pauseProducing() @@ -533,7 +539,7 @@ class OutboundTest(unittest.TestCase): o, m, c = make_outbound() o.use_connection(c) self.assertFalse(o._paused) - o.stopProducing() # connection does this before loss + o.stopProducing() # connection does this before loss self.assertTrue(o._paused) o.stop_using_connection() self.assertTrue(o._paused) @@ -559,13 +565,15 @@ def make_pushpull(pauses): clock = Clock() eq = EventualQueue(clock) - term = mock.Mock(side_effect=lambda: True) # one write per Eventual tick - term_factory = lambda: term + term = mock.Mock(side_effect=lambda: True) # one write per Eventual tick + + def term_factory(): return term coop = Cooperator(terminationPredicateFactory=term_factory, scheduler=eq.eventually) pp = PullToPush(p, unregister, coop) it = cycle(pauses) + def action(i): if isinstance(i, Exception): raise i @@ -574,41 +582,45 @@ def make_pushpull(pauses): p.resumeProducing.side_effect = lambda: action(next(it)) return p, unregister, pp, eq + class PretendResumptionError(Exception): pass + + class PretendUnregisterError(Exception): pass + class PushPull(unittest.TestCase): # test our wrapper utility, which I copied from # twisted.internet._producer_helpers since it isn't publically exposed def test_start_unpaused(self): - p, unr, pp, eq = make_pushpull([True]) # pause on each resumeProducing + p, unr, pp, eq = make_pushpull([True]) # pause on each resumeProducing # if it starts unpaused, it gets one write before being halted pp.startStreaming(False) eq.flush_sync() - self.assertEqual(p.mock_calls, [mock.call.resumeProducing()]*1) + self.assertEqual(p.mock_calls, [mock.call.resumeProducing()] * 1) clear_mock_calls(p) # now each time we call resumeProducing, we should see one delivered to # the underlying IPullProducer pp.resumeProducing() eq.flush_sync() - self.assertEqual(p.mock_calls, [mock.call.resumeProducing()]*1) + self.assertEqual(p.mock_calls, [mock.call.resumeProducing()] * 1) pp.stopStreaming() - pp.stopStreaming() # should tolerate this + pp.stopStreaming() # should tolerate this def test_start_unpaused_two_writes(self): - p, unr, pp, eq = make_pushpull([False, True]) # pause every other time + p, unr, pp, eq = make_pushpull([False, True]) # pause every other time # it should get two writes, since the first didn't pause pp.startStreaming(False) eq.flush_sync() - self.assertEqual(p.mock_calls, [mock.call.resumeProducing()]*2) + self.assertEqual(p.mock_calls, [mock.call.resumeProducing()] * 2) def test_start_paused(self): - p, unr, pp, eq = make_pushpull([True]) # pause on each resumeProducing + p, unr, pp, eq = make_pushpull([True]) # pause on each resumeProducing pp.startStreaming(True) eq.flush_sync() self.assertEqual(p.mock_calls, []) @@ -637,9 +649,5 @@ class PushPull(unittest.TestCase): self.assertEqual(unr.mock_calls, [mock.call()]) self.flushLoggedErrors(PretendResumptionError, PretendUnregisterError) - - - - # TODO: consider making p1/p2/p3 all elements of a shared Mock, maybe I # could capture the inter-call ordering that way diff --git a/src/wormhole/test/dilate/test_parse.py b/src/wormhole/test/dilate/test_parse.py index 8365e62..f7276a6 100644 --- a/src/wormhole/test/dilate/test_parse.py +++ b/src/wormhole/test/dilate/test_parse.py @@ -4,6 +4,7 @@ from twisted.trial import unittest from ..._dilation.connection import (parse_record, encode_record, KCM, Ping, Pong, Open, Data, Close, Ack) + class Parse(unittest.TestCase): def test_parse(self): self.assertEqual(parse_record(b"\x00"), KCM()) diff --git a/src/wormhole/test/dilate/test_record.py b/src/wormhole/test/dilate/test_record.py index 810396c..41b36e3 100644 --- a/src/wormhole/test/dilate/test_record.py +++ b/src/wormhole/test/dilate/test_record.py @@ -7,13 +7,15 @@ from ..._dilation.connection import (IFramer, Frame, Prologue, _Record, Handshake, Disconnect, Ping) + def make_record(): f = mock.Mock() alsoProvides(f, IFramer) - n = mock.Mock() # pretends to be a Noise object + n = mock.Mock() # pretends to be a Noise object r = _Record(f, n) return r, f, n + class Record(unittest.TestCase): def test_good2(self): f = mock.Mock() @@ -23,7 +25,7 @@ class Record(unittest.TestCase): [Prologue()], [Frame(frame=b"rx-handshake")], [Frame(frame=b"frame1"), Frame(frame=b"frame2")], - ]) + ]) n = mock.Mock() n.write_message = mock.Mock(return_value=b"tx-handshake") p1, p2 = object(), object() @@ -60,9 +62,9 @@ class Record(unittest.TestCase): n.mock_calls[:] = [] # next is a pair of Records - r1, r2 = object() , object() + r1, r2 = object(), object() with mock.patch("wormhole._dilation.connection.parse_record", - side_effect=[r1,r2]) as pr: + side_effect=[r1, r2]) as pr: self.assertEqual(list(r.add_and_unframe(b"blah2")), [r1, r2]) self.assertEqual(n.mock_calls, [mock.call.decrypt(b"frame1"), mock.call.decrypt(b"frame2")]) @@ -186,7 +188,7 @@ class Record(unittest.TestCase): n.write_message = mock.Mock(return_value=outbound_handshake) n.decrypt = mock.Mock(side_effect=[kcm, msg1]) n.encrypt = mock.Mock(side_effect=[f_kcm, f_msg1]) - f.add_and_parse = mock.Mock(side_effect=[[], # no tokens yet + f.add_and_parse = mock.Mock(side_effect=[[], # no tokens yet [Prologue()], [Frame("f_handshake")], [Frame("f_kcm"), @@ -238,7 +240,6 @@ class Record(unittest.TestCase): f.mock_calls[:] = [] n.mock_calls[:] = [] - # 5: at this point we ought to be able to send a messge, the KCM with mock.patch("wormhole._dilation.connection.encode_record", side_effect=[b"r-kcm"]) as er: diff --git a/src/wormhole/test/dilate/test_subchannel.py b/src/wormhole/test/dilate/test_subchannel.py index 69fa001..d56cdf1 100644 --- a/src/wormhole/test/dilate/test_subchannel.py +++ b/src/wormhole/test/dilate/test_subchannel.py @@ -8,6 +8,7 @@ from ..._dilation.subchannel import (Once, SubChannel, AlreadyClosedError) from .common import mock_manager + def make_sc(set_protocol=True): scid = b"scid" hostaddr = _WormholeAddress() @@ -19,6 +20,7 @@ def make_sc(set_protocol=True): sc._set_protocol(p) return sc, m, scid, hostaddr, peeraddr, p + class SubChannelAPI(unittest.TestCase): def test_once(self): o = Once(ValueError) From 5f61531445df878f06edcb49bc59a8cb571f83cc Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Sat, 30 Jun 2018 16:24:28 -0700 Subject: [PATCH 11/49] more flake8 fixes, in wormhole.py and _interfaces.py --- src/wormhole/_interfaces.py | 11 +++++++++++ src/wormhole/wormhole.py | 6 +++--- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/src/wormhole/_interfaces.py b/src/wormhole/_interfaces.py index 52bde35..15302d1 100644 --- a/src/wormhole/_interfaces.py +++ b/src/wormhole/_interfaces.py @@ -434,15 +434,26 @@ class IInputHelper(Interface): class IJournal(Interface): # TODO: this needs to be public pass + class IDilator(Interface): pass + + class IDilationManager(Interface): pass + + class IDilationConnector(Interface): pass + + class ISubChannel(Interface): pass + + class IInbound(Interface): pass + + class IOutbound(Interface): pass diff --git a/src/wormhole/wormhole.py b/src/wormhole/wormhole.py index f967c25..3734fd2 100644 --- a/src/wormhole/wormhole.py +++ b/src/wormhole/wormhole.py @@ -191,7 +191,7 @@ class _DeferredWormhole(object): return derive_key(self._key, to_bytes(purpose), length) def dilate(self): - return self._boss.dilate() # fires with (endpoints) + return self._boss.dilate() # fires with (endpoints) def close(self): # fails with WormholeError unless we established a connection @@ -273,8 +273,8 @@ def create( wormhole_versions = { "can-dilate": [1], "dilation-abilities": Connector.get_connection_abilities(), - } - wormhole_versions["app_versions"] = versions # app-specific capabilities + } + wormhole_versions["app_versions"] = versions # app-specific capabilities v = __version__ if isinstance(v, type(b"")): v = v.decode("utf-8", errors="replace") From 48d740406bdadd4e9509125174259a22701a5cc1 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Sat, 30 Jun 2018 17:06:34 -0700 Subject: [PATCH 12/49] setup.cfg: bump flake8 max-line-length to 85 --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 996e37d..c2199e9 100644 --- a/setup.cfg +++ b/setup.cfg @@ -9,4 +9,4 @@ tag_prefix = parentdir_prefix = magic-wormhole [flake8] -max-line-length = 84 +max-line-length = 85 From d65fcaa1a61d64e15013db745caa47af70aea4d7 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Sat, 30 Jun 2018 17:06:48 -0700 Subject: [PATCH 13/49] more flake8 fixes --- src/wormhole/test/dilate/test_manager.py | 3 ++- src/wormhole/test/dilate/test_outbound.py | 6 ++++-- src/wormhole/test/test_machines.py | 4 ++-- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/wormhole/test/dilate/test_manager.py b/src/wormhole/test/dilate/test_manager.py index 6acf25b..690e1cf 100644 --- a/src/wormhole/test/dilate/test_manager.py +++ b/src/wormhole/test/dilate/test_manager.py @@ -20,7 +20,8 @@ def make_dilator(): eq = EventualQueue(clock) term = mock.Mock(side_effect=lambda: True) # one write per Eventual tick - def term_factory(): return term + def term_factory(): + return term coop = Cooperator(terminationPredicateFactory=term_factory, scheduler=eq.eventually) send = mock.Mock() diff --git a/src/wormhole/test/dilate/test_outbound.py b/src/wormhole/test/dilate/test_outbound.py index db38502..6ba5264 100644 --- a/src/wormhole/test/dilate/test_outbound.py +++ b/src/wormhole/test/dilate/test_outbound.py @@ -24,7 +24,8 @@ def make_outbound(): eq = EventualQueue(clock) term = mock.Mock(side_effect=lambda: True) # one write per Eventual tick - def term_factory(): return term + def term_factory(): + return term coop = Cooperator(terminationPredicateFactory=term_factory, scheduler=eq.eventually) o = Outbound(m, coop) @@ -567,7 +568,8 @@ def make_pushpull(pauses): eq = EventualQueue(clock) term = mock.Mock(side_effect=lambda: True) # one write per Eventual tick - def term_factory(): return term + def term_factory(): + return term coop = Cooperator(terminationPredicateFactory=term_factory, scheduler=eq.eventually) pp = PullToPush(p, unregister, coop) diff --git a/src/wormhole/test/test_machines.py b/src/wormhole/test/test_machines.py index 35854ae..fe956c1 100644 --- a/src/wormhole/test/test_machines.py +++ b/src/wormhole/test/test_machines.py @@ -12,8 +12,8 @@ import mock from .. import (__version__, _allocator, _boss, _code, _input, _key, _lister, _mailbox, _nameplate, _order, _receive, _rendezvous, _send, _terminator, errors, timing) -from .._interfaces import (IAllocator, IBoss, ICode, IDilator, IInput, IKey, ILister, - IMailbox, INameplate, IOrder, IReceive, +from .._interfaces import (IAllocator, IBoss, ICode, IDilator, IInput, IKey, + ILister, IMailbox, INameplate, IOrder, IReceive, IRendezvousConnector, ISend, ITerminator, IWordlist) from .._key import derive_key, derive_phase_key, encrypt_data from ..journal import ImmediateJournal From 5dca0542ebdd999146d33d015b402978e2211f96 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Sun, 1 Jul 2018 14:45:56 -0700 Subject: [PATCH 14/49] travis: tolerate py3.4 failure because of txtorcon bug https://github.com/meejah/txtorcon/issues/306 --- .travis.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.travis.yml b/.travis.yml index 69344cf..3bcad1a 100644 --- a/.travis.yml +++ b/.travis.yml @@ -36,6 +36,8 @@ matrix: allow_failures: - python: 2.7 - python: 3.3 + # txtorcon is currently broken on py3.4 + - python: 3.4 # travis doesn't support py3.7 yet - python: 3.7 - python: nightly From 72c9683cdfa3db3ce5996ed41169f7e076ade555 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Sun, 1 Jul 2018 16:40:52 -0700 Subject: [PATCH 15/49] dilation-protocol.md: update for new PLEASE+PLEASE approach --- docs/dilation-protocol.md | 116 ++++++++++++++++++++++---------------- 1 file changed, 68 insertions(+), 48 deletions(-) diff --git a/docs/dilation-protocol.md b/docs/dilation-protocol.md index 4f0dac1..ab5b50b 100644 --- a/docs/dilation-protocol.md +++ b/docs/dilation-protocol.md @@ -13,19 +13,20 @@ messages are used to open/use/close the application-visible subchannels. ## Leaders and Followers -Each side of a Wormhole has a randomly-generated "side" string. When the -wormhole is dilated, the side with the lexicographically-higher "side" value -is named the "Leader", and the other side is named the "Follower". The general -wormhole protocol treats both sides identically, but the distinction matters -for the dilation protocol. +Each side of a Wormhole has a randomly-generated dilation "side" string (this +is independent of the Wormhole's mailbox "side"). When the wormhole is +dilated, the side with the lexicographically-higher "side" value is named the +"Leader", and the other side is named the "Follower". The general wormhole +protocol treats both sides identically, but the distinction matters for the +dilation protocol. -Either side can trigger dilation, but the Follower does so by asking the -Leader to start the process, whereas the Leader just starts the process -unilaterally. The Leader has exclusive control over whether a given -connection is considered established or not: if there are multiple potential -connections to use, the Leader decides which one to use, and the Leader gets -to decide when the connection is no longer viable (and triggers the -establishment of a new one). +Both sides send a `please-dilate` as soon as dilation is triggered. Each side +discovers whether it is the Leader or the Follower when the peer's +"please-dilate" arrives. The Leader has exclusive control over whether a +given connection is considered established or not: if there are multiple +potential connections to use, the Leader decides which one to use, and the +Leader gets to decide when the connection is no longer viable (and triggers +the establishment of a new one). ## Connection Layers @@ -110,51 +111,60 @@ Each `DILATE-n` message is a JSON-encoded dictionary with a `type` field that has a string value. The dictionary will have other keys that depend upon the type. -`w.dilate()` triggers a `please-dilate` record with a set of versions that -can be accepted. Both Leader and Follower emit this record, although the -Leader is responsible for version decisions. Versions use strings, rather -than integers, to support experimental protocols, however there is still a -total ordering of preferability. +`w.dilate()` triggers transmission of a `please-dilate` record with a set of +versions that can be accepted. Versions use strings, rather than integers, to +support experimental protocols, however there is still a total ordering of +preferability. ``` { "type": "please-dilate", + "side": "abcdef", "accepted-versions": ["1"] } ``` -The Leader then sends a `start-dilation` message with a `version` field (the -"best" mutually-supported value) and the new "L2 generation" number in the -`generation` field. Generation numbers are integers, monotonically increasing -by 1 each time. +If one side receives a `please-dilate` before `w.dilate()` has been called +locally, the contents are stored in case `w.dilate()` is called in the +future. Once both `w.dilate()` has been called and the peer's `please-dilate` +has been received, the side determines whether it is the Leader or the +Follower. Both sides also compare `accepted-versions` fields to choose the +best mutually-compatible version to use: they should always pick the same +one. -``` -{ "type": start-dilation, - "version": "1", - "generation": 1, -} -``` +Then both sides begin the connection process for generation 1 by opening +listening sockets and sending `connection-hint` records for each one. After a +slight delay they will also open connections to the Transit Relay of their +choice and produce hints for it too. The receipt of inbound hints (on both +sides) will trigger outbound connection attempts. -The Follower responds with a `ok-dilation` message with matching `version` -and `generation` fields. +Some number of these connections may succeed, and the Leader decides which to +use (via an in-band signal on the established connection). The others are +dropped. -The Leader decides when a new dilation connection is necessary, both for the -initial connection and any subsequent reconnects. Therefore the Leader has -the exclusive right to send the `start-dilation` record. It won't send this -until after it has sent its own `please-dilate`, and after it has received -the Follower's `please-dilate`. As a result, local preparations may begin as -soon as `w.dilate()` is called, but L2 connections do not begin until the -Leader declares the start of a new L2 generation with the `start-dilation` -message. +If something goes wrong with the established connection and the Leader +decides a new one is necessary, the Leader will send a `reconnect` message. +This might happen while connections are still being established, or while the +Follower thinks it still has a viable connection (the Leader might observer +problems that the Follower does not), or after the Follower thinks the +connection has been lost. In all cases, the Leader is the only side which +should send `reconnect`. The state machine code looks the same on both sides, +for simplicity, but one path on each side is never used. + +Upon receiving a `reconnect`, the Follower should stop any pending connection +attempts and terminate any existing connections (even if they appear viable). +Listening sockets may be retained, but any previous connection made through +them must be dropped. + +Once all connections have stopped, the Follower should send an `ok` message, +then start the connection process for the next generation, which will send +new `connection-hint` messages for all listening sockets). Generations are non-overlapping. The Leader will drop all connections from -generation 1 before sending the `start-dilation` for generation 2, and will -not initiate any gen-2 connections until it receives the matching -`ok-dilation` from the Follower. The Follower must drop all gen-1 connections -before it sends the `ok-dilation` response (even if it thinks they are still -functioning: if the Leader thought the gen-1 connection still worked, it -wouldn't have started gen-2). Listening sockets can be retained, but any -previous connection made through them must be dropped. This should avoid a -race. +generation 1 before sending the `reconnect` for generation 2, and will not +initiate any gen-2 connections until it receives the matching `ok` from the +Follower. The Follower must drop all gen-1 connections before it sends the +`ok` response (even if it thinks they are still functioning: if the Leader +thought the gen-1 connection still worked, it wouldn't have started gen-2). (TODO: what about a follower->leader connection that was started before start-dilation is received, and gets established on the Leader side after @@ -166,6 +176,16 @@ derived key) (TODO: reduce the number of round-trip stalls here, I've added too many) +Each side is in the "connecting" state (which encompasses both making +connection attempts and having an established connection) starting with the +receipt of a `please-dilate` message and a local `w.dilate()` call. The +Leader remains in that state until it abandons the connection and sends a +`reconnect` message, at which point it remains in the "flushing" state until +the Follower's `ok` message is received. The Follower remains in "connecting" +until it receives `reconnect`, then it stays in "dropping" until it finishes +halting all outstanding connections, after which it sends `ok` and switches +back to "connecting". + "Connection hints" are type/address/port records that tell the other side of likely targets for L2 connections. Both sides will try to determine their external IP addresses, listen on a TCP port, and advertise `(tcp, @@ -225,7 +245,7 @@ everything until the first newline. Everything beyond that point is a Noise protocol message, which consists of a 4-byte big-endian length field, followed by the indicated number of bytes. -This ises the `NNpsk0` pattern with the Leader as the first party ("-> psk, +This uses the `NNpsk0` pattern with the Leader as the first party ("-> psk, e" in the Noise spec), and the Follower as the second ("<- e, ee"). The pre-shared-key is the "dilation key", which is statically derived from the master PAKE key using HKDF. Each L2 connection uses the same dilation key, @@ -241,8 +261,8 @@ protocol object to the L3 manager, which will decide which connection to select. When the L2 connection is selected to be the new L3, it will send an empty KCM of its own, to let the Follower know the connection being selected. All other L2 connections (either viable or still in handshake) are dropped, -all other connection attempts are cancelled, and all listening sockets are -shut down. +all other connection attempts are cancelled. All listening sockets may or may +not be shut down (TODO: think about it). The Follower will wait for either an empty KCM (at which point the L2 connection is delivered to the Dilation manager as the new L3), a From d4c9210a4e44287b82a175377c8d13a6e4ce0b45 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Sun, 1 Jul 2018 17:03:57 -0700 Subject: [PATCH 16/49] more docs updates --- docs/dilation-protocol.md | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/docs/dilation-protocol.md b/docs/dilation-protocol.md index ab5b50b..e7a90bd 100644 --- a/docs/dilation-protocol.md +++ b/docs/dilation-protocol.md @@ -37,8 +37,8 @@ L1 is the mailbox channel (queued store-and-forward messages that always go to the mailbox server, and then are forwarded to other clients subscribed to the same mailbox). Both clients remain connected to the mailbox server until the Wormhole is closed. They send DILATE-n messages to each other to manage -the dilation process, including records like `please-dilate`, -`start-dilation`, `ok-dilation`, and `connection-hints` +the dilation process, including records like `please`, `connection-hints`, +`reconnect`, and `reconnecting`. L2 is the set of competing connection attempts for a given generation of connection. Each time the Leader decides to establish a new connection, a new @@ -155,16 +155,17 @@ attempts and terminate any existing connections (even if they appear viable). Listening sockets may be retained, but any previous connection made through them must be dropped. -Once all connections have stopped, the Follower should send an `ok` message, -then start the connection process for the next generation, which will send -new `connection-hint` messages for all listening sockets). +Once all connections have stopped, the Follower should send an `reconnecting` +message, then start the connection process for the next generation, which +will send new `connection-hint` messages for all listening sockets). Generations are non-overlapping. The Leader will drop all connections from generation 1 before sending the `reconnect` for generation 2, and will not -initiate any gen-2 connections until it receives the matching `ok` from the -Follower. The Follower must drop all gen-1 connections before it sends the -`ok` response (even if it thinks they are still functioning: if the Leader -thought the gen-1 connection still worked, it wouldn't have started gen-2). +initiate any gen-2 connections until it receives the matching `reconnecting` +from the Follower. The Follower must drop all gen-1 connections before it +sends the `reconnecting` response (even if it thinks they are still +functioning: if the Leader thought the gen-1 connection still worked, it +wouldn't have started gen-2). (TODO: what about a follower->leader connection that was started before start-dilation is received, and gets established on the Leader side after @@ -181,19 +182,18 @@ connection attempts and having an established connection) starting with the receipt of a `please-dilate` message and a local `w.dilate()` call. The Leader remains in that state until it abandons the connection and sends a `reconnect` message, at which point it remains in the "flushing" state until -the Follower's `ok` message is received. The Follower remains in "connecting" -until it receives `reconnect`, then it stays in "dropping" until it finishes -halting all outstanding connections, after which it sends `ok` and switches -back to "connecting". +the Follower's `reconnecting` message is received. The Follower remains in +"connecting" until it receives `reconnect`, then it stays in "dropping" until +it finishes halting all outstanding connections, after which it sends +`reconnecting` and switches back to "connecting". "Connection hints" are type/address/port records that tell the other side of likely targets for L2 connections. Both sides will try to determine their external IP addresses, listen on a TCP port, and advertise `(tcp, external-IP, port)` as a connection hint. The Transit Relay is also used as a (lower-priority) hint. These are sent in `connection-hint` records, which can -be sent by the Leader any time after the `start-dilation` record, and by the -Follower after the `ok-dilation` record. Each side will initiate connections -upon receipt of the hints. +be sent any time after both sending and receiving a `please` record. Each +side will initiate connections upon receipt of the hints. ``` { "type": "connection-hints", From a594a854272e73bb3b299397c55e44dcaf37a474 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Sun, 1 Jul 2018 17:08:23 -0700 Subject: [PATCH 17/49] Revert "Boss/Receive: add 'side' to got_message" This reverts commit 1fece5701c9de5e470526d2e7e9cfd7b461977e0. --- src/wormhole/_boss.py | 10 ++++------ src/wormhole/_receive.py | 14 +++++++------- src/wormhole/test/test_machines.py | 18 +++++++++--------- 3 files changed, 20 insertions(+), 22 deletions(-) diff --git a/src/wormhole/_boss.py b/src/wormhole/_boss.py index d3cb5ff..8d01e8d 100644 --- a/src/wormhole/_boss.py +++ b/src/wormhole/_boss.py @@ -264,13 +264,12 @@ class Boss(object): def scared(self): pass - def got_message(self, side, phase, plaintext): - # this is only called for side != ours + def got_message(self, phase, plaintext): assert isinstance(phase, type("")), type(phase) assert isinstance(plaintext, type(b"")), type(plaintext) d_mo = re.search(r'^dilate-(\d+)$', phase) if phase == "version": - self._got_version(side, plaintext) + self._got_version(plaintext) elif d_mo: self._got_dilate(int(d_mo.group(1)), plaintext) elif re.search(r'^\d+$', phase): @@ -281,7 +280,7 @@ class Boss(object): log.err(_UnknownPhaseError("received unknown phase '%s'" % phase)) @m.input() - def _got_version(self, side, plaintext): + def _got_version(self, plaintext): pass @m.input() @@ -310,10 +309,9 @@ class Boss(object): self._W.got_code(code) @m.output() - def process_version(self, side, plaintext): + def process_version(self, plaintext): # most of this is wormhole-to-wormhole, ignored for now # in the future, this is how Dilation is signalled - self._their_side = side self._their_versions = bytes_to_dict(plaintext) self._D.got_wormhole_versions(self._side, self._their_side, self._their_versions) diff --git a/src/wormhole/_receive.py b/src/wormhole/_receive.py index 832dc44..8e9de4f 100644 --- a/src/wormhole/_receive.py +++ b/src/wormhole/_receive.py @@ -53,10 +53,10 @@ class Receive(object): except CryptoError: self.got_message_bad() return - self.got_message_good(side, phase, plaintext) + self.got_message_good(phase, plaintext) @m.input() - def got_message_good(self, side, phase, plaintext): + def got_message_good(self, phase, plaintext): pass @m.input() @@ -73,23 +73,23 @@ class Receive(object): self._key = key @m.output() - def S_got_verified_key(self, side, phase, plaintext): + def S_got_verified_key(self, phase, plaintext): assert self._key self._S.got_verified_key(self._key) @m.output() - def W_happy(self, side, phase, plaintext): + def W_happy(self, phase, plaintext): self._B.happy() @m.output() - def W_got_verifier(self, side, phase, plaintext): + def W_got_verifier(self, phase, plaintext): self._B.got_verifier(derive_key(self._key, b"wormhole:verifier")) @m.output() - def W_got_message(self, side, phase, plaintext): + def W_got_message(self, phase, plaintext): assert isinstance(phase, type("")), type(phase) assert isinstance(plaintext, type(b"")), type(plaintext) - self._B.got_message(side, phase, plaintext) + self._B.got_message(phase, plaintext) @m.output() def W_scared(self): diff --git a/src/wormhole/test/test_machines.py b/src/wormhole/test/test_machines.py index fe956c1..c989e0a 100644 --- a/src/wormhole/test/test_machines.py +++ b/src/wormhole/test/test_machines.py @@ -167,7 +167,7 @@ class Receive(unittest.TestCase): ("s.got_verified_key", key), ("b.happy", ), ("b.got_verifier", verifier), - ("b.got_message", u"side", u"phase1", data1), + ("b.got_message", u"phase1", data1), ]) phase2_key = derive_phase_key(key, u"side", u"phase2") @@ -178,8 +178,8 @@ class Receive(unittest.TestCase): ("s.got_verified_key", key), ("b.happy", ), ("b.got_verifier", verifier), - ("b.got_message", u"side", u"phase1", data1), - ("b.got_message", u"side", u"phase2", data2), + ("b.got_message", u"phase1", data1), + ("b.got_message", u"phase2", data2), ]) def test_early_bad(self): @@ -217,7 +217,7 @@ class Receive(unittest.TestCase): ("s.got_verified_key", key), ("b.happy", ), ("b.got_verifier", verifier), - ("b.got_message", u"side", u"phase1", data1), + ("b.got_message", u"phase1", data1), ]) phase2_key = derive_phase_key(key, u"side", u"bad") @@ -228,7 +228,7 @@ class Receive(unittest.TestCase): ("s.got_verified_key", key), ("b.happy", ), ("b.got_verifier", verifier), - ("b.got_message", u"side", u"phase1", data1), + ("b.got_message", u"phase1", data1), ("b.scared", ), ]) r.got_message(u"side", u"phase1", good_body) @@ -237,7 +237,7 @@ class Receive(unittest.TestCase): ("s.got_verified_key", key), ("b.happy", ), ("b.got_verifier", verifier), - ("b.got_message", u"side", u"phase1", data1), + ("b.got_message", u"phase1", data1), ("b.scared", ), ]) @@ -1324,8 +1324,8 @@ class Boss(unittest.TestCase): b.got_key(b"key") b.happy() b.got_verifier(b"verifier") - b.got_message("side", "version", b"{}") - b.got_message("side", "0", b"msg1") + b.got_message("version", b"{}") + b.got_message("0", b"msg1") self.assertEqual(events, [ ("w.got_key", b"key"), ("d.got_key", b"key"), @@ -1483,7 +1483,7 @@ class Boss(unittest.TestCase): b.happy() # phase=version - b.got_message("side", "unknown-phase", b"spooky") + b.got_message("unknown-phase", b"spooky") self.assertEqual(events, []) self.flushLoggedErrors(errors._UnknownPhaseError) From 74e5d9948b0cc4e62d85a10e8e4426645fc7d7d5 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Sun, 1 Jul 2018 21:14:02 -0700 Subject: [PATCH 18/49] move old-follower.py out of src, kept for temporary reference --- src/wormhole/_dilation/old-follower.py => old-follower.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename src/wormhole/_dilation/old-follower.py => old-follower.py (100%) diff --git a/src/wormhole/_dilation/old-follower.py b/old-follower.py similarity index 100% rename from src/wormhole/_dilation/old-follower.py rename to old-follower.py From ec5df72cd3c62787d113d6650849a964fca595ed Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Sun, 1 Jul 2018 21:14:21 -0700 Subject: [PATCH 19/49] more protocol docs updates --- docs/dilation-protocol.md | 31 +++++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/docs/dilation-protocol.md b/docs/dilation-protocol.md index e7a90bd..deabe90 100644 --- a/docs/dilation-protocol.md +++ b/docs/dilation-protocol.md @@ -11,14 +11,24 @@ connection has been lost, and then coordinate the construction of a replacement. Within this connection, a series of queued-and-acked subchannel messages are used to open/use/close the application-visible subchannels. +## Versions and can-dilate + +The Wormhole protocol includes a `versions` message sent immediately after +the shared PAKE key is established. This also serves as a key-confirmation +message, allowing each side to confirm that the other side knows the right +key. The body of the `versions` message is a JSON-formatted string with keys +that are available for learning the abilities of the peer. Dilation is +signaled by a key named `can-dilate`, whose value is a list of strings. Any +version present in both side's lists is eligible for use. + ## Leaders and Followers -Each side of a Wormhole has a randomly-generated dilation "side" string (this -is independent of the Wormhole's mailbox "side"). When the wormhole is -dilated, the side with the lexicographically-higher "side" value is named the -"Leader", and the other side is named the "Follower". The general wormhole -protocol treats both sides identically, but the distinction matters for the -dilation protocol. +Each side of a Wormhole has a randomly-generated dilation `side` string (this +is included in the `please-dilate` message, and is independent of the +Wormhole's mailbox "side"). When the wormhole is dilated, the side with the +lexicographically-higher "side" value is named the "Leader", and the other +side is named the "Follower". The general wormhole protocol treats both sides +identically, but the distinction matters for the dilation protocol. Both sides send a `please-dilate` as soon as dilation is triggered. Each side discovers whether it is the Leader or the Follower when the peer's @@ -28,6 +38,15 @@ potential connections to use, the Leader decides which one to use, and the Leader gets to decide when the connection is no longer viable (and triggers the establishment of a new one). +The `please-dilate` includes a `use-version` key, computed as the "best" +version of the intersection of the two sides' abilities as reported in the +`versions` message. Both sides will use whichever `use-version` was specified +by the Leader (they learn which side is the Leader at the same moment they +learn the peer's `use-version` value). If the Follower cannot handle the +`use-version` value, dilation fails (this shouldn't happen, as the Leader +knew what the Follower was and was not capable of before sending that +message). + ## Connection Layers We describe the protocol as a series of layers. Messages sent on one layer From d4a551c6b82562c424aa347d534039d8c4a7f135 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Sun, 1 Jul 2018 21:14:41 -0700 Subject: [PATCH 20/49] boss: remove sides from call to D.got_wormhole_versions() --- src/wormhole/_boss.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/wormhole/_boss.py b/src/wormhole/_boss.py index 8d01e8d..ce650b7 100644 --- a/src/wormhole/_boss.py +++ b/src/wormhole/_boss.py @@ -313,8 +313,7 @@ class Boss(object): # most of this is wormhole-to-wormhole, ignored for now # in the future, this is how Dilation is signalled self._their_versions = bytes_to_dict(plaintext) - self._D.got_wormhole_versions(self._side, self._their_side, - self._their_versions) + self._D.got_wormhole_versions(self._their_versions) # but this part is app-to-app app_versions = self._their_versions.get("app_versions", {}) self._W.got_versions(app_versions) From 7e168b819e39a776a1e478cde69f87ae02354afa Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Sun, 1 Jul 2018 21:15:16 -0700 Subject: [PATCH 21/49] manager: clean up versions, merge state machines --- src/wormhole/_dilation/manager.py | 215 +++++++++++++++--------------- src/wormhole/wormhole.py | 3 +- 2 files changed, 109 insertions(+), 109 deletions(-) diff --git a/src/wormhole/_dilation/manager.py b/src/wormhole/_dilation/manager.py index b16d198..bae38d0 100644 --- a/src/wormhole/_dilation/manager.py +++ b/src/wormhole/_dilation/manager.py @@ -1,4 +1,5 @@ from __future__ import print_function, unicode_literals +import os from collections import deque from attr import attrs, attrib from attr.validators import provides, instance_of, optional @@ -7,7 +8,7 @@ from zope.interface import implementer from twisted.internet.defer import Deferred, inlineCallbacks, returnValue from twisted.python import log from .._interfaces import IDilator, IDilationManager, ISend -from ..util import dict_to_bytes, bytes_to_dict +from ..util import dict_to_bytes, bytes_to_dict, bytes_to_hexstr from ..observer import OneShotObserver from .._key import derive_key from .encode import to_be4 @@ -21,6 +22,10 @@ from .inbound import Inbound from .outbound import Outbound +# exported to Wormhole() for inclusion in versions message +DILATION_VERSIONS = ["1"] + + class OldPeerCannotDilateError(Exception): pass @@ -33,9 +38,45 @@ class ReceivedHintsTooEarly(Exception): pass +# new scheme: +# * both sides send PLEASE as soon as they have an unverified key and +# w.dilate has been called, +# * PLEASE includes a dilation-specific "side" (independent of the "side" +# used by mailbox messages) +# * higher "side" is Leader, lower is Follower +# * PLEASE includes can-dilate list of version integers, requires overlap +# "1" is current + +# * we start dilation after both w.dilate() and receiving VERSION, putting us +# in WANTING, then we process all previously-queued inbound DILATE-n +# messages. When PLEASE arrives, we move to CONNECTING +# * HINTS sent after dilation starts +# * only Leader sends RECONNECT, only Follower sends RECONNECTING. This +# is the only difference between the two sides, and is not enforced +# by the protocol (i.e. if the Follower sends RECONNECT to the Leader, +# the Leader will obey, although TODO how confusing will this get?) +# * upon receiving RECONNECT: drop Connector, start new Connector, send +# RECONNECTING, start sending HINTS +# * upon sending RECONNECT: go into FLUSHING state and ignore all HINTS until +# RECONNECTING received. The new Connector can be spun up earlier, and it +# can send HINTS, but it must not be given any HINTS that arrive before +# RECONNECTING (since they're probably stale) + +# * after VERSIONS(KCM) received, we might learn that they other side cannot +# dilate. w.dilate errbacks at this point + +# * maybe signal warning if we stay in a "want" state for too long +# * nobody sends HINTS until they're ready to receive +# * nobody sends HINTS unless they've called w.dilate() and received PLEASE +# * nobody connects to inbound hints unless they've called w.dilate() +# * if leader calls w.dilate() but not follower, leader waits forever in +# "want" (doesn't send anything) +# * if follower calls w.dilate() but not leader, follower waits forever +# in "want", leader waits forever in "wanted" + @attrs @implementer(IDilationManager) -class _ManagerBase(object): +class Manager(object): _S = attrib(validator=provides(ISend)) _my_side = attrib(validator=instance_of(type(u""))) _transit_key = attrib(validator=instance_of(bytes)) @@ -46,6 +87,10 @@ class _ManagerBase(object): _no_listen = False # TODO _tor = None # TODO _timing = None # TODO + _next_subchannel_id = None # initialized in choose_role + + m = MethodicalMachine() + set_trace = getattr(m, "_setTrace", lambda self, f: None) def __attrs_post_init__(self): self._got_versions_d = Deferred() @@ -59,8 +104,6 @@ class _ManagerBase(object): self._next_dilation_phase = 0 - self._next_subchannel_id = 0 # increments by 2 - # I kept getting confused about which methods were for inbound data # (and thus flow-control methods go "out") and which were for # outbound data (with flow-control going "in"), so I split them up @@ -127,18 +170,6 @@ class _ManagerBase(object): self._inbound.subchannel_closed(scid, sc) self._outbound.subchannel_closed(scid, sc) - def _start_connecting(self, role): - assert self._my_role is not None - self._connector = Connector(self._transit_key, - self._transit_relay_location, - self, - self._reactor, self._eventual_queue, - self._no_listen, self._tor, - self._timing, - self._side, # needed for relay handshake - self._my_role) - self._connector.start() - # our Connector calls these def connector_connection_made(self, c): @@ -209,59 +240,22 @@ class _ManagerBase(object): # subchannel maintenance def allocate_subchannel_id(self): - raise NotImplementedError # subclass knows if we're leader or follower + scid_num = self._next_subchannel_id + self._next_subchannel_id += 2 + return to_be4(scid_num) -# new scheme: -# * both sides send PLEASE as soon as they have an unverified key and -# w.dilate has been called, -# * PLEASE includes a dilation-specific "side" (independent of the "side" -# used by mailbox messages) -# * higher "side" is Leader, lower is Follower -# * PLEASE includes can-dilate list of version integers, requires overlap -# "1" is current -# * dilation starts as soon as we've sent PLEASE and received PLEASE -# (four-state two-variable IDLE/WANTING/WANTED/STARTED diamond FSM) -# * HINTS sent after dilation starts -# * only Leader sends RECONNECT, only Follower sends RECONNECTING. This -# is the only difference between the two sides, and is not enforced -# by the protocol (i.e. if the Follower sends RECONNECT to the Leader, -# the Leader will obey, although TODO how confusing will this get?) -# * upon receiving RECONNECT: drop Connector, start new Connector, send -# RECONNECTING, start sending HINTS -# * upon sending CONNECT: go into FLUSHING state and ignore all HINTS until -# RECONNECTING received. The new Connector can be spun up earlier, and it -# can send HINTS, but it must not be given any HINTS that arrive before -# RECONNECTING (since they're probably stale) + # state machine -# * after VERSIONS(KCM) received, we might learn that they other side cannot -# dilate. w.dilate errbacks at this point + # We are born WANTING after the local app calls w.dilate(). We start + # CONNECTING when we receive PLEASE from the remote side -# * maybe signal warning if we stay in a "want" state for too long -# * nobody sends HINTS until they're ready to receive -# * nobody sends HINTS unless they've called w.dilate() and received PLEASE -# * nobody connects to inbound hints unless they've called w.dilate() -# * if leader calls w.dilate() but not follower, leader waits forever in -# "want" (doesn't send anything) -# * if follower calls w.dilate() but not leader, follower waits forever -# in "want", leader waits forever in "wanted" - - -class ManagerShared(_ManagerBase): - m = MethodicalMachine() - set_trace = getattr(m, "_setTrace", lambda self, f: None) + def start(self): + self.send_please() @m.state(initial=True) - def IDLE(self): - pass # pragma: no cover - - @m.state() def WANTING(self): pass # pragma: no cover - @m.state() - def WANTED(self): - pass # pragma: no cover - @m.state() def CONNECTING(self): pass # pragma: no cover @@ -290,10 +284,6 @@ class ManagerShared(_ManagerBase): def STOPPED(self): pass # pragma: no cover - @m.input() - def start(self): - pass # pragma: no cover - @m.input() def rx_PLEASE(self, message): pass # pragma: no cover @@ -332,9 +322,19 @@ class ManagerShared(_ManagerBase): pass # pragma: no cover @m.output() - def stash_side(self, message): + def choose_role(self, message): their_side = message["side"] - self.my_role = LEADER if self._my_side > their_side else FOLLOWER + if self._my_side > their_side: + self._my_role = LEADER + # scid 0 is reserved for the control channel. the leader uses odd + # numbers starting with 1 + self._next_subchannel_id = 1 + elif their_side > self._my_side: + self._my_role = FOLLOWER + # the follower uses even numbers starting with 2 + self._next_subchannel_id = 2 + else: + raise ValueError("their side shouldn't be equal: reflection?") # these Outputs behave differently for the Leader vs the Follower @m.output() @@ -342,12 +342,22 @@ class ManagerShared(_ManagerBase): self.send_dilation_phase(type="please", side=self._my_side) @m.output() - def start_connecting(self): - self._start_connecting() # TODO: merge + def start_connecting_ignore_message(self, message): + del message # ignored + return self.start_connecting() @m.output() - def ignore_message_start_connecting(self, message): - self.start_connecting() + def start_connecting(self): + assert self._my_role is not None + self._connector = Connector(self._transit_key, + self._transit_relay_location, + self, + self._reactor, self._eventual_queue, + self._no_listen, self._tor, + self._timing, + self._my_side, # needed for relay handshake + self._my_role) + self._connector.start() @m.output() def send_reconnect(self): @@ -374,14 +384,9 @@ class ManagerShared(_ManagerBase): # been told to shut down. self._connection.disconnect() # let connection_lost do cleanup - # we don't start CONNECTING until a local start() plus rx_PLEASE - IDLE.upon(rx_PLEASE, enter=WANTED, outputs=[stash_side]) - IDLE.upon(start, enter=WANTING, outputs=[send_please]) - WANTED.upon(start, enter=CONNECTING, outputs=[ - send_please, start_connecting]) + # we start CONNECTING when we get rx_PLEASE WANTING.upon(rx_PLEASE, enter=CONNECTING, - outputs=[stash_side, - ignore_message_start_connecting]) + outputs=[choose_role, start_connecting_ignore_message]) CONNECTING.upon(connection_made, enter=CONNECTED, outputs=[]) @@ -394,11 +399,11 @@ class ManagerShared(_ManagerBase): # Follower # if we notice a lost connection, just wait for the Leader to notice too CONNECTED.upon(connection_lost_follower, enter=LONELY, outputs=[]) - LONELY.upon(rx_RECONNECT, enter=CONNECTING, outputs=[start_connecting]) + LONELY.upon(rx_RECONNECT, enter=CONNECTING, + outputs=[send_reconnecting, start_connecting]) # but if they notice it first, abandon our (seemingly functional) # connection, then tell them that we're ready to try again - CONNECTED.upon(rx_RECONNECT, enter=ABANDONING, # they noticed loss - outputs=[abandon_connection]) + CONNECTED.upon(rx_RECONNECT, enter=ABANDONING, outputs=[abandon_connection]) ABANDONING.upon(connection_lost_follower, enter=CONNECTING, outputs=[send_reconnecting, start_connecting]) # and if they notice a problem while we're still connecting, abandon our @@ -410,8 +415,6 @@ class ManagerShared(_ManagerBase): start_connecting]) # rx_HINTS never changes state, they're just accepted or ignored - IDLE.upon(rx_HINTS, enter=IDLE, outputs=[]) # too early - WANTED.upon(rx_HINTS, enter=WANTED, outputs=[]) # too early WANTING.upon(rx_HINTS, enter=WANTING, outputs=[]) # too early CONNECTING.upon(rx_HINTS, enter=CONNECTING, outputs=[use_hints]) CONNECTED.upon(rx_HINTS, enter=CONNECTED, outputs=[]) # too late, ignore @@ -420,24 +423,15 @@ class ManagerShared(_ManagerBase): ABANDONING.upon(rx_HINTS, enter=ABANDONING, outputs=[]) # shouldn't happen STOPPING.upon(rx_HINTS, enter=STOPPING, outputs=[]) - IDLE.upon(stop, enter=STOPPED, outputs=[]) - WANTED.upon(stop, enter=STOPPED, outputs=[]) WANTING.upon(stop, enter=STOPPED, outputs=[]) CONNECTING.upon(stop, enter=STOPPED, outputs=[stop_connecting]) CONNECTED.upon(stop, enter=STOPPING, outputs=[abandon_connection]) ABANDONING.upon(stop, enter=STOPPING, outputs=[]) - FLUSHING.upon(stop, enter=STOPPED, outputs=[stop_connecting]) + FLUSHING.upon(stop, enter=STOPPED, outputs=[]) LONELY.upon(stop, enter=STOPPED, outputs=[]) STOPPING.upon(connection_lost_leader, enter=STOPPED, outputs=[]) STOPPING.upon(connection_lost_follower, enter=STOPPED, outputs=[]) - def allocate_subchannel_id(self): - # scid 0 is reserved for the control channel. the leader uses odd - # numbers starting with 1 - scid_num = self._next_outbound_seqnum + 1 - self._next_outbound_seqnum += 2 - return to_be4(scid_num) - @attrs @implementer(IDilator) @@ -477,15 +471,23 @@ class Dilator(object): @inlineCallbacks def _start(self): # first, we wait until we hear the VERSION message, which tells us 1: - # the PAKE key works, so we can talk securely, 2: their side, so we - # know who will lead, and 3: that they can do dilation at all + # the PAKE key works, so we can talk securely, 2: that they can do + # dilation at all (if they can't then w.dilate() errbacks) dilation_version = yield self._got_versions_d - if not dilation_version: # 1 or None + # TODO: we could probably return the endpoints earlier, if we flunk + # any connection/listen attempts upon OldPeerCannotDilateError, or + # if/when we give up on the initial connection + + if not dilation_version: # "1" or None + # TODO: be more specific about the error. dilation_version==None + # means we had no version in common with them, which could either + # be because they're so old they don't dilate at all, or because + # they're so new that they no longer accomodate our old version raise OldPeerCannotDilateError() - my_dilation_side = TODO # random + my_dilation_side = bytes_to_hexstr(os.urandom(6)) self._manager = Manager(self._S, my_dilation_side, self._transit_key, self._transit_relay_location, @@ -497,8 +499,8 @@ class Dilator(object): plaintext = self._pending_inbound_dilate_messages.popleft() self.received_dilate(plaintext) - # we could probably return the endpoints earlier yield self._manager.when_first_connected() + # we can open subchannels as soon as we get our first connection scid0 = b"\x00\x00\x00\x00" self._host_addr = _WormholeAddress() # TODO: share with Manager @@ -508,8 +510,7 @@ class Dilator(object): control_ep._subchannel_zero_opened(sc0) self._manager.set_subchannel_zero(scid0, sc0) - connect_ep = SubchannelConnectorEndpoint( - self._manager, self._host_addr) + connect_ep = SubchannelConnectorEndpoint(self._manager, self._host_addr) listen_ep = SubchannelListenerEndpoint(self._manager, self._host_addr) self._manager.set_listener_endpoint(listen_ep) @@ -526,16 +527,14 @@ class Dilator(object): LENGTH = 32 # TODO: whatever Noise wants, I guess self._transit_key = derive_key(key, purpose, LENGTH) - def got_wormhole_versions(self, our_side, their_side, - their_wormhole_versions): - # TODO: remove our_side, their_side - assert isinstance(our_side, str), str - assert isinstance(their_side, str), str + def got_wormhole_versions(self, their_wormhole_versions): # this always happens before received_dilate dilation_version = None - their_dilation_versions = their_wormhole_versions.get("can-dilate", []) - if 1 in their_dilation_versions: - dilation_version = 1 + their_dilation_versions = set(their_wormhole_versions.get("can-dilate", [])) + my_versions = set(DILATION_VERSIONS) + shared_versions = my_versions.intersection(their_dilation_versions) + if "1" in shared_versions: + dilation_version = "1" self._got_versions_d.callback(dilation_version) def received_dilate(self, plaintext): diff --git a/src/wormhole/wormhole.py b/src/wormhole/wormhole.py index 3734fd2..95f5146 100644 --- a/src/wormhole/wormhole.py +++ b/src/wormhole/wormhole.py @@ -9,6 +9,7 @@ from twisted.internet.task import Cooperator from zope.interface import implementer from ._boss import Boss +from ._dilation.manager import DILATION_VERSIONS from ._dilation.connector import Connector from ._interfaces import IDeferredWormhole, IWormhole from ._key import derive_key @@ -271,7 +272,7 @@ def create( w = _DeferredWormhole(reactor, eq) # this indicates Wormhole capabilities wormhole_versions = { - "can-dilate": [1], + "can-dilate": DILATION_VERSIONS, "dilation-abilities": Connector.get_connection_abilities(), } wormhole_versions["app_versions"] = versions # app-specific capabilities From 8a1a8b1f9c0fe5441082bb6cbef1711ddd30ada5 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Mon, 2 Jul 2018 08:58:28 -0700 Subject: [PATCH 22/49] manager: factor out make_side for testing and override --- src/wormhole/_dilation/manager.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/wormhole/_dilation/manager.py b/src/wormhole/_dilation/manager.py index bae38d0..16553b2 100644 --- a/src/wormhole/_dilation/manager.py +++ b/src/wormhole/_dilation/manager.py @@ -37,6 +37,8 @@ class UnknownDilationMessageType(Exception): class ReceivedHintsTooEarly(Exception): pass +def make_side(): + return bytes_to_hexstr(os.urandom(6)) # new scheme: # * both sides send PLEASE as soon as they have an unverified key and @@ -487,7 +489,7 @@ class Dilator(object): # they're so new that they no longer accomodate our old version raise OldPeerCannotDilateError() - my_dilation_side = bytes_to_hexstr(os.urandom(6)) + my_dilation_side = make_side() self._manager = Manager(self._S, my_dilation_side, self._transit_key, self._transit_relay_location, From a4234cdecf7716ec511e7300c713633ac9e68a46 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Mon, 2 Jul 2018 08:59:02 -0700 Subject: [PATCH 23/49] test_machines: fix for change to got_wormhole_versions --- src/wormhole/test/test_machines.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/wormhole/test/test_machines.py b/src/wormhole/test/test_machines.py index c989e0a..dff3fb0 100644 --- a/src/wormhole/test/test_machines.py +++ b/src/wormhole/test/test_machines.py @@ -1330,7 +1330,7 @@ class Boss(unittest.TestCase): ("w.got_key", b"key"), ("d.got_key", b"key"), ("w.got_verifier", b"verifier"), - ("d.got_wormhole_versions", "side", "side", {}), + ("d.got_wormhole_versions", {}), ("w.got_versions", {}), ("w.received", b"msg1"), ]) From 7084cbcb6fb16399b34136198ba0cce01aaa0813 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Mon, 2 Jul 2018 08:59:25 -0700 Subject: [PATCH 24/49] test_manager: fix --- src/wormhole/test/dilate/test_manager.py | 34 +++++++++++++++--------- 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/src/wormhole/test/dilate/test_manager.py b/src/wormhole/test/dilate/test_manager.py index 690e1cf..66f0186 100644 --- a/src/wormhole/test/dilate/test_manager.py +++ b/src/wormhole/test/dilate/test_manager.py @@ -53,9 +53,11 @@ class TestDilator(unittest.TestCase): alsoProvides(m, IDilationManager) m.when_first_connected.return_value = wfc_d = Deferred() # TODO: test missing can-dilate, and no-overlap - with mock.patch("wormhole._dilation.manager.ManagerLeader", + with mock.patch("wormhole._dilation.manager.Manager", return_value=m) as ml: - dil.got_wormhole_versions("us", "them", {"can-dilate": [1]}) + with mock.patch("wormhole._dilation.manager.make_side", + return_value="us"): + dil.got_wormhole_versions({"can-dilate": ["1"]}) # that should create the Manager. Because "us" > "them", we're # the leader self.assertEqual(ml.mock_calls, [mock.call(send, "us", transit_key, @@ -124,9 +126,11 @@ class TestDilator(unittest.TestCase): self.flushLoggedErrors(UnknownDilationMessageType) def test_follower(self): + # todo: this no longer proceeds far enough to pick a side dil, send, reactor, eq, clock, coop = make_dilator() d1 = dil.dilate() self.assertNoResult(d1) + self.assertEqual(send.mock_calls, []) key = b"key" transit_key = object() @@ -137,10 +141,12 @@ class TestDilator(unittest.TestCase): m = mock.Mock() alsoProvides(m, IDilationManager) m.when_first_connected.return_value = Deferred() - with mock.patch("wormhole._dilation.manager.ManagerFollower", - return_value=m) as mf: - dil.got_wormhole_versions("me", "you", {"can-dilate": [1]}) - # "me" < "you", so we're the follower + with mock.patch("wormhole._dilation.manager.Manager", return_value=m) as mf: + with mock.patch("wormhole._dilation.manager.make_side", + return_value="me"): + dil.got_wormhole_versions({"can-dilate": ["1"]}) + # we want to dilate (dil.dilate() above), and now we know they *can* + # dilate (got_wormhole_versions), so we create and start the manager self.assertEqual(mf.mock_calls, [mock.call(send, "me", transit_key, None, reactor, eq, coop)]) self.assertEqual(m.mock_calls, [mock.call.start(), @@ -152,7 +158,7 @@ class TestDilator(unittest.TestCase): d1 = dil.dilate() self.assertNoResult(d1) - dil.got_wormhole_versions("me", "you", {}) # missing "can-dilate" + dil.got_wormhole_versions({}) # missing "can-dilate" eq.flush_sync() f = self.failureResultOf(d1) f.check(OldPeerCannotDilateError) @@ -162,7 +168,7 @@ class TestDilator(unittest.TestCase): d1 = dil.dilate() self.assertNoResult(d1) - dil.got_wormhole_versions("me", "you", {"can-dilate": [-1]}) + dil.got_wormhole_versions({"can-dilate": [-1]}) eq.flush_sync() f = self.failureResultOf(d1) f.check(OldPeerCannotDilateError) @@ -180,9 +186,11 @@ class TestDilator(unittest.TestCase): alsoProvides(m, IDilationManager) m.when_first_connected.return_value = Deferred() - with mock.patch("wormhole._dilation.manager.ManagerLeader", + with mock.patch("wormhole._dilation.manager.Manager", return_value=m) as ml: - dil.got_wormhole_versions("us", "them", {"can-dilate": [1]}) + with mock.patch("wormhole._dilation.manager.make_side", + return_value="us"): + dil.got_wormhole_versions({"can-dilate": ["1"]}) self.assertEqual(ml.mock_calls, [mock.call(send, "us", b"key", None, reactor, eq, coop)]) self.assertEqual(m.mock_calls, [mock.call.start(), @@ -197,8 +205,10 @@ class TestDilator(unittest.TestCase): d1 = dil.dilate(transit_relay_location=relay) self.assertNoResult(d1) - with mock.patch("wormhole._dilation.manager.ManagerLeader") as ml: - dil.got_wormhole_versions("us", "them", {"can-dilate": [1]}) + with mock.patch("wormhole._dilation.manager.Manager") as ml: + with mock.patch("wormhole._dilation.manager.make_side", + return_value="us"): + dil.got_wormhole_versions({"can-dilate": ["1"]}) self.assertEqual(ml.mock_calls, [mock.call(send, "us", b"key", relay, reactor, eq, coop), mock.call().start(), From 78358358bc8f550c6efb022a986d1174f3f783cf Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Mon, 2 Jul 2018 09:04:23 -0700 Subject: [PATCH 25/49] manager: hush flake8 --- src/wormhole/_dilation/manager.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/wormhole/_dilation/manager.py b/src/wormhole/_dilation/manager.py index 16553b2..9a12409 100644 --- a/src/wormhole/_dilation/manager.py +++ b/src/wormhole/_dilation/manager.py @@ -37,9 +37,11 @@ class UnknownDilationMessageType(Exception): class ReceivedHintsTooEarly(Exception): pass + def make_side(): return bytes_to_hexstr(os.urandom(6)) + # new scheme: # * both sides send PLEASE as soon as they have an unverified key and # w.dilate has been called, From dd8bff30f2bbb016516fc4613f397f076e5269ef Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Mon, 2 Jul 2018 09:09:52 -0700 Subject: [PATCH 26/49] remove old-follower.py, flake8 still sees it and doesn't like what it sees --- old-follower.py | 130 ------------------------------------------------ 1 file changed, 130 deletions(-) delete mode 100644 old-follower.py diff --git a/old-follower.py b/old-follower.py deleted file mode 100644 index 7be4307..0000000 --- a/old-follower.py +++ /dev/null @@ -1,130 +0,0 @@ - -class ManagerFollower(_ManagerBase): - m = MethodicalMachine() - set_trace = getattr(m, "_setTrace", lambda self, f: None) - - @m.state(initial=True) - def IDLE(self): - pass # pragma: no cover - - @m.state() - def WANTING(self): - pass # pragma: no cover - - @m.state() - def CONNECTING(self): - pass # pragma: no cover - - @m.state() - def CONNECTED(self): - pass # pragma: no cover - - @m.state(terminal=True) - def STOPPED(self): - pass # pragma: no cover - - @m.input() - def start(self): - pass # pragma: no cover - - @m.input() - def rx_PLEASE(self): - pass # pragma: no cover - - @m.input() - def rx_DILATE(self): - pass # pragma: no cover - - @m.input() - def rx_HINTS(self, hint_message): - pass # pragma: no cover - - @m.input() - def connection_made(self): - pass # pragma: no cover - - @m.input() - def connection_lost(self): - pass # pragma: no cover - # follower doesn't react to connection_lost, but waits for a new LETS_DILATE - - @m.input() - def stop(self): - pass # pragma: no cover - - # these Outputs behave differently for the Leader vs the Follower - @m.output() - def send_please(self): - self.send_dilation_phase(type="please") - - @m.output() - def start_connecting(self): - self._start_connecting(FOLLOWER) - - # these Outputs delegate to the same code in both the Leader and the - # Follower, but they must be replicated here because the Automat instance - # is on the subclass, not the shared superclass - - @m.output() - def use_hints(self, hint_message): - hint_objs = filter(lambda h: h, # ignore None, unrecognizable - [parse_hint(hs) for hs in hint_message["hints"]]) - self._connector.got_hints(hint_objs) - - @m.output() - def stop_connecting(self): - self._connector.stop() - - @m.output() - def use_connection(self, c): - self._use_connection(c) - - @m.output() - def stop_using_connection(self): - self._stop_using_connection() - - @m.output() - def signal_error(self): - pass # TODO - - @m.output() - def signal_error_hints(self, hint_message): - pass # TODO - - IDLE.upon(rx_HINTS, enter=STOPPED, outputs=[signal_error_hints]) # too early - IDLE.upon(rx_DILATE, enter=STOPPED, outputs=[signal_error]) # too early - # leader shouldn't send us DILATE before receiving our PLEASE - IDLE.upon(stop, enter=STOPPED, outputs=[]) - IDLE.upon(start, enter=WANTING, outputs=[send_please]) - WANTING.upon(rx_DILATE, enter=CONNECTING, outputs=[start_connecting]) - WANTING.upon(stop, enter=STOPPED, outputs=[]) - - CONNECTING.upon(rx_HINTS, enter=CONNECTING, outputs=[use_hints]) - CONNECTING.upon(connection_made, enter=CONNECTED, outputs=[use_connection]) - # shouldn't happen: connection_lost - # CONNECTING.upon(connection_lost, enter=CONNECTING, outputs=[?]) - CONNECTING.upon(rx_DILATE, enter=CONNECTING, outputs=[stop_connecting, - start_connecting]) - # receiving rx_DILATE while we're still working on the last one means the - # leader thought we'd connected, then thought we'd been disconnected, all - # before we heard about that connection - CONNECTING.upon(stop, enter=STOPPED, outputs=[stop_connecting]) - - CONNECTED.upon(connection_lost, enter=WANTING, outputs=[stop_using_connection]) - CONNECTED.upon(rx_DILATE, enter=CONNECTING, outputs=[stop_using_connection, - start_connecting]) - CONNECTED.upon(rx_HINTS, enter=CONNECTED, outputs=[]) # too late, ignore - CONNECTED.upon(stop, enter=STOPPED, outputs=[stop_using_connection]) - # shouldn't happen: connection_made - - # we should never receive PLEASE, we're the follower - IDLE.upon(rx_PLEASE, enter=STOPPED, outputs=[signal_error]) - WANTING.upon(rx_PLEASE, enter=STOPPED, outputs=[signal_error]) - CONNECTING.upon(rx_PLEASE, enter=STOPPED, outputs=[signal_error]) - CONNECTED.upon(rx_PLEASE, enter=STOPPED, outputs=[signal_error]) - - def allocate_subchannel_id(self): - # the follower uses even numbers starting with 2 - scid_num = self._next_outbound_seqnum + 2 - self._next_outbound_seqnum += 2 - return to_be4(scid_num) From e19c7d1281d4856f793500e6012a81b3f6a5d3a7 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Fri, 21 Dec 2018 00:04:10 -0500 Subject: [PATCH 27/49] typos/cleanups in docs/dilation-protocol.md --- docs/dilation-protocol.md | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/docs/dilation-protocol.md b/docs/dilation-protocol.md index deabe90..03716e9 100644 --- a/docs/dilation-protocol.md +++ b/docs/dilation-protocol.md @@ -82,7 +82,7 @@ contains OPEN/DATA/CLOSE/ACK messages: OPEN/DATA/CLOSE have a sequence number messages reference those sequence numbers. When a message is given to the L4 channel for delivery to the remote side, it is always queued, then transmitted if there is an L3 connection available. This message remains in -the queue until an ACK is received to release it. If a new L3 connection is +the queue until an ACK is received to retire it. If a new L3 connection is made, all queued messages will be re-sent (in seqnum order). L5 are subchannels. There is one pre-established subchannel 0 known as the @@ -118,13 +118,14 @@ frames or bytes), and all-number phase names are reserved for application data (sent via `w.send_message()`). Therefore the dilation control messages use phases named `DILATE-0`, `DILATE-1`, etc. Each side maintains its own counter, so one side might be up to e.g. `DILATE-5` while the other has only -gotten as far as `DILATE-2`. This effectively creates a unidirectional stream -of `DILATE-n` messages, each containing one or more dilation record, of -various types described below. Note that all phases beyond the initial -VERSION and PAKE phases are encrypted by the shared session key. +gotten as far as `DILATE-2`. This effectively creates a pair of +unidirectional streams of `DILATE-n` messages, each containing one or more +dilation record, of various types described below. Note that all phases +beyond the initial VERSION and PAKE phases are encrypted by the shared +session key. -A future mailbox protocol might provide a simple ordered stream of messages, -with application records and dilation records mixed together. +A future mailbox protocol might provide a simple ordered stream of typed +messages, with application records and dilation records mixed together. Each `DILATE-n` message is a JSON-encoded dictionary with a `type` field that has a string value. The dictionary will have other keys that depend upon the @@ -163,7 +164,7 @@ dropped. If something goes wrong with the established connection and the Leader decides a new one is necessary, the Leader will send a `reconnect` message. This might happen while connections are still being established, or while the -Follower thinks it still has a viable connection (the Leader might observer +Follower thinks it still has a viable connection (the Leader might observe problems that the Follower does not), or after the Follower thinks the connection has been lost. In all cases, the Leader is the only side which should send `reconnect`. The state machine code looks the same on both sides, @@ -174,9 +175,9 @@ attempts and terminate any existing connections (even if they appear viable). Listening sockets may be retained, but any previous connection made through them must be dropped. -Once all connections have stopped, the Follower should send an `reconnecting` +Once all connections have stopped, the Follower should send a `reconnecting` message, then start the connection process for the next generation, which -will send new `connection-hint` messages for all listening sockets). +will send new `connection-hint` messages for all listening sockets. Generations are non-overlapping. The Leader will drop all connections from generation 1 before sending the `reconnect` for generation 2, and will not @@ -467,8 +468,9 @@ identifier. Both directions share a number space (unlike L4 seqnums), so the rule is that the Leader side sets the last bit of the last byte to a 0, while the Follower sets it to a 1. These are not generally treated as integers, however for the sake of debugging, the implementation generates them with a -simple big-endian-encoded counter (`next(counter)*2` for the Leader, -`next(counter)*2+1` for the Follower). +simple big-endian-encoded counter (`counter*2+2` for the Leader, +`counter*2+1` for the Follower, with id `0` reserved for the control +channel). When the `client_ep.connect()` API is called, the Initiator allocates a subchannel-id and sends an OPEN. It can then immediately send DATA messages From e55787c6936d22fc50e9913c960eebb6ec224c6d Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Fri, 21 Dec 2018 16:51:28 -0500 Subject: [PATCH 28/49] get most of Manager working and tested still need to test the subchannel interfaces, and ping/pong/kcm --- docs/dilation-protocol.md | 31 +- src/wormhole/_dilation/manager.py | 28 +- src/wormhole/test/dilate/test_manager.py | 444 +++++++++++++++++++++-- 3 files changed, 436 insertions(+), 67 deletions(-) diff --git a/docs/dilation-protocol.md b/docs/dilation-protocol.md index 03716e9..7625353 100644 --- a/docs/dilation-protocol.md +++ b/docs/dilation-protocol.md @@ -131,25 +131,24 @@ Each `DILATE-n` message is a JSON-encoded dictionary with a `type` field that has a string value. The dictionary will have other keys that depend upon the type. -`w.dilate()` triggers transmission of a `please-dilate` record with a set of -versions that can be accepted. Versions use strings, rather than integers, to -support experimental protocols, however there is still a total ordering of -preferability. +`w.dilate()` triggers transmission of a `please` (i.e. "please dilate") +record with a set of versions that can be accepted. Versions use strings, +rather than integers, to support experimental protocols, however there is +still a total ordering of preferability. ``` -{ "type": "please-dilate", +{ "type": "please", "side": "abcdef", "accepted-versions": ["1"] } ``` -If one side receives a `please-dilate` before `w.dilate()` has been called -locally, the contents are stored in case `w.dilate()` is called in the -future. Once both `w.dilate()` has been called and the peer's `please-dilate` -has been received, the side determines whether it is the Leader or the -Follower. Both sides also compare `accepted-versions` fields to choose the -best mutually-compatible version to use: they should always pick the same -one. +If one side receives a `please` before `w.dilate()` has been called locally, +the contents are stored in case `w.dilate()` is called in the future. Once +both `w.dilate()` has been called and the peer's `please` has been received, +the side determines whether it is the Leader or the Follower. Both sides also +compare `accepted-versions` fields to choose the best mutually-compatible +version to use: they should always pick the same one. Then both sides begin the connection process for generation 1 by opening listening sockets and sending `connection-hint` records for each one. After a @@ -465,11 +464,11 @@ side seeking to send a file. Each subchannel uses a distinct subchannel-id, which is a four-byte identifier. Both directions share a number space (unlike L4 seqnums), so the -rule is that the Leader side sets the last bit of the last byte to a 0, while -the Follower sets it to a 1. These are not generally treated as integers, +rule is that the Leader side sets the last bit of the last byte to a 1, while +the Follower sets it to a 0. These are not generally treated as integers, however for the sake of debugging, the implementation generates them with a -simple big-endian-encoded counter (`counter*2+2` for the Leader, -`counter*2+1` for the Follower, with id `0` reserved for the control +simple big-endian-encoded counter (`counter*2+1` for the Leader, +`counter*2+2` for the Follower, with id `0` reserved for the control channel). When the `client_ep.connect()` API is called, the Initiator allocates a diff --git a/src/wormhole/_dilation/manager.py b/src/wormhole/_dilation/manager.py index 9a12409..f38390b 100644 --- a/src/wormhole/_dilation/manager.py +++ b/src/wormhole/_dilation/manager.py @@ -188,7 +188,7 @@ class Manager(object): def connector_connection_lost(self): self._stop_using_connection() - if self.role is LEADER: + if self._my_role is LEADER: self.connection_lost_leader() # state machine else: self.connection_lost_follower() @@ -256,6 +256,9 @@ class Manager(object): def start(self): self.send_please() + def send_please(self): + self.send_dilation_phase(type="please", side=self._my_side) + @m.state(initial=True) def WANTING(self): pass # pragma: no cover @@ -278,7 +281,7 @@ class Manager(object): @m.state() def LONELY(self): - pass # pragme: no cover + pass # pragma: no cover @m.state() def STOPPING(self): @@ -306,7 +309,7 @@ class Manager(object): # Connector gives us connection_made() @m.input() - def connection_made(self, c): + def connection_made(self): pass # pragma: no cover # our connection_lost() fires connection_lost_leader or @@ -341,17 +344,17 @@ class Manager(object): raise ValueError("their side shouldn't be equal: reflection?") # these Outputs behave differently for the Leader vs the Follower - @m.output() - def send_please(self): - self.send_dilation_phase(type="please", side=self._my_side) @m.output() def start_connecting_ignore_message(self, message): del message # ignored - return self.start_connecting() + return self._start_connecting() @m.output() def start_connecting(self): + self._start_connecting() + + def _start_connecting(self): assert self._my_role is not None self._connector = Connector(self._transit_key, self._transit_relay_location, @@ -447,7 +450,7 @@ class Dilator(object): before we know whether we'll be the Leader or the Follower. Once we hear the other side's VERSION message (which tells us that we have a connection, they are capable of dilating, and which side we're on), - then we build a DilationManager and hand control to it. + then we build a Manager and hand control to it. """ _reactor = attrib() @@ -532,6 +535,7 @@ class Dilator(object): self._transit_key = derive_key(key, purpose, LENGTH) def got_wormhole_versions(self, their_wormhole_versions): + assert self._transit_key is not None # this always happens before received_dilate dilation_version = None their_dilation_versions = set(their_wormhole_versions.get("can-dilate", [])) @@ -554,11 +558,13 @@ class Dilator(object): message = bytes_to_dict(plaintext) type = message["type"] if type == "please": - self._manager.rx_PLEASE() # message) - elif type == "dilate": - self._manager.rx_DILATE() # message) + self._manager.rx_PLEASE(message) elif type == "connection-hints": self._manager.rx_HINTS(message) + elif type == "reconnect": + self._manager.rx_RECONNECT() + elif type == "reconnecting": + self._manager.rx_RECONNECTING() else: log.err(UnknownDilationMessageType(message)) return diff --git a/src/wormhole/test/dilate/test_manager.py b/src/wormhole/test/dilate/test_manager.py index 66f0186..f0d50a0 100644 --- a/src/wormhole/test/dilate/test_manager.py +++ b/src/wormhole/test/dilate/test_manager.py @@ -7,10 +7,13 @@ import mock from ...eventual import EventualQueue from ..._interfaces import ISend, IDilationManager from ...util import dict_to_bytes -from ..._dilation.manager import (Dilator, +from ..._dilation import roles +from ..._dilation.encode import to_be4 +from ..._dilation.manager import (Dilator, Manager, make_side, OldPeerCannotDilateError, UnknownDilationMessageType) from ..._dilation.subchannel import _WormholeAddress +from ..._dilation.connection import Open, Data, Close, Ack from .common import clear_mock_calls @@ -30,9 +33,8 @@ def make_dilator(): dil.wire(send) return dil, send, reactor, eq, clock, coop - class TestDilator(unittest.TestCase): - def test_leader(self): + def test_manager_and_endpoints(self): dil, send, reactor, eq, clock, coop = make_dilator() d1 = dil.dilate() d2 = dil.dilate() @@ -52,16 +54,15 @@ class TestDilator(unittest.TestCase): m = mock.Mock() alsoProvides(m, IDilationManager) m.when_first_connected.return_value = wfc_d = Deferred() - # TODO: test missing can-dilate, and no-overlap with mock.patch("wormhole._dilation.manager.Manager", return_value=m) as ml: with mock.patch("wormhole._dilation.manager.make_side", return_value="us"): dil.got_wormhole_versions({"can-dilate": ["1"]}) - # that should create the Manager. Because "us" > "them", we're - # the leader + # that should create the Manager self.assertEqual(ml.mock_calls, [mock.call(send, "us", transit_key, None, reactor, eq, coop)]) + # and tell it to start, and get wait-for-it-to-connect Deferred self.assertEqual(m.mock_calls, [mock.call.start(), mock.call.when_first_connected(), ]) @@ -107,9 +108,11 @@ class TestDilator(unittest.TestCase): eq.flush_sync() self.assertEqual(eps, self.successResultOf(d3)) + # all subsequent DILATE-n messages should get passed to the manager self.assertEqual(m.mock_calls, []) - dil.received_dilate(dict_to_bytes(dict(type="please"))) - self.assertEqual(m.mock_calls, [mock.call.rx_PLEASE()]) + pleasemsg = dict(type="please", side="them") + dil.received_dilate(dict_to_bytes(pleasemsg)) + self.assertEqual(m.mock_calls, [mock.call.rx_PLEASE(pleasemsg)]) clear_mock_calls(m) hintmsg = dict(type="connection-hints") @@ -117,47 +120,27 @@ class TestDilator(unittest.TestCase): self.assertEqual(m.mock_calls, [mock.call.rx_HINTS(hintmsg)]) clear_mock_calls(m) - dil.received_dilate(dict_to_bytes(dict(type="dilate"))) - self.assertEqual(m.mock_calls, [mock.call.rx_DILATE()]) + # we're nominally the LEADER, and the leader would not normally be + # receiving a RECONNECT, but since we've mocked out the Manager it + # won't notice + dil.received_dilate(dict_to_bytes(dict(type="reconnect"))) + self.assertEqual(m.mock_calls, [mock.call.rx_RECONNECT()]) + clear_mock_calls(m) + + dil.received_dilate(dict_to_bytes(dict(type="reconnecting"))) + self.assertEqual(m.mock_calls, [mock.call.rx_RECONNECTING()]) clear_mock_calls(m) dil.received_dilate(dict_to_bytes(dict(type="unknown"))) self.assertEqual(m.mock_calls, []) self.flushLoggedErrors(UnknownDilationMessageType) - def test_follower(self): - # todo: this no longer proceeds far enough to pick a side - dil, send, reactor, eq, clock, coop = make_dilator() - d1 = dil.dilate() - self.assertNoResult(d1) - self.assertEqual(send.mock_calls, []) - - key = b"key" - transit_key = object() - with mock.patch("wormhole._dilation.manager.derive_key", - return_value=transit_key): - dil.got_key(key) - - m = mock.Mock() - alsoProvides(m, IDilationManager) - m.when_first_connected.return_value = Deferred() - with mock.patch("wormhole._dilation.manager.Manager", return_value=m) as mf: - with mock.patch("wormhole._dilation.manager.make_side", - return_value="me"): - dil.got_wormhole_versions({"can-dilate": ["1"]}) - # we want to dilate (dil.dilate() above), and now we know they *can* - # dilate (got_wormhole_versions), so we create and start the manager - self.assertEqual(mf.mock_calls, [mock.call(send, "me", transit_key, - None, reactor, eq, coop)]) - self.assertEqual(m.mock_calls, [mock.call.start(), - mock.call.when_first_connected(), - ]) - def test_peer_cannot_dilate(self): dil, send, reactor, eq, clock, coop = make_dilator() d1 = dil.dilate() self.assertNoResult(d1) + dil._transit_key = b"\x01"*32 dil.got_wormhole_versions({}) # missing "can-dilate" eq.flush_sync() f = self.failureResultOf(d1) @@ -168,6 +151,7 @@ class TestDilator(unittest.TestCase): d1 = dil.dilate() self.assertNoResult(d1) + dil._transit_key = b"key" dil.got_wormhole_versions({"can-dilate": [-1]}) eq.flush_sync() f = self.failureResultOf(d1) @@ -178,7 +162,8 @@ class TestDilator(unittest.TestCase): dil._transit_key = b"key" d1 = dil.dilate() self.assertNoResult(d1) - dil.received_dilate(dict_to_bytes(dict(type="please"))) + pleasemsg = dict(type="please", side="them") + dil.received_dilate(dict_to_bytes(pleasemsg)) hintmsg = dict(type="connection-hints") dil.received_dilate(dict_to_bytes(hintmsg)) @@ -194,7 +179,7 @@ class TestDilator(unittest.TestCase): self.assertEqual(ml.mock_calls, [mock.call(send, "us", b"key", None, reactor, eq, coop)]) self.assertEqual(m.mock_calls, [mock.call.start(), - mock.call.rx_PLEASE(), + mock.call.rx_PLEASE(pleasemsg), mock.call.rx_HINTS(hintmsg), mock.call.when_first_connected()]) @@ -213,3 +198,382 @@ class TestDilator(unittest.TestCase): relay, reactor, eq, coop), mock.call().start(), mock.call().when_first_connected()]) + +LEADER = "ff3456abcdef" +FOLLOWER = "123456abcdef" + +def make_manager(leader=True): + class Holder: + pass + h = Holder() + h.send = mock.Mock() + alsoProvides(h.send, ISend) + if leader: + side = LEADER + else: + side = FOLLOWER + h.key = b"\x00"*32 + h.relay = None + h.reactor = object() + h.clock = Clock() + h.eq = EventualQueue(h.clock) + term = mock.Mock(side_effect=lambda: True) # one write per Eventual tick + def term_factory(): + return term + h.coop = Cooperator(terminationPredicateFactory=term_factory, + scheduler=h.eq.eventually) + h.inbound = mock.Mock() + h.Inbound = mock.Mock(return_value=h.inbound) + h.outbound = mock.Mock() + h.Outbound = mock.Mock(return_value=h.outbound) + h.hostaddr = object() + with mock.patch("wormhole._dilation.manager.Inbound", h.Inbound): + with mock.patch("wormhole._dilation.manager.Outbound", h.Outbound): + with mock.patch("wormhole._dilation.manager._WormholeAddress", + return_value=h.hostaddr): + m = Manager(h.send, side, h.key, h.relay, h.reactor, h.eq, h.coop) + return m, h + +class TestManager(unittest.TestCase): + def test_make_side(self): + side = make_side() + self.assertEqual(type(side), type(u"")) + self.assertEqual(len(side), 2*6) + + def test_create(self): + m, h = make_manager() + + def test_leader(self): + m, h = make_manager(leader=True) + self.assertEqual(h.send.mock_calls, []) + self.assertEqual(h.Inbound.mock_calls, [mock.call(m, h.hostaddr)]) + self.assertEqual(h.Outbound.mock_calls, [mock.call(m, h.coop)]) + + m.start() + self.assertEqual(h.send.mock_calls, [ + mock.call.send("dilate-0", + dict_to_bytes({"type": "please", "side": LEADER})) + ]) + clear_mock_calls(h.send) + + wfc_d = m.when_first_connected() + self.assertNoResult(wfc_d) + + # ignore early hints + m.rx_HINTS({}) + self.assertEqual(h.send.mock_calls, []) + + c = mock.Mock() + connector = mock.Mock(return_value=c) + with mock.patch("wormhole._dilation.manager.Connector", connector): + # receiving this PLEASE triggers creation of the Connector + m.rx_PLEASE({"side": FOLLOWER}) + self.assertEqual(h.send.mock_calls, []) + self.assertEqual(connector.mock_calls, [ + mock.call(b"\x00"*32, None, m, h.reactor, h.eq, + False, # no_listen + None, # tor + None, # timing + LEADER, roles.LEADER), + ]) + self.assertEqual(c.mock_calls, [mock.call.start()]) + clear_mock_calls(connector, c) + + self.assertNoResult(wfc_d) + + # now any inbound hints should get passed to our Connector + with mock.patch("wormhole._dilation.manager.parse_hint", + side_effect=["p1", None, "p3"]) as ph: + m.rx_HINTS({"hints": [1, 2, 3]}) + self.assertEqual(ph.mock_calls, [mock.call(1), mock.call(2), mock.call(3)]) + self.assertEqual(c.mock_calls, [mock.call.got_hints(["p1", "p3"])]) + clear_mock_calls(ph, c) + + # and we send out any (listening) hints from our Connector + m.send_hints([1,2]) + self.assertEqual(h.send.mock_calls, [ + mock.call.send("dilate-1", + dict_to_bytes({"type": "connection-hints", + "hints": [1, 2]})) + ]) + clear_mock_calls(h.send) + + # the first successful connection fires when_first_connected(), so + # the Dilator can create and return the endpoints + c1 = mock.Mock() + m.connector_connection_made(c1) + + self.assertEqual(h.inbound.mock_calls, [mock.call.use_connection(c1)]) + self.assertEqual(h.outbound.mock_calls, [mock.call.use_connection(c1)]) + clear_mock_calls(h.inbound, h.outbound) + + h.eq.flush_sync() + self.successResultOf(wfc_d) # fires with None + wfc_d2 = m.when_first_connected() + h.eq.flush_sync() + self.successResultOf(wfc_d2) + + scid0 = b"\x00\x00\x00\x00" + sc0 = mock.Mock() + m.set_subchannel_zero(scid0, sc0) + listen_ep = mock.Mock() + m.set_listener_endpoint(listen_ep) + self.assertEqual(h.inbound.mock_calls, [ + mock.call.set_subchannel_zero(scid0, sc0), + mock.call.set_listener_endpoint(listen_ep), + ]) + clear_mock_calls(h.inbound) + + # the Leader making a new outbound channel should get scid=1 + scid1 = to_be4(1) + self.assertEqual(m.allocate_subchannel_id(), scid1) + r1 = Open(10, scid1) # seqnum=10 + h.outbound.build_record = mock.Mock(return_value=r1) + m.send_open(scid1) + self.assertEqual(h.outbound.mock_calls, [ + mock.call.build_record(Open, scid1), + mock.call.queue_and_send_record(r1), + ]) + clear_mock_calls(h.outbound) + + r2 = Data(11, scid1, b"data") + h.outbound.build_record = mock.Mock(return_value=r2) + m.send_data(scid1, b"data") + self.assertEqual(h.outbound.mock_calls, [ + mock.call.build_record(Data, scid1, b"data"), + mock.call.queue_and_send_record(r2), + ]) + clear_mock_calls(h.outbound) + + r3 = Close(12, scid1) + h.outbound.build_record = mock.Mock(return_value=r3) + m.send_close(scid1) + self.assertEqual(h.outbound.mock_calls, [ + mock.call.build_record(Close, scid1), + mock.call.queue_and_send_record(r3), + ]) + clear_mock_calls(h.outbound) + + # ack the OPEN + m.got_record(Ack(10)) + self.assertEqual(h.outbound.mock_calls, [ + mock.call.handle_ack(10) + ]) + clear_mock_calls(h.outbound) + + # test that inbound records get acked and routed to Inbound + h.inbound.is_record_old = mock.Mock(return_value=False) + scid2 = to_be4(2) + o200 = Open(200, scid2) + m.got_record(o200) + self.assertEqual(h.outbound.mock_calls, [ + mock.call.send_if_connected(Ack(200)) + ]) + self.assertEqual(h.inbound.mock_calls, [ + mock.call.is_record_old(o200), + mock.call.update_ack_watermark(200), + mock.call.handle_open(scid2), + ]) + clear_mock_calls(h.outbound, h.inbound) + + # old (duplicate) records should provoke new Acks, but not get + # forwarded + h.inbound.is_record_old = mock.Mock(return_value=True) + m.got_record(o200) + self.assertEqual(h.outbound.mock_calls, [ + mock.call.send_if_connected(Ack(200)) + ]) + self.assertEqual(h.inbound.mock_calls, [ + mock.call.is_record_old(o200), + ]) + clear_mock_calls(h.outbound, h.inbound) + + # check Data and Close too + h.inbound.is_record_old = mock.Mock(return_value=False) + d201 = Data(201, scid2, b"data") + m.got_record(d201) + self.assertEqual(h.outbound.mock_calls, [ + mock.call.send_if_connected(Ack(201)) + ]) + self.assertEqual(h.inbound.mock_calls, [ + mock.call.is_record_old(d201), + mock.call.update_ack_watermark(201), + mock.call.handle_data(scid2, b"data"), + ]) + clear_mock_calls(h.outbound, h.inbound) + + c202 = Close(202, scid2) + m.got_record(c202) + self.assertEqual(h.outbound.mock_calls, [ + mock.call.send_if_connected(Ack(202)) + ]) + self.assertEqual(h.inbound.mock_calls, [ + mock.call.is_record_old(c202), + mock.call.update_ack_watermark(202), + mock.call.handle_close(scid2), + ]) + clear_mock_calls(h.outbound, h.inbound) + + # Now we lose the connection. The Leader should tell the other side + # that we're reconnecting. + + m.connector_connection_lost() + self.assertEqual(h.send.mock_calls, [ + mock.call.send("dilate-2", + dict_to_bytes({"type": "reconnect"})) + ]) + self.assertEqual(h.inbound.mock_calls, [ + mock.call.stop_using_connection() + ]) + self.assertEqual(h.outbound.mock_calls, [ + mock.call.stop_using_connection() + ]) + clear_mock_calls(h.send, h.inbound, h.outbound) + + # leader does nothing (stays in FLUSHING) until the follower acks by + # sending RECONNECTING + + # inbound hints should be ignored during FLUSHING + with mock.patch("wormhole._dilation.manager.parse_hint", + return_value=None) as ph: + m.rx_HINTS({"hints": [1, 2, 3]}) + self.assertEqual(ph.mock_calls, []) # ignored + + c2 = mock.Mock() + connector2 = mock.Mock(return_value=c2) + with mock.patch("wormhole._dilation.manager.Connector", connector2): + # this triggers creation of a new Connector + m.rx_RECONNECTING() + self.assertEqual(h.send.mock_calls, []) + self.assertEqual(connector2.mock_calls, [ + mock.call(b"\x00"*32, None, m, h.reactor, h.eq, + False, # no_listen + None, # tor + None, # timing + LEADER, roles.LEADER), + ]) + self.assertEqual(c2.mock_calls, [mock.call.start()]) + clear_mock_calls(connector2, c2) + + self.assertEqual(h.inbound.mock_calls, []) + self.assertEqual(h.outbound.mock_calls, []) + + # and a new connection should re-register with Inbound/Outbound, + # which are responsible for re-sending unacked queued messages + c3 = mock.Mock() + m.connector_connection_made(c3) + + self.assertEqual(h.inbound.mock_calls, [mock.call.use_connection(c3)]) + self.assertEqual(h.outbound.mock_calls, [mock.call.use_connection(c3)]) + clear_mock_calls(h.inbound, h.outbound) + + def test_follower(self): + m, h = make_manager(leader=False) + + m.start() + self.assertEqual(h.send.mock_calls, [ + mock.call.send("dilate-0", + dict_to_bytes({"type": "please", "side": FOLLOWER})) + ]) + clear_mock_calls(h.send) + + c = mock.Mock() + connector = mock.Mock(return_value=c) + with mock.patch("wormhole._dilation.manager.Connector", connector): + # receiving this PLEASE triggers creation of the Connector + m.rx_PLEASE({"side": LEADER}) + self.assertEqual(h.send.mock_calls, []) + self.assertEqual(connector.mock_calls, [ + mock.call(b"\x00"*32, None, m, h.reactor, h.eq, + False, # no_listen + None, # tor + None, # timing + FOLLOWER, roles.FOLLOWER), + ]) + self.assertEqual(c.mock_calls, [mock.call.start()]) + clear_mock_calls(connector, c) + + # get connected, then lose the connection + c1 = mock.Mock() + m.connector_connection_made(c1) + self.assertEqual(h.inbound.mock_calls, [mock.call.use_connection(c1)]) + self.assertEqual(h.outbound.mock_calls, [mock.call.use_connection(c1)]) + clear_mock_calls(h.inbound, h.outbound) + + # now lose the connection. As the follower, we don't notify the + # leader, we just wait for them to notice + m.connector_connection_lost() + self.assertEqual(h.send.mock_calls, []) + self.assertEqual(h.inbound.mock_calls, [ + mock.call.stop_using_connection() + ]) + self.assertEqual(h.outbound.mock_calls, [ + mock.call.stop_using_connection() + ]) + clear_mock_calls(h.send, h.inbound, h.outbound) + + # now we get a RECONNECT: we should send RECONNECTING + c2 = mock.Mock() + connector2 = mock.Mock(return_value=c2) + with mock.patch("wormhole._dilation.manager.Connector", connector2): + m.rx_RECONNECT() + self.assertEqual(h.send.mock_calls, [ + mock.call.send("dilate-1", + dict_to_bytes({"type": "reconnecting"})) + ]) + self.assertEqual(connector2.mock_calls, [ + mock.call(b"\x00"*32, None, m, h.reactor, h.eq, + False, # no_listen + None, # tor + None, # timing + FOLLOWER, roles.FOLLOWER), + ]) + self.assertEqual(c2.mock_calls, [mock.call.start()]) + clear_mock_calls(connector2, c2) + + # while we're trying to connect, we get told to stop again, so we + # should abandon the connection attempt and start another + c3 = mock.Mock() + connector3 = mock.Mock(return_value=c3) + with mock.patch("wormhole._dilation.manager.Connector", connector3): + m.rx_RECONNECT() + self.assertEqual(c2.mock_calls, [mock.call.stop()]) + self.assertEqual(connector3.mock_calls, [ + mock.call(b"\x00"*32, None, m, h.reactor, h.eq, + False, # no_listen + None, # tor + None, # timing + FOLLOWER, roles.FOLLOWER), + ]) + self.assertEqual(c3.mock_calls, [mock.call.start()]) + clear_mock_calls(c2, connector3, c3) + + m.connector_connection_made(c3) + # finally if we're already connected, rx_RECONNECT means we should + # abandon this connection (even though it still looks ok to us), then + # when the attempt is finished stopping, we should start another + + m.rx_RECONNECT() + + c4 = mock.Mock() + connector4 = mock.Mock(return_value=c4) + with mock.patch("wormhole._dilation.manager.Connector", connector4): + m.connector_connection_lost() + self.assertEqual(c3.mock_calls, [mock.call.disconnect()]) + self.assertEqual(connector4.mock_calls, [ + mock.call(b"\x00"*32, None, m, h.reactor, h.eq, + False, # no_listen + None, # tor + None, # timing + FOLLOWER, roles.FOLLOWER), + ]) + self.assertEqual(c4.mock_calls, [mock.call.start()]) + clear_mock_calls(c3, connector4, c4) + + def test_stop(self): + pass + + + def test_mirror(self): + # receive a PLEASE with the same side as us: shouldn't happen + pass From 40dadfeb7175b8d28900e06d58deb476c5dcf838 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Fri, 21 Dec 2018 21:48:27 -0500 Subject: [PATCH 29/49] finish fixing/testing manager.py --- src/wormhole/_dilation/manager.py | 14 ++++- src/wormhole/test/dilate/test_manager.py | 80 +++++++++++++++++++++--- 2 files changed, 84 insertions(+), 10 deletions(-) diff --git a/src/wormhole/_dilation/manager.py b/src/wormhole/_dilation/manager.py index f38390b..7f5c6d8 100644 --- a/src/wormhole/_dilation/manager.py +++ b/src/wormhole/_dilation/manager.py @@ -38,6 +38,13 @@ class ReceivedHintsTooEarly(Exception): pass +class UnexpectedKCM(Exception): + pass + + +class UnknownMessageType(Exception): + pass + def make_side(): return bytes_to_hexstr(os.urandom(6)) @@ -94,7 +101,7 @@ class Manager(object): _next_subchannel_id = None # initialized in choose_role m = MethodicalMachine() - set_trace = getattr(m, "_setTrace", lambda self, f: None) + set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover def __attrs_post_init__(self): self._got_versions_d = Deferred() @@ -214,8 +221,9 @@ class Manager(object): self._inbound.handle_data(r.scid, r.data) else: # isinstance(r, Close) self._inbound.handle_close(r.scid) + return if isinstance(r, KCM): - log.err("got unexpected KCM") + log.err(UnexpectedKCM()) elif isinstance(r, Ping): self.handle_ping(r.ping_id) elif isinstance(r, Pong): @@ -223,7 +231,7 @@ class Manager(object): elif isinstance(r, Ack): self._outbound.handle_ack(r.resp_seqnum) # retire queued messages else: - log.err("received unknown message type {}".format(r)) + log.err(UnknownMessageType("{}".format(r))) # pings, pongs, and acks are not queued def send_ping(self, ping_id): diff --git a/src/wormhole/test/dilate/test_manager.py b/src/wormhole/test/dilate/test_manager.py index f0d50a0..45832c4 100644 --- a/src/wormhole/test/dilate/test_manager.py +++ b/src/wormhole/test/dilate/test_manager.py @@ -11,9 +11,11 @@ from ..._dilation import roles from ..._dilation.encode import to_be4 from ..._dilation.manager import (Dilator, Manager, make_side, OldPeerCannotDilateError, - UnknownDilationMessageType) + UnknownDilationMessageType, + UnexpectedKCM, + UnknownMessageType) from ..._dilation.subchannel import _WormholeAddress -from ..._dilation.connection import Open, Data, Close, Ack +from ..._dilation.connection import Open, Data, Close, Ack, KCM, Ping, Pong from .common import clear_mock_calls @@ -570,10 +572,74 @@ class TestManager(unittest.TestCase): self.assertEqual(c4.mock_calls, [mock.call.start()]) clear_mock_calls(c3, connector4, c4) - def test_stop(self): - pass - - def test_mirror(self): # receive a PLEASE with the same side as us: shouldn't happen - pass + m, h = make_manager(leader=True) + + m.start() + clear_mock_calls(h.send) + e = self.assertRaises(ValueError, m.rx_PLEASE, {"side": LEADER}) + self.assertEqual(str(e), "their side shouldn't be equal: reflection?") + + def test_ping_pong(self): + m, h = make_manager(leader=False) + + m.got_record(KCM()) + self.flushLoggedErrors(UnexpectedKCM) + + m.got_record(Ping(1)) + self.assertEqual(h.outbound.mock_calls, + [mock.call.send_if_connected(Pong(1))]) + clear_mock_calls(h.outbound) + + m.got_record(Pong(2)) + # currently ignored, will eventually update a timer + + m.got_record("not recognized") + e = self.flushLoggedErrors(UnknownMessageType) + self.assertEqual(len(e), 1) + self.assertEqual(str(e[0].value), "not recognized") + + m.send_ping(3) + self.assertEqual(h.outbound.mock_calls, + [mock.call.send_if_connected(Pong(3))]) + clear_mock_calls(h.outbound) + + def test_subchannel(self): + m, h = make_manager(leader=True) + sc = object() + + m.subchannel_pauseProducing(sc) + self.assertEqual(h.inbound.mock_calls, [ + mock.call.subchannel_pauseProducing(sc)]) + clear_mock_calls(h.inbound) + + m.subchannel_resumeProducing(sc) + self.assertEqual(h.inbound.mock_calls, [ + mock.call.subchannel_resumeProducing(sc)]) + clear_mock_calls(h.inbound) + + m.subchannel_stopProducing(sc) + self.assertEqual(h.inbound.mock_calls, [ + mock.call.subchannel_stopProducing(sc)]) + clear_mock_calls(h.inbound) + + p = object() + streaming = object() + + m.subchannel_registerProducer(sc, p, streaming) + self.assertEqual(h.outbound.mock_calls, [ + mock.call.subchannel_registerProducer(sc, p, streaming)]) + clear_mock_calls(h.outbound) + + m.subchannel_unregisterProducer(sc) + self.assertEqual(h.outbound.mock_calls, [ + mock.call.subchannel_unregisterProducer(sc)]) + clear_mock_calls(h.outbound) + + m.subchannel_closed("scid", sc) + self.assertEqual(h.inbound.mock_calls, [ + mock.call.subchannel_closed("scid", sc)]) + self.assertEqual(h.outbound.mock_calls, [ + mock.call.subchannel_closed("scid", sc)]) + clear_mock_calls(h.inbound, h.outbound) From 3b7c9831f64b557a36356caed08f42837720ae38 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Fri, 21 Dec 2018 21:54:17 -0500 Subject: [PATCH 30/49] appease flake8 somewhat --- src/wormhole/_dilation/manager.py | 1 + src/wormhole/test/dilate/test_manager.py | 67 +++++++++++++----------- 2 files changed, 37 insertions(+), 31 deletions(-) diff --git a/src/wormhole/_dilation/manager.py b/src/wormhole/_dilation/manager.py index 7f5c6d8..d9db016 100644 --- a/src/wormhole/_dilation/manager.py +++ b/src/wormhole/_dilation/manager.py @@ -45,6 +45,7 @@ class UnexpectedKCM(Exception): class UnknownMessageType(Exception): pass + def make_side(): return bytes_to_hexstr(os.urandom(6)) diff --git a/src/wormhole/test/dilate/test_manager.py b/src/wormhole/test/dilate/test_manager.py index 45832c4..e223a1c 100644 --- a/src/wormhole/test/dilate/test_manager.py +++ b/src/wormhole/test/dilate/test_manager.py @@ -35,6 +35,7 @@ def make_dilator(): dil.wire(send) return dil, send, reactor, eq, clock, coop + class TestDilator(unittest.TestCase): def test_manager_and_endpoints(self): dil, send, reactor, eq, clock, coop = make_dilator() @@ -142,7 +143,7 @@ class TestDilator(unittest.TestCase): d1 = dil.dilate() self.assertNoResult(d1) - dil._transit_key = b"\x01"*32 + dil._transit_key = b"\x01" * 32 dil.got_wormhole_versions({}) # missing "can-dilate" eq.flush_sync() f = self.failureResultOf(d1) @@ -201,9 +202,11 @@ class TestDilator(unittest.TestCase): mock.call().start(), mock.call().when_first_connected()]) + LEADER = "ff3456abcdef" FOLLOWER = "123456abcdef" + def make_manager(leader=True): class Holder: pass @@ -214,12 +217,13 @@ def make_manager(leader=True): side = LEADER else: side = FOLLOWER - h.key = b"\x00"*32 + h.key = b"\x00" * 32 h.relay = None h.reactor = object() h.clock = Clock() h.eq = EventualQueue(h.clock) term = mock.Mock(side_effect=lambda: True) # one write per Eventual tick + def term_factory(): return term h.coop = Cooperator(terminationPredicateFactory=term_factory, @@ -236,11 +240,12 @@ def make_manager(leader=True): m = Manager(h.send, side, h.key, h.relay, h.reactor, h.eq, h.coop) return m, h + class TestManager(unittest.TestCase): def test_make_side(self): side = make_side() self.assertEqual(type(side), type(u"")) - self.assertEqual(len(side), 2*6) + self.assertEqual(len(side), 2 * 6) def test_create(self): m, h = make_manager() @@ -272,10 +277,10 @@ class TestManager(unittest.TestCase): m.rx_PLEASE({"side": FOLLOWER}) self.assertEqual(h.send.mock_calls, []) self.assertEqual(connector.mock_calls, [ - mock.call(b"\x00"*32, None, m, h.reactor, h.eq, - False, # no_listen - None, # tor - None, # timing + mock.call(b"\x00" * 32, None, m, h.reactor, h.eq, + False, # no_listen + None, # tor + None, # timing LEADER, roles.LEADER), ]) self.assertEqual(c.mock_calls, [mock.call.start()]) @@ -292,7 +297,7 @@ class TestManager(unittest.TestCase): clear_mock_calls(ph, c) # and we send out any (listening) hints from our Connector - m.send_hints([1,2]) + m.send_hints([1, 2]) self.assertEqual(h.send.mock_calls, [ mock.call.send("dilate-1", dict_to_bytes({"type": "connection-hints", @@ -310,7 +315,7 @@ class TestManager(unittest.TestCase): clear_mock_calls(h.inbound, h.outbound) h.eq.flush_sync() - self.successResultOf(wfc_d) # fires with None + self.successResultOf(wfc_d) # fires with None wfc_d2 = m.when_first_connected() h.eq.flush_sync() self.successResultOf(wfc_d2) @@ -329,7 +334,7 @@ class TestManager(unittest.TestCase): # the Leader making a new outbound channel should get scid=1 scid1 = to_be4(1) self.assertEqual(m.allocate_subchannel_id(), scid1) - r1 = Open(10, scid1) # seqnum=10 + r1 = Open(10, scid1) # seqnum=10 h.outbound.build_record = mock.Mock(return_value=r1) m.send_open(scid1) self.assertEqual(h.outbound.mock_calls, [ @@ -439,7 +444,7 @@ class TestManager(unittest.TestCase): with mock.patch("wormhole._dilation.manager.parse_hint", return_value=None) as ph: m.rx_HINTS({"hints": [1, 2, 3]}) - self.assertEqual(ph.mock_calls, []) # ignored + self.assertEqual(ph.mock_calls, []) # ignored c2 = mock.Mock() connector2 = mock.Mock(return_value=c2) @@ -448,10 +453,10 @@ class TestManager(unittest.TestCase): m.rx_RECONNECTING() self.assertEqual(h.send.mock_calls, []) self.assertEqual(connector2.mock_calls, [ - mock.call(b"\x00"*32, None, m, h.reactor, h.eq, - False, # no_listen - None, # tor - None, # timing + mock.call(b"\x00" * 32, None, m, h.reactor, h.eq, + False, # no_listen + None, # tor + None, # timing LEADER, roles.LEADER), ]) self.assertEqual(c2.mock_calls, [mock.call.start()]) @@ -486,10 +491,10 @@ class TestManager(unittest.TestCase): m.rx_PLEASE({"side": LEADER}) self.assertEqual(h.send.mock_calls, []) self.assertEqual(connector.mock_calls, [ - mock.call(b"\x00"*32, None, m, h.reactor, h.eq, - False, # no_listen - None, # tor - None, # timing + mock.call(b"\x00" * 32, None, m, h.reactor, h.eq, + False, # no_listen + None, # tor + None, # timing FOLLOWER, roles.FOLLOWER), ]) self.assertEqual(c.mock_calls, [mock.call.start()]) @@ -524,10 +529,10 @@ class TestManager(unittest.TestCase): dict_to_bytes({"type": "reconnecting"})) ]) self.assertEqual(connector2.mock_calls, [ - mock.call(b"\x00"*32, None, m, h.reactor, h.eq, - False, # no_listen - None, # tor - None, # timing + mock.call(b"\x00" * 32, None, m, h.reactor, h.eq, + False, # no_listen + None, # tor + None, # timing FOLLOWER, roles.FOLLOWER), ]) self.assertEqual(c2.mock_calls, [mock.call.start()]) @@ -541,10 +546,10 @@ class TestManager(unittest.TestCase): m.rx_RECONNECT() self.assertEqual(c2.mock_calls, [mock.call.stop()]) self.assertEqual(connector3.mock_calls, [ - mock.call(b"\x00"*32, None, m, h.reactor, h.eq, - False, # no_listen - None, # tor - None, # timing + mock.call(b"\x00" * 32, None, m, h.reactor, h.eq, + False, # no_listen + None, # tor + None, # timing FOLLOWER, roles.FOLLOWER), ]) self.assertEqual(c3.mock_calls, [mock.call.start()]) @@ -563,10 +568,10 @@ class TestManager(unittest.TestCase): m.connector_connection_lost() self.assertEqual(c3.mock_calls, [mock.call.disconnect()]) self.assertEqual(connector4.mock_calls, [ - mock.call(b"\x00"*32, None, m, h.reactor, h.eq, - False, # no_listen - None, # tor - None, # timing + mock.call(b"\x00" * 32, None, m, h.reactor, h.eq, + False, # no_listen + None, # tor + None, # timing FOLLOWER, roles.FOLLOWER), ]) self.assertEqual(c4.mock_calls, [mock.call.start()]) From bd1a199f3e9e1177ac2f17c7acd920ecf8f26a34 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Fri, 21 Dec 2018 22:39:10 -0500 Subject: [PATCH 31/49] start factoring Hints out to separate file shared between old transit.py and new _dilation/connector.py --- src/wormhole/_dilation/connector.py | 61 +---------------------------- src/wormhole/_hints.py | 58 +++++++++++++++++++++++++++ src/wormhole/transit.py | 61 +---------------------------- 3 files changed, 62 insertions(+), 118 deletions(-) create mode 100644 src/wormhole/_hints.py diff --git a/src/wormhole/_dilation/connector.py b/src/wormhole/_dilation/connector.py index f530039..8d7f4a5 100644 --- a/src/wormhole/_dilation/connector.py +++ b/src/wormhole/_dilation/connector.py @@ -1,7 +1,5 @@ from __future__ import print_function, unicode_literals -import sys -import re -from collections import defaultdict, namedtuple +from collections import defaultdict from binascii import hexlify import six from attr import attrs, attrib @@ -21,25 +19,7 @@ from ..observer import EmptyableSet from .connection import DilatedConnectionProtocol, KCM from .roles import LEADER - -# These namedtuples are "hint objects". The JSON-serializable dictionaries -# are "hint dicts". - -# DirectTCPV1Hint and TorTCPV1Hint mean the following protocol: -# * make a TCP connection (possibly via Tor) -# * send the sender/receiver handshake bytes first -# * expect to see the receiver/sender handshake bytes from the other side -# * the sender writes "go\n", the receiver waits for "go\n" -# * the rest of the connection contains transit data -DirectTCPV1Hint = namedtuple( - "DirectTCPV1Hint", ["hostname", "port", "priority"]) -TorTCPV1Hint = namedtuple("TorTCPV1Hint", ["hostname", "port", "priority"]) -# RelayV1Hint contains a tuple of DirectTCPV1Hint and TorTCPV1Hint hints (we -# use a tuple rather than a list so they'll be hashable into a set). For each -# one, make the TCP connection, send the relay handshake, then complete the -# rest of the V1 protocol. Only one hint per relay is useful. -RelayV1Hint = namedtuple("RelayV1Hint", ["hints"]) - +from .._hints import parse_hint_argv, DirectTCPV1Hint, TorTCPV1Hint, RelayV1Hint def describe_hint_obj(hint, relay, tor): prefix = "tor->" if tor else "->" @@ -53,43 +33,6 @@ def describe_hint_obj(hint, relay, tor): return prefix + str(hint) -def parse_hint_argv(hint, stderr=sys.stderr): - assert isinstance(hint, type("")) - # return tuple or None for an unparseable hint - priority = 0.0 - mo = re.search(r'^([a-zA-Z0-9]+):(.*)$', hint) - if not mo: - print("unparseable hint '%s'" % (hint,), file=stderr) - return None - hint_type = mo.group(1) - if hint_type != "tcp": - print("unknown hint type '%s' in '%s'" % (hint_type, hint), - file=stderr) - return None - hint_value = mo.group(2) - pieces = hint_value.split(":") - if len(pieces) < 2: - print("unparseable TCP hint (need more colons) '%s'" % (hint,), - file=stderr) - return None - mo = re.search(r'^(\d+)$', pieces[1]) - if not mo: - print("non-numeric port in TCP hint '%s'" % (hint,), file=stderr) - return None - hint_host = pieces[0] - hint_port = int(pieces[1]) - for more in pieces[2:]: - if more.startswith("priority="): - more_pieces = more.split("=") - try: - priority = float(more_pieces[1]) - except ValueError: - print("non-float priority= in TCP hint '%s'" % (hint,), - file=stderr) - return None - return DirectTCPV1Hint(hint_host, hint_port, priority) - - def parse_tcp_v1_hint(hint): # hint_struct -> hint_obj hint_type = hint.get("type", "") if hint_type not in ["direct-tcp-v1", "tor-tcp-v1"]: diff --git a/src/wormhole/_hints.py b/src/wormhole/_hints.py new file mode 100644 index 0000000..8955e97 --- /dev/null +++ b/src/wormhole/_hints.py @@ -0,0 +1,58 @@ +from __future__ import print_function, unicode_literals +import sys +import re +from collections import namedtuple + +# These namedtuples are "hint objects". The JSON-serializable dictionaries +# are "hint dicts". + +# DirectTCPV1Hint and TorTCPV1Hint mean the following protocol: +# * make a TCP connection (possibly via Tor) +# * send the sender/receiver handshake bytes first +# * expect to see the receiver/sender handshake bytes from the other side +# * the sender writes "go\n", the receiver waits for "go\n" +# * the rest of the connection contains transit data +DirectTCPV1Hint = namedtuple("DirectTCPV1Hint", + ["hostname", "port", "priority"]) +TorTCPV1Hint = namedtuple("TorTCPV1Hint", ["hostname", "port", "priority"]) +# RelayV1Hint contains a tuple of DirectTCPV1Hint and TorTCPV1Hint hints (we +# use a tuple rather than a list so they'll be hashable into a set). For each +# one, make the TCP connection, send the relay handshake, then complete the +# rest of the V1 protocol. Only one hint per relay is useful. +RelayV1Hint = namedtuple("RelayV1Hint", ["hints"]) + +def parse_hint_argv(hint, stderr=sys.stderr): + assert isinstance(hint, type(u"")) + # return tuple or None for an unparseable hint + priority = 0.0 + mo = re.search(r'^([a-zA-Z0-9]+):(.*)$', hint) + if not mo: + print("unparseable hint '%s'" % (hint, ), file=stderr) + return None + hint_type = mo.group(1) + if hint_type != "tcp": + print("unknown hint type '%s' in '%s'" % (hint_type, hint), + file=stderr) + return None + hint_value = mo.group(2) + pieces = hint_value.split(":") + if len(pieces) < 2: + print("unparseable TCP hint (need more colons) '%s'" % (hint, ), + file=stderr) + return None + mo = re.search(r'^(\d+)$', pieces[1]) + if not mo: + print("non-numeric port in TCP hint '%s'" % (hint, ), file=stderr) + return None + hint_host = pieces[0] + hint_port = int(pieces[1]) + for more in pieces[2:]: + if more.startswith("priority="): + more_pieces = more.split("=") + try: + priority = float(more_pieces[1]) + except ValueError: + print("non-float priority= in TCP hint '%s'" % (hint, ), + file=stderr) + return None + return DirectTCPV1Hint(hint_host, hint_port, priority) diff --git a/src/wormhole/transit.py b/src/wormhole/transit.py index 96d8f95..fa708fb 100644 --- a/src/wormhole/transit.py +++ b/src/wormhole/transit.py @@ -2,12 +2,11 @@ from __future__ import absolute_import, print_function import os -import re import socket import sys import time from binascii import hexlify, unhexlify -from collections import deque, namedtuple +from collections import deque import six from hkdf import Hkdf @@ -95,24 +94,7 @@ def build_sided_relay_handshake(key, side): "ascii") + b"\n" -# These namedtuples are "hint objects". The JSON-serializable dictionaries -# are "hint dicts". - -# DirectTCPV1Hint and TorTCPV1Hint mean the following protocol: -# * make a TCP connection (possibly via Tor) -# * send the sender/receiver handshake bytes first -# * expect to see the receiver/sender handshake bytes from the other side -# * the sender writes "go\n", the receiver waits for "go\n" -# * the rest of the connection contains transit data -DirectTCPV1Hint = namedtuple("DirectTCPV1Hint", - ["hostname", "port", "priority"]) -TorTCPV1Hint = namedtuple("TorTCPV1Hint", ["hostname", "port", "priority"]) -# RelayV1Hint contains a tuple of DirectTCPV1Hint and TorTCPV1Hint hints (we -# use a tuple rather than a list so they'll be hashable into a set). For each -# one, make the TCP connection, send the relay handshake, then complete the -# rest of the V1 protocol. Only one hint per relay is useful. -RelayV1Hint = namedtuple("RelayV1Hint", ["hints"]) - +from ._hints import parse_hint_argv, DirectTCPV1Hint, TorTCPV1Hint, RelayV1Hint def describe_hint_obj(hint): if isinstance(hint, DirectTCPV1Hint): @@ -123,45 +105,6 @@ def describe_hint_obj(hint): return str(hint) -def parse_hint_argv(hint, stderr=sys.stderr): - assert isinstance(hint, type(u"")) - # return tuple or None for an unparseable hint - priority = 0.0 - mo = re.search(r'^([a-zA-Z0-9]+):(.*)$', hint) - if not mo: - print("unparseable hint '%s'" % (hint, ), file=stderr) - return None - hint_type = mo.group(1) - if hint_type != "tcp": - print( - "unknown hint type '%s' in '%s'" % (hint_type, hint), file=stderr) - return None - hint_value = mo.group(2) - pieces = hint_value.split(":") - if len(pieces) < 2: - print( - "unparseable TCP hint (need more colons) '%s'" % (hint, ), - file=stderr) - return None - mo = re.search(r'^(\d+)$', pieces[1]) - if not mo: - print("non-numeric port in TCP hint '%s'" % (hint, ), file=stderr) - return None - hint_host = pieces[0] - hint_port = int(pieces[1]) - for more in pieces[2:]: - if more.startswith("priority="): - more_pieces = more.split("=") - try: - priority = float(more_pieces[1]) - except ValueError: - print( - "non-float priority= in TCP hint '%s'" % (hint, ), - file=stderr) - return None - return DirectTCPV1Hint(hint_host, hint_port, priority) - - TIMEOUT = 60 # seconds From 2f4e4d30318261548ddf4b7009eba17ffde3b5dd Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Fri, 21 Dec 2018 23:12:17 -0500 Subject: [PATCH 32/49] factor out describe_hint_obj and endpoint_from_hint_obj --- src/wormhole/_dilation/connector.py | 34 ++------ src/wormhole/_hints.py | 26 ++++++ src/wormhole/test/test_transit.py | 122 +++++++++++++++------------- src/wormhole/transit.py | 42 ++-------- 4 files changed, 104 insertions(+), 120 deletions(-) diff --git a/src/wormhole/_dilation/connector.py b/src/wormhole/_dilation/connector.py index 8d7f4a5..f472ea9 100644 --- a/src/wormhole/_dilation/connector.py +++ b/src/wormhole/_dilation/connector.py @@ -8,7 +8,7 @@ from automat import MethodicalMachine from zope.interface import implementer from twisted.internet.task import deferLater from twisted.internet.defer import DeferredList -from twisted.internet.endpoints import HostnameEndpoint, serverFromString +from twisted.internet.endpoints import serverFromString from twisted.internet.protocol import ClientFactory, ServerFactory from twisted.python import log from hkdf import Hkdf @@ -19,18 +19,8 @@ from ..observer import EmptyableSet from .connection import DilatedConnectionProtocol, KCM from .roles import LEADER -from .._hints import parse_hint_argv, DirectTCPV1Hint, TorTCPV1Hint, RelayV1Hint - -def describe_hint_obj(hint, relay, tor): - prefix = "tor->" if tor else "->" - if relay: - prefix = prefix + "relay:" - if isinstance(hint, DirectTCPV1Hint): - return prefix + "tcp:%s:%d" % (hint.hostname, hint.port) - elif isinstance(hint, TorTCPV1Hint): - return prefix + "tor:%s:%d" % (hint.hostname, hint.port) - else: - return prefix + str(hint) +from .._hints import (DirectTCPV1Hint, TorTCPV1Hint, RelayV1Hint, + parse_hint_argv, describe_hint_obj, endpoint_from_hint_obj) def parse_tcp_v1_hint(hint): # hint_struct -> hint_obj @@ -340,7 +330,7 @@ class Connector(object): for h in direct[p]: if isinstance(h, TorTCPV1Hint) and not self._tor: continue - ep = self._endpoint_from_hint_obj(h) + ep = endpoint_from_hint_obj(h, self._tor, self._reactor) desc = describe_hint_obj(h, False, self._tor) d = deferLater(self._reactor, delay, self._connect, ep, desc, is_relay=False) @@ -376,7 +366,7 @@ class Connector(object): for p in priorities: for r in relays[p]: for h in r.hints: - ep = self._endpoint_from_hint_obj(h) + ep = endpoint_from_hint_obj(h, self._tor, self._reactor) desc = describe_hint_obj(h, True, self._tor) d = deferLater(self._reactor, delay, self._connect, ep, desc, is_relay=True) @@ -405,20 +395,6 @@ class Connector(object): d.addCallback(_connected) return d - def _endpoint_from_hint_obj(self, hint): - if self._tor: - if isinstance(hint, (DirectTCPV1Hint, TorTCPV1Hint)): - # this Tor object will throw ValueError for non-public IPv4 - # addresses and any IPv6 address - try: - return self._tor.stream_via(hint.hostname, hint.port) - except ValueError: - return None - return None - if isinstance(hint, DirectTCPV1Hint): - return HostnameEndpoint(self._reactor, hint.hostname, hint.port) - return None - # Connection selection. All instances of DilatedConnectionProtocol which # look viable get passed into our add_contender() method. diff --git a/src/wormhole/_hints.py b/src/wormhole/_hints.py index 8955e97..9b79542 100644 --- a/src/wormhole/_hints.py +++ b/src/wormhole/_hints.py @@ -2,6 +2,7 @@ from __future__ import print_function, unicode_literals import sys import re from collections import namedtuple +from twisted.internet.endpoints import HostnameEndpoint # These namedtuples are "hint objects". The JSON-serializable dictionaries # are "hint dicts". @@ -21,6 +22,17 @@ TorTCPV1Hint = namedtuple("TorTCPV1Hint", ["hostname", "port", "priority"]) # rest of the V1 protocol. Only one hint per relay is useful. RelayV1Hint = namedtuple("RelayV1Hint", ["hints"]) +def describe_hint_obj(hint, relay, tor): + prefix = "tor->" if tor else "->" + if relay: + prefix = prefix + "relay:" + if isinstance(hint, DirectTCPV1Hint): + return prefix + "tcp:%s:%d" % (hint.hostname, hint.port) + elif isinstance(hint, TorTCPV1Hint): + return prefix + "tor:%s:%d" % (hint.hostname, hint.port) + else: + return prefix + str(hint) + def parse_hint_argv(hint, stderr=sys.stderr): assert isinstance(hint, type(u"")) # return tuple or None for an unparseable hint @@ -56,3 +68,17 @@ def parse_hint_argv(hint, stderr=sys.stderr): file=stderr) return None return DirectTCPV1Hint(hint_host, hint_port, priority) + +def endpoint_from_hint_obj(hint, tor, reactor): + if tor: + if isinstance(hint, (DirectTCPV1Hint, TorTCPV1Hint)): + # this Tor object will throw ValueError for non-public IPv4 + # addresses and any IPv6 address + try: + return tor.stream_via(hint.hostname, hint.port) + except ValueError: + return None + return None + if isinstance(hint, DirectTCPV1Hint): + return HostnameEndpoint(reactor, hint.hostname, hint.port) + return None diff --git a/src/wormhole/test/test_transit.py b/src/wormhole/test/test_transit.py index 63216ae..9a53983 100644 --- a/src/wormhole/test/test_transit.py +++ b/src/wormhole/test/test_transit.py @@ -8,7 +8,7 @@ from collections import namedtuple import six from nacl.exceptions import CryptoError from nacl.secret import SecretBox -from twisted.internet import address, defer, endpoints, error, protocol, task +from twisted.internet import address, defer, endpoints, error, protocol, task, reactor from twisted.internet.defer import gatherResults, inlineCallbacks from twisted.python import log from twisted.test import proto_helpers @@ -18,6 +18,7 @@ import mock from wormhole_transit_relay import transit_server from .. import transit +from .._hints import endpoint_from_hint_obj from ..errors import InternalError from .common import ServerBase @@ -145,30 +146,32 @@ UnknownHint = namedtuple("UnknownHint", ["stuff"]) class Hints(unittest.TestCase): def test_endpoint_from_hint_obj(self): - c = transit.Common("") - efho = c._endpoint_from_hint_obj + def efho(hint, tor=None): + return endpoint_from_hint_obj(hint, tor, reactor) self.assertIsInstance( efho(transit.DirectTCPV1Hint("host", 1234, 0.0)), endpoints.HostnameEndpoint) self.assertEqual(efho("unknown:stuff:yowza:pivlor"), None) - # c._tor is currently None - self.assertEqual(efho(transit.TorTCPV1Hint("host", "port", 0)), None) - c._tor = mock.Mock() + # tor=None + self.assertEqual(efho(transit.TorTCPV1Hint("host", "port", 0)), None) + + tor = mock.Mock() def tor_ep(hostname, port): if hostname == "non-public": return None return ("tor_ep", hostname, port) + tor.stream_via = mock.Mock(side_effect=tor_ep) - c._tor.stream_via = mock.Mock(side_effect=tor_ep) self.assertEqual( - efho(transit.DirectTCPV1Hint("host", 1234, 0.0)), + efho(transit.DirectTCPV1Hint("host", 1234, 0.0), tor), ("tor_ep", "host", 1234)) self.assertEqual( - efho(transit.TorTCPV1Hint("host2.onion", 1234, 0.0)), + efho(transit.TorTCPV1Hint("host2.onion", 1234, 0.0), tor), ("tor_ep", "host2.onion", 1234)) self.assertEqual( - efho(transit.DirectTCPV1Hint("non-public", 1234, 0.0)), None) + efho(transit.DirectTCPV1Hint("non-public", 1234, 0.0), tor), None) + self.assertEqual(efho(UnknownHint("foo")), None) def test_comparable(self): @@ -270,10 +273,13 @@ class Hints(unittest.TestCase): def test_describe_hint_obj(self): d = transit.describe_hint_obj self.assertEqual( - d(transit.DirectTCPV1Hint("host", 1234, 0.0)), "tcp:host:1234") + d(transit.DirectTCPV1Hint("host", 1234, 0.0), False, False), + "->tcp:host:1234") self.assertEqual( - d(transit.TorTCPV1Hint("host", 1234, 0.0)), "tor:host:1234") - self.assertEqual(d(UnknownHint("stuff")), str(UnknownHint("stuff"))) + d(transit.TorTCPV1Hint("host", 1234, 0.0), False, False), + "->tor:host:1234") + self.assertEqual(d(UnknownHint("stuff"), False, False), + "->%s" % str(UnknownHint("stuff"))) # ipaddrs.py currently uses native strings: bytes on py2, unicode on @@ -1507,7 +1513,7 @@ class Transit(unittest.TestCase): self.assertEqual(self.successResultOf(d), "winner") self.assertEqual(self._descriptions, ["tor->relay:tcp:relay:1234"]) - def _endpoint_from_hint_obj(self, hint): + def _endpoint_from_hint_obj(self, hint, _tor, _reactor): if isinstance(hint, transit.DirectTCPV1Hint): if hint.hostname == "unavailable": return None @@ -1523,20 +1529,21 @@ class Transit(unittest.TestCase): del hints s.add_connection_hints( [DIRECT_HINT_JSON, UNRECOGNIZED_HINT_JSON, RELAY_HINT_JSON]) - s._endpoint_from_hint_obj = self._endpoint_from_hint_obj s._start_connector = self._start_connector - d = s.connect() - self.assertNoResult(d) - # the direct connectors are tried right away, but the relay - # connectors are stalled for a few seconds - self.assertEqual(self._connectors, ["direct"]) + with mock.patch("wormhole.transit.endpoint_from_hint_obj", + self._endpoint_from_hint_obj): + d = s.connect() + self.assertNoResult(d) + # the direct connectors are tried right away, but the relay + # connectors are stalled for a few seconds + self.assertEqual(self._connectors, ["direct"]) - clock.advance(s.RELAY_DELAY + 1.0) - self.assertEqual(self._connectors, ["direct", "relay"]) + clock.advance(s.RELAY_DELAY + 1.0) + self.assertEqual(self._connectors, ["direct", "relay"]) - self._waiters[0].callback("winner") - self.assertEqual(self.successResultOf(d), "winner") + self._waiters[0].callback("winner") + self.assertEqual(self.successResultOf(d), "winner") @inlineCallbacks def test_priorities(self): @@ -1586,31 +1593,32 @@ class Transit(unittest.TestCase): }] }, ]) - s._endpoint_from_hint_obj = self._endpoint_from_hint_obj s._start_connector = self._start_connector - d = s.connect() - self.assertNoResult(d) - # direct connector should be used first, then the priority=3.0 relay, - # then the two 2.0 relays, then the (default) 0.0 relay + with mock.patch("wormhole.transit.endpoint_from_hint_obj", + self._endpoint_from_hint_obj): + d = s.connect() + self.assertNoResult(d) + # direct connector should be used first, then the priority=3.0 relay, + # then the two 2.0 relays, then the (default) 0.0 relay - self.assertEqual(self._connectors, ["direct"]) + self.assertEqual(self._connectors, ["direct"]) - clock.advance(s.RELAY_DELAY + 1.0) - self.assertEqual(self._connectors, ["direct", "relay3"]) + clock.advance(s.RELAY_DELAY + 1.0) + self.assertEqual(self._connectors, ["direct", "relay3"]) - clock.advance(s.RELAY_DELAY) - self.assertIn(self._connectors, - (["direct", "relay3", "relay2", "relay4"], - ["direct", "relay3", "relay4", "relay2"])) + clock.advance(s.RELAY_DELAY) + self.assertIn(self._connectors, + (["direct", "relay3", "relay2", "relay4"], + ["direct", "relay3", "relay4", "relay2"])) - clock.advance(s.RELAY_DELAY) - self.assertIn(self._connectors, - (["direct", "relay3", "relay2", "relay4", "relay"], - ["direct", "relay3", "relay4", "relay2", "relay"])) + clock.advance(s.RELAY_DELAY) + self.assertIn(self._connectors, + (["direct", "relay3", "relay2", "relay4", "relay"], + ["direct", "relay3", "relay4", "relay2", "relay"])) - self._waiters[0].callback("winner") - self.assertEqual(self.successResultOf(d), "winner") + self._waiters[0].callback("winner") + self.assertEqual(self.successResultOf(d), "winner") @inlineCallbacks def test_no_direct_hints(self): @@ -1624,20 +1632,21 @@ class Transit(unittest.TestCase): UNRECOGNIZED_HINT_JSON, UNAVAILABLE_HINT_JSON, RELAY_HINT2_JSON, UNAVAILABLE_RELAY_HINT_JSON ]) - s._endpoint_from_hint_obj = self._endpoint_from_hint_obj s._start_connector = self._start_connector - d = s.connect() - self.assertNoResult(d) - # since there are no usable direct hints, the relay connector will - # only be stalled for 0 seconds - self.assertEqual(self._connectors, []) + with mock.patch("wormhole.transit.endpoint_from_hint_obj", + self._endpoint_from_hint_obj): + d = s.connect() + self.assertNoResult(d) + # since there are no usable direct hints, the relay connector will + # only be stalled for 0 seconds + self.assertEqual(self._connectors, []) - clock.advance(0) - self.assertEqual(self._connectors, ["relay"]) + clock.advance(0) + self.assertEqual(self._connectors, ["relay"]) - self._waiters[0].callback("winner") - self.assertEqual(self.successResultOf(d), "winner") + self._waiters[0].callback("winner") + self.assertEqual(self.successResultOf(d), "winner") @inlineCallbacks def test_no_contenders(self): @@ -1647,12 +1656,13 @@ class Transit(unittest.TestCase): hints = yield s.get_connection_hints() # start the listener del hints s.add_connection_hints([]) # no hints at all - s._endpoint_from_hint_obj = self._endpoint_from_hint_obj s._start_connector = self._start_connector - d = s.connect() - f = self.failureResultOf(d, transit.TransitError) - self.assertEqual(str(f.value), "No contenders for connection") + with mock.patch("wormhole.transit.endpoint_from_hint_obj", + self._endpoint_from_hint_obj): + d = s.connect() + f = self.failureResultOf(d, transit.TransitError) + self.assertEqual(str(f.value), "No contenders for connection") class RelayHandshake(unittest.TestCase): diff --git a/src/wormhole/transit.py b/src/wormhole/transit.py index fa708fb..b122055 100644 --- a/src/wormhole/transit.py +++ b/src/wormhole/transit.py @@ -23,6 +23,8 @@ from . import ipaddrs from .errors import InternalError from .timing import DebugTiming from .util import bytes_to_hexstr +from ._hints import (DirectTCPV1Hint, TorTCPV1Hint, RelayV1Hint, + parse_hint_argv, describe_hint_obj, endpoint_from_hint_obj) def HKDF(skm, outlen, salt=None, CTXinfo=b""): @@ -94,16 +96,6 @@ def build_sided_relay_handshake(key, side): "ascii") + b"\n" -from ._hints import parse_hint_argv, DirectTCPV1Hint, TorTCPV1Hint, RelayV1Hint - -def describe_hint_obj(hint): - if isinstance(hint, DirectTCPV1Hint): - return u"tcp:%s:%d" % (hint.hostname, hint.port) - elif isinstance(hint, TorTCPV1Hint): - return u"tor:%s:%d" % (hint.hostname, hint.port) - else: - return str(hint) - TIMEOUT = 60 # seconds @@ -818,13 +810,11 @@ class Common: # Check the hint type to see if we can support it (e.g. skip # onion hints on a non-Tor client). Do not increase relay_delay # unless we have at least one viable hint. - ep = self._endpoint_from_hint_obj(hint_obj) + ep = endpoint_from_hint_obj(hint_obj, self._tor, self._reactor) if not ep: continue - description = "->%s" % describe_hint_obj(hint_obj) - if self._tor: - description = "tor" + description - d = self._start_connector(ep, description) + d = self._start_connector(ep, + describe_hint_obj(hint_obj, False, self._tor)) contenders.append(d) relay_delay = self.RELAY_DELAY @@ -845,18 +835,15 @@ class Common: for priority in sorted(prioritized_relays, reverse=True): for hint_obj in prioritized_relays[priority]: - ep = self._endpoint_from_hint_obj(hint_obj) + ep = endpoint_from_hint_obj(hint_obj, self._tor, self._reactor) if not ep: continue - description = "->relay:%s" % describe_hint_obj(hint_obj) - if self._tor: - description = "tor" + description d = task.deferLater( self._reactor, relay_delay, self._start_connector, ep, - description, + describe_hint_obj(hint_obj, True, self._tor), is_relay=True) contenders.append(d) relay_delay += self.RELAY_DELAY @@ -894,21 +881,6 @@ class Common: d.addCallback(lambda p: p.startNegotiation()) return d - def _endpoint_from_hint_obj(self, hint): - if self._tor: - if isinstance(hint, (DirectTCPV1Hint, TorTCPV1Hint)): - # this Tor object will throw ValueError for non-public IPv4 - # addresses and any IPv6 address - try: - return self._tor.stream_via(hint.hostname, hint.port) - except ValueError: - return None - return None - if isinstance(hint, DirectTCPV1Hint): - return endpoints.HostnameEndpoint(self._reactor, hint.hostname, - hint.port) - return None - def connection_ready(self, p): # inbound/outbound Connection protocols call this when they finish # negotiation. The first one wins and gets a "go". Any subsequent From 7720312c8f22e914bf9ffce1f1240f05ff86d86d Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Fri, 21 Dec 2018 23:22:02 -0500 Subject: [PATCH 33/49] factor out parse_tcp_v1_hint --- src/wormhole/_dilation/connector.py | 23 +--------- src/wormhole/_hints.py | 21 +++++++++ src/wormhole/test/test_transit.py | 68 +++++++++++++---------------- src/wormhole/transit.py | 28 +++--------- 4 files changed, 59 insertions(+), 81 deletions(-) diff --git a/src/wormhole/_dilation/connector.py b/src/wormhole/_dilation/connector.py index f472ea9..3d4d773 100644 --- a/src/wormhole/_dilation/connector.py +++ b/src/wormhole/_dilation/connector.py @@ -20,27 +20,8 @@ from .connection import DilatedConnectionProtocol, KCM from .roles import LEADER from .._hints import (DirectTCPV1Hint, TorTCPV1Hint, RelayV1Hint, - parse_hint_argv, describe_hint_obj, endpoint_from_hint_obj) - - -def parse_tcp_v1_hint(hint): # hint_struct -> hint_obj - hint_type = hint.get("type", "") - if hint_type not in ["direct-tcp-v1", "tor-tcp-v1"]: - log.msg("unknown hint type: %r" % (hint,)) - return None - if not("hostname" in hint and - isinstance(hint["hostname"], type(""))): - log.msg("invalid hostname in hint: %r" % (hint,)) - return None - if not("port" in hint and - isinstance(hint["port"], six.integer_types)): - log.msg("invalid port in hint: %r" % (hint,)) - return None - priority = hint.get("priority", 0.0) - if hint_type == "direct-tcp-v1": - return DirectTCPV1Hint(hint["hostname"], hint["port"], priority) - else: - return TorTCPV1Hint(hint["hostname"], hint["port"], priority) + parse_hint_argv, describe_hint_obj, endpoint_from_hint_obj, + parse_tcp_v1_hint) def parse_hint(hint_struct): diff --git a/src/wormhole/_hints.py b/src/wormhole/_hints.py index 9b79542..0a44b3b 100644 --- a/src/wormhole/_hints.py +++ b/src/wormhole/_hints.py @@ -1,8 +1,10 @@ from __future__ import print_function, unicode_literals import sys import re +import six from collections import namedtuple from twisted.internet.endpoints import HostnameEndpoint +from twisted.python import log # These namedtuples are "hint objects". The JSON-serializable dictionaries # are "hint dicts". @@ -82,3 +84,22 @@ def endpoint_from_hint_obj(hint, tor, reactor): if isinstance(hint, DirectTCPV1Hint): return HostnameEndpoint(reactor, hint.hostname, hint.port) return None + +def parse_tcp_v1_hint(hint): # hint_struct -> hint_obj + hint_type = hint.get("type", "") + if hint_type not in ["direct-tcp-v1", "tor-tcp-v1"]: + log.msg("unknown hint type: %r" % (hint, )) + return None + if not ("hostname" in hint and + isinstance(hint["hostname"], type(""))): + log.msg("invalid hostname in hint: %r" % (hint, )) + return None + if not ("port" in hint and + isinstance(hint["port"], six.integer_types)): + log.msg("invalid port in hint: %r" % (hint, )) + return None + priority = hint.get("priority", 0.0) + if hint_type == "direct-tcp-v1": + return DirectTCPV1Hint(hint["hostname"], hint["port"], priority) + else: + return TorTCPV1Hint(hint["hostname"], hint["port"], priority) diff --git a/src/wormhole/test/test_transit.py b/src/wormhole/test/test_transit.py index 9a53983..86f008b 100644 --- a/src/wormhole/test/test_transit.py +++ b/src/wormhole/test/test_transit.py @@ -18,7 +18,8 @@ import mock from wormhole_transit_relay import transit_server from .. import transit -from .._hints import endpoint_from_hint_obj +from .._hints import (endpoint_from_hint_obj, parse_hint_argv, parse_tcp_v1_hint, + DirectTCPV1Hint, TorTCPV1Hint, RelayV1Hint) from ..errors import InternalError from .common import ServerBase @@ -148,13 +149,12 @@ class Hints(unittest.TestCase): def test_endpoint_from_hint_obj(self): def efho(hint, tor=None): return endpoint_from_hint_obj(hint, tor, reactor) - self.assertIsInstance( - efho(transit.DirectTCPV1Hint("host", 1234, 0.0)), - endpoints.HostnameEndpoint) + self.assertIsInstance(efho(DirectTCPV1Hint("host", 1234, 0.0)), + endpoints.HostnameEndpoint) self.assertEqual(efho("unknown:stuff:yowza:pivlor"), None) # tor=None - self.assertEqual(efho(transit.TorTCPV1Hint("host", "port", 0)), None) + self.assertEqual(efho(TorTCPV1Hint("host", "port", 0)), None) tor = mock.Mock() def tor_ep(hostname, port): @@ -163,50 +163,46 @@ class Hints(unittest.TestCase): return ("tor_ep", hostname, port) tor.stream_via = mock.Mock(side_effect=tor_ep) - self.assertEqual( - efho(transit.DirectTCPV1Hint("host", 1234, 0.0), tor), - ("tor_ep", "host", 1234)) - self.assertEqual( - efho(transit.TorTCPV1Hint("host2.onion", 1234, 0.0), tor), - ("tor_ep", "host2.onion", 1234)) - self.assertEqual( - efho(transit.DirectTCPV1Hint("non-public", 1234, 0.0), tor), None) + self.assertEqual(efho(DirectTCPV1Hint("host", 1234, 0.0), tor), + ("tor_ep", "host", 1234)) + self.assertEqual(efho(TorTCPV1Hint("host2.onion", 1234, 0.0), tor), + ("tor_ep", "host2.onion", 1234)) + self.assertEqual( efho(DirectTCPV1Hint("non-public", 1234, 0.0), tor), None) self.assertEqual(efho(UnknownHint("foo")), None) def test_comparable(self): - h1 = transit.DirectTCPV1Hint("hostname", "port1", 0.0) - h1b = transit.DirectTCPV1Hint("hostname", "port1", 0.0) - h2 = transit.DirectTCPV1Hint("hostname", "port2", 0.0) - r1 = transit.RelayV1Hint(tuple(sorted([h1, h2]))) - r2 = transit.RelayV1Hint(tuple(sorted([h2, h1]))) - r3 = transit.RelayV1Hint(tuple(sorted([h1b, h2]))) + h1 = DirectTCPV1Hint("hostname", "port1", 0.0) + h1b = DirectTCPV1Hint("hostname", "port1", 0.0) + h2 = DirectTCPV1Hint("hostname", "port2", 0.0) + r1 = RelayV1Hint(tuple(sorted([h1, h2]))) + r2 = RelayV1Hint(tuple(sorted([h2, h1]))) + r3 = RelayV1Hint(tuple(sorted([h1b, h2]))) self.assertEqual(r1, r2) self.assertEqual(r2, r3) self.assertEqual(len(set([r1, r2, r3])), 1) def test_parse_tcp_v1_hint(self): - c = transit.Common("") - p = c._parse_tcp_v1_hint + p = parse_tcp_v1_hint self.assertEqual(p({"type": "unknown"}), None) h = p({"type": "direct-tcp-v1", "hostname": "foo", "port": 1234}) - self.assertEqual(h, transit.DirectTCPV1Hint("foo", 1234, 0.0)) + self.assertEqual(h, DirectTCPV1Hint("foo", 1234, 0.0)) h = p({ "type": "direct-tcp-v1", "hostname": "foo", "port": 1234, "priority": 2.5 }) - self.assertEqual(h, transit.DirectTCPV1Hint("foo", 1234, 2.5)) + self.assertEqual(h, DirectTCPV1Hint("foo", 1234, 2.5)) h = p({"type": "tor-tcp-v1", "hostname": "foo", "port": 1234}) - self.assertEqual(h, transit.TorTCPV1Hint("foo", 1234, 0.0)) + self.assertEqual(h, TorTCPV1Hint("foo", 1234, 0.0)) h = p({ "type": "tor-tcp-v1", "hostname": "foo", "port": 1234, "priority": 2.5 }) - self.assertEqual(h, transit.TorTCPV1Hint("foo", 1234, 2.5)) + self.assertEqual(h, TorTCPV1Hint("foo", 1234, 2.5)) self.assertEqual(p({ "type": "direct-tcp-v1" }), None) # missing hostname @@ -229,19 +225,19 @@ class Hints(unittest.TestCase): def test_parse_hint_argv(self): def p(hint): stderr = io.StringIO() - value = transit.parse_hint_argv(hint, stderr=stderr) + value = parse_hint_argv(hint, stderr=stderr) return value, stderr.getvalue() h, stderr = p("tcp:host:1234") - self.assertEqual(h, transit.DirectTCPV1Hint("host", 1234, 0.0)) + self.assertEqual(h, DirectTCPV1Hint("host", 1234, 0.0)) self.assertEqual(stderr, "") h, stderr = p("tcp:host:1234:priority=2.6") - self.assertEqual(h, transit.DirectTCPV1Hint("host", 1234, 2.6)) + self.assertEqual(h, DirectTCPV1Hint("host", 1234, 2.6)) self.assertEqual(stderr, "") h, stderr = p("tcp:host:1234:unknown=stuff") - self.assertEqual(h, transit.DirectTCPV1Hint("host", 1234, 0.0)) + self.assertEqual(h, DirectTCPV1Hint("host", 1234, 0.0)) self.assertEqual(stderr, "") h, stderr = p("$!@#^") @@ -272,12 +268,10 @@ class Hints(unittest.TestCase): def test_describe_hint_obj(self): d = transit.describe_hint_obj - self.assertEqual( - d(transit.DirectTCPV1Hint("host", 1234, 0.0), False, False), - "->tcp:host:1234") - self.assertEqual( - d(transit.TorTCPV1Hint("host", 1234, 0.0), False, False), - "->tor:host:1234") + self.assertEqual(d(DirectTCPV1Hint("host", 1234, 0.0), False, False), + "->tcp:host:1234") + self.assertEqual(d(TorTCPV1Hint("host", 1234, 0.0), False, False), + "->tor:host:1234") self.assertEqual(d(UnknownHint("stuff"), False, False), "->%s" % str(UnknownHint("stuff"))) @@ -443,7 +437,7 @@ class Listener(unittest.TestCase): hints, ep = c._build_listener() self.assertIsInstance(hints, (list, set)) if hints: - self.assertIsInstance(hints[0], transit.DirectTCPV1Hint) + self.assertIsInstance(hints[0], DirectTCPV1Hint) self.assertIsInstance(ep, endpoints.TCP4ServerEndpoint) def test_get_direct_hints(self): @@ -1514,7 +1508,7 @@ class Transit(unittest.TestCase): self.assertEqual(self._descriptions, ["tor->relay:tcp:relay:1234"]) def _endpoint_from_hint_obj(self, hint, _tor, _reactor): - if isinstance(hint, transit.DirectTCPV1Hint): + if isinstance(hint, DirectTCPV1Hint): if hint.hostname == "unavailable": return None return hint.hostname diff --git a/src/wormhole/transit.py b/src/wormhole/transit.py index b122055..63aafdb 100644 --- a/src/wormhole/transit.py +++ b/src/wormhole/transit.py @@ -23,8 +23,9 @@ from . import ipaddrs from .errors import InternalError from .timing import DebugTiming from .util import bytes_to_hexstr -from ._hints import (DirectTCPV1Hint, TorTCPV1Hint, RelayV1Hint, - parse_hint_argv, describe_hint_obj, endpoint_from_hint_obj) +from ._hints import (DirectTCPV1Hint, RelayV1Hint, + parse_hint_argv, describe_hint_obj, endpoint_from_hint_obj, + parse_tcp_v1_hint) def HKDF(skm, outlen, salt=None, CTXinfo=b""): @@ -681,30 +682,11 @@ class Common: self._listener_d.addErrback(lambda f: None) self._listener_d.cancel() - def _parse_tcp_v1_hint(self, hint): # hint_struct -> hint_obj - hint_type = hint.get(u"type", u"") - if hint_type not in [u"direct-tcp-v1", u"tor-tcp-v1"]: - log.msg("unknown hint type: %r" % (hint, )) - return None - if not (u"hostname" in hint and - isinstance(hint[u"hostname"], type(u""))): - log.msg("invalid hostname in hint: %r" % (hint, )) - return None - if not (u"port" in hint and - isinstance(hint[u"port"], six.integer_types)): - log.msg("invalid port in hint: %r" % (hint, )) - return None - priority = hint.get(u"priority", 0.0) - if hint_type == u"direct-tcp-v1": - return DirectTCPV1Hint(hint[u"hostname"], hint[u"port"], priority) - else: - return TorTCPV1Hint(hint[u"hostname"], hint[u"port"], priority) - def add_connection_hints(self, hints): for h in hints: # hint structs hint_type = h.get(u"type", u"") if hint_type in [u"direct-tcp-v1", u"tor-tcp-v1"]: - dh = self._parse_tcp_v1_hint(h) + dh = parse_tcp_v1_hint(h) if dh: self._their_direct_hints.append(dh) # hint_obj elif hint_type == u"relay-v1": @@ -714,7 +696,7 @@ class Common: # together like this. relay_hints = [] for rhs in h.get(u"hints", []): - h = self._parse_tcp_v1_hint(rhs) + h = parse_tcp_v1_hint(rhs) if h: relay_hints.append(h) if relay_hints: From 1bb5634d0ef5b3d653194571020b9e312d6149f6 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Fri, 21 Dec 2018 23:28:45 -0500 Subject: [PATCH 34/49] factor Hints tests out of test_transit into a new file --- src/wormhole/test/test_hints.py | 142 ++++++++++++++++++++++++++++++ src/wormhole/test/test_transit.py | 139 +---------------------------- 2 files changed, 144 insertions(+), 137 deletions(-) create mode 100644 src/wormhole/test/test_hints.py diff --git a/src/wormhole/test/test_hints.py b/src/wormhole/test/test_hints.py new file mode 100644 index 0000000..8cc71e3 --- /dev/null +++ b/src/wormhole/test/test_hints.py @@ -0,0 +1,142 @@ +from __future__ import print_function, unicode_literals +import io +from collections import namedtuple +import mock +from twisted.internet import endpoints, reactor +from twisted.trial import unittest +from .._hints import (endpoint_from_hint_obj, parse_hint_argv, parse_tcp_v1_hint, + describe_hint_obj, + DirectTCPV1Hint, TorTCPV1Hint, RelayV1Hint) + +UnknownHint = namedtuple("UnknownHint", ["stuff"]) + + +class Hints(unittest.TestCase): + def test_endpoint_from_hint_obj(self): + def efho(hint, tor=None): + return endpoint_from_hint_obj(hint, tor, reactor) + self.assertIsInstance(efho(DirectTCPV1Hint("host", 1234, 0.0)), + endpoints.HostnameEndpoint) + self.assertEqual(efho("unknown:stuff:yowza:pivlor"), None) + + # tor=None + self.assertEqual(efho(TorTCPV1Hint("host", "port", 0)), None) + + tor = mock.Mock() + def tor_ep(hostname, port): + if hostname == "non-public": + return None + return ("tor_ep", hostname, port) + tor.stream_via = mock.Mock(side_effect=tor_ep) + + self.assertEqual(efho(DirectTCPV1Hint("host", 1234, 0.0), tor), + ("tor_ep", "host", 1234)) + self.assertEqual(efho(TorTCPV1Hint("host2.onion", 1234, 0.0), tor), + ("tor_ep", "host2.onion", 1234)) + self.assertEqual( efho(DirectTCPV1Hint("non-public", 1234, 0.0), tor), None) + + self.assertEqual(efho(UnknownHint("foo")), None) + + def test_comparable(self): + h1 = DirectTCPV1Hint("hostname", "port1", 0.0) + h1b = DirectTCPV1Hint("hostname", "port1", 0.0) + h2 = DirectTCPV1Hint("hostname", "port2", 0.0) + r1 = RelayV1Hint(tuple(sorted([h1, h2]))) + r2 = RelayV1Hint(tuple(sorted([h2, h1]))) + r3 = RelayV1Hint(tuple(sorted([h1b, h2]))) + self.assertEqual(r1, r2) + self.assertEqual(r2, r3) + self.assertEqual(len(set([r1, r2, r3])), 1) + + def test_parse_tcp_v1_hint(self): + p = parse_tcp_v1_hint + self.assertEqual(p({"type": "unknown"}), None) + h = p({"type": "direct-tcp-v1", "hostname": "foo", "port": 1234}) + self.assertEqual(h, DirectTCPV1Hint("foo", 1234, 0.0)) + h = p({ + "type": "direct-tcp-v1", + "hostname": "foo", + "port": 1234, + "priority": 2.5 + }) + self.assertEqual(h, DirectTCPV1Hint("foo", 1234, 2.5)) + h = p({"type": "tor-tcp-v1", "hostname": "foo", "port": 1234}) + self.assertEqual(h, TorTCPV1Hint("foo", 1234, 0.0)) + h = p({ + "type": "tor-tcp-v1", + "hostname": "foo", + "port": 1234, + "priority": 2.5 + }) + self.assertEqual(h, TorTCPV1Hint("foo", 1234, 2.5)) + self.assertEqual(p({ + "type": "direct-tcp-v1" + }), None) # missing hostname + self.assertEqual(p({ + "type": "direct-tcp-v1", + "hostname": 12 + }), None) # invalid hostname + self.assertEqual( + p({ + "type": "direct-tcp-v1", + "hostname": "foo" + }), None) # missing port + self.assertEqual( + p({ + "type": "direct-tcp-v1", + "hostname": "foo", + "port": "not a number" + }), None) # invalid port + + def test_parse_hint_argv(self): + def p(hint): + stderr = io.StringIO() + value = parse_hint_argv(hint, stderr=stderr) + return value, stderr.getvalue() + + h, stderr = p("tcp:host:1234") + self.assertEqual(h, DirectTCPV1Hint("host", 1234, 0.0)) + self.assertEqual(stderr, "") + + h, stderr = p("tcp:host:1234:priority=2.6") + self.assertEqual(h, DirectTCPV1Hint("host", 1234, 2.6)) + self.assertEqual(stderr, "") + + h, stderr = p("tcp:host:1234:unknown=stuff") + self.assertEqual(h, DirectTCPV1Hint("host", 1234, 0.0)) + self.assertEqual(stderr, "") + + h, stderr = p("$!@#^") + self.assertEqual(h, None) + self.assertEqual(stderr, "unparseable hint '$!@#^'\n") + + h, stderr = p("unknown:stuff") + self.assertEqual(h, None) + self.assertEqual(stderr, + "unknown hint type 'unknown' in 'unknown:stuff'\n") + + h, stderr = p("tcp:just-a-hostname") + self.assertEqual(h, None) + self.assertEqual( + stderr, + "unparseable TCP hint (need more colons) 'tcp:just-a-hostname'\n") + + h, stderr = p("tcp:host:number") + self.assertEqual(h, None) + self.assertEqual(stderr, + "non-numeric port in TCP hint 'tcp:host:number'\n") + + h, stderr = p("tcp:host:1234:priority=bad") + self.assertEqual(h, None) + self.assertEqual( + stderr, + "non-float priority= in TCP hint 'tcp:host:1234:priority=bad'\n") + + def test_describe_hint_obj(self): + d = describe_hint_obj + self.assertEqual(d(DirectTCPV1Hint("host", 1234, 0.0), False, False), + "->tcp:host:1234") + self.assertEqual(d(TorTCPV1Hint("host", 1234, 0.0), False, False), + "->tor:host:1234") + self.assertEqual(d(UnknownHint("stuff"), False, False), + "->%s" % str(UnknownHint("stuff"))) diff --git a/src/wormhole/test/test_transit.py b/src/wormhole/test/test_transit.py index 86f008b..b3d7590 100644 --- a/src/wormhole/test/test_transit.py +++ b/src/wormhole/test/test_transit.py @@ -3,12 +3,11 @@ from __future__ import print_function, unicode_literals import gc import io from binascii import hexlify, unhexlify -from collections import namedtuple import six from nacl.exceptions import CryptoError from nacl.secret import SecretBox -from twisted.internet import address, defer, endpoints, error, protocol, task, reactor +from twisted.internet import address, defer, endpoints, error, protocol, task from twisted.internet.defer import gatherResults, inlineCallbacks from twisted.python import log from twisted.test import proto_helpers @@ -18,8 +17,7 @@ import mock from wormhole_transit_relay import transit_server from .. import transit -from .._hints import (endpoint_from_hint_obj, parse_hint_argv, parse_tcp_v1_hint, - DirectTCPV1Hint, TorTCPV1Hint, RelayV1Hint) +from .._hints import DirectTCPV1Hint from ..errors import InternalError from .common import ServerBase @@ -142,139 +140,6 @@ class Misc(unittest.TestCase): self.assertIsInstance(portno, int) -UnknownHint = namedtuple("UnknownHint", ["stuff"]) - - -class Hints(unittest.TestCase): - def test_endpoint_from_hint_obj(self): - def efho(hint, tor=None): - return endpoint_from_hint_obj(hint, tor, reactor) - self.assertIsInstance(efho(DirectTCPV1Hint("host", 1234, 0.0)), - endpoints.HostnameEndpoint) - self.assertEqual(efho("unknown:stuff:yowza:pivlor"), None) - - # tor=None - self.assertEqual(efho(TorTCPV1Hint("host", "port", 0)), None) - - tor = mock.Mock() - def tor_ep(hostname, port): - if hostname == "non-public": - return None - return ("tor_ep", hostname, port) - tor.stream_via = mock.Mock(side_effect=tor_ep) - - self.assertEqual(efho(DirectTCPV1Hint("host", 1234, 0.0), tor), - ("tor_ep", "host", 1234)) - self.assertEqual(efho(TorTCPV1Hint("host2.onion", 1234, 0.0), tor), - ("tor_ep", "host2.onion", 1234)) - self.assertEqual( efho(DirectTCPV1Hint("non-public", 1234, 0.0), tor), None) - - self.assertEqual(efho(UnknownHint("foo")), None) - - def test_comparable(self): - h1 = DirectTCPV1Hint("hostname", "port1", 0.0) - h1b = DirectTCPV1Hint("hostname", "port1", 0.0) - h2 = DirectTCPV1Hint("hostname", "port2", 0.0) - r1 = RelayV1Hint(tuple(sorted([h1, h2]))) - r2 = RelayV1Hint(tuple(sorted([h2, h1]))) - r3 = RelayV1Hint(tuple(sorted([h1b, h2]))) - self.assertEqual(r1, r2) - self.assertEqual(r2, r3) - self.assertEqual(len(set([r1, r2, r3])), 1) - - def test_parse_tcp_v1_hint(self): - p = parse_tcp_v1_hint - self.assertEqual(p({"type": "unknown"}), None) - h = p({"type": "direct-tcp-v1", "hostname": "foo", "port": 1234}) - self.assertEqual(h, DirectTCPV1Hint("foo", 1234, 0.0)) - h = p({ - "type": "direct-tcp-v1", - "hostname": "foo", - "port": 1234, - "priority": 2.5 - }) - self.assertEqual(h, DirectTCPV1Hint("foo", 1234, 2.5)) - h = p({"type": "tor-tcp-v1", "hostname": "foo", "port": 1234}) - self.assertEqual(h, TorTCPV1Hint("foo", 1234, 0.0)) - h = p({ - "type": "tor-tcp-v1", - "hostname": "foo", - "port": 1234, - "priority": 2.5 - }) - self.assertEqual(h, TorTCPV1Hint("foo", 1234, 2.5)) - self.assertEqual(p({ - "type": "direct-tcp-v1" - }), None) # missing hostname - self.assertEqual(p({ - "type": "direct-tcp-v1", - "hostname": 12 - }), None) # invalid hostname - self.assertEqual( - p({ - "type": "direct-tcp-v1", - "hostname": "foo" - }), None) # missing port - self.assertEqual( - p({ - "type": "direct-tcp-v1", - "hostname": "foo", - "port": "not a number" - }), None) # invalid port - - def test_parse_hint_argv(self): - def p(hint): - stderr = io.StringIO() - value = parse_hint_argv(hint, stderr=stderr) - return value, stderr.getvalue() - - h, stderr = p("tcp:host:1234") - self.assertEqual(h, DirectTCPV1Hint("host", 1234, 0.0)) - self.assertEqual(stderr, "") - - h, stderr = p("tcp:host:1234:priority=2.6") - self.assertEqual(h, DirectTCPV1Hint("host", 1234, 2.6)) - self.assertEqual(stderr, "") - - h, stderr = p("tcp:host:1234:unknown=stuff") - self.assertEqual(h, DirectTCPV1Hint("host", 1234, 0.0)) - self.assertEqual(stderr, "") - - h, stderr = p("$!@#^") - self.assertEqual(h, None) - self.assertEqual(stderr, "unparseable hint '$!@#^'\n") - - h, stderr = p("unknown:stuff") - self.assertEqual(h, None) - self.assertEqual(stderr, - "unknown hint type 'unknown' in 'unknown:stuff'\n") - - h, stderr = p("tcp:just-a-hostname") - self.assertEqual(h, None) - self.assertEqual( - stderr, - "unparseable TCP hint (need more colons) 'tcp:just-a-hostname'\n") - - h, stderr = p("tcp:host:number") - self.assertEqual(h, None) - self.assertEqual(stderr, - "non-numeric port in TCP hint 'tcp:host:number'\n") - - h, stderr = p("tcp:host:1234:priority=bad") - self.assertEqual(h, None) - self.assertEqual( - stderr, - "non-float priority= in TCP hint 'tcp:host:1234:priority=bad'\n") - - def test_describe_hint_obj(self): - d = transit.describe_hint_obj - self.assertEqual(d(DirectTCPV1Hint("host", 1234, 0.0), False, False), - "->tcp:host:1234") - self.assertEqual(d(TorTCPV1Hint("host", 1234, 0.0), False, False), - "->tor:host:1234") - self.assertEqual(d(UnknownHint("stuff"), False, False), - "->%s" % str(UnknownHint("stuff"))) - # ipaddrs.py currently uses native strings: bytes on py2, unicode on # py3 From d64c94a1dc53a0d5dca0bc5cab940556fd26f972 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Fri, 21 Dec 2018 23:37:07 -0500 Subject: [PATCH 35/49] test_hints: finish coverage of hints.py --- src/wormhole/test/test_hints.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/wormhole/test/test_hints.py b/src/wormhole/test/test_hints.py index 8cc71e3..92c1def 100644 --- a/src/wormhole/test/test_hints.py +++ b/src/wormhole/test/test_hints.py @@ -21,11 +21,12 @@ class Hints(unittest.TestCase): # tor=None self.assertEqual(efho(TorTCPV1Hint("host", "port", 0)), None) + self.assertEqual(efho(UnknownHint("foo")), None) tor = mock.Mock() def tor_ep(hostname, port): if hostname == "non-public": - return None + raise ValueError return ("tor_ep", hostname, port) tor.stream_via = mock.Mock(side_effect=tor_ep) @@ -35,7 +36,7 @@ class Hints(unittest.TestCase): ("tor_ep", "host2.onion", 1234)) self.assertEqual( efho(DirectTCPV1Hint("non-public", 1234, 0.0), tor), None) - self.assertEqual(efho(UnknownHint("foo")), None) + self.assertEqual(efho(UnknownHint("foo"), tor), None) def test_comparable(self): h1 = DirectTCPV1Hint("hostname", "port1", 0.0) @@ -136,6 +137,12 @@ class Hints(unittest.TestCase): d = describe_hint_obj self.assertEqual(d(DirectTCPV1Hint("host", 1234, 0.0), False, False), "->tcp:host:1234") + self.assertEqual(d(DirectTCPV1Hint("host", 1234, 0.0), True, False), + "->relay:tcp:host:1234") + self.assertEqual(d(DirectTCPV1Hint("host", 1234, 0.0), False, True), + "tor->tcp:host:1234") + self.assertEqual(d(DirectTCPV1Hint("host", 1234, 0.0), True, True), + "tor->relay:tcp:host:1234") self.assertEqual(d(TorTCPV1Hint("host", 1234, 0.0), False, False), "->tor:host:1234") self.assertEqual(d(UnknownHint("stuff"), False, False), From b4c90b40a294c721b24e76d52134011b8377a03c Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Fri, 21 Dec 2018 23:55:58 -0500 Subject: [PATCH 36/49] move parse_hint/encode_hint into _hints.py, add tests --- src/wormhole/_dilation/connector.py | 35 +-------------------- src/wormhole/_dilation/manager.py | 3 +- src/wormhole/_hints.py | 33 +++++++++++++++++++ src/wormhole/test/test_hints.py | 49 ++++++++++++++++++++++++++++- 4 files changed, 84 insertions(+), 36 deletions(-) diff --git a/src/wormhole/_dilation/connector.py b/src/wormhole/_dilation/connector.py index 3d4d773..97c68eb 100644 --- a/src/wormhole/_dilation/connector.py +++ b/src/wormhole/_dilation/connector.py @@ -21,42 +21,9 @@ from .roles import LEADER from .._hints import (DirectTCPV1Hint, TorTCPV1Hint, RelayV1Hint, parse_hint_argv, describe_hint_obj, endpoint_from_hint_obj, - parse_tcp_v1_hint) + encode_hint) -def parse_hint(hint_struct): - hint_type = hint_struct.get("type", "") - if hint_type == "relay-v1": - # the struct can include multiple ways to reach the same relay - rhints = filter(lambda h: h, # drop None (unrecognized) - [parse_tcp_v1_hint(rh) for rh in hint_struct["hints"]]) - return RelayV1Hint(rhints) - return parse_tcp_v1_hint(hint_struct) - - -def encode_hint(h): - if isinstance(h, DirectTCPV1Hint): - return {"type": "direct-tcp-v1", - "priority": h.priority, - "hostname": h.hostname, - "port": h.port, # integer - } - elif isinstance(h, RelayV1Hint): - rhint = {"type": "relay-v1", "hints": []} - for rh in h.hints: - rhint["hints"].append({"type": "direct-tcp-v1", - "priority": rh.priority, - "hostname": rh.hostname, - "port": rh.port}) - return rhint - elif isinstance(h, TorTCPV1Hint): - return {"type": "tor-tcp-v1", - "priority": h.priority, - "hostname": h.hostname, - "port": h.port, # integer - } - raise ValueError("unknown hint type", h) - def HKDF(skm, outlen, salt=None, CTXinfo=b""): return Hkdf(salt, skm).expand(CTXinfo, outlen) diff --git a/src/wormhole/_dilation/manager.py b/src/wormhole/_dilation/manager.py index d9db016..18d1770 100644 --- a/src/wormhole/_dilation/manager.py +++ b/src/wormhole/_dilation/manager.py @@ -15,7 +15,8 @@ from .encode import to_be4 from .subchannel import (SubChannel, _SubchannelAddress, _WormholeAddress, ControlEndpoint, SubchannelConnectorEndpoint, SubchannelListenerEndpoint) -from .connector import Connector, parse_hint +from .connector import Connector +from .._hints import parse_hint from .roles import LEADER, FOLLOWER from .connection import KCM, Ping, Pong, Open, Data, Close, Ack from .inbound import Inbound diff --git a/src/wormhole/_hints.py b/src/wormhole/_hints.py index 0a44b3b..a8ee258 100644 --- a/src/wormhole/_hints.py +++ b/src/wormhole/_hints.py @@ -103,3 +103,36 @@ def parse_tcp_v1_hint(hint): # hint_struct -> hint_obj return DirectTCPV1Hint(hint["hostname"], hint["port"], priority) else: return TorTCPV1Hint(hint["hostname"], hint["port"], priority) + +def parse_hint(hint_struct): + hint_type = hint_struct.get("type", "") + if hint_type == "relay-v1": + # the struct can include multiple ways to reach the same relay + rhints = filter(lambda h: h, # drop None (unrecognized) + [parse_tcp_v1_hint(rh) for rh in hint_struct["hints"]]) + return RelayV1Hint(list(rhints)) + return parse_tcp_v1_hint(hint_struct) + + +def encode_hint(h): + if isinstance(h, DirectTCPV1Hint): + return {"type": "direct-tcp-v1", + "priority": h.priority, + "hostname": h.hostname, + "port": h.port, # integer + } + elif isinstance(h, RelayV1Hint): + rhint = {"type": "relay-v1", "hints": []} + for rh in h.hints: + rhint["hints"].append({"type": "direct-tcp-v1", + "priority": rh.priority, + "hostname": rh.hostname, + "port": rh.port}) + return rhint + elif isinstance(h, TorTCPV1Hint): + return {"type": "tor-tcp-v1", + "priority": h.priority, + "hostname": h.hostname, + "port": h.port, # integer + } + raise ValueError("unknown hint type", h) diff --git a/src/wormhole/test/test_hints.py b/src/wormhole/test/test_hints.py index 92c1def..7dcffd6 100644 --- a/src/wormhole/test/test_hints.py +++ b/src/wormhole/test/test_hints.py @@ -5,7 +5,7 @@ import mock from twisted.internet import endpoints, reactor from twisted.trial import unittest from .._hints import (endpoint_from_hint_obj, parse_hint_argv, parse_tcp_v1_hint, - describe_hint_obj, + describe_hint_obj, parse_hint, encode_hint, DirectTCPV1Hint, TorTCPV1Hint, RelayV1Hint) UnknownHint = namedtuple("UnknownHint", ["stuff"]) @@ -89,6 +89,24 @@ class Hints(unittest.TestCase): "port": "not a number" }), None) # invalid port + def test_parse_hint(self): + p = parse_hint + self.assertEqual(p({"type": "direct-tcp-v1", + "hostname": "foo", + "port": 12}), + DirectTCPV1Hint("foo", 12, 0.0)) + self.assertEqual(p({"type": "relay-v1", + "hints": [ + {"type": "direct-tcp-v1", + "hostname": "foo", + "port": 12}, + {"type": "unrecognized"}, + {"type": "direct-tcp-v1", + "hostname": "bar", + "port": 13}]}), + RelayV1Hint([DirectTCPV1Hint("foo", 12, 0.0), + DirectTCPV1Hint("bar", 13, 0.0)])) + def test_parse_hint_argv(self): def p(hint): stderr = io.StringIO() @@ -147,3 +165,32 @@ class Hints(unittest.TestCase): "->tor:host:1234") self.assertEqual(d(UnknownHint("stuff"), False, False), "->%s" % str(UnknownHint("stuff"))) + + def test_encode_hint(self): + e = encode_hint + self.assertEqual(e(DirectTCPV1Hint("host", 1234, 1.0)), + {"type": "direct-tcp-v1", + "priority": 1.0, + "hostname": "host", + "port": 1234}) + self.assertEqual(e(RelayV1Hint([DirectTCPV1Hint("foo", 12, 0.0), + DirectTCPV1Hint("bar", 13, 0.0)])), + {"type": "relay-v1", + "hints": [ + {"type": "direct-tcp-v1", + "hostname": "foo", + "port": 12, + "priority": 0.0}, + {"type": "direct-tcp-v1", + "hostname": "bar", + "port": 13, + "priority": 0.0}, + ]}) + self.assertEqual(e(TorTCPV1Hint("host", 1234, 1.0)), + {"type": "tor-tcp-v1", + "priority": 1.0, + "hostname": "host", + "port": 1234}) + e = self.assertRaises(ValueError, e, "not a Hint") + self.assertIn("unknown hint type", str(e)) + self.assertIn("not a Hint", str(e)) From e7cb1df785d0fc034b37b5c5a03b0482abceb242 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Sat, 22 Dec 2018 17:27:54 -0500 Subject: [PATCH 37/49] factor out HKDF --- src/wormhole/_dilation/connector.py | 7 +------ src/wormhole/_key.py | 7 +------ src/wormhole/test/test_transit.py | 3 ++- src/wormhole/transit.py | 7 +------ src/wormhole/util.py | 4 ++++ 5 files changed, 9 insertions(+), 19 deletions(-) diff --git a/src/wormhole/_dilation/connector.py b/src/wormhole/_dilation/connector.py index 97c68eb..d148862 100644 --- a/src/wormhole/_dilation/connector.py +++ b/src/wormhole/_dilation/connector.py @@ -11,11 +11,11 @@ from twisted.internet.defer import DeferredList from twisted.internet.endpoints import serverFromString from twisted.internet.protocol import ClientFactory, ServerFactory from twisted.python import log -from hkdf import Hkdf from .. import ipaddrs # TODO: move into _dilation/ from .._interfaces import IDilationConnector, IDilationManager from ..timing import DebugTiming from ..observer import EmptyableSet +from ..util import HKDF from .connection import DilatedConnectionProtocol, KCM from .roles import LEADER @@ -24,11 +24,6 @@ from .._hints import (DirectTCPV1Hint, TorTCPV1Hint, RelayV1Hint, encode_hint) - -def HKDF(skm, outlen, salt=None, CTXinfo=b""): - return Hkdf(salt, skm).expand(CTXinfo, outlen) - - def build_sided_relay_handshake(key, side): assert isinstance(side, type(u"")) assert len(side) == 8 * 2 diff --git a/src/wormhole/_key.py b/src/wormhole/_key.py index 849ed41..e14e51a 100644 --- a/src/wormhole/_key.py +++ b/src/wormhole/_key.py @@ -6,7 +6,6 @@ import six from attr import attrib, attrs from attr.validators import instance_of, provides from automat import MethodicalMachine -from hkdf import Hkdf from nacl import utils from nacl.exceptions import CryptoError from nacl.secret import SecretBox @@ -15,16 +14,12 @@ from zope.interface import implementer from . import _interfaces from .util import (bytes_to_dict, bytes_to_hexstr, dict_to_bytes, - hexstr_to_bytes, to_bytes) + hexstr_to_bytes, to_bytes, HKDF) CryptoError __all__ = ["derive_key", "derive_phase_key", "CryptoError", "Key"] -def HKDF(skm, outlen, salt=None, CTXinfo=b""): - return Hkdf(salt, skm).expand(CTXinfo, outlen) - - def derive_key(key, purpose, length=SecretBox.KEY_SIZE): if not isinstance(key, type(b"")): raise TypeError(type(key)) diff --git a/src/wormhole/test/test_transit.py b/src/wormhole/test/test_transit.py index b3d7590..df6d5a6 100644 --- a/src/wormhole/test/test_transit.py +++ b/src/wormhole/test/test_transit.py @@ -19,6 +19,7 @@ from wormhole_transit_relay import transit_server from .. import transit from .._hints import DirectTCPV1Hint from ..errors import InternalError +from ..util import HKDF from .common import ServerBase @@ -1526,7 +1527,7 @@ class Transit(unittest.TestCase): class RelayHandshake(unittest.TestCase): def old_build_relay_handshake(self, key): - token = transit.HKDF(key, 32, CTXinfo=b"transit_relay_token") + token = HKDF(key, 32, CTXinfo=b"transit_relay_token") return (token, b"please relay " + hexlify(token) + b"\n") def test_old(self): diff --git a/src/wormhole/transit.py b/src/wormhole/transit.py index 63aafdb..98e1b72 100644 --- a/src/wormhole/transit.py +++ b/src/wormhole/transit.py @@ -9,7 +9,6 @@ from binascii import hexlify, unhexlify from collections import deque import six -from hkdf import Hkdf from nacl.secret import SecretBox from twisted.internet import (address, defer, endpoints, error, interfaces, protocol, reactor, task) @@ -22,16 +21,12 @@ from zope.interface import implementer from . import ipaddrs from .errors import InternalError from .timing import DebugTiming -from .util import bytes_to_hexstr +from .util import bytes_to_hexstr, HKDF from ._hints import (DirectTCPV1Hint, RelayV1Hint, parse_hint_argv, describe_hint_obj, endpoint_from_hint_obj, parse_tcp_v1_hint) -def HKDF(skm, outlen, salt=None, CTXinfo=b""): - return Hkdf(salt, skm).expand(CTXinfo, outlen) - - class TransitError(Exception): pass diff --git a/src/wormhole/util.py b/src/wormhole/util.py index 0b57c5e..26f234f 100644 --- a/src/wormhole/util.py +++ b/src/wormhole/util.py @@ -3,8 +3,12 @@ import json import os import unicodedata from binascii import hexlify, unhexlify +from hkdf import Hkdf +def HKDF(skm, outlen, salt=None, CTXinfo=b""): + return Hkdf(salt, skm).expand(CTXinfo, outlen) + def to_bytes(u): return unicodedata.normalize("NFC", u).encode("utf-8") From 6ad6f8f40f4f8266546312aaf980110f920b2480 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Sun, 23 Dec 2018 00:57:19 -0500 Subject: [PATCH 38/49] test and fix half of connector.py still to do: * relay delays * connection race * cancellation of losing connections * shutdown of all connections when abandoned --- src/wormhole/_dilation/connector.py | 65 ++-- src/wormhole/test/dilate/test_connector.py | 329 +++++++++++++++++++++ 2 files changed, 364 insertions(+), 30 deletions(-) create mode 100644 src/wormhole/test/dilate/test_connector.py diff --git a/src/wormhole/_dilation/connector.py b/src/wormhole/_dilation/connector.py index d148862..92888f9 100644 --- a/src/wormhole/_dilation/connector.py +++ b/src/wormhole/_dilation/connector.py @@ -34,8 +34,11 @@ def build_sided_relay_handshake(key, side): PROLOGUE_LEADER = b"Magic-Wormhole Dilation Handshake v1 Leader\n\n" PROLOGUE_FOLLOWER = b"Magic-Wormhole Dilation Handshake v1 Follower\n\n" -NOISEPROTO = "Noise_NNpsk0_25519_ChaChaPoly_BLAKE2s" +NOISEPROTO = b"Noise_NNpsk0_25519_ChaChaPoly_BLAKE2s" +def build_noise(): + from noise.connection import NoiseConnection + return NoiseConnection.from_name(NOISEPROTO) @attrs @implementer(IDilationConnector) @@ -53,7 +56,7 @@ class Connector(object): _role = attrib() m = MethodicalMachine() - set_trace = getattr(m, "_setTrace", lambda self, f: None) + set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover RELAY_DELAY = 2.0 @@ -85,8 +88,7 @@ class Connector(object): # encryption: let's use Noise NNpsk0 (or maybe NNpsk2). That uses # ephemeral keys plus a pre-shared symmetric key (the Transit key), a # different one for each potential connection. - from noise.connection import NoiseConnection - noise = NoiseConnection.from_name(NOISEPROTO) + noise = build_noise() noise.set_psks(self._dilation_key) if self._role is LEADER: noise.set_as_initiator() @@ -144,6 +146,9 @@ class Connector(object): @m.output() def publish_hints(self, hint_objs): + self._publish_hints(hint_objs) + + def _publish_hints(self, hint_objs): self._manager.send_hints([encode_hint(h) for h in hint_objs]) @m.output() @@ -229,7 +234,7 @@ class Connector(object): def start(self): self._start_listener() if self._transit_relays: - self.publish_hints(self._transit_relays) + self._publish_hints(self._transit_relays) self._use_hints(self._transit_relays) def _start_listener(self): @@ -260,14 +265,15 @@ class Connector(object): def _use_hints(self, hints): # first, pull out all the relays, we'll connect to them later - relays = defaultdict(list) + relays = [] direct = defaultdict(list) for h in hints: if isinstance(h, RelayV1Hint): - relays[h.priority].append(h) + relays.append(h) else: direct[h.priority].append(h) delay = 0.0 + made_direct = False priorities = sorted(set(direct.keys()), reverse=True) for p in priorities: for h in direct[p]: @@ -277,7 +283,9 @@ class Connector(object): desc = describe_hint_obj(h, False, self._tor) d = deferLater(self._reactor, delay, self._connect, ep, desc, is_relay=False) + d.addErrback(log.err) self._pending_connectors.add(d) + made_direct = True # Make all direct connections immediately. Later, we'll change # the add_candidate() function to look at the priority when # deciding whether to accept a successful connection or not, @@ -286,34 +294,32 @@ class Connector(object): # putting an inter-direct-hint delay here to influence the # process. # delay += 1.0 - if delay > 0.0: - # Start trying the relays a few seconds after we start to try the - # direct hints. The idea is to prefer direct connections, but not - # be afraid of using a relay when we have direct hints that don't - # resolve quickly. Many direct hints will be to unused - # local-network IP addresses, which won't answer, and would take - # the full TCP timeout (30s or more) to fail. If there were no - # direct hints, don't delay at all. - delay += self.RELAY_DELAY - # prefer direct connections by stalling relay connections by a few - # seconds, unless we're using --no-listen in which case we're probably - # going to have to use the relay - delay = self.RELAY_DELAY if self._no_listen else 0.0 + if made_direct and not self._no_listen: + # Prefer direct connections by stalling relay connections by a + # few seconds. We don't wait until direct connections have + # failed, because many direct hints will be to unused + # local-network IP address, which won't answer, and can take the + # full 30s TCP timeout to fail. + # + # If we didn't make any direct connections, or we're using + # --no-listen, then we're probably going to have to use the + # relay, so don't delay it at all. + delay += self.RELAY_DELAY # It might be nice to wire this so that a failure in the direct hints # causes the relay hints to be used right away (fast failover). But # none of our current use cases would take advantage of that: if we # have any viable direct hints, then they're either going to succeed # quickly or hang for a long time. - for p in priorities: - for r in relays[p]: - for h in r.hints: - ep = endpoint_from_hint_obj(h, self._tor, self._reactor) - desc = describe_hint_obj(h, True, self._tor) - d = deferLater(self._reactor, delay, - self._connect, ep, desc, is_relay=True) - self._pending_connectors.add(d) + for r in relays: + for h in r.hints: + ep = endpoint_from_hint_obj(h, self._tor, self._reactor) + desc = describe_hint_obj(h, True, self._tor) + d = deferLater(self._reactor, delay, + self._connect, ep, desc, is_relay=True) + d.addErrback(log.err) + self._pending_connectors.add(d) # TODO: # if not contenders: # raise TransitError("No contenders for connection") @@ -321,7 +327,7 @@ class Connector(object): # TODO: add 2*TIMEOUT deadline for first generation, don't wait forever for # the initial connection - def _connect(self, h, ep, description, is_relay=False): + def _connect(self, ep, description, is_relay=False): relay_handshake = None if is_relay: relay_handshake = build_sided_relay_handshake(self._dilation_key, @@ -369,7 +375,6 @@ class OutboundConnectionFactory(ClientFactory, object): @attrs class InboundConnectionFactory(ServerFactory, object): _connector = attrib(validator=provides(IDilationConnector)) - protocol = DilatedConnectionProtocol def buildProtocol(self, addr): p = self._connector.build_protocol(addr) diff --git a/src/wormhole/test/dilate/test_connector.py b/src/wormhole/test/dilate/test_connector.py new file mode 100644 index 0000000..8d05f65 --- /dev/null +++ b/src/wormhole/test/dilate/test_connector.py @@ -0,0 +1,329 @@ +from __future__ import print_function, unicode_literals + +import mock +from zope.interface import alsoProvides +from twisted.trial import unittest +from twisted.internet.task import Clock +from twisted.internet.defer import Deferred +#from twisted.internet import endpoints +from ...eventual import EventualQueue +from ..._interfaces import IDilationManager, IDilationConnector +from ..._dilation import roles +from ..._hints import DirectTCPV1Hint, RelayV1Hint, TorTCPV1Hint +from ..._dilation.connector import (#describe_hint_obj, parse_hint_argv, + #parse_tcp_v1_hint, parse_hint, encode_hint, + Connector, + build_sided_relay_handshake, + build_noise, + OutboundConnectionFactory, + InboundConnectionFactory, + PROLOGUE_LEADER, PROLOGUE_FOLLOWER, + ) + +class Handshake(unittest.TestCase): + def test_build(self): + key = b"k"*32 + side = "12345678abcdabcd" + self.assertEqual(build_sided_relay_handshake(key, side), + b"please relay 3f4147851dbd2589d25b654ee9fb35ed0d3e5f19c5c5403e8e6a195c70f0577a for side 12345678abcdabcd\n") + +class Outbound(unittest.TestCase): + def test_no_relay(self): + c = mock.Mock() + alsoProvides(c, IDilationConnector) + p0 = mock.Mock() + c.build_protocol = mock.Mock(return_value=p0) + relay_handshake = None + f = OutboundConnectionFactory(c, relay_handshake) + addr = object() + p = f.buildProtocol(addr) + self.assertIdentical(p, p0) + self.assertEqual(c.mock_calls, [mock.call.build_protocol(addr)]) + self.assertEqual(p.mock_calls, []) + self.assertIdentical(p.factory, f) + + def test_with_relay(self): + c = mock.Mock() + alsoProvides(c, IDilationConnector) + p0 = mock.Mock() + c.build_protocol = mock.Mock(return_value=p0) + relay_handshake = b"relay handshake" + f = OutboundConnectionFactory(c, relay_handshake) + addr = object() + p = f.buildProtocol(addr) + self.assertIdentical(p, p0) + self.assertEqual(c.mock_calls, [mock.call.build_protocol(addr)]) + self.assertEqual(p.mock_calls, [mock.call.use_relay(relay_handshake)]) + self.assertIdentical(p.factory, f) + +class Inbound(unittest.TestCase): + def test_build(self): + c = mock.Mock() + alsoProvides(c, IDilationConnector) + p0 = mock.Mock() + c.build_protocol = mock.Mock(return_value=p0) + f = InboundConnectionFactory(c) + addr = object() + p = f.buildProtocol(addr) + self.assertIdentical(p, p0) + self.assertEqual(c.mock_calls, [mock.call.build_protocol(addr)]) + self.assertIdentical(p.factory, f) + +def make_connector(listen=True, tor=False, relay=None, role=roles.LEADER): + class Holder: + pass + h = Holder() + h.dilation_key = b"key" + h.relay = relay + h.manager = mock.Mock() + alsoProvides(h.manager, IDilationManager) + h.clock = Clock() + h.reactor = h.clock + h.eq = EventualQueue(h.clock) + h.tor = None + if tor: + h.tor = mock.Mock() + timing = None + h.side = u"abcd1234abcd5678" + h.role = role + c = Connector(h.dilation_key, h.relay, h.manager, h.reactor, h.eq, + not listen, h.tor, timing, h.side, h.role) + return c, h + +class TestConnector(unittest.TestCase): + def test_build(self): + c, h = make_connector() + c, h = make_connector(relay="tcp:host:1234") + + def test_connection_abilities(self): + self.assertEqual(Connector.get_connection_abilities(), + [{"type": "direct-tcp-v1"}, + {"type": "relay-v1"}, + ]) + + def test_build_noise(self): + build_noise() + + def test_build_protocol_leader(self): + c, h = make_connector(role=roles.LEADER) + n0 = mock.Mock() + p0 = mock.Mock() + addr = object() + with mock.patch("wormhole._dilation.connector.build_noise", + return_value=n0) as bn: + with mock.patch("wormhole._dilation.connector.DilatedConnectionProtocol", + return_value=p0) as dcp: + p = c.build_protocol(addr) + self.assertEqual(bn.mock_calls, [mock.call()]) + self.assertEqual(n0.mock_calls, [mock.call.set_psks(h.dilation_key), + mock.call.set_as_initiator()]) + self.assertIdentical(p, p0) + self.assertEqual(dcp.mock_calls, + [mock.call(h.eq, h.role, c, n0, + PROLOGUE_LEADER, PROLOGUE_FOLLOWER)]) + + def test_build_protocol_follower(self): + c, h = make_connector(role=roles.FOLLOWER) + n0 = mock.Mock() + p0 = mock.Mock() + addr = object() + with mock.patch("wormhole._dilation.connector.build_noise", + return_value=n0) as bn: + with mock.patch("wormhole._dilation.connector.DilatedConnectionProtocol", + return_value=p0) as dcp: + p = c.build_protocol(addr) + self.assertEqual(bn.mock_calls, [mock.call()]) + self.assertEqual(n0.mock_calls, [mock.call.set_psks(h.dilation_key), + mock.call.set_as_responder()]) + self.assertIdentical(p, p0) + self.assertEqual(dcp.mock_calls, + [mock.call(h.eq, h.role, c, n0, + PROLOGUE_FOLLOWER, PROLOGUE_LEADER)]) + + def test_start_stop(self): + c, h = make_connector(listen=False, relay=None, role=roles.LEADER) + c.start() + # no relays, so it publishes no hints + self.assertEqual(h.manager.mock_calls, []) + # and no listener, so nothing happens until we provide a hint + c.stop() + # we stop while we're connecting, so no connections must be stopped + + def test_basic(self): + c, h = make_connector(listen=False, relay=None, role=roles.LEADER) + c.start() + # no relays, so it publishes no hints + self.assertEqual(h.manager.mock_calls, []) + # and no listener, so nothing happens until we provide a hint + + ep0 = mock.Mock() + ep0_connect_d = Deferred() + ep0.connect = mock.Mock(return_value=ep0_connect_d) + efho = mock.Mock(side_effect=[ep0]) + hint0 = DirectTCPV1Hint("foo", 55, 0.0) + dho = mock.Mock(side_effect=["desc0"]) + with mock.patch("wormhole._dilation.connector.endpoint_from_hint_obj", + efho): + with mock.patch("wormhole._dilation.connector.describe_hint_obj", dho): + c.got_hints([hint0]) + self.assertEqual(efho.mock_calls, [mock.call(hint0, h.tor, h.reactor)]) + self.assertEqual(dho.mock_calls, [mock.call(hint0, False, h.tor)]) + f0 = mock.Mock() + with mock.patch("wormhole._dilation.connector.OutboundConnectionFactory", + return_value=f0) as ocf: + h.clock.advance(c.RELAY_DELAY / 2 + 0.01) + self.assertEqual(ocf.mock_calls, [mock.call(c, None)]) + self.assertEqual(ep0.connect.mock_calls, [mock.call(f0)]) + + p = mock.Mock() + ep0_connect_d.callback(p) + self.assertEqual(p.mock_calls, + [mock.call.when_disconnected(), + mock.call.when_disconnected().addCallback(c._pending_connections.discard)]) + + def test_listen(self): + c, h = make_connector(listen=True, role=roles.LEADER) + d = Deferred() + ep = mock.Mock() + ep.listen = mock.Mock(return_value=d) + f = mock.Mock() + with mock.patch("wormhole.ipaddrs.find_addresses", + return_value=["127.0.0.1", "1.2.3.4", "5.6.7.8"]): + with mock.patch("wormhole._dilation.connector.serverFromString", + side_effect=[ep]): + with mock.patch("wormhole._dilation.connector.InboundConnectionFactory", + return_value=f): + c.start() + # no relays and the listener isn't ready yet, so no hints yet + self.assertEqual(h.manager.mock_calls, []) + # but a listener was started + self.assertEqual(ep.mock_calls, [mock.call.listen(f)]) + lp = mock.Mock() + host = mock.Mock() + host.port = 2345 + lp.getHost = mock.Mock(return_value=host) + d.callback(lp) + self.assertEqual(h.manager.mock_calls, + [mock.call.send_hints( + [{"type": "direct-tcp-v1", "hostname": "1.2.3.4", + "port": 2345, "priority": 0.0}, + {"type": "direct-tcp-v1", "hostname": "5.6.7.8", + "port": 2345, "priority": 0.0}, + ])]) + + def test_listen_but_tor(self): + c, h = make_connector(listen=True, tor=True, role=roles.LEADER) + with mock.patch("wormhole.ipaddrs.find_addresses", + return_value=["127.0.0.1", "1.2.3.4", "5.6.7.8"]) as fa: + c.start() + # don't even look up addresses + self.assertEqual(fa.mock_calls, []) + # no relays and the listener isn't ready yet, so no hints yet + self.assertEqual(h.manager.mock_calls, []) + + def test_listen_only_loopback(self): + # some test hosts, including the appveyor VMs, *only* have + # 127.0.0.1, and the tests will hang badly if we remove it. + c, h = make_connector(listen=True, role=roles.LEADER) + d = Deferred() + ep = mock.Mock() + ep.listen = mock.Mock(return_value=d) + f = mock.Mock() + with mock.patch("wormhole.ipaddrs.find_addresses", return_value=["127.0.0.1"]): + with mock.patch("wormhole._dilation.connector.serverFromString", + side_effect=[ep]): + with mock.patch("wormhole._dilation.connector.InboundConnectionFactory", + return_value=f): + c.start() + # no relays and the listener isn't ready yet, so no hints yet + self.assertEqual(h.manager.mock_calls, []) + # but a listener was started + self.assertEqual(ep.mock_calls, [mock.call.listen(f)]) + lp = mock.Mock() + host = mock.Mock() + host.port = 2345 + lp.getHost = mock.Mock(return_value=host) + d.callback(lp) + self.assertEqual(h.manager.mock_calls, + [mock.call.send_hints( + [{"type": "direct-tcp-v1", "hostname": "127.0.0.1", + "port": 2345, "priority": 0.0}, + ])]) + + def OFFtest_relay_delay(self): + # given a direct connection and a relay, we should see the direct + # connection initiated at T+0 seconds, and the relay at T+RELAY_DELAY + c, h = make_connector(listen=False, relay="tcp:foo:55", role=roles.LEADER) + c.start() + hint1 = DirectTCPV1Hint("foo", 55, 0.0) + hint2 = DirectTCPV1Hint("bar", 55, 0.0) + hint3 = RelayV1Hint([DirectTCPV1Hint("relay", 55, 0.0)]) + ep1, ep2, ep3 = mock.Mock(), mock.Mock(), mock.Mock() + with mock.patch("wormhole._dilation.connector.endpoint_from_hint_obj", + side_effect=[ep1, ep2, ep3]): + c.got_hints([hint1, hint2, hint3]) + self.assertEqual(ep1.mock_calls, []) + self.assertEqual(ep2.mock_calls, []) + self.assertEqual(ep3.mock_calls, []) + + h.clock.advance(c.RELAY_DELAY / 2 + 0.01) + self.assertEqual(len(ep1.mock_calls), 2) + self.assertEqual(len(ep2.mock_calls), 2) + self.assertEqual(ep3.mock_calls, []) + + h.clock.advance(c.RELAY_DELAY) + self.assertEqual(len(ep1.mock_calls), 2) + self.assertEqual(len(ep2.mock_calls), 2) + self.assertEqual(len(ep3.mock_calls), 2) + + def test_initial_relay(self): + c, h = make_connector(listen=False, relay="tcp:foo:55", role=roles.LEADER) + ep = mock.Mock() + with mock.patch("wormhole._dilation.connector.endpoint_from_hint_obj", + side_effect=[ep]) as efho: + c.start() + self.assertEqual(h.manager.mock_calls, + [mock.call.send_hints([{"type": "relay-v1", + "hints": [ + {"type": "direct-tcp-v1", + "hostname": "foo", + "port": 55, + "priority": 0.0 + }, + ], + }])]) + self.assertEqual(len(efho.mock_calls), 1) + + def test_tor_no_manager(self): + # tor hints should be ignored if we don't have a Tor manager to use them + c, h = make_connector(listen=False, role=roles.LEADER) + c.start() + hint = TorTCPV1Hint("foo", 55, 0.0) + ep = mock.Mock() + with mock.patch("wormhole._dilation.connector.endpoint_from_hint_obj", + side_effect=[ep]): + c.got_hints([hint]) + self.assertEqual(ep.mock_calls, []) + + h.clock.advance(c.RELAY_DELAY * 2) + self.assertEqual(ep.mock_calls, []) + + def test_tor_with_manager(self): + # tor hints should be processed if we do have a Tor manager + c, h = make_connector(listen=False, tor=True, role=roles.LEADER) + c.start() + hint = TorTCPV1Hint("foo", 55, 0.0) + ep = mock.Mock() + with mock.patch("wormhole._dilation.connector.endpoint_from_hint_obj", + side_effect=[ep]): + c.got_hints([hint]) + self.assertEqual(ep.mock_calls, []) + + h.clock.advance(c.RELAY_DELAY * 2) + self.assertEqual(len(ep.mock_calls), 2) + + + def test_priorities(self): + # given two hints with different priorities, we should somehow prefer + # one. This is a placeholder to fill in once we implement priorities. + pass From a458fe9ab9865394612cb7d4ff213dde98191f74 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Sun, 23 Dec 2018 15:01:16 -0500 Subject: [PATCH 39/49] finish test coverage/fixes for connector.py --- src/wormhole/_dilation/connector.py | 64 ++-- src/wormhole/test/dilate/test_connector.py | 375 ++++++++++++++------- 2 files changed, 288 insertions(+), 151 deletions(-) diff --git a/src/wormhole/_dilation/connector.py b/src/wormhole/_dilation/connector.py index 92888f9..545d9de 100644 --- a/src/wormhole/_dilation/connector.py +++ b/src/wormhole/_dilation/connector.py @@ -116,10 +116,9 @@ class Connector(object): pass # pragma: no cover # TODO: unify the tense of these method-name verbs - @m.input() - def listener_ready(self, hint_objs): - pass + # add_relay() and got_hints() are called by the Manager as it receives + # messages from our peer. stop() is called when the Manager shuts down @m.input() def add_relay(self, hint_objs): pass @@ -129,16 +128,25 @@ class Connector(object): pass @m.input() - def add_candidate(self, c): # called by DilatedConnectionProtocol + def stop(self): pass + # called by ourselves, when _start_listener() is ready + @m.input() + def listener_ready(self, hint_objs): + pass + + # called when DilatedConnectionProtocol submits itself, after KCM + # received + @m.input() + def add_candidate(self, c): + pass + + # called by ourselves, via consider() @m.input() def accept(self, c): pass - @m.input() - def stop(self): - pass @m.output() def use_hints(self, hint_objs): @@ -199,17 +207,12 @@ class Connector(object): [c.loseConnection() for c in self._pending_connections] return d - def stop_winner(self): - d = self._winner.when_disconnected() - self._winner.disconnect() - return d - def break_cycles(self): # help GC by forgetting references to things that reference us self._listeners.clear() self._pending_connectors.clear() self._pending_connections.clear() - self._winner = None + self._winning_connection = None connecting.upon(listener_ready, enter=connecting, outputs=[publish_hints]) connecting.upon(add_relay, enter=connecting, outputs=[use_hints, @@ -224,6 +227,8 @@ class Connector(object): connected.upon(listener_ready, enter=connected, outputs=[]) connected.upon(add_relay, enter=connected, outputs=[]) connected.upon(got_hints, enter=connected, outputs=[]) + # TODO: tell them to disconnect? will they hang out forever? I *think* + # they'll drop this once they get a KCM on the winning connection. connected.upon(add_candidate, enter=connected, outputs=[]) connected.upon(accept, enter=connected, outputs=[]) connected.upon(stop, enter=stopped, outputs=[stop_everything]) @@ -232,20 +237,23 @@ class Connector(object): # maybe add_candidate, accept def start(self): - self._start_listener() + if not self._no_listen and not self._tor: + addresses = self._get_listener_addresses() + self._start_listener(addresses) if self._transit_relays: self._publish_hints(self._transit_relays) self._use_hints(self._transit_relays) - def _start_listener(self): - if self._no_listen or self._tor: - return + def _get_listener_addresses(self): addresses = ipaddrs.find_addresses() non_loopback_addresses = [a for a in addresses if a != "127.0.0.1"] if non_loopback_addresses: # some test hosts, including the appveyor VMs, *only* have # 127.0.0.1, and the tests will hang badly if we remove it. addresses = non_loopback_addresses + return addresses + + def _start_listener(self, addresses): # TODO: listen on a fixed port, if possible, for NAT/p2p benefits, also # to make firewall configs easier # TODO: retain listening port between connection generations? @@ -263,6 +271,14 @@ class Connector(object): d.addCallback(_listening) d.addErrback(log.err) + def _schedule_connection(self, delay, h, is_relay): + ep = endpoint_from_hint_obj(h, self._tor, self._reactor) + desc = describe_hint_obj(h, is_relay, self._tor) + d = deferLater(self._reactor, delay, + self._connect, ep, desc, is_relay) + d.addErrback(log.err) + self._pending_connectors.add(d) + def _use_hints(self, hints): # first, pull out all the relays, we'll connect to them later relays = [] @@ -279,12 +295,7 @@ class Connector(object): for h in direct[p]: if isinstance(h, TorTCPV1Hint) and not self._tor: continue - ep = endpoint_from_hint_obj(h, self._tor, self._reactor) - desc = describe_hint_obj(h, False, self._tor) - d = deferLater(self._reactor, delay, - self._connect, ep, desc, is_relay=False) - d.addErrback(log.err) - self._pending_connectors.add(d) + self._schedule_connection(delay, h, is_relay=False) made_direct = True # Make all direct connections immediately. Later, we'll change # the add_candidate() function to look at the priority when @@ -314,12 +325,7 @@ class Connector(object): # quickly or hang for a long time. for r in relays: for h in r.hints: - ep = endpoint_from_hint_obj(h, self._tor, self._reactor) - desc = describe_hint_obj(h, True, self._tor) - d = deferLater(self._reactor, delay, - self._connect, ep, desc, is_relay=True) - d.addErrback(log.err) - self._pending_connectors.add(d) + self._schedule_connection(delay, h, is_relay=True) # TODO: # if not contenders: # raise TransitError("No contenders for connection") diff --git a/src/wormhole/test/dilate/test_connector.py b/src/wormhole/test/dilate/test_connector.py index 8d05f65..64291e5 100644 --- a/src/wormhole/test/dilate/test_connector.py +++ b/src/wormhole/test/dilate/test_connector.py @@ -5,20 +5,19 @@ from zope.interface import alsoProvides from twisted.trial import unittest from twisted.internet.task import Clock from twisted.internet.defer import Deferred -#from twisted.internet import endpoints from ...eventual import EventualQueue from ..._interfaces import IDilationManager, IDilationConnector from ..._dilation import roles from ..._hints import DirectTCPV1Hint, RelayV1Hint, TorTCPV1Hint -from ..._dilation.connector import (#describe_hint_obj, parse_hint_argv, - #parse_tcp_v1_hint, parse_hint, encode_hint, - Connector, - build_sided_relay_handshake, - build_noise, - OutboundConnectionFactory, - InboundConnectionFactory, - PROLOGUE_LEADER, PROLOGUE_FOLLOWER, - ) +from ..._dilation.connection import KCM +from ..._dilation.connector import (Connector, + build_sided_relay_handshake, + build_noise, + OutboundConnectionFactory, + InboundConnectionFactory, + PROLOGUE_LEADER, PROLOGUE_FOLLOWER, + ) +from .common import clear_mock_calls class Handshake(unittest.TestCase): def test_build(self): @@ -149,67 +148,125 @@ class TestConnector(unittest.TestCase): c.stop() # we stop while we're connecting, so no connections must be stopped - def test_basic(self): + def test_empty(self): c, h = make_connector(listen=False, relay=None, role=roles.LEADER) + c._schedule_connection = mock.Mock() c.start() # no relays, so it publishes no hints self.assertEqual(h.manager.mock_calls, []) # and no listener, so nothing happens until we provide a hint + self.assertEqual(c._schedule_connection.mock_calls, []) + c.stop() - ep0 = mock.Mock() - ep0_connect_d = Deferred() - ep0.connect = mock.Mock(return_value=ep0_connect_d) - efho = mock.Mock(side_effect=[ep0]) - hint0 = DirectTCPV1Hint("foo", 55, 0.0) - dho = mock.Mock(side_effect=["desc0"]) + def test_basic(self): + c, h = make_connector(listen=False, relay=None, role=roles.LEADER) + c._schedule_connection = mock.Mock() + c.start() + # no relays, so it publishes no hints + self.assertEqual(h.manager.mock_calls, []) + # and no listener, so nothing happens until we provide a hint + self.assertEqual(c._schedule_connection.mock_calls, []) + + hint = DirectTCPV1Hint("foo", 55, 0.0) + c.got_hints([hint]) + + # received hints don't get published + self.assertEqual(h.manager.mock_calls, []) + # they just schedule a connection + self.assertEqual(c._schedule_connection.mock_calls, + [mock.call(0.0, DirectTCPV1Hint("foo", 55, 0.0), + is_relay=False)]) + + def test_listen_addresses(self): + c, h = make_connector(listen=True, role=roles.LEADER) + with mock.patch("wormhole.ipaddrs.find_addresses", + return_value=["127.0.0.1", "1.2.3.4", "5.6.7.8"]): + self.assertEqual(c._get_listener_addresses(), + ["1.2.3.4", "5.6.7.8"]) + with mock.patch("wormhole.ipaddrs.find_addresses", + return_value=["127.0.0.1"]): + # some test hosts, including the appveyor VMs, *only* have + # 127.0.0.1, and the tests will hang badly if we remove it. + self.assertEqual(c._get_listener_addresses(), ["127.0.0.1"]) + + def test_listen(self): + c, h = make_connector(listen=True, role=roles.LEADER) + c._start_listener = mock.Mock() + with mock.patch("wormhole.ipaddrs.find_addresses", + return_value=["127.0.0.1", "1.2.3.4", "5.6.7.8"]): + c.start() + self.assertEqual(c._start_listener.mock_calls, + [mock.call(["1.2.3.4", "5.6.7.8"])]) + + def test_start_listen(self): + c, h = make_connector(listen=True, role=roles.LEADER) + ep = mock.Mock() + d = Deferred() + ep.listen = mock.Mock(return_value=d) + with mock.patch("wormhole._dilation.connector.serverFromString", + return_value=ep) as sfs: + c._start_listener(["1.2.3.4", "5.6.7.8"]) + self.assertEqual(sfs.mock_calls, [mock.call(h.reactor, "tcp:0")]) + lp = mock.Mock() + host = mock.Mock() + host.port = 66 + lp.getHost = mock.Mock(return_value=host) + d.callback(lp) + self.assertEqual(h.manager.mock_calls, + [mock.call.send_hints([{"type": "direct-tcp-v1", + "hostname": "1.2.3.4", + "port": 66, + "priority": 0.0 + }, + {"type": "direct-tcp-v1", + "hostname": "5.6.7.8", + "port": 66, + "priority": 0.0 + }, + ])]) + + def test_schedule_connection_no_relay(self): + c, h = make_connector(listen=True, role=roles.LEADER) + hint = DirectTCPV1Hint("foo", 55, 0.0) + ep = mock.Mock() with mock.patch("wormhole._dilation.connector.endpoint_from_hint_obj", - efho): - with mock.patch("wormhole._dilation.connector.describe_hint_obj", dho): - c.got_hints([hint0]) - self.assertEqual(efho.mock_calls, [mock.call(hint0, h.tor, h.reactor)]) - self.assertEqual(dho.mock_calls, [mock.call(hint0, False, h.tor)]) - f0 = mock.Mock() + side_effect=[ep]) as efho: + c._schedule_connection(0.0, hint, False) + self.assertEqual(efho.mock_calls, [mock.call(hint, h.tor, h.reactor)]) + self.assertEqual(ep.mock_calls, []) + d = Deferred() + ep.connect = mock.Mock(side_effect=[d]) + # direct hints are scheduled for T+0.0 + f = mock.Mock() with mock.patch("wormhole._dilation.connector.OutboundConnectionFactory", - return_value=f0) as ocf: - h.clock.advance(c.RELAY_DELAY / 2 + 0.01) + return_value=f) as ocf: + h.clock.advance(1.0) self.assertEqual(ocf.mock_calls, [mock.call(c, None)]) - self.assertEqual(ep0.connect.mock_calls, [mock.call(f0)]) - + self.assertEqual(ep.connect.mock_calls, [mock.call(f)]) p = mock.Mock() - ep0_connect_d.callback(p) + d.callback(p) self.assertEqual(p.mock_calls, [mock.call.when_disconnected(), mock.call.when_disconnected().addCallback(c._pending_connections.discard)]) - def test_listen(self): + def test_schedule_connection_relay(self): c, h = make_connector(listen=True, role=roles.LEADER) - d = Deferred() + hint = DirectTCPV1Hint("foo", 55, 0.0) ep = mock.Mock() - ep.listen = mock.Mock(return_value=d) + with mock.patch("wormhole._dilation.connector.endpoint_from_hint_obj", + side_effect=[ep]) as efho: + c._schedule_connection(0.0, hint, True) + self.assertEqual(efho.mock_calls, [mock.call(hint, h.tor, h.reactor)]) + self.assertEqual(ep.mock_calls, []) + d = Deferred() + ep.connect = mock.Mock(side_effect=[d]) + # direct hints are scheduled for T+0.0 f = mock.Mock() - with mock.patch("wormhole.ipaddrs.find_addresses", - return_value=["127.0.0.1", "1.2.3.4", "5.6.7.8"]): - with mock.patch("wormhole._dilation.connector.serverFromString", - side_effect=[ep]): - with mock.patch("wormhole._dilation.connector.InboundConnectionFactory", - return_value=f): - c.start() - # no relays and the listener isn't ready yet, so no hints yet - self.assertEqual(h.manager.mock_calls, []) - # but a listener was started - self.assertEqual(ep.mock_calls, [mock.call.listen(f)]) - lp = mock.Mock() - host = mock.Mock() - host.port = 2345 - lp.getHost = mock.Mock(return_value=host) - d.callback(lp) - self.assertEqual(h.manager.mock_calls, - [mock.call.send_hints( - [{"type": "direct-tcp-v1", "hostname": "1.2.3.4", - "port": 2345, "priority": 0.0}, - {"type": "direct-tcp-v1", "hostname": "5.6.7.8", - "port": 2345, "priority": 0.0}, - ])]) + with mock.patch("wormhole._dilation.connector.OutboundConnectionFactory", + return_value=f) as ocf: + h.clock.advance(1.0) + handshake = build_sided_relay_handshake(h.dilation_key, h.side) + self.assertEqual(ocf.mock_calls, [mock.call(c, handshake)]) def test_listen_but_tor(self): c, h = make_connector(listen=True, tor=True, role=roles.LEADER) @@ -221,67 +278,36 @@ class TestConnector(unittest.TestCase): # no relays and the listener isn't ready yet, so no hints yet self.assertEqual(h.manager.mock_calls, []) - def test_listen_only_loopback(self): - # some test hosts, including the appveyor VMs, *only* have - # 127.0.0.1, and the tests will hang badly if we remove it. - c, h = make_connector(listen=True, role=roles.LEADER) - d = Deferred() - ep = mock.Mock() - ep.listen = mock.Mock(return_value=d) - f = mock.Mock() - with mock.patch("wormhole.ipaddrs.find_addresses", return_value=["127.0.0.1"]): - with mock.patch("wormhole._dilation.connector.serverFromString", - side_effect=[ep]): - with mock.patch("wormhole._dilation.connector.InboundConnectionFactory", - return_value=f): - c.start() - # no relays and the listener isn't ready yet, so no hints yet + def test_no_listen(self): + c, h = make_connector(listen=False, tor=False, role=roles.LEADER) + with mock.patch("wormhole.ipaddrs.find_addresses", + return_value=["127.0.0.1", "1.2.3.4", "5.6.7.8"]) as fa: + c.start() + # don't even look up addresses + self.assertEqual(fa.mock_calls, []) self.assertEqual(h.manager.mock_calls, []) - # but a listener was started - self.assertEqual(ep.mock_calls, [mock.call.listen(f)]) - lp = mock.Mock() - host = mock.Mock() - host.port = 2345 - lp.getHost = mock.Mock(return_value=host) - d.callback(lp) - self.assertEqual(h.manager.mock_calls, - [mock.call.send_hints( - [{"type": "direct-tcp-v1", "hostname": "127.0.0.1", - "port": 2345, "priority": 0.0}, - ])]) - def OFFtest_relay_delay(self): + def test_relay_delay(self): # given a direct connection and a relay, we should see the direct # connection initiated at T+0 seconds, and the relay at T+RELAY_DELAY - c, h = make_connector(listen=False, relay="tcp:foo:55", role=roles.LEADER) + c, h = make_connector(listen=True, relay=None, role=roles.LEADER) + c._schedule_connection = mock.Mock() + c._start_listener = mock.Mock() c.start() hint1 = DirectTCPV1Hint("foo", 55, 0.0) hint2 = DirectTCPV1Hint("bar", 55, 0.0) hint3 = RelayV1Hint([DirectTCPV1Hint("relay", 55, 0.0)]) - ep1, ep2, ep3 = mock.Mock(), mock.Mock(), mock.Mock() - with mock.patch("wormhole._dilation.connector.endpoint_from_hint_obj", - side_effect=[ep1, ep2, ep3]): - c.got_hints([hint1, hint2, hint3]) - self.assertEqual(ep1.mock_calls, []) - self.assertEqual(ep2.mock_calls, []) - self.assertEqual(ep3.mock_calls, []) - - h.clock.advance(c.RELAY_DELAY / 2 + 0.01) - self.assertEqual(len(ep1.mock_calls), 2) - self.assertEqual(len(ep2.mock_calls), 2) - self.assertEqual(ep3.mock_calls, []) - - h.clock.advance(c.RELAY_DELAY) - self.assertEqual(len(ep1.mock_calls), 2) - self.assertEqual(len(ep2.mock_calls), 2) - self.assertEqual(len(ep3.mock_calls), 2) + c.got_hints([hint1, hint2, hint3]) + self.assertEqual(c._schedule_connection.mock_calls, + [mock.call(0.0, hint1, is_relay=False), + mock.call(0.0, hint2, is_relay=False), + mock.call(c.RELAY_DELAY, hint3.hints[0], is_relay=True), + ]) def test_initial_relay(self): c, h = make_connector(listen=False, relay="tcp:foo:55", role=roles.LEADER) - ep = mock.Mock() - with mock.patch("wormhole._dilation.connector.endpoint_from_hint_obj", - side_effect=[ep]) as efho: - c.start() + c._schedule_connection = mock.Mock() + c.start() self.assertEqual(h.manager.mock_calls, [mock.call.send_hints([{"type": "relay-v1", "hints": [ @@ -292,38 +318,143 @@ class TestConnector(unittest.TestCase): }, ], }])]) - self.assertEqual(len(efho.mock_calls), 1) + self.assertEqual(c._schedule_connection.mock_calls, + [mock.call(0.0, DirectTCPV1Hint("foo", 55, 0.0), + is_relay=True)]) + + def test_add_relay(self): + c, h = make_connector(listen=False, relay=None, role=roles.LEADER) + c._schedule_connection = mock.Mock() + c.start() + self.assertEqual(h.manager.mock_calls, []) + self.assertEqual(c._schedule_connection.mock_calls, []) + hint = RelayV1Hint([DirectTCPV1Hint("foo", 55, 0.0)]) + c.add_relay([hint]) + self.assertEqual(h.manager.mock_calls, + [mock.call.send_hints([{"type": "relay-v1", + "hints": [ + {"type": "direct-tcp-v1", + "hostname": "foo", + "port": 55, + "priority": 0.0 + }, + ], + }])]) + self.assertEqual(c._schedule_connection.mock_calls, + [mock.call(0.0, DirectTCPV1Hint("foo", 55, 0.0), + is_relay=True)]) def test_tor_no_manager(self): # tor hints should be ignored if we don't have a Tor manager to use them c, h = make_connector(listen=False, role=roles.LEADER) + c._schedule_connection = mock.Mock() c.start() hint = TorTCPV1Hint("foo", 55, 0.0) - ep = mock.Mock() - with mock.patch("wormhole._dilation.connector.endpoint_from_hint_obj", - side_effect=[ep]): - c.got_hints([hint]) - self.assertEqual(ep.mock_calls, []) - - h.clock.advance(c.RELAY_DELAY * 2) - self.assertEqual(ep.mock_calls, []) + c.got_hints([hint]) + self.assertEqual(h.manager.mock_calls, []) + self.assertEqual(c._schedule_connection.mock_calls, []) def test_tor_with_manager(self): # tor hints should be processed if we do have a Tor manager c, h = make_connector(listen=False, tor=True, role=roles.LEADER) + c._schedule_connection = mock.Mock() c.start() hint = TorTCPV1Hint("foo", 55, 0.0) - ep = mock.Mock() - with mock.patch("wormhole._dilation.connector.endpoint_from_hint_obj", - side_effect=[ep]): - c.got_hints([hint]) - self.assertEqual(ep.mock_calls, []) - - h.clock.advance(c.RELAY_DELAY * 2) - self.assertEqual(len(ep.mock_calls), 2) - + c.got_hints([hint]) + self.assertEqual(c._schedule_connection.mock_calls, + [mock.call(0.0, hint, is_relay=False)]) def test_priorities(self): # given two hints with different priorities, we should somehow prefer # one. This is a placeholder to fill in once we implement priorities. pass + + +class Race(unittest.TestCase): + def test_one_leader(self): + c, h = make_connector(listen=True, role=roles.LEADER) + lp = mock.Mock() + def start_listener(addresses): + c._listeners.add(lp) + c._start_listener = start_listener + c._schedule_connection = mock.Mock() + c.start() + self.assertEqual(c._listeners, set([lp])) + + p1 = mock.Mock() # DilatedConnectionProtocol instance + c.add_candidate(p1) + self.assertEqual(h.manager.mock_calls, []) + h.eq.flush_sync() + self.assertEqual(h.manager.mock_calls, [mock.call.use_connection(p1)]) + self.assertEqual(p1.mock_calls, + [mock.call.select(h.manager), + mock.call.send_record(KCM())]) + self.assertEqual(lp.mock_calls[0], mock.call.stopListening()) + # stop_listeners() uses a DeferredList, so we ignore the second call + + def test_one_follower(self): + c, h = make_connector(listen=True, role=roles.FOLLOWER) + lp = mock.Mock() + def start_listener(addresses): + c._listeners.add(lp) + c._start_listener = start_listener + c._schedule_connection = mock.Mock() + c.start() + self.assertEqual(c._listeners, set([lp])) + + p1 = mock.Mock() # DilatedConnectionProtocol instance + c.add_candidate(p1) + self.assertEqual(h.manager.mock_calls, []) + h.eq.flush_sync() + self.assertEqual(h.manager.mock_calls, [mock.call.use_connection(p1)]) + # just like LEADER, but follower doesn't send KCM now (it sent one + # earlier, to tell the leader that this connection looks viable) + self.assertEqual(p1.mock_calls, + [mock.call.select(h.manager)]) + self.assertEqual(lp.mock_calls[0], mock.call.stopListening()) + # stop_listeners() uses a DeferredList, so we ignore the second call + + # TODO: make sure a pending connection is abandoned when the listener + # answers successfully + + # TODO: make sure a second pending connection is abandoned when the first + # connection succeeds + + def test_late(self): + c, h = make_connector(listen=False, role=roles.LEADER) + c._schedule_connection = mock.Mock() + c.start() + + p1 = mock.Mock() # DilatedConnectionProtocol instance + c.add_candidate(p1) + self.assertEqual(h.manager.mock_calls, []) + h.eq.flush_sync() + self.assertEqual(h.manager.mock_calls, [mock.call.use_connection(p1)]) + clear_mock_calls(h.manager) + self.assertEqual(p1.mock_calls, + [mock.call.select(h.manager), + mock.call.send_record(KCM())]) + + # late connection is ignored + p2 = mock.Mock() + c.add_candidate(p2) + self.assertEqual(h.manager.mock_calls, []) + + + # make sure an established connection is dropped when stop() is called + def test_stop(self): + c, h = make_connector(listen=False, role=roles.LEADER) + c._schedule_connection = mock.Mock() + c.start() + + p1 = mock.Mock() # DilatedConnectionProtocol instance + c.add_candidate(p1) + self.assertEqual(h.manager.mock_calls, []) + h.eq.flush_sync() + self.assertEqual(h.manager.mock_calls, [mock.call.use_connection(p1)]) + self.assertEqual(p1.mock_calls, + [mock.call.select(h.manager), + mock.call.send_record(KCM())]) + + c.stop() + From 29c269ac8d03e23898a8a3d057bda60bfcb11d57 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Mon, 24 Dec 2018 00:07:06 -0500 Subject: [PATCH 40/49] get tests to work on py2.7 only install 'noiseprotocol' (which is necessary for dilation to work) if the "dilate" feature is requested (e.g. `pip install magic-wormhole[dilate]`) --- setup.py | 2 +- src/wormhole/_dilation/_noise.py | 11 +++++++++++ src/wormhole/_dilation/connection.py | 3 +-- src/wormhole/_dilation/connector.py | 9 ++++----- src/wormhole/test/dilate/test_connector.py | 5 ++++- src/wormhole/test/dilate/test_record.py | 2 +- src/wormhole/util.py | 4 ++++ 7 files changed, 26 insertions(+), 10 deletions(-) create mode 100644 src/wormhole/_dilation/_noise.py diff --git a/setup.py b/setup.py index c275e8b..f428250 100644 --- a/setup.py +++ b/setup.py @@ -48,13 +48,13 @@ setup(name="magic-wormhole", "click", "humanize", "txtorcon >= 18.0.2", # 18.0.2 fixes py3.4 support - "noiseprotocol", ], extras_require={ ':sys_platform=="win32"': ["pywin32"], "dev": ["mock", "tox", "pyflakes", "magic-wormhole-transit-relay==0.1.2", "magic-wormhole-mailbox-server==0.3.1"], + "dilate": ["noiseprotocol"], }, test_suite="wormhole.test", cmdclass=commands, diff --git a/src/wormhole/_dilation/_noise.py b/src/wormhole/_dilation/_noise.py new file mode 100644 index 0000000..bb4cf58 --- /dev/null +++ b/src/wormhole/_dilation/_noise.py @@ -0,0 +1,11 @@ +try: + from noise.exceptions import NoiseInvalidMessage +except ImportError: + class NoiseInvalidMessage(Exception): + pass + +try: + from noise.connection import NoiseConnection +except ImportError: + # allow imports to work on py2.7, even if dilation doesn't + NoiseConnection = None diff --git a/src/wormhole/_dilation/connection.py b/src/wormhole/_dilation/connection.py index d142eed..b8f3ec6 100644 --- a/src/wormhole/_dilation/connection.py +++ b/src/wormhole/_dilation/connection.py @@ -12,6 +12,7 @@ from .._interfaces import IDilationConnector from ..observer import OneShotObserver from .encode import to_be4, from_be4 from .roles import FOLLOWER +from ._noise import NoiseInvalidMessage # InboundFraming is given data and returns Frames (Noise wire-side # bytestrings). It handles the relay handshake and the prologue. The Frames it @@ -346,7 +347,6 @@ class _Record(object): @n.output() def process_handshake(self, frame): - from noise.exceptions import NoiseInvalidMessage try: payload = self._noise.read_message(frame) # Noise can include unencrypted data in the handshake, but we don't @@ -359,7 +359,6 @@ class _Record(object): @n.output() def decrypt_message(self, frame): - from noise.exceptions import NoiseInvalidMessage try: message = self._noise.decrypt(frame) except NoiseInvalidMessage as e: diff --git a/src/wormhole/_dilation/connector.py b/src/wormhole/_dilation/connector.py index 545d9de..aa5f8e0 100644 --- a/src/wormhole/_dilation/connector.py +++ b/src/wormhole/_dilation/connector.py @@ -1,7 +1,6 @@ from __future__ import print_function, unicode_literals from collections import defaultdict from binascii import hexlify -import six from attr import attrs, attrib from attr.validators import instance_of, provides, optional from automat import MethodicalMachine @@ -15,13 +14,14 @@ from .. import ipaddrs # TODO: move into _dilation/ from .._interfaces import IDilationConnector, IDilationManager from ..timing import DebugTiming from ..observer import EmptyableSet -from ..util import HKDF +from ..util import HKDF, to_unicode from .connection import DilatedConnectionProtocol, KCM from .roles import LEADER from .._hints import (DirectTCPV1Hint, TorTCPV1Hint, RelayV1Hint, parse_hint_argv, describe_hint_obj, endpoint_from_hint_obj, encode_hint) +from ._noise import NoiseConnection def build_sided_relay_handshake(key, side): @@ -37,14 +37,13 @@ PROLOGUE_FOLLOWER = b"Magic-Wormhole Dilation Handshake v1 Follower\n\n" NOISEPROTO = b"Noise_NNpsk0_25519_ChaChaPoly_BLAKE2s" def build_noise(): - from noise.connection import NoiseConnection return NoiseConnection.from_name(NOISEPROTO) @attrs @implementer(IDilationConnector) class Connector(object): _dilation_key = attrib(validator=instance_of(type(b""))) - _transit_relay_location = attrib(validator=optional(instance_of(str))) + _transit_relay_location = attrib(validator=optional(instance_of(type(u"")))) _manager = attrib(validator=provides(IDilationManager)) _reactor = attrib() _eventual_queue = attrib() @@ -265,7 +264,7 @@ class Connector(object): # lp is an IListeningPort self._listeners.add(lp) # for shutdown and tests portnum = lp.getHost().port - direct_hints = [DirectTCPV1Hint(six.u(addr), portnum, 0.0) + direct_hints = [DirectTCPV1Hint(to_unicode(addr), portnum, 0.0) for addr in addresses] self.listener_ready(direct_hints) d.addCallback(_listening) diff --git a/src/wormhole/test/dilate/test_connector.py b/src/wormhole/test/dilate/test_connector.py index 64291e5..2bb8809 100644 --- a/src/wormhole/test/dilate/test_connector.py +++ b/src/wormhole/test/dilate/test_connector.py @@ -7,8 +7,9 @@ from twisted.internet.task import Clock from twisted.internet.defer import Deferred from ...eventual import EventualQueue from ..._interfaces import IDilationManager, IDilationConnector -from ..._dilation import roles from ..._hints import DirectTCPV1Hint, RelayV1Hint, TorTCPV1Hint +from ..._dilation import roles +from ..._dilation._noise import NoiseConnection from ..._dilation.connection import KCM from ..._dilation.connector import (Connector, build_sided_relay_handshake, @@ -101,6 +102,8 @@ class TestConnector(unittest.TestCase): ]) def test_build_noise(self): + if not NoiseConnection: + raise unittest.SkipTest("noiseprotocol unavailable") build_noise() def test_build_protocol_leader(self): diff --git a/src/wormhole/test/dilate/test_record.py b/src/wormhole/test/dilate/test_record.py index 41b36e3..63a784c 100644 --- a/src/wormhole/test/dilate/test_record.py +++ b/src/wormhole/test/dilate/test_record.py @@ -2,7 +2,7 @@ from __future__ import print_function, unicode_literals import mock from zope.interface import alsoProvides from twisted.trial import unittest -from noise.exceptions import NoiseInvalidMessage +from ..._dilation._noise import NoiseInvalidMessage from ..._dilation.connection import (IFramer, Frame, Prologue, _Record, Handshake, Disconnect, Ping) diff --git a/src/wormhole/util.py b/src/wormhole/util.py index 26f234f..971de7e 100644 --- a/src/wormhole/util.py +++ b/src/wormhole/util.py @@ -12,6 +12,10 @@ def HKDF(skm, outlen, salt=None, CTXinfo=b""): def to_bytes(u): return unicodedata.normalize("NFC", u).encode("utf-8") +def to_unicode(any): + if isinstance(any, type(u"")): + return any + return any.decode("ascii") def bytes_to_hexstr(b): assert isinstance(b, type(b"")) From 96f52b931def3d27efc0b601efaadb4589db7ae5 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Mon, 24 Dec 2018 00:16:36 -0500 Subject: [PATCH 41/49] drop support for py33 current pkg_resources requires py3.4 or newer (or py27) txtorcon appears to work on py3.4 again, so remove it from allow_failures --- .travis.yml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/.travis.yml b/.travis.yml index 3bcad1a..07b23cd 100644 --- a/.travis.yml +++ b/.travis.yml @@ -23,7 +23,6 @@ after_success: matrix: include: - python: 2.7 - - python: 3.3 - python: 3.4 - python: 3.5 - python: 3.6 @@ -35,9 +34,6 @@ matrix: - python: nightly allow_failures: - python: 2.7 - - python: 3.3 - # txtorcon is currently broken on py3.4 - - python: 3.4 # travis doesn't support py3.7 yet - python: 3.7 - python: nightly From b0db8add2a2d2d3d00dbfd42d709c8737f14c634 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Mon, 24 Dec 2018 00:19:58 -0500 Subject: [PATCH 42/49] travis: stop allowing failures on py2.7 and py3.7 py2.7 now works py3.7 is now supported by travis --- .travis.yml | 3 --- 1 file changed, 3 deletions(-) diff --git a/.travis.yml b/.travis.yml index 07b23cd..8bd3a06 100644 --- a/.travis.yml +++ b/.travis.yml @@ -33,7 +33,4 @@ matrix: dist: xenial - python: nightly allow_failures: - - python: 2.7 - # travis doesn't support py3.7 yet - - python: 3.7 - python: nightly From 69bab3d814691a69119d8d9a6be476f15ab56d4c Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Mon, 24 Dec 2018 14:33:41 -0500 Subject: [PATCH 43/49] docs/api: minor fixes --- docs/api.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/api.md b/docs/api.md index 43cc318..a2f0e61 100644 --- a/docs/api.md +++ b/docs/api.md @@ -524,6 +524,8 @@ object twice. ## Dilation +(NOTE: this API is still in development) + To send bulk data, or anything more than a handful of messages, a Wormhole can be "dilated" into a form that uses a direct TCP connection between the two endpoints. @@ -639,7 +641,7 @@ def FileSendingProtocol(internet.Protocol): self.transport.loseConnection() f.close() def _send(metadata, filename): - f = protocol.ClientCreator(reactor, + f = protocol.ClientCreator(reactor, FileSendingProtocol, metadata, filename) subchannel_client_ep.connect(f) def FileReceivingProtocol(internet.Protocol): From 4083beeb6c81e2784e3dc6e54d09c14c8bd24663 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Mon, 24 Dec 2018 14:34:02 -0500 Subject: [PATCH 44/49] wormhole.py: disable dilate() API until ready more importantly, turn off the "we can do Dilation" advertisement for now, since we really can't --- src/wormhole/wormhole.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/wormhole/wormhole.py b/src/wormhole/wormhole.py index 95f5146..e17b7a1 100644 --- a/src/wormhole/wormhole.py +++ b/src/wormhole/wormhole.py @@ -192,6 +192,7 @@ class _DeferredWormhole(object): return derive_key(self._key, to_bytes(purpose), length) def dilate(self): + raise NotImplementedError return self._boss.dilate() # fires with (endpoints) def close(self): @@ -275,6 +276,7 @@ def create( "can-dilate": DILATION_VERSIONS, "dilation-abilities": Connector.get_connection_abilities(), } + wormhole_versions = {} # don't advertise Dilation yet: not ready wormhole_versions["app_versions"] = versions # app-specific capabilities v = __version__ if isinstance(v, type(b"")): From b01f48ad888b1655ac82a52f0215ad7bbae1b90d Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Mon, 24 Dec 2018 14:37:19 -0500 Subject: [PATCH 45/49] tox: test dilation on py3, but not on py2 --- tox.ini | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 4d7a007..09520a3 100644 --- a/tox.ini +++ b/tox.ini @@ -10,7 +10,7 @@ minversion = 2.4.0 [testenv] usedevelop = True -extras = dev +extras = dev,dilate deps = pyflakes >= 1.2.3 commands = @@ -18,6 +18,8 @@ commands = wormhole --version python -m wormhole.test.run_trial {posargs:wormhole} +[testenv:py27] +extras = dev # on windows, trial is installed as venv/bin/trial.py, not .exe, but (at # least appveyor) adds .PY to $PATHEXT. So "trial wormhole" might work on From 937a7d93e82707530a4c5f0e246ca64e9d29034b Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Mon, 24 Dec 2018 14:49:25 -0500 Subject: [PATCH 46/49] tox: only run coverage on py3.7 tox/coverage doesn't know to avoid the "dilate" extra, so it fails on py2.7 and py3.4 --- .travis.yml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 8bd3a06..d237669 100644 --- a/.travis.yml +++ b/.travis.yml @@ -17,7 +17,11 @@ before_script: flake8 *.py src --count --select=E901,E999,F821,F822,F823 --statistics ; fi script: - - tox -e coverage + - if [[ $TRAVIS_PYTHON_VERSION == 3.7 ]]; then + tox -e coverage ; + else + tox ; + fi after_success: - codecov matrix: From 061ff9838393c3ef79a1e0c468ba0987ad676123 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Mon, 24 Dec 2018 14:59:13 -0500 Subject: [PATCH 47/49] fix travis don't run all of tox, just a single environment that uses the default python (selected by travis) --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index d237669..58a0d0e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -20,7 +20,7 @@ script: - if [[ $TRAVIS_PYTHON_VERSION == 3.7 ]]; then tox -e coverage ; else - tox ; + tox -e py ; fi after_success: - codecov From 942a04952f4f13a85ca540e4f5a1a72c4b5b9940 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Mon, 24 Dec 2018 22:54:01 -0500 Subject: [PATCH 48/49] try to fix travis again add a new tox target "no-dilate" to use on py2.7, and use "coverage" everywhere else --- .travis.yml | 6 +++--- tox.ini | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.travis.yml b/.travis.yml index 58a0d0e..adfc4c5 100644 --- a/.travis.yml +++ b/.travis.yml @@ -17,10 +17,10 @@ before_script: flake8 *.py src --count --select=E901,E999,F821,F822,F823 --statistics ; fi script: - - if [[ $TRAVIS_PYTHON_VERSION == 3.7 ]]; then - tox -e coverage ; + - if [[ $TRAVIS_PYTHON_VERSION == 2.7 ]]; then + tox -e no-dilate ; else - tox -e py ; + tox -e coverage ; fi after_success: - codecov diff --git a/tox.ini b/tox.ini index 09520a3..3c01425 100644 --- a/tox.ini +++ b/tox.ini @@ -18,7 +18,7 @@ commands = wormhole --version python -m wormhole.test.run_trial {posargs:wormhole} -[testenv:py27] +[testenv:no-dilate] extras = dev # on windows, trial is installed as venv/bin/trial.py, not .exe, but (at From 803aa07f35df4244eeca7c08a62270fb4e135357 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Mon, 24 Dec 2018 23:00:00 -0500 Subject: [PATCH 49/49] travis: don't test dilation on py3.4 either --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index adfc4c5..fcfc1d4 100644 --- a/.travis.yml +++ b/.travis.yml @@ -17,7 +17,7 @@ before_script: flake8 *.py src --count --select=E901,E999,F821,F822,F823 --statistics ; fi script: - - if [[ $TRAVIS_PYTHON_VERSION == 2.7 ]]; then + - if [[ $TRAVIS_PYTHON_VERSION == 2.7 || $TRAVIS_PYTHON_VERSION == 3.4 ]]; then tox -e no-dilate ; else tox -e coverage ;