Merge branch 'dilate-5'

This adds (but does not enable/expose) the low-level code for the new
Dilation protocol (refs #312). The spec and docs are done, the unit tests
pass (with full branch coverage).

The next step is to write some higher-level integration tests, which use a
fake/short-circuited mailbox connection (Manager.send) but real localhost TCP
sockets.

Then we need to figure out backwards compatibility with non-dilation-capable
versions. I've got a table in my notes, I'll add it to the ticket.
This commit is contained in:
Brian Warner 2018-12-24 23:23:16 -05:00
commit ddba0fc840
45 changed files with 8434 additions and 351 deletions

View File

@ -17,13 +17,16 @@ before_script:
flake8 *.py src --count --select=E901,E999,F821,F822,F823 --statistics ;
fi
script:
- tox -e coverage
- if [[ $TRAVIS_PYTHON_VERSION == 2.7 || $TRAVIS_PYTHON_VERSION == 3.4 ]]; then
tox -e no-dilate ;
else
tox -e coverage ;
fi
after_success:
- codecov
matrix:
include:
- python: 2.7
- python: 3.3
- python: 3.4
- python: 3.5
- python: 3.6
@ -34,5 +37,4 @@ matrix:
dist: xenial
- python: nightly
allow_failures:
- python: 3.3
- python: nightly

View File

@ -524,25 +524,33 @@ object twice.
## Dilation
(NOTE: this section is speculative: this code has not yet been written)
(NOTE: this API is still in development)
In the longer term, the Wormhole object will incorporate the "Transit"
functionality (see transit.md) directly, removing the need to instantiate a
second object. A Wormhole can be "dilated" into a form that is suitable for
bulk data transfer.
To send bulk data, or anything more than a handful of messages, a Wormhole
can be "dilated" into a form that uses a direct TCP connection between the
two endpoints.
All wormholes start out "undilated". In this state, all messages are queued
on the Rendezvous Server for the lifetime of the wormhole, and server-imposed
number/size/rate limits apply. Calling `w.dilate()` initiates the dilation
process, and success is signalled via either `d=w.when_dilated()` firing, or
`dg.wormhole_dilated()` being called. Once dilated, the Wormhole can be used
as an IConsumer/IProducer, and messages will be sent on a direct connection
(if possible) or through the transit relay (if not).
process, and eventually yields a set of Endpoints. Once dilated, the usual
`.send_message()`/`.get_message()` APIs are disabled (TODO: really?), and
these endpoints can be used to establish multiple (encrypted) "subchannel"
connections to the other side.
Each subchannel behaves like a regular Twisted `ITransport`, so they can be
glued to the Protocol instance of your choice. They also implement the
IConsumer/IProducer interfaces.
These subchannels are *durable*: as long as the processes on both sides keep
running, the subchannel will survive the network connection being dropped.
For example, a file transfer can be started from a laptop, then while it is
running, the laptop can be closed, moved to a new wifi network, opened back
up, and the transfer will resume from the new IP address.
What's good about a non-dilated wormhole?:
* setup is faster: no delay while it tries to make a direct connection
* survives temporary network outages, since messages are queued
* works with "journaled mode", allowing progress to be made even when both
sides are never online at the same time, by serializing the wormhole
@ -556,21 +564,103 @@ Use non-dilated wormholes when your application only needs to exchange a
couple of messages, for example to set up public keys or provision access
tokens. Use a dilated wormhole to move files.
Dilated wormholes can provide multiple "channels": these are multiplexed
through the single (encrypted) TCP connection. Each channel is a separate
stream (offering IProducer/IConsumer)
Dilated wormholes can provide multiple "subchannels": these are multiplexed
through the single (encrypted) TCP connection. Each subchannel is a separate
stream (offering IProducer/IConsumer for flow control), and is opened and
closed independently. A special "control channel" is available to both sides
so they can coordinate how they use the subchannels.
To create a channel, call `c = w.create_channel()` on a dilated wormhole. The
"channel ID" can be obtained with `c.get_id()`. This ID will be a short
(unicode) string, which can be sent to the other side via a normal
`w.send()`, or any other means. On the other side, use `c =
w.open_channel(channel_id)` to get a matching channel object.
The `d = w.dilate()` Deferred fires with a triple of Endpoints:
Then use `c.send(data)` and `d=c.when_received()` to exchange data, or wire
them up with `c.registerProducer()`. Note that channels do not close until
the wormhole connection is closed, so they do not have separate `close()`
methods or events. Therefore if you plan to send files through them, you'll
need to inform the recipient ahead of time about how many bytes to expect.
```python
d = w.dilate()
def _dilated(res):
(control_channel_ep, subchannel_client_ep, subchannel_server_ep) = res
d.addCallback(_dilated)
```
The `control_channel_ep` endpoint is a client-style endpoint, so both sides
will connect to it with `ep.connect(factory)`. This endpoint is single-use:
calling `.connect()` a second time will fail. The control channel is
symmetric: it doesn't matter which side is the application-level
client/server or initiator/responder, if the application even has such
concepts. The two applications can use the control channel to negotiate who
goes first, if necessary.
The subchannel endpoints are *not* symmetric: for each subchannel, one side
must listen as a server, and the other must connect as a client. Subchannels
can be established by either side at any time. This supports e.g.
bidirectional file transfer, where either user of a GUI app can drop files
into the "wormhole" whenever they like.
The `subchannel_client_ep` on one side is used to connect to the other side's
`subchannel_server_ep`, and vice versa. The client endpoint is reusable. The
server endpoint is single-use: `.listen(factory)` may only be called once.
Applications are under no obligation to use subchannels: for many use cases,
the control channel is enough.
To use subchannels, once the wormhole is dilated and the endpoints are
available, the listening-side application should attach a listener to the
`subchannel_server_ep` endpoint:
```python
def _dilated(res):
(control_channel_ep, subchannel_client_ep, subchannel_server_ep) = res
f = Factory(MyListeningProtocol)
subchannel_server_ep.listen(f)
```
When the connecting-side application wants to connect to that listening
protocol, it should use `.connect()` with a suitable connecting protocol
factory:
```python
def _connect():
f = Factory(MyConnectingProtocol)
subchannel_client_ep.connect(f)
```
For a bidirectional file-transfer application, both sides will establish a
listening protocol. Later, if/when the user drops a file on the application
window, that side will initiate a connection, use the resulting subchannel to
transfer the single file, and then close the subchannel.
```python
def FileSendingProtocol(internet.Protocol):
def __init__(self, metadata, filename):
self.file_metadata = metadata
self.file_name = filename
def connectionMade(self):
self.transport.write(self.file_metadata)
sender = protocols.basic.FileSender()
f = open(self.file_name,"rb")
d = sender.beginFileTransfer(f, self.transport)
d.addBoth(self._done, f)
def _done(res, f):
self.transport.loseConnection()
f.close()
def _send(metadata, filename):
f = protocol.ClientCreator(reactor,
FileSendingProtocol, metadata, filename)
subchannel_client_ep.connect(f)
def FileReceivingProtocol(internet.Protocol):
state = INITIAL
def dataReceived(self, data):
if state == INITIAL:
self.state = DATA
metadata = parse(data)
self.f = open(metadata.filename, "wb")
else:
# local file writes are blocking, so don't bother with IConsumer
self.f.write(data)
def connectionLost(self, reason):
self.f.close()
def _dilated(res):
(control_channel_ep, subchannel_client_ep, subchannel_server_ep) = res
f = Factory(FileReceivingProtocol)
subchannel_server_ep.listen(f)
```
## Bytes, Strings, Unicode, and Python 3

540
docs/dilation-protocol.md Normal file
View File

@ -0,0 +1,540 @@
# Dilation Internals
Wormhole dilation involves several moving parts. Both sides exchange messages
through the Mailbox server to coordinate the establishment of a more direct
connection. This connection might flow in either direction, so they trade
"connection hints" to point at potential listening ports. This process might
succeed in making multiple connections at about the same time, so one side
must select the best one to use, and cleanly shut down the others. To make
the dilated connection *durable*, this side must also decide when the
connection has been lost, and then coordinate the construction of a
replacement. Within this connection, a series of queued-and-acked subchannel
messages are used to open/use/close the application-visible subchannels.
## Versions and can-dilate
The Wormhole protocol includes a `versions` message sent immediately after
the shared PAKE key is established. This also serves as a key-confirmation
message, allowing each side to confirm that the other side knows the right
key. The body of the `versions` message is a JSON-formatted string with keys
that are available for learning the abilities of the peer. Dilation is
signaled by a key named `can-dilate`, whose value is a list of strings. Any
version present in both side's lists is eligible for use.
## Leaders and Followers
Each side of a Wormhole has a randomly-generated dilation `side` string (this
is included in the `please-dilate` message, and is independent of the
Wormhole's mailbox "side"). When the wormhole is dilated, the side with the
lexicographically-higher "side" value is named the "Leader", and the other
side is named the "Follower". The general wormhole protocol treats both sides
identically, but the distinction matters for the dilation protocol.
Both sides send a `please-dilate` as soon as dilation is triggered. Each side
discovers whether it is the Leader or the Follower when the peer's
"please-dilate" arrives. The Leader has exclusive control over whether a
given connection is considered established or not: if there are multiple
potential connections to use, the Leader decides which one to use, and the
Leader gets to decide when the connection is no longer viable (and triggers
the establishment of a new one).
The `please-dilate` includes a `use-version` key, computed as the "best"
version of the intersection of the two sides' abilities as reported in the
`versions` message. Both sides will use whichever `use-version` was specified
by the Leader (they learn which side is the Leader at the same moment they
learn the peer's `use-version` value). If the Follower cannot handle the
`use-version` value, dilation fails (this shouldn't happen, as the Leader
knew what the Follower was and was not capable of before sending that
message).
## Connection Layers
We describe the protocol as a series of layers. Messages sent on one layer
may be encoded or transformed before being delivered on some other layer.
L1 is the mailbox channel (queued store-and-forward messages that always go
to the mailbox server, and then are forwarded to other clients subscribed to
the same mailbox). Both clients remain connected to the mailbox server until
the Wormhole is closed. They send DILATE-n messages to each other to manage
the dilation process, including records like `please`, `connection-hints`,
`reconnect`, and `reconnecting`.
L2 is the set of competing connection attempts for a given generation of
connection. Each time the Leader decides to establish a new connection, a new
generation number is used. Hopefully these are direct TCP connections between
the two peers, but they may also include connections through the transit
relay. Each connection must go through an encrypted handshake process before
it is considered viable. Viable connections are then submitted to a selection
process (on the Leader side), which chooses exactly one to use, and drops the
others. It may wait an extra few seconds in the hopes of getting a "better"
connection (faster, cheaper, etc), but eventually it will select one.
L3 is the current selected connection. There is one L3 for each generation.
At all times, the wormhole will have exactly zero or one L3 connection. L3 is
responsible for the selection process, connection monitoring/keepalives, and
serialization/deserialization of the plaintext frames. L3 delivers decoded
frames and connection-establishment events up to L4.
L4 is the persistent higher-level channel. It is created as soon as the first
L3 connection is selected, and lasts until wormhole is closed entirely. L4
contains OPEN/DATA/CLOSE/ACK messages: OPEN/DATA/CLOSE have a sequence number
(scoped to the L4 connection and the direction of travel), and the ACK
messages reference those sequence numbers. When a message is given to the L4
channel for delivery to the remote side, it is always queued, then
transmitted if there is an L3 connection available. This message remains in
the queue until an ACK is received to retire it. If a new L3 connection is
made, all queued messages will be re-sent (in seqnum order).
L5 are subchannels. There is one pre-established subchannel 0 known as the
"control channel", which does not require an OPEN message. All other
subchannels are created by the receipt of an OPEN message with the subchannel
number. DATA frames are delivered to a specific subchannel. When the
subchannel is no longer needed, one side will invoke the ``close()`` API
(``loseConnection()`` in Twisted), which will cause a CLOSE message to be
sent, and the local L5 object will be put into the "closing "state. When the
other side receives the CLOSE, it will send its own CLOSE for the same
subchannel, and fully close its local object (``connectionLost()``). When the
first side receives CLOSE in the "closing" state, it will fully close its
local object too.
All L5 subchannels will be paused (``pauseProducing()``) when the L3
connection is paused or lost. They are resumed when the L3 connection is
resumed or reestablished.
## Initiating Dilation
Dilation is triggered by calling the `w.dilate()` API. This returns a
Deferred that will fire once the first L3 connection is established. It fires
with a 3-tuple of endpoints that can be used to establish subchannels.
For dilation to succeed, both sides must call `w.dilate()`, since the
resulting endpoints are the only way to access the subchannels. If the other
side never calls `w.dilate()`, the Deferred will never fire.
The L1 (mailbox) path is used to deliver dilation requests and connection
hints. The current mailbox protocol uses named "phases" to distinguish
messages (rather than behaving like a regular ordered channel of arbitrary
frames or bytes), and all-number phase names are reserved for application
data (sent via `w.send_message()`). Therefore the dilation control messages
use phases named `DILATE-0`, `DILATE-1`, etc. Each side maintains its own
counter, so one side might be up to e.g. `DILATE-5` while the other has only
gotten as far as `DILATE-2`. This effectively creates a pair of
unidirectional streams of `DILATE-n` messages, each containing one or more
dilation record, of various types described below. Note that all phases
beyond the initial VERSION and PAKE phases are encrypted by the shared
session key.
A future mailbox protocol might provide a simple ordered stream of typed
messages, with application records and dilation records mixed together.
Each `DILATE-n` message is a JSON-encoded dictionary with a `type` field that
has a string value. The dictionary will have other keys that depend upon the
type.
`w.dilate()` triggers transmission of a `please` (i.e. "please dilate")
record with a set of versions that can be accepted. Versions use strings,
rather than integers, to support experimental protocols, however there is
still a total ordering of preferability.
```
{ "type": "please",
"side": "abcdef",
"accepted-versions": ["1"]
}
```
If one side receives a `please` before `w.dilate()` has been called locally,
the contents are stored in case `w.dilate()` is called in the future. Once
both `w.dilate()` has been called and the peer's `please` has been received,
the side determines whether it is the Leader or the Follower. Both sides also
compare `accepted-versions` fields to choose the best mutually-compatible
version to use: they should always pick the same one.
Then both sides begin the connection process for generation 1 by opening
listening sockets and sending `connection-hint` records for each one. After a
slight delay they will also open connections to the Transit Relay of their
choice and produce hints for it too. The receipt of inbound hints (on both
sides) will trigger outbound connection attempts.
Some number of these connections may succeed, and the Leader decides which to
use (via an in-band signal on the established connection). The others are
dropped.
If something goes wrong with the established connection and the Leader
decides a new one is necessary, the Leader will send a `reconnect` message.
This might happen while connections are still being established, or while the
Follower thinks it still has a viable connection (the Leader might observe
problems that the Follower does not), or after the Follower thinks the
connection has been lost. In all cases, the Leader is the only side which
should send `reconnect`. The state machine code looks the same on both sides,
for simplicity, but one path on each side is never used.
Upon receiving a `reconnect`, the Follower should stop any pending connection
attempts and terminate any existing connections (even if they appear viable).
Listening sockets may be retained, but any previous connection made through
them must be dropped.
Once all connections have stopped, the Follower should send a `reconnecting`
message, then start the connection process for the next generation, which
will send new `connection-hint` messages for all listening sockets.
Generations are non-overlapping. The Leader will drop all connections from
generation 1 before sending the `reconnect` for generation 2, and will not
initiate any gen-2 connections until it receives the matching `reconnecting`
from the Follower. The Follower must drop all gen-1 connections before it
sends the `reconnecting` response (even if it thinks they are still
functioning: if the Leader thought the gen-1 connection still worked, it
wouldn't have started gen-2).
(TODO: what about a follower->leader connection that was started before
start-dilation is received, and gets established on the Leader side after
start-dilation is sent? the follower will drop it after it receives
start-dilation, but meanwhile the leader may accept it as gen2)
(probably need to include the generation number in the handshake, or in the
derived key)
(TODO: reduce the number of round-trip stalls here, I've added too many)
Each side is in the "connecting" state (which encompasses both making
connection attempts and having an established connection) starting with the
receipt of a `please-dilate` message and a local `w.dilate()` call. The
Leader remains in that state until it abandons the connection and sends a
`reconnect` message, at which point it remains in the "flushing" state until
the Follower's `reconnecting` message is received. The Follower remains in
"connecting" until it receives `reconnect`, then it stays in "dropping" until
it finishes halting all outstanding connections, after which it sends
`reconnecting` and switches back to "connecting".
"Connection hints" are type/address/port records that tell the other side of
likely targets for L2 connections. Both sides will try to determine their
external IP addresses, listen on a TCP port, and advertise `(tcp,
external-IP, port)` as a connection hint. The Transit Relay is also used as a
(lower-priority) hint. These are sent in `connection-hint` records, which can
be sent any time after both sending and receiving a `please` record. Each
side will initiate connections upon receipt of the hints.
```
{ "type": "connection-hints",
"hints": [ ... ]
}
```
Hints can arrive at any time. One side might immediately send hints that can
be computed quickly, then send additional hints later as they become
available. For example, it might enumerate the local network interfaces and
send hints for all of the LAN addresses first, then send port-forwarding
(UPnP) requests to the local router. When the forwarding is established
(providing an externally-visible IP address and port), it can send additional
hints for that new endpoint. If the other peer happens to be on the same LAN,
the local connection can be established without waiting for the router's
response.
### Connection Hint Format
Each member of the `hints` field describes a potential L2 connection target
endpoint, with an associated priority and a set of hints.
The priority is a number (positive or negative float), where larger numbers
indicate that the client supplying that hint would prefer to use this
connection over others of lower number. This indicates a sense of cost or
performance. For example, the Transit Relay is lower priority than a direct
TCP connection, because it incurs a bandwidth cost (on the relay operator),
as well as adding latency.
Each endpoint has a set of hints, because the same target might be reachable
by multiple hints. Once one hint succeeds, there is no point in using the
other hints.
TODO: think this through some more. What's the example of a single endpoint
reachable by multiple hints? Should each hint have its own priority, or just
each endpoint?
## L2 protocol
Upon ``connectionMade()``, both sides send their handshake message. The
Leader sends "Magic-Wormhole Dilation Handshake v1 Leader\n\n". The Follower
sends "Magic-Wormhole Dilation Handshake v1 Follower\n\n". This should
trigger an immediate error for most non-magic-wormhole listeners (e.g. HTTP
servers that were contacted by accident). If the wrong handshake is received,
the connection will be dropped. For debugging purposes, the node might want
to keep looking at data beyond the first incorrect character and log
everything until the first newline.
Everything beyond that point is a Noise protocol message, which consists of a
4-byte big-endian length field, followed by the indicated number of bytes.
This uses the `NNpsk0` pattern with the Leader as the first party ("-> psk,
e" in the Noise spec), and the Follower as the second ("<- e, ee"). The
pre-shared-key is the "dilation key", which is statically derived from the
master PAKE key using HKDF. Each L2 connection uses the same dilation key,
but different ephemeral keys, so each gets a different session key.
The Leader sends the first message, which is a psk-encrypted ephemeral key.
The Follower sends the next message, its own psk-encrypted ephemeral key. The
Follower then sends an empty packet as the "key confirmation message", which
will be encrypted by the shared key.
The Leader sees the KCM and knows the connection is viable. It delivers the
protocol object to the L3 manager, which will decide which connection to
select. When the L2 connection is selected to be the new L3, it will send an
empty KCM of its own, to let the Follower know the connection being selected.
All other L2 connections (either viable or still in handshake) are dropped,
all other connection attempts are cancelled. All listening sockets may or may
not be shut down (TODO: think about it).
The Follower will wait for either an empty KCM (at which point the L2
connection is delivered to the Dilation manager as the new L3), a
disconnection, or an invalid message (which causes the connection to be
dropped). Other connections and/or listening sockets are stopped.
Internally, the L2Protocol object manages the Noise session itself. It knows
(via a constructor argument) whether it is on the Leader or Follower side,
which affects both the role is plays in the Noise pattern, and the reaction
to receiving the ephemeral key (for which only the Follower sends an empty
KCM message). After that, the L2Protocol notifies the L3 object in three
situations:
* the Noise session produces a valid decrypted frame (for Leader, this
includes the Follower's KCM, and thus indicates a viable candidate for
connection selection)
* the Noise session reports a failed decryption
* the TCP session is lost
All notifications include a reference to the L2Protocol object (`self`). The
L3 object uses this reference to either close the connection (for errors or
when the selection process chooses someone else), to send the KCM message
(after selection, only for the Leader), or to send other L4 messages. The L3
object will retain a reference to the winning L2 object.
## L3 protocol
The L3 layer is responsible for connection selection, monitoring/keepalives,
and message (de)serialization. Framing is handled by L2, so the inbound L3
codepath receives single-message bytestrings, and delivers the same down to
L2 for encryption, framing, and transmission.
Connection selection takes place exclusively on the Leader side, and includes
the following:
* receipt of viable L2 connections from below (indicated by the first valid
decrypted frame received for any given connection)
* expiration of a timer
* comparison of TBD quality/desirability/cost metrics of viable connections
* selection of winner
* instructions to losing connections to disconnect
* delivery of KCM message through winning connection
* retain reference to winning connection
On the Follower side, the L3 manager just waits for the first connection to
receive the Leader's KCM, at which point it is retained and all others are
dropped.
The L3 manager knows which "generation" of connection is being established.
Each generation uses a different dilation key (?), and is triggered by a new
set of L1 messages. Connections from one generation should not be confused
with those of a different generation.
Each time a new L3 connection is established, the L4 protocol is notified. It
will will immediately send all the L4 messages waiting in its outbound queue.
The L3 protocol simply wraps these in Noise frames and sends them to the
other side.
The L3 manager monitors the viability of the current connection, and declares
it as lost when bidirectional traffic cannot be maintained. It uses PING and
PONG messages to detect this. These also serve to keep NAT entries alive,
since many firewalls will stop forwarding packets if they don't observe any
traffic for e.g. 5 minutes.
Our goals are:
* don't allow more than 30? seconds to pass without at least *some* data
being sent along each side of the connection
* allow the Leader to detect silent connection loss within 60? seconds
* minimize overhead
We need both sides to:
* maintain a 30-second repeating timer
* set a flag each time we write to the connection
* each time the timer fires, if the flag was clear then send a PONG,
otherwise clear the flag
In addition, the Leader must:
* run a 60-second repeating timer (ideally somewhat offset from the other)
* set a flag each time we receive data from the connection
* each time the timer fires, if the flag was clear then drop the connection,
otherwise clear the flag
In the future, we might have L2 links that are less connection-oriented,
which might have a unidirectional failure mode, at which point we'll need to
monitor full roundtrips. To accomplish this, the Leader will send periodic
unconditional PINGs, and the Follower will respond with PONGs. If the
Leader->Follower connection is down, the PINGs won't arrive and no PONGs will
be produced. If the Follower->Leader direction has failed, the PONGs won't
arrive. The delivery of both will be delayed by actual data, so the timeouts
should be adjusted if we see regular data arriving.
If the connection is dropped before the wormhole is closed (either the other
end explicitly dropped it, we noticed a problem and told TCP to drop it, or
TCP noticed a problem itself), the Leader-side L3 manager will initiate a
reconnection attempt. This uses L1 to send a new DILATE message through the
mailbox server, along with new connection hints. Eventually this will result
in a new L3 connection being established.
Finally, L3 is responsible for message serialization and deserialization. L2
performs decryption and delivers plaintext frames to L3. Each frame starts
with a one-byte type indicator. The rest of the message depends upon the
type:
* 0x00 PING, 4-byte ping-id
* 0x01 PONG, 4-byte ping-id
* 0x02 OPEN, 4-byte subchannel-id, 4-byte seqnum
* 0x03 DATA, 4-byte subchannel-id, 4-byte seqnum, variable-length payload
* 0x04 CLOSE, 4-byte subchannel-id, 4-byte seqnum
* 0x05 ACK, 4-byte response-seqnum
All seqnums are big-endian, and are provided by the L4 protocol. The other
fields are arbitrary and not interpreted as integers. The subchannel-ids must
be allocated by both sides without collision, but otherwise they are only
used to look up L5 objects for dispatch. The response-seqnum is always copied
from the OPEN/DATA/CLOSE packet being acknowledged.
L3 consumes the PING and PONG messages. Receiving any PING will provoke a
PONG in response, with a copy of the ping-id field. The 30-second timer will
produce unprovoked PONGs with a ping-id of all zeros. A future viability
protocol will use PINGs to test for roundtrip functionality.
All other messages (OPEN/DATA/CLOSE/ACK) are deserialized and delivered
"upstairs" to the L4 protocol handler.
The current L3 connection's `IProducer`/`IConsumer` interface is made
available to the L4 flow-control manager.
## L4 protocol
The L4 protocol manages a durable stream of OPEN/DATA/CLOSE/ACK messages.
Since each will be enclosed in a Noise frame before they pass to L3, they do
not need length fields or other framing.
Each OPEN/DATA/CLOSE has a sequence number, starting at 0, and monotonically
increasing by 1 for each message. Each direction has a separate number space.
The L4 manager maintains a double-ended queue of unacknowledged outbound
messages. Subchannel activity (opening, closing, sending data) cause messages
to be added to this queue. If an L3 connection is available, these messages
are also sent over that connection, but they remain in the queue in case the
connection is lost and they must be retransmitted on some future replacement
connection. Messages stay in the queue until they can be retired by the
receipt of an ACK with a matching response-sequence-number. This provides
reliable message delivery that survives the L3 connection being replaced.
ACKs are not acked, nor do they have seqnums of their own. Each inbound side
remembers the highest ACK it has sent, and ignores incoming OPEN/DATA/CLOSE
messages with that sequence number or higher. This ensures in-order
at-most-once processing of OPEN/DATA/CLOSE messages.
Each inbound OPEN message causes a new L5 subchannel object to be created.
Subsequent DATA/CLOSE messages for the same subchannel-id are delivered to
that object.
Each time an L3 connection is established, the side will immediately send all
L4 messages waiting in the outbound queue. A future protocol might reduce
this duplication by including the highest received sequence number in the L1
PLEASE-DILATE message, which would effectively retire queued messages before
initiating the L2 connection process. On any given L3 connection, all
messages are sent in-order. The receipt of an ACK for seqnum `N` allows all
messages with `seqnum <= N` to be retired.
The L4 layer is also responsible for managing flow control among the L3
connection and the various L5 subchannels.
## L5 subchannels
The L5 layer consists of a collection of "subchannel" objects, a dispatcher,
and the endpoints that provide the Twisted-flavored API.
Other than the "control channel", all subchannels are created by a client
endpoint connection API. The side that calls this API is named the Initiator,
and the other side is named the Acceptor. Subchannels can be initiated in
either direction, independent of the Leader/Follower distinction. For a
typical file-transfer application, the subchannel would be initiated by the
side seeking to send a file.
Each subchannel uses a distinct subchannel-id, which is a four-byte
identifier. Both directions share a number space (unlike L4 seqnums), so the
rule is that the Leader side sets the last bit of the last byte to a 1, while
the Follower sets it to a 0. These are not generally treated as integers,
however for the sake of debugging, the implementation generates them with a
simple big-endian-encoded counter (`counter*2+1` for the Leader,
`counter*2+2` for the Follower, with id `0` reserved for the control
channel).
When the `client_ep.connect()` API is called, the Initiator allocates a
subchannel-id and sends an OPEN. It can then immediately send DATA messages
with the outbound data (there is no special response to an OPEN, so there is
no need to wait). The Acceptor will trigger their `.connectionMade` handler
upon receipt of the OPEN.
Subchannels are durable: they do not close until one side calls
`.loseConnection` on the subchannel object (or the enclosing Wormhole is
closed). Either the Initiator or the Acceptor can call `.loseConnection`.
This causes a CLOSE message to be sent (with the subchannel-id). The other
side will send its own CLOSE message in response. Each side will signal the
`.connectionLost()` event upon receipt of a CLOSE.
There is no equivalent to TCP's "half-closed" state, however if only one side
calls `close()`, then all data written before that call will be delivered
before the other side observes `.connectionLost()`. Any inbound data that was
queued for delivery before the other side sees the CLOSE will still be
delivered to the side that called `close()` before it sees
`.connectionLost()`. Internally, the side which called `.loseConnection` will
remain in a special "closing" state until the CLOSE response arrives, during
which time DATA payloads are still delivered. After calling `close()` (or
receiving CLOSE), any outbound `.write()` calls will trigger an error.
DATA payloads that arrive for a non-open subchannel are logged and discarded.
This protocol calls for one OPEN and two CLOSE messages for each subchannel,
with some arbitrary number of DATA messages in between. Subchannel-ids should
not be reused (it would probably work, the protocol hasn't been analyzed
enough to be sure).
The "control channel" is special. It uses a subchannel-id of all zeros, and
is opened implicitly by both sides as soon as the first L3 connection is
selected. It is routed to a special client-on-both-sides endpoint, rather
than causing the listening endpoint to accept a new connection. This avoids
the need for application-level code to negotiate who should be the one to
open it (the Leader/Follower distinction is private to the Wormhole
internals: applications are not obligated to pick a side).
OPEN and CLOSE messages for the control channel are logged and discarded. The
control-channel client endpoints can only be used once, and does not close
until the Wormhole itself is closed.
Each OPEN/DATA/CLOSE message is delivered to the L4 object for queueing,
delivery, and eventual retirement. The L5 layer does not keep track of old
messages.
### Flow Control
Subchannels are flow-controlled by pausing their writes when the L3
connection is paused, and pausing the L3 connection when the subchannel
signals a pause. When the outbound L3 connection is full, *all* subchannels
are paused. Likewise the inbound connection is paused if *any* of the
subchannels asks for a pause. This is much easier to implement and improves
our utilization factor (we can use TCP's window-filling algorithm, instead of
rolling our own), but will block all subchannels even if only one of them
gets full. This shouldn't matter for many applications, but might be
noticeable when combining very different kinds of traffic (e.g. a chat
conversation sharing a wormhole with file-transfer might prefer the IM text
to take priority).
Each subchannel implements Twisted's `ITransport`, `IProducer`, and
`IConsumer` interfaces. The Endpoint API causes a new `IProtocol` object to
be created (by the caller's factory) and glued to the subchannel object in
the `.transport` property, as is standard in Twisted-based applications.
All subchannels are also paused when the L3 connection is lost, and are
unpaused when a new replacement connection is selected.

2000
docs/new-protocol.svg Normal file

File diff suppressed because it is too large Load Diff

After

Width:  |  Height:  |  Size: 91 KiB

View File

@ -23,12 +23,13 @@ digraph {
Terminator [shape="box" color="blue" fontcolor="blue"]
InputHelperAPI [shape="oval" label="input\nhelper\nAPI"
color="blue" fontcolor="blue"]
Dilator [shape="box" label="Dilator" color="blue" fontcolor="blue"]
#Connection -> websocket [color="blue"]
#Connection -> Order [color="blue"]
Wormhole -> Boss [style="dashed"
label="allocate_code\ninput_code\nset_code\nsend\nclose\n(once)"
label="allocate_code\ninput_code\nset_code\ndilate\nsend\nclose\n(once)"
color="red" fontcolor="red"]
#Wormhole -> Boss [color="blue"]
Boss -> Wormhole [style="dashed" label="got_code\ngot_key\ngot_verifier\ngot_version\nreceived (seq)\nclosed\n(once)"]
@ -112,4 +113,7 @@ digraph {
Terminator -> Boss [style="dashed" label="closed\n(once)"]
Boss -> Terminator [style="dashed" color="red" fontcolor="red"
label="close"]
Boss -> Dilator [style="dashed" label="dilate\nreceived_dilate\ngot_wormhole_versions"]
Dilator -> Send [style="dashed" label="send(dilate-N)"]
}

View File

@ -7,3 +7,6 @@ versionfile_source = src/wormhole/_version.py
versionfile_build = wormhole/_version.py
tag_prefix =
parentdir_prefix = magic-wormhole
[flake8]
max-line-length = 85

View File

@ -54,6 +54,7 @@ setup(name="magic-wormhole",
"dev": ["mock", "tox", "pyflakes",
"magic-wormhole-transit-relay==0.1.2",
"magic-wormhole-mailbox-server==0.3.1"],
"dilate": ["noiseprotocol"],
},
test_suite="wormhole.test",
cmdclass=commands,

View File

@ -1,7 +1,7 @@
from __future__ import absolute_import, print_function, unicode_literals
from .cli import cli
if __name__ != "__main__":
raise ImportError('this module should not be imported')
cli.wormhole()
if __name__ == "__main__":
from .cli import cli
cli.wormhole()
else:
# raise ImportError('this module should not be imported')
pass

View File

@ -12,6 +12,7 @@ from zope.interface import implementer
from . import _interfaces
from ._allocator import Allocator
from ._code import Code, validate_code
from ._dilation.manager import Dilator
from ._input import Input
from ._key import Key
from ._lister import Lister
@ -38,6 +39,8 @@ class Boss(object):
_versions = attrib(validator=instance_of(dict))
_client_version = attrib(validator=instance_of(tuple))
_reactor = attrib()
_eventual_queue = attrib()
_cooperator = attrib()
_journal = attrib(validator=provides(_interfaces.IJournal))
_tor = attrib(validator=optional(provides(_interfaces.ITorManager)))
_timing = attrib(validator=provides(_interfaces.ITiming))
@ -64,6 +67,8 @@ class Boss(object):
self._I = Input(self._timing)
self._C = Code(self._timing)
self._T = Terminator()
self._D = Dilator(self._reactor, self._eventual_queue,
self._cooperator)
self._N.wire(self._M, self._I, self._RC, self._T)
self._M.wire(self._N, self._RC, self._O, self._T)
@ -77,6 +82,7 @@ class Boss(object):
self._I.wire(self._C, self._L)
self._C.wire(self, self._A, self._N, self._K, self._I)
self._T.wire(self, self._RC, self._N, self._M)
self._D.wire(self._S)
def _init_other_state(self):
self._did_start_code = False
@ -84,6 +90,9 @@ class Boss(object):
self._next_rx_phase = 0
self._rx_phases = {} # phase -> plaintext
self._next_rx_dilate_seqnum = 0
self._rx_dilate_seqnums = {} # seqnum -> plaintext
self._result = "empty"
# these methods are called from outside
@ -196,6 +205,9 @@ class Boss(object):
self._did_start_code = True
self._C.set_code(code)
def dilate(self):
return self._D.dilate() # fires with endpoints
@m.input()
def send(self, plaintext):
pass
@ -255,8 +267,11 @@ class Boss(object):
def got_message(self, phase, plaintext):
assert isinstance(phase, type("")), type(phase)
assert isinstance(plaintext, type(b"")), type(plaintext)
d_mo = re.search(r'^dilate-(\d+)$', phase)
if phase == "version":
self._got_version(plaintext)
elif d_mo:
self._got_dilate(int(d_mo.group(1)), plaintext)
elif re.search(r'^\d+$', phase):
self._got_phase(int(phase), plaintext)
else:
@ -272,6 +287,10 @@ class Boss(object):
def _got_phase(self, phase, plaintext):
pass
@m.input()
def _got_dilate(self, seqnum, plaintext):
pass
@m.input()
def got_key(self, key):
pass
@ -294,6 +313,7 @@ class Boss(object):
# most of this is wormhole-to-wormhole, ignored for now
# in the future, this is how Dilation is signalled
self._their_versions = bytes_to_dict(plaintext)
self._D.got_wormhole_versions(self._their_versions)
# but this part is app-to-app
app_versions = self._their_versions.get("app_versions", {})
self._W.got_versions(app_versions)
@ -335,6 +355,10 @@ class Boss(object):
def W_got_key(self, key):
self._W.got_key(key)
@m.output()
def D_got_key(self, key):
self._D.got_key(key)
@m.output()
def W_got_verifier(self, verifier):
self._W.got_verifier(verifier)
@ -348,6 +372,16 @@ class Boss(object):
self._W.received(self._rx_phases.pop(self._next_rx_phase))
self._next_rx_phase += 1
@m.output()
def D_received_dilate(self, seqnum, plaintext):
assert isinstance(seqnum, six.integer_types), type(seqnum)
# strict phase order, no gaps
self._rx_dilate_seqnums[seqnum] = plaintext
while self._next_rx_dilate_seqnum in self._rx_dilate_seqnums:
m = self._rx_dilate_seqnums.pop(self._next_rx_dilate_seqnum)
self._D.received_dilate(m)
self._next_rx_dilate_seqnum += 1
@m.output()
def W_close_with_error(self, err):
self._result = err # exception
@ -370,7 +404,7 @@ class Boss(object):
S1_lonely.upon(scared, enter=S3_closing, outputs=[close_scared])
S1_lonely.upon(close, enter=S3_closing, outputs=[close_lonely])
S1_lonely.upon(send, enter=S1_lonely, outputs=[S_send])
S1_lonely.upon(got_key, enter=S1_lonely, outputs=[W_got_key])
S1_lonely.upon(got_key, enter=S1_lonely, outputs=[W_got_key, D_got_key])
S1_lonely.upon(rx_error, enter=S3_closing, outputs=[close_error])
S1_lonely.upon(error, enter=S4_closed, outputs=[W_close_with_error])
@ -378,6 +412,7 @@ class Boss(object):
S2_happy.upon(got_verifier, enter=S2_happy, outputs=[W_got_verifier])
S2_happy.upon(_got_phase, enter=S2_happy, outputs=[W_received])
S2_happy.upon(_got_version, enter=S2_happy, outputs=[process_version])
S2_happy.upon(_got_dilate, enter=S2_happy, outputs=[D_received_dilate])
S2_happy.upon(scared, enter=S3_closing, outputs=[close_scared])
S2_happy.upon(close, enter=S3_closing, outputs=[close_happy])
S2_happy.upon(send, enter=S2_happy, outputs=[S_send])
@ -389,6 +424,7 @@ class Boss(object):
S3_closing.upon(got_verifier, enter=S3_closing, outputs=[])
S3_closing.upon(_got_phase, enter=S3_closing, outputs=[])
S3_closing.upon(_got_version, enter=S3_closing, outputs=[])
S3_closing.upon(_got_dilate, enter=S3_closing, outputs=[])
S3_closing.upon(happy, enter=S3_closing, outputs=[])
S3_closing.upon(scared, enter=S3_closing, outputs=[])
S3_closing.upon(close, enter=S3_closing, outputs=[])
@ -400,6 +436,7 @@ class Boss(object):
S4_closed.upon(got_verifier, enter=S4_closed, outputs=[])
S4_closed.upon(_got_phase, enter=S4_closed, outputs=[])
S4_closed.upon(_got_version, enter=S4_closed, outputs=[])
S4_closed.upon(_got_dilate, enter=S4_closed, outputs=[])
S4_closed.upon(happy, enter=S4_closed, outputs=[])
S4_closed.upon(scared, enter=S4_closed, outputs=[])
S4_closed.upon(close, enter=S4_closed, outputs=[])

View File

View File

@ -0,0 +1,11 @@
try:
from noise.exceptions import NoiseInvalidMessage
except ImportError:
class NoiseInvalidMessage(Exception):
pass
try:
from noise.connection import NoiseConnection
except ImportError:
# allow imports to work on py2.7, even if dilation doesn't
NoiseConnection = None

View File

@ -0,0 +1,520 @@
from __future__ import print_function, unicode_literals
from collections import namedtuple
import six
from attr import attrs, attrib
from attr.validators import instance_of, provides
from automat import MethodicalMachine
from zope.interface import Interface, implementer
from twisted.python import log
from twisted.internet.protocol import Protocol
from twisted.internet.interfaces import ITransport
from .._interfaces import IDilationConnector
from ..observer import OneShotObserver
from .encode import to_be4, from_be4
from .roles import FOLLOWER
from ._noise import NoiseInvalidMessage
# InboundFraming is given data and returns Frames (Noise wire-side
# bytestrings). It handles the relay handshake and the prologue. The Frames it
# returns are either the ephemeral key (the Noise "handshake") or ciphertext
# messages.
# The next object up knows whether it's expecting a Handshake or a message. It
# feeds the first into Noise as a handshake, it feeds the rest into Noise as a
# message (which produces a plaintext stream). It emits tokens that are either
# "i've finished with the handshake (so you can send the KCM if you want)", or
# "here is a decrypted message (which might be the KCM)".
# the transmit direction goes directly to transport.write, and doesn't touch
# the state machine. we can do this because the way we encode/encrypt/frame
# things doesn't depend upon the receiver state. It would be more safe to e.g.
# prohibit sending ciphertext frames unless we're in the received-handshake
# state, but then we'll be in the middle of an inbound state transition ("we
# just received the handshake, so you can send a KCM now") when we perform an
# operation that depends upon the state (send_plaintext(kcm)), which is not a
# coherent/safe place to touch the state machine.
# we could set a flag and test it from inside send_plaintext, which kind of
# violates the state machine owning the state (ideally all "if" statements
# would be translated into same-input transitions from different starting
# states). For the specific question of sending plaintext frames, Noise will
# refuse us unless it's ready anyways, so the question is probably moot.
class IFramer(Interface):
pass
class IRecord(Interface):
pass
def first(l):
return l[0]
class Disconnect(Exception):
pass
RelayOK = namedtuple("RelayOk", [])
Prologue = namedtuple("Prologue", [])
Frame = namedtuple("Frame", ["frame"])
@attrs
@implementer(IFramer)
class _Framer(object):
_transport = attrib(validator=provides(ITransport))
_outbound_prologue = attrib(validator=instance_of(bytes))
_inbound_prologue = attrib(validator=instance_of(bytes))
_buffer = b""
_can_send_frames = False
# in: use_relay
# in: connectionMade, dataReceived
# out: prologue_received, frame_received
# out (shared): transport.loseConnection
# out (shared): transport.write (relay handshake, prologue)
# states: want_relay, want_prologue, want_frame
m = MethodicalMachine()
set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover
@m.state()
def want_relay(self):
pass # pragma: no cover
@m.state(initial=True)
def want_prologue(self):
pass # pragma: no cover
@m.state()
def want_frame(self):
pass # pragma: no cover
@m.input()
def use_relay(self, relay_handshake):
pass
@m.input()
def connectionMade(self):
pass
@m.input()
def parse(self):
pass
@m.input()
def got_relay_ok(self):
pass
@m.input()
def got_prologue(self):
pass
@m.output()
def store_relay_handshake(self, relay_handshake):
self._outbound_relay_handshake = relay_handshake
self._expected_relay_handshake = b"ok\n" # TODO: make this configurable
@m.output()
def send_relay_handshake(self):
self._transport.write(self._outbound_relay_handshake)
@m.output()
def send_prologue(self):
self._transport.write(self._outbound_prologue)
@m.output()
def parse_relay_ok(self):
if self._get_expected("relay_ok", self._expected_relay_handshake):
return RelayOK()
@m.output()
def parse_prologue(self):
if self._get_expected("prologue", self._inbound_prologue):
return Prologue()
@m.output()
def can_send_frames(self):
self._can_send_frames = True # for assertion in send_frame()
@m.output()
def parse_frame(self):
if len(self._buffer) < 4:
return None
frame_length = from_be4(self._buffer[0:4])
if len(self._buffer) < 4 + frame_length:
return None
frame = self._buffer[4:4 + frame_length]
self._buffer = self._buffer[4 + frame_length:] # TODO: avoid copy
return Frame(frame=frame)
want_prologue.upon(use_relay, outputs=[store_relay_handshake],
enter=want_relay)
want_relay.upon(connectionMade, outputs=[send_relay_handshake],
enter=want_relay)
want_relay.upon(parse, outputs=[parse_relay_ok], enter=want_relay,
collector=first)
want_relay.upon(got_relay_ok, outputs=[send_prologue], enter=want_prologue)
want_prologue.upon(connectionMade, outputs=[send_prologue],
enter=want_prologue)
want_prologue.upon(parse, outputs=[parse_prologue], enter=want_prologue,
collector=first)
want_prologue.upon(got_prologue, outputs=[can_send_frames], enter=want_frame)
want_frame.upon(parse, outputs=[parse_frame], enter=want_frame,
collector=first)
def _get_expected(self, name, expected):
lb = len(self._buffer)
le = len(expected)
if self._buffer.startswith(expected):
# if the buffer starts with the expected string, consume it and
# return True
self._buffer = self._buffer[le:]
return True
if not expected.startswith(self._buffer):
# we're not on track: the data we've received so far does not
# match the expected value, so this can't possibly be right.
# Don't complain until we see the expected length, or a newline,
# so we can capture the weird input in the log for debugging.
if (b"\n" in self._buffer or lb >= le):
log.msg("bad {}: {}".format(name, self._buffer[:le]))
raise Disconnect()
return False # wait a bit longer
# good so far, just waiting for the rest
return False
# external API is: connectionMade, add_and_parse, and send_frame
def add_and_parse(self, data):
# we can't make this an @m.input because we can't change the state
# from within an input. Instead, let the state choose the parser to
# use, and use the parsed token drive a state transition.
self._buffer += data
while True:
# it'd be nice to use an iterator here, but since self.parse()
# dispatches to a different parser (depending upon the current
# state), we'd be using multiple iterators
token = self.parse()
if isinstance(token, RelayOK):
self.got_relay_ok()
elif isinstance(token, Prologue):
self.got_prologue()
yield token # triggers send_handshake
elif isinstance(token, Frame):
yield token
else:
break
def send_frame(self, frame):
assert self._can_send_frames
self._transport.write(to_be4(len(frame)) + frame)
# RelayOK: Newline-terminated buddy-is-connected response from Relay.
# First data received from relay.
# Prologue: double-newline-terminated this-is-really-wormhole response
# from peer. First data received from peer.
# Frame: Either handshake or encrypted message. Length-prefixed on wire.
# Handshake: the Noise ephemeral key, first framed message
# Message: plaintext: encoded KCM/PING/PONG/OPEN/DATA/CLOSE/ACK
# KCM: Key Confirmation Message (encrypted b"\x00"). First frame
# from peer. Sent immediately by Follower, after Selection by Leader.
# Record: namedtuple of KCM/Open/Data/Close/Ack/Ping/Pong
Handshake = namedtuple("Handshake", [])
# decrypted frames: produces KCM, Ping, Pong, Open, Data, Close, Ack
KCM = namedtuple("KCM", [])
Ping = namedtuple("Ping", ["ping_id"]) # ping_id is arbitrary 4-byte value
Pong = namedtuple("Pong", ["ping_id"])
Open = namedtuple("Open", ["seqnum", "scid"]) # seqnum is integer
Data = namedtuple("Data", ["seqnum", "scid", "data"])
Close = namedtuple("Close", ["seqnum", "scid"]) # scid is integer
Ack = namedtuple("Ack", ["resp_seqnum"]) # resp_seqnum is integer
Records = (KCM, Ping, Pong, Open, Data, Close, Ack)
Handshake_or_Records = (Handshake,) + Records
T_KCM = b"\x00"
T_PING = b"\x01"
T_PONG = b"\x02"
T_OPEN = b"\x03"
T_DATA = b"\x04"
T_CLOSE = b"\x05"
T_ACK = b"\x06"
def parse_record(plaintext):
msgtype = plaintext[0:1]
if msgtype == T_KCM:
return KCM()
if msgtype == T_PING:
ping_id = plaintext[1:5]
return Ping(ping_id)
if msgtype == T_PONG:
ping_id = plaintext[1:5]
return Pong(ping_id)
if msgtype == T_OPEN:
scid = from_be4(plaintext[1:5])
seqnum = from_be4(plaintext[5:9])
return Open(seqnum, scid)
if msgtype == T_DATA:
scid = from_be4(plaintext[1:5])
seqnum = from_be4(plaintext[5:9])
data = plaintext[9:]
return Data(seqnum, scid, data)
if msgtype == T_CLOSE:
scid = from_be4(plaintext[1:5])
seqnum = from_be4(plaintext[5:9])
return Close(seqnum, scid)
if msgtype == T_ACK:
resp_seqnum = from_be4(plaintext[1:5])
return Ack(resp_seqnum)
log.err("received unknown message type: {}".format(plaintext))
raise ValueError()
def encode_record(r):
if isinstance(r, KCM):
return b"\x00"
if isinstance(r, Ping):
return b"\x01" + r.ping_id
if isinstance(r, Pong):
return b"\x02" + r.ping_id
if isinstance(r, Open):
assert isinstance(r.scid, six.integer_types)
assert isinstance(r.seqnum, six.integer_types)
return b"\x03" + to_be4(r.scid) + to_be4(r.seqnum)
if isinstance(r, Data):
assert isinstance(r.scid, six.integer_types)
assert isinstance(r.seqnum, six.integer_types)
return b"\x04" + to_be4(r.scid) + to_be4(r.seqnum) + r.data
if isinstance(r, Close):
assert isinstance(r.scid, six.integer_types)
assert isinstance(r.seqnum, six.integer_types)
return b"\x05" + to_be4(r.scid) + to_be4(r.seqnum)
if isinstance(r, Ack):
assert isinstance(r.resp_seqnum, six.integer_types)
return b"\x06" + to_be4(r.resp_seqnum)
raise TypeError(r)
@attrs
@implementer(IRecord)
class _Record(object):
_framer = attrib(validator=provides(IFramer))
_noise = attrib()
n = MethodicalMachine()
# TODO: set_trace
def __attrs_post_init__(self):
self._noise.start_handshake()
# in: role=
# in: prologue_received, frame_received
# out: handshake_received, record_received
# out: transport.write (noise handshake, encrypted records)
# states: want_prologue, want_handshake, want_record
@n.state(initial=True)
def want_prologue(self):
pass # pragma: no cover
@n.state()
def want_handshake(self):
pass # pragma: no cover
@n.state()
def want_message(self):
pass # pragma: no cover
@n.input()
def got_prologue(self):
pass
@n.input()
def got_frame(self, frame):
pass
@n.output()
def send_handshake(self):
handshake = self._noise.write_message() # generate the ephemeral key
self._framer.send_frame(handshake)
@n.output()
def process_handshake(self, frame):
try:
payload = self._noise.read_message(frame)
# Noise can include unencrypted data in the handshake, but we don't
# use it
del payload
except NoiseInvalidMessage as e:
log.err(e, "bad inbound noise handshake")
raise Disconnect()
return Handshake()
@n.output()
def decrypt_message(self, frame):
try:
message = self._noise.decrypt(frame)
except NoiseInvalidMessage as e:
# if this happens during tests, flunk the test
log.err(e, "bad inbound noise frame")
raise Disconnect()
return parse_record(message)
want_prologue.upon(got_prologue, outputs=[send_handshake],
enter=want_handshake)
want_handshake.upon(got_frame, outputs=[process_handshake],
collector=first, enter=want_message)
want_message.upon(got_frame, outputs=[decrypt_message],
collector=first, enter=want_message)
# external API is: connectionMade, dataReceived, send_record
def connectionMade(self):
self._framer.connectionMade()
def add_and_unframe(self, data):
for token in self._framer.add_and_parse(data):
if isinstance(token, Prologue):
self.got_prologue() # triggers send_handshake
else:
assert isinstance(token, Frame)
yield self.got_frame(token.frame) # Handshake or a Record type
def send_record(self, r):
message = encode_record(r)
frame = self._noise.encrypt(message)
self._framer.send_frame(frame)
@attrs
class DilatedConnectionProtocol(Protocol, object):
"""I manage an L2 connection.
When a new L2 connection is needed (as determined by the Leader),
both Leader and Follower will initiate many simultaneous connections
(probably TCP, but conceivably others). A subset will actually
connect. A subset of those will successfully pass negotiation by
exchanging handshakes to demonstrate knowledge of the session key.
One of the negotiated connections will be selected by the Leader for
active use, and the others will be dropped.
At any given time, there is at most one active L2 connection.
"""
_eventual_queue = attrib()
_role = attrib()
_connector = attrib(validator=provides(IDilationConnector))
_noise = attrib()
_outbound_prologue = attrib(validator=instance_of(bytes))
_inbound_prologue = attrib(validator=instance_of(bytes))
_use_relay = False
_relay_handshake = None
m = MethodicalMachine()
set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover
def __attrs_post_init__(self):
self._manager = None # set if/when we are selected
self._disconnected = OneShotObserver(self._eventual_queue)
self._can_send_records = False
@m.state(initial=True)
def unselected(self):
pass # pragma: no cover
@m.state()
def selecting(self):
pass # pragma: no cover
@m.state()
def selected(self):
pass # pragma: no cover
@m.input()
def got_kcm(self):
pass
@m.input()
def select(self, manager):
pass # fires set_manager()
@m.input()
def got_record(self, record):
pass
@m.output()
def add_candidate(self):
self._connector.add_candidate(self)
@m.output()
def set_manager(self, manager):
self._manager = manager
@m.output()
def can_send_records(self, manager):
self._can_send_records = True
@m.output()
def deliver_record(self, record):
self._manager.got_record(record)
unselected.upon(got_kcm, outputs=[add_candidate], enter=selecting)
selecting.upon(select, outputs=[set_manager, can_send_records], enter=selected)
selected.upon(got_record, outputs=[deliver_record], enter=selected)
# called by Connector
def use_relay(self, relay_handshake):
assert isinstance(relay_handshake, bytes)
self._use_relay = True
self._relay_handshake = relay_handshake
def when_disconnected(self):
return self._disconnected.when_fired()
def disconnect(self):
self.transport.loseConnection()
# select() called by Connector
# called by Manager
def send_record(self, record):
assert self._can_send_records
self._record.send_record(record)
# IProtocol methods
def connectionMade(self):
framer = _Framer(self.transport,
self._outbound_prologue, self._inbound_prologue)
if self._use_relay:
framer.use_relay(self._relay_handshake)
self._record = _Record(framer, self._noise)
self._record.connectionMade()
def dataReceived(self, data):
try:
for token in self._record.add_and_unframe(data):
assert isinstance(token, Handshake_or_Records)
if isinstance(token, Handshake):
if self._role is FOLLOWER:
self._record.send_record(KCM())
elif isinstance(token, KCM):
# if we're the leader, add this connection as a candiate.
# if we're the follower, accept this connection.
self.got_kcm() # connector.add_candidate()
else:
self.got_record(token) # manager.got_record()
except Disconnect:
self.transport.loseConnection()
def connectionLost(self, why=None):
self._disconnected.fire(self)

View File

@ -0,0 +1,387 @@
from __future__ import print_function, unicode_literals
from collections import defaultdict
from binascii import hexlify
from attr import attrs, attrib
from attr.validators import instance_of, provides, optional
from automat import MethodicalMachine
from zope.interface import implementer
from twisted.internet.task import deferLater
from twisted.internet.defer import DeferredList
from twisted.internet.endpoints import serverFromString
from twisted.internet.protocol import ClientFactory, ServerFactory
from twisted.python import log
from .. import ipaddrs # TODO: move into _dilation/
from .._interfaces import IDilationConnector, IDilationManager
from ..timing import DebugTiming
from ..observer import EmptyableSet
from ..util import HKDF, to_unicode
from .connection import DilatedConnectionProtocol, KCM
from .roles import LEADER
from .._hints import (DirectTCPV1Hint, TorTCPV1Hint, RelayV1Hint,
parse_hint_argv, describe_hint_obj, endpoint_from_hint_obj,
encode_hint)
from ._noise import NoiseConnection
def build_sided_relay_handshake(key, side):
assert isinstance(side, type(u""))
assert len(side) == 8 * 2
token = HKDF(key, 32, CTXinfo=b"transit_relay_token")
return (b"please relay " + hexlify(token) +
b" for side " + side.encode("ascii") + b"\n")
PROLOGUE_LEADER = b"Magic-Wormhole Dilation Handshake v1 Leader\n\n"
PROLOGUE_FOLLOWER = b"Magic-Wormhole Dilation Handshake v1 Follower\n\n"
NOISEPROTO = b"Noise_NNpsk0_25519_ChaChaPoly_BLAKE2s"
def build_noise():
return NoiseConnection.from_name(NOISEPROTO)
@attrs
@implementer(IDilationConnector)
class Connector(object):
_dilation_key = attrib(validator=instance_of(type(b"")))
_transit_relay_location = attrib(validator=optional(instance_of(type(u""))))
_manager = attrib(validator=provides(IDilationManager))
_reactor = attrib()
_eventual_queue = attrib()
_no_listen = attrib(validator=instance_of(bool))
_tor = attrib()
_timing = attrib()
_side = attrib(validator=instance_of(type(u"")))
# was self._side = bytes_to_hexstr(os.urandom(8)) # unicode
_role = attrib()
m = MethodicalMachine()
set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover
RELAY_DELAY = 2.0
def __attrs_post_init__(self):
if self._transit_relay_location:
# TODO: allow multiple hints for a single relay
relay_hint = parse_hint_argv(self._transit_relay_location)
relay = RelayV1Hint(hints=(relay_hint,))
self._transit_relays = [relay]
else:
self._transit_relays = []
self._listeners = set() # IListeningPorts that can be stopped
self._pending_connectors = set() # Deferreds that can be cancelled
self._pending_connections = EmptyableSet(
_eventual_queue=self._eventual_queue) # Protocols to be stopped
self._contenders = set() # viable connections
self._winning_connection = None
self._timing = self._timing or DebugTiming()
self._timing.add("transit")
# this describes what our Connector can do, for the initial advertisement
@classmethod
def get_connection_abilities(klass):
return [{"type": "direct-tcp-v1"},
{"type": "relay-v1"},
]
def build_protocol(self, addr):
# encryption: let's use Noise NNpsk0 (or maybe NNpsk2). That uses
# ephemeral keys plus a pre-shared symmetric key (the Transit key), a
# different one for each potential connection.
noise = build_noise()
noise.set_psks(self._dilation_key)
if self._role is LEADER:
noise.set_as_initiator()
outbound_prologue = PROLOGUE_LEADER
inbound_prologue = PROLOGUE_FOLLOWER
else:
noise.set_as_responder()
outbound_prologue = PROLOGUE_FOLLOWER
inbound_prologue = PROLOGUE_LEADER
p = DilatedConnectionProtocol(self._eventual_queue, self._role,
self, noise,
outbound_prologue, inbound_prologue)
return p
@m.state(initial=True)
def connecting(self):
pass # pragma: no cover
@m.state()
def connected(self):
pass # pragma: no cover
@m.state(terminal=True)
def stopped(self):
pass # pragma: no cover
# TODO: unify the tense of these method-name verbs
# add_relay() and got_hints() are called by the Manager as it receives
# messages from our peer. stop() is called when the Manager shuts down
@m.input()
def add_relay(self, hint_objs):
pass
@m.input()
def got_hints(self, hint_objs):
pass
@m.input()
def stop(self):
pass
# called by ourselves, when _start_listener() is ready
@m.input()
def listener_ready(self, hint_objs):
pass
# called when DilatedConnectionProtocol submits itself, after KCM
# received
@m.input()
def add_candidate(self, c):
pass
# called by ourselves, via consider()
@m.input()
def accept(self, c):
pass
@m.output()
def use_hints(self, hint_objs):
self._use_hints(hint_objs)
@m.output()
def publish_hints(self, hint_objs):
self._publish_hints(hint_objs)
def _publish_hints(self, hint_objs):
self._manager.send_hints([encode_hint(h) for h in hint_objs])
@m.output()
def consider(self, c):
self._contenders.add(c)
if self._role is LEADER:
# for now, just accept the first one. TODO: be clever.
self._eventual_queue.eventually(self.accept, c)
else:
# the follower always uses the first contender, since that's the
# only one the leader picked
self._eventual_queue.eventually(self.accept, c)
@m.output()
def select_and_stop_remaining(self, c):
self._winning_connection = c
self._contenders.clear() # we no longer care who else came close
# remove this winner from the losers, so we don't shut it down
self._pending_connections.discard(c)
# shut down losing connections
self.stop_listeners() # TODO: maybe keep it open? NAT/p2p assist
self.stop_pending_connectors()
self.stop_pending_connections()
c.select(self._manager) # subsequent frames go directly to the manager
if self._role is LEADER:
# TODO: this should live in Connection
c.send_record(KCM()) # leader sends KCM now
self._manager.use_connection(c) # manager sends frames to Connection
@m.output()
def stop_everything(self):
self.stop_listeners()
self.stop_pending_connectors()
self.stop_pending_connections()
self.break_cycles()
def stop_listeners(self):
d = DeferredList([l.stopListening() for l in self._listeners])
self._listeners.clear()
return d # synchronization for tests
def stop_pending_connectors(self):
return DeferredList([d.cancel() for d in self._pending_connectors])
def stop_pending_connections(self):
d = self._pending_connections.when_next_empty()
[c.loseConnection() for c in self._pending_connections]
return d
def break_cycles(self):
# help GC by forgetting references to things that reference us
self._listeners.clear()
self._pending_connectors.clear()
self._pending_connections.clear()
self._winning_connection = None
connecting.upon(listener_ready, enter=connecting, outputs=[publish_hints])
connecting.upon(add_relay, enter=connecting, outputs=[use_hints,
publish_hints])
connecting.upon(got_hints, enter=connecting, outputs=[use_hints])
connecting.upon(add_candidate, enter=connecting, outputs=[consider])
connecting.upon(accept, enter=connected, outputs=[
select_and_stop_remaining])
connecting.upon(stop, enter=stopped, outputs=[stop_everything])
# once connected, we ignore everything except stop
connected.upon(listener_ready, enter=connected, outputs=[])
connected.upon(add_relay, enter=connected, outputs=[])
connected.upon(got_hints, enter=connected, outputs=[])
# TODO: tell them to disconnect? will they hang out forever? I *think*
# they'll drop this once they get a KCM on the winning connection.
connected.upon(add_candidate, enter=connected, outputs=[])
connected.upon(accept, enter=connected, outputs=[])
connected.upon(stop, enter=stopped, outputs=[stop_everything])
# from Manager: start, got_hints, stop
# maybe add_candidate, accept
def start(self):
if not self._no_listen and not self._tor:
addresses = self._get_listener_addresses()
self._start_listener(addresses)
if self._transit_relays:
self._publish_hints(self._transit_relays)
self._use_hints(self._transit_relays)
def _get_listener_addresses(self):
addresses = ipaddrs.find_addresses()
non_loopback_addresses = [a for a in addresses if a != "127.0.0.1"]
if non_loopback_addresses:
# some test hosts, including the appveyor VMs, *only* have
# 127.0.0.1, and the tests will hang badly if we remove it.
addresses = non_loopback_addresses
return addresses
def _start_listener(self, addresses):
# TODO: listen on a fixed port, if possible, for NAT/p2p benefits, also
# to make firewall configs easier
# TODO: retain listening port between connection generations?
ep = serverFromString(self._reactor, "tcp:0")
f = InboundConnectionFactory(self)
d = ep.listen(f)
def _listening(lp):
# lp is an IListeningPort
self._listeners.add(lp) # for shutdown and tests
portnum = lp.getHost().port
direct_hints = [DirectTCPV1Hint(to_unicode(addr), portnum, 0.0)
for addr in addresses]
self.listener_ready(direct_hints)
d.addCallback(_listening)
d.addErrback(log.err)
def _schedule_connection(self, delay, h, is_relay):
ep = endpoint_from_hint_obj(h, self._tor, self._reactor)
desc = describe_hint_obj(h, is_relay, self._tor)
d = deferLater(self._reactor, delay,
self._connect, ep, desc, is_relay)
d.addErrback(log.err)
self._pending_connectors.add(d)
def _use_hints(self, hints):
# first, pull out all the relays, we'll connect to them later
relays = []
direct = defaultdict(list)
for h in hints:
if isinstance(h, RelayV1Hint):
relays.append(h)
else:
direct[h.priority].append(h)
delay = 0.0
made_direct = False
priorities = sorted(set(direct.keys()), reverse=True)
for p in priorities:
for h in direct[p]:
if isinstance(h, TorTCPV1Hint) and not self._tor:
continue
self._schedule_connection(delay, h, is_relay=False)
made_direct = True
# Make all direct connections immediately. Later, we'll change
# the add_candidate() function to look at the priority when
# deciding whether to accept a successful connection or not,
# and it can wait for more options if it sees a higher-priority
# one still running. But if we bail on that, we might consider
# putting an inter-direct-hint delay here to influence the
# process.
# delay += 1.0
if made_direct and not self._no_listen:
# Prefer direct connections by stalling relay connections by a
# few seconds. We don't wait until direct connections have
# failed, because many direct hints will be to unused
# local-network IP address, which won't answer, and can take the
# full 30s TCP timeout to fail.
#
# If we didn't make any direct connections, or we're using
# --no-listen, then we're probably going to have to use the
# relay, so don't delay it at all.
delay += self.RELAY_DELAY
# It might be nice to wire this so that a failure in the direct hints
# causes the relay hints to be used right away (fast failover). But
# none of our current use cases would take advantage of that: if we
# have any viable direct hints, then they're either going to succeed
# quickly or hang for a long time.
for r in relays:
for h in r.hints:
self._schedule_connection(delay, h, is_relay=True)
# TODO:
# if not contenders:
# raise TransitError("No contenders for connection")
# TODO: add 2*TIMEOUT deadline for first generation, don't wait forever for
# the initial connection
def _connect(self, ep, description, is_relay=False):
relay_handshake = None
if is_relay:
relay_handshake = build_sided_relay_handshake(self._dilation_key,
self._side)
f = OutboundConnectionFactory(self, relay_handshake)
d = ep.connect(f)
# fires with protocol, or ConnectError
def _connected(p):
self._pending_connections.add(p)
# c might not be in _pending_connections, if it turned out to be a
# winner, which is why we use discard() and not remove()
p.when_disconnected().addCallback(self._pending_connections.discard)
d.addCallback(_connected)
return d
# Connection selection. All instances of DilatedConnectionProtocol which
# look viable get passed into our add_contender() method.
# On the Leader side, "viable" means we've seen their KCM frame, which is
# the first Noise-encrypted packet on any given connection, and it has an
# empty body. We gather viable connections until we see one that we like,
# or a timer expires. Then we "select" it, close the others, and tell our
# Manager to use it.
# On the Follower side, we'll only see a KCM on the one connection selected
# by the Leader, so the first viable connection wins.
# our Connection protocols call: add_candidate
@attrs
class OutboundConnectionFactory(ClientFactory, object):
_connector = attrib(validator=provides(IDilationConnector))
_relay_handshake = attrib(validator=optional(instance_of(bytes)))
def buildProtocol(self, addr):
p = self._connector.build_protocol(addr)
p.factory = self
if self._relay_handshake is not None:
p.use_relay(self._relay_handshake)
return p
@attrs
class InboundConnectionFactory(ServerFactory, object):
_connector = attrib(validator=provides(IDilationConnector))
def buildProtocol(self, addr):
p = self._connector.build_protocol(addr)
p.factory = self
return p

View File

@ -0,0 +1,19 @@
from __future__ import print_function, unicode_literals
import struct
assert len(struct.pack("<L", 0)) == 4
assert len(struct.pack("<Q", 0)) == 8
def to_be4(value):
if not 0 <= value < 2**32:
raise ValueError
return struct.pack(">L", value)
def from_be4(b):
if not isinstance(b, bytes):
raise TypeError(repr(b))
if len(b) != 4:
raise ValueError
return struct.unpack(">L", b)[0]

View File

@ -0,0 +1,134 @@
from __future__ import print_function, unicode_literals
from attr import attrs, attrib
from attr.validators import provides
from zope.interface import implementer
from twisted.python import log
from .._interfaces import IDilationManager, IInbound
from .subchannel import (SubChannel, _SubchannelAddress)
class DuplicateOpenError(Exception):
pass
class DataForMissingSubchannelError(Exception):
pass
class CloseForMissingSubchannelError(Exception):
pass
@attrs
@implementer(IInbound)
class Inbound(object):
# Inbound flow control: TCP delivers data to Connection.dataReceived,
# Connection delivers to our handle_data, we deliver to
# SubChannel.remote_data, subchannel delivers to proto.dataReceived
_manager = attrib(validator=provides(IDilationManager))
_host_addr = attrib()
def __attrs_post_init__(self):
# we route inbound Data records to Subchannels .dataReceived
self._open_subchannels = {} # scid -> Subchannel
self._paused_subchannels = set() # Subchannels that have paused us
# the set is non-empty, we pause the transport
self._highest_inbound_acked = -1
self._connection = None
# from our Manager
def set_listener_endpoint(self, listener_endpoint):
self._listener_endpoint = listener_endpoint
def set_subchannel_zero(self, scid0, sc0):
self._open_subchannels[scid0] = sc0
def use_connection(self, c):
self._connection = c
# We can pause the connection's reads when we receive too much data. If
# this is a non-initial connection, then we might already have
# subchannels that are paused from before, so we might need to pause
# the new connection before it can send us any data
if self._paused_subchannels:
self._connection.pauseProducing()
# Inbound is responsible for tracking the high watermark and deciding
# whether to ignore inbound messages or not
def is_record_old(self, r):
if r.seqnum <= self._highest_inbound_acked:
return True
return False
def update_ack_watermark(self, r):
self._highest_inbound_acked = max(self._highest_inbound_acked,
r.seqnum)
def handle_open(self, scid):
if scid in self._open_subchannels:
log.err(DuplicateOpenError(
"received duplicate OPEN for {}".format(scid)))
return
peer_addr = _SubchannelAddress(scid)
sc = SubChannel(scid, self._manager, self._host_addr, peer_addr)
self._open_subchannels[scid] = sc
self._listener_endpoint._got_open(sc, peer_addr)
def handle_data(self, scid, data):
sc = self._open_subchannels.get(scid)
if sc is None:
log.err(DataForMissingSubchannelError(
"received DATA for non-existent subchannel {}".format(scid)))
return
sc.remote_data(data)
def handle_close(self, scid):
sc = self._open_subchannels.get(scid)
if sc is None:
log.err(CloseForMissingSubchannelError(
"received CLOSE for non-existent subchannel {}".format(scid)))
return
sc.remote_close()
def subchannel_closed(self, scid, sc):
# connectionLost has just been signalled
assert self._open_subchannels[scid] is sc
del self._open_subchannels[scid]
def stop_using_connection(self):
self._connection = None
# from our Subchannel, or rather from the Protocol above it and sent
# through the subchannel
# The subchannel is an IProducer, and application protocols can always
# thell them to pauseProducing if we're delivering inbound data too
# quickly. They don't need to register anything.
def subchannel_pauseProducing(self, sc):
was_paused = bool(self._paused_subchannels)
self._paused_subchannels.add(sc)
if self._connection and not was_paused:
self._connection.pauseProducing()
def subchannel_resumeProducing(self, sc):
was_paused = bool(self._paused_subchannels)
self._paused_subchannels.discard(sc)
if self._connection and was_paused and not self._paused_subchannels:
self._connection.resumeProducing()
def subchannel_stopProducing(self, sc):
# This protocol doesn't want any additional data. If we were a normal
# (single-owner) Transport, we'd call .loseConnection now. But our
# Connection is shared among many subchannels, so instead we just
# stop letting them pause the connection.
was_paused = bool(self._paused_subchannels)
self._paused_subchannels.discard(sc)
if self._connection and was_paused and not self._paused_subchannels:
self._connection.resumeProducing()
# TODO: we might refactor these pause/resume/stop methods by building a
# context manager that look at the paused/not-paused state first, then
# lets the caller modify self._paused_subchannels, then looks at it a
# second time, and calls c.pauseProducing/c.resumeProducing as
# appropriate. I'm not sure it would be any cleaner, though.

View File

@ -0,0 +1,580 @@
from __future__ import print_function, unicode_literals
import os
from collections import deque
from attr import attrs, attrib
from attr.validators import provides, instance_of, optional
from automat import MethodicalMachine
from zope.interface import implementer
from twisted.internet.defer import Deferred, inlineCallbacks, returnValue
from twisted.python import log
from .._interfaces import IDilator, IDilationManager, ISend
from ..util import dict_to_bytes, bytes_to_dict, bytes_to_hexstr
from ..observer import OneShotObserver
from .._key import derive_key
from .encode import to_be4
from .subchannel import (SubChannel, _SubchannelAddress, _WormholeAddress,
ControlEndpoint, SubchannelConnectorEndpoint,
SubchannelListenerEndpoint)
from .connector import Connector
from .._hints import parse_hint
from .roles import LEADER, FOLLOWER
from .connection import KCM, Ping, Pong, Open, Data, Close, Ack
from .inbound import Inbound
from .outbound import Outbound
# exported to Wormhole() for inclusion in versions message
DILATION_VERSIONS = ["1"]
class OldPeerCannotDilateError(Exception):
pass
class UnknownDilationMessageType(Exception):
pass
class ReceivedHintsTooEarly(Exception):
pass
class UnexpectedKCM(Exception):
pass
class UnknownMessageType(Exception):
pass
def make_side():
return bytes_to_hexstr(os.urandom(6))
# new scheme:
# * both sides send PLEASE as soon as they have an unverified key and
# w.dilate has been called,
# * PLEASE includes a dilation-specific "side" (independent of the "side"
# used by mailbox messages)
# * higher "side" is Leader, lower is Follower
# * PLEASE includes can-dilate list of version integers, requires overlap
# "1" is current
# * we start dilation after both w.dilate() and receiving VERSION, putting us
# in WANTING, then we process all previously-queued inbound DILATE-n
# messages. When PLEASE arrives, we move to CONNECTING
# * HINTS sent after dilation starts
# * only Leader sends RECONNECT, only Follower sends RECONNECTING. This
# is the only difference between the two sides, and is not enforced
# by the protocol (i.e. if the Follower sends RECONNECT to the Leader,
# the Leader will obey, although TODO how confusing will this get?)
# * upon receiving RECONNECT: drop Connector, start new Connector, send
# RECONNECTING, start sending HINTS
# * upon sending RECONNECT: go into FLUSHING state and ignore all HINTS until
# RECONNECTING received. The new Connector can be spun up earlier, and it
# can send HINTS, but it must not be given any HINTS that arrive before
# RECONNECTING (since they're probably stale)
# * after VERSIONS(KCM) received, we might learn that they other side cannot
# dilate. w.dilate errbacks at this point
# * maybe signal warning if we stay in a "want" state for too long
# * nobody sends HINTS until they're ready to receive
# * nobody sends HINTS unless they've called w.dilate() and received PLEASE
# * nobody connects to inbound hints unless they've called w.dilate()
# * if leader calls w.dilate() but not follower, leader waits forever in
# "want" (doesn't send anything)
# * if follower calls w.dilate() but not leader, follower waits forever
# in "want", leader waits forever in "wanted"
@attrs
@implementer(IDilationManager)
class Manager(object):
_S = attrib(validator=provides(ISend))
_my_side = attrib(validator=instance_of(type(u"")))
_transit_key = attrib(validator=instance_of(bytes))
_transit_relay_location = attrib(validator=optional(instance_of(str)))
_reactor = attrib()
_eventual_queue = attrib()
_cooperator = attrib()
_no_listen = False # TODO
_tor = None # TODO
_timing = None # TODO
_next_subchannel_id = None # initialized in choose_role
m = MethodicalMachine()
set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover
def __attrs_post_init__(self):
self._got_versions_d = Deferred()
self._my_role = None # determined upon rx_PLEASE
self._connection = None
self._made_first_connection = False
self._first_connected = OneShotObserver(self._eventual_queue)
self._host_addr = _WormholeAddress()
self._next_dilation_phase = 0
# I kept getting confused about which methods were for inbound data
# (and thus flow-control methods go "out") and which were for
# outbound data (with flow-control going "in"), so I split them up
# into separate pieces.
self._inbound = Inbound(self, self._host_addr)
self._outbound = Outbound(self, self._cooperator) # from us to peer
def set_listener_endpoint(self, listener_endpoint):
self._inbound.set_listener_endpoint(listener_endpoint)
def set_subchannel_zero(self, scid0, sc0):
self._inbound.set_subchannel_zero(scid0, sc0)
def when_first_connected(self):
return self._first_connected.when_fired()
def send_dilation_phase(self, **fields):
dilation_phase = self._next_dilation_phase
self._next_dilation_phase += 1
self._S.send("dilate-%d" % dilation_phase, dict_to_bytes(fields))
def send_hints(self, hints): # from Connector
self.send_dilation_phase(type="connection-hints", hints=hints)
# forward inbound-ish things to _Inbound
def subchannel_pauseProducing(self, sc):
self._inbound.subchannel_pauseProducing(sc)
def subchannel_resumeProducing(self, sc):
self._inbound.subchannel_resumeProducing(sc)
def subchannel_stopProducing(self, sc):
self._inbound.subchannel_stopProducing(sc)
# forward outbound-ish things to _Outbound
def subchannel_registerProducer(self, sc, producer, streaming):
self._outbound.subchannel_registerProducer(sc, producer, streaming)
def subchannel_unregisterProducer(self, sc):
self._outbound.subchannel_unregisterProducer(sc)
def send_open(self, scid):
self._queue_and_send(Open, scid)
def send_data(self, scid, data):
self._queue_and_send(Data, scid, data)
def send_close(self, scid):
self._queue_and_send(Close, scid)
def _queue_and_send(self, record_type, *args):
r = self._outbound.build_record(record_type, *args)
# Outbound owns the send_record() pipe, so that it can stall new
# writes after a new connection is made until after all queued
# messages are written (to preserve ordering).
self._outbound.queue_and_send_record(r) # may trigger pauseProducing
def subchannel_closed(self, scid, sc):
# let everyone clean up. This happens just after we delivered
# connectionLost to the Protocol, except for the control channel,
# which might get connectionLost later after they use ep.connect.
# TODO: is this inversion a problem?
self._inbound.subchannel_closed(scid, sc)
self._outbound.subchannel_closed(scid, sc)
# our Connector calls these
def connector_connection_made(self, c):
self.connection_made() # state machine update
self._connection = c
self._inbound.use_connection(c)
self._outbound.use_connection(c) # does c.registerProducer
if not self._made_first_connection:
self._made_first_connection = True
self._first_connected.fire(None)
pass
def connector_connection_lost(self):
self._stop_using_connection()
if self._my_role is LEADER:
self.connection_lost_leader() # state machine
else:
self.connection_lost_follower()
def _stop_using_connection(self):
# the connection is already lost by this point
self._connection = None
self._inbound.stop_using_connection()
self._outbound.stop_using_connection() # does c.unregisterProducer
# from our active Connection
def got_record(self, r):
# records with sequence numbers: always ack, ignore old ones
if isinstance(r, (Open, Data, Close)):
self.send_ack(r.seqnum) # always ack, even for old ones
if self._inbound.is_record_old(r):
return
self._inbound.update_ack_watermark(r.seqnum)
if isinstance(r, Open):
self._inbound.handle_open(r.scid)
elif isinstance(r, Data):
self._inbound.handle_data(r.scid, r.data)
else: # isinstance(r, Close)
self._inbound.handle_close(r.scid)
return
if isinstance(r, KCM):
log.err(UnexpectedKCM())
elif isinstance(r, Ping):
self.handle_ping(r.ping_id)
elif isinstance(r, Pong):
self.handle_pong(r.ping_id)
elif isinstance(r, Ack):
self._outbound.handle_ack(r.resp_seqnum) # retire queued messages
else:
log.err(UnknownMessageType("{}".format(r)))
# pings, pongs, and acks are not queued
def send_ping(self, ping_id):
self._outbound.send_if_connected(Ping(ping_id))
def send_pong(self, ping_id):
self._outbound.send_if_connected(Pong(ping_id))
def send_ack(self, resp_seqnum):
self._outbound.send_if_connected(Ack(resp_seqnum))
def handle_ping(self, ping_id):
self.send_pong(ping_id)
def handle_pong(self, ping_id):
# TODO: update is-alive timer
pass
# subchannel maintenance
def allocate_subchannel_id(self):
scid_num = self._next_subchannel_id
self._next_subchannel_id += 2
return to_be4(scid_num)
# state machine
# We are born WANTING after the local app calls w.dilate(). We start
# CONNECTING when we receive PLEASE from the remote side
def start(self):
self.send_please()
def send_please(self):
self.send_dilation_phase(type="please", side=self._my_side)
@m.state(initial=True)
def WANTING(self):
pass # pragma: no cover
@m.state()
def CONNECTING(self):
pass # pragma: no cover
@m.state()
def CONNECTED(self):
pass # pragma: no cover
@m.state()
def FLUSHING(self):
pass # pragma: no cover
@m.state()
def ABANDONING(self):
pass # pragma: no cover
@m.state()
def LONELY(self):
pass # pragma: no cover
@m.state()
def STOPPING(self):
pass # pragma: no cover
@m.state(terminal=True)
def STOPPED(self):
pass # pragma: no cover
@m.input()
def rx_PLEASE(self, message):
pass # pragma: no cover
@m.input() # only sent by Follower
def rx_HINTS(self, hint_message):
pass # pragma: no cover
@m.input() # only Leader sends RECONNECT, so only Follower receives it
def rx_RECONNECT(self):
pass # pragma: no cover
@m.input() # only Follower sends RECONNECTING, so only Leader receives it
def rx_RECONNECTING(self):
pass # pragma: no cover
# Connector gives us connection_made()
@m.input()
def connection_made(self):
pass # pragma: no cover
# our connection_lost() fires connection_lost_leader or
# connection_lost_follower depending upon our role. If either side sees a
# problem with the connection (timeouts, bad authentication) then they
# just drop it and let connection_lost() handle the cleanup.
@m.input()
def connection_lost_leader(self):
pass # pragma: no cover
@m.input()
def connection_lost_follower(self):
pass
@m.input()
def stop(self):
pass # pragma: no cover
@m.output()
def choose_role(self, message):
their_side = message["side"]
if self._my_side > their_side:
self._my_role = LEADER
# scid 0 is reserved for the control channel. the leader uses odd
# numbers starting with 1
self._next_subchannel_id = 1
elif their_side > self._my_side:
self._my_role = FOLLOWER
# the follower uses even numbers starting with 2
self._next_subchannel_id = 2
else:
raise ValueError("their side shouldn't be equal: reflection?")
# these Outputs behave differently for the Leader vs the Follower
@m.output()
def start_connecting_ignore_message(self, message):
del message # ignored
return self._start_connecting()
@m.output()
def start_connecting(self):
self._start_connecting()
def _start_connecting(self):
assert self._my_role is not None
self._connector = Connector(self._transit_key,
self._transit_relay_location,
self,
self._reactor, self._eventual_queue,
self._no_listen, self._tor,
self._timing,
self._my_side, # needed for relay handshake
self._my_role)
self._connector.start()
@m.output()
def send_reconnect(self):
self.send_dilation_phase(type="reconnect") # TODO: generation number?
@m.output()
def send_reconnecting(self):
self.send_dilation_phase(type="reconnecting") # TODO: generation?
@m.output()
def use_hints(self, hint_message):
hint_objs = filter(lambda h: h, # ignore None, unrecognizable
[parse_hint(hs) for hs in hint_message["hints"]])
hint_objs = list(hint_objs)
self._connector.got_hints(hint_objs)
@m.output()
def stop_connecting(self):
self._connector.stop()
@m.output()
def abandon_connection(self):
# we think we're still connected, but the Leader disagrees. Or we've
# been told to shut down.
self._connection.disconnect() # let connection_lost do cleanup
# we start CONNECTING when we get rx_PLEASE
WANTING.upon(rx_PLEASE, enter=CONNECTING,
outputs=[choose_role, start_connecting_ignore_message])
CONNECTING.upon(connection_made, enter=CONNECTED, outputs=[])
# Leader
CONNECTED.upon(connection_lost_leader, enter=FLUSHING,
outputs=[send_reconnect])
FLUSHING.upon(rx_RECONNECTING, enter=CONNECTING,
outputs=[start_connecting])
# Follower
# if we notice a lost connection, just wait for the Leader to notice too
CONNECTED.upon(connection_lost_follower, enter=LONELY, outputs=[])
LONELY.upon(rx_RECONNECT, enter=CONNECTING,
outputs=[send_reconnecting, start_connecting])
# but if they notice it first, abandon our (seemingly functional)
# connection, then tell them that we're ready to try again
CONNECTED.upon(rx_RECONNECT, enter=ABANDONING, outputs=[abandon_connection])
ABANDONING.upon(connection_lost_follower, enter=CONNECTING,
outputs=[send_reconnecting, start_connecting])
# and if they notice a problem while we're still connecting, abandon our
# incomplete attempt and try again. in this case we don't have to wait
# for a connection to finish shutdown
CONNECTING.upon(rx_RECONNECT, enter=CONNECTING,
outputs=[stop_connecting,
send_reconnecting,
start_connecting])
# rx_HINTS never changes state, they're just accepted or ignored
WANTING.upon(rx_HINTS, enter=WANTING, outputs=[]) # too early
CONNECTING.upon(rx_HINTS, enter=CONNECTING, outputs=[use_hints])
CONNECTED.upon(rx_HINTS, enter=CONNECTED, outputs=[]) # too late, ignore
FLUSHING.upon(rx_HINTS, enter=FLUSHING, outputs=[]) # stale, ignore
LONELY.upon(rx_HINTS, enter=LONELY, outputs=[]) # stale, ignore
ABANDONING.upon(rx_HINTS, enter=ABANDONING, outputs=[]) # shouldn't happen
STOPPING.upon(rx_HINTS, enter=STOPPING, outputs=[])
WANTING.upon(stop, enter=STOPPED, outputs=[])
CONNECTING.upon(stop, enter=STOPPED, outputs=[stop_connecting])
CONNECTED.upon(stop, enter=STOPPING, outputs=[abandon_connection])
ABANDONING.upon(stop, enter=STOPPING, outputs=[])
FLUSHING.upon(stop, enter=STOPPED, outputs=[])
LONELY.upon(stop, enter=STOPPED, outputs=[])
STOPPING.upon(connection_lost_leader, enter=STOPPED, outputs=[])
STOPPING.upon(connection_lost_follower, enter=STOPPED, outputs=[])
@attrs
@implementer(IDilator)
class Dilator(object):
"""I launch the dilation process.
I am created with every Wormhole (regardless of whether .dilate()
was called or not), and I handle the initial phase of dilation,
before we know whether we'll be the Leader or the Follower. Once we
hear the other side's VERSION message (which tells us that we have a
connection, they are capable of dilating, and which side we're on),
then we build a Manager and hand control to it.
"""
_reactor = attrib()
_eventual_queue = attrib()
_cooperator = attrib()
def __attrs_post_init__(self):
self._got_versions_d = Deferred()
self._started = False
self._endpoints = OneShotObserver(self._eventual_queue)
self._pending_inbound_dilate_messages = deque()
self._manager = None
def wire(self, sender):
self._S = ISend(sender)
# this is the primary entry point, called when w.dilate() is invoked
def dilate(self, transit_relay_location=None):
self._transit_relay_location = transit_relay_location
if not self._started:
self._started = True
self._start().addBoth(self._endpoints.fire)
return self._endpoints.when_fired()
@inlineCallbacks
def _start(self):
# first, we wait until we hear the VERSION message, which tells us 1:
# the PAKE key works, so we can talk securely, 2: that they can do
# dilation at all (if they can't then w.dilate() errbacks)
dilation_version = yield self._got_versions_d
# TODO: we could probably return the endpoints earlier, if we flunk
# any connection/listen attempts upon OldPeerCannotDilateError, or
# if/when we give up on the initial connection
if not dilation_version: # "1" or None
# TODO: be more specific about the error. dilation_version==None
# means we had no version in common with them, which could either
# be because they're so old they don't dilate at all, or because
# they're so new that they no longer accomodate our old version
raise OldPeerCannotDilateError()
my_dilation_side = make_side()
self._manager = Manager(self._S, my_dilation_side,
self._transit_key,
self._transit_relay_location,
self._reactor, self._eventual_queue,
self._cooperator)
self._manager.start()
while self._pending_inbound_dilate_messages:
plaintext = self._pending_inbound_dilate_messages.popleft()
self.received_dilate(plaintext)
yield self._manager.when_first_connected()
# we can open subchannels as soon as we get our first connection
scid0 = b"\x00\x00\x00\x00"
self._host_addr = _WormholeAddress() # TODO: share with Manager
peer_addr0 = _SubchannelAddress(scid0)
control_ep = ControlEndpoint(peer_addr0)
sc0 = SubChannel(scid0, self._manager, self._host_addr, peer_addr0)
control_ep._subchannel_zero_opened(sc0)
self._manager.set_subchannel_zero(scid0, sc0)
connect_ep = SubchannelConnectorEndpoint(self._manager, self._host_addr)
listen_ep = SubchannelListenerEndpoint(self._manager, self._host_addr)
self._manager.set_listener_endpoint(listen_ep)
endpoints = (control_ep, connect_ep, listen_ep)
returnValue(endpoints)
# from Boss
def got_key(self, key):
# TODO: verify this happens before got_wormhole_versions, or add a gate
# to tolerate either ordering
purpose = b"dilation-v1"
LENGTH = 32 # TODO: whatever Noise wants, I guess
self._transit_key = derive_key(key, purpose, LENGTH)
def got_wormhole_versions(self, their_wormhole_versions):
assert self._transit_key is not None
# this always happens before received_dilate
dilation_version = None
their_dilation_versions = set(their_wormhole_versions.get("can-dilate", []))
my_versions = set(DILATION_VERSIONS)
shared_versions = my_versions.intersection(their_dilation_versions)
if "1" in shared_versions:
dilation_version = "1"
self._got_versions_d.callback(dilation_version)
def received_dilate(self, plaintext):
# this receives new in-order DILATE-n payloads, decrypted but not
# de-JSONed.
# this can appear before our .dilate() method is called, in which case
# we queue them for later
if not self._manager:
self._pending_inbound_dilate_messages.append(plaintext)
return
message = bytes_to_dict(plaintext)
type = message["type"]
if type == "please":
self._manager.rx_PLEASE(message)
elif type == "connection-hints":
self._manager.rx_HINTS(message)
elif type == "reconnect":
self._manager.rx_RECONNECT()
elif type == "reconnecting":
self._manager.rx_RECONNECTING()
else:
log.err(UnknownDilationMessageType(message))
return

View File

@ -0,0 +1,389 @@
from __future__ import print_function, unicode_literals
from collections import deque
from attr import attrs, attrib
from attr.validators import provides
from zope.interface import implementer
from twisted.internet.interfaces import IPushProducer, IPullProducer
from twisted.python import log
from twisted.python.reflect import safe_str
from .._interfaces import IDilationManager, IOutbound
from .connection import KCM, Ping, Pong, Ack
# Outbound flow control: app writes to subchannel, we write to Connection
# The app can register an IProducer of their choice, to let us throttle their
# outbound data. Not all subchannels will have producers registered, and the
# producer probably won't be the IProtocol instance (it'll be something else
# which feeds data out through the protocol, like a t.p.basic.FileSender). If
# a producerless subchannel writes too much, we won't be able to stop them,
# and we'll keep writing records into the Connection even though it's asked
# us to pause. Likewise, when the connection is down (and we're busily trying
# to reestablish a new one), registered subchannels will be paused, but
# unregistered ones will just dump everything in _outbound_queue, and we'll
# consume memory without bound until they stop.
# We need several things:
#
# * Add each registered IProducer to a list, whose order remains stable. We
# want fairness under outbound throttling: each time the outbound
# connection opens up (our resumeProducing method is called), we should let
# just one producer have an opportunity to do transport.write, and then we
# should pause them again, and not come back to them until everyone else
# has gotten a turn. So we want an ordered list of producers to track this
# rotation.
#
# * Remove the IProducer if/when the protocol uses unregisterProducer
#
# * Remove any registered IProducer when the associated Subchannel is closed.
# This isn't a problem for normal transports, because usually there's a
# one-to-one mapping from Protocol to Transport, so when the Transport you
# forget the only reference to the Producer anyways. Our situation is
# unusual because we have multiple Subchannels that get merged into the
# same underlying Connection: each Subchannel's Protocol can register a
# producer on the Subchannel (which is an ITransport), but that adds it to
# a set of Producers for the Connection (which is also an ITransport). So
# if the Subchannel is closed, we need to remove its Producer (if any) even
# though the Connection remains open.
#
# * Register ourselves as an IPushProducer with each successive Connection
# object. These connections will come and go, but there will never be more
# than one. When the connection goes away, pause all our producers. When a
# new one is established, write all our queued messages, then unpause our
# producers as we would in resumeProducing.
#
# * Inside our resumeProducing call, we'll cycle through all producers,
# calling their individual resumeProducing methods one at a time. If they
# write so much data that the Connection pauses us again, we'll find out
# because our pauseProducing will be called inside that loop. When that
# happens, we need to stop looping. If we make it through the whole loop
# without being paused, then all subchannel Producers are left unpaused,
# and are free to write whenever they want. During this loop, some
# Producers will be paused, and others will be resumed
#
# * If our pauseProducing is called, all Producers must be paused, and a flag
# should be set to notify the resumeProducing loop to exit
#
# * In between calls to our resumeProducing method, we're in one of two
# states.
# * If we're writing data too fast, then we'll be left in the "paused"
# state, in which all Subchannel producers are paused, and the aggregate
# is paused too (our Connection told us to pauseProducing and hasn't yet
# told us to resumeProducing). In this state, activity is driven by the
# outbound TCP window opening up, which calls resumeProducing and allows
# (probably just) one message to be sent. We receive pauseProducing in
# the middle of their transport.write, so the loop exits early, and the
# only state change is that some other Producer should get to go next
# time.
# * If we're writing too slowly, we'll be left in the "unpaused" state: all
# Subchannel producers are unpaused, and the aggregate is unpaused too
# (resumeProducing is the last thing we've been told). In this satte,
# activity is driven by the Subchannels doing a transport.write, which
# queues some data on the TCP connection (and then might call
# pauseProducing if it's now full).
#
# * We want to guard against:
#
# * application protocol registering a Producer without first unregistering
# the previous one
#
# * application protocols writing data despite being told to pause
# (Subchannels without a registered Producer cannot be throttled, and we
# can't do anything about that, but we must also handle the case where
# they give us a pause switch and then proceed to ignore it)
#
# * our Connection calling resumeProducing or pauseProducing without an
# intervening call of the other kind
#
# * application protocols that don't handle a resumeProducing or
# pauseProducing call without an intervening call of the other kind (i.e.
# we should keep track of the last thing we told them, and not repeat
# ourselves)
#
# * If the Wormhole is closed, all Subchannels should close. This is not our
# responsibility: it lives in (Manager? Inbound?)
#
# * If we're given an IPullProducer, we should keep calling its
# resumeProducing until it runs out of data. We still want fairness, so we
# won't call it a second time until everyone else has had a turn.
# There are a couple of different ways to approach this. The one I've
# selected is:
#
# * keep a dict that maps from Subchannel to Producer, which only contains
# entries for Subchannels that have registered a producer. We use this to
# remove Producers when Subchannels are closed
#
# * keep a Deque of Producers. This represents the fair-throttling rotation:
# the left-most item gets the next upcoming turn, and then they'll be moved
# to the end of the queue.
#
# * keep a set of IPushProducers which are paused, a second set of
# IPushProducers which are unpaused, and a third set of IPullProducers
# (which are always left paused) Enforce the invariant that these three
# sets are disjoint, and that their union equals the contents of the deque.
#
# * keep a "paused" flag, which is cleared upon entry to resumeProducing, and
# set upon entry to pauseProducing. The loop inside resumeProducing checks
# this flag after each call to producer.resumeProducing, to sense whether
# they used their turn to write data, and if that write was large enough to
# fill the TCP window. If set, we break out of the loop. If not, we look
# for the next producer to unpause. The loop finishes when all producers
# are unpaused (evidenced by the two sets of paused producers being empty)
#
# * the "paused" flag also determines whether new IPushProducers are added to
# the paused or unpaused set (IPullProducers are always added to the
# pull+paused set). If we have any IPullProducers, we're always in the
# "writing data too fast" state.
# other approaches that I didn't decide to do at this time (but might use in
# the future):
#
# * use one set instead of two. pros: fewer moving parts. cons: harder to
# spot decoherence bugs like adding a producer to the deque but forgetting
# to add it to one of the
#
# * use zero sets, and keep the paused-vs-unpaused state in the Subchannel as
# a visible boolean flag. This conflates Subchannels with their associated
# Producer (so if we went this way, we should also let them track their own
# Producer). Our resumeProducing loop ends when 'not any(sc.paused for sc
# in self._subchannels_with_producers)'. Pros: fewer subchannel->producer
# mappings lying around to disagree with one another. Cons: exposes a bit
# too much of the Subchannel internals
@attrs
@implementer(IOutbound)
class Outbound(object):
# Manage outbound data: subchannel writes to us, we write to transport
_manager = attrib(validator=provides(IDilationManager))
_cooperator = attrib()
def __attrs_post_init__(self):
# _outbound_queue holds all messages we've ever sent but not retired
self._outbound_queue = deque()
self._next_outbound_seqnum = 0
# _queued_unsent are messages to retry with our new connection
self._queued_unsent = deque()
# outbound flow control: the Connection throttles our writes
self._subchannel_producers = {} # Subchannel -> IProducer
self._paused = True # our Connection called our pauseProducing
self._all_producers = deque() # rotates, left-is-next
self._paused_producers = set()
self._unpaused_producers = set()
self._check_invariants()
self._connection = None
def _check_invariants(self):
assert self._unpaused_producers.isdisjoint(self._paused_producers)
assert (self._paused_producers.union(self._unpaused_producers) ==
set(self._all_producers))
def build_record(self, record_type, *args):
seqnum = self._next_outbound_seqnum
self._next_outbound_seqnum += 1
r = record_type(seqnum, *args)
assert hasattr(r, "seqnum"), r # only Open/Data/Close
return r
def queue_and_send_record(self, r):
# we always queue it, to resend on a subsequent connection if
# necessary
self._outbound_queue.append(r)
if self._connection:
if self._queued_unsent:
# to maintain correct ordering, queue this instead of sending it
self._queued_unsent.append(r)
else:
# we're allowed to send it immediately
self._connection.send_record(r)
def send_if_connected(self, r):
assert isinstance(r, (KCM, Ping, Pong, Ack)), r # nothing with seqnum
if self._connection:
self._connection.send_record(r)
# our subchannels call these to register a producer
def subchannel_registerProducer(self, sc, producer, streaming):
# streaming==True: IPushProducer (pause/resume)
# streaming==False: IPullProducer (just resume)
if sc in self._subchannel_producers:
raise ValueError(
"registering producer %s before previous one (%s) was "
"unregistered" % (producer,
self._subchannel_producers[sc]))
# our underlying Connection uses streaming==True, so to make things
# easier, use an adapter when the Subchannel asks for streaming=False
if not streaming:
def unregister():
self.subchannel_unregisterProducer(sc)
producer = PullToPush(producer, unregister, self._cooperator)
self._subchannel_producers[sc] = producer
self._all_producers.append(producer)
if self._paused:
self._paused_producers.add(producer)
else:
self._unpaused_producers.add(producer)
self._check_invariants()
if streaming:
if self._paused:
# IPushProducers need to be paused immediately, before they
# speak
producer.pauseProducing() # you wake up sleeping
else:
# our PullToPush adapter must be started, but if we're paused then
# we tell it to pause before it gets a chance to write anything
producer.startStreaming(self._paused)
def subchannel_unregisterProducer(self, sc):
# TODO: what if the subchannel closes, so we unregister their
# producer for them, then the application reacts to connectionLost
# with a duplicate unregisterProducer?
p = self._subchannel_producers.pop(sc)
if isinstance(p, PullToPush):
p.stopStreaming()
self._all_producers.remove(p)
self._paused_producers.discard(p)
self._unpaused_producers.discard(p)
self._check_invariants()
def subchannel_closed(self, sc):
self._check_invariants()
if sc in self._subchannel_producers:
self.subchannel_unregisterProducer(sc)
# our Manager tells us when we've got a new Connection to work with
def use_connection(self, c):
self._connection = c
assert not self._queued_unsent
self._queued_unsent.extend(self._outbound_queue)
# the connection can tell us to pause when we send too much data
c.registerProducer(self, True) # IPushProducer: pause+resume
# send our queued messages
self.resumeProducing()
def stop_using_connection(self):
self._connection.unregisterProducer()
self._connection = None
self._queued_unsent.clear()
self.pauseProducing()
# TODO: I expect this will call pauseProducing twice: the first time
# when we get stopProducing (since we're registere with the
# underlying connection as the producer), and again when the manager
# notices the connectionLost and calls our _stop_using_connection
def handle_ack(self, resp_seqnum):
# we've received an inbound ack, so retire something
while (self._outbound_queue and
self._outbound_queue[0].seqnum <= resp_seqnum):
self._outbound_queue.popleft()
while (self._queued_unsent and
self._queued_unsent[0].seqnum <= resp_seqnum):
self._queued_unsent.popleft()
# Inbound is responsible for tracking the high watermark and deciding
# whether to ignore inbound messages or not
# IProducer: the active connection calls these because we used
# c.registerProducer to ask for them
def pauseProducing(self):
if self._paused:
return # someone is confused and called us twice
self._paused = True
for p in self._all_producers:
if p in self._unpaused_producers:
self._unpaused_producers.remove(p)
self._paused_producers.add(p)
p.pauseProducing()
def resumeProducing(self):
if not self._paused:
return # someone is confused and called us twice
self._paused = False
while not self._paused:
if self._queued_unsent:
r = self._queued_unsent.popleft()
self._connection.send_record(r)
continue
p = self._get_next_unpaused_producer()
if not p:
break
self._paused_producers.remove(p)
self._unpaused_producers.add(p)
p.resumeProducing()
def _get_next_unpaused_producer(self):
self._check_invariants()
if not self._paused_producers:
return None
while True:
p = self._all_producers[0]
self._all_producers.rotate(-1) # p moves to the end of the line
# the only unpaused Producers are at the end of the list
assert p in self._paused_producers
return p
def stopProducing(self):
# we'll hopefully have a new connection to work with in the future,
# so we don't shut anything down. We do pause everyone, though.
self.pauseProducing()
# modelled after twisted.internet._producer_helper._PullToPush , but with a
# configurable Cooperator, a pause-immediately argument to startStreaming()
@implementer(IPushProducer)
@attrs(cmp=False)
class PullToPush(object):
_producer = attrib(validator=provides(IPullProducer))
_unregister = attrib(validator=lambda _a, _b, v: callable(v))
_cooperator = attrib()
_finished = False
def _pull(self):
while True:
try:
self._producer.resumeProducing()
except Exception:
log.err(None, "%s failed, producing will be stopped:" %
(safe_str(self._producer),))
try:
self._unregister()
# The consumer should now call stopStreaming() on us,
# thus stopping the streaming.
except Exception:
# Since the consumer blew up, we may not have had
# stopStreaming() called, so we just stop on our own:
log.err(None, "%s failed to unregister producer:" %
(safe_str(self._unregister),))
self._finished = True
return
yield None
def startStreaming(self, paused):
self._coopTask = self._cooperator.cooperate(self._pull())
if paused:
self.pauseProducing() # timer is scheduled, but task is removed
def stopStreaming(self):
if self._finished:
return
self._finished = True
self._coopTask.stop()
def pauseProducing(self):
self._coopTask.pause()
def resumeProducing(self):
self._coopTask.resume()
def stopProducing(self):
self.stopStreaming()
self._producer.stopProducing()

View File

@ -0,0 +1 @@
LEADER, FOLLOWER = object(), object()

View File

@ -0,0 +1,300 @@
from attr import attrs, attrib
from attr.validators import instance_of, provides
from zope.interface import implementer
from twisted.internet.defer import (Deferred, inlineCallbacks, returnValue,
succeed)
from twisted.internet.interfaces import (ITransport, IProducer, IConsumer,
IAddress, IListeningPort,
IStreamClientEndpoint,
IStreamServerEndpoint)
from twisted.internet.error import ConnectionDone
from automat import MethodicalMachine
from .._interfaces import ISubChannel, IDilationManager
@attrs
class Once(object):
_errtype = attrib()
def __attrs_post_init__(self):
self._called = False
def __call__(self):
if self._called:
raise self._errtype()
self._called = True
class SingleUseEndpointError(Exception):
pass
# created in the (OPEN) state, by either:
# * receipt of an OPEN message
# * or local client_endpoint.connect()
# then transitions are:
# (OPEN) rx DATA: deliver .dataReceived(), -> (OPEN)
# (OPEN) rx CLOSE: deliver .connectionLost(), send CLOSE, -> (CLOSED)
# (OPEN) local .write(): send DATA, -> (OPEN)
# (OPEN) local .loseConnection(): send CLOSE, -> (CLOSING)
# (CLOSING) local .write(): error
# (CLOSING) local .loseConnection(): error
# (CLOSING) rx DATA: deliver .dataReceived(), -> (CLOSING)
# (CLOSING) rx CLOSE: deliver .connectionLost(), -> (CLOSED)
# object is deleted upon transition to (CLOSED)
class AlreadyClosedError(Exception):
pass
@implementer(IAddress)
class _WormholeAddress(object):
pass
@implementer(IAddress)
@attrs
class _SubchannelAddress(object):
_scid = attrib()
@attrs
@implementer(ITransport)
@implementer(IProducer)
@implementer(IConsumer)
@implementer(ISubChannel)
class SubChannel(object):
_id = attrib(validator=instance_of(bytes))
_manager = attrib(validator=provides(IDilationManager))
_host_addr = attrib(validator=instance_of(_WormholeAddress))
_peer_addr = attrib(validator=instance_of(_SubchannelAddress))
m = MethodicalMachine()
set_trace = getattr(m, "_setTrace", lambda self,
f: None) # pragma: no cover
def __attrs_post_init__(self):
# self._mailbox = None
# self._pending_outbound = {}
# self._processed = set()
self._protocol = None
self._pending_dataReceived = []
self._pending_connectionLost = (False, None)
@m.state(initial=True)
def open(self):
pass # pragma: no cover
@m.state()
def closing():
pass # pragma: no cover
@m.state()
def closed():
pass # pragma: no cover
@m.input()
def remote_data(self, data):
pass
@m.input()
def remote_close(self):
pass
@m.input()
def local_data(self, data):
pass
@m.input()
def local_close(self):
pass
@m.output()
def send_data(self, data):
self._manager.send_data(self._id, data)
@m.output()
def send_close(self):
self._manager.send_close(self._id)
@m.output()
def signal_dataReceived(self, data):
if self._protocol:
self._protocol.dataReceived(data)
else:
self._pending_dataReceived.append(data)
@m.output()
def signal_connectionLost(self):
if self._protocol:
self._protocol.connectionLost(ConnectionDone())
else:
self._pending_connectionLost = (True, ConnectionDone())
self._manager.subchannel_closed(self)
# we're deleted momentarily
@m.output()
def error_closed_write(self, data):
raise AlreadyClosedError("write not allowed on closed subchannel")
@m.output()
def error_closed_close(self):
raise AlreadyClosedError(
"loseConnection not allowed on closed subchannel")
# primary transitions
open.upon(remote_data, enter=open, outputs=[signal_dataReceived])
open.upon(local_data, enter=open, outputs=[send_data])
open.upon(remote_close, enter=closed, outputs=[signal_connectionLost])
open.upon(local_close, enter=closing, outputs=[send_close])
closing.upon(remote_data, enter=closing, outputs=[signal_dataReceived])
closing.upon(remote_close, enter=closed, outputs=[signal_connectionLost])
# error cases
# we won't ever see an OPEN, since L4 will log+ignore those for us
closing.upon(local_data, enter=closing, outputs=[error_closed_write])
closing.upon(local_close, enter=closing, outputs=[error_closed_close])
# the CLOSED state won't ever see messages, since we'll be deleted
# our endpoints use this
def _set_protocol(self, protocol):
assert not self._protocol
self._protocol = protocol
if self._pending_dataReceived:
for data in self._pending_dataReceived:
self._protocol.dataReceived(data)
self._pending_dataReceived = []
cl, what = self._pending_connectionLost
if cl:
self._protocol.connectionLost(what)
# ITransport
def write(self, data):
assert isinstance(data, type(b""))
self.local_data(data)
def writeSequence(self, iovec):
self.write(b"".join(iovec))
def loseConnection(self):
self.local_close()
def getHost(self):
# we define "host addr" as the overall wormhole
return self._host_addr
def getPeer(self):
# and "peer addr" as the subchannel within that wormhole
return self._peer_addr
# IProducer: throttle inbound data (wormhole "up" to local app's Protocol)
def stopProducing(self):
self._manager.subchannel_stopProducing(self)
def pauseProducing(self):
self._manager.subchannel_pauseProducing(self)
def resumeProducing(self):
self._manager.subchannel_resumeProducing(self)
# IConsumer: allow the wormhole to throttle outbound data (app->wormhole)
def registerProducer(self, producer, streaming):
self._manager.subchannel_registerProducer(self, producer, streaming)
def unregisterProducer(self):
self._manager.subchannel_unregisterProducer(self)
@implementer(IStreamClientEndpoint)
class ControlEndpoint(object):
_used = False
def __init__(self, peer_addr):
self._subchannel_zero = Deferred()
self._peer_addr = peer_addr
self._once = Once(SingleUseEndpointError)
# from manager
def _subchannel_zero_opened(self, subchannel):
assert ISubChannel.providedBy(subchannel), subchannel
self._subchannel_zero.callback(subchannel)
@inlineCallbacks
def connect(self, protocolFactory):
# return Deferred that fires with IProtocol or Failure(ConnectError)
self._once()
t = yield self._subchannel_zero
p = protocolFactory.buildProtocol(self._peer_addr)
t._set_protocol(p)
p.makeConnection(t) # set p.transport = t and call connectionMade()
returnValue(p)
@implementer(IStreamClientEndpoint)
@attrs
class SubchannelConnectorEndpoint(object):
_manager = attrib(validator=provides(IDilationManager))
_host_addr = attrib(validator=instance_of(_WormholeAddress))
def connect(self, protocolFactory):
# return Deferred that fires with IProtocol or Failure(ConnectError)
scid = self._manager.allocate_subchannel_id()
self._manager.send_open(scid)
peer_addr = _SubchannelAddress(scid)
# ? f.doStart()
# ? f.startedConnecting(CONNECTOR) # ??
t = SubChannel(scid, self._manager, self._host_addr, peer_addr)
p = protocolFactory.buildProtocol(peer_addr)
t._set_protocol(p)
p.makeConnection(t) # set p.transport = t and call connectionMade()
return succeed(p)
@implementer(IStreamServerEndpoint)
@attrs
class SubchannelListenerEndpoint(object):
_manager = attrib(validator=provides(IDilationManager))
_host_addr = attrib(validator=provides(IAddress))
def __attrs_post_init__(self):
self._factory = None
self._pending_opens = []
# from manager
def _got_open(self, t, peer_addr):
if self._factory:
self._connect(t, peer_addr)
else:
self._pending_opens.append((t, peer_addr))
def _connect(self, t, peer_addr):
p = self._factory.buildProtocol(peer_addr)
t._set_protocol(p)
p.makeConnection(t)
# IStreamServerEndpoint
def listen(self, protocolFactory):
self._factory = protocolFactory
for (t, peer_addr) in self._pending_opens:
self._connect(t, peer_addr)
self._pending_opens = []
lp = SubchannelListeningPort(self._host_addr)
return succeed(lp)
@implementer(IListeningPort)
@attrs
class SubchannelListeningPort(object):
_host_addr = attrib(validator=provides(IAddress))
def startListening(self):
pass
def stopListening(self):
# TODO
pass
def getHost(self):
return self._host_addr

138
src/wormhole/_hints.py Normal file
View File

@ -0,0 +1,138 @@
from __future__ import print_function, unicode_literals
import sys
import re
import six
from collections import namedtuple
from twisted.internet.endpoints import HostnameEndpoint
from twisted.python import log
# These namedtuples are "hint objects". The JSON-serializable dictionaries
# are "hint dicts".
# DirectTCPV1Hint and TorTCPV1Hint mean the following protocol:
# * make a TCP connection (possibly via Tor)
# * send the sender/receiver handshake bytes first
# * expect to see the receiver/sender handshake bytes from the other side
# * the sender writes "go\n", the receiver waits for "go\n"
# * the rest of the connection contains transit data
DirectTCPV1Hint = namedtuple("DirectTCPV1Hint",
["hostname", "port", "priority"])
TorTCPV1Hint = namedtuple("TorTCPV1Hint", ["hostname", "port", "priority"])
# RelayV1Hint contains a tuple of DirectTCPV1Hint and TorTCPV1Hint hints (we
# use a tuple rather than a list so they'll be hashable into a set). For each
# one, make the TCP connection, send the relay handshake, then complete the
# rest of the V1 protocol. Only one hint per relay is useful.
RelayV1Hint = namedtuple("RelayV1Hint", ["hints"])
def describe_hint_obj(hint, relay, tor):
prefix = "tor->" if tor else "->"
if relay:
prefix = prefix + "relay:"
if isinstance(hint, DirectTCPV1Hint):
return prefix + "tcp:%s:%d" % (hint.hostname, hint.port)
elif isinstance(hint, TorTCPV1Hint):
return prefix + "tor:%s:%d" % (hint.hostname, hint.port)
else:
return prefix + str(hint)
def parse_hint_argv(hint, stderr=sys.stderr):
assert isinstance(hint, type(u""))
# return tuple or None for an unparseable hint
priority = 0.0
mo = re.search(r'^([a-zA-Z0-9]+):(.*)$', hint)
if not mo:
print("unparseable hint '%s'" % (hint, ), file=stderr)
return None
hint_type = mo.group(1)
if hint_type != "tcp":
print("unknown hint type '%s' in '%s'" % (hint_type, hint),
file=stderr)
return None
hint_value = mo.group(2)
pieces = hint_value.split(":")
if len(pieces) < 2:
print("unparseable TCP hint (need more colons) '%s'" % (hint, ),
file=stderr)
return None
mo = re.search(r'^(\d+)$', pieces[1])
if not mo:
print("non-numeric port in TCP hint '%s'" % (hint, ), file=stderr)
return None
hint_host = pieces[0]
hint_port = int(pieces[1])
for more in pieces[2:]:
if more.startswith("priority="):
more_pieces = more.split("=")
try:
priority = float(more_pieces[1])
except ValueError:
print("non-float priority= in TCP hint '%s'" % (hint, ),
file=stderr)
return None
return DirectTCPV1Hint(hint_host, hint_port, priority)
def endpoint_from_hint_obj(hint, tor, reactor):
if tor:
if isinstance(hint, (DirectTCPV1Hint, TorTCPV1Hint)):
# this Tor object will throw ValueError for non-public IPv4
# addresses and any IPv6 address
try:
return tor.stream_via(hint.hostname, hint.port)
except ValueError:
return None
return None
if isinstance(hint, DirectTCPV1Hint):
return HostnameEndpoint(reactor, hint.hostname, hint.port)
return None
def parse_tcp_v1_hint(hint): # hint_struct -> hint_obj
hint_type = hint.get("type", "")
if hint_type not in ["direct-tcp-v1", "tor-tcp-v1"]:
log.msg("unknown hint type: %r" % (hint, ))
return None
if not ("hostname" in hint and
isinstance(hint["hostname"], type(""))):
log.msg("invalid hostname in hint: %r" % (hint, ))
return None
if not ("port" in hint and
isinstance(hint["port"], six.integer_types)):
log.msg("invalid port in hint: %r" % (hint, ))
return None
priority = hint.get("priority", 0.0)
if hint_type == "direct-tcp-v1":
return DirectTCPV1Hint(hint["hostname"], hint["port"], priority)
else:
return TorTCPV1Hint(hint["hostname"], hint["port"], priority)
def parse_hint(hint_struct):
hint_type = hint_struct.get("type", "")
if hint_type == "relay-v1":
# the struct can include multiple ways to reach the same relay
rhints = filter(lambda h: h, # drop None (unrecognized)
[parse_tcp_v1_hint(rh) for rh in hint_struct["hints"]])
return RelayV1Hint(list(rhints))
return parse_tcp_v1_hint(hint_struct)
def encode_hint(h):
if isinstance(h, DirectTCPV1Hint):
return {"type": "direct-tcp-v1",
"priority": h.priority,
"hostname": h.hostname,
"port": h.port, # integer
}
elif isinstance(h, RelayV1Hint):
rhint = {"type": "relay-v1", "hints": []}
for rh in h.hints:
rhint["hints"].append({"type": "direct-tcp-v1",
"priority": rh.priority,
"hostname": rh.hostname,
"port": rh.port})
return rhint
elif isinstance(h, TorTCPV1Hint):
return {"type": "tor-tcp-v1",
"priority": h.priority,
"hostname": h.hostname,
"port": h.port, # integer
}
raise ValueError("unknown hint type", h)

View File

@ -433,3 +433,27 @@ class IInputHelper(Interface):
class IJournal(Interface): # TODO: this needs to be public
pass
class IDilator(Interface):
pass
class IDilationManager(Interface):
pass
class IDilationConnector(Interface):
pass
class ISubChannel(Interface):
pass
class IInbound(Interface):
pass
class IOutbound(Interface):
pass

View File

@ -6,7 +6,6 @@ import six
from attr import attrib, attrs
from attr.validators import instance_of, provides
from automat import MethodicalMachine
from hkdf import Hkdf
from nacl import utils
from nacl.exceptions import CryptoError
from nacl.secret import SecretBox
@ -15,16 +14,12 @@ from zope.interface import implementer
from . import _interfaces
from .util import (bytes_to_dict, bytes_to_hexstr, dict_to_bytes,
hexstr_to_bytes, to_bytes)
hexstr_to_bytes, to_bytes, HKDF)
CryptoError
__all__ = ["derive_key", "derive_phase_key", "CryptoError", "Key"]
def HKDF(skm, outlen, salt=None, CTXinfo=b""):
return Hkdf(salt, skm).expand(CTXinfo, outlen)
def derive_key(key, purpose, length=SecretBox.KEY_SIZE):
if not isinstance(key, type(b"")):
raise TypeError(type(key))

View File

@ -89,6 +89,7 @@ class RendezvousConnector(object):
# if the initial connection fails, signal an error and shut down. do
# this in a different reactor turn to avoid some hazards
d.addBoth(lambda res: task.deferLater(self._reactor, 0.0, lambda: res))
# TODO: use EventualQueue
d.addErrback(self._initial_connection_failed)
self._debug_record_inbound_f = None

View File

@ -70,3 +70,24 @@ class SequenceObserver(object):
if self._observers:
d = self._observers.pop(0)
self._eq.eventually(d.callback, self._results.pop(0))
class EmptyableSet(set):
# manage a set which grows and shrinks over time. Fire a Deferred the first
# time it becomes empty after you start watching for it.
def __init__(self, *args, **kwargs):
self._eq = kwargs.pop("_eventual_queue") # required
super(EmptyableSet, self).__init__(*args, **kwargs)
self._observer = None
def when_next_empty(self):
if not self._observer:
self._observer = OneShotObserver(self._eq)
return self._observer.when_fired()
def discard(self, o):
super(EmptyableSet, self).discard(o)
if self._observer and not self:
self._observer.fire(None)
self._observer = None

View File

View File

@ -0,0 +1,21 @@
from __future__ import print_function, unicode_literals
import mock
from zope.interface import alsoProvides
from ..._interfaces import IDilationManager, IWormhole
def mock_manager():
m = mock.Mock()
alsoProvides(m, IDilationManager)
return m
def mock_wormhole():
m = mock.Mock()
alsoProvides(m, IWormhole)
return m
def clear_mock_calls(*args):
for a in args:
a.mock_calls[:] = []

View File

@ -0,0 +1,217 @@
from __future__ import print_function, unicode_literals
import mock
from zope.interface import alsoProvides
from twisted.trial import unittest
from twisted.internet.task import Clock
from twisted.internet.interfaces import ITransport
from ...eventual import EventualQueue
from ..._interfaces import IDilationConnector
from ..._dilation.roles import LEADER, FOLLOWER
from ..._dilation.connection import (DilatedConnectionProtocol, encode_record,
KCM, Open, Ack)
from .common import clear_mock_calls
def make_con(role, use_relay=False):
clock = Clock()
eq = EventualQueue(clock)
connector = mock.Mock()
alsoProvides(connector, IDilationConnector)
n = mock.Mock() # pretends to be a Noise object
n.write_message = mock.Mock(side_effect=[b"handshake"])
c = DilatedConnectionProtocol(eq, role, connector, n,
b"outbound_prologue\n", b"inbound_prologue\n")
if use_relay:
c.use_relay(b"relay_handshake\n")
t = mock.Mock()
alsoProvides(t, ITransport)
return c, n, connector, t, eq
class Connection(unittest.TestCase):
def test_bad_prologue(self):
c, n, connector, t, eq = make_con(LEADER)
c.makeConnection(t)
d = c.when_disconnected()
self.assertEqual(n.mock_calls, [mock.call.start_handshake()])
self.assertEqual(connector.mock_calls, [])
self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")])
clear_mock_calls(n, connector, t)
c.dataReceived(b"prologue\n")
self.assertEqual(n.mock_calls, [])
self.assertEqual(connector.mock_calls, [])
self.assertEqual(t.mock_calls, [mock.call.loseConnection()])
eq.flush_sync()
self.assertNoResult(d)
c.connectionLost(b"why")
eq.flush_sync()
self.assertIdentical(self.successResultOf(d), c)
def _test_no_relay(self, role):
c, n, connector, t, eq = make_con(role)
t_kcm = KCM()
t_open = Open(seqnum=1, scid=0x11223344)
t_ack = Ack(resp_seqnum=2)
n.decrypt = mock.Mock(side_effect=[
encode_record(t_kcm),
encode_record(t_open),
])
exp_kcm = b"\x00\x00\x00\x03kcm"
n.encrypt = mock.Mock(side_effect=[b"kcm", b"ack1"])
m = mock.Mock() # Manager
c.makeConnection(t)
self.assertEqual(n.mock_calls, [mock.call.start_handshake()])
self.assertEqual(connector.mock_calls, [])
self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")])
clear_mock_calls(n, connector, t, m)
c.dataReceived(b"inbound_prologue\n")
self.assertEqual(n.mock_calls, [mock.call.write_message()])
self.assertEqual(connector.mock_calls, [])
exp_handshake = b"\x00\x00\x00\x09handshake"
self.assertEqual(t.mock_calls, [mock.call.write(exp_handshake)])
clear_mock_calls(n, connector, t, m)
c.dataReceived(b"\x00\x00\x00\x0Ahandshake2")
if role is LEADER:
# we're the leader, so we don't send the KCM right away
self.assertEqual(n.mock_calls, [
mock.call.read_message(b"handshake2")])
self.assertEqual(connector.mock_calls, [])
self.assertEqual(t.mock_calls, [])
self.assertEqual(c._manager, None)
else:
# we're the follower, so we encrypt and send the KCM immediately
self.assertEqual(n.mock_calls, [
mock.call.read_message(b"handshake2"),
mock.call.encrypt(encode_record(t_kcm)),
])
self.assertEqual(connector.mock_calls, [])
self.assertEqual(t.mock_calls, [
mock.call.write(exp_kcm)])
self.assertEqual(c._manager, None)
clear_mock_calls(n, connector, t, m)
c.dataReceived(b"\x00\x00\x00\x03KCM")
# leader: inbound KCM means we add the candidate
# follower: inbound KCM means we've been selected.
# in both cases we notify Connector.add_candidate(), and the Connector
# decides if/when to call .select()
self.assertEqual(n.mock_calls, [mock.call.decrypt(b"KCM")])
self.assertEqual(connector.mock_calls, [mock.call.add_candidate(c)])
self.assertEqual(t.mock_calls, [])
clear_mock_calls(n, connector, t, m)
# now pretend this connection wins (either the Leader decides to use
# this one among all the candiates, or we're the Follower and the
# Connector is reacting to add_candidate() by recognizing we're the
# only candidate there is)
c.select(m)
self.assertIdentical(c._manager, m)
if role is LEADER:
# TODO: currently Connector.select_and_stop_remaining() is
# responsible for sending the KCM just before calling c.select()
# iff we're the LEADER, therefore Connection.select won't send
# anything. This should be moved to c.select().
self.assertEqual(n.mock_calls, [])
self.assertEqual(connector.mock_calls, [])
self.assertEqual(t.mock_calls, [])
self.assertEqual(m.mock_calls, [])
c.send_record(KCM())
self.assertEqual(n.mock_calls, [
mock.call.encrypt(encode_record(t_kcm)),
])
self.assertEqual(connector.mock_calls, [])
self.assertEqual(t.mock_calls, [mock.call.write(exp_kcm)])
self.assertEqual(m.mock_calls, [])
else:
# follower: we already sent the KCM, do nothing
self.assertEqual(n.mock_calls, [])
self.assertEqual(connector.mock_calls, [])
self.assertEqual(t.mock_calls, [])
self.assertEqual(m.mock_calls, [])
clear_mock_calls(n, connector, t, m)
c.dataReceived(b"\x00\x00\x00\x04msg1")
self.assertEqual(n.mock_calls, [mock.call.decrypt(b"msg1")])
self.assertEqual(connector.mock_calls, [])
self.assertEqual(t.mock_calls, [])
self.assertEqual(m.mock_calls, [mock.call.got_record(t_open)])
clear_mock_calls(n, connector, t, m)
c.send_record(t_ack)
exp_ack = b"\x06\x00\x00\x00\x02"
self.assertEqual(n.mock_calls, [mock.call.encrypt(exp_ack)])
self.assertEqual(connector.mock_calls, [])
self.assertEqual(t.mock_calls, [mock.call.write(b"\x00\x00\x00\x04ack1")])
self.assertEqual(m.mock_calls, [])
clear_mock_calls(n, connector, t, m)
c.disconnect()
self.assertEqual(n.mock_calls, [])
self.assertEqual(connector.mock_calls, [])
self.assertEqual(t.mock_calls, [mock.call.loseConnection()])
self.assertEqual(m.mock_calls, [])
clear_mock_calls(n, connector, t, m)
def test_no_relay_leader(self):
return self._test_no_relay(LEADER)
def test_no_relay_follower(self):
return self._test_no_relay(FOLLOWER)
def test_relay(self):
c, n, connector, t, eq = make_con(LEADER, use_relay=True)
c.makeConnection(t)
self.assertEqual(n.mock_calls, [mock.call.start_handshake()])
self.assertEqual(connector.mock_calls, [])
self.assertEqual(t.mock_calls, [mock.call.write(b"relay_handshake\n")])
clear_mock_calls(n, connector, t)
c.dataReceived(b"ok\n")
self.assertEqual(n.mock_calls, [])
self.assertEqual(connector.mock_calls, [])
self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")])
clear_mock_calls(n, connector, t)
c.dataReceived(b"inbound_prologue\n")
self.assertEqual(n.mock_calls, [mock.call.write_message()])
self.assertEqual(connector.mock_calls, [])
exp_handshake = b"\x00\x00\x00\x09handshake"
self.assertEqual(t.mock_calls, [mock.call.write(exp_handshake)])
clear_mock_calls(n, connector, t)
def test_relay_jilted(self):
c, n, connector, t, eq = make_con(LEADER, use_relay=True)
d = c.when_disconnected()
c.makeConnection(t)
self.assertEqual(n.mock_calls, [mock.call.start_handshake()])
self.assertEqual(connector.mock_calls, [])
self.assertEqual(t.mock_calls, [mock.call.write(b"relay_handshake\n")])
clear_mock_calls(n, connector, t)
c.connectionLost(b"why")
eq.flush_sync()
self.assertIdentical(self.successResultOf(d), c)
def test_relay_bad_response(self):
c, n, connector, t, eq = make_con(LEADER, use_relay=True)
c.makeConnection(t)
self.assertEqual(n.mock_calls, [mock.call.start_handshake()])
self.assertEqual(connector.mock_calls, [])
self.assertEqual(t.mock_calls, [mock.call.write(b"relay_handshake\n")])
clear_mock_calls(n, connector, t)
c.dataReceived(b"not ok\n")
self.assertEqual(n.mock_calls, [])
self.assertEqual(connector.mock_calls, [])
self.assertEqual(t.mock_calls, [mock.call.loseConnection()])
clear_mock_calls(n, connector, t)

View File

@ -0,0 +1,463 @@
from __future__ import print_function, unicode_literals
import mock
from zope.interface import alsoProvides
from twisted.trial import unittest
from twisted.internet.task import Clock
from twisted.internet.defer import Deferred
from ...eventual import EventualQueue
from ..._interfaces import IDilationManager, IDilationConnector
from ..._hints import DirectTCPV1Hint, RelayV1Hint, TorTCPV1Hint
from ..._dilation import roles
from ..._dilation._noise import NoiseConnection
from ..._dilation.connection import KCM
from ..._dilation.connector import (Connector,
build_sided_relay_handshake,
build_noise,
OutboundConnectionFactory,
InboundConnectionFactory,
PROLOGUE_LEADER, PROLOGUE_FOLLOWER,
)
from .common import clear_mock_calls
class Handshake(unittest.TestCase):
def test_build(self):
key = b"k"*32
side = "12345678abcdabcd"
self.assertEqual(build_sided_relay_handshake(key, side),
b"please relay 3f4147851dbd2589d25b654ee9fb35ed0d3e5f19c5c5403e8e6a195c70f0577a for side 12345678abcdabcd\n")
class Outbound(unittest.TestCase):
def test_no_relay(self):
c = mock.Mock()
alsoProvides(c, IDilationConnector)
p0 = mock.Mock()
c.build_protocol = mock.Mock(return_value=p0)
relay_handshake = None
f = OutboundConnectionFactory(c, relay_handshake)
addr = object()
p = f.buildProtocol(addr)
self.assertIdentical(p, p0)
self.assertEqual(c.mock_calls, [mock.call.build_protocol(addr)])
self.assertEqual(p.mock_calls, [])
self.assertIdentical(p.factory, f)
def test_with_relay(self):
c = mock.Mock()
alsoProvides(c, IDilationConnector)
p0 = mock.Mock()
c.build_protocol = mock.Mock(return_value=p0)
relay_handshake = b"relay handshake"
f = OutboundConnectionFactory(c, relay_handshake)
addr = object()
p = f.buildProtocol(addr)
self.assertIdentical(p, p0)
self.assertEqual(c.mock_calls, [mock.call.build_protocol(addr)])
self.assertEqual(p.mock_calls, [mock.call.use_relay(relay_handshake)])
self.assertIdentical(p.factory, f)
class Inbound(unittest.TestCase):
def test_build(self):
c = mock.Mock()
alsoProvides(c, IDilationConnector)
p0 = mock.Mock()
c.build_protocol = mock.Mock(return_value=p0)
f = InboundConnectionFactory(c)
addr = object()
p = f.buildProtocol(addr)
self.assertIdentical(p, p0)
self.assertEqual(c.mock_calls, [mock.call.build_protocol(addr)])
self.assertIdentical(p.factory, f)
def make_connector(listen=True, tor=False, relay=None, role=roles.LEADER):
class Holder:
pass
h = Holder()
h.dilation_key = b"key"
h.relay = relay
h.manager = mock.Mock()
alsoProvides(h.manager, IDilationManager)
h.clock = Clock()
h.reactor = h.clock
h.eq = EventualQueue(h.clock)
h.tor = None
if tor:
h.tor = mock.Mock()
timing = None
h.side = u"abcd1234abcd5678"
h.role = role
c = Connector(h.dilation_key, h.relay, h.manager, h.reactor, h.eq,
not listen, h.tor, timing, h.side, h.role)
return c, h
class TestConnector(unittest.TestCase):
def test_build(self):
c, h = make_connector()
c, h = make_connector(relay="tcp:host:1234")
def test_connection_abilities(self):
self.assertEqual(Connector.get_connection_abilities(),
[{"type": "direct-tcp-v1"},
{"type": "relay-v1"},
])
def test_build_noise(self):
if not NoiseConnection:
raise unittest.SkipTest("noiseprotocol unavailable")
build_noise()
def test_build_protocol_leader(self):
c, h = make_connector(role=roles.LEADER)
n0 = mock.Mock()
p0 = mock.Mock()
addr = object()
with mock.patch("wormhole._dilation.connector.build_noise",
return_value=n0) as bn:
with mock.patch("wormhole._dilation.connector.DilatedConnectionProtocol",
return_value=p0) as dcp:
p = c.build_protocol(addr)
self.assertEqual(bn.mock_calls, [mock.call()])
self.assertEqual(n0.mock_calls, [mock.call.set_psks(h.dilation_key),
mock.call.set_as_initiator()])
self.assertIdentical(p, p0)
self.assertEqual(dcp.mock_calls,
[mock.call(h.eq, h.role, c, n0,
PROLOGUE_LEADER, PROLOGUE_FOLLOWER)])
def test_build_protocol_follower(self):
c, h = make_connector(role=roles.FOLLOWER)
n0 = mock.Mock()
p0 = mock.Mock()
addr = object()
with mock.patch("wormhole._dilation.connector.build_noise",
return_value=n0) as bn:
with mock.patch("wormhole._dilation.connector.DilatedConnectionProtocol",
return_value=p0) as dcp:
p = c.build_protocol(addr)
self.assertEqual(bn.mock_calls, [mock.call()])
self.assertEqual(n0.mock_calls, [mock.call.set_psks(h.dilation_key),
mock.call.set_as_responder()])
self.assertIdentical(p, p0)
self.assertEqual(dcp.mock_calls,
[mock.call(h.eq, h.role, c, n0,
PROLOGUE_FOLLOWER, PROLOGUE_LEADER)])
def test_start_stop(self):
c, h = make_connector(listen=False, relay=None, role=roles.LEADER)
c.start()
# no relays, so it publishes no hints
self.assertEqual(h.manager.mock_calls, [])
# and no listener, so nothing happens until we provide a hint
c.stop()
# we stop while we're connecting, so no connections must be stopped
def test_empty(self):
c, h = make_connector(listen=False, relay=None, role=roles.LEADER)
c._schedule_connection = mock.Mock()
c.start()
# no relays, so it publishes no hints
self.assertEqual(h.manager.mock_calls, [])
# and no listener, so nothing happens until we provide a hint
self.assertEqual(c._schedule_connection.mock_calls, [])
c.stop()
def test_basic(self):
c, h = make_connector(listen=False, relay=None, role=roles.LEADER)
c._schedule_connection = mock.Mock()
c.start()
# no relays, so it publishes no hints
self.assertEqual(h.manager.mock_calls, [])
# and no listener, so nothing happens until we provide a hint
self.assertEqual(c._schedule_connection.mock_calls, [])
hint = DirectTCPV1Hint("foo", 55, 0.0)
c.got_hints([hint])
# received hints don't get published
self.assertEqual(h.manager.mock_calls, [])
# they just schedule a connection
self.assertEqual(c._schedule_connection.mock_calls,
[mock.call(0.0, DirectTCPV1Hint("foo", 55, 0.0),
is_relay=False)])
def test_listen_addresses(self):
c, h = make_connector(listen=True, role=roles.LEADER)
with mock.patch("wormhole.ipaddrs.find_addresses",
return_value=["127.0.0.1", "1.2.3.4", "5.6.7.8"]):
self.assertEqual(c._get_listener_addresses(),
["1.2.3.4", "5.6.7.8"])
with mock.patch("wormhole.ipaddrs.find_addresses",
return_value=["127.0.0.1"]):
# some test hosts, including the appveyor VMs, *only* have
# 127.0.0.1, and the tests will hang badly if we remove it.
self.assertEqual(c._get_listener_addresses(), ["127.0.0.1"])
def test_listen(self):
c, h = make_connector(listen=True, role=roles.LEADER)
c._start_listener = mock.Mock()
with mock.patch("wormhole.ipaddrs.find_addresses",
return_value=["127.0.0.1", "1.2.3.4", "5.6.7.8"]):
c.start()
self.assertEqual(c._start_listener.mock_calls,
[mock.call(["1.2.3.4", "5.6.7.8"])])
def test_start_listen(self):
c, h = make_connector(listen=True, role=roles.LEADER)
ep = mock.Mock()
d = Deferred()
ep.listen = mock.Mock(return_value=d)
with mock.patch("wormhole._dilation.connector.serverFromString",
return_value=ep) as sfs:
c._start_listener(["1.2.3.4", "5.6.7.8"])
self.assertEqual(sfs.mock_calls, [mock.call(h.reactor, "tcp:0")])
lp = mock.Mock()
host = mock.Mock()
host.port = 66
lp.getHost = mock.Mock(return_value=host)
d.callback(lp)
self.assertEqual(h.manager.mock_calls,
[mock.call.send_hints([{"type": "direct-tcp-v1",
"hostname": "1.2.3.4",
"port": 66,
"priority": 0.0
},
{"type": "direct-tcp-v1",
"hostname": "5.6.7.8",
"port": 66,
"priority": 0.0
},
])])
def test_schedule_connection_no_relay(self):
c, h = make_connector(listen=True, role=roles.LEADER)
hint = DirectTCPV1Hint("foo", 55, 0.0)
ep = mock.Mock()
with mock.patch("wormhole._dilation.connector.endpoint_from_hint_obj",
side_effect=[ep]) as efho:
c._schedule_connection(0.0, hint, False)
self.assertEqual(efho.mock_calls, [mock.call(hint, h.tor, h.reactor)])
self.assertEqual(ep.mock_calls, [])
d = Deferred()
ep.connect = mock.Mock(side_effect=[d])
# direct hints are scheduled for T+0.0
f = mock.Mock()
with mock.patch("wormhole._dilation.connector.OutboundConnectionFactory",
return_value=f) as ocf:
h.clock.advance(1.0)
self.assertEqual(ocf.mock_calls, [mock.call(c, None)])
self.assertEqual(ep.connect.mock_calls, [mock.call(f)])
p = mock.Mock()
d.callback(p)
self.assertEqual(p.mock_calls,
[mock.call.when_disconnected(),
mock.call.when_disconnected().addCallback(c._pending_connections.discard)])
def test_schedule_connection_relay(self):
c, h = make_connector(listen=True, role=roles.LEADER)
hint = DirectTCPV1Hint("foo", 55, 0.0)
ep = mock.Mock()
with mock.patch("wormhole._dilation.connector.endpoint_from_hint_obj",
side_effect=[ep]) as efho:
c._schedule_connection(0.0, hint, True)
self.assertEqual(efho.mock_calls, [mock.call(hint, h.tor, h.reactor)])
self.assertEqual(ep.mock_calls, [])
d = Deferred()
ep.connect = mock.Mock(side_effect=[d])
# direct hints are scheduled for T+0.0
f = mock.Mock()
with mock.patch("wormhole._dilation.connector.OutboundConnectionFactory",
return_value=f) as ocf:
h.clock.advance(1.0)
handshake = build_sided_relay_handshake(h.dilation_key, h.side)
self.assertEqual(ocf.mock_calls, [mock.call(c, handshake)])
def test_listen_but_tor(self):
c, h = make_connector(listen=True, tor=True, role=roles.LEADER)
with mock.patch("wormhole.ipaddrs.find_addresses",
return_value=["127.0.0.1", "1.2.3.4", "5.6.7.8"]) as fa:
c.start()
# don't even look up addresses
self.assertEqual(fa.mock_calls, [])
# no relays and the listener isn't ready yet, so no hints yet
self.assertEqual(h.manager.mock_calls, [])
def test_no_listen(self):
c, h = make_connector(listen=False, tor=False, role=roles.LEADER)
with mock.patch("wormhole.ipaddrs.find_addresses",
return_value=["127.0.0.1", "1.2.3.4", "5.6.7.8"]) as fa:
c.start()
# don't even look up addresses
self.assertEqual(fa.mock_calls, [])
self.assertEqual(h.manager.mock_calls, [])
def test_relay_delay(self):
# given a direct connection and a relay, we should see the direct
# connection initiated at T+0 seconds, and the relay at T+RELAY_DELAY
c, h = make_connector(listen=True, relay=None, role=roles.LEADER)
c._schedule_connection = mock.Mock()
c._start_listener = mock.Mock()
c.start()
hint1 = DirectTCPV1Hint("foo", 55, 0.0)
hint2 = DirectTCPV1Hint("bar", 55, 0.0)
hint3 = RelayV1Hint([DirectTCPV1Hint("relay", 55, 0.0)])
c.got_hints([hint1, hint2, hint3])
self.assertEqual(c._schedule_connection.mock_calls,
[mock.call(0.0, hint1, is_relay=False),
mock.call(0.0, hint2, is_relay=False),
mock.call(c.RELAY_DELAY, hint3.hints[0], is_relay=True),
])
def test_initial_relay(self):
c, h = make_connector(listen=False, relay="tcp:foo:55", role=roles.LEADER)
c._schedule_connection = mock.Mock()
c.start()
self.assertEqual(h.manager.mock_calls,
[mock.call.send_hints([{"type": "relay-v1",
"hints": [
{"type": "direct-tcp-v1",
"hostname": "foo",
"port": 55,
"priority": 0.0
},
],
}])])
self.assertEqual(c._schedule_connection.mock_calls,
[mock.call(0.0, DirectTCPV1Hint("foo", 55, 0.0),
is_relay=True)])
def test_add_relay(self):
c, h = make_connector(listen=False, relay=None, role=roles.LEADER)
c._schedule_connection = mock.Mock()
c.start()
self.assertEqual(h.manager.mock_calls, [])
self.assertEqual(c._schedule_connection.mock_calls, [])
hint = RelayV1Hint([DirectTCPV1Hint("foo", 55, 0.0)])
c.add_relay([hint])
self.assertEqual(h.manager.mock_calls,
[mock.call.send_hints([{"type": "relay-v1",
"hints": [
{"type": "direct-tcp-v1",
"hostname": "foo",
"port": 55,
"priority": 0.0
},
],
}])])
self.assertEqual(c._schedule_connection.mock_calls,
[mock.call(0.0, DirectTCPV1Hint("foo", 55, 0.0),
is_relay=True)])
def test_tor_no_manager(self):
# tor hints should be ignored if we don't have a Tor manager to use them
c, h = make_connector(listen=False, role=roles.LEADER)
c._schedule_connection = mock.Mock()
c.start()
hint = TorTCPV1Hint("foo", 55, 0.0)
c.got_hints([hint])
self.assertEqual(h.manager.mock_calls, [])
self.assertEqual(c._schedule_connection.mock_calls, [])
def test_tor_with_manager(self):
# tor hints should be processed if we do have a Tor manager
c, h = make_connector(listen=False, tor=True, role=roles.LEADER)
c._schedule_connection = mock.Mock()
c.start()
hint = TorTCPV1Hint("foo", 55, 0.0)
c.got_hints([hint])
self.assertEqual(c._schedule_connection.mock_calls,
[mock.call(0.0, hint, is_relay=False)])
def test_priorities(self):
# given two hints with different priorities, we should somehow prefer
# one. This is a placeholder to fill in once we implement priorities.
pass
class Race(unittest.TestCase):
def test_one_leader(self):
c, h = make_connector(listen=True, role=roles.LEADER)
lp = mock.Mock()
def start_listener(addresses):
c._listeners.add(lp)
c._start_listener = start_listener
c._schedule_connection = mock.Mock()
c.start()
self.assertEqual(c._listeners, set([lp]))
p1 = mock.Mock() # DilatedConnectionProtocol instance
c.add_candidate(p1)
self.assertEqual(h.manager.mock_calls, [])
h.eq.flush_sync()
self.assertEqual(h.manager.mock_calls, [mock.call.use_connection(p1)])
self.assertEqual(p1.mock_calls,
[mock.call.select(h.manager),
mock.call.send_record(KCM())])
self.assertEqual(lp.mock_calls[0], mock.call.stopListening())
# stop_listeners() uses a DeferredList, so we ignore the second call
def test_one_follower(self):
c, h = make_connector(listen=True, role=roles.FOLLOWER)
lp = mock.Mock()
def start_listener(addresses):
c._listeners.add(lp)
c._start_listener = start_listener
c._schedule_connection = mock.Mock()
c.start()
self.assertEqual(c._listeners, set([lp]))
p1 = mock.Mock() # DilatedConnectionProtocol instance
c.add_candidate(p1)
self.assertEqual(h.manager.mock_calls, [])
h.eq.flush_sync()
self.assertEqual(h.manager.mock_calls, [mock.call.use_connection(p1)])
# just like LEADER, but follower doesn't send KCM now (it sent one
# earlier, to tell the leader that this connection looks viable)
self.assertEqual(p1.mock_calls,
[mock.call.select(h.manager)])
self.assertEqual(lp.mock_calls[0], mock.call.stopListening())
# stop_listeners() uses a DeferredList, so we ignore the second call
# TODO: make sure a pending connection is abandoned when the listener
# answers successfully
# TODO: make sure a second pending connection is abandoned when the first
# connection succeeds
def test_late(self):
c, h = make_connector(listen=False, role=roles.LEADER)
c._schedule_connection = mock.Mock()
c.start()
p1 = mock.Mock() # DilatedConnectionProtocol instance
c.add_candidate(p1)
self.assertEqual(h.manager.mock_calls, [])
h.eq.flush_sync()
self.assertEqual(h.manager.mock_calls, [mock.call.use_connection(p1)])
clear_mock_calls(h.manager)
self.assertEqual(p1.mock_calls,
[mock.call.select(h.manager),
mock.call.send_record(KCM())])
# late connection is ignored
p2 = mock.Mock()
c.add_candidate(p2)
self.assertEqual(h.manager.mock_calls, [])
# make sure an established connection is dropped when stop() is called
def test_stop(self):
c, h = make_connector(listen=False, role=roles.LEADER)
c._schedule_connection = mock.Mock()
c.start()
p1 = mock.Mock() # DilatedConnectionProtocol instance
c.add_candidate(p1)
self.assertEqual(h.manager.mock_calls, [])
h.eq.flush_sync()
self.assertEqual(h.manager.mock_calls, [mock.call.use_connection(p1)])
self.assertEqual(p1.mock_calls,
[mock.call.select(h.manager),
mock.call.send_record(KCM())])
c.stop()

View File

@ -0,0 +1,26 @@
from __future__ import print_function, unicode_literals
from twisted.trial import unittest
from ..._dilation.encode import to_be4, from_be4
class Encoding(unittest.TestCase):
def test_be4(self):
self.assertEqual(to_be4(0), b"\x00\x00\x00\x00")
self.assertEqual(to_be4(1), b"\x00\x00\x00\x01")
self.assertEqual(to_be4(256), b"\x00\x00\x01\x00")
self.assertEqual(to_be4(257), b"\x00\x00\x01\x01")
with self.assertRaises(ValueError):
to_be4(-1)
with self.assertRaises(ValueError):
to_be4(2**32)
self.assertEqual(from_be4(b"\x00\x00\x00\x00"), 0)
self.assertEqual(from_be4(b"\x00\x00\x00\x01"), 1)
self.assertEqual(from_be4(b"\x00\x00\x01\x00"), 256)
self.assertEqual(from_be4(b"\x00\x00\x01\x01"), 257)
with self.assertRaises(TypeError):
from_be4(0)
with self.assertRaises(ValueError):
from_be4(b"\x01\x00\x00\x00\x00")

View File

@ -0,0 +1,98 @@
from __future__ import print_function, unicode_literals
import mock
from zope.interface import alsoProvides
from twisted.trial import unittest
from ..._interfaces import ISubChannel
from ..._dilation.subchannel import (ControlEndpoint,
SubchannelConnectorEndpoint,
SubchannelListenerEndpoint,
SubchannelListeningPort,
_WormholeAddress, _SubchannelAddress,
SingleUseEndpointError)
from .common import mock_manager
class Endpoints(unittest.TestCase):
def test_control(self):
scid0 = b"scid0"
peeraddr = _SubchannelAddress(scid0)
ep = ControlEndpoint(peeraddr)
f = mock.Mock()
p = mock.Mock()
f.buildProtocol = mock.Mock(return_value=p)
d = ep.connect(f)
self.assertNoResult(d)
t = mock.Mock()
alsoProvides(t, ISubChannel)
ep._subchannel_zero_opened(t)
self.assertIdentical(self.successResultOf(d), p)
self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr)])
self.assertEqual(t.mock_calls, [mock.call._set_protocol(p)])
self.assertEqual(p.mock_calls, [mock.call.makeConnection(t)])
d = ep.connect(f)
self.failureResultOf(d, SingleUseEndpointError)
def assert_makeConnection(self, mock_calls):
self.assertEqual(len(mock_calls), 1)
self.assertEqual(mock_calls[0][0], "makeConnection")
self.assertEqual(len(mock_calls[0][1]), 1)
return mock_calls[0][1][0]
def test_connector(self):
m = mock_manager()
m.allocate_subchannel_id = mock.Mock(return_value=b"scid")
hostaddr = _WormholeAddress()
peeraddr = _SubchannelAddress(b"scid")
ep = SubchannelConnectorEndpoint(m, hostaddr)
f = mock.Mock()
p = mock.Mock()
t = mock.Mock()
f.buildProtocol = mock.Mock(return_value=p)
with mock.patch("wormhole._dilation.subchannel.SubChannel",
return_value=t) as sc:
d = ep.connect(f)
self.assertIdentical(self.successResultOf(d), p)
self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr)])
self.assertEqual(sc.mock_calls, [mock.call(b"scid", m, hostaddr, peeraddr)])
self.assertEqual(t.mock_calls, [mock.call._set_protocol(p)])
self.assertEqual(p.mock_calls, [mock.call.makeConnection(t)])
def test_listener(self):
m = mock_manager()
m.allocate_subchannel_id = mock.Mock(return_value=b"scid")
hostaddr = _WormholeAddress()
ep = SubchannelListenerEndpoint(m, hostaddr)
f = mock.Mock()
p1 = mock.Mock()
p2 = mock.Mock()
f.buildProtocol = mock.Mock(side_effect=[p1, p2])
# OPEN that arrives before we ep.listen() should be queued
t1 = mock.Mock()
peeraddr1 = _SubchannelAddress(b"peer1")
ep._got_open(t1, peeraddr1)
d = ep.listen(f)
lp = self.successResultOf(d)
self.assertIsInstance(lp, SubchannelListeningPort)
self.assertEqual(lp.getHost(), hostaddr)
lp.startListening()
self.assertEqual(t1.mock_calls, [mock.call._set_protocol(p1)])
self.assertEqual(p1.mock_calls, [mock.call.makeConnection(t1)])
t2 = mock.Mock()
peeraddr2 = _SubchannelAddress(b"peer2")
ep._got_open(t2, peeraddr2)
self.assertEqual(t2.mock_calls, [mock.call._set_protocol(p2)])
self.assertEqual(p2.mock_calls, [mock.call.makeConnection(t2)])
lp.stopListening() # TODO: should this do more?

View File

@ -0,0 +1,112 @@
from __future__ import print_function, unicode_literals
import mock
from zope.interface import alsoProvides
from twisted.trial import unittest
from twisted.internet.interfaces import ITransport
from ..._dilation.connection import _Framer, Frame, Prologue, Disconnect
def make_framer():
t = mock.Mock()
alsoProvides(t, ITransport)
f = _Framer(t, b"outbound_prologue\n", b"inbound_prologue\n")
return f, t
class Framer(unittest.TestCase):
def test_bad_prologue_length(self):
f, t = make_framer()
self.assertEqual(t.mock_calls, [])
f.connectionMade()
self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")])
t.mock_calls[:] = []
self.assertEqual([], list(f.add_and_parse(b"inbound_"))) # wait for it
self.assertEqual(t.mock_calls, [])
with mock.patch("wormhole._dilation.connection.log.msg") as m:
with self.assertRaises(Disconnect):
list(f.add_and_parse(b"not the prologue after all"))
self.assertEqual(m.mock_calls,
[mock.call("bad prologue: {}".format(
b"inbound_not the p"))])
self.assertEqual(t.mock_calls, [])
def test_bad_prologue_newline(self):
f, t = make_framer()
self.assertEqual(t.mock_calls, [])
f.connectionMade()
self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")])
t.mock_calls[:] = []
self.assertEqual([], list(f.add_and_parse(b"inbound_"))) # wait for it
self.assertEqual([], list(f.add_and_parse(b"not")))
with mock.patch("wormhole._dilation.connection.log.msg") as m:
with self.assertRaises(Disconnect):
list(f.add_and_parse(b"\n"))
self.assertEqual(m.mock_calls,
[mock.call("bad prologue: {}".format(
b"inbound_not\n"))])
self.assertEqual(t.mock_calls, [])
def test_good_prologue(self):
f, t = make_framer()
self.assertEqual(t.mock_calls, [])
f.connectionMade()
self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")])
t.mock_calls[:] = []
self.assertEqual([Prologue()],
list(f.add_and_parse(b"inbound_prologue\n")))
self.assertEqual(t.mock_calls, [])
# now send_frame should work
f.send_frame(b"frame")
self.assertEqual(t.mock_calls,
[mock.call.write(b"\x00\x00\x00\x05frame")])
def test_bad_relay(self):
f, t = make_framer()
self.assertEqual(t.mock_calls, [])
f.use_relay(b"relay handshake\n")
f.connectionMade()
self.assertEqual(t.mock_calls, [mock.call.write(b"relay handshake\n")])
t.mock_calls[:] = []
with mock.patch("wormhole._dilation.connection.log.msg") as m:
with self.assertRaises(Disconnect):
list(f.add_and_parse(b"goodbye\n"))
self.assertEqual(m.mock_calls,
[mock.call("bad relay_ok: {}".format(b"goo"))])
self.assertEqual(t.mock_calls, [])
def test_good_relay(self):
f, t = make_framer()
self.assertEqual(t.mock_calls, [])
f.use_relay(b"relay handshake\n")
self.assertEqual(t.mock_calls, [])
f.connectionMade()
self.assertEqual(t.mock_calls, [mock.call.write(b"relay handshake\n")])
t.mock_calls[:] = []
self.assertEqual([], list(f.add_and_parse(b"ok\n")))
self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")])
def test_frame(self):
f, t = make_framer()
self.assertEqual(t.mock_calls, [])
f.connectionMade()
self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")])
t.mock_calls[:] = []
self.assertEqual([Prologue()],
list(f.add_and_parse(b"inbound_prologue\n")))
self.assertEqual(t.mock_calls, [])
encoded_frame = b"\x00\x00\x00\x05frame"
self.assertEqual([], list(f.add_and_parse(encoded_frame[:2])))
self.assertEqual([], list(f.add_and_parse(encoded_frame[2:6])))
self.assertEqual([Frame(frame=b"frame")],
list(f.add_and_parse(encoded_frame[6:])))

View File

@ -0,0 +1,174 @@
from __future__ import print_function, unicode_literals
import mock
from zope.interface import alsoProvides
from twisted.trial import unittest
from ..._interfaces import IDilationManager
from ..._dilation.connection import Open, Data, Close
from ..._dilation.inbound import (Inbound, DuplicateOpenError,
DataForMissingSubchannelError,
CloseForMissingSubchannelError)
def make_inbound():
m = mock.Mock()
alsoProvides(m, IDilationManager)
host_addr = object()
i = Inbound(m, host_addr)
return i, m, host_addr
class InboundTest(unittest.TestCase):
def test_seqnum(self):
i, m, host_addr = make_inbound()
r1 = Open(scid=513, seqnum=1)
r2 = Data(scid=513, seqnum=2, data=b"")
r3 = Close(scid=513, seqnum=3)
self.assertFalse(i.is_record_old(r1))
self.assertFalse(i.is_record_old(r2))
self.assertFalse(i.is_record_old(r3))
i.update_ack_watermark(r1)
self.assertTrue(i.is_record_old(r1))
self.assertFalse(i.is_record_old(r2))
self.assertFalse(i.is_record_old(r3))
i.update_ack_watermark(r2)
self.assertTrue(i.is_record_old(r1))
self.assertTrue(i.is_record_old(r2))
self.assertFalse(i.is_record_old(r3))
def test_open_data_close(self):
i, m, host_addr = make_inbound()
scid1 = b"scid"
scid2 = b"scXX"
c = mock.Mock()
lep = mock.Mock()
i.set_listener_endpoint(lep)
i.use_connection(c)
sc1 = mock.Mock()
peer_addr = object()
with mock.patch("wormhole._dilation.inbound.SubChannel",
side_effect=[sc1]) as sc:
with mock.patch("wormhole._dilation.inbound._SubchannelAddress",
side_effect=[peer_addr]) as sca:
i.handle_open(scid1)
self.assertEqual(lep.mock_calls, [mock.call._got_open(sc1, peer_addr)])
self.assertEqual(sc.mock_calls, [mock.call(scid1, m, host_addr, peer_addr)])
self.assertEqual(sca.mock_calls, [mock.call(scid1)])
lep.mock_calls[:] = []
# a subsequent duplicate OPEN should be ignored
with mock.patch("wormhole._dilation.inbound.SubChannel",
side_effect=[sc1]) as sc:
with mock.patch("wormhole._dilation.inbound._SubchannelAddress",
side_effect=[peer_addr]) as sca:
i.handle_open(scid1)
self.assertEqual(lep.mock_calls, [])
self.assertEqual(sc.mock_calls, [])
self.assertEqual(sca.mock_calls, [])
self.flushLoggedErrors(DuplicateOpenError)
i.handle_data(scid1, b"data")
self.assertEqual(sc1.mock_calls, [mock.call.remote_data(b"data")])
sc1.mock_calls[:] = []
i.handle_data(scid2, b"for non-existent subchannel")
self.assertEqual(sc1.mock_calls, [])
self.flushLoggedErrors(DataForMissingSubchannelError)
i.handle_close(scid1)
self.assertEqual(sc1.mock_calls, [mock.call.remote_close()])
sc1.mock_calls[:] = []
i.handle_close(scid2)
self.assertEqual(sc1.mock_calls, [])
self.flushLoggedErrors(CloseForMissingSubchannelError)
# after the subchannel is closed, the Manager will notify Inbound
i.subchannel_closed(scid1, sc1)
i.stop_using_connection()
def test_control_channel(self):
i, m, host_addr = make_inbound()
lep = mock.Mock()
i.set_listener_endpoint(lep)
scid0 = b"scid"
sc0 = mock.Mock()
i.set_subchannel_zero(scid0, sc0)
# OPEN on the control channel identifier should be ignored as a
# duplicate, since the control channel is already registered
sc1 = mock.Mock()
peer_addr = object()
with mock.patch("wormhole._dilation.inbound.SubChannel",
side_effect=[sc1]) as sc:
with mock.patch("wormhole._dilation.inbound._SubchannelAddress",
side_effect=[peer_addr]) as sca:
i.handle_open(scid0)
self.assertEqual(lep.mock_calls, [])
self.assertEqual(sc.mock_calls, [])
self.assertEqual(sca.mock_calls, [])
self.flushLoggedErrors(DuplicateOpenError)
# and DATA to it should be delivered correctly
i.handle_data(scid0, b"data")
self.assertEqual(sc0.mock_calls, [mock.call.remote_data(b"data")])
sc0.mock_calls[:] = []
def test_pause(self):
i, m, host_addr = make_inbound()
c = mock.Mock()
lep = mock.Mock()
i.set_listener_endpoint(lep)
# add two subchannels, pause one, then add a connection
scid1 = b"sci1"
scid2 = b"sci2"
sc1 = mock.Mock()
sc2 = mock.Mock()
peer_addr = object()
with mock.patch("wormhole._dilation.inbound.SubChannel",
side_effect=[sc1, sc2]):
with mock.patch("wormhole._dilation.inbound._SubchannelAddress",
return_value=peer_addr):
i.handle_open(scid1)
i.handle_open(scid2)
self.assertEqual(c.mock_calls, [])
i.subchannel_pauseProducing(sc1)
self.assertEqual(c.mock_calls, [])
i.subchannel_resumeProducing(sc1)
self.assertEqual(c.mock_calls, [])
i.subchannel_pauseProducing(sc1)
self.assertEqual(c.mock_calls, [])
i.use_connection(c)
self.assertEqual(c.mock_calls, [mock.call.pauseProducing()])
c.mock_calls[:] = []
i.subchannel_resumeProducing(sc1)
self.assertEqual(c.mock_calls, [mock.call.resumeProducing()])
c.mock_calls[:] = []
# consumers aren't really supposed to do this, but tolerate it
i.subchannel_resumeProducing(sc1)
self.assertEqual(c.mock_calls, [])
i.subchannel_pauseProducing(sc1)
self.assertEqual(c.mock_calls, [mock.call.pauseProducing()])
c.mock_calls[:] = []
i.subchannel_pauseProducing(sc2)
self.assertEqual(c.mock_calls, []) # was already paused
# tolerate duplicate pauseProducing
i.subchannel_pauseProducing(sc2)
self.assertEqual(c.mock_calls, [])
# stopProducing is treated like a terminal resumeProducing
i.subchannel_stopProducing(sc1)
self.assertEqual(c.mock_calls, [])
i.subchannel_stopProducing(sc2)
self.assertEqual(c.mock_calls, [mock.call.resumeProducing()])
c.mock_calls[:] = []

View File

@ -0,0 +1,650 @@
from __future__ import print_function, unicode_literals
from zope.interface import alsoProvides
from twisted.trial import unittest
from twisted.internet.defer import Deferred
from twisted.internet.task import Clock, Cooperator
import mock
from ...eventual import EventualQueue
from ..._interfaces import ISend, IDilationManager
from ...util import dict_to_bytes
from ..._dilation import roles
from ..._dilation.encode import to_be4
from ..._dilation.manager import (Dilator, Manager, make_side,
OldPeerCannotDilateError,
UnknownDilationMessageType,
UnexpectedKCM,
UnknownMessageType)
from ..._dilation.subchannel import _WormholeAddress
from ..._dilation.connection import Open, Data, Close, Ack, KCM, Ping, Pong
from .common import clear_mock_calls
def make_dilator():
reactor = object()
clock = Clock()
eq = EventualQueue(clock)
term = mock.Mock(side_effect=lambda: True) # one write per Eventual tick
def term_factory():
return term
coop = Cooperator(terminationPredicateFactory=term_factory,
scheduler=eq.eventually)
send = mock.Mock()
alsoProvides(send, ISend)
dil = Dilator(reactor, eq, coop)
dil.wire(send)
return dil, send, reactor, eq, clock, coop
class TestDilator(unittest.TestCase):
def test_manager_and_endpoints(self):
dil, send, reactor, eq, clock, coop = make_dilator()
d1 = dil.dilate()
d2 = dil.dilate()
self.assertNoResult(d1)
self.assertNoResult(d2)
key = b"key"
transit_key = object()
with mock.patch("wormhole._dilation.manager.derive_key",
return_value=transit_key) as dk:
dil.got_key(key)
self.assertEqual(dk.mock_calls, [mock.call(key, b"dilation-v1", 32)])
self.assertIdentical(dil._transit_key, transit_key)
self.assertNoResult(d1)
self.assertNoResult(d2)
m = mock.Mock()
alsoProvides(m, IDilationManager)
m.when_first_connected.return_value = wfc_d = Deferred()
with mock.patch("wormhole._dilation.manager.Manager",
return_value=m) as ml:
with mock.patch("wormhole._dilation.manager.make_side",
return_value="us"):
dil.got_wormhole_versions({"can-dilate": ["1"]})
# that should create the Manager
self.assertEqual(ml.mock_calls, [mock.call(send, "us", transit_key,
None, reactor, eq, coop)])
# and tell it to start, and get wait-for-it-to-connect Deferred
self.assertEqual(m.mock_calls, [mock.call.start(),
mock.call.when_first_connected(),
])
clear_mock_calls(m)
self.assertNoResult(d1)
self.assertNoResult(d2)
host_addr = _WormholeAddress()
m_wa = mock.patch("wormhole._dilation.manager._WormholeAddress",
return_value=host_addr)
peer_addr = object()
m_sca = mock.patch("wormhole._dilation.manager._SubchannelAddress",
return_value=peer_addr)
ce = mock.Mock()
m_ce = mock.patch("wormhole._dilation.manager.ControlEndpoint",
return_value=ce)
sc = mock.Mock()
m_sc = mock.patch("wormhole._dilation.manager.SubChannel",
return_value=sc)
lep = object()
m_sle = mock.patch("wormhole._dilation.manager.SubchannelListenerEndpoint",
return_value=lep)
with m_wa, m_sca, m_ce as m_ce_m, m_sc as m_sc_m, m_sle as m_sle_m:
wfc_d.callback(None)
eq.flush_sync()
scid0 = b"\x00\x00\x00\x00"
self.assertEqual(m_ce_m.mock_calls, [mock.call(peer_addr)])
self.assertEqual(m_sc_m.mock_calls,
[mock.call(scid0, m, host_addr, peer_addr)])
self.assertEqual(ce.mock_calls, [mock.call._subchannel_zero_opened(sc)])
self.assertEqual(m_sle_m.mock_calls, [mock.call(m, host_addr)])
self.assertEqual(m.mock_calls,
[mock.call.set_subchannel_zero(scid0, sc),
mock.call.set_listener_endpoint(lep),
])
clear_mock_calls(m)
eps = self.successResultOf(d1)
self.assertEqual(eps, self.successResultOf(d2))
d3 = dil.dilate()
eq.flush_sync()
self.assertEqual(eps, self.successResultOf(d3))
# all subsequent DILATE-n messages should get passed to the manager
self.assertEqual(m.mock_calls, [])
pleasemsg = dict(type="please", side="them")
dil.received_dilate(dict_to_bytes(pleasemsg))
self.assertEqual(m.mock_calls, [mock.call.rx_PLEASE(pleasemsg)])
clear_mock_calls(m)
hintmsg = dict(type="connection-hints")
dil.received_dilate(dict_to_bytes(hintmsg))
self.assertEqual(m.mock_calls, [mock.call.rx_HINTS(hintmsg)])
clear_mock_calls(m)
# we're nominally the LEADER, and the leader would not normally be
# receiving a RECONNECT, but since we've mocked out the Manager it
# won't notice
dil.received_dilate(dict_to_bytes(dict(type="reconnect")))
self.assertEqual(m.mock_calls, [mock.call.rx_RECONNECT()])
clear_mock_calls(m)
dil.received_dilate(dict_to_bytes(dict(type="reconnecting")))
self.assertEqual(m.mock_calls, [mock.call.rx_RECONNECTING()])
clear_mock_calls(m)
dil.received_dilate(dict_to_bytes(dict(type="unknown")))
self.assertEqual(m.mock_calls, [])
self.flushLoggedErrors(UnknownDilationMessageType)
def test_peer_cannot_dilate(self):
dil, send, reactor, eq, clock, coop = make_dilator()
d1 = dil.dilate()
self.assertNoResult(d1)
dil._transit_key = b"\x01" * 32
dil.got_wormhole_versions({}) # missing "can-dilate"
eq.flush_sync()
f = self.failureResultOf(d1)
f.check(OldPeerCannotDilateError)
def test_disjoint_versions(self):
dil, send, reactor, eq, clock, coop = make_dilator()
d1 = dil.dilate()
self.assertNoResult(d1)
dil._transit_key = b"key"
dil.got_wormhole_versions({"can-dilate": [-1]})
eq.flush_sync()
f = self.failureResultOf(d1)
f.check(OldPeerCannotDilateError)
def test_early_dilate_messages(self):
dil, send, reactor, eq, clock, coop = make_dilator()
dil._transit_key = b"key"
d1 = dil.dilate()
self.assertNoResult(d1)
pleasemsg = dict(type="please", side="them")
dil.received_dilate(dict_to_bytes(pleasemsg))
hintmsg = dict(type="connection-hints")
dil.received_dilate(dict_to_bytes(hintmsg))
m = mock.Mock()
alsoProvides(m, IDilationManager)
m.when_first_connected.return_value = Deferred()
with mock.patch("wormhole._dilation.manager.Manager",
return_value=m) as ml:
with mock.patch("wormhole._dilation.manager.make_side",
return_value="us"):
dil.got_wormhole_versions({"can-dilate": ["1"]})
self.assertEqual(ml.mock_calls, [mock.call(send, "us", b"key",
None, reactor, eq, coop)])
self.assertEqual(m.mock_calls, [mock.call.start(),
mock.call.rx_PLEASE(pleasemsg),
mock.call.rx_HINTS(hintmsg),
mock.call.when_first_connected()])
def test_transit_relay(self):
dil, send, reactor, eq, clock, coop = make_dilator()
dil._transit_key = b"key"
relay = object()
d1 = dil.dilate(transit_relay_location=relay)
self.assertNoResult(d1)
with mock.patch("wormhole._dilation.manager.Manager") as ml:
with mock.patch("wormhole._dilation.manager.make_side",
return_value="us"):
dil.got_wormhole_versions({"can-dilate": ["1"]})
self.assertEqual(ml.mock_calls, [mock.call(send, "us", b"key",
relay, reactor, eq, coop),
mock.call().start(),
mock.call().when_first_connected()])
LEADER = "ff3456abcdef"
FOLLOWER = "123456abcdef"
def make_manager(leader=True):
class Holder:
pass
h = Holder()
h.send = mock.Mock()
alsoProvides(h.send, ISend)
if leader:
side = LEADER
else:
side = FOLLOWER
h.key = b"\x00" * 32
h.relay = None
h.reactor = object()
h.clock = Clock()
h.eq = EventualQueue(h.clock)
term = mock.Mock(side_effect=lambda: True) # one write per Eventual tick
def term_factory():
return term
h.coop = Cooperator(terminationPredicateFactory=term_factory,
scheduler=h.eq.eventually)
h.inbound = mock.Mock()
h.Inbound = mock.Mock(return_value=h.inbound)
h.outbound = mock.Mock()
h.Outbound = mock.Mock(return_value=h.outbound)
h.hostaddr = object()
with mock.patch("wormhole._dilation.manager.Inbound", h.Inbound):
with mock.patch("wormhole._dilation.manager.Outbound", h.Outbound):
with mock.patch("wormhole._dilation.manager._WormholeAddress",
return_value=h.hostaddr):
m = Manager(h.send, side, h.key, h.relay, h.reactor, h.eq, h.coop)
return m, h
class TestManager(unittest.TestCase):
def test_make_side(self):
side = make_side()
self.assertEqual(type(side), type(u""))
self.assertEqual(len(side), 2 * 6)
def test_create(self):
m, h = make_manager()
def test_leader(self):
m, h = make_manager(leader=True)
self.assertEqual(h.send.mock_calls, [])
self.assertEqual(h.Inbound.mock_calls, [mock.call(m, h.hostaddr)])
self.assertEqual(h.Outbound.mock_calls, [mock.call(m, h.coop)])
m.start()
self.assertEqual(h.send.mock_calls, [
mock.call.send("dilate-0",
dict_to_bytes({"type": "please", "side": LEADER}))
])
clear_mock_calls(h.send)
wfc_d = m.when_first_connected()
self.assertNoResult(wfc_d)
# ignore early hints
m.rx_HINTS({})
self.assertEqual(h.send.mock_calls, [])
c = mock.Mock()
connector = mock.Mock(return_value=c)
with mock.patch("wormhole._dilation.manager.Connector", connector):
# receiving this PLEASE triggers creation of the Connector
m.rx_PLEASE({"side": FOLLOWER})
self.assertEqual(h.send.mock_calls, [])
self.assertEqual(connector.mock_calls, [
mock.call(b"\x00" * 32, None, m, h.reactor, h.eq,
False, # no_listen
None, # tor
None, # timing
LEADER, roles.LEADER),
])
self.assertEqual(c.mock_calls, [mock.call.start()])
clear_mock_calls(connector, c)
self.assertNoResult(wfc_d)
# now any inbound hints should get passed to our Connector
with mock.patch("wormhole._dilation.manager.parse_hint",
side_effect=["p1", None, "p3"]) as ph:
m.rx_HINTS({"hints": [1, 2, 3]})
self.assertEqual(ph.mock_calls, [mock.call(1), mock.call(2), mock.call(3)])
self.assertEqual(c.mock_calls, [mock.call.got_hints(["p1", "p3"])])
clear_mock_calls(ph, c)
# and we send out any (listening) hints from our Connector
m.send_hints([1, 2])
self.assertEqual(h.send.mock_calls, [
mock.call.send("dilate-1",
dict_to_bytes({"type": "connection-hints",
"hints": [1, 2]}))
])
clear_mock_calls(h.send)
# the first successful connection fires when_first_connected(), so
# the Dilator can create and return the endpoints
c1 = mock.Mock()
m.connector_connection_made(c1)
self.assertEqual(h.inbound.mock_calls, [mock.call.use_connection(c1)])
self.assertEqual(h.outbound.mock_calls, [mock.call.use_connection(c1)])
clear_mock_calls(h.inbound, h.outbound)
h.eq.flush_sync()
self.successResultOf(wfc_d) # fires with None
wfc_d2 = m.when_first_connected()
h.eq.flush_sync()
self.successResultOf(wfc_d2)
scid0 = b"\x00\x00\x00\x00"
sc0 = mock.Mock()
m.set_subchannel_zero(scid0, sc0)
listen_ep = mock.Mock()
m.set_listener_endpoint(listen_ep)
self.assertEqual(h.inbound.mock_calls, [
mock.call.set_subchannel_zero(scid0, sc0),
mock.call.set_listener_endpoint(listen_ep),
])
clear_mock_calls(h.inbound)
# the Leader making a new outbound channel should get scid=1
scid1 = to_be4(1)
self.assertEqual(m.allocate_subchannel_id(), scid1)
r1 = Open(10, scid1) # seqnum=10
h.outbound.build_record = mock.Mock(return_value=r1)
m.send_open(scid1)
self.assertEqual(h.outbound.mock_calls, [
mock.call.build_record(Open, scid1),
mock.call.queue_and_send_record(r1),
])
clear_mock_calls(h.outbound)
r2 = Data(11, scid1, b"data")
h.outbound.build_record = mock.Mock(return_value=r2)
m.send_data(scid1, b"data")
self.assertEqual(h.outbound.mock_calls, [
mock.call.build_record(Data, scid1, b"data"),
mock.call.queue_and_send_record(r2),
])
clear_mock_calls(h.outbound)
r3 = Close(12, scid1)
h.outbound.build_record = mock.Mock(return_value=r3)
m.send_close(scid1)
self.assertEqual(h.outbound.mock_calls, [
mock.call.build_record(Close, scid1),
mock.call.queue_and_send_record(r3),
])
clear_mock_calls(h.outbound)
# ack the OPEN
m.got_record(Ack(10))
self.assertEqual(h.outbound.mock_calls, [
mock.call.handle_ack(10)
])
clear_mock_calls(h.outbound)
# test that inbound records get acked and routed to Inbound
h.inbound.is_record_old = mock.Mock(return_value=False)
scid2 = to_be4(2)
o200 = Open(200, scid2)
m.got_record(o200)
self.assertEqual(h.outbound.mock_calls, [
mock.call.send_if_connected(Ack(200))
])
self.assertEqual(h.inbound.mock_calls, [
mock.call.is_record_old(o200),
mock.call.update_ack_watermark(200),
mock.call.handle_open(scid2),
])
clear_mock_calls(h.outbound, h.inbound)
# old (duplicate) records should provoke new Acks, but not get
# forwarded
h.inbound.is_record_old = mock.Mock(return_value=True)
m.got_record(o200)
self.assertEqual(h.outbound.mock_calls, [
mock.call.send_if_connected(Ack(200))
])
self.assertEqual(h.inbound.mock_calls, [
mock.call.is_record_old(o200),
])
clear_mock_calls(h.outbound, h.inbound)
# check Data and Close too
h.inbound.is_record_old = mock.Mock(return_value=False)
d201 = Data(201, scid2, b"data")
m.got_record(d201)
self.assertEqual(h.outbound.mock_calls, [
mock.call.send_if_connected(Ack(201))
])
self.assertEqual(h.inbound.mock_calls, [
mock.call.is_record_old(d201),
mock.call.update_ack_watermark(201),
mock.call.handle_data(scid2, b"data"),
])
clear_mock_calls(h.outbound, h.inbound)
c202 = Close(202, scid2)
m.got_record(c202)
self.assertEqual(h.outbound.mock_calls, [
mock.call.send_if_connected(Ack(202))
])
self.assertEqual(h.inbound.mock_calls, [
mock.call.is_record_old(c202),
mock.call.update_ack_watermark(202),
mock.call.handle_close(scid2),
])
clear_mock_calls(h.outbound, h.inbound)
# Now we lose the connection. The Leader should tell the other side
# that we're reconnecting.
m.connector_connection_lost()
self.assertEqual(h.send.mock_calls, [
mock.call.send("dilate-2",
dict_to_bytes({"type": "reconnect"}))
])
self.assertEqual(h.inbound.mock_calls, [
mock.call.stop_using_connection()
])
self.assertEqual(h.outbound.mock_calls, [
mock.call.stop_using_connection()
])
clear_mock_calls(h.send, h.inbound, h.outbound)
# leader does nothing (stays in FLUSHING) until the follower acks by
# sending RECONNECTING
# inbound hints should be ignored during FLUSHING
with mock.patch("wormhole._dilation.manager.parse_hint",
return_value=None) as ph:
m.rx_HINTS({"hints": [1, 2, 3]})
self.assertEqual(ph.mock_calls, []) # ignored
c2 = mock.Mock()
connector2 = mock.Mock(return_value=c2)
with mock.patch("wormhole._dilation.manager.Connector", connector2):
# this triggers creation of a new Connector
m.rx_RECONNECTING()
self.assertEqual(h.send.mock_calls, [])
self.assertEqual(connector2.mock_calls, [
mock.call(b"\x00" * 32, None, m, h.reactor, h.eq,
False, # no_listen
None, # tor
None, # timing
LEADER, roles.LEADER),
])
self.assertEqual(c2.mock_calls, [mock.call.start()])
clear_mock_calls(connector2, c2)
self.assertEqual(h.inbound.mock_calls, [])
self.assertEqual(h.outbound.mock_calls, [])
# and a new connection should re-register with Inbound/Outbound,
# which are responsible for re-sending unacked queued messages
c3 = mock.Mock()
m.connector_connection_made(c3)
self.assertEqual(h.inbound.mock_calls, [mock.call.use_connection(c3)])
self.assertEqual(h.outbound.mock_calls, [mock.call.use_connection(c3)])
clear_mock_calls(h.inbound, h.outbound)
def test_follower(self):
m, h = make_manager(leader=False)
m.start()
self.assertEqual(h.send.mock_calls, [
mock.call.send("dilate-0",
dict_to_bytes({"type": "please", "side": FOLLOWER}))
])
clear_mock_calls(h.send)
c = mock.Mock()
connector = mock.Mock(return_value=c)
with mock.patch("wormhole._dilation.manager.Connector", connector):
# receiving this PLEASE triggers creation of the Connector
m.rx_PLEASE({"side": LEADER})
self.assertEqual(h.send.mock_calls, [])
self.assertEqual(connector.mock_calls, [
mock.call(b"\x00" * 32, None, m, h.reactor, h.eq,
False, # no_listen
None, # tor
None, # timing
FOLLOWER, roles.FOLLOWER),
])
self.assertEqual(c.mock_calls, [mock.call.start()])
clear_mock_calls(connector, c)
# get connected, then lose the connection
c1 = mock.Mock()
m.connector_connection_made(c1)
self.assertEqual(h.inbound.mock_calls, [mock.call.use_connection(c1)])
self.assertEqual(h.outbound.mock_calls, [mock.call.use_connection(c1)])
clear_mock_calls(h.inbound, h.outbound)
# now lose the connection. As the follower, we don't notify the
# leader, we just wait for them to notice
m.connector_connection_lost()
self.assertEqual(h.send.mock_calls, [])
self.assertEqual(h.inbound.mock_calls, [
mock.call.stop_using_connection()
])
self.assertEqual(h.outbound.mock_calls, [
mock.call.stop_using_connection()
])
clear_mock_calls(h.send, h.inbound, h.outbound)
# now we get a RECONNECT: we should send RECONNECTING
c2 = mock.Mock()
connector2 = mock.Mock(return_value=c2)
with mock.patch("wormhole._dilation.manager.Connector", connector2):
m.rx_RECONNECT()
self.assertEqual(h.send.mock_calls, [
mock.call.send("dilate-1",
dict_to_bytes({"type": "reconnecting"}))
])
self.assertEqual(connector2.mock_calls, [
mock.call(b"\x00" * 32, None, m, h.reactor, h.eq,
False, # no_listen
None, # tor
None, # timing
FOLLOWER, roles.FOLLOWER),
])
self.assertEqual(c2.mock_calls, [mock.call.start()])
clear_mock_calls(connector2, c2)
# while we're trying to connect, we get told to stop again, so we
# should abandon the connection attempt and start another
c3 = mock.Mock()
connector3 = mock.Mock(return_value=c3)
with mock.patch("wormhole._dilation.manager.Connector", connector3):
m.rx_RECONNECT()
self.assertEqual(c2.mock_calls, [mock.call.stop()])
self.assertEqual(connector3.mock_calls, [
mock.call(b"\x00" * 32, None, m, h.reactor, h.eq,
False, # no_listen
None, # tor
None, # timing
FOLLOWER, roles.FOLLOWER),
])
self.assertEqual(c3.mock_calls, [mock.call.start()])
clear_mock_calls(c2, connector3, c3)
m.connector_connection_made(c3)
# finally if we're already connected, rx_RECONNECT means we should
# abandon this connection (even though it still looks ok to us), then
# when the attempt is finished stopping, we should start another
m.rx_RECONNECT()
c4 = mock.Mock()
connector4 = mock.Mock(return_value=c4)
with mock.patch("wormhole._dilation.manager.Connector", connector4):
m.connector_connection_lost()
self.assertEqual(c3.mock_calls, [mock.call.disconnect()])
self.assertEqual(connector4.mock_calls, [
mock.call(b"\x00" * 32, None, m, h.reactor, h.eq,
False, # no_listen
None, # tor
None, # timing
FOLLOWER, roles.FOLLOWER),
])
self.assertEqual(c4.mock_calls, [mock.call.start()])
clear_mock_calls(c3, connector4, c4)
def test_mirror(self):
# receive a PLEASE with the same side as us: shouldn't happen
m, h = make_manager(leader=True)
m.start()
clear_mock_calls(h.send)
e = self.assertRaises(ValueError, m.rx_PLEASE, {"side": LEADER})
self.assertEqual(str(e), "their side shouldn't be equal: reflection?")
def test_ping_pong(self):
m, h = make_manager(leader=False)
m.got_record(KCM())
self.flushLoggedErrors(UnexpectedKCM)
m.got_record(Ping(1))
self.assertEqual(h.outbound.mock_calls,
[mock.call.send_if_connected(Pong(1))])
clear_mock_calls(h.outbound)
m.got_record(Pong(2))
# currently ignored, will eventually update a timer
m.got_record("not recognized")
e = self.flushLoggedErrors(UnknownMessageType)
self.assertEqual(len(e), 1)
self.assertEqual(str(e[0].value), "not recognized")
m.send_ping(3)
self.assertEqual(h.outbound.mock_calls,
[mock.call.send_if_connected(Pong(3))])
clear_mock_calls(h.outbound)
def test_subchannel(self):
m, h = make_manager(leader=True)
sc = object()
m.subchannel_pauseProducing(sc)
self.assertEqual(h.inbound.mock_calls, [
mock.call.subchannel_pauseProducing(sc)])
clear_mock_calls(h.inbound)
m.subchannel_resumeProducing(sc)
self.assertEqual(h.inbound.mock_calls, [
mock.call.subchannel_resumeProducing(sc)])
clear_mock_calls(h.inbound)
m.subchannel_stopProducing(sc)
self.assertEqual(h.inbound.mock_calls, [
mock.call.subchannel_stopProducing(sc)])
clear_mock_calls(h.inbound)
p = object()
streaming = object()
m.subchannel_registerProducer(sc, p, streaming)
self.assertEqual(h.outbound.mock_calls, [
mock.call.subchannel_registerProducer(sc, p, streaming)])
clear_mock_calls(h.outbound)
m.subchannel_unregisterProducer(sc)
self.assertEqual(h.outbound.mock_calls, [
mock.call.subchannel_unregisterProducer(sc)])
clear_mock_calls(h.outbound)
m.subchannel_closed("scid", sc)
self.assertEqual(h.inbound.mock_calls, [
mock.call.subchannel_closed("scid", sc)])
self.assertEqual(h.outbound.mock_calls, [
mock.call.subchannel_closed("scid", sc)])
clear_mock_calls(h.inbound, h.outbound)

View File

@ -0,0 +1,655 @@
from __future__ import print_function, unicode_literals
from collections import namedtuple
from itertools import cycle
import mock
from zope.interface import alsoProvides
from twisted.trial import unittest
from twisted.internet.task import Clock, Cooperator
from twisted.internet.interfaces import IPullProducer
from ...eventual import EventualQueue
from ..._interfaces import IDilationManager
from ..._dilation.connection import KCM, Open, Data, Close, Ack
from ..._dilation.outbound import Outbound, PullToPush
from .common import clear_mock_calls
Pauser = namedtuple("Pauser", ["seqnum"])
NonPauser = namedtuple("NonPauser", ["seqnum"])
Stopper = namedtuple("Stopper", ["sc"])
def make_outbound():
m = mock.Mock()
alsoProvides(m, IDilationManager)
clock = Clock()
eq = EventualQueue(clock)
term = mock.Mock(side_effect=lambda: True) # one write per Eventual tick
def term_factory():
return term
coop = Cooperator(terminationPredicateFactory=term_factory,
scheduler=eq.eventually)
o = Outbound(m, coop)
c = mock.Mock() # Connection
def maybe_pause(r):
if isinstance(r, Pauser):
o.pauseProducing()
elif isinstance(r, Stopper):
o.subchannel_unregisterProducer(r.sc)
c.send_record = mock.Mock(side_effect=maybe_pause)
o._test_eq = eq
o._test_term = term
return o, m, c
class OutboundTest(unittest.TestCase):
def test_build_record(self):
o, m, c = make_outbound()
scid1 = b"scid"
self.assertEqual(o.build_record(Open, scid1),
Open(seqnum=0, scid=b"scid"))
self.assertEqual(o.build_record(Data, scid1, b"dataaa"),
Data(seqnum=1, scid=b"scid", data=b"dataaa"))
self.assertEqual(o.build_record(Close, scid1),
Close(seqnum=2, scid=b"scid"))
self.assertEqual(o.build_record(Close, scid1),
Close(seqnum=3, scid=b"scid"))
def test_outbound_queue(self):
o, m, c = make_outbound()
scid1 = b"scid"
r1 = o.build_record(Open, scid1)
r2 = o.build_record(Data, scid1, b"data1")
r3 = o.build_record(Data, scid1, b"data2")
o.queue_and_send_record(r1)
o.queue_and_send_record(r2)
o.queue_and_send_record(r3)
self.assertEqual(list(o._outbound_queue), [r1, r2, r3])
# we would never normally receive an ACK without first getting a
# connection
o.handle_ack(r2.seqnum)
self.assertEqual(list(o._outbound_queue), [r3])
o.handle_ack(r3.seqnum)
self.assertEqual(list(o._outbound_queue), [])
o.handle_ack(r3.seqnum) # ignored
self.assertEqual(list(o._outbound_queue), [])
o.handle_ack(r1.seqnum) # ignored
self.assertEqual(list(o._outbound_queue), [])
def test_duplicate_registerProducer(self):
o, m, c = make_outbound()
sc1 = object()
p1 = mock.Mock()
o.subchannel_registerProducer(sc1, p1, True)
with self.assertRaises(ValueError) as ar:
o.subchannel_registerProducer(sc1, p1, True)
s = str(ar.exception)
self.assertIn("registering producer", s)
self.assertIn("before previous one", s)
self.assertIn("was unregistered", s)
def test_connection_send_queued_unpaused(self):
o, m, c = make_outbound()
scid1 = b"scid"
r1 = o.build_record(Open, scid1)
r2 = o.build_record(Data, scid1, b"data1")
r3 = o.build_record(Data, scid1, b"data2")
o.queue_and_send_record(r1)
o.queue_and_send_record(r2)
self.assertEqual(list(o._outbound_queue), [r1, r2])
self.assertEqual(list(o._queued_unsent), [])
# as soon as the connection is established, everything is sent
o.use_connection(c)
self.assertEqual(c.mock_calls, [mock.call.registerProducer(o, True),
mock.call.send_record(r1),
mock.call.send_record(r2)])
self.assertEqual(list(o._outbound_queue), [r1, r2])
self.assertEqual(list(o._queued_unsent), [])
clear_mock_calls(c)
o.queue_and_send_record(r3)
self.assertEqual(list(o._outbound_queue), [r1, r2, r3])
self.assertEqual(list(o._queued_unsent), [])
self.assertEqual(c.mock_calls, [mock.call.send_record(r3)])
def test_connection_send_queued_paused(self):
o, m, c = make_outbound()
r1 = Pauser(seqnum=1)
r2 = Pauser(seqnum=2)
r3 = Pauser(seqnum=3)
o.queue_and_send_record(r1)
o.queue_and_send_record(r2)
self.assertEqual(list(o._outbound_queue), [r1, r2])
self.assertEqual(list(o._queued_unsent), [])
# pausing=True, so our mock Manager will pause the Outbound producer
# after each write. So only r1 should have been sent before getting
# paused
o.use_connection(c)
self.assertEqual(c.mock_calls, [mock.call.registerProducer(o, True),
mock.call.send_record(r1)])
self.assertEqual(list(o._outbound_queue), [r1, r2])
self.assertEqual(list(o._queued_unsent), [r2])
clear_mock_calls(c)
# Outbound is responsible for sending all records, so when Manager
# wants to send a new one, and Outbound is still in the middle of
# draining the beginning-of-connection queue, the new message gets
# queued behind the rest (in addition to being queued in
# _outbound_queue until an ACK retires it).
o.queue_and_send_record(r3)
self.assertEqual(list(o._outbound_queue), [r1, r2, r3])
self.assertEqual(list(o._queued_unsent), [r2, r3])
self.assertEqual(c.mock_calls, [])
o.handle_ack(r1.seqnum)
self.assertEqual(list(o._outbound_queue), [r2, r3])
self.assertEqual(list(o._queued_unsent), [r2, r3])
self.assertEqual(c.mock_calls, [])
def test_premptive_ack(self):
# one mode I have in mind is for each side to send an immediate ACK,
# with everything they've ever seen, as the very first message on each
# new connection. The idea is that you might preempt sending stuff from
# the _queued_unsent list if it arrives fast enough (in practice this
# is more likely to be delivered via the DILATE mailbox message, but
# the effects might be vaguely similar, so it seems worth testing
# here). A similar situation would be if each side sends ACKs with the
# highest seqnum they've ever seen, instead of merely ACKing the
# message which was just received.
o, m, c = make_outbound()
r1 = Pauser(seqnum=1)
r2 = Pauser(seqnum=2)
r3 = Pauser(seqnum=3)
o.queue_and_send_record(r1)
o.queue_and_send_record(r2)
self.assertEqual(list(o._outbound_queue), [r1, r2])
self.assertEqual(list(o._queued_unsent), [])
o.use_connection(c)
self.assertEqual(c.mock_calls, [mock.call.registerProducer(o, True),
mock.call.send_record(r1)])
self.assertEqual(list(o._outbound_queue), [r1, r2])
self.assertEqual(list(o._queued_unsent), [r2])
clear_mock_calls(c)
o.queue_and_send_record(r3)
self.assertEqual(list(o._outbound_queue), [r1, r2, r3])
self.assertEqual(list(o._queued_unsent), [r2, r3])
self.assertEqual(c.mock_calls, [])
o.handle_ack(r2.seqnum)
self.assertEqual(list(o._outbound_queue), [r3])
self.assertEqual(list(o._queued_unsent), [r3])
self.assertEqual(c.mock_calls, [])
def test_pause(self):
o, m, c = make_outbound()
o.use_connection(c)
self.assertEqual(c.mock_calls, [mock.call.registerProducer(o, True)])
self.assertEqual(list(o._outbound_queue), [])
self.assertEqual(list(o._queued_unsent), [])
clear_mock_calls(c)
sc1, sc2, sc3 = object(), object(), object()
p1, p2, p3 = mock.Mock(name="p1"), mock.Mock(
name="p2"), mock.Mock(name="p3")
# we aren't paused yet, since we haven't sent any data
o.subchannel_registerProducer(sc1, p1, True)
self.assertEqual(p1.mock_calls, [])
r1 = Pauser(seqnum=1)
o.queue_and_send_record(r1)
# now we should be paused
self.assertTrue(o._paused)
self.assertEqual(c.mock_calls, [mock.call.send_record(r1)])
self.assertEqual(p1.mock_calls, [mock.call.pauseProducing()])
clear_mock_calls(p1, c)
# so an IPushProducer will be paused right away
o.subchannel_registerProducer(sc2, p2, True)
self.assertEqual(p2.mock_calls, [mock.call.pauseProducing()])
clear_mock_calls(p2)
o.subchannel_registerProducer(sc3, p3, True)
self.assertEqual(p3.mock_calls, [mock.call.pauseProducing()])
self.assertEqual(o._paused_producers, set([p1, p2, p3]))
self.assertEqual(list(o._all_producers), [p1, p2, p3])
clear_mock_calls(p3)
# one resumeProducing should cause p1 to get a turn, since p2 was added
# after we were paused and p1 was at the "end" of a one-element list.
# If it writes anything, it will get paused again immediately.
r2 = Pauser(seqnum=2)
p1.resumeProducing.side_effect = lambda: c.send_record(r2)
o.resumeProducing()
self.assertEqual(p1.mock_calls, [mock.call.resumeProducing(),
mock.call.pauseProducing(),
])
self.assertEqual(p2.mock_calls, [])
self.assertEqual(p3.mock_calls, [])
self.assertEqual(c.mock_calls, [mock.call.send_record(r2)])
clear_mock_calls(p1, p2, p3, c)
# p2 should now be at the head of the queue
self.assertEqual(list(o._all_producers), [p2, p3, p1])
# next turn: p2 has nothing to send, but p3 does. we should see p3
# called but not p1. The actual sequence of expected calls is:
# p2.resume, p3.resume, pauseProducing, set(p2.pause, p3.pause)
r3 = Pauser(seqnum=3)
p2.resumeProducing.side_effect = lambda: None
p3.resumeProducing.side_effect = lambda: c.send_record(r3)
o.resumeProducing()
self.assertEqual(p1.mock_calls, [])
self.assertEqual(p2.mock_calls, [mock.call.resumeProducing(),
mock.call.pauseProducing(),
])
self.assertEqual(p3.mock_calls, [mock.call.resumeProducing(),
mock.call.pauseProducing(),
])
self.assertEqual(c.mock_calls, [mock.call.send_record(r3)])
clear_mock_calls(p1, p2, p3, c)
# p1 should now be at the head of the queue
self.assertEqual(list(o._all_producers), [p1, p2, p3])
# next turn: p1 has data to send, but not enough to cause a pause. same
# for p2. p3 causes a pause
r4 = NonPauser(seqnum=4)
r5 = NonPauser(seqnum=5)
r6 = Pauser(seqnum=6)
p1.resumeProducing.side_effect = lambda: c.send_record(r4)
p2.resumeProducing.side_effect = lambda: c.send_record(r5)
p3.resumeProducing.side_effect = lambda: c.send_record(r6)
o.resumeProducing()
self.assertEqual(p1.mock_calls, [mock.call.resumeProducing(),
mock.call.pauseProducing(),
])
self.assertEqual(p2.mock_calls, [mock.call.resumeProducing(),
mock.call.pauseProducing(),
])
self.assertEqual(p3.mock_calls, [mock.call.resumeProducing(),
mock.call.pauseProducing(),
])
self.assertEqual(c.mock_calls, [mock.call.send_record(r4),
mock.call.send_record(r5),
mock.call.send_record(r6),
])
clear_mock_calls(p1, p2, p3, c)
# p1 should now be at the head of the queue again
self.assertEqual(list(o._all_producers), [p1, p2, p3])
# now we let it catch up. p1 and p2 send non-pausing data, p3 sends
# nothing.
r7 = NonPauser(seqnum=4)
r8 = NonPauser(seqnum=5)
p1.resumeProducing.side_effect = lambda: c.send_record(r7)
p2.resumeProducing.side_effect = lambda: c.send_record(r8)
p3.resumeProducing.side_effect = lambda: None
o.resumeProducing()
self.assertEqual(p1.mock_calls, [mock.call.resumeProducing(),
])
self.assertEqual(p2.mock_calls, [mock.call.resumeProducing(),
])
self.assertEqual(p3.mock_calls, [mock.call.resumeProducing(),
])
self.assertEqual(c.mock_calls, [mock.call.send_record(r7),
mock.call.send_record(r8),
])
clear_mock_calls(p1, p2, p3, c)
# p1 should now be at the head of the queue again
self.assertEqual(list(o._all_producers), [p1, p2, p3])
self.assertFalse(o._paused)
# now a producer disconnects itself (spontaneously, not from inside a
# resumeProducing)
o.subchannel_unregisterProducer(sc1)
self.assertEqual(list(o._all_producers), [p2, p3])
self.assertEqual(p1.mock_calls, [])
self.assertFalse(o._paused)
# and another disconnects itself when called
p2.resumeProducing.side_effect = lambda: None
p3.resumeProducing.side_effect = lambda: o.subchannel_unregisterProducer(
sc3)
o.pauseProducing()
o.resumeProducing()
self.assertEqual(p2.mock_calls, [mock.call.pauseProducing(),
mock.call.resumeProducing()])
self.assertEqual(p3.mock_calls, [mock.call.pauseProducing(),
mock.call.resumeProducing()])
clear_mock_calls(p2, p3)
self.assertEqual(list(o._all_producers), [p2])
self.assertFalse(o._paused)
def test_subchannel_closed(self):
o, m, c = make_outbound()
sc1 = mock.Mock()
p1 = mock.Mock(name="p1")
o.subchannel_registerProducer(sc1, p1, True)
self.assertEqual(p1.mock_calls, [mock.call.pauseProducing()])
clear_mock_calls(p1)
o.subchannel_closed(sc1)
self.assertEqual(p1.mock_calls, [])
self.assertEqual(list(o._all_producers), [])
sc2 = mock.Mock()
o.subchannel_closed(sc2)
def test_disconnect(self):
o, m, c = make_outbound()
o.use_connection(c)
sc1 = mock.Mock()
p1 = mock.Mock(name="p1")
o.subchannel_registerProducer(sc1, p1, True)
self.assertEqual(p1.mock_calls, [])
o.stop_using_connection()
self.assertEqual(p1.mock_calls, [mock.call.pauseProducing()])
def OFF_test_push_pull(self):
# use one IPushProducer and one IPullProducer. They should take turns
o, m, c = make_outbound()
o.use_connection(c)
clear_mock_calls(c)
sc1, sc2 = object(), object()
p1, p2 = mock.Mock(name="p1"), mock.Mock(name="p2")
r1 = Pauser(seqnum=1)
r2 = NonPauser(seqnum=2)
# we aren't paused yet, since we haven't sent any data
o.subchannel_registerProducer(sc1, p1, True) # push
o.queue_and_send_record(r1)
# now we're paused
self.assertTrue(o._paused)
self.assertEqual(c.mock_calls, [mock.call.send_record(r1)])
self.assertEqual(p1.mock_calls, [mock.call.pauseProducing()])
self.assertEqual(p2.mock_calls, [])
clear_mock_calls(p1, p2, c)
p1.resumeProducing.side_effect = lambda: c.send_record(r1)
p2.resumeProducing.side_effect = lambda: c.send_record(r2)
o.subchannel_registerProducer(sc2, p2, False) # pull: always ready
# p1 is still first, since p2 was just added (at the end)
self.assertTrue(o._paused)
self.assertEqual(c.mock_calls, [])
self.assertEqual(p1.mock_calls, [])
self.assertEqual(p2.mock_calls, [])
self.assertEqual(list(o._all_producers), [p1, p2])
clear_mock_calls(p1, p2, c)
# resume should send r1, which should pause everything
o.resumeProducing()
self.assertTrue(o._paused)
self.assertEqual(c.mock_calls, [mock.call.send_record(r1),
])
self.assertEqual(p1.mock_calls, [mock.call.resumeProducing(),
mock.call.pauseProducing(),
])
self.assertEqual(p2.mock_calls, [])
self.assertEqual(list(o._all_producers), [p2, p1]) # now p2 is next
clear_mock_calls(p1, p2, c)
# next should fire p2, then p1
o.resumeProducing()
self.assertTrue(o._paused)
self.assertEqual(c.mock_calls, [mock.call.send_record(r2),
mock.call.send_record(r1),
])
self.assertEqual(p1.mock_calls, [mock.call.resumeProducing(),
mock.call.pauseProducing(),
])
self.assertEqual(p2.mock_calls, [mock.call.resumeProducing(),
])
self.assertEqual(list(o._all_producers), [p2, p1]) # p2 still at bat
clear_mock_calls(p1, p2, c)
def test_pull_producer(self):
# a single pull producer should write until it is paused, rate-limited
# by the cooperator (so we'll see back-to-back resumeProducing calls
# until the Connection is paused, or 10ms have passed, whichever comes
# first, and if it's stopped by the timer, then the next EventualQueue
# turn will start it off again)
o, m, c = make_outbound()
eq = o._test_eq
o.use_connection(c)
clear_mock_calls(c)
self.assertFalse(o._paused)
sc1 = mock.Mock()
p1 = mock.Mock(name="p1")
alsoProvides(p1, IPullProducer)
records = [NonPauser(seqnum=1)] * 10
records.append(Pauser(seqnum=2))
records.append(Stopper(sc1))
it = iter(records)
p1.resumeProducing.side_effect = lambda: c.send_record(next(it))
o.subchannel_registerProducer(sc1, p1, False)
eq.flush_sync() # fast forward into the glorious (paused) future
self.assertTrue(o._paused)
self.assertEqual(c.mock_calls,
[mock.call.send_record(r) for r in records[:-1]])
self.assertEqual(p1.mock_calls,
[mock.call.resumeProducing()] * (len(records) - 1))
clear_mock_calls(c, p1)
# next resumeProducing should cause it to disconnect
o.resumeProducing()
eq.flush_sync()
self.assertEqual(c.mock_calls, [mock.call.send_record(records[-1])])
self.assertEqual(p1.mock_calls, [mock.call.resumeProducing()])
self.assertEqual(len(o._all_producers), 0)
self.assertFalse(o._paused)
def test_two_pull_producers(self):
# we should alternate between them until paused
p1_records = ([NonPauser(seqnum=i) for i in range(5)] +
[Pauser(seqnum=5)] +
[NonPauser(seqnum=i) for i in range(6, 10)])
p2_records = ([NonPauser(seqnum=i) for i in range(10, 19)] +
[Pauser(seqnum=19)])
expected1 = [NonPauser(0), NonPauser(10),
NonPauser(1), NonPauser(11),
NonPauser(2), NonPauser(12),
NonPauser(3), NonPauser(13),
NonPauser(4), NonPauser(14),
Pauser(5)]
expected2 = [NonPauser(15),
NonPauser(6), NonPauser(16),
NonPauser(7), NonPauser(17),
NonPauser(8), NonPauser(18),
NonPauser(9), Pauser(19),
]
o, m, c = make_outbound()
eq = o._test_eq
o.use_connection(c)
clear_mock_calls(c)
self.assertFalse(o._paused)
sc1 = mock.Mock()
p1 = mock.Mock(name="p1")
alsoProvides(p1, IPullProducer)
it1 = iter(p1_records)
p1.resumeProducing.side_effect = lambda: c.send_record(next(it1))
o.subchannel_registerProducer(sc1, p1, False)
sc2 = mock.Mock()
p2 = mock.Mock(name="p2")
alsoProvides(p2, IPullProducer)
it2 = iter(p2_records)
p2.resumeProducing.side_effect = lambda: c.send_record(next(it2))
o.subchannel_registerProducer(sc2, p2, False)
eq.flush_sync() # fast forward into the glorious (paused) future
sends = [mock.call.resumeProducing()]
self.assertTrue(o._paused)
self.assertEqual(c.mock_calls,
[mock.call.send_record(r) for r in expected1])
self.assertEqual(p1.mock_calls, 6 * sends)
self.assertEqual(p2.mock_calls, 5 * sends)
clear_mock_calls(c, p1, p2)
o.resumeProducing()
eq.flush_sync()
self.assertTrue(o._paused)
self.assertEqual(c.mock_calls,
[mock.call.send_record(r) for r in expected2])
self.assertEqual(p1.mock_calls, 4 * sends)
self.assertEqual(p2.mock_calls, 5 * sends)
clear_mock_calls(c, p1, p2)
def test_send_if_connected(self):
o, m, c = make_outbound()
o.send_if_connected(Ack(1)) # not connected yet
o.use_connection(c)
o.send_if_connected(KCM())
self.assertEqual(c.mock_calls, [mock.call.registerProducer(o, True),
mock.call.send_record(KCM())])
def test_tolerate_duplicate_pause_resume(self):
o, m, c = make_outbound()
self.assertTrue(o._paused) # no connection
o.use_connection(c)
self.assertFalse(o._paused)
o.pauseProducing()
self.assertTrue(o._paused)
o.pauseProducing()
self.assertTrue(o._paused)
o.resumeProducing()
self.assertFalse(o._paused)
o.resumeProducing()
self.assertFalse(o._paused)
def test_stopProducing(self):
o, m, c = make_outbound()
o.use_connection(c)
self.assertFalse(o._paused)
o.stopProducing() # connection does this before loss
self.assertTrue(o._paused)
o.stop_using_connection()
self.assertTrue(o._paused)
def test_resume_error(self):
o, m, c = make_outbound()
o.use_connection(c)
sc1 = mock.Mock()
p1 = mock.Mock(name="p1")
alsoProvides(p1, IPullProducer)
p1.resumeProducing.side_effect = PretendResumptionError
o.subchannel_registerProducer(sc1, p1, False)
o._test_eq.flush_sync()
# the error is supposed to automatically unregister the producer
self.assertEqual(list(o._all_producers), [])
self.flushLoggedErrors(PretendResumptionError)
def make_pushpull(pauses):
p = mock.Mock()
alsoProvides(p, IPullProducer)
unregister = mock.Mock()
clock = Clock()
eq = EventualQueue(clock)
term = mock.Mock(side_effect=lambda: True) # one write per Eventual tick
def term_factory():
return term
coop = Cooperator(terminationPredicateFactory=term_factory,
scheduler=eq.eventually)
pp = PullToPush(p, unregister, coop)
it = cycle(pauses)
def action(i):
if isinstance(i, Exception):
raise i
elif i:
pp.pauseProducing()
p.resumeProducing.side_effect = lambda: action(next(it))
return p, unregister, pp, eq
class PretendResumptionError(Exception):
pass
class PretendUnregisterError(Exception):
pass
class PushPull(unittest.TestCase):
# test our wrapper utility, which I copied from
# twisted.internet._producer_helpers since it isn't publically exposed
def test_start_unpaused(self):
p, unr, pp, eq = make_pushpull([True]) # pause on each resumeProducing
# if it starts unpaused, it gets one write before being halted
pp.startStreaming(False)
eq.flush_sync()
self.assertEqual(p.mock_calls, [mock.call.resumeProducing()] * 1)
clear_mock_calls(p)
# now each time we call resumeProducing, we should see one delivered to
# the underlying IPullProducer
pp.resumeProducing()
eq.flush_sync()
self.assertEqual(p.mock_calls, [mock.call.resumeProducing()] * 1)
pp.stopStreaming()
pp.stopStreaming() # should tolerate this
def test_start_unpaused_two_writes(self):
p, unr, pp, eq = make_pushpull([False, True]) # pause every other time
# it should get two writes, since the first didn't pause
pp.startStreaming(False)
eq.flush_sync()
self.assertEqual(p.mock_calls, [mock.call.resumeProducing()] * 2)
def test_start_paused(self):
p, unr, pp, eq = make_pushpull([True]) # pause on each resumeProducing
pp.startStreaming(True)
eq.flush_sync()
self.assertEqual(p.mock_calls, [])
pp.stopStreaming()
def test_stop(self):
p, unr, pp, eq = make_pushpull([True])
pp.startStreaming(True)
pp.stopProducing()
eq.flush_sync()
self.assertEqual(p.mock_calls, [mock.call.stopProducing()])
def test_error(self):
p, unr, pp, eq = make_pushpull([PretendResumptionError()])
unr.side_effect = lambda: pp.stopStreaming()
pp.startStreaming(False)
eq.flush_sync()
self.assertEqual(unr.mock_calls, [mock.call()])
self.flushLoggedErrors(PretendResumptionError)
def test_error_during_unregister(self):
p, unr, pp, eq = make_pushpull([PretendResumptionError()])
unr.side_effect = PretendUnregisterError()
pp.startStreaming(False)
eq.flush_sync()
self.assertEqual(unr.mock_calls, [mock.call()])
self.flushLoggedErrors(PretendResumptionError, PretendUnregisterError)
# TODO: consider making p1/p2/p3 all elements of a shared Mock, maybe I
# could capture the inter-call ordering that way

View File

@ -0,0 +1,44 @@
from __future__ import print_function, unicode_literals
import mock
from twisted.trial import unittest
from ..._dilation.connection import (parse_record, encode_record,
KCM, Ping, Pong, Open, Data, Close, Ack)
class Parse(unittest.TestCase):
def test_parse(self):
self.assertEqual(parse_record(b"\x00"), KCM())
self.assertEqual(parse_record(b"\x01\x55\x44\x33\x22"),
Ping(ping_id=b"\x55\x44\x33\x22"))
self.assertEqual(parse_record(b"\x02\x55\x44\x33\x22"),
Pong(ping_id=b"\x55\x44\x33\x22"))
self.assertEqual(parse_record(b"\x03\x00\x00\x02\x01\x00\x00\x01\x00"),
Open(scid=513, seqnum=256))
self.assertEqual(parse_record(b"\x04\x00\x00\x02\x02\x00\x00\x01\x01dataaa"),
Data(scid=514, seqnum=257, data=b"dataaa"))
self.assertEqual(parse_record(b"\x05\x00\x00\x02\x03\x00\x00\x01\x02"),
Close(scid=515, seqnum=258))
self.assertEqual(parse_record(b"\x06\x00\x00\x01\x03"),
Ack(resp_seqnum=259))
with mock.patch("wormhole._dilation.connection.log.err") as le:
with self.assertRaises(ValueError):
parse_record(b"\x07unknown")
self.assertEqual(le.mock_calls,
[mock.call("received unknown message type: {}".format(
b"\x07unknown"))])
def test_encode(self):
self.assertEqual(encode_record(KCM()), b"\x00")
self.assertEqual(encode_record(Ping(ping_id=b"ping")), b"\x01ping")
self.assertEqual(encode_record(Pong(ping_id=b"pong")), b"\x02pong")
self.assertEqual(encode_record(Open(scid=65536, seqnum=16)),
b"\x03\x00\x01\x00\x00\x00\x00\x00\x10")
self.assertEqual(encode_record(Data(scid=65537, seqnum=17, data=b"dataaa")),
b"\x04\x00\x01\x00\x01\x00\x00\x00\x11dataaa")
self.assertEqual(encode_record(Close(scid=65538, seqnum=18)),
b"\x05\x00\x01\x00\x02\x00\x00\x00\x12")
self.assertEqual(encode_record(Ack(resp_seqnum=19)),
b"\x06\x00\x00\x00\x13")
with self.assertRaises(TypeError) as ar:
encode_record("not a record")
self.assertEqual(str(ar.exception), "not a record")

View File

@ -0,0 +1,269 @@
from __future__ import print_function, unicode_literals
import mock
from zope.interface import alsoProvides
from twisted.trial import unittest
from ..._dilation._noise import NoiseInvalidMessage
from ..._dilation.connection import (IFramer, Frame, Prologue,
_Record, Handshake,
Disconnect, Ping)
def make_record():
f = mock.Mock()
alsoProvides(f, IFramer)
n = mock.Mock() # pretends to be a Noise object
r = _Record(f, n)
return r, f, n
class Record(unittest.TestCase):
def test_good2(self):
f = mock.Mock()
alsoProvides(f, IFramer)
f.add_and_parse = mock.Mock(side_effect=[
[],
[Prologue()],
[Frame(frame=b"rx-handshake")],
[Frame(frame=b"frame1"), Frame(frame=b"frame2")],
])
n = mock.Mock()
n.write_message = mock.Mock(return_value=b"tx-handshake")
p1, p2 = object(), object()
n.decrypt = mock.Mock(side_effect=[p1, p2])
r = _Record(f, n)
self.assertEqual(f.mock_calls, [])
r.connectionMade()
self.assertEqual(f.mock_calls, [mock.call.connectionMade()])
f.mock_calls[:] = []
self.assertEqual(n.mock_calls, [mock.call.start_handshake()])
n.mock_calls[:] = []
# Pretend to deliver the prologue in two parts. The text we send in
# doesn't matter: the side_effect= is what causes the prologue to be
# recognized by the second call.
self.assertEqual(list(r.add_and_unframe(b"pro")), [])
self.assertEqual(f.mock_calls, [mock.call.add_and_parse(b"pro")])
f.mock_calls[:] = []
self.assertEqual(n.mock_calls, [])
# recognizing the prologue causes a handshake frame to be sent
self.assertEqual(list(r.add_and_unframe(b"logue")), [])
self.assertEqual(f.mock_calls, [mock.call.add_and_parse(b"logue"),
mock.call.send_frame(b"tx-handshake")])
f.mock_calls[:] = []
self.assertEqual(n.mock_calls, [mock.call.write_message()])
n.mock_calls[:] = []
# next add_and_unframe is recognized as the Handshake
self.assertEqual(list(r.add_and_unframe(b"blah")), [Handshake()])
self.assertEqual(f.mock_calls, [mock.call.add_and_parse(b"blah")])
f.mock_calls[:] = []
self.assertEqual(n.mock_calls, [mock.call.read_message(b"rx-handshake")])
n.mock_calls[:] = []
# next is a pair of Records
r1, r2 = object(), object()
with mock.patch("wormhole._dilation.connection.parse_record",
side_effect=[r1, r2]) as pr:
self.assertEqual(list(r.add_and_unframe(b"blah2")), [r1, r2])
self.assertEqual(n.mock_calls, [mock.call.decrypt(b"frame1"),
mock.call.decrypt(b"frame2")])
self.assertEqual(pr.mock_calls, [mock.call(p1), mock.call(p2)])
def test_bad_handshake(self):
f = mock.Mock()
alsoProvides(f, IFramer)
f.add_and_parse = mock.Mock(return_value=[Prologue(),
Frame(frame=b"rx-handshake")])
n = mock.Mock()
n.write_message = mock.Mock(return_value=b"tx-handshake")
nvm = NoiseInvalidMessage()
n.read_message = mock.Mock(side_effect=nvm)
r = _Record(f, n)
self.assertEqual(f.mock_calls, [])
r.connectionMade()
self.assertEqual(f.mock_calls, [mock.call.connectionMade()])
f.mock_calls[:] = []
self.assertEqual(n.mock_calls, [mock.call.start_handshake()])
n.mock_calls[:] = []
with mock.patch("wormhole._dilation.connection.log.err") as le:
with self.assertRaises(Disconnect):
list(r.add_and_unframe(b"data"))
self.assertEqual(le.mock_calls,
[mock.call(nvm, "bad inbound noise handshake")])
def test_bad_message(self):
f = mock.Mock()
alsoProvides(f, IFramer)
f.add_and_parse = mock.Mock(return_value=[Prologue(),
Frame(frame=b"rx-handshake"),
Frame(frame=b"bad-message")])
n = mock.Mock()
n.write_message = mock.Mock(return_value=b"tx-handshake")
nvm = NoiseInvalidMessage()
n.decrypt = mock.Mock(side_effect=nvm)
r = _Record(f, n)
self.assertEqual(f.mock_calls, [])
r.connectionMade()
self.assertEqual(f.mock_calls, [mock.call.connectionMade()])
f.mock_calls[:] = []
self.assertEqual(n.mock_calls, [mock.call.start_handshake()])
n.mock_calls[:] = []
with mock.patch("wormhole._dilation.connection.log.err") as le:
with self.assertRaises(Disconnect):
list(r.add_and_unframe(b"data"))
self.assertEqual(le.mock_calls,
[mock.call(nvm, "bad inbound noise frame")])
def test_send_record(self):
f = mock.Mock()
alsoProvides(f, IFramer)
n = mock.Mock()
f1 = object()
n.encrypt = mock.Mock(return_value=f1)
r1 = Ping(b"pingid")
r = _Record(f, n)
self.assertEqual(f.mock_calls, [])
m1 = object()
with mock.patch("wormhole._dilation.connection.encode_record",
return_value=m1) as er:
r.send_record(r1)
self.assertEqual(er.mock_calls, [mock.call(r1)])
self.assertEqual(n.mock_calls, [mock.call.start_handshake(),
mock.call.encrypt(m1)])
self.assertEqual(f.mock_calls, [mock.call.send_frame(f1)])
def test_good(self):
# Exercise the success path. The Record instance is given each chunk
# of data as it arrives on Protocol.dataReceived, and is supposed to
# return a series of Tokens (maybe none, if the chunk was incomplete,
# or more than one, if the chunk was larger). Internally, it delivers
# the chunks to the Framer for unframing (which returns 0 or more
# frames), manages the Noise decryption object, and parses any
# decrypted messages into tokens (some of which are consumed
# internally, others for delivery upstairs).
#
# in the normal flow, we get:
#
# | | Inbound | NoiseAction | Outbound | ToUpstairs |
# | | - | - | - | - |
# | 1 | | | prologue | |
# | 2 | prologue | | | |
# | 3 | | write_message | handshake | |
# | 4 | handshake | read_message | | Handshake |
# | 5 | | encrypt | KCM | |
# | 6 | KCM | decrypt | | KCM |
# | 7 | msg1 | decrypt | | msg1 |
# 1: instantiating the Record instance causes the outbound prologue
# to be sent
# 2+3: receipt of the inbound prologue triggers creation of the
# ephemeral key (the "handshake") by calling noise.write_message()
# and then writes the handshake to the outbound transport
# 4: when the peer's handshake is received, it is delivered to
# noise.read_message(), which generates the shared key (enabling
# noise.send() and noise.decrypt()). It also delivers the Handshake
# token upstairs, which might (on the Follower) trigger immediate
# transmission of the Key Confirmation Message (KCM)
# 5: the outbound KCM is framed and fed into noise.encrypt(), then
# sent outbound
# 6: the peer's KCM is decrypted then delivered upstairs. The
# Follower treats this as a signal that it should use this connection
# (and drop all others).
# 7: the peer's first message is decrypted, parsed, and delivered
# upstairs. This might be an Open or a Data, depending upon what
# queued messages were left over from the previous connection
r, f, n = make_record()
outbound_handshake = object()
kcm, msg1 = object(), object()
f_kcm, f_msg1 = object(), object()
n.write_message = mock.Mock(return_value=outbound_handshake)
n.decrypt = mock.Mock(side_effect=[kcm, msg1])
n.encrypt = mock.Mock(side_effect=[f_kcm, f_msg1])
f.add_and_parse = mock.Mock(side_effect=[[], # no tokens yet
[Prologue()],
[Frame("f_handshake")],
[Frame("f_kcm"),
Frame("f_msg1")],
])
self.assertEqual(f.mock_calls, [])
self.assertEqual(n.mock_calls, [mock.call.start_handshake()])
n.mock_calls[:] = []
# 1. The Framer is responsible for sending the prologue, so we don't
# have to check that here, we just check that the Framer was told
# about connectionMade properly.
r.connectionMade()
self.assertEqual(f.mock_calls, [mock.call.connectionMade()])
self.assertEqual(n.mock_calls, [])
f.mock_calls[:] = []
# 2
# we dribble the prologue in over two messages, to make sure we can
# handle a dataReceived that doesn't complete the token
# remember, add_and_unframe is a generator
self.assertEqual(list(r.add_and_unframe(b"pro")), [])
self.assertEqual(f.mock_calls, [mock.call.add_and_parse(b"pro")])
self.assertEqual(n.mock_calls, [])
f.mock_calls[:] = []
self.assertEqual(list(r.add_and_unframe(b"logue")), [])
# 3: write_message, send outbound handshake
self.assertEqual(f.mock_calls, [mock.call.add_and_parse(b"logue"),
mock.call.send_frame(outbound_handshake),
])
self.assertEqual(n.mock_calls, [mock.call.write_message()])
f.mock_calls[:] = []
n.mock_calls[:] = []
# 4
# Now deliver the Noise "handshake", the ephemeral public key. This
# is framed, but not a record, so it shouldn't decrypt or parse
# anything, but the handshake is delivered to the Noise object, and
# it does return a Handshake token so we can let the next layer up
# react (by sending the KCM frame if we're a Follower, or not if
# we're the Leader)
self.assertEqual(list(r.add_and_unframe(b"handshake")), [Handshake()])
self.assertEqual(f.mock_calls, [mock.call.add_and_parse(b"handshake")])
self.assertEqual(n.mock_calls, [mock.call.read_message("f_handshake")])
f.mock_calls[:] = []
n.mock_calls[:] = []
# 5: at this point we ought to be able to send a messge, the KCM
with mock.patch("wormhole._dilation.connection.encode_record",
side_effect=[b"r-kcm"]) as er:
r.send_record(kcm)
self.assertEqual(er.mock_calls, [mock.call(kcm)])
self.assertEqual(n.mock_calls, [mock.call.encrypt(b"r-kcm")])
self.assertEqual(f.mock_calls, [mock.call.send_frame(f_kcm)])
n.mock_calls[:] = []
f.mock_calls[:] = []
# 6: Now we deliver two messages stacked up: the KCM (Key
# Confirmation Message) and the first real message. Concatenating
# them tests that we can handle more than one token in a single
# chunk. We need to mock parse_record() because everything past the
# handshake is decrypted and parsed.
with mock.patch("wormhole._dilation.connection.parse_record",
side_effect=[kcm, msg1]) as pr:
self.assertEqual(list(r.add_and_unframe(b"kcm,msg1")),
[kcm, msg1])
self.assertEqual(f.mock_calls,
[mock.call.add_and_parse(b"kcm,msg1")])
self.assertEqual(n.mock_calls, [mock.call.decrypt("f_kcm"),
mock.call.decrypt("f_msg1")])
self.assertEqual(pr.mock_calls, [mock.call(kcm), mock.call(msg1)])
n.mock_calls[:] = []
f.mock_calls[:] = []

View File

@ -0,0 +1,144 @@
from __future__ import print_function, unicode_literals
import mock
from twisted.trial import unittest
from twisted.internet.interfaces import ITransport
from twisted.internet.error import ConnectionDone
from ..._dilation.subchannel import (Once, SubChannel,
_WormholeAddress, _SubchannelAddress,
AlreadyClosedError)
from .common import mock_manager
def make_sc(set_protocol=True):
scid = b"scid"
hostaddr = _WormholeAddress()
peeraddr = _SubchannelAddress(scid)
m = mock_manager()
sc = SubChannel(scid, m, hostaddr, peeraddr)
p = mock.Mock()
if set_protocol:
sc._set_protocol(p)
return sc, m, scid, hostaddr, peeraddr, p
class SubChannelAPI(unittest.TestCase):
def test_once(self):
o = Once(ValueError)
o()
with self.assertRaises(ValueError):
o()
def test_create(self):
sc, m, scid, hostaddr, peeraddr, p = make_sc()
self.assert_(ITransport.providedBy(sc))
self.assertEqual(m.mock_calls, [])
self.assertIdentical(sc.getHost(), hostaddr)
self.assertIdentical(sc.getPeer(), peeraddr)
def test_write(self):
sc, m, scid, hostaddr, peeraddr, p = make_sc()
sc.write(b"data")
self.assertEqual(m.mock_calls, [mock.call.send_data(scid, b"data")])
m.mock_calls[:] = []
sc.writeSequence([b"more", b"data"])
self.assertEqual(m.mock_calls, [mock.call.send_data(scid, b"moredata")])
def test_write_when_closing(self):
sc, m, scid, hostaddr, peeraddr, p = make_sc()
sc.loseConnection()
self.assertEqual(m.mock_calls, [mock.call.send_close(scid)])
m.mock_calls[:] = []
with self.assertRaises(AlreadyClosedError) as e:
sc.write(b"data")
self.assertEqual(str(e.exception),
"write not allowed on closed subchannel")
def test_local_close(self):
sc, m, scid, hostaddr, peeraddr, p = make_sc()
sc.loseConnection()
self.assertEqual(m.mock_calls, [mock.call.send_close(scid)])
m.mock_calls[:] = []
# late arriving data is still delivered
sc.remote_data(b"late")
self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"late")])
p.mock_calls[:] = []
sc.remote_close()
self.assert_connectionDone(p.mock_calls)
def test_local_double_close(self):
sc, m, scid, hostaddr, peeraddr, p = make_sc()
sc.loseConnection()
self.assertEqual(m.mock_calls, [mock.call.send_close(scid)])
m.mock_calls[:] = []
with self.assertRaises(AlreadyClosedError) as e:
sc.loseConnection()
self.assertEqual(str(e.exception),
"loseConnection not allowed on closed subchannel")
def assert_connectionDone(self, mock_calls):
self.assertEqual(len(mock_calls), 1)
self.assertEqual(mock_calls[0][0], "connectionLost")
self.assertEqual(len(mock_calls[0][1]), 1)
self.assertIsInstance(mock_calls[0][1][0], ConnectionDone)
def test_remote_close(self):
sc, m, scid, hostaddr, peeraddr, p = make_sc()
sc.remote_close()
self.assertEqual(m.mock_calls, [mock.call.subchannel_closed(sc)])
self.assert_connectionDone(p.mock_calls)
def test_data(self):
sc, m, scid, hostaddr, peeraddr, p = make_sc()
sc.remote_data(b"data")
self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"data")])
p.mock_calls[:] = []
sc.remote_data(b"not")
sc.remote_data(b"coalesced")
self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"not"),
mock.call.dataReceived(b"coalesced"),
])
def test_data_before_open(self):
sc, m, scid, hostaddr, peeraddr, p = make_sc(set_protocol=False)
sc.remote_data(b"data")
self.assertEqual(p.mock_calls, [])
sc._set_protocol(p)
self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"data")])
p.mock_calls[:] = []
sc.remote_data(b"more")
self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"more")])
def test_close_before_open(self):
sc, m, scid, hostaddr, peeraddr, p = make_sc(set_protocol=False)
sc.remote_close()
self.assertEqual(p.mock_calls, [])
sc._set_protocol(p)
self.assert_connectionDone(p.mock_calls)
def test_producer(self):
sc, m, scid, hostaddr, peeraddr, p = make_sc()
sc.pauseProducing()
self.assertEqual(m.mock_calls, [mock.call.subchannel_pauseProducing(sc)])
m.mock_calls[:] = []
sc.resumeProducing()
self.assertEqual(m.mock_calls, [mock.call.subchannel_resumeProducing(sc)])
m.mock_calls[:] = []
sc.stopProducing()
self.assertEqual(m.mock_calls, [mock.call.subchannel_stopProducing(sc)])
m.mock_calls[:] = []
def test_consumer(self):
sc, m, scid, hostaddr, peeraddr, p = make_sc()
# TODO: more, once this is implemented
sc.registerProducer(None, True)
sc.unregisterProducer()

