diff --git a/docs/api.md b/docs/api.md
index 39a4e43..43cc318 100644
--- a/docs/api.md
+++ b/docs/api.md
@@ -524,25 +524,31 @@ object twice.
## Dilation
-(NOTE: this section is speculative: this code has not yet been written)
-
-In the longer term, the Wormhole object will incorporate the "Transit"
-functionality (see transit.md) directly, removing the need to instantiate a
-second object. A Wormhole can be "dilated" into a form that is suitable for
-bulk data transfer.
+To send bulk data, or anything more than a handful of messages, a Wormhole
+can be "dilated" into a form that uses a direct TCP connection between the
+two endpoints.
All wormholes start out "undilated". In this state, all messages are queued
on the Rendezvous Server for the lifetime of the wormhole, and server-imposed
number/size/rate limits apply. Calling `w.dilate()` initiates the dilation
-process, and success is signalled via either `d=w.when_dilated()` firing, or
-`dg.wormhole_dilated()` being called. Once dilated, the Wormhole can be used
-as an IConsumer/IProducer, and messages will be sent on a direct connection
-(if possible) or through the transit relay (if not).
+process, and eventually yields a set of Endpoints. Once dilated, the usual
+`.send_message()`/`.get_message()` APIs are disabled (TODO: really?), and
+these endpoints can be used to establish multiple (encrypted) "subchannel"
+connections to the other side.
+
+Each subchannel behaves like a regular Twisted `ITransport`, so they can be
+glued to the Protocol instance of your choice. They also implement the
+IConsumer/IProducer interfaces.
+
+These subchannels are *durable*: as long as the processes on both sides keep
+running, the subchannel will survive the network connection being dropped.
+For example, a file transfer can be started from a laptop, then while it is
+running, the laptop can be closed, moved to a new wifi network, opened back
+up, and the transfer will resume from the new IP address.
What's good about a non-dilated wormhole?:
* setup is faster: no delay while it tries to make a direct connection
-* survives temporary network outages, since messages are queued
* works with "journaled mode", allowing progress to be made even when both
sides are never online at the same time, by serializing the wormhole
@@ -556,21 +562,103 @@ Use non-dilated wormholes when your application only needs to exchange a
couple of messages, for example to set up public keys or provision access
tokens. Use a dilated wormhole to move files.
-Dilated wormholes can provide multiple "channels": these are multiplexed
-through the single (encrypted) TCP connection. Each channel is a separate
-stream (offering IProducer/IConsumer)
+Dilated wormholes can provide multiple "subchannels": these are multiplexed
+through the single (encrypted) TCP connection. Each subchannel is a separate
+stream (offering IProducer/IConsumer for flow control), and is opened and
+closed independently. A special "control channel" is available to both sides
+so they can coordinate how they use the subchannels.
-To create a channel, call `c = w.create_channel()` on a dilated wormhole. The
-"channel ID" can be obtained with `c.get_id()`. This ID will be a short
-(unicode) string, which can be sent to the other side via a normal
-`w.send()`, or any other means. On the other side, use `c =
-w.open_channel(channel_id)` to get a matching channel object.
+The `d = w.dilate()` Deferred fires with a triple of Endpoints:
-Then use `c.send(data)` and `d=c.when_received()` to exchange data, or wire
-them up with `c.registerProducer()`. Note that channels do not close until
-the wormhole connection is closed, so they do not have separate `close()`
-methods or events. Therefore if you plan to send files through them, you'll
-need to inform the recipient ahead of time about how many bytes to expect.
+```python
+d = w.dilate()
+def _dilated(res):
+ (control_channel_ep, subchannel_client_ep, subchannel_server_ep) = res
+d.addCallback(_dilated)
+```
+
+The `control_channel_ep` endpoint is a client-style endpoint, so both sides
+will connect to it with `ep.connect(factory)`. This endpoint is single-use:
+calling `.connect()` a second time will fail. The control channel is
+symmetric: it doesn't matter which side is the application-level
+client/server or initiator/responder, if the application even has such
+concepts. The two applications can use the control channel to negotiate who
+goes first, if necessary.
+
+The subchannel endpoints are *not* symmetric: for each subchannel, one side
+must listen as a server, and the other must connect as a client. Subchannels
+can be established by either side at any time. This supports e.g.
+bidirectional file transfer, where either user of a GUI app can drop files
+into the "wormhole" whenever they like.
+
+The `subchannel_client_ep` on one side is used to connect to the other side's
+`subchannel_server_ep`, and vice versa. The client endpoint is reusable. The
+server endpoint is single-use: `.listen(factory)` may only be called once.
+
+Applications are under no obligation to use subchannels: for many use cases,
+the control channel is enough.
+
+To use subchannels, once the wormhole is dilated and the endpoints are
+available, the listening-side application should attach a listener to the
+`subchannel_server_ep` endpoint:
+
+```python
+def _dilated(res):
+ (control_channel_ep, subchannel_client_ep, subchannel_server_ep) = res
+ f = Factory(MyListeningProtocol)
+ subchannel_server_ep.listen(f)
+```
+
+When the connecting-side application wants to connect to that listening
+protocol, it should use `.connect()` with a suitable connecting protocol
+factory:
+
+```python
+def _connect():
+ f = Factory(MyConnectingProtocol)
+ subchannel_client_ep.connect(f)
+```
+
+For a bidirectional file-transfer application, both sides will establish a
+listening protocol. Later, if/when the user drops a file on the application
+window, that side will initiate a connection, use the resulting subchannel to
+transfer the single file, and then close the subchannel.
+
+```python
+def FileSendingProtocol(internet.Protocol):
+ def __init__(self, metadata, filename):
+ self.file_metadata = metadata
+ self.file_name = filename
+ def connectionMade(self):
+ self.transport.write(self.file_metadata)
+ sender = protocols.basic.FileSender()
+ f = open(self.file_name,"rb")
+ d = sender.beginFileTransfer(f, self.transport)
+ d.addBoth(self._done, f)
+ def _done(res, f):
+ self.transport.loseConnection()
+ f.close()
+def _send(metadata, filename):
+ f = protocol.ClientCreator(reactor,
+ FileSendingProtocol, metadata, filename)
+ subchannel_client_ep.connect(f)
+def FileReceivingProtocol(internet.Protocol):
+ state = INITIAL
+ def dataReceived(self, data):
+ if state == INITIAL:
+ self.state = DATA
+ metadata = parse(data)
+ self.f = open(metadata.filename, "wb")
+ else:
+ # local file writes are blocking, so don't bother with IConsumer
+ self.f.write(data)
+ def connectionLost(self, reason):
+ self.f.close()
+def _dilated(res):
+ (control_channel_ep, subchannel_client_ep, subchannel_server_ep) = res
+ f = Factory(FileReceivingProtocol)
+ subchannel_server_ep.listen(f)
+```
## Bytes, Strings, Unicode, and Python 3
diff --git a/docs/dilation-protocol.md b/docs/dilation-protocol.md
new file mode 100644
index 0000000..4f0dac1
--- /dev/null
+++ b/docs/dilation-protocol.md
@@ -0,0 +1,500 @@
+# Dilation Internals
+
+Wormhole dilation involves several moving parts. Both sides exchange messages
+through the Mailbox server to coordinate the establishment of a more direct
+connection. This connection might flow in either direction, so they trade
+"connection hints" to point at potential listening ports. This process might
+succeed in making multiple connections at about the same time, so one side
+must select the best one to use, and cleanly shut down the others. To make
+the dilated connection *durable*, this side must also decide when the
+connection has been lost, and then coordinate the construction of a
+replacement. Within this connection, a series of queued-and-acked subchannel
+messages are used to open/use/close the application-visible subchannels.
+
+## Leaders and Followers
+
+Each side of a Wormhole has a randomly-generated "side" string. When the
+wormhole is dilated, the side with the lexicographically-higher "side" value
+is named the "Leader", and the other side is named the "Follower". The general
+wormhole protocol treats both sides identically, but the distinction matters
+for the dilation protocol.
+
+Either side can trigger dilation, but the Follower does so by asking the
+Leader to start the process, whereas the Leader just starts the process
+unilaterally. The Leader has exclusive control over whether a given
+connection is considered established or not: if there are multiple potential
+connections to use, the Leader decides which one to use, and the Leader gets
+to decide when the connection is no longer viable (and triggers the
+establishment of a new one).
+
+## Connection Layers
+
+We describe the protocol as a series of layers. Messages sent on one layer
+may be encoded or transformed before being delivered on some other layer.
+
+L1 is the mailbox channel (queued store-and-forward messages that always go
+to the mailbox server, and then are forwarded to other clients subscribed to
+the same mailbox). Both clients remain connected to the mailbox server until
+the Wormhole is closed. They send DILATE-n messages to each other to manage
+the dilation process, including records like `please-dilate`,
+`start-dilation`, `ok-dilation`, and `connection-hints`
+
+L2 is the set of competing connection attempts for a given generation of
+connection. Each time the Leader decides to establish a new connection, a new
+generation number is used. Hopefully these are direct TCP connections between
+the two peers, but they may also include connections through the transit
+relay. Each connection must go through an encrypted handshake process before
+it is considered viable. Viable connections are then submitted to a selection
+process (on the Leader side), which chooses exactly one to use, and drops the
+others. It may wait an extra few seconds in the hopes of getting a "better"
+connection (faster, cheaper, etc), but eventually it will select one.
+
+L3 is the current selected connection. There is one L3 for each generation.
+At all times, the wormhole will have exactly zero or one L3 connection. L3 is
+responsible for the selection process, connection monitoring/keepalives, and
+serialization/deserialization of the plaintext frames. L3 delivers decoded
+frames and connection-establishment events up to L4.
+
+L4 is the persistent higher-level channel. It is created as soon as the first
+L3 connection is selected, and lasts until wormhole is closed entirely. L4
+contains OPEN/DATA/CLOSE/ACK messages: OPEN/DATA/CLOSE have a sequence number
+(scoped to the L4 connection and the direction of travel), and the ACK
+messages reference those sequence numbers. When a message is given to the L4
+channel for delivery to the remote side, it is always queued, then
+transmitted if there is an L3 connection available. This message remains in
+the queue until an ACK is received to release it. If a new L3 connection is
+made, all queued messages will be re-sent (in seqnum order).
+
+L5 are subchannels. There is one pre-established subchannel 0 known as the
+"control channel", which does not require an OPEN message. All other
+subchannels are created by the receipt of an OPEN message with the subchannel
+number. DATA frames are delivered to a specific subchannel. When the
+subchannel is no longer needed, one side will invoke the ``close()`` API
+(``loseConnection()`` in Twisted), which will cause a CLOSE message to be
+sent, and the local L5 object will be put into the "closing "state. When the
+other side receives the CLOSE, it will send its own CLOSE for the same
+subchannel, and fully close its local object (``connectionLost()``). When the
+first side receives CLOSE in the "closing" state, it will fully close its
+local object too.
+
+All L5 subchannels will be paused (``pauseProducing()``) when the L3
+connection is paused or lost. They are resumed when the L3 connection is
+resumed or reestablished.
+
+## Initiating Dilation
+
+Dilation is triggered by calling the `w.dilate()` API. This returns a
+Deferred that will fire once the first L3 connection is established. It fires
+with a 3-tuple of endpoints that can be used to establish subchannels.
+
+For dilation to succeed, both sides must call `w.dilate()`, since the
+resulting endpoints are the only way to access the subchannels. If the other
+side never calls `w.dilate()`, the Deferred will never fire.
+
+The L1 (mailbox) path is used to deliver dilation requests and connection
+hints. The current mailbox protocol uses named "phases" to distinguish
+messages (rather than behaving like a regular ordered channel of arbitrary
+frames or bytes), and all-number phase names are reserved for application
+data (sent via `w.send_message()`). Therefore the dilation control messages
+use phases named `DILATE-0`, `DILATE-1`, etc. Each side maintains its own
+counter, so one side might be up to e.g. `DILATE-5` while the other has only
+gotten as far as `DILATE-2`. This effectively creates a unidirectional stream
+of `DILATE-n` messages, each containing one or more dilation record, of
+various types described below. Note that all phases beyond the initial
+VERSION and PAKE phases are encrypted by the shared session key.
+
+A future mailbox protocol might provide a simple ordered stream of messages,
+with application records and dilation records mixed together.
+
+Each `DILATE-n` message is a JSON-encoded dictionary with a `type` field that
+has a string value. The dictionary will have other keys that depend upon the
+type.
+
+`w.dilate()` triggers a `please-dilate` record with a set of versions that
+can be accepted. Both Leader and Follower emit this record, although the
+Leader is responsible for version decisions. Versions use strings, rather
+than integers, to support experimental protocols, however there is still a
+total ordering of preferability.
+
+```
+{ "type": "please-dilate",
+ "accepted-versions": ["1"]
+}
+```
+
+The Leader then sends a `start-dilation` message with a `version` field (the
+"best" mutually-supported value) and the new "L2 generation" number in the
+`generation` field. Generation numbers are integers, monotonically increasing
+by 1 each time.
+
+```
+{ "type": start-dilation,
+ "version": "1",
+ "generation": 1,
+}
+```
+
+The Follower responds with a `ok-dilation` message with matching `version`
+and `generation` fields.
+
+The Leader decides when a new dilation connection is necessary, both for the
+initial connection and any subsequent reconnects. Therefore the Leader has
+the exclusive right to send the `start-dilation` record. It won't send this
+until after it has sent its own `please-dilate`, and after it has received
+the Follower's `please-dilate`. As a result, local preparations may begin as
+soon as `w.dilate()` is called, but L2 connections do not begin until the
+Leader declares the start of a new L2 generation with the `start-dilation`
+message.
+
+Generations are non-overlapping. The Leader will drop all connections from
+generation 1 before sending the `start-dilation` for generation 2, and will
+not initiate any gen-2 connections until it receives the matching
+`ok-dilation` from the Follower. The Follower must drop all gen-1 connections
+before it sends the `ok-dilation` response (even if it thinks they are still
+functioning: if the Leader thought the gen-1 connection still worked, it
+wouldn't have started gen-2). Listening sockets can be retained, but any
+previous connection made through them must be dropped. This should avoid a
+race.
+
+(TODO: what about a follower->leader connection that was started before
+start-dilation is received, and gets established on the Leader side after
+start-dilation is sent? the follower will drop it after it receives
+start-dilation, but meanwhile the leader may accept it as gen2)
+
+(probably need to include the generation number in the handshake, or in the
+derived key)
+
+(TODO: reduce the number of round-trip stalls here, I've added too many)
+
+"Connection hints" are type/address/port records that tell the other side of
+likely targets for L2 connections. Both sides will try to determine their
+external IP addresses, listen on a TCP port, and advertise `(tcp,
+external-IP, port)` as a connection hint. The Transit Relay is also used as a
+(lower-priority) hint. These are sent in `connection-hint` records, which can
+be sent by the Leader any time after the `start-dilation` record, and by the
+Follower after the `ok-dilation` record. Each side will initiate connections
+upon receipt of the hints.
+
+```
+{ "type": "connection-hints",
+ "hints": [ ... ]
+}
+```
+
+Hints can arrive at any time. One side might immediately send hints that can
+be computed quickly, then send additional hints later as they become
+available. For example, it might enumerate the local network interfaces and
+send hints for all of the LAN addresses first, then send port-forwarding
+(UPnP) requests to the local router. When the forwarding is established
+(providing an externally-visible IP address and port), it can send additional
+hints for that new endpoint. If the other peer happens to be on the same LAN,
+the local connection can be established without waiting for the router's
+response.
+
+
+### Connection Hint Format
+
+Each member of the `hints` field describes a potential L2 connection target
+endpoint, with an associated priority and a set of hints.
+
+The priority is a number (positive or negative float), where larger numbers
+indicate that the client supplying that hint would prefer to use this
+connection over others of lower number. This indicates a sense of cost or
+performance. For example, the Transit Relay is lower priority than a direct
+TCP connection, because it incurs a bandwidth cost (on the relay operator),
+as well as adding latency.
+
+Each endpoint has a set of hints, because the same target might be reachable
+by multiple hints. Once one hint succeeds, there is no point in using the
+other hints.
+
+TODO: think this through some more. What's the example of a single endpoint
+reachable by multiple hints? Should each hint have its own priority, or just
+each endpoint?
+
+## L2 protocol
+
+Upon ``connectionMade()``, both sides send their handshake message. The
+Leader sends "Magic-Wormhole Dilation Handshake v1 Leader\n\n". The Follower
+sends "Magic-Wormhole Dilation Handshake v1 Follower\n\n". This should
+trigger an immediate error for most non-magic-wormhole listeners (e.g. HTTP
+servers that were contacted by accident). If the wrong handshake is received,
+the connection will be dropped. For debugging purposes, the node might want
+to keep looking at data beyond the first incorrect character and log
+everything until the first newline.
+
+Everything beyond that point is a Noise protocol message, which consists of a
+4-byte big-endian length field, followed by the indicated number of bytes.
+This ises the `NNpsk0` pattern with the Leader as the first party ("-> psk,
+e" in the Noise spec), and the Follower as the second ("<- e, ee"). The
+pre-shared-key is the "dilation key", which is statically derived from the
+master PAKE key using HKDF. Each L2 connection uses the same dilation key,
+but different ephemeral keys, so each gets a different session key.
+
+The Leader sends the first message, which is a psk-encrypted ephemeral key.
+The Follower sends the next message, its own psk-encrypted ephemeral key. The
+Follower then sends an empty packet as the "key confirmation message", which
+will be encrypted by the shared key.
+
+The Leader sees the KCM and knows the connection is viable. It delivers the
+protocol object to the L3 manager, which will decide which connection to
+select. When the L2 connection is selected to be the new L3, it will send an
+empty KCM of its own, to let the Follower know the connection being selected.
+All other L2 connections (either viable or still in handshake) are dropped,
+all other connection attempts are cancelled, and all listening sockets are
+shut down.
+
+The Follower will wait for either an empty KCM (at which point the L2
+connection is delivered to the Dilation manager as the new L3), a
+disconnection, or an invalid message (which causes the connection to be
+dropped). Other connections and/or listening sockets are stopped.
+
+Internally, the L2Protocol object manages the Noise session itself. It knows
+(via a constructor argument) whether it is on the Leader or Follower side,
+which affects both the role is plays in the Noise pattern, and the reaction
+to receiving the ephemeral key (for which only the Follower sends an empty
+KCM message). After that, the L2Protocol notifies the L3 object in three
+situations:
+
+* the Noise session produces a valid decrypted frame (for Leader, this
+ includes the Follower's KCM, and thus indicates a viable candidate for
+ connection selection)
+* the Noise session reports a failed decryption
+* the TCP session is lost
+
+All notifications include a reference to the L2Protocol object (`self`). The
+L3 object uses this reference to either close the connection (for errors or
+when the selection process chooses someone else), to send the KCM message
+(after selection, only for the Leader), or to send other L4 messages. The L3
+object will retain a reference to the winning L2 object.
+
+## L3 protocol
+
+The L3 layer is responsible for connection selection, monitoring/keepalives,
+and message (de)serialization. Framing is handled by L2, so the inbound L3
+codepath receives single-message bytestrings, and delivers the same down to
+L2 for encryption, framing, and transmission.
+
+Connection selection takes place exclusively on the Leader side, and includes
+the following:
+
+* receipt of viable L2 connections from below (indicated by the first valid
+ decrypted frame received for any given connection)
+* expiration of a timer
+* comparison of TBD quality/desirability/cost metrics of viable connections
+* selection of winner
+* instructions to losing connections to disconnect
+* delivery of KCM message through winning connection
+* retain reference to winning connection
+
+On the Follower side, the L3 manager just waits for the first connection to
+receive the Leader's KCM, at which point it is retained and all others are
+dropped.
+
+The L3 manager knows which "generation" of connection is being established.
+Each generation uses a different dilation key (?), and is triggered by a new
+set of L1 messages. Connections from one generation should not be confused
+with those of a different generation.
+
+Each time a new L3 connection is established, the L4 protocol is notified. It
+will will immediately send all the L4 messages waiting in its outbound queue.
+The L3 protocol simply wraps these in Noise frames and sends them to the
+other side.
+
+The L3 manager monitors the viability of the current connection, and declares
+it as lost when bidirectional traffic cannot be maintained. It uses PING and
+PONG messages to detect this. These also serve to keep NAT entries alive,
+since many firewalls will stop forwarding packets if they don't observe any
+traffic for e.g. 5 minutes.
+
+Our goals are:
+
+* don't allow more than 30? seconds to pass without at least *some* data
+ being sent along each side of the connection
+* allow the Leader to detect silent connection loss within 60? seconds
+* minimize overhead
+
+We need both sides to:
+
+* maintain a 30-second repeating timer
+* set a flag each time we write to the connection
+* each time the timer fires, if the flag was clear then send a PONG,
+ otherwise clear the flag
+
+In addition, the Leader must:
+
+* run a 60-second repeating timer (ideally somewhat offset from the other)
+* set a flag each time we receive data from the connection
+* each time the timer fires, if the flag was clear then drop the connection,
+ otherwise clear the flag
+
+In the future, we might have L2 links that are less connection-oriented,
+which might have a unidirectional failure mode, at which point we'll need to
+monitor full roundtrips. To accomplish this, the Leader will send periodic
+unconditional PINGs, and the Follower will respond with PONGs. If the
+Leader->Follower connection is down, the PINGs won't arrive and no PONGs will
+be produced. If the Follower->Leader direction has failed, the PONGs won't
+arrive. The delivery of both will be delayed by actual data, so the timeouts
+should be adjusted if we see regular data arriving.
+
+If the connection is dropped before the wormhole is closed (either the other
+end explicitly dropped it, we noticed a problem and told TCP to drop it, or
+TCP noticed a problem itself), the Leader-side L3 manager will initiate a
+reconnection attempt. This uses L1 to send a new DILATE message through the
+mailbox server, along with new connection hints. Eventually this will result
+in a new L3 connection being established.
+
+Finally, L3 is responsible for message serialization and deserialization. L2
+performs decryption and delivers plaintext frames to L3. Each frame starts
+with a one-byte type indicator. The rest of the message depends upon the
+type:
+
+* 0x00 PING, 4-byte ping-id
+* 0x01 PONG, 4-byte ping-id
+* 0x02 OPEN, 4-byte subchannel-id, 4-byte seqnum
+* 0x03 DATA, 4-byte subchannel-id, 4-byte seqnum, variable-length payload
+* 0x04 CLOSE, 4-byte subchannel-id, 4-byte seqnum
+* 0x05 ACK, 4-byte response-seqnum
+
+All seqnums are big-endian, and are provided by the L4 protocol. The other
+fields are arbitrary and not interpreted as integers. The subchannel-ids must
+be allocated by both sides without collision, but otherwise they are only
+used to look up L5 objects for dispatch. The response-seqnum is always copied
+from the OPEN/DATA/CLOSE packet being acknowledged.
+
+L3 consumes the PING and PONG messages. Receiving any PING will provoke a
+PONG in response, with a copy of the ping-id field. The 30-second timer will
+produce unprovoked PONGs with a ping-id of all zeros. A future viability
+protocol will use PINGs to test for roundtrip functionality.
+
+All other messages (OPEN/DATA/CLOSE/ACK) are deserialized and delivered
+"upstairs" to the L4 protocol handler.
+
+The current L3 connection's `IProducer`/`IConsumer` interface is made
+available to the L4 flow-control manager.
+
+## L4 protocol
+
+The L4 protocol manages a durable stream of OPEN/DATA/CLOSE/ACK messages.
+Since each will be enclosed in a Noise frame before they pass to L3, they do
+not need length fields or other framing.
+
+Each OPEN/DATA/CLOSE has a sequence number, starting at 0, and monotonically
+increasing by 1 for each message. Each direction has a separate number space.
+
+The L4 manager maintains a double-ended queue of unacknowledged outbound
+messages. Subchannel activity (opening, closing, sending data) cause messages
+to be added to this queue. If an L3 connection is available, these messages
+are also sent over that connection, but they remain in the queue in case the
+connection is lost and they must be retransmitted on some future replacement
+connection. Messages stay in the queue until they can be retired by the
+receipt of an ACK with a matching response-sequence-number. This provides
+reliable message delivery that survives the L3 connection being replaced.
+
+ACKs are not acked, nor do they have seqnums of their own. Each inbound side
+remembers the highest ACK it has sent, and ignores incoming OPEN/DATA/CLOSE
+messages with that sequence number or higher. This ensures in-order
+at-most-once processing of OPEN/DATA/CLOSE messages.
+
+Each inbound OPEN message causes a new L5 subchannel object to be created.
+Subsequent DATA/CLOSE messages for the same subchannel-id are delivered to
+that object.
+
+Each time an L3 connection is established, the side will immediately send all
+L4 messages waiting in the outbound queue. A future protocol might reduce
+this duplication by including the highest received sequence number in the L1
+PLEASE-DILATE message, which would effectively retire queued messages before
+initiating the L2 connection process. On any given L3 connection, all
+messages are sent in-order. The receipt of an ACK for seqnum `N` allows all
+messages with `seqnum <= N` to be retired.
+
+The L4 layer is also responsible for managing flow control among the L3
+connection and the various L5 subchannels.
+
+## L5 subchannels
+
+The L5 layer consists of a collection of "subchannel" objects, a dispatcher,
+and the endpoints that provide the Twisted-flavored API.
+
+Other than the "control channel", all subchannels are created by a client
+endpoint connection API. The side that calls this API is named the Initiator,
+and the other side is named the Acceptor. Subchannels can be initiated in
+either direction, independent of the Leader/Follower distinction. For a
+typical file-transfer application, the subchannel would be initiated by the
+side seeking to send a file.
+
+Each subchannel uses a distinct subchannel-id, which is a four-byte
+identifier. Both directions share a number space (unlike L4 seqnums), so the
+rule is that the Leader side sets the last bit of the last byte to a 0, while
+the Follower sets it to a 1. These are not generally treated as integers,
+however for the sake of debugging, the implementation generates them with a
+simple big-endian-encoded counter (`next(counter)*2` for the Leader,
+`next(counter)*2+1` for the Follower).
+
+When the `client_ep.connect()` API is called, the Initiator allocates a
+subchannel-id and sends an OPEN. It can then immediately send DATA messages
+with the outbound data (there is no special response to an OPEN, so there is
+no need to wait). The Acceptor will trigger their `.connectionMade` handler
+upon receipt of the OPEN.
+
+Subchannels are durable: they do not close until one side calls
+`.loseConnection` on the subchannel object (or the enclosing Wormhole is
+closed). Either the Initiator or the Acceptor can call `.loseConnection`.
+This causes a CLOSE message to be sent (with the subchannel-id). The other
+side will send its own CLOSE message in response. Each side will signal the
+`.connectionLost()` event upon receipt of a CLOSE.
+
+There is no equivalent to TCP's "half-closed" state, however if only one side
+calls `close()`, then all data written before that call will be delivered
+before the other side observes `.connectionLost()`. Any inbound data that was
+queued for delivery before the other side sees the CLOSE will still be
+delivered to the side that called `close()` before it sees
+`.connectionLost()`. Internally, the side which called `.loseConnection` will
+remain in a special "closing" state until the CLOSE response arrives, during
+which time DATA payloads are still delivered. After calling `close()` (or
+receiving CLOSE), any outbound `.write()` calls will trigger an error.
+
+DATA payloads that arrive for a non-open subchannel are logged and discarded.
+
+This protocol calls for one OPEN and two CLOSE messages for each subchannel,
+with some arbitrary number of DATA messages in between. Subchannel-ids should
+not be reused (it would probably work, the protocol hasn't been analyzed
+enough to be sure).
+
+The "control channel" is special. It uses a subchannel-id of all zeros, and
+is opened implicitly by both sides as soon as the first L3 connection is
+selected. It is routed to a special client-on-both-sides endpoint, rather
+than causing the listening endpoint to accept a new connection. This avoids
+the need for application-level code to negotiate who should be the one to
+open it (the Leader/Follower distinction is private to the Wormhole
+internals: applications are not obligated to pick a side).
+
+OPEN and CLOSE messages for the control channel are logged and discarded. The
+control-channel client endpoints can only be used once, and does not close
+until the Wormhole itself is closed.
+
+Each OPEN/DATA/CLOSE message is delivered to the L4 object for queueing,
+delivery, and eventual retirement. The L5 layer does not keep track of old
+messages.
+
+### Flow Control
+
+Subchannels are flow-controlled by pausing their writes when the L3
+connection is paused, and pausing the L3 connection when the subchannel
+signals a pause. When the outbound L3 connection is full, *all* subchannels
+are paused. Likewise the inbound connection is paused if *any* of the
+subchannels asks for a pause. This is much easier to implement and improves
+our utilization factor (we can use TCP's window-filling algorithm, instead of
+rolling our own), but will block all subchannels even if only one of them
+gets full. This shouldn't matter for many applications, but might be
+noticeable when combining very different kinds of traffic (e.g. a chat
+conversation sharing a wormhole with file-transfer might prefer the IM text
+to take priority).
+
+Each subchannel implements Twisted's `ITransport`, `IProducer`, and
+`IConsumer` interfaces. The Endpoint API causes a new `IProtocol` object to
+be created (by the caller's factory) and glued to the subchannel object in
+the `.transport` property, as is standard in Twisted-based applications.
+
+All subchannels are also paused when the L3 connection is lost, and are
+unpaused when a new replacement connection is selected.
diff --git a/docs/new-protocol.svg b/docs/new-protocol.svg
new file mode 100644
index 0000000..a69b79a
--- /dev/null
+++ b/docs/new-protocol.svg
@@ -0,0 +1,2000 @@
+
+
+
+
diff --git a/docs/state-machines/machines.dot b/docs/state-machines/machines.dot
index ecfdf9a..09cc02c 100644
--- a/docs/state-machines/machines.dot
+++ b/docs/state-machines/machines.dot
@@ -23,12 +23,13 @@ digraph {
Terminator [shape="box" color="blue" fontcolor="blue"]
InputHelperAPI [shape="oval" label="input\nhelper\nAPI"
color="blue" fontcolor="blue"]
+ Dilator [shape="box" label="Dilator" color="blue" fontcolor="blue"]
#Connection -> websocket [color="blue"]
#Connection -> Order [color="blue"]
Wormhole -> Boss [style="dashed"
- label="allocate_code\ninput_code\nset_code\nsend\nclose\n(once)"
+ label="allocate_code\ninput_code\nset_code\ndilate\nsend\nclose\n(once)"
color="red" fontcolor="red"]
#Wormhole -> Boss [color="blue"]
Boss -> Wormhole [style="dashed" label="got_code\ngot_key\ngot_verifier\ngot_version\nreceived (seq)\nclosed\n(once)"]
@@ -112,4 +113,7 @@ digraph {
Terminator -> Boss [style="dashed" label="closed\n(once)"]
Boss -> Terminator [style="dashed" color="red" fontcolor="red"
label="close"]
+
+ Boss -> Dilator [style="dashed" label="dilate\nreceived_dilate\ngot_wormhole_versions"]
+ Dilator -> Send [style="dashed" label="send(dilate-N)"]
}
diff --git a/setup.py b/setup.py
index a99a2d2..c275e8b 100644
--- a/setup.py
+++ b/setup.py
@@ -48,6 +48,7 @@ setup(name="magic-wormhole",
"click",
"humanize",
"txtorcon >= 18.0.2", # 18.0.2 fixes py3.4 support
+ "noiseprotocol",
],
extras_require={
':sys_platform=="win32"': ["pywin32"],
diff --git a/src/wormhole/_boss.py b/src/wormhole/_boss.py
index 373c0f4..cc6f4fa 100644
--- a/src/wormhole/_boss.py
+++ b/src/wormhole/_boss.py
@@ -12,6 +12,7 @@ from zope.interface import implementer
from . import _interfaces
from ._allocator import Allocator
from ._code import Code, validate_code
+from ._dilation.manager import Dilator
from ._input import Input
from ._key import Key
from ._lister import Lister
@@ -66,6 +67,7 @@ class Boss(object):
self._I = Input(self._timing)
self._C = Code(self._timing)
self._T = Terminator()
+ self._D = Dilator(self._reactor, self._eventual_queue, self._cooperator)
self._N.wire(self._M, self._I, self._RC, self._T)
self._M.wire(self._N, self._RC, self._O, self._T)
@@ -79,6 +81,7 @@ class Boss(object):
self._I.wire(self._C, self._L)
self._C.wire(self, self._A, self._N, self._K, self._I)
self._T.wire(self, self._RC, self._N, self._M)
+ self._D.wire(self._S)
def _init_other_state(self):
self._did_start_code = False
@@ -86,6 +89,9 @@ class Boss(object):
self._next_rx_phase = 0
self._rx_phases = {} # phase -> plaintext
+ self._next_rx_dilate_seqnum = 0
+ self._rx_dilate_seqnums = {} # seqnum -> plaintext
+
self._result = "empty"
# these methods are called from outside
@@ -198,6 +204,9 @@ class Boss(object):
self._did_start_code = True
self._C.set_code(code)
+ def dilate(self):
+ return self._D.dilate() # fires with endpoints
+
@m.input()
def send(self, plaintext):
pass
@@ -258,8 +267,11 @@ class Boss(object):
# this is only called for side != ours
assert isinstance(phase, type("")), type(phase)
assert isinstance(plaintext, type(b"")), type(plaintext)
+ d_mo = re.search(r'^dilate-(\d+)$', phase)
if phase == "version":
self._got_version(side, plaintext)
+ elif d_mo:
+ self._got_dilate(int(d_mo.group(1)), plaintext)
elif re.search(r'^\d+$', phase):
self._got_phase(int(phase), plaintext)
else:
@@ -275,6 +287,10 @@ class Boss(object):
def _got_phase(self, phase, plaintext):
pass
+ @m.input()
+ def _got_dilate(self, seqnum, plaintext):
+ pass
+
@m.input()
def got_key(self, key):
pass
@@ -298,6 +314,8 @@ class Boss(object):
# in the future, this is how Dilation is signalled
self._their_side = side
self._their_versions = bytes_to_dict(plaintext)
+ self._D.got_wormhole_versions(self._side, self._their_side,
+ self._their_versions)
# but this part is app-to-app
app_versions = self._their_versions.get("app_versions", {})
self._W.got_versions(app_versions)
@@ -339,6 +357,10 @@ class Boss(object):
def W_got_key(self, key):
self._W.got_key(key)
+ @m.output()
+ def D_got_key(self, key):
+ self._D.got_key(key)
+
@m.output()
def W_got_verifier(self, verifier):
self._W.got_verifier(verifier)
@@ -352,6 +374,16 @@ class Boss(object):
self._W.received(self._rx_phases.pop(self._next_rx_phase))
self._next_rx_phase += 1
+ @m.output()
+ def D_received_dilate(self, seqnum, plaintext):
+ assert isinstance(seqnum, six.integer_types), type(seqnum)
+ # strict phase order, no gaps
+ self._rx_dilate_seqnums[seqnum] = plaintext
+ while self._next_rx_dilate_seqnum in self._rx_dilate_seqnums:
+ m = self._rx_dilate_seqnums.pop(self._next_rx_dilate_seqnum)
+ self._D.received_dilate(m)
+ self._next_rx_dilate_seqnum += 1
+
@m.output()
def W_close_with_error(self, err):
self._result = err # exception
@@ -374,7 +406,7 @@ class Boss(object):
S1_lonely.upon(scared, enter=S3_closing, outputs=[close_scared])
S1_lonely.upon(close, enter=S3_closing, outputs=[close_lonely])
S1_lonely.upon(send, enter=S1_lonely, outputs=[S_send])
- S1_lonely.upon(got_key, enter=S1_lonely, outputs=[W_got_key])
+ S1_lonely.upon(got_key, enter=S1_lonely, outputs=[W_got_key, D_got_key])
S1_lonely.upon(rx_error, enter=S3_closing, outputs=[close_error])
S1_lonely.upon(error, enter=S4_closed, outputs=[W_close_with_error])
@@ -382,6 +414,7 @@ class Boss(object):
S2_happy.upon(got_verifier, enter=S2_happy, outputs=[W_got_verifier])
S2_happy.upon(_got_phase, enter=S2_happy, outputs=[W_received])
S2_happy.upon(_got_version, enter=S2_happy, outputs=[process_version])
+ S2_happy.upon(_got_dilate, enter=S2_happy, outputs=[D_received_dilate])
S2_happy.upon(scared, enter=S3_closing, outputs=[close_scared])
S2_happy.upon(close, enter=S3_closing, outputs=[close_happy])
S2_happy.upon(send, enter=S2_happy, outputs=[S_send])
@@ -393,6 +426,7 @@ class Boss(object):
S3_closing.upon(got_verifier, enter=S3_closing, outputs=[])
S3_closing.upon(_got_phase, enter=S3_closing, outputs=[])
S3_closing.upon(_got_version, enter=S3_closing, outputs=[])
+ S3_closing.upon(_got_dilate, enter=S3_closing, outputs=[])
S3_closing.upon(happy, enter=S3_closing, outputs=[])
S3_closing.upon(scared, enter=S3_closing, outputs=[])
S3_closing.upon(close, enter=S3_closing, outputs=[])
@@ -404,6 +438,7 @@ class Boss(object):
S4_closed.upon(got_verifier, enter=S4_closed, outputs=[])
S4_closed.upon(_got_phase, enter=S4_closed, outputs=[])
S4_closed.upon(_got_version, enter=S4_closed, outputs=[])
+ S4_closed.upon(_got_dilate, enter=S4_closed, outputs=[])
S4_closed.upon(happy, enter=S4_closed, outputs=[])
S4_closed.upon(scared, enter=S4_closed, outputs=[])
S4_closed.upon(close, enter=S4_closed, outputs=[])
diff --git a/src/wormhole/_dilation/__init__.py b/src/wormhole/_dilation/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/wormhole/_dilation/connection.py b/src/wormhole/_dilation/connection.py
new file mode 100644
index 0000000..f0c8e35
--- /dev/null
+++ b/src/wormhole/_dilation/connection.py
@@ -0,0 +1,482 @@
+from __future__ import print_function, unicode_literals
+from collections import namedtuple
+import six
+from attr import attrs, attrib
+from attr.validators import instance_of, provides
+from automat import MethodicalMachine
+from zope.interface import Interface, implementer
+from twisted.python import log
+from twisted.internet.protocol import Protocol
+from twisted.internet.interfaces import ITransport
+from .._interfaces import IDilationConnector
+from ..observer import OneShotObserver
+from .encode import to_be4, from_be4
+from .roles import FOLLOWER
+
+# InboundFraming is given data and returns Frames (Noise wire-side
+# bytestrings). It handles the relay handshake and the prologue. The Frames it
+# returns are either the ephemeral key (the Noise "handshake") or ciphertext
+# messages.
+
+# The next object up knows whether it's expecting a Handshake or a message. It
+# feeds the first into Noise as a handshake, it feeds the rest into Noise as a
+# message (which produces a plaintext stream). It emits tokens that are either
+# "i've finished with the handshake (so you can send the KCM if you want)", or
+# "here is a decrypted message (which might be the KCM)".
+
+# the transmit direction goes directly to transport.write, and doesn't touch
+# the state machine. we can do this because the way we encode/encrypt/frame
+# things doesn't depend upon the receiver state. It would be more safe to e.g.
+# prohibit sending ciphertext frames unless we're in the received-handshake
+# state, but then we'll be in the middle of an inbound state transition ("we
+# just received the handshake, so you can send a KCM now") when we perform an
+# operation that depends upon the state (send_plaintext(kcm)), which is not a
+# coherent/safe place to touch the state machine.
+
+# we could set a flag and test it from inside send_plaintext, which kind of
+# violates the state machine owning the state (ideally all "if" statements
+# would be translated into same-input transitions from different starting
+# states). For the specific question of sending plaintext frames, Noise will
+# refuse us unless it's ready anyways, so the question is probably moot.
+
+class IFramer(Interface):
+ pass
+class IRecord(Interface):
+ pass
+
+def first(l):
+ return l[0]
+
+class Disconnect(Exception):
+ pass
+RelayOK = namedtuple("RelayOk", [])
+Prologue = namedtuple("Prologue", [])
+Frame = namedtuple("Frame", ["frame"])
+
+@attrs
+@implementer(IFramer)
+class _Framer(object):
+ _transport = attrib(validator=provides(ITransport))
+ _outbound_prologue = attrib(validator=instance_of(bytes))
+ _inbound_prologue = attrib(validator=instance_of(bytes))
+ _buffer = b""
+ _can_send_frames = False
+
+ # in: use_relay
+ # in: connectionMade, dataReceived
+ # out: prologue_received, frame_received
+ # out (shared): transport.loseConnection
+ # out (shared): transport.write (relay handshake, prologue)
+ # states: want_relay, want_prologue, want_frame
+ m = MethodicalMachine()
+ set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover
+
+ @m.state()
+ def want_relay(self): pass # pragma: no cover
+ @m.state(initial=True)
+ def want_prologue(self): pass # pragma: no cover
+ @m.state()
+ def want_frame(self): pass # pragma: no cover
+
+ @m.input()
+ def use_relay(self, relay_handshake): pass
+ @m.input()
+ def connectionMade(self): pass
+ @m.input()
+ def parse(self): pass
+ @m.input()
+ def got_relay_ok(self): pass
+ @m.input()
+ def got_prologue(self): pass
+
+ @m.output()
+ def store_relay_handshake(self, relay_handshake):
+ self._outbound_relay_handshake = relay_handshake
+ self._expected_relay_handshake = b"ok\n" # TODO: make this configurable
+ @m.output()
+ def send_relay_handshake(self):
+ self._transport.write(self._outbound_relay_handshake)
+
+ @m.output()
+ def send_prologue(self):
+ self._transport.write(self._outbound_prologue)
+
+ @m.output()
+ def parse_relay_ok(self):
+ if self._get_expected("relay_ok", self._expected_relay_handshake):
+ return RelayOK()
+
+ @m.output()
+ def parse_prologue(self):
+ if self._get_expected("prologue", self._inbound_prologue):
+ return Prologue()
+
+ @m.output()
+ def can_send_frames(self):
+ self._can_send_frames = True # for assertion in send_frame()
+
+ @m.output()
+ def parse_frame(self):
+ if len(self._buffer) < 4:
+ return None
+ frame_length = from_be4(self._buffer[0:4])
+ if len(self._buffer) < 4+frame_length:
+ return None
+ frame = self._buffer[4:4+frame_length]
+ self._buffer = self._buffer[4+frame_length:] # TODO: avoid copy
+ return Frame(frame=frame)
+
+ want_prologue.upon(use_relay, outputs=[store_relay_handshake],
+ enter=want_relay)
+
+ want_relay.upon(connectionMade, outputs=[send_relay_handshake],
+ enter=want_relay)
+ want_relay.upon(parse, outputs=[parse_relay_ok], enter=want_relay,
+ collector=first)
+ want_relay.upon(got_relay_ok, outputs=[send_prologue], enter=want_prologue)
+
+ want_prologue.upon(connectionMade, outputs=[send_prologue],
+ enter=want_prologue)
+ want_prologue.upon(parse, outputs=[parse_prologue], enter=want_prologue,
+ collector=first)
+ want_prologue.upon(got_prologue, outputs=[can_send_frames], enter=want_frame)
+
+ want_frame.upon(parse, outputs=[parse_frame], enter=want_frame,
+ collector=first)
+
+
+ def _get_expected(self, name, expected):
+ lb = len(self._buffer)
+ le = len(expected)
+ if self._buffer.startswith(expected):
+ # if the buffer starts with the expected string, consume it and
+ # return True
+ self._buffer = self._buffer[le:]
+ return True
+ if not expected.startswith(self._buffer):
+ # we're not on track: the data we've received so far does not
+ # match the expected value, so this can't possibly be right.
+ # Don't complain until we see the expected length, or a newline,
+ # so we can capture the weird input in the log for debugging.
+ if (b"\n" in self._buffer or lb >= le):
+ log.msg("bad {}: {}".format(name, self._buffer[:le]))
+ raise Disconnect()
+ return False # wait a bit longer
+ # good so far, just waiting for the rest
+ return False
+
+ # external API is: connectionMade, add_and_parse, and send_frame
+
+ def add_and_parse(self, data):
+ # we can't make this an @m.input because we can't change the state
+ # from within an input. Instead, let the state choose the parser to
+ # use, and use the parsed token drive a state transition.
+ self._buffer += data
+ while True:
+ # it'd be nice to use an iterator here, but since self.parse()
+ # dispatches to a different parser (depending upon the current
+ # state), we'd be using multiple iterators
+ token = self.parse()
+ if isinstance(token, RelayOK):
+ self.got_relay_ok()
+ elif isinstance(token, Prologue):
+ self.got_prologue()
+ yield token # triggers send_handshake
+ elif isinstance(token, Frame):
+ yield token
+ else:
+ break
+
+ def send_frame(self, frame):
+ assert self._can_send_frames
+ self._transport.write(to_be4(len(frame)) + frame)
+
+# RelayOK: Newline-terminated buddy-is-connected response from Relay.
+# First data received from relay.
+# Prologue: double-newline-terminated this-is-really-wormhole response
+# from peer. First data received from peer.
+# Frame: Either handshake or encrypted message. Length-prefixed on wire.
+# Handshake: the Noise ephemeral key, first framed message
+# Message: plaintext: encoded KCM/PING/PONG/OPEN/DATA/CLOSE/ACK
+# KCM: Key Confirmation Message (encrypted b"\x00"). First frame
+# from peer. Sent immediately by Follower, after Selection by Leader.
+# Record: namedtuple of KCM/Open/Data/Close/Ack/Ping/Pong
+
+Handshake = namedtuple("Handshake", [])
+# decrypted frames: produces KCM, Ping, Pong, Open, Data, Close, Ack
+KCM = namedtuple("KCM", [])
+Ping = namedtuple("Ping", ["ping_id"]) # ping_id is arbitrary 4-byte value
+Pong = namedtuple("Pong", ["ping_id"])
+Open = namedtuple("Open", ["seqnum", "scid"]) # seqnum is integer
+Data = namedtuple("Data", ["seqnum", "scid", "data"])
+Close = namedtuple("Close", ["seqnum", "scid"]) # scid is integer
+Ack = namedtuple("Ack", ["resp_seqnum"]) # resp_seqnum is integer
+Records = (KCM, Ping, Pong, Open, Data, Close, Ack)
+Handshake_or_Records = (Handshake,) + Records
+
+T_KCM = b"\x00"
+T_PING = b"\x01"
+T_PONG = b"\x02"
+T_OPEN = b"\x03"
+T_DATA = b"\x04"
+T_CLOSE = b"\x05"
+T_ACK = b"\x06"
+
+def parse_record(plaintext):
+ msgtype = plaintext[0:1]
+ if msgtype == T_KCM:
+ return KCM()
+ if msgtype == T_PING:
+ ping_id = plaintext[1:5]
+ return Ping(ping_id)
+ if msgtype == T_PONG:
+ ping_id = plaintext[1:5]
+ return Pong(ping_id)
+ if msgtype == T_OPEN:
+ scid = from_be4(plaintext[1:5])
+ seqnum = from_be4(plaintext[5:9])
+ return Open(seqnum, scid)
+ if msgtype == T_DATA:
+ scid = from_be4(plaintext[1:5])
+ seqnum = from_be4(plaintext[5:9])
+ data = plaintext[9:]
+ return Data(seqnum, scid, data)
+ if msgtype == T_CLOSE:
+ scid = from_be4(plaintext[1:5])
+ seqnum = from_be4(plaintext[5:9])
+ return Close(seqnum, scid)
+ if msgtype == T_ACK:
+ resp_seqnum = from_be4(plaintext[1:5])
+ return Ack(resp_seqnum)
+ log.err("received unknown message type: {}".format(plaintext))
+ raise ValueError()
+
+def encode_record(r):
+ if isinstance(r, KCM):
+ return b"\x00"
+ if isinstance(r, Ping):
+ return b"\x01" + r.ping_id
+ if isinstance(r, Pong):
+ return b"\x02" + r.ping_id
+ if isinstance(r, Open):
+ assert isinstance(r.scid, six.integer_types)
+ assert isinstance(r.seqnum, six.integer_types)
+ return b"\x03" + to_be4(r.scid) + to_be4(r.seqnum)
+ if isinstance(r, Data):
+ assert isinstance(r.scid, six.integer_types)
+ assert isinstance(r.seqnum, six.integer_types)
+ return b"\x04" + to_be4(r.scid) + to_be4(r.seqnum) + r.data
+ if isinstance(r, Close):
+ assert isinstance(r.scid, six.integer_types)
+ assert isinstance(r.seqnum, six.integer_types)
+ return b"\x05" + to_be4(r.scid) + to_be4(r.seqnum)
+ if isinstance(r, Ack):
+ assert isinstance(r.resp_seqnum, six.integer_types)
+ return b"\x06" + to_be4(r.resp_seqnum)
+ raise TypeError(r)
+
+@attrs
+@implementer(IRecord)
+class _Record(object):
+ _framer = attrib(validator=provides(IFramer))
+ _noise = attrib()
+
+ n = MethodicalMachine()
+ # TODO: set_trace
+
+ def __attrs_post_init__(self):
+ self._noise.start_handshake()
+
+ # in: role=
+ # in: prologue_received, frame_received
+ # out: handshake_received, record_received
+ # out: transport.write (noise handshake, encrypted records)
+ # states: want_prologue, want_handshake, want_record
+
+ @n.state(initial=True)
+ def want_prologue(self): pass # pragma: no cover
+ @n.state()
+ def want_handshake(self): pass # pragma: no cover
+ @n.state()
+ def want_message(self): pass # pragma: no cover
+
+ @n.input()
+ def got_prologue(self):
+ pass
+ @n.input()
+ def got_frame(self, frame):
+ pass
+
+ @n.output()
+ def send_handshake(self):
+ handshake = self._noise.write_message() # generate the ephemeral key
+ self._framer.send_frame(handshake)
+
+ @n.output()
+ def process_handshake(self, frame):
+ from noise.exceptions import NoiseInvalidMessage
+ try:
+ payload = self._noise.read_message(frame)
+ # Noise can include unencrypted data in the handshake, but we don't
+ # use it
+ del payload
+ except NoiseInvalidMessage as e:
+ log.err(e, "bad inbound noise handshake")
+ raise Disconnect()
+ return Handshake()
+
+ @n.output()
+ def decrypt_message(self, frame):
+ from noise.exceptions import NoiseInvalidMessage
+ try:
+ message = self._noise.decrypt(frame)
+ except NoiseInvalidMessage as e:
+ # if this happens during tests, flunk the test
+ log.err(e, "bad inbound noise frame")
+ raise Disconnect()
+ return parse_record(message)
+
+ want_prologue.upon(got_prologue, outputs=[send_handshake],
+ enter=want_handshake)
+ want_handshake.upon(got_frame, outputs=[process_handshake],
+ collector=first, enter=want_message)
+ want_message.upon(got_frame, outputs=[decrypt_message],
+ collector=first, enter=want_message)
+
+ # external API is: connectionMade, dataReceived, send_record
+
+ def connectionMade(self):
+ self._framer.connectionMade()
+
+ def add_and_unframe(self, data):
+ for token in self._framer.add_and_parse(data):
+ if isinstance(token, Prologue):
+ self.got_prologue() # triggers send_handshake
+ else:
+ assert isinstance(token, Frame)
+ yield self.got_frame(token.frame) # Handshake or a Record type
+
+ def send_record(self, r):
+ message = encode_record(r)
+ frame = self._noise.encrypt(message)
+ self._framer.send_frame(frame)
+
+
+@attrs
+class DilatedConnectionProtocol(Protocol, object):
+ """I manage an L2 connection.
+
+ When a new L2 connection is needed (as determined by the Leader),
+ both Leader and Follower will initiate many simultaneous connections
+ (probably TCP, but conceivably others). A subset will actually
+ connect. A subset of those will successfully pass negotiation by
+ exchanging handshakes to demonstrate knowledge of the session key.
+ One of the negotiated connections will be selected by the Leader for
+ active use, and the others will be dropped.
+
+ At any given time, there is at most one active L2 connection.
+ """
+
+ _eventual_queue = attrib()
+ _role = attrib()
+ _connector = attrib(validator=provides(IDilationConnector))
+ _noise = attrib()
+ _outbound_prologue = attrib(validator=instance_of(bytes))
+ _inbound_prologue = attrib(validator=instance_of(bytes))
+
+ _use_relay = False
+ _relay_handshake = None
+
+ m = MethodicalMachine()
+ set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover
+
+ def __attrs_post_init__(self):
+ self._manager = None # set if/when we are selected
+ self._disconnected = OneShotObserver(self._eventual_queue)
+ self._can_send_records = False
+
+ @m.state(initial=True)
+ def unselected(self): pass # pragma: no cover
+ @m.state()
+ def selecting(self): pass # pragma: no cover
+ @m.state()
+ def selected(self): pass # pragma: no cover
+
+ @m.input()
+ def got_kcm(self):
+ pass
+ @m.input()
+ def select(self, manager):
+ pass # fires set_manager()
+ @m.input()
+ def got_record(self, record):
+ pass
+
+ @m.output()
+ def add_candidate(self):
+ self._connector.add_candidate(self)
+
+ @m.output()
+ def set_manager(self, manager):
+ self._manager = manager
+
+ @m.output()
+ def can_send_records(self, manager):
+ self._can_send_records = True
+
+ @m.output()
+ def deliver_record(self, record):
+ self._manager.got_record(record)
+
+ unselected.upon(got_kcm, outputs=[add_candidate], enter=selecting)
+ selecting.upon(select, outputs=[set_manager, can_send_records], enter=selected)
+ selected.upon(got_record, outputs=[deliver_record], enter=selected)
+
+ # called by Connector
+
+ def use_relay(self, relay_handshake):
+ assert isinstance(relay_handshake, bytes)
+ self._use_relay = True
+ self._relay_handshake = relay_handshake
+
+ def when_disconnected(self):
+ return self._disconnected.when_fired()
+
+ def disconnect(self):
+ self.transport.loseConnection()
+
+ # select() called by Connector
+
+ # called by Manager
+ def send_record(self, record):
+ assert self._can_send_records
+ self._record.send_record(record)
+
+ # IProtocol methods
+
+ def connectionMade(self):
+ framer = _Framer(self.transport,
+ self._outbound_prologue, self._inbound_prologue)
+ if self._use_relay:
+ framer.use_relay(self._relay_handshake)
+ self._record = _Record(framer, self._noise)
+ self._record.connectionMade()
+
+ def dataReceived(self, data):
+ try:
+ for token in self._record.add_and_unframe(data):
+ assert isinstance(token, Handshake_or_Records)
+ if isinstance(token, Handshake):
+ if self._role is FOLLOWER:
+ self._record.send_record(KCM())
+ elif isinstance(token, KCM):
+ # if we're the leader, add this connection as a candiate.
+ # if we're the follower, accept this connection.
+ self.got_kcm() # connector.add_candidate()
+ else:
+ self.got_record(token) # manager.got_record()
+ except Disconnect:
+ self.transport.loseConnection()
+
+ def connectionLost(self, why=None):
+ self._disconnected.fire(self)
diff --git a/src/wormhole/_dilation/connector.py b/src/wormhole/_dilation/connector.py
new file mode 100644
index 0000000..86f2a72
--- /dev/null
+++ b/src/wormhole/_dilation/connector.py
@@ -0,0 +1,482 @@
+from __future__ import print_function, unicode_literals
+import sys, re
+from collections import defaultdict, namedtuple
+from binascii import hexlify
+import six
+from attr import attrs, attrib
+from attr.validators import instance_of, provides, optional
+from automat import MethodicalMachine
+from zope.interface import implementer
+from twisted.internet.task import deferLater
+from twisted.internet.defer import DeferredList
+from twisted.internet.endpoints import HostnameEndpoint, serverFromString
+from twisted.internet.protocol import ClientFactory, ServerFactory
+from twisted.python import log
+from hkdf import Hkdf
+from .. import ipaddrs # TODO: move into _dilation/
+from .._interfaces import IDilationConnector, IDilationManager
+from ..timing import DebugTiming
+from ..observer import EmptyableSet
+from .connection import DilatedConnectionProtocol, KCM
+from .roles import LEADER
+
+
+# These namedtuples are "hint objects". The JSON-serializable dictionaries
+# are "hint dicts".
+
+# DirectTCPV1Hint and TorTCPV1Hint mean the following protocol:
+# * make a TCP connection (possibly via Tor)
+# * send the sender/receiver handshake bytes first
+# * expect to see the receiver/sender handshake bytes from the other side
+# * the sender writes "go\n", the receiver waits for "go\n"
+# * the rest of the connection contains transit data
+DirectTCPV1Hint = namedtuple("DirectTCPV1Hint", ["hostname", "port", "priority"])
+TorTCPV1Hint = namedtuple("TorTCPV1Hint", ["hostname", "port", "priority"])
+# RelayV1Hint contains a tuple of DirectTCPV1Hint and TorTCPV1Hint hints (we
+# use a tuple rather than a list so they'll be hashable into a set). For each
+# one, make the TCP connection, send the relay handshake, then complete the
+# rest of the V1 protocol. Only one hint per relay is useful.
+RelayV1Hint = namedtuple("RelayV1Hint", ["hints"])
+
+def describe_hint_obj(hint, relay, tor):
+ prefix = "tor->" if tor else "->"
+ if relay:
+ prefix = prefix + "relay:"
+ if isinstance(hint, DirectTCPV1Hint):
+ return prefix + "tcp:%s:%d" % (hint.hostname, hint.port)
+ elif isinstance(hint, TorTCPV1Hint):
+ return prefix+"tor:%s:%d" % (hint.hostname, hint.port)
+ else:
+ return prefix+str(hint)
+
+def parse_hint_argv(hint, stderr=sys.stderr):
+ assert isinstance(hint, type(""))
+ # return tuple or None for an unparseable hint
+ priority = 0.0
+ mo = re.search(r'^([a-zA-Z0-9]+):(.*)$', hint)
+ if not mo:
+ print("unparseable hint '%s'" % (hint,), file=stderr)
+ return None
+ hint_type = mo.group(1)
+ if hint_type != "tcp":
+ print("unknown hint type '%s' in '%s'" % (hint_type, hint), file=stderr)
+ return None
+ hint_value = mo.group(2)
+ pieces = hint_value.split(":")
+ if len(pieces) < 2:
+ print("unparseable TCP hint (need more colons) '%s'" % (hint,),
+ file=stderr)
+ return None
+ mo = re.search(r'^(\d+)$', pieces[1])
+ if not mo:
+ print("non-numeric port in TCP hint '%s'" % (hint,), file=stderr)
+ return None
+ hint_host = pieces[0]
+ hint_port = int(pieces[1])
+ for more in pieces[2:]:
+ if more.startswith("priority="):
+ more_pieces = more.split("=")
+ try:
+ priority = float(more_pieces[1])
+ except ValueError:
+ print("non-float priority= in TCP hint '%s'" % (hint,),
+ file=stderr)
+ return None
+ return DirectTCPV1Hint(hint_host, hint_port, priority)
+
+def parse_tcp_v1_hint(hint): # hint_struct -> hint_obj
+ hint_type = hint.get("type", "")
+ if hint_type not in ["direct-tcp-v1", "tor-tcp-v1"]:
+ log.msg("unknown hint type: %r" % (hint,))
+ return None
+ if not("hostname" in hint
+ and isinstance(hint["hostname"], type(""))):
+ log.msg("invalid hostname in hint: %r" % (hint,))
+ return None
+ if not("port" in hint
+ and isinstance(hint["port"], six.integer_types)):
+ log.msg("invalid port in hint: %r" % (hint,))
+ return None
+ priority = hint.get("priority", 0.0)
+ if hint_type == "direct-tcp-v1":
+ return DirectTCPV1Hint(hint["hostname"], hint["port"], priority)
+ else:
+ return TorTCPV1Hint(hint["hostname"], hint["port"], priority)
+
+def parse_hint(hint_struct):
+ hint_type = hint_struct.get("type", "")
+ if hint_type == "relay-v1":
+ # the struct can include multiple ways to reach the same relay
+ rhints = filter(lambda h: h, # drop None (unrecognized)
+ [parse_tcp_v1_hint(rh) for rh in hint_struct["hints"]])
+ return RelayV1Hint(rhints)
+ return parse_tcp_v1_hint(hint_struct)
+
+def encode_hint(h):
+ if isinstance(h, DirectTCPV1Hint):
+ return {"type": "direct-tcp-v1",
+ "priority": h.priority,
+ "hostname": h.hostname,
+ "port": h.port, # integer
+ }
+ elif isinstance(h, RelayV1Hint):
+ rhint = {"type": "relay-v1", "hints": []}
+ for rh in h.hints:
+ rhint["hints"].append({"type": "direct-tcp-v1",
+ "priority": rh.priority,
+ "hostname": rh.hostname,
+ "port": rh.port})
+ return rhint
+ elif isinstance(h, TorTCPV1Hint):
+ return {"type": "tor-tcp-v1",
+ "priority": h.priority,
+ "hostname": h.hostname,
+ "port": h.port, # integer
+ }
+ raise ValueError("unknown hint type", h)
+
+def HKDF(skm, outlen, salt=None, CTXinfo=b""):
+ return Hkdf(salt, skm).expand(CTXinfo, outlen)
+
+def build_sided_relay_handshake(key, side):
+ assert isinstance(side, type(u""))
+ assert len(side) == 8*2
+ token = HKDF(key, 32, CTXinfo=b"transit_relay_token")
+ return b"please relay "+hexlify(token)+b" for side "+side.encode("ascii")+b"\n"
+
+PROLOGUE_LEADER = b"Magic-Wormhole Dilation Handshake v1 Leader\n\n"
+PROLOGUE_FOLLOWER = b"Magic-Wormhole Dilation Handshake v1 Follower\n\n"
+NOISEPROTO = "Noise_NNpsk0_25519_ChaChaPoly_BLAKE2s"
+
+@attrs
+@implementer(IDilationConnector)
+class Connector(object):
+ _dilation_key = attrib(validator=instance_of(type(b"")))
+ _transit_relay_location = attrib(validator=optional(instance_of(str)))
+ _manager = attrib(validator=provides(IDilationManager))
+ _reactor = attrib()
+ _eventual_queue = attrib()
+ _no_listen = attrib(validator=instance_of(bool))
+ _tor = attrib()
+ _timing = attrib()
+ _side = attrib(validator=instance_of(type(u"")))
+ # was self._side = bytes_to_hexstr(os.urandom(8)) # unicode
+ _role = attrib()
+
+ m = MethodicalMachine()
+ set_trace = getattr(m, "_setTrace", lambda self, f: None)
+
+ RELAY_DELAY = 2.0
+
+ def __attrs_post_init__(self):
+ if self._transit_relay_location:
+ # TODO: allow multiple hints for a single relay
+ relay_hint = parse_hint_argv(self._transit_relay_location)
+ relay = RelayV1Hint(hints=(relay_hint,))
+ self._transit_relays = [relay]
+ else:
+ self._transit_relays = []
+ self._listeners = set() # IListeningPorts that can be stopped
+ self._pending_connectors = set() # Deferreds that can be cancelled
+ self._pending_connections = EmptyableSet(_eventual_queue=self._eventual_queue) # Protocols to be stopped
+ self._contenders = set() # viable connections
+ self._winning_connection = None
+ self._timing = self._timing or DebugTiming()
+ self._timing.add("transit")
+
+ # this describes what our Connector can do, for the initial advertisement
+ @classmethod
+ def get_connection_abilities(klass):
+ return [{"type": "direct-tcp-v1"},
+ {"type": "relay-v1"},
+ ]
+
+ def build_protocol(self, addr):
+ # encryption: let's use Noise NNpsk0 (or maybe NNpsk2). That uses
+ # ephemeral keys plus a pre-shared symmetric key (the Transit key), a
+ # different one for each potential connection.
+ from noise.connection import NoiseConnection
+ noise = NoiseConnection.from_name(NOISEPROTO)
+ noise.set_psks(self._dilation_key)
+ if self._role is LEADER:
+ noise.set_as_initiator()
+ outbound_prologue = PROLOGUE_LEADER
+ inbound_prologue = PROLOGUE_FOLLOWER
+ else:
+ noise.set_as_responder()
+ outbound_prologue = PROLOGUE_FOLLOWER
+ inbound_prologue = PROLOGUE_LEADER
+ p = DilatedConnectionProtocol(self._eventual_queue, self._role,
+ self, noise,
+ outbound_prologue, inbound_prologue)
+ return p
+
+ @m.state(initial=True)
+ def connecting(self): pass # pragma: no cover
+ @m.state()
+ def connected(self): pass # pragma: no cover
+ @m.state(terminal=True)
+ def stopped(self): pass # pragma: no cover
+
+ # TODO: unify the tense of these method-name verbs
+ @m.input()
+ def listener_ready(self, hint_objs): pass
+ @m.input()
+ def add_relay(self, hint_objs): pass
+ @m.input()
+ def got_hints(self, hint_objs): pass
+ @m.input()
+ def add_candidate(self, c): # called by DilatedConnectionProtocol
+ pass
+ @m.input()
+ def accept(self, c): pass
+ @m.input()
+ def stop(self): pass
+
+ @m.output()
+ def use_hints(self, hint_objs):
+ self._use_hints(hint_objs)
+
+ @m.output()
+ def publish_hints(self, hint_objs):
+ self._manager.send_hints([encode_hint(h) for h in hint_objs])
+
+ @m.output()
+ def consider(self, c):
+ self._contenders.add(c)
+ if self._role is LEADER:
+ # for now, just accept the first one. TODO: be clever.
+ self._eventual_queue.eventually(self.accept, c)
+ else:
+ # the follower always uses the first contender, since that's the
+ # only one the leader picked
+ self._eventual_queue.eventually(self.accept, c)
+
+ @m.output()
+ def select_and_stop_remaining(self, c):
+ self._winning_connection = c
+ self._contenders.clear() # we no longer care who else came close
+ # remove this winner from the losers, so we don't shut it down
+ self._pending_connections.discard(c)
+ # shut down losing connections
+ self.stop_listeners() # TODO: maybe keep it open? NAT/p2p assist
+ self.stop_pending_connectors()
+ self.stop_pending_connections()
+
+ c.select(self._manager) # subsequent frames go directly to the manager
+ if self._role is LEADER:
+ # TODO: this should live in Connection
+ c.send_record(KCM()) # leader sends KCM now
+ self._manager.use_connection(c) # manager sends frames to Connection
+
+ @m.output()
+ def stop_everything(self):
+ self.stop_listeners()
+ self.stop_pending_connectors()
+ self.stop_pending_connections()
+ self.break_cycles()
+
+ def stop_listeners(self):
+ d = DeferredList([l.stopListening() for l in self._listeners])
+ self._listeners.clear()
+ return d # synchronization for tests
+
+ def stop_pending_connectors(self):
+ return DeferredList([d.cancel() for d in self._pending_connectors])
+
+ def stop_pending_connections(self):
+ d = self._pending_connections.when_next_empty()
+ [c.loseConnection() for c in self._pending_connections]
+ return d
+
+ def stop_winner(self):
+ d = self._winner.when_disconnected()
+ self._winner.disconnect()
+ return d
+
+ def break_cycles(self):
+ # help GC by forgetting references to things that reference us
+ self._listeners.clear()
+ self._pending_connectors.clear()
+ self._pending_connections.clear()
+ self._winner = None
+
+ connecting.upon(listener_ready, enter=connecting, outputs=[publish_hints])
+ connecting.upon(add_relay, enter=connecting, outputs=[use_hints,
+ publish_hints])
+ connecting.upon(got_hints, enter=connecting, outputs=[use_hints])
+ connecting.upon(add_candidate, enter=connecting, outputs=[consider])
+ connecting.upon(accept, enter=connected, outputs=[select_and_stop_remaining])
+ connecting.upon(stop, enter=stopped, outputs=[stop_everything])
+
+ # once connected, we ignore everything except stop
+ connected.upon(listener_ready, enter=connected, outputs=[])
+ connected.upon(add_relay, enter=connected, outputs=[])
+ connected.upon(got_hints, enter=connected, outputs=[])
+ connected.upon(add_candidate, enter=connected, outputs=[])
+ connected.upon(accept, enter=connected, outputs=[])
+ connected.upon(stop, enter=stopped, outputs=[stop_everything])
+
+
+ # from Manager: start, got_hints, stop
+ # maybe add_candidate, accept
+ def start(self):
+ self._start_listener()
+ if self._transit_relays:
+ self.publish_hints(self._transit_relays)
+ self._use_hints(self._transit_relays)
+
+ def _start_listener(self):
+ if self._no_listen or self._tor:
+ return
+ addresses = ipaddrs.find_addresses()
+ non_loopback_addresses = [a for a in addresses if a != "127.0.0.1"]
+ if non_loopback_addresses:
+ # some test hosts, including the appveyor VMs, *only* have
+ # 127.0.0.1, and the tests will hang badly if we remove it.
+ addresses = non_loopback_addresses
+ # TODO: listen on a fixed port, if possible, for NAT/p2p benefits, also
+ # to make firewall configs easier
+ # TODO: retain listening port between connection generations?
+ ep = serverFromString(self._reactor, "tcp:0")
+ f = InboundConnectionFactory(self)
+ d = ep.listen(f)
+ def _listening(lp):
+ # lp is an IListeningPort
+ self._listeners.add(lp) # for shutdown and tests
+ portnum = lp.getHost().port
+ direct_hints = [DirectTCPV1Hint(six.u(addr), portnum, 0.0)
+ for addr in addresses]
+ self.listener_ready(direct_hints)
+ d.addCallback(_listening)
+ d.addErrback(log.err)
+
+ def _use_hints(self, hints):
+ # first, pull out all the relays, we'll connect to them later
+ relays = defaultdict(list)
+ direct = defaultdict(list)
+ for h in hints:
+ if isinstance(h, RelayV1Hint):
+ relays[h.priority].append(h)
+ else:
+ direct[h.priority].append(h)
+ delay = 0.0
+ priorities = sorted(set(direct.keys()), reverse=True)
+ for p in priorities:
+ for h in direct[p]:
+ if isinstance(h, TorTCPV1Hint) and not self._tor:
+ continue
+ ep = self._endpoint_from_hint_obj(h)
+ desc = describe_hint_obj(h, False, self._tor)
+ d = deferLater(self._reactor, delay,
+ self._connect, ep, desc, is_relay=False)
+ self._pending_connectors.add(d)
+ # Make all direct connections immediately. Later, we'll change
+ # the add_candidate() function to look at the priority when
+ # deciding whether to accept a successful connection or not,
+ # and it can wait for more options if it sees a higher-priority
+ # one still running. But if we bail on that, we might consider
+ # putting an inter-direct-hint delay here to influence the
+ # process.
+ #delay += 1.0
+ if delay > 0.0:
+ # Start trying the relays a few seconds after we start to try the
+ # direct hints. The idea is to prefer direct connections, but not
+ # be afraid of using a relay when we have direct hints that don't
+ # resolve quickly. Many direct hints will be to unused
+ # local-network IP addresses, which won't answer, and would take
+ # the full TCP timeout (30s or more) to fail. If there were no
+ # direct hints, don't delay at all.
+ delay += self.RELAY_DELAY
+
+ # prefer direct connections by stalling relay connections by a few
+ # seconds, unless we're using --no-listen in which case we're probably
+ # going to have to use the relay
+ delay = self.RELAY_DELAY if self._no_listen else 0.0
+
+ # It might be nice to wire this so that a failure in the direct hints
+ # causes the relay hints to be used right away (fast failover). But
+ # none of our current use cases would take advantage of that: if we
+ # have any viable direct hints, then they're either going to succeed
+ # quickly or hang for a long time.
+ for p in priorities:
+ for r in relays[p]:
+ for h in r.hints:
+ ep = self._endpoint_from_hint_obj(h)
+ desc = describe_hint_obj(h, True, self._tor)
+ d = deferLater(self._reactor, delay,
+ self._connect, ep, desc, is_relay=True)
+ self._pending_connectors.add(d)
+ # TODO:
+ #if not contenders:
+ # raise TransitError("No contenders for connection")
+
+ # TODO: add 2*TIMEOUT deadline for first generation, don't wait forever for
+ # the initial connection
+
+ def _connect(self, h, ep, description, is_relay=False):
+ relay_handshake = None
+ if is_relay:
+ relay_handshake = build_sided_relay_handshake(self._dilation_key,
+ self._side)
+ f = OutboundConnectionFactory(self, relay_handshake)
+ d = ep.connect(f)
+ # fires with protocol, or ConnectError
+ def _connected(p):
+ self._pending_connections.add(p)
+ # c might not be in _pending_connections, if it turned out to be a
+ # winner, which is why we use discard() and not remove()
+ p.when_disconnected().addCallback(self._pending_connections.discard)
+ d.addCallback(_connected)
+ return d
+
+ def _endpoint_from_hint_obj(self, hint):
+ if self._tor:
+ if isinstance(hint, (DirectTCPV1Hint, TorTCPV1Hint)):
+ # this Tor object will throw ValueError for non-public IPv4
+ # addresses and any IPv6 address
+ try:
+ return self._tor.stream_via(hint.hostname, hint.port)
+ except ValueError:
+ return None
+ return None
+ if isinstance(hint, DirectTCPV1Hint):
+ return HostnameEndpoint(self._reactor, hint.hostname, hint.port)
+ return None
+
+
+ # Connection selection. All instances of DilatedConnectionProtocol which
+ # look viable get passed into our add_contender() method.
+
+ # On the Leader side, "viable" means we've seen their KCM frame, which is
+ # the first Noise-encrypted packet on any given connection, and it has an
+ # empty body. We gather viable connections until we see one that we like,
+ # or a timer expires. Then we "select" it, close the others, and tell our
+ # Manager to use it.
+
+ # On the Follower side, we'll only see a KCM on the one connection selected
+ # by the Leader, so the first viable connection wins.
+
+ # our Connection protocols call: add_candidate
+
+@attrs
+class OutboundConnectionFactory(ClientFactory, object):
+ _connector = attrib(validator=provides(IDilationConnector))
+ _relay_handshake = attrib(validator=optional(instance_of(bytes)))
+
+ def buildProtocol(self, addr):
+ p = self._connector.build_protocol(addr)
+ p.factory = self
+ if self._relay_handshake is not None:
+ p.use_relay(self._relay_handshake)
+ return p
+
+@attrs
+class InboundConnectionFactory(ServerFactory, object):
+ _connector = attrib(validator=provides(IDilationConnector))
+ protocol = DilatedConnectionProtocol
+
+ def buildProtocol(self, addr):
+ p = self._connector.build_protocol(addr)
+ p.factory = self
+ return p
diff --git a/src/wormhole/_dilation/encode.py b/src/wormhole/_dilation/encode.py
new file mode 100644
index 0000000..eb1b0b2
--- /dev/null
+++ b/src/wormhole/_dilation/encode.py
@@ -0,0 +1,16 @@
+from __future__ import print_function, unicode_literals
+import struct
+
+assert len(struct.pack("L", value)
+def from_be4(b):
+ if not isinstance(b, bytes):
+ raise TypeError(repr(b))
+ if len(b) != 4:
+ raise ValueError
+ return struct.unpack(">L", b)[0]
diff --git a/src/wormhole/_dilation/inbound.py b/src/wormhole/_dilation/inbound.py
new file mode 100644
index 0000000..9235681
--- /dev/null
+++ b/src/wormhole/_dilation/inbound.py
@@ -0,0 +1,127 @@
+from __future__ import print_function, unicode_literals
+from attr import attrs, attrib
+from attr.validators import provides
+from zope.interface import implementer
+from twisted.python import log
+from .._interfaces import IDilationManager, IInbound
+from .subchannel import (SubChannel, _SubchannelAddress)
+
+class DuplicateOpenError(Exception):
+ pass
+class DataForMissingSubchannelError(Exception):
+ pass
+class CloseForMissingSubchannelError(Exception):
+ pass
+
+@attrs
+@implementer(IInbound)
+class Inbound(object):
+ # Inbound flow control: TCP delivers data to Connection.dataReceived,
+ # Connection delivers to our handle_data, we deliver to
+ # SubChannel.remote_data, subchannel delivers to proto.dataReceived
+ _manager = attrib(validator=provides(IDilationManager))
+ _host_addr = attrib()
+
+ def __attrs_post_init__(self):
+ # we route inbound Data records to Subchannels .dataReceived
+ self._open_subchannels = {} # scid -> Subchannel
+ self._paused_subchannels = set() # Subchannels that have paused us
+ # the set is non-empty, we pause the transport
+ self._highest_inbound_acked = -1
+ self._connection = None
+
+ # from our Manager
+ def set_listener_endpoint(self, listener_endpoint):
+ self._listener_endpoint = listener_endpoint
+
+ def set_subchannel_zero(self, scid0, sc0):
+ self._open_subchannels[scid0] = sc0
+
+
+ def use_connection(self, c):
+ self._connection = c
+ # We can pause the connection's reads when we receive too much data. If
+ # this is a non-initial connection, then we might already have
+ # subchannels that are paused from before, so we might need to pause
+ # the new connection before it can send us any data
+ if self._paused_subchannels:
+ self._connection.pauseProducing()
+
+ # Inbound is responsible for tracking the high watermark and deciding
+ # whether to ignore inbound messages or not
+
+ def is_record_old(self, r):
+ if r.seqnum <= self._highest_inbound_acked:
+ return True
+ return False
+
+ def update_ack_watermark(self, r):
+ self._highest_inbound_acked = max(self._highest_inbound_acked,
+ r.seqnum)
+
+ def handle_open(self, scid):
+ if scid in self._open_subchannels:
+ log.err(DuplicateOpenError("received duplicate OPEN for {}".format(scid)))
+ return
+ peer_addr = _SubchannelAddress(scid)
+ sc = SubChannel(scid, self._manager, self._host_addr, peer_addr)
+ self._open_subchannels[scid] = sc
+ self._listener_endpoint._got_open(sc, peer_addr)
+
+ def handle_data(self, scid, data):
+ sc = self._open_subchannels.get(scid)
+ if sc is None:
+ log.err(DataForMissingSubchannelError("received DATA for non-existent subchannel {}".format(scid)))
+ return
+ sc.remote_data(data)
+
+ def handle_close(self, scid):
+ sc = self._open_subchannels.get(scid)
+ if sc is None:
+ log.err(CloseForMissingSubchannelError("received CLOSE for non-existent subchannel {}".format(scid)))
+ return
+ sc.remote_close()
+
+ def subchannel_closed(self, scid, sc):
+ # connectionLost has just been signalled
+ assert self._open_subchannels[scid] is sc
+ del self._open_subchannels[scid]
+
+ def stop_using_connection(self):
+ self._connection = None
+
+
+ # from our Subchannel, or rather from the Protocol above it and sent
+ # through the subchannel
+
+ # The subchannel is an IProducer, and application protocols can always
+ # thell them to pauseProducing if we're delivering inbound data too
+ # quickly. They don't need to register anything.
+
+ def subchannel_pauseProducing(self, sc):
+ was_paused = bool(self._paused_subchannels)
+ self._paused_subchannels.add(sc)
+ if self._connection and not was_paused:
+ self._connection.pauseProducing()
+
+ def subchannel_resumeProducing(self, sc):
+ was_paused = bool(self._paused_subchannels)
+ self._paused_subchannels.discard(sc)
+ if self._connection and was_paused and not self._paused_subchannels:
+ self._connection.resumeProducing()
+
+ def subchannel_stopProducing(self, sc):
+ # This protocol doesn't want any additional data. If we were a normal
+ # (single-owner) Transport, we'd call .loseConnection now. But our
+ # Connection is shared among many subchannels, so instead we just
+ # stop letting them pause the connection.
+ was_paused = bool(self._paused_subchannels)
+ self._paused_subchannels.discard(sc)
+ if self._connection and was_paused and not self._paused_subchannels:
+ self._connection.resumeProducing()
+
+ # TODO: we might refactor these pause/resume/stop methods by building a
+ # context manager that look at the paused/not-paused state first, then
+ # lets the caller modify self._paused_subchannels, then looks at it a
+ # second time, and calls c.pauseProducing/c.resumeProducing as
+ # appropriate. I'm not sure it would be any cleaner, though.
diff --git a/src/wormhole/_dilation/manager.py b/src/wormhole/_dilation/manager.py
new file mode 100644
index 0000000..860665a
--- /dev/null
+++ b/src/wormhole/_dilation/manager.py
@@ -0,0 +1,514 @@
+from __future__ import print_function, unicode_literals
+from collections import deque
+from attr import attrs, attrib
+from attr.validators import provides, instance_of, optional
+from automat import MethodicalMachine
+from zope.interface import implementer
+from twisted.internet.defer import Deferred, inlineCallbacks, returnValue
+from twisted.python import log
+from .._interfaces import IDilator, IDilationManager, ISend
+from ..util import dict_to_bytes, bytes_to_dict
+from ..observer import OneShotObserver
+from .._key import derive_key
+from .encode import to_be4
+from .subchannel import (SubChannel, _SubchannelAddress, _WormholeAddress,
+ ControlEndpoint, SubchannelConnectorEndpoint,
+ SubchannelListenerEndpoint)
+from .connector import Connector, parse_hint
+from .roles import LEADER, FOLLOWER
+from .connection import KCM, Ping, Pong, Open, Data, Close, Ack
+from .inbound import Inbound
+from .outbound import Outbound
+
+class OldPeerCannotDilateError(Exception):
+ pass
+class UnknownDilationMessageType(Exception):
+ pass
+class ReceivedHintsTooEarly(Exception):
+ pass
+
+@attrs
+@implementer(IDilationManager)
+class _ManagerBase(object):
+ _S = attrib(validator=provides(ISend))
+ _my_side = attrib(validator=instance_of(type(u"")))
+ _transit_key = attrib(validator=instance_of(bytes))
+ _transit_relay_location = attrib(validator=optional(instance_of(str)))
+ _reactor = attrib()
+ _eventual_queue = attrib()
+ _cooperator = attrib()
+ _no_listen = False # TODO
+ _tor = None # TODO
+ _timing = None # TODO
+
+ def __attrs_post_init__(self):
+ self._got_versions_d = Deferred()
+
+ self._my_role = None # determined upon rx_PLEASE
+
+ self._connection = None
+ self._made_first_connection = False
+ self._first_connected = OneShotObserver(self._eventual_queue)
+ self._host_addr = _WormholeAddress()
+
+ self._next_dilation_phase = 0
+
+ self._next_subchannel_id = 0 # increments by 2
+
+ # I kept getting confused about which methods were for inbound data
+ # (and thus flow-control methods go "out") and which were for
+ # outbound data (with flow-control going "in"), so I split them up
+ # into separate pieces.
+ self._inbound = Inbound(self, self._host_addr)
+ self._outbound = Outbound(self, self._cooperator) # from us to peer
+
+ def set_listener_endpoint(self, listener_endpoint):
+ self._inbound.set_listener_endpoint(listener_endpoint)
+ def set_subchannel_zero(self, scid0, sc0):
+ self._inbound.set_subchannel_zero(scid0, sc0)
+
+ def when_first_connected(self):
+ return self._first_connected.when_fired()
+
+
+ def send_dilation_phase(self, **fields):
+ dilation_phase = self._next_dilation_phase
+ self._next_dilation_phase += 1
+ self._S.send("dilate-%d" % dilation_phase, dict_to_bytes(fields))
+
+ def send_hints(self, hints): # from Connector
+ self.send_dilation_phase(type="connection-hints", hints=hints)
+
+
+ # forward inbound-ish things to _Inbound
+ def subchannel_pauseProducing(self, sc):
+ self._inbound.subchannel_pauseProducing(sc)
+ def subchannel_resumeProducing(self, sc):
+ self._inbound.subchannel_resumeProducing(sc)
+ def subchannel_stopProducing(self, sc):
+ self._inbound.subchannel_stopProducing(sc)
+
+ # forward outbound-ish things to _Outbound
+ def subchannel_registerProducer(self, sc, producer, streaming):
+ self._outbound.subchannel_registerProducer(sc, producer, streaming)
+ def subchannel_unregisterProducer(self, sc):
+ self._outbound.subchannel_unregisterProducer(sc)
+
+ def send_open(self, scid):
+ self._queue_and_send(Open, scid)
+ def send_data(self, scid, data):
+ self._queue_and_send(Data, scid, data)
+ def send_close(self, scid):
+ self._queue_and_send(Close, scid)
+
+ def _queue_and_send(self, record_type, *args):
+ r = self._outbound.build_record(record_type, *args)
+ # Outbound owns the send_record() pipe, so that it can stall new
+ # writes after a new connection is made until after all queued
+ # messages are written (to preserve ordering).
+ self._outbound.queue_and_send_record(r) # may trigger pauseProducing
+
+ def subchannel_closed(self, scid, sc):
+ # let everyone clean up. This happens just after we delivered
+ # connectionLost to the Protocol, except for the control channel,
+ # which might get connectionLost later after they use ep.connect.
+ # TODO: is this inversion a problem?
+ self._inbound.subchannel_closed(scid, sc)
+ self._outbound.subchannel_closed(scid, sc)
+
+
+ def _start_connecting(self, role):
+ assert self._my_role is not None
+ self._connector = Connector(self._transit_key,
+ self._transit_relay_location,
+ self,
+ self._reactor, self._eventual_queue,
+ self._no_listen, self._tor,
+ self._timing,
+ self._side, # needed for relay handshake
+ self._my_role)
+ self._connector.start()
+
+ # our Connector calls these
+
+ def connector_connection_made(self, c):
+ self.connection_made() # state machine update
+ self._connection = c
+ self._inbound.use_connection(c)
+ self._outbound.use_connection(c) # does c.registerProducer
+ if not self._made_first_connection:
+ self._made_first_connection = True
+ self._first_connected.fire(None)
+ pass
+ def connector_connection_lost(self):
+ self._stop_using_connection()
+ if self.role is LEADER:
+ self.connection_lost_leader() # state machine
+ else:
+ self.connection_lost_follower()
+
+
+ def _stop_using_connection(self):
+ # the connection is already lost by this point
+ self._connection = None
+ self._inbound.stop_using_connection()
+ self._outbound.stop_using_connection() # does c.unregisterProducer
+
+ # from our active Connection
+
+ def got_record(self, r):
+ # records with sequence numbers: always ack, ignore old ones
+ if isinstance(r, (Open, Data, Close)):
+ self.send_ack(r.seqnum) # always ack, even for old ones
+ if self._inbound.is_record_old(r):
+ return
+ self._inbound.update_ack_watermark(r.seqnum)
+ if isinstance(r, Open):
+ self._inbound.handle_open(r.scid)
+ elif isinstance(r, Data):
+ self._inbound.handle_data(r.scid, r.data)
+ else: # isinstance(r, Close)
+ self._inbound.handle_close(r.scid)
+ if isinstance(r, KCM):
+ log.err("got unexpected KCM")
+ elif isinstance(r, Ping):
+ self.handle_ping(r.ping_id)
+ elif isinstance(r, Pong):
+ self.handle_pong(r.ping_id)
+ elif isinstance(r, Ack):
+ self._outbound.handle_ack(r.resp_seqnum) # retire queued messages
+ else:
+ log.err("received unknown message type {}".format(r))
+
+ # pings, pongs, and acks are not queued
+ def send_ping(self, ping_id):
+ self._outbound.send_if_connected(Ping(ping_id))
+
+ def send_pong(self, ping_id):
+ self._outbound.send_if_connected(Pong(ping_id))
+
+ def send_ack(self, resp_seqnum):
+ self._outbound.send_if_connected(Ack(resp_seqnum))
+
+
+ def handle_ping(self, ping_id):
+ self.send_pong(ping_id)
+
+ def handle_pong(self, ping_id):
+ # TODO: update is-alive timer
+ pass
+
+ # subchannel maintenance
+ def allocate_subchannel_id(self):
+ raise NotImplemented # subclass knows if we're leader or follower
+
+# new scheme:
+# * both sides send PLEASE as soon as they have an unverified key and
+# w.dilate has been called,
+# * PLEASE includes a dilation-specific "side" (independent of the "side"
+# used by mailbox messages)
+# * higher "side" is Leader, lower is Follower
+# * PLEASE includes can-dilate list of version integers, requires overlap
+# "1" is current
+# * dilation starts as soon as we've sent PLEASE and received PLEASE
+# (four-state two-variable IDLE/WANTING/WANTED/STARTED diamond FSM)
+# * HINTS sent after dilation starts
+# * only Leader sends RECONNECT, only Follower sends RECONNECTING. This
+# is the only difference between the two sides, and is not enforced
+# by the protocol (i.e. if the Follower sends RECONNECT to the Leader,
+# the Leader will obey, although TODO how confusing will this get?)
+# * upon receiving RECONNECT: drop Connector, start new Connector, send
+# RECONNECTING, start sending HINTS
+# * upon sending CONNECT: go into FLUSHING state and ignore all HINTS until
+# RECONNECTING received. The new Connector can be spun up earlier, and it
+# can send HINTS, but it must not be given any HINTS that arrive before
+# RECONNECTING (since they're probably stale)
+
+# * after VERSIONS(KCM) received, we might learn that they other side cannot
+# dilate. w.dilate errbacks at this point
+
+# * maybe signal warning if we stay in a "want" state for too long
+# * nobody sends HINTS until they're ready to receive
+# * nobody sends HINTS unless they've called w.dilate() and received PLEASE
+# * nobody connects to inbound hints unless they've called w.dilate()
+# * if leader calls w.dilate() but not follower, leader waits forever in
+# "want" (doesn't send anything)
+# * if follower calls w.dilate() but not leader, follower waits forever
+# in "want", leader waits forever in "wanted"
+
+class ManagerShared(_ManagerBase):
+ m = MethodicalMachine()
+ set_trace = getattr(m, "_setTrace", lambda self, f: None)
+
+ @m.state(initial=True)
+ def IDLE(self): pass # pragma: no cover
+
+ @m.state()
+ def WANTING(self): pass # pragma: no cover
+ @m.state()
+ def WANTED(self): pass # pragma: no cover
+ @m.state()
+ def CONNECTING(self): pass # pragma: no cover
+ @m.state()
+ def CONNECTED(self): pass # pragma: no cover
+ @m.state()
+ def FLUSHING(self): pass # pragma: no cover
+ @m.state()
+ def ABANDONING(self): pass # pragma: no cover
+ @m.state()
+ def LONELY(self): pass # pragme: no cover
+ @m.state()
+ def STOPPING(self): pass # pragma: no cover
+ @m.state(terminal=True)
+ def STOPPED(self): pass # pragma: no cover
+
+ @m.input()
+ def start(self): pass # pragma: no cover
+ @m.input()
+ def rx_PLEASE(self, message): pass # pragma: no cover
+ @m.input() # only sent by Follower
+ def rx_HINTS(self, hint_message): pass # pragma: no cover
+ @m.input() # only Leader sends RECONNECT, so only Follower receives it
+ def rx_RECONNECT(self): pass # pragma: no cover
+ @m.input() # only Follower sends RECONNECTING, so only Leader receives it
+ def rx_RECONNECTING(self): pass # pragma: no cover
+
+ # Connector gives us connection_made()
+ @m.input()
+ def connection_made(self, c): pass # pragma: no cover
+
+ # our connection_lost() fires connection_lost_leader or
+ # connection_lost_follower depending upon our role. If either side sees a
+ # problem with the connection (timeouts, bad authentication) then they
+ # just drop it and let connection_lost() handle the cleanup.
+ @m.input()
+ def connection_lost_leader(self): pass # pragma: no cover
+ @m.input()
+ def connection_lost_follower(self): pass
+
+ @m.input()
+ def stop(self): pass # pragma: no cover
+
+ @m.output()
+ def stash_side(self, message):
+ their_side = message["side"]
+ self.my_role = LEADER if self._my_side > their_side else FOLLOWER
+
+ # these Outputs behave differently for the Leader vs the Follower
+ @m.output()
+ def send_please(self):
+ self.send_dilation_phase(type="please", side=self._my_side)
+
+ @m.output()
+ def start_connecting(self):
+ self._start_connecting() # TODO: merge
+ @m.output()
+ def ignore_message_start_connecting(self, message):
+ self.start_connecting()
+
+ @m.output()
+ def send_reconnect(self):
+ self.send_dilation_phase(type="reconnect") # TODO: generation number?
+ @m.output()
+ def send_reconnecting(self):
+ self.send_dilation_phase(type="reconnecting") # TODO: generation?
+
+ @m.output()
+ def use_hints(self, hint_message):
+ hint_objs = filter(lambda h: h, # ignore None, unrecognizable
+ [parse_hint(hs) for hs in hint_message["hints"]])
+ hint_objs = list(hint_objs)
+ self._connector.got_hints(hint_objs)
+ @m.output()
+ def stop_connecting(self):
+ self._connector.stop()
+ @m.output()
+ def abandon_connection(self):
+ # we think we're still connected, but the Leader disagrees. Or we've
+ # been told to shut down.
+ self._connection.disconnect() # let connection_lost do cleanup
+
+
+ # we don't start CONNECTING until a local start() plus rx_PLEASE
+ IDLE.upon(rx_PLEASE, enter=WANTED, outputs=[stash_side])
+ IDLE.upon(start, enter=WANTING, outputs=[send_please])
+ WANTED.upon(start, enter=CONNECTING, outputs=[send_please, start_connecting])
+ WANTING.upon(rx_PLEASE, enter=CONNECTING,
+ outputs=[stash_side,
+ ignore_message_start_connecting])
+
+ CONNECTING.upon(connection_made, enter=CONNECTED, outputs=[])
+
+ # Leader
+ CONNECTED.upon(connection_lost_leader, enter=FLUSHING,
+ outputs=[send_reconnect])
+ FLUSHING.upon(rx_RECONNECTING, enter=CONNECTING, outputs=[start_connecting])
+
+ # Follower
+ # if we notice a lost connection, just wait for the Leader to notice too
+ CONNECTED.upon(connection_lost_follower, enter=LONELY, outputs=[])
+ LONELY.upon(rx_RECONNECT, enter=CONNECTING, outputs=[start_connecting])
+ # but if they notice it first, abandon our (seemingly functional)
+ # connection, then tell them that we're ready to try again
+ CONNECTED.upon(rx_RECONNECT, enter=ABANDONING, # they noticed loss
+ outputs=[abandon_connection])
+ ABANDONING.upon(connection_lost_follower, enter=CONNECTING,
+ outputs=[send_reconnecting, start_connecting])
+ # and if they notice a problem while we're still connecting, abandon our
+ # incomplete attempt and try again. in this case we don't have to wait
+ # for a connection to finish shutdown
+ CONNECTING.upon(rx_RECONNECT, enter=CONNECTING,
+ outputs=[stop_connecting,
+ send_reconnecting,
+ start_connecting])
+
+
+ # rx_HINTS never changes state, they're just accepted or ignored
+ IDLE.upon(rx_HINTS, enter=IDLE, outputs=[]) # too early
+ WANTED.upon(rx_HINTS, enter=WANTED, outputs=[]) # too early
+ WANTING.upon(rx_HINTS, enter=WANTING, outputs=[]) # too early
+ CONNECTING.upon(rx_HINTS, enter=CONNECTING, outputs=[use_hints])
+ CONNECTED.upon(rx_HINTS, enter=CONNECTED, outputs=[]) # too late, ignore
+ FLUSHING.upon(rx_HINTS, enter=FLUSHING, outputs=[]) # stale, ignore
+ LONELY.upon(rx_HINTS, enter=LONELY, outputs=[]) # stale, ignore
+ ABANDONING.upon(rx_HINTS, enter=ABANDONING, outputs=[]) # shouldn't happen
+ STOPPING.upon(rx_HINTS, enter=STOPPING, outputs=[])
+
+ IDLE.upon(stop, enter=STOPPED, outputs=[])
+ WANTED.upon(stop, enter=STOPPED, outputs=[])
+ WANTING.upon(stop, enter=STOPPED, outputs=[])
+ CONNECTING.upon(stop, enter=STOPPED, outputs=[stop_connecting])
+ CONNECTED.upon(stop, enter=STOPPING, outputs=[abandon_connection])
+ ABANDONING.upon(stop, enter=STOPPING, outputs=[])
+ FLUSHING.upon(stop, enter=STOPPED, outputs=[stop_connecting])
+ LONELY.upon(stop, enter=STOPPED, outputs=[])
+ STOPPING.upon(connection_lost_leader, enter=STOPPED, outputs=[])
+ STOPPING.upon(connection_lost_follower, enter=STOPPED, outputs=[])
+
+
+ def allocate_subchannel_id(self):
+ # scid 0 is reserved for the control channel. the leader uses odd
+ # numbers starting with 1
+ scid_num = self._next_outbound_seqnum + 1
+ self._next_outbound_seqnum += 2
+ return to_be4(scid_num)
+
+@attrs
+@implementer(IDilator)
+class Dilator(object):
+ """I launch the dilation process.
+
+ I am created with every Wormhole (regardless of whether .dilate()
+ was called or not), and I handle the initial phase of dilation,
+ before we know whether we'll be the Leader or the Follower. Once we
+ hear the other side's VERSION message (which tells us that we have a
+ connection, they are capable of dilating, and which side we're on),
+ then we build a DilationManager and hand control to it.
+ """
+
+ _reactor = attrib()
+ _eventual_queue = attrib()
+ _cooperator = attrib()
+
+ def __attrs_post_init__(self):
+ self._got_versions_d = Deferred()
+ self._started = False
+ self._endpoints = OneShotObserver(self._eventual_queue)
+ self._pending_inbound_dilate_messages = deque()
+ self._manager = None
+
+ def wire(self, sender):
+ self._S = ISend(sender)
+
+ # this is the primary entry point, called when w.dilate() is invoked
+ def dilate(self, transit_relay_location=None):
+ self._transit_relay_location = transit_relay_location
+ if not self._started:
+ self._started = True
+ self._start().addBoth(self._endpoints.fire)
+ return self._endpoints.when_fired()
+
+ @inlineCallbacks
+ def _start(self):
+ # first, we wait until we hear the VERSION message, which tells us 1:
+ # the PAKE key works, so we can talk securely, 2: their side, so we
+ # know who will lead, and 3: that they can do dilation at all
+
+ dilation_version = yield self._got_versions_d
+
+ if not dilation_version: # 1 or None
+ raise OldPeerCannotDilateError()
+
+ my_dilation_side = TODO # random
+ self._manager = Manager(self._S, my_dilation_side,
+ self._transit_key,
+ self._transit_relay_location,
+ self._reactor, self._eventual_queue,
+ self._cooperator)
+ self._manager.start()
+
+ while self._pending_inbound_dilate_messages:
+ plaintext = self._pending_inbound_dilate_messages.popleft()
+ self.received_dilate(plaintext)
+
+ # we could probably return the endpoints earlier
+ yield self._manager.when_first_connected()
+ # we can open subchannels as soon as we get our first connection
+ scid0 = b"\x00\x00\x00\x00"
+ self._host_addr = _WormholeAddress() # TODO: share with Manager
+ peer_addr0 = _SubchannelAddress(scid0)
+ control_ep = ControlEndpoint(peer_addr0)
+ sc0 = SubChannel(scid0, self._manager, self._host_addr, peer_addr0)
+ control_ep._subchannel_zero_opened(sc0)
+ self._manager.set_subchannel_zero(scid0, sc0)
+
+ connect_ep = SubchannelConnectorEndpoint(self._manager, self._host_addr)
+
+ listen_ep = SubchannelListenerEndpoint(self._manager, self._host_addr)
+ self._manager.set_listener_endpoint(listen_ep)
+
+ endpoints = (control_ep, connect_ep, listen_ep)
+ returnValue(endpoints)
+
+ # from Boss
+
+ def got_key(self, key):
+ # TODO: verify this happens before got_wormhole_versions, or add a gate
+ # to tolerate either ordering
+ purpose = b"dilation-v1"
+ LENGTH =32 # TODO: whatever Noise wants, I guess
+ self._transit_key = derive_key(key, purpose, LENGTH)
+
+ def got_wormhole_versions(self, our_side, their_side,
+ their_wormhole_versions):
+ # TODO: remove our_side, their_side
+ assert isinstance(our_side, str), str
+ assert isinstance(their_side, str), str
+ # this always happens before received_dilate
+ dilation_version = None
+ their_dilation_versions = their_wormhole_versions.get("can-dilate", [])
+ if 1 in their_dilation_versions:
+ dilation_version = 1
+ self._got_versions_d.callback(dilation_version)
+
+ def received_dilate(self, plaintext):
+ # this receives new in-order DILATE-n payloads, decrypted but not
+ # de-JSONed.
+
+ # this can appear before our .dilate() method is called, in which case
+ # we queue them for later
+ if not self._manager:
+ self._pending_inbound_dilate_messages.append(plaintext)
+ return
+
+ message = bytes_to_dict(plaintext)
+ type = message["type"]
+ if type == "please":
+ self._manager.rx_PLEASE() # message)
+ elif type == "dilate":
+ self._manager.rx_DILATE() #message)
+ elif type == "connection-hints":
+ self._manager.rx_HINTS(message)
+ else:
+ log.err(UnknownDilationMessageType(message))
+ return
diff --git a/src/wormhole/_dilation/old-follower.py b/src/wormhole/_dilation/old-follower.py
new file mode 100644
index 0000000..68e3a38
--- /dev/null
+++ b/src/wormhole/_dilation/old-follower.py
@@ -0,0 +1,106 @@
+
+class ManagerFollower(_ManagerBase):
+ m = MethodicalMachine()
+ set_trace = getattr(m, "_setTrace", lambda self, f: None)
+
+ @m.state(initial=True)
+ def IDLE(self): pass # pragma: no cover
+
+ @m.state()
+ def WANTING(self): pass # pragma: no cover
+ @m.state()
+ def CONNECTING(self): pass # pragma: no cover
+ @m.state()
+ def CONNECTED(self): pass # pragma: no cover
+ @m.state(terminal=True)
+ def STOPPED(self): pass # pragma: no cover
+
+ @m.input()
+ def start(self): pass # pragma: no cover
+ @m.input()
+ def rx_PLEASE(self): pass # pragma: no cover
+ @m.input()
+ def rx_DILATE(self): pass # pragma: no cover
+ @m.input()
+ def rx_HINTS(self, hint_message): pass # pragma: no cover
+
+ @m.input()
+ def connection_made(self): pass # pragma: no cover
+ @m.input()
+ def connection_lost(self): pass # pragma: no cover
+ # follower doesn't react to connection_lost, but waits for a new LETS_DILATE
+
+ @m.input()
+ def stop(self): pass # pragma: no cover
+
+ # these Outputs behave differently for the Leader vs the Follower
+ @m.output()
+ def send_please(self):
+ self.send_dilation_phase(type="please")
+
+ @m.output()
+ def start_connecting(self):
+ self._start_connecting(FOLLOWER)
+
+ # these Outputs delegate to the same code in both the Leader and the
+ # Follower, but they must be replicated here because the Automat instance
+ # is on the subclass, not the shared superclass
+
+ @m.output()
+ def use_hints(self, hint_message):
+ hint_objs = filter(lambda h: h, # ignore None, unrecognizable
+ [parse_hint(hs) for hs in hint_message["hints"]])
+ self._connector.got_hints(hint_objs)
+ @m.output()
+ def stop_connecting(self):
+ self._connector.stop()
+ @m.output()
+ def use_connection(self, c):
+ self._use_connection(c)
+ @m.output()
+ def stop_using_connection(self):
+ self._stop_using_connection()
+ @m.output()
+ def signal_error(self):
+ pass # TODO
+ @m.output()
+ def signal_error_hints(self, hint_message):
+ pass # TODO
+
+ IDLE.upon(rx_HINTS, enter=STOPPED, outputs=[signal_error_hints]) # too early
+ IDLE.upon(rx_DILATE, enter=STOPPED, outputs=[signal_error]) # too early
+ # leader shouldn't send us DILATE before receiving our PLEASE
+ IDLE.upon(stop, enter=STOPPED, outputs=[])
+ IDLE.upon(start, enter=WANTING, outputs=[send_please])
+ WANTING.upon(rx_DILATE, enter=CONNECTING, outputs=[start_connecting])
+ WANTING.upon(stop, enter=STOPPED, outputs=[])
+
+ CONNECTING.upon(rx_HINTS, enter=CONNECTING, outputs=[use_hints])
+ CONNECTING.upon(connection_made, enter=CONNECTED, outputs=[use_connection])
+ # shouldn't happen: connection_lost
+ #CONNECTING.upon(connection_lost, enter=CONNECTING, outputs=[?])
+ CONNECTING.upon(rx_DILATE, enter=CONNECTING, outputs=[stop_connecting,
+ start_connecting])
+ # receiving rx_DILATE while we're still working on the last one means the
+ # leader thought we'd connected, then thought we'd been disconnected, all
+ # before we heard about that connection
+ CONNECTING.upon(stop, enter=STOPPED, outputs=[stop_connecting])
+
+ CONNECTED.upon(connection_lost, enter=WANTING, outputs=[stop_using_connection])
+ CONNECTED.upon(rx_DILATE, enter=CONNECTING, outputs=[stop_using_connection,
+ start_connecting])
+ CONNECTED.upon(rx_HINTS, enter=CONNECTED, outputs=[]) # too late, ignore
+ CONNECTED.upon(stop, enter=STOPPED, outputs=[stop_using_connection])
+ # shouldn't happen: connection_made
+
+ # we should never receive PLEASE, we're the follower
+ IDLE.upon(rx_PLEASE, enter=STOPPED, outputs=[signal_error])
+ WANTING.upon(rx_PLEASE, enter=STOPPED, outputs=[signal_error])
+ CONNECTING.upon(rx_PLEASE, enter=STOPPED, outputs=[signal_error])
+ CONNECTED.upon(rx_PLEASE, enter=STOPPED, outputs=[signal_error])
+
+ def allocate_subchannel_id(self):
+ # the follower uses even numbers starting with 2
+ scid_num = self._next_outbound_seqnum + 2
+ self._next_outbound_seqnum += 2
+ return to_be4(scid_num)
diff --git a/src/wormhole/_dilation/outbound.py b/src/wormhole/_dilation/outbound.py
new file mode 100644
index 0000000..6538ffe
--- /dev/null
+++ b/src/wormhole/_dilation/outbound.py
@@ -0,0 +1,392 @@
+from __future__ import print_function, unicode_literals
+from collections import deque
+from attr import attrs, attrib
+from attr.validators import provides
+from zope.interface import implementer
+from twisted.internet.interfaces import IPushProducer, IPullProducer
+from twisted.python import log
+from twisted.python.reflect import safe_str
+from .._interfaces import IDilationManager, IOutbound
+from .connection import KCM, Ping, Pong, Ack
+
+
+# Outbound flow control: app writes to subchannel, we write to Connection
+
+# The app can register an IProducer of their choice, to let us throttle their
+# outbound data. Not all subchannels will have producers registered, and the
+# producer probably won't be the IProtocol instance (it'll be something else
+# which feeds data out through the protocol, like a t.p.basic.FileSender). If
+# a producerless subchannel writes too much, we won't be able to stop them,
+# and we'll keep writing records into the Connection even though it's asked
+# us to pause. Likewise, when the connection is down (and we're busily trying
+# to reestablish a new one), registered subchannels will be paused, but
+# unregistered ones will just dump everything in _outbound_queue, and we'll
+# consume memory without bound until they stop.
+
+# We need several things:
+#
+# * Add each registered IProducer to a list, whose order remains stable. We
+# want fairness under outbound throttling: each time the outbound
+# connection opens up (our resumeProducing method is called), we should let
+# just one producer have an opportunity to do transport.write, and then we
+# should pause them again, and not come back to them until everyone else
+# has gotten a turn. So we want an ordered list of producers to track this
+# rotation.
+#
+# * Remove the IProducer if/when the protocol uses unregisterProducer
+#
+# * Remove any registered IProducer when the associated Subchannel is closed.
+# This isn't a problem for normal transports, because usually there's a
+# one-to-one mapping from Protocol to Transport, so when the Transport you
+# forget the only reference to the Producer anyways. Our situation is
+# unusual because we have multiple Subchannels that get merged into the
+# same underlying Connection: each Subchannel's Protocol can register a
+# producer on the Subchannel (which is an ITransport), but that adds it to
+# a set of Producers for the Connection (which is also an ITransport). So
+# if the Subchannel is closed, we need to remove its Producer (if any) even
+# though the Connection remains open.
+#
+# * Register ourselves as an IPushProducer with each successive Connection
+# object. These connections will come and go, but there will never be more
+# than one. When the connection goes away, pause all our producers. When a
+# new one is established, write all our queued messages, then unpause our
+# producers as we would in resumeProducing.
+#
+# * Inside our resumeProducing call, we'll cycle through all producers,
+# calling their individual resumeProducing methods one at a time. If they
+# write so much data that the Connection pauses us again, we'll find out
+# because our pauseProducing will be called inside that loop. When that
+# happens, we need to stop looping. If we make it through the whole loop
+# without being paused, then all subchannel Producers are left unpaused,
+# and are free to write whenever they want. During this loop, some
+# Producers will be paused, and others will be resumed
+#
+# * If our pauseProducing is called, all Producers must be paused, and a flag
+# should be set to notify the resumeProducing loop to exit
+#
+# * In between calls to our resumeProducing method, we're in one of two
+# states.
+# * If we're writing data too fast, then we'll be left in the "paused"
+# state, in which all Subchannel producers are paused, and the aggregate
+# is paused too (our Connection told us to pauseProducing and hasn't yet
+# told us to resumeProducing). In this state, activity is driven by the
+# outbound TCP window opening up, which calls resumeProducing and allows
+# (probably just) one message to be sent. We receive pauseProducing in
+# the middle of their transport.write, so the loop exits early, and the
+# only state change is that some other Producer should get to go next
+# time.
+# * If we're writing too slowly, we'll be left in the "unpaused" state: all
+# Subchannel producers are unpaused, and the aggregate is unpaused too
+# (resumeProducing is the last thing we've been told). In this satte,
+# activity is driven by the Subchannels doing a transport.write, which
+# queues some data on the TCP connection (and then might call
+# pauseProducing if it's now full).
+#
+# * We want to guard against:
+#
+# * application protocol registering a Producer without first unregistering
+# the previous one
+#
+# * application protocols writing data despite being told to pause
+# (Subchannels without a registered Producer cannot be throttled, and we
+# can't do anything about that, but we must also handle the case where
+# they give us a pause switch and then proceed to ignore it)
+#
+# * our Connection calling resumeProducing or pauseProducing without an
+# intervening call of the other kind
+#
+# * application protocols that don't handle a resumeProducing or
+# pauseProducing call without an intervening call of the other kind (i.e.
+# we should keep track of the last thing we told them, and not repeat
+# ourselves)
+#
+# * If the Wormhole is closed, all Subchannels should close. This is not our
+# responsibility: it lives in (Manager? Inbound?)
+#
+# * If we're given an IPullProducer, we should keep calling its
+# resumeProducing until it runs out of data. We still want fairness, so we
+# won't call it a second time until everyone else has had a turn.
+
+
+# There are a couple of different ways to approach this. The one I've
+# selected is:
+#
+# * keep a dict that maps from Subchannel to Producer, which only contains
+# entries for Subchannels that have registered a producer. We use this to
+# remove Producers when Subchannels are closed
+#
+# * keep a Deque of Producers. This represents the fair-throttling rotation:
+# the left-most item gets the next upcoming turn, and then they'll be moved
+# to the end of the queue.
+#
+# * keep a set of IPushProducers which are paused, a second set of
+# IPushProducers which are unpaused, and a third set of IPullProducers
+# (which are always left paused) Enforce the invariant that these three
+# sets are disjoint, and that their union equals the contents of the deque.
+#
+# * keep a "paused" flag, which is cleared upon entry to resumeProducing, and
+# set upon entry to pauseProducing. The loop inside resumeProducing checks
+# this flag after each call to producer.resumeProducing, to sense whether
+# they used their turn to write data, and if that write was large enough to
+# fill the TCP window. If set, we break out of the loop. If not, we look
+# for the next producer to unpause. The loop finishes when all producers
+# are unpaused (evidenced by the two sets of paused producers being empty)
+#
+# * the "paused" flag also determines whether new IPushProducers are added to
+# the paused or unpaused set (IPullProducers are always added to the
+# pull+paused set). If we have any IPullProducers, we're always in the
+# "writing data too fast" state.
+
+# other approaches that I didn't decide to do at this time (but might use in
+# the future):
+#
+# * use one set instead of two. pros: fewer moving parts. cons: harder to
+# spot decoherence bugs like adding a producer to the deque but forgetting
+# to add it to one of the
+#
+# * use zero sets, and keep the paused-vs-unpaused state in the Subchannel as
+# a visible boolean flag. This conflates Subchannels with their associated
+# Producer (so if we went this way, we should also let them track their own
+# Producer). Our resumeProducing loop ends when 'not any(sc.paused for sc
+# in self._subchannels_with_producers)'. Pros: fewer subchannel->producer
+# mappings lying around to disagree with one another. Cons: exposes a bit
+# too much of the Subchannel internals
+
+
+@attrs
+@implementer(IOutbound)
+class Outbound(object):
+ # Manage outbound data: subchannel writes to us, we write to transport
+ _manager = attrib(validator=provides(IDilationManager))
+ _cooperator = attrib()
+
+ def __attrs_post_init__(self):
+ # _outbound_queue holds all messages we've ever sent but not retired
+ self._outbound_queue = deque()
+ self._next_outbound_seqnum = 0
+ # _queued_unsent are messages to retry with our new connection
+ self._queued_unsent = deque()
+
+ # outbound flow control: the Connection throttles our writes
+ self._subchannel_producers = {} # Subchannel -> IProducer
+ self._paused = True # our Connection called our pauseProducing
+ self._all_producers = deque() # rotates, left-is-next
+ self._paused_producers = set()
+ self._unpaused_producers = set()
+ self._check_invariants()
+
+ self._connection = None
+
+ def _check_invariants(self):
+ assert self._unpaused_producers.isdisjoint(self._paused_producers)
+ assert (self._paused_producers.union(self._unpaused_producers) ==
+ set(self._all_producers))
+
+ def build_record(self, record_type, *args):
+ seqnum = self._next_outbound_seqnum
+ self._next_outbound_seqnum += 1
+ r = record_type(seqnum, *args)
+ assert hasattr(r, "seqnum"), r # only Open/Data/Close
+ return r
+
+ def queue_and_send_record(self, r):
+ # we always queue it, to resend on a subsequent connection if
+ # necessary
+ self._outbound_queue.append(r)
+
+ if self._connection:
+ if self._queued_unsent:
+ # to maintain correct ordering, queue this instead of sending it
+ self._queued_unsent.append(r)
+ else:
+ # we're allowed to send it immediately
+ self._connection.send_record(r)
+
+ def send_if_connected(self, r):
+ assert isinstance(r, (KCM, Ping, Pong, Ack)), r # nothing with seqnum
+ if self._connection:
+ self._connection.send_record(r)
+
+ # our subchannels call these to register a producer
+
+ def subchannel_registerProducer(self, sc, producer, streaming):
+ # streaming==True: IPushProducer (pause/resume)
+ # streaming==False: IPullProducer (just resume)
+ if sc in self._subchannel_producers:
+ raise ValueError(
+ "registering producer %s before previous one (%s) was "
+ "unregistered" % (producer,
+ self._subchannel_producers[sc]))
+ # our underlying Connection uses streaming==True, so to make things
+ # easier, use an adapter when the Subchannel asks for streaming=False
+ if not streaming:
+ def unregister():
+ self.subchannel_unregisterProducer(sc)
+ producer = PullToPush(producer, unregister, self._cooperator)
+
+ self._subchannel_producers[sc] = producer
+ self._all_producers.append(producer)
+ if self._paused:
+ self._paused_producers.add(producer)
+ else:
+ self._unpaused_producers.add(producer)
+ self._check_invariants()
+ if streaming:
+ if self._paused:
+ # IPushProducers need to be paused immediately, before they
+ # speak
+ producer.pauseProducing() # you wake up sleeping
+ else:
+ # our PullToPush adapter must be started, but if we're paused then
+ # we tell it to pause before it gets a chance to write anything
+ producer.startStreaming(self._paused)
+
+ def subchannel_unregisterProducer(self, sc):
+ # TODO: what if the subchannel closes, so we unregister their
+ # producer for them, then the application reacts to connectionLost
+ # with a duplicate unregisterProducer?
+ p = self._subchannel_producers.pop(sc)
+ if isinstance(p, PullToPush):
+ p.stopStreaming()
+ self._all_producers.remove(p)
+ self._paused_producers.discard(p)
+ self._unpaused_producers.discard(p)
+ self._check_invariants()
+
+ def subchannel_closed(self, sc):
+ self._check_invariants()
+ if sc in self._subchannel_producers:
+ self.subchannel_unregisterProducer(sc)
+
+ # our Manager tells us when we've got a new Connection to work with
+
+ def use_connection(self, c):
+ self._connection = c
+ assert not self._queued_unsent
+ self._queued_unsent.extend(self._outbound_queue)
+ # the connection can tell us to pause when we send too much data
+ c.registerProducer(self, True) # IPushProducer: pause+resume
+ # send our queued messages
+ self.resumeProducing()
+
+ def stop_using_connection(self):
+ self._connection.unregisterProducer()
+ self._connection = None
+ self._queued_unsent.clear()
+ self.pauseProducing()
+ # TODO: I expect this will call pauseProducing twice: the first time
+ # when we get stopProducing (since we're registere with the
+ # underlying connection as the producer), and again when the manager
+ # notices the connectionLost and calls our _stop_using_connection
+
+ def handle_ack(self, resp_seqnum):
+ # we've received an inbound ack, so retire something
+ while (self._outbound_queue and
+ self._outbound_queue[0].seqnum <= resp_seqnum):
+ self._outbound_queue.popleft()
+ while (self._queued_unsent and
+ self._queued_unsent[0].seqnum <= resp_seqnum):
+ self._queued_unsent.popleft()
+ # Inbound is responsible for tracking the high watermark and deciding
+ # whether to ignore inbound messages or not
+
+
+ # IProducer: the active connection calls these because we used
+ # c.registerProducer to ask for them
+ def pauseProducing(self):
+ if self._paused:
+ return # someone is confused and called us twice
+ self._paused = True
+ for p in self._all_producers:
+ if p in self._unpaused_producers:
+ self._unpaused_producers.remove(p)
+ self._paused_producers.add(p)
+ p.pauseProducing()
+
+ def resumeProducing(self):
+ if not self._paused:
+ return # someone is confused and called us twice
+ self._paused = False
+
+ while not self._paused:
+ if self._queued_unsent:
+ r = self._queued_unsent.popleft()
+ self._connection.send_record(r)
+ continue
+ p = self._get_next_unpaused_producer()
+ if not p:
+ break
+ self._paused_producers.remove(p)
+ self._unpaused_producers.add(p)
+ p.resumeProducing()
+
+ def _get_next_unpaused_producer(self):
+ self._check_invariants()
+ if not self._paused_producers:
+ return None
+ while True:
+ p = self._all_producers[0]
+ self._all_producers.rotate(-1) # p moves to the end of the line
+ # the only unpaused Producers are at the end of the list
+ assert p in self._paused_producers
+ return p
+
+ def stopProducing(self):
+ # we'll hopefully have a new connection to work with in the future,
+ # so we don't shut anything down. We do pause everyone, though.
+ self.pauseProducing()
+
+
+# modelled after twisted.internet._producer_helper._PullToPush , but with a
+# configurable Cooperator, a pause-immediately argument to startStreaming()
+@implementer(IPushProducer)
+@attrs(cmp=False)
+class PullToPush(object):
+ _producer = attrib(validator=provides(IPullProducer))
+ _unregister = attrib(validator=lambda _a,_b,v: callable(v))
+ _cooperator = attrib()
+ _finished = False
+
+ def _pull(self):
+ while True:
+ try:
+ self._producer.resumeProducing()
+ except:
+ log.err(None, "%s failed, producing will be stopped:" %
+ (safe_str(self._producer),))
+ try:
+ self._unregister()
+ # The consumer should now call stopStreaming() on us,
+ # thus stopping the streaming.
+ except:
+ # Since the consumer blew up, we may not have had
+ # stopStreaming() called, so we just stop on our own:
+ log.err(None, "%s failed to unregister producer:" %
+ (safe_str(self._unregister),))
+ self._finished = True
+ return
+ yield None
+
+ def startStreaming(self, paused):
+ self._coopTask = self._cooperator.cooperate(self._pull())
+ if paused:
+ self.pauseProducing() # timer is scheduled, but task is removed
+
+ def stopStreaming(self):
+ if self._finished:
+ return
+ self._finished = True
+ self._coopTask.stop()
+
+
+ def pauseProducing(self):
+ self._coopTask.pause()
+
+
+ def resumeProducing(self):
+ self._coopTask.resume()
+
+
+ def stopProducing(self):
+ self.stopStreaming()
+ self._producer.stopProducing()
diff --git a/src/wormhole/_dilation/roles.py b/src/wormhole/_dilation/roles.py
new file mode 100644
index 0000000..8f9adac
--- /dev/null
+++ b/src/wormhole/_dilation/roles.py
@@ -0,0 +1 @@
+LEADER, FOLLOWER = object(), object()
diff --git a/src/wormhole/_dilation/subchannel.py b/src/wormhole/_dilation/subchannel.py
new file mode 100644
index 0000000..94b4a03
--- /dev/null
+++ b/src/wormhole/_dilation/subchannel.py
@@ -0,0 +1,269 @@
+from attr import attrs, attrib
+from attr.validators import instance_of, provides
+from zope.interface import implementer
+from twisted.internet.defer import Deferred, inlineCallbacks, returnValue, succeed
+from twisted.internet.interfaces import (ITransport, IProducer, IConsumer,
+ IAddress, IListeningPort,
+ IStreamClientEndpoint,
+ IStreamServerEndpoint)
+from twisted.internet.error import ConnectionDone
+from automat import MethodicalMachine
+from .._interfaces import ISubChannel, IDilationManager
+
+@attrs
+class Once(object):
+ _errtype = attrib()
+ def __attrs_post_init__(self):
+ self._called = False
+
+ def __call__(self):
+ if self._called:
+ raise self._errtype()
+ self._called = True
+
+class SingleUseEndpointError(Exception):
+ pass
+
+# created in the (OPEN) state, by either:
+# * receipt of an OPEN message
+# * or local client_endpoint.connect()
+# then transitions are:
+# (OPEN) rx DATA: deliver .dataReceived(), -> (OPEN)
+# (OPEN) rx CLOSE: deliver .connectionLost(), send CLOSE, -> (CLOSED)
+# (OPEN) local .write(): send DATA, -> (OPEN)
+# (OPEN) local .loseConnection(): send CLOSE, -> (CLOSING)
+# (CLOSING) local .write(): error
+# (CLOSING) local .loseConnection(): error
+# (CLOSING) rx DATA: deliver .dataReceived(), -> (CLOSING)
+# (CLOSING) rx CLOSE: deliver .connectionLost(), -> (CLOSED)
+# object is deleted upon transition to (CLOSED)
+
+class AlreadyClosedError(Exception):
+ pass
+
+@implementer(IAddress)
+class _WormholeAddress(object):
+ pass
+
+@implementer(IAddress)
+@attrs
+class _SubchannelAddress(object):
+ _scid = attrib()
+
+
+@attrs
+@implementer(ITransport)
+@implementer(IProducer)
+@implementer(IConsumer)
+@implementer(ISubChannel)
+class SubChannel(object):
+ _id = attrib(validator=instance_of(bytes))
+ _manager = attrib(validator=provides(IDilationManager))
+ _host_addr = attrib(validator=instance_of(_WormholeAddress))
+ _peer_addr = attrib(validator=instance_of(_SubchannelAddress))
+
+ m = MethodicalMachine()
+ set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover
+
+ def __attrs_post_init__(self):
+ #self._mailbox = None
+ #self._pending_outbound = {}
+ #self._processed = set()
+ self._protocol = None
+ self._pending_dataReceived = []
+ self._pending_connectionLost = (False, None)
+
+ @m.state(initial=True)
+ def open(self): pass # pragma: no cover
+
+ @m.state()
+ def closing(): pass # pragma: no cover
+
+ @m.state()
+ def closed(): pass # pragma: no cover
+
+ @m.input()
+ def remote_data(self, data): pass
+ @m.input()
+ def remote_close(self): pass
+
+ @m.input()
+ def local_data(self, data): pass
+ @m.input()
+ def local_close(self): pass
+
+
+ @m.output()
+ def send_data(self, data):
+ self._manager.send_data(self._id, data)
+
+ @m.output()
+ def send_close(self):
+ self._manager.send_close(self._id)
+
+ @m.output()
+ def signal_dataReceived(self, data):
+ if self._protocol:
+ self._protocol.dataReceived(data)
+ else:
+ self._pending_dataReceived.append(data)
+
+ @m.output()
+ def signal_connectionLost(self):
+ if self._protocol:
+ self._protocol.connectionLost(ConnectionDone())
+ else:
+ self._pending_connectionLost = (True, ConnectionDone())
+ self._manager.subchannel_closed(self)
+ # we're deleted momentarily
+
+ @m.output()
+ def error_closed_write(self, data):
+ raise AlreadyClosedError("write not allowed on closed subchannel")
+ @m.output()
+ def error_closed_close(self):
+ raise AlreadyClosedError("loseConnection not allowed on closed subchannel")
+
+ # primary transitions
+ open.upon(remote_data, enter=open, outputs=[signal_dataReceived])
+ open.upon(local_data, enter=open, outputs=[send_data])
+ open.upon(remote_close, enter=closed, outputs=[signal_connectionLost])
+ open.upon(local_close, enter=closing, outputs=[send_close])
+ closing.upon(remote_data, enter=closing, outputs=[signal_dataReceived])
+ closing.upon(remote_close, enter=closed, outputs=[signal_connectionLost])
+
+ # error cases
+ # we won't ever see an OPEN, since L4 will log+ignore those for us
+ closing.upon(local_data, enter=closing, outputs=[error_closed_write])
+ closing.upon(local_close, enter=closing, outputs=[error_closed_close])
+ # the CLOSED state won't ever see messages, since we'll be deleted
+
+ # our endpoints use this
+
+ def _set_protocol(self, protocol):
+ assert not self._protocol
+ self._protocol = protocol
+ if self._pending_dataReceived:
+ for data in self._pending_dataReceived:
+ self._protocol.dataReceived(data)
+ self._pending_dataReceived = []
+ cl, what = self._pending_connectionLost
+ if cl:
+ self._protocol.connectionLost(what)
+
+ # ITransport
+ def write(self, data):
+ assert isinstance(data, type(b""))
+ self.local_data(data)
+ def writeSequence(self, iovec):
+ self.write(b"".join(iovec))
+ def loseConnection(self):
+ self.local_close()
+ def getHost(self):
+ # we define "host addr" as the overall wormhole
+ return self._host_addr
+ def getPeer(self):
+ # and "peer addr" as the subchannel within that wormhole
+ return self._peer_addr
+
+ # IProducer: throttle inbound data (wormhole "up" to local app's Protocol)
+ def stopProducing(self):
+ self._manager.subchannel_stopProducing(self)
+ def pauseProducing(self):
+ self._manager.subchannel_pauseProducing(self)
+ def resumeProducing(self):
+ self._manager.subchannel_resumeProducing(self)
+
+ # IConsumer: allow the wormhole to throttle outbound data (app->wormhole)
+ def registerProducer(self, producer, streaming):
+ self._manager.subchannel_registerProducer(self, producer, streaming)
+ def unregisterProducer(self):
+ self._manager.subchannel_unregisterProducer(self)
+
+
+@implementer(IStreamClientEndpoint)
+class ControlEndpoint(object):
+ _used = False
+ def __init__(self, peer_addr):
+ self._subchannel_zero = Deferred()
+ self._peer_addr = peer_addr
+ self._once = Once(SingleUseEndpointError)
+
+ # from manager
+ def _subchannel_zero_opened(self, subchannel):
+ assert ISubChannel.providedBy(subchannel), subchannel
+ self._subchannel_zero.callback(subchannel)
+
+ @inlineCallbacks
+ def connect(self, protocolFactory):
+ # return Deferred that fires with IProtocol or Failure(ConnectError)
+ self._once()
+ t = yield self._subchannel_zero
+ p = protocolFactory.buildProtocol(self._peer_addr)
+ t._set_protocol(p)
+ p.makeConnection(t) # set p.transport = t and call connectionMade()
+ returnValue(p)
+
+@implementer(IStreamClientEndpoint)
+@attrs
+class SubchannelConnectorEndpoint(object):
+ _manager = attrib(validator=provides(IDilationManager))
+ _host_addr = attrib(validator=instance_of(_WormholeAddress))
+
+ def connect(self, protocolFactory):
+ # return Deferred that fires with IProtocol or Failure(ConnectError)
+ scid = self._manager.allocate_subchannel_id()
+ self._manager.send_open(scid)
+ peer_addr = _SubchannelAddress(scid)
+ # ? f.doStart()
+ # ? f.startedConnecting(CONNECTOR) # ??
+ t = SubChannel(scid, self._manager, self._host_addr, peer_addr)
+ p = protocolFactory.buildProtocol(peer_addr)
+ t._set_protocol(p)
+ p.makeConnection(t) # set p.transport = t and call connectionMade()
+ return succeed(p)
+
+@implementer(IStreamServerEndpoint)
+@attrs
+class SubchannelListenerEndpoint(object):
+ _manager = attrib(validator=provides(IDilationManager))
+ _host_addr = attrib(validator=provides(IAddress))
+
+ def __attrs_post_init__(self):
+ self._factory = None
+ self._pending_opens = []
+
+ # from manager
+ def _got_open(self, t, peer_addr):
+ if self._factory:
+ self._connect(t, peer_addr)
+ else:
+ self._pending_opens.append( (t, peer_addr) )
+
+ def _connect(self, t, peer_addr):
+ p = self._factory.buildProtocol(peer_addr)
+ t._set_protocol(p)
+ p.makeConnection(t)
+
+ # IStreamServerEndpoint
+
+ def listen(self, protocolFactory):
+ self._factory = protocolFactory
+ for (t, peer_addr) in self._pending_opens:
+ self._connect(t, peer_addr)
+ self._pending_opens = []
+ lp = SubchannelListeningPort(self._host_addr)
+ return succeed(lp)
+
+@implementer(IListeningPort)
+@attrs
+class SubchannelListeningPort(object):
+ _host_addr = attrib(validator=provides(IAddress))
+
+ def startListening(self):
+ pass
+ def stopListening(self):
+ # TODO
+ pass
+ def getHost(self):
+ return self._host_addr
diff --git a/src/wormhole/_interfaces.py b/src/wormhole/_interfaces.py
index 9f59ff4..52bde35 100644
--- a/src/wormhole/_interfaces.py
+++ b/src/wormhole/_interfaces.py
@@ -433,3 +433,16 @@ class IInputHelper(Interface):
class IJournal(Interface): # TODO: this needs to be public
pass
+
+class IDilator(Interface):
+ pass
+class IDilationManager(Interface):
+ pass
+class IDilationConnector(Interface):
+ pass
+class ISubChannel(Interface):
+ pass
+class IInbound(Interface):
+ pass
+class IOutbound(Interface):
+ pass
diff --git a/src/wormhole/test/dilate/__init__.py b/src/wormhole/test/dilate/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/wormhole/test/dilate/common.py b/src/wormhole/test/dilate/common.py
new file mode 100644
index 0000000..2ddacfb
--- /dev/null
+++ b/src/wormhole/test/dilate/common.py
@@ -0,0 +1,18 @@
+from __future__ import print_function, unicode_literals
+import mock
+from zope.interface import alsoProvides
+from ..._interfaces import IDilationManager, IWormhole
+
+def mock_manager():
+ m = mock.Mock()
+ alsoProvides(m, IDilationManager)
+ return m
+
+def mock_wormhole():
+ m = mock.Mock()
+ alsoProvides(m, IWormhole)
+ return m
+
+def clear_mock_calls(*args):
+ for a in args:
+ a.mock_calls[:] = []
diff --git a/src/wormhole/test/dilate/test_connection.py b/src/wormhole/test/dilate/test_connection.py
new file mode 100644
index 0000000..406e42a
--- /dev/null
+++ b/src/wormhole/test/dilate/test_connection.py
@@ -0,0 +1,216 @@
+from __future__ import print_function, unicode_literals
+import mock
+from zope.interface import alsoProvides
+from twisted.trial import unittest
+from twisted.internet.task import Clock
+from twisted.internet.interfaces import ITransport
+from ...eventual import EventualQueue
+from ..._interfaces import IDilationConnector
+from ..._dilation.roles import LEADER, FOLLOWER
+from ..._dilation.connection import (DilatedConnectionProtocol, encode_record,
+ KCM, Open, Ack)
+from .common import clear_mock_calls
+
+def make_con(role, use_relay=False):
+ clock = Clock()
+ eq = EventualQueue(clock)
+ connector = mock.Mock()
+ alsoProvides(connector, IDilationConnector)
+ n = mock.Mock() # pretends to be a Noise object
+ n.write_message = mock.Mock(side_effect=[b"handshake"])
+ c = DilatedConnectionProtocol(eq, role, connector, n,
+ b"outbound_prologue\n", b"inbound_prologue\n")
+ if use_relay:
+ c.use_relay(b"relay_handshake\n")
+ t = mock.Mock()
+ alsoProvides(t, ITransport)
+ return c, n, connector, t, eq
+
+class Connection(unittest.TestCase):
+ def test_bad_prologue(self):
+ c, n, connector, t, eq = make_con(LEADER)
+ c.makeConnection(t)
+ d = c.when_disconnected()
+ self.assertEqual(n.mock_calls, [mock.call.start_handshake()])
+ self.assertEqual(connector.mock_calls, [])
+ self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")])
+ clear_mock_calls(n, connector, t)
+
+ c.dataReceived(b"prologue\n")
+ self.assertEqual(n.mock_calls, [])
+ self.assertEqual(connector.mock_calls, [])
+ self.assertEqual(t.mock_calls, [mock.call.loseConnection()])
+
+ eq.flush_sync()
+ self.assertNoResult(d)
+ c.connectionLost(b"why")
+ eq.flush_sync()
+ self.assertIdentical(self.successResultOf(d), c)
+
+ def _test_no_relay(self, role):
+ c, n, connector, t, eq = make_con(role)
+ t_kcm = KCM()
+ t_open = Open(seqnum=1, scid=0x11223344)
+ t_ack = Ack(resp_seqnum=2)
+ n.decrypt = mock.Mock(side_effect=[
+ encode_record(t_kcm),
+ encode_record(t_open),
+ ])
+ exp_kcm = b"\x00\x00\x00\x03kcm"
+ n.encrypt = mock.Mock(side_effect=[b"kcm", b"ack1"])
+ m = mock.Mock() # Manager
+
+ c.makeConnection(t)
+ self.assertEqual(n.mock_calls, [mock.call.start_handshake()])
+ self.assertEqual(connector.mock_calls, [])
+ self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")])
+ clear_mock_calls(n, connector, t, m)
+
+ c.dataReceived(b"inbound_prologue\n")
+ self.assertEqual(n.mock_calls, [mock.call.write_message()])
+ self.assertEqual(connector.mock_calls, [])
+ exp_handshake = b"\x00\x00\x00\x09handshake"
+ self.assertEqual(t.mock_calls, [mock.call.write(exp_handshake)])
+ clear_mock_calls(n, connector, t, m)
+
+ c.dataReceived(b"\x00\x00\x00\x0Ahandshake2")
+ if role is LEADER:
+ # we're the leader, so we don't send the KCM right away
+ self.assertEqual(n.mock_calls, [
+ mock.call.read_message(b"handshake2")])
+ self.assertEqual(connector.mock_calls, [])
+ self.assertEqual(t.mock_calls, [])
+ self.assertEqual(c._manager, None)
+ else:
+ # we're the follower, so we encrypt and send the KCM immediately
+ self.assertEqual(n.mock_calls, [
+ mock.call.read_message(b"handshake2"),
+ mock.call.encrypt(encode_record(t_kcm)),
+ ])
+ self.assertEqual(connector.mock_calls, [])
+ self.assertEqual(t.mock_calls, [
+ mock.call.write(exp_kcm)])
+ self.assertEqual(c._manager, None)
+ clear_mock_calls(n, connector, t, m)
+
+ c.dataReceived(b"\x00\x00\x00\x03KCM")
+ # leader: inbound KCM means we add the candidate
+ # follower: inbound KCM means we've been selected.
+ # in both cases we notify Connector.add_candidate(), and the Connector
+ # decides if/when to call .select()
+
+ self.assertEqual(n.mock_calls, [mock.call.decrypt(b"KCM")])
+ self.assertEqual(connector.mock_calls, [mock.call.add_candidate(c)])
+ self.assertEqual(t.mock_calls, [])
+ clear_mock_calls(n, connector, t, m)
+
+ # now pretend this connection wins (either the Leader decides to use
+ # this one among all the candiates, or we're the Follower and the
+ # Connector is reacting to add_candidate() by recognizing we're the
+ # only candidate there is)
+ c.select(m)
+ self.assertIdentical(c._manager, m)
+ if role is LEADER:
+ # TODO: currently Connector.select_and_stop_remaining() is
+ # responsible for sending the KCM just before calling c.select()
+ # iff we're the LEADER, therefore Connection.select won't send
+ # anything. This should be moved to c.select().
+ self.assertEqual(n.mock_calls, [])
+ self.assertEqual(connector.mock_calls, [])
+ self.assertEqual(t.mock_calls, [])
+ self.assertEqual(m.mock_calls, [])
+
+ c.send_record(KCM())
+ self.assertEqual(n.mock_calls, [
+ mock.call.encrypt(encode_record(t_kcm)),
+ ])
+ self.assertEqual(connector.mock_calls, [])
+ self.assertEqual(t.mock_calls, [mock.call.write(exp_kcm)])
+ self.assertEqual(m.mock_calls, [])
+ else:
+ # follower: we already sent the KCM, do nothing
+ self.assertEqual(n.mock_calls, [])
+ self.assertEqual(connector.mock_calls, [])
+ self.assertEqual(t.mock_calls, [])
+ self.assertEqual(m.mock_calls, [])
+ clear_mock_calls(n, connector, t, m)
+
+ c.dataReceived(b"\x00\x00\x00\x04msg1")
+ self.assertEqual(n.mock_calls, [mock.call.decrypt(b"msg1")])
+ self.assertEqual(connector.mock_calls, [])
+ self.assertEqual(t.mock_calls, [])
+ self.assertEqual(m.mock_calls, [mock.call.got_record(t_open)])
+ clear_mock_calls(n, connector, t, m)
+
+ c.send_record(t_ack)
+ exp_ack = b"\x06\x00\x00\x00\x02"
+ self.assertEqual(n.mock_calls, [mock.call.encrypt(exp_ack)])
+ self.assertEqual(connector.mock_calls, [])
+ self.assertEqual(t.mock_calls, [mock.call.write(b"\x00\x00\x00\x04ack1")])
+ self.assertEqual(m.mock_calls, [])
+ clear_mock_calls(n, connector, t, m)
+
+ c.disconnect()
+ self.assertEqual(n.mock_calls, [])
+ self.assertEqual(connector.mock_calls, [])
+ self.assertEqual(t.mock_calls, [mock.call.loseConnection()])
+ self.assertEqual(m.mock_calls, [])
+ clear_mock_calls(n, connector, t, m)
+
+ def test_no_relay_leader(self):
+ return self._test_no_relay(LEADER)
+
+ def test_no_relay_follower(self):
+ return self._test_no_relay(FOLLOWER)
+
+
+ def test_relay(self):
+ c, n, connector, t, eq = make_con(LEADER, use_relay=True)
+
+ c.makeConnection(t)
+ self.assertEqual(n.mock_calls, [mock.call.start_handshake()])
+ self.assertEqual(connector.mock_calls, [])
+ self.assertEqual(t.mock_calls, [mock.call.write(b"relay_handshake\n")])
+ clear_mock_calls(n, connector, t)
+
+ c.dataReceived(b"ok\n")
+ self.assertEqual(n.mock_calls, [])
+ self.assertEqual(connector.mock_calls, [])
+ self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")])
+ clear_mock_calls(n, connector, t)
+
+ c.dataReceived(b"inbound_prologue\n")
+ self.assertEqual(n.mock_calls, [mock.call.write_message()])
+ self.assertEqual(connector.mock_calls, [])
+ exp_handshake = b"\x00\x00\x00\x09handshake"
+ self.assertEqual(t.mock_calls, [mock.call.write(exp_handshake)])
+ clear_mock_calls(n, connector, t)
+
+ def test_relay_jilted(self):
+ c, n, connector, t, eq = make_con(LEADER, use_relay=True)
+ d = c.when_disconnected()
+
+ c.makeConnection(t)
+ self.assertEqual(n.mock_calls, [mock.call.start_handshake()])
+ self.assertEqual(connector.mock_calls, [])
+ self.assertEqual(t.mock_calls, [mock.call.write(b"relay_handshake\n")])
+ clear_mock_calls(n, connector, t)
+
+ c.connectionLost(b"why")
+ eq.flush_sync()
+ self.assertIdentical(self.successResultOf(d), c)
+
+ def test_relay_bad_response(self):
+ c, n, connector, t, eq = make_con(LEADER, use_relay=True)
+
+ c.makeConnection(t)
+ self.assertEqual(n.mock_calls, [mock.call.start_handshake()])
+ self.assertEqual(connector.mock_calls, [])
+ self.assertEqual(t.mock_calls, [mock.call.write(b"relay_handshake\n")])
+ clear_mock_calls(n, connector, t)
+
+ c.dataReceived(b"not ok\n")
+ self.assertEqual(n.mock_calls, [])
+ self.assertEqual(connector.mock_calls, [])
+ self.assertEqual(t.mock_calls, [mock.call.loseConnection()])
+ clear_mock_calls(n, connector, t)
diff --git a/src/wormhole/test/dilate/test_encoding.py b/src/wormhole/test/dilate/test_encoding.py
new file mode 100644
index 0000000..e2c854e
--- /dev/null
+++ b/src/wormhole/test/dilate/test_encoding.py
@@ -0,0 +1,25 @@
+from __future__ import print_function, unicode_literals
+from twisted.trial import unittest
+from ..._dilation.encode import to_be4, from_be4
+
+class Encoding(unittest.TestCase):
+
+ def test_be4(self):
+ self.assertEqual(to_be4(0), b"\x00\x00\x00\x00")
+ self.assertEqual(to_be4(1), b"\x00\x00\x00\x01")
+ self.assertEqual(to_be4(256), b"\x00\x00\x01\x00")
+ self.assertEqual(to_be4(257), b"\x00\x00\x01\x01")
+ with self.assertRaises(ValueError):
+ to_be4(-1)
+ with self.assertRaises(ValueError):
+ to_be4(2**32)
+
+ self.assertEqual(from_be4(b"\x00\x00\x00\x00"), 0)
+ self.assertEqual(from_be4(b"\x00\x00\x00\x01"), 1)
+ self.assertEqual(from_be4(b"\x00\x00\x01\x00"), 256)
+ self.assertEqual(from_be4(b"\x00\x00\x01\x01"), 257)
+
+ with self.assertRaises(TypeError):
+ from_be4(0)
+ with self.assertRaises(ValueError):
+ from_be4(b"\x01\x00\x00\x00\x00")
diff --git a/src/wormhole/test/dilate/test_endpoints.py b/src/wormhole/test/dilate/test_endpoints.py
new file mode 100644
index 0000000..bd8f995
--- /dev/null
+++ b/src/wormhole/test/dilate/test_endpoints.py
@@ -0,0 +1,97 @@
+from __future__ import print_function, unicode_literals
+import mock
+from zope.interface import alsoProvides
+from twisted.trial import unittest
+from ..._interfaces import ISubChannel
+from ..._dilation.subchannel import (ControlEndpoint,
+ SubchannelConnectorEndpoint,
+ SubchannelListenerEndpoint,
+ SubchannelListeningPort,
+ _WormholeAddress, _SubchannelAddress,
+ SingleUseEndpointError)
+from .common import mock_manager
+
+class Endpoints(unittest.TestCase):
+ def test_control(self):
+ scid0 = b"scid0"
+ peeraddr = _SubchannelAddress(scid0)
+ ep = ControlEndpoint(peeraddr)
+
+ f = mock.Mock()
+ p = mock.Mock()
+ f.buildProtocol = mock.Mock(return_value=p)
+ d = ep.connect(f)
+ self.assertNoResult(d)
+
+ t = mock.Mock()
+ alsoProvides(t, ISubChannel)
+ ep._subchannel_zero_opened(t)
+ self.assertIdentical(self.successResultOf(d), p)
+ self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr)])
+ self.assertEqual(t.mock_calls, [mock.call._set_protocol(p)])
+ self.assertEqual(p.mock_calls, [mock.call.makeConnection(t)])
+
+ d = ep.connect(f)
+ self.failureResultOf(d, SingleUseEndpointError)
+
+ def assert_makeConnection(self, mock_calls):
+ self.assertEqual(len(mock_calls), 1)
+ self.assertEqual(mock_calls[0][0], "makeConnection")
+ self.assertEqual(len(mock_calls[0][1]), 1)
+ return mock_calls[0][1][0]
+
+ def test_connector(self):
+ m = mock_manager()
+ m.allocate_subchannel_id = mock.Mock(return_value=b"scid")
+ hostaddr = _WormholeAddress()
+ peeraddr = _SubchannelAddress(b"scid")
+ ep = SubchannelConnectorEndpoint(m, hostaddr)
+
+ f = mock.Mock()
+ p = mock.Mock()
+ t = mock.Mock()
+ f.buildProtocol = mock.Mock(return_value=p)
+ with mock.patch("wormhole._dilation.subchannel.SubChannel",
+ return_value=t) as sc:
+ d = ep.connect(f)
+ self.assertIdentical(self.successResultOf(d), p)
+ self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr)])
+ self.assertEqual(sc.mock_calls, [mock.call(b"scid", m, hostaddr, peeraddr)])
+ self.assertEqual(t.mock_calls, [mock.call._set_protocol(p)])
+ self.assertEqual(p.mock_calls, [mock.call.makeConnection(t)])
+
+ def test_listener(self):
+ m = mock_manager()
+ m.allocate_subchannel_id = mock.Mock(return_value=b"scid")
+ hostaddr = _WormholeAddress()
+ ep = SubchannelListenerEndpoint(m, hostaddr)
+
+ f = mock.Mock()
+ p1 = mock.Mock()
+ p2 = mock.Mock()
+ f.buildProtocol = mock.Mock(side_effect=[p1, p2])
+
+ # OPEN that arrives before we ep.listen() should be queued
+
+ t1 = mock.Mock()
+ peeraddr1 = _SubchannelAddress(b"peer1")
+ ep._got_open(t1, peeraddr1)
+
+ d = ep.listen(f)
+ lp = self.successResultOf(d)
+ self.assertIsInstance(lp, SubchannelListeningPort)
+
+ self.assertEqual(lp.getHost(), hostaddr)
+ lp.startListening()
+
+ self.assertEqual(t1.mock_calls, [mock.call._set_protocol(p1)])
+ self.assertEqual(p1.mock_calls, [mock.call.makeConnection(t1)])
+
+ t2 = mock.Mock()
+ peeraddr2 = _SubchannelAddress(b"peer2")
+ ep._got_open(t2, peeraddr2)
+
+ self.assertEqual(t2.mock_calls, [mock.call._set_protocol(p2)])
+ self.assertEqual(p2.mock_calls, [mock.call.makeConnection(t2)])
+
+ lp.stopListening() # TODO: should this do more?
diff --git a/src/wormhole/test/dilate/test_framer.py b/src/wormhole/test/dilate/test_framer.py
new file mode 100644
index 0000000..81d4cf9
--- /dev/null
+++ b/src/wormhole/test/dilate/test_framer.py
@@ -0,0 +1,110 @@
+from __future__ import print_function, unicode_literals
+import mock
+from zope.interface import alsoProvides
+from twisted.trial import unittest
+from twisted.internet.interfaces import ITransport
+from ..._dilation.connection import _Framer, Frame, Prologue, Disconnect
+
+def make_framer():
+ t = mock.Mock()
+ alsoProvides(t, ITransport)
+ f = _Framer(t, b"outbound_prologue\n", b"inbound_prologue\n")
+ return f, t
+
+class Framer(unittest.TestCase):
+ def test_bad_prologue_length(self):
+ f, t = make_framer()
+ self.assertEqual(t.mock_calls, [])
+
+ f.connectionMade()
+ self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")])
+ t.mock_calls[:] = []
+ self.assertEqual([], list(f.add_and_parse(b"inbound_"))) # wait for it
+ self.assertEqual(t.mock_calls, [])
+
+ with mock.patch("wormhole._dilation.connection.log.msg") as m:
+ with self.assertRaises(Disconnect):
+ list(f.add_and_parse(b"not the prologue after all"))
+ self.assertEqual(m.mock_calls,
+ [mock.call("bad prologue: {}".format(
+ b"inbound_not the p"))])
+ self.assertEqual(t.mock_calls, [])
+
+ def test_bad_prologue_newline(self):
+ f, t = make_framer()
+ self.assertEqual(t.mock_calls, [])
+
+ f.connectionMade()
+ self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")])
+ t.mock_calls[:] = []
+ self.assertEqual([], list(f.add_and_parse(b"inbound_"))) # wait for it
+
+ self.assertEqual([], list(f.add_and_parse(b"not")))
+ with mock.patch("wormhole._dilation.connection.log.msg") as m:
+ with self.assertRaises(Disconnect):
+ list(f.add_and_parse(b"\n"))
+ self.assertEqual(m.mock_calls,
+ [mock.call("bad prologue: {}".format(
+ b"inbound_not\n"))])
+ self.assertEqual(t.mock_calls, [])
+
+ def test_good_prologue(self):
+ f, t = make_framer()
+ self.assertEqual(t.mock_calls, [])
+
+ f.connectionMade()
+ self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")])
+ t.mock_calls[:] = []
+ self.assertEqual([Prologue()],
+ list(f.add_and_parse(b"inbound_prologue\n")))
+ self.assertEqual(t.mock_calls, [])
+
+ # now send_frame should work
+ f.send_frame(b"frame")
+ self.assertEqual(t.mock_calls,
+ [mock.call.write(b"\x00\x00\x00\x05frame")])
+
+ def test_bad_relay(self):
+ f, t = make_framer()
+ self.assertEqual(t.mock_calls, [])
+ f.use_relay(b"relay handshake\n")
+
+ f.connectionMade()
+ self.assertEqual(t.mock_calls, [mock.call.write(b"relay handshake\n")])
+ t.mock_calls[:] = []
+ with mock.patch("wormhole._dilation.connection.log.msg") as m:
+ with self.assertRaises(Disconnect):
+ list(f.add_and_parse(b"goodbye\n"))
+ self.assertEqual(m.mock_calls,
+ [mock.call("bad relay_ok: {}".format(b"goo"))])
+ self.assertEqual(t.mock_calls, [])
+
+ def test_good_relay(self):
+ f, t = make_framer()
+ self.assertEqual(t.mock_calls, [])
+ f.use_relay(b"relay handshake\n")
+ self.assertEqual(t.mock_calls, [])
+
+ f.connectionMade()
+ self.assertEqual(t.mock_calls, [mock.call.write(b"relay handshake\n")])
+ t.mock_calls[:] = []
+
+ self.assertEqual([], list(f.add_and_parse(b"ok\n")))
+ self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")])
+
+ def test_frame(self):
+ f, t = make_framer()
+ self.assertEqual(t.mock_calls, [])
+
+ f.connectionMade()
+ self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")])
+ t.mock_calls[:] = []
+ self.assertEqual([Prologue()],
+ list(f.add_and_parse(b"inbound_prologue\n")))
+ self.assertEqual(t.mock_calls, [])
+
+ encoded_frame = b"\x00\x00\x00\x05frame"
+ self.assertEqual([], list(f.add_and_parse(encoded_frame[:2])))
+ self.assertEqual([], list(f.add_and_parse(encoded_frame[2:6])))
+ self.assertEqual([Frame(frame=b"frame")],
+ list(f.add_and_parse(encoded_frame[6:])))
diff --git a/src/wormhole/test/dilate/test_inbound.py b/src/wormhole/test/dilate/test_inbound.py
new file mode 100644
index 0000000..d147283
--- /dev/null
+++ b/src/wormhole/test/dilate/test_inbound.py
@@ -0,0 +1,172 @@
+from __future__ import print_function, unicode_literals
+import mock
+from zope.interface import alsoProvides
+from twisted.trial import unittest
+from ..._interfaces import IDilationManager
+from ..._dilation.connection import Open, Data, Close
+from ..._dilation.inbound import (Inbound, DuplicateOpenError,
+ DataForMissingSubchannelError,
+ CloseForMissingSubchannelError)
+
+def make_inbound():
+ m = mock.Mock()
+ alsoProvides(m, IDilationManager)
+ host_addr = object()
+ i = Inbound(m, host_addr)
+ return i, m, host_addr
+
+class InboundTest(unittest.TestCase):
+ def test_seqnum(self):
+ i, m, host_addr = make_inbound()
+ r1 = Open(scid=513, seqnum=1)
+ r2 = Data(scid=513, seqnum=2, data=b"")
+ r3 = Close(scid=513, seqnum=3)
+ self.assertFalse(i.is_record_old(r1))
+ self.assertFalse(i.is_record_old(r2))
+ self.assertFalse(i.is_record_old(r3))
+
+ i.update_ack_watermark(r1)
+ self.assertTrue(i.is_record_old(r1))
+ self.assertFalse(i.is_record_old(r2))
+ self.assertFalse(i.is_record_old(r3))
+
+ i.update_ack_watermark(r2)
+ self.assertTrue(i.is_record_old(r1))
+ self.assertTrue(i.is_record_old(r2))
+ self.assertFalse(i.is_record_old(r3))
+
+ def test_open_data_close(self):
+ i, m, host_addr = make_inbound()
+ scid1 = b"scid"
+ scid2 = b"scXX"
+ c = mock.Mock()
+ lep = mock.Mock()
+ i.set_listener_endpoint(lep)
+ i.use_connection(c)
+ sc1 = mock.Mock()
+ peer_addr = object()
+ with mock.patch("wormhole._dilation.inbound.SubChannel",
+ side_effect=[sc1]) as sc:
+ with mock.patch("wormhole._dilation.inbound._SubchannelAddress",
+ side_effect=[peer_addr]) as sca:
+ i.handle_open(scid1)
+ self.assertEqual(lep.mock_calls, [mock.call._got_open(sc1, peer_addr)])
+ self.assertEqual(sc.mock_calls, [mock.call(scid1, m, host_addr, peer_addr)])
+ self.assertEqual(sca.mock_calls, [mock.call(scid1)])
+ lep.mock_calls[:] = []
+
+ # a subsequent duplicate OPEN should be ignored
+ with mock.patch("wormhole._dilation.inbound.SubChannel",
+ side_effect=[sc1]) as sc:
+ with mock.patch("wormhole._dilation.inbound._SubchannelAddress",
+ side_effect=[peer_addr]) as sca:
+ i.handle_open(scid1)
+ self.assertEqual(lep.mock_calls, [])
+ self.assertEqual(sc.mock_calls, [])
+ self.assertEqual(sca.mock_calls, [])
+ self.flushLoggedErrors(DuplicateOpenError)
+
+ i.handle_data(scid1, b"data")
+ self.assertEqual(sc1.mock_calls, [mock.call.remote_data(b"data")])
+ sc1.mock_calls[:] = []
+
+ i.handle_data(scid2, b"for non-existent subchannel")
+ self.assertEqual(sc1.mock_calls, [])
+ self.flushLoggedErrors(DataForMissingSubchannelError)
+
+ i.handle_close(scid1)
+ self.assertEqual(sc1.mock_calls, [mock.call.remote_close()])
+ sc1.mock_calls[:] = []
+
+ i.handle_close(scid2)
+ self.assertEqual(sc1.mock_calls, [])
+ self.flushLoggedErrors(CloseForMissingSubchannelError)
+
+ # after the subchannel is closed, the Manager will notify Inbound
+ i.subchannel_closed(scid1, sc1)
+
+ i.stop_using_connection()
+
+ def test_control_channel(self):
+ i, m, host_addr = make_inbound()
+ lep = mock.Mock()
+ i.set_listener_endpoint(lep)
+
+ scid0 = b"scid"
+ sc0 = mock.Mock()
+ i.set_subchannel_zero(scid0, sc0)
+
+ # OPEN on the control channel identifier should be ignored as a
+ # duplicate, since the control channel is already registered
+ sc1 = mock.Mock()
+ peer_addr = object()
+ with mock.patch("wormhole._dilation.inbound.SubChannel",
+ side_effect=[sc1]) as sc:
+ with mock.patch("wormhole._dilation.inbound._SubchannelAddress",
+ side_effect=[peer_addr]) as sca:
+ i.handle_open(scid0)
+ self.assertEqual(lep.mock_calls, [])
+ self.assertEqual(sc.mock_calls, [])
+ self.assertEqual(sca.mock_calls, [])
+ self.flushLoggedErrors(DuplicateOpenError)
+
+ # and DATA to it should be delivered correctly
+ i.handle_data(scid0, b"data")
+ self.assertEqual(sc0.mock_calls, [mock.call.remote_data(b"data")])
+ sc0.mock_calls[:] = []
+
+ def test_pause(self):
+ i, m, host_addr = make_inbound()
+ c = mock.Mock()
+ lep = mock.Mock()
+ i.set_listener_endpoint(lep)
+
+ # add two subchannels, pause one, then add a connection
+ scid1 = b"sci1"
+ scid2 = b"sci2"
+ sc1 = mock.Mock()
+ sc2 = mock.Mock()
+ peer_addr = object()
+ with mock.patch("wormhole._dilation.inbound.SubChannel",
+ side_effect=[sc1, sc2]):
+ with mock.patch("wormhole._dilation.inbound._SubchannelAddress",
+ return_value=peer_addr):
+ i.handle_open(scid1)
+ i.handle_open(scid2)
+ self.assertEqual(c.mock_calls, [])
+
+ i.subchannel_pauseProducing(sc1)
+ self.assertEqual(c.mock_calls, [])
+ i.subchannel_resumeProducing(sc1)
+ self.assertEqual(c.mock_calls, [])
+ i.subchannel_pauseProducing(sc1)
+ self.assertEqual(c.mock_calls, [])
+
+ i.use_connection(c)
+ self.assertEqual(c.mock_calls, [mock.call.pauseProducing()])
+ c.mock_calls[:] = []
+
+ i.subchannel_resumeProducing(sc1)
+ self.assertEqual(c.mock_calls, [mock.call.resumeProducing()])
+ c.mock_calls[:] = []
+
+ # consumers aren't really supposed to do this, but tolerate it
+ i.subchannel_resumeProducing(sc1)
+ self.assertEqual(c.mock_calls, [])
+
+ i.subchannel_pauseProducing(sc1)
+ self.assertEqual(c.mock_calls, [mock.call.pauseProducing()])
+ c.mock_calls[:] = []
+ i.subchannel_pauseProducing(sc2)
+ self.assertEqual(c.mock_calls, []) # was already paused
+
+ # tolerate duplicate pauseProducing
+ i.subchannel_pauseProducing(sc2)
+ self.assertEqual(c.mock_calls, [])
+
+ # stopProducing is treated like a terminal resumeProducing
+ i.subchannel_stopProducing(sc1)
+ self.assertEqual(c.mock_calls, [])
+ i.subchannel_stopProducing(sc2)
+ self.assertEqual(c.mock_calls, [mock.call.resumeProducing()])
+ c.mock_calls[:] = []
diff --git a/src/wormhole/test/dilate/test_manager.py b/src/wormhole/test/dilate/test_manager.py
new file mode 100644
index 0000000..625039d
--- /dev/null
+++ b/src/wormhole/test/dilate/test_manager.py
@@ -0,0 +1,205 @@
+from __future__ import print_function, unicode_literals
+from zope.interface import alsoProvides
+from twisted.trial import unittest
+from twisted.internet.defer import Deferred
+from twisted.internet.task import Clock, Cooperator
+import mock
+from ...eventual import EventualQueue
+from ..._interfaces import ISend, IDilationManager
+from ...util import dict_to_bytes
+from ..._dilation.manager import (Dilator,
+ OldPeerCannotDilateError,
+ UnknownDilationMessageType)
+from ..._dilation.subchannel import _WormholeAddress
+from .common import clear_mock_calls
+
+def make_dilator():
+ reactor = object()
+ clock = Clock()
+ eq = EventualQueue(clock)
+ term = mock.Mock(side_effect=lambda: True) # one write per Eventual tick
+ term_factory = lambda: term
+ coop = Cooperator(terminationPredicateFactory=term_factory,
+ scheduler=eq.eventually)
+ send = mock.Mock()
+ alsoProvides(send, ISend)
+ dil = Dilator(reactor, eq, coop)
+ dil.wire(send)
+ return dil, send, reactor, eq, clock, coop
+
+class TestDilator(unittest.TestCase):
+ def test_leader(self):
+ dil, send, reactor, eq, clock, coop = make_dilator()
+ d1 = dil.dilate()
+ d2 = dil.dilate()
+ self.assertNoResult(d1)
+ self.assertNoResult(d2)
+
+ key = b"key"
+ transit_key = object()
+ with mock.patch("wormhole._dilation.manager.derive_key",
+ return_value=transit_key) as dk:
+ dil.got_key(key)
+ self.assertEqual(dk.mock_calls, [mock.call(key, b"dilation-v1", 32)])
+ self.assertIdentical(dil._transit_key, transit_key)
+ self.assertNoResult(d1)
+ self.assertNoResult(d2)
+
+ m = mock.Mock()
+ alsoProvides(m, IDilationManager)
+ m.when_first_connected.return_value = wfc_d = Deferred()
+ # TODO: test missing can-dilate, and no-overlap
+ with mock.patch("wormhole._dilation.manager.ManagerLeader",
+ return_value=m) as ml:
+ dil.got_wormhole_versions("us", "them", {"can-dilate": [1]})
+ # that should create the Manager. Because "us" > "them", we're
+ # the leader
+ self.assertEqual(ml.mock_calls, [mock.call(send, "us", transit_key,
+ None, reactor, eq, coop)])
+ self.assertEqual(m.mock_calls, [mock.call.start(),
+ mock.call.when_first_connected(),
+ ])
+ clear_mock_calls(m)
+ self.assertNoResult(d1)
+ self.assertNoResult(d2)
+
+ host_addr = _WormholeAddress()
+ m_wa = mock.patch("wormhole._dilation.manager._WormholeAddress",
+ return_value=host_addr)
+ peer_addr = object()
+ m_sca = mock.patch("wormhole._dilation.manager._SubchannelAddress",
+ return_value=peer_addr)
+ ce = mock.Mock()
+ m_ce = mock.patch("wormhole._dilation.manager.ControlEndpoint",
+ return_value=ce)
+ sc = mock.Mock()
+ m_sc = mock.patch("wormhole._dilation.manager.SubChannel",
+ return_value=sc)
+
+ lep = object()
+ m_sle = mock.patch("wormhole._dilation.manager.SubchannelListenerEndpoint",
+ return_value=lep)
+
+ with m_wa, m_sca, m_ce as m_ce_m, m_sc as m_sc_m, m_sle as m_sle_m:
+ wfc_d.callback(None)
+ eq.flush_sync()
+ scid0 = b"\x00\x00\x00\x00"
+ self.assertEqual(m_ce_m.mock_calls, [mock.call(peer_addr)])
+ self.assertEqual(m_sc_m.mock_calls,
+ [mock.call(scid0, m, host_addr, peer_addr)])
+ self.assertEqual(ce.mock_calls, [mock.call._subchannel_zero_opened(sc)])
+ self.assertEqual(m_sle_m.mock_calls, [mock.call(m, host_addr)])
+ self.assertEqual(m.mock_calls,
+ [mock.call.set_subchannel_zero(scid0, sc),
+ mock.call.set_listener_endpoint(lep),
+ ])
+ clear_mock_calls(m)
+
+ eps = self.successResultOf(d1)
+ self.assertEqual(eps, self.successResultOf(d2))
+ d3 = dil.dilate()
+ eq.flush_sync()
+ self.assertEqual(eps, self.successResultOf(d3))
+
+ self.assertEqual(m.mock_calls, [])
+ dil.received_dilate(dict_to_bytes(dict(type="please")))
+ self.assertEqual(m.mock_calls, [mock.call.rx_PLEASE()])
+ clear_mock_calls(m)
+
+ hintmsg = dict(type="connection-hints")
+ dil.received_dilate(dict_to_bytes(hintmsg))
+ self.assertEqual(m.mock_calls, [mock.call.rx_HINTS(hintmsg)])
+ clear_mock_calls(m)
+
+ dil.received_dilate(dict_to_bytes(dict(type="dilate")))
+ self.assertEqual(m.mock_calls, [mock.call.rx_DILATE()])
+ clear_mock_calls(m)
+
+ dil.received_dilate(dict_to_bytes(dict(type="unknown")))
+ self.assertEqual(m.mock_calls, [])
+ self.flushLoggedErrors(UnknownDilationMessageType)
+
+ def test_follower(self):
+ dil, send, reactor, eq, clock, coop = make_dilator()
+ d1 = dil.dilate()
+ self.assertNoResult(d1)
+
+ key = b"key"
+ transit_key = object()
+ with mock.patch("wormhole._dilation.manager.derive_key",
+ return_value=transit_key):
+ dil.got_key(key)
+
+ m = mock.Mock()
+ alsoProvides(m, IDilationManager)
+ m.when_first_connected.return_value = Deferred()
+ with mock.patch("wormhole._dilation.manager.ManagerFollower",
+ return_value=m) as mf:
+ dil.got_wormhole_versions("me", "you", {"can-dilate": [1]})
+ # "me" < "you", so we're the follower
+ self.assertEqual(mf.mock_calls, [mock.call(send, "me", transit_key,
+ None, reactor, eq, coop)])
+ self.assertEqual(m.mock_calls, [mock.call.start(),
+ mock.call.when_first_connected(),
+ ])
+
+ def test_peer_cannot_dilate(self):
+ dil, send, reactor, eq, clock, coop = make_dilator()
+ d1 = dil.dilate()
+ self.assertNoResult(d1)
+
+ dil.got_wormhole_versions("me", "you", {}) # missing "can-dilate"
+ eq.flush_sync()
+ f = self.failureResultOf(d1)
+ f.check(OldPeerCannotDilateError)
+
+
+ def test_disjoint_versions(self):
+ dil, send, reactor, eq, clock, coop = make_dilator()
+ d1 = dil.dilate()
+ self.assertNoResult(d1)
+
+ dil.got_wormhole_versions("me", "you", {"can-dilate": [-1]})
+ eq.flush_sync()
+ f = self.failureResultOf(d1)
+ f.check(OldPeerCannotDilateError)
+
+
+ def test_early_dilate_messages(self):
+ dil, send, reactor, eq, clock, coop = make_dilator()
+ dil._transit_key = b"key"
+ d1 = dil.dilate()
+ self.assertNoResult(d1)
+ dil.received_dilate(dict_to_bytes(dict(type="please")))
+ hintmsg = dict(type="connection-hints")
+ dil.received_dilate(dict_to_bytes(hintmsg))
+
+ m = mock.Mock()
+ alsoProvides(m, IDilationManager)
+ m.when_first_connected.return_value = Deferred()
+
+ with mock.patch("wormhole._dilation.manager.ManagerLeader",
+ return_value=m) as ml:
+ dil.got_wormhole_versions("us", "them", {"can-dilate": [1]})
+ self.assertEqual(ml.mock_calls, [mock.call(send, "us", b"key",
+ None, reactor, eq, coop)])
+ self.assertEqual(m.mock_calls, [mock.call.start(),
+ mock.call.rx_PLEASE(),
+ mock.call.rx_HINTS(hintmsg),
+ mock.call.when_first_connected()])
+
+
+
+ def test_transit_relay(self):
+ dil, send, reactor, eq, clock, coop = make_dilator()
+ dil._transit_key = b"key"
+ relay = object()
+ d1 = dil.dilate(transit_relay_location=relay)
+ self.assertNoResult(d1)
+
+ with mock.patch("wormhole._dilation.manager.ManagerLeader") as ml:
+ dil.got_wormhole_versions("us", "them", {"can-dilate": [1]})
+ self.assertEqual(ml.mock_calls, [mock.call(send, "us", b"key",
+ relay, reactor, eq, coop),
+ mock.call().start(),
+ mock.call().when_first_connected()])
diff --git a/src/wormhole/test/dilate/test_outbound.py b/src/wormhole/test/dilate/test_outbound.py
new file mode 100644
index 0000000..fab596a
--- /dev/null
+++ b/src/wormhole/test/dilate/test_outbound.py
@@ -0,0 +1,645 @@
+from __future__ import print_function, unicode_literals
+from collections import namedtuple
+from itertools import cycle
+import mock
+from zope.interface import alsoProvides
+from twisted.trial import unittest
+from twisted.internet.task import Clock, Cooperator
+from twisted.internet.interfaces import IPullProducer
+from ...eventual import EventualQueue
+from ..._interfaces import IDilationManager
+from ..._dilation.connection import KCM, Open, Data, Close, Ack
+from ..._dilation.outbound import Outbound, PullToPush
+from .common import clear_mock_calls
+
+Pauser = namedtuple("Pauser", ["seqnum"])
+NonPauser = namedtuple("NonPauser", ["seqnum"])
+Stopper = namedtuple("Stopper", ["sc"])
+
+def make_outbound():
+ m = mock.Mock()
+ alsoProvides(m, IDilationManager)
+ clock = Clock()
+ eq = EventualQueue(clock)
+ term = mock.Mock(side_effect=lambda: True) # one write per Eventual tick
+ term_factory = lambda: term
+ coop = Cooperator(terminationPredicateFactory=term_factory,
+ scheduler=eq.eventually)
+ o = Outbound(m, coop)
+ c = mock.Mock() # Connection
+ def maybe_pause(r):
+ if isinstance(r, Pauser):
+ o.pauseProducing()
+ elif isinstance(r, Stopper):
+ o.subchannel_unregisterProducer(r.sc)
+ c.send_record = mock.Mock(side_effect=maybe_pause)
+ o._test_eq = eq
+ o._test_term = term
+ return o, m, c
+
+class OutboundTest(unittest.TestCase):
+ def test_build_record(self):
+ o, m, c = make_outbound()
+ scid1 = b"scid"
+ self.assertEqual(o.build_record(Open, scid1),
+ Open(seqnum=0, scid=b"scid"))
+ self.assertEqual(o.build_record(Data, scid1, b"dataaa"),
+ Data(seqnum=1, scid=b"scid", data=b"dataaa"))
+ self.assertEqual(o.build_record(Close, scid1),
+ Close(seqnum=2, scid=b"scid"))
+ self.assertEqual(o.build_record(Close, scid1),
+ Close(seqnum=3, scid=b"scid"))
+
+ def test_outbound_queue(self):
+ o, m, c = make_outbound()
+ scid1 = b"scid"
+ r1 = o.build_record(Open, scid1)
+ r2 = o.build_record(Data, scid1, b"data1")
+ r3 = o.build_record(Data, scid1, b"data2")
+ o.queue_and_send_record(r1)
+ o.queue_and_send_record(r2)
+ o.queue_and_send_record(r3)
+ self.assertEqual(list(o._outbound_queue), [r1, r2, r3])
+
+ # we would never normally receive an ACK without first getting a
+ # connection
+ o.handle_ack(r2.seqnum)
+ self.assertEqual(list(o._outbound_queue), [r3])
+
+ o.handle_ack(r3.seqnum)
+ self.assertEqual(list(o._outbound_queue), [])
+
+ o.handle_ack(r3.seqnum) # ignored
+ self.assertEqual(list(o._outbound_queue), [])
+
+ o.handle_ack(r1.seqnum) # ignored
+ self.assertEqual(list(o._outbound_queue), [])
+
+ def test_duplicate_registerProducer(self):
+ o, m, c = make_outbound()
+ sc1 = object()
+ p1 = mock.Mock()
+ o.subchannel_registerProducer(sc1, p1, True)
+ with self.assertRaises(ValueError) as ar:
+ o.subchannel_registerProducer(sc1, p1, True)
+ s = str(ar.exception)
+ self.assertIn("registering producer", s)
+ self.assertIn("before previous one", s)
+ self.assertIn("was unregistered", s)
+
+ def test_connection_send_queued_unpaused(self):
+ o, m, c = make_outbound()
+ scid1 = b"scid"
+ r1 = o.build_record(Open, scid1)
+ r2 = o.build_record(Data, scid1, b"data1")
+ r3 = o.build_record(Data, scid1, b"data2")
+ o.queue_and_send_record(r1)
+ o.queue_and_send_record(r2)
+ self.assertEqual(list(o._outbound_queue), [r1, r2])
+ self.assertEqual(list(o._queued_unsent), [])
+
+ # as soon as the connection is established, everything is sent
+ o.use_connection(c)
+ self.assertEqual(c.mock_calls, [mock.call.registerProducer(o, True),
+ mock.call.send_record(r1),
+ mock.call.send_record(r2)])
+ self.assertEqual(list(o._outbound_queue), [r1, r2])
+ self.assertEqual(list(o._queued_unsent), [])
+ clear_mock_calls(c)
+
+ o.queue_and_send_record(r3)
+ self.assertEqual(list(o._outbound_queue), [r1, r2, r3])
+ self.assertEqual(list(o._queued_unsent), [])
+ self.assertEqual(c.mock_calls, [mock.call.send_record(r3)])
+
+ def test_connection_send_queued_paused(self):
+ o, m, c = make_outbound()
+ r1 = Pauser(seqnum=1)
+ r2 = Pauser(seqnum=2)
+ r3 = Pauser(seqnum=3)
+ o.queue_and_send_record(r1)
+ o.queue_and_send_record(r2)
+ self.assertEqual(list(o._outbound_queue), [r1, r2])
+ self.assertEqual(list(o._queued_unsent), [])
+
+ # pausing=True, so our mock Manager will pause the Outbound producer
+ # after each write. So only r1 should have been sent before getting
+ # paused
+ o.use_connection(c)
+ self.assertEqual(c.mock_calls, [mock.call.registerProducer(o, True),
+ mock.call.send_record(r1)])
+ self.assertEqual(list(o._outbound_queue), [r1, r2])
+ self.assertEqual(list(o._queued_unsent), [r2])
+ clear_mock_calls(c)
+
+ # Outbound is responsible for sending all records, so when Manager
+ # wants to send a new one, and Outbound is still in the middle of
+ # draining the beginning-of-connection queue, the new message gets
+ # queued behind the rest (in addition to being queued in
+ # _outbound_queue until an ACK retires it).
+ o.queue_and_send_record(r3)
+ self.assertEqual(list(o._outbound_queue), [r1, r2, r3])
+ self.assertEqual(list(o._queued_unsent), [r2, r3])
+ self.assertEqual(c.mock_calls, [])
+
+ o.handle_ack(r1.seqnum)
+ self.assertEqual(list(o._outbound_queue), [r2, r3])
+ self.assertEqual(list(o._queued_unsent), [r2, r3])
+ self.assertEqual(c.mock_calls, [])
+
+ def test_premptive_ack(self):
+ # one mode I have in mind is for each side to send an immediate ACK,
+ # with everything they've ever seen, as the very first message on each
+ # new connection. The idea is that you might preempt sending stuff from
+ # the _queued_unsent list if it arrives fast enough (in practice this
+ # is more likely to be delivered via the DILATE mailbox message, but
+ # the effects might be vaguely similar, so it seems worth testing
+ # here). A similar situation would be if each side sends ACKs with the
+ # highest seqnum they've ever seen, instead of merely ACKing the
+ # message which was just received.
+ o, m, c = make_outbound()
+ r1 = Pauser(seqnum=1)
+ r2 = Pauser(seqnum=2)
+ r3 = Pauser(seqnum=3)
+ o.queue_and_send_record(r1)
+ o.queue_and_send_record(r2)
+ self.assertEqual(list(o._outbound_queue), [r1, r2])
+ self.assertEqual(list(o._queued_unsent), [])
+
+ o.use_connection(c)
+ self.assertEqual(c.mock_calls, [mock.call.registerProducer(o, True),
+ mock.call.send_record(r1)])
+ self.assertEqual(list(o._outbound_queue), [r1, r2])
+ self.assertEqual(list(o._queued_unsent), [r2])
+ clear_mock_calls(c)
+
+ o.queue_and_send_record(r3)
+ self.assertEqual(list(o._outbound_queue), [r1, r2, r3])
+ self.assertEqual(list(o._queued_unsent), [r2, r3])
+ self.assertEqual(c.mock_calls, [])
+
+ o.handle_ack(r2.seqnum)
+ self.assertEqual(list(o._outbound_queue), [r3])
+ self.assertEqual(list(o._queued_unsent), [r3])
+ self.assertEqual(c.mock_calls, [])
+
+ def test_pause(self):
+ o, m, c = make_outbound()
+ o.use_connection(c)
+ self.assertEqual(c.mock_calls, [mock.call.registerProducer(o, True)])
+ self.assertEqual(list(o._outbound_queue), [])
+ self.assertEqual(list(o._queued_unsent), [])
+ clear_mock_calls(c)
+
+ sc1, sc2, sc3 = object(), object(), object()
+ p1, p2, p3 = mock.Mock(name="p1"), mock.Mock(name="p2"), mock.Mock(name="p3")
+
+ # we aren't paused yet, since we haven't sent any data
+ o.subchannel_registerProducer(sc1, p1, True)
+ self.assertEqual(p1.mock_calls, [])
+
+ r1 = Pauser(seqnum=1)
+ o.queue_and_send_record(r1)
+ # now we should be paused
+ self.assertTrue(o._paused)
+ self.assertEqual(c.mock_calls, [mock.call.send_record(r1)])
+ self.assertEqual(p1.mock_calls, [mock.call.pauseProducing()])
+ clear_mock_calls(p1, c)
+
+ # so an IPushProducer will be paused right away
+ o.subchannel_registerProducer(sc2, p2, True)
+ self.assertEqual(p2.mock_calls, [mock.call.pauseProducing()])
+ clear_mock_calls(p2)
+
+ o.subchannel_registerProducer(sc3, p3, True)
+ self.assertEqual(p3.mock_calls, [mock.call.pauseProducing()])
+ self.assertEqual(o._paused_producers, set([p1, p2, p3]))
+ self.assertEqual(list(o._all_producers), [p1, p2, p3])
+ clear_mock_calls(p3)
+
+ # one resumeProducing should cause p1 to get a turn, since p2 was added
+ # after we were paused and p1 was at the "end" of a one-element list.
+ # If it writes anything, it will get paused again immediately.
+ r2 = Pauser(seqnum=2)
+ p1.resumeProducing.side_effect = lambda: c.send_record(r2)
+ o.resumeProducing()
+ self.assertEqual(p1.mock_calls, [mock.call.resumeProducing(),
+ mock.call.pauseProducing(),
+ ])
+ self.assertEqual(p2.mock_calls, [])
+ self.assertEqual(p3.mock_calls, [])
+ self.assertEqual(c.mock_calls, [mock.call.send_record(r2)])
+ clear_mock_calls(p1, p2, p3, c)
+ # p2 should now be at the head of the queue
+ self.assertEqual(list(o._all_producers), [p2, p3, p1])
+
+ # next turn: p2 has nothing to send, but p3 does. we should see p3
+ # called but not p1. The actual sequence of expected calls is:
+ # p2.resume, p3.resume, pauseProducing, set(p2.pause, p3.pause)
+ r3 = Pauser(seqnum=3)
+ p2.resumeProducing.side_effect = lambda: None
+ p3.resumeProducing.side_effect = lambda: c.send_record(r3)
+ o.resumeProducing()
+ self.assertEqual(p1.mock_calls, [])
+ self.assertEqual(p2.mock_calls, [mock.call.resumeProducing(),
+ mock.call.pauseProducing(),
+ ])
+ self.assertEqual(p3.mock_calls, [mock.call.resumeProducing(),
+ mock.call.pauseProducing(),
+ ])
+ self.assertEqual(c.mock_calls, [mock.call.send_record(r3)])
+ clear_mock_calls(p1, p2, p3, c)
+ # p1 should now be at the head of the queue
+ self.assertEqual(list(o._all_producers), [p1, p2, p3])
+
+ # next turn: p1 has data to send, but not enough to cause a pause. same
+ # for p2. p3 causes a pause
+ r4 = NonPauser(seqnum=4)
+ r5 = NonPauser(seqnum=5)
+ r6 = Pauser(seqnum=6)
+ p1.resumeProducing.side_effect = lambda: c.send_record(r4)
+ p2.resumeProducing.side_effect = lambda: c.send_record(r5)
+ p3.resumeProducing.side_effect = lambda: c.send_record(r6)
+ o.resumeProducing()
+ self.assertEqual(p1.mock_calls, [mock.call.resumeProducing(),
+ mock.call.pauseProducing(),
+ ])
+ self.assertEqual(p2.mock_calls, [mock.call.resumeProducing(),
+ mock.call.pauseProducing(),
+ ])
+ self.assertEqual(p3.mock_calls, [mock.call.resumeProducing(),
+ mock.call.pauseProducing(),
+ ])
+ self.assertEqual(c.mock_calls, [mock.call.send_record(r4),
+ mock.call.send_record(r5),
+ mock.call.send_record(r6),
+ ])
+ clear_mock_calls(p1, p2, p3, c)
+ # p1 should now be at the head of the queue again
+ self.assertEqual(list(o._all_producers), [p1, p2, p3])
+
+ # now we let it catch up. p1 and p2 send non-pausing data, p3 sends
+ # nothing.
+ r7 = NonPauser(seqnum=4)
+ r8 = NonPauser(seqnum=5)
+ p1.resumeProducing.side_effect = lambda: c.send_record(r7)
+ p2.resumeProducing.side_effect = lambda: c.send_record(r8)
+ p3.resumeProducing.side_effect = lambda: None
+
+ o.resumeProducing()
+ self.assertEqual(p1.mock_calls, [mock.call.resumeProducing(),
+ ])
+ self.assertEqual(p2.mock_calls, [mock.call.resumeProducing(),
+ ])
+ self.assertEqual(p3.mock_calls, [mock.call.resumeProducing(),
+ ])
+ self.assertEqual(c.mock_calls, [mock.call.send_record(r7),
+ mock.call.send_record(r8),
+ ])
+ clear_mock_calls(p1, p2, p3, c)
+ # p1 should now be at the head of the queue again
+ self.assertEqual(list(o._all_producers), [p1, p2, p3])
+ self.assertFalse(o._paused)
+
+ # now a producer disconnects itself (spontaneously, not from inside a
+ # resumeProducing)
+ o.subchannel_unregisterProducer(sc1)
+ self.assertEqual(list(o._all_producers), [p2, p3])
+ self.assertEqual(p1.mock_calls, [])
+ self.assertFalse(o._paused)
+
+ # and another disconnects itself when called
+ p2.resumeProducing.side_effect = lambda: None
+ p3.resumeProducing.side_effect = lambda: o.subchannel_unregisterProducer(sc3)
+ o.pauseProducing()
+ o.resumeProducing()
+ self.assertEqual(p2.mock_calls, [mock.call.pauseProducing(),
+ mock.call.resumeProducing()])
+ self.assertEqual(p3.mock_calls, [mock.call.pauseProducing(),
+ mock.call.resumeProducing()])
+ clear_mock_calls(p2, p3)
+ self.assertEqual(list(o._all_producers), [p2])
+ self.assertFalse(o._paused)
+
+ def test_subchannel_closed(self):
+ o, m, c = make_outbound()
+
+ sc1 = mock.Mock()
+ p1 = mock.Mock(name="p1")
+ o.subchannel_registerProducer(sc1, p1, True)
+ self.assertEqual(p1.mock_calls, [mock.call.pauseProducing()])
+ clear_mock_calls(p1)
+
+ o.subchannel_closed(sc1)
+ self.assertEqual(p1.mock_calls, [])
+ self.assertEqual(list(o._all_producers), [])
+
+ sc2 = mock.Mock()
+ o.subchannel_closed(sc2)
+
+ def test_disconnect(self):
+ o, m, c = make_outbound()
+ o.use_connection(c)
+
+ sc1 = mock.Mock()
+ p1 = mock.Mock(name="p1")
+ o.subchannel_registerProducer(sc1, p1, True)
+ self.assertEqual(p1.mock_calls, [])
+ o.stop_using_connection()
+ self.assertEqual(p1.mock_calls, [mock.call.pauseProducing()])
+
+ def OFF_test_push_pull(self):
+ # use one IPushProducer and one IPullProducer. They should take turns
+ o, m, c = make_outbound()
+ o.use_connection(c)
+ clear_mock_calls(c)
+
+ sc1, sc2 = object(), object()
+ p1, p2 = mock.Mock(name="p1"), mock.Mock(name="p2")
+ r1 = Pauser(seqnum=1)
+ r2 = NonPauser(seqnum=2)
+
+ # we aren't paused yet, since we haven't sent any data
+ o.subchannel_registerProducer(sc1, p1, True) # push
+ o.queue_and_send_record(r1)
+ # now we're paused
+ self.assertTrue(o._paused)
+ self.assertEqual(c.mock_calls, [mock.call.send_record(r1)])
+ self.assertEqual(p1.mock_calls, [mock.call.pauseProducing()])
+ self.assertEqual(p2.mock_calls, [])
+ clear_mock_calls(p1, p2, c)
+
+ p1.resumeProducing.side_effect = lambda: c.send_record(r1)
+ p2.resumeProducing.side_effect = lambda: c.send_record(r2)
+ o.subchannel_registerProducer(sc2, p2, False) # pull: always ready
+
+ # p1 is still first, since p2 was just added (at the end)
+ self.assertTrue(o._paused)
+ self.assertEqual(c.mock_calls, [])
+ self.assertEqual(p1.mock_calls, [])
+ self.assertEqual(p2.mock_calls, [])
+ self.assertEqual(list(o._all_producers), [p1, p2])
+ clear_mock_calls(p1, p2, c)
+
+ # resume should send r1, which should pause everything
+ o.resumeProducing()
+ self.assertTrue(o._paused)
+ self.assertEqual(c.mock_calls, [mock.call.send_record(r1),
+ ])
+ self.assertEqual(p1.mock_calls, [mock.call.resumeProducing(),
+ mock.call.pauseProducing(),
+ ])
+ self.assertEqual(p2.mock_calls, [])
+ self.assertEqual(list(o._all_producers), [p2, p1]) # now p2 is next
+ clear_mock_calls(p1, p2, c)
+
+ # next should fire p2, then p1
+ o.resumeProducing()
+ self.assertTrue(o._paused)
+ self.assertEqual(c.mock_calls, [mock.call.send_record(r2),
+ mock.call.send_record(r1),
+ ])
+ self.assertEqual(p1.mock_calls, [mock.call.resumeProducing(),
+ mock.call.pauseProducing(),
+ ])
+ self.assertEqual(p2.mock_calls, [mock.call.resumeProducing(),
+ ])
+ self.assertEqual(list(o._all_producers), [p2, p1]) # p2 still at bat
+ clear_mock_calls(p1, p2, c)
+
+ def test_pull_producer(self):
+ # a single pull producer should write until it is paused, rate-limited
+ # by the cooperator (so we'll see back-to-back resumeProducing calls
+ # until the Connection is paused, or 10ms have passed, whichever comes
+ # first, and if it's stopped by the timer, then the next EventualQueue
+ # turn will start it off again)
+
+ o, m, c = make_outbound()
+ eq = o._test_eq
+ o.use_connection(c)
+ clear_mock_calls(c)
+ self.assertFalse(o._paused)
+
+ sc1 = mock.Mock()
+ p1 = mock.Mock(name="p1")
+ alsoProvides(p1, IPullProducer)
+
+ records = [NonPauser(seqnum=1)] * 10
+ records.append(Pauser(seqnum=2))
+ records.append(Stopper(sc1))
+ it = iter(records)
+ p1.resumeProducing.side_effect = lambda: c.send_record(next(it))
+ o.subchannel_registerProducer(sc1, p1, False)
+ eq.flush_sync() # fast forward into the glorious (paused) future
+
+ self.assertTrue(o._paused)
+ self.assertEqual(c.mock_calls,
+ [mock.call.send_record(r) for r in records[:-1]])
+ self.assertEqual(p1.mock_calls,
+ [mock.call.resumeProducing()]*(len(records)-1))
+ clear_mock_calls(c, p1)
+
+ # next resumeProducing should cause it to disconnect
+ o.resumeProducing()
+ eq.flush_sync()
+ self.assertEqual(c.mock_calls, [mock.call.send_record(records[-1])])
+ self.assertEqual(p1.mock_calls, [mock.call.resumeProducing()])
+ self.assertEqual(len(o._all_producers), 0)
+ self.assertFalse(o._paused)
+
+ def test_two_pull_producers(self):
+ # we should alternate between them until paused
+ p1_records = ([NonPauser(seqnum=i) for i in range(5)] +
+ [Pauser(seqnum=5)] +
+ [NonPauser(seqnum=i) for i in range(6, 10)])
+ p2_records = ([NonPauser(seqnum=i) for i in range(10, 19)] +
+ [Pauser(seqnum=19)])
+ expected1 = [NonPauser(0), NonPauser(10),
+ NonPauser(1), NonPauser(11),
+ NonPauser(2), NonPauser(12),
+ NonPauser(3), NonPauser(13),
+ NonPauser(4), NonPauser(14),
+ Pauser(5)]
+ expected2 = [ NonPauser(15),
+ NonPauser(6), NonPauser(16),
+ NonPauser(7), NonPauser(17),
+ NonPauser(8), NonPauser(18),
+ NonPauser(9), Pauser(19),
+ ]
+
+ o, m, c = make_outbound()
+ eq = o._test_eq
+ o.use_connection(c)
+ clear_mock_calls(c)
+ self.assertFalse(o._paused)
+
+ sc1 = mock.Mock()
+ p1 = mock.Mock(name="p1")
+ alsoProvides(p1, IPullProducer)
+ it1 = iter(p1_records)
+ p1.resumeProducing.side_effect = lambda: c.send_record(next(it1))
+ o.subchannel_registerProducer(sc1, p1, False)
+
+ sc2 = mock.Mock()
+ p2 = mock.Mock(name="p2")
+ alsoProvides(p2, IPullProducer)
+ it2 = iter(p2_records)
+ p2.resumeProducing.side_effect = lambda: c.send_record(next(it2))
+ o.subchannel_registerProducer(sc2, p2, False)
+
+ eq.flush_sync() # fast forward into the glorious (paused) future
+
+ sends = [mock.call.resumeProducing()]
+ self.assertTrue(o._paused)
+ self.assertEqual(c.mock_calls,
+ [mock.call.send_record(r) for r in expected1])
+ self.assertEqual(p1.mock_calls, 6*sends)
+ self.assertEqual(p2.mock_calls, 5*sends)
+ clear_mock_calls(c, p1, p2)
+
+ o.resumeProducing()
+ eq.flush_sync()
+ self.assertTrue(o._paused)
+ self.assertEqual(c.mock_calls,
+ [mock.call.send_record(r) for r in expected2])
+ self.assertEqual(p1.mock_calls, 4*sends)
+ self.assertEqual(p2.mock_calls, 5*sends)
+ clear_mock_calls(c, p1, p2)
+
+ def test_send_if_connected(self):
+ o, m, c = make_outbound()
+ o.send_if_connected(Ack(1)) # not connected yet
+
+ o.use_connection(c)
+ o.send_if_connected(KCM())
+ self.assertEqual(c.mock_calls, [mock.call.registerProducer(o, True),
+ mock.call.send_record(KCM())])
+
+ def test_tolerate_duplicate_pause_resume(self):
+ o, m, c = make_outbound()
+ self.assertTrue(o._paused) # no connection
+ o.use_connection(c)
+ self.assertFalse(o._paused)
+ o.pauseProducing()
+ self.assertTrue(o._paused)
+ o.pauseProducing()
+ self.assertTrue(o._paused)
+ o.resumeProducing()
+ self.assertFalse(o._paused)
+ o.resumeProducing()
+ self.assertFalse(o._paused)
+
+ def test_stopProducing(self):
+ o, m, c = make_outbound()
+ o.use_connection(c)
+ self.assertFalse(o._paused)
+ o.stopProducing() # connection does this before loss
+ self.assertTrue(o._paused)
+ o.stop_using_connection()
+ self.assertTrue(o._paused)
+
+ def test_resume_error(self):
+ o, m, c = make_outbound()
+ o.use_connection(c)
+ sc1 = mock.Mock()
+ p1 = mock.Mock(name="p1")
+ alsoProvides(p1, IPullProducer)
+ p1.resumeProducing.side_effect = PretendResumptionError
+ o.subchannel_registerProducer(sc1, p1, False)
+ o._test_eq.flush_sync()
+ # the error is supposed to automatically unregister the producer
+ self.assertEqual(list(o._all_producers), [])
+ self.flushLoggedErrors(PretendResumptionError)
+
+
+def make_pushpull(pauses):
+ p = mock.Mock()
+ alsoProvides(p, IPullProducer)
+ unregister = mock.Mock()
+
+ clock = Clock()
+ eq = EventualQueue(clock)
+ term = mock.Mock(side_effect=lambda: True) # one write per Eventual tick
+ term_factory = lambda: term
+ coop = Cooperator(terminationPredicateFactory=term_factory,
+ scheduler=eq.eventually)
+ pp = PullToPush(p, unregister, coop)
+
+ it = cycle(pauses)
+ def action(i):
+ if isinstance(i, Exception):
+ raise i
+ elif i:
+ pp.pauseProducing()
+ p.resumeProducing.side_effect = lambda: action(next(it))
+ return p, unregister, pp, eq
+
+class PretendResumptionError(Exception):
+ pass
+class PretendUnregisterError(Exception):
+ pass
+
+class PushPull(unittest.TestCase):
+ # test our wrapper utility, which I copied from
+ # twisted.internet._producer_helpers since it isn't publically exposed
+
+ def test_start_unpaused(self):
+ p, unr, pp, eq = make_pushpull([True]) # pause on each resumeProducing
+ # if it starts unpaused, it gets one write before being halted
+ pp.startStreaming(False)
+ eq.flush_sync()
+ self.assertEqual(p.mock_calls, [mock.call.resumeProducing()]*1)
+ clear_mock_calls(p)
+
+ # now each time we call resumeProducing, we should see one delivered to
+ # the underlying IPullProducer
+ pp.resumeProducing()
+ eq.flush_sync()
+ self.assertEqual(p.mock_calls, [mock.call.resumeProducing()]*1)
+
+ pp.stopStreaming()
+ pp.stopStreaming() # should tolerate this
+
+ def test_start_unpaused_two_writes(self):
+ p, unr, pp, eq = make_pushpull([False, True]) # pause every other time
+ # it should get two writes, since the first didn't pause
+ pp.startStreaming(False)
+ eq.flush_sync()
+ self.assertEqual(p.mock_calls, [mock.call.resumeProducing()]*2)
+
+ def test_start_paused(self):
+ p, unr, pp, eq = make_pushpull([True]) # pause on each resumeProducing
+ pp.startStreaming(True)
+ eq.flush_sync()
+ self.assertEqual(p.mock_calls, [])
+ pp.stopStreaming()
+
+ def test_stop(self):
+ p, unr, pp, eq = make_pushpull([True])
+ pp.startStreaming(True)
+ pp.stopProducing()
+ eq.flush_sync()
+ self.assertEqual(p.mock_calls, [mock.call.stopProducing()])
+
+ def test_error(self):
+ p, unr, pp, eq = make_pushpull([PretendResumptionError()])
+ unr.side_effect = lambda: pp.stopStreaming()
+ pp.startStreaming(False)
+ eq.flush_sync()
+ self.assertEqual(unr.mock_calls, [mock.call()])
+ self.flushLoggedErrors(PretendResumptionError)
+
+ def test_error_during_unregister(self):
+ p, unr, pp, eq = make_pushpull([PretendResumptionError()])
+ unr.side_effect = PretendUnregisterError()
+ pp.startStreaming(False)
+ eq.flush_sync()
+ self.assertEqual(unr.mock_calls, [mock.call()])
+ self.flushLoggedErrors(PretendResumptionError, PretendUnregisterError)
+
+
+
+
+
+ # TODO: consider making p1/p2/p3 all elements of a shared Mock, maybe I
+ # could capture the inter-call ordering that way
diff --git a/src/wormhole/test/dilate/test_parse.py b/src/wormhole/test/dilate/test_parse.py
new file mode 100644
index 0000000..8365e62
--- /dev/null
+++ b/src/wormhole/test/dilate/test_parse.py
@@ -0,0 +1,43 @@
+from __future__ import print_function, unicode_literals
+import mock
+from twisted.trial import unittest
+from ..._dilation.connection import (parse_record, encode_record,
+ KCM, Ping, Pong, Open, Data, Close, Ack)
+
+class Parse(unittest.TestCase):
+ def test_parse(self):
+ self.assertEqual(parse_record(b"\x00"), KCM())
+ self.assertEqual(parse_record(b"\x01\x55\x44\x33\x22"),
+ Ping(ping_id=b"\x55\x44\x33\x22"))
+ self.assertEqual(parse_record(b"\x02\x55\x44\x33\x22"),
+ Pong(ping_id=b"\x55\x44\x33\x22"))
+ self.assertEqual(parse_record(b"\x03\x00\x00\x02\x01\x00\x00\x01\x00"),
+ Open(scid=513, seqnum=256))
+ self.assertEqual(parse_record(b"\x04\x00\x00\x02\x02\x00\x00\x01\x01dataaa"),
+ Data(scid=514, seqnum=257, data=b"dataaa"))
+ self.assertEqual(parse_record(b"\x05\x00\x00\x02\x03\x00\x00\x01\x02"),
+ Close(scid=515, seqnum=258))
+ self.assertEqual(parse_record(b"\x06\x00\x00\x01\x03"),
+ Ack(resp_seqnum=259))
+ with mock.patch("wormhole._dilation.connection.log.err") as le:
+ with self.assertRaises(ValueError):
+ parse_record(b"\x07unknown")
+ self.assertEqual(le.mock_calls,
+ [mock.call("received unknown message type: {}".format(
+ b"\x07unknown"))])
+
+ def test_encode(self):
+ self.assertEqual(encode_record(KCM()), b"\x00")
+ self.assertEqual(encode_record(Ping(ping_id=b"ping")), b"\x01ping")
+ self.assertEqual(encode_record(Pong(ping_id=b"pong")), b"\x02pong")
+ self.assertEqual(encode_record(Open(scid=65536, seqnum=16)),
+ b"\x03\x00\x01\x00\x00\x00\x00\x00\x10")
+ self.assertEqual(encode_record(Data(scid=65537, seqnum=17, data=b"dataaa")),
+ b"\x04\x00\x01\x00\x01\x00\x00\x00\x11dataaa")
+ self.assertEqual(encode_record(Close(scid=65538, seqnum=18)),
+ b"\x05\x00\x01\x00\x02\x00\x00\x00\x12")
+ self.assertEqual(encode_record(Ack(resp_seqnum=19)),
+ b"\x06\x00\x00\x00\x13")
+ with self.assertRaises(TypeError) as ar:
+ encode_record("not a record")
+ self.assertEqual(str(ar.exception), "not a record")
diff --git a/src/wormhole/test/dilate/test_record.py b/src/wormhole/test/dilate/test_record.py
new file mode 100644
index 0000000..810396c
--- /dev/null
+++ b/src/wormhole/test/dilate/test_record.py
@@ -0,0 +1,268 @@
+from __future__ import print_function, unicode_literals
+import mock
+from zope.interface import alsoProvides
+from twisted.trial import unittest
+from noise.exceptions import NoiseInvalidMessage
+from ..._dilation.connection import (IFramer, Frame, Prologue,
+ _Record, Handshake,
+ Disconnect, Ping)
+
+def make_record():
+ f = mock.Mock()
+ alsoProvides(f, IFramer)
+ n = mock.Mock() # pretends to be a Noise object
+ r = _Record(f, n)
+ return r, f, n
+
+class Record(unittest.TestCase):
+ def test_good2(self):
+ f = mock.Mock()
+ alsoProvides(f, IFramer)
+ f.add_and_parse = mock.Mock(side_effect=[
+ [],
+ [Prologue()],
+ [Frame(frame=b"rx-handshake")],
+ [Frame(frame=b"frame1"), Frame(frame=b"frame2")],
+ ])
+ n = mock.Mock()
+ n.write_message = mock.Mock(return_value=b"tx-handshake")
+ p1, p2 = object(), object()
+ n.decrypt = mock.Mock(side_effect=[p1, p2])
+ r = _Record(f, n)
+ self.assertEqual(f.mock_calls, [])
+ r.connectionMade()
+ self.assertEqual(f.mock_calls, [mock.call.connectionMade()])
+ f.mock_calls[:] = []
+ self.assertEqual(n.mock_calls, [mock.call.start_handshake()])
+ n.mock_calls[:] = []
+
+ # Pretend to deliver the prologue in two parts. The text we send in
+ # doesn't matter: the side_effect= is what causes the prologue to be
+ # recognized by the second call.
+ self.assertEqual(list(r.add_and_unframe(b"pro")), [])
+ self.assertEqual(f.mock_calls, [mock.call.add_and_parse(b"pro")])
+ f.mock_calls[:] = []
+ self.assertEqual(n.mock_calls, [])
+
+ # recognizing the prologue causes a handshake frame to be sent
+ self.assertEqual(list(r.add_and_unframe(b"logue")), [])
+ self.assertEqual(f.mock_calls, [mock.call.add_and_parse(b"logue"),
+ mock.call.send_frame(b"tx-handshake")])
+ f.mock_calls[:] = []
+ self.assertEqual(n.mock_calls, [mock.call.write_message()])
+ n.mock_calls[:] = []
+
+ # next add_and_unframe is recognized as the Handshake
+ self.assertEqual(list(r.add_and_unframe(b"blah")), [Handshake()])
+ self.assertEqual(f.mock_calls, [mock.call.add_and_parse(b"blah")])
+ f.mock_calls[:] = []
+ self.assertEqual(n.mock_calls, [mock.call.read_message(b"rx-handshake")])
+ n.mock_calls[:] = []
+
+ # next is a pair of Records
+ r1, r2 = object() , object()
+ with mock.patch("wormhole._dilation.connection.parse_record",
+ side_effect=[r1,r2]) as pr:
+ self.assertEqual(list(r.add_and_unframe(b"blah2")), [r1, r2])
+ self.assertEqual(n.mock_calls, [mock.call.decrypt(b"frame1"),
+ mock.call.decrypt(b"frame2")])
+ self.assertEqual(pr.mock_calls, [mock.call(p1), mock.call(p2)])
+
+ def test_bad_handshake(self):
+ f = mock.Mock()
+ alsoProvides(f, IFramer)
+ f.add_and_parse = mock.Mock(return_value=[Prologue(),
+ Frame(frame=b"rx-handshake")])
+ n = mock.Mock()
+ n.write_message = mock.Mock(return_value=b"tx-handshake")
+ nvm = NoiseInvalidMessage()
+ n.read_message = mock.Mock(side_effect=nvm)
+ r = _Record(f, n)
+ self.assertEqual(f.mock_calls, [])
+ r.connectionMade()
+ self.assertEqual(f.mock_calls, [mock.call.connectionMade()])
+ f.mock_calls[:] = []
+ self.assertEqual(n.mock_calls, [mock.call.start_handshake()])
+ n.mock_calls[:] = []
+
+ with mock.patch("wormhole._dilation.connection.log.err") as le:
+ with self.assertRaises(Disconnect):
+ list(r.add_and_unframe(b"data"))
+ self.assertEqual(le.mock_calls,
+ [mock.call(nvm, "bad inbound noise handshake")])
+
+ def test_bad_message(self):
+ f = mock.Mock()
+ alsoProvides(f, IFramer)
+ f.add_and_parse = mock.Mock(return_value=[Prologue(),
+ Frame(frame=b"rx-handshake"),
+ Frame(frame=b"bad-message")])
+ n = mock.Mock()
+ n.write_message = mock.Mock(return_value=b"tx-handshake")
+ nvm = NoiseInvalidMessage()
+ n.decrypt = mock.Mock(side_effect=nvm)
+ r = _Record(f, n)
+ self.assertEqual(f.mock_calls, [])
+ r.connectionMade()
+ self.assertEqual(f.mock_calls, [mock.call.connectionMade()])
+ f.mock_calls[:] = []
+ self.assertEqual(n.mock_calls, [mock.call.start_handshake()])
+ n.mock_calls[:] = []
+
+ with mock.patch("wormhole._dilation.connection.log.err") as le:
+ with self.assertRaises(Disconnect):
+ list(r.add_and_unframe(b"data"))
+ self.assertEqual(le.mock_calls,
+ [mock.call(nvm, "bad inbound noise frame")])
+
+ def test_send_record(self):
+ f = mock.Mock()
+ alsoProvides(f, IFramer)
+ n = mock.Mock()
+ f1 = object()
+ n.encrypt = mock.Mock(return_value=f1)
+ r1 = Ping(b"pingid")
+ r = _Record(f, n)
+ self.assertEqual(f.mock_calls, [])
+ m1 = object()
+ with mock.patch("wormhole._dilation.connection.encode_record",
+ return_value=m1) as er:
+ r.send_record(r1)
+ self.assertEqual(er.mock_calls, [mock.call(r1)])
+ self.assertEqual(n.mock_calls, [mock.call.start_handshake(),
+ mock.call.encrypt(m1)])
+ self.assertEqual(f.mock_calls, [mock.call.send_frame(f1)])
+
+ def test_good(self):
+ # Exercise the success path. The Record instance is given each chunk
+ # of data as it arrives on Protocol.dataReceived, and is supposed to
+ # return a series of Tokens (maybe none, if the chunk was incomplete,
+ # or more than one, if the chunk was larger). Internally, it delivers
+ # the chunks to the Framer for unframing (which returns 0 or more
+ # frames), manages the Noise decryption object, and parses any
+ # decrypted messages into tokens (some of which are consumed
+ # internally, others for delivery upstairs).
+ #
+ # in the normal flow, we get:
+ #
+ # | | Inbound | NoiseAction | Outbound | ToUpstairs |
+ # | | - | - | - | - |
+ # | 1 | | | prologue | |
+ # | 2 | prologue | | | |
+ # | 3 | | write_message | handshake | |
+ # | 4 | handshake | read_message | | Handshake |
+ # | 5 | | encrypt | KCM | |
+ # | 6 | KCM | decrypt | | KCM |
+ # | 7 | msg1 | decrypt | | msg1 |
+
+ # 1: instantiating the Record instance causes the outbound prologue
+ # to be sent
+
+ # 2+3: receipt of the inbound prologue triggers creation of the
+ # ephemeral key (the "handshake") by calling noise.write_message()
+ # and then writes the handshake to the outbound transport
+
+ # 4: when the peer's handshake is received, it is delivered to
+ # noise.read_message(), which generates the shared key (enabling
+ # noise.send() and noise.decrypt()). It also delivers the Handshake
+ # token upstairs, which might (on the Follower) trigger immediate
+ # transmission of the Key Confirmation Message (KCM)
+
+ # 5: the outbound KCM is framed and fed into noise.encrypt(), then
+ # sent outbound
+
+ # 6: the peer's KCM is decrypted then delivered upstairs. The
+ # Follower treats this as a signal that it should use this connection
+ # (and drop all others).
+
+ # 7: the peer's first message is decrypted, parsed, and delivered
+ # upstairs. This might be an Open or a Data, depending upon what
+ # queued messages were left over from the previous connection
+
+ r, f, n = make_record()
+ outbound_handshake = object()
+ kcm, msg1 = object(), object()
+ f_kcm, f_msg1 = object(), object()
+ n.write_message = mock.Mock(return_value=outbound_handshake)
+ n.decrypt = mock.Mock(side_effect=[kcm, msg1])
+ n.encrypt = mock.Mock(side_effect=[f_kcm, f_msg1])
+ f.add_and_parse = mock.Mock(side_effect=[[], # no tokens yet
+ [Prologue()],
+ [Frame("f_handshake")],
+ [Frame("f_kcm"),
+ Frame("f_msg1")],
+ ])
+
+ self.assertEqual(f.mock_calls, [])
+ self.assertEqual(n.mock_calls, [mock.call.start_handshake()])
+ n.mock_calls[:] = []
+
+ # 1. The Framer is responsible for sending the prologue, so we don't
+ # have to check that here, we just check that the Framer was told
+ # about connectionMade properly.
+ r.connectionMade()
+ self.assertEqual(f.mock_calls, [mock.call.connectionMade()])
+ self.assertEqual(n.mock_calls, [])
+ f.mock_calls[:] = []
+
+ # 2
+ # we dribble the prologue in over two messages, to make sure we can
+ # handle a dataReceived that doesn't complete the token
+
+ # remember, add_and_unframe is a generator
+ self.assertEqual(list(r.add_and_unframe(b"pro")), [])
+ self.assertEqual(f.mock_calls, [mock.call.add_and_parse(b"pro")])
+ self.assertEqual(n.mock_calls, [])
+ f.mock_calls[:] = []
+
+ self.assertEqual(list(r.add_and_unframe(b"logue")), [])
+ # 3: write_message, send outbound handshake
+ self.assertEqual(f.mock_calls, [mock.call.add_and_parse(b"logue"),
+ mock.call.send_frame(outbound_handshake),
+ ])
+ self.assertEqual(n.mock_calls, [mock.call.write_message()])
+ f.mock_calls[:] = []
+ n.mock_calls[:] = []
+
+ # 4
+ # Now deliver the Noise "handshake", the ephemeral public key. This
+ # is framed, but not a record, so it shouldn't decrypt or parse
+ # anything, but the handshake is delivered to the Noise object, and
+ # it does return a Handshake token so we can let the next layer up
+ # react (by sending the KCM frame if we're a Follower, or not if
+ # we're the Leader)
+
+ self.assertEqual(list(r.add_and_unframe(b"handshake")), [Handshake()])
+ self.assertEqual(f.mock_calls, [mock.call.add_and_parse(b"handshake")])
+ self.assertEqual(n.mock_calls, [mock.call.read_message("f_handshake")])
+ f.mock_calls[:] = []
+ n.mock_calls[:] = []
+
+
+ # 5: at this point we ought to be able to send a messge, the KCM
+ with mock.patch("wormhole._dilation.connection.encode_record",
+ side_effect=[b"r-kcm"]) as er:
+ r.send_record(kcm)
+ self.assertEqual(er.mock_calls, [mock.call(kcm)])
+ self.assertEqual(n.mock_calls, [mock.call.encrypt(b"r-kcm")])
+ self.assertEqual(f.mock_calls, [mock.call.send_frame(f_kcm)])
+ n.mock_calls[:] = []
+ f.mock_calls[:] = []
+
+ # 6: Now we deliver two messages stacked up: the KCM (Key
+ # Confirmation Message) and the first real message. Concatenating
+ # them tests that we can handle more than one token in a single
+ # chunk. We need to mock parse_record() because everything past the
+ # handshake is decrypted and parsed.
+
+ with mock.patch("wormhole._dilation.connection.parse_record",
+ side_effect=[kcm, msg1]) as pr:
+ self.assertEqual(list(r.add_and_unframe(b"kcm,msg1")),
+ [kcm, msg1])
+ self.assertEqual(f.mock_calls,
+ [mock.call.add_and_parse(b"kcm,msg1")])
+ self.assertEqual(n.mock_calls, [mock.call.decrypt("f_kcm"),
+ mock.call.decrypt("f_msg1")])
+ self.assertEqual(pr.mock_calls, [mock.call(kcm), mock.call(msg1)])
+ n.mock_calls[:] = []
+ f.mock_calls[:] = []
diff --git a/src/wormhole/test/dilate/test_subchannel.py b/src/wormhole/test/dilate/test_subchannel.py
new file mode 100644
index 0000000..69fa001
--- /dev/null
+++ b/src/wormhole/test/dilate/test_subchannel.py
@@ -0,0 +1,142 @@
+from __future__ import print_function, unicode_literals
+import mock
+from twisted.trial import unittest
+from twisted.internet.interfaces import ITransport
+from twisted.internet.error import ConnectionDone
+from ..._dilation.subchannel import (Once, SubChannel,
+ _WormholeAddress, _SubchannelAddress,
+ AlreadyClosedError)
+from .common import mock_manager
+
+def make_sc(set_protocol=True):
+ scid = b"scid"
+ hostaddr = _WormholeAddress()
+ peeraddr = _SubchannelAddress(scid)
+ m = mock_manager()
+ sc = SubChannel(scid, m, hostaddr, peeraddr)
+ p = mock.Mock()
+ if set_protocol:
+ sc._set_protocol(p)
+ return sc, m, scid, hostaddr, peeraddr, p
+
+class SubChannelAPI(unittest.TestCase):
+ def test_once(self):
+ o = Once(ValueError)
+ o()
+ with self.assertRaises(ValueError):
+ o()
+
+ def test_create(self):
+ sc, m, scid, hostaddr, peeraddr, p = make_sc()
+ self.assert_(ITransport.providedBy(sc))
+ self.assertEqual(m.mock_calls, [])
+ self.assertIdentical(sc.getHost(), hostaddr)
+ self.assertIdentical(sc.getPeer(), peeraddr)
+
+ def test_write(self):
+ sc, m, scid, hostaddr, peeraddr, p = make_sc()
+
+ sc.write(b"data")
+ self.assertEqual(m.mock_calls, [mock.call.send_data(scid, b"data")])
+ m.mock_calls[:] = []
+ sc.writeSequence([b"more", b"data"])
+ self.assertEqual(m.mock_calls, [mock.call.send_data(scid, b"moredata")])
+
+ def test_write_when_closing(self):
+ sc, m, scid, hostaddr, peeraddr, p = make_sc()
+
+ sc.loseConnection()
+ self.assertEqual(m.mock_calls, [mock.call.send_close(scid)])
+ m.mock_calls[:] = []
+
+ with self.assertRaises(AlreadyClosedError) as e:
+ sc.write(b"data")
+ self.assertEqual(str(e.exception),
+ "write not allowed on closed subchannel")
+
+ def test_local_close(self):
+ sc, m, scid, hostaddr, peeraddr, p = make_sc()
+
+ sc.loseConnection()
+ self.assertEqual(m.mock_calls, [mock.call.send_close(scid)])
+ m.mock_calls[:] = []
+
+ # late arriving data is still delivered
+ sc.remote_data(b"late")
+ self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"late")])
+ p.mock_calls[:] = []
+
+ sc.remote_close()
+ self.assert_connectionDone(p.mock_calls)
+
+ def test_local_double_close(self):
+ sc, m, scid, hostaddr, peeraddr, p = make_sc()
+
+ sc.loseConnection()
+ self.assertEqual(m.mock_calls, [mock.call.send_close(scid)])
+ m.mock_calls[:] = []
+
+ with self.assertRaises(AlreadyClosedError) as e:
+ sc.loseConnection()
+ self.assertEqual(str(e.exception),
+ "loseConnection not allowed on closed subchannel")
+
+ def assert_connectionDone(self, mock_calls):
+ self.assertEqual(len(mock_calls), 1)
+ self.assertEqual(mock_calls[0][0], "connectionLost")
+ self.assertEqual(len(mock_calls[0][1]), 1)
+ self.assertIsInstance(mock_calls[0][1][0], ConnectionDone)
+
+ def test_remote_close(self):
+ sc, m, scid, hostaddr, peeraddr, p = make_sc()
+ sc.remote_close()
+ self.assertEqual(m.mock_calls, [mock.call.subchannel_closed(sc)])
+ self.assert_connectionDone(p.mock_calls)
+
+ def test_data(self):
+ sc, m, scid, hostaddr, peeraddr, p = make_sc()
+ sc.remote_data(b"data")
+ self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"data")])
+ p.mock_calls[:] = []
+ sc.remote_data(b"not")
+ sc.remote_data(b"coalesced")
+ self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"not"),
+ mock.call.dataReceived(b"coalesced"),
+ ])
+
+ def test_data_before_open(self):
+ sc, m, scid, hostaddr, peeraddr, p = make_sc(set_protocol=False)
+ sc.remote_data(b"data")
+ self.assertEqual(p.mock_calls, [])
+ sc._set_protocol(p)
+ self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"data")])
+ p.mock_calls[:] = []
+ sc.remote_data(b"more")
+ self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"more")])
+
+ def test_close_before_open(self):
+ sc, m, scid, hostaddr, peeraddr, p = make_sc(set_protocol=False)
+ sc.remote_close()
+ self.assertEqual(p.mock_calls, [])
+ sc._set_protocol(p)
+ self.assert_connectionDone(p.mock_calls)
+
+ def test_producer(self):
+ sc, m, scid, hostaddr, peeraddr, p = make_sc()
+
+ sc.pauseProducing()
+ self.assertEqual(m.mock_calls, [mock.call.subchannel_pauseProducing(sc)])
+ m.mock_calls[:] = []
+ sc.resumeProducing()
+ self.assertEqual(m.mock_calls, [mock.call.subchannel_resumeProducing(sc)])
+ m.mock_calls[:] = []
+ sc.stopProducing()
+ self.assertEqual(m.mock_calls, [mock.call.subchannel_stopProducing(sc)])
+ m.mock_calls[:] = []
+
+ def test_consumer(self):
+ sc, m, scid, hostaddr, peeraddr, p = make_sc()
+
+ # TODO: more, once this is implemented
+ sc.registerProducer(None, True)
+ sc.unregisterProducer()
diff --git a/src/wormhole/test/test_machines.py b/src/wormhole/test/test_machines.py
index 5b5e9f1..35854ae 100644
--- a/src/wormhole/test/test_machines.py
+++ b/src/wormhole/test/test_machines.py
@@ -12,7 +12,7 @@ import mock
from .. import (__version__, _allocator, _boss, _code, _input, _key, _lister,
_mailbox, _nameplate, _order, _receive, _rendezvous, _send,
_terminator, errors, timing)
-from .._interfaces import (IAllocator, IBoss, ICode, IInput, IKey, ILister,
+from .._interfaces import (IAllocator, IBoss, ICode, IDilator, IInput, IKey, ILister,
IMailbox, INameplate, IOrder, IReceive,
IRendezvousConnector, ISend, ITerminator, IWordlist)
from .._key import derive_key, derive_phase_key, encrypt_data
@@ -1300,6 +1300,7 @@ class Boss(unittest.TestCase):
b._RC = Dummy("rc", events, IRendezvousConnector, "start")
b._C = Dummy("c", events, ICode, "allocate_code", "input_code",
"set_code")
+ b._D = Dummy("d", events, IDilator, "got_wormhole_versions", "got_key")
return b, events
def test_basic(self):
@@ -1327,7 +1328,9 @@ class Boss(unittest.TestCase):
b.got_message("side", "0", b"msg1")
self.assertEqual(events, [
("w.got_key", b"key"),
+ ("d.got_key", b"key"),
("w.got_verifier", b"verifier"),
+ ("d.got_wormhole_versions", "side", "side", {}),
("w.got_versions", {}),
("w.received", b"msg1"),
])
diff --git a/src/wormhole/wormhole.py b/src/wormhole/wormhole.py
index c02aa60..f967c25 100644
--- a/src/wormhole/wormhole.py
+++ b/src/wormhole/wormhole.py
@@ -9,6 +9,7 @@ from twisted.internet.task import Cooperator
from zope.interface import implementer
from ._boss import Boss
+from ._dilation.connector import Connector
from ._interfaces import IDeferredWormhole, IWormhole
from ._key import derive_key
from .errors import NoKeyError, WormholeClosed
@@ -189,6 +190,9 @@ class _DeferredWormhole(object):
raise NoKeyError()
return derive_key(self._key, to_bytes(purpose), length)
+ def dilate(self):
+ return self._boss.dilate() # fires with (endpoints)
+
def close(self):
# fails with WormholeError unless we established a connection
# (state=="happy"). Fails with WrongPasswordError (a subclass of
@@ -265,8 +269,12 @@ def create(
w = _DelegatedWormhole(delegate)
else:
w = _DeferredWormhole(reactor, eq)
- wormhole_versions = {} # will be used to indicate Wormhole capabilities
- wormhole_versions["app_versions"] = versions # app-specific capabilities
+ # this indicates Wormhole capabilities
+ wormhole_versions = {
+ "can-dilate": [1],
+ "dilation-abilities": Connector.get_connection_abilities(),
+ }
+ wormhole_versions["app_versions"] = versions # app-specific capabilities
v = __version__
if isinstance(v, type(b"")):
v = v.decode("utf-8", errors="replace")