diff --git a/.travis.yml b/.travis.yml index 33c3dad..fcfc1d4 100644 --- a/.travis.yml +++ b/.travis.yml @@ -17,13 +17,16 @@ before_script: flake8 *.py src --count --select=E901,E999,F821,F822,F823 --statistics ; fi script: - - tox -e coverage + - if [[ $TRAVIS_PYTHON_VERSION == 2.7 || $TRAVIS_PYTHON_VERSION == 3.4 ]]; then + tox -e no-dilate ; + else + tox -e coverage ; + fi after_success: - codecov matrix: include: - python: 2.7 - - python: 3.3 - python: 3.4 - python: 3.5 - python: 3.6 @@ -34,5 +37,4 @@ matrix: dist: xenial - python: nightly allow_failures: - - python: 3.3 - python: nightly diff --git a/docs/api.md b/docs/api.md index 39a4e43..a2f0e61 100644 --- a/docs/api.md +++ b/docs/api.md @@ -524,25 +524,33 @@ object twice. ## Dilation -(NOTE: this section is speculative: this code has not yet been written) +(NOTE: this API is still in development) -In the longer term, the Wormhole object will incorporate the "Transit" -functionality (see transit.md) directly, removing the need to instantiate a -second object. A Wormhole can be "dilated" into a form that is suitable for -bulk data transfer. +To send bulk data, or anything more than a handful of messages, a Wormhole +can be "dilated" into a form that uses a direct TCP connection between the +two endpoints. All wormholes start out "undilated". In this state, all messages are queued on the Rendezvous Server for the lifetime of the wormhole, and server-imposed number/size/rate limits apply. Calling `w.dilate()` initiates the dilation -process, and success is signalled via either `d=w.when_dilated()` firing, or -`dg.wormhole_dilated()` being called. Once dilated, the Wormhole can be used -as an IConsumer/IProducer, and messages will be sent on a direct connection -(if possible) or through the transit relay (if not). +process, and eventually yields a set of Endpoints. Once dilated, the usual +`.send_message()`/`.get_message()` APIs are disabled (TODO: really?), and +these endpoints can be used to establish multiple (encrypted) "subchannel" +connections to the other side. + +Each subchannel behaves like a regular Twisted `ITransport`, so they can be +glued to the Protocol instance of your choice. They also implement the +IConsumer/IProducer interfaces. + +These subchannels are *durable*: as long as the processes on both sides keep +running, the subchannel will survive the network connection being dropped. +For example, a file transfer can be started from a laptop, then while it is +running, the laptop can be closed, moved to a new wifi network, opened back +up, and the transfer will resume from the new IP address. What's good about a non-dilated wormhole?: * setup is faster: no delay while it tries to make a direct connection -* survives temporary network outages, since messages are queued * works with "journaled mode", allowing progress to be made even when both sides are never online at the same time, by serializing the wormhole @@ -556,21 +564,103 @@ Use non-dilated wormholes when your application only needs to exchange a couple of messages, for example to set up public keys or provision access tokens. Use a dilated wormhole to move files. -Dilated wormholes can provide multiple "channels": these are multiplexed -through the single (encrypted) TCP connection. Each channel is a separate -stream (offering IProducer/IConsumer) +Dilated wormholes can provide multiple "subchannels": these are multiplexed +through the single (encrypted) TCP connection. Each subchannel is a separate +stream (offering IProducer/IConsumer for flow control), and is opened and +closed independently. A special "control channel" is available to both sides +so they can coordinate how they use the subchannels. -To create a channel, call `c = w.create_channel()` on a dilated wormhole. The -"channel ID" can be obtained with `c.get_id()`. This ID will be a short -(unicode) string, which can be sent to the other side via a normal -`w.send()`, or any other means. On the other side, use `c = -w.open_channel(channel_id)` to get a matching channel object. +The `d = w.dilate()` Deferred fires with a triple of Endpoints: -Then use `c.send(data)` and `d=c.when_received()` to exchange data, or wire -them up with `c.registerProducer()`. Note that channels do not close until -the wormhole connection is closed, so they do not have separate `close()` -methods or events. Therefore if you plan to send files through them, you'll -need to inform the recipient ahead of time about how many bytes to expect. +```python +d = w.dilate() +def _dilated(res): + (control_channel_ep, subchannel_client_ep, subchannel_server_ep) = res +d.addCallback(_dilated) +``` + +The `control_channel_ep` endpoint is a client-style endpoint, so both sides +will connect to it with `ep.connect(factory)`. This endpoint is single-use: +calling `.connect()` a second time will fail. The control channel is +symmetric: it doesn't matter which side is the application-level +client/server or initiator/responder, if the application even has such +concepts. The two applications can use the control channel to negotiate who +goes first, if necessary. + +The subchannel endpoints are *not* symmetric: for each subchannel, one side +must listen as a server, and the other must connect as a client. Subchannels +can be established by either side at any time. This supports e.g. +bidirectional file transfer, where either user of a GUI app can drop files +into the "wormhole" whenever they like. + +The `subchannel_client_ep` on one side is used to connect to the other side's +`subchannel_server_ep`, and vice versa. The client endpoint is reusable. The +server endpoint is single-use: `.listen(factory)` may only be called once. + +Applications are under no obligation to use subchannels: for many use cases, +the control channel is enough. + +To use subchannels, once the wormhole is dilated and the endpoints are +available, the listening-side application should attach a listener to the +`subchannel_server_ep` endpoint: + +```python +def _dilated(res): + (control_channel_ep, subchannel_client_ep, subchannel_server_ep) = res + f = Factory(MyListeningProtocol) + subchannel_server_ep.listen(f) +``` + +When the connecting-side application wants to connect to that listening +protocol, it should use `.connect()` with a suitable connecting protocol +factory: + +```python +def _connect(): + f = Factory(MyConnectingProtocol) + subchannel_client_ep.connect(f) +``` + +For a bidirectional file-transfer application, both sides will establish a +listening protocol. Later, if/when the user drops a file on the application +window, that side will initiate a connection, use the resulting subchannel to +transfer the single file, and then close the subchannel. + +```python +def FileSendingProtocol(internet.Protocol): + def __init__(self, metadata, filename): + self.file_metadata = metadata + self.file_name = filename + def connectionMade(self): + self.transport.write(self.file_metadata) + sender = protocols.basic.FileSender() + f = open(self.file_name,"rb") + d = sender.beginFileTransfer(f, self.transport) + d.addBoth(self._done, f) + def _done(res, f): + self.transport.loseConnection() + f.close() +def _send(metadata, filename): + f = protocol.ClientCreator(reactor, + FileSendingProtocol, metadata, filename) + subchannel_client_ep.connect(f) +def FileReceivingProtocol(internet.Protocol): + state = INITIAL + def dataReceived(self, data): + if state == INITIAL: + self.state = DATA + metadata = parse(data) + self.f = open(metadata.filename, "wb") + else: + # local file writes are blocking, so don't bother with IConsumer + self.f.write(data) + def connectionLost(self, reason): + self.f.close() +def _dilated(res): + (control_channel_ep, subchannel_client_ep, subchannel_server_ep) = res + f = Factory(FileReceivingProtocol) + subchannel_server_ep.listen(f) +``` ## Bytes, Strings, Unicode, and Python 3 diff --git a/docs/dilation-protocol.md b/docs/dilation-protocol.md new file mode 100644 index 0000000..7625353 --- /dev/null +++ b/docs/dilation-protocol.md @@ -0,0 +1,540 @@ +# Dilation Internals + +Wormhole dilation involves several moving parts. Both sides exchange messages +through the Mailbox server to coordinate the establishment of a more direct +connection. This connection might flow in either direction, so they trade +"connection hints" to point at potential listening ports. This process might +succeed in making multiple connections at about the same time, so one side +must select the best one to use, and cleanly shut down the others. To make +the dilated connection *durable*, this side must also decide when the +connection has been lost, and then coordinate the construction of a +replacement. Within this connection, a series of queued-and-acked subchannel +messages are used to open/use/close the application-visible subchannels. + +## Versions and can-dilate + +The Wormhole protocol includes a `versions` message sent immediately after +the shared PAKE key is established. This also serves as a key-confirmation +message, allowing each side to confirm that the other side knows the right +key. The body of the `versions` message is a JSON-formatted string with keys +that are available for learning the abilities of the peer. Dilation is +signaled by a key named `can-dilate`, whose value is a list of strings. Any +version present in both side's lists is eligible for use. + +## Leaders and Followers + +Each side of a Wormhole has a randomly-generated dilation `side` string (this +is included in the `please-dilate` message, and is independent of the +Wormhole's mailbox "side"). When the wormhole is dilated, the side with the +lexicographically-higher "side" value is named the "Leader", and the other +side is named the "Follower". The general wormhole protocol treats both sides +identically, but the distinction matters for the dilation protocol. + +Both sides send a `please-dilate` as soon as dilation is triggered. Each side +discovers whether it is the Leader or the Follower when the peer's +"please-dilate" arrives. The Leader has exclusive control over whether a +given connection is considered established or not: if there are multiple +potential connections to use, the Leader decides which one to use, and the +Leader gets to decide when the connection is no longer viable (and triggers +the establishment of a new one). + +The `please-dilate` includes a `use-version` key, computed as the "best" +version of the intersection of the two sides' abilities as reported in the +`versions` message. Both sides will use whichever `use-version` was specified +by the Leader (they learn which side is the Leader at the same moment they +learn the peer's `use-version` value). If the Follower cannot handle the +`use-version` value, dilation fails (this shouldn't happen, as the Leader +knew what the Follower was and was not capable of before sending that +message). + +## Connection Layers + +We describe the protocol as a series of layers. Messages sent on one layer +may be encoded or transformed before being delivered on some other layer. + +L1 is the mailbox channel (queued store-and-forward messages that always go +to the mailbox server, and then are forwarded to other clients subscribed to +the same mailbox). Both clients remain connected to the mailbox server until +the Wormhole is closed. They send DILATE-n messages to each other to manage +the dilation process, including records like `please`, `connection-hints`, +`reconnect`, and `reconnecting`. + +L2 is the set of competing connection attempts for a given generation of +connection. Each time the Leader decides to establish a new connection, a new +generation number is used. Hopefully these are direct TCP connections between +the two peers, but they may also include connections through the transit +relay. Each connection must go through an encrypted handshake process before +it is considered viable. Viable connections are then submitted to a selection +process (on the Leader side), which chooses exactly one to use, and drops the +others. It may wait an extra few seconds in the hopes of getting a "better" +connection (faster, cheaper, etc), but eventually it will select one. + +L3 is the current selected connection. There is one L3 for each generation. +At all times, the wormhole will have exactly zero or one L3 connection. L3 is +responsible for the selection process, connection monitoring/keepalives, and +serialization/deserialization of the plaintext frames. L3 delivers decoded +frames and connection-establishment events up to L4. + +L4 is the persistent higher-level channel. It is created as soon as the first +L3 connection is selected, and lasts until wormhole is closed entirely. L4 +contains OPEN/DATA/CLOSE/ACK messages: OPEN/DATA/CLOSE have a sequence number +(scoped to the L4 connection and the direction of travel), and the ACK +messages reference those sequence numbers. When a message is given to the L4 +channel for delivery to the remote side, it is always queued, then +transmitted if there is an L3 connection available. This message remains in +the queue until an ACK is received to retire it. If a new L3 connection is +made, all queued messages will be re-sent (in seqnum order). + +L5 are subchannels. There is one pre-established subchannel 0 known as the +"control channel", which does not require an OPEN message. All other +subchannels are created by the receipt of an OPEN message with the subchannel +number. DATA frames are delivered to a specific subchannel. When the +subchannel is no longer needed, one side will invoke the ``close()`` API +(``loseConnection()`` in Twisted), which will cause a CLOSE message to be +sent, and the local L5 object will be put into the "closing "state. When the +other side receives the CLOSE, it will send its own CLOSE for the same +subchannel, and fully close its local object (``connectionLost()``). When the +first side receives CLOSE in the "closing" state, it will fully close its +local object too. + +All L5 subchannels will be paused (``pauseProducing()``) when the L3 +connection is paused or lost. They are resumed when the L3 connection is +resumed or reestablished. + +## Initiating Dilation + +Dilation is triggered by calling the `w.dilate()` API. This returns a +Deferred that will fire once the first L3 connection is established. It fires +with a 3-tuple of endpoints that can be used to establish subchannels. + +For dilation to succeed, both sides must call `w.dilate()`, since the +resulting endpoints are the only way to access the subchannels. If the other +side never calls `w.dilate()`, the Deferred will never fire. + +The L1 (mailbox) path is used to deliver dilation requests and connection +hints. The current mailbox protocol uses named "phases" to distinguish +messages (rather than behaving like a regular ordered channel of arbitrary +frames or bytes), and all-number phase names are reserved for application +data (sent via `w.send_message()`). Therefore the dilation control messages +use phases named `DILATE-0`, `DILATE-1`, etc. Each side maintains its own +counter, so one side might be up to e.g. `DILATE-5` while the other has only +gotten as far as `DILATE-2`. This effectively creates a pair of +unidirectional streams of `DILATE-n` messages, each containing one or more +dilation record, of various types described below. Note that all phases +beyond the initial VERSION and PAKE phases are encrypted by the shared +session key. + +A future mailbox protocol might provide a simple ordered stream of typed +messages, with application records and dilation records mixed together. + +Each `DILATE-n` message is a JSON-encoded dictionary with a `type` field that +has a string value. The dictionary will have other keys that depend upon the +type. + +`w.dilate()` triggers transmission of a `please` (i.e. "please dilate") +record with a set of versions that can be accepted. Versions use strings, +rather than integers, to support experimental protocols, however there is +still a total ordering of preferability. + +``` +{ "type": "please", + "side": "abcdef", + "accepted-versions": ["1"] +} +``` + +If one side receives a `please` before `w.dilate()` has been called locally, +the contents are stored in case `w.dilate()` is called in the future. Once +both `w.dilate()` has been called and the peer's `please` has been received, +the side determines whether it is the Leader or the Follower. Both sides also +compare `accepted-versions` fields to choose the best mutually-compatible +version to use: they should always pick the same one. + +Then both sides begin the connection process for generation 1 by opening +listening sockets and sending `connection-hint` records for each one. After a +slight delay they will also open connections to the Transit Relay of their +choice and produce hints for it too. The receipt of inbound hints (on both +sides) will trigger outbound connection attempts. + +Some number of these connections may succeed, and the Leader decides which to +use (via an in-band signal on the established connection). The others are +dropped. + +If something goes wrong with the established connection and the Leader +decides a new one is necessary, the Leader will send a `reconnect` message. +This might happen while connections are still being established, or while the +Follower thinks it still has a viable connection (the Leader might observe +problems that the Follower does not), or after the Follower thinks the +connection has been lost. In all cases, the Leader is the only side which +should send `reconnect`. The state machine code looks the same on both sides, +for simplicity, but one path on each side is never used. + +Upon receiving a `reconnect`, the Follower should stop any pending connection +attempts and terminate any existing connections (even if they appear viable). +Listening sockets may be retained, but any previous connection made through +them must be dropped. + +Once all connections have stopped, the Follower should send a `reconnecting` +message, then start the connection process for the next generation, which +will send new `connection-hint` messages for all listening sockets. + +Generations are non-overlapping. The Leader will drop all connections from +generation 1 before sending the `reconnect` for generation 2, and will not +initiate any gen-2 connections until it receives the matching `reconnecting` +from the Follower. The Follower must drop all gen-1 connections before it +sends the `reconnecting` response (even if it thinks they are still +functioning: if the Leader thought the gen-1 connection still worked, it +wouldn't have started gen-2). + +(TODO: what about a follower->leader connection that was started before +start-dilation is received, and gets established on the Leader side after +start-dilation is sent? the follower will drop it after it receives +start-dilation, but meanwhile the leader may accept it as gen2) + +(probably need to include the generation number in the handshake, or in the +derived key) + +(TODO: reduce the number of round-trip stalls here, I've added too many) + +Each side is in the "connecting" state (which encompasses both making +connection attempts and having an established connection) starting with the +receipt of a `please-dilate` message and a local `w.dilate()` call. The +Leader remains in that state until it abandons the connection and sends a +`reconnect` message, at which point it remains in the "flushing" state until +the Follower's `reconnecting` message is received. The Follower remains in +"connecting" until it receives `reconnect`, then it stays in "dropping" until +it finishes halting all outstanding connections, after which it sends +`reconnecting` and switches back to "connecting". + +"Connection hints" are type/address/port records that tell the other side of +likely targets for L2 connections. Both sides will try to determine their +external IP addresses, listen on a TCP port, and advertise `(tcp, +external-IP, port)` as a connection hint. The Transit Relay is also used as a +(lower-priority) hint. These are sent in `connection-hint` records, which can +be sent any time after both sending and receiving a `please` record. Each +side will initiate connections upon receipt of the hints. + +``` +{ "type": "connection-hints", + "hints": [ ... ] +} +``` + +Hints can arrive at any time. One side might immediately send hints that can +be computed quickly, then send additional hints later as they become +available. For example, it might enumerate the local network interfaces and +send hints for all of the LAN addresses first, then send port-forwarding +(UPnP) requests to the local router. When the forwarding is established +(providing an externally-visible IP address and port), it can send additional +hints for that new endpoint. If the other peer happens to be on the same LAN, +the local connection can be established without waiting for the router's +response. + + +### Connection Hint Format + +Each member of the `hints` field describes a potential L2 connection target +endpoint, with an associated priority and a set of hints. + +The priority is a number (positive or negative float), where larger numbers +indicate that the client supplying that hint would prefer to use this +connection over others of lower number. This indicates a sense of cost or +performance. For example, the Transit Relay is lower priority than a direct +TCP connection, because it incurs a bandwidth cost (on the relay operator), +as well as adding latency. + +Each endpoint has a set of hints, because the same target might be reachable +by multiple hints. Once one hint succeeds, there is no point in using the +other hints. + +TODO: think this through some more. What's the example of a single endpoint +reachable by multiple hints? Should each hint have its own priority, or just +each endpoint? + +## L2 protocol + +Upon ``connectionMade()``, both sides send their handshake message. The +Leader sends "Magic-Wormhole Dilation Handshake v1 Leader\n\n". The Follower +sends "Magic-Wormhole Dilation Handshake v1 Follower\n\n". This should +trigger an immediate error for most non-magic-wormhole listeners (e.g. HTTP +servers that were contacted by accident). If the wrong handshake is received, +the connection will be dropped. For debugging purposes, the node might want +to keep looking at data beyond the first incorrect character and log +everything until the first newline. + +Everything beyond that point is a Noise protocol message, which consists of a +4-byte big-endian length field, followed by the indicated number of bytes. +This uses the `NNpsk0` pattern with the Leader as the first party ("-> psk, +e" in the Noise spec), and the Follower as the second ("<- e, ee"). The +pre-shared-key is the "dilation key", which is statically derived from the +master PAKE key using HKDF. Each L2 connection uses the same dilation key, +but different ephemeral keys, so each gets a different session key. + +The Leader sends the first message, which is a psk-encrypted ephemeral key. +The Follower sends the next message, its own psk-encrypted ephemeral key. The +Follower then sends an empty packet as the "key confirmation message", which +will be encrypted by the shared key. + +The Leader sees the KCM and knows the connection is viable. It delivers the +protocol object to the L3 manager, which will decide which connection to +select. When the L2 connection is selected to be the new L3, it will send an +empty KCM of its own, to let the Follower know the connection being selected. +All other L2 connections (either viable or still in handshake) are dropped, +all other connection attempts are cancelled. All listening sockets may or may +not be shut down (TODO: think about it). + +The Follower will wait for either an empty KCM (at which point the L2 +connection is delivered to the Dilation manager as the new L3), a +disconnection, or an invalid message (which causes the connection to be +dropped). Other connections and/or listening sockets are stopped. + +Internally, the L2Protocol object manages the Noise session itself. It knows +(via a constructor argument) whether it is on the Leader or Follower side, +which affects both the role is plays in the Noise pattern, and the reaction +to receiving the ephemeral key (for which only the Follower sends an empty +KCM message). After that, the L2Protocol notifies the L3 object in three +situations: + +* the Noise session produces a valid decrypted frame (for Leader, this + includes the Follower's KCM, and thus indicates a viable candidate for + connection selection) +* the Noise session reports a failed decryption +* the TCP session is lost + +All notifications include a reference to the L2Protocol object (`self`). The +L3 object uses this reference to either close the connection (for errors or +when the selection process chooses someone else), to send the KCM message +(after selection, only for the Leader), or to send other L4 messages. The L3 +object will retain a reference to the winning L2 object. + +## L3 protocol + +The L3 layer is responsible for connection selection, monitoring/keepalives, +and message (de)serialization. Framing is handled by L2, so the inbound L3 +codepath receives single-message bytestrings, and delivers the same down to +L2 for encryption, framing, and transmission. + +Connection selection takes place exclusively on the Leader side, and includes +the following: + +* receipt of viable L2 connections from below (indicated by the first valid + decrypted frame received for any given connection) +* expiration of a timer +* comparison of TBD quality/desirability/cost metrics of viable connections +* selection of winner +* instructions to losing connections to disconnect +* delivery of KCM message through winning connection +* retain reference to winning connection + +On the Follower side, the L3 manager just waits for the first connection to +receive the Leader's KCM, at which point it is retained and all others are +dropped. + +The L3 manager knows which "generation" of connection is being established. +Each generation uses a different dilation key (?), and is triggered by a new +set of L1 messages. Connections from one generation should not be confused +with those of a different generation. + +Each time a new L3 connection is established, the L4 protocol is notified. It +will will immediately send all the L4 messages waiting in its outbound queue. +The L3 protocol simply wraps these in Noise frames and sends them to the +other side. + +The L3 manager monitors the viability of the current connection, and declares +it as lost when bidirectional traffic cannot be maintained. It uses PING and +PONG messages to detect this. These also serve to keep NAT entries alive, +since many firewalls will stop forwarding packets if they don't observe any +traffic for e.g. 5 minutes. + +Our goals are: + +* don't allow more than 30? seconds to pass without at least *some* data + being sent along each side of the connection +* allow the Leader to detect silent connection loss within 60? seconds +* minimize overhead + +We need both sides to: + +* maintain a 30-second repeating timer +* set a flag each time we write to the connection +* each time the timer fires, if the flag was clear then send a PONG, + otherwise clear the flag + +In addition, the Leader must: + +* run a 60-second repeating timer (ideally somewhat offset from the other) +* set a flag each time we receive data from the connection +* each time the timer fires, if the flag was clear then drop the connection, + otherwise clear the flag + +In the future, we might have L2 links that are less connection-oriented, +which might have a unidirectional failure mode, at which point we'll need to +monitor full roundtrips. To accomplish this, the Leader will send periodic +unconditional PINGs, and the Follower will respond with PONGs. If the +Leader->Follower connection is down, the PINGs won't arrive and no PONGs will +be produced. If the Follower->Leader direction has failed, the PONGs won't +arrive. The delivery of both will be delayed by actual data, so the timeouts +should be adjusted if we see regular data arriving. + +If the connection is dropped before the wormhole is closed (either the other +end explicitly dropped it, we noticed a problem and told TCP to drop it, or +TCP noticed a problem itself), the Leader-side L3 manager will initiate a +reconnection attempt. This uses L1 to send a new DILATE message through the +mailbox server, along with new connection hints. Eventually this will result +in a new L3 connection being established. + +Finally, L3 is responsible for message serialization and deserialization. L2 +performs decryption and delivers plaintext frames to L3. Each frame starts +with a one-byte type indicator. The rest of the message depends upon the +type: + +* 0x00 PING, 4-byte ping-id +* 0x01 PONG, 4-byte ping-id +* 0x02 OPEN, 4-byte subchannel-id, 4-byte seqnum +* 0x03 DATA, 4-byte subchannel-id, 4-byte seqnum, variable-length payload +* 0x04 CLOSE, 4-byte subchannel-id, 4-byte seqnum +* 0x05 ACK, 4-byte response-seqnum + +All seqnums are big-endian, and are provided by the L4 protocol. The other +fields are arbitrary and not interpreted as integers. The subchannel-ids must +be allocated by both sides without collision, but otherwise they are only +used to look up L5 objects for dispatch. The response-seqnum is always copied +from the OPEN/DATA/CLOSE packet being acknowledged. + +L3 consumes the PING and PONG messages. Receiving any PING will provoke a +PONG in response, with a copy of the ping-id field. The 30-second timer will +produce unprovoked PONGs with a ping-id of all zeros. A future viability +protocol will use PINGs to test for roundtrip functionality. + +All other messages (OPEN/DATA/CLOSE/ACK) are deserialized and delivered +"upstairs" to the L4 protocol handler. + +The current L3 connection's `IProducer`/`IConsumer` interface is made +available to the L4 flow-control manager. + +## L4 protocol + +The L4 protocol manages a durable stream of OPEN/DATA/CLOSE/ACK messages. +Since each will be enclosed in a Noise frame before they pass to L3, they do +not need length fields or other framing. + +Each OPEN/DATA/CLOSE has a sequence number, starting at 0, and monotonically +increasing by 1 for each message. Each direction has a separate number space. + +The L4 manager maintains a double-ended queue of unacknowledged outbound +messages. Subchannel activity (opening, closing, sending data) cause messages +to be added to this queue. If an L3 connection is available, these messages +are also sent over that connection, but they remain in the queue in case the +connection is lost and they must be retransmitted on some future replacement +connection. Messages stay in the queue until they can be retired by the +receipt of an ACK with a matching response-sequence-number. This provides +reliable message delivery that survives the L3 connection being replaced. + +ACKs are not acked, nor do they have seqnums of their own. Each inbound side +remembers the highest ACK it has sent, and ignores incoming OPEN/DATA/CLOSE +messages with that sequence number or higher. This ensures in-order +at-most-once processing of OPEN/DATA/CLOSE messages. + +Each inbound OPEN message causes a new L5 subchannel object to be created. +Subsequent DATA/CLOSE messages for the same subchannel-id are delivered to +that object. + +Each time an L3 connection is established, the side will immediately send all +L4 messages waiting in the outbound queue. A future protocol might reduce +this duplication by including the highest received sequence number in the L1 +PLEASE-DILATE message, which would effectively retire queued messages before +initiating the L2 connection process. On any given L3 connection, all +messages are sent in-order. The receipt of an ACK for seqnum `N` allows all +messages with `seqnum <= N` to be retired. + +The L4 layer is also responsible for managing flow control among the L3 +connection and the various L5 subchannels. + +## L5 subchannels + +The L5 layer consists of a collection of "subchannel" objects, a dispatcher, +and the endpoints that provide the Twisted-flavored API. + +Other than the "control channel", all subchannels are created by a client +endpoint connection API. The side that calls this API is named the Initiator, +and the other side is named the Acceptor. Subchannels can be initiated in +either direction, independent of the Leader/Follower distinction. For a +typical file-transfer application, the subchannel would be initiated by the +side seeking to send a file. + +Each subchannel uses a distinct subchannel-id, which is a four-byte +identifier. Both directions share a number space (unlike L4 seqnums), so the +rule is that the Leader side sets the last bit of the last byte to a 1, while +the Follower sets it to a 0. These are not generally treated as integers, +however for the sake of debugging, the implementation generates them with a +simple big-endian-encoded counter (`counter*2+1` for the Leader, +`counter*2+2` for the Follower, with id `0` reserved for the control +channel). + +When the `client_ep.connect()` API is called, the Initiator allocates a +subchannel-id and sends an OPEN. It can then immediately send DATA messages +with the outbound data (there is no special response to an OPEN, so there is +no need to wait). The Acceptor will trigger their `.connectionMade` handler +upon receipt of the OPEN. + +Subchannels are durable: they do not close until one side calls +`.loseConnection` on the subchannel object (or the enclosing Wormhole is +closed). Either the Initiator or the Acceptor can call `.loseConnection`. +This causes a CLOSE message to be sent (with the subchannel-id). The other +side will send its own CLOSE message in response. Each side will signal the +`.connectionLost()` event upon receipt of a CLOSE. + +There is no equivalent to TCP's "half-closed" state, however if only one side +calls `close()`, then all data written before that call will be delivered +before the other side observes `.connectionLost()`. Any inbound data that was +queued for delivery before the other side sees the CLOSE will still be +delivered to the side that called `close()` before it sees +`.connectionLost()`. Internally, the side which called `.loseConnection` will +remain in a special "closing" state until the CLOSE response arrives, during +which time DATA payloads are still delivered. After calling `close()` (or +receiving CLOSE), any outbound `.write()` calls will trigger an error. + +DATA payloads that arrive for a non-open subchannel are logged and discarded. + +This protocol calls for one OPEN and two CLOSE messages for each subchannel, +with some arbitrary number of DATA messages in between. Subchannel-ids should +not be reused (it would probably work, the protocol hasn't been analyzed +enough to be sure). + +The "control channel" is special. It uses a subchannel-id of all zeros, and +is opened implicitly by both sides as soon as the first L3 connection is +selected. It is routed to a special client-on-both-sides endpoint, rather +than causing the listening endpoint to accept a new connection. This avoids +the need for application-level code to negotiate who should be the one to +open it (the Leader/Follower distinction is private to the Wormhole +internals: applications are not obligated to pick a side). + +OPEN and CLOSE messages for the control channel are logged and discarded. The +control-channel client endpoints can only be used once, and does not close +until the Wormhole itself is closed. + +Each OPEN/DATA/CLOSE message is delivered to the L4 object for queueing, +delivery, and eventual retirement. The L5 layer does not keep track of old +messages. + +### Flow Control + +Subchannels are flow-controlled by pausing their writes when the L3 +connection is paused, and pausing the L3 connection when the subchannel +signals a pause. When the outbound L3 connection is full, *all* subchannels +are paused. Likewise the inbound connection is paused if *any* of the +subchannels asks for a pause. This is much easier to implement and improves +our utilization factor (we can use TCP's window-filling algorithm, instead of +rolling our own), but will block all subchannels even if only one of them +gets full. This shouldn't matter for many applications, but might be +noticeable when combining very different kinds of traffic (e.g. a chat +conversation sharing a wormhole with file-transfer might prefer the IM text +to take priority). + +Each subchannel implements Twisted's `ITransport`, `IProducer`, and +`IConsumer` interfaces. The Endpoint API causes a new `IProtocol` object to +be created (by the caller's factory) and glued to the subchannel object in +the `.transport` property, as is standard in Twisted-based applications. + +All subchannels are also paused when the L3 connection is lost, and are +unpaused when a new replacement connection is selected. diff --git a/docs/new-protocol.svg b/docs/new-protocol.svg new file mode 100644 index 0000000..a69b79a --- /dev/null +++ b/docs/new-protocol.svg @@ -0,0 +1,2000 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + image/svg+xml + + + + + + + connectionMade() + dataReceived() + dataReceived() + connectionLost() + + + empty + + + + open + + + + open + + + + closing + + + + empty + + + + + connect() + + + Open1 + + write() + write() + loseConnection() + connectionLost() + + + open + + + + + + Data1 + + + + Data1 + + + + Close1 + + + + Close1 + + + + + + + + + + + + + + empty + + + + open + + + + open + + + + empty + + + + + + open + + + + + + + + + + + + + + + Open1 + + + + Data1 + + + + Data1 + + + + Close1 + + connection 1 + + + Open1 + + + + Data1 + + 0 + 1 + 2 + 3 + + + ack 0 + + connection 2 + + + Data1 + + + + ack 1 + + + + Data1 + + + + Close1 + + + + ack 2 + + + + ack 3 + + + + Close1 + + + + ack 0' + + 0' + 0 + 1 + 1 + 2 + 3 + 0' + logical + 0 + 1 + 2 + 3 + + diff --git a/docs/state-machines/machines.dot b/docs/state-machines/machines.dot index ecfdf9a..09cc02c 100644 --- a/docs/state-machines/machines.dot +++ b/docs/state-machines/machines.dot @@ -23,12 +23,13 @@ digraph { Terminator [shape="box" color="blue" fontcolor="blue"] InputHelperAPI [shape="oval" label="input\nhelper\nAPI" color="blue" fontcolor="blue"] + Dilator [shape="box" label="Dilator" color="blue" fontcolor="blue"] #Connection -> websocket [color="blue"] #Connection -> Order [color="blue"] Wormhole -> Boss [style="dashed" - label="allocate_code\ninput_code\nset_code\nsend\nclose\n(once)" + label="allocate_code\ninput_code\nset_code\ndilate\nsend\nclose\n(once)" color="red" fontcolor="red"] #Wormhole -> Boss [color="blue"] Boss -> Wormhole [style="dashed" label="got_code\ngot_key\ngot_verifier\ngot_version\nreceived (seq)\nclosed\n(once)"] @@ -112,4 +113,7 @@ digraph { Terminator -> Boss [style="dashed" label="closed\n(once)"] Boss -> Terminator [style="dashed" color="red" fontcolor="red" label="close"] + + Boss -> Dilator [style="dashed" label="dilate\nreceived_dilate\ngot_wormhole_versions"] + Dilator -> Send [style="dashed" label="send(dilate-N)"] } diff --git a/setup.cfg b/setup.cfg index aec9318..c2199e9 100644 --- a/setup.cfg +++ b/setup.cfg @@ -7,3 +7,6 @@ versionfile_source = src/wormhole/_version.py versionfile_build = wormhole/_version.py tag_prefix = parentdir_prefix = magic-wormhole + +[flake8] +max-line-length = 85 diff --git a/setup.py b/setup.py index a99a2d2..f428250 100644 --- a/setup.py +++ b/setup.py @@ -54,6 +54,7 @@ setup(name="magic-wormhole", "dev": ["mock", "tox", "pyflakes", "magic-wormhole-transit-relay==0.1.2", "magic-wormhole-mailbox-server==0.3.1"], + "dilate": ["noiseprotocol"], }, test_suite="wormhole.test", cmdclass=commands, diff --git a/src/wormhole/__main__.py b/src/wormhole/__main__.py index d6c1b97..37e97b7 100644 --- a/src/wormhole/__main__.py +++ b/src/wormhole/__main__.py @@ -1,7 +1,7 @@ from __future__ import absolute_import, print_function, unicode_literals -from .cli import cli - -if __name__ != "__main__": - raise ImportError('this module should not be imported') - -cli.wormhole() +if __name__ == "__main__": + from .cli import cli + cli.wormhole() +else: + # raise ImportError('this module should not be imported') + pass diff --git a/src/wormhole/_boss.py b/src/wormhole/_boss.py index 097efe7..ce650b7 100644 --- a/src/wormhole/_boss.py +++ b/src/wormhole/_boss.py @@ -12,6 +12,7 @@ from zope.interface import implementer from . import _interfaces from ._allocator import Allocator from ._code import Code, validate_code +from ._dilation.manager import Dilator from ._input import Input from ._key import Key from ._lister import Lister @@ -38,6 +39,8 @@ class Boss(object): _versions = attrib(validator=instance_of(dict)) _client_version = attrib(validator=instance_of(tuple)) _reactor = attrib() + _eventual_queue = attrib() + _cooperator = attrib() _journal = attrib(validator=provides(_interfaces.IJournal)) _tor = attrib(validator=optional(provides(_interfaces.ITorManager))) _timing = attrib(validator=provides(_interfaces.ITiming)) @@ -64,6 +67,8 @@ class Boss(object): self._I = Input(self._timing) self._C = Code(self._timing) self._T = Terminator() + self._D = Dilator(self._reactor, self._eventual_queue, + self._cooperator) self._N.wire(self._M, self._I, self._RC, self._T) self._M.wire(self._N, self._RC, self._O, self._T) @@ -77,6 +82,7 @@ class Boss(object): self._I.wire(self._C, self._L) self._C.wire(self, self._A, self._N, self._K, self._I) self._T.wire(self, self._RC, self._N, self._M) + self._D.wire(self._S) def _init_other_state(self): self._did_start_code = False @@ -84,6 +90,9 @@ class Boss(object): self._next_rx_phase = 0 self._rx_phases = {} # phase -> plaintext + self._next_rx_dilate_seqnum = 0 + self._rx_dilate_seqnums = {} # seqnum -> plaintext + self._result = "empty" # these methods are called from outside @@ -196,6 +205,9 @@ class Boss(object): self._did_start_code = True self._C.set_code(code) + def dilate(self): + return self._D.dilate() # fires with endpoints + @m.input() def send(self, plaintext): pass @@ -255,8 +267,11 @@ class Boss(object): def got_message(self, phase, plaintext): assert isinstance(phase, type("")), type(phase) assert isinstance(plaintext, type(b"")), type(plaintext) + d_mo = re.search(r'^dilate-(\d+)$', phase) if phase == "version": self._got_version(plaintext) + elif d_mo: + self._got_dilate(int(d_mo.group(1)), plaintext) elif re.search(r'^\d+$', phase): self._got_phase(int(phase), plaintext) else: @@ -272,6 +287,10 @@ class Boss(object): def _got_phase(self, phase, plaintext): pass + @m.input() + def _got_dilate(self, seqnum, plaintext): + pass + @m.input() def got_key(self, key): pass @@ -294,6 +313,7 @@ class Boss(object): # most of this is wormhole-to-wormhole, ignored for now # in the future, this is how Dilation is signalled self._their_versions = bytes_to_dict(plaintext) + self._D.got_wormhole_versions(self._their_versions) # but this part is app-to-app app_versions = self._their_versions.get("app_versions", {}) self._W.got_versions(app_versions) @@ -335,6 +355,10 @@ class Boss(object): def W_got_key(self, key): self._W.got_key(key) + @m.output() + def D_got_key(self, key): + self._D.got_key(key) + @m.output() def W_got_verifier(self, verifier): self._W.got_verifier(verifier) @@ -348,6 +372,16 @@ class Boss(object): self._W.received(self._rx_phases.pop(self._next_rx_phase)) self._next_rx_phase += 1 + @m.output() + def D_received_dilate(self, seqnum, plaintext): + assert isinstance(seqnum, six.integer_types), type(seqnum) + # strict phase order, no gaps + self._rx_dilate_seqnums[seqnum] = plaintext + while self._next_rx_dilate_seqnum in self._rx_dilate_seqnums: + m = self._rx_dilate_seqnums.pop(self._next_rx_dilate_seqnum) + self._D.received_dilate(m) + self._next_rx_dilate_seqnum += 1 + @m.output() def W_close_with_error(self, err): self._result = err # exception @@ -370,7 +404,7 @@ class Boss(object): S1_lonely.upon(scared, enter=S3_closing, outputs=[close_scared]) S1_lonely.upon(close, enter=S3_closing, outputs=[close_lonely]) S1_lonely.upon(send, enter=S1_lonely, outputs=[S_send]) - S1_lonely.upon(got_key, enter=S1_lonely, outputs=[W_got_key]) + S1_lonely.upon(got_key, enter=S1_lonely, outputs=[W_got_key, D_got_key]) S1_lonely.upon(rx_error, enter=S3_closing, outputs=[close_error]) S1_lonely.upon(error, enter=S4_closed, outputs=[W_close_with_error]) @@ -378,6 +412,7 @@ class Boss(object): S2_happy.upon(got_verifier, enter=S2_happy, outputs=[W_got_verifier]) S2_happy.upon(_got_phase, enter=S2_happy, outputs=[W_received]) S2_happy.upon(_got_version, enter=S2_happy, outputs=[process_version]) + S2_happy.upon(_got_dilate, enter=S2_happy, outputs=[D_received_dilate]) S2_happy.upon(scared, enter=S3_closing, outputs=[close_scared]) S2_happy.upon(close, enter=S3_closing, outputs=[close_happy]) S2_happy.upon(send, enter=S2_happy, outputs=[S_send]) @@ -389,6 +424,7 @@ class Boss(object): S3_closing.upon(got_verifier, enter=S3_closing, outputs=[]) S3_closing.upon(_got_phase, enter=S3_closing, outputs=[]) S3_closing.upon(_got_version, enter=S3_closing, outputs=[]) + S3_closing.upon(_got_dilate, enter=S3_closing, outputs=[]) S3_closing.upon(happy, enter=S3_closing, outputs=[]) S3_closing.upon(scared, enter=S3_closing, outputs=[]) S3_closing.upon(close, enter=S3_closing, outputs=[]) @@ -400,6 +436,7 @@ class Boss(object): S4_closed.upon(got_verifier, enter=S4_closed, outputs=[]) S4_closed.upon(_got_phase, enter=S4_closed, outputs=[]) S4_closed.upon(_got_version, enter=S4_closed, outputs=[]) + S4_closed.upon(_got_dilate, enter=S4_closed, outputs=[]) S4_closed.upon(happy, enter=S4_closed, outputs=[]) S4_closed.upon(scared, enter=S4_closed, outputs=[]) S4_closed.upon(close, enter=S4_closed, outputs=[]) diff --git a/src/wormhole/_dilation/__init__.py b/src/wormhole/_dilation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/wormhole/_dilation/_noise.py b/src/wormhole/_dilation/_noise.py new file mode 100644 index 0000000..bb4cf58 --- /dev/null +++ b/src/wormhole/_dilation/_noise.py @@ -0,0 +1,11 @@ +try: + from noise.exceptions import NoiseInvalidMessage +except ImportError: + class NoiseInvalidMessage(Exception): + pass + +try: + from noise.connection import NoiseConnection +except ImportError: + # allow imports to work on py2.7, even if dilation doesn't + NoiseConnection = None diff --git a/src/wormhole/_dilation/connection.py b/src/wormhole/_dilation/connection.py new file mode 100644 index 0000000..b8f3ec6 --- /dev/null +++ b/src/wormhole/_dilation/connection.py @@ -0,0 +1,520 @@ +from __future__ import print_function, unicode_literals +from collections import namedtuple +import six +from attr import attrs, attrib +from attr.validators import instance_of, provides +from automat import MethodicalMachine +from zope.interface import Interface, implementer +from twisted.python import log +from twisted.internet.protocol import Protocol +from twisted.internet.interfaces import ITransport +from .._interfaces import IDilationConnector +from ..observer import OneShotObserver +from .encode import to_be4, from_be4 +from .roles import FOLLOWER +from ._noise import NoiseInvalidMessage + +# InboundFraming is given data and returns Frames (Noise wire-side +# bytestrings). It handles the relay handshake and the prologue. The Frames it +# returns are either the ephemeral key (the Noise "handshake") or ciphertext +# messages. + +# The next object up knows whether it's expecting a Handshake or a message. It +# feeds the first into Noise as a handshake, it feeds the rest into Noise as a +# message (which produces a plaintext stream). It emits tokens that are either +# "i've finished with the handshake (so you can send the KCM if you want)", or +# "here is a decrypted message (which might be the KCM)". + +# the transmit direction goes directly to transport.write, and doesn't touch +# the state machine. we can do this because the way we encode/encrypt/frame +# things doesn't depend upon the receiver state. It would be more safe to e.g. +# prohibit sending ciphertext frames unless we're in the received-handshake +# state, but then we'll be in the middle of an inbound state transition ("we +# just received the handshake, so you can send a KCM now") when we perform an +# operation that depends upon the state (send_plaintext(kcm)), which is not a +# coherent/safe place to touch the state machine. + +# we could set a flag and test it from inside send_plaintext, which kind of +# violates the state machine owning the state (ideally all "if" statements +# would be translated into same-input transitions from different starting +# states). For the specific question of sending plaintext frames, Noise will +# refuse us unless it's ready anyways, so the question is probably moot. + + +class IFramer(Interface): + pass + + +class IRecord(Interface): + pass + + +def first(l): + return l[0] + + +class Disconnect(Exception): + pass + + +RelayOK = namedtuple("RelayOk", []) +Prologue = namedtuple("Prologue", []) +Frame = namedtuple("Frame", ["frame"]) + + +@attrs +@implementer(IFramer) +class _Framer(object): + _transport = attrib(validator=provides(ITransport)) + _outbound_prologue = attrib(validator=instance_of(bytes)) + _inbound_prologue = attrib(validator=instance_of(bytes)) + _buffer = b"" + _can_send_frames = False + + # in: use_relay + # in: connectionMade, dataReceived + # out: prologue_received, frame_received + # out (shared): transport.loseConnection + # out (shared): transport.write (relay handshake, prologue) + # states: want_relay, want_prologue, want_frame + m = MethodicalMachine() + set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover + + @m.state() + def want_relay(self): + pass # pragma: no cover + + @m.state(initial=True) + def want_prologue(self): + pass # pragma: no cover + + @m.state() + def want_frame(self): + pass # pragma: no cover + + @m.input() + def use_relay(self, relay_handshake): + pass + + @m.input() + def connectionMade(self): + pass + + @m.input() + def parse(self): + pass + + @m.input() + def got_relay_ok(self): + pass + + @m.input() + def got_prologue(self): + pass + + @m.output() + def store_relay_handshake(self, relay_handshake): + self._outbound_relay_handshake = relay_handshake + self._expected_relay_handshake = b"ok\n" # TODO: make this configurable + + @m.output() + def send_relay_handshake(self): + self._transport.write(self._outbound_relay_handshake) + + @m.output() + def send_prologue(self): + self._transport.write(self._outbound_prologue) + + @m.output() + def parse_relay_ok(self): + if self._get_expected("relay_ok", self._expected_relay_handshake): + return RelayOK() + + @m.output() + def parse_prologue(self): + if self._get_expected("prologue", self._inbound_prologue): + return Prologue() + + @m.output() + def can_send_frames(self): + self._can_send_frames = True # for assertion in send_frame() + + @m.output() + def parse_frame(self): + if len(self._buffer) < 4: + return None + frame_length = from_be4(self._buffer[0:4]) + if len(self._buffer) < 4 + frame_length: + return None + frame = self._buffer[4:4 + frame_length] + self._buffer = self._buffer[4 + frame_length:] # TODO: avoid copy + return Frame(frame=frame) + + want_prologue.upon(use_relay, outputs=[store_relay_handshake], + enter=want_relay) + + want_relay.upon(connectionMade, outputs=[send_relay_handshake], + enter=want_relay) + want_relay.upon(parse, outputs=[parse_relay_ok], enter=want_relay, + collector=first) + want_relay.upon(got_relay_ok, outputs=[send_prologue], enter=want_prologue) + + want_prologue.upon(connectionMade, outputs=[send_prologue], + enter=want_prologue) + want_prologue.upon(parse, outputs=[parse_prologue], enter=want_prologue, + collector=first) + want_prologue.upon(got_prologue, outputs=[can_send_frames], enter=want_frame) + + want_frame.upon(parse, outputs=[parse_frame], enter=want_frame, + collector=first) + + def _get_expected(self, name, expected): + lb = len(self._buffer) + le = len(expected) + if self._buffer.startswith(expected): + # if the buffer starts with the expected string, consume it and + # return True + self._buffer = self._buffer[le:] + return True + if not expected.startswith(self._buffer): + # we're not on track: the data we've received so far does not + # match the expected value, so this can't possibly be right. + # Don't complain until we see the expected length, or a newline, + # so we can capture the weird input in the log for debugging. + if (b"\n" in self._buffer or lb >= le): + log.msg("bad {}: {}".format(name, self._buffer[:le])) + raise Disconnect() + return False # wait a bit longer + # good so far, just waiting for the rest + return False + + # external API is: connectionMade, add_and_parse, and send_frame + + def add_and_parse(self, data): + # we can't make this an @m.input because we can't change the state + # from within an input. Instead, let the state choose the parser to + # use, and use the parsed token drive a state transition. + self._buffer += data + while True: + # it'd be nice to use an iterator here, but since self.parse() + # dispatches to a different parser (depending upon the current + # state), we'd be using multiple iterators + token = self.parse() + if isinstance(token, RelayOK): + self.got_relay_ok() + elif isinstance(token, Prologue): + self.got_prologue() + yield token # triggers send_handshake + elif isinstance(token, Frame): + yield token + else: + break + + def send_frame(self, frame): + assert self._can_send_frames + self._transport.write(to_be4(len(frame)) + frame) + +# RelayOK: Newline-terminated buddy-is-connected response from Relay. +# First data received from relay. +# Prologue: double-newline-terminated this-is-really-wormhole response +# from peer. First data received from peer. +# Frame: Either handshake or encrypted message. Length-prefixed on wire. +# Handshake: the Noise ephemeral key, first framed message +# Message: plaintext: encoded KCM/PING/PONG/OPEN/DATA/CLOSE/ACK +# KCM: Key Confirmation Message (encrypted b"\x00"). First frame +# from peer. Sent immediately by Follower, after Selection by Leader. +# Record: namedtuple of KCM/Open/Data/Close/Ack/Ping/Pong + + +Handshake = namedtuple("Handshake", []) +# decrypted frames: produces KCM, Ping, Pong, Open, Data, Close, Ack +KCM = namedtuple("KCM", []) +Ping = namedtuple("Ping", ["ping_id"]) # ping_id is arbitrary 4-byte value +Pong = namedtuple("Pong", ["ping_id"]) +Open = namedtuple("Open", ["seqnum", "scid"]) # seqnum is integer +Data = namedtuple("Data", ["seqnum", "scid", "data"]) +Close = namedtuple("Close", ["seqnum", "scid"]) # scid is integer +Ack = namedtuple("Ack", ["resp_seqnum"]) # resp_seqnum is integer +Records = (KCM, Ping, Pong, Open, Data, Close, Ack) +Handshake_or_Records = (Handshake,) + Records + +T_KCM = b"\x00" +T_PING = b"\x01" +T_PONG = b"\x02" +T_OPEN = b"\x03" +T_DATA = b"\x04" +T_CLOSE = b"\x05" +T_ACK = b"\x06" + + +def parse_record(plaintext): + msgtype = plaintext[0:1] + if msgtype == T_KCM: + return KCM() + if msgtype == T_PING: + ping_id = plaintext[1:5] + return Ping(ping_id) + if msgtype == T_PONG: + ping_id = plaintext[1:5] + return Pong(ping_id) + if msgtype == T_OPEN: + scid = from_be4(plaintext[1:5]) + seqnum = from_be4(plaintext[5:9]) + return Open(seqnum, scid) + if msgtype == T_DATA: + scid = from_be4(plaintext[1:5]) + seqnum = from_be4(plaintext[5:9]) + data = plaintext[9:] + return Data(seqnum, scid, data) + if msgtype == T_CLOSE: + scid = from_be4(plaintext[1:5]) + seqnum = from_be4(plaintext[5:9]) + return Close(seqnum, scid) + if msgtype == T_ACK: + resp_seqnum = from_be4(plaintext[1:5]) + return Ack(resp_seqnum) + log.err("received unknown message type: {}".format(plaintext)) + raise ValueError() + + +def encode_record(r): + if isinstance(r, KCM): + return b"\x00" + if isinstance(r, Ping): + return b"\x01" + r.ping_id + if isinstance(r, Pong): + return b"\x02" + r.ping_id + if isinstance(r, Open): + assert isinstance(r.scid, six.integer_types) + assert isinstance(r.seqnum, six.integer_types) + return b"\x03" + to_be4(r.scid) + to_be4(r.seqnum) + if isinstance(r, Data): + assert isinstance(r.scid, six.integer_types) + assert isinstance(r.seqnum, six.integer_types) + return b"\x04" + to_be4(r.scid) + to_be4(r.seqnum) + r.data + if isinstance(r, Close): + assert isinstance(r.scid, six.integer_types) + assert isinstance(r.seqnum, six.integer_types) + return b"\x05" + to_be4(r.scid) + to_be4(r.seqnum) + if isinstance(r, Ack): + assert isinstance(r.resp_seqnum, six.integer_types) + return b"\x06" + to_be4(r.resp_seqnum) + raise TypeError(r) + + +@attrs +@implementer(IRecord) +class _Record(object): + _framer = attrib(validator=provides(IFramer)) + _noise = attrib() + + n = MethodicalMachine() + # TODO: set_trace + + def __attrs_post_init__(self): + self._noise.start_handshake() + + # in: role= + # in: prologue_received, frame_received + # out: handshake_received, record_received + # out: transport.write (noise handshake, encrypted records) + # states: want_prologue, want_handshake, want_record + + @n.state(initial=True) + def want_prologue(self): + pass # pragma: no cover + + @n.state() + def want_handshake(self): + pass # pragma: no cover + + @n.state() + def want_message(self): + pass # pragma: no cover + + @n.input() + def got_prologue(self): + pass + + @n.input() + def got_frame(self, frame): + pass + + @n.output() + def send_handshake(self): + handshake = self._noise.write_message() # generate the ephemeral key + self._framer.send_frame(handshake) + + @n.output() + def process_handshake(self, frame): + try: + payload = self._noise.read_message(frame) + # Noise can include unencrypted data in the handshake, but we don't + # use it + del payload + except NoiseInvalidMessage as e: + log.err(e, "bad inbound noise handshake") + raise Disconnect() + return Handshake() + + @n.output() + def decrypt_message(self, frame): + try: + message = self._noise.decrypt(frame) + except NoiseInvalidMessage as e: + # if this happens during tests, flunk the test + log.err(e, "bad inbound noise frame") + raise Disconnect() + return parse_record(message) + + want_prologue.upon(got_prologue, outputs=[send_handshake], + enter=want_handshake) + want_handshake.upon(got_frame, outputs=[process_handshake], + collector=first, enter=want_message) + want_message.upon(got_frame, outputs=[decrypt_message], + collector=first, enter=want_message) + + # external API is: connectionMade, dataReceived, send_record + + def connectionMade(self): + self._framer.connectionMade() + + def add_and_unframe(self, data): + for token in self._framer.add_and_parse(data): + if isinstance(token, Prologue): + self.got_prologue() # triggers send_handshake + else: + assert isinstance(token, Frame) + yield self.got_frame(token.frame) # Handshake or a Record type + + def send_record(self, r): + message = encode_record(r) + frame = self._noise.encrypt(message) + self._framer.send_frame(frame) + + +@attrs +class DilatedConnectionProtocol(Protocol, object): + """I manage an L2 connection. + + When a new L2 connection is needed (as determined by the Leader), + both Leader and Follower will initiate many simultaneous connections + (probably TCP, but conceivably others). A subset will actually + connect. A subset of those will successfully pass negotiation by + exchanging handshakes to demonstrate knowledge of the session key. + One of the negotiated connections will be selected by the Leader for + active use, and the others will be dropped. + + At any given time, there is at most one active L2 connection. + """ + + _eventual_queue = attrib() + _role = attrib() + _connector = attrib(validator=provides(IDilationConnector)) + _noise = attrib() + _outbound_prologue = attrib(validator=instance_of(bytes)) + _inbound_prologue = attrib(validator=instance_of(bytes)) + + _use_relay = False + _relay_handshake = None + + m = MethodicalMachine() + set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover + + def __attrs_post_init__(self): + self._manager = None # set if/when we are selected + self._disconnected = OneShotObserver(self._eventual_queue) + self._can_send_records = False + + @m.state(initial=True) + def unselected(self): + pass # pragma: no cover + + @m.state() + def selecting(self): + pass # pragma: no cover + + @m.state() + def selected(self): + pass # pragma: no cover + + @m.input() + def got_kcm(self): + pass + + @m.input() + def select(self, manager): + pass # fires set_manager() + + @m.input() + def got_record(self, record): + pass + + @m.output() + def add_candidate(self): + self._connector.add_candidate(self) + + @m.output() + def set_manager(self, manager): + self._manager = manager + + @m.output() + def can_send_records(self, manager): + self._can_send_records = True + + @m.output() + def deliver_record(self, record): + self._manager.got_record(record) + + unselected.upon(got_kcm, outputs=[add_candidate], enter=selecting) + selecting.upon(select, outputs=[set_manager, can_send_records], enter=selected) + selected.upon(got_record, outputs=[deliver_record], enter=selected) + + # called by Connector + + def use_relay(self, relay_handshake): + assert isinstance(relay_handshake, bytes) + self._use_relay = True + self._relay_handshake = relay_handshake + + def when_disconnected(self): + return self._disconnected.when_fired() + + def disconnect(self): + self.transport.loseConnection() + + # select() called by Connector + + # called by Manager + def send_record(self, record): + assert self._can_send_records + self._record.send_record(record) + + # IProtocol methods + + def connectionMade(self): + framer = _Framer(self.transport, + self._outbound_prologue, self._inbound_prologue) + if self._use_relay: + framer.use_relay(self._relay_handshake) + self._record = _Record(framer, self._noise) + self._record.connectionMade() + + def dataReceived(self, data): + try: + for token in self._record.add_and_unframe(data): + assert isinstance(token, Handshake_or_Records) + if isinstance(token, Handshake): + if self._role is FOLLOWER: + self._record.send_record(KCM()) + elif isinstance(token, KCM): + # if we're the leader, add this connection as a candiate. + # if we're the follower, accept this connection. + self.got_kcm() # connector.add_candidate() + else: + self.got_record(token) # manager.got_record() + except Disconnect: + self.transport.loseConnection() + + def connectionLost(self, why=None): + self._disconnected.fire(self) diff --git a/src/wormhole/_dilation/connector.py b/src/wormhole/_dilation/connector.py new file mode 100644 index 0000000..aa5f8e0 --- /dev/null +++ b/src/wormhole/_dilation/connector.py @@ -0,0 +1,387 @@ +from __future__ import print_function, unicode_literals +from collections import defaultdict +from binascii import hexlify +from attr import attrs, attrib +from attr.validators import instance_of, provides, optional +from automat import MethodicalMachine +from zope.interface import implementer +from twisted.internet.task import deferLater +from twisted.internet.defer import DeferredList +from twisted.internet.endpoints import serverFromString +from twisted.internet.protocol import ClientFactory, ServerFactory +from twisted.python import log +from .. import ipaddrs # TODO: move into _dilation/ +from .._interfaces import IDilationConnector, IDilationManager +from ..timing import DebugTiming +from ..observer import EmptyableSet +from ..util import HKDF, to_unicode +from .connection import DilatedConnectionProtocol, KCM +from .roles import LEADER + +from .._hints import (DirectTCPV1Hint, TorTCPV1Hint, RelayV1Hint, + parse_hint_argv, describe_hint_obj, endpoint_from_hint_obj, + encode_hint) +from ._noise import NoiseConnection + + +def build_sided_relay_handshake(key, side): + assert isinstance(side, type(u"")) + assert len(side) == 8 * 2 + token = HKDF(key, 32, CTXinfo=b"transit_relay_token") + return (b"please relay " + hexlify(token) + + b" for side " + side.encode("ascii") + b"\n") + + +PROLOGUE_LEADER = b"Magic-Wormhole Dilation Handshake v1 Leader\n\n" +PROLOGUE_FOLLOWER = b"Magic-Wormhole Dilation Handshake v1 Follower\n\n" +NOISEPROTO = b"Noise_NNpsk0_25519_ChaChaPoly_BLAKE2s" + +def build_noise(): + return NoiseConnection.from_name(NOISEPROTO) + +@attrs +@implementer(IDilationConnector) +class Connector(object): + _dilation_key = attrib(validator=instance_of(type(b""))) + _transit_relay_location = attrib(validator=optional(instance_of(type(u"")))) + _manager = attrib(validator=provides(IDilationManager)) + _reactor = attrib() + _eventual_queue = attrib() + _no_listen = attrib(validator=instance_of(bool)) + _tor = attrib() + _timing = attrib() + _side = attrib(validator=instance_of(type(u""))) + # was self._side = bytes_to_hexstr(os.urandom(8)) # unicode + _role = attrib() + + m = MethodicalMachine() + set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover + + RELAY_DELAY = 2.0 + + def __attrs_post_init__(self): + if self._transit_relay_location: + # TODO: allow multiple hints for a single relay + relay_hint = parse_hint_argv(self._transit_relay_location) + relay = RelayV1Hint(hints=(relay_hint,)) + self._transit_relays = [relay] + else: + self._transit_relays = [] + self._listeners = set() # IListeningPorts that can be stopped + self._pending_connectors = set() # Deferreds that can be cancelled + self._pending_connections = EmptyableSet( + _eventual_queue=self._eventual_queue) # Protocols to be stopped + self._contenders = set() # viable connections + self._winning_connection = None + self._timing = self._timing or DebugTiming() + self._timing.add("transit") + + # this describes what our Connector can do, for the initial advertisement + @classmethod + def get_connection_abilities(klass): + return [{"type": "direct-tcp-v1"}, + {"type": "relay-v1"}, + ] + + def build_protocol(self, addr): + # encryption: let's use Noise NNpsk0 (or maybe NNpsk2). That uses + # ephemeral keys plus a pre-shared symmetric key (the Transit key), a + # different one for each potential connection. + noise = build_noise() + noise.set_psks(self._dilation_key) + if self._role is LEADER: + noise.set_as_initiator() + outbound_prologue = PROLOGUE_LEADER + inbound_prologue = PROLOGUE_FOLLOWER + else: + noise.set_as_responder() + outbound_prologue = PROLOGUE_FOLLOWER + inbound_prologue = PROLOGUE_LEADER + p = DilatedConnectionProtocol(self._eventual_queue, self._role, + self, noise, + outbound_prologue, inbound_prologue) + return p + + @m.state(initial=True) + def connecting(self): + pass # pragma: no cover + + @m.state() + def connected(self): + pass # pragma: no cover + + @m.state(terminal=True) + def stopped(self): + pass # pragma: no cover + + # TODO: unify the tense of these method-name verbs + + # add_relay() and got_hints() are called by the Manager as it receives + # messages from our peer. stop() is called when the Manager shuts down + @m.input() + def add_relay(self, hint_objs): + pass + + @m.input() + def got_hints(self, hint_objs): + pass + + @m.input() + def stop(self): + pass + + # called by ourselves, when _start_listener() is ready + @m.input() + def listener_ready(self, hint_objs): + pass + + # called when DilatedConnectionProtocol submits itself, after KCM + # received + @m.input() + def add_candidate(self, c): + pass + + # called by ourselves, via consider() + @m.input() + def accept(self, c): + pass + + + @m.output() + def use_hints(self, hint_objs): + self._use_hints(hint_objs) + + @m.output() + def publish_hints(self, hint_objs): + self._publish_hints(hint_objs) + + def _publish_hints(self, hint_objs): + self._manager.send_hints([encode_hint(h) for h in hint_objs]) + + @m.output() + def consider(self, c): + self._contenders.add(c) + if self._role is LEADER: + # for now, just accept the first one. TODO: be clever. + self._eventual_queue.eventually(self.accept, c) + else: + # the follower always uses the first contender, since that's the + # only one the leader picked + self._eventual_queue.eventually(self.accept, c) + + @m.output() + def select_and_stop_remaining(self, c): + self._winning_connection = c + self._contenders.clear() # we no longer care who else came close + # remove this winner from the losers, so we don't shut it down + self._pending_connections.discard(c) + # shut down losing connections + self.stop_listeners() # TODO: maybe keep it open? NAT/p2p assist + self.stop_pending_connectors() + self.stop_pending_connections() + + c.select(self._manager) # subsequent frames go directly to the manager + if self._role is LEADER: + # TODO: this should live in Connection + c.send_record(KCM()) # leader sends KCM now + self._manager.use_connection(c) # manager sends frames to Connection + + @m.output() + def stop_everything(self): + self.stop_listeners() + self.stop_pending_connectors() + self.stop_pending_connections() + self.break_cycles() + + def stop_listeners(self): + d = DeferredList([l.stopListening() for l in self._listeners]) + self._listeners.clear() + return d # synchronization for tests + + def stop_pending_connectors(self): + return DeferredList([d.cancel() for d in self._pending_connectors]) + + def stop_pending_connections(self): + d = self._pending_connections.when_next_empty() + [c.loseConnection() for c in self._pending_connections] + return d + + def break_cycles(self): + # help GC by forgetting references to things that reference us + self._listeners.clear() + self._pending_connectors.clear() + self._pending_connections.clear() + self._winning_connection = None + + connecting.upon(listener_ready, enter=connecting, outputs=[publish_hints]) + connecting.upon(add_relay, enter=connecting, outputs=[use_hints, + publish_hints]) + connecting.upon(got_hints, enter=connecting, outputs=[use_hints]) + connecting.upon(add_candidate, enter=connecting, outputs=[consider]) + connecting.upon(accept, enter=connected, outputs=[ + select_and_stop_remaining]) + connecting.upon(stop, enter=stopped, outputs=[stop_everything]) + + # once connected, we ignore everything except stop + connected.upon(listener_ready, enter=connected, outputs=[]) + connected.upon(add_relay, enter=connected, outputs=[]) + connected.upon(got_hints, enter=connected, outputs=[]) + # TODO: tell them to disconnect? will they hang out forever? I *think* + # they'll drop this once they get a KCM on the winning connection. + connected.upon(add_candidate, enter=connected, outputs=[]) + connected.upon(accept, enter=connected, outputs=[]) + connected.upon(stop, enter=stopped, outputs=[stop_everything]) + + # from Manager: start, got_hints, stop + # maybe add_candidate, accept + + def start(self): + if not self._no_listen and not self._tor: + addresses = self._get_listener_addresses() + self._start_listener(addresses) + if self._transit_relays: + self._publish_hints(self._transit_relays) + self._use_hints(self._transit_relays) + + def _get_listener_addresses(self): + addresses = ipaddrs.find_addresses() + non_loopback_addresses = [a for a in addresses if a != "127.0.0.1"] + if non_loopback_addresses: + # some test hosts, including the appveyor VMs, *only* have + # 127.0.0.1, and the tests will hang badly if we remove it. + addresses = non_loopback_addresses + return addresses + + def _start_listener(self, addresses): + # TODO: listen on a fixed port, if possible, for NAT/p2p benefits, also + # to make firewall configs easier + # TODO: retain listening port between connection generations? + ep = serverFromString(self._reactor, "tcp:0") + f = InboundConnectionFactory(self) + d = ep.listen(f) + + def _listening(lp): + # lp is an IListeningPort + self._listeners.add(lp) # for shutdown and tests + portnum = lp.getHost().port + direct_hints = [DirectTCPV1Hint(to_unicode(addr), portnum, 0.0) + for addr in addresses] + self.listener_ready(direct_hints) + d.addCallback(_listening) + d.addErrback(log.err) + + def _schedule_connection(self, delay, h, is_relay): + ep = endpoint_from_hint_obj(h, self._tor, self._reactor) + desc = describe_hint_obj(h, is_relay, self._tor) + d = deferLater(self._reactor, delay, + self._connect, ep, desc, is_relay) + d.addErrback(log.err) + self._pending_connectors.add(d) + + def _use_hints(self, hints): + # first, pull out all the relays, we'll connect to them later + relays = [] + direct = defaultdict(list) + for h in hints: + if isinstance(h, RelayV1Hint): + relays.append(h) + else: + direct[h.priority].append(h) + delay = 0.0 + made_direct = False + priorities = sorted(set(direct.keys()), reverse=True) + for p in priorities: + for h in direct[p]: + if isinstance(h, TorTCPV1Hint) and not self._tor: + continue + self._schedule_connection(delay, h, is_relay=False) + made_direct = True + # Make all direct connections immediately. Later, we'll change + # the add_candidate() function to look at the priority when + # deciding whether to accept a successful connection or not, + # and it can wait for more options if it sees a higher-priority + # one still running. But if we bail on that, we might consider + # putting an inter-direct-hint delay here to influence the + # process. + # delay += 1.0 + + if made_direct and not self._no_listen: + # Prefer direct connections by stalling relay connections by a + # few seconds. We don't wait until direct connections have + # failed, because many direct hints will be to unused + # local-network IP address, which won't answer, and can take the + # full 30s TCP timeout to fail. + # + # If we didn't make any direct connections, or we're using + # --no-listen, then we're probably going to have to use the + # relay, so don't delay it at all. + delay += self.RELAY_DELAY + + # It might be nice to wire this so that a failure in the direct hints + # causes the relay hints to be used right away (fast failover). But + # none of our current use cases would take advantage of that: if we + # have any viable direct hints, then they're either going to succeed + # quickly or hang for a long time. + for r in relays: + for h in r.hints: + self._schedule_connection(delay, h, is_relay=True) + # TODO: + # if not contenders: + # raise TransitError("No contenders for connection") + + # TODO: add 2*TIMEOUT deadline for first generation, don't wait forever for + # the initial connection + + def _connect(self, ep, description, is_relay=False): + relay_handshake = None + if is_relay: + relay_handshake = build_sided_relay_handshake(self._dilation_key, + self._side) + f = OutboundConnectionFactory(self, relay_handshake) + d = ep.connect(f) + # fires with protocol, or ConnectError + + def _connected(p): + self._pending_connections.add(p) + # c might not be in _pending_connections, if it turned out to be a + # winner, which is why we use discard() and not remove() + p.when_disconnected().addCallback(self._pending_connections.discard) + d.addCallback(_connected) + return d + + # Connection selection. All instances of DilatedConnectionProtocol which + # look viable get passed into our add_contender() method. + + # On the Leader side, "viable" means we've seen their KCM frame, which is + # the first Noise-encrypted packet on any given connection, and it has an + # empty body. We gather viable connections until we see one that we like, + # or a timer expires. Then we "select" it, close the others, and tell our + # Manager to use it. + + # On the Follower side, we'll only see a KCM on the one connection selected + # by the Leader, so the first viable connection wins. + + # our Connection protocols call: add_candidate + + +@attrs +class OutboundConnectionFactory(ClientFactory, object): + _connector = attrib(validator=provides(IDilationConnector)) + _relay_handshake = attrib(validator=optional(instance_of(bytes))) + + def buildProtocol(self, addr): + p = self._connector.build_protocol(addr) + p.factory = self + if self._relay_handshake is not None: + p.use_relay(self._relay_handshake) + return p + + +@attrs +class InboundConnectionFactory(ServerFactory, object): + _connector = attrib(validator=provides(IDilationConnector)) + + def buildProtocol(self, addr): + p = self._connector.build_protocol(addr) + p.factory = self + return p diff --git a/src/wormhole/_dilation/encode.py b/src/wormhole/_dilation/encode.py new file mode 100644 index 0000000..80e9902 --- /dev/null +++ b/src/wormhole/_dilation/encode.py @@ -0,0 +1,19 @@ +from __future__ import print_function, unicode_literals +import struct + +assert len(struct.pack("L", value) + + +def from_be4(b): + if not isinstance(b, bytes): + raise TypeError(repr(b)) + if len(b) != 4: + raise ValueError + return struct.unpack(">L", b)[0] diff --git a/src/wormhole/_dilation/inbound.py b/src/wormhole/_dilation/inbound.py new file mode 100644 index 0000000..2f6ffaf --- /dev/null +++ b/src/wormhole/_dilation/inbound.py @@ -0,0 +1,134 @@ +from __future__ import print_function, unicode_literals +from attr import attrs, attrib +from attr.validators import provides +from zope.interface import implementer +from twisted.python import log +from .._interfaces import IDilationManager, IInbound +from .subchannel import (SubChannel, _SubchannelAddress) + + +class DuplicateOpenError(Exception): + pass + + +class DataForMissingSubchannelError(Exception): + pass + + +class CloseForMissingSubchannelError(Exception): + pass + + +@attrs +@implementer(IInbound) +class Inbound(object): + # Inbound flow control: TCP delivers data to Connection.dataReceived, + # Connection delivers to our handle_data, we deliver to + # SubChannel.remote_data, subchannel delivers to proto.dataReceived + _manager = attrib(validator=provides(IDilationManager)) + _host_addr = attrib() + + def __attrs_post_init__(self): + # we route inbound Data records to Subchannels .dataReceived + self._open_subchannels = {} # scid -> Subchannel + self._paused_subchannels = set() # Subchannels that have paused us + # the set is non-empty, we pause the transport + self._highest_inbound_acked = -1 + self._connection = None + + # from our Manager + def set_listener_endpoint(self, listener_endpoint): + self._listener_endpoint = listener_endpoint + + def set_subchannel_zero(self, scid0, sc0): + self._open_subchannels[scid0] = sc0 + + def use_connection(self, c): + self._connection = c + # We can pause the connection's reads when we receive too much data. If + # this is a non-initial connection, then we might already have + # subchannels that are paused from before, so we might need to pause + # the new connection before it can send us any data + if self._paused_subchannels: + self._connection.pauseProducing() + + # Inbound is responsible for tracking the high watermark and deciding + # whether to ignore inbound messages or not + + def is_record_old(self, r): + if r.seqnum <= self._highest_inbound_acked: + return True + return False + + def update_ack_watermark(self, r): + self._highest_inbound_acked = max(self._highest_inbound_acked, + r.seqnum) + + def handle_open(self, scid): + if scid in self._open_subchannels: + log.err(DuplicateOpenError( + "received duplicate OPEN for {}".format(scid))) + return + peer_addr = _SubchannelAddress(scid) + sc = SubChannel(scid, self._manager, self._host_addr, peer_addr) + self._open_subchannels[scid] = sc + self._listener_endpoint._got_open(sc, peer_addr) + + def handle_data(self, scid, data): + sc = self._open_subchannels.get(scid) + if sc is None: + log.err(DataForMissingSubchannelError( + "received DATA for non-existent subchannel {}".format(scid))) + return + sc.remote_data(data) + + def handle_close(self, scid): + sc = self._open_subchannels.get(scid) + if sc is None: + log.err(CloseForMissingSubchannelError( + "received CLOSE for non-existent subchannel {}".format(scid))) + return + sc.remote_close() + + def subchannel_closed(self, scid, sc): + # connectionLost has just been signalled + assert self._open_subchannels[scid] is sc + del self._open_subchannels[scid] + + def stop_using_connection(self): + self._connection = None + + # from our Subchannel, or rather from the Protocol above it and sent + # through the subchannel + + # The subchannel is an IProducer, and application protocols can always + # thell them to pauseProducing if we're delivering inbound data too + # quickly. They don't need to register anything. + + def subchannel_pauseProducing(self, sc): + was_paused = bool(self._paused_subchannels) + self._paused_subchannels.add(sc) + if self._connection and not was_paused: + self._connection.pauseProducing() + + def subchannel_resumeProducing(self, sc): + was_paused = bool(self._paused_subchannels) + self._paused_subchannels.discard(sc) + if self._connection and was_paused and not self._paused_subchannels: + self._connection.resumeProducing() + + def subchannel_stopProducing(self, sc): + # This protocol doesn't want any additional data. If we were a normal + # (single-owner) Transport, we'd call .loseConnection now. But our + # Connection is shared among many subchannels, so instead we just + # stop letting them pause the connection. + was_paused = bool(self._paused_subchannels) + self._paused_subchannels.discard(sc) + if self._connection and was_paused and not self._paused_subchannels: + self._connection.resumeProducing() + + # TODO: we might refactor these pause/resume/stop methods by building a + # context manager that look at the paused/not-paused state first, then + # lets the caller modify self._paused_subchannels, then looks at it a + # second time, and calls c.pauseProducing/c.resumeProducing as + # appropriate. I'm not sure it would be any cleaner, though. diff --git a/src/wormhole/_dilation/manager.py b/src/wormhole/_dilation/manager.py new file mode 100644 index 0000000..18d1770 --- /dev/null +++ b/src/wormhole/_dilation/manager.py @@ -0,0 +1,580 @@ +from __future__ import print_function, unicode_literals +import os +from collections import deque +from attr import attrs, attrib +from attr.validators import provides, instance_of, optional +from automat import MethodicalMachine +from zope.interface import implementer +from twisted.internet.defer import Deferred, inlineCallbacks, returnValue +from twisted.python import log +from .._interfaces import IDilator, IDilationManager, ISend +from ..util import dict_to_bytes, bytes_to_dict, bytes_to_hexstr +from ..observer import OneShotObserver +from .._key import derive_key +from .encode import to_be4 +from .subchannel import (SubChannel, _SubchannelAddress, _WormholeAddress, + ControlEndpoint, SubchannelConnectorEndpoint, + SubchannelListenerEndpoint) +from .connector import Connector +from .._hints import parse_hint +from .roles import LEADER, FOLLOWER +from .connection import KCM, Ping, Pong, Open, Data, Close, Ack +from .inbound import Inbound +from .outbound import Outbound + + +# exported to Wormhole() for inclusion in versions message +DILATION_VERSIONS = ["1"] + + +class OldPeerCannotDilateError(Exception): + pass + + +class UnknownDilationMessageType(Exception): + pass + + +class ReceivedHintsTooEarly(Exception): + pass + + +class UnexpectedKCM(Exception): + pass + + +class UnknownMessageType(Exception): + pass + + +def make_side(): + return bytes_to_hexstr(os.urandom(6)) + + +# new scheme: +# * both sides send PLEASE as soon as they have an unverified key and +# w.dilate has been called, +# * PLEASE includes a dilation-specific "side" (independent of the "side" +# used by mailbox messages) +# * higher "side" is Leader, lower is Follower +# * PLEASE includes can-dilate list of version integers, requires overlap +# "1" is current + +# * we start dilation after both w.dilate() and receiving VERSION, putting us +# in WANTING, then we process all previously-queued inbound DILATE-n +# messages. When PLEASE arrives, we move to CONNECTING +# * HINTS sent after dilation starts +# * only Leader sends RECONNECT, only Follower sends RECONNECTING. This +# is the only difference between the two sides, and is not enforced +# by the protocol (i.e. if the Follower sends RECONNECT to the Leader, +# the Leader will obey, although TODO how confusing will this get?) +# * upon receiving RECONNECT: drop Connector, start new Connector, send +# RECONNECTING, start sending HINTS +# * upon sending RECONNECT: go into FLUSHING state and ignore all HINTS until +# RECONNECTING received. The new Connector can be spun up earlier, and it +# can send HINTS, but it must not be given any HINTS that arrive before +# RECONNECTING (since they're probably stale) + +# * after VERSIONS(KCM) received, we might learn that they other side cannot +# dilate. w.dilate errbacks at this point + +# * maybe signal warning if we stay in a "want" state for too long +# * nobody sends HINTS until they're ready to receive +# * nobody sends HINTS unless they've called w.dilate() and received PLEASE +# * nobody connects to inbound hints unless they've called w.dilate() +# * if leader calls w.dilate() but not follower, leader waits forever in +# "want" (doesn't send anything) +# * if follower calls w.dilate() but not leader, follower waits forever +# in "want", leader waits forever in "wanted" + +@attrs +@implementer(IDilationManager) +class Manager(object): + _S = attrib(validator=provides(ISend)) + _my_side = attrib(validator=instance_of(type(u""))) + _transit_key = attrib(validator=instance_of(bytes)) + _transit_relay_location = attrib(validator=optional(instance_of(str))) + _reactor = attrib() + _eventual_queue = attrib() + _cooperator = attrib() + _no_listen = False # TODO + _tor = None # TODO + _timing = None # TODO + _next_subchannel_id = None # initialized in choose_role + + m = MethodicalMachine() + set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover + + def __attrs_post_init__(self): + self._got_versions_d = Deferred() + + self._my_role = None # determined upon rx_PLEASE + + self._connection = None + self._made_first_connection = False + self._first_connected = OneShotObserver(self._eventual_queue) + self._host_addr = _WormholeAddress() + + self._next_dilation_phase = 0 + + # I kept getting confused about which methods were for inbound data + # (and thus flow-control methods go "out") and which were for + # outbound data (with flow-control going "in"), so I split them up + # into separate pieces. + self._inbound = Inbound(self, self._host_addr) + self._outbound = Outbound(self, self._cooperator) # from us to peer + + def set_listener_endpoint(self, listener_endpoint): + self._inbound.set_listener_endpoint(listener_endpoint) + + def set_subchannel_zero(self, scid0, sc0): + self._inbound.set_subchannel_zero(scid0, sc0) + + def when_first_connected(self): + return self._first_connected.when_fired() + + def send_dilation_phase(self, **fields): + dilation_phase = self._next_dilation_phase + self._next_dilation_phase += 1 + self._S.send("dilate-%d" % dilation_phase, dict_to_bytes(fields)) + + def send_hints(self, hints): # from Connector + self.send_dilation_phase(type="connection-hints", hints=hints) + + # forward inbound-ish things to _Inbound + + def subchannel_pauseProducing(self, sc): + self._inbound.subchannel_pauseProducing(sc) + + def subchannel_resumeProducing(self, sc): + self._inbound.subchannel_resumeProducing(sc) + + def subchannel_stopProducing(self, sc): + self._inbound.subchannel_stopProducing(sc) + + # forward outbound-ish things to _Outbound + def subchannel_registerProducer(self, sc, producer, streaming): + self._outbound.subchannel_registerProducer(sc, producer, streaming) + + def subchannel_unregisterProducer(self, sc): + self._outbound.subchannel_unregisterProducer(sc) + + def send_open(self, scid): + self._queue_and_send(Open, scid) + + def send_data(self, scid, data): + self._queue_and_send(Data, scid, data) + + def send_close(self, scid): + self._queue_and_send(Close, scid) + + def _queue_and_send(self, record_type, *args): + r = self._outbound.build_record(record_type, *args) + # Outbound owns the send_record() pipe, so that it can stall new + # writes after a new connection is made until after all queued + # messages are written (to preserve ordering). + self._outbound.queue_and_send_record(r) # may trigger pauseProducing + + def subchannel_closed(self, scid, sc): + # let everyone clean up. This happens just after we delivered + # connectionLost to the Protocol, except for the control channel, + # which might get connectionLost later after they use ep.connect. + # TODO: is this inversion a problem? + self._inbound.subchannel_closed(scid, sc) + self._outbound.subchannel_closed(scid, sc) + + # our Connector calls these + + def connector_connection_made(self, c): + self.connection_made() # state machine update + self._connection = c + self._inbound.use_connection(c) + self._outbound.use_connection(c) # does c.registerProducer + if not self._made_first_connection: + self._made_first_connection = True + self._first_connected.fire(None) + pass + + def connector_connection_lost(self): + self._stop_using_connection() + if self._my_role is LEADER: + self.connection_lost_leader() # state machine + else: + self.connection_lost_follower() + + def _stop_using_connection(self): + # the connection is already lost by this point + self._connection = None + self._inbound.stop_using_connection() + self._outbound.stop_using_connection() # does c.unregisterProducer + + # from our active Connection + + def got_record(self, r): + # records with sequence numbers: always ack, ignore old ones + if isinstance(r, (Open, Data, Close)): + self.send_ack(r.seqnum) # always ack, even for old ones + if self._inbound.is_record_old(r): + return + self._inbound.update_ack_watermark(r.seqnum) + if isinstance(r, Open): + self._inbound.handle_open(r.scid) + elif isinstance(r, Data): + self._inbound.handle_data(r.scid, r.data) + else: # isinstance(r, Close) + self._inbound.handle_close(r.scid) + return + if isinstance(r, KCM): + log.err(UnexpectedKCM()) + elif isinstance(r, Ping): + self.handle_ping(r.ping_id) + elif isinstance(r, Pong): + self.handle_pong(r.ping_id) + elif isinstance(r, Ack): + self._outbound.handle_ack(r.resp_seqnum) # retire queued messages + else: + log.err(UnknownMessageType("{}".format(r))) + + # pings, pongs, and acks are not queued + def send_ping(self, ping_id): + self._outbound.send_if_connected(Ping(ping_id)) + + def send_pong(self, ping_id): + self._outbound.send_if_connected(Pong(ping_id)) + + def send_ack(self, resp_seqnum): + self._outbound.send_if_connected(Ack(resp_seqnum)) + + def handle_ping(self, ping_id): + self.send_pong(ping_id) + + def handle_pong(self, ping_id): + # TODO: update is-alive timer + pass + + # subchannel maintenance + def allocate_subchannel_id(self): + scid_num = self._next_subchannel_id + self._next_subchannel_id += 2 + return to_be4(scid_num) + + # state machine + + # We are born WANTING after the local app calls w.dilate(). We start + # CONNECTING when we receive PLEASE from the remote side + + def start(self): + self.send_please() + + def send_please(self): + self.send_dilation_phase(type="please", side=self._my_side) + + @m.state(initial=True) + def WANTING(self): + pass # pragma: no cover + + @m.state() + def CONNECTING(self): + pass # pragma: no cover + + @m.state() + def CONNECTED(self): + pass # pragma: no cover + + @m.state() + def FLUSHING(self): + pass # pragma: no cover + + @m.state() + def ABANDONING(self): + pass # pragma: no cover + + @m.state() + def LONELY(self): + pass # pragma: no cover + + @m.state() + def STOPPING(self): + pass # pragma: no cover + + @m.state(terminal=True) + def STOPPED(self): + pass # pragma: no cover + + @m.input() + def rx_PLEASE(self, message): + pass # pragma: no cover + + @m.input() # only sent by Follower + def rx_HINTS(self, hint_message): + pass # pragma: no cover + + @m.input() # only Leader sends RECONNECT, so only Follower receives it + def rx_RECONNECT(self): + pass # pragma: no cover + + @m.input() # only Follower sends RECONNECTING, so only Leader receives it + def rx_RECONNECTING(self): + pass # pragma: no cover + + # Connector gives us connection_made() + @m.input() + def connection_made(self): + pass # pragma: no cover + + # our connection_lost() fires connection_lost_leader or + # connection_lost_follower depending upon our role. If either side sees a + # problem with the connection (timeouts, bad authentication) then they + # just drop it and let connection_lost() handle the cleanup. + @m.input() + def connection_lost_leader(self): + pass # pragma: no cover + + @m.input() + def connection_lost_follower(self): + pass + + @m.input() + def stop(self): + pass # pragma: no cover + + @m.output() + def choose_role(self, message): + their_side = message["side"] + if self._my_side > their_side: + self._my_role = LEADER + # scid 0 is reserved for the control channel. the leader uses odd + # numbers starting with 1 + self._next_subchannel_id = 1 + elif their_side > self._my_side: + self._my_role = FOLLOWER + # the follower uses even numbers starting with 2 + self._next_subchannel_id = 2 + else: + raise ValueError("their side shouldn't be equal: reflection?") + + # these Outputs behave differently for the Leader vs the Follower + + @m.output() + def start_connecting_ignore_message(self, message): + del message # ignored + return self._start_connecting() + + @m.output() + def start_connecting(self): + self._start_connecting() + + def _start_connecting(self): + assert self._my_role is not None + self._connector = Connector(self._transit_key, + self._transit_relay_location, + self, + self._reactor, self._eventual_queue, + self._no_listen, self._tor, + self._timing, + self._my_side, # needed for relay handshake + self._my_role) + self._connector.start() + + @m.output() + def send_reconnect(self): + self.send_dilation_phase(type="reconnect") # TODO: generation number? + + @m.output() + def send_reconnecting(self): + self.send_dilation_phase(type="reconnecting") # TODO: generation? + + @m.output() + def use_hints(self, hint_message): + hint_objs = filter(lambda h: h, # ignore None, unrecognizable + [parse_hint(hs) for hs in hint_message["hints"]]) + hint_objs = list(hint_objs) + self._connector.got_hints(hint_objs) + + @m.output() + def stop_connecting(self): + self._connector.stop() + + @m.output() + def abandon_connection(self): + # we think we're still connected, but the Leader disagrees. Or we've + # been told to shut down. + self._connection.disconnect() # let connection_lost do cleanup + + # we start CONNECTING when we get rx_PLEASE + WANTING.upon(rx_PLEASE, enter=CONNECTING, + outputs=[choose_role, start_connecting_ignore_message]) + + CONNECTING.upon(connection_made, enter=CONNECTED, outputs=[]) + + # Leader + CONNECTED.upon(connection_lost_leader, enter=FLUSHING, + outputs=[send_reconnect]) + FLUSHING.upon(rx_RECONNECTING, enter=CONNECTING, + outputs=[start_connecting]) + + # Follower + # if we notice a lost connection, just wait for the Leader to notice too + CONNECTED.upon(connection_lost_follower, enter=LONELY, outputs=[]) + LONELY.upon(rx_RECONNECT, enter=CONNECTING, + outputs=[send_reconnecting, start_connecting]) + # but if they notice it first, abandon our (seemingly functional) + # connection, then tell them that we're ready to try again + CONNECTED.upon(rx_RECONNECT, enter=ABANDONING, outputs=[abandon_connection]) + ABANDONING.upon(connection_lost_follower, enter=CONNECTING, + outputs=[send_reconnecting, start_connecting]) + # and if they notice a problem while we're still connecting, abandon our + # incomplete attempt and try again. in this case we don't have to wait + # for a connection to finish shutdown + CONNECTING.upon(rx_RECONNECT, enter=CONNECTING, + outputs=[stop_connecting, + send_reconnecting, + start_connecting]) + + # rx_HINTS never changes state, they're just accepted or ignored + WANTING.upon(rx_HINTS, enter=WANTING, outputs=[]) # too early + CONNECTING.upon(rx_HINTS, enter=CONNECTING, outputs=[use_hints]) + CONNECTED.upon(rx_HINTS, enter=CONNECTED, outputs=[]) # too late, ignore + FLUSHING.upon(rx_HINTS, enter=FLUSHING, outputs=[]) # stale, ignore + LONELY.upon(rx_HINTS, enter=LONELY, outputs=[]) # stale, ignore + ABANDONING.upon(rx_HINTS, enter=ABANDONING, outputs=[]) # shouldn't happen + STOPPING.upon(rx_HINTS, enter=STOPPING, outputs=[]) + + WANTING.upon(stop, enter=STOPPED, outputs=[]) + CONNECTING.upon(stop, enter=STOPPED, outputs=[stop_connecting]) + CONNECTED.upon(stop, enter=STOPPING, outputs=[abandon_connection]) + ABANDONING.upon(stop, enter=STOPPING, outputs=[]) + FLUSHING.upon(stop, enter=STOPPED, outputs=[]) + LONELY.upon(stop, enter=STOPPED, outputs=[]) + STOPPING.upon(connection_lost_leader, enter=STOPPED, outputs=[]) + STOPPING.upon(connection_lost_follower, enter=STOPPED, outputs=[]) + + +@attrs +@implementer(IDilator) +class Dilator(object): + """I launch the dilation process. + + I am created with every Wormhole (regardless of whether .dilate() + was called or not), and I handle the initial phase of dilation, + before we know whether we'll be the Leader or the Follower. Once we + hear the other side's VERSION message (which tells us that we have a + connection, they are capable of dilating, and which side we're on), + then we build a Manager and hand control to it. + """ + + _reactor = attrib() + _eventual_queue = attrib() + _cooperator = attrib() + + def __attrs_post_init__(self): + self._got_versions_d = Deferred() + self._started = False + self._endpoints = OneShotObserver(self._eventual_queue) + self._pending_inbound_dilate_messages = deque() + self._manager = None + + def wire(self, sender): + self._S = ISend(sender) + + # this is the primary entry point, called when w.dilate() is invoked + def dilate(self, transit_relay_location=None): + self._transit_relay_location = transit_relay_location + if not self._started: + self._started = True + self._start().addBoth(self._endpoints.fire) + return self._endpoints.when_fired() + + @inlineCallbacks + def _start(self): + # first, we wait until we hear the VERSION message, which tells us 1: + # the PAKE key works, so we can talk securely, 2: that they can do + # dilation at all (if they can't then w.dilate() errbacks) + + dilation_version = yield self._got_versions_d + + # TODO: we could probably return the endpoints earlier, if we flunk + # any connection/listen attempts upon OldPeerCannotDilateError, or + # if/when we give up on the initial connection + + if not dilation_version: # "1" or None + # TODO: be more specific about the error. dilation_version==None + # means we had no version in common with them, which could either + # be because they're so old they don't dilate at all, or because + # they're so new that they no longer accomodate our old version + raise OldPeerCannotDilateError() + + my_dilation_side = make_side() + self._manager = Manager(self._S, my_dilation_side, + self._transit_key, + self._transit_relay_location, + self._reactor, self._eventual_queue, + self._cooperator) + self._manager.start() + + while self._pending_inbound_dilate_messages: + plaintext = self._pending_inbound_dilate_messages.popleft() + self.received_dilate(plaintext) + + yield self._manager.when_first_connected() + + # we can open subchannels as soon as we get our first connection + scid0 = b"\x00\x00\x00\x00" + self._host_addr = _WormholeAddress() # TODO: share with Manager + peer_addr0 = _SubchannelAddress(scid0) + control_ep = ControlEndpoint(peer_addr0) + sc0 = SubChannel(scid0, self._manager, self._host_addr, peer_addr0) + control_ep._subchannel_zero_opened(sc0) + self._manager.set_subchannel_zero(scid0, sc0) + + connect_ep = SubchannelConnectorEndpoint(self._manager, self._host_addr) + + listen_ep = SubchannelListenerEndpoint(self._manager, self._host_addr) + self._manager.set_listener_endpoint(listen_ep) + + endpoints = (control_ep, connect_ep, listen_ep) + returnValue(endpoints) + + # from Boss + + def got_key(self, key): + # TODO: verify this happens before got_wormhole_versions, or add a gate + # to tolerate either ordering + purpose = b"dilation-v1" + LENGTH = 32 # TODO: whatever Noise wants, I guess + self._transit_key = derive_key(key, purpose, LENGTH) + + def got_wormhole_versions(self, their_wormhole_versions): + assert self._transit_key is not None + # this always happens before received_dilate + dilation_version = None + their_dilation_versions = set(their_wormhole_versions.get("can-dilate", [])) + my_versions = set(DILATION_VERSIONS) + shared_versions = my_versions.intersection(their_dilation_versions) + if "1" in shared_versions: + dilation_version = "1" + self._got_versions_d.callback(dilation_version) + + def received_dilate(self, plaintext): + # this receives new in-order DILATE-n payloads, decrypted but not + # de-JSONed. + + # this can appear before our .dilate() method is called, in which case + # we queue them for later + if not self._manager: + self._pending_inbound_dilate_messages.append(plaintext) + return + + message = bytes_to_dict(plaintext) + type = message["type"] + if type == "please": + self._manager.rx_PLEASE(message) + elif type == "connection-hints": + self._manager.rx_HINTS(message) + elif type == "reconnect": + self._manager.rx_RECONNECT() + elif type == "reconnecting": + self._manager.rx_RECONNECTING() + else: + log.err(UnknownDilationMessageType(message)) + return diff --git a/src/wormhole/_dilation/outbound.py b/src/wormhole/_dilation/outbound.py new file mode 100644 index 0000000..96fbd3d --- /dev/null +++ b/src/wormhole/_dilation/outbound.py @@ -0,0 +1,389 @@ +from __future__ import print_function, unicode_literals +from collections import deque +from attr import attrs, attrib +from attr.validators import provides +from zope.interface import implementer +from twisted.internet.interfaces import IPushProducer, IPullProducer +from twisted.python import log +from twisted.python.reflect import safe_str +from .._interfaces import IDilationManager, IOutbound +from .connection import KCM, Ping, Pong, Ack + + +# Outbound flow control: app writes to subchannel, we write to Connection + +# The app can register an IProducer of their choice, to let us throttle their +# outbound data. Not all subchannels will have producers registered, and the +# producer probably won't be the IProtocol instance (it'll be something else +# which feeds data out through the protocol, like a t.p.basic.FileSender). If +# a producerless subchannel writes too much, we won't be able to stop them, +# and we'll keep writing records into the Connection even though it's asked +# us to pause. Likewise, when the connection is down (and we're busily trying +# to reestablish a new one), registered subchannels will be paused, but +# unregistered ones will just dump everything in _outbound_queue, and we'll +# consume memory without bound until they stop. + +# We need several things: +# +# * Add each registered IProducer to a list, whose order remains stable. We +# want fairness under outbound throttling: each time the outbound +# connection opens up (our resumeProducing method is called), we should let +# just one producer have an opportunity to do transport.write, and then we +# should pause them again, and not come back to them until everyone else +# has gotten a turn. So we want an ordered list of producers to track this +# rotation. +# +# * Remove the IProducer if/when the protocol uses unregisterProducer +# +# * Remove any registered IProducer when the associated Subchannel is closed. +# This isn't a problem for normal transports, because usually there's a +# one-to-one mapping from Protocol to Transport, so when the Transport you +# forget the only reference to the Producer anyways. Our situation is +# unusual because we have multiple Subchannels that get merged into the +# same underlying Connection: each Subchannel's Protocol can register a +# producer on the Subchannel (which is an ITransport), but that adds it to +# a set of Producers for the Connection (which is also an ITransport). So +# if the Subchannel is closed, we need to remove its Producer (if any) even +# though the Connection remains open. +# +# * Register ourselves as an IPushProducer with each successive Connection +# object. These connections will come and go, but there will never be more +# than one. When the connection goes away, pause all our producers. When a +# new one is established, write all our queued messages, then unpause our +# producers as we would in resumeProducing. +# +# * Inside our resumeProducing call, we'll cycle through all producers, +# calling their individual resumeProducing methods one at a time. If they +# write so much data that the Connection pauses us again, we'll find out +# because our pauseProducing will be called inside that loop. When that +# happens, we need to stop looping. If we make it through the whole loop +# without being paused, then all subchannel Producers are left unpaused, +# and are free to write whenever they want. During this loop, some +# Producers will be paused, and others will be resumed +# +# * If our pauseProducing is called, all Producers must be paused, and a flag +# should be set to notify the resumeProducing loop to exit +# +# * In between calls to our resumeProducing method, we're in one of two +# states. +# * If we're writing data too fast, then we'll be left in the "paused" +# state, in which all Subchannel producers are paused, and the aggregate +# is paused too (our Connection told us to pauseProducing and hasn't yet +# told us to resumeProducing). In this state, activity is driven by the +# outbound TCP window opening up, which calls resumeProducing and allows +# (probably just) one message to be sent. We receive pauseProducing in +# the middle of their transport.write, so the loop exits early, and the +# only state change is that some other Producer should get to go next +# time. +# * If we're writing too slowly, we'll be left in the "unpaused" state: all +# Subchannel producers are unpaused, and the aggregate is unpaused too +# (resumeProducing is the last thing we've been told). In this satte, +# activity is driven by the Subchannels doing a transport.write, which +# queues some data on the TCP connection (and then might call +# pauseProducing if it's now full). +# +# * We want to guard against: +# +# * application protocol registering a Producer without first unregistering +# the previous one +# +# * application protocols writing data despite being told to pause +# (Subchannels without a registered Producer cannot be throttled, and we +# can't do anything about that, but we must also handle the case where +# they give us a pause switch and then proceed to ignore it) +# +# * our Connection calling resumeProducing or pauseProducing without an +# intervening call of the other kind +# +# * application protocols that don't handle a resumeProducing or +# pauseProducing call without an intervening call of the other kind (i.e. +# we should keep track of the last thing we told them, and not repeat +# ourselves) +# +# * If the Wormhole is closed, all Subchannels should close. This is not our +# responsibility: it lives in (Manager? Inbound?) +# +# * If we're given an IPullProducer, we should keep calling its +# resumeProducing until it runs out of data. We still want fairness, so we +# won't call it a second time until everyone else has had a turn. + + +# There are a couple of different ways to approach this. The one I've +# selected is: +# +# * keep a dict that maps from Subchannel to Producer, which only contains +# entries for Subchannels that have registered a producer. We use this to +# remove Producers when Subchannels are closed +# +# * keep a Deque of Producers. This represents the fair-throttling rotation: +# the left-most item gets the next upcoming turn, and then they'll be moved +# to the end of the queue. +# +# * keep a set of IPushProducers which are paused, a second set of +# IPushProducers which are unpaused, and a third set of IPullProducers +# (which are always left paused) Enforce the invariant that these three +# sets are disjoint, and that their union equals the contents of the deque. +# +# * keep a "paused" flag, which is cleared upon entry to resumeProducing, and +# set upon entry to pauseProducing. The loop inside resumeProducing checks +# this flag after each call to producer.resumeProducing, to sense whether +# they used their turn to write data, and if that write was large enough to +# fill the TCP window. If set, we break out of the loop. If not, we look +# for the next producer to unpause. The loop finishes when all producers +# are unpaused (evidenced by the two sets of paused producers being empty) +# +# * the "paused" flag also determines whether new IPushProducers are added to +# the paused or unpaused set (IPullProducers are always added to the +# pull+paused set). If we have any IPullProducers, we're always in the +# "writing data too fast" state. + +# other approaches that I didn't decide to do at this time (but might use in +# the future): +# +# * use one set instead of two. pros: fewer moving parts. cons: harder to +# spot decoherence bugs like adding a producer to the deque but forgetting +# to add it to one of the +# +# * use zero sets, and keep the paused-vs-unpaused state in the Subchannel as +# a visible boolean flag. This conflates Subchannels with their associated +# Producer (so if we went this way, we should also let them track their own +# Producer). Our resumeProducing loop ends when 'not any(sc.paused for sc +# in self._subchannels_with_producers)'. Pros: fewer subchannel->producer +# mappings lying around to disagree with one another. Cons: exposes a bit +# too much of the Subchannel internals + + +@attrs +@implementer(IOutbound) +class Outbound(object): + # Manage outbound data: subchannel writes to us, we write to transport + _manager = attrib(validator=provides(IDilationManager)) + _cooperator = attrib() + + def __attrs_post_init__(self): + # _outbound_queue holds all messages we've ever sent but not retired + self._outbound_queue = deque() + self._next_outbound_seqnum = 0 + # _queued_unsent are messages to retry with our new connection + self._queued_unsent = deque() + + # outbound flow control: the Connection throttles our writes + self._subchannel_producers = {} # Subchannel -> IProducer + self._paused = True # our Connection called our pauseProducing + self._all_producers = deque() # rotates, left-is-next + self._paused_producers = set() + self._unpaused_producers = set() + self._check_invariants() + + self._connection = None + + def _check_invariants(self): + assert self._unpaused_producers.isdisjoint(self._paused_producers) + assert (self._paused_producers.union(self._unpaused_producers) == + set(self._all_producers)) + + def build_record(self, record_type, *args): + seqnum = self._next_outbound_seqnum + self._next_outbound_seqnum += 1 + r = record_type(seqnum, *args) + assert hasattr(r, "seqnum"), r # only Open/Data/Close + return r + + def queue_and_send_record(self, r): + # we always queue it, to resend on a subsequent connection if + # necessary + self._outbound_queue.append(r) + + if self._connection: + if self._queued_unsent: + # to maintain correct ordering, queue this instead of sending it + self._queued_unsent.append(r) + else: + # we're allowed to send it immediately + self._connection.send_record(r) + + def send_if_connected(self, r): + assert isinstance(r, (KCM, Ping, Pong, Ack)), r # nothing with seqnum + if self._connection: + self._connection.send_record(r) + + # our subchannels call these to register a producer + + def subchannel_registerProducer(self, sc, producer, streaming): + # streaming==True: IPushProducer (pause/resume) + # streaming==False: IPullProducer (just resume) + if sc in self._subchannel_producers: + raise ValueError( + "registering producer %s before previous one (%s) was " + "unregistered" % (producer, + self._subchannel_producers[sc])) + # our underlying Connection uses streaming==True, so to make things + # easier, use an adapter when the Subchannel asks for streaming=False + if not streaming: + def unregister(): + self.subchannel_unregisterProducer(sc) + producer = PullToPush(producer, unregister, self._cooperator) + + self._subchannel_producers[sc] = producer + self._all_producers.append(producer) + if self._paused: + self._paused_producers.add(producer) + else: + self._unpaused_producers.add(producer) + self._check_invariants() + if streaming: + if self._paused: + # IPushProducers need to be paused immediately, before they + # speak + producer.pauseProducing() # you wake up sleeping + else: + # our PullToPush adapter must be started, but if we're paused then + # we tell it to pause before it gets a chance to write anything + producer.startStreaming(self._paused) + + def subchannel_unregisterProducer(self, sc): + # TODO: what if the subchannel closes, so we unregister their + # producer for them, then the application reacts to connectionLost + # with a duplicate unregisterProducer? + p = self._subchannel_producers.pop(sc) + if isinstance(p, PullToPush): + p.stopStreaming() + self._all_producers.remove(p) + self._paused_producers.discard(p) + self._unpaused_producers.discard(p) + self._check_invariants() + + def subchannel_closed(self, sc): + self._check_invariants() + if sc in self._subchannel_producers: + self.subchannel_unregisterProducer(sc) + + # our Manager tells us when we've got a new Connection to work with + + def use_connection(self, c): + self._connection = c + assert not self._queued_unsent + self._queued_unsent.extend(self._outbound_queue) + # the connection can tell us to pause when we send too much data + c.registerProducer(self, True) # IPushProducer: pause+resume + # send our queued messages + self.resumeProducing() + + def stop_using_connection(self): + self._connection.unregisterProducer() + self._connection = None + self._queued_unsent.clear() + self.pauseProducing() + # TODO: I expect this will call pauseProducing twice: the first time + # when we get stopProducing (since we're registere with the + # underlying connection as the producer), and again when the manager + # notices the connectionLost and calls our _stop_using_connection + + def handle_ack(self, resp_seqnum): + # we've received an inbound ack, so retire something + while (self._outbound_queue and + self._outbound_queue[0].seqnum <= resp_seqnum): + self._outbound_queue.popleft() + while (self._queued_unsent and + self._queued_unsent[0].seqnum <= resp_seqnum): + self._queued_unsent.popleft() + # Inbound is responsible for tracking the high watermark and deciding + # whether to ignore inbound messages or not + + # IProducer: the active connection calls these because we used + # c.registerProducer to ask for them + + def pauseProducing(self): + if self._paused: + return # someone is confused and called us twice + self._paused = True + for p in self._all_producers: + if p in self._unpaused_producers: + self._unpaused_producers.remove(p) + self._paused_producers.add(p) + p.pauseProducing() + + def resumeProducing(self): + if not self._paused: + return # someone is confused and called us twice + self._paused = False + + while not self._paused: + if self._queued_unsent: + r = self._queued_unsent.popleft() + self._connection.send_record(r) + continue + p = self._get_next_unpaused_producer() + if not p: + break + self._paused_producers.remove(p) + self._unpaused_producers.add(p) + p.resumeProducing() + + def _get_next_unpaused_producer(self): + self._check_invariants() + if not self._paused_producers: + return None + while True: + p = self._all_producers[0] + self._all_producers.rotate(-1) # p moves to the end of the line + # the only unpaused Producers are at the end of the list + assert p in self._paused_producers + return p + + def stopProducing(self): + # we'll hopefully have a new connection to work with in the future, + # so we don't shut anything down. We do pause everyone, though. + self.pauseProducing() + + +# modelled after twisted.internet._producer_helper._PullToPush , but with a +# configurable Cooperator, a pause-immediately argument to startStreaming() +@implementer(IPushProducer) +@attrs(cmp=False) +class PullToPush(object): + _producer = attrib(validator=provides(IPullProducer)) + _unregister = attrib(validator=lambda _a, _b, v: callable(v)) + _cooperator = attrib() + _finished = False + + def _pull(self): + while True: + try: + self._producer.resumeProducing() + except Exception: + log.err(None, "%s failed, producing will be stopped:" % + (safe_str(self._producer),)) + try: + self._unregister() + # The consumer should now call stopStreaming() on us, + # thus stopping the streaming. + except Exception: + # Since the consumer blew up, we may not have had + # stopStreaming() called, so we just stop on our own: + log.err(None, "%s failed to unregister producer:" % + (safe_str(self._unregister),)) + self._finished = True + return + yield None + + def startStreaming(self, paused): + self._coopTask = self._cooperator.cooperate(self._pull()) + if paused: + self.pauseProducing() # timer is scheduled, but task is removed + + def stopStreaming(self): + if self._finished: + return + self._finished = True + self._coopTask.stop() + + def pauseProducing(self): + self._coopTask.pause() + + def resumeProducing(self): + self._coopTask.resume() + + def stopProducing(self): + self.stopStreaming() + self._producer.stopProducing() diff --git a/src/wormhole/_dilation/roles.py b/src/wormhole/_dilation/roles.py new file mode 100644 index 0000000..8f9adac --- /dev/null +++ b/src/wormhole/_dilation/roles.py @@ -0,0 +1 @@ +LEADER, FOLLOWER = object(), object() diff --git a/src/wormhole/_dilation/subchannel.py b/src/wormhole/_dilation/subchannel.py new file mode 100644 index 0000000..abd1939 --- /dev/null +++ b/src/wormhole/_dilation/subchannel.py @@ -0,0 +1,300 @@ +from attr import attrs, attrib +from attr.validators import instance_of, provides +from zope.interface import implementer +from twisted.internet.defer import (Deferred, inlineCallbacks, returnValue, + succeed) +from twisted.internet.interfaces import (ITransport, IProducer, IConsumer, + IAddress, IListeningPort, + IStreamClientEndpoint, + IStreamServerEndpoint) +from twisted.internet.error import ConnectionDone +from automat import MethodicalMachine +from .._interfaces import ISubChannel, IDilationManager + + +@attrs +class Once(object): + _errtype = attrib() + + def __attrs_post_init__(self): + self._called = False + + def __call__(self): + if self._called: + raise self._errtype() + self._called = True + + +class SingleUseEndpointError(Exception): + pass + +# created in the (OPEN) state, by either: +# * receipt of an OPEN message +# * or local client_endpoint.connect() +# then transitions are: +# (OPEN) rx DATA: deliver .dataReceived(), -> (OPEN) +# (OPEN) rx CLOSE: deliver .connectionLost(), send CLOSE, -> (CLOSED) +# (OPEN) local .write(): send DATA, -> (OPEN) +# (OPEN) local .loseConnection(): send CLOSE, -> (CLOSING) +# (CLOSING) local .write(): error +# (CLOSING) local .loseConnection(): error +# (CLOSING) rx DATA: deliver .dataReceived(), -> (CLOSING) +# (CLOSING) rx CLOSE: deliver .connectionLost(), -> (CLOSED) +# object is deleted upon transition to (CLOSED) + + +class AlreadyClosedError(Exception): + pass + + +@implementer(IAddress) +class _WormholeAddress(object): + pass + + +@implementer(IAddress) +@attrs +class _SubchannelAddress(object): + _scid = attrib() + + +@attrs +@implementer(ITransport) +@implementer(IProducer) +@implementer(IConsumer) +@implementer(ISubChannel) +class SubChannel(object): + _id = attrib(validator=instance_of(bytes)) + _manager = attrib(validator=provides(IDilationManager)) + _host_addr = attrib(validator=instance_of(_WormholeAddress)) + _peer_addr = attrib(validator=instance_of(_SubchannelAddress)) + + m = MethodicalMachine() + set_trace = getattr(m, "_setTrace", lambda self, + f: None) # pragma: no cover + + def __attrs_post_init__(self): + # self._mailbox = None + # self._pending_outbound = {} + # self._processed = set() + self._protocol = None + self._pending_dataReceived = [] + self._pending_connectionLost = (False, None) + + @m.state(initial=True) + def open(self): + pass # pragma: no cover + + @m.state() + def closing(): + pass # pragma: no cover + + @m.state() + def closed(): + pass # pragma: no cover + + @m.input() + def remote_data(self, data): + pass + + @m.input() + def remote_close(self): + pass + + @m.input() + def local_data(self, data): + pass + + @m.input() + def local_close(self): + pass + + @m.output() + def send_data(self, data): + self._manager.send_data(self._id, data) + + @m.output() + def send_close(self): + self._manager.send_close(self._id) + + @m.output() + def signal_dataReceived(self, data): + if self._protocol: + self._protocol.dataReceived(data) + else: + self._pending_dataReceived.append(data) + + @m.output() + def signal_connectionLost(self): + if self._protocol: + self._protocol.connectionLost(ConnectionDone()) + else: + self._pending_connectionLost = (True, ConnectionDone()) + self._manager.subchannel_closed(self) + # we're deleted momentarily + + @m.output() + def error_closed_write(self, data): + raise AlreadyClosedError("write not allowed on closed subchannel") + + @m.output() + def error_closed_close(self): + raise AlreadyClosedError( + "loseConnection not allowed on closed subchannel") + + # primary transitions + open.upon(remote_data, enter=open, outputs=[signal_dataReceived]) + open.upon(local_data, enter=open, outputs=[send_data]) + open.upon(remote_close, enter=closed, outputs=[signal_connectionLost]) + open.upon(local_close, enter=closing, outputs=[send_close]) + closing.upon(remote_data, enter=closing, outputs=[signal_dataReceived]) + closing.upon(remote_close, enter=closed, outputs=[signal_connectionLost]) + + # error cases + # we won't ever see an OPEN, since L4 will log+ignore those for us + closing.upon(local_data, enter=closing, outputs=[error_closed_write]) + closing.upon(local_close, enter=closing, outputs=[error_closed_close]) + # the CLOSED state won't ever see messages, since we'll be deleted + + # our endpoints use this + + def _set_protocol(self, protocol): + assert not self._protocol + self._protocol = protocol + if self._pending_dataReceived: + for data in self._pending_dataReceived: + self._protocol.dataReceived(data) + self._pending_dataReceived = [] + cl, what = self._pending_connectionLost + if cl: + self._protocol.connectionLost(what) + + # ITransport + def write(self, data): + assert isinstance(data, type(b"")) + self.local_data(data) + + def writeSequence(self, iovec): + self.write(b"".join(iovec)) + + def loseConnection(self): + self.local_close() + + def getHost(self): + # we define "host addr" as the overall wormhole + return self._host_addr + + def getPeer(self): + # and "peer addr" as the subchannel within that wormhole + return self._peer_addr + + # IProducer: throttle inbound data (wormhole "up" to local app's Protocol) + def stopProducing(self): + self._manager.subchannel_stopProducing(self) + + def pauseProducing(self): + self._manager.subchannel_pauseProducing(self) + + def resumeProducing(self): + self._manager.subchannel_resumeProducing(self) + + # IConsumer: allow the wormhole to throttle outbound data (app->wormhole) + def registerProducer(self, producer, streaming): + self._manager.subchannel_registerProducer(self, producer, streaming) + + def unregisterProducer(self): + self._manager.subchannel_unregisterProducer(self) + + +@implementer(IStreamClientEndpoint) +class ControlEndpoint(object): + _used = False + + def __init__(self, peer_addr): + self._subchannel_zero = Deferred() + self._peer_addr = peer_addr + self._once = Once(SingleUseEndpointError) + + # from manager + def _subchannel_zero_opened(self, subchannel): + assert ISubChannel.providedBy(subchannel), subchannel + self._subchannel_zero.callback(subchannel) + + @inlineCallbacks + def connect(self, protocolFactory): + # return Deferred that fires with IProtocol or Failure(ConnectError) + self._once() + t = yield self._subchannel_zero + p = protocolFactory.buildProtocol(self._peer_addr) + t._set_protocol(p) + p.makeConnection(t) # set p.transport = t and call connectionMade() + returnValue(p) + + +@implementer(IStreamClientEndpoint) +@attrs +class SubchannelConnectorEndpoint(object): + _manager = attrib(validator=provides(IDilationManager)) + _host_addr = attrib(validator=instance_of(_WormholeAddress)) + + def connect(self, protocolFactory): + # return Deferred that fires with IProtocol or Failure(ConnectError) + scid = self._manager.allocate_subchannel_id() + self._manager.send_open(scid) + peer_addr = _SubchannelAddress(scid) + # ? f.doStart() + # ? f.startedConnecting(CONNECTOR) # ?? + t = SubChannel(scid, self._manager, self._host_addr, peer_addr) + p = protocolFactory.buildProtocol(peer_addr) + t._set_protocol(p) + p.makeConnection(t) # set p.transport = t and call connectionMade() + return succeed(p) + + +@implementer(IStreamServerEndpoint) +@attrs +class SubchannelListenerEndpoint(object): + _manager = attrib(validator=provides(IDilationManager)) + _host_addr = attrib(validator=provides(IAddress)) + + def __attrs_post_init__(self): + self._factory = None + self._pending_opens = [] + + # from manager + def _got_open(self, t, peer_addr): + if self._factory: + self._connect(t, peer_addr) + else: + self._pending_opens.append((t, peer_addr)) + + def _connect(self, t, peer_addr): + p = self._factory.buildProtocol(peer_addr) + t._set_protocol(p) + p.makeConnection(t) + + # IStreamServerEndpoint + + def listen(self, protocolFactory): + self._factory = protocolFactory + for (t, peer_addr) in self._pending_opens: + self._connect(t, peer_addr) + self._pending_opens = [] + lp = SubchannelListeningPort(self._host_addr) + return succeed(lp) + + +@implementer(IListeningPort) +@attrs +class SubchannelListeningPort(object): + _host_addr = attrib(validator=provides(IAddress)) + + def startListening(self): + pass + + def stopListening(self): + # TODO + pass + + def getHost(self): + return self._host_addr diff --git a/src/wormhole/_hints.py b/src/wormhole/_hints.py new file mode 100644 index 0000000..a8ee258 --- /dev/null +++ b/src/wormhole/_hints.py @@ -0,0 +1,138 @@ +from __future__ import print_function, unicode_literals +import sys +import re +import six +from collections import namedtuple +from twisted.internet.endpoints import HostnameEndpoint +from twisted.python import log + +# These namedtuples are "hint objects". The JSON-serializable dictionaries +# are "hint dicts". + +# DirectTCPV1Hint and TorTCPV1Hint mean the following protocol: +# * make a TCP connection (possibly via Tor) +# * send the sender/receiver handshake bytes first +# * expect to see the receiver/sender handshake bytes from the other side +# * the sender writes "go\n", the receiver waits for "go\n" +# * the rest of the connection contains transit data +DirectTCPV1Hint = namedtuple("DirectTCPV1Hint", + ["hostname", "port", "priority"]) +TorTCPV1Hint = namedtuple("TorTCPV1Hint", ["hostname", "port", "priority"]) +# RelayV1Hint contains a tuple of DirectTCPV1Hint and TorTCPV1Hint hints (we +# use a tuple rather than a list so they'll be hashable into a set). For each +# one, make the TCP connection, send the relay handshake, then complete the +# rest of the V1 protocol. Only one hint per relay is useful. +RelayV1Hint = namedtuple("RelayV1Hint", ["hints"]) + +def describe_hint_obj(hint, relay, tor): + prefix = "tor->" if tor else "->" + if relay: + prefix = prefix + "relay:" + if isinstance(hint, DirectTCPV1Hint): + return prefix + "tcp:%s:%d" % (hint.hostname, hint.port) + elif isinstance(hint, TorTCPV1Hint): + return prefix + "tor:%s:%d" % (hint.hostname, hint.port) + else: + return prefix + str(hint) + +def parse_hint_argv(hint, stderr=sys.stderr): + assert isinstance(hint, type(u"")) + # return tuple or None for an unparseable hint + priority = 0.0 + mo = re.search(r'^([a-zA-Z0-9]+):(.*)$', hint) + if not mo: + print("unparseable hint '%s'" % (hint, ), file=stderr) + return None + hint_type = mo.group(1) + if hint_type != "tcp": + print("unknown hint type '%s' in '%s'" % (hint_type, hint), + file=stderr) + return None + hint_value = mo.group(2) + pieces = hint_value.split(":") + if len(pieces) < 2: + print("unparseable TCP hint (need more colons) '%s'" % (hint, ), + file=stderr) + return None + mo = re.search(r'^(\d+)$', pieces[1]) + if not mo: + print("non-numeric port in TCP hint '%s'" % (hint, ), file=stderr) + return None + hint_host = pieces[0] + hint_port = int(pieces[1]) + for more in pieces[2:]: + if more.startswith("priority="): + more_pieces = more.split("=") + try: + priority = float(more_pieces[1]) + except ValueError: + print("non-float priority= in TCP hint '%s'" % (hint, ), + file=stderr) + return None + return DirectTCPV1Hint(hint_host, hint_port, priority) + +def endpoint_from_hint_obj(hint, tor, reactor): + if tor: + if isinstance(hint, (DirectTCPV1Hint, TorTCPV1Hint)): + # this Tor object will throw ValueError for non-public IPv4 + # addresses and any IPv6 address + try: + return tor.stream_via(hint.hostname, hint.port) + except ValueError: + return None + return None + if isinstance(hint, DirectTCPV1Hint): + return HostnameEndpoint(reactor, hint.hostname, hint.port) + return None + +def parse_tcp_v1_hint(hint): # hint_struct -> hint_obj + hint_type = hint.get("type", "") + if hint_type not in ["direct-tcp-v1", "tor-tcp-v1"]: + log.msg("unknown hint type: %r" % (hint, )) + return None + if not ("hostname" in hint and + isinstance(hint["hostname"], type(""))): + log.msg("invalid hostname in hint: %r" % (hint, )) + return None + if not ("port" in hint and + isinstance(hint["port"], six.integer_types)): + log.msg("invalid port in hint: %r" % (hint, )) + return None + priority = hint.get("priority", 0.0) + if hint_type == "direct-tcp-v1": + return DirectTCPV1Hint(hint["hostname"], hint["port"], priority) + else: + return TorTCPV1Hint(hint["hostname"], hint["port"], priority) + +def parse_hint(hint_struct): + hint_type = hint_struct.get("type", "") + if hint_type == "relay-v1": + # the struct can include multiple ways to reach the same relay + rhints = filter(lambda h: h, # drop None (unrecognized) + [parse_tcp_v1_hint(rh) for rh in hint_struct["hints"]]) + return RelayV1Hint(list(rhints)) + return parse_tcp_v1_hint(hint_struct) + + +def encode_hint(h): + if isinstance(h, DirectTCPV1Hint): + return {"type": "direct-tcp-v1", + "priority": h.priority, + "hostname": h.hostname, + "port": h.port, # integer + } + elif isinstance(h, RelayV1Hint): + rhint = {"type": "relay-v1", "hints": []} + for rh in h.hints: + rhint["hints"].append({"type": "direct-tcp-v1", + "priority": rh.priority, + "hostname": rh.hostname, + "port": rh.port}) + return rhint + elif isinstance(h, TorTCPV1Hint): + return {"type": "tor-tcp-v1", + "priority": h.priority, + "hostname": h.hostname, + "port": h.port, # integer + } + raise ValueError("unknown hint type", h) diff --git a/src/wormhole/_interfaces.py b/src/wormhole/_interfaces.py index 9f59ff4..15302d1 100644 --- a/src/wormhole/_interfaces.py +++ b/src/wormhole/_interfaces.py @@ -433,3 +433,27 @@ class IInputHelper(Interface): class IJournal(Interface): # TODO: this needs to be public pass + + +class IDilator(Interface): + pass + + +class IDilationManager(Interface): + pass + + +class IDilationConnector(Interface): + pass + + +class ISubChannel(Interface): + pass + + +class IInbound(Interface): + pass + + +class IOutbound(Interface): + pass diff --git a/src/wormhole/_key.py b/src/wormhole/_key.py index 849ed41..e14e51a 100644 --- a/src/wormhole/_key.py +++ b/src/wormhole/_key.py @@ -6,7 +6,6 @@ import six from attr import attrib, attrs from attr.validators import instance_of, provides from automat import MethodicalMachine -from hkdf import Hkdf from nacl import utils from nacl.exceptions import CryptoError from nacl.secret import SecretBox @@ -15,16 +14,12 @@ from zope.interface import implementer from . import _interfaces from .util import (bytes_to_dict, bytes_to_hexstr, dict_to_bytes, - hexstr_to_bytes, to_bytes) + hexstr_to_bytes, to_bytes, HKDF) CryptoError __all__ = ["derive_key", "derive_phase_key", "CryptoError", "Key"] -def HKDF(skm, outlen, salt=None, CTXinfo=b""): - return Hkdf(salt, skm).expand(CTXinfo, outlen) - - def derive_key(key, purpose, length=SecretBox.KEY_SIZE): if not isinstance(key, type(b"")): raise TypeError(type(key)) diff --git a/src/wormhole/_rendezvous.py b/src/wormhole/_rendezvous.py index 65d6545..db02166 100644 --- a/src/wormhole/_rendezvous.py +++ b/src/wormhole/_rendezvous.py @@ -89,6 +89,7 @@ class RendezvousConnector(object): # if the initial connection fails, signal an error and shut down. do # this in a different reactor turn to avoid some hazards d.addBoth(lambda res: task.deferLater(self._reactor, 0.0, lambda: res)) + # TODO: use EventualQueue d.addErrback(self._initial_connection_failed) self._debug_record_inbound_f = None diff --git a/src/wormhole/observer.py b/src/wormhole/observer.py index 43afa70..34fbf1b 100644 --- a/src/wormhole/observer.py +++ b/src/wormhole/observer.py @@ -70,3 +70,24 @@ class SequenceObserver(object): if self._observers: d = self._observers.pop(0) self._eq.eventually(d.callback, self._results.pop(0)) + + +class EmptyableSet(set): + # manage a set which grows and shrinks over time. Fire a Deferred the first + # time it becomes empty after you start watching for it. + + def __init__(self, *args, **kwargs): + self._eq = kwargs.pop("_eventual_queue") # required + super(EmptyableSet, self).__init__(*args, **kwargs) + self._observer = None + + def when_next_empty(self): + if not self._observer: + self._observer = OneShotObserver(self._eq) + return self._observer.when_fired() + + def discard(self, o): + super(EmptyableSet, self).discard(o) + if self._observer and not self: + self._observer.fire(None) + self._observer = None diff --git a/src/wormhole/test/dilate/__init__.py b/src/wormhole/test/dilate/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/wormhole/test/dilate/common.py b/src/wormhole/test/dilate/common.py new file mode 100644 index 0000000..4f398d7 --- /dev/null +++ b/src/wormhole/test/dilate/common.py @@ -0,0 +1,21 @@ +from __future__ import print_function, unicode_literals +import mock +from zope.interface import alsoProvides +from ..._interfaces import IDilationManager, IWormhole + + +def mock_manager(): + m = mock.Mock() + alsoProvides(m, IDilationManager) + return m + + +def mock_wormhole(): + m = mock.Mock() + alsoProvides(m, IWormhole) + return m + + +def clear_mock_calls(*args): + for a in args: + a.mock_calls[:] = [] diff --git a/src/wormhole/test/dilate/test_connection.py b/src/wormhole/test/dilate/test_connection.py new file mode 100644 index 0000000..ee761fd --- /dev/null +++ b/src/wormhole/test/dilate/test_connection.py @@ -0,0 +1,217 @@ +from __future__ import print_function, unicode_literals +import mock +from zope.interface import alsoProvides +from twisted.trial import unittest +from twisted.internet.task import Clock +from twisted.internet.interfaces import ITransport +from ...eventual import EventualQueue +from ..._interfaces import IDilationConnector +from ..._dilation.roles import LEADER, FOLLOWER +from ..._dilation.connection import (DilatedConnectionProtocol, encode_record, + KCM, Open, Ack) +from .common import clear_mock_calls + + +def make_con(role, use_relay=False): + clock = Clock() + eq = EventualQueue(clock) + connector = mock.Mock() + alsoProvides(connector, IDilationConnector) + n = mock.Mock() # pretends to be a Noise object + n.write_message = mock.Mock(side_effect=[b"handshake"]) + c = DilatedConnectionProtocol(eq, role, connector, n, + b"outbound_prologue\n", b"inbound_prologue\n") + if use_relay: + c.use_relay(b"relay_handshake\n") + t = mock.Mock() + alsoProvides(t, ITransport) + return c, n, connector, t, eq + + +class Connection(unittest.TestCase): + def test_bad_prologue(self): + c, n, connector, t, eq = make_con(LEADER) + c.makeConnection(t) + d = c.when_disconnected() + self.assertEqual(n.mock_calls, [mock.call.start_handshake()]) + self.assertEqual(connector.mock_calls, []) + self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")]) + clear_mock_calls(n, connector, t) + + c.dataReceived(b"prologue\n") + self.assertEqual(n.mock_calls, []) + self.assertEqual(connector.mock_calls, []) + self.assertEqual(t.mock_calls, [mock.call.loseConnection()]) + + eq.flush_sync() + self.assertNoResult(d) + c.connectionLost(b"why") + eq.flush_sync() + self.assertIdentical(self.successResultOf(d), c) + + def _test_no_relay(self, role): + c, n, connector, t, eq = make_con(role) + t_kcm = KCM() + t_open = Open(seqnum=1, scid=0x11223344) + t_ack = Ack(resp_seqnum=2) + n.decrypt = mock.Mock(side_effect=[ + encode_record(t_kcm), + encode_record(t_open), + ]) + exp_kcm = b"\x00\x00\x00\x03kcm" + n.encrypt = mock.Mock(side_effect=[b"kcm", b"ack1"]) + m = mock.Mock() # Manager + + c.makeConnection(t) + self.assertEqual(n.mock_calls, [mock.call.start_handshake()]) + self.assertEqual(connector.mock_calls, []) + self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")]) + clear_mock_calls(n, connector, t, m) + + c.dataReceived(b"inbound_prologue\n") + self.assertEqual(n.mock_calls, [mock.call.write_message()]) + self.assertEqual(connector.mock_calls, []) + exp_handshake = b"\x00\x00\x00\x09handshake" + self.assertEqual(t.mock_calls, [mock.call.write(exp_handshake)]) + clear_mock_calls(n, connector, t, m) + + c.dataReceived(b"\x00\x00\x00\x0Ahandshake2") + if role is LEADER: + # we're the leader, so we don't send the KCM right away + self.assertEqual(n.mock_calls, [ + mock.call.read_message(b"handshake2")]) + self.assertEqual(connector.mock_calls, []) + self.assertEqual(t.mock_calls, []) + self.assertEqual(c._manager, None) + else: + # we're the follower, so we encrypt and send the KCM immediately + self.assertEqual(n.mock_calls, [ + mock.call.read_message(b"handshake2"), + mock.call.encrypt(encode_record(t_kcm)), + ]) + self.assertEqual(connector.mock_calls, []) + self.assertEqual(t.mock_calls, [ + mock.call.write(exp_kcm)]) + self.assertEqual(c._manager, None) + clear_mock_calls(n, connector, t, m) + + c.dataReceived(b"\x00\x00\x00\x03KCM") + # leader: inbound KCM means we add the candidate + # follower: inbound KCM means we've been selected. + # in both cases we notify Connector.add_candidate(), and the Connector + # decides if/when to call .select() + + self.assertEqual(n.mock_calls, [mock.call.decrypt(b"KCM")]) + self.assertEqual(connector.mock_calls, [mock.call.add_candidate(c)]) + self.assertEqual(t.mock_calls, []) + clear_mock_calls(n, connector, t, m) + + # now pretend this connection wins (either the Leader decides to use + # this one among all the candiates, or we're the Follower and the + # Connector is reacting to add_candidate() by recognizing we're the + # only candidate there is) + c.select(m) + self.assertIdentical(c._manager, m) + if role is LEADER: + # TODO: currently Connector.select_and_stop_remaining() is + # responsible for sending the KCM just before calling c.select() + # iff we're the LEADER, therefore Connection.select won't send + # anything. This should be moved to c.select(). + self.assertEqual(n.mock_calls, []) + self.assertEqual(connector.mock_calls, []) + self.assertEqual(t.mock_calls, []) + self.assertEqual(m.mock_calls, []) + + c.send_record(KCM()) + self.assertEqual(n.mock_calls, [ + mock.call.encrypt(encode_record(t_kcm)), + ]) + self.assertEqual(connector.mock_calls, []) + self.assertEqual(t.mock_calls, [mock.call.write(exp_kcm)]) + self.assertEqual(m.mock_calls, []) + else: + # follower: we already sent the KCM, do nothing + self.assertEqual(n.mock_calls, []) + self.assertEqual(connector.mock_calls, []) + self.assertEqual(t.mock_calls, []) + self.assertEqual(m.mock_calls, []) + clear_mock_calls(n, connector, t, m) + + c.dataReceived(b"\x00\x00\x00\x04msg1") + self.assertEqual(n.mock_calls, [mock.call.decrypt(b"msg1")]) + self.assertEqual(connector.mock_calls, []) + self.assertEqual(t.mock_calls, []) + self.assertEqual(m.mock_calls, [mock.call.got_record(t_open)]) + clear_mock_calls(n, connector, t, m) + + c.send_record(t_ack) + exp_ack = b"\x06\x00\x00\x00\x02" + self.assertEqual(n.mock_calls, [mock.call.encrypt(exp_ack)]) + self.assertEqual(connector.mock_calls, []) + self.assertEqual(t.mock_calls, [mock.call.write(b"\x00\x00\x00\x04ack1")]) + self.assertEqual(m.mock_calls, []) + clear_mock_calls(n, connector, t, m) + + c.disconnect() + self.assertEqual(n.mock_calls, []) + self.assertEqual(connector.mock_calls, []) + self.assertEqual(t.mock_calls, [mock.call.loseConnection()]) + self.assertEqual(m.mock_calls, []) + clear_mock_calls(n, connector, t, m) + + def test_no_relay_leader(self): + return self._test_no_relay(LEADER) + + def test_no_relay_follower(self): + return self._test_no_relay(FOLLOWER) + + def test_relay(self): + c, n, connector, t, eq = make_con(LEADER, use_relay=True) + + c.makeConnection(t) + self.assertEqual(n.mock_calls, [mock.call.start_handshake()]) + self.assertEqual(connector.mock_calls, []) + self.assertEqual(t.mock_calls, [mock.call.write(b"relay_handshake\n")]) + clear_mock_calls(n, connector, t) + + c.dataReceived(b"ok\n") + self.assertEqual(n.mock_calls, []) + self.assertEqual(connector.mock_calls, []) + self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")]) + clear_mock_calls(n, connector, t) + + c.dataReceived(b"inbound_prologue\n") + self.assertEqual(n.mock_calls, [mock.call.write_message()]) + self.assertEqual(connector.mock_calls, []) + exp_handshake = b"\x00\x00\x00\x09handshake" + self.assertEqual(t.mock_calls, [mock.call.write(exp_handshake)]) + clear_mock_calls(n, connector, t) + + def test_relay_jilted(self): + c, n, connector, t, eq = make_con(LEADER, use_relay=True) + d = c.when_disconnected() + + c.makeConnection(t) + self.assertEqual(n.mock_calls, [mock.call.start_handshake()]) + self.assertEqual(connector.mock_calls, []) + self.assertEqual(t.mock_calls, [mock.call.write(b"relay_handshake\n")]) + clear_mock_calls(n, connector, t) + + c.connectionLost(b"why") + eq.flush_sync() + self.assertIdentical(self.successResultOf(d), c) + + def test_relay_bad_response(self): + c, n, connector, t, eq = make_con(LEADER, use_relay=True) + + c.makeConnection(t) + self.assertEqual(n.mock_calls, [mock.call.start_handshake()]) + self.assertEqual(connector.mock_calls, []) + self.assertEqual(t.mock_calls, [mock.call.write(b"relay_handshake\n")]) + clear_mock_calls(n, connector, t) + + c.dataReceived(b"not ok\n") + self.assertEqual(n.mock_calls, []) + self.assertEqual(connector.mock_calls, []) + self.assertEqual(t.mock_calls, [mock.call.loseConnection()]) + clear_mock_calls(n, connector, t) diff --git a/src/wormhole/test/dilate/test_connector.py b/src/wormhole/test/dilate/test_connector.py new file mode 100644 index 0000000..2bb8809 --- /dev/null +++ b/src/wormhole/test/dilate/test_connector.py @@ -0,0 +1,463 @@ +from __future__ import print_function, unicode_literals + +import mock +from zope.interface import alsoProvides +from twisted.trial import unittest +from twisted.internet.task import Clock +from twisted.internet.defer import Deferred +from ...eventual import EventualQueue +from ..._interfaces import IDilationManager, IDilationConnector +from ..._hints import DirectTCPV1Hint, RelayV1Hint, TorTCPV1Hint +from ..._dilation import roles +from ..._dilation._noise import NoiseConnection +from ..._dilation.connection import KCM +from ..._dilation.connector import (Connector, + build_sided_relay_handshake, + build_noise, + OutboundConnectionFactory, + InboundConnectionFactory, + PROLOGUE_LEADER, PROLOGUE_FOLLOWER, + ) +from .common import clear_mock_calls + +class Handshake(unittest.TestCase): + def test_build(self): + key = b"k"*32 + side = "12345678abcdabcd" + self.assertEqual(build_sided_relay_handshake(key, side), + b"please relay 3f4147851dbd2589d25b654ee9fb35ed0d3e5f19c5c5403e8e6a195c70f0577a for side 12345678abcdabcd\n") + +class Outbound(unittest.TestCase): + def test_no_relay(self): + c = mock.Mock() + alsoProvides(c, IDilationConnector) + p0 = mock.Mock() + c.build_protocol = mock.Mock(return_value=p0) + relay_handshake = None + f = OutboundConnectionFactory(c, relay_handshake) + addr = object() + p = f.buildProtocol(addr) + self.assertIdentical(p, p0) + self.assertEqual(c.mock_calls, [mock.call.build_protocol(addr)]) + self.assertEqual(p.mock_calls, []) + self.assertIdentical(p.factory, f) + + def test_with_relay(self): + c = mock.Mock() + alsoProvides(c, IDilationConnector) + p0 = mock.Mock() + c.build_protocol = mock.Mock(return_value=p0) + relay_handshake = b"relay handshake" + f = OutboundConnectionFactory(c, relay_handshake) + addr = object() + p = f.buildProtocol(addr) + self.assertIdentical(p, p0) + self.assertEqual(c.mock_calls, [mock.call.build_protocol(addr)]) + self.assertEqual(p.mock_calls, [mock.call.use_relay(relay_handshake)]) + self.assertIdentical(p.factory, f) + +class Inbound(unittest.TestCase): + def test_build(self): + c = mock.Mock() + alsoProvides(c, IDilationConnector) + p0 = mock.Mock() + c.build_protocol = mock.Mock(return_value=p0) + f = InboundConnectionFactory(c) + addr = object() + p = f.buildProtocol(addr) + self.assertIdentical(p, p0) + self.assertEqual(c.mock_calls, [mock.call.build_protocol(addr)]) + self.assertIdentical(p.factory, f) + +def make_connector(listen=True, tor=False, relay=None, role=roles.LEADER): + class Holder: + pass + h = Holder() + h.dilation_key = b"key" + h.relay = relay + h.manager = mock.Mock() + alsoProvides(h.manager, IDilationManager) + h.clock = Clock() + h.reactor = h.clock + h.eq = EventualQueue(h.clock) + h.tor = None + if tor: + h.tor = mock.Mock() + timing = None + h.side = u"abcd1234abcd5678" + h.role = role + c = Connector(h.dilation_key, h.relay, h.manager, h.reactor, h.eq, + not listen, h.tor, timing, h.side, h.role) + return c, h + +class TestConnector(unittest.TestCase): + def test_build(self): + c, h = make_connector() + c, h = make_connector(relay="tcp:host:1234") + + def test_connection_abilities(self): + self.assertEqual(Connector.get_connection_abilities(), + [{"type": "direct-tcp-v1"}, + {"type": "relay-v1"}, + ]) + + def test_build_noise(self): + if not NoiseConnection: + raise unittest.SkipTest("noiseprotocol unavailable") + build_noise() + + def test_build_protocol_leader(self): + c, h = make_connector(role=roles.LEADER) + n0 = mock.Mock() + p0 = mock.Mock() + addr = object() + with mock.patch("wormhole._dilation.connector.build_noise", + return_value=n0) as bn: + with mock.patch("wormhole._dilation.connector.DilatedConnectionProtocol", + return_value=p0) as dcp: + p = c.build_protocol(addr) + self.assertEqual(bn.mock_calls, [mock.call()]) + self.assertEqual(n0.mock_calls, [mock.call.set_psks(h.dilation_key), + mock.call.set_as_initiator()]) + self.assertIdentical(p, p0) + self.assertEqual(dcp.mock_calls, + [mock.call(h.eq, h.role, c, n0, + PROLOGUE_LEADER, PROLOGUE_FOLLOWER)]) + + def test_build_protocol_follower(self): + c, h = make_connector(role=roles.FOLLOWER) + n0 = mock.Mock() + p0 = mock.Mock() + addr = object() + with mock.patch("wormhole._dilation.connector.build_noise", + return_value=n0) as bn: + with mock.patch("wormhole._dilation.connector.DilatedConnectionProtocol", + return_value=p0) as dcp: + p = c.build_protocol(addr) + self.assertEqual(bn.mock_calls, [mock.call()]) + self.assertEqual(n0.mock_calls, [mock.call.set_psks(h.dilation_key), + mock.call.set_as_responder()]) + self.assertIdentical(p, p0) + self.assertEqual(dcp.mock_calls, + [mock.call(h.eq, h.role, c, n0, + PROLOGUE_FOLLOWER, PROLOGUE_LEADER)]) + + def test_start_stop(self): + c, h = make_connector(listen=False, relay=None, role=roles.LEADER) + c.start() + # no relays, so it publishes no hints + self.assertEqual(h.manager.mock_calls, []) + # and no listener, so nothing happens until we provide a hint + c.stop() + # we stop while we're connecting, so no connections must be stopped + + def test_empty(self): + c, h = make_connector(listen=False, relay=None, role=roles.LEADER) + c._schedule_connection = mock.Mock() + c.start() + # no relays, so it publishes no hints + self.assertEqual(h.manager.mock_calls, []) + # and no listener, so nothing happens until we provide a hint + self.assertEqual(c._schedule_connection.mock_calls, []) + c.stop() + + def test_basic(self): + c, h = make_connector(listen=False, relay=None, role=roles.LEADER) + c._schedule_connection = mock.Mock() + c.start() + # no relays, so it publishes no hints + self.assertEqual(h.manager.mock_calls, []) + # and no listener, so nothing happens until we provide a hint + self.assertEqual(c._schedule_connection.mock_calls, []) + + hint = DirectTCPV1Hint("foo", 55, 0.0) + c.got_hints([hint]) + + # received hints don't get published + self.assertEqual(h.manager.mock_calls, []) + # they just schedule a connection + self.assertEqual(c._schedule_connection.mock_calls, + [mock.call(0.0, DirectTCPV1Hint("foo", 55, 0.0), + is_relay=False)]) + + def test_listen_addresses(self): + c, h = make_connector(listen=True, role=roles.LEADER) + with mock.patch("wormhole.ipaddrs.find_addresses", + return_value=["127.0.0.1", "1.2.3.4", "5.6.7.8"]): + self.assertEqual(c._get_listener_addresses(), + ["1.2.3.4", "5.6.7.8"]) + with mock.patch("wormhole.ipaddrs.find_addresses", + return_value=["127.0.0.1"]): + # some test hosts, including the appveyor VMs, *only* have + # 127.0.0.1, and the tests will hang badly if we remove it. + self.assertEqual(c._get_listener_addresses(), ["127.0.0.1"]) + + def test_listen(self): + c, h = make_connector(listen=True, role=roles.LEADER) + c._start_listener = mock.Mock() + with mock.patch("wormhole.ipaddrs.find_addresses", + return_value=["127.0.0.1", "1.2.3.4", "5.6.7.8"]): + c.start() + self.assertEqual(c._start_listener.mock_calls, + [mock.call(["1.2.3.4", "5.6.7.8"])]) + + def test_start_listen(self): + c, h = make_connector(listen=True, role=roles.LEADER) + ep = mock.Mock() + d = Deferred() + ep.listen = mock.Mock(return_value=d) + with mock.patch("wormhole._dilation.connector.serverFromString", + return_value=ep) as sfs: + c._start_listener(["1.2.3.4", "5.6.7.8"]) + self.assertEqual(sfs.mock_calls, [mock.call(h.reactor, "tcp:0")]) + lp = mock.Mock() + host = mock.Mock() + host.port = 66 + lp.getHost = mock.Mock(return_value=host) + d.callback(lp) + self.assertEqual(h.manager.mock_calls, + [mock.call.send_hints([{"type": "direct-tcp-v1", + "hostname": "1.2.3.4", + "port": 66, + "priority": 0.0 + }, + {"type": "direct-tcp-v1", + "hostname": "5.6.7.8", + "port": 66, + "priority": 0.0 + }, + ])]) + + def test_schedule_connection_no_relay(self): + c, h = make_connector(listen=True, role=roles.LEADER) + hint = DirectTCPV1Hint("foo", 55, 0.0) + ep = mock.Mock() + with mock.patch("wormhole._dilation.connector.endpoint_from_hint_obj", + side_effect=[ep]) as efho: + c._schedule_connection(0.0, hint, False) + self.assertEqual(efho.mock_calls, [mock.call(hint, h.tor, h.reactor)]) + self.assertEqual(ep.mock_calls, []) + d = Deferred() + ep.connect = mock.Mock(side_effect=[d]) + # direct hints are scheduled for T+0.0 + f = mock.Mock() + with mock.patch("wormhole._dilation.connector.OutboundConnectionFactory", + return_value=f) as ocf: + h.clock.advance(1.0) + self.assertEqual(ocf.mock_calls, [mock.call(c, None)]) + self.assertEqual(ep.connect.mock_calls, [mock.call(f)]) + p = mock.Mock() + d.callback(p) + self.assertEqual(p.mock_calls, + [mock.call.when_disconnected(), + mock.call.when_disconnected().addCallback(c._pending_connections.discard)]) + + def test_schedule_connection_relay(self): + c, h = make_connector(listen=True, role=roles.LEADER) + hint = DirectTCPV1Hint("foo", 55, 0.0) + ep = mock.Mock() + with mock.patch("wormhole._dilation.connector.endpoint_from_hint_obj", + side_effect=[ep]) as efho: + c._schedule_connection(0.0, hint, True) + self.assertEqual(efho.mock_calls, [mock.call(hint, h.tor, h.reactor)]) + self.assertEqual(ep.mock_calls, []) + d = Deferred() + ep.connect = mock.Mock(side_effect=[d]) + # direct hints are scheduled for T+0.0 + f = mock.Mock() + with mock.patch("wormhole._dilation.connector.OutboundConnectionFactory", + return_value=f) as ocf: + h.clock.advance(1.0) + handshake = build_sided_relay_handshake(h.dilation_key, h.side) + self.assertEqual(ocf.mock_calls, [mock.call(c, handshake)]) + + def test_listen_but_tor(self): + c, h = make_connector(listen=True, tor=True, role=roles.LEADER) + with mock.patch("wormhole.ipaddrs.find_addresses", + return_value=["127.0.0.1", "1.2.3.4", "5.6.7.8"]) as fa: + c.start() + # don't even look up addresses + self.assertEqual(fa.mock_calls, []) + # no relays and the listener isn't ready yet, so no hints yet + self.assertEqual(h.manager.mock_calls, []) + + def test_no_listen(self): + c, h = make_connector(listen=False, tor=False, role=roles.LEADER) + with mock.patch("wormhole.ipaddrs.find_addresses", + return_value=["127.0.0.1", "1.2.3.4", "5.6.7.8"]) as fa: + c.start() + # don't even look up addresses + self.assertEqual(fa.mock_calls, []) + self.assertEqual(h.manager.mock_calls, []) + + def test_relay_delay(self): + # given a direct connection and a relay, we should see the direct + # connection initiated at T+0 seconds, and the relay at T+RELAY_DELAY + c, h = make_connector(listen=True, relay=None, role=roles.LEADER) + c._schedule_connection = mock.Mock() + c._start_listener = mock.Mock() + c.start() + hint1 = DirectTCPV1Hint("foo", 55, 0.0) + hint2 = DirectTCPV1Hint("bar", 55, 0.0) + hint3 = RelayV1Hint([DirectTCPV1Hint("relay", 55, 0.0)]) + c.got_hints([hint1, hint2, hint3]) + self.assertEqual(c._schedule_connection.mock_calls, + [mock.call(0.0, hint1, is_relay=False), + mock.call(0.0, hint2, is_relay=False), + mock.call(c.RELAY_DELAY, hint3.hints[0], is_relay=True), + ]) + + def test_initial_relay(self): + c, h = make_connector(listen=False, relay="tcp:foo:55", role=roles.LEADER) + c._schedule_connection = mock.Mock() + c.start() + self.assertEqual(h.manager.mock_calls, + [mock.call.send_hints([{"type": "relay-v1", + "hints": [ + {"type": "direct-tcp-v1", + "hostname": "foo", + "port": 55, + "priority": 0.0 + }, + ], + }])]) + self.assertEqual(c._schedule_connection.mock_calls, + [mock.call(0.0, DirectTCPV1Hint("foo", 55, 0.0), + is_relay=True)]) + + def test_add_relay(self): + c, h = make_connector(listen=False, relay=None, role=roles.LEADER) + c._schedule_connection = mock.Mock() + c.start() + self.assertEqual(h.manager.mock_calls, []) + self.assertEqual(c._schedule_connection.mock_calls, []) + hint = RelayV1Hint([DirectTCPV1Hint("foo", 55, 0.0)]) + c.add_relay([hint]) + self.assertEqual(h.manager.mock_calls, + [mock.call.send_hints([{"type": "relay-v1", + "hints": [ + {"type": "direct-tcp-v1", + "hostname": "foo", + "port": 55, + "priority": 0.0 + }, + ], + }])]) + self.assertEqual(c._schedule_connection.mock_calls, + [mock.call(0.0, DirectTCPV1Hint("foo", 55, 0.0), + is_relay=True)]) + + def test_tor_no_manager(self): + # tor hints should be ignored if we don't have a Tor manager to use them + c, h = make_connector(listen=False, role=roles.LEADER) + c._schedule_connection = mock.Mock() + c.start() + hint = TorTCPV1Hint("foo", 55, 0.0) + c.got_hints([hint]) + self.assertEqual(h.manager.mock_calls, []) + self.assertEqual(c._schedule_connection.mock_calls, []) + + def test_tor_with_manager(self): + # tor hints should be processed if we do have a Tor manager + c, h = make_connector(listen=False, tor=True, role=roles.LEADER) + c._schedule_connection = mock.Mock() + c.start() + hint = TorTCPV1Hint("foo", 55, 0.0) + c.got_hints([hint]) + self.assertEqual(c._schedule_connection.mock_calls, + [mock.call(0.0, hint, is_relay=False)]) + + def test_priorities(self): + # given two hints with different priorities, we should somehow prefer + # one. This is a placeholder to fill in once we implement priorities. + pass + + +class Race(unittest.TestCase): + def test_one_leader(self): + c, h = make_connector(listen=True, role=roles.LEADER) + lp = mock.Mock() + def start_listener(addresses): + c._listeners.add(lp) + c._start_listener = start_listener + c._schedule_connection = mock.Mock() + c.start() + self.assertEqual(c._listeners, set([lp])) + + p1 = mock.Mock() # DilatedConnectionProtocol instance + c.add_candidate(p1) + self.assertEqual(h.manager.mock_calls, []) + h.eq.flush_sync() + self.assertEqual(h.manager.mock_calls, [mock.call.use_connection(p1)]) + self.assertEqual(p1.mock_calls, + [mock.call.select(h.manager), + mock.call.send_record(KCM())]) + self.assertEqual(lp.mock_calls[0], mock.call.stopListening()) + # stop_listeners() uses a DeferredList, so we ignore the second call + + def test_one_follower(self): + c, h = make_connector(listen=True, role=roles.FOLLOWER) + lp = mock.Mock() + def start_listener(addresses): + c._listeners.add(lp) + c._start_listener = start_listener + c._schedule_connection = mock.Mock() + c.start() + self.assertEqual(c._listeners, set([lp])) + + p1 = mock.Mock() # DilatedConnectionProtocol instance + c.add_candidate(p1) + self.assertEqual(h.manager.mock_calls, []) + h.eq.flush_sync() + self.assertEqual(h.manager.mock_calls, [mock.call.use_connection(p1)]) + # just like LEADER, but follower doesn't send KCM now (it sent one + # earlier, to tell the leader that this connection looks viable) + self.assertEqual(p1.mock_calls, + [mock.call.select(h.manager)]) + self.assertEqual(lp.mock_calls[0], mock.call.stopListening()) + # stop_listeners() uses a DeferredList, so we ignore the second call + + # TODO: make sure a pending connection is abandoned when the listener + # answers successfully + + # TODO: make sure a second pending connection is abandoned when the first + # connection succeeds + + def test_late(self): + c, h = make_connector(listen=False, role=roles.LEADER) + c._schedule_connection = mock.Mock() + c.start() + + p1 = mock.Mock() # DilatedConnectionProtocol instance + c.add_candidate(p1) + self.assertEqual(h.manager.mock_calls, []) + h.eq.flush_sync() + self.assertEqual(h.manager.mock_calls, [mock.call.use_connection(p1)]) + clear_mock_calls(h.manager) + self.assertEqual(p1.mock_calls, + [mock.call.select(h.manager), + mock.call.send_record(KCM())]) + + # late connection is ignored + p2 = mock.Mock() + c.add_candidate(p2) + self.assertEqual(h.manager.mock_calls, []) + + + # make sure an established connection is dropped when stop() is called + def test_stop(self): + c, h = make_connector(listen=False, role=roles.LEADER) + c._schedule_connection = mock.Mock() + c.start() + + p1 = mock.Mock() # DilatedConnectionProtocol instance + c.add_candidate(p1) + self.assertEqual(h.manager.mock_calls, []) + h.eq.flush_sync() + self.assertEqual(h.manager.mock_calls, [mock.call.use_connection(p1)]) + self.assertEqual(p1.mock_calls, + [mock.call.select(h.manager), + mock.call.send_record(KCM())]) + + c.stop() + diff --git a/src/wormhole/test/dilate/test_encoding.py b/src/wormhole/test/dilate/test_encoding.py new file mode 100644 index 0000000..6bf2c7a --- /dev/null +++ b/src/wormhole/test/dilate/test_encoding.py @@ -0,0 +1,26 @@ +from __future__ import print_function, unicode_literals +from twisted.trial import unittest +from ..._dilation.encode import to_be4, from_be4 + + +class Encoding(unittest.TestCase): + + def test_be4(self): + self.assertEqual(to_be4(0), b"\x00\x00\x00\x00") + self.assertEqual(to_be4(1), b"\x00\x00\x00\x01") + self.assertEqual(to_be4(256), b"\x00\x00\x01\x00") + self.assertEqual(to_be4(257), b"\x00\x00\x01\x01") + with self.assertRaises(ValueError): + to_be4(-1) + with self.assertRaises(ValueError): + to_be4(2**32) + + self.assertEqual(from_be4(b"\x00\x00\x00\x00"), 0) + self.assertEqual(from_be4(b"\x00\x00\x00\x01"), 1) + self.assertEqual(from_be4(b"\x00\x00\x01\x00"), 256) + self.assertEqual(from_be4(b"\x00\x00\x01\x01"), 257) + + with self.assertRaises(TypeError): + from_be4(0) + with self.assertRaises(ValueError): + from_be4(b"\x01\x00\x00\x00\x00") diff --git a/src/wormhole/test/dilate/test_endpoints.py b/src/wormhole/test/dilate/test_endpoints.py new file mode 100644 index 0000000..ba07fe0 --- /dev/null +++ b/src/wormhole/test/dilate/test_endpoints.py @@ -0,0 +1,98 @@ +from __future__ import print_function, unicode_literals +import mock +from zope.interface import alsoProvides +from twisted.trial import unittest +from ..._interfaces import ISubChannel +from ..._dilation.subchannel import (ControlEndpoint, + SubchannelConnectorEndpoint, + SubchannelListenerEndpoint, + SubchannelListeningPort, + _WormholeAddress, _SubchannelAddress, + SingleUseEndpointError) +from .common import mock_manager + + +class Endpoints(unittest.TestCase): + def test_control(self): + scid0 = b"scid0" + peeraddr = _SubchannelAddress(scid0) + ep = ControlEndpoint(peeraddr) + + f = mock.Mock() + p = mock.Mock() + f.buildProtocol = mock.Mock(return_value=p) + d = ep.connect(f) + self.assertNoResult(d) + + t = mock.Mock() + alsoProvides(t, ISubChannel) + ep._subchannel_zero_opened(t) + self.assertIdentical(self.successResultOf(d), p) + self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr)]) + self.assertEqual(t.mock_calls, [mock.call._set_protocol(p)]) + self.assertEqual(p.mock_calls, [mock.call.makeConnection(t)]) + + d = ep.connect(f) + self.failureResultOf(d, SingleUseEndpointError) + + def assert_makeConnection(self, mock_calls): + self.assertEqual(len(mock_calls), 1) + self.assertEqual(mock_calls[0][0], "makeConnection") + self.assertEqual(len(mock_calls[0][1]), 1) + return mock_calls[0][1][0] + + def test_connector(self): + m = mock_manager() + m.allocate_subchannel_id = mock.Mock(return_value=b"scid") + hostaddr = _WormholeAddress() + peeraddr = _SubchannelAddress(b"scid") + ep = SubchannelConnectorEndpoint(m, hostaddr) + + f = mock.Mock() + p = mock.Mock() + t = mock.Mock() + f.buildProtocol = mock.Mock(return_value=p) + with mock.patch("wormhole._dilation.subchannel.SubChannel", + return_value=t) as sc: + d = ep.connect(f) + self.assertIdentical(self.successResultOf(d), p) + self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr)]) + self.assertEqual(sc.mock_calls, [mock.call(b"scid", m, hostaddr, peeraddr)]) + self.assertEqual(t.mock_calls, [mock.call._set_protocol(p)]) + self.assertEqual(p.mock_calls, [mock.call.makeConnection(t)]) + + def test_listener(self): + m = mock_manager() + m.allocate_subchannel_id = mock.Mock(return_value=b"scid") + hostaddr = _WormholeAddress() + ep = SubchannelListenerEndpoint(m, hostaddr) + + f = mock.Mock() + p1 = mock.Mock() + p2 = mock.Mock() + f.buildProtocol = mock.Mock(side_effect=[p1, p2]) + + # OPEN that arrives before we ep.listen() should be queued + + t1 = mock.Mock() + peeraddr1 = _SubchannelAddress(b"peer1") + ep._got_open(t1, peeraddr1) + + d = ep.listen(f) + lp = self.successResultOf(d) + self.assertIsInstance(lp, SubchannelListeningPort) + + self.assertEqual(lp.getHost(), hostaddr) + lp.startListening() + + self.assertEqual(t1.mock_calls, [mock.call._set_protocol(p1)]) + self.assertEqual(p1.mock_calls, [mock.call.makeConnection(t1)]) + + t2 = mock.Mock() + peeraddr2 = _SubchannelAddress(b"peer2") + ep._got_open(t2, peeraddr2) + + self.assertEqual(t2.mock_calls, [mock.call._set_protocol(p2)]) + self.assertEqual(p2.mock_calls, [mock.call.makeConnection(t2)]) + + lp.stopListening() # TODO: should this do more? diff --git a/src/wormhole/test/dilate/test_framer.py b/src/wormhole/test/dilate/test_framer.py new file mode 100644 index 0000000..51ac039 --- /dev/null +++ b/src/wormhole/test/dilate/test_framer.py @@ -0,0 +1,112 @@ +from __future__ import print_function, unicode_literals +import mock +from zope.interface import alsoProvides +from twisted.trial import unittest +from twisted.internet.interfaces import ITransport +from ..._dilation.connection import _Framer, Frame, Prologue, Disconnect + + +def make_framer(): + t = mock.Mock() + alsoProvides(t, ITransport) + f = _Framer(t, b"outbound_prologue\n", b"inbound_prologue\n") + return f, t + + +class Framer(unittest.TestCase): + def test_bad_prologue_length(self): + f, t = make_framer() + self.assertEqual(t.mock_calls, []) + + f.connectionMade() + self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")]) + t.mock_calls[:] = [] + self.assertEqual([], list(f.add_and_parse(b"inbound_"))) # wait for it + self.assertEqual(t.mock_calls, []) + + with mock.patch("wormhole._dilation.connection.log.msg") as m: + with self.assertRaises(Disconnect): + list(f.add_and_parse(b"not the prologue after all")) + self.assertEqual(m.mock_calls, + [mock.call("bad prologue: {}".format( + b"inbound_not the p"))]) + self.assertEqual(t.mock_calls, []) + + def test_bad_prologue_newline(self): + f, t = make_framer() + self.assertEqual(t.mock_calls, []) + + f.connectionMade() + self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")]) + t.mock_calls[:] = [] + self.assertEqual([], list(f.add_and_parse(b"inbound_"))) # wait for it + + self.assertEqual([], list(f.add_and_parse(b"not"))) + with mock.patch("wormhole._dilation.connection.log.msg") as m: + with self.assertRaises(Disconnect): + list(f.add_and_parse(b"\n")) + self.assertEqual(m.mock_calls, + [mock.call("bad prologue: {}".format( + b"inbound_not\n"))]) + self.assertEqual(t.mock_calls, []) + + def test_good_prologue(self): + f, t = make_framer() + self.assertEqual(t.mock_calls, []) + + f.connectionMade() + self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")]) + t.mock_calls[:] = [] + self.assertEqual([Prologue()], + list(f.add_and_parse(b"inbound_prologue\n"))) + self.assertEqual(t.mock_calls, []) + + # now send_frame should work + f.send_frame(b"frame") + self.assertEqual(t.mock_calls, + [mock.call.write(b"\x00\x00\x00\x05frame")]) + + def test_bad_relay(self): + f, t = make_framer() + self.assertEqual(t.mock_calls, []) + f.use_relay(b"relay handshake\n") + + f.connectionMade() + self.assertEqual(t.mock_calls, [mock.call.write(b"relay handshake\n")]) + t.mock_calls[:] = [] + with mock.patch("wormhole._dilation.connection.log.msg") as m: + with self.assertRaises(Disconnect): + list(f.add_and_parse(b"goodbye\n")) + self.assertEqual(m.mock_calls, + [mock.call("bad relay_ok: {}".format(b"goo"))]) + self.assertEqual(t.mock_calls, []) + + def test_good_relay(self): + f, t = make_framer() + self.assertEqual(t.mock_calls, []) + f.use_relay(b"relay handshake\n") + self.assertEqual(t.mock_calls, []) + + f.connectionMade() + self.assertEqual(t.mock_calls, [mock.call.write(b"relay handshake\n")]) + t.mock_calls[:] = [] + + self.assertEqual([], list(f.add_and_parse(b"ok\n"))) + self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")]) + + def test_frame(self): + f, t = make_framer() + self.assertEqual(t.mock_calls, []) + + f.connectionMade() + self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")]) + t.mock_calls[:] = [] + self.assertEqual([Prologue()], + list(f.add_and_parse(b"inbound_prologue\n"))) + self.assertEqual(t.mock_calls, []) + + encoded_frame = b"\x00\x00\x00\x05frame" + self.assertEqual([], list(f.add_and_parse(encoded_frame[:2]))) + self.assertEqual([], list(f.add_and_parse(encoded_frame[2:6]))) + self.assertEqual([Frame(frame=b"frame")], + list(f.add_and_parse(encoded_frame[6:]))) diff --git a/src/wormhole/test/dilate/test_inbound.py b/src/wormhole/test/dilate/test_inbound.py new file mode 100644 index 0000000..392a661 --- /dev/null +++ b/src/wormhole/test/dilate/test_inbound.py @@ -0,0 +1,174 @@ +from __future__ import print_function, unicode_literals +import mock +from zope.interface import alsoProvides +from twisted.trial import unittest +from ..._interfaces import IDilationManager +from ..._dilation.connection import Open, Data, Close +from ..._dilation.inbound import (Inbound, DuplicateOpenError, + DataForMissingSubchannelError, + CloseForMissingSubchannelError) + + +def make_inbound(): + m = mock.Mock() + alsoProvides(m, IDilationManager) + host_addr = object() + i = Inbound(m, host_addr) + return i, m, host_addr + + +class InboundTest(unittest.TestCase): + def test_seqnum(self): + i, m, host_addr = make_inbound() + r1 = Open(scid=513, seqnum=1) + r2 = Data(scid=513, seqnum=2, data=b"") + r3 = Close(scid=513, seqnum=3) + self.assertFalse(i.is_record_old(r1)) + self.assertFalse(i.is_record_old(r2)) + self.assertFalse(i.is_record_old(r3)) + + i.update_ack_watermark(r1) + self.assertTrue(i.is_record_old(r1)) + self.assertFalse(i.is_record_old(r2)) + self.assertFalse(i.is_record_old(r3)) + + i.update_ack_watermark(r2) + self.assertTrue(i.is_record_old(r1)) + self.assertTrue(i.is_record_old(r2)) + self.assertFalse(i.is_record_old(r3)) + + def test_open_data_close(self): + i, m, host_addr = make_inbound() + scid1 = b"scid" + scid2 = b"scXX" + c = mock.Mock() + lep = mock.Mock() + i.set_listener_endpoint(lep) + i.use_connection(c) + sc1 = mock.Mock() + peer_addr = object() + with mock.patch("wormhole._dilation.inbound.SubChannel", + side_effect=[sc1]) as sc: + with mock.patch("wormhole._dilation.inbound._SubchannelAddress", + side_effect=[peer_addr]) as sca: + i.handle_open(scid1) + self.assertEqual(lep.mock_calls, [mock.call._got_open(sc1, peer_addr)]) + self.assertEqual(sc.mock_calls, [mock.call(scid1, m, host_addr, peer_addr)]) + self.assertEqual(sca.mock_calls, [mock.call(scid1)]) + lep.mock_calls[:] = [] + + # a subsequent duplicate OPEN should be ignored + with mock.patch("wormhole._dilation.inbound.SubChannel", + side_effect=[sc1]) as sc: + with mock.patch("wormhole._dilation.inbound._SubchannelAddress", + side_effect=[peer_addr]) as sca: + i.handle_open(scid1) + self.assertEqual(lep.mock_calls, []) + self.assertEqual(sc.mock_calls, []) + self.assertEqual(sca.mock_calls, []) + self.flushLoggedErrors(DuplicateOpenError) + + i.handle_data(scid1, b"data") + self.assertEqual(sc1.mock_calls, [mock.call.remote_data(b"data")]) + sc1.mock_calls[:] = [] + + i.handle_data(scid2, b"for non-existent subchannel") + self.assertEqual(sc1.mock_calls, []) + self.flushLoggedErrors(DataForMissingSubchannelError) + + i.handle_close(scid1) + self.assertEqual(sc1.mock_calls, [mock.call.remote_close()]) + sc1.mock_calls[:] = [] + + i.handle_close(scid2) + self.assertEqual(sc1.mock_calls, []) + self.flushLoggedErrors(CloseForMissingSubchannelError) + + # after the subchannel is closed, the Manager will notify Inbound + i.subchannel_closed(scid1, sc1) + + i.stop_using_connection() + + def test_control_channel(self): + i, m, host_addr = make_inbound() + lep = mock.Mock() + i.set_listener_endpoint(lep) + + scid0 = b"scid" + sc0 = mock.Mock() + i.set_subchannel_zero(scid0, sc0) + + # OPEN on the control channel identifier should be ignored as a + # duplicate, since the control channel is already registered + sc1 = mock.Mock() + peer_addr = object() + with mock.patch("wormhole._dilation.inbound.SubChannel", + side_effect=[sc1]) as sc: + with mock.patch("wormhole._dilation.inbound._SubchannelAddress", + side_effect=[peer_addr]) as sca: + i.handle_open(scid0) + self.assertEqual(lep.mock_calls, []) + self.assertEqual(sc.mock_calls, []) + self.assertEqual(sca.mock_calls, []) + self.flushLoggedErrors(DuplicateOpenError) + + # and DATA to it should be delivered correctly + i.handle_data(scid0, b"data") + self.assertEqual(sc0.mock_calls, [mock.call.remote_data(b"data")]) + sc0.mock_calls[:] = [] + + def test_pause(self): + i, m, host_addr = make_inbound() + c = mock.Mock() + lep = mock.Mock() + i.set_listener_endpoint(lep) + + # add two subchannels, pause one, then add a connection + scid1 = b"sci1" + scid2 = b"sci2" + sc1 = mock.Mock() + sc2 = mock.Mock() + peer_addr = object() + with mock.patch("wormhole._dilation.inbound.SubChannel", + side_effect=[sc1, sc2]): + with mock.patch("wormhole._dilation.inbound._SubchannelAddress", + return_value=peer_addr): + i.handle_open(scid1) + i.handle_open(scid2) + self.assertEqual(c.mock_calls, []) + + i.subchannel_pauseProducing(sc1) + self.assertEqual(c.mock_calls, []) + i.subchannel_resumeProducing(sc1) + self.assertEqual(c.mock_calls, []) + i.subchannel_pauseProducing(sc1) + self.assertEqual(c.mock_calls, []) + + i.use_connection(c) + self.assertEqual(c.mock_calls, [mock.call.pauseProducing()]) + c.mock_calls[:] = [] + + i.subchannel_resumeProducing(sc1) + self.assertEqual(c.mock_calls, [mock.call.resumeProducing()]) + c.mock_calls[:] = [] + + # consumers aren't really supposed to do this, but tolerate it + i.subchannel_resumeProducing(sc1) + self.assertEqual(c.mock_calls, []) + + i.subchannel_pauseProducing(sc1) + self.assertEqual(c.mock_calls, [mock.call.pauseProducing()]) + c.mock_calls[:] = [] + i.subchannel_pauseProducing(sc2) + self.assertEqual(c.mock_calls, []) # was already paused + + # tolerate duplicate pauseProducing + i.subchannel_pauseProducing(sc2) + self.assertEqual(c.mock_calls, []) + + # stopProducing is treated like a terminal resumeProducing + i.subchannel_stopProducing(sc1) + self.assertEqual(c.mock_calls, []) + i.subchannel_stopProducing(sc2) + self.assertEqual(c.mock_calls, [mock.call.resumeProducing()]) + c.mock_calls[:] = [] diff --git a/src/wormhole/test/dilate/test_manager.py b/src/wormhole/test/dilate/test_manager.py new file mode 100644 index 0000000..e223a1c --- /dev/null +++ b/src/wormhole/test/dilate/test_manager.py @@ -0,0 +1,650 @@ +from __future__ import print_function, unicode_literals +from zope.interface import alsoProvides +from twisted.trial import unittest +from twisted.internet.defer import Deferred +from twisted.internet.task import Clock, Cooperator +import mock +from ...eventual import EventualQueue +from ..._interfaces import ISend, IDilationManager +from ...util import dict_to_bytes +from ..._dilation import roles +from ..._dilation.encode import to_be4 +from ..._dilation.manager import (Dilator, Manager, make_side, + OldPeerCannotDilateError, + UnknownDilationMessageType, + UnexpectedKCM, + UnknownMessageType) +from ..._dilation.subchannel import _WormholeAddress +from ..._dilation.connection import Open, Data, Close, Ack, KCM, Ping, Pong +from .common import clear_mock_calls + + +def make_dilator(): + reactor = object() + clock = Clock() + eq = EventualQueue(clock) + term = mock.Mock(side_effect=lambda: True) # one write per Eventual tick + + def term_factory(): + return term + coop = Cooperator(terminationPredicateFactory=term_factory, + scheduler=eq.eventually) + send = mock.Mock() + alsoProvides(send, ISend) + dil = Dilator(reactor, eq, coop) + dil.wire(send) + return dil, send, reactor, eq, clock, coop + + +class TestDilator(unittest.TestCase): + def test_manager_and_endpoints(self): + dil, send, reactor, eq, clock, coop = make_dilator() + d1 = dil.dilate() + d2 = dil.dilate() + self.assertNoResult(d1) + self.assertNoResult(d2) + + key = b"key" + transit_key = object() + with mock.patch("wormhole._dilation.manager.derive_key", + return_value=transit_key) as dk: + dil.got_key(key) + self.assertEqual(dk.mock_calls, [mock.call(key, b"dilation-v1", 32)]) + self.assertIdentical(dil._transit_key, transit_key) + self.assertNoResult(d1) + self.assertNoResult(d2) + + m = mock.Mock() + alsoProvides(m, IDilationManager) + m.when_first_connected.return_value = wfc_d = Deferred() + with mock.patch("wormhole._dilation.manager.Manager", + return_value=m) as ml: + with mock.patch("wormhole._dilation.manager.make_side", + return_value="us"): + dil.got_wormhole_versions({"can-dilate": ["1"]}) + # that should create the Manager + self.assertEqual(ml.mock_calls, [mock.call(send, "us", transit_key, + None, reactor, eq, coop)]) + # and tell it to start, and get wait-for-it-to-connect Deferred + self.assertEqual(m.mock_calls, [mock.call.start(), + mock.call.when_first_connected(), + ]) + clear_mock_calls(m) + self.assertNoResult(d1) + self.assertNoResult(d2) + + host_addr = _WormholeAddress() + m_wa = mock.patch("wormhole._dilation.manager._WormholeAddress", + return_value=host_addr) + peer_addr = object() + m_sca = mock.patch("wormhole._dilation.manager._SubchannelAddress", + return_value=peer_addr) + ce = mock.Mock() + m_ce = mock.patch("wormhole._dilation.manager.ControlEndpoint", + return_value=ce) + sc = mock.Mock() + m_sc = mock.patch("wormhole._dilation.manager.SubChannel", + return_value=sc) + + lep = object() + m_sle = mock.patch("wormhole._dilation.manager.SubchannelListenerEndpoint", + return_value=lep) + + with m_wa, m_sca, m_ce as m_ce_m, m_sc as m_sc_m, m_sle as m_sle_m: + wfc_d.callback(None) + eq.flush_sync() + scid0 = b"\x00\x00\x00\x00" + self.assertEqual(m_ce_m.mock_calls, [mock.call(peer_addr)]) + self.assertEqual(m_sc_m.mock_calls, + [mock.call(scid0, m, host_addr, peer_addr)]) + self.assertEqual(ce.mock_calls, [mock.call._subchannel_zero_opened(sc)]) + self.assertEqual(m_sle_m.mock_calls, [mock.call(m, host_addr)]) + self.assertEqual(m.mock_calls, + [mock.call.set_subchannel_zero(scid0, sc), + mock.call.set_listener_endpoint(lep), + ]) + clear_mock_calls(m) + + eps = self.successResultOf(d1) + self.assertEqual(eps, self.successResultOf(d2)) + d3 = dil.dilate() + eq.flush_sync() + self.assertEqual(eps, self.successResultOf(d3)) + + # all subsequent DILATE-n messages should get passed to the manager + self.assertEqual(m.mock_calls, []) + pleasemsg = dict(type="please", side="them") + dil.received_dilate(dict_to_bytes(pleasemsg)) + self.assertEqual(m.mock_calls, [mock.call.rx_PLEASE(pleasemsg)]) + clear_mock_calls(m) + + hintmsg = dict(type="connection-hints") + dil.received_dilate(dict_to_bytes(hintmsg)) + self.assertEqual(m.mock_calls, [mock.call.rx_HINTS(hintmsg)]) + clear_mock_calls(m) + + # we're nominally the LEADER, and the leader would not normally be + # receiving a RECONNECT, but since we've mocked out the Manager it + # won't notice + dil.received_dilate(dict_to_bytes(dict(type="reconnect"))) + self.assertEqual(m.mock_calls, [mock.call.rx_RECONNECT()]) + clear_mock_calls(m) + + dil.received_dilate(dict_to_bytes(dict(type="reconnecting"))) + self.assertEqual(m.mock_calls, [mock.call.rx_RECONNECTING()]) + clear_mock_calls(m) + + dil.received_dilate(dict_to_bytes(dict(type="unknown"))) + self.assertEqual(m.mock_calls, []) + self.flushLoggedErrors(UnknownDilationMessageType) + + def test_peer_cannot_dilate(self): + dil, send, reactor, eq, clock, coop = make_dilator() + d1 = dil.dilate() + self.assertNoResult(d1) + + dil._transit_key = b"\x01" * 32 + dil.got_wormhole_versions({}) # missing "can-dilate" + eq.flush_sync() + f = self.failureResultOf(d1) + f.check(OldPeerCannotDilateError) + + def test_disjoint_versions(self): + dil, send, reactor, eq, clock, coop = make_dilator() + d1 = dil.dilate() + self.assertNoResult(d1) + + dil._transit_key = b"key" + dil.got_wormhole_versions({"can-dilate": [-1]}) + eq.flush_sync() + f = self.failureResultOf(d1) + f.check(OldPeerCannotDilateError) + + def test_early_dilate_messages(self): + dil, send, reactor, eq, clock, coop = make_dilator() + dil._transit_key = b"key" + d1 = dil.dilate() + self.assertNoResult(d1) + pleasemsg = dict(type="please", side="them") + dil.received_dilate(dict_to_bytes(pleasemsg)) + hintmsg = dict(type="connection-hints") + dil.received_dilate(dict_to_bytes(hintmsg)) + + m = mock.Mock() + alsoProvides(m, IDilationManager) + m.when_first_connected.return_value = Deferred() + + with mock.patch("wormhole._dilation.manager.Manager", + return_value=m) as ml: + with mock.patch("wormhole._dilation.manager.make_side", + return_value="us"): + dil.got_wormhole_versions({"can-dilate": ["1"]}) + self.assertEqual(ml.mock_calls, [mock.call(send, "us", b"key", + None, reactor, eq, coop)]) + self.assertEqual(m.mock_calls, [mock.call.start(), + mock.call.rx_PLEASE(pleasemsg), + mock.call.rx_HINTS(hintmsg), + mock.call.when_first_connected()]) + + def test_transit_relay(self): + dil, send, reactor, eq, clock, coop = make_dilator() + dil._transit_key = b"key" + relay = object() + d1 = dil.dilate(transit_relay_location=relay) + self.assertNoResult(d1) + + with mock.patch("wormhole._dilation.manager.Manager") as ml: + with mock.patch("wormhole._dilation.manager.make_side", + return_value="us"): + dil.got_wormhole_versions({"can-dilate": ["1"]}) + self.assertEqual(ml.mock_calls, [mock.call(send, "us", b"key", + relay, reactor, eq, coop), + mock.call().start(), + mock.call().when_first_connected()]) + + +LEADER = "ff3456abcdef" +FOLLOWER = "123456abcdef" + + +def make_manager(leader=True): + class Holder: + pass + h = Holder() + h.send = mock.Mock() + alsoProvides(h.send, ISend) + if leader: + side = LEADER + else: + side = FOLLOWER + h.key = b"\x00" * 32 + h.relay = None + h.reactor = object() + h.clock = Clock() + h.eq = EventualQueue(h.clock) + term = mock.Mock(side_effect=lambda: True) # one write per Eventual tick + + def term_factory(): + return term + h.coop = Cooperator(terminationPredicateFactory=term_factory, + scheduler=h.eq.eventually) + h.inbound = mock.Mock() + h.Inbound = mock.Mock(return_value=h.inbound) + h.outbound = mock.Mock() + h.Outbound = mock.Mock(return_value=h.outbound) + h.hostaddr = object() + with mock.patch("wormhole._dilation.manager.Inbound", h.Inbound): + with mock.patch("wormhole._dilation.manager.Outbound", h.Outbound): + with mock.patch("wormhole._dilation.manager._WormholeAddress", + return_value=h.hostaddr): + m = Manager(h.send, side, h.key, h.relay, h.reactor, h.eq, h.coop) + return m, h + + +class TestManager(unittest.TestCase): + def test_make_side(self): + side = make_side() + self.assertEqual(type(side), type(u"")) + self.assertEqual(len(side), 2 * 6) + + def test_create(self): + m, h = make_manager() + + def test_leader(self): + m, h = make_manager(leader=True) + self.assertEqual(h.send.mock_calls, []) + self.assertEqual(h.Inbound.mock_calls, [mock.call(m, h.hostaddr)]) + self.assertEqual(h.Outbound.mock_calls, [mock.call(m, h.coop)]) + + m.start() + self.assertEqual(h.send.mock_calls, [ + mock.call.send("dilate-0", + dict_to_bytes({"type": "please", "side": LEADER})) + ]) + clear_mock_calls(h.send) + + wfc_d = m.when_first_connected() + self.assertNoResult(wfc_d) + + # ignore early hints + m.rx_HINTS({}) + self.assertEqual(h.send.mock_calls, []) + + c = mock.Mock() + connector = mock.Mock(return_value=c) + with mock.patch("wormhole._dilation.manager.Connector", connector): + # receiving this PLEASE triggers creation of the Connector + m.rx_PLEASE({"side": FOLLOWER}) + self.assertEqual(h.send.mock_calls, []) + self.assertEqual(connector.mock_calls, [ + mock.call(b"\x00" * 32, None, m, h.reactor, h.eq, + False, # no_listen + None, # tor + None, # timing + LEADER, roles.LEADER), + ]) + self.assertEqual(c.mock_calls, [mock.call.start()]) + clear_mock_calls(connector, c) + + self.assertNoResult(wfc_d) + + # now any inbound hints should get passed to our Connector + with mock.patch("wormhole._dilation.manager.parse_hint", + side_effect=["p1", None, "p3"]) as ph: + m.rx_HINTS({"hints": [1, 2, 3]}) + self.assertEqual(ph.mock_calls, [mock.call(1), mock.call(2), mock.call(3)]) + self.assertEqual(c.mock_calls, [mock.call.got_hints(["p1", "p3"])]) + clear_mock_calls(ph, c) + + # and we send out any (listening) hints from our Connector + m.send_hints([1, 2]) + self.assertEqual(h.send.mock_calls, [ + mock.call.send("dilate-1", + dict_to_bytes({"type": "connection-hints", + "hints": [1, 2]})) + ]) + clear_mock_calls(h.send) + + # the first successful connection fires when_first_connected(), so + # the Dilator can create and return the endpoints + c1 = mock.Mock() + m.connector_connection_made(c1) + + self.assertEqual(h.inbound.mock_calls, [mock.call.use_connection(c1)]) + self.assertEqual(h.outbound.mock_calls, [mock.call.use_connection(c1)]) + clear_mock_calls(h.inbound, h.outbound) + + h.eq.flush_sync() + self.successResultOf(wfc_d) # fires with None + wfc_d2 = m.when_first_connected() + h.eq.flush_sync() + self.successResultOf(wfc_d2) + + scid0 = b"\x00\x00\x00\x00" + sc0 = mock.Mock() + m.set_subchannel_zero(scid0, sc0) + listen_ep = mock.Mock() + m.set_listener_endpoint(listen_ep) + self.assertEqual(h.inbound.mock_calls, [ + mock.call.set_subchannel_zero(scid0, sc0), + mock.call.set_listener_endpoint(listen_ep), + ]) + clear_mock_calls(h.inbound) + + # the Leader making a new outbound channel should get scid=1 + scid1 = to_be4(1) + self.assertEqual(m.allocate_subchannel_id(), scid1) + r1 = Open(10, scid1) # seqnum=10 + h.outbound.build_record = mock.Mock(return_value=r1) + m.send_open(scid1) + self.assertEqual(h.outbound.mock_calls, [ + mock.call.build_record(Open, scid1), + mock.call.queue_and_send_record(r1), + ]) + clear_mock_calls(h.outbound) + + r2 = Data(11, scid1, b"data") + h.outbound.build_record = mock.Mock(return_value=r2) + m.send_data(scid1, b"data") + self.assertEqual(h.outbound.mock_calls, [ + mock.call.build_record(Data, scid1, b"data"), + mock.call.queue_and_send_record(r2), + ]) + clear_mock_calls(h.outbound) + + r3 = Close(12, scid1) + h.outbound.build_record = mock.Mock(return_value=r3) + m.send_close(scid1) + self.assertEqual(h.outbound.mock_calls, [ + mock.call.build_record(Close, scid1), + mock.call.queue_and_send_record(r3), + ]) + clear_mock_calls(h.outbound) + + # ack the OPEN + m.got_record(Ack(10)) + self.assertEqual(h.outbound.mock_calls, [ + mock.call.handle_ack(10) + ]) + clear_mock_calls(h.outbound) + + # test that inbound records get acked and routed to Inbound + h.inbound.is_record_old = mock.Mock(return_value=False) + scid2 = to_be4(2) + o200 = Open(200, scid2) + m.got_record(o200) + self.assertEqual(h.outbound.mock_calls, [ + mock.call.send_if_connected(Ack(200)) + ]) + self.assertEqual(h.inbound.mock_calls, [ + mock.call.is_record_old(o200), + mock.call.update_ack_watermark(200), + mock.call.handle_open(scid2), + ]) + clear_mock_calls(h.outbound, h.inbound) + + # old (duplicate) records should provoke new Acks, but not get + # forwarded + h.inbound.is_record_old = mock.Mock(return_value=True) + m.got_record(o200) + self.assertEqual(h.outbound.mock_calls, [ + mock.call.send_if_connected(Ack(200)) + ]) + self.assertEqual(h.inbound.mock_calls, [ + mock.call.is_record_old(o200), + ]) + clear_mock_calls(h.outbound, h.inbound) + + # check Data and Close too + h.inbound.is_record_old = mock.Mock(return_value=False) + d201 = Data(201, scid2, b"data") + m.got_record(d201) + self.assertEqual(h.outbound.mock_calls, [ + mock.call.send_if_connected(Ack(201)) + ]) + self.assertEqual(h.inbound.mock_calls, [ + mock.call.is_record_old(d201), + mock.call.update_ack_watermark(201), + mock.call.handle_data(scid2, b"data"), + ]) + clear_mock_calls(h.outbound, h.inbound) + + c202 = Close(202, scid2) + m.got_record(c202) + self.assertEqual(h.outbound.mock_calls, [ + mock.call.send_if_connected(Ack(202)) + ]) + self.assertEqual(h.inbound.mock_calls, [ + mock.call.is_record_old(c202), + mock.call.update_ack_watermark(202), + mock.call.handle_close(scid2), + ]) + clear_mock_calls(h.outbound, h.inbound) + + # Now we lose the connection. The Leader should tell the other side + # that we're reconnecting. + + m.connector_connection_lost() + self.assertEqual(h.send.mock_calls, [ + mock.call.send("dilate-2", + dict_to_bytes({"type": "reconnect"})) + ]) + self.assertEqual(h.inbound.mock_calls, [ + mock.call.stop_using_connection() + ]) + self.assertEqual(h.outbound.mock_calls, [ + mock.call.stop_using_connection() + ]) + clear_mock_calls(h.send, h.inbound, h.outbound) + + # leader does nothing (stays in FLUSHING) until the follower acks by + # sending RECONNECTING + + # inbound hints should be ignored during FLUSHING + with mock.patch("wormhole._dilation.manager.parse_hint", + return_value=None) as ph: + m.rx_HINTS({"hints": [1, 2, 3]}) + self.assertEqual(ph.mock_calls, []) # ignored + + c2 = mock.Mock() + connector2 = mock.Mock(return_value=c2) + with mock.patch("wormhole._dilation.manager.Connector", connector2): + # this triggers creation of a new Connector + m.rx_RECONNECTING() + self.assertEqual(h.send.mock_calls, []) + self.assertEqual(connector2.mock_calls, [ + mock.call(b"\x00" * 32, None, m, h.reactor, h.eq, + False, # no_listen + None, # tor + None, # timing + LEADER, roles.LEADER), + ]) + self.assertEqual(c2.mock_calls, [mock.call.start()]) + clear_mock_calls(connector2, c2) + + self.assertEqual(h.inbound.mock_calls, []) + self.assertEqual(h.outbound.mock_calls, []) + + # and a new connection should re-register with Inbound/Outbound, + # which are responsible for re-sending unacked queued messages + c3 = mock.Mock() + m.connector_connection_made(c3) + + self.assertEqual(h.inbound.mock_calls, [mock.call.use_connection(c3)]) + self.assertEqual(h.outbound.mock_calls, [mock.call.use_connection(c3)]) + clear_mock_calls(h.inbound, h.outbound) + + def test_follower(self): + m, h = make_manager(leader=False) + + m.start() + self.assertEqual(h.send.mock_calls, [ + mock.call.send("dilate-0", + dict_to_bytes({"type": "please", "side": FOLLOWER})) + ]) + clear_mock_calls(h.send) + + c = mock.Mock() + connector = mock.Mock(return_value=c) + with mock.patch("wormhole._dilation.manager.Connector", connector): + # receiving this PLEASE triggers creation of the Connector + m.rx_PLEASE({"side": LEADER}) + self.assertEqual(h.send.mock_calls, []) + self.assertEqual(connector.mock_calls, [ + mock.call(b"\x00" * 32, None, m, h.reactor, h.eq, + False, # no_listen + None, # tor + None, # timing + FOLLOWER, roles.FOLLOWER), + ]) + self.assertEqual(c.mock_calls, [mock.call.start()]) + clear_mock_calls(connector, c) + + # get connected, then lose the connection + c1 = mock.Mock() + m.connector_connection_made(c1) + self.assertEqual(h.inbound.mock_calls, [mock.call.use_connection(c1)]) + self.assertEqual(h.outbound.mock_calls, [mock.call.use_connection(c1)]) + clear_mock_calls(h.inbound, h.outbound) + + # now lose the connection. As the follower, we don't notify the + # leader, we just wait for them to notice + m.connector_connection_lost() + self.assertEqual(h.send.mock_calls, []) + self.assertEqual(h.inbound.mock_calls, [ + mock.call.stop_using_connection() + ]) + self.assertEqual(h.outbound.mock_calls, [ + mock.call.stop_using_connection() + ]) + clear_mock_calls(h.send, h.inbound, h.outbound) + + # now we get a RECONNECT: we should send RECONNECTING + c2 = mock.Mock() + connector2 = mock.Mock(return_value=c2) + with mock.patch("wormhole._dilation.manager.Connector", connector2): + m.rx_RECONNECT() + self.assertEqual(h.send.mock_calls, [ + mock.call.send("dilate-1", + dict_to_bytes({"type": "reconnecting"})) + ]) + self.assertEqual(connector2.mock_calls, [ + mock.call(b"\x00" * 32, None, m, h.reactor, h.eq, + False, # no_listen + None, # tor + None, # timing + FOLLOWER, roles.FOLLOWER), + ]) + self.assertEqual(c2.mock_calls, [mock.call.start()]) + clear_mock_calls(connector2, c2) + + # while we're trying to connect, we get told to stop again, so we + # should abandon the connection attempt and start another + c3 = mock.Mock() + connector3 = mock.Mock(return_value=c3) + with mock.patch("wormhole._dilation.manager.Connector", connector3): + m.rx_RECONNECT() + self.assertEqual(c2.mock_calls, [mock.call.stop()]) + self.assertEqual(connector3.mock_calls, [ + mock.call(b"\x00" * 32, None, m, h.reactor, h.eq, + False, # no_listen + None, # tor + None, # timing + FOLLOWER, roles.FOLLOWER), + ]) + self.assertEqual(c3.mock_calls, [mock.call.start()]) + clear_mock_calls(c2, connector3, c3) + + m.connector_connection_made(c3) + # finally if we're already connected, rx_RECONNECT means we should + # abandon this connection (even though it still looks ok to us), then + # when the attempt is finished stopping, we should start another + + m.rx_RECONNECT() + + c4 = mock.Mock() + connector4 = mock.Mock(return_value=c4) + with mock.patch("wormhole._dilation.manager.Connector", connector4): + m.connector_connection_lost() + self.assertEqual(c3.mock_calls, [mock.call.disconnect()]) + self.assertEqual(connector4.mock_calls, [ + mock.call(b"\x00" * 32, None, m, h.reactor, h.eq, + False, # no_listen + None, # tor + None, # timing + FOLLOWER, roles.FOLLOWER), + ]) + self.assertEqual(c4.mock_calls, [mock.call.start()]) + clear_mock_calls(c3, connector4, c4) + + def test_mirror(self): + # receive a PLEASE with the same side as us: shouldn't happen + m, h = make_manager(leader=True) + + m.start() + clear_mock_calls(h.send) + e = self.assertRaises(ValueError, m.rx_PLEASE, {"side": LEADER}) + self.assertEqual(str(e), "their side shouldn't be equal: reflection?") + + def test_ping_pong(self): + m, h = make_manager(leader=False) + + m.got_record(KCM()) + self.flushLoggedErrors(UnexpectedKCM) + + m.got_record(Ping(1)) + self.assertEqual(h.outbound.mock_calls, + [mock.call.send_if_connected(Pong(1))]) + clear_mock_calls(h.outbound) + + m.got_record(Pong(2)) + # currently ignored, will eventually update a timer + + m.got_record("not recognized") + e = self.flushLoggedErrors(UnknownMessageType) + self.assertEqual(len(e), 1) + self.assertEqual(str(e[0].value), "not recognized") + + m.send_ping(3) + self.assertEqual(h.outbound.mock_calls, + [mock.call.send_if_connected(Pong(3))]) + clear_mock_calls(h.outbound) + + def test_subchannel(self): + m, h = make_manager(leader=True) + sc = object() + + m.subchannel_pauseProducing(sc) + self.assertEqual(h.inbound.mock_calls, [ + mock.call.subchannel_pauseProducing(sc)]) + clear_mock_calls(h.inbound) + + m.subchannel_resumeProducing(sc) + self.assertEqual(h.inbound.mock_calls, [ + mock.call.subchannel_resumeProducing(sc)]) + clear_mock_calls(h.inbound) + + m.subchannel_stopProducing(sc) + self.assertEqual(h.inbound.mock_calls, [ + mock.call.subchannel_stopProducing(sc)]) + clear_mock_calls(h.inbound) + + p = object() + streaming = object() + + m.subchannel_registerProducer(sc, p, streaming) + self.assertEqual(h.outbound.mock_calls, [ + mock.call.subchannel_registerProducer(sc, p, streaming)]) + clear_mock_calls(h.outbound) + + m.subchannel_unregisterProducer(sc) + self.assertEqual(h.outbound.mock_calls, [ + mock.call.subchannel_unregisterProducer(sc)]) + clear_mock_calls(h.outbound) + + m.subchannel_closed("scid", sc) + self.assertEqual(h.inbound.mock_calls, [ + mock.call.subchannel_closed("scid", sc)]) + self.assertEqual(h.outbound.mock_calls, [ + mock.call.subchannel_closed("scid", sc)]) + clear_mock_calls(h.inbound, h.outbound) diff --git a/src/wormhole/test/dilate/test_outbound.py b/src/wormhole/test/dilate/test_outbound.py new file mode 100644 index 0000000..6ba5264 --- /dev/null +++ b/src/wormhole/test/dilate/test_outbound.py @@ -0,0 +1,655 @@ +from __future__ import print_function, unicode_literals +from collections import namedtuple +from itertools import cycle +import mock +from zope.interface import alsoProvides +from twisted.trial import unittest +from twisted.internet.task import Clock, Cooperator +from twisted.internet.interfaces import IPullProducer +from ...eventual import EventualQueue +from ..._interfaces import IDilationManager +from ..._dilation.connection import KCM, Open, Data, Close, Ack +from ..._dilation.outbound import Outbound, PullToPush +from .common import clear_mock_calls + +Pauser = namedtuple("Pauser", ["seqnum"]) +NonPauser = namedtuple("NonPauser", ["seqnum"]) +Stopper = namedtuple("Stopper", ["sc"]) + + +def make_outbound(): + m = mock.Mock() + alsoProvides(m, IDilationManager) + clock = Clock() + eq = EventualQueue(clock) + term = mock.Mock(side_effect=lambda: True) # one write per Eventual tick + + def term_factory(): + return term + coop = Cooperator(terminationPredicateFactory=term_factory, + scheduler=eq.eventually) + o = Outbound(m, coop) + c = mock.Mock() # Connection + + def maybe_pause(r): + if isinstance(r, Pauser): + o.pauseProducing() + elif isinstance(r, Stopper): + o.subchannel_unregisterProducer(r.sc) + c.send_record = mock.Mock(side_effect=maybe_pause) + o._test_eq = eq + o._test_term = term + return o, m, c + + +class OutboundTest(unittest.TestCase): + def test_build_record(self): + o, m, c = make_outbound() + scid1 = b"scid" + self.assertEqual(o.build_record(Open, scid1), + Open(seqnum=0, scid=b"scid")) + self.assertEqual(o.build_record(Data, scid1, b"dataaa"), + Data(seqnum=1, scid=b"scid", data=b"dataaa")) + self.assertEqual(o.build_record(Close, scid1), + Close(seqnum=2, scid=b"scid")) + self.assertEqual(o.build_record(Close, scid1), + Close(seqnum=3, scid=b"scid")) + + def test_outbound_queue(self): + o, m, c = make_outbound() + scid1 = b"scid" + r1 = o.build_record(Open, scid1) + r2 = o.build_record(Data, scid1, b"data1") + r3 = o.build_record(Data, scid1, b"data2") + o.queue_and_send_record(r1) + o.queue_and_send_record(r2) + o.queue_and_send_record(r3) + self.assertEqual(list(o._outbound_queue), [r1, r2, r3]) + + # we would never normally receive an ACK without first getting a + # connection + o.handle_ack(r2.seqnum) + self.assertEqual(list(o._outbound_queue), [r3]) + + o.handle_ack(r3.seqnum) + self.assertEqual(list(o._outbound_queue), []) + + o.handle_ack(r3.seqnum) # ignored + self.assertEqual(list(o._outbound_queue), []) + + o.handle_ack(r1.seqnum) # ignored + self.assertEqual(list(o._outbound_queue), []) + + def test_duplicate_registerProducer(self): + o, m, c = make_outbound() + sc1 = object() + p1 = mock.Mock() + o.subchannel_registerProducer(sc1, p1, True) + with self.assertRaises(ValueError) as ar: + o.subchannel_registerProducer(sc1, p1, True) + s = str(ar.exception) + self.assertIn("registering producer", s) + self.assertIn("before previous one", s) + self.assertIn("was unregistered", s) + + def test_connection_send_queued_unpaused(self): + o, m, c = make_outbound() + scid1 = b"scid" + r1 = o.build_record(Open, scid1) + r2 = o.build_record(Data, scid1, b"data1") + r3 = o.build_record(Data, scid1, b"data2") + o.queue_and_send_record(r1) + o.queue_and_send_record(r2) + self.assertEqual(list(o._outbound_queue), [r1, r2]) + self.assertEqual(list(o._queued_unsent), []) + + # as soon as the connection is established, everything is sent + o.use_connection(c) + self.assertEqual(c.mock_calls, [mock.call.registerProducer(o, True), + mock.call.send_record(r1), + mock.call.send_record(r2)]) + self.assertEqual(list(o._outbound_queue), [r1, r2]) + self.assertEqual(list(o._queued_unsent), []) + clear_mock_calls(c) + + o.queue_and_send_record(r3) + self.assertEqual(list(o._outbound_queue), [r1, r2, r3]) + self.assertEqual(list(o._queued_unsent), []) + self.assertEqual(c.mock_calls, [mock.call.send_record(r3)]) + + def test_connection_send_queued_paused(self): + o, m, c = make_outbound() + r1 = Pauser(seqnum=1) + r2 = Pauser(seqnum=2) + r3 = Pauser(seqnum=3) + o.queue_and_send_record(r1) + o.queue_and_send_record(r2) + self.assertEqual(list(o._outbound_queue), [r1, r2]) + self.assertEqual(list(o._queued_unsent), []) + + # pausing=True, so our mock Manager will pause the Outbound producer + # after each write. So only r1 should have been sent before getting + # paused + o.use_connection(c) + self.assertEqual(c.mock_calls, [mock.call.registerProducer(o, True), + mock.call.send_record(r1)]) + self.assertEqual(list(o._outbound_queue), [r1, r2]) + self.assertEqual(list(o._queued_unsent), [r2]) + clear_mock_calls(c) + + # Outbound is responsible for sending all records, so when Manager + # wants to send a new one, and Outbound is still in the middle of + # draining the beginning-of-connection queue, the new message gets + # queued behind the rest (in addition to being queued in + # _outbound_queue until an ACK retires it). + o.queue_and_send_record(r3) + self.assertEqual(list(o._outbound_queue), [r1, r2, r3]) + self.assertEqual(list(o._queued_unsent), [r2, r3]) + self.assertEqual(c.mock_calls, []) + + o.handle_ack(r1.seqnum) + self.assertEqual(list(o._outbound_queue), [r2, r3]) + self.assertEqual(list(o._queued_unsent), [r2, r3]) + self.assertEqual(c.mock_calls, []) + + def test_premptive_ack(self): + # one mode I have in mind is for each side to send an immediate ACK, + # with everything they've ever seen, as the very first message on each + # new connection. The idea is that you might preempt sending stuff from + # the _queued_unsent list if it arrives fast enough (in practice this + # is more likely to be delivered via the DILATE mailbox message, but + # the effects might be vaguely similar, so it seems worth testing + # here). A similar situation would be if each side sends ACKs with the + # highest seqnum they've ever seen, instead of merely ACKing the + # message which was just received. + o, m, c = make_outbound() + r1 = Pauser(seqnum=1) + r2 = Pauser(seqnum=2) + r3 = Pauser(seqnum=3) + o.queue_and_send_record(r1) + o.queue_and_send_record(r2) + self.assertEqual(list(o._outbound_queue), [r1, r2]) + self.assertEqual(list(o._queued_unsent), []) + + o.use_connection(c) + self.assertEqual(c.mock_calls, [mock.call.registerProducer(o, True), + mock.call.send_record(r1)]) + self.assertEqual(list(o._outbound_queue), [r1, r2]) + self.assertEqual(list(o._queued_unsent), [r2]) + clear_mock_calls(c) + + o.queue_and_send_record(r3) + self.assertEqual(list(o._outbound_queue), [r1, r2, r3]) + self.assertEqual(list(o._queued_unsent), [r2, r3]) + self.assertEqual(c.mock_calls, []) + + o.handle_ack(r2.seqnum) + self.assertEqual(list(o._outbound_queue), [r3]) + self.assertEqual(list(o._queued_unsent), [r3]) + self.assertEqual(c.mock_calls, []) + + def test_pause(self): + o, m, c = make_outbound() + o.use_connection(c) + self.assertEqual(c.mock_calls, [mock.call.registerProducer(o, True)]) + self.assertEqual(list(o._outbound_queue), []) + self.assertEqual(list(o._queued_unsent), []) + clear_mock_calls(c) + + sc1, sc2, sc3 = object(), object(), object() + p1, p2, p3 = mock.Mock(name="p1"), mock.Mock( + name="p2"), mock.Mock(name="p3") + + # we aren't paused yet, since we haven't sent any data + o.subchannel_registerProducer(sc1, p1, True) + self.assertEqual(p1.mock_calls, []) + + r1 = Pauser(seqnum=1) + o.queue_and_send_record(r1) + # now we should be paused + self.assertTrue(o._paused) + self.assertEqual(c.mock_calls, [mock.call.send_record(r1)]) + self.assertEqual(p1.mock_calls, [mock.call.pauseProducing()]) + clear_mock_calls(p1, c) + + # so an IPushProducer will be paused right away + o.subchannel_registerProducer(sc2, p2, True) + self.assertEqual(p2.mock_calls, [mock.call.pauseProducing()]) + clear_mock_calls(p2) + + o.subchannel_registerProducer(sc3, p3, True) + self.assertEqual(p3.mock_calls, [mock.call.pauseProducing()]) + self.assertEqual(o._paused_producers, set([p1, p2, p3])) + self.assertEqual(list(o._all_producers), [p1, p2, p3]) + clear_mock_calls(p3) + + # one resumeProducing should cause p1 to get a turn, since p2 was added + # after we were paused and p1 was at the "end" of a one-element list. + # If it writes anything, it will get paused again immediately. + r2 = Pauser(seqnum=2) + p1.resumeProducing.side_effect = lambda: c.send_record(r2) + o.resumeProducing() + self.assertEqual(p1.mock_calls, [mock.call.resumeProducing(), + mock.call.pauseProducing(), + ]) + self.assertEqual(p2.mock_calls, []) + self.assertEqual(p3.mock_calls, []) + self.assertEqual(c.mock_calls, [mock.call.send_record(r2)]) + clear_mock_calls(p1, p2, p3, c) + # p2 should now be at the head of the queue + self.assertEqual(list(o._all_producers), [p2, p3, p1]) + + # next turn: p2 has nothing to send, but p3 does. we should see p3 + # called but not p1. The actual sequence of expected calls is: + # p2.resume, p3.resume, pauseProducing, set(p2.pause, p3.pause) + r3 = Pauser(seqnum=3) + p2.resumeProducing.side_effect = lambda: None + p3.resumeProducing.side_effect = lambda: c.send_record(r3) + o.resumeProducing() + self.assertEqual(p1.mock_calls, []) + self.assertEqual(p2.mock_calls, [mock.call.resumeProducing(), + mock.call.pauseProducing(), + ]) + self.assertEqual(p3.mock_calls, [mock.call.resumeProducing(), + mock.call.pauseProducing(), + ]) + self.assertEqual(c.mock_calls, [mock.call.send_record(r3)]) + clear_mock_calls(p1, p2, p3, c) + # p1 should now be at the head of the queue + self.assertEqual(list(o._all_producers), [p1, p2, p3]) + + # next turn: p1 has data to send, but not enough to cause a pause. same + # for p2. p3 causes a pause + r4 = NonPauser(seqnum=4) + r5 = NonPauser(seqnum=5) + r6 = Pauser(seqnum=6) + p1.resumeProducing.side_effect = lambda: c.send_record(r4) + p2.resumeProducing.side_effect = lambda: c.send_record(r5) + p3.resumeProducing.side_effect = lambda: c.send_record(r6) + o.resumeProducing() + self.assertEqual(p1.mock_calls, [mock.call.resumeProducing(), + mock.call.pauseProducing(), + ]) + self.assertEqual(p2.mock_calls, [mock.call.resumeProducing(), + mock.call.pauseProducing(), + ]) + self.assertEqual(p3.mock_calls, [mock.call.resumeProducing(), + mock.call.pauseProducing(), + ]) + self.assertEqual(c.mock_calls, [mock.call.send_record(r4), + mock.call.send_record(r5), + mock.call.send_record(r6), + ]) + clear_mock_calls(p1, p2, p3, c) + # p1 should now be at the head of the queue again + self.assertEqual(list(o._all_producers), [p1, p2, p3]) + + # now we let it catch up. p1 and p2 send non-pausing data, p3 sends + # nothing. + r7 = NonPauser(seqnum=4) + r8 = NonPauser(seqnum=5) + p1.resumeProducing.side_effect = lambda: c.send_record(r7) + p2.resumeProducing.side_effect = lambda: c.send_record(r8) + p3.resumeProducing.side_effect = lambda: None + + o.resumeProducing() + self.assertEqual(p1.mock_calls, [mock.call.resumeProducing(), + ]) + self.assertEqual(p2.mock_calls, [mock.call.resumeProducing(), + ]) + self.assertEqual(p3.mock_calls, [mock.call.resumeProducing(), + ]) + self.assertEqual(c.mock_calls, [mock.call.send_record(r7), + mock.call.send_record(r8), + ]) + clear_mock_calls(p1, p2, p3, c) + # p1 should now be at the head of the queue again + self.assertEqual(list(o._all_producers), [p1, p2, p3]) + self.assertFalse(o._paused) + + # now a producer disconnects itself (spontaneously, not from inside a + # resumeProducing) + o.subchannel_unregisterProducer(sc1) + self.assertEqual(list(o._all_producers), [p2, p3]) + self.assertEqual(p1.mock_calls, []) + self.assertFalse(o._paused) + + # and another disconnects itself when called + p2.resumeProducing.side_effect = lambda: None + p3.resumeProducing.side_effect = lambda: o.subchannel_unregisterProducer( + sc3) + o.pauseProducing() + o.resumeProducing() + self.assertEqual(p2.mock_calls, [mock.call.pauseProducing(), + mock.call.resumeProducing()]) + self.assertEqual(p3.mock_calls, [mock.call.pauseProducing(), + mock.call.resumeProducing()]) + clear_mock_calls(p2, p3) + self.assertEqual(list(o._all_producers), [p2]) + self.assertFalse(o._paused) + + def test_subchannel_closed(self): + o, m, c = make_outbound() + + sc1 = mock.Mock() + p1 = mock.Mock(name="p1") + o.subchannel_registerProducer(sc1, p1, True) + self.assertEqual(p1.mock_calls, [mock.call.pauseProducing()]) + clear_mock_calls(p1) + + o.subchannel_closed(sc1) + self.assertEqual(p1.mock_calls, []) + self.assertEqual(list(o._all_producers), []) + + sc2 = mock.Mock() + o.subchannel_closed(sc2) + + def test_disconnect(self): + o, m, c = make_outbound() + o.use_connection(c) + + sc1 = mock.Mock() + p1 = mock.Mock(name="p1") + o.subchannel_registerProducer(sc1, p1, True) + self.assertEqual(p1.mock_calls, []) + o.stop_using_connection() + self.assertEqual(p1.mock_calls, [mock.call.pauseProducing()]) + + def OFF_test_push_pull(self): + # use one IPushProducer and one IPullProducer. They should take turns + o, m, c = make_outbound() + o.use_connection(c) + clear_mock_calls(c) + + sc1, sc2 = object(), object() + p1, p2 = mock.Mock(name="p1"), mock.Mock(name="p2") + r1 = Pauser(seqnum=1) + r2 = NonPauser(seqnum=2) + + # we aren't paused yet, since we haven't sent any data + o.subchannel_registerProducer(sc1, p1, True) # push + o.queue_and_send_record(r1) + # now we're paused + self.assertTrue(o._paused) + self.assertEqual(c.mock_calls, [mock.call.send_record(r1)]) + self.assertEqual(p1.mock_calls, [mock.call.pauseProducing()]) + self.assertEqual(p2.mock_calls, []) + clear_mock_calls(p1, p2, c) + + p1.resumeProducing.side_effect = lambda: c.send_record(r1) + p2.resumeProducing.side_effect = lambda: c.send_record(r2) + o.subchannel_registerProducer(sc2, p2, False) # pull: always ready + + # p1 is still first, since p2 was just added (at the end) + self.assertTrue(o._paused) + self.assertEqual(c.mock_calls, []) + self.assertEqual(p1.mock_calls, []) + self.assertEqual(p2.mock_calls, []) + self.assertEqual(list(o._all_producers), [p1, p2]) + clear_mock_calls(p1, p2, c) + + # resume should send r1, which should pause everything + o.resumeProducing() + self.assertTrue(o._paused) + self.assertEqual(c.mock_calls, [mock.call.send_record(r1), + ]) + self.assertEqual(p1.mock_calls, [mock.call.resumeProducing(), + mock.call.pauseProducing(), + ]) + self.assertEqual(p2.mock_calls, []) + self.assertEqual(list(o._all_producers), [p2, p1]) # now p2 is next + clear_mock_calls(p1, p2, c) + + # next should fire p2, then p1 + o.resumeProducing() + self.assertTrue(o._paused) + self.assertEqual(c.mock_calls, [mock.call.send_record(r2), + mock.call.send_record(r1), + ]) + self.assertEqual(p1.mock_calls, [mock.call.resumeProducing(), + mock.call.pauseProducing(), + ]) + self.assertEqual(p2.mock_calls, [mock.call.resumeProducing(), + ]) + self.assertEqual(list(o._all_producers), [p2, p1]) # p2 still at bat + clear_mock_calls(p1, p2, c) + + def test_pull_producer(self): + # a single pull producer should write until it is paused, rate-limited + # by the cooperator (so we'll see back-to-back resumeProducing calls + # until the Connection is paused, or 10ms have passed, whichever comes + # first, and if it's stopped by the timer, then the next EventualQueue + # turn will start it off again) + + o, m, c = make_outbound() + eq = o._test_eq + o.use_connection(c) + clear_mock_calls(c) + self.assertFalse(o._paused) + + sc1 = mock.Mock() + p1 = mock.Mock(name="p1") + alsoProvides(p1, IPullProducer) + + records = [NonPauser(seqnum=1)] * 10 + records.append(Pauser(seqnum=2)) + records.append(Stopper(sc1)) + it = iter(records) + p1.resumeProducing.side_effect = lambda: c.send_record(next(it)) + o.subchannel_registerProducer(sc1, p1, False) + eq.flush_sync() # fast forward into the glorious (paused) future + + self.assertTrue(o._paused) + self.assertEqual(c.mock_calls, + [mock.call.send_record(r) for r in records[:-1]]) + self.assertEqual(p1.mock_calls, + [mock.call.resumeProducing()] * (len(records) - 1)) + clear_mock_calls(c, p1) + + # next resumeProducing should cause it to disconnect + o.resumeProducing() + eq.flush_sync() + self.assertEqual(c.mock_calls, [mock.call.send_record(records[-1])]) + self.assertEqual(p1.mock_calls, [mock.call.resumeProducing()]) + self.assertEqual(len(o._all_producers), 0) + self.assertFalse(o._paused) + + def test_two_pull_producers(self): + # we should alternate between them until paused + p1_records = ([NonPauser(seqnum=i) for i in range(5)] + + [Pauser(seqnum=5)] + + [NonPauser(seqnum=i) for i in range(6, 10)]) + p2_records = ([NonPauser(seqnum=i) for i in range(10, 19)] + + [Pauser(seqnum=19)]) + expected1 = [NonPauser(0), NonPauser(10), + NonPauser(1), NonPauser(11), + NonPauser(2), NonPauser(12), + NonPauser(3), NonPauser(13), + NonPauser(4), NonPauser(14), + Pauser(5)] + expected2 = [NonPauser(15), + NonPauser(6), NonPauser(16), + NonPauser(7), NonPauser(17), + NonPauser(8), NonPauser(18), + NonPauser(9), Pauser(19), + ] + + o, m, c = make_outbound() + eq = o._test_eq + o.use_connection(c) + clear_mock_calls(c) + self.assertFalse(o._paused) + + sc1 = mock.Mock() + p1 = mock.Mock(name="p1") + alsoProvides(p1, IPullProducer) + it1 = iter(p1_records) + p1.resumeProducing.side_effect = lambda: c.send_record(next(it1)) + o.subchannel_registerProducer(sc1, p1, False) + + sc2 = mock.Mock() + p2 = mock.Mock(name="p2") + alsoProvides(p2, IPullProducer) + it2 = iter(p2_records) + p2.resumeProducing.side_effect = lambda: c.send_record(next(it2)) + o.subchannel_registerProducer(sc2, p2, False) + + eq.flush_sync() # fast forward into the glorious (paused) future + + sends = [mock.call.resumeProducing()] + self.assertTrue(o._paused) + self.assertEqual(c.mock_calls, + [mock.call.send_record(r) for r in expected1]) + self.assertEqual(p1.mock_calls, 6 * sends) + self.assertEqual(p2.mock_calls, 5 * sends) + clear_mock_calls(c, p1, p2) + + o.resumeProducing() + eq.flush_sync() + self.assertTrue(o._paused) + self.assertEqual(c.mock_calls, + [mock.call.send_record(r) for r in expected2]) + self.assertEqual(p1.mock_calls, 4 * sends) + self.assertEqual(p2.mock_calls, 5 * sends) + clear_mock_calls(c, p1, p2) + + def test_send_if_connected(self): + o, m, c = make_outbound() + o.send_if_connected(Ack(1)) # not connected yet + + o.use_connection(c) + o.send_if_connected(KCM()) + self.assertEqual(c.mock_calls, [mock.call.registerProducer(o, True), + mock.call.send_record(KCM())]) + + def test_tolerate_duplicate_pause_resume(self): + o, m, c = make_outbound() + self.assertTrue(o._paused) # no connection + o.use_connection(c) + self.assertFalse(o._paused) + o.pauseProducing() + self.assertTrue(o._paused) + o.pauseProducing() + self.assertTrue(o._paused) + o.resumeProducing() + self.assertFalse(o._paused) + o.resumeProducing() + self.assertFalse(o._paused) + + def test_stopProducing(self): + o, m, c = make_outbound() + o.use_connection(c) + self.assertFalse(o._paused) + o.stopProducing() # connection does this before loss + self.assertTrue(o._paused) + o.stop_using_connection() + self.assertTrue(o._paused) + + def test_resume_error(self): + o, m, c = make_outbound() + o.use_connection(c) + sc1 = mock.Mock() + p1 = mock.Mock(name="p1") + alsoProvides(p1, IPullProducer) + p1.resumeProducing.side_effect = PretendResumptionError + o.subchannel_registerProducer(sc1, p1, False) + o._test_eq.flush_sync() + # the error is supposed to automatically unregister the producer + self.assertEqual(list(o._all_producers), []) + self.flushLoggedErrors(PretendResumptionError) + + +def make_pushpull(pauses): + p = mock.Mock() + alsoProvides(p, IPullProducer) + unregister = mock.Mock() + + clock = Clock() + eq = EventualQueue(clock) + term = mock.Mock(side_effect=lambda: True) # one write per Eventual tick + + def term_factory(): + return term + coop = Cooperator(terminationPredicateFactory=term_factory, + scheduler=eq.eventually) + pp = PullToPush(p, unregister, coop) + + it = cycle(pauses) + + def action(i): + if isinstance(i, Exception): + raise i + elif i: + pp.pauseProducing() + p.resumeProducing.side_effect = lambda: action(next(it)) + return p, unregister, pp, eq + + +class PretendResumptionError(Exception): + pass + + +class PretendUnregisterError(Exception): + pass + + +class PushPull(unittest.TestCase): + # test our wrapper utility, which I copied from + # twisted.internet._producer_helpers since it isn't publically exposed + + def test_start_unpaused(self): + p, unr, pp, eq = make_pushpull([True]) # pause on each resumeProducing + # if it starts unpaused, it gets one write before being halted + pp.startStreaming(False) + eq.flush_sync() + self.assertEqual(p.mock_calls, [mock.call.resumeProducing()] * 1) + clear_mock_calls(p) + + # now each time we call resumeProducing, we should see one delivered to + # the underlying IPullProducer + pp.resumeProducing() + eq.flush_sync() + self.assertEqual(p.mock_calls, [mock.call.resumeProducing()] * 1) + + pp.stopStreaming() + pp.stopStreaming() # should tolerate this + + def test_start_unpaused_two_writes(self): + p, unr, pp, eq = make_pushpull([False, True]) # pause every other time + # it should get two writes, since the first didn't pause + pp.startStreaming(False) + eq.flush_sync() + self.assertEqual(p.mock_calls, [mock.call.resumeProducing()] * 2) + + def test_start_paused(self): + p, unr, pp, eq = make_pushpull([True]) # pause on each resumeProducing + pp.startStreaming(True) + eq.flush_sync() + self.assertEqual(p.mock_calls, []) + pp.stopStreaming() + + def test_stop(self): + p, unr, pp, eq = make_pushpull([True]) + pp.startStreaming(True) + pp.stopProducing() + eq.flush_sync() + self.assertEqual(p.mock_calls, [mock.call.stopProducing()]) + + def test_error(self): + p, unr, pp, eq = make_pushpull([PretendResumptionError()]) + unr.side_effect = lambda: pp.stopStreaming() + pp.startStreaming(False) + eq.flush_sync() + self.assertEqual(unr.mock_calls, [mock.call()]) + self.flushLoggedErrors(PretendResumptionError) + + def test_error_during_unregister(self): + p, unr, pp, eq = make_pushpull([PretendResumptionError()]) + unr.side_effect = PretendUnregisterError() + pp.startStreaming(False) + eq.flush_sync() + self.assertEqual(unr.mock_calls, [mock.call()]) + self.flushLoggedErrors(PretendResumptionError, PretendUnregisterError) + + # TODO: consider making p1/p2/p3 all elements of a shared Mock, maybe I + # could capture the inter-call ordering that way diff --git a/src/wormhole/test/dilate/test_parse.py b/src/wormhole/test/dilate/test_parse.py new file mode 100644 index 0000000..f7276a6 --- /dev/null +++ b/src/wormhole/test/dilate/test_parse.py @@ -0,0 +1,44 @@ +from __future__ import print_function, unicode_literals +import mock +from twisted.trial import unittest +from ..._dilation.connection import (parse_record, encode_record, + KCM, Ping, Pong, Open, Data, Close, Ack) + + +class Parse(unittest.TestCase): + def test_parse(self): + self.assertEqual(parse_record(b"\x00"), KCM()) + self.assertEqual(parse_record(b"\x01\x55\x44\x33\x22"), + Ping(ping_id=b"\x55\x44\x33\x22")) + self.assertEqual(parse_record(b"\x02\x55\x44\x33\x22"), + Pong(ping_id=b"\x55\x44\x33\x22")) + self.assertEqual(parse_record(b"\x03\x00\x00\x02\x01\x00\x00\x01\x00"), + Open(scid=513, seqnum=256)) + self.assertEqual(parse_record(b"\x04\x00\x00\x02\x02\x00\x00\x01\x01dataaa"), + Data(scid=514, seqnum=257, data=b"dataaa")) + self.assertEqual(parse_record(b"\x05\x00\x00\x02\x03\x00\x00\x01\x02"), + Close(scid=515, seqnum=258)) + self.assertEqual(parse_record(b"\x06\x00\x00\x01\x03"), + Ack(resp_seqnum=259)) + with mock.patch("wormhole._dilation.connection.log.err") as le: + with self.assertRaises(ValueError): + parse_record(b"\x07unknown") + self.assertEqual(le.mock_calls, + [mock.call("received unknown message type: {}".format( + b"\x07unknown"))]) + + def test_encode(self): + self.assertEqual(encode_record(KCM()), b"\x00") + self.assertEqual(encode_record(Ping(ping_id=b"ping")), b"\x01ping") + self.assertEqual(encode_record(Pong(ping_id=b"pong")), b"\x02pong") + self.assertEqual(encode_record(Open(scid=65536, seqnum=16)), + b"\x03\x00\x01\x00\x00\x00\x00\x00\x10") + self.assertEqual(encode_record(Data(scid=65537, seqnum=17, data=b"dataaa")), + b"\x04\x00\x01\x00\x01\x00\x00\x00\x11dataaa") + self.assertEqual(encode_record(Close(scid=65538, seqnum=18)), + b"\x05\x00\x01\x00\x02\x00\x00\x00\x12") + self.assertEqual(encode_record(Ack(resp_seqnum=19)), + b"\x06\x00\x00\x00\x13") + with self.assertRaises(TypeError) as ar: + encode_record("not a record") + self.assertEqual(str(ar.exception), "not a record") diff --git a/src/wormhole/test/dilate/test_record.py b/src/wormhole/test/dilate/test_record.py new file mode 100644 index 0000000..63a784c --- /dev/null +++ b/src/wormhole/test/dilate/test_record.py @@ -0,0 +1,269 @@ +from __future__ import print_function, unicode_literals +import mock +from zope.interface import alsoProvides +from twisted.trial import unittest +from ..._dilation._noise import NoiseInvalidMessage +from ..._dilation.connection import (IFramer, Frame, Prologue, + _Record, Handshake, + Disconnect, Ping) + + +def make_record(): + f = mock.Mock() + alsoProvides(f, IFramer) + n = mock.Mock() # pretends to be a Noise object + r = _Record(f, n) + return r, f, n + + +class Record(unittest.TestCase): + def test_good2(self): + f = mock.Mock() + alsoProvides(f, IFramer) + f.add_and_parse = mock.Mock(side_effect=[ + [], + [Prologue()], + [Frame(frame=b"rx-handshake")], + [Frame(frame=b"frame1"), Frame(frame=b"frame2")], + ]) + n = mock.Mock() + n.write_message = mock.Mock(return_value=b"tx-handshake") + p1, p2 = object(), object() + n.decrypt = mock.Mock(side_effect=[p1, p2]) + r = _Record(f, n) + self.assertEqual(f.mock_calls, []) + r.connectionMade() + self.assertEqual(f.mock_calls, [mock.call.connectionMade()]) + f.mock_calls[:] = [] + self.assertEqual(n.mock_calls, [mock.call.start_handshake()]) + n.mock_calls[:] = [] + + # Pretend to deliver the prologue in two parts. The text we send in + # doesn't matter: the side_effect= is what causes the prologue to be + # recognized by the second call. + self.assertEqual(list(r.add_and_unframe(b"pro")), []) + self.assertEqual(f.mock_calls, [mock.call.add_and_parse(b"pro")]) + f.mock_calls[:] = [] + self.assertEqual(n.mock_calls, []) + + # recognizing the prologue causes a handshake frame to be sent + self.assertEqual(list(r.add_and_unframe(b"logue")), []) + self.assertEqual(f.mock_calls, [mock.call.add_and_parse(b"logue"), + mock.call.send_frame(b"tx-handshake")]) + f.mock_calls[:] = [] + self.assertEqual(n.mock_calls, [mock.call.write_message()]) + n.mock_calls[:] = [] + + # next add_and_unframe is recognized as the Handshake + self.assertEqual(list(r.add_and_unframe(b"blah")), [Handshake()]) + self.assertEqual(f.mock_calls, [mock.call.add_and_parse(b"blah")]) + f.mock_calls[:] = [] + self.assertEqual(n.mock_calls, [mock.call.read_message(b"rx-handshake")]) + n.mock_calls[:] = [] + + # next is a pair of Records + r1, r2 = object(), object() + with mock.patch("wormhole._dilation.connection.parse_record", + side_effect=[r1, r2]) as pr: + self.assertEqual(list(r.add_and_unframe(b"blah2")), [r1, r2]) + self.assertEqual(n.mock_calls, [mock.call.decrypt(b"frame1"), + mock.call.decrypt(b"frame2")]) + self.assertEqual(pr.mock_calls, [mock.call(p1), mock.call(p2)]) + + def test_bad_handshake(self): + f = mock.Mock() + alsoProvides(f, IFramer) + f.add_and_parse = mock.Mock(return_value=[Prologue(), + Frame(frame=b"rx-handshake")]) + n = mock.Mock() + n.write_message = mock.Mock(return_value=b"tx-handshake") + nvm = NoiseInvalidMessage() + n.read_message = mock.Mock(side_effect=nvm) + r = _Record(f, n) + self.assertEqual(f.mock_calls, []) + r.connectionMade() + self.assertEqual(f.mock_calls, [mock.call.connectionMade()]) + f.mock_calls[:] = [] + self.assertEqual(n.mock_calls, [mock.call.start_handshake()]) + n.mock_calls[:] = [] + + with mock.patch("wormhole._dilation.connection.log.err") as le: + with self.assertRaises(Disconnect): + list(r.add_and_unframe(b"data")) + self.assertEqual(le.mock_calls, + [mock.call(nvm, "bad inbound noise handshake")]) + + def test_bad_message(self): + f = mock.Mock() + alsoProvides(f, IFramer) + f.add_and_parse = mock.Mock(return_value=[Prologue(), + Frame(frame=b"rx-handshake"), + Frame(frame=b"bad-message")]) + n = mock.Mock() + n.write_message = mock.Mock(return_value=b"tx-handshake") + nvm = NoiseInvalidMessage() + n.decrypt = mock.Mock(side_effect=nvm) + r = _Record(f, n) + self.assertEqual(f.mock_calls, []) + r.connectionMade() + self.assertEqual(f.mock_calls, [mock.call.connectionMade()]) + f.mock_calls[:] = [] + self.assertEqual(n.mock_calls, [mock.call.start_handshake()]) + n.mock_calls[:] = [] + + with mock.patch("wormhole._dilation.connection.log.err") as le: + with self.assertRaises(Disconnect): + list(r.add_and_unframe(b"data")) + self.assertEqual(le.mock_calls, + [mock.call(nvm, "bad inbound noise frame")]) + + def test_send_record(self): + f = mock.Mock() + alsoProvides(f, IFramer) + n = mock.Mock() + f1 = object() + n.encrypt = mock.Mock(return_value=f1) + r1 = Ping(b"pingid") + r = _Record(f, n) + self.assertEqual(f.mock_calls, []) + m1 = object() + with mock.patch("wormhole._dilation.connection.encode_record", + return_value=m1) as er: + r.send_record(r1) + self.assertEqual(er.mock_calls, [mock.call(r1)]) + self.assertEqual(n.mock_calls, [mock.call.start_handshake(), + mock.call.encrypt(m1)]) + self.assertEqual(f.mock_calls, [mock.call.send_frame(f1)]) + + def test_good(self): + # Exercise the success path. The Record instance is given each chunk + # of data as it arrives on Protocol.dataReceived, and is supposed to + # return a series of Tokens (maybe none, if the chunk was incomplete, + # or more than one, if the chunk was larger). Internally, it delivers + # the chunks to the Framer for unframing (which returns 0 or more + # frames), manages the Noise decryption object, and parses any + # decrypted messages into tokens (some of which are consumed + # internally, others for delivery upstairs). + # + # in the normal flow, we get: + # + # | | Inbound | NoiseAction | Outbound | ToUpstairs | + # | | - | - | - | - | + # | 1 | | | prologue | | + # | 2 | prologue | | | | + # | 3 | | write_message | handshake | | + # | 4 | handshake | read_message | | Handshake | + # | 5 | | encrypt | KCM | | + # | 6 | KCM | decrypt | | KCM | + # | 7 | msg1 | decrypt | | msg1 | + + # 1: instantiating the Record instance causes the outbound prologue + # to be sent + + # 2+3: receipt of the inbound prologue triggers creation of the + # ephemeral key (the "handshake") by calling noise.write_message() + # and then writes the handshake to the outbound transport + + # 4: when the peer's handshake is received, it is delivered to + # noise.read_message(), which generates the shared key (enabling + # noise.send() and noise.decrypt()). It also delivers the Handshake + # token upstairs, which might (on the Follower) trigger immediate + # transmission of the Key Confirmation Message (KCM) + + # 5: the outbound KCM is framed and fed into noise.encrypt(), then + # sent outbound + + # 6: the peer's KCM is decrypted then delivered upstairs. The + # Follower treats this as a signal that it should use this connection + # (and drop all others). + + # 7: the peer's first message is decrypted, parsed, and delivered + # upstairs. This might be an Open or a Data, depending upon what + # queued messages were left over from the previous connection + + r, f, n = make_record() + outbound_handshake = object() + kcm, msg1 = object(), object() + f_kcm, f_msg1 = object(), object() + n.write_message = mock.Mock(return_value=outbound_handshake) + n.decrypt = mock.Mock(side_effect=[kcm, msg1]) + n.encrypt = mock.Mock(side_effect=[f_kcm, f_msg1]) + f.add_and_parse = mock.Mock(side_effect=[[], # no tokens yet + [Prologue()], + [Frame("f_handshake")], + [Frame("f_kcm"), + Frame("f_msg1")], + ]) + + self.assertEqual(f.mock_calls, []) + self.assertEqual(n.mock_calls, [mock.call.start_handshake()]) + n.mock_calls[:] = [] + + # 1. The Framer is responsible for sending the prologue, so we don't + # have to check that here, we just check that the Framer was told + # about connectionMade properly. + r.connectionMade() + self.assertEqual(f.mock_calls, [mock.call.connectionMade()]) + self.assertEqual(n.mock_calls, []) + f.mock_calls[:] = [] + + # 2 + # we dribble the prologue in over two messages, to make sure we can + # handle a dataReceived that doesn't complete the token + + # remember, add_and_unframe is a generator + self.assertEqual(list(r.add_and_unframe(b"pro")), []) + self.assertEqual(f.mock_calls, [mock.call.add_and_parse(b"pro")]) + self.assertEqual(n.mock_calls, []) + f.mock_calls[:] = [] + + self.assertEqual(list(r.add_and_unframe(b"logue")), []) + # 3: write_message, send outbound handshake + self.assertEqual(f.mock_calls, [mock.call.add_and_parse(b"logue"), + mock.call.send_frame(outbound_handshake), + ]) + self.assertEqual(n.mock_calls, [mock.call.write_message()]) + f.mock_calls[:] = [] + n.mock_calls[:] = [] + + # 4 + # Now deliver the Noise "handshake", the ephemeral public key. This + # is framed, but not a record, so it shouldn't decrypt or parse + # anything, but the handshake is delivered to the Noise object, and + # it does return a Handshake token so we can let the next layer up + # react (by sending the KCM frame if we're a Follower, or not if + # we're the Leader) + + self.assertEqual(list(r.add_and_unframe(b"handshake")), [Handshake()]) + self.assertEqual(f.mock_calls, [mock.call.add_and_parse(b"handshake")]) + self.assertEqual(n.mock_calls, [mock.call.read_message("f_handshake")]) + f.mock_calls[:] = [] + n.mock_calls[:] = [] + + # 5: at this point we ought to be able to send a messge, the KCM + with mock.patch("wormhole._dilation.connection.encode_record", + side_effect=[b"r-kcm"]) as er: + r.send_record(kcm) + self.assertEqual(er.mock_calls, [mock.call(kcm)]) + self.assertEqual(n.mock_calls, [mock.call.encrypt(b"r-kcm")]) + self.assertEqual(f.mock_calls, [mock.call.send_frame(f_kcm)]) + n.mock_calls[:] = [] + f.mock_calls[:] = [] + + # 6: Now we deliver two messages stacked up: the KCM (Key + # Confirmation Message) and the first real message. Concatenating + # them tests that we can handle more than one token in a single + # chunk. We need to mock parse_record() because everything past the + # handshake is decrypted and parsed. + + with mock.patch("wormhole._dilation.connection.parse_record", + side_effect=[kcm, msg1]) as pr: + self.assertEqual(list(r.add_and_unframe(b"kcm,msg1")), + [kcm, msg1]) + self.assertEqual(f.mock_calls, + [mock.call.add_and_parse(b"kcm,msg1")]) + self.assertEqual(n.mock_calls, [mock.call.decrypt("f_kcm"), + mock.call.decrypt("f_msg1")]) + self.assertEqual(pr.mock_calls, [mock.call(kcm), mock.call(msg1)]) + n.mock_calls[:] = [] + f.mock_calls[:] = [] diff --git a/src/wormhole/test/dilate/test_subchannel.py b/src/wormhole/test/dilate/test_subchannel.py new file mode 100644 index 0000000..d56cdf1 --- /dev/null +++ b/src/wormhole/test/dilate/test_subchannel.py @@ -0,0 +1,144 @@ +from __future__ import print_function, unicode_literals +import mock +from twisted.trial import unittest +from twisted.internet.interfaces import ITransport +from twisted.internet.error import ConnectionDone +from ..._dilation.subchannel import (Once, SubChannel, + _WormholeAddress, _SubchannelAddress, + AlreadyClosedError) +from .common import mock_manager + + +def make_sc(set_protocol=True): + scid = b"scid" + hostaddr = _WormholeAddress() + peeraddr = _SubchannelAddress(scid) + m = mock_manager() + sc = SubChannel(scid, m, hostaddr, peeraddr) + p = mock.Mock() + if set_protocol: + sc._set_protocol(p) + return sc, m, scid, hostaddr, peeraddr, p + + +class SubChannelAPI(unittest.TestCase): + def test_once(self): + o = Once(ValueError) + o() + with self.assertRaises(ValueError): + o() + + def test_create(self): + sc, m, scid, hostaddr, peeraddr, p = make_sc() + self.assert_(ITransport.providedBy(sc)) + self.assertEqual(m.mock_calls, []) + self.assertIdentical(sc.getHost(), hostaddr) + self.assertIdentical(sc.getPeer(), peeraddr) + + def test_write(self): + sc, m, scid, hostaddr, peeraddr, p = make_sc() + + sc.write(b"data") + self.assertEqual(m.mock_calls, [mock.call.send_data(scid, b"data")]) + m.mock_calls[:] = [] + sc.writeSequence([b"more", b"data"]) + self.assertEqual(m.mock_calls, [mock.call.send_data(scid, b"moredata")]) + + def test_write_when_closing(self): + sc, m, scid, hostaddr, peeraddr, p = make_sc() + + sc.loseConnection() + self.assertEqual(m.mock_calls, [mock.call.send_close(scid)]) + m.mock_calls[:] = [] + + with self.assertRaises(AlreadyClosedError) as e: + sc.write(b"data") + self.assertEqual(str(e.exception), + "write not allowed on closed subchannel") + + def test_local_close(self): + sc, m, scid, hostaddr, peeraddr, p = make_sc() + + sc.loseConnection() + self.assertEqual(m.mock_calls, [mock.call.send_close(scid)]) + m.mock_calls[:] = [] + + # late arriving data is still delivered + sc.remote_data(b"late") + self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"late")]) + p.mock_calls[:] = [] + + sc.remote_close() + self.assert_connectionDone(p.mock_calls) + + def test_local_double_close(self): + sc, m, scid, hostaddr, peeraddr, p = make_sc() + + sc.loseConnection() + self.assertEqual(m.mock_calls, [mock.call.send_close(scid)]) + m.mock_calls[:] = [] + + with self.assertRaises(AlreadyClosedError) as e: + sc.loseConnection() + self.assertEqual(str(e.exception), + "loseConnection not allowed on closed subchannel") + + def assert_connectionDone(self, mock_calls): + self.assertEqual(len(mock_calls), 1) + self.assertEqual(mock_calls[0][0], "connectionLost") + self.assertEqual(len(mock_calls[0][1]), 1) + self.assertIsInstance(mock_calls[0][1][0], ConnectionDone) + + def test_remote_close(self): + sc, m, scid, hostaddr, peeraddr, p = make_sc() + sc.remote_close() + self.assertEqual(m.mock_calls, [mock.call.subchannel_closed(sc)]) + self.assert_connectionDone(p.mock_calls) + + def test_data(self): + sc, m, scid, hostaddr, peeraddr, p = make_sc() + sc.remote_data(b"data") + self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"data")]) + p.mock_calls[:] = [] + sc.remote_data(b"not") + sc.remote_data(b"coalesced") + self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"not"), + mock.call.dataReceived(b"coalesced"), + ]) + + def test_data_before_open(self): + sc, m, scid, hostaddr, peeraddr, p = make_sc(set_protocol=False) + sc.remote_data(b"data") + self.assertEqual(p.mock_calls, []) + sc._set_protocol(p) + self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"data")]) + p.mock_calls[:] = [] + sc.remote_data(b"more") + self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"more")]) + + def test_close_before_open(self): + sc, m, scid, hostaddr, peeraddr, p = make_sc(set_protocol=False) + sc.remote_close() + self.assertEqual(p.mock_calls, []) + sc._set_protocol(p) + self.assert_connectionDone(p.mock_calls) + + def test_producer(self): + sc, m, scid, hostaddr, peeraddr, p = make_sc() + + sc.pauseProducing() + self.assertEqual(m.mock_calls, [mock.call.subchannel_pauseProducing(sc)]) + m.mock_calls[:] = [] + sc.resumeProducing() + self.assertEqual(m.mock_calls, [mock.call.subchannel_resumeProducing(sc)]) + m.mock_calls[:] = [] + sc.stopProducing() + self.assertEqual(m.mock_calls, [mock.call.subchannel_stopProducing(sc)]) + m.mock_calls[:] = [] + + def test_consumer(self): + sc, m, scid, hostaddr, peeraddr, p = make_sc() + + # TODO: more, once this is implemented + sc.registerProducer(None, True) + sc.unregisterProducer() diff --git a/src/wormhole/test/test_hints.py b/src/wormhole/test/test_hints.py new file mode 100644 index 0000000..7dcffd6 --- /dev/null +++ b/src/wormhole/test/test_hints.py @@ -0,0 +1,196 @@ +from __future__ import print_function, unicode_literals +import io +from collections import namedtuple +import mock +from twisted.internet import endpoints, reactor +from twisted.trial import unittest +from .._hints import (endpoint_from_hint_obj, parse_hint_argv, parse_tcp_v1_hint, + describe_hint_obj, parse_hint, encode_hint, + DirectTCPV1Hint, TorTCPV1Hint, RelayV1Hint) + +UnknownHint = namedtuple("UnknownHint", ["stuff"]) + + +class Hints(unittest.TestCase): + def test_endpoint_from_hint_obj(self): + def efho(hint, tor=None): + return endpoint_from_hint_obj(hint, tor, reactor) + self.assertIsInstance(efho(DirectTCPV1Hint("host", 1234, 0.0)), + endpoints.HostnameEndpoint) + self.assertEqual(efho("unknown:stuff:yowza:pivlor"), None) + + # tor=None + self.assertEqual(efho(TorTCPV1Hint("host", "port", 0)), None) + self.assertEqual(efho(UnknownHint("foo")), None) + + tor = mock.Mock() + def tor_ep(hostname, port): + if hostname == "non-public": + raise ValueError + return ("tor_ep", hostname, port) + tor.stream_via = mock.Mock(side_effect=tor_ep) + + self.assertEqual(efho(DirectTCPV1Hint("host", 1234, 0.0), tor), + ("tor_ep", "host", 1234)) + self.assertEqual(efho(TorTCPV1Hint("host2.onion", 1234, 0.0), tor), + ("tor_ep", "host2.onion", 1234)) + self.assertEqual( efho(DirectTCPV1Hint("non-public", 1234, 0.0), tor), None) + + self.assertEqual(efho(UnknownHint("foo"), tor), None) + + def test_comparable(self): + h1 = DirectTCPV1Hint("hostname", "port1", 0.0) + h1b = DirectTCPV1Hint("hostname", "port1", 0.0) + h2 = DirectTCPV1Hint("hostname", "port2", 0.0) + r1 = RelayV1Hint(tuple(sorted([h1, h2]))) + r2 = RelayV1Hint(tuple(sorted([h2, h1]))) + r3 = RelayV1Hint(tuple(sorted([h1b, h2]))) + self.assertEqual(r1, r2) + self.assertEqual(r2, r3) + self.assertEqual(len(set([r1, r2, r3])), 1) + + def test_parse_tcp_v1_hint(self): + p = parse_tcp_v1_hint + self.assertEqual(p({"type": "unknown"}), None) + h = p({"type": "direct-tcp-v1", "hostname": "foo", "port": 1234}) + self.assertEqual(h, DirectTCPV1Hint("foo", 1234, 0.0)) + h = p({ + "type": "direct-tcp-v1", + "hostname": "foo", + "port": 1234, + "priority": 2.5 + }) + self.assertEqual(h, DirectTCPV1Hint("foo", 1234, 2.5)) + h = p({"type": "tor-tcp-v1", "hostname": "foo", "port": 1234}) + self.assertEqual(h, TorTCPV1Hint("foo", 1234, 0.0)) + h = p({ + "type": "tor-tcp-v1", + "hostname": "foo", + "port": 1234, + "priority": 2.5 + }) + self.assertEqual(h, TorTCPV1Hint("foo", 1234, 2.5)) + self.assertEqual(p({ + "type": "direct-tcp-v1" + }), None) # missing hostname + self.assertEqual(p({ + "type": "direct-tcp-v1", + "hostname": 12 + }), None) # invalid hostname + self.assertEqual( + p({ + "type": "direct-tcp-v1", + "hostname": "foo" + }), None) # missing port + self.assertEqual( + p({ + "type": "direct-tcp-v1", + "hostname": "foo", + "port": "not a number" + }), None) # invalid port + + def test_parse_hint(self): + p = parse_hint + self.assertEqual(p({"type": "direct-tcp-v1", + "hostname": "foo", + "port": 12}), + DirectTCPV1Hint("foo", 12, 0.0)) + self.assertEqual(p({"type": "relay-v1", + "hints": [ + {"type": "direct-tcp-v1", + "hostname": "foo", + "port": 12}, + {"type": "unrecognized"}, + {"type": "direct-tcp-v1", + "hostname": "bar", + "port": 13}]}), + RelayV1Hint([DirectTCPV1Hint("foo", 12, 0.0), + DirectTCPV1Hint("bar", 13, 0.0)])) + + def test_parse_hint_argv(self): + def p(hint): + stderr = io.StringIO() + value = parse_hint_argv(hint, stderr=stderr) + return value, stderr.getvalue() + + h, stderr = p("tcp:host:1234") + self.assertEqual(h, DirectTCPV1Hint("host", 1234, 0.0)) + self.assertEqual(stderr, "") + + h, stderr = p("tcp:host:1234:priority=2.6") + self.assertEqual(h, DirectTCPV1Hint("host", 1234, 2.6)) + self.assertEqual(stderr, "") + + h, stderr = p("tcp:host:1234:unknown=stuff") + self.assertEqual(h, DirectTCPV1Hint("host", 1234, 0.0)) + self.assertEqual(stderr, "") + + h, stderr = p("$!@#^") + self.assertEqual(h, None) + self.assertEqual(stderr, "unparseable hint '$!@#^'\n") + + h, stderr = p("unknown:stuff") + self.assertEqual(h, None) + self.assertEqual(stderr, + "unknown hint type 'unknown' in 'unknown:stuff'\n") + + h, stderr = p("tcp:just-a-hostname") + self.assertEqual(h, None) + self.assertEqual( + stderr, + "unparseable TCP hint (need more colons) 'tcp:just-a-hostname'\n") + + h, stderr = p("tcp:host:number") + self.assertEqual(h, None) + self.assertEqual(stderr, + "non-numeric port in TCP hint 'tcp:host:number'\n") + + h, stderr = p("tcp:host:1234:priority=bad") + self.assertEqual(h, None) + self.assertEqual( + stderr, + "non-float priority= in TCP hint 'tcp:host:1234:priority=bad'\n") + + def test_describe_hint_obj(self): + d = describe_hint_obj + self.assertEqual(d(DirectTCPV1Hint("host", 1234, 0.0), False, False), + "->tcp:host:1234") + self.assertEqual(d(DirectTCPV1Hint("host", 1234, 0.0), True, False), + "->relay:tcp:host:1234") + self.assertEqual(d(DirectTCPV1Hint("host", 1234, 0.0), False, True), + "tor->tcp:host:1234") + self.assertEqual(d(DirectTCPV1Hint("host", 1234, 0.0), True, True), + "tor->relay:tcp:host:1234") + self.assertEqual(d(TorTCPV1Hint("host", 1234, 0.0), False, False), + "->tor:host:1234") + self.assertEqual(d(UnknownHint("stuff"), False, False), + "->%s" % str(UnknownHint("stuff"))) + + def test_encode_hint(self): + e = encode_hint + self.assertEqual(e(DirectTCPV1Hint("host", 1234, 1.0)), + {"type": "direct-tcp-v1", + "priority": 1.0, + "hostname": "host", + "port": 1234}) + self.assertEqual(e(RelayV1Hint([DirectTCPV1Hint("foo", 12, 0.0), + DirectTCPV1Hint("bar", 13, 0.0)])), + {"type": "relay-v1", + "hints": [ + {"type": "direct-tcp-v1", + "hostname": "foo", + "port": 12, + "priority": 0.0}, + {"type": "direct-tcp-v1", + "hostname": "bar", + "port": 13, + "priority": 0.0}, + ]}) + self.assertEqual(e(TorTCPV1Hint("host", 1234, 1.0)), + {"type": "tor-tcp-v1", + "priority": 1.0, + "hostname": "host", + "port": 1234}) + e = self.assertRaises(ValueError, e, "not a Hint") + self.assertIn("unknown hint type", str(e)) + self.assertIn("not a Hint", str(e)) diff --git a/src/wormhole/test/test_machines.py b/src/wormhole/test/test_machines.py index 860b0b8..dff3fb0 100644 --- a/src/wormhole/test/test_machines.py +++ b/src/wormhole/test/test_machines.py @@ -12,8 +12,8 @@ import mock from .. import (__version__, _allocator, _boss, _code, _input, _key, _lister, _mailbox, _nameplate, _order, _receive, _rendezvous, _send, _terminator, errors, timing) -from .._interfaces import (IAllocator, IBoss, ICode, IInput, IKey, ILister, - IMailbox, INameplate, IOrder, IReceive, +from .._interfaces import (IAllocator, IBoss, ICode, IDilator, IInput, IKey, + ILister, IMailbox, INameplate, IOrder, IReceive, IRendezvousConnector, ISend, ITerminator, IWordlist) from .._key import derive_key, derive_phase_key, encrypt_data from ..journal import ImmediateJournal @@ -1286,17 +1286,21 @@ class Boss(unittest.TestCase): "closed") versions = {"app": "version1"} reactor = None + eq = None + cooperator = None journal = ImmediateJournal() tor_manager = None client_version = ("python", __version__) b = MockBoss(wormhole, "side", "url", "appid", versions, - client_version, reactor, journal, tor_manager, + client_version, reactor, eq, cooperator, journal, + tor_manager, timing.DebugTiming()) b._T = Dummy("t", events, ITerminator, "close") b._S = Dummy("s", events, ISend, "send") b._RC = Dummy("rc", events, IRendezvousConnector, "start") b._C = Dummy("c", events, ICode, "allocate_code", "input_code", "set_code") + b._D = Dummy("d", events, IDilator, "got_wormhole_versions", "got_key") return b, events def test_basic(self): @@ -1324,7 +1328,9 @@ class Boss(unittest.TestCase): b.got_message("0", b"msg1") self.assertEqual(events, [ ("w.got_key", b"key"), + ("d.got_key", b"key"), ("w.got_verifier", b"verifier"), + ("d.got_wormhole_versions", {}), ("w.got_versions", {}), ("w.received", b"msg1"), ]) diff --git a/src/wormhole/test/test_observer.py b/src/wormhole/test/test_observer.py index 8d89ddd..6ca8957 100644 --- a/src/wormhole/test/test_observer.py +++ b/src/wormhole/test/test_observer.py @@ -3,7 +3,7 @@ from twisted.python.failure import Failure from twisted.trial import unittest from ..eventual import EventualQueue -from ..observer import OneShotObserver, SequenceObserver +from ..observer import OneShotObserver, SequenceObserver, EmptyableSet class OneShot(unittest.TestCase): @@ -121,3 +121,28 @@ class Sequence(unittest.TestCase): d2 = o.when_next_event() eq.flush_sync() self.assertIdentical(self.failureResultOf(d2), f) + + +class Empty(unittest.TestCase): + def test_set(self): + eq = EventualQueue(Clock()) + s = EmptyableSet(_eventual_queue=eq) + d1 = s.when_next_empty() + eq.flush_sync() + self.assertNoResult(d1) + s.add(1) + eq.flush_sync() + self.assertNoResult(d1) + s.add(2) + s.discard(1) + d2 = s.when_next_empty() + eq.flush_sync() + self.assertNoResult(d1) + self.assertNoResult(d2) + s.discard(2) + eq.flush_sync() + self.assertEqual(self.successResultOf(d1), None) + self.assertEqual(self.successResultOf(d2), None) + + s.add(3) + s.discard(3) diff --git a/src/wormhole/test/test_transit.py b/src/wormhole/test/test_transit.py index 63216ae..df6d5a6 100644 --- a/src/wormhole/test/test_transit.py +++ b/src/wormhole/test/test_transit.py @@ -3,7 +3,6 @@ from __future__ import print_function, unicode_literals import gc import io from binascii import hexlify, unhexlify -from collections import namedtuple import six from nacl.exceptions import CryptoError @@ -18,7 +17,9 @@ import mock from wormhole_transit_relay import transit_server from .. import transit +from .._hints import DirectTCPV1Hint from ..errors import InternalError +from ..util import HKDF from .common import ServerBase @@ -140,141 +141,6 @@ class Misc(unittest.TestCase): self.assertIsInstance(portno, int) -UnknownHint = namedtuple("UnknownHint", ["stuff"]) - - -class Hints(unittest.TestCase): - def test_endpoint_from_hint_obj(self): - c = transit.Common("") - efho = c._endpoint_from_hint_obj - self.assertIsInstance( - efho(transit.DirectTCPV1Hint("host", 1234, 0.0)), - endpoints.HostnameEndpoint) - self.assertEqual(efho("unknown:stuff:yowza:pivlor"), None) - # c._tor is currently None - self.assertEqual(efho(transit.TorTCPV1Hint("host", "port", 0)), None) - c._tor = mock.Mock() - - def tor_ep(hostname, port): - if hostname == "non-public": - return None - return ("tor_ep", hostname, port) - - c._tor.stream_via = mock.Mock(side_effect=tor_ep) - self.assertEqual( - efho(transit.DirectTCPV1Hint("host", 1234, 0.0)), - ("tor_ep", "host", 1234)) - self.assertEqual( - efho(transit.TorTCPV1Hint("host2.onion", 1234, 0.0)), - ("tor_ep", "host2.onion", 1234)) - self.assertEqual( - efho(transit.DirectTCPV1Hint("non-public", 1234, 0.0)), None) - self.assertEqual(efho(UnknownHint("foo")), None) - - def test_comparable(self): - h1 = transit.DirectTCPV1Hint("hostname", "port1", 0.0) - h1b = transit.DirectTCPV1Hint("hostname", "port1", 0.0) - h2 = transit.DirectTCPV1Hint("hostname", "port2", 0.0) - r1 = transit.RelayV1Hint(tuple(sorted([h1, h2]))) - r2 = transit.RelayV1Hint(tuple(sorted([h2, h1]))) - r3 = transit.RelayV1Hint(tuple(sorted([h1b, h2]))) - self.assertEqual(r1, r2) - self.assertEqual(r2, r3) - self.assertEqual(len(set([r1, r2, r3])), 1) - - def test_parse_tcp_v1_hint(self): - c = transit.Common("") - p = c._parse_tcp_v1_hint - self.assertEqual(p({"type": "unknown"}), None) - h = p({"type": "direct-tcp-v1", "hostname": "foo", "port": 1234}) - self.assertEqual(h, transit.DirectTCPV1Hint("foo", 1234, 0.0)) - h = p({ - "type": "direct-tcp-v1", - "hostname": "foo", - "port": 1234, - "priority": 2.5 - }) - self.assertEqual(h, transit.DirectTCPV1Hint("foo", 1234, 2.5)) - h = p({"type": "tor-tcp-v1", "hostname": "foo", "port": 1234}) - self.assertEqual(h, transit.TorTCPV1Hint("foo", 1234, 0.0)) - h = p({ - "type": "tor-tcp-v1", - "hostname": "foo", - "port": 1234, - "priority": 2.5 - }) - self.assertEqual(h, transit.TorTCPV1Hint("foo", 1234, 2.5)) - self.assertEqual(p({ - "type": "direct-tcp-v1" - }), None) # missing hostname - self.assertEqual(p({ - "type": "direct-tcp-v1", - "hostname": 12 - }), None) # invalid hostname - self.assertEqual( - p({ - "type": "direct-tcp-v1", - "hostname": "foo" - }), None) # missing port - self.assertEqual( - p({ - "type": "direct-tcp-v1", - "hostname": "foo", - "port": "not a number" - }), None) # invalid port - - def test_parse_hint_argv(self): - def p(hint): - stderr = io.StringIO() - value = transit.parse_hint_argv(hint, stderr=stderr) - return value, stderr.getvalue() - - h, stderr = p("tcp:host:1234") - self.assertEqual(h, transit.DirectTCPV1Hint("host", 1234, 0.0)) - self.assertEqual(stderr, "") - - h, stderr = p("tcp:host:1234:priority=2.6") - self.assertEqual(h, transit.DirectTCPV1Hint("host", 1234, 2.6)) - self.assertEqual(stderr, "") - - h, stderr = p("tcp:host:1234:unknown=stuff") - self.assertEqual(h, transit.DirectTCPV1Hint("host", 1234, 0.0)) - self.assertEqual(stderr, "") - - h, stderr = p("$!@#^") - self.assertEqual(h, None) - self.assertEqual(stderr, "unparseable hint '$!@#^'\n") - - h, stderr = p("unknown:stuff") - self.assertEqual(h, None) - self.assertEqual(stderr, - "unknown hint type 'unknown' in 'unknown:stuff'\n") - - h, stderr = p("tcp:just-a-hostname") - self.assertEqual(h, None) - self.assertEqual( - stderr, - "unparseable TCP hint (need more colons) 'tcp:just-a-hostname'\n") - - h, stderr = p("tcp:host:number") - self.assertEqual(h, None) - self.assertEqual(stderr, - "non-numeric port in TCP hint 'tcp:host:number'\n") - - h, stderr = p("tcp:host:1234:priority=bad") - self.assertEqual(h, None) - self.assertEqual( - stderr, - "non-float priority= in TCP hint 'tcp:host:1234:priority=bad'\n") - - def test_describe_hint_obj(self): - d = transit.describe_hint_obj - self.assertEqual( - d(transit.DirectTCPV1Hint("host", 1234, 0.0)), "tcp:host:1234") - self.assertEqual( - d(transit.TorTCPV1Hint("host", 1234, 0.0)), "tor:host:1234") - self.assertEqual(d(UnknownHint("stuff")), str(UnknownHint("stuff"))) - # ipaddrs.py currently uses native strings: bytes on py2, unicode on # py3 @@ -437,7 +303,7 @@ class Listener(unittest.TestCase): hints, ep = c._build_listener() self.assertIsInstance(hints, (list, set)) if hints: - self.assertIsInstance(hints[0], transit.DirectTCPV1Hint) + self.assertIsInstance(hints[0], DirectTCPV1Hint) self.assertIsInstance(ep, endpoints.TCP4ServerEndpoint) def test_get_direct_hints(self): @@ -1507,8 +1373,8 @@ class Transit(unittest.TestCase): self.assertEqual(self.successResultOf(d), "winner") self.assertEqual(self._descriptions, ["tor->relay:tcp:relay:1234"]) - def _endpoint_from_hint_obj(self, hint): - if isinstance(hint, transit.DirectTCPV1Hint): + def _endpoint_from_hint_obj(self, hint, _tor, _reactor): + if isinstance(hint, DirectTCPV1Hint): if hint.hostname == "unavailable": return None return hint.hostname @@ -1523,20 +1389,21 @@ class Transit(unittest.TestCase): del hints s.add_connection_hints( [DIRECT_HINT_JSON, UNRECOGNIZED_HINT_JSON, RELAY_HINT_JSON]) - s._endpoint_from_hint_obj = self._endpoint_from_hint_obj s._start_connector = self._start_connector - d = s.connect() - self.assertNoResult(d) - # the direct connectors are tried right away, but the relay - # connectors are stalled for a few seconds - self.assertEqual(self._connectors, ["direct"]) + with mock.patch("wormhole.transit.endpoint_from_hint_obj", + self._endpoint_from_hint_obj): + d = s.connect() + self.assertNoResult(d) + # the direct connectors are tried right away, but the relay + # connectors are stalled for a few seconds + self.assertEqual(self._connectors, ["direct"]) - clock.advance(s.RELAY_DELAY + 1.0) - self.assertEqual(self._connectors, ["direct", "relay"]) + clock.advance(s.RELAY_DELAY + 1.0) + self.assertEqual(self._connectors, ["direct", "relay"]) - self._waiters[0].callback("winner") - self.assertEqual(self.successResultOf(d), "winner") + self._waiters[0].callback("winner") + self.assertEqual(self.successResultOf(d), "winner") @inlineCallbacks def test_priorities(self): @@ -1586,31 +1453,32 @@ class Transit(unittest.TestCase): }] }, ]) - s._endpoint_from_hint_obj = self._endpoint_from_hint_obj s._start_connector = self._start_connector - d = s.connect() - self.assertNoResult(d) - # direct connector should be used first, then the priority=3.0 relay, - # then the two 2.0 relays, then the (default) 0.0 relay + with mock.patch("wormhole.transit.endpoint_from_hint_obj", + self._endpoint_from_hint_obj): + d = s.connect() + self.assertNoResult(d) + # direct connector should be used first, then the priority=3.0 relay, + # then the two 2.0 relays, then the (default) 0.0 relay - self.assertEqual(self._connectors, ["direct"]) + self.assertEqual(self._connectors, ["direct"]) - clock.advance(s.RELAY_DELAY + 1.0) - self.assertEqual(self._connectors, ["direct", "relay3"]) + clock.advance(s.RELAY_DELAY + 1.0) + self.assertEqual(self._connectors, ["direct", "relay3"]) - clock.advance(s.RELAY_DELAY) - self.assertIn(self._connectors, - (["direct", "relay3", "relay2", "relay4"], - ["direct", "relay3", "relay4", "relay2"])) + clock.advance(s.RELAY_DELAY) + self.assertIn(self._connectors, + (["direct", "relay3", "relay2", "relay4"], + ["direct", "relay3", "relay4", "relay2"])) - clock.advance(s.RELAY_DELAY) - self.assertIn(self._connectors, - (["direct", "relay3", "relay2", "relay4", "relay"], - ["direct", "relay3", "relay4", "relay2", "relay"])) + clock.advance(s.RELAY_DELAY) + self.assertIn(self._connectors, + (["direct", "relay3", "relay2", "relay4", "relay"], + ["direct", "relay3", "relay4", "relay2", "relay"])) - self._waiters[0].callback("winner") - self.assertEqual(self.successResultOf(d), "winner") + self._waiters[0].callback("winner") + self.assertEqual(self.successResultOf(d), "winner") @inlineCallbacks def test_no_direct_hints(self): @@ -1624,20 +1492,21 @@ class Transit(unittest.TestCase): UNRECOGNIZED_HINT_JSON, UNAVAILABLE_HINT_JSON, RELAY_HINT2_JSON, UNAVAILABLE_RELAY_HINT_JSON ]) - s._endpoint_from_hint_obj = self._endpoint_from_hint_obj s._start_connector = self._start_connector - d = s.connect() - self.assertNoResult(d) - # since there are no usable direct hints, the relay connector will - # only be stalled for 0 seconds - self.assertEqual(self._connectors, []) + with mock.patch("wormhole.transit.endpoint_from_hint_obj", + self._endpoint_from_hint_obj): + d = s.connect() + self.assertNoResult(d) + # since there are no usable direct hints, the relay connector will + # only be stalled for 0 seconds + self.assertEqual(self._connectors, []) - clock.advance(0) - self.assertEqual(self._connectors, ["relay"]) + clock.advance(0) + self.assertEqual(self._connectors, ["relay"]) - self._waiters[0].callback("winner") - self.assertEqual(self.successResultOf(d), "winner") + self._waiters[0].callback("winner") + self.assertEqual(self.successResultOf(d), "winner") @inlineCallbacks def test_no_contenders(self): @@ -1647,17 +1516,18 @@ class Transit(unittest.TestCase): hints = yield s.get_connection_hints() # start the listener del hints s.add_connection_hints([]) # no hints at all - s._endpoint_from_hint_obj = self._endpoint_from_hint_obj s._start_connector = self._start_connector - d = s.connect() - f = self.failureResultOf(d, transit.TransitError) - self.assertEqual(str(f.value), "No contenders for connection") + with mock.patch("wormhole.transit.endpoint_from_hint_obj", + self._endpoint_from_hint_obj): + d = s.connect() + f = self.failureResultOf(d, transit.TransitError) + self.assertEqual(str(f.value), "No contenders for connection") class RelayHandshake(unittest.TestCase): def old_build_relay_handshake(self, key): - token = transit.HKDF(key, 32, CTXinfo=b"transit_relay_token") + token = HKDF(key, 32, CTXinfo=b"transit_relay_token") return (token, b"please relay " + hexlify(token) + b"\n") def test_old(self): diff --git a/src/wormhole/transit.py b/src/wormhole/transit.py index 96d8f95..98e1b72 100644 --- a/src/wormhole/transit.py +++ b/src/wormhole/transit.py @@ -2,15 +2,13 @@ from __future__ import absolute_import, print_function import os -import re import socket import sys import time from binascii import hexlify, unhexlify -from collections import deque, namedtuple +from collections import deque import six -from hkdf import Hkdf from nacl.secret import SecretBox from twisted.internet import (address, defer, endpoints, error, interfaces, protocol, reactor, task) @@ -23,11 +21,10 @@ from zope.interface import implementer from . import ipaddrs from .errors import InternalError from .timing import DebugTiming -from .util import bytes_to_hexstr - - -def HKDF(skm, outlen, salt=None, CTXinfo=b""): - return Hkdf(salt, skm).expand(CTXinfo, outlen) +from .util import bytes_to_hexstr, HKDF +from ._hints import (DirectTCPV1Hint, RelayV1Hint, + parse_hint_argv, describe_hint_obj, endpoint_from_hint_obj, + parse_tcp_v1_hint) class TransitError(Exception): @@ -95,72 +92,6 @@ def build_sided_relay_handshake(key, side): "ascii") + b"\n" -# These namedtuples are "hint objects". The JSON-serializable dictionaries -# are "hint dicts". - -# DirectTCPV1Hint and TorTCPV1Hint mean the following protocol: -# * make a TCP connection (possibly via Tor) -# * send the sender/receiver handshake bytes first -# * expect to see the receiver/sender handshake bytes from the other side -# * the sender writes "go\n", the receiver waits for "go\n" -# * the rest of the connection contains transit data -DirectTCPV1Hint = namedtuple("DirectTCPV1Hint", - ["hostname", "port", "priority"]) -TorTCPV1Hint = namedtuple("TorTCPV1Hint", ["hostname", "port", "priority"]) -# RelayV1Hint contains a tuple of DirectTCPV1Hint and TorTCPV1Hint hints (we -# use a tuple rather than a list so they'll be hashable into a set). For each -# one, make the TCP connection, send the relay handshake, then complete the -# rest of the V1 protocol. Only one hint per relay is useful. -RelayV1Hint = namedtuple("RelayV1Hint", ["hints"]) - - -def describe_hint_obj(hint): - if isinstance(hint, DirectTCPV1Hint): - return u"tcp:%s:%d" % (hint.hostname, hint.port) - elif isinstance(hint, TorTCPV1Hint): - return u"tor:%s:%d" % (hint.hostname, hint.port) - else: - return str(hint) - - -def parse_hint_argv(hint, stderr=sys.stderr): - assert isinstance(hint, type(u"")) - # return tuple or None for an unparseable hint - priority = 0.0 - mo = re.search(r'^([a-zA-Z0-9]+):(.*)$', hint) - if not mo: - print("unparseable hint '%s'" % (hint, ), file=stderr) - return None - hint_type = mo.group(1) - if hint_type != "tcp": - print( - "unknown hint type '%s' in '%s'" % (hint_type, hint), file=stderr) - return None - hint_value = mo.group(2) - pieces = hint_value.split(":") - if len(pieces) < 2: - print( - "unparseable TCP hint (need more colons) '%s'" % (hint, ), - file=stderr) - return None - mo = re.search(r'^(\d+)$', pieces[1]) - if not mo: - print("non-numeric port in TCP hint '%s'" % (hint, ), file=stderr) - return None - hint_host = pieces[0] - hint_port = int(pieces[1]) - for more in pieces[2:]: - if more.startswith("priority="): - more_pieces = more.split("=") - try: - priority = float(more_pieces[1]) - except ValueError: - print( - "non-float priority= in TCP hint '%s'" % (hint, ), - file=stderr) - return None - return DirectTCPV1Hint(hint_host, hint_port, priority) - TIMEOUT = 60 # seconds @@ -746,30 +677,11 @@ class Common: self._listener_d.addErrback(lambda f: None) self._listener_d.cancel() - def _parse_tcp_v1_hint(self, hint): # hint_struct -> hint_obj - hint_type = hint.get(u"type", u"") - if hint_type not in [u"direct-tcp-v1", u"tor-tcp-v1"]: - log.msg("unknown hint type: %r" % (hint, )) - return None - if not (u"hostname" in hint and - isinstance(hint[u"hostname"], type(u""))): - log.msg("invalid hostname in hint: %r" % (hint, )) - return None - if not (u"port" in hint and - isinstance(hint[u"port"], six.integer_types)): - log.msg("invalid port in hint: %r" % (hint, )) - return None - priority = hint.get(u"priority", 0.0) - if hint_type == u"direct-tcp-v1": - return DirectTCPV1Hint(hint[u"hostname"], hint[u"port"], priority) - else: - return TorTCPV1Hint(hint[u"hostname"], hint[u"port"], priority) - def add_connection_hints(self, hints): for h in hints: # hint structs hint_type = h.get(u"type", u"") if hint_type in [u"direct-tcp-v1", u"tor-tcp-v1"]: - dh = self._parse_tcp_v1_hint(h) + dh = parse_tcp_v1_hint(h) if dh: self._their_direct_hints.append(dh) # hint_obj elif hint_type == u"relay-v1": @@ -779,7 +691,7 @@ class Common: # together like this. relay_hints = [] for rhs in h.get(u"hints", []): - h = self._parse_tcp_v1_hint(rhs) + h = parse_tcp_v1_hint(rhs) if h: relay_hints.append(h) if relay_hints: @@ -875,13 +787,11 @@ class Common: # Check the hint type to see if we can support it (e.g. skip # onion hints on a non-Tor client). Do not increase relay_delay # unless we have at least one viable hint. - ep = self._endpoint_from_hint_obj(hint_obj) + ep = endpoint_from_hint_obj(hint_obj, self._tor, self._reactor) if not ep: continue - description = "->%s" % describe_hint_obj(hint_obj) - if self._tor: - description = "tor" + description - d = self._start_connector(ep, description) + d = self._start_connector(ep, + describe_hint_obj(hint_obj, False, self._tor)) contenders.append(d) relay_delay = self.RELAY_DELAY @@ -902,18 +812,15 @@ class Common: for priority in sorted(prioritized_relays, reverse=True): for hint_obj in prioritized_relays[priority]: - ep = self._endpoint_from_hint_obj(hint_obj) + ep = endpoint_from_hint_obj(hint_obj, self._tor, self._reactor) if not ep: continue - description = "->relay:%s" % describe_hint_obj(hint_obj) - if self._tor: - description = "tor" + description d = task.deferLater( self._reactor, relay_delay, self._start_connector, ep, - description, + describe_hint_obj(hint_obj, True, self._tor), is_relay=True) contenders.append(d) relay_delay += self.RELAY_DELAY @@ -951,21 +858,6 @@ class Common: d.addCallback(lambda p: p.startNegotiation()) return d - def _endpoint_from_hint_obj(self, hint): - if self._tor: - if isinstance(hint, (DirectTCPV1Hint, TorTCPV1Hint)): - # this Tor object will throw ValueError for non-public IPv4 - # addresses and any IPv6 address - try: - return self._tor.stream_via(hint.hostname, hint.port) - except ValueError: - return None - return None - if isinstance(hint, DirectTCPV1Hint): - return endpoints.HostnameEndpoint(self._reactor, hint.hostname, - hint.port) - return None - def connection_ready(self, p): # inbound/outbound Connection protocols call this when they finish # negotiation. The first one wins and gets a "go". Any subsequent diff --git a/src/wormhole/util.py b/src/wormhole/util.py index 0b57c5e..971de7e 100644 --- a/src/wormhole/util.py +++ b/src/wormhole/util.py @@ -3,11 +3,19 @@ import json import os import unicodedata from binascii import hexlify, unhexlify +from hkdf import Hkdf +def HKDF(skm, outlen, salt=None, CTXinfo=b""): + return Hkdf(salt, skm).expand(CTXinfo, outlen) + def to_bytes(u): return unicodedata.normalize("NFC", u).encode("utf-8") +def to_unicode(any): + if isinstance(any, type(u"")): + return any + return any.decode("ascii") def bytes_to_hexstr(b): assert isinstance(b, type(b"")) diff --git a/src/wormhole/wormhole.py b/src/wormhole/wormhole.py index 610069c..e17b7a1 100644 --- a/src/wormhole/wormhole.py +++ b/src/wormhole/wormhole.py @@ -5,9 +5,12 @@ import sys from attr import attrib, attrs from twisted.python import failure +from twisted.internet.task import Cooperator from zope.interface import implementer from ._boss import Boss +from ._dilation.manager import DILATION_VERSIONS +from ._dilation.connector import Connector from ._interfaces import IDeferredWormhole, IWormhole from ._key import derive_key from .errors import NoKeyError, WormholeClosed @@ -122,7 +125,8 @@ class _DelegatedWormhole(object): @implementer(IWormhole, IDeferredWormhole) class _DeferredWormhole(object): - def __init__(self, eq): + def __init__(self, reactor, eq): + self._reactor = reactor self._welcome_observer = OneShotObserver(eq) self._code_observer = OneShotObserver(eq) self._key = None @@ -187,6 +191,10 @@ class _DeferredWormhole(object): raise NoKeyError() return derive_key(self._key, to_bytes(purpose), length) + def dilate(self): + raise NotImplementedError + return self._boss.dilate() # fires with (endpoints) + def close(self): # fails with WormholeError unless we established a connection # (state=="happy"). Fails with WrongPasswordError (a subclass of @@ -258,18 +266,24 @@ def create( side = bytes_to_hexstr(os.urandom(5)) journal = journal or ImmediateJournal() eq = _eventual_queue or EventualQueue(reactor) + cooperator = Cooperator(scheduler=eq.eventually) if delegate: w = _DelegatedWormhole(delegate) else: - w = _DeferredWormhole(eq) - wormhole_versions = {} # will be used to indicate Wormhole capabilities + w = _DeferredWormhole(reactor, eq) + # this indicates Wormhole capabilities + wormhole_versions = { + "can-dilate": DILATION_VERSIONS, + "dilation-abilities": Connector.get_connection_abilities(), + } + wormhole_versions = {} # don't advertise Dilation yet: not ready wormhole_versions["app_versions"] = versions # app-specific capabilities v = __version__ if isinstance(v, type(b"")): v = v.decode("utf-8", errors="replace") client_version = ("python", v) b = Boss(w, side, relay_url, appid, wormhole_versions, client_version, - reactor, journal, tor, timing) + reactor, eq, cooperator, journal, tor, timing) w._set_boss(b) b.start() return w diff --git a/tox.ini b/tox.ini index 4d7a007..3c01425 100644 --- a/tox.ini +++ b/tox.ini @@ -10,7 +10,7 @@ minversion = 2.4.0 [testenv] usedevelop = True -extras = dev +extras = dev,dilate deps = pyflakes >= 1.2.3 commands = @@ -18,6 +18,8 @@ commands = wormhole --version python -m wormhole.test.run_trial {posargs:wormhole} +[testenv:no-dilate] +extras = dev # on windows, trial is installed as venv/bin/trial.py, not .exe, but (at # least appveyor) adds .PY to $PATHEXT. So "trial wormhole" might work on