View File

@ -0,0 +1,196 @@
from __future__ import print_function, unicode_literals
import io
from collections import namedtuple
import mock
from twisted.internet import endpoints, reactor
from twisted.trial import unittest
from .._hints import (endpoint_from_hint_obj, parse_hint_argv, parse_tcp_v1_hint,
describe_hint_obj, parse_hint, encode_hint,
DirectTCPV1Hint, TorTCPV1Hint, RelayV1Hint)
UnknownHint = namedtuple("UnknownHint", ["stuff"])
class Hints(unittest.TestCase):
def test_endpoint_from_hint_obj(self):
def efho(hint, tor=None):
return endpoint_from_hint_obj(hint, tor, reactor)
self.assertIsInstance(efho(DirectTCPV1Hint("host", 1234, 0.0)),
endpoints.HostnameEndpoint)
self.assertEqual(efho("unknown:stuff:yowza:pivlor"), None)
# tor=None
self.assertEqual(efho(TorTCPV1Hint("host", "port", 0)), None)
self.assertEqual(efho(UnknownHint("foo")), None)
tor = mock.Mock()
def tor_ep(hostname, port):
if hostname == "non-public":
raise ValueError
return ("tor_ep", hostname, port)
tor.stream_via = mock.Mock(side_effect=tor_ep)
self.assertEqual(efho(DirectTCPV1Hint("host", 1234, 0.0), tor),
("tor_ep", "host", 1234))
self.assertEqual(efho(TorTCPV1Hint("host2.onion", 1234, 0.0), tor),
("tor_ep", "host2.onion", 1234))
self.assertEqual( efho(DirectTCPV1Hint("non-public", 1234, 0.0), tor), None)
self.assertEqual(efho(UnknownHint("foo"), tor), None)
def test_comparable(self):
h1 = DirectTCPV1Hint("hostname", "port1", 0.0)
h1b = DirectTCPV1Hint("hostname", "port1", 0.0)
h2 = DirectTCPV1Hint("hostname", "port2", 0.0)
r1 = RelayV1Hint(tuple(sorted([h1, h2])))
r2 = RelayV1Hint(tuple(sorted([h2, h1])))
r3 = RelayV1Hint(tuple(sorted([h1b, h2])))
self.assertEqual(r1, r2)
self.assertEqual(r2, r3)
self.assertEqual(len(set([r1, r2, r3])), 1)
def test_parse_tcp_v1_hint(self):
p = parse_tcp_v1_hint
self.assertEqual(p({"type": "unknown"}), None)
h = p({"type": "direct-tcp-v1", "hostname": "foo", "port": 1234})
self.assertEqual(h, DirectTCPV1Hint("foo", 1234, 0.0))
h = p({
"type": "direct-tcp-v1",
"hostname": "foo",
"port": 1234,
"priority": 2.5
})
self.assertEqual(h, DirectTCPV1Hint("foo", 1234, 2.5))
h = p({"type": "tor-tcp-v1", "hostname": "foo", "port": 1234})
self.assertEqual(h, TorTCPV1Hint("foo", 1234, 0.0))
h = p({
"type": "tor-tcp-v1",
"hostname": "foo",
"port": 1234,
"priority": 2.5
})
self.assertEqual(h, TorTCPV1Hint("foo", 1234, 2.5))
self.assertEqual(p({
"type": "direct-tcp-v1"
}), None) # missing hostname
self.assertEqual(p({
"type": "direct-tcp-v1",
"hostname": 12
}), None) # invalid hostname
self.assertEqual(
p({
"type": "direct-tcp-v1",
"hostname": "foo"
}), None) # missing port
self.assertEqual(
p({
"type": "direct-tcp-v1",
"hostname": "foo",
"port": "not a number"
}), None) # invalid port
def test_parse_hint(self):
p = parse_hint
self.assertEqual(p({"type": "direct-tcp-v1",
"hostname": "foo",
"port": 12}),
DirectTCPV1Hint("foo", 12, 0.0))
self.assertEqual(p({"type": "relay-v1",
"hints": [
{"type": "direct-tcp-v1",
"hostname": "foo",
"port": 12},
{"type": "unrecognized"},
{"type": "direct-tcp-v1",
"hostname": "bar",
"port": 13}]}),
RelayV1Hint([DirectTCPV1Hint("foo", 12, 0.0),
DirectTCPV1Hint("bar", 13, 0.0)]))
def test_parse_hint_argv(self):
def p(hint):
stderr = io.StringIO()
value = parse_hint_argv(hint, stderr=stderr)
return value, stderr.getvalue()
h, stderr = p("tcp:host:1234")
self.assertEqual(h, DirectTCPV1Hint("host", 1234, 0.0))
self.assertEqual(stderr, "")
h, stderr = p("tcp:host:1234:priority=2.6")
self.assertEqual(h, DirectTCPV1Hint("host", 1234, 2.6))
self.assertEqual(stderr, "")
h, stderr = p("tcp:host:1234:unknown=stuff")
self.assertEqual(h, DirectTCPV1Hint("host", 1234, 0.0))
self.assertEqual(stderr, "")
h, stderr = p("$!@#^")
self.assertEqual(h, None)
self.assertEqual(stderr, "unparseable hint '$!@#^'\n")
h, stderr = p("unknown:stuff")
self.assertEqual(h, None)
self.assertEqual(stderr,
"unknown hint type 'unknown' in 'unknown:stuff'\n")
h, stderr = p("tcp:just-a-hostname")
self.assertEqual(h, None)
self.assertEqual(
stderr,
"unparseable TCP hint (need more colons) 'tcp:just-a-hostname'\n")
h, stderr = p("tcp:host:number")
self.assertEqual(h, None)
self.assertEqual(stderr,
"non-numeric port in TCP hint 'tcp:host:number'\n")
h, stderr = p("tcp:host:1234:priority=bad")
self.assertEqual(h, None)
self.assertEqual(
stderr,
"non-float priority= in TCP hint 'tcp:host:1234:priority=bad'\n")
def test_describe_hint_obj(self):
d = describe_hint_obj
self.assertEqual(d(DirectTCPV1Hint("host", 1234, 0.0), False, False),
"->tcp:host:1234")
self.assertEqual(d(DirectTCPV1Hint("host", 1234, 0.0), True, False),
"->relay:tcp:host:1234")
self.assertEqual(d(DirectTCPV1Hint("host", 1234, 0.0), False, True),
"tor->tcp:host:1234")
self.assertEqual(d(DirectTCPV1Hint("host", 1234, 0.0), True, True),
"tor->relay:tcp:host:1234")
self.assertEqual(d(TorTCPV1Hint("host", 1234, 0.0), False, False),
"->tor:host:1234")
self.assertEqual(d(UnknownHint("stuff"), False, False),
"->%s" % str(UnknownHint("stuff")))
def test_encode_hint(self):
e = encode_hint
self.assertEqual(e(DirectTCPV1Hint("host", 1234, 1.0)),
{"type": "direct-tcp-v1",
"priority": 1.0,
"hostname": "host",
"port": 1234})
self.assertEqual(e(RelayV1Hint([DirectTCPV1Hint("foo", 12, 0.0),
DirectTCPV1Hint("bar", 13, 0.0)])),
{"type": "relay-v1",
"hints": [
{"type": "direct-tcp-v1",
"hostname": "foo",
"port": 12,
"priority": 0.0},
{"type": "direct-tcp-v1",
"hostname": "bar",
"port": 13,
"priority": 0.0},
]})
self.assertEqual(e(TorTCPV1Hint("host", 1234, 1.0)),
{"type": "tor-tcp-v1",
"priority": 1.0,
"hostname": "host",
"port": 1234})
e = self.assertRaises(ValueError, e, "not a Hint")
self.assertIn("unknown hint type", str(e))
self.assertIn("not a Hint", str(e))

