diff --git a/docs/api.md b/docs/api.md index 39a4e43..43cc318 100644 --- a/docs/api.md +++ b/docs/api.md @@ -524,25 +524,31 @@ object twice. ## Dilation -(NOTE: this section is speculative: this code has not yet been written) - -In the longer term, the Wormhole object will incorporate the "Transit" -functionality (see transit.md) directly, removing the need to instantiate a -second object. A Wormhole can be "dilated" into a form that is suitable for -bulk data transfer. +To send bulk data, or anything more than a handful of messages, a Wormhole +can be "dilated" into a form that uses a direct TCP connection between the +two endpoints. All wormholes start out "undilated". In this state, all messages are queued on the Rendezvous Server for the lifetime of the wormhole, and server-imposed number/size/rate limits apply. Calling `w.dilate()` initiates the dilation -process, and success is signalled via either `d=w.when_dilated()` firing, or -`dg.wormhole_dilated()` being called. Once dilated, the Wormhole can be used -as an IConsumer/IProducer, and messages will be sent on a direct connection -(if possible) or through the transit relay (if not). +process, and eventually yields a set of Endpoints. Once dilated, the usual +`.send_message()`/`.get_message()` APIs are disabled (TODO: really?), and +these endpoints can be used to establish multiple (encrypted) "subchannel" +connections to the other side. + +Each subchannel behaves like a regular Twisted `ITransport`, so they can be +glued to the Protocol instance of your choice. They also implement the +IConsumer/IProducer interfaces. + +These subchannels are *durable*: as long as the processes on both sides keep +running, the subchannel will survive the network connection being dropped. +For example, a file transfer can be started from a laptop, then while it is +running, the laptop can be closed, moved to a new wifi network, opened back +up, and the transfer will resume from the new IP address. What's good about a non-dilated wormhole?: * setup is faster: no delay while it tries to make a direct connection -* survives temporary network outages, since messages are queued * works with "journaled mode", allowing progress to be made even when both sides are never online at the same time, by serializing the wormhole @@ -556,21 +562,103 @@ Use non-dilated wormholes when your application only needs to exchange a couple of messages, for example to set up public keys or provision access tokens. Use a dilated wormhole to move files. -Dilated wormholes can provide multiple "channels": these are multiplexed -through the single (encrypted) TCP connection. Each channel is a separate -stream (offering IProducer/IConsumer) +Dilated wormholes can provide multiple "subchannels": these are multiplexed +through the single (encrypted) TCP connection. Each subchannel is a separate +stream (offering IProducer/IConsumer for flow control), and is opened and +closed independently. A special "control channel" is available to both sides +so they can coordinate how they use the subchannels. -To create a channel, call `c = w.create_channel()` on a dilated wormhole. The -"channel ID" can be obtained with `c.get_id()`. This ID will be a short -(unicode) string, which can be sent to the other side via a normal -`w.send()`, or any other means. On the other side, use `c = -w.open_channel(channel_id)` to get a matching channel object. +The `d = w.dilate()` Deferred fires with a triple of Endpoints: -Then use `c.send(data)` and `d=c.when_received()` to exchange data, or wire -them up with `c.registerProducer()`. Note that channels do not close until -the wormhole connection is closed, so they do not have separate `close()` -methods or events. Therefore if you plan to send files through them, you'll -need to inform the recipient ahead of time about how many bytes to expect. +```python +d = w.dilate() +def _dilated(res): + (control_channel_ep, subchannel_client_ep, subchannel_server_ep) = res +d.addCallback(_dilated) +``` + +The `control_channel_ep` endpoint is a client-style endpoint, so both sides +will connect to it with `ep.connect(factory)`. This endpoint is single-use: +calling `.connect()` a second time will fail. The control channel is +symmetric: it doesn't matter which side is the application-level +client/server or initiator/responder, if the application even has such +concepts. The two applications can use the control channel to negotiate who +goes first, if necessary. + +The subchannel endpoints are *not* symmetric: for each subchannel, one side +must listen as a server, and the other must connect as a client. Subchannels +can be established by either side at any time. This supports e.g. +bidirectional file transfer, where either user of a GUI app can drop files +into the "wormhole" whenever they like. + +The `subchannel_client_ep` on one side is used to connect to the other side's +`subchannel_server_ep`, and vice versa. The client endpoint is reusable. The +server endpoint is single-use: `.listen(factory)` may only be called once. + +Applications are under no obligation to use subchannels: for many use cases, +the control channel is enough. + +To use subchannels, once the wormhole is dilated and the endpoints are +available, the listening-side application should attach a listener to the +`subchannel_server_ep` endpoint: + +```python +def _dilated(res): + (control_channel_ep, subchannel_client_ep, subchannel_server_ep) = res + f = Factory(MyListeningProtocol) + subchannel_server_ep.listen(f) +``` + +When the connecting-side application wants to connect to that listening +protocol, it should use `.connect()` with a suitable connecting protocol +factory: + +```python +def _connect(): + f = Factory(MyConnectingProtocol) + subchannel_client_ep.connect(f) +``` + +For a bidirectional file-transfer application, both sides will establish a +listening protocol. Later, if/when the user drops a file on the application +window, that side will initiate a connection, use the resulting subchannel to +transfer the single file, and then close the subchannel. + +```python +def FileSendingProtocol(internet.Protocol): + def __init__(self, metadata, filename): + self.file_metadata = metadata + self.file_name = filename + def connectionMade(self): + self.transport.write(self.file_metadata) + sender = protocols.basic.FileSender() + f = open(self.file_name,"rb") + d = sender.beginFileTransfer(f, self.transport) + d.addBoth(self._done, f) + def _done(res, f): + self.transport.loseConnection() + f.close() +def _send(metadata, filename): + f = protocol.ClientCreator(reactor, + FileSendingProtocol, metadata, filename) + subchannel_client_ep.connect(f) +def FileReceivingProtocol(internet.Protocol): + state = INITIAL + def dataReceived(self, data): + if state == INITIAL: + self.state = DATA + metadata = parse(data) + self.f = open(metadata.filename, "wb") + else: + # local file writes are blocking, so don't bother with IConsumer + self.f.write(data) + def connectionLost(self, reason): + self.f.close() +def _dilated(res): + (control_channel_ep, subchannel_client_ep, subchannel_server_ep) = res + f = Factory(FileReceivingProtocol) + subchannel_server_ep.listen(f) +``` ## Bytes, Strings, Unicode, and Python 3 diff --git a/docs/dilation-protocol.md b/docs/dilation-protocol.md new file mode 100644 index 0000000..4f0dac1 --- /dev/null +++ b/docs/dilation-protocol.md @@ -0,0 +1,500 @@ +# Dilation Internals + +Wormhole dilation involves several moving parts. Both sides exchange messages +through the Mailbox server to coordinate the establishment of a more direct +connection. This connection might flow in either direction, so they trade +"connection hints" to point at potential listening ports. This process might +succeed in making multiple connections at about the same time, so one side +must select the best one to use, and cleanly shut down the others. To make +the dilated connection *durable*, this side must also decide when the +connection has been lost, and then coordinate the construction of a +replacement. Within this connection, a series of queued-and-acked subchannel +messages are used to open/use/close the application-visible subchannels. + +## Leaders and Followers + +Each side of a Wormhole has a randomly-generated "side" string. When the +wormhole is dilated, the side with the lexicographically-higher "side" value +is named the "Leader", and the other side is named the "Follower". The general +wormhole protocol treats both sides identically, but the distinction matters +for the dilation protocol. + +Either side can trigger dilation, but the Follower does so by asking the +Leader to start the process, whereas the Leader just starts the process +unilaterally. The Leader has exclusive control over whether a given +connection is considered established or not: if there are multiple potential +connections to use, the Leader decides which one to use, and the Leader gets +to decide when the connection is no longer viable (and triggers the +establishment of a new one). + +## Connection Layers + +We describe the protocol as a series of layers. Messages sent on one layer +may be encoded or transformed before being delivered on some other layer. + +L1 is the mailbox channel (queued store-and-forward messages that always go +to the mailbox server, and then are forwarded to other clients subscribed to +the same mailbox). Both clients remain connected to the mailbox server until +the Wormhole is closed. They send DILATE-n messages to each other to manage +the dilation process, including records like `please-dilate`, +`start-dilation`, `ok-dilation`, and `connection-hints` + +L2 is the set of competing connection attempts for a given generation of +connection. Each time the Leader decides to establish a new connection, a new +generation number is used. Hopefully these are direct TCP connections between +the two peers, but they may also include connections through the transit +relay. Each connection must go through an encrypted handshake process before +it is considered viable. Viable connections are then submitted to a selection +process (on the Leader side), which chooses exactly one to use, and drops the +others. It may wait an extra few seconds in the hopes of getting a "better" +connection (faster, cheaper, etc), but eventually it will select one. + +L3 is the current selected connection. There is one L3 for each generation. +At all times, the wormhole will have exactly zero or one L3 connection. L3 is +responsible for the selection process, connection monitoring/keepalives, and +serialization/deserialization of the plaintext frames. L3 delivers decoded +frames and connection-establishment events up to L4. + +L4 is the persistent higher-level channel. It is created as soon as the first +L3 connection is selected, and lasts until wormhole is closed entirely. L4 +contains OPEN/DATA/CLOSE/ACK messages: OPEN/DATA/CLOSE have a sequence number +(scoped to the L4 connection and the direction of travel), and the ACK +messages reference those sequence numbers. When a message is given to the L4 +channel for delivery to the remote side, it is always queued, then +transmitted if there is an L3 connection available. This message remains in +the queue until an ACK is received to release it. If a new L3 connection is +made, all queued messages will be re-sent (in seqnum order). + +L5 are subchannels. There is one pre-established subchannel 0 known as the +"control channel", which does not require an OPEN message. All other +subchannels are created by the receipt of an OPEN message with the subchannel +number. DATA frames are delivered to a specific subchannel. When the +subchannel is no longer needed, one side will invoke the ``close()`` API +(``loseConnection()`` in Twisted), which will cause a CLOSE message to be +sent, and the local L5 object will be put into the "closing "state. When the +other side receives the CLOSE, it will send its own CLOSE for the same +subchannel, and fully close its local object (``connectionLost()``). When the +first side receives CLOSE in the "closing" state, it will fully close its +local object too. + +All L5 subchannels will be paused (``pauseProducing()``) when the L3 +connection is paused or lost. They are resumed when the L3 connection is +resumed or reestablished. + +## Initiating Dilation + +Dilation is triggered by calling the `w.dilate()` API. This returns a +Deferred that will fire once the first L3 connection is established. It fires +with a 3-tuple of endpoints that can be used to establish subchannels. + +For dilation to succeed, both sides must call `w.dilate()`, since the +resulting endpoints are the only way to access the subchannels. If the other +side never calls `w.dilate()`, the Deferred will never fire. + +The L1 (mailbox) path is used to deliver dilation requests and connection +hints. The current mailbox protocol uses named "phases" to distinguish +messages (rather than behaving like a regular ordered channel of arbitrary +frames or bytes), and all-number phase names are reserved for application +data (sent via `w.send_message()`). Therefore the dilation control messages +use phases named `DILATE-0`, `DILATE-1`, etc. Each side maintains its own +counter, so one side might be up to e.g. `DILATE-5` while the other has only +gotten as far as `DILATE-2`. This effectively creates a unidirectional stream +of `DILATE-n` messages, each containing one or more dilation record, of +various types described below. Note that all phases beyond the initial +VERSION and PAKE phases are encrypted by the shared session key. + +A future mailbox protocol might provide a simple ordered stream of messages, +with application records and dilation records mixed together. + +Each `DILATE-n` message is a JSON-encoded dictionary with a `type` field that +has a string value. The dictionary will have other keys that depend upon the +type. + +`w.dilate()` triggers a `please-dilate` record with a set of versions that +can be accepted. Both Leader and Follower emit this record, although the +Leader is responsible for version decisions. Versions use strings, rather +than integers, to support experimental protocols, however there is still a +total ordering of preferability. + +``` +{ "type": "please-dilate", + "accepted-versions": ["1"] +} +``` + +The Leader then sends a `start-dilation` message with a `version` field (the +"best" mutually-supported value) and the new "L2 generation" number in the +`generation` field. Generation numbers are integers, monotonically increasing +by 1 each time. + +``` +{ "type": start-dilation, + "version": "1", + "generation": 1, +} +``` + +The Follower responds with a `ok-dilation` message with matching `version` +and `generation` fields. + +The Leader decides when a new dilation connection is necessary, both for the +initial connection and any subsequent reconnects. Therefore the Leader has +the exclusive right to send the `start-dilation` record. It won't send this +until after it has sent its own `please-dilate`, and after it has received +the Follower's `please-dilate`. As a result, local preparations may begin as +soon as `w.dilate()` is called, but L2 connections do not begin until the +Leader declares the start of a new L2 generation with the `start-dilation` +message. + +Generations are non-overlapping. The Leader will drop all connections from +generation 1 before sending the `start-dilation` for generation 2, and will +not initiate any gen-2 connections until it receives the matching +`ok-dilation` from the Follower. The Follower must drop all gen-1 connections +before it sends the `ok-dilation` response (even if it thinks they are still +functioning: if the Leader thought the gen-1 connection still worked, it +wouldn't have started gen-2). Listening sockets can be retained, but any +previous connection made through them must be dropped. This should avoid a +race. + +(TODO: what about a follower->leader connection that was started before +start-dilation is received, and gets established on the Leader side after +start-dilation is sent? the follower will drop it after it receives +start-dilation, but meanwhile the leader may accept it as gen2) + +(probably need to include the generation number in the handshake, or in the +derived key) + +(TODO: reduce the number of round-trip stalls here, I've added too many) + +"Connection hints" are type/address/port records that tell the other side of +likely targets for L2 connections. Both sides will try to determine their +external IP addresses, listen on a TCP port, and advertise `(tcp, +external-IP, port)` as a connection hint. The Transit Relay is also used as a +(lower-priority) hint. These are sent in `connection-hint` records, which can +be sent by the Leader any time after the `start-dilation` record, and by the +Follower after the `ok-dilation` record. Each side will initiate connections +upon receipt of the hints. + +``` +{ "type": "connection-hints", + "hints": [ ... ] +} +``` + +Hints can arrive at any time. One side might immediately send hints that can +be computed quickly, then send additional hints later as they become +available. For example, it might enumerate the local network interfaces and +send hints for all of the LAN addresses first, then send port-forwarding +(UPnP) requests to the local router. When the forwarding is established +(providing an externally-visible IP address and port), it can send additional +hints for that new endpoint. If the other peer happens to be on the same LAN, +the local connection can be established without waiting for the router's +response. + + +### Connection Hint Format + +Each member of the `hints` field describes a potential L2 connection target +endpoint, with an associated priority and a set of hints. + +The priority is a number (positive or negative float), where larger numbers +indicate that the client supplying that hint would prefer to use this +connection over others of lower number. This indicates a sense of cost or +performance. For example, the Transit Relay is lower priority than a direct +TCP connection, because it incurs a bandwidth cost (on the relay operator), +as well as adding latency. + +Each endpoint has a set of hints, because the same target might be reachable +by multiple hints. Once one hint succeeds, there is no point in using the +other hints. + +TODO: think this through some more. What's the example of a single endpoint +reachable by multiple hints? Should each hint have its own priority, or just +each endpoint? + +## L2 protocol + +Upon ``connectionMade()``, both sides send their handshake message. The +Leader sends "Magic-Wormhole Dilation Handshake v1 Leader\n\n". The Follower +sends "Magic-Wormhole Dilation Handshake v1 Follower\n\n". This should +trigger an immediate error for most non-magic-wormhole listeners (e.g. HTTP +servers that were contacted by accident). If the wrong handshake is received, +the connection will be dropped. For debugging purposes, the node might want +to keep looking at data beyond the first incorrect character and log +everything until the first newline. + +Everything beyond that point is a Noise protocol message, which consists of a +4-byte big-endian length field, followed by the indicated number of bytes. +This ises the `NNpsk0` pattern with the Leader as the first party ("-> psk, +e" in the Noise spec), and the Follower as the second ("<- e, ee"). The +pre-shared-key is the "dilation key", which is statically derived from the +master PAKE key using HKDF. Each L2 connection uses the same dilation key, +but different ephemeral keys, so each gets a different session key. + +The Leader sends the first message, which is a psk-encrypted ephemeral key. +The Follower sends the next message, its own psk-encrypted ephemeral key. The +Follower then sends an empty packet as the "key confirmation message", which +will be encrypted by the shared key. + +The Leader sees the KCM and knows the connection is viable. It delivers the +protocol object to the L3 manager, which will decide which connection to +select. When the L2 connection is selected to be the new L3, it will send an +empty KCM of its own, to let the Follower know the connection being selected. +All other L2 connections (either viable or still in handshake) are dropped, +all other connection attempts are cancelled, and all listening sockets are +shut down. + +The Follower will wait for either an empty KCM (at which point the L2 +connection is delivered to the Dilation manager as the new L3), a +disconnection, or an invalid message (which causes the connection to be +dropped). Other connections and/or listening sockets are stopped. + +Internally, the L2Protocol object manages the Noise session itself. It knows +(via a constructor argument) whether it is on the Leader or Follower side, +which affects both the role is plays in the Noise pattern, and the reaction +to receiving the ephemeral key (for which only the Follower sends an empty +KCM message). After that, the L2Protocol notifies the L3 object in three +situations: + +* the Noise session produces a valid decrypted frame (for Leader, this + includes the Follower's KCM, and thus indicates a viable candidate for + connection selection) +* the Noise session reports a failed decryption +* the TCP session is lost + +All notifications include a reference to the L2Protocol object (`self`). The +L3 object uses this reference to either close the connection (for errors or +when the selection process chooses someone else), to send the KCM message +(after selection, only for the Leader), or to send other L4 messages. The L3 +object will retain a reference to the winning L2 object. + +## L3 protocol + +The L3 layer is responsible for connection selection, monitoring/keepalives, +and message (de)serialization. Framing is handled by L2, so the inbound L3 +codepath receives single-message bytestrings, and delivers the same down to +L2 for encryption, framing, and transmission. + +Connection selection takes place exclusively on the Leader side, and includes +the following: + +* receipt of viable L2 connections from below (indicated by the first valid + decrypted frame received for any given connection) +* expiration of a timer +* comparison of TBD quality/desirability/cost metrics of viable connections +* selection of winner +* instructions to losing connections to disconnect +* delivery of KCM message through winning connection +* retain reference to winning connection + +On the Follower side, the L3 manager just waits for the first connection to +receive the Leader's KCM, at which point it is retained and all others are +dropped. + +The L3 manager knows which "generation" of connection is being established. +Each generation uses a different dilation key (?), and is triggered by a new +set of L1 messages. Connections from one generation should not be confused +with those of a different generation. + +Each time a new L3 connection is established, the L4 protocol is notified. It +will will immediately send all the L4 messages waiting in its outbound queue. +The L3 protocol simply wraps these in Noise frames and sends them to the +other side. + +The L3 manager monitors the viability of the current connection, and declares +it as lost when bidirectional traffic cannot be maintained. It uses PING and +PONG messages to detect this. These also serve to keep NAT entries alive, +since many firewalls will stop forwarding packets if they don't observe any +traffic for e.g. 5 minutes. + +Our goals are: + +* don't allow more than 30? seconds to pass without at least *some* data + being sent along each side of the connection +* allow the Leader to detect silent connection loss within 60? seconds +* minimize overhead + +We need both sides to: + +* maintain a 30-second repeating timer +* set a flag each time we write to the connection +* each time the timer fires, if the flag was clear then send a PONG, + otherwise clear the flag + +In addition, the Leader must: + +* run a 60-second repeating timer (ideally somewhat offset from the other) +* set a flag each time we receive data from the connection +* each time the timer fires, if the flag was clear then drop the connection, + otherwise clear the flag + +In the future, we might have L2 links that are less connection-oriented, +which might have a unidirectional failure mode, at which point we'll need to +monitor full roundtrips. To accomplish this, the Leader will send periodic +unconditional PINGs, and the Follower will respond with PONGs. If the +Leader->Follower connection is down, the PINGs won't arrive and no PONGs will +be produced. If the Follower->Leader direction has failed, the PONGs won't +arrive. The delivery of both will be delayed by actual data, so the timeouts +should be adjusted if we see regular data arriving. + +If the connection is dropped before the wormhole is closed (either the other +end explicitly dropped it, we noticed a problem and told TCP to drop it, or +TCP noticed a problem itself), the Leader-side L3 manager will initiate a +reconnection attempt. This uses L1 to send a new DILATE message through the +mailbox server, along with new connection hints. Eventually this will result +in a new L3 connection being established. + +Finally, L3 is responsible for message serialization and deserialization. L2 +performs decryption and delivers plaintext frames to L3. Each frame starts +with a one-byte type indicator. The rest of the message depends upon the +type: + +* 0x00 PING, 4-byte ping-id +* 0x01 PONG, 4-byte ping-id +* 0x02 OPEN, 4-byte subchannel-id, 4-byte seqnum +* 0x03 DATA, 4-byte subchannel-id, 4-byte seqnum, variable-length payload +* 0x04 CLOSE, 4-byte subchannel-id, 4-byte seqnum +* 0x05 ACK, 4-byte response-seqnum + +All seqnums are big-endian, and are provided by the L4 protocol. The other +fields are arbitrary and not interpreted as integers. The subchannel-ids must +be allocated by both sides without collision, but otherwise they are only +used to look up L5 objects for dispatch. The response-seqnum is always copied +from the OPEN/DATA/CLOSE packet being acknowledged. + +L3 consumes the PING and PONG messages. Receiving any PING will provoke a +PONG in response, with a copy of the ping-id field. The 30-second timer will +produce unprovoked PONGs with a ping-id of all zeros. A future viability +protocol will use PINGs to test for roundtrip functionality. + +All other messages (OPEN/DATA/CLOSE/ACK) are deserialized and delivered +"upstairs" to the L4 protocol handler. + +The current L3 connection's `IProducer`/`IConsumer` interface is made +available to the L4 flow-control manager. + +## L4 protocol + +The L4 protocol manages a durable stream of OPEN/DATA/CLOSE/ACK messages. +Since each will be enclosed in a Noise frame before they pass to L3, they do +not need length fields or other framing. + +Each OPEN/DATA/CLOSE has a sequence number, starting at 0, and monotonically +increasing by 1 for each message. Each direction has a separate number space. + +The L4 manager maintains a double-ended queue of unacknowledged outbound +messages. Subchannel activity (opening, closing, sending data) cause messages +to be added to this queue. If an L3 connection is available, these messages +are also sent over that connection, but they remain in the queue in case the +connection is lost and they must be retransmitted on some future replacement +connection. Messages stay in the queue until they can be retired by the +receipt of an ACK with a matching response-sequence-number. This provides +reliable message delivery that survives the L3 connection being replaced. + +ACKs are not acked, nor do they have seqnums of their own. Each inbound side +remembers the highest ACK it has sent, and ignores incoming OPEN/DATA/CLOSE +messages with that sequence number or higher. This ensures in-order +at-most-once processing of OPEN/DATA/CLOSE messages. + +Each inbound OPEN message causes a new L5 subchannel object to be created. +Subsequent DATA/CLOSE messages for the same subchannel-id are delivered to +that object. + +Each time an L3 connection is established, the side will immediately send all +L4 messages waiting in the outbound queue. A future protocol might reduce +this duplication by including the highest received sequence number in the L1 +PLEASE-DILATE message, which would effectively retire queued messages before +initiating the L2 connection process. On any given L3 connection, all +messages are sent in-order. The receipt of an ACK for seqnum `N` allows all +messages with `seqnum <= N` to be retired. + +The L4 layer is also responsible for managing flow control among the L3 +connection and the various L5 subchannels. + +## L5 subchannels + +The L5 layer consists of a collection of "subchannel" objects, a dispatcher, +and the endpoints that provide the Twisted-flavored API. + +Other than the "control channel", all subchannels are created by a client +endpoint connection API. The side that calls this API is named the Initiator, +and the other side is named the Acceptor. Subchannels can be initiated in +either direction, independent of the Leader/Follower distinction. For a +typical file-transfer application, the subchannel would be initiated by the +side seeking to send a file. + +Each subchannel uses a distinct subchannel-id, which is a four-byte +identifier. Both directions share a number space (unlike L4 seqnums), so the +rule is that the Leader side sets the last bit of the last byte to a 0, while +the Follower sets it to a 1. These are not generally treated as integers, +however for the sake of debugging, the implementation generates them with a +simple big-endian-encoded counter (`next(counter)*2` for the Leader, +`next(counter)*2+1` for the Follower). + +When the `client_ep.connect()` API is called, the Initiator allocates a +subchannel-id and sends an OPEN. It can then immediately send DATA messages +with the outbound data (there is no special response to an OPEN, so there is +no need to wait). The Acceptor will trigger their `.connectionMade` handler +upon receipt of the OPEN. + +Subchannels are durable: they do not close until one side calls +`.loseConnection` on the subchannel object (or the enclosing Wormhole is +closed). Either the Initiator or the Acceptor can call `.loseConnection`. +This causes a CLOSE message to be sent (with the subchannel-id). The other +side will send its own CLOSE message in response. Each side will signal the +`.connectionLost()` event upon receipt of a CLOSE. + +There is no equivalent to TCP's "half-closed" state, however if only one side +calls `close()`, then all data written before that call will be delivered +before the other side observes `.connectionLost()`. Any inbound data that was +queued for delivery before the other side sees the CLOSE will still be +delivered to the side that called `close()` before it sees +`.connectionLost()`. Internally, the side which called `.loseConnection` will +remain in a special "closing" state until the CLOSE response arrives, during +which time DATA payloads are still delivered. After calling `close()` (or +receiving CLOSE), any outbound `.write()` calls will trigger an error. + +DATA payloads that arrive for a non-open subchannel are logged and discarded. + +This protocol calls for one OPEN and two CLOSE messages for each subchannel, +with some arbitrary number of DATA messages in between. Subchannel-ids should +not be reused (it would probably work, the protocol hasn't been analyzed +enough to be sure). + +The "control channel" is special. It uses a subchannel-id of all zeros, and +is opened implicitly by both sides as soon as the first L3 connection is +selected. It is routed to a special client-on-both-sides endpoint, rather +than causing the listening endpoint to accept a new connection. This avoids +the need for application-level code to negotiate who should be the one to +open it (the Leader/Follower distinction is private to the Wormhole +internals: applications are not obligated to pick a side). + +OPEN and CLOSE messages for the control channel are logged and discarded. The +control-channel client endpoints can only be used once, and does not close +until the Wormhole itself is closed. + +Each OPEN/DATA/CLOSE message is delivered to the L4 object for queueing, +delivery, and eventual retirement. The L5 layer does not keep track of old +messages. + +### Flow Control + +Subchannels are flow-controlled by pausing their writes when the L3 +connection is paused, and pausing the L3 connection when the subchannel +signals a pause. When the outbound L3 connection is full, *all* subchannels +are paused. Likewise the inbound connection is paused if *any* of the +subchannels asks for a pause. This is much easier to implement and improves +our utilization factor (we can use TCP's window-filling algorithm, instead of +rolling our own), but will block all subchannels even if only one of them +gets full. This shouldn't matter for many applications, but might be +noticeable when combining very different kinds of traffic (e.g. a chat +conversation sharing a wormhole with file-transfer might prefer the IM text +to take priority). + +Each subchannel implements Twisted's `ITransport`, `IProducer`, and +`IConsumer` interfaces. The Endpoint API causes a new `IProtocol` object to +be created (by the caller's factory) and glued to the subchannel object in +the `.transport` property, as is standard in Twisted-based applications. + +All subchannels are also paused when the L3 connection is lost, and are +unpaused when a new replacement connection is selected. diff --git a/docs/new-protocol.svg b/docs/new-protocol.svg new file mode 100644 index 0000000..a69b79a --- /dev/null +++ b/docs/new-protocol.svg @@ -0,0 +1,2000 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + image/svg+xml + + + + + + + connectionMade() + dataReceived() + dataReceived() + connectionLost() + + + empty + + + + open + + + + open + + + + closing + + + + empty + + + + + connect() + + + Open1 + + write() + write() + loseConnection() + connectionLost() + + + open + + + + + + Data1 + + + + Data1 + + + + Close1 + + + + Close1 + + + + + + + + + + + + + + empty + + + + open + + + + open + + + + empty + + + + + + open + + + + + + + + + + + + + + + Open1 + + + + Data1 + + + + Data1 + + + + Close1 + + connection 1 + + + Open1 + + + + Data1 + + 0 + 1 + 2 + 3 + + + ack 0 + + connection 2 + + + Data1 + + + + ack 1 + + + + Data1 + + + + Close1 + + + + ack 2 + + + + ack 3 + + + + Close1 + + + + ack 0' + + 0' + 0 + 1 + 1 + 2 + 3 + 0' + logical + 0 + 1 + 2 + 3 + + diff --git a/docs/state-machines/machines.dot b/docs/state-machines/machines.dot index ecfdf9a..09cc02c 100644 --- a/docs/state-machines/machines.dot +++ b/docs/state-machines/machines.dot @@ -23,12 +23,13 @@ digraph { Terminator [shape="box" color="blue" fontcolor="blue"] InputHelperAPI [shape="oval" label="input\nhelper\nAPI" color="blue" fontcolor="blue"] + Dilator [shape="box" label="Dilator" color="blue" fontcolor="blue"] #Connection -> websocket [color="blue"] #Connection -> Order [color="blue"] Wormhole -> Boss [style="dashed" - label="allocate_code\ninput_code\nset_code\nsend\nclose\n(once)" + label="allocate_code\ninput_code\nset_code\ndilate\nsend\nclose\n(once)" color="red" fontcolor="red"] #Wormhole -> Boss [color="blue"] Boss -> Wormhole [style="dashed" label="got_code\ngot_key\ngot_verifier\ngot_version\nreceived (seq)\nclosed\n(once)"] @@ -112,4 +113,7 @@ digraph { Terminator -> Boss [style="dashed" label="closed\n(once)"] Boss -> Terminator [style="dashed" color="red" fontcolor="red" label="close"] + + Boss -> Dilator [style="dashed" label="dilate\nreceived_dilate\ngot_wormhole_versions"] + Dilator -> Send [style="dashed" label="send(dilate-N)"] } diff --git a/setup.py b/setup.py index a99a2d2..c275e8b 100644 --- a/setup.py +++ b/setup.py @@ -48,6 +48,7 @@ setup(name="magic-wormhole", "click", "humanize", "txtorcon >= 18.0.2", # 18.0.2 fixes py3.4 support + "noiseprotocol", ], extras_require={ ':sys_platform=="win32"': ["pywin32"], diff --git a/src/wormhole/_boss.py b/src/wormhole/_boss.py index 373c0f4..cc6f4fa 100644 --- a/src/wormhole/_boss.py +++ b/src/wormhole/_boss.py @@ -12,6 +12,7 @@ from zope.interface import implementer from . import _interfaces from ._allocator import Allocator from ._code import Code, validate_code +from ._dilation.manager import Dilator from ._input import Input from ._key import Key from ._lister import Lister @@ -66,6 +67,7 @@ class Boss(object): self._I = Input(self._timing) self._C = Code(self._timing) self._T = Terminator() + self._D = Dilator(self._reactor, self._eventual_queue, self._cooperator) self._N.wire(self._M, self._I, self._RC, self._T) self._M.wire(self._N, self._RC, self._O, self._T) @@ -79,6 +81,7 @@ class Boss(object): self._I.wire(self._C, self._L) self._C.wire(self, self._A, self._N, self._K, self._I) self._T.wire(self, self._RC, self._N, self._M) + self._D.wire(self._S) def _init_other_state(self): self._did_start_code = False @@ -86,6 +89,9 @@ class Boss(object): self._next_rx_phase = 0 self._rx_phases = {} # phase -> plaintext + self._next_rx_dilate_seqnum = 0 + self._rx_dilate_seqnums = {} # seqnum -> plaintext + self._result = "empty" # these methods are called from outside @@ -198,6 +204,9 @@ class Boss(object): self._did_start_code = True self._C.set_code(code) + def dilate(self): + return self._D.dilate() # fires with endpoints + @m.input() def send(self, plaintext): pass @@ -258,8 +267,11 @@ class Boss(object): # this is only called for side != ours assert isinstance(phase, type("")), type(phase) assert isinstance(plaintext, type(b"")), type(plaintext) + d_mo = re.search(r'^dilate-(\d+)$', phase) if phase == "version": self._got_version(side, plaintext) + elif d_mo: + self._got_dilate(int(d_mo.group(1)), plaintext) elif re.search(r'^\d+$', phase): self._got_phase(int(phase), plaintext) else: @@ -275,6 +287,10 @@ class Boss(object): def _got_phase(self, phase, plaintext): pass + @m.input() + def _got_dilate(self, seqnum, plaintext): + pass + @m.input() def got_key(self, key): pass @@ -298,6 +314,8 @@ class Boss(object): # in the future, this is how Dilation is signalled self._their_side = side self._their_versions = bytes_to_dict(plaintext) + self._D.got_wormhole_versions(self._side, self._their_side, + self._their_versions) # but this part is app-to-app app_versions = self._their_versions.get("app_versions", {}) self._W.got_versions(app_versions) @@ -339,6 +357,10 @@ class Boss(object): def W_got_key(self, key): self._W.got_key(key) + @m.output() + def D_got_key(self, key): + self._D.got_key(key) + @m.output() def W_got_verifier(self, verifier): self._W.got_verifier(verifier) @@ -352,6 +374,16 @@ class Boss(object): self._W.received(self._rx_phases.pop(self._next_rx_phase)) self._next_rx_phase += 1 + @m.output() + def D_received_dilate(self, seqnum, plaintext): + assert isinstance(seqnum, six.integer_types), type(seqnum) + # strict phase order, no gaps + self._rx_dilate_seqnums[seqnum] = plaintext + while self._next_rx_dilate_seqnum in self._rx_dilate_seqnums: + m = self._rx_dilate_seqnums.pop(self._next_rx_dilate_seqnum) + self._D.received_dilate(m) + self._next_rx_dilate_seqnum += 1 + @m.output() def W_close_with_error(self, err): self._result = err # exception @@ -374,7 +406,7 @@ class Boss(object): S1_lonely.upon(scared, enter=S3_closing, outputs=[close_scared]) S1_lonely.upon(close, enter=S3_closing, outputs=[close_lonely]) S1_lonely.upon(send, enter=S1_lonely, outputs=[S_send]) - S1_lonely.upon(got_key, enter=S1_lonely, outputs=[W_got_key]) + S1_lonely.upon(got_key, enter=S1_lonely, outputs=[W_got_key, D_got_key]) S1_lonely.upon(rx_error, enter=S3_closing, outputs=[close_error]) S1_lonely.upon(error, enter=S4_closed, outputs=[W_close_with_error]) @@ -382,6 +414,7 @@ class Boss(object): S2_happy.upon(got_verifier, enter=S2_happy, outputs=[W_got_verifier]) S2_happy.upon(_got_phase, enter=S2_happy, outputs=[W_received]) S2_happy.upon(_got_version, enter=S2_happy, outputs=[process_version]) + S2_happy.upon(_got_dilate, enter=S2_happy, outputs=[D_received_dilate]) S2_happy.upon(scared, enter=S3_closing, outputs=[close_scared]) S2_happy.upon(close, enter=S3_closing, outputs=[close_happy]) S2_happy.upon(send, enter=S2_happy, outputs=[S_send]) @@ -393,6 +426,7 @@ class Boss(object): S3_closing.upon(got_verifier, enter=S3_closing, outputs=[]) S3_closing.upon(_got_phase, enter=S3_closing, outputs=[]) S3_closing.upon(_got_version, enter=S3_closing, outputs=[]) + S3_closing.upon(_got_dilate, enter=S3_closing, outputs=[]) S3_closing.upon(happy, enter=S3_closing, outputs=[]) S3_closing.upon(scared, enter=S3_closing, outputs=[]) S3_closing.upon(close, enter=S3_closing, outputs=[]) @@ -404,6 +438,7 @@ class Boss(object): S4_closed.upon(got_verifier, enter=S4_closed, outputs=[]) S4_closed.upon(_got_phase, enter=S4_closed, outputs=[]) S4_closed.upon(_got_version, enter=S4_closed, outputs=[]) + S4_closed.upon(_got_dilate, enter=S4_closed, outputs=[]) S4_closed.upon(happy, enter=S4_closed, outputs=[]) S4_closed.upon(scared, enter=S4_closed, outputs=[]) S4_closed.upon(close, enter=S4_closed, outputs=[]) diff --git a/src/wormhole/_dilation/__init__.py b/src/wormhole/_dilation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/wormhole/_dilation/connection.py b/src/wormhole/_dilation/connection.py new file mode 100644 index 0000000..f0c8e35 --- /dev/null +++ b/src/wormhole/_dilation/connection.py @@ -0,0 +1,482 @@ +from __future__ import print_function, unicode_literals +from collections import namedtuple +import six +from attr import attrs, attrib +from attr.validators import instance_of, provides +from automat import MethodicalMachine +from zope.interface import Interface, implementer +from twisted.python import log +from twisted.internet.protocol import Protocol +from twisted.internet.interfaces import ITransport +from .._interfaces import IDilationConnector +from ..observer import OneShotObserver +from .encode import to_be4, from_be4 +from .roles import FOLLOWER + +# InboundFraming is given data and returns Frames (Noise wire-side +# bytestrings). It handles the relay handshake and the prologue. The Frames it +# returns are either the ephemeral key (the Noise "handshake") or ciphertext +# messages. + +# The next object up knows whether it's expecting a Handshake or a message. It +# feeds the first into Noise as a handshake, it feeds the rest into Noise as a +# message (which produces a plaintext stream). It emits tokens that are either +# "i've finished with the handshake (so you can send the KCM if you want)", or +# "here is a decrypted message (which might be the KCM)". + +# the transmit direction goes directly to transport.write, and doesn't touch +# the state machine. we can do this because the way we encode/encrypt/frame +# things doesn't depend upon the receiver state. It would be more safe to e.g. +# prohibit sending ciphertext frames unless we're in the received-handshake +# state, but then we'll be in the middle of an inbound state transition ("we +# just received the handshake, so you can send a KCM now") when we perform an +# operation that depends upon the state (send_plaintext(kcm)), which is not a +# coherent/safe place to touch the state machine. + +# we could set a flag and test it from inside send_plaintext, which kind of +# violates the state machine owning the state (ideally all "if" statements +# would be translated into same-input transitions from different starting +# states). For the specific question of sending plaintext frames, Noise will +# refuse us unless it's ready anyways, so the question is probably moot. + +class IFramer(Interface): + pass +class IRecord(Interface): + pass + +def first(l): + return l[0] + +class Disconnect(Exception): + pass +RelayOK = namedtuple("RelayOk", []) +Prologue = namedtuple("Prologue", []) +Frame = namedtuple("Frame", ["frame"]) + +@attrs +@implementer(IFramer) +class _Framer(object): + _transport = attrib(validator=provides(ITransport)) + _outbound_prologue = attrib(validator=instance_of(bytes)) + _inbound_prologue = attrib(validator=instance_of(bytes)) + _buffer = b"" + _can_send_frames = False + + # in: use_relay + # in: connectionMade, dataReceived + # out: prologue_received, frame_received + # out (shared): transport.loseConnection + # out (shared): transport.write (relay handshake, prologue) + # states: want_relay, want_prologue, want_frame + m = MethodicalMachine() + set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover + + @m.state() + def want_relay(self): pass # pragma: no cover + @m.state(initial=True) + def want_prologue(self): pass # pragma: no cover + @m.state() + def want_frame(self): pass # pragma: no cover + + @m.input() + def use_relay(self, relay_handshake): pass + @m.input() + def connectionMade(self): pass + @m.input() + def parse(self): pass + @m.input() + def got_relay_ok(self): pass + @m.input() + def got_prologue(self): pass + + @m.output() + def store_relay_handshake(self, relay_handshake): + self._outbound_relay_handshake = relay_handshake + self._expected_relay_handshake = b"ok\n" # TODO: make this configurable + @m.output() + def send_relay_handshake(self): + self._transport.write(self._outbound_relay_handshake) + + @m.output() + def send_prologue(self): + self._transport.write(self._outbound_prologue) + + @m.output() + def parse_relay_ok(self): + if self._get_expected("relay_ok", self._expected_relay_handshake): + return RelayOK() + + @m.output() + def parse_prologue(self): + if self._get_expected("prologue", self._inbound_prologue): + return Prologue() + + @m.output() + def can_send_frames(self): + self._can_send_frames = True # for assertion in send_frame() + + @m.output() + def parse_frame(self): + if len(self._buffer) < 4: + return None + frame_length = from_be4(self._buffer[0:4]) + if len(self._buffer) < 4+frame_length: + return None + frame = self._buffer[4:4+frame_length] + self._buffer = self._buffer[4+frame_length:] # TODO: avoid copy + return Frame(frame=frame) + + want_prologue.upon(use_relay, outputs=[store_relay_handshake], + enter=want_relay) + + want_relay.upon(connectionMade, outputs=[send_relay_handshake], + enter=want_relay) + want_relay.upon(parse, outputs=[parse_relay_ok], enter=want_relay, + collector=first) + want_relay.upon(got_relay_ok, outputs=[send_prologue], enter=want_prologue) + + want_prologue.upon(connectionMade, outputs=[send_prologue], + enter=want_prologue) + want_prologue.upon(parse, outputs=[parse_prologue], enter=want_prologue, + collector=first) + want_prologue.upon(got_prologue, outputs=[can_send_frames], enter=want_frame) + + want_frame.upon(parse, outputs=[parse_frame], enter=want_frame, + collector=first) + + + def _get_expected(self, name, expected): + lb = len(self._buffer) + le = len(expected) + if self._buffer.startswith(expected): + # if the buffer starts with the expected string, consume it and + # return True + self._buffer = self._buffer[le:] + return True + if not expected.startswith(self._buffer): + # we're not on track: the data we've received so far does not + # match the expected value, so this can't possibly be right. + # Don't complain until we see the expected length, or a newline, + # so we can capture the weird input in the log for debugging. + if (b"\n" in self._buffer or lb >= le): + log.msg("bad {}: {}".format(name, self._buffer[:le])) + raise Disconnect() + return False # wait a bit longer + # good so far, just waiting for the rest + return False + + # external API is: connectionMade, add_and_parse, and send_frame + + def add_and_parse(self, data): + # we can't make this an @m.input because we can't change the state + # from within an input. Instead, let the state choose the parser to + # use, and use the parsed token drive a state transition. + self._buffer += data + while True: + # it'd be nice to use an iterator here, but since self.parse() + # dispatches to a different parser (depending upon the current + # state), we'd be using multiple iterators + token = self.parse() + if isinstance(token, RelayOK): + self.got_relay_ok() + elif isinstance(token, Prologue): + self.got_prologue() + yield token # triggers send_handshake + elif isinstance(token, Frame): + yield token + else: + break + + def send_frame(self, frame): + assert self._can_send_frames + self._transport.write(to_be4(len(frame)) + frame) + +# RelayOK: Newline-terminated buddy-is-connected response from Relay. +# First data received from relay. +# Prologue: double-newline-terminated this-is-really-wormhole response +# from peer. First data received from peer. +# Frame: Either handshake or encrypted message. Length-prefixed on wire. +# Handshake: the Noise ephemeral key, first framed message +# Message: plaintext: encoded KCM/PING/PONG/OPEN/DATA/CLOSE/ACK +# KCM: Key Confirmation Message (encrypted b"\x00"). First frame +# from peer. Sent immediately by Follower, after Selection by Leader. +# Record: namedtuple of KCM/Open/Data/Close/Ack/Ping/Pong + +Handshake = namedtuple("Handshake", []) +# decrypted frames: produces KCM, Ping, Pong, Open, Data, Close, Ack +KCM = namedtuple("KCM", []) +Ping = namedtuple("Ping", ["ping_id"]) # ping_id is arbitrary 4-byte value +Pong = namedtuple("Pong", ["ping_id"]) +Open = namedtuple("Open", ["seqnum", "scid"]) # seqnum is integer +Data = namedtuple("Data", ["seqnum", "scid", "data"]) +Close = namedtuple("Close", ["seqnum", "scid"]) # scid is integer +Ack = namedtuple("Ack", ["resp_seqnum"]) # resp_seqnum is integer +Records = (KCM, Ping, Pong, Open, Data, Close, Ack) +Handshake_or_Records = (Handshake,) + Records + +T_KCM = b"\x00" +T_PING = b"\x01" +T_PONG = b"\x02" +T_OPEN = b"\x03" +T_DATA = b"\x04" +T_CLOSE = b"\x05" +T_ACK = b"\x06" + +def parse_record(plaintext): + msgtype = plaintext[0:1] + if msgtype == T_KCM: + return KCM() + if msgtype == T_PING: + ping_id = plaintext[1:5] + return Ping(ping_id) + if msgtype == T_PONG: + ping_id = plaintext[1:5] + return Pong(ping_id) + if msgtype == T_OPEN: + scid = from_be4(plaintext[1:5]) + seqnum = from_be4(plaintext[5:9]) + return Open(seqnum, scid) + if msgtype == T_DATA: + scid = from_be4(plaintext[1:5]) + seqnum = from_be4(plaintext[5:9]) + data = plaintext[9:] + return Data(seqnum, scid, data) + if msgtype == T_CLOSE: + scid = from_be4(plaintext[1:5]) + seqnum = from_be4(plaintext[5:9]) + return Close(seqnum, scid) + if msgtype == T_ACK: + resp_seqnum = from_be4(plaintext[1:5]) + return Ack(resp_seqnum) + log.err("received unknown message type: {}".format(plaintext)) + raise ValueError() + +def encode_record(r): + if isinstance(r, KCM): + return b"\x00" + if isinstance(r, Ping): + return b"\x01" + r.ping_id + if isinstance(r, Pong): + return b"\x02" + r.ping_id + if isinstance(r, Open): + assert isinstance(r.scid, six.integer_types) + assert isinstance(r.seqnum, six.integer_types) + return b"\x03" + to_be4(r.scid) + to_be4(r.seqnum) + if isinstance(r, Data): + assert isinstance(r.scid, six.integer_types) + assert isinstance(r.seqnum, six.integer_types) + return b"\x04" + to_be4(r.scid) + to_be4(r.seqnum) + r.data + if isinstance(r, Close): + assert isinstance(r.scid, six.integer_types) + assert isinstance(r.seqnum, six.integer_types) + return b"\x05" + to_be4(r.scid) + to_be4(r.seqnum) + if isinstance(r, Ack): + assert isinstance(r.resp_seqnum, six.integer_types) + return b"\x06" + to_be4(r.resp_seqnum) + raise TypeError(r) + +@attrs +@implementer(IRecord) +class _Record(object): + _framer = attrib(validator=provides(IFramer)) + _noise = attrib() + + n = MethodicalMachine() + # TODO: set_trace + + def __attrs_post_init__(self): + self._noise.start_handshake() + + # in: role= + # in: prologue_received, frame_received + # out: handshake_received, record_received + # out: transport.write (noise handshake, encrypted records) + # states: want_prologue, want_handshake, want_record + + @n.state(initial=True) + def want_prologue(self): pass # pragma: no cover + @n.state() + def want_handshake(self): pass # pragma: no cover + @n.state() + def want_message(self): pass # pragma: no cover + + @n.input() + def got_prologue(self): + pass + @n.input() + def got_frame(self, frame): + pass + + @n.output() + def send_handshake(self): + handshake = self._noise.write_message() # generate the ephemeral key + self._framer.send_frame(handshake) + + @n.output() + def process_handshake(self, frame): + from noise.exceptions import NoiseInvalidMessage + try: + payload = self._noise.read_message(frame) + # Noise can include unencrypted data in the handshake, but we don't + # use it + del payload + except NoiseInvalidMessage as e: + log.err(e, "bad inbound noise handshake") + raise Disconnect() + return Handshake() + + @n.output() + def decrypt_message(self, frame): + from noise.exceptions import NoiseInvalidMessage + try: + message = self._noise.decrypt(frame) + except NoiseInvalidMessage as e: + # if this happens during tests, flunk the test + log.err(e, "bad inbound noise frame") + raise Disconnect() + return parse_record(message) + + want_prologue.upon(got_prologue, outputs=[send_handshake], + enter=want_handshake) + want_handshake.upon(got_frame, outputs=[process_handshake], + collector=first, enter=want_message) + want_message.upon(got_frame, outputs=[decrypt_message], + collector=first, enter=want_message) + + # external API is: connectionMade, dataReceived, send_record + + def connectionMade(self): + self._framer.connectionMade() + + def add_and_unframe(self, data): + for token in self._framer.add_and_parse(data): + if isinstance(token, Prologue): + self.got_prologue() # triggers send_handshake + else: + assert isinstance(token, Frame) + yield self.got_frame(token.frame) # Handshake or a Record type + + def send_record(self, r): + message = encode_record(r) + frame = self._noise.encrypt(message) + self._framer.send_frame(frame) + + +@attrs +class DilatedConnectionProtocol(Protocol, object): + """I manage an L2 connection. + + When a new L2 connection is needed (as determined by the Leader), + both Leader and Follower will initiate many simultaneous connections + (probably TCP, but conceivably others). A subset will actually + connect. A subset of those will successfully pass negotiation by + exchanging handshakes to demonstrate knowledge of the session key. + One of the negotiated connections will be selected by the Leader for + active use, and the others will be dropped. + + At any given time, there is at most one active L2 connection. + """ + + _eventual_queue = attrib() + _role = attrib() + _connector = attrib(validator=provides(IDilationConnector)) + _noise = attrib() + _outbound_prologue = attrib(validator=instance_of(bytes)) + _inbound_prologue = attrib(validator=instance_of(bytes)) + + _use_relay = False + _relay_handshake = None + + m = MethodicalMachine() + set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover + + def __attrs_post_init__(self): + self._manager = None # set if/when we are selected + self._disconnected = OneShotObserver(self._eventual_queue) + self._can_send_records = False + + @m.state(initial=True) + def unselected(self): pass # pragma: no cover + @m.state() + def selecting(self): pass # pragma: no cover + @m.state() + def selected(self): pass # pragma: no cover + + @m.input() + def got_kcm(self): + pass + @m.input() + def select(self, manager): + pass # fires set_manager() + @m.input() + def got_record(self, record): + pass + + @m.output() + def add_candidate(self): + self._connector.add_candidate(self) + + @m.output() + def set_manager(self, manager): + self._manager = manager + + @m.output() + def can_send_records(self, manager): + self._can_send_records = True + + @m.output() + def deliver_record(self, record): + self._manager.got_record(record) + + unselected.upon(got_kcm, outputs=[add_candidate], enter=selecting) + selecting.upon(select, outputs=[set_manager, can_send_records], enter=selected) + selected.upon(got_record, outputs=[deliver_record], enter=selected) + + # called by Connector + + def use_relay(self, relay_handshake): + assert isinstance(relay_handshake, bytes) + self._use_relay = True + self._relay_handshake = relay_handshake + + def when_disconnected(self): + return self._disconnected.when_fired() + + def disconnect(self): + self.transport.loseConnection() + + # select() called by Connector + + # called by Manager + def send_record(self, record): + assert self._can_send_records + self._record.send_record(record) + + # IProtocol methods + + def connectionMade(self): + framer = _Framer(self.transport, + self._outbound_prologue, self._inbound_prologue) + if self._use_relay: + framer.use_relay(self._relay_handshake) + self._record = _Record(framer, self._noise) + self._record.connectionMade() + + def dataReceived(self, data): + try: + for token in self._record.add_and_unframe(data): + assert isinstance(token, Handshake_or_Records) + if isinstance(token, Handshake): + if self._role is FOLLOWER: + self._record.send_record(KCM()) + elif isinstance(token, KCM): + # if we're the leader, add this connection as a candiate. + # if we're the follower, accept this connection. + self.got_kcm() # connector.add_candidate() + else: + self.got_record(token) # manager.got_record() + except Disconnect: + self.transport.loseConnection() + + def connectionLost(self, why=None): + self._disconnected.fire(self) diff --git a/src/wormhole/_dilation/connector.py b/src/wormhole/_dilation/connector.py new file mode 100644 index 0000000..86f2a72 --- /dev/null +++ b/src/wormhole/_dilation/connector.py @@ -0,0 +1,482 @@ +from __future__ import print_function, unicode_literals +import sys, re +from collections import defaultdict, namedtuple +from binascii import hexlify +import six +from attr import attrs, attrib +from attr.validators import instance_of, provides, optional +from automat import MethodicalMachine +from zope.interface import implementer +from twisted.internet.task import deferLater +from twisted.internet.defer import DeferredList +from twisted.internet.endpoints import HostnameEndpoint, serverFromString +from twisted.internet.protocol import ClientFactory, ServerFactory +from twisted.python import log +from hkdf import Hkdf +from .. import ipaddrs # TODO: move into _dilation/ +from .._interfaces import IDilationConnector, IDilationManager +from ..timing import DebugTiming +from ..observer import EmptyableSet +from .connection import DilatedConnectionProtocol, KCM +from .roles import LEADER + + +# These namedtuples are "hint objects". The JSON-serializable dictionaries +# are "hint dicts". + +# DirectTCPV1Hint and TorTCPV1Hint mean the following protocol: +# * make a TCP connection (possibly via Tor) +# * send the sender/receiver handshake bytes first +# * expect to see the receiver/sender handshake bytes from the other side +# * the sender writes "go\n", the receiver waits for "go\n" +# * the rest of the connection contains transit data +DirectTCPV1Hint = namedtuple("DirectTCPV1Hint", ["hostname", "port", "priority"]) +TorTCPV1Hint = namedtuple("TorTCPV1Hint", ["hostname", "port", "priority"]) +# RelayV1Hint contains a tuple of DirectTCPV1Hint and TorTCPV1Hint hints (we +# use a tuple rather than a list so they'll be hashable into a set). For each +# one, make the TCP connection, send the relay handshake, then complete the +# rest of the V1 protocol. Only one hint per relay is useful. +RelayV1Hint = namedtuple("RelayV1Hint", ["hints"]) + +def describe_hint_obj(hint, relay, tor): + prefix = "tor->" if tor else "->" + if relay: + prefix = prefix + "relay:" + if isinstance(hint, DirectTCPV1Hint): + return prefix + "tcp:%s:%d" % (hint.hostname, hint.port) + elif isinstance(hint, TorTCPV1Hint): + return prefix+"tor:%s:%d" % (hint.hostname, hint.port) + else: + return prefix+str(hint) + +def parse_hint_argv(hint, stderr=sys.stderr): + assert isinstance(hint, type("")) + # return tuple or None for an unparseable hint + priority = 0.0 + mo = re.search(r'^([a-zA-Z0-9]+):(.*)$', hint) + if not mo: + print("unparseable hint '%s'" % (hint,), file=stderr) + return None + hint_type = mo.group(1) + if hint_type != "tcp": + print("unknown hint type '%s' in '%s'" % (hint_type, hint), file=stderr) + return None + hint_value = mo.group(2) + pieces = hint_value.split(":") + if len(pieces) < 2: + print("unparseable TCP hint (need more colons) '%s'" % (hint,), + file=stderr) + return None + mo = re.search(r'^(\d+)$', pieces[1]) + if not mo: + print("non-numeric port in TCP hint '%s'" % (hint,), file=stderr) + return None + hint_host = pieces[0] + hint_port = int(pieces[1]) + for more in pieces[2:]: + if more.startswith("priority="): + more_pieces = more.split("=") + try: + priority = float(more_pieces[1]) + except ValueError: + print("non-float priority= in TCP hint '%s'" % (hint,), + file=stderr) + return None + return DirectTCPV1Hint(hint_host, hint_port, priority) + +def parse_tcp_v1_hint(hint): # hint_struct -> hint_obj + hint_type = hint.get("type", "") + if hint_type not in ["direct-tcp-v1", "tor-tcp-v1"]: + log.msg("unknown hint type: %r" % (hint,)) + return None + if not("hostname" in hint + and isinstance(hint["hostname"], type(""))): + log.msg("invalid hostname in hint: %r" % (hint,)) + return None + if not("port" in hint + and isinstance(hint["port"], six.integer_types)): + log.msg("invalid port in hint: %r" % (hint,)) + return None + priority = hint.get("priority", 0.0) + if hint_type == "direct-tcp-v1": + return DirectTCPV1Hint(hint["hostname"], hint["port"], priority) + else: + return TorTCPV1Hint(hint["hostname"], hint["port"], priority) + +def parse_hint(hint_struct): + hint_type = hint_struct.get("type", "") + if hint_type == "relay-v1": + # the struct can include multiple ways to reach the same relay + rhints = filter(lambda h: h, # drop None (unrecognized) + [parse_tcp_v1_hint(rh) for rh in hint_struct["hints"]]) + return RelayV1Hint(rhints) + return parse_tcp_v1_hint(hint_struct) + +def encode_hint(h): + if isinstance(h, DirectTCPV1Hint): + return {"type": "direct-tcp-v1", + "priority": h.priority, + "hostname": h.hostname, + "port": h.port, # integer + } + elif isinstance(h, RelayV1Hint): + rhint = {"type": "relay-v1", "hints": []} + for rh in h.hints: + rhint["hints"].append({"type": "direct-tcp-v1", + "priority": rh.priority, + "hostname": rh.hostname, + "port": rh.port}) + return rhint + elif isinstance(h, TorTCPV1Hint): + return {"type": "tor-tcp-v1", + "priority": h.priority, + "hostname": h.hostname, + "port": h.port, # integer + } + raise ValueError("unknown hint type", h) + +def HKDF(skm, outlen, salt=None, CTXinfo=b""): + return Hkdf(salt, skm).expand(CTXinfo, outlen) + +def build_sided_relay_handshake(key, side): + assert isinstance(side, type(u"")) + assert len(side) == 8*2 + token = HKDF(key, 32, CTXinfo=b"transit_relay_token") + return b"please relay "+hexlify(token)+b" for side "+side.encode("ascii")+b"\n" + +PROLOGUE_LEADER = b"Magic-Wormhole Dilation Handshake v1 Leader\n\n" +PROLOGUE_FOLLOWER = b"Magic-Wormhole Dilation Handshake v1 Follower\n\n" +NOISEPROTO = "Noise_NNpsk0_25519_ChaChaPoly_BLAKE2s" + +@attrs +@implementer(IDilationConnector) +class Connector(object): + _dilation_key = attrib(validator=instance_of(type(b""))) + _transit_relay_location = attrib(validator=optional(instance_of(str))) + _manager = attrib(validator=provides(IDilationManager)) + _reactor = attrib() + _eventual_queue = attrib() + _no_listen = attrib(validator=instance_of(bool)) + _tor = attrib() + _timing = attrib() + _side = attrib(validator=instance_of(type(u""))) + # was self._side = bytes_to_hexstr(os.urandom(8)) # unicode + _role = attrib() + + m = MethodicalMachine() + set_trace = getattr(m, "_setTrace", lambda self, f: None) + + RELAY_DELAY = 2.0 + + def __attrs_post_init__(self): + if self._transit_relay_location: + # TODO: allow multiple hints for a single relay + relay_hint = parse_hint_argv(self._transit_relay_location) + relay = RelayV1Hint(hints=(relay_hint,)) + self._transit_relays = [relay] + else: + self._transit_relays = [] + self._listeners = set() # IListeningPorts that can be stopped + self._pending_connectors = set() # Deferreds that can be cancelled + self._pending_connections = EmptyableSet(_eventual_queue=self._eventual_queue) # Protocols to be stopped + self._contenders = set() # viable connections + self._winning_connection = None + self._timing = self._timing or DebugTiming() + self._timing.add("transit") + + # this describes what our Connector can do, for the initial advertisement + @classmethod + def get_connection_abilities(klass): + return [{"type": "direct-tcp-v1"}, + {"type": "relay-v1"}, + ] + + def build_protocol(self, addr): + # encryption: let's use Noise NNpsk0 (or maybe NNpsk2). That uses + # ephemeral keys plus a pre-shared symmetric key (the Transit key), a + # different one for each potential connection. + from noise.connection import NoiseConnection + noise = NoiseConnection.from_name(NOISEPROTO) + noise.set_psks(self._dilation_key) + if self._role is LEADER: + noise.set_as_initiator() + outbound_prologue = PROLOGUE_LEADER + inbound_prologue = PROLOGUE_FOLLOWER + else: + noise.set_as_responder() + outbound_prologue = PROLOGUE_FOLLOWER + inbound_prologue = PROLOGUE_LEADER + p = DilatedConnectionProtocol(self._eventual_queue, self._role, + self, noise, + outbound_prologue, inbound_prologue) + return p + + @m.state(initial=True) + def connecting(self): pass # pragma: no cover + @m.state() + def connected(self): pass # pragma: no cover + @m.state(terminal=True) + def stopped(self): pass # pragma: no cover + + # TODO: unify the tense of these method-name verbs + @m.input() + def listener_ready(self, hint_objs): pass + @m.input() + def add_relay(self, hint_objs): pass + @m.input() + def got_hints(self, hint_objs): pass + @m.input() + def add_candidate(self, c): # called by DilatedConnectionProtocol + pass + @m.input() + def accept(self, c): pass + @m.input() + def stop(self): pass + + @m.output() + def use_hints(self, hint_objs): + self._use_hints(hint_objs) + + @m.output() + def publish_hints(self, hint_objs): + self._manager.send_hints([encode_hint(h) for h in hint_objs]) + + @m.output() + def consider(self, c): + self._contenders.add(c) + if self._role is LEADER: + # for now, just accept the first one. TODO: be clever. + self._eventual_queue.eventually(self.accept, c) + else: + # the follower always uses the first contender, since that's the + # only one the leader picked + self._eventual_queue.eventually(self.accept, c) + + @m.output() + def select_and_stop_remaining(self, c): + self._winning_connection = c + self._contenders.clear() # we no longer care who else came close + # remove this winner from the losers, so we don't shut it down + self._pending_connections.discard(c) + # shut down losing connections + self.stop_listeners() # TODO: maybe keep it open? NAT/p2p assist + self.stop_pending_connectors() + self.stop_pending_connections() + + c.select(self._manager) # subsequent frames go directly to the manager + if self._role is LEADER: + # TODO: this should live in Connection + c.send_record(KCM()) # leader sends KCM now + self._manager.use_connection(c) # manager sends frames to Connection + + @m.output() + def stop_everything(self): + self.stop_listeners() + self.stop_pending_connectors() + self.stop_pending_connections() + self.break_cycles() + + def stop_listeners(self): + d = DeferredList([l.stopListening() for l in self._listeners]) + self._listeners.clear() + return d # synchronization for tests + + def stop_pending_connectors(self): + return DeferredList([d.cancel() for d in self._pending_connectors]) + + def stop_pending_connections(self): + d = self._pending_connections.when_next_empty() + [c.loseConnection() for c in self._pending_connections] + return d + + def stop_winner(self): + d = self._winner.when_disconnected() + self._winner.disconnect() + return d + + def break_cycles(self): + # help GC by forgetting references to things that reference us + self._listeners.clear() + self._pending_connectors.clear() + self._pending_connections.clear() + self._winner = None + + connecting.upon(listener_ready, enter=connecting, outputs=[publish_hints]) + connecting.upon(add_relay, enter=connecting, outputs=[use_hints, + publish_hints]) + connecting.upon(got_hints, enter=connecting, outputs=[use_hints]) + connecting.upon(add_candidate, enter=connecting, outputs=[consider]) + connecting.upon(accept, enter=connected, outputs=[select_and_stop_remaining]) + connecting.upon(stop, enter=stopped, outputs=[stop_everything]) + + # once connected, we ignore everything except stop + connected.upon(listener_ready, enter=connected, outputs=[]) + connected.upon(add_relay, enter=connected, outputs=[]) + connected.upon(got_hints, enter=connected, outputs=[]) + connected.upon(add_candidate, enter=connected, outputs=[]) + connected.upon(accept, enter=connected, outputs=[]) + connected.upon(stop, enter=stopped, outputs=[stop_everything]) + + + # from Manager: start, got_hints, stop + # maybe add_candidate, accept + def start(self): + self._start_listener() + if self._transit_relays: + self.publish_hints(self._transit_relays) + self._use_hints(self._transit_relays) + + def _start_listener(self): + if self._no_listen or self._tor: + return + addresses = ipaddrs.find_addresses() + non_loopback_addresses = [a for a in addresses if a != "127.0.0.1"] + if non_loopback_addresses: + # some test hosts, including the appveyor VMs, *only* have + # 127.0.0.1, and the tests will hang badly if we remove it. + addresses = non_loopback_addresses + # TODO: listen on a fixed port, if possible, for NAT/p2p benefits, also + # to make firewall configs easier + # TODO: retain listening port between connection generations? + ep = serverFromString(self._reactor, "tcp:0") + f = InboundConnectionFactory(self) + d = ep.listen(f) + def _listening(lp): + # lp is an IListeningPort + self._listeners.add(lp) # for shutdown and tests + portnum = lp.getHost().port + direct_hints = [DirectTCPV1Hint(six.u(addr), portnum, 0.0) + for addr in addresses] + self.listener_ready(direct_hints) + d.addCallback(_listening) + d.addErrback(log.err) + + def _use_hints(self, hints): + # first, pull out all the relays, we'll connect to them later + relays = defaultdict(list) + direct = defaultdict(list) + for h in hints: + if isinstance(h, RelayV1Hint): + relays[h.priority].append(h) + else: + direct[h.priority].append(h) + delay = 0.0 + priorities = sorted(set(direct.keys()), reverse=True) + for p in priorities: + for h in direct[p]: + if isinstance(h, TorTCPV1Hint) and not self._tor: + continue + ep = self._endpoint_from_hint_obj(h) + desc = describe_hint_obj(h, False, self._tor) + d = deferLater(self._reactor, delay, + self._connect, ep, desc, is_relay=False) + self._pending_connectors.add(d) + # Make all direct connections immediately. Later, we'll change + # the add_candidate() function to look at the priority when + # deciding whether to accept a successful connection or not, + # and it can wait for more options if it sees a higher-priority + # one still running. But if we bail on that, we might consider + # putting an inter-direct-hint delay here to influence the + # process. + #delay += 1.0 + if delay > 0.0: + # Start trying the relays a few seconds after we start to try the + # direct hints. The idea is to prefer direct connections, but not + # be afraid of using a relay when we have direct hints that don't + # resolve quickly. Many direct hints will be to unused + # local-network IP addresses, which won't answer, and would take + # the full TCP timeout (30s or more) to fail. If there were no + # direct hints, don't delay at all. + delay += self.RELAY_DELAY + + # prefer direct connections by stalling relay connections by a few + # seconds, unless we're using --no-listen in which case we're probably + # going to have to use the relay + delay = self.RELAY_DELAY if self._no_listen else 0.0 + + # It might be nice to wire this so that a failure in the direct hints + # causes the relay hints to be used right away (fast failover). But + # none of our current use cases would take advantage of that: if we + # have any viable direct hints, then they're either going to succeed + # quickly or hang for a long time. + for p in priorities: + for r in relays[p]: + for h in r.hints: + ep = self._endpoint_from_hint_obj(h) + desc = describe_hint_obj(h, True, self._tor) + d = deferLater(self._reactor, delay, + self._connect, ep, desc, is_relay=True) + self._pending_connectors.add(d) + # TODO: + #if not contenders: + # raise TransitError("No contenders for connection") + + # TODO: add 2*TIMEOUT deadline for first generation, don't wait forever for + # the initial connection + + def _connect(self, h, ep, description, is_relay=False): + relay_handshake = None + if is_relay: + relay_handshake = build_sided_relay_handshake(self._dilation_key, + self._side) + f = OutboundConnectionFactory(self, relay_handshake) + d = ep.connect(f) + # fires with protocol, or ConnectError + def _connected(p): + self._pending_connections.add(p) + # c might not be in _pending_connections, if it turned out to be a + # winner, which is why we use discard() and not remove() + p.when_disconnected().addCallback(self._pending_connections.discard) + d.addCallback(_connected) + return d + + def _endpoint_from_hint_obj(self, hint): + if self._tor: + if isinstance(hint, (DirectTCPV1Hint, TorTCPV1Hint)): + # this Tor object will throw ValueError for non-public IPv4 + # addresses and any IPv6 address + try: + return self._tor.stream_via(hint.hostname, hint.port) + except ValueError: + return None + return None + if isinstance(hint, DirectTCPV1Hint): + return HostnameEndpoint(self._reactor, hint.hostname, hint.port) + return None + + + # Connection selection. All instances of DilatedConnectionProtocol which + # look viable get passed into our add_contender() method. + + # On the Leader side, "viable" means we've seen their KCM frame, which is + # the first Noise-encrypted packet on any given connection, and it has an + # empty body. We gather viable connections until we see one that we like, + # or a timer expires. Then we "select" it, close the others, and tell our + # Manager to use it. + + # On the Follower side, we'll only see a KCM on the one connection selected + # by the Leader, so the first viable connection wins. + + # our Connection protocols call: add_candidate + +@attrs +class OutboundConnectionFactory(ClientFactory, object): + _connector = attrib(validator=provides(IDilationConnector)) + _relay_handshake = attrib(validator=optional(instance_of(bytes))) + + def buildProtocol(self, addr): + p = self._connector.build_protocol(addr) + p.factory = self + if self._relay_handshake is not None: + p.use_relay(self._relay_handshake) + return p + +@attrs +class InboundConnectionFactory(ServerFactory, object): + _connector = attrib(validator=provides(IDilationConnector)) + protocol = DilatedConnectionProtocol + + def buildProtocol(self, addr): + p = self._connector.build_protocol(addr) + p.factory = self + return p diff --git a/src/wormhole/_dilation/encode.py b/src/wormhole/_dilation/encode.py new file mode 100644 index 0000000..eb1b0b2 --- /dev/null +++ b/src/wormhole/_dilation/encode.py @@ -0,0 +1,16 @@ +from __future__ import print_function, unicode_literals +import struct + +assert len(struct.pack("L", value) +def from_be4(b): + if not isinstance(b, bytes): + raise TypeError(repr(b)) + if len(b) != 4: + raise ValueError + return struct.unpack(">L", b)[0] diff --git a/src/wormhole/_dilation/inbound.py b/src/wormhole/_dilation/inbound.py new file mode 100644 index 0000000..9235681 --- /dev/null +++ b/src/wormhole/_dilation/inbound.py @@ -0,0 +1,127 @@ +from __future__ import print_function, unicode_literals +from attr import attrs, attrib +from attr.validators import provides +from zope.interface import implementer +from twisted.python import log +from .._interfaces import IDilationManager, IInbound +from .subchannel import (SubChannel, _SubchannelAddress) + +class DuplicateOpenError(Exception): + pass +class DataForMissingSubchannelError(Exception): + pass +class CloseForMissingSubchannelError(Exception): + pass + +@attrs +@implementer(IInbound) +class Inbound(object): + # Inbound flow control: TCP delivers data to Connection.dataReceived, + # Connection delivers to our handle_data, we deliver to + # SubChannel.remote_data, subchannel delivers to proto.dataReceived + _manager = attrib(validator=provides(IDilationManager)) + _host_addr = attrib() + + def __attrs_post_init__(self): + # we route inbound Data records to Subchannels .dataReceived + self._open_subchannels = {} # scid -> Subchannel + self._paused_subchannels = set() # Subchannels that have paused us + # the set is non-empty, we pause the transport + self._highest_inbound_acked = -1 + self._connection = None + + # from our Manager + def set_listener_endpoint(self, listener_endpoint): + self._listener_endpoint = listener_endpoint + + def set_subchannel_zero(self, scid0, sc0): + self._open_subchannels[scid0] = sc0 + + + def use_connection(self, c): + self._connection = c + # We can pause the connection's reads when we receive too much data. If + # this is a non-initial connection, then we might already have + # subchannels that are paused from before, so we might need to pause + # the new connection before it can send us any data + if self._paused_subchannels: + self._connection.pauseProducing() + + # Inbound is responsible for tracking the high watermark and deciding + # whether to ignore inbound messages or not + + def is_record_old(self, r): + if r.seqnum <= self._highest_inbound_acked: + return True + return False + + def update_ack_watermark(self, r): + self._highest_inbound_acked = max(self._highest_inbound_acked, + r.seqnum) + + def handle_open(self, scid): + if scid in self._open_subchannels: + log.err(DuplicateOpenError("received duplicate OPEN for {}".format(scid))) + return + peer_addr = _SubchannelAddress(scid) + sc = SubChannel(scid, self._manager, self._host_addr, peer_addr) + self._open_subchannels[scid] = sc + self._listener_endpoint._got_open(sc, peer_addr) + + def handle_data(self, scid, data): + sc = self._open_subchannels.get(scid) + if sc is None: + log.err(DataForMissingSubchannelError("received DATA for non-existent subchannel {}".format(scid))) + return + sc.remote_data(data) + + def handle_close(self, scid): + sc = self._open_subchannels.get(scid) + if sc is None: + log.err(CloseForMissingSubchannelError("received CLOSE for non-existent subchannel {}".format(scid))) + return + sc.remote_close() + + def subchannel_closed(self, scid, sc): + # connectionLost has just been signalled + assert self._open_subchannels[scid] is sc + del self._open_subchannels[scid] + + def stop_using_connection(self): + self._connection = None + + + # from our Subchannel, or rather from the Protocol above it and sent + # through the subchannel + + # The subchannel is an IProducer, and application protocols can always + # thell them to pauseProducing if we're delivering inbound data too + # quickly. They don't need to register anything. + + def subchannel_pauseProducing(self, sc): + was_paused = bool(self._paused_subchannels) + self._paused_subchannels.add(sc) + if self._connection and not was_paused: + self._connection.pauseProducing() + + def subchannel_resumeProducing(self, sc): + was_paused = bool(self._paused_subchannels) + self._paused_subchannels.discard(sc) + if self._connection and was_paused and not self._paused_subchannels: + self._connection.resumeProducing() + + def subchannel_stopProducing(self, sc): + # This protocol doesn't want any additional data. If we were a normal + # (single-owner) Transport, we'd call .loseConnection now. But our + # Connection is shared among many subchannels, so instead we just + # stop letting them pause the connection. + was_paused = bool(self._paused_subchannels) + self._paused_subchannels.discard(sc) + if self._connection and was_paused and not self._paused_subchannels: + self._connection.resumeProducing() + + # TODO: we might refactor these pause/resume/stop methods by building a + # context manager that look at the paused/not-paused state first, then + # lets the caller modify self._paused_subchannels, then looks at it a + # second time, and calls c.pauseProducing/c.resumeProducing as + # appropriate. I'm not sure it would be any cleaner, though. diff --git a/src/wormhole/_dilation/manager.py b/src/wormhole/_dilation/manager.py new file mode 100644 index 0000000..860665a --- /dev/null +++ b/src/wormhole/_dilation/manager.py @@ -0,0 +1,514 @@ +from __future__ import print_function, unicode_literals +from collections import deque +from attr import attrs, attrib +from attr.validators import provides, instance_of, optional +from automat import MethodicalMachine +from zope.interface import implementer +from twisted.internet.defer import Deferred, inlineCallbacks, returnValue +from twisted.python import log +from .._interfaces import IDilator, IDilationManager, ISend +from ..util import dict_to_bytes, bytes_to_dict +from ..observer import OneShotObserver +from .._key import derive_key +from .encode import to_be4 +from .subchannel import (SubChannel, _SubchannelAddress, _WormholeAddress, + ControlEndpoint, SubchannelConnectorEndpoint, + SubchannelListenerEndpoint) +from .connector import Connector, parse_hint +from .roles import LEADER, FOLLOWER +from .connection import KCM, Ping, Pong, Open, Data, Close, Ack +from .inbound import Inbound +from .outbound import Outbound + +class OldPeerCannotDilateError(Exception): + pass +class UnknownDilationMessageType(Exception): + pass +class ReceivedHintsTooEarly(Exception): + pass + +@attrs +@implementer(IDilationManager) +class _ManagerBase(object): + _S = attrib(validator=provides(ISend)) + _my_side = attrib(validator=instance_of(type(u""))) + _transit_key = attrib(validator=instance_of(bytes)) + _transit_relay_location = attrib(validator=optional(instance_of(str))) + _reactor = attrib() + _eventual_queue = attrib() + _cooperator = attrib() + _no_listen = False # TODO + _tor = None # TODO + _timing = None # TODO + + def __attrs_post_init__(self): + self._got_versions_d = Deferred() + + self._my_role = None # determined upon rx_PLEASE + + self._connection = None + self._made_first_connection = False + self._first_connected = OneShotObserver(self._eventual_queue) + self._host_addr = _WormholeAddress() + + self._next_dilation_phase = 0 + + self._next_subchannel_id = 0 # increments by 2 + + # I kept getting confused about which methods were for inbound data + # (and thus flow-control methods go "out") and which were for + # outbound data (with flow-control going "in"), so I split them up + # into separate pieces. + self._inbound = Inbound(self, self._host_addr) + self._outbound = Outbound(self, self._cooperator) # from us to peer + + def set_listener_endpoint(self, listener_endpoint): + self._inbound.set_listener_endpoint(listener_endpoint) + def set_subchannel_zero(self, scid0, sc0): + self._inbound.set_subchannel_zero(scid0, sc0) + + def when_first_connected(self): + return self._first_connected.when_fired() + + + def send_dilation_phase(self, **fields): + dilation_phase = self._next_dilation_phase + self._next_dilation_phase += 1 + self._S.send("dilate-%d" % dilation_phase, dict_to_bytes(fields)) + + def send_hints(self, hints): # from Connector + self.send_dilation_phase(type="connection-hints", hints=hints) + + + # forward inbound-ish things to _Inbound + def subchannel_pauseProducing(self, sc): + self._inbound.subchannel_pauseProducing(sc) + def subchannel_resumeProducing(self, sc): + self._inbound.subchannel_resumeProducing(sc) + def subchannel_stopProducing(self, sc): + self._inbound.subchannel_stopProducing(sc) + + # forward outbound-ish things to _Outbound + def subchannel_registerProducer(self, sc, producer, streaming): + self._outbound.subchannel_registerProducer(sc, producer, streaming) + def subchannel_unregisterProducer(self, sc): + self._outbound.subchannel_unregisterProducer(sc) + + def send_open(self, scid): + self._queue_and_send(Open, scid) + def send_data(self, scid, data): + self._queue_and_send(Data, scid, data) + def send_close(self, scid): + self._queue_and_send(Close, scid) + + def _queue_and_send(self, record_type, *args): + r = self._outbound.build_record(record_type, *args) + # Outbound owns the send_record() pipe, so that it can stall new + # writes after a new connection is made until after all queued + # messages are written (to preserve ordering). + self._outbound.queue_and_send_record(r) # may trigger pauseProducing + + def subchannel_closed(self, scid, sc): + # let everyone clean up. This happens just after we delivered + # connectionLost to the Protocol, except for the control channel, + # which might get connectionLost later after they use ep.connect. + # TODO: is this inversion a problem? + self._inbound.subchannel_closed(scid, sc) + self._outbound.subchannel_closed(scid, sc) + + + def _start_connecting(self, role): + assert self._my_role is not None + self._connector = Connector(self._transit_key, + self._transit_relay_location, + self, + self._reactor, self._eventual_queue, + self._no_listen, self._tor, + self._timing, + self._side, # needed for relay handshake + self._my_role) + self._connector.start() + + # our Connector calls these + + def connector_connection_made(self, c): + self.connection_made() # state machine update + self._connection = c + self._inbound.use_connection(c) + self._outbound.use_connection(c) # does c.registerProducer + if not self._made_first_connection: + self._made_first_connection = True + self._first_connected.fire(None) + pass + def connector_connection_lost(self): + self._stop_using_connection() + if self.role is LEADER: + self.connection_lost_leader() # state machine + else: + self.connection_lost_follower() + + + def _stop_using_connection(self): + # the connection is already lost by this point + self._connection = None + self._inbound.stop_using_connection() + self._outbound.stop_using_connection() # does c.unregisterProducer + + # from our active Connection + + def got_record(self, r): + # records with sequence numbers: always ack, ignore old ones + if isinstance(r, (Open, Data, Close)): + self.send_ack(r.seqnum) # always ack, even for old ones + if self._inbound.is_record_old(r): + return + self._inbound.update_ack_watermark(r.seqnum) + if isinstance(r, Open): + self._inbound.handle_open(r.scid) + elif isinstance(r, Data): + self._inbound.handle_data(r.scid, r.data) + else: # isinstance(r, Close) + self._inbound.handle_close(r.scid) + if isinstance(r, KCM): + log.err("got unexpected KCM") + elif isinstance(r, Ping): + self.handle_ping(r.ping_id) + elif isinstance(r, Pong): + self.handle_pong(r.ping_id) + elif isinstance(r, Ack): + self._outbound.handle_ack(r.resp_seqnum) # retire queued messages + else: + log.err("received unknown message type {}".format(r)) + + # pings, pongs, and acks are not queued + def send_ping(self, ping_id): + self._outbound.send_if_connected(Ping(ping_id)) + + def send_pong(self, ping_id): + self._outbound.send_if_connected(Pong(ping_id)) + + def send_ack(self, resp_seqnum): + self._outbound.send_if_connected(Ack(resp_seqnum)) + + + def handle_ping(self, ping_id): + self.send_pong(ping_id) + + def handle_pong(self, ping_id): + # TODO: update is-alive timer + pass + + # subchannel maintenance + def allocate_subchannel_id(self): + raise NotImplemented # subclass knows if we're leader or follower + +# new scheme: +# * both sides send PLEASE as soon as they have an unverified key and +# w.dilate has been called, +# * PLEASE includes a dilation-specific "side" (independent of the "side" +# used by mailbox messages) +# * higher "side" is Leader, lower is Follower +# * PLEASE includes can-dilate list of version integers, requires overlap +# "1" is current +# * dilation starts as soon as we've sent PLEASE and received PLEASE +# (four-state two-variable IDLE/WANTING/WANTED/STARTED diamond FSM) +# * HINTS sent after dilation starts +# * only Leader sends RECONNECT, only Follower sends RECONNECTING. This +# is the only difference between the two sides, and is not enforced +# by the protocol (i.e. if the Follower sends RECONNECT to the Leader, +# the Leader will obey, although TODO how confusing will this get?) +# * upon receiving RECONNECT: drop Connector, start new Connector, send +# RECONNECTING, start sending HINTS +# * upon sending CONNECT: go into FLUSHING state and ignore all HINTS until +# RECONNECTING received. The new Connector can be spun up earlier, and it +# can send HINTS, but it must not be given any HINTS that arrive before +# RECONNECTING (since they're probably stale) + +# * after VERSIONS(KCM) received, we might learn that they other side cannot +# dilate. w.dilate errbacks at this point + +# * maybe signal warning if we stay in a "want" state for too long +# * nobody sends HINTS until they're ready to receive +# * nobody sends HINTS unless they've called w.dilate() and received PLEASE +# * nobody connects to inbound hints unless they've called w.dilate() +# * if leader calls w.dilate() but not follower, leader waits forever in +# "want" (doesn't send anything) +# * if follower calls w.dilate() but not leader, follower waits forever +# in "want", leader waits forever in "wanted" + +class ManagerShared(_ManagerBase): + m = MethodicalMachine() + set_trace = getattr(m, "_setTrace", lambda self, f: None) + + @m.state(initial=True) + def IDLE(self): pass # pragma: no cover + + @m.state() + def WANTING(self): pass # pragma: no cover + @m.state() + def WANTED(self): pass # pragma: no cover + @m.state() + def CONNECTING(self): pass # pragma: no cover + @m.state() + def CONNECTED(self): pass # pragma: no cover + @m.state() + def FLUSHING(self): pass # pragma: no cover + @m.state() + def ABANDONING(self): pass # pragma: no cover + @m.state() + def LONELY(self): pass # pragme: no cover + @m.state() + def STOPPING(self): pass # pragma: no cover + @m.state(terminal=True) + def STOPPED(self): pass # pragma: no cover + + @m.input() + def start(self): pass # pragma: no cover + @m.input() + def rx_PLEASE(self, message): pass # pragma: no cover + @m.input() # only sent by Follower + def rx_HINTS(self, hint_message): pass # pragma: no cover + @m.input() # only Leader sends RECONNECT, so only Follower receives it + def rx_RECONNECT(self): pass # pragma: no cover + @m.input() # only Follower sends RECONNECTING, so only Leader receives it + def rx_RECONNECTING(self): pass # pragma: no cover + + # Connector gives us connection_made() + @m.input() + def connection_made(self, c): pass # pragma: no cover + + # our connection_lost() fires connection_lost_leader or + # connection_lost_follower depending upon our role. If either side sees a + # problem with the connection (timeouts, bad authentication) then they + # just drop it and let connection_lost() handle the cleanup. + @m.input() + def connection_lost_leader(self): pass # pragma: no cover + @m.input() + def connection_lost_follower(self): pass + + @m.input() + def stop(self): pass # pragma: no cover + + @m.output() + def stash_side(self, message): + their_side = message["side"] + self.my_role = LEADER if self._my_side > their_side else FOLLOWER + + # these Outputs behave differently for the Leader vs the Follower + @m.output() + def send_please(self): + self.send_dilation_phase(type="please", side=self._my_side) + + @m.output() + def start_connecting(self): + self._start_connecting() # TODO: merge + @m.output() + def ignore_message_start_connecting(self, message): + self.start_connecting() + + @m.output() + def send_reconnect(self): + self.send_dilation_phase(type="reconnect") # TODO: generation number? + @m.output() + def send_reconnecting(self): + self.send_dilation_phase(type="reconnecting") # TODO: generation? + + @m.output() + def use_hints(self, hint_message): + hint_objs = filter(lambda h: h, # ignore None, unrecognizable + [parse_hint(hs) for hs in hint_message["hints"]]) + hint_objs = list(hint_objs) + self._connector.got_hints(hint_objs) + @m.output() + def stop_connecting(self): + self._connector.stop() + @m.output() + def abandon_connection(self): + # we think we're still connected, but the Leader disagrees. Or we've + # been told to shut down. + self._connection.disconnect() # let connection_lost do cleanup + + + # we don't start CONNECTING until a local start() plus rx_PLEASE + IDLE.upon(rx_PLEASE, enter=WANTED, outputs=[stash_side]) + IDLE.upon(start, enter=WANTING, outputs=[send_please]) + WANTED.upon(start, enter=CONNECTING, outputs=[send_please, start_connecting]) + WANTING.upon(rx_PLEASE, enter=CONNECTING, + outputs=[stash_side, + ignore_message_start_connecting]) + + CONNECTING.upon(connection_made, enter=CONNECTED, outputs=[]) + + # Leader + CONNECTED.upon(connection_lost_leader, enter=FLUSHING, + outputs=[send_reconnect]) + FLUSHING.upon(rx_RECONNECTING, enter=CONNECTING, outputs=[start_connecting]) + + # Follower + # if we notice a lost connection, just wait for the Leader to notice too + CONNECTED.upon(connection_lost_follower, enter=LONELY, outputs=[]) + LONELY.upon(rx_RECONNECT, enter=CONNECTING, outputs=[start_connecting]) + # but if they notice it first, abandon our (seemingly functional) + # connection, then tell them that we're ready to try again + CONNECTED.upon(rx_RECONNECT, enter=ABANDONING, # they noticed loss + outputs=[abandon_connection]) + ABANDONING.upon(connection_lost_follower, enter=CONNECTING, + outputs=[send_reconnecting, start_connecting]) + # and if they notice a problem while we're still connecting, abandon our + # incomplete attempt and try again. in this case we don't have to wait + # for a connection to finish shutdown + CONNECTING.upon(rx_RECONNECT, enter=CONNECTING, + outputs=[stop_connecting, + send_reconnecting, + start_connecting]) + + + # rx_HINTS never changes state, they're just accepted or ignored + IDLE.upon(rx_HINTS, enter=IDLE, outputs=[]) # too early + WANTED.upon(rx_HINTS, enter=WANTED, outputs=[]) # too early + WANTING.upon(rx_HINTS, enter=WANTING, outputs=[]) # too early + CONNECTING.upon(rx_HINTS, enter=CONNECTING, outputs=[use_hints]) + CONNECTED.upon(rx_HINTS, enter=CONNECTED, outputs=[]) # too late, ignore + FLUSHING.upon(rx_HINTS, enter=FLUSHING, outputs=[]) # stale, ignore + LONELY.upon(rx_HINTS, enter=LONELY, outputs=[]) # stale, ignore + ABANDONING.upon(rx_HINTS, enter=ABANDONING, outputs=[]) # shouldn't happen + STOPPING.upon(rx_HINTS, enter=STOPPING, outputs=[]) + + IDLE.upon(stop, enter=STOPPED, outputs=[]) + WANTED.upon(stop, enter=STOPPED, outputs=[]) + WANTING.upon(stop, enter=STOPPED, outputs=[]) + CONNECTING.upon(stop, enter=STOPPED, outputs=[stop_connecting]) + CONNECTED.upon(stop, enter=STOPPING, outputs=[abandon_connection]) + ABANDONING.upon(stop, enter=STOPPING, outputs=[]) + FLUSHING.upon(stop, enter=STOPPED, outputs=[stop_connecting]) + LONELY.upon(stop, enter=STOPPED, outputs=[]) + STOPPING.upon(connection_lost_leader, enter=STOPPED, outputs=[]) + STOPPING.upon(connection_lost_follower, enter=STOPPED, outputs=[]) + + + def allocate_subchannel_id(self): + # scid 0 is reserved for the control channel. the leader uses odd + # numbers starting with 1 + scid_num = self._next_outbound_seqnum + 1 + self._next_outbound_seqnum += 2 + return to_be4(scid_num) + +@attrs +@implementer(IDilator) +class Dilator(object): + """I launch the dilation process. + + I am created with every Wormhole (regardless of whether .dilate() + was called or not), and I handle the initial phase of dilation, + before we know whether we'll be the Leader or the Follower. Once we + hear the other side's VERSION message (which tells us that we have a + connection, they are capable of dilating, and which side we're on), + then we build a DilationManager and hand control to it. + """ + + _reactor = attrib() + _eventual_queue = attrib() + _cooperator = attrib() + + def __attrs_post_init__(self): + self._got_versions_d = Deferred() + self._started = False + self._endpoints = OneShotObserver(self._eventual_queue) + self._pending_inbound_dilate_messages = deque() + self._manager = None + + def wire(self, sender): + self._S = ISend(sender) + + # this is the primary entry point, called when w.dilate() is invoked + def dilate(self, transit_relay_location=None): + self._transit_relay_location = transit_relay_location + if not self._started: + self._started = True + self._start().addBoth(self._endpoints.fire) + return self._endpoints.when_fired() + + @inlineCallbacks + def _start(self): + # first, we wait until we hear the VERSION message, which tells us 1: + # the PAKE key works, so we can talk securely, 2: their side, so we + # know who will lead, and 3: that they can do dilation at all + + dilation_version = yield self._got_versions_d + + if not dilation_version: # 1 or None + raise OldPeerCannotDilateError() + + my_dilation_side = TODO # random + self._manager = Manager(self._S, my_dilation_side, + self._transit_key, + self._transit_relay_location, + self._reactor, self._eventual_queue, + self._cooperator) + self._manager.start() + + while self._pending_inbound_dilate_messages: + plaintext = self._pending_inbound_dilate_messages.popleft() + self.received_dilate(plaintext) + + # we could probably return the endpoints earlier + yield self._manager.when_first_connected() + # we can open subchannels as soon as we get our first connection + scid0 = b"\x00\x00\x00\x00" + self._host_addr = _WormholeAddress() # TODO: share with Manager + peer_addr0 = _SubchannelAddress(scid0) + control_ep = ControlEndpoint(peer_addr0) + sc0 = SubChannel(scid0, self._manager, self._host_addr, peer_addr0) + control_ep._subchannel_zero_opened(sc0) + self._manager.set_subchannel_zero(scid0, sc0) + + connect_ep = SubchannelConnectorEndpoint(self._manager, self._host_addr) + + listen_ep = SubchannelListenerEndpoint(self._manager, self._host_addr) + self._manager.set_listener_endpoint(listen_ep) + + endpoints = (control_ep, connect_ep, listen_ep) + returnValue(endpoints) + + # from Boss + + def got_key(self, key): + # TODO: verify this happens before got_wormhole_versions, or add a gate + # to tolerate either ordering + purpose = b"dilation-v1" + LENGTH =32 # TODO: whatever Noise wants, I guess + self._transit_key = derive_key(key, purpose, LENGTH) + + def got_wormhole_versions(self, our_side, their_side, + their_wormhole_versions): + # TODO: remove our_side, their_side + assert isinstance(our_side, str), str + assert isinstance(their_side, str), str + # this always happens before received_dilate + dilation_version = None + their_dilation_versions = their_wormhole_versions.get("can-dilate", []) + if 1 in their_dilation_versions: + dilation_version = 1 + self._got_versions_d.callback(dilation_version) + + def received_dilate(self, plaintext): + # this receives new in-order DILATE-n payloads, decrypted but not + # de-JSONed. + + # this can appear before our .dilate() method is called, in which case + # we queue them for later + if not self._manager: + self._pending_inbound_dilate_messages.append(plaintext) + return + + message = bytes_to_dict(plaintext) + type = message["type"] + if type == "please": + self._manager.rx_PLEASE() # message) + elif type == "dilate": + self._manager.rx_DILATE() #message) + elif type == "connection-hints": + self._manager.rx_HINTS(message) + else: + log.err(UnknownDilationMessageType(message)) + return diff --git a/src/wormhole/_dilation/old-follower.py b/src/wormhole/_dilation/old-follower.py new file mode 100644 index 0000000..68e3a38 --- /dev/null +++ b/src/wormhole/_dilation/old-follower.py @@ -0,0 +1,106 @@ + +class ManagerFollower(_ManagerBase): + m = MethodicalMachine() + set_trace = getattr(m, "_setTrace", lambda self, f: None) + + @m.state(initial=True) + def IDLE(self): pass # pragma: no cover + + @m.state() + def WANTING(self): pass # pragma: no cover + @m.state() + def CONNECTING(self): pass # pragma: no cover + @m.state() + def CONNECTED(self): pass # pragma: no cover + @m.state(terminal=True) + def STOPPED(self): pass # pragma: no cover + + @m.input() + def start(self): pass # pragma: no cover + @m.input() + def rx_PLEASE(self): pass # pragma: no cover + @m.input() + def rx_DILATE(self): pass # pragma: no cover + @m.input() + def rx_HINTS(self, hint_message): pass # pragma: no cover + + @m.input() + def connection_made(self): pass # pragma: no cover + @m.input() + def connection_lost(self): pass # pragma: no cover + # follower doesn't react to connection_lost, but waits for a new LETS_DILATE + + @m.input() + def stop(self): pass # pragma: no cover + + # these Outputs behave differently for the Leader vs the Follower + @m.output() + def send_please(self): + self.send_dilation_phase(type="please") + + @m.output() + def start_connecting(self): + self._start_connecting(FOLLOWER) + + # these Outputs delegate to the same code in both the Leader and the + # Follower, but they must be replicated here because the Automat instance + # is on the subclass, not the shared superclass + + @m.output() + def use_hints(self, hint_message): + hint_objs = filter(lambda h: h, # ignore None, unrecognizable + [parse_hint(hs) for hs in hint_message["hints"]]) + self._connector.got_hints(hint_objs) + @m.output() + def stop_connecting(self): + self._connector.stop() + @m.output() + def use_connection(self, c): + self._use_connection(c) + @m.output() + def stop_using_connection(self): + self._stop_using_connection() + @m.output() + def signal_error(self): + pass # TODO + @m.output() + def signal_error_hints(self, hint_message): + pass # TODO + + IDLE.upon(rx_HINTS, enter=STOPPED, outputs=[signal_error_hints]) # too early + IDLE.upon(rx_DILATE, enter=STOPPED, outputs=[signal_error]) # too early + # leader shouldn't send us DILATE before receiving our PLEASE + IDLE.upon(stop, enter=STOPPED, outputs=[]) + IDLE.upon(start, enter=WANTING, outputs=[send_please]) + WANTING.upon(rx_DILATE, enter=CONNECTING, outputs=[start_connecting]) + WANTING.upon(stop, enter=STOPPED, outputs=[]) + + CONNECTING.upon(rx_HINTS, enter=CONNECTING, outputs=[use_hints]) + CONNECTING.upon(connection_made, enter=CONNECTED, outputs=[use_connection]) + # shouldn't happen: connection_lost + #CONNECTING.upon(connection_lost, enter=CONNECTING, outputs=[?]) + CONNECTING.upon(rx_DILATE, enter=CONNECTING, outputs=[stop_connecting, + start_connecting]) + # receiving rx_DILATE while we're still working on the last one means the + # leader thought we'd connected, then thought we'd been disconnected, all + # before we heard about that connection + CONNECTING.upon(stop, enter=STOPPED, outputs=[stop_connecting]) + + CONNECTED.upon(connection_lost, enter=WANTING, outputs=[stop_using_connection]) + CONNECTED.upon(rx_DILATE, enter=CONNECTING, outputs=[stop_using_connection, + start_connecting]) + CONNECTED.upon(rx_HINTS, enter=CONNECTED, outputs=[]) # too late, ignore + CONNECTED.upon(stop, enter=STOPPED, outputs=[stop_using_connection]) + # shouldn't happen: connection_made + + # we should never receive PLEASE, we're the follower + IDLE.upon(rx_PLEASE, enter=STOPPED, outputs=[signal_error]) + WANTING.upon(rx_PLEASE, enter=STOPPED, outputs=[signal_error]) + CONNECTING.upon(rx_PLEASE, enter=STOPPED, outputs=[signal_error]) + CONNECTED.upon(rx_PLEASE, enter=STOPPED, outputs=[signal_error]) + + def allocate_subchannel_id(self): + # the follower uses even numbers starting with 2 + scid_num = self._next_outbound_seqnum + 2 + self._next_outbound_seqnum += 2 + return to_be4(scid_num) diff --git a/src/wormhole/_dilation/outbound.py b/src/wormhole/_dilation/outbound.py new file mode 100644 index 0000000..6538ffe --- /dev/null +++ b/src/wormhole/_dilation/outbound.py @@ -0,0 +1,392 @@ +from __future__ import print_function, unicode_literals +from collections import deque +from attr import attrs, attrib +from attr.validators import provides +from zope.interface import implementer +from twisted.internet.interfaces import IPushProducer, IPullProducer +from twisted.python import log +from twisted.python.reflect import safe_str +from .._interfaces import IDilationManager, IOutbound +from .connection import KCM, Ping, Pong, Ack + + +# Outbound flow control: app writes to subchannel, we write to Connection + +# The app can register an IProducer of their choice, to let us throttle their +# outbound data. Not all subchannels will have producers registered, and the +# producer probably won't be the IProtocol instance (it'll be something else +# which feeds data out through the protocol, like a t.p.basic.FileSender). If +# a producerless subchannel writes too much, we won't be able to stop them, +# and we'll keep writing records into the Connection even though it's asked +# us to pause. Likewise, when the connection is down (and we're busily trying +# to reestablish a new one), registered subchannels will be paused, but +# unregistered ones will just dump everything in _outbound_queue, and we'll +# consume memory without bound until they stop. + +# We need several things: +# +# * Add each registered IProducer to a list, whose order remains stable. We +# want fairness under outbound throttling: each time the outbound +# connection opens up (our resumeProducing method is called), we should let +# just one producer have an opportunity to do transport.write, and then we +# should pause them again, and not come back to them until everyone else +# has gotten a turn. So we want an ordered list of producers to track this +# rotation. +# +# * Remove the IProducer if/when the protocol uses unregisterProducer +# +# * Remove any registered IProducer when the associated Subchannel is closed. +# This isn't a problem for normal transports, because usually there's a +# one-to-one mapping from Protocol to Transport, so when the Transport you +# forget the only reference to the Producer anyways. Our situation is +# unusual because we have multiple Subchannels that get merged into the +# same underlying Connection: each Subchannel's Protocol can register a +# producer on the Subchannel (which is an ITransport), but that adds it to +# a set of Producers for the Connection (which is also an ITransport). So +# if the Subchannel is closed, we need to remove its Producer (if any) even +# though the Connection remains open. +# +# * Register ourselves as an IPushProducer with each successive Connection +# object. These connections will come and go, but there will never be more +# than one. When the connection goes away, pause all our producers. When a +# new one is established, write all our queued messages, then unpause our +# producers as we would in resumeProducing. +# +# * Inside our resumeProducing call, we'll cycle through all producers, +# calling their individual resumeProducing methods one at a time. If they +# write so much data that the Connection pauses us again, we'll find out +# because our pauseProducing will be called inside that loop. When that +# happens, we need to stop looping. If we make it through the whole loop +# without being paused, then all subchannel Producers are left unpaused, +# and are free to write whenever they want. During this loop, some +# Producers will be paused, and others will be resumed +# +# * If our pauseProducing is called, all Producers must be paused, and a flag +# should be set to notify the resumeProducing loop to exit +# +# * In between calls to our resumeProducing method, we're in one of two +# states. +# * If we're writing data too fast, then we'll be left in the "paused" +# state, in which all Subchannel producers are paused, and the aggregate +# is paused too (our Connection told us to pauseProducing and hasn't yet +# told us to resumeProducing). In this state, activity is driven by the +# outbound TCP window opening up, which calls resumeProducing and allows +# (probably just) one message to be sent. We receive pauseProducing in +# the middle of their transport.write, so the loop exits early, and the +# only state change is that some other Producer should get to go next +# time. +# * If we're writing too slowly, we'll be left in the "unpaused" state: all +# Subchannel producers are unpaused, and the aggregate is unpaused too +# (resumeProducing is the last thing we've been told). In this satte, +# activity is driven by the Subchannels doing a transport.write, which +# queues some data on the TCP connection (and then might call +# pauseProducing if it's now full). +# +# * We want to guard against: +# +# * application protocol registering a Producer without first unregistering +# the previous one +# +# * application protocols writing data despite being told to pause +# (Subchannels without a registered Producer cannot be throttled, and we +# can't do anything about that, but we must also handle the case where +# they give us a pause switch and then proceed to ignore it) +# +# * our Connection calling resumeProducing or pauseProducing without an +# intervening call of the other kind +# +# * application protocols that don't handle a resumeProducing or +# pauseProducing call without an intervening call of the other kind (i.e. +# we should keep track of the last thing we told them, and not repeat +# ourselves) +# +# * If the Wormhole is closed, all Subchannels should close. This is not our +# responsibility: it lives in (Manager? Inbound?) +# +# * If we're given an IPullProducer, we should keep calling its +# resumeProducing until it runs out of data. We still want fairness, so we +# won't call it a second time until everyone else has had a turn. + + +# There are a couple of different ways to approach this. The one I've +# selected is: +# +# * keep a dict that maps from Subchannel to Producer, which only contains +# entries for Subchannels that have registered a producer. We use this to +# remove Producers when Subchannels are closed +# +# * keep a Deque of Producers. This represents the fair-throttling rotation: +# the left-most item gets the next upcoming turn, and then they'll be moved +# to the end of the queue. +# +# * keep a set of IPushProducers which are paused, a second set of +# IPushProducers which are unpaused, and a third set of IPullProducers +# (which are always left paused) Enforce the invariant that these three +# sets are disjoint, and that their union equals the contents of the deque. +# +# * keep a "paused" flag, which is cleared upon entry to resumeProducing, and +# set upon entry to pauseProducing. The loop inside resumeProducing checks +# this flag after each call to producer.resumeProducing, to sense whether +# they used their turn to write data, and if that write was large enough to +# fill the TCP window. If set, we break out of the loop. If not, we look +# for the next producer to unpause. The loop finishes when all producers +# are unpaused (evidenced by the two sets of paused producers being empty) +# +# * the "paused" flag also determines whether new IPushProducers are added to +# the paused or unpaused set (IPullProducers are always added to the +# pull+paused set). If we have any IPullProducers, we're always in the +# "writing data too fast" state. + +# other approaches that I didn't decide to do at this time (but might use in +# the future): +# +# * use one set instead of two. pros: fewer moving parts. cons: harder to +# spot decoherence bugs like adding a producer to the deque but forgetting +# to add it to one of the +# +# * use zero sets, and keep the paused-vs-unpaused state in the Subchannel as +# a visible boolean flag. This conflates Subchannels with their associated +# Producer (so if we went this way, we should also let them track their own +# Producer). Our resumeProducing loop ends when 'not any(sc.paused for sc +# in self._subchannels_with_producers)'. Pros: fewer subchannel->producer +# mappings lying around to disagree with one another. Cons: exposes a bit +# too much of the Subchannel internals + + +@attrs +@implementer(IOutbound) +class Outbound(object): + # Manage outbound data: subchannel writes to us, we write to transport + _manager = attrib(validator=provides(IDilationManager)) + _cooperator = attrib() + + def __attrs_post_init__(self): + # _outbound_queue holds all messages we've ever sent but not retired + self._outbound_queue = deque() + self._next_outbound_seqnum = 0 + # _queued_unsent are messages to retry with our new connection + self._queued_unsent = deque() + + # outbound flow control: the Connection throttles our writes + self._subchannel_producers = {} # Subchannel -> IProducer + self._paused = True # our Connection called our pauseProducing + self._all_producers = deque() # rotates, left-is-next + self._paused_producers = set() + self._unpaused_producers = set() + self._check_invariants() + + self._connection = None + + def _check_invariants(self): + assert self._unpaused_producers.isdisjoint(self._paused_producers) + assert (self._paused_producers.union(self._unpaused_producers) == + set(self._all_producers)) + + def build_record(self, record_type, *args): + seqnum = self._next_outbound_seqnum + self._next_outbound_seqnum += 1 + r = record_type(seqnum, *args) + assert hasattr(r, "seqnum"), r # only Open/Data/Close + return r + + def queue_and_send_record(self, r): + # we always queue it, to resend on a subsequent connection if + # necessary + self._outbound_queue.append(r) + + if self._connection: + if self._queued_unsent: + # to maintain correct ordering, queue this instead of sending it + self._queued_unsent.append(r) + else: + # we're allowed to send it immediately + self._connection.send_record(r) + + def send_if_connected(self, r): + assert isinstance(r, (KCM, Ping, Pong, Ack)), r # nothing with seqnum + if self._connection: + self._connection.send_record(r) + + # our subchannels call these to register a producer + + def subchannel_registerProducer(self, sc, producer, streaming): + # streaming==True: IPushProducer (pause/resume) + # streaming==False: IPullProducer (just resume) + if sc in self._subchannel_producers: + raise ValueError( + "registering producer %s before previous one (%s) was " + "unregistered" % (producer, + self._subchannel_producers[sc])) + # our underlying Connection uses streaming==True, so to make things + # easier, use an adapter when the Subchannel asks for streaming=False + if not streaming: + def unregister(): + self.subchannel_unregisterProducer(sc) + producer = PullToPush(producer, unregister, self._cooperator) + + self._subchannel_producers[sc] = producer + self._all_producers.append(producer) + if self._paused: + self._paused_producers.add(producer) + else: + self._unpaused_producers.add(producer) + self._check_invariants() + if streaming: + if self._paused: + # IPushProducers need to be paused immediately, before they + # speak + producer.pauseProducing() # you wake up sleeping + else: + # our PullToPush adapter must be started, but if we're paused then + # we tell it to pause before it gets a chance to write anything + producer.startStreaming(self._paused) + + def subchannel_unregisterProducer(self, sc): + # TODO: what if the subchannel closes, so we unregister their + # producer for them, then the application reacts to connectionLost + # with a duplicate unregisterProducer? + p = self._subchannel_producers.pop(sc) + if isinstance(p, PullToPush): + p.stopStreaming() + self._all_producers.remove(p) + self._paused_producers.discard(p) + self._unpaused_producers.discard(p) + self._check_invariants() + + def subchannel_closed(self, sc): + self._check_invariants() + if sc in self._subchannel_producers: + self.subchannel_unregisterProducer(sc) + + # our Manager tells us when we've got a new Connection to work with + + def use_connection(self, c): + self._connection = c + assert not self._queued_unsent + self._queued_unsent.extend(self._outbound_queue) + # the connection can tell us to pause when we send too much data + c.registerProducer(self, True) # IPushProducer: pause+resume + # send our queued messages + self.resumeProducing() + + def stop_using_connection(self): + self._connection.unregisterProducer() + self._connection = None + self._queued_unsent.clear() + self.pauseProducing() + # TODO: I expect this will call pauseProducing twice: the first time + # when we get stopProducing (since we're registere with the + # underlying connection as the producer), and again when the manager + # notices the connectionLost and calls our _stop_using_connection + + def handle_ack(self, resp_seqnum): + # we've received an inbound ack, so retire something + while (self._outbound_queue and + self._outbound_queue[0].seqnum <= resp_seqnum): + self._outbound_queue.popleft() + while (self._queued_unsent and + self._queued_unsent[0].seqnum <= resp_seqnum): + self._queued_unsent.popleft() + # Inbound is responsible for tracking the high watermark and deciding + # whether to ignore inbound messages or not + + + # IProducer: the active connection calls these because we used + # c.registerProducer to ask for them + def pauseProducing(self): + if self._paused: + return # someone is confused and called us twice + self._paused = True + for p in self._all_producers: + if p in self._unpaused_producers: + self._unpaused_producers.remove(p) + self._paused_producers.add(p) + p.pauseProducing() + + def resumeProducing(self): + if not self._paused: + return # someone is confused and called us twice + self._paused = False + + while not self._paused: + if self._queued_unsent: + r = self._queued_unsent.popleft() + self._connection.send_record(r) + continue + p = self._get_next_unpaused_producer() + if not p: + break + self._paused_producers.remove(p) + self._unpaused_producers.add(p) + p.resumeProducing() + + def _get_next_unpaused_producer(self): + self._check_invariants() + if not self._paused_producers: + return None + while True: + p = self._all_producers[0] + self._all_producers.rotate(-1) # p moves to the end of the line + # the only unpaused Producers are at the end of the list + assert p in self._paused_producers + return p + + def stopProducing(self): + # we'll hopefully have a new connection to work with in the future, + # so we don't shut anything down. We do pause everyone, though. + self.pauseProducing() + + +# modelled after twisted.internet._producer_helper._PullToPush , but with a +# configurable Cooperator, a pause-immediately argument to startStreaming() +@implementer(IPushProducer) +@attrs(cmp=False) +class PullToPush(object): + _producer = attrib(validator=provides(IPullProducer)) + _unregister = attrib(validator=lambda _a,_b,v: callable(v)) + _cooperator = attrib() + _finished = False + + def _pull(self): + while True: + try: + self._producer.resumeProducing() + except: + log.err(None, "%s failed, producing will be stopped:" % + (safe_str(self._producer),)) + try: + self._unregister() + # The consumer should now call stopStreaming() on us, + # thus stopping the streaming. + except: + # Since the consumer blew up, we may not have had + # stopStreaming() called, so we just stop on our own: + log.err(None, "%s failed to unregister producer:" % + (safe_str(self._unregister),)) + self._finished = True + return + yield None + + def startStreaming(self, paused): + self._coopTask = self._cooperator.cooperate(self._pull()) + if paused: + self.pauseProducing() # timer is scheduled, but task is removed + + def stopStreaming(self): + if self._finished: + return + self._finished = True + self._coopTask.stop() + + + def pauseProducing(self): + self._coopTask.pause() + + + def resumeProducing(self): + self._coopTask.resume() + + + def stopProducing(self): + self.stopStreaming() + self._producer.stopProducing() diff --git a/src/wormhole/_dilation/roles.py b/src/wormhole/_dilation/roles.py new file mode 100644 index 0000000..8f9adac --- /dev/null +++ b/src/wormhole/_dilation/roles.py @@ -0,0 +1 @@ +LEADER, FOLLOWER = object(), object() diff --git a/src/wormhole/_dilation/subchannel.py b/src/wormhole/_dilation/subchannel.py new file mode 100644 index 0000000..94b4a03 --- /dev/null +++ b/src/wormhole/_dilation/subchannel.py @@ -0,0 +1,269 @@ +from attr import attrs, attrib +from attr.validators import instance_of, provides +from zope.interface import implementer +from twisted.internet.defer import Deferred, inlineCallbacks, returnValue, succeed +from twisted.internet.interfaces import (ITransport, IProducer, IConsumer, + IAddress, IListeningPort, + IStreamClientEndpoint, + IStreamServerEndpoint) +from twisted.internet.error import ConnectionDone +from automat import MethodicalMachine +from .._interfaces import ISubChannel, IDilationManager + +@attrs +class Once(object): + _errtype = attrib() + def __attrs_post_init__(self): + self._called = False + + def __call__(self): + if self._called: + raise self._errtype() + self._called = True + +class SingleUseEndpointError(Exception): + pass + +# created in the (OPEN) state, by either: +# * receipt of an OPEN message +# * or local client_endpoint.connect() +# then transitions are: +# (OPEN) rx DATA: deliver .dataReceived(), -> (OPEN) +# (OPEN) rx CLOSE: deliver .connectionLost(), send CLOSE, -> (CLOSED) +# (OPEN) local .write(): send DATA, -> (OPEN) +# (OPEN) local .loseConnection(): send CLOSE, -> (CLOSING) +# (CLOSING) local .write(): error +# (CLOSING) local .loseConnection(): error +# (CLOSING) rx DATA: deliver .dataReceived(), -> (CLOSING) +# (CLOSING) rx CLOSE: deliver .connectionLost(), -> (CLOSED) +# object is deleted upon transition to (CLOSED) + +class AlreadyClosedError(Exception): + pass + +@implementer(IAddress) +class _WormholeAddress(object): + pass + +@implementer(IAddress) +@attrs +class _SubchannelAddress(object): + _scid = attrib() + + +@attrs +@implementer(ITransport) +@implementer(IProducer) +@implementer(IConsumer) +@implementer(ISubChannel) +class SubChannel(object): + _id = attrib(validator=instance_of(bytes)) + _manager = attrib(validator=provides(IDilationManager)) + _host_addr = attrib(validator=instance_of(_WormholeAddress)) + _peer_addr = attrib(validator=instance_of(_SubchannelAddress)) + + m = MethodicalMachine() + set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover + + def __attrs_post_init__(self): + #self._mailbox = None + #self._pending_outbound = {} + #self._processed = set() + self._protocol = None + self._pending_dataReceived = [] + self._pending_connectionLost = (False, None) + + @m.state(initial=True) + def open(self): pass # pragma: no cover + + @m.state() + def closing(): pass # pragma: no cover + + @m.state() + def closed(): pass # pragma: no cover + + @m.input() + def remote_data(self, data): pass + @m.input() + def remote_close(self): pass + + @m.input() + def local_data(self, data): pass + @m.input() + def local_close(self): pass + + + @m.output() + def send_data(self, data): + self._manager.send_data(self._id, data) + + @m.output() + def send_close(self): + self._manager.send_close(self._id) + + @m.output() + def signal_dataReceived(self, data): + if self._protocol: + self._protocol.dataReceived(data) + else: + self._pending_dataReceived.append(data) + + @m.output() + def signal_connectionLost(self): + if self._protocol: + self._protocol.connectionLost(ConnectionDone()) + else: + self._pending_connectionLost = (True, ConnectionDone()) + self._manager.subchannel_closed(self) + # we're deleted momentarily + + @m.output() + def error_closed_write(self, data): + raise AlreadyClosedError("write not allowed on closed subchannel") + @m.output() + def error_closed_close(self): + raise AlreadyClosedError("loseConnection not allowed on closed subchannel") + + # primary transitions + open.upon(remote_data, enter=open, outputs=[signal_dataReceived]) + open.upon(local_data, enter=open, outputs=[send_data]) + open.upon(remote_close, enter=closed, outputs=[signal_connectionLost]) + open.upon(local_close, enter=closing, outputs=[send_close]) + closing.upon(remote_data, enter=closing, outputs=[signal_dataReceived]) + closing.upon(remote_close, enter=closed, outputs=[signal_connectionLost]) + + # error cases + # we won't ever see an OPEN, since L4 will log+ignore those for us + closing.upon(local_data, enter=closing, outputs=[error_closed_write]) + closing.upon(local_close, enter=closing, outputs=[error_closed_close]) + # the CLOSED state won't ever see messages, since we'll be deleted + + # our endpoints use this + + def _set_protocol(self, protocol): + assert not self._protocol + self._protocol = protocol + if self._pending_dataReceived: + for data in self._pending_dataReceived: + self._protocol.dataReceived(data) + self._pending_dataReceived = [] + cl, what = self._pending_connectionLost + if cl: + self._protocol.connectionLost(what) + + # ITransport + def write(self, data): + assert isinstance(data, type(b"")) + self.local_data(data) + def writeSequence(self, iovec): + self.write(b"".join(iovec)) + def loseConnection(self): + self.local_close() + def getHost(self): + # we define "host addr" as the overall wormhole + return self._host_addr + def getPeer(self): + # and "peer addr" as the subchannel within that wormhole + return self._peer_addr + + # IProducer: throttle inbound data (wormhole "up" to local app's Protocol) + def stopProducing(self): + self._manager.subchannel_stopProducing(self) + def pauseProducing(self): + self._manager.subchannel_pauseProducing(self) + def resumeProducing(self): + self._manager.subchannel_resumeProducing(self) + + # IConsumer: allow the wormhole to throttle outbound data (app->wormhole) + def registerProducer(self, producer, streaming): + self._manager.subchannel_registerProducer(self, producer, streaming) + def unregisterProducer(self): + self._manager.subchannel_unregisterProducer(self) + + +@implementer(IStreamClientEndpoint) +class ControlEndpoint(object): + _used = False + def __init__(self, peer_addr): + self._subchannel_zero = Deferred() + self._peer_addr = peer_addr + self._once = Once(SingleUseEndpointError) + + # from manager + def _subchannel_zero_opened(self, subchannel): + assert ISubChannel.providedBy(subchannel), subchannel + self._subchannel_zero.callback(subchannel) + + @inlineCallbacks + def connect(self, protocolFactory): + # return Deferred that fires with IProtocol or Failure(ConnectError) + self._once() + t = yield self._subchannel_zero + p = protocolFactory.buildProtocol(self._peer_addr) + t._set_protocol(p) + p.makeConnection(t) # set p.transport = t and call connectionMade() + returnValue(p) + +@implementer(IStreamClientEndpoint) +@attrs +class SubchannelConnectorEndpoint(object): + _manager = attrib(validator=provides(IDilationManager)) + _host_addr = attrib(validator=instance_of(_WormholeAddress)) + + def connect(self, protocolFactory): + # return Deferred that fires with IProtocol or Failure(ConnectError) + scid = self._manager.allocate_subchannel_id() + self._manager.send_open(scid) + peer_addr = _SubchannelAddress(scid) + # ? f.doStart() + # ? f.startedConnecting(CONNECTOR) # ?? + t = SubChannel(scid, self._manager, self._host_addr, peer_addr) + p = protocolFactory.buildProtocol(peer_addr) + t._set_protocol(p) + p.makeConnection(t) # set p.transport = t and call connectionMade() + return succeed(p) + +@implementer(IStreamServerEndpoint) +@attrs +class SubchannelListenerEndpoint(object): + _manager = attrib(validator=provides(IDilationManager)) + _host_addr = attrib(validator=provides(IAddress)) + + def __attrs_post_init__(self): + self._factory = None + self._pending_opens = [] + + # from manager + def _got_open(self, t, peer_addr): + if self._factory: + self._connect(t, peer_addr) + else: + self._pending_opens.append( (t, peer_addr) ) + + def _connect(self, t, peer_addr): + p = self._factory.buildProtocol(peer_addr) + t._set_protocol(p) + p.makeConnection(t) + + # IStreamServerEndpoint + + def listen(self, protocolFactory): + self._factory = protocolFactory + for (t, peer_addr) in self._pending_opens: + self._connect(t, peer_addr) + self._pending_opens = [] + lp = SubchannelListeningPort(self._host_addr) + return succeed(lp) + +@implementer(IListeningPort) +@attrs +class SubchannelListeningPort(object): + _host_addr = attrib(validator=provides(IAddress)) + + def startListening(self): + pass + def stopListening(self): + # TODO + pass + def getHost(self): + return self._host_addr diff --git a/src/wormhole/_interfaces.py b/src/wormhole/_interfaces.py index 9f59ff4..52bde35 100644 --- a/src/wormhole/_interfaces.py +++ b/src/wormhole/_interfaces.py @@ -433,3 +433,16 @@ class IInputHelper(Interface): class IJournal(Interface): # TODO: this needs to be public pass + +class IDilator(Interface): + pass +class IDilationManager(Interface): + pass +class IDilationConnector(Interface): + pass +class ISubChannel(Interface): + pass +class IInbound(Interface): + pass +class IOutbound(Interface): + pass diff --git a/src/wormhole/test/dilate/__init__.py b/src/wormhole/test/dilate/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/wormhole/test/dilate/common.py b/src/wormhole/test/dilate/common.py new file mode 100644 index 0000000..2ddacfb --- /dev/null +++ b/src/wormhole/test/dilate/common.py @@ -0,0 +1,18 @@ +from __future__ import print_function, unicode_literals +import mock +from zope.interface import alsoProvides +from ..._interfaces import IDilationManager, IWormhole + +def mock_manager(): + m = mock.Mock() + alsoProvides(m, IDilationManager) + return m + +def mock_wormhole(): + m = mock.Mock() + alsoProvides(m, IWormhole) + return m + +def clear_mock_calls(*args): + for a in args: + a.mock_calls[:] = [] diff --git a/src/wormhole/test/dilate/test_connection.py b/src/wormhole/test/dilate/test_connection.py new file mode 100644 index 0000000..406e42a --- /dev/null +++ b/src/wormhole/test/dilate/test_connection.py @@ -0,0 +1,216 @@ +from __future__ import print_function, unicode_literals +import mock +from zope.interface import alsoProvides +from twisted.trial import unittest +from twisted.internet.task import Clock +from twisted.internet.interfaces import ITransport +from ...eventual import EventualQueue +from ..._interfaces import IDilationConnector +from ..._dilation.roles import LEADER, FOLLOWER +from ..._dilation.connection import (DilatedConnectionProtocol, encode_record, + KCM, Open, Ack) +from .common import clear_mock_calls + +def make_con(role, use_relay=False): + clock = Clock() + eq = EventualQueue(clock) + connector = mock.Mock() + alsoProvides(connector, IDilationConnector) + n = mock.Mock() # pretends to be a Noise object + n.write_message = mock.Mock(side_effect=[b"handshake"]) + c = DilatedConnectionProtocol(eq, role, connector, n, + b"outbound_prologue\n", b"inbound_prologue\n") + if use_relay: + c.use_relay(b"relay_handshake\n") + t = mock.Mock() + alsoProvides(t, ITransport) + return c, n, connector, t, eq + +class Connection(unittest.TestCase): + def test_bad_prologue(self): + c, n, connector, t, eq = make_con(LEADER) + c.makeConnection(t) + d = c.when_disconnected() + self.assertEqual(n.mock_calls, [mock.call.start_handshake()]) + self.assertEqual(connector.mock_calls, []) + self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")]) + clear_mock_calls(n, connector, t) + + c.dataReceived(b"prologue\n") + self.assertEqual(n.mock_calls, []) + self.assertEqual(connector.mock_calls, []) + self.assertEqual(t.mock_calls, [mock.call.loseConnection()]) + + eq.flush_sync() + self.assertNoResult(d) + c.connectionLost(b"why") + eq.flush_sync() + self.assertIdentical(self.successResultOf(d), c) + + def _test_no_relay(self, role): + c, n, connector, t, eq = make_con(role) + t_kcm = KCM() + t_open = Open(seqnum=1, scid=0x11223344) + t_ack = Ack(resp_seqnum=2) + n.decrypt = mock.Mock(side_effect=[ + encode_record(t_kcm), + encode_record(t_open), + ]) + exp_kcm = b"\x00\x00\x00\x03kcm" + n.encrypt = mock.Mock(side_effect=[b"kcm", b"ack1"]) + m = mock.Mock() # Manager + + c.makeConnection(t) + self.assertEqual(n.mock_calls, [mock.call.start_handshake()]) + self.assertEqual(connector.mock_calls, []) + self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")]) + clear_mock_calls(n, connector, t, m) + + c.dataReceived(b"inbound_prologue\n") + self.assertEqual(n.mock_calls, [mock.call.write_message()]) + self.assertEqual(connector.mock_calls, []) + exp_handshake = b"\x00\x00\x00\x09handshake" + self.assertEqual(t.mock_calls, [mock.call.write(exp_handshake)]) + clear_mock_calls(n, connector, t, m) + + c.dataReceived(b"\x00\x00\x00\x0Ahandshake2") + if role is LEADER: + # we're the leader, so we don't send the KCM right away + self.assertEqual(n.mock_calls, [ + mock.call.read_message(b"handshake2")]) + self.assertEqual(connector.mock_calls, []) + self.assertEqual(t.mock_calls, []) + self.assertEqual(c._manager, None) + else: + # we're the follower, so we encrypt and send the KCM immediately + self.assertEqual(n.mock_calls, [ + mock.call.read_message(b"handshake2"), + mock.call.encrypt(encode_record(t_kcm)), + ]) + self.assertEqual(connector.mock_calls, []) + self.assertEqual(t.mock_calls, [ + mock.call.write(exp_kcm)]) + self.assertEqual(c._manager, None) + clear_mock_calls(n, connector, t, m) + + c.dataReceived(b"\x00\x00\x00\x03KCM") + # leader: inbound KCM means we add the candidate + # follower: inbound KCM means we've been selected. + # in both cases we notify Connector.add_candidate(), and the Connector + # decides if/when to call .select() + + self.assertEqual(n.mock_calls, [mock.call.decrypt(b"KCM")]) + self.assertEqual(connector.mock_calls, [mock.call.add_candidate(c)]) + self.assertEqual(t.mock_calls, []) + clear_mock_calls(n, connector, t, m) + + # now pretend this connection wins (either the Leader decides to use + # this one among all the candiates, or we're the Follower and the + # Connector is reacting to add_candidate() by recognizing we're the + # only candidate there is) + c.select(m) + self.assertIdentical(c._manager, m) + if role is LEADER: + # TODO: currently Connector.select_and_stop_remaining() is + # responsible for sending the KCM just before calling c.select() + # iff we're the LEADER, therefore Connection.select won't send + # anything. This should be moved to c.select(). + self.assertEqual(n.mock_calls, []) + self.assertEqual(connector.mock_calls, []) + self.assertEqual(t.mock_calls, []) + self.assertEqual(m.mock_calls, []) + + c.send_record(KCM()) + self.assertEqual(n.mock_calls, [ + mock.call.encrypt(encode_record(t_kcm)), + ]) + self.assertEqual(connector.mock_calls, []) + self.assertEqual(t.mock_calls, [mock.call.write(exp_kcm)]) + self.assertEqual(m.mock_calls, []) + else: + # follower: we already sent the KCM, do nothing + self.assertEqual(n.mock_calls, []) + self.assertEqual(connector.mock_calls, []) + self.assertEqual(t.mock_calls, []) + self.assertEqual(m.mock_calls, []) + clear_mock_calls(n, connector, t, m) + + c.dataReceived(b"\x00\x00\x00\x04msg1") + self.assertEqual(n.mock_calls, [mock.call.decrypt(b"msg1")]) + self.assertEqual(connector.mock_calls, []) + self.assertEqual(t.mock_calls, []) + self.assertEqual(m.mock_calls, [mock.call.got_record(t_open)]) + clear_mock_calls(n, connector, t, m) + + c.send_record(t_ack) + exp_ack = b"\x06\x00\x00\x00\x02" + self.assertEqual(n.mock_calls, [mock.call.encrypt(exp_ack)]) + self.assertEqual(connector.mock_calls, []) + self.assertEqual(t.mock_calls, [mock.call.write(b"\x00\x00\x00\x04ack1")]) + self.assertEqual(m.mock_calls, []) + clear_mock_calls(n, connector, t, m) + + c.disconnect() + self.assertEqual(n.mock_calls, []) + self.assertEqual(connector.mock_calls, []) + self.assertEqual(t.mock_calls, [mock.call.loseConnection()]) + self.assertEqual(m.mock_calls, []) + clear_mock_calls(n, connector, t, m) + + def test_no_relay_leader(self): + return self._test_no_relay(LEADER) + + def test_no_relay_follower(self): + return self._test_no_relay(FOLLOWER) + + + def test_relay(self): + c, n, connector, t, eq = make_con(LEADER, use_relay=True) + + c.makeConnection(t) + self.assertEqual(n.mock_calls, [mock.call.start_handshake()]) + self.assertEqual(connector.mock_calls, []) + self.assertEqual(t.mock_calls, [mock.call.write(b"relay_handshake\n")]) + clear_mock_calls(n, connector, t) + + c.dataReceived(b"ok\n") + self.assertEqual(n.mock_calls, []) + self.assertEqual(connector.mock_calls, []) + self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")]) + clear_mock_calls(n, connector, t) + + c.dataReceived(b"inbound_prologue\n") + self.assertEqual(n.mock_calls, [mock.call.write_message()]) + self.assertEqual(connector.mock_calls, []) + exp_handshake = b"\x00\x00\x00\x09handshake" + self.assertEqual(t.mock_calls, [mock.call.write(exp_handshake)]) + clear_mock_calls(n, connector, t) + + def test_relay_jilted(self): + c, n, connector, t, eq = make_con(LEADER, use_relay=True) + d = c.when_disconnected() + + c.makeConnection(t) + self.assertEqual(n.mock_calls, [mock.call.start_handshake()]) + self.assertEqual(connector.mock_calls, []) + self.assertEqual(t.mock_calls, [mock.call.write(b"relay_handshake\n")]) + clear_mock_calls(n, connector, t) + + c.connectionLost(b"why") + eq.flush_sync() + self.assertIdentical(self.successResultOf(d), c) + + def test_relay_bad_response(self): + c, n, connector, t, eq = make_con(LEADER, use_relay=True) + + c.makeConnection(t) + self.assertEqual(n.mock_calls, [mock.call.start_handshake()]) + self.assertEqual(connector.mock_calls, []) + self.assertEqual(t.mock_calls, [mock.call.write(b"relay_handshake\n")]) + clear_mock_calls(n, connector, t) + + c.dataReceived(b"not ok\n") + self.assertEqual(n.mock_calls, []) + self.assertEqual(connector.mock_calls, []) + self.assertEqual(t.mock_calls, [mock.call.loseConnection()]) + clear_mock_calls(n, connector, t) diff --git a/src/wormhole/test/dilate/test_encoding.py b/src/wormhole/test/dilate/test_encoding.py new file mode 100644 index 0000000..e2c854e --- /dev/null +++ b/src/wormhole/test/dilate/test_encoding.py @@ -0,0 +1,25 @@ +from __future__ import print_function, unicode_literals +from twisted.trial import unittest +from ..._dilation.encode import to_be4, from_be4 + +class Encoding(unittest.TestCase): + + def test_be4(self): + self.assertEqual(to_be4(0), b"\x00\x00\x00\x00") + self.assertEqual(to_be4(1), b"\x00\x00\x00\x01") + self.assertEqual(to_be4(256), b"\x00\x00\x01\x00") + self.assertEqual(to_be4(257), b"\x00\x00\x01\x01") + with self.assertRaises(ValueError): + to_be4(-1) + with self.assertRaises(ValueError): + to_be4(2**32) + + self.assertEqual(from_be4(b"\x00\x00\x00\x00"), 0) + self.assertEqual(from_be4(b"\x00\x00\x00\x01"), 1) + self.assertEqual(from_be4(b"\x00\x00\x01\x00"), 256) + self.assertEqual(from_be4(b"\x00\x00\x01\x01"), 257) + + with self.assertRaises(TypeError): + from_be4(0) + with self.assertRaises(ValueError): + from_be4(b"\x01\x00\x00\x00\x00") diff --git a/src/wormhole/test/dilate/test_endpoints.py b/src/wormhole/test/dilate/test_endpoints.py new file mode 100644 index 0000000..bd8f995 --- /dev/null +++ b/src/wormhole/test/dilate/test_endpoints.py @@ -0,0 +1,97 @@ +from __future__ import print_function, unicode_literals +import mock +from zope.interface import alsoProvides +from twisted.trial import unittest +from ..._interfaces import ISubChannel +from ..._dilation.subchannel import (ControlEndpoint, + SubchannelConnectorEndpoint, + SubchannelListenerEndpoint, + SubchannelListeningPort, + _WormholeAddress, _SubchannelAddress, + SingleUseEndpointError) +from .common import mock_manager + +class Endpoints(unittest.TestCase): + def test_control(self): + scid0 = b"scid0" + peeraddr = _SubchannelAddress(scid0) + ep = ControlEndpoint(peeraddr) + + f = mock.Mock() + p = mock.Mock() + f.buildProtocol = mock.Mock(return_value=p) + d = ep.connect(f) + self.assertNoResult(d) + + t = mock.Mock() + alsoProvides(t, ISubChannel) + ep._subchannel_zero_opened(t) + self.assertIdentical(self.successResultOf(d), p) + self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr)]) + self.assertEqual(t.mock_calls, [mock.call._set_protocol(p)]) + self.assertEqual(p.mock_calls, [mock.call.makeConnection(t)]) + + d = ep.connect(f) + self.failureResultOf(d, SingleUseEndpointError) + + def assert_makeConnection(self, mock_calls): + self.assertEqual(len(mock_calls), 1) + self.assertEqual(mock_calls[0][0], "makeConnection") + self.assertEqual(len(mock_calls[0][1]), 1) + return mock_calls[0][1][0] + + def test_connector(self): + m = mock_manager() + m.allocate_subchannel_id = mock.Mock(return_value=b"scid") + hostaddr = _WormholeAddress() + peeraddr = _SubchannelAddress(b"scid") + ep = SubchannelConnectorEndpoint(m, hostaddr) + + f = mock.Mock() + p = mock.Mock() + t = mock.Mock() + f.buildProtocol = mock.Mock(return_value=p) + with mock.patch("wormhole._dilation.subchannel.SubChannel", + return_value=t) as sc: + d = ep.connect(f) + self.assertIdentical(self.successResultOf(d), p) + self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr)]) + self.assertEqual(sc.mock_calls, [mock.call(b"scid", m, hostaddr, peeraddr)]) + self.assertEqual(t.mock_calls, [mock.call._set_protocol(p)]) + self.assertEqual(p.mock_calls, [mock.call.makeConnection(t)]) + + def test_listener(self): + m = mock_manager() + m.allocate_subchannel_id = mock.Mock(return_value=b"scid") + hostaddr = _WormholeAddress() + ep = SubchannelListenerEndpoint(m, hostaddr) + + f = mock.Mock() + p1 = mock.Mock() + p2 = mock.Mock() + f.buildProtocol = mock.Mock(side_effect=[p1, p2]) + + # OPEN that arrives before we ep.listen() should be queued + + t1 = mock.Mock() + peeraddr1 = _SubchannelAddress(b"peer1") + ep._got_open(t1, peeraddr1) + + d = ep.listen(f) + lp = self.successResultOf(d) + self.assertIsInstance(lp, SubchannelListeningPort) + + self.assertEqual(lp.getHost(), hostaddr) + lp.startListening() + + self.assertEqual(t1.mock_calls, [mock.call._set_protocol(p1)]) + self.assertEqual(p1.mock_calls, [mock.call.makeConnection(t1)]) + + t2 = mock.Mock() + peeraddr2 = _SubchannelAddress(b"peer2") + ep._got_open(t2, peeraddr2) + + self.assertEqual(t2.mock_calls, [mock.call._set_protocol(p2)]) + self.assertEqual(p2.mock_calls, [mock.call.makeConnection(t2)]) + + lp.stopListening() # TODO: should this do more? diff --git a/src/wormhole/test/dilate/test_framer.py b/src/wormhole/test/dilate/test_framer.py new file mode 100644 index 0000000..81d4cf9 --- /dev/null +++ b/src/wormhole/test/dilate/test_framer.py @@ -0,0 +1,110 @@ +from __future__ import print_function, unicode_literals +import mock +from zope.interface import alsoProvides +from twisted.trial import unittest +from twisted.internet.interfaces import ITransport +from ..._dilation.connection import _Framer, Frame, Prologue, Disconnect + +def make_framer(): + t = mock.Mock() + alsoProvides(t, ITransport) + f = _Framer(t, b"outbound_prologue\n", b"inbound_prologue\n") + return f, t + +class Framer(unittest.TestCase): + def test_bad_prologue_length(self): + f, t = make_framer() + self.assertEqual(t.mock_calls, []) + + f.connectionMade() + self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")]) + t.mock_calls[:] = [] + self.assertEqual([], list(f.add_and_parse(b"inbound_"))) # wait for it + self.assertEqual(t.mock_calls, []) + + with mock.patch("wormhole._dilation.connection.log.msg") as m: + with self.assertRaises(Disconnect): + list(f.add_and_parse(b"not the prologue after all")) + self.assertEqual(m.mock_calls, + [mock.call("bad prologue: {}".format( + b"inbound_not the p"))]) + self.assertEqual(t.mock_calls, []) + + def test_bad_prologue_newline(self): + f, t = make_framer() + self.assertEqual(t.mock_calls, []) + + f.connectionMade() + self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")]) + t.mock_calls[:] = [] + self.assertEqual([], list(f.add_and_parse(b"inbound_"))) # wait for it + + self.assertEqual([], list(f.add_and_parse(b"not"))) + with mock.patch("wormhole._dilation.connection.log.msg") as m: + with self.assertRaises(Disconnect): + list(f.add_and_parse(b"\n")) + self.assertEqual(m.mock_calls, + [mock.call("bad prologue: {}".format( + b"inbound_not\n"))]) + self.assertEqual(t.mock_calls, []) + + def test_good_prologue(self): + f, t = make_framer() + self.assertEqual(t.mock_calls, []) + + f.connectionMade() + self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")]) + t.mock_calls[:] = [] + self.assertEqual([Prologue()], + list(f.add_and_parse(b"inbound_prologue\n"))) + self.assertEqual(t.mock_calls, []) + + # now send_frame should work + f.send_frame(b"frame") + self.assertEqual(t.mock_calls, + [mock.call.write(b"\x00\x00\x00\x05frame")]) + + def test_bad_relay(self): + f, t = make_framer() + self.assertEqual(t.mock_calls, []) + f.use_relay(b"relay handshake\n") + + f.connectionMade() + self.assertEqual(t.mock_calls, [mock.call.write(b"relay handshake\n")]) + t.mock_calls[:] = [] + with mock.patch("wormhole._dilation.connection.log.msg") as m: + with self.assertRaises(Disconnect): + list(f.add_and_parse(b"goodbye\n")) + self.assertEqual(m.mock_calls, + [mock.call("bad relay_ok: {}".format(b"goo"))]) + self.assertEqual(t.mock_calls, []) + + def test_good_relay(self): + f, t = make_framer() + self.assertEqual(t.mock_calls, []) + f.use_relay(b"relay handshake\n") + self.assertEqual(t.mock_calls, []) + + f.connectionMade() + self.assertEqual(t.mock_calls, [mock.call.write(b"relay handshake\n")]) + t.mock_calls[:] = [] + + self.assertEqual([], list(f.add_and_parse(b"ok\n"))) + self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")]) + + def test_frame(self): + f, t = make_framer() + self.assertEqual(t.mock_calls, []) + + f.connectionMade() + self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")]) + t.mock_calls[:] = [] + self.assertEqual([Prologue()], + list(f.add_and_parse(b"inbound_prologue\n"))) + self.assertEqual(t.mock_calls, []) + + encoded_frame = b"\x00\x00\x00\x05frame" + self.assertEqual([], list(f.add_and_parse(encoded_frame[:2]))) + self.assertEqual([], list(f.add_and_parse(encoded_frame[2:6]))) + self.assertEqual([Frame(frame=b"frame")], + list(f.add_and_parse(encoded_frame[6:]))) diff --git a/src/wormhole/test/dilate/test_inbound.py b/src/wormhole/test/dilate/test_inbound.py new file mode 100644 index 0000000..d147283 --- /dev/null +++ b/src/wormhole/test/dilate/test_inbound.py @@ -0,0 +1,172 @@ +from __future__ import print_function, unicode_literals +import mock +from zope.interface import alsoProvides +from twisted.trial import unittest +from ..._interfaces import IDilationManager +from ..._dilation.connection import Open, Data, Close +from ..._dilation.inbound import (Inbound, DuplicateOpenError, + DataForMissingSubchannelError, + CloseForMissingSubchannelError) + +def make_inbound(): + m = mock.Mock() + alsoProvides(m, IDilationManager) + host_addr = object() + i = Inbound(m, host_addr) + return i, m, host_addr + +class InboundTest(unittest.TestCase): + def test_seqnum(self): + i, m, host_addr = make_inbound() + r1 = Open(scid=513, seqnum=1) + r2 = Data(scid=513, seqnum=2, data=b"") + r3 = Close(scid=513, seqnum=3) + self.assertFalse(i.is_record_old(r1)) + self.assertFalse(i.is_record_old(r2)) + self.assertFalse(i.is_record_old(r3)) + + i.update_ack_watermark(r1) + self.assertTrue(i.is_record_old(r1)) + self.assertFalse(i.is_record_old(r2)) + self.assertFalse(i.is_record_old(r3)) + + i.update_ack_watermark(r2) + self.assertTrue(i.is_record_old(r1)) + self.assertTrue(i.is_record_old(r2)) + self.assertFalse(i.is_record_old(r3)) + + def test_open_data_close(self): + i, m, host_addr = make_inbound() + scid1 = b"scid" + scid2 = b"scXX" + c = mock.Mock() + lep = mock.Mock() + i.set_listener_endpoint(lep) + i.use_connection(c) + sc1 = mock.Mock() + peer_addr = object() + with mock.patch("wormhole._dilation.inbound.SubChannel", + side_effect=[sc1]) as sc: + with mock.patch("wormhole._dilation.inbound._SubchannelAddress", + side_effect=[peer_addr]) as sca: + i.handle_open(scid1) + self.assertEqual(lep.mock_calls, [mock.call._got_open(sc1, peer_addr)]) + self.assertEqual(sc.mock_calls, [mock.call(scid1, m, host_addr, peer_addr)]) + self.assertEqual(sca.mock_calls, [mock.call(scid1)]) + lep.mock_calls[:] = [] + + # a subsequent duplicate OPEN should be ignored + with mock.patch("wormhole._dilation.inbound.SubChannel", + side_effect=[sc1]) as sc: + with mock.patch("wormhole._dilation.inbound._SubchannelAddress", + side_effect=[peer_addr]) as sca: + i.handle_open(scid1) + self.assertEqual(lep.mock_calls, []) + self.assertEqual(sc.mock_calls, []) + self.assertEqual(sca.mock_calls, []) + self.flushLoggedErrors(DuplicateOpenError) + + i.handle_data(scid1, b"data") + self.assertEqual(sc1.mock_calls, [mock.call.remote_data(b"data")]) + sc1.mock_calls[:] = [] + + i.handle_data(scid2, b"for non-existent subchannel") + self.assertEqual(sc1.mock_calls, []) + self.flushLoggedErrors(DataForMissingSubchannelError) + + i.handle_close(scid1) + self.assertEqual(sc1.mock_calls, [mock.call.remote_close()]) + sc1.mock_calls[:] = [] + + i.handle_close(scid2) + self.assertEqual(sc1.mock_calls, []) + self.flushLoggedErrors(CloseForMissingSubchannelError) + + # after the subchannel is closed, the Manager will notify Inbound + i.subchannel_closed(scid1, sc1) + + i.stop_using_connection() + + def test_control_channel(self): + i, m, host_addr = make_inbound() + lep = mock.Mock() + i.set_listener_endpoint(lep) + + scid0 = b"scid" + sc0 = mock.Mock() + i.set_subchannel_zero(scid0, sc0) + + # OPEN on the control channel identifier should be ignored as a + # duplicate, since the control channel is already registered + sc1 = mock.Mock() + peer_addr = object() + with mock.patch("wormhole._dilation.inbound.SubChannel", + side_effect=[sc1]) as sc: + with mock.patch("wormhole._dilation.inbound._SubchannelAddress", + side_effect=[peer_addr]) as sca: + i.handle_open(scid0) + self.assertEqual(lep.mock_calls, []) + self.assertEqual(sc.mock_calls, []) + self.assertEqual(sca.mock_calls, []) + self.flushLoggedErrors(DuplicateOpenError) + + # and DATA to it should be delivered correctly + i.handle_data(scid0, b"data") + self.assertEqual(sc0.mock_calls, [mock.call.remote_data(b"data")]) + sc0.mock_calls[:] = [] + + def test_pause(self): + i, m, host_addr = make_inbound() + c = mock.Mock() + lep = mock.Mock() + i.set_listener_endpoint(lep) + + # add two subchannels, pause one, then add a connection + scid1 = b"sci1" + scid2 = b"sci2" + sc1 = mock.Mock() + sc2 = mock.Mock() + peer_addr = object() + with mock.patch("wormhole._dilation.inbound.SubChannel", + side_effect=[sc1, sc2]): + with mock.patch("wormhole._dilation.inbound._SubchannelAddress", + return_value=peer_addr): + i.handle_open(scid1) + i.handle_open(scid2) + self.assertEqual(c.mock_calls, []) + + i.subchannel_pauseProducing(sc1) + self.assertEqual(c.mock_calls, []) + i.subchannel_resumeProducing(sc1) + self.assertEqual(c.mock_calls, []) + i.subchannel_pauseProducing(sc1) + self.assertEqual(c.mock_calls, []) + + i.use_connection(c) + self.assertEqual(c.mock_calls, [mock.call.pauseProducing()]) + c.mock_calls[:] = [] + + i.subchannel_resumeProducing(sc1) + self.assertEqual(c.mock_calls, [mock.call.resumeProducing()]) + c.mock_calls[:] = [] + + # consumers aren't really supposed to do this, but tolerate it + i.subchannel_resumeProducing(sc1) + self.assertEqual(c.mock_calls, []) + + i.subchannel_pauseProducing(sc1) + self.assertEqual(c.mock_calls, [mock.call.pauseProducing()]) + c.mock_calls[:] = [] + i.subchannel_pauseProducing(sc2) + self.assertEqual(c.mock_calls, []) # was already paused + + # tolerate duplicate pauseProducing + i.subchannel_pauseProducing(sc2) + self.assertEqual(c.mock_calls, []) + + # stopProducing is treated like a terminal resumeProducing + i.subchannel_stopProducing(sc1) + self.assertEqual(c.mock_calls, []) + i.subchannel_stopProducing(sc2) + self.assertEqual(c.mock_calls, [mock.call.resumeProducing()]) + c.mock_calls[:] = [] diff --git a/src/wormhole/test/dilate/test_manager.py b/src/wormhole/test/dilate/test_manager.py new file mode 100644 index 0000000..625039d --- /dev/null +++ b/src/wormhole/test/dilate/test_manager.py @@ -0,0 +1,205 @@ +from __future__ import print_function, unicode_literals +from zope.interface import alsoProvides +from twisted.trial import unittest +from twisted.internet.defer import Deferred +from twisted.internet.task import Clock, Cooperator +import mock +from ...eventual import EventualQueue +from ..._interfaces import ISend, IDilationManager +from ...util import dict_to_bytes +from ..._dilation.manager import (Dilator, + OldPeerCannotDilateError, + UnknownDilationMessageType) +from ..._dilation.subchannel import _WormholeAddress +from .common import clear_mock_calls + +def make_dilator(): + reactor = object() + clock = Clock() + eq = EventualQueue(clock) + term = mock.Mock(side_effect=lambda: True) # one write per Eventual tick + term_factory = lambda: term + coop = Cooperator(terminationPredicateFactory=term_factory, + scheduler=eq.eventually) + send = mock.Mock() + alsoProvides(send, ISend) + dil = Dilator(reactor, eq, coop) + dil.wire(send) + return dil, send, reactor, eq, clock, coop + +class TestDilator(unittest.TestCase): + def test_leader(self): + dil, send, reactor, eq, clock, coop = make_dilator() + d1 = dil.dilate() + d2 = dil.dilate() + self.assertNoResult(d1) + self.assertNoResult(d2) + + key = b"key" + transit_key = object() + with mock.patch("wormhole._dilation.manager.derive_key", + return_value=transit_key) as dk: + dil.got_key(key) + self.assertEqual(dk.mock_calls, [mock.call(key, b"dilation-v1", 32)]) + self.assertIdentical(dil._transit_key, transit_key) + self.assertNoResult(d1) + self.assertNoResult(d2) + + m = mock.Mock() + alsoProvides(m, IDilationManager) + m.when_first_connected.return_value = wfc_d = Deferred() + # TODO: test missing can-dilate, and no-overlap + with mock.patch("wormhole._dilation.manager.ManagerLeader", + return_value=m) as ml: + dil.got_wormhole_versions("us", "them", {"can-dilate": [1]}) + # that should create the Manager. Because "us" > "them", we're + # the leader + self.assertEqual(ml.mock_calls, [mock.call(send, "us", transit_key, + None, reactor, eq, coop)]) + self.assertEqual(m.mock_calls, [mock.call.start(), + mock.call.when_first_connected(), + ]) + clear_mock_calls(m) + self.assertNoResult(d1) + self.assertNoResult(d2) + + host_addr = _WormholeAddress() + m_wa = mock.patch("wormhole._dilation.manager._WormholeAddress", + return_value=host_addr) + peer_addr = object() + m_sca = mock.patch("wormhole._dilation.manager._SubchannelAddress", + return_value=peer_addr) + ce = mock.Mock() + m_ce = mock.patch("wormhole._dilation.manager.ControlEndpoint", + return_value=ce) + sc = mock.Mock() + m_sc = mock.patch("wormhole._dilation.manager.SubChannel", + return_value=sc) + + lep = object() + m_sle = mock.patch("wormhole._dilation.manager.SubchannelListenerEndpoint", + return_value=lep) + + with m_wa, m_sca, m_ce as m_ce_m, m_sc as m_sc_m, m_sle as m_sle_m: + wfc_d.callback(None) + eq.flush_sync() + scid0 = b"\x00\x00\x00\x00" + self.assertEqual(m_ce_m.mock_calls, [mock.call(peer_addr)]) + self.assertEqual(m_sc_m.mock_calls, + [mock.call(scid0, m, host_addr, peer_addr)]) + self.assertEqual(ce.mock_calls, [mock.call._subchannel_zero_opened(sc)]) + self.assertEqual(m_sle_m.mock_calls, [mock.call(m, host_addr)]) + self.assertEqual(m.mock_calls, + [mock.call.set_subchannel_zero(scid0, sc), + mock.call.set_listener_endpoint(lep), + ]) + clear_mock_calls(m) + + eps = self.successResultOf(d1) + self.assertEqual(eps, self.successResultOf(d2)) + d3 = dil.dilate() + eq.flush_sync() + self.assertEqual(eps, self.successResultOf(d3)) + + self.assertEqual(m.mock_calls, []) + dil.received_dilate(dict_to_bytes(dict(type="please"))) + self.assertEqual(m.mock_calls, [mock.call.rx_PLEASE()]) + clear_mock_calls(m) + + hintmsg = dict(type="connection-hints") + dil.received_dilate(dict_to_bytes(hintmsg)) + self.assertEqual(m.mock_calls, [mock.call.rx_HINTS(hintmsg)]) + clear_mock_calls(m) + + dil.received_dilate(dict_to_bytes(dict(type="dilate"))) + self.assertEqual(m.mock_calls, [mock.call.rx_DILATE()]) + clear_mock_calls(m) + + dil.received_dilate(dict_to_bytes(dict(type="unknown"))) + self.assertEqual(m.mock_calls, []) + self.flushLoggedErrors(UnknownDilationMessageType) + + def test_follower(self): + dil, send, reactor, eq, clock, coop = make_dilator() + d1 = dil.dilate() + self.assertNoResult(d1) + + key = b"key" + transit_key = object() + with mock.patch("wormhole._dilation.manager.derive_key", + return_value=transit_key): + dil.got_key(key) + + m = mock.Mock() + alsoProvides(m, IDilationManager) + m.when_first_connected.return_value = Deferred() + with mock.patch("wormhole._dilation.manager.ManagerFollower", + return_value=m) as mf: + dil.got_wormhole_versions("me", "you", {"can-dilate": [1]}) + # "me" < "you", so we're the follower + self.assertEqual(mf.mock_calls, [mock.call(send, "me", transit_key, + None, reactor, eq, coop)]) + self.assertEqual(m.mock_calls, [mock.call.start(), + mock.call.when_first_connected(), + ]) + + def test_peer_cannot_dilate(self): + dil, send, reactor, eq, clock, coop = make_dilator() + d1 = dil.dilate() + self.assertNoResult(d1) + + dil.got_wormhole_versions("me", "you", {}) # missing "can-dilate" + eq.flush_sync() + f = self.failureResultOf(d1) + f.check(OldPeerCannotDilateError) + + + def test_disjoint_versions(self): + dil, send, reactor, eq, clock, coop = make_dilator() + d1 = dil.dilate() + self.assertNoResult(d1) + + dil.got_wormhole_versions("me", "you", {"can-dilate": [-1]}) + eq.flush_sync() + f = self.failureResultOf(d1) + f.check(OldPeerCannotDilateError) + + + def test_early_dilate_messages(self): + dil, send, reactor, eq, clock, coop = make_dilator() + dil._transit_key = b"key" + d1 = dil.dilate() + self.assertNoResult(d1) + dil.received_dilate(dict_to_bytes(dict(type="please"))) + hintmsg = dict(type="connection-hints") + dil.received_dilate(dict_to_bytes(hintmsg)) + + m = mock.Mock() + alsoProvides(m, IDilationManager) + m.when_first_connected.return_value = Deferred() + + with mock.patch("wormhole._dilation.manager.ManagerLeader", + return_value=m) as ml: + dil.got_wormhole_versions("us", "them", {"can-dilate": [1]}) + self.assertEqual(ml.mock_calls, [mock.call(send, "us", b"key", + None, reactor, eq, coop)]) + self.assertEqual(m.mock_calls, [mock.call.start(), + mock.call.rx_PLEASE(), + mock.call.rx_HINTS(hintmsg), + mock.call.when_first_connected()]) + + + + def test_transit_relay(self): + dil, send, reactor, eq, clock, coop = make_dilator() + dil._transit_key = b"key" + relay = object() + d1 = dil.dilate(transit_relay_location=relay) + self.assertNoResult(d1) + + with mock.patch("wormhole._dilation.manager.ManagerLeader") as ml: + dil.got_wormhole_versions("us", "them", {"can-dilate": [1]}) + self.assertEqual(ml.mock_calls, [mock.call(send, "us", b"key", + relay, reactor, eq, coop), + mock.call().start(), + mock.call().when_first_connected()]) diff --git a/src/wormhole/test/dilate/test_outbound.py b/src/wormhole/test/dilate/test_outbound.py new file mode 100644 index 0000000..fab596a --- /dev/null +++ b/src/wormhole/test/dilate/test_outbound.py @@ -0,0 +1,645 @@ +from __future__ import print_function, unicode_literals +from collections import namedtuple +from itertools import cycle +import mock +from zope.interface import alsoProvides +from twisted.trial import unittest +from twisted.internet.task import Clock, Cooperator +from twisted.internet.interfaces import IPullProducer +from ...eventual import EventualQueue +from ..._interfaces import IDilationManager +from ..._dilation.connection import KCM, Open, Data, Close, Ack +from ..._dilation.outbound import Outbound, PullToPush +from .common import clear_mock_calls + +Pauser = namedtuple("Pauser", ["seqnum"]) +NonPauser = namedtuple("NonPauser", ["seqnum"]) +Stopper = namedtuple("Stopper", ["sc"]) + +def make_outbound(): + m = mock.Mock() + alsoProvides(m, IDilationManager) + clock = Clock() + eq = EventualQueue(clock) + term = mock.Mock(side_effect=lambda: True) # one write per Eventual tick + term_factory = lambda: term + coop = Cooperator(terminationPredicateFactory=term_factory, + scheduler=eq.eventually) + o = Outbound(m, coop) + c = mock.Mock() # Connection + def maybe_pause(r): + if isinstance(r, Pauser): + o.pauseProducing() + elif isinstance(r, Stopper): + o.subchannel_unregisterProducer(r.sc) + c.send_record = mock.Mock(side_effect=maybe_pause) + o._test_eq = eq + o._test_term = term + return o, m, c + +class OutboundTest(unittest.TestCase): + def test_build_record(self): + o, m, c = make_outbound() + scid1 = b"scid" + self.assertEqual(o.build_record(Open, scid1), + Open(seqnum=0, scid=b"scid")) + self.assertEqual(o.build_record(Data, scid1, b"dataaa"), + Data(seqnum=1, scid=b"scid", data=b"dataaa")) + self.assertEqual(o.build_record(Close, scid1), + Close(seqnum=2, scid=b"scid")) + self.assertEqual(o.build_record(Close, scid1), + Close(seqnum=3, scid=b"scid")) + + def test_outbound_queue(self): + o, m, c = make_outbound() + scid1 = b"scid" + r1 = o.build_record(Open, scid1) + r2 = o.build_record(Data, scid1, b"data1") + r3 = o.build_record(Data, scid1, b"data2") + o.queue_and_send_record(r1) + o.queue_and_send_record(r2) + o.queue_and_send_record(r3) + self.assertEqual(list(o._outbound_queue), [r1, r2, r3]) + + # we would never normally receive an ACK without first getting a + # connection + o.handle_ack(r2.seqnum) + self.assertEqual(list(o._outbound_queue), [r3]) + + o.handle_ack(r3.seqnum) + self.assertEqual(list(o._outbound_queue), []) + + o.handle_ack(r3.seqnum) # ignored + self.assertEqual(list(o._outbound_queue), []) + + o.handle_ack(r1.seqnum) # ignored + self.assertEqual(list(o._outbound_queue), []) + + def test_duplicate_registerProducer(self): + o, m, c = make_outbound() + sc1 = object() + p1 = mock.Mock() + o.subchannel_registerProducer(sc1, p1, True) + with self.assertRaises(ValueError) as ar: + o.subchannel_registerProducer(sc1, p1, True) + s = str(ar.exception) + self.assertIn("registering producer", s) + self.assertIn("before previous one", s) + self.assertIn("was unregistered", s) + + def test_connection_send_queued_unpaused(self): + o, m, c = make_outbound() + scid1 = b"scid" + r1 = o.build_record(Open, scid1) + r2 = o.build_record(Data, scid1, b"data1") + r3 = o.build_record(Data, scid1, b"data2") + o.queue_and_send_record(r1) + o.queue_and_send_record(r2) + self.assertEqual(list(o._outbound_queue), [r1, r2]) + self.assertEqual(list(o._queued_unsent), []) + + # as soon as the connection is established, everything is sent + o.use_connection(c) + self.assertEqual(c.mock_calls, [mock.call.registerProducer(o, True), + mock.call.send_record(r1), + mock.call.send_record(r2)]) + self.assertEqual(list(o._outbound_queue), [r1, r2]) + self.assertEqual(list(o._queued_unsent), []) + clear_mock_calls(c) + + o.queue_and_send_record(r3) + self.assertEqual(list(o._outbound_queue), [r1, r2, r3]) + self.assertEqual(list(o._queued_unsent), []) + self.assertEqual(c.mock_calls, [mock.call.send_record(r3)]) + + def test_connection_send_queued_paused(self): + o, m, c = make_outbound() + r1 = Pauser(seqnum=1) + r2 = Pauser(seqnum=2) + r3 = Pauser(seqnum=3) + o.queue_and_send_record(r1) + o.queue_and_send_record(r2) + self.assertEqual(list(o._outbound_queue), [r1, r2]) + self.assertEqual(list(o._queued_unsent), []) + + # pausing=True, so our mock Manager will pause the Outbound producer + # after each write. So only r1 should have been sent before getting + # paused + o.use_connection(c) + self.assertEqual(c.mock_calls, [mock.call.registerProducer(o, True), + mock.call.send_record(r1)]) + self.assertEqual(list(o._outbound_queue), [r1, r2]) + self.assertEqual(list(o._queued_unsent), [r2]) + clear_mock_calls(c) + + # Outbound is responsible for sending all records, so when Manager + # wants to send a new one, and Outbound is still in the middle of + # draining the beginning-of-connection queue, the new message gets + # queued behind the rest (in addition to being queued in + # _outbound_queue until an ACK retires it). + o.queue_and_send_record(r3) + self.assertEqual(list(o._outbound_queue), [r1, r2, r3]) + self.assertEqual(list(o._queued_unsent), [r2, r3]) + self.assertEqual(c.mock_calls, []) + + o.handle_ack(r1.seqnum) + self.assertEqual(list(o._outbound_queue), [r2, r3]) + self.assertEqual(list(o._queued_unsent), [r2, r3]) + self.assertEqual(c.mock_calls, []) + + def test_premptive_ack(self): + # one mode I have in mind is for each side to send an immediate ACK, + # with everything they've ever seen, as the very first message on each + # new connection. The idea is that you might preempt sending stuff from + # the _queued_unsent list if it arrives fast enough (in practice this + # is more likely to be delivered via the DILATE mailbox message, but + # the effects might be vaguely similar, so it seems worth testing + # here). A similar situation would be if each side sends ACKs with the + # highest seqnum they've ever seen, instead of merely ACKing the + # message which was just received. + o, m, c = make_outbound() + r1 = Pauser(seqnum=1) + r2 = Pauser(seqnum=2) + r3 = Pauser(seqnum=3) + o.queue_and_send_record(r1) + o.queue_and_send_record(r2) + self.assertEqual(list(o._outbound_queue), [r1, r2]) + self.assertEqual(list(o._queued_unsent), []) + + o.use_connection(c) + self.assertEqual(c.mock_calls, [mock.call.registerProducer(o, True), + mock.call.send_record(r1)]) + self.assertEqual(list(o._outbound_queue), [r1, r2]) + self.assertEqual(list(o._queued_unsent), [r2]) + clear_mock_calls(c) + + o.queue_and_send_record(r3) + self.assertEqual(list(o._outbound_queue), [r1, r2, r3]) + self.assertEqual(list(o._queued_unsent), [r2, r3]) + self.assertEqual(c.mock_calls, []) + + o.handle_ack(r2.seqnum) + self.assertEqual(list(o._outbound_queue), [r3]) + self.assertEqual(list(o._queued_unsent), [r3]) + self.assertEqual(c.mock_calls, []) + + def test_pause(self): + o, m, c = make_outbound() + o.use_connection(c) + self.assertEqual(c.mock_calls, [mock.call.registerProducer(o, True)]) + self.assertEqual(list(o._outbound_queue), []) + self.assertEqual(list(o._queued_unsent), []) + clear_mock_calls(c) + + sc1, sc2, sc3 = object(), object(), object() + p1, p2, p3 = mock.Mock(name="p1"), mock.Mock(name="p2"), mock.Mock(name="p3") + + # we aren't paused yet, since we haven't sent any data + o.subchannel_registerProducer(sc1, p1, True) + self.assertEqual(p1.mock_calls, []) + + r1 = Pauser(seqnum=1) + o.queue_and_send_record(r1) + # now we should be paused + self.assertTrue(o._paused) + self.assertEqual(c.mock_calls, [mock.call.send_record(r1)]) + self.assertEqual(p1.mock_calls, [mock.call.pauseProducing()]) + clear_mock_calls(p1, c) + + # so an IPushProducer will be paused right away + o.subchannel_registerProducer(sc2, p2, True) + self.assertEqual(p2.mock_calls, [mock.call.pauseProducing()]) + clear_mock_calls(p2) + + o.subchannel_registerProducer(sc3, p3, True) + self.assertEqual(p3.mock_calls, [mock.call.pauseProducing()]) + self.assertEqual(o._paused_producers, set([p1, p2, p3])) + self.assertEqual(list(o._all_producers), [p1, p2, p3]) + clear_mock_calls(p3) + + # one resumeProducing should cause p1 to get a turn, since p2 was added + # after we were paused and p1 was at the "end" of a one-element list. + # If it writes anything, it will get paused again immediately. + r2 = Pauser(seqnum=2) + p1.resumeProducing.side_effect = lambda: c.send_record(r2) + o.resumeProducing() + self.assertEqual(p1.mock_calls, [mock.call.resumeProducing(), + mock.call.pauseProducing(), + ]) + self.assertEqual(p2.mock_calls, []) + self.assertEqual(p3.mock_calls, []) + self.assertEqual(c.mock_calls, [mock.call.send_record(r2)]) + clear_mock_calls(p1, p2, p3, c) + # p2 should now be at the head of the queue + self.assertEqual(list(o._all_producers), [p2, p3, p1]) + + # next turn: p2 has nothing to send, but p3 does. we should see p3 + # called but not p1. The actual sequence of expected calls is: + # p2.resume, p3.resume, pauseProducing, set(p2.pause, p3.pause) + r3 = Pauser(seqnum=3) + p2.resumeProducing.side_effect = lambda: None + p3.resumeProducing.side_effect = lambda: c.send_record(r3) + o.resumeProducing() + self.assertEqual(p1.mock_calls, []) + self.assertEqual(p2.mock_calls, [mock.call.resumeProducing(), + mock.call.pauseProducing(), + ]) + self.assertEqual(p3.mock_calls, [mock.call.resumeProducing(), + mock.call.pauseProducing(), + ]) + self.assertEqual(c.mock_calls, [mock.call.send_record(r3)]) + clear_mock_calls(p1, p2, p3, c) + # p1 should now be at the head of the queue + self.assertEqual(list(o._all_producers), [p1, p2, p3]) + + # next turn: p1 has data to send, but not enough to cause a pause. same + # for p2. p3 causes a pause + r4 = NonPauser(seqnum=4) + r5 = NonPauser(seqnum=5) + r6 = Pauser(seqnum=6) + p1.resumeProducing.side_effect = lambda: c.send_record(r4) + p2.resumeProducing.side_effect = lambda: c.send_record(r5) + p3.resumeProducing.side_effect = lambda: c.send_record(r6) + o.resumeProducing() + self.assertEqual(p1.mock_calls, [mock.call.resumeProducing(), + mock.call.pauseProducing(), + ]) + self.assertEqual(p2.mock_calls, [mock.call.resumeProducing(), + mock.call.pauseProducing(), + ]) + self.assertEqual(p3.mock_calls, [mock.call.resumeProducing(), + mock.call.pauseProducing(), + ]) + self.assertEqual(c.mock_calls, [mock.call.send_record(r4), + mock.call.send_record(r5), + mock.call.send_record(r6), + ]) + clear_mock_calls(p1, p2, p3, c) + # p1 should now be at the head of the queue again + self.assertEqual(list(o._all_producers), [p1, p2, p3]) + + # now we let it catch up. p1 and p2 send non-pausing data, p3 sends + # nothing. + r7 = NonPauser(seqnum=4) + r8 = NonPauser(seqnum=5) + p1.resumeProducing.side_effect = lambda: c.send_record(r7) + p2.resumeProducing.side_effect = lambda: c.send_record(r8) + p3.resumeProducing.side_effect = lambda: None + + o.resumeProducing() + self.assertEqual(p1.mock_calls, [mock.call.resumeProducing(), + ]) + self.assertEqual(p2.mock_calls, [mock.call.resumeProducing(), + ]) + self.assertEqual(p3.mock_calls, [mock.call.resumeProducing(), + ]) + self.assertEqual(c.mock_calls, [mock.call.send_record(r7), + mock.call.send_record(r8), + ]) + clear_mock_calls(p1, p2, p3, c) + # p1 should now be at the head of the queue again + self.assertEqual(list(o._all_producers), [p1, p2, p3]) + self.assertFalse(o._paused) + + # now a producer disconnects itself (spontaneously, not from inside a + # resumeProducing) + o.subchannel_unregisterProducer(sc1) + self.assertEqual(list(o._all_producers), [p2, p3]) + self.assertEqual(p1.mock_calls, []) + self.assertFalse(o._paused) + + # and another disconnects itself when called + p2.resumeProducing.side_effect = lambda: None + p3.resumeProducing.side_effect = lambda: o.subchannel_unregisterProducer(sc3) + o.pauseProducing() + o.resumeProducing() + self.assertEqual(p2.mock_calls, [mock.call.pauseProducing(), + mock.call.resumeProducing()]) + self.assertEqual(p3.mock_calls, [mock.call.pauseProducing(), + mock.call.resumeProducing()]) + clear_mock_calls(p2, p3) + self.assertEqual(list(o._all_producers), [p2]) + self.assertFalse(o._paused) + + def test_subchannel_closed(self): + o, m, c = make_outbound() + + sc1 = mock.Mock() + p1 = mock.Mock(name="p1") + o.subchannel_registerProducer(sc1, p1, True) + self.assertEqual(p1.mock_calls, [mock.call.pauseProducing()]) + clear_mock_calls(p1) + + o.subchannel_closed(sc1) + self.assertEqual(p1.mock_calls, []) + self.assertEqual(list(o._all_producers), []) + + sc2 = mock.Mock() + o.subchannel_closed(sc2) + + def test_disconnect(self): + o, m, c = make_outbound() + o.use_connection(c) + + sc1 = mock.Mock() + p1 = mock.Mock(name="p1") + o.subchannel_registerProducer(sc1, p1, True) + self.assertEqual(p1.mock_calls, []) + o.stop_using_connection() + self.assertEqual(p1.mock_calls, [mock.call.pauseProducing()]) + + def OFF_test_push_pull(self): + # use one IPushProducer and one IPullProducer. They should take turns + o, m, c = make_outbound() + o.use_connection(c) + clear_mock_calls(c) + + sc1, sc2 = object(), object() + p1, p2 = mock.Mock(name="p1"), mock.Mock(name="p2") + r1 = Pauser(seqnum=1) + r2 = NonPauser(seqnum=2) + + # we aren't paused yet, since we haven't sent any data + o.subchannel_registerProducer(sc1, p1, True) # push + o.queue_and_send_record(r1) + # now we're paused + self.assertTrue(o._paused) + self.assertEqual(c.mock_calls, [mock.call.send_record(r1)]) + self.assertEqual(p1.mock_calls, [mock.call.pauseProducing()]) + self.assertEqual(p2.mock_calls, []) + clear_mock_calls(p1, p2, c) + + p1.resumeProducing.side_effect = lambda: c.send_record(r1) + p2.resumeProducing.side_effect = lambda: c.send_record(r2) + o.subchannel_registerProducer(sc2, p2, False) # pull: always ready + + # p1 is still first, since p2 was just added (at the end) + self.assertTrue(o._paused) + self.assertEqual(c.mock_calls, []) + self.assertEqual(p1.mock_calls, []) + self.assertEqual(p2.mock_calls, []) + self.assertEqual(list(o._all_producers), [p1, p2]) + clear_mock_calls(p1, p2, c) + + # resume should send r1, which should pause everything + o.resumeProducing() + self.assertTrue(o._paused) + self.assertEqual(c.mock_calls, [mock.call.send_record(r1), + ]) + self.assertEqual(p1.mock_calls, [mock.call.resumeProducing(), + mock.call.pauseProducing(), + ]) + self.assertEqual(p2.mock_calls, []) + self.assertEqual(list(o._all_producers), [p2, p1]) # now p2 is next + clear_mock_calls(p1, p2, c) + + # next should fire p2, then p1 + o.resumeProducing() + self.assertTrue(o._paused) + self.assertEqual(c.mock_calls, [mock.call.send_record(r2), + mock.call.send_record(r1), + ]) + self.assertEqual(p1.mock_calls, [mock.call.resumeProducing(), + mock.call.pauseProducing(), + ]) + self.assertEqual(p2.mock_calls, [mock.call.resumeProducing(), + ]) + self.assertEqual(list(o._all_producers), [p2, p1]) # p2 still at bat + clear_mock_calls(p1, p2, c) + + def test_pull_producer(self): + # a single pull producer should write until it is paused, rate-limited + # by the cooperator (so we'll see back-to-back resumeProducing calls + # until the Connection is paused, or 10ms have passed, whichever comes + # first, and if it's stopped by the timer, then the next EventualQueue + # turn will start it off again) + + o, m, c = make_outbound() + eq = o._test_eq + o.use_connection(c) + clear_mock_calls(c) + self.assertFalse(o._paused) + + sc1 = mock.Mock() + p1 = mock.Mock(name="p1") + alsoProvides(p1, IPullProducer) + + records = [NonPauser(seqnum=1)] * 10 + records.append(Pauser(seqnum=2)) + records.append(Stopper(sc1)) + it = iter(records) + p1.resumeProducing.side_effect = lambda: c.send_record(next(it)) + o.subchannel_registerProducer(sc1, p1, False) + eq.flush_sync() # fast forward into the glorious (paused) future + + self.assertTrue(o._paused) + self.assertEqual(c.mock_calls, + [mock.call.send_record(r) for r in records[:-1]]) + self.assertEqual(p1.mock_calls, + [mock.call.resumeProducing()]*(len(records)-1)) + clear_mock_calls(c, p1) + + # next resumeProducing should cause it to disconnect + o.resumeProducing() + eq.flush_sync() + self.assertEqual(c.mock_calls, [mock.call.send_record(records[-1])]) + self.assertEqual(p1.mock_calls, [mock.call.resumeProducing()]) + self.assertEqual(len(o._all_producers), 0) + self.assertFalse(o._paused) + + def test_two_pull_producers(self): + # we should alternate between them until paused + p1_records = ([NonPauser(seqnum=i) for i in range(5)] + + [Pauser(seqnum=5)] + + [NonPauser(seqnum=i) for i in range(6, 10)]) + p2_records = ([NonPauser(seqnum=i) for i in range(10, 19)] + + [Pauser(seqnum=19)]) + expected1 = [NonPauser(0), NonPauser(10), + NonPauser(1), NonPauser(11), + NonPauser(2), NonPauser(12), + NonPauser(3), NonPauser(13), + NonPauser(4), NonPauser(14), + Pauser(5)] + expected2 = [ NonPauser(15), + NonPauser(6), NonPauser(16), + NonPauser(7), NonPauser(17), + NonPauser(8), NonPauser(18), + NonPauser(9), Pauser(19), + ] + + o, m, c = make_outbound() + eq = o._test_eq + o.use_connection(c) + clear_mock_calls(c) + self.assertFalse(o._paused) + + sc1 = mock.Mock() + p1 = mock.Mock(name="p1") + alsoProvides(p1, IPullProducer) + it1 = iter(p1_records) + p1.resumeProducing.side_effect = lambda: c.send_record(next(it1)) + o.subchannel_registerProducer(sc1, p1, False) + + sc2 = mock.Mock() + p2 = mock.Mock(name="p2") + alsoProvides(p2, IPullProducer) + it2 = iter(p2_records) + p2.resumeProducing.side_effect = lambda: c.send_record(next(it2)) + o.subchannel_registerProducer(sc2, p2, False) + + eq.flush_sync() # fast forward into the glorious (paused) future + + sends = [mock.call.resumeProducing()] + self.assertTrue(o._paused) + self.assertEqual(c.mock_calls, + [mock.call.send_record(r) for r in expected1]) + self.assertEqual(p1.mock_calls, 6*sends) + self.assertEqual(p2.mock_calls, 5*sends) + clear_mock_calls(c, p1, p2) + + o.resumeProducing() + eq.flush_sync() + self.assertTrue(o._paused) + self.assertEqual(c.mock_calls, + [mock.call.send_record(r) for r in expected2]) + self.assertEqual(p1.mock_calls, 4*sends) + self.assertEqual(p2.mock_calls, 5*sends) + clear_mock_calls(c, p1, p2) + + def test_send_if_connected(self): + o, m, c = make_outbound() + o.send_if_connected(Ack(1)) # not connected yet + + o.use_connection(c) + o.send_if_connected(KCM()) + self.assertEqual(c.mock_calls, [mock.call.registerProducer(o, True), + mock.call.send_record(KCM())]) + + def test_tolerate_duplicate_pause_resume(self): + o, m, c = make_outbound() + self.assertTrue(o._paused) # no connection + o.use_connection(c) + self.assertFalse(o._paused) + o.pauseProducing() + self.assertTrue(o._paused) + o.pauseProducing() + self.assertTrue(o._paused) + o.resumeProducing() + self.assertFalse(o._paused) + o.resumeProducing() + self.assertFalse(o._paused) + + def test_stopProducing(self): + o, m, c = make_outbound() + o.use_connection(c) + self.assertFalse(o._paused) + o.stopProducing() # connection does this before loss + self.assertTrue(o._paused) + o.stop_using_connection() + self.assertTrue(o._paused) + + def test_resume_error(self): + o, m, c = make_outbound() + o.use_connection(c) + sc1 = mock.Mock() + p1 = mock.Mock(name="p1") + alsoProvides(p1, IPullProducer) + p1.resumeProducing.side_effect = PretendResumptionError + o.subchannel_registerProducer(sc1, p1, False) + o._test_eq.flush_sync() + # the error is supposed to automatically unregister the producer + self.assertEqual(list(o._all_producers), []) + self.flushLoggedErrors(PretendResumptionError) + + +def make_pushpull(pauses): + p = mock.Mock() + alsoProvides(p, IPullProducer) + unregister = mock.Mock() + + clock = Clock() + eq = EventualQueue(clock) + term = mock.Mock(side_effect=lambda: True) # one write per Eventual tick + term_factory = lambda: term + coop = Cooperator(terminationPredicateFactory=term_factory, + scheduler=eq.eventually) + pp = PullToPush(p, unregister, coop) + + it = cycle(pauses) + def action(i): + if isinstance(i, Exception): + raise i + elif i: + pp.pauseProducing() + p.resumeProducing.side_effect = lambda: action(next(it)) + return p, unregister, pp, eq + +class PretendResumptionError(Exception): + pass +class PretendUnregisterError(Exception): + pass + +class PushPull(unittest.TestCase): + # test our wrapper utility, which I copied from + # twisted.internet._producer_helpers since it isn't publically exposed + + def test_start_unpaused(self): + p, unr, pp, eq = make_pushpull([True]) # pause on each resumeProducing + # if it starts unpaused, it gets one write before being halted + pp.startStreaming(False) + eq.flush_sync() + self.assertEqual(p.mock_calls, [mock.call.resumeProducing()]*1) + clear_mock_calls(p) + + # now each time we call resumeProducing, we should see one delivered to + # the underlying IPullProducer + pp.resumeProducing() + eq.flush_sync() + self.assertEqual(p.mock_calls, [mock.call.resumeProducing()]*1) + + pp.stopStreaming() + pp.stopStreaming() # should tolerate this + + def test_start_unpaused_two_writes(self): + p, unr, pp, eq = make_pushpull([False, True]) # pause every other time + # it should get two writes, since the first didn't pause + pp.startStreaming(False) + eq.flush_sync() + self.assertEqual(p.mock_calls, [mock.call.resumeProducing()]*2) + + def test_start_paused(self): + p, unr, pp, eq = make_pushpull([True]) # pause on each resumeProducing + pp.startStreaming(True) + eq.flush_sync() + self.assertEqual(p.mock_calls, []) + pp.stopStreaming() + + def test_stop(self): + p, unr, pp, eq = make_pushpull([True]) + pp.startStreaming(True) + pp.stopProducing() + eq.flush_sync() + self.assertEqual(p.mock_calls, [mock.call.stopProducing()]) + + def test_error(self): + p, unr, pp, eq = make_pushpull([PretendResumptionError()]) + unr.side_effect = lambda: pp.stopStreaming() + pp.startStreaming(False) + eq.flush_sync() + self.assertEqual(unr.mock_calls, [mock.call()]) + self.flushLoggedErrors(PretendResumptionError) + + def test_error_during_unregister(self): + p, unr, pp, eq = make_pushpull([PretendResumptionError()]) + unr.side_effect = PretendUnregisterError() + pp.startStreaming(False) + eq.flush_sync() + self.assertEqual(unr.mock_calls, [mock.call()]) + self.flushLoggedErrors(PretendResumptionError, PretendUnregisterError) + + + + + + # TODO: consider making p1/p2/p3 all elements of a shared Mock, maybe I + # could capture the inter-call ordering that way diff --git a/src/wormhole/test/dilate/test_parse.py b/src/wormhole/test/dilate/test_parse.py new file mode 100644 index 0000000..8365e62 --- /dev/null +++ b/src/wormhole/test/dilate/test_parse.py @@ -0,0 +1,43 @@ +from __future__ import print_function, unicode_literals +import mock +from twisted.trial import unittest +from ..._dilation.connection import (parse_record, encode_record, + KCM, Ping, Pong, Open, Data, Close, Ack) + +class Parse(unittest.TestCase): + def test_parse(self): + self.assertEqual(parse_record(b"\x00"), KCM()) + self.assertEqual(parse_record(b"\x01\x55\x44\x33\x22"), + Ping(ping_id=b"\x55\x44\x33\x22")) + self.assertEqual(parse_record(b"\x02\x55\x44\x33\x22"), + Pong(ping_id=b"\x55\x44\x33\x22")) + self.assertEqual(parse_record(b"\x03\x00\x00\x02\x01\x00\x00\x01\x00"), + Open(scid=513, seqnum=256)) + self.assertEqual(parse_record(b"\x04\x00\x00\x02\x02\x00\x00\x01\x01dataaa"), + Data(scid=514, seqnum=257, data=b"dataaa")) + self.assertEqual(parse_record(b"\x05\x00\x00\x02\x03\x00\x00\x01\x02"), + Close(scid=515, seqnum=258)) + self.assertEqual(parse_record(b"\x06\x00\x00\x01\x03"), + Ack(resp_seqnum=259)) + with mock.patch("wormhole._dilation.connection.log.err") as le: + with self.assertRaises(ValueError): + parse_record(b"\x07unknown") + self.assertEqual(le.mock_calls, + [mock.call("received unknown message type: {}".format( + b"\x07unknown"))]) + + def test_encode(self): + self.assertEqual(encode_record(KCM()), b"\x00") + self.assertEqual(encode_record(Ping(ping_id=b"ping")), b"\x01ping") + self.assertEqual(encode_record(Pong(ping_id=b"pong")), b"\x02pong") + self.assertEqual(encode_record(Open(scid=65536, seqnum=16)), + b"\x03\x00\x01\x00\x00\x00\x00\x00\x10") + self.assertEqual(encode_record(Data(scid=65537, seqnum=17, data=b"dataaa")), + b"\x04\x00\x01\x00\x01\x00\x00\x00\x11dataaa") + self.assertEqual(encode_record(Close(scid=65538, seqnum=18)), + b"\x05\x00\x01\x00\x02\x00\x00\x00\x12") + self.assertEqual(encode_record(Ack(resp_seqnum=19)), + b"\x06\x00\x00\x00\x13") + with self.assertRaises(TypeError) as ar: + encode_record("not a record") + self.assertEqual(str(ar.exception), "not a record") diff --git a/src/wormhole/test/dilate/test_record.py b/src/wormhole/test/dilate/test_record.py new file mode 100644 index 0000000..810396c --- /dev/null +++ b/src/wormhole/test/dilate/test_record.py @@ -0,0 +1,268 @@ +from __future__ import print_function, unicode_literals +import mock +from zope.interface import alsoProvides +from twisted.trial import unittest +from noise.exceptions import NoiseInvalidMessage +from ..._dilation.connection import (IFramer, Frame, Prologue, + _Record, Handshake, + Disconnect, Ping) + +def make_record(): + f = mock.Mock() + alsoProvides(f, IFramer) + n = mock.Mock() # pretends to be a Noise object + r = _Record(f, n) + return r, f, n + +class Record(unittest.TestCase): + def test_good2(self): + f = mock.Mock() + alsoProvides(f, IFramer) + f.add_and_parse = mock.Mock(side_effect=[ + [], + [Prologue()], + [Frame(frame=b"rx-handshake")], + [Frame(frame=b"frame1"), Frame(frame=b"frame2")], + ]) + n = mock.Mock() + n.write_message = mock.Mock(return_value=b"tx-handshake") + p1, p2 = object(), object() + n.decrypt = mock.Mock(side_effect=[p1, p2]) + r = _Record(f, n) + self.assertEqual(f.mock_calls, []) + r.connectionMade() + self.assertEqual(f.mock_calls, [mock.call.connectionMade()]) + f.mock_calls[:] = [] + self.assertEqual(n.mock_calls, [mock.call.start_handshake()]) + n.mock_calls[:] = [] + + # Pretend to deliver the prologue in two parts. The text we send in + # doesn't matter: the side_effect= is what causes the prologue to be + # recognized by the second call. + self.assertEqual(list(r.add_and_unframe(b"pro")), []) + self.assertEqual(f.mock_calls, [mock.call.add_and_parse(b"pro")]) + f.mock_calls[:] = [] + self.assertEqual(n.mock_calls, []) + + # recognizing the prologue causes a handshake frame to be sent + self.assertEqual(list(r.add_and_unframe(b"logue")), []) + self.assertEqual(f.mock_calls, [mock.call.add_and_parse(b"logue"), + mock.call.send_frame(b"tx-handshake")]) + f.mock_calls[:] = [] + self.assertEqual(n.mock_calls, [mock.call.write_message()]) + n.mock_calls[:] = [] + + # next add_and_unframe is recognized as the Handshake + self.assertEqual(list(r.add_and_unframe(b"blah")), [Handshake()]) + self.assertEqual(f.mock_calls, [mock.call.add_and_parse(b"blah")]) + f.mock_calls[:] = [] + self.assertEqual(n.mock_calls, [mock.call.read_message(b"rx-handshake")]) + n.mock_calls[:] = [] + + # next is a pair of Records + r1, r2 = object() , object() + with mock.patch("wormhole._dilation.connection.parse_record", + side_effect=[r1,r2]) as pr: + self.assertEqual(list(r.add_and_unframe(b"blah2")), [r1, r2]) + self.assertEqual(n.mock_calls, [mock.call.decrypt(b"frame1"), + mock.call.decrypt(b"frame2")]) + self.assertEqual(pr.mock_calls, [mock.call(p1), mock.call(p2)]) + + def test_bad_handshake(self): + f = mock.Mock() + alsoProvides(f, IFramer) + f.add_and_parse = mock.Mock(return_value=[Prologue(), + Frame(frame=b"rx-handshake")]) + n = mock.Mock() + n.write_message = mock.Mock(return_value=b"tx-handshake") + nvm = NoiseInvalidMessage() + n.read_message = mock.Mock(side_effect=nvm) + r = _Record(f, n) + self.assertEqual(f.mock_calls, []) + r.connectionMade() + self.assertEqual(f.mock_calls, [mock.call.connectionMade()]) + f.mock_calls[:] = [] + self.assertEqual(n.mock_calls, [mock.call.start_handshake()]) + n.mock_calls[:] = [] + + with mock.patch("wormhole._dilation.connection.log.err") as le: + with self.assertRaises(Disconnect): + list(r.add_and_unframe(b"data")) + self.assertEqual(le.mock_calls, + [mock.call(nvm, "bad inbound noise handshake")]) + + def test_bad_message(self): + f = mock.Mock() + alsoProvides(f, IFramer) + f.add_and_parse = mock.Mock(return_value=[Prologue(), + Frame(frame=b"rx-handshake"), + Frame(frame=b"bad-message")]) + n = mock.Mock() + n.write_message = mock.Mock(return_value=b"tx-handshake") + nvm = NoiseInvalidMessage() + n.decrypt = mock.Mock(side_effect=nvm) + r = _Record(f, n) + self.assertEqual(f.mock_calls, []) + r.connectionMade() + self.assertEqual(f.mock_calls, [mock.call.connectionMade()]) + f.mock_calls[:] = [] + self.assertEqual(n.mock_calls, [mock.call.start_handshake()]) + n.mock_calls[:] = [] + + with mock.patch("wormhole._dilation.connection.log.err") as le: + with self.assertRaises(Disconnect): + list(r.add_and_unframe(b"data")) + self.assertEqual(le.mock_calls, + [mock.call(nvm, "bad inbound noise frame")]) + + def test_send_record(self): + f = mock.Mock() + alsoProvides(f, IFramer) + n = mock.Mock() + f1 = object() + n.encrypt = mock.Mock(return_value=f1) + r1 = Ping(b"pingid") + r = _Record(f, n) + self.assertEqual(f.mock_calls, []) + m1 = object() + with mock.patch("wormhole._dilation.connection.encode_record", + return_value=m1) as er: + r.send_record(r1) + self.assertEqual(er.mock_calls, [mock.call(r1)]) + self.assertEqual(n.mock_calls, [mock.call.start_handshake(), + mock.call.encrypt(m1)]) + self.assertEqual(f.mock_calls, [mock.call.send_frame(f1)]) + + def test_good(self): + # Exercise the success path. The Record instance is given each chunk + # of data as it arrives on Protocol.dataReceived, and is supposed to + # return a series of Tokens (maybe none, if the chunk was incomplete, + # or more than one, if the chunk was larger). Internally, it delivers + # the chunks to the Framer for unframing (which returns 0 or more + # frames), manages the Noise decryption object, and parses any + # decrypted messages into tokens (some of which are consumed + # internally, others for delivery upstairs). + # + # in the normal flow, we get: + # + # | | Inbound | NoiseAction | Outbound | ToUpstairs | + # | | - | - | - | - | + # | 1 | | | prologue | | + # | 2 | prologue | | | | + # | 3 | | write_message | handshake | | + # | 4 | handshake | read_message | | Handshake | + # | 5 | | encrypt | KCM | | + # | 6 | KCM | decrypt | | KCM | + # | 7 | msg1 | decrypt | | msg1 | + + # 1: instantiating the Record instance causes the outbound prologue + # to be sent + + # 2+3: receipt of the inbound prologue triggers creation of the + # ephemeral key (the "handshake") by calling noise.write_message() + # and then writes the handshake to the outbound transport + + # 4: when the peer's handshake is received, it is delivered to + # noise.read_message(), which generates the shared key (enabling + # noise.send() and noise.decrypt()). It also delivers the Handshake + # token upstairs, which might (on the Follower) trigger immediate + # transmission of the Key Confirmation Message (KCM) + + # 5: the outbound KCM is framed and fed into noise.encrypt(), then + # sent outbound + + # 6: the peer's KCM is decrypted then delivered upstairs. The + # Follower treats this as a signal that it should use this connection + # (and drop all others). + + # 7: the peer's first message is decrypted, parsed, and delivered + # upstairs. This might be an Open or a Data, depending upon what + # queued messages were left over from the previous connection + + r, f, n = make_record() + outbound_handshake = object() + kcm, msg1 = object(), object() + f_kcm, f_msg1 = object(), object() + n.write_message = mock.Mock(return_value=outbound_handshake) + n.decrypt = mock.Mock(side_effect=[kcm, msg1]) + n.encrypt = mock.Mock(side_effect=[f_kcm, f_msg1]) + f.add_and_parse = mock.Mock(side_effect=[[], # no tokens yet + [Prologue()], + [Frame("f_handshake")], + [Frame("f_kcm"), + Frame("f_msg1")], + ]) + + self.assertEqual(f.mock_calls, []) + self.assertEqual(n.mock_calls, [mock.call.start_handshake()]) + n.mock_calls[:] = [] + + # 1. The Framer is responsible for sending the prologue, so we don't + # have to check that here, we just check that the Framer was told + # about connectionMade properly. + r.connectionMade() + self.assertEqual(f.mock_calls, [mock.call.connectionMade()]) + self.assertEqual(n.mock_calls, []) + f.mock_calls[:] = [] + + # 2 + # we dribble the prologue in over two messages, to make sure we can + # handle a dataReceived that doesn't complete the token + + # remember, add_and_unframe is a generator + self.assertEqual(list(r.add_and_unframe(b"pro")), []) + self.assertEqual(f.mock_calls, [mock.call.add_and_parse(b"pro")]) + self.assertEqual(n.mock_calls, []) + f.mock_calls[:] = [] + + self.assertEqual(list(r.add_and_unframe(b"logue")), []) + # 3: write_message, send outbound handshake + self.assertEqual(f.mock_calls, [mock.call.add_and_parse(b"logue"), + mock.call.send_frame(outbound_handshake), + ]) + self.assertEqual(n.mock_calls, [mock.call.write_message()]) + f.mock_calls[:] = [] + n.mock_calls[:] = [] + + # 4 + # Now deliver the Noise "handshake", the ephemeral public key. This + # is framed, but not a record, so it shouldn't decrypt or parse + # anything, but the handshake is delivered to the Noise object, and + # it does return a Handshake token so we can let the next layer up + # react (by sending the KCM frame if we're a Follower, or not if + # we're the Leader) + + self.assertEqual(list(r.add_and_unframe(b"handshake")), [Handshake()]) + self.assertEqual(f.mock_calls, [mock.call.add_and_parse(b"handshake")]) + self.assertEqual(n.mock_calls, [mock.call.read_message("f_handshake")]) + f.mock_calls[:] = [] + n.mock_calls[:] = [] + + + # 5: at this point we ought to be able to send a messge, the KCM + with mock.patch("wormhole._dilation.connection.encode_record", + side_effect=[b"r-kcm"]) as er: + r.send_record(kcm) + self.assertEqual(er.mock_calls, [mock.call(kcm)]) + self.assertEqual(n.mock_calls, [mock.call.encrypt(b"r-kcm")]) + self.assertEqual(f.mock_calls, [mock.call.send_frame(f_kcm)]) + n.mock_calls[:] = [] + f.mock_calls[:] = [] + + # 6: Now we deliver two messages stacked up: the KCM (Key + # Confirmation Message) and the first real message. Concatenating + # them tests that we can handle more than one token in a single + # chunk. We need to mock parse_record() because everything past the + # handshake is decrypted and parsed. + + with mock.patch("wormhole._dilation.connection.parse_record", + side_effect=[kcm, msg1]) as pr: + self.assertEqual(list(r.add_and_unframe(b"kcm,msg1")), + [kcm, msg1]) + self.assertEqual(f.mock_calls, + [mock.call.add_and_parse(b"kcm,msg1")]) + self.assertEqual(n.mock_calls, [mock.call.decrypt("f_kcm"), + mock.call.decrypt("f_msg1")]) + self.assertEqual(pr.mock_calls, [mock.call(kcm), mock.call(msg1)]) + n.mock_calls[:] = [] + f.mock_calls[:] = [] diff --git a/src/wormhole/test/dilate/test_subchannel.py b/src/wormhole/test/dilate/test_subchannel.py new file mode 100644 index 0000000..69fa001 --- /dev/null +++ b/src/wormhole/test/dilate/test_subchannel.py @@ -0,0 +1,142 @@ +from __future__ import print_function, unicode_literals +import mock +from twisted.trial import unittest +from twisted.internet.interfaces import ITransport +from twisted.internet.error import ConnectionDone +from ..._dilation.subchannel import (Once, SubChannel, + _WormholeAddress, _SubchannelAddress, + AlreadyClosedError) +from .common import mock_manager + +def make_sc(set_protocol=True): + scid = b"scid" + hostaddr = _WormholeAddress() + peeraddr = _SubchannelAddress(scid) + m = mock_manager() + sc = SubChannel(scid, m, hostaddr, peeraddr) + p = mock.Mock() + if set_protocol: + sc._set_protocol(p) + return sc, m, scid, hostaddr, peeraddr, p + +class SubChannelAPI(unittest.TestCase): + def test_once(self): + o = Once(ValueError) + o() + with self.assertRaises(ValueError): + o() + + def test_create(self): + sc, m, scid, hostaddr, peeraddr, p = make_sc() + self.assert_(ITransport.providedBy(sc)) + self.assertEqual(m.mock_calls, []) + self.assertIdentical(sc.getHost(), hostaddr) + self.assertIdentical(sc.getPeer(), peeraddr) + + def test_write(self): + sc, m, scid, hostaddr, peeraddr, p = make_sc() + + sc.write(b"data") + self.assertEqual(m.mock_calls, [mock.call.send_data(scid, b"data")]) + m.mock_calls[:] = [] + sc.writeSequence([b"more", b"data"]) + self.assertEqual(m.mock_calls, [mock.call.send_data(scid, b"moredata")]) + + def test_write_when_closing(self): + sc, m, scid, hostaddr, peeraddr, p = make_sc() + + sc.loseConnection() + self.assertEqual(m.mock_calls, [mock.call.send_close(scid)]) + m.mock_calls[:] = [] + + with self.assertRaises(AlreadyClosedError) as e: + sc.write(b"data") + self.assertEqual(str(e.exception), + "write not allowed on closed subchannel") + + def test_local_close(self): + sc, m, scid, hostaddr, peeraddr, p = make_sc() + + sc.loseConnection() + self.assertEqual(m.mock_calls, [mock.call.send_close(scid)]) + m.mock_calls[:] = [] + + # late arriving data is still delivered + sc.remote_data(b"late") + self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"late")]) + p.mock_calls[:] = [] + + sc.remote_close() + self.assert_connectionDone(p.mock_calls) + + def test_local_double_close(self): + sc, m, scid, hostaddr, peeraddr, p = make_sc() + + sc.loseConnection() + self.assertEqual(m.mock_calls, [mock.call.send_close(scid)]) + m.mock_calls[:] = [] + + with self.assertRaises(AlreadyClosedError) as e: + sc.loseConnection() + self.assertEqual(str(e.exception), + "loseConnection not allowed on closed subchannel") + + def assert_connectionDone(self, mock_calls): + self.assertEqual(len(mock_calls), 1) + self.assertEqual(mock_calls[0][0], "connectionLost") + self.assertEqual(len(mock_calls[0][1]), 1) + self.assertIsInstance(mock_calls[0][1][0], ConnectionDone) + + def test_remote_close(self): + sc, m, scid, hostaddr, peeraddr, p = make_sc() + sc.remote_close() + self.assertEqual(m.mock_calls, [mock.call.subchannel_closed(sc)]) + self.assert_connectionDone(p.mock_calls) + + def test_data(self): + sc, m, scid, hostaddr, peeraddr, p = make_sc() + sc.remote_data(b"data") + self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"data")]) + p.mock_calls[:] = [] + sc.remote_data(b"not") + sc.remote_data(b"coalesced") + self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"not"), + mock.call.dataReceived(b"coalesced"), + ]) + + def test_data_before_open(self): + sc, m, scid, hostaddr, peeraddr, p = make_sc(set_protocol=False) + sc.remote_data(b"data") + self.assertEqual(p.mock_calls, []) + sc._set_protocol(p) + self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"data")]) + p.mock_calls[:] = [] + sc.remote_data(b"more") + self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"more")]) + + def test_close_before_open(self): + sc, m, scid, hostaddr, peeraddr, p = make_sc(set_protocol=False) + sc.remote_close() + self.assertEqual(p.mock_calls, []) + sc._set_protocol(p) + self.assert_connectionDone(p.mock_calls) + + def test_producer(self): + sc, m, scid, hostaddr, peeraddr, p = make_sc() + + sc.pauseProducing() + self.assertEqual(m.mock_calls, [mock.call.subchannel_pauseProducing(sc)]) + m.mock_calls[:] = [] + sc.resumeProducing() + self.assertEqual(m.mock_calls, [mock.call.subchannel_resumeProducing(sc)]) + m.mock_calls[:] = [] + sc.stopProducing() + self.assertEqual(m.mock_calls, [mock.call.subchannel_stopProducing(sc)]) + m.mock_calls[:] = [] + + def test_consumer(self): + sc, m, scid, hostaddr, peeraddr, p = make_sc() + + # TODO: more, once this is implemented + sc.registerProducer(None, True) + sc.unregisterProducer() diff --git a/src/wormhole/test/test_machines.py b/src/wormhole/test/test_machines.py index 5b5e9f1..35854ae 100644 --- a/src/wormhole/test/test_machines.py +++ b/src/wormhole/test/test_machines.py @@ -12,7 +12,7 @@ import mock from .. import (__version__, _allocator, _boss, _code, _input, _key, _lister, _mailbox, _nameplate, _order, _receive, _rendezvous, _send, _terminator, errors, timing) -from .._interfaces import (IAllocator, IBoss, ICode, IInput, IKey, ILister, +from .._interfaces import (IAllocator, IBoss, ICode, IDilator, IInput, IKey, ILister, IMailbox, INameplate, IOrder, IReceive, IRendezvousConnector, ISend, ITerminator, IWordlist) from .._key import derive_key, derive_phase_key, encrypt_data @@ -1300,6 +1300,7 @@ class Boss(unittest.TestCase): b._RC = Dummy("rc", events, IRendezvousConnector, "start") b._C = Dummy("c", events, ICode, "allocate_code", "input_code", "set_code") + b._D = Dummy("d", events, IDilator, "got_wormhole_versions", "got_key") return b, events def test_basic(self): @@ -1327,7 +1328,9 @@ class Boss(unittest.TestCase): b.got_message("side", "0", b"msg1") self.assertEqual(events, [ ("w.got_key", b"key"), + ("d.got_key", b"key"), ("w.got_verifier", b"verifier"), + ("d.got_wormhole_versions", "side", "side", {}), ("w.got_versions", {}), ("w.received", b"msg1"), ]) diff --git a/src/wormhole/wormhole.py b/src/wormhole/wormhole.py index c02aa60..f967c25 100644 --- a/src/wormhole/wormhole.py +++ b/src/wormhole/wormhole.py @@ -9,6 +9,7 @@ from twisted.internet.task import Cooperator from zope.interface import implementer from ._boss import Boss +from ._dilation.connector import Connector from ._interfaces import IDeferredWormhole, IWormhole from ._key import derive_key from .errors import NoKeyError, WormholeClosed @@ -189,6 +190,9 @@ class _DeferredWormhole(object): raise NoKeyError() return derive_key(self._key, to_bytes(purpose), length) + def dilate(self): + return self._boss.dilate() # fires with (endpoints) + def close(self): # fails with WormholeError unless we established a connection # (state=="happy"). Fails with WrongPasswordError (a subclass of @@ -265,8 +269,12 @@ def create( w = _DelegatedWormhole(delegate) else: w = _DeferredWormhole(reactor, eq) - wormhole_versions = {} # will be used to indicate Wormhole capabilities - wormhole_versions["app_versions"] = versions # app-specific capabilities + # this indicates Wormhole capabilities + wormhole_versions = { + "can-dilate": [1], + "dilation-abilities": Connector.get_connection_abilities(), + } + wormhole_versions["app_versions"] = versions # app-specific capabilities v = __version__ if isinstance(v, type(b"")): v = v.decode("utf-8", errors="replace")