View File

@ -12,8 +12,8 @@ import mock
from .. import (__version__, _allocator, _boss, _code, _input, _key, _lister,
_mailbox, _nameplate, _order, _receive, _rendezvous, _send,
_terminator, errors, timing)
from .._interfaces import (IAllocator, IBoss, ICode, IInput, IKey, ILister,
IMailbox, INameplate, IOrder, IReceive,
from .._interfaces import (IAllocator, IBoss, ICode, IDilator, IInput, IKey,
ILister, IMailbox, INameplate, IOrder, IReceive,
IRendezvousConnector, ISend, ITerminator, IWordlist)
from .._key import derive_key, derive_phase_key, encrypt_data
from ..journal import ImmediateJournal
@ -1286,17 +1286,21 @@ class Boss(unittest.TestCase):
"closed")
versions = {"app": "version1"}
reactor = None
eq = None
cooperator = None
journal = ImmediateJournal()
tor_manager = None
client_version = ("python", __version__)
b = MockBoss(wormhole, "side", "url", "appid", versions,
client_version, reactor, journal, tor_manager,
client_version, reactor, eq, cooperator, journal,
tor_manager,
timing.DebugTiming())
b._T = Dummy("t", events, ITerminator, "close")
b._S = Dummy("s", events, ISend, "send")
b._RC = Dummy("rc", events, IRendezvousConnector, "start")
b._C = Dummy("c", events, ICode, "allocate_code", "input_code",
"set_code")
b._D = Dummy("d", events, IDilator, "got_wormhole_versions", "got_key")
return b, events
def test_basic(self):
@ -1324,7 +1328,9 @@ class Boss(unittest.TestCase):
b.got_message("0", b"msg1")
self.assertEqual(events, [
("w.got_key", b"key"),
("d.got_key", b"key"),
("w.got_verifier", b"verifier"),
("d.got_wormhole_versions", {}),
("w.got_versions", {}),
("w.received", b"msg1"),
])

View File

@ -3,7 +3,7 @@ from twisted.python.failure import Failure
from twisted.trial import unittest
from ..eventual import EventualQueue
from ..observer import OneShotObserver, SequenceObserver
from ..observer import OneShotObserver, SequenceObserver, EmptyableSet
class OneShot(unittest.TestCase):
@ -121,3 +121,28 @@ class Sequence(unittest.TestCase):
d2 = o.when_next_event()
eq.flush_sync()
self.assertIdentical(self.failureResultOf(d2), f)
class Empty(unittest.TestCase):
def test_set(self):
eq = EventualQueue(Clock())
s = EmptyableSet(_eventual_queue=eq)
d1 = s.when_next_empty()
eq.flush_sync()
self.assertNoResult(d1)
s.add(1)
eq.flush_sync()
self.assertNoResult(d1)
s.add(2)
s.discard(1)
d2 = s.when_next_empty()
eq.flush_sync()
self.assertNoResult(d1)
self.assertNoResult(d2)
s.discard(2)
eq.flush_sync()
self.assertEqual(self.successResultOf(d1), None)
self.assertEqual(self.successResultOf(d2), None)
s.add(3)
s.discard(3)

View File

@ -3,7 +3,6 @@ from __future__ import print_function, unicode_literals
import gc
import io
from binascii import hexlify, unhexlify
from collections import namedtuple
import six
from nacl.exceptions import CryptoError
@ -18,7 +17,9 @@ import mock
from wormhole_transit_relay import transit_server
from .. import transit
from .._hints import DirectTCPV1Hint
from ..errors import InternalError
from ..util import HKDF
from .common import ServerBase
@ -140,141 +141,6 @@ class Misc(unittest.TestCase):
self.assertIsInstance(portno, int)
UnknownHint = namedtuple("UnknownHint", ["stuff"])
class Hints(unittest.TestCase):
def test_endpoint_from_hint_obj(self):
c = transit.Common("")
efho = c._endpoint_from_hint_obj
self.assertIsInstance(
efho(transit.DirectTCPV1Hint("host", 1234, 0.0)),
endpoints.HostnameEndpoint)
self.assertEqual(efho("unknown:stuff:yowza:pivlor"), None)
# c._tor is currently None
self.assertEqual(efho(transit.TorTCPV1Hint("host", "port", 0)), None)
c._tor = mock.Mock()
def tor_ep(hostname, port):
if hostname == "non-public":
return None
return ("tor_ep", hostname, port)
c._tor.stream_via = mock.Mock(side_effect=tor_ep)
self.assertEqual(
efho(transit.DirectTCPV1Hint("host", 1234, 0.0)),
("tor_ep", "host", 1234))
self.assertEqual(
efho(transit.TorTCPV1Hint("host2.onion", 1234, 0.0)),
("tor_ep", "host2.onion", 1234))
self.assertEqual(
efho(transit.DirectTCPV1Hint("non-public", 1234, 0.0)), None)
self.assertEqual(efho(UnknownHint("foo")), None)
def test_comparable(self):
h1 = transit.DirectTCPV1Hint("hostname", "port1", 0.0)
h1b = transit.DirectTCPV1Hint("hostname", "port1", 0.0)
h2 = transit.DirectTCPV1Hint("hostname", "port2", 0.0)
r1 = transit.RelayV1Hint(tuple(sorted([h1, h2])))
r2 = transit.RelayV1Hint(tuple(sorted([h2, h1])))
r3 = transit.RelayV1Hint(tuple(sorted([h1b, h2])))
self.assertEqual(r1, r2)
self.assertEqual(r2, r3)
self.assertEqual(len(set([r1, r2, r3])), 1)
def test_parse_tcp_v1_hint(self):
c = transit.Common("")
p = c._parse_tcp_v1_hint
self.assertEqual(p({"type": "unknown"}), None)
h = p({"type": "direct-tcp-v1", "hostname": "foo", "port": 1234})
self.assertEqual(h, transit.DirectTCPV1Hint("foo", 1234, 0.0))
h = p({
"type": "direct-tcp-v1",
"hostname": "foo",
"port": 1234,
"priority": 2.5
})
self.assertEqual(h, transit.DirectTCPV1Hint("foo", 1234, 2.5))
h = p({"type": "tor-tcp-v1", "hostname": "foo", "port": 1234})
self.assertEqual(h, transit.TorTCPV1Hint("foo", 1234, 0.0))
h = p({
"type": "tor-tcp-v1",
"hostname": "foo",
"port": 1234,
"priority": 2.5
})
self.assertEqual(h, transit.TorTCPV1Hint("foo", 1234, 2.5))
self.assertEqual(p({
"type": "direct-tcp-v1"
}), None) # missing hostname
self.assertEqual(p({
"type": "direct-tcp-v1",
"hostname": 12
}), None) # invalid hostname
self.assertEqual(
p({
"type": "direct-tcp-v1",
"hostname": "foo"
}), None) # missing port
self.assertEqual(
p({
"type": "direct-tcp-v1",
"hostname": "foo",
"port": "not a number"
}), None) # invalid port
def test_parse_hint_argv(self):
def p(hint):
stderr = io.StringIO()
value = transit.parse_hint_argv(hint, stderr=stderr)
return value, stderr.getvalue()
h, stderr = p("tcp:host:1234")
self.assertEqual(h, transit.DirectTCPV1Hint("host", 1234, 0.0))
self.assertEqual(stderr, "")
h, stderr = p("tcp:host:1234:priority=2.6")
self.assertEqual(h, transit.DirectTCPV1Hint("host", 1234, 2.6))
self.assertEqual(stderr, "")
h, stderr = p("tcp:host:1234:unknown=stuff")
self.assertEqual(h, transit.DirectTCPV1Hint("host", 1234, 0.0))
self.assertEqual(stderr, "")
h, stderr = p("$!@#^")
self.assertEqual(h, None)
self.assertEqual(stderr, "unparseable hint '$!@#^'\n")
h, stderr = p("unknown:stuff")
self.assertEqual(h, None)
self.assertEqual(stderr,
"unknown hint type 'unknown' in 'unknown:stuff'\n")
h, stderr = p("tcp:just-a-hostname")
self.assertEqual(h, None)
self.assertEqual(
stderr,
"unparseable TCP hint (need more colons) 'tcp:just-a-hostname'\n")
h, stderr = p("tcp:host:number")
self.assertEqual(h, None)
self.assertEqual(stderr,
"non-numeric port in TCP hint 'tcp:host:number'\n")
h, stderr = p("tcp:host:1234:priority=bad")
self.assertEqual(h, None)
self.assertEqual(
stderr,
"non-float priority= in TCP hint 'tcp:host:1234:priority=bad'\n")
def test_describe_hint_obj(self):
d = transit.describe_hint_obj
self.assertEqual(
d(transit.DirectTCPV1Hint("host", 1234, 0.0)), "tcp:host:1234")
self.assertEqual(
d(transit.TorTCPV1Hint("host", 1234, 0.0)), "tor:host:1234")
self.assertEqual(d(UnknownHint("stuff")), str(UnknownHint("stuff")))
# ipaddrs.py currently uses native strings: bytes on py2, unicode on
# py3
@ -437,7 +303,7 @@ class Listener(unittest.TestCase):
hints, ep = c._build_listener()
self.assertIsInstance(hints, (list, set))
if hints:
self.assertIsInstance(hints[0], transit.DirectTCPV1Hint)
self.assertIsInstance(hints[0], DirectTCPV1Hint)
self.assertIsInstance(ep, endpoints.TCP4ServerEndpoint)
def test_get_direct_hints(self):
@ -1507,8 +1373,8 @@ class Transit(unittest.TestCase):
self.assertEqual(self.successResultOf(d), "winner")
self.assertEqual(self._descriptions, ["tor->relay:tcp:relay:1234"])
def _endpoint_from_hint_obj(self, hint):
if isinstance(hint, transit.DirectTCPV1Hint):
def _endpoint_from_hint_obj(self, hint, _tor, _reactor):
if isinstance(hint, DirectTCPV1Hint):
if hint.hostname == "unavailable":
return None
return hint.hostname
@ -1523,20 +1389,21 @@ class Transit(unittest.TestCase):
del hints
s.add_connection_hints(
[DIRECT_HINT_JSON, UNRECOGNIZED_HINT_JSON, RELAY_HINT_JSON])
s._endpoint_from_hint_obj = self._endpoint_from_hint_obj
s._start_connector = self._start_connector
d = s.connect()
self.assertNoResult(d)
# the direct connectors are tried right away, but the relay
# connectors are stalled for a few seconds
self.assertEqual(self._connectors, ["direct"])
with mock.patch("wormhole.transit.endpoint_from_hint_obj",
self._endpoint_from_hint_obj):
d = s.connect()
self.assertNoResult(d)
# the direct connectors are tried right away, but the relay
# connectors are stalled for a few seconds
self.assertEqual(self._connectors, ["direct"])
clock.advance(s.RELAY_DELAY + 1.0)
self.assertEqual(self._connectors, ["direct", "relay"])
clock.advance(s.RELAY_DELAY + 1.0)
self.assertEqual(self._connectors, ["direct", "relay"])
self._waiters[0].callback("winner")
self.assertEqual(self.successResultOf(d), "winner")
self._waiters[0].callback("winner")
self.assertEqual(self.successResultOf(d), "winner")
@inlineCallbacks
def test_priorities(self):
@ -1586,31 +1453,32 @@ class Transit(unittest.TestCase):
}]
},
])
s._endpoint_from_hint_obj = self._endpoint_from_hint_obj
s._start_connector = self._start_connector
d = s.connect()
self.assertNoResult(d)
# direct connector should be used first, then the priority=3.0 relay,
# then the two 2.0 relays, then the (default) 0.0 relay
with mock.patch("wormhole.transit.endpoint_from_hint_obj",
self._endpoint_from_hint_obj):
d = s.connect()
self.assertNoResult(d)
# direct connector should be used first, then the priority=3.0 relay,
# then the two 2.0 relays, then the (default) 0.0 relay
self.assertEqual(self._connectors, ["direct"])
self.assertEqual(self._connectors, ["direct"])
clock.advance(s.RELAY_DELAY + 1.0)
self.assertEqual(self._connectors, ["direct", "relay3"])
clock.advance(s.RELAY_DELAY + 1.0)
self.assertEqual(self._connectors, ["direct", "relay3"])
clock.advance(s.RELAY_DELAY)
self.assertIn(self._connectors,
(["direct", "relay3", "relay2", "relay4"],
["direct", "relay3", "relay4", "relay2"]))
clock.advance(s.RELAY_DELAY)
self.assertIn(self._connectors,
(["direct", "relay3", "relay2", "relay4"],
["direct", "relay3", "relay4", "relay2"]))
clock.advance(s.RELAY_DELAY)
self.assertIn(self._connectors,
(["direct", "relay3", "relay2", "relay4", "relay"],
["direct", "relay3", "relay4", "relay2", "relay"]))
clock.advance(s.RELAY_DELAY)
self.assertIn(self._connectors,
(["direct", "relay3", "relay2", "relay4", "relay"],
["direct", "relay3", "relay4", "relay2", "relay"]))
self._waiters[0].callback("winner")
self.assertEqual(self.successResultOf(d), "winner")
self._waiters[0].callback("winner")
self.assertEqual(self.successResultOf(d), "winner")
@inlineCallbacks
def test_no_direct_hints(self):
@ -1624,20 +1492,21 @@ class Transit(unittest.TestCase):
UNRECOGNIZED_HINT_JSON, UNAVAILABLE_HINT_JSON, RELAY_HINT2_JSON,
UNAVAILABLE_RELAY_HINT_JSON
])
s._endpoint_from_hint_obj = self._endpoint_from_hint_obj
s._start_connector = self._start_connector
d = s.connect()
self.assertNoResult(d)
# since there are no usable direct hints, the relay connector will
# only be stalled for 0 seconds
self.assertEqual(self._connectors, [])
with mock.patch("wormhole.transit.endpoint_from_hint_obj",
self._endpoint_from_hint_obj):
d = s.connect()
self.assertNoResult(d)
# since there are no usable direct hints, the relay connector will
# only be stalled for 0 seconds
self.assertEqual(self._connectors, [])
clock.advance(0)
self.assertEqual(self._connectors, ["relay"])
clock.advance(0)
self.assertEqual(self._connectors, ["relay"])
self._waiters[0].callback("winner")
self.assertEqual(self.successResultOf(d), "winner")
self._waiters[0].callback("winner")
self.assertEqual(self.successResultOf(d), "winner")
@inlineCallbacks
def test_no_contenders(self):
@ -1647,17 +1516,18 @@ class Transit(unittest.TestCase):
hints = yield s.get_connection_hints() # start the listener
del hints
s.add_connection_hints([]) # no hints at all
s._endpoint_from_hint_obj = self._endpoint_from_hint_obj
s._start_connector = self._start_connector
d = s.connect()
f = self.failureResultOf(d, transit.TransitError)
self.assertEqual(str(f.value), "No contenders for connection")
with mock.patch("wormhole.transit.endpoint_from_hint_obj",
self._endpoint_from_hint_obj):
d = s.connect()
f = self.failureResultOf(d, transit.TransitError)
self.assertEqual(str(f.value), "No contenders for connection")
class RelayHandshake(unittest.TestCase):
def old_build_relay_handshake(self, key):
token = transit.HKDF(key, 32, CTXinfo=b"transit_relay_token")
token = HKDF(key, 32, CTXinfo=b"transit_relay_token")
return (token, b"please relay " + hexlify(token) + b"\n")
def test_old(self):

View File

@ -2,15 +2,13 @@
from __future__ import absolute_import, print_function
import os
import re
import socket
import sys
import time
from binascii import hexlify, unhexlify
from collections import deque, namedtuple
from collections import deque
import six
from hkdf import Hkdf
from nacl.secret import SecretBox
from twisted.internet import (address, defer, endpoints, error, interfaces,
protocol, reactor, task)
@ -23,11 +21,10 @@ from zope.interface import implementer
from . import ipaddrs
from .errors import InternalError
from .timing import DebugTiming
from .util import bytes_to_hexstr
def HKDF(skm, outlen, salt=None, CTXinfo=b""):
return Hkdf(salt, skm).expand(CTXinfo, outlen)
from .util import bytes_to_hexstr, HKDF
from ._hints import (DirectTCPV1Hint, RelayV1Hint,
parse_hint_argv, describe_hint_obj, endpoint_from_hint_obj,
parse_tcp_v1_hint)
class TransitError(Exception):
@ -95,72 +92,6 @@ def build_sided_relay_handshake(key, side):
"ascii") + b"\n"
# These namedtuples are "hint objects". The JSON-serializable dictionaries
# are "hint dicts".
# DirectTCPV1Hint and TorTCPV1Hint mean the following protocol:
# * make a TCP connection (possibly via Tor)
# * send the sender/receiver handshake bytes first
# * expect to see the receiver/sender handshake bytes from the other side
# * the sender writes "go\n", the receiver waits for "go\n"
# * the rest of the connection contains transit data
DirectTCPV1Hint = namedtuple("DirectTCPV1Hint",
["hostname", "port", "priority"])
TorTCPV1Hint = namedtuple("TorTCPV1Hint", ["hostname", "port", "priority"])
# RelayV1Hint contains a tuple of DirectTCPV1Hint and TorTCPV1Hint hints (we
# use a tuple rather than a list so they'll be hashable into a set). For each
# one, make the TCP connection, send the relay handshake, then complete the
# rest of the V1 protocol. Only one hint per relay is useful.
RelayV1Hint = namedtuple("RelayV1Hint", ["hints"])
def describe_hint_obj(hint):
if isinstance(hint, DirectTCPV1Hint):
return u"tcp:%s:%d" % (hint.hostname, hint.port)
elif isinstance(hint, TorTCPV1Hint):
return u"tor:%s:%d" % (hint.hostname, hint.port)
else:
return str(hint)
def parse_hint_argv(hint, stderr=sys.stderr):
assert isinstance(hint, type(u""))
# return tuple or None for an unparseable hint
priority = 0.0
mo = re.search(r'^([a-zA-Z0-9]+):(.*)$', hint)
if not mo:
print("unparseable hint '%s'" % (hint, ), file=stderr)
return None
hint_type = mo.group(1)
if hint_type != "tcp":
print(
"unknown hint type '%s' in '%s'" % (hint_type, hint), file=stderr)
return None
hint_value = mo.group(2)
pieces = hint_value.split(":")
if len(pieces) < 2:
print(
"unparseable TCP hint (need more colons) '%s'" % (hint, ),
file=stderr)
return None
mo = re.search(r'^(\d+)$', pieces[1])
if not mo:
print("non-numeric port in TCP hint '%s'" % (hint, ), file=stderr)
return None
hint_host = pieces[0]
hint_port = int(pieces[1])
for more in pieces[2:]:
if more.startswith("priority="):
more_pieces = more.split("=")
try:
priority = float(more_pieces[1])
except ValueError:
print(
"non-float priority= in TCP hint '%s'" % (hint, ),
file=stderr)
return None
return DirectTCPV1Hint(hint_host, hint_port, priority)
TIMEOUT = 60 # seconds
@ -746,30 +677,11 @@ class Common:
self._listener_d.addErrback(lambda f: None)
self._listener_d.cancel()
def _parse_tcp_v1_hint(self, hint): # hint_struct -> hint_obj
hint_type = hint.get(u"type", u"")
if hint_type not in [u"direct-tcp-v1", u"tor-tcp-v1"]:
log.msg("unknown hint type: %r" % (hint, ))
return None
if not (u"hostname" in hint and
isinstance(hint[u"hostname"], type(u""))):
log.msg("invalid hostname in hint: %r" % (hint, ))
return None
if not (u"port" in hint and
isinstance(hint[u"port"], six.integer_types)):
log.msg("invalid port in hint: %r" % (hint, ))
return None
priority = hint.get(u"priority", 0.0)
if hint_type == u"direct-tcp-v1":
return DirectTCPV1Hint(hint[u"hostname"], hint[u"port"], priority)
else:
return TorTCPV1Hint(hint[u"hostname"], hint[u"port"], priority)
def add_connection_hints(self, hints):
for h in hints: # hint structs
hint_type = h.get(u"type", u"")
if hint_type in [u"direct-tcp-v1", u"tor-tcp-v1"]:
dh = self._parse_tcp_v1_hint(h)
dh = parse_tcp_v1_hint(h)
if dh:
self._their_direct_hints.append(dh) # hint_obj
elif hint_type == u"relay-v1":
@ -779,7 +691,7 @@ class Common:
# together like this.
relay_hints = []
for rhs in h.get(u"hints", []):
h = self._parse_tcp_v1_hint(rhs)
h = parse_tcp_v1_hint(rhs)
if h:
relay_hints.append(h)
if relay_hints:
@ -875,13 +787,11 @@ class Common:
# Check the hint type to see if we can support it (e.g. skip
# onion hints on a non-Tor client). Do not increase relay_delay
# unless we have at least one viable hint.
ep = self._endpoint_from_hint_obj(hint_obj)
ep = endpoint_from_hint_obj(hint_obj, self._tor, self._reactor)
if not ep:
continue
description = "->%s" % describe_hint_obj(hint_obj)
if self._tor:
description = "tor" + description
d = self._start_connector(ep, description)
d = self._start_connector(ep,
describe_hint_obj(hint_obj, False, self._tor))
contenders.append(d)
relay_delay = self.RELAY_DELAY
@ -902,18 +812,15 @@ class Common:
for priority in sorted(prioritized_relays, reverse=True):
for hint_obj in prioritized_relays[priority]:
ep = self._endpoint_from_hint_obj(hint_obj)
ep = endpoint_from_hint_obj(hint_obj, self._tor, self._reactor)
if not ep:
continue
description = "->relay:%s" % describe_hint_obj(hint_obj)
if self._tor:
description = "tor" + description
d = task.deferLater(
self._reactor,
relay_delay,
self._start_connector,
ep,
description,
describe_hint_obj(hint_obj, True, self._tor),
is_relay=True)
contenders.append(d)
relay_delay += self.RELAY_DELAY
@ -951,21 +858,6 @@ class Common:
d.addCallback(lambda p: p.startNegotiation())
return d
def _endpoint_from_hint_obj(self, hint):
if self._tor:
if isinstance(hint, (DirectTCPV1Hint, TorTCPV1Hint)):
# this Tor object will throw ValueError for non-public IPv4
# addresses and any IPv6 address
try:
return self._tor.stream_via(hint.hostname, hint.port)
except ValueError:
return None
return None
if isinstance(hint, DirectTCPV1Hint):
return endpoints.HostnameEndpoint(self._reactor, hint.hostname,
hint.port)
return None
def connection_ready(self, p):
# inbound/outbound Connection protocols call this when they finish
# negotiation. The first one wins and gets a "go". Any subsequent

View File

@ -3,11 +3,19 @@ import json
import os
import unicodedata
from binascii import hexlify, unhexlify
from hkdf import Hkdf
def HKDF(skm, outlen, salt=None, CTXinfo=b""):
return Hkdf(salt, skm).expand(CTXinfo, outlen)
def to_bytes(u):
return unicodedata.normalize("NFC", u).encode("utf-8")
def to_unicode(any):
if isinstance(any, type(u"")):
return any
return any.decode("ascii")
def bytes_to_hexstr(b):
assert isinstance(b, type(b""))

View File

@ -5,9 +5,12 @@ import sys
from attr import attrib, attrs
from twisted.python import failure
from twisted.internet.task import Cooperator
from zope.interface import implementer
from ._boss import Boss
from ._dilation.manager import DILATION_VERSIONS
from ._dilation.connector import Connector
from ._interfaces import IDeferredWormhole, IWormhole
from ._key import derive_key
from .errors import NoKeyError, WormholeClosed
@ -122,7 +125,8 @@ class _DelegatedWormhole(object):
@implementer(IWormhole, IDeferredWormhole)
class _DeferredWormhole(object):
def __init__(self, eq):
def __init__(self, reactor, eq):
self._reactor = reactor
self._welcome_observer = OneShotObserver(eq)
self._code_observer = OneShotObserver(eq)
self._key = None
@ -187,6 +191,10 @@ class _DeferredWormhole(object):
raise NoKeyError()
return derive_key(self._key, to_bytes(purpose), length)
def dilate(self):
raise NotImplementedError
return self._boss.dilate() # fires with (endpoints)
def close(self):
# fails with WormholeError unless we established a connection
# (state=="happy"). Fails with WrongPasswordError (a subclass of
@ -258,18 +266,24 @@ def create(
side = bytes_to_hexstr(os.urandom(5))
journal = journal or ImmediateJournal()
eq = _eventual_queue or EventualQueue(reactor)
cooperator = Cooperator(scheduler=eq.eventually)
if delegate:
w = _DelegatedWormhole(delegate)
else:
w = _DeferredWormhole(eq)
wormhole_versions = {} # will be used to indicate Wormhole capabilities
w = _DeferredWormhole(reactor, eq)
# this indicates Wormhole capabilities
wormhole_versions = {
"can-dilate": DILATION_VERSIONS,
"dilation-abilities": Connector.get_connection_abilities(),
}
wormhole_versions = {} # don't advertise Dilation yet: not ready
wormhole_versions["app_versions"] = versions # app-specific capabilities
v = __version__
if isinstance(v, type(b"")):
v = v.decode("utf-8", errors="replace")
client_version = ("python", v)
b = Boss(w, side, relay_url, appid, wormhole_versions, client_version,
reactor, journal, tor, timing)
reactor, eq, cooperator, journal, tor, timing)
w._set_boss(b)
b.start()
return w

View File

@ -10,7 +10,7 @@ minversion = 2.4.0
[testenv]
usedevelop = True
extras = dev
extras = dev,dilate
deps =
pyflakes >= 1.2.3
commands =
@ -18,6 +18,8 @@ commands =
wormhole --version
python -m wormhole.test.run_trial {posargs:wormhole}
[testenv:no-dilate]
extras = dev
# on windows, trial is installed as venv/bin/trial.py, not .exe, but (at
# least appveyor) adds .PY to $PATHEXT. So "trial wormhole" might work on