Merge branch 'dilate-5'
This adds (but does not enable/expose) the low-level code for the new Dilation protocol (refs #312). The spec and docs are done, the unit tests pass (with full branch coverage). The next step is to write some higher-level integration tests, which use a fake/short-circuited mailbox connection (Manager.send) but real localhost TCP sockets. Then we need to figure out backwards compatibility with non-dilation-capable versions. I've got a table in my notes, I'll add it to the ticket.
This commit is contained in:
commit
ddba0fc840
|
@ -17,13 +17,16 @@ before_script:
|
|||
flake8 *.py src --count --select=E901,E999,F821,F822,F823 --statistics ;
|
||||
fi
|
||||
script:
|
||||
- tox -e coverage
|
||||
- if [[ $TRAVIS_PYTHON_VERSION == 2.7 || $TRAVIS_PYTHON_VERSION == 3.4 ]]; then
|
||||
tox -e no-dilate ;
|
||||
else
|
||||
tox -e coverage ;
|
||||
fi
|
||||
after_success:
|
||||
- codecov
|
||||
matrix:
|
||||
include:
|
||||
- python: 2.7
|
||||
- python: 3.3
|
||||
- python: 3.4
|
||||
- python: 3.5
|
||||
- python: 3.6
|
||||
|
@ -34,5 +37,4 @@ matrix:
|
|||
dist: xenial
|
||||
- python: nightly
|
||||
allow_failures:
|
||||
- python: 3.3
|
||||
- python: nightly
|
||||
|
|
136
docs/api.md
136
docs/api.md
|
@ -524,25 +524,33 @@ object twice.
|
|||
|
||||
## Dilation
|
||||
|
||||
(NOTE: this section is speculative: this code has not yet been written)
|
||||
(NOTE: this API is still in development)
|
||||
|
||||
In the longer term, the Wormhole object will incorporate the "Transit"
|
||||
functionality (see transit.md) directly, removing the need to instantiate a
|
||||
second object. A Wormhole can be "dilated" into a form that is suitable for
|
||||
bulk data transfer.
|
||||
To send bulk data, or anything more than a handful of messages, a Wormhole
|
||||
can be "dilated" into a form that uses a direct TCP connection between the
|
||||
two endpoints.
|
||||
|
||||
All wormholes start out "undilated". In this state, all messages are queued
|
||||
on the Rendezvous Server for the lifetime of the wormhole, and server-imposed
|
||||
number/size/rate limits apply. Calling `w.dilate()` initiates the dilation
|
||||
process, and success is signalled via either `d=w.when_dilated()` firing, or
|
||||
`dg.wormhole_dilated()` being called. Once dilated, the Wormhole can be used
|
||||
as an IConsumer/IProducer, and messages will be sent on a direct connection
|
||||
(if possible) or through the transit relay (if not).
|
||||
process, and eventually yields a set of Endpoints. Once dilated, the usual
|
||||
`.send_message()`/`.get_message()` APIs are disabled (TODO: really?), and
|
||||
these endpoints can be used to establish multiple (encrypted) "subchannel"
|
||||
connections to the other side.
|
||||
|
||||
Each subchannel behaves like a regular Twisted `ITransport`, so they can be
|
||||
glued to the Protocol instance of your choice. They also implement the
|
||||
IConsumer/IProducer interfaces.
|
||||
|
||||
These subchannels are *durable*: as long as the processes on both sides keep
|
||||
running, the subchannel will survive the network connection being dropped.
|
||||
For example, a file transfer can be started from a laptop, then while it is
|
||||
running, the laptop can be closed, moved to a new wifi network, opened back
|
||||
up, and the transfer will resume from the new IP address.
|
||||
|
||||
What's good about a non-dilated wormhole?:
|
||||
|
||||
* setup is faster: no delay while it tries to make a direct connection
|
||||
* survives temporary network outages, since messages are queued
|
||||
* works with "journaled mode", allowing progress to be made even when both
|
||||
sides are never online at the same time, by serializing the wormhole
|
||||
|
||||
|
@ -556,21 +564,103 @@ Use non-dilated wormholes when your application only needs to exchange a
|
|||
couple of messages, for example to set up public keys or provision access
|
||||
tokens. Use a dilated wormhole to move files.
|
||||
|
||||
Dilated wormholes can provide multiple "channels": these are multiplexed
|
||||
through the single (encrypted) TCP connection. Each channel is a separate
|
||||
stream (offering IProducer/IConsumer)
|
||||
Dilated wormholes can provide multiple "subchannels": these are multiplexed
|
||||
through the single (encrypted) TCP connection. Each subchannel is a separate
|
||||
stream (offering IProducer/IConsumer for flow control), and is opened and
|
||||
closed independently. A special "control channel" is available to both sides
|
||||
so they can coordinate how they use the subchannels.
|
||||
|
||||
To create a channel, call `c = w.create_channel()` on a dilated wormhole. The
|
||||
"channel ID" can be obtained with `c.get_id()`. This ID will be a short
|
||||
(unicode) string, which can be sent to the other side via a normal
|
||||
`w.send()`, or any other means. On the other side, use `c =
|
||||
w.open_channel(channel_id)` to get a matching channel object.
|
||||
The `d = w.dilate()` Deferred fires with a triple of Endpoints:
|
||||
|
||||
Then use `c.send(data)` and `d=c.when_received()` to exchange data, or wire
|
||||
them up with `c.registerProducer()`. Note that channels do not close until
|
||||
the wormhole connection is closed, so they do not have separate `close()`
|
||||
methods or events. Therefore if you plan to send files through them, you'll
|
||||
need to inform the recipient ahead of time about how many bytes to expect.
|
||||
```python
|
||||
d = w.dilate()
|
||||
def _dilated(res):
|
||||
(control_channel_ep, subchannel_client_ep, subchannel_server_ep) = res
|
||||
d.addCallback(_dilated)
|
||||
```
|
||||
|
||||
The `control_channel_ep` endpoint is a client-style endpoint, so both sides
|
||||
will connect to it with `ep.connect(factory)`. This endpoint is single-use:
|
||||
calling `.connect()` a second time will fail. The control channel is
|
||||
symmetric: it doesn't matter which side is the application-level
|
||||
client/server or initiator/responder, if the application even has such
|
||||
concepts. The two applications can use the control channel to negotiate who
|
||||
goes first, if necessary.
|
||||
|
||||
The subchannel endpoints are *not* symmetric: for each subchannel, one side
|
||||
must listen as a server, and the other must connect as a client. Subchannels
|
||||
can be established by either side at any time. This supports e.g.
|
||||
bidirectional file transfer, where either user of a GUI app can drop files
|
||||
into the "wormhole" whenever they like.
|
||||
|
||||
The `subchannel_client_ep` on one side is used to connect to the other side's
|
||||
`subchannel_server_ep`, and vice versa. The client endpoint is reusable. The
|
||||
server endpoint is single-use: `.listen(factory)` may only be called once.
|
||||
|
||||
Applications are under no obligation to use subchannels: for many use cases,
|
||||
the control channel is enough.
|
||||
|
||||
To use subchannels, once the wormhole is dilated and the endpoints are
|
||||
available, the listening-side application should attach a listener to the
|
||||
`subchannel_server_ep` endpoint:
|
||||
|
||||
```python
|
||||
def _dilated(res):
|
||||
(control_channel_ep, subchannel_client_ep, subchannel_server_ep) = res
|
||||
f = Factory(MyListeningProtocol)
|
||||
subchannel_server_ep.listen(f)
|
||||
```
|
||||
|
||||
When the connecting-side application wants to connect to that listening
|
||||
protocol, it should use `.connect()` with a suitable connecting protocol
|
||||
factory:
|
||||
|
||||
```python
|
||||
def _connect():
|
||||
f = Factory(MyConnectingProtocol)
|
||||
subchannel_client_ep.connect(f)
|
||||
```
|
||||
|
||||
For a bidirectional file-transfer application, both sides will establish a
|
||||
listening protocol. Later, if/when the user drops a file on the application
|
||||
window, that side will initiate a connection, use the resulting subchannel to
|
||||
transfer the single file, and then close the subchannel.
|
||||
|
||||
```python
|
||||
def FileSendingProtocol(internet.Protocol):
|
||||
def __init__(self, metadata, filename):
|
||||
self.file_metadata = metadata
|
||||
self.file_name = filename
|
||||
def connectionMade(self):
|
||||
self.transport.write(self.file_metadata)
|
||||
sender = protocols.basic.FileSender()
|
||||
f = open(self.file_name,"rb")
|
||||
d = sender.beginFileTransfer(f, self.transport)
|
||||
d.addBoth(self._done, f)
|
||||
def _done(res, f):
|
||||
self.transport.loseConnection()
|
||||
f.close()
|
||||
def _send(metadata, filename):
|
||||
f = protocol.ClientCreator(reactor,
|
||||
FileSendingProtocol, metadata, filename)
|
||||
subchannel_client_ep.connect(f)
|
||||
def FileReceivingProtocol(internet.Protocol):
|
||||
state = INITIAL
|
||||
def dataReceived(self, data):
|
||||
if state == INITIAL:
|
||||
self.state = DATA
|
||||
metadata = parse(data)
|
||||
self.f = open(metadata.filename, "wb")
|
||||
else:
|
||||
# local file writes are blocking, so don't bother with IConsumer
|
||||
self.f.write(data)
|
||||
def connectionLost(self, reason):
|
||||
self.f.close()
|
||||
def _dilated(res):
|
||||
(control_channel_ep, subchannel_client_ep, subchannel_server_ep) = res
|
||||
f = Factory(FileReceivingProtocol)
|
||||
subchannel_server_ep.listen(f)
|
||||
```
|
||||
|
||||
## Bytes, Strings, Unicode, and Python 3
|
||||
|
||||
|
|
540
docs/dilation-protocol.md
Normal file
540
docs/dilation-protocol.md
Normal file
|
@ -0,0 +1,540 @@
|
|||
# Dilation Internals
|
||||
|
||||
Wormhole dilation involves several moving parts. Both sides exchange messages
|
||||
through the Mailbox server to coordinate the establishment of a more direct
|
||||
connection. This connection might flow in either direction, so they trade
|
||||
"connection hints" to point at potential listening ports. This process might
|
||||
succeed in making multiple connections at about the same time, so one side
|
||||
must select the best one to use, and cleanly shut down the others. To make
|
||||
the dilated connection *durable*, this side must also decide when the
|
||||
connection has been lost, and then coordinate the construction of a
|
||||
replacement. Within this connection, a series of queued-and-acked subchannel
|
||||
messages are used to open/use/close the application-visible subchannels.
|
||||
|
||||
## Versions and can-dilate
|
||||
|
||||
The Wormhole protocol includes a `versions` message sent immediately after
|
||||
the shared PAKE key is established. This also serves as a key-confirmation
|
||||
message, allowing each side to confirm that the other side knows the right
|
||||
key. The body of the `versions` message is a JSON-formatted string with keys
|
||||
that are available for learning the abilities of the peer. Dilation is
|
||||
signaled by a key named `can-dilate`, whose value is a list of strings. Any
|
||||
version present in both side's lists is eligible for use.
|
||||
|
||||
## Leaders and Followers
|
||||
|
||||
Each side of a Wormhole has a randomly-generated dilation `side` string (this
|
||||
is included in the `please-dilate` message, and is independent of the
|
||||
Wormhole's mailbox "side"). When the wormhole is dilated, the side with the
|
||||
lexicographically-higher "side" value is named the "Leader", and the other
|
||||
side is named the "Follower". The general wormhole protocol treats both sides
|
||||
identically, but the distinction matters for the dilation protocol.
|
||||
|
||||
Both sides send a `please-dilate` as soon as dilation is triggered. Each side
|
||||
discovers whether it is the Leader or the Follower when the peer's
|
||||
"please-dilate" arrives. The Leader has exclusive control over whether a
|
||||
given connection is considered established or not: if there are multiple
|
||||
potential connections to use, the Leader decides which one to use, and the
|
||||
Leader gets to decide when the connection is no longer viable (and triggers
|
||||
the establishment of a new one).
|
||||
|
||||
The `please-dilate` includes a `use-version` key, computed as the "best"
|
||||
version of the intersection of the two sides' abilities as reported in the
|
||||
`versions` message. Both sides will use whichever `use-version` was specified
|
||||
by the Leader (they learn which side is the Leader at the same moment they
|
||||
learn the peer's `use-version` value). If the Follower cannot handle the
|
||||
`use-version` value, dilation fails (this shouldn't happen, as the Leader
|
||||
knew what the Follower was and was not capable of before sending that
|
||||
message).
|
||||
|
||||
## Connection Layers
|
||||
|
||||
We describe the protocol as a series of layers. Messages sent on one layer
|
||||
may be encoded or transformed before being delivered on some other layer.
|
||||
|
||||
L1 is the mailbox channel (queued store-and-forward messages that always go
|
||||
to the mailbox server, and then are forwarded to other clients subscribed to
|
||||
the same mailbox). Both clients remain connected to the mailbox server until
|
||||
the Wormhole is closed. They send DILATE-n messages to each other to manage
|
||||
the dilation process, including records like `please`, `connection-hints`,
|
||||
`reconnect`, and `reconnecting`.
|
||||
|
||||
L2 is the set of competing connection attempts for a given generation of
|
||||
connection. Each time the Leader decides to establish a new connection, a new
|
||||
generation number is used. Hopefully these are direct TCP connections between
|
||||
the two peers, but they may also include connections through the transit
|
||||
relay. Each connection must go through an encrypted handshake process before
|
||||
it is considered viable. Viable connections are then submitted to a selection
|
||||
process (on the Leader side), which chooses exactly one to use, and drops the
|
||||
others. It may wait an extra few seconds in the hopes of getting a "better"
|
||||
connection (faster, cheaper, etc), but eventually it will select one.
|
||||
|
||||
L3 is the current selected connection. There is one L3 for each generation.
|
||||
At all times, the wormhole will have exactly zero or one L3 connection. L3 is
|
||||
responsible for the selection process, connection monitoring/keepalives, and
|
||||
serialization/deserialization of the plaintext frames. L3 delivers decoded
|
||||
frames and connection-establishment events up to L4.
|
||||
|
||||
L4 is the persistent higher-level channel. It is created as soon as the first
|
||||
L3 connection is selected, and lasts until wormhole is closed entirely. L4
|
||||
contains OPEN/DATA/CLOSE/ACK messages: OPEN/DATA/CLOSE have a sequence number
|
||||
(scoped to the L4 connection and the direction of travel), and the ACK
|
||||
messages reference those sequence numbers. When a message is given to the L4
|
||||
channel for delivery to the remote side, it is always queued, then
|
||||
transmitted if there is an L3 connection available. This message remains in
|
||||
the queue until an ACK is received to retire it. If a new L3 connection is
|
||||
made, all queued messages will be re-sent (in seqnum order).
|
||||
|
||||
L5 are subchannels. There is one pre-established subchannel 0 known as the
|
||||
"control channel", which does not require an OPEN message. All other
|
||||
subchannels are created by the receipt of an OPEN message with the subchannel
|
||||
number. DATA frames are delivered to a specific subchannel. When the
|
||||
subchannel is no longer needed, one side will invoke the ``close()`` API
|
||||
(``loseConnection()`` in Twisted), which will cause a CLOSE message to be
|
||||
sent, and the local L5 object will be put into the "closing "state. When the
|
||||
other side receives the CLOSE, it will send its own CLOSE for the same
|
||||
subchannel, and fully close its local object (``connectionLost()``). When the
|
||||
first side receives CLOSE in the "closing" state, it will fully close its
|
||||
local object too.
|
||||
|
||||
All L5 subchannels will be paused (``pauseProducing()``) when the L3
|
||||
connection is paused or lost. They are resumed when the L3 connection is
|
||||
resumed or reestablished.
|
||||
|
||||
## Initiating Dilation
|
||||
|
||||
Dilation is triggered by calling the `w.dilate()` API. This returns a
|
||||
Deferred that will fire once the first L3 connection is established. It fires
|
||||
with a 3-tuple of endpoints that can be used to establish subchannels.
|
||||
|
||||
For dilation to succeed, both sides must call `w.dilate()`, since the
|
||||
resulting endpoints are the only way to access the subchannels. If the other
|
||||
side never calls `w.dilate()`, the Deferred will never fire.
|
||||
|
||||
The L1 (mailbox) path is used to deliver dilation requests and connection
|
||||
hints. The current mailbox protocol uses named "phases" to distinguish
|
||||
messages (rather than behaving like a regular ordered channel of arbitrary
|
||||
frames or bytes), and all-number phase names are reserved for application
|
||||
data (sent via `w.send_message()`). Therefore the dilation control messages
|
||||
use phases named `DILATE-0`, `DILATE-1`, etc. Each side maintains its own
|
||||
counter, so one side might be up to e.g. `DILATE-5` while the other has only
|
||||
gotten as far as `DILATE-2`. This effectively creates a pair of
|
||||
unidirectional streams of `DILATE-n` messages, each containing one or more
|
||||
dilation record, of various types described below. Note that all phases
|
||||
beyond the initial VERSION and PAKE phases are encrypted by the shared
|
||||
session key.
|
||||
|
||||
A future mailbox protocol might provide a simple ordered stream of typed
|
||||
messages, with application records and dilation records mixed together.
|
||||
|
||||
Each `DILATE-n` message is a JSON-encoded dictionary with a `type` field that
|
||||
has a string value. The dictionary will have other keys that depend upon the
|
||||
type.
|
||||
|
||||
`w.dilate()` triggers transmission of a `please` (i.e. "please dilate")
|
||||
record with a set of versions that can be accepted. Versions use strings,
|
||||
rather than integers, to support experimental protocols, however there is
|
||||
still a total ordering of preferability.
|
||||
|
||||
```
|
||||
{ "type": "please",
|
||||
"side": "abcdef",
|
||||
"accepted-versions": ["1"]
|
||||
}
|
||||
```
|
||||
|
||||
If one side receives a `please` before `w.dilate()` has been called locally,
|
||||
the contents are stored in case `w.dilate()` is called in the future. Once
|
||||
both `w.dilate()` has been called and the peer's `please` has been received,
|
||||
the side determines whether it is the Leader or the Follower. Both sides also
|
||||
compare `accepted-versions` fields to choose the best mutually-compatible
|
||||
version to use: they should always pick the same one.
|
||||
|
||||
Then both sides begin the connection process for generation 1 by opening
|
||||
listening sockets and sending `connection-hint` records for each one. After a
|
||||
slight delay they will also open connections to the Transit Relay of their
|
||||
choice and produce hints for it too. The receipt of inbound hints (on both
|
||||
sides) will trigger outbound connection attempts.
|
||||
|
||||
Some number of these connections may succeed, and the Leader decides which to
|
||||
use (via an in-band signal on the established connection). The others are
|
||||
dropped.
|
||||
|
||||
If something goes wrong with the established connection and the Leader
|
||||
decides a new one is necessary, the Leader will send a `reconnect` message.
|
||||
This might happen while connections are still being established, or while the
|
||||
Follower thinks it still has a viable connection (the Leader might observe
|
||||
problems that the Follower does not), or after the Follower thinks the
|
||||
connection has been lost. In all cases, the Leader is the only side which
|
||||
should send `reconnect`. The state machine code looks the same on both sides,
|
||||
for simplicity, but one path on each side is never used.
|
||||
|
||||
Upon receiving a `reconnect`, the Follower should stop any pending connection
|
||||
attempts and terminate any existing connections (even if they appear viable).
|
||||
Listening sockets may be retained, but any previous connection made through
|
||||
them must be dropped.
|
||||
|
||||
Once all connections have stopped, the Follower should send a `reconnecting`
|
||||
message, then start the connection process for the next generation, which
|
||||
will send new `connection-hint` messages for all listening sockets.
|
||||
|
||||
Generations are non-overlapping. The Leader will drop all connections from
|
||||
generation 1 before sending the `reconnect` for generation 2, and will not
|
||||
initiate any gen-2 connections until it receives the matching `reconnecting`
|
||||
from the Follower. The Follower must drop all gen-1 connections before it
|
||||
sends the `reconnecting` response (even if it thinks they are still
|
||||
functioning: if the Leader thought the gen-1 connection still worked, it
|
||||
wouldn't have started gen-2).
|
||||
|
||||
(TODO: what about a follower->leader connection that was started before
|
||||
start-dilation is received, and gets established on the Leader side after
|
||||
start-dilation is sent? the follower will drop it after it receives
|
||||
start-dilation, but meanwhile the leader may accept it as gen2)
|
||||
|
||||
(probably need to include the generation number in the handshake, or in the
|
||||
derived key)
|
||||
|
||||
(TODO: reduce the number of round-trip stalls here, I've added too many)
|
||||
|
||||
Each side is in the "connecting" state (which encompasses both making
|
||||
connection attempts and having an established connection) starting with the
|
||||
receipt of a `please-dilate` message and a local `w.dilate()` call. The
|
||||
Leader remains in that state until it abandons the connection and sends a
|
||||
`reconnect` message, at which point it remains in the "flushing" state until
|
||||
the Follower's `reconnecting` message is received. The Follower remains in
|
||||
"connecting" until it receives `reconnect`, then it stays in "dropping" until
|
||||
it finishes halting all outstanding connections, after which it sends
|
||||
`reconnecting` and switches back to "connecting".
|
||||
|
||||
"Connection hints" are type/address/port records that tell the other side of
|
||||
likely targets for L2 connections. Both sides will try to determine their
|
||||
external IP addresses, listen on a TCP port, and advertise `(tcp,
|
||||
external-IP, port)` as a connection hint. The Transit Relay is also used as a
|
||||
(lower-priority) hint. These are sent in `connection-hint` records, which can
|
||||
be sent any time after both sending and receiving a `please` record. Each
|
||||
side will initiate connections upon receipt of the hints.
|
||||
|
||||
```
|
||||
{ "type": "connection-hints",
|
||||
"hints": [ ... ]
|
||||
}
|
||||
```
|
||||
|
||||
Hints can arrive at any time. One side might immediately send hints that can
|
||||
be computed quickly, then send additional hints later as they become
|
||||
available. For example, it might enumerate the local network interfaces and
|
||||
send hints for all of the LAN addresses first, then send port-forwarding
|
||||
(UPnP) requests to the local router. When the forwarding is established
|
||||
(providing an externally-visible IP address and port), it can send additional
|
||||
hints for that new endpoint. If the other peer happens to be on the same LAN,
|
||||
the local connection can be established without waiting for the router's
|
||||
response.
|
||||
|
||||
|
||||
### Connection Hint Format
|
||||
|
||||
Each member of the `hints` field describes a potential L2 connection target
|
||||
endpoint, with an associated priority and a set of hints.
|
||||
|
||||
The priority is a number (positive or negative float), where larger numbers
|
||||
indicate that the client supplying that hint would prefer to use this
|
||||
connection over others of lower number. This indicates a sense of cost or
|
||||
performance. For example, the Transit Relay is lower priority than a direct
|
||||
TCP connection, because it incurs a bandwidth cost (on the relay operator),
|
||||
as well as adding latency.
|
||||
|
||||
Each endpoint has a set of hints, because the same target might be reachable
|
||||
by multiple hints. Once one hint succeeds, there is no point in using the
|
||||
other hints.
|
||||
|
||||
TODO: think this through some more. What's the example of a single endpoint
|
||||
reachable by multiple hints? Should each hint have its own priority, or just
|
||||
each endpoint?
|
||||
|
||||
## L2 protocol
|
||||
|
||||
Upon ``connectionMade()``, both sides send their handshake message. The
|
||||
Leader sends "Magic-Wormhole Dilation Handshake v1 Leader\n\n". The Follower
|
||||
sends "Magic-Wormhole Dilation Handshake v1 Follower\n\n". This should
|
||||
trigger an immediate error for most non-magic-wormhole listeners (e.g. HTTP
|
||||
servers that were contacted by accident). If the wrong handshake is received,
|
||||
the connection will be dropped. For debugging purposes, the node might want
|
||||
to keep looking at data beyond the first incorrect character and log
|
||||
everything until the first newline.
|
||||
|
||||
Everything beyond that point is a Noise protocol message, which consists of a
|
||||
4-byte big-endian length field, followed by the indicated number of bytes.
|
||||
This uses the `NNpsk0` pattern with the Leader as the first party ("-> psk,
|
||||
e" in the Noise spec), and the Follower as the second ("<- e, ee"). The
|
||||
pre-shared-key is the "dilation key", which is statically derived from the
|
||||
master PAKE key using HKDF. Each L2 connection uses the same dilation key,
|
||||
but different ephemeral keys, so each gets a different session key.
|
||||
|
||||
The Leader sends the first message, which is a psk-encrypted ephemeral key.
|
||||
The Follower sends the next message, its own psk-encrypted ephemeral key. The
|
||||
Follower then sends an empty packet as the "key confirmation message", which
|
||||
will be encrypted by the shared key.
|
||||
|
||||
The Leader sees the KCM and knows the connection is viable. It delivers the
|
||||
protocol object to the L3 manager, which will decide which connection to
|
||||
select. When the L2 connection is selected to be the new L3, it will send an
|
||||
empty KCM of its own, to let the Follower know the connection being selected.
|
||||
All other L2 connections (either viable or still in handshake) are dropped,
|
||||
all other connection attempts are cancelled. All listening sockets may or may
|
||||
not be shut down (TODO: think about it).
|
||||
|
||||
The Follower will wait for either an empty KCM (at which point the L2
|
||||
connection is delivered to the Dilation manager as the new L3), a
|
||||
disconnection, or an invalid message (which causes the connection to be
|
||||
dropped). Other connections and/or listening sockets are stopped.
|
||||
|
||||
Internally, the L2Protocol object manages the Noise session itself. It knows
|
||||
(via a constructor argument) whether it is on the Leader or Follower side,
|
||||
which affects both the role is plays in the Noise pattern, and the reaction
|
||||
to receiving the ephemeral key (for which only the Follower sends an empty
|
||||
KCM message). After that, the L2Protocol notifies the L3 object in three
|
||||
situations:
|
||||
|
||||
* the Noise session produces a valid decrypted frame (for Leader, this
|
||||
includes the Follower's KCM, and thus indicates a viable candidate for
|
||||
connection selection)
|
||||
* the Noise session reports a failed decryption
|
||||
* the TCP session is lost
|
||||
|
||||
All notifications include a reference to the L2Protocol object (`self`). The
|
||||
L3 object uses this reference to either close the connection (for errors or
|
||||
when the selection process chooses someone else), to send the KCM message
|
||||
(after selection, only for the Leader), or to send other L4 messages. The L3
|
||||
object will retain a reference to the winning L2 object.
|
||||
|
||||
## L3 protocol
|
||||
|
||||
The L3 layer is responsible for connection selection, monitoring/keepalives,
|
||||
and message (de)serialization. Framing is handled by L2, so the inbound L3
|
||||
codepath receives single-message bytestrings, and delivers the same down to
|
||||
L2 for encryption, framing, and transmission.
|
||||
|
||||
Connection selection takes place exclusively on the Leader side, and includes
|
||||
the following:
|
||||
|
||||
* receipt of viable L2 connections from below (indicated by the first valid
|
||||
decrypted frame received for any given connection)
|
||||
* expiration of a timer
|
||||
* comparison of TBD quality/desirability/cost metrics of viable connections
|
||||
* selection of winner
|
||||
* instructions to losing connections to disconnect
|
||||
* delivery of KCM message through winning connection
|
||||
* retain reference to winning connection
|
||||
|
||||
On the Follower side, the L3 manager just waits for the first connection to
|
||||
receive the Leader's KCM, at which point it is retained and all others are
|
||||
dropped.
|
||||
|
||||
The L3 manager knows which "generation" of connection is being established.
|
||||
Each generation uses a different dilation key (?), and is triggered by a new
|
||||
set of L1 messages. Connections from one generation should not be confused
|
||||
with those of a different generation.
|
||||
|
||||
Each time a new L3 connection is established, the L4 protocol is notified. It
|
||||
will will immediately send all the L4 messages waiting in its outbound queue.
|
||||
The L3 protocol simply wraps these in Noise frames and sends them to the
|
||||
other side.
|
||||
|
||||
The L3 manager monitors the viability of the current connection, and declares
|
||||
it as lost when bidirectional traffic cannot be maintained. It uses PING and
|
||||
PONG messages to detect this. These also serve to keep NAT entries alive,
|
||||
since many firewalls will stop forwarding packets if they don't observe any
|
||||
traffic for e.g. 5 minutes.
|
||||
|
||||
Our goals are:
|
||||
|
||||
* don't allow more than 30? seconds to pass without at least *some* data
|
||||
being sent along each side of the connection
|
||||
* allow the Leader to detect silent connection loss within 60? seconds
|
||||
* minimize overhead
|
||||
|
||||
We need both sides to:
|
||||
|
||||
* maintain a 30-second repeating timer
|
||||
* set a flag each time we write to the connection
|
||||
* each time the timer fires, if the flag was clear then send a PONG,
|
||||
otherwise clear the flag
|
||||
|
||||
In addition, the Leader must:
|
||||
|
||||
* run a 60-second repeating timer (ideally somewhat offset from the other)
|
||||
* set a flag each time we receive data from the connection
|
||||
* each time the timer fires, if the flag was clear then drop the connection,
|
||||
otherwise clear the flag
|
||||
|
||||
In the future, we might have L2 links that are less connection-oriented,
|
||||
which might have a unidirectional failure mode, at which point we'll need to
|
||||
monitor full roundtrips. To accomplish this, the Leader will send periodic
|
||||
unconditional PINGs, and the Follower will respond with PONGs. If the
|
||||
Leader->Follower connection is down, the PINGs won't arrive and no PONGs will
|
||||
be produced. If the Follower->Leader direction has failed, the PONGs won't
|
||||
arrive. The delivery of both will be delayed by actual data, so the timeouts
|
||||
should be adjusted if we see regular data arriving.
|
||||
|
||||
If the connection is dropped before the wormhole is closed (either the other
|
||||
end explicitly dropped it, we noticed a problem and told TCP to drop it, or
|
||||
TCP noticed a problem itself), the Leader-side L3 manager will initiate a
|
||||
reconnection attempt. This uses L1 to send a new DILATE message through the
|
||||
mailbox server, along with new connection hints. Eventually this will result
|
||||
in a new L3 connection being established.
|
||||
|
||||
Finally, L3 is responsible for message serialization and deserialization. L2
|
||||
performs decryption and delivers plaintext frames to L3. Each frame starts
|
||||
with a one-byte type indicator. The rest of the message depends upon the
|
||||
type:
|
||||
|
||||
* 0x00 PING, 4-byte ping-id
|
||||
* 0x01 PONG, 4-byte ping-id
|
||||
* 0x02 OPEN, 4-byte subchannel-id, 4-byte seqnum
|
||||
* 0x03 DATA, 4-byte subchannel-id, 4-byte seqnum, variable-length payload
|
||||
* 0x04 CLOSE, 4-byte subchannel-id, 4-byte seqnum
|
||||
* 0x05 ACK, 4-byte response-seqnum
|
||||
|
||||
All seqnums are big-endian, and are provided by the L4 protocol. The other
|
||||
fields are arbitrary and not interpreted as integers. The subchannel-ids must
|
||||
be allocated by both sides without collision, but otherwise they are only
|
||||
used to look up L5 objects for dispatch. The response-seqnum is always copied
|
||||
from the OPEN/DATA/CLOSE packet being acknowledged.
|
||||
|
||||
L3 consumes the PING and PONG messages. Receiving any PING will provoke a
|
||||
PONG in response, with a copy of the ping-id field. The 30-second timer will
|
||||
produce unprovoked PONGs with a ping-id of all zeros. A future viability
|
||||
protocol will use PINGs to test for roundtrip functionality.
|
||||
|
||||
All other messages (OPEN/DATA/CLOSE/ACK) are deserialized and delivered
|
||||
"upstairs" to the L4 protocol handler.
|
||||
|
||||
The current L3 connection's `IProducer`/`IConsumer` interface is made
|
||||
available to the L4 flow-control manager.
|
||||
|
||||
## L4 protocol
|
||||
|
||||
The L4 protocol manages a durable stream of OPEN/DATA/CLOSE/ACK messages.
|
||||
Since each will be enclosed in a Noise frame before they pass to L3, they do
|
||||
not need length fields or other framing.
|
||||
|
||||
Each OPEN/DATA/CLOSE has a sequence number, starting at 0, and monotonically
|
||||
increasing by 1 for each message. Each direction has a separate number space.
|
||||
|
||||
The L4 manager maintains a double-ended queue of unacknowledged outbound
|
||||
messages. Subchannel activity (opening, closing, sending data) cause messages
|
||||
to be added to this queue. If an L3 connection is available, these messages
|
||||
are also sent over that connection, but they remain in the queue in case the
|
||||
connection is lost and they must be retransmitted on some future replacement
|
||||
connection. Messages stay in the queue until they can be retired by the
|
||||
receipt of an ACK with a matching response-sequence-number. This provides
|
||||
reliable message delivery that survives the L3 connection being replaced.
|
||||
|
||||
ACKs are not acked, nor do they have seqnums of their own. Each inbound side
|
||||
remembers the highest ACK it has sent, and ignores incoming OPEN/DATA/CLOSE
|
||||
messages with that sequence number or higher. This ensures in-order
|
||||
at-most-once processing of OPEN/DATA/CLOSE messages.
|
||||
|
||||
Each inbound OPEN message causes a new L5 subchannel object to be created.
|
||||
Subsequent DATA/CLOSE messages for the same subchannel-id are delivered to
|
||||
that object.
|
||||
|
||||
Each time an L3 connection is established, the side will immediately send all
|
||||
L4 messages waiting in the outbound queue. A future protocol might reduce
|
||||
this duplication by including the highest received sequence number in the L1
|
||||
PLEASE-DILATE message, which would effectively retire queued messages before
|
||||
initiating the L2 connection process. On any given L3 connection, all
|
||||
messages are sent in-order. The receipt of an ACK for seqnum `N` allows all
|
||||
messages with `seqnum <= N` to be retired.
|
||||
|
||||
The L4 layer is also responsible for managing flow control among the L3
|
||||
connection and the various L5 subchannels.
|
||||
|
||||
## L5 subchannels
|
||||
|
||||
The L5 layer consists of a collection of "subchannel" objects, a dispatcher,
|
||||
and the endpoints that provide the Twisted-flavored API.
|
||||
|
||||
Other than the "control channel", all subchannels are created by a client
|
||||
endpoint connection API. The side that calls this API is named the Initiator,
|
||||
and the other side is named the Acceptor. Subchannels can be initiated in
|
||||
either direction, independent of the Leader/Follower distinction. For a
|
||||
typical file-transfer application, the subchannel would be initiated by the
|
||||
side seeking to send a file.
|
||||
|
||||
Each subchannel uses a distinct subchannel-id, which is a four-byte
|
||||
identifier. Both directions share a number space (unlike L4 seqnums), so the
|
||||
rule is that the Leader side sets the last bit of the last byte to a 1, while
|
||||
the Follower sets it to a 0. These are not generally treated as integers,
|
||||
however for the sake of debugging, the implementation generates them with a
|
||||
simple big-endian-encoded counter (`counter*2+1` for the Leader,
|
||||
`counter*2+2` for the Follower, with id `0` reserved for the control
|
||||
channel).
|
||||
|
||||
When the `client_ep.connect()` API is called, the Initiator allocates a
|
||||
subchannel-id and sends an OPEN. It can then immediately send DATA messages
|
||||
with the outbound data (there is no special response to an OPEN, so there is
|
||||
no need to wait). The Acceptor will trigger their `.connectionMade` handler
|
||||
upon receipt of the OPEN.
|
||||
|
||||
Subchannels are durable: they do not close until one side calls
|
||||
`.loseConnection` on the subchannel object (or the enclosing Wormhole is
|
||||
closed). Either the Initiator or the Acceptor can call `.loseConnection`.
|
||||
This causes a CLOSE message to be sent (with the subchannel-id). The other
|
||||
side will send its own CLOSE message in response. Each side will signal the
|
||||
`.connectionLost()` event upon receipt of a CLOSE.
|
||||
|
||||
There is no equivalent to TCP's "half-closed" state, however if only one side
|
||||
calls `close()`, then all data written before that call will be delivered
|
||||
before the other side observes `.connectionLost()`. Any inbound data that was
|
||||
queued for delivery before the other side sees the CLOSE will still be
|
||||
delivered to the side that called `close()` before it sees
|
||||
`.connectionLost()`. Internally, the side which called `.loseConnection` will
|
||||
remain in a special "closing" state until the CLOSE response arrives, during
|
||||
which time DATA payloads are still delivered. After calling `close()` (or
|
||||
receiving CLOSE), any outbound `.write()` calls will trigger an error.
|
||||
|
||||
DATA payloads that arrive for a non-open subchannel are logged and discarded.
|
||||
|
||||
This protocol calls for one OPEN and two CLOSE messages for each subchannel,
|
||||
with some arbitrary number of DATA messages in between. Subchannel-ids should
|
||||
not be reused (it would probably work, the protocol hasn't been analyzed
|
||||
enough to be sure).
|
||||
|
||||
The "control channel" is special. It uses a subchannel-id of all zeros, and
|
||||
is opened implicitly by both sides as soon as the first L3 connection is
|
||||
selected. It is routed to a special client-on-both-sides endpoint, rather
|
||||
than causing the listening endpoint to accept a new connection. This avoids
|
||||
the need for application-level code to negotiate who should be the one to
|
||||
open it (the Leader/Follower distinction is private to the Wormhole
|
||||
internals: applications are not obligated to pick a side).
|
||||
|
||||
OPEN and CLOSE messages for the control channel are logged and discarded. The
|
||||
control-channel client endpoints can only be used once, and does not close
|
||||
until the Wormhole itself is closed.
|
||||
|
||||
Each OPEN/DATA/CLOSE message is delivered to the L4 object for queueing,
|
||||
delivery, and eventual retirement. The L5 layer does not keep track of old
|
||||
messages.
|
||||
|
||||
### Flow Control
|
||||
|
||||
Subchannels are flow-controlled by pausing their writes when the L3
|
||||
connection is paused, and pausing the L3 connection when the subchannel
|
||||
signals a pause. When the outbound L3 connection is full, *all* subchannels
|
||||
are paused. Likewise the inbound connection is paused if *any* of the
|
||||
subchannels asks for a pause. This is much easier to implement and improves
|
||||
our utilization factor (we can use TCP's window-filling algorithm, instead of
|
||||
rolling our own), but will block all subchannels even if only one of them
|
||||
gets full. This shouldn't matter for many applications, but might be
|
||||
noticeable when combining very different kinds of traffic (e.g. a chat
|
||||
conversation sharing a wormhole with file-transfer might prefer the IM text
|
||||
to take priority).
|
||||
|
||||
Each subchannel implements Twisted's `ITransport`, `IProducer`, and
|
||||
`IConsumer` interfaces. The Endpoint API causes a new `IProtocol` object to
|
||||
be created (by the caller's factory) and glued to the subchannel object in
|
||||
the `.transport` property, as is standard in Twisted-based applications.
|
||||
|
||||
All subchannels are also paused when the L3 connection is lost, and are
|
||||
unpaused when a new replacement connection is selected.
|
2000
docs/new-protocol.svg
Normal file
2000
docs/new-protocol.svg
Normal file
File diff suppressed because it is too large
Load Diff
After Width: | Height: | Size: 91 KiB |
|
@ -23,12 +23,13 @@ digraph {
|
|||
Terminator [shape="box" color="blue" fontcolor="blue"]
|
||||
InputHelperAPI [shape="oval" label="input\nhelper\nAPI"
|
||||
color="blue" fontcolor="blue"]
|
||||
Dilator [shape="box" label="Dilator" color="blue" fontcolor="blue"]
|
||||
|
||||
#Connection -> websocket [color="blue"]
|
||||
#Connection -> Order [color="blue"]
|
||||
|
||||
Wormhole -> Boss [style="dashed"
|
||||
label="allocate_code\ninput_code\nset_code\nsend\nclose\n(once)"
|
||||
label="allocate_code\ninput_code\nset_code\ndilate\nsend\nclose\n(once)"
|
||||
color="red" fontcolor="red"]
|
||||
#Wormhole -> Boss [color="blue"]
|
||||
Boss -> Wormhole [style="dashed" label="got_code\ngot_key\ngot_verifier\ngot_version\nreceived (seq)\nclosed\n(once)"]
|
||||
|
@ -112,4 +113,7 @@ digraph {
|
|||
Terminator -> Boss [style="dashed" label="closed\n(once)"]
|
||||
Boss -> Terminator [style="dashed" color="red" fontcolor="red"
|
||||
label="close"]
|
||||
|
||||
Boss -> Dilator [style="dashed" label="dilate\nreceived_dilate\ngot_wormhole_versions"]
|
||||
Dilator -> Send [style="dashed" label="send(dilate-N)"]
|
||||
}
|
||||
|
|
|
@ -7,3 +7,6 @@ versionfile_source = src/wormhole/_version.py
|
|||
versionfile_build = wormhole/_version.py
|
||||
tag_prefix =
|
||||
parentdir_prefix = magic-wormhole
|
||||
|
||||
[flake8]
|
||||
max-line-length = 85
|
||||
|
|
1
setup.py
1
setup.py
|
@ -54,6 +54,7 @@ setup(name="magic-wormhole",
|
|||
"dev": ["mock", "tox", "pyflakes",
|
||||
"magic-wormhole-transit-relay==0.1.2",
|
||||
"magic-wormhole-mailbox-server==0.3.1"],
|
||||
"dilate": ["noiseprotocol"],
|
||||
},
|
||||
test_suite="wormhole.test",
|
||||
cmdclass=commands,
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from __future__ import absolute_import, print_function, unicode_literals
|
||||
from .cli import cli
|
||||
|
||||
if __name__ != "__main__":
|
||||
raise ImportError('this module should not be imported')
|
||||
|
||||
cli.wormhole()
|
||||
if __name__ == "__main__":
|
||||
from .cli import cli
|
||||
cli.wormhole()
|
||||
else:
|
||||
# raise ImportError('this module should not be imported')
|
||||
pass
|
||||
|
|
|
@ -12,6 +12,7 @@ from zope.interface import implementer
|
|||
from . import _interfaces
|
||||
from ._allocator import Allocator
|
||||
from ._code import Code, validate_code
|
||||
from ._dilation.manager import Dilator
|
||||
from ._input import Input
|
||||
from ._key import Key
|
||||
from ._lister import Lister
|
||||
|
@ -38,6 +39,8 @@ class Boss(object):
|
|||
_versions = attrib(validator=instance_of(dict))
|
||||
_client_version = attrib(validator=instance_of(tuple))
|
||||
_reactor = attrib()
|
||||
_eventual_queue = attrib()
|
||||
_cooperator = attrib()
|
||||
_journal = attrib(validator=provides(_interfaces.IJournal))
|
||||
_tor = attrib(validator=optional(provides(_interfaces.ITorManager)))
|
||||
_timing = attrib(validator=provides(_interfaces.ITiming))
|
||||
|
@ -64,6 +67,8 @@ class Boss(object):
|
|||
self._I = Input(self._timing)
|
||||
self._C = Code(self._timing)
|
||||
self._T = Terminator()
|
||||
self._D = Dilator(self._reactor, self._eventual_queue,
|
||||
self._cooperator)
|
||||
|
||||
self._N.wire(self._M, self._I, self._RC, self._T)
|
||||
self._M.wire(self._N, self._RC, self._O, self._T)
|
||||
|
@ -77,6 +82,7 @@ class Boss(object):
|
|||
self._I.wire(self._C, self._L)
|
||||
self._C.wire(self, self._A, self._N, self._K, self._I)
|
||||
self._T.wire(self, self._RC, self._N, self._M)
|
||||
self._D.wire(self._S)
|
||||
|
||||
def _init_other_state(self):
|
||||
self._did_start_code = False
|
||||
|
@ -84,6 +90,9 @@ class Boss(object):
|
|||
self._next_rx_phase = 0
|
||||
self._rx_phases = {} # phase -> plaintext
|
||||
|
||||
self._next_rx_dilate_seqnum = 0
|
||||
self._rx_dilate_seqnums = {} # seqnum -> plaintext
|
||||
|
||||
self._result = "empty"
|
||||
|
||||
# these methods are called from outside
|
||||
|
@ -196,6 +205,9 @@ class Boss(object):
|
|||
self._did_start_code = True
|
||||
self._C.set_code(code)
|
||||
|
||||
def dilate(self):
|
||||
return self._D.dilate() # fires with endpoints
|
||||
|
||||
@m.input()
|
||||
def send(self, plaintext):
|
||||
pass
|
||||
|
@ -255,8 +267,11 @@ class Boss(object):
|
|||
def got_message(self, phase, plaintext):
|
||||
assert isinstance(phase, type("")), type(phase)
|
||||
assert isinstance(plaintext, type(b"")), type(plaintext)
|
||||
d_mo = re.search(r'^dilate-(\d+)$', phase)
|
||||
if phase == "version":
|
||||
self._got_version(plaintext)
|
||||
elif d_mo:
|
||||
self._got_dilate(int(d_mo.group(1)), plaintext)
|
||||
elif re.search(r'^\d+$', phase):
|
||||
self._got_phase(int(phase), plaintext)
|
||||
else:
|
||||
|
@ -272,6 +287,10 @@ class Boss(object):
|
|||
def _got_phase(self, phase, plaintext):
|
||||
pass
|
||||
|
||||
@m.input()
|
||||
def _got_dilate(self, seqnum, plaintext):
|
||||
pass
|
||||
|
||||
@m.input()
|
||||
def got_key(self, key):
|
||||
pass
|
||||
|
@ -294,6 +313,7 @@ class Boss(object):
|
|||
# most of this is wormhole-to-wormhole, ignored for now
|
||||
# in the future, this is how Dilation is signalled
|
||||
self._their_versions = bytes_to_dict(plaintext)
|
||||
self._D.got_wormhole_versions(self._their_versions)
|
||||
# but this part is app-to-app
|
||||
app_versions = self._their_versions.get("app_versions", {})
|
||||
self._W.got_versions(app_versions)
|
||||
|
@ -335,6 +355,10 @@ class Boss(object):
|
|||
def W_got_key(self, key):
|
||||
self._W.got_key(key)
|
||||
|
||||
@m.output()
|
||||
def D_got_key(self, key):
|
||||
self._D.got_key(key)
|
||||
|
||||
@m.output()
|
||||
def W_got_verifier(self, verifier):
|
||||
self._W.got_verifier(verifier)
|
||||
|
@ -348,6 +372,16 @@ class Boss(object):
|
|||
self._W.received(self._rx_phases.pop(self._next_rx_phase))
|
||||
self._next_rx_phase += 1
|
||||
|
||||
@m.output()
|
||||
def D_received_dilate(self, seqnum, plaintext):
|
||||
assert isinstance(seqnum, six.integer_types), type(seqnum)
|
||||
# strict phase order, no gaps
|
||||
self._rx_dilate_seqnums[seqnum] = plaintext
|
||||
while self._next_rx_dilate_seqnum in self._rx_dilate_seqnums:
|
||||
m = self._rx_dilate_seqnums.pop(self._next_rx_dilate_seqnum)
|
||||
self._D.received_dilate(m)
|
||||
self._next_rx_dilate_seqnum += 1
|
||||
|
||||
@m.output()
|
||||
def W_close_with_error(self, err):
|
||||
self._result = err # exception
|
||||
|
@ -370,7 +404,7 @@ class Boss(object):
|
|||
S1_lonely.upon(scared, enter=S3_closing, outputs=[close_scared])
|
||||
S1_lonely.upon(close, enter=S3_closing, outputs=[close_lonely])
|
||||
S1_lonely.upon(send, enter=S1_lonely, outputs=[S_send])
|
||||
S1_lonely.upon(got_key, enter=S1_lonely, outputs=[W_got_key])
|
||||
S1_lonely.upon(got_key, enter=S1_lonely, outputs=[W_got_key, D_got_key])
|
||||
S1_lonely.upon(rx_error, enter=S3_closing, outputs=[close_error])
|
||||
S1_lonely.upon(error, enter=S4_closed, outputs=[W_close_with_error])
|
||||
|
||||
|
@ -378,6 +412,7 @@ class Boss(object):
|
|||
S2_happy.upon(got_verifier, enter=S2_happy, outputs=[W_got_verifier])
|
||||
S2_happy.upon(_got_phase, enter=S2_happy, outputs=[W_received])
|
||||
S2_happy.upon(_got_version, enter=S2_happy, outputs=[process_version])
|
||||
S2_happy.upon(_got_dilate, enter=S2_happy, outputs=[D_received_dilate])
|
||||
S2_happy.upon(scared, enter=S3_closing, outputs=[close_scared])
|
||||
S2_happy.upon(close, enter=S3_closing, outputs=[close_happy])
|
||||
S2_happy.upon(send, enter=S2_happy, outputs=[S_send])
|
||||
|
@ -389,6 +424,7 @@ class Boss(object):
|
|||
S3_closing.upon(got_verifier, enter=S3_closing, outputs=[])
|
||||
S3_closing.upon(_got_phase, enter=S3_closing, outputs=[])
|
||||
S3_closing.upon(_got_version, enter=S3_closing, outputs=[])
|
||||
S3_closing.upon(_got_dilate, enter=S3_closing, outputs=[])
|
||||
S3_closing.upon(happy, enter=S3_closing, outputs=[])
|
||||
S3_closing.upon(scared, enter=S3_closing, outputs=[])
|
||||
S3_closing.upon(close, enter=S3_closing, outputs=[])
|
||||
|
@ -400,6 +436,7 @@ class Boss(object):
|
|||
S4_closed.upon(got_verifier, enter=S4_closed, outputs=[])
|
||||
S4_closed.upon(_got_phase, enter=S4_closed, outputs=[])
|
||||
S4_closed.upon(_got_version, enter=S4_closed, outputs=[])
|
||||
S4_closed.upon(_got_dilate, enter=S4_closed, outputs=[])
|
||||
S4_closed.upon(happy, enter=S4_closed, outputs=[])
|
||||
S4_closed.upon(scared, enter=S4_closed, outputs=[])
|
||||
S4_closed.upon(close, enter=S4_closed, outputs=[])
|
||||
|
|
0
src/wormhole/_dilation/__init__.py
Normal file
0
src/wormhole/_dilation/__init__.py
Normal file
11
src/wormhole/_dilation/_noise.py
Normal file
11
src/wormhole/_dilation/_noise.py
Normal file
|
@ -0,0 +1,11 @@
|
|||
try:
|
||||
from noise.exceptions import NoiseInvalidMessage
|
||||
except ImportError:
|
||||
class NoiseInvalidMessage(Exception):
|
||||
pass
|
||||
|
||||
try:
|
||||
from noise.connection import NoiseConnection
|
||||
except ImportError:
|
||||
# allow imports to work on py2.7, even if dilation doesn't
|
||||
NoiseConnection = None
|
520
src/wormhole/_dilation/connection.py
Normal file
520
src/wormhole/_dilation/connection.py
Normal file
|
@ -0,0 +1,520 @@
|
|||
from __future__ import print_function, unicode_literals
|
||||
from collections import namedtuple
|
||||
import six
|
||||
from attr import attrs, attrib
|
||||
from attr.validators import instance_of, provides
|
||||
from automat import MethodicalMachine
|
||||
from zope.interface import Interface, implementer
|
||||
from twisted.python import log
|
||||
from twisted.internet.protocol import Protocol
|
||||
from twisted.internet.interfaces import ITransport
|
||||
from .._interfaces import IDilationConnector
|
||||
from ..observer import OneShotObserver
|
||||
from .encode import to_be4, from_be4
|
||||
from .roles import FOLLOWER
|
||||
from ._noise import NoiseInvalidMessage
|
||||
|
||||
# InboundFraming is given data and returns Frames (Noise wire-side
|
||||
# bytestrings). It handles the relay handshake and the prologue. The Frames it
|
||||
# returns are either the ephemeral key (the Noise "handshake") or ciphertext
|
||||
# messages.
|
||||
|
||||
# The next object up knows whether it's expecting a Handshake or a message. It
|
||||
# feeds the first into Noise as a handshake, it feeds the rest into Noise as a
|
||||
# message (which produces a plaintext stream). It emits tokens that are either
|
||||
# "i've finished with the handshake (so you can send the KCM if you want)", or
|
||||
# "here is a decrypted message (which might be the KCM)".
|
||||
|
||||
# the transmit direction goes directly to transport.write, and doesn't touch
|
||||
# the state machine. we can do this because the way we encode/encrypt/frame
|
||||
# things doesn't depend upon the receiver state. It would be more safe to e.g.
|
||||
# prohibit sending ciphertext frames unless we're in the received-handshake
|
||||
# state, but then we'll be in the middle of an inbound state transition ("we
|
||||
# just received the handshake, so you can send a KCM now") when we perform an
|
||||
# operation that depends upon the state (send_plaintext(kcm)), which is not a
|
||||
# coherent/safe place to touch the state machine.
|
||||
|
||||
# we could set a flag and test it from inside send_plaintext, which kind of
|
||||
# violates the state machine owning the state (ideally all "if" statements
|
||||
# would be translated into same-input transitions from different starting
|
||||
# states). For the specific question of sending plaintext frames, Noise will
|
||||
# refuse us unless it's ready anyways, so the question is probably moot.
|
||||
|
||||
|
||||
class IFramer(Interface):
|
||||
pass
|
||||
|
||||
|
||||
class IRecord(Interface):
|
||||
pass
|
||||
|
||||
|
||||
def first(l):
|
||||
return l[0]
|
||||
|
||||
|
||||
class Disconnect(Exception):
|
||||
pass
|
||||
|
||||
|
||||
RelayOK = namedtuple("RelayOk", [])
|
||||
Prologue = namedtuple("Prologue", [])
|
||||
Frame = namedtuple("Frame", ["frame"])
|
||||
|
||||
|
||||
@attrs
|
||||
@implementer(IFramer)
|
||||
class _Framer(object):
|
||||
_transport = attrib(validator=provides(ITransport))
|
||||
_outbound_prologue = attrib(validator=instance_of(bytes))
|
||||
_inbound_prologue = attrib(validator=instance_of(bytes))
|
||||
_buffer = b""
|
||||
_can_send_frames = False
|
||||
|
||||
# in: use_relay
|
||||
# in: connectionMade, dataReceived
|
||||
# out: prologue_received, frame_received
|
||||
# out (shared): transport.loseConnection
|
||||
# out (shared): transport.write (relay handshake, prologue)
|
||||
# states: want_relay, want_prologue, want_frame
|
||||
m = MethodicalMachine()
|
||||
set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover
|
||||
|
||||
@m.state()
|
||||
def want_relay(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.state(initial=True)
|
||||
def want_prologue(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.state()
|
||||
def want_frame(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.input()
|
||||
def use_relay(self, relay_handshake):
|
||||
pass
|
||||
|
||||
@m.input()
|
||||
def connectionMade(self):
|
||||
pass
|
||||
|
||||
@m.input()
|
||||
def parse(self):
|
||||
pass
|
||||
|
||||
@m.input()
|
||||
def got_relay_ok(self):
|
||||
pass
|
||||
|
||||
@m.input()
|
||||
def got_prologue(self):
|
||||
pass
|
||||
|
||||
@m.output()
|
||||
def store_relay_handshake(self, relay_handshake):
|
||||
self._outbound_relay_handshake = relay_handshake
|
||||
self._expected_relay_handshake = b"ok\n" # TODO: make this configurable
|
||||
|
||||
@m.output()
|
||||
def send_relay_handshake(self):
|
||||
self._transport.write(self._outbound_relay_handshake)
|
||||
|
||||
@m.output()
|
||||
def send_prologue(self):
|
||||
self._transport.write(self._outbound_prologue)
|
||||
|
||||
@m.output()
|
||||
def parse_relay_ok(self):
|
||||
if self._get_expected("relay_ok", self._expected_relay_handshake):
|
||||
return RelayOK()
|
||||
|
||||
@m.output()
|
||||
def parse_prologue(self):
|
||||
if self._get_expected("prologue", self._inbound_prologue):
|
||||
return Prologue()
|
||||
|
||||
@m.output()
|
||||
def can_send_frames(self):
|
||||
self._can_send_frames = True # for assertion in send_frame()
|
||||
|
||||
@m.output()
|
||||
def parse_frame(self):
|
||||
if len(self._buffer) < 4:
|
||||
return None
|
||||
frame_length = from_be4(self._buffer[0:4])
|
||||
if len(self._buffer) < 4 + frame_length:
|
||||
return None
|
||||
frame = self._buffer[4:4 + frame_length]
|
||||
self._buffer = self._buffer[4 + frame_length:] # TODO: avoid copy
|
||||
return Frame(frame=frame)
|
||||
|
||||
want_prologue.upon(use_relay, outputs=[store_relay_handshake],
|
||||
enter=want_relay)
|
||||
|
||||
want_relay.upon(connectionMade, outputs=[send_relay_handshake],
|
||||
enter=want_relay)
|
||||
want_relay.upon(parse, outputs=[parse_relay_ok], enter=want_relay,
|
||||
collector=first)
|
||||
want_relay.upon(got_relay_ok, outputs=[send_prologue], enter=want_prologue)
|
||||
|
||||
want_prologue.upon(connectionMade, outputs=[send_prologue],
|
||||
enter=want_prologue)
|
||||
want_prologue.upon(parse, outputs=[parse_prologue], enter=want_prologue,
|
||||
collector=first)
|
||||
want_prologue.upon(got_prologue, outputs=[can_send_frames], enter=want_frame)
|
||||
|
||||
want_frame.upon(parse, outputs=[parse_frame], enter=want_frame,
|
||||
collector=first)
|
||||
|
||||
def _get_expected(self, name, expected):
|
||||
lb = len(self._buffer)
|
||||
le = len(expected)
|
||||
if self._buffer.startswith(expected):
|
||||
# if the buffer starts with the expected string, consume it and
|
||||
# return True
|
||||
self._buffer = self._buffer[le:]
|
||||
return True
|
||||
if not expected.startswith(self._buffer):
|
||||
# we're not on track: the data we've received so far does not
|
||||
# match the expected value, so this can't possibly be right.
|
||||
# Don't complain until we see the expected length, or a newline,
|
||||
# so we can capture the weird input in the log for debugging.
|
||||
if (b"\n" in self._buffer or lb >= le):
|
||||
log.msg("bad {}: {}".format(name, self._buffer[:le]))
|
||||
raise Disconnect()
|
||||
return False # wait a bit longer
|
||||
# good so far, just waiting for the rest
|
||||
return False
|
||||
|
||||
# external API is: connectionMade, add_and_parse, and send_frame
|
||||
|
||||
def add_and_parse(self, data):
|
||||
# we can't make this an @m.input because we can't change the state
|
||||
# from within an input. Instead, let the state choose the parser to
|
||||
# use, and use the parsed token drive a state transition.
|
||||
self._buffer += data
|
||||
while True:
|
||||
# it'd be nice to use an iterator here, but since self.parse()
|
||||
# dispatches to a different parser (depending upon the current
|
||||
# state), we'd be using multiple iterators
|
||||
token = self.parse()
|
||||
if isinstance(token, RelayOK):
|
||||
self.got_relay_ok()
|
||||
elif isinstance(token, Prologue):
|
||||
self.got_prologue()
|
||||
yield token # triggers send_handshake
|
||||
elif isinstance(token, Frame):
|
||||
yield token
|
||||
else:
|
||||
break
|
||||
|
||||
def send_frame(self, frame):
|
||||
assert self._can_send_frames
|
||||
self._transport.write(to_be4(len(frame)) + frame)
|
||||
|
||||
# RelayOK: Newline-terminated buddy-is-connected response from Relay.
|
||||
# First data received from relay.
|
||||
# Prologue: double-newline-terminated this-is-really-wormhole response
|
||||
# from peer. First data received from peer.
|
||||
# Frame: Either handshake or encrypted message. Length-prefixed on wire.
|
||||
# Handshake: the Noise ephemeral key, first framed message
|
||||
# Message: plaintext: encoded KCM/PING/PONG/OPEN/DATA/CLOSE/ACK
|
||||
# KCM: Key Confirmation Message (encrypted b"\x00"). First frame
|
||||
# from peer. Sent immediately by Follower, after Selection by Leader.
|
||||
# Record: namedtuple of KCM/Open/Data/Close/Ack/Ping/Pong
|
||||
|
||||
|
||||
Handshake = namedtuple("Handshake", [])
|
||||
# decrypted frames: produces KCM, Ping, Pong, Open, Data, Close, Ack
|
||||
KCM = namedtuple("KCM", [])
|
||||
Ping = namedtuple("Ping", ["ping_id"]) # ping_id is arbitrary 4-byte value
|
||||
Pong = namedtuple("Pong", ["ping_id"])
|
||||
Open = namedtuple("Open", ["seqnum", "scid"]) # seqnum is integer
|
||||
Data = namedtuple("Data", ["seqnum", "scid", "data"])
|
||||
Close = namedtuple("Close", ["seqnum", "scid"]) # scid is integer
|
||||
Ack = namedtuple("Ack", ["resp_seqnum"]) # resp_seqnum is integer
|
||||
Records = (KCM, Ping, Pong, Open, Data, Close, Ack)
|
||||
Handshake_or_Records = (Handshake,) + Records
|
||||
|
||||
T_KCM = b"\x00"
|
||||
T_PING = b"\x01"
|
||||
T_PONG = b"\x02"
|
||||
T_OPEN = b"\x03"
|
||||
T_DATA = b"\x04"
|
||||
T_CLOSE = b"\x05"
|
||||
T_ACK = b"\x06"
|
||||
|
||||
|
||||
def parse_record(plaintext):
|
||||
msgtype = plaintext[0:1]
|
||||
if msgtype == T_KCM:
|
||||
return KCM()
|
||||
if msgtype == T_PING:
|
||||
ping_id = plaintext[1:5]
|
||||
return Ping(ping_id)
|
||||
if msgtype == T_PONG:
|
||||
ping_id = plaintext[1:5]
|
||||
return Pong(ping_id)
|
||||
if msgtype == T_OPEN:
|
||||
scid = from_be4(plaintext[1:5])
|
||||
seqnum = from_be4(plaintext[5:9])
|
||||
return Open(seqnum, scid)
|
||||
if msgtype == T_DATA:
|
||||
scid = from_be4(plaintext[1:5])
|
||||
seqnum = from_be4(plaintext[5:9])
|
||||
data = plaintext[9:]
|
||||
return Data(seqnum, scid, data)
|
||||
if msgtype == T_CLOSE:
|
||||
scid = from_be4(plaintext[1:5])
|
||||
seqnum = from_be4(plaintext[5:9])
|
||||
return Close(seqnum, scid)
|
||||
if msgtype == T_ACK:
|
||||
resp_seqnum = from_be4(plaintext[1:5])
|
||||
return Ack(resp_seqnum)
|
||||
log.err("received unknown message type: {}".format(plaintext))
|
||||
raise ValueError()
|
||||
|
||||
|
||||
def encode_record(r):
|
||||
if isinstance(r, KCM):
|
||||
return b"\x00"
|
||||
if isinstance(r, Ping):
|
||||
return b"\x01" + r.ping_id
|
||||
if isinstance(r, Pong):
|
||||
return b"\x02" + r.ping_id
|
||||
if isinstance(r, Open):
|
||||
assert isinstance(r.scid, six.integer_types)
|
||||
assert isinstance(r.seqnum, six.integer_types)
|
||||
return b"\x03" + to_be4(r.scid) + to_be4(r.seqnum)
|
||||
if isinstance(r, Data):
|
||||
assert isinstance(r.scid, six.integer_types)
|
||||
assert isinstance(r.seqnum, six.integer_types)
|
||||
return b"\x04" + to_be4(r.scid) + to_be4(r.seqnum) + r.data
|
||||
if isinstance(r, Close):
|
||||
assert isinstance(r.scid, six.integer_types)
|
||||
assert isinstance(r.seqnum, six.integer_types)
|
||||
return b"\x05" + to_be4(r.scid) + to_be4(r.seqnum)
|
||||
if isinstance(r, Ack):
|
||||
assert isinstance(r.resp_seqnum, six.integer_types)
|
||||
return b"\x06" + to_be4(r.resp_seqnum)
|
||||
raise TypeError(r)
|
||||
|
||||
|
||||
@attrs
|
||||
@implementer(IRecord)
|
||||
class _Record(object):
|
||||
_framer = attrib(validator=provides(IFramer))
|
||||
_noise = attrib()
|
||||
|
||||
n = MethodicalMachine()
|
||||
# TODO: set_trace
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
self._noise.start_handshake()
|
||||
|
||||
# in: role=
|
||||
# in: prologue_received, frame_received
|
||||
# out: handshake_received, record_received
|
||||
# out: transport.write (noise handshake, encrypted records)
|
||||
# states: want_prologue, want_handshake, want_record
|
||||
|
||||
@n.state(initial=True)
|
||||
def want_prologue(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@n.state()
|
||||
def want_handshake(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@n.state()
|
||||
def want_message(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@n.input()
|
||||
def got_prologue(self):
|
||||
pass
|
||||
|
||||
@n.input()
|
||||
def got_frame(self, frame):
|
||||
pass
|
||||
|
||||
@n.output()
|
||||
def send_handshake(self):
|
||||
handshake = self._noise.write_message() # generate the ephemeral key
|
||||
self._framer.send_frame(handshake)
|
||||
|
||||
@n.output()
|
||||
def process_handshake(self, frame):
|
||||
try:
|
||||
payload = self._noise.read_message(frame)
|
||||
# Noise can include unencrypted data in the handshake, but we don't
|
||||
# use it
|
||||
del payload
|
||||
except NoiseInvalidMessage as e:
|
||||
log.err(e, "bad inbound noise handshake")
|
||||
raise Disconnect()
|
||||
return Handshake()
|
||||
|
||||
@n.output()
|
||||
def decrypt_message(self, frame):
|
||||
try:
|
||||
message = self._noise.decrypt(frame)
|
||||
except NoiseInvalidMessage as e:
|
||||
# if this happens during tests, flunk the test
|
||||
log.err(e, "bad inbound noise frame")
|
||||
raise Disconnect()
|
||||
return parse_record(message)
|
||||
|
||||
want_prologue.upon(got_prologue, outputs=[send_handshake],
|
||||
enter=want_handshake)
|
||||
want_handshake.upon(got_frame, outputs=[process_handshake],
|
||||
collector=first, enter=want_message)
|
||||
want_message.upon(got_frame, outputs=[decrypt_message],
|
||||
collector=first, enter=want_message)
|
||||
|
||||
# external API is: connectionMade, dataReceived, send_record
|
||||
|
||||
def connectionMade(self):
|
||||
self._framer.connectionMade()
|
||||
|
||||
def add_and_unframe(self, data):
|
||||
for token in self._framer.add_and_parse(data):
|
||||
if isinstance(token, Prologue):
|
||||
self.got_prologue() # triggers send_handshake
|
||||
else:
|
||||
assert isinstance(token, Frame)
|
||||
yield self.got_frame(token.frame) # Handshake or a Record type
|
||||
|
||||
def send_record(self, r):
|
||||
message = encode_record(r)
|
||||
frame = self._noise.encrypt(message)
|
||||
self._framer.send_frame(frame)
|
||||
|
||||
|
||||
@attrs
|
||||
class DilatedConnectionProtocol(Protocol, object):
|
||||
"""I manage an L2 connection.
|
||||
|
||||
When a new L2 connection is needed (as determined by the Leader),
|
||||
both Leader and Follower will initiate many simultaneous connections
|
||||
(probably TCP, but conceivably others). A subset will actually
|
||||
connect. A subset of those will successfully pass negotiation by
|
||||
exchanging handshakes to demonstrate knowledge of the session key.
|
||||
One of the negotiated connections will be selected by the Leader for
|
||||
active use, and the others will be dropped.
|
||||
|
||||
At any given time, there is at most one active L2 connection.
|
||||
"""
|
||||
|
||||
_eventual_queue = attrib()
|
||||
_role = attrib()
|
||||
_connector = attrib(validator=provides(IDilationConnector))
|
||||
_noise = attrib()
|
||||
_outbound_prologue = attrib(validator=instance_of(bytes))
|
||||
_inbound_prologue = attrib(validator=instance_of(bytes))
|
||||
|
||||
_use_relay = False
|
||||
_relay_handshake = None
|
||||
|
||||
m = MethodicalMachine()
|
||||
set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
self._manager = None # set if/when we are selected
|
||||
self._disconnected = OneShotObserver(self._eventual_queue)
|
||||
self._can_send_records = False
|
||||
|
||||
@m.state(initial=True)
|
||||
def unselected(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.state()
|
||||
def selecting(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.state()
|
||||
def selected(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.input()
|
||||
def got_kcm(self):
|
||||
pass
|
||||
|
||||
@m.input()
|
||||
def select(self, manager):
|
||||
pass # fires set_manager()
|
||||
|
||||
@m.input()
|
||||
def got_record(self, record):
|
||||
pass
|
||||
|
||||
@m.output()
|
||||
def add_candidate(self):
|
||||
self._connector.add_candidate(self)
|
||||
|
||||
@m.output()
|
||||
def set_manager(self, manager):
|
||||
self._manager = manager
|
||||
|
||||
@m.output()
|
||||
def can_send_records(self, manager):
|
||||
self._can_send_records = True
|
||||
|
||||
@m.output()
|
||||
def deliver_record(self, record):
|
||||
self._manager.got_record(record)
|
||||
|
||||
unselected.upon(got_kcm, outputs=[add_candidate], enter=selecting)
|
||||
selecting.upon(select, outputs=[set_manager, can_send_records], enter=selected)
|
||||
selected.upon(got_record, outputs=[deliver_record], enter=selected)
|
||||
|
||||
# called by Connector
|
||||
|
||||
def use_relay(self, relay_handshake):
|
||||
assert isinstance(relay_handshake, bytes)
|
||||
self._use_relay = True
|
||||
self._relay_handshake = relay_handshake
|
||||
|
||||
def when_disconnected(self):
|
||||
return self._disconnected.when_fired()
|
||||
|
||||
def disconnect(self):
|
||||
self.transport.loseConnection()
|
||||
|
||||
# select() called by Connector
|
||||
|
||||
# called by Manager
|
||||
def send_record(self, record):
|
||||
assert self._can_send_records
|
||||
self._record.send_record(record)
|
||||
|
||||
# IProtocol methods
|
||||
|
||||
def connectionMade(self):
|
||||
framer = _Framer(self.transport,
|
||||
self._outbound_prologue, self._inbound_prologue)
|
||||
if self._use_relay:
|
||||
framer.use_relay(self._relay_handshake)
|
||||
self._record = _Record(framer, self._noise)
|
||||
self._record.connectionMade()
|
||||
|
||||
def dataReceived(self, data):
|
||||
try:
|
||||
for token in self._record.add_and_unframe(data):
|
||||
assert isinstance(token, Handshake_or_Records)
|
||||
if isinstance(token, Handshake):
|
||||
if self._role is FOLLOWER:
|
||||
self._record.send_record(KCM())
|
||||
elif isinstance(token, KCM):
|
||||
# if we're the leader, add this connection as a candiate.
|
||||
# if we're the follower, accept this connection.
|
||||
self.got_kcm() # connector.add_candidate()
|
||||
else:
|
||||
self.got_record(token) # manager.got_record()
|
||||
except Disconnect:
|
||||
self.transport.loseConnection()
|
||||
|
||||
def connectionLost(self, why=None):
|
||||
self._disconnected.fire(self)
|
387
src/wormhole/_dilation/connector.py
Normal file
387
src/wormhole/_dilation/connector.py
Normal file
|
@ -0,0 +1,387 @@
|
|||
from __future__ import print_function, unicode_literals
|
||||
from collections import defaultdict
|
||||
from binascii import hexlify
|
||||
from attr import attrs, attrib
|
||||
from attr.validators import instance_of, provides, optional
|
||||
from automat import MethodicalMachine
|
||||
from zope.interface import implementer
|
||||
from twisted.internet.task import deferLater
|
||||
from twisted.internet.defer import DeferredList
|
||||
from twisted.internet.endpoints import serverFromString
|
||||
from twisted.internet.protocol import ClientFactory, ServerFactory
|
||||
from twisted.python import log
|
||||
from .. import ipaddrs # TODO: move into _dilation/
|
||||
from .._interfaces import IDilationConnector, IDilationManager
|
||||
from ..timing import DebugTiming
|
||||
from ..observer import EmptyableSet
|
||||
from ..util import HKDF, to_unicode
|
||||
from .connection import DilatedConnectionProtocol, KCM
|
||||
from .roles import LEADER
|
||||
|
||||
from .._hints import (DirectTCPV1Hint, TorTCPV1Hint, RelayV1Hint,
|
||||
parse_hint_argv, describe_hint_obj, endpoint_from_hint_obj,
|
||||
encode_hint)
|
||||
from ._noise import NoiseConnection
|
||||
|
||||
|
||||
def build_sided_relay_handshake(key, side):
|
||||
assert isinstance(side, type(u""))
|
||||
assert len(side) == 8 * 2
|
||||
token = HKDF(key, 32, CTXinfo=b"transit_relay_token")
|
||||
return (b"please relay " + hexlify(token) +
|
||||
b" for side " + side.encode("ascii") + b"\n")
|
||||
|
||||
|
||||
PROLOGUE_LEADER = b"Magic-Wormhole Dilation Handshake v1 Leader\n\n"
|
||||
PROLOGUE_FOLLOWER = b"Magic-Wormhole Dilation Handshake v1 Follower\n\n"
|
||||
NOISEPROTO = b"Noise_NNpsk0_25519_ChaChaPoly_BLAKE2s"
|
||||
|
||||
def build_noise():
|
||||
return NoiseConnection.from_name(NOISEPROTO)
|
||||
|
||||
@attrs
|
||||
@implementer(IDilationConnector)
|
||||
class Connector(object):
|
||||
_dilation_key = attrib(validator=instance_of(type(b"")))
|
||||
_transit_relay_location = attrib(validator=optional(instance_of(type(u""))))
|
||||
_manager = attrib(validator=provides(IDilationManager))
|
||||
_reactor = attrib()
|
||||
_eventual_queue = attrib()
|
||||
_no_listen = attrib(validator=instance_of(bool))
|
||||
_tor = attrib()
|
||||
_timing = attrib()
|
||||
_side = attrib(validator=instance_of(type(u"")))
|
||||
# was self._side = bytes_to_hexstr(os.urandom(8)) # unicode
|
||||
_role = attrib()
|
||||
|
||||
m = MethodicalMachine()
|
||||
set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover
|
||||
|
||||
RELAY_DELAY = 2.0
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
if self._transit_relay_location:
|
||||
# TODO: allow multiple hints for a single relay
|
||||
relay_hint = parse_hint_argv(self._transit_relay_location)
|
||||
relay = RelayV1Hint(hints=(relay_hint,))
|
||||
self._transit_relays = [relay]
|
||||
else:
|
||||
self._transit_relays = []
|
||||
self._listeners = set() # IListeningPorts that can be stopped
|
||||
self._pending_connectors = set() # Deferreds that can be cancelled
|
||||
self._pending_connections = EmptyableSet(
|
||||
_eventual_queue=self._eventual_queue) # Protocols to be stopped
|
||||
self._contenders = set() # viable connections
|
||||
self._winning_connection = None
|
||||
self._timing = self._timing or DebugTiming()
|
||||
self._timing.add("transit")
|
||||
|
||||
# this describes what our Connector can do, for the initial advertisement
|
||||
@classmethod
|
||||
def get_connection_abilities(klass):
|
||||
return [{"type": "direct-tcp-v1"},
|
||||
{"type": "relay-v1"},
|
||||
]
|
||||
|
||||
def build_protocol(self, addr):
|
||||
# encryption: let's use Noise NNpsk0 (or maybe NNpsk2). That uses
|
||||
# ephemeral keys plus a pre-shared symmetric key (the Transit key), a
|
||||
# different one for each potential connection.
|
||||
noise = build_noise()
|
||||
noise.set_psks(self._dilation_key)
|
||||
if self._role is LEADER:
|
||||
noise.set_as_initiator()
|
||||
outbound_prologue = PROLOGUE_LEADER
|
||||
inbound_prologue = PROLOGUE_FOLLOWER
|
||||
else:
|
||||
noise.set_as_responder()
|
||||
outbound_prologue = PROLOGUE_FOLLOWER
|
||||
inbound_prologue = PROLOGUE_LEADER
|
||||
p = DilatedConnectionProtocol(self._eventual_queue, self._role,
|
||||
self, noise,
|
||||
outbound_prologue, inbound_prologue)
|
||||
return p
|
||||
|
||||
@m.state(initial=True)
|
||||
def connecting(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.state()
|
||||
def connected(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.state(terminal=True)
|
||||
def stopped(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
# TODO: unify the tense of these method-name verbs
|
||||
|
||||
# add_relay() and got_hints() are called by the Manager as it receives
|
||||
# messages from our peer. stop() is called when the Manager shuts down
|
||||
@m.input()
|
||||
def add_relay(self, hint_objs):
|
||||
pass
|
||||
|
||||
@m.input()
|
||||
def got_hints(self, hint_objs):
|
||||
pass
|
||||
|
||||
@m.input()
|
||||
def stop(self):
|
||||
pass
|
||||
|
||||
# called by ourselves, when _start_listener() is ready
|
||||
@m.input()
|
||||
def listener_ready(self, hint_objs):
|
||||
pass
|
||||
|
||||
# called when DilatedConnectionProtocol submits itself, after KCM
|
||||
# received
|
||||
@m.input()
|
||||
def add_candidate(self, c):
|
||||
pass
|
||||
|
||||
# called by ourselves, via consider()
|
||||
@m.input()
|
||||
def accept(self, c):
|
||||
pass
|
||||
|
||||
|
||||
@m.output()
|
||||
def use_hints(self, hint_objs):
|
||||
self._use_hints(hint_objs)
|
||||
|
||||
@m.output()
|
||||
def publish_hints(self, hint_objs):
|
||||
self._publish_hints(hint_objs)
|
||||
|
||||
def _publish_hints(self, hint_objs):
|
||||
self._manager.send_hints([encode_hint(h) for h in hint_objs])
|
||||
|
||||
@m.output()
|
||||
def consider(self, c):
|
||||
self._contenders.add(c)
|
||||
if self._role is LEADER:
|
||||
# for now, just accept the first one. TODO: be clever.
|
||||
self._eventual_queue.eventually(self.accept, c)
|
||||
else:
|
||||
# the follower always uses the first contender, since that's the
|
||||
# only one the leader picked
|
||||
self._eventual_queue.eventually(self.accept, c)
|
||||
|
||||
@m.output()
|
||||
def select_and_stop_remaining(self, c):
|
||||
self._winning_connection = c
|
||||
self._contenders.clear() # we no longer care who else came close
|
||||
# remove this winner from the losers, so we don't shut it down
|
||||
self._pending_connections.discard(c)
|
||||
# shut down losing connections
|
||||
self.stop_listeners() # TODO: maybe keep it open? NAT/p2p assist
|
||||
self.stop_pending_connectors()
|
||||
self.stop_pending_connections()
|
||||
|
||||
c.select(self._manager) # subsequent frames go directly to the manager
|
||||
if self._role is LEADER:
|
||||
# TODO: this should live in Connection
|
||||
c.send_record(KCM()) # leader sends KCM now
|
||||
self._manager.use_connection(c) # manager sends frames to Connection
|
||||
|
||||
@m.output()
|
||||
def stop_everything(self):
|
||||
self.stop_listeners()
|
||||
self.stop_pending_connectors()
|
||||
self.stop_pending_connections()
|
||||
self.break_cycles()
|
||||
|
||||
def stop_listeners(self):
|
||||
d = DeferredList([l.stopListening() for l in self._listeners])
|
||||
self._listeners.clear()
|
||||
return d # synchronization for tests
|
||||
|
||||
def stop_pending_connectors(self):
|
||||
return DeferredList([d.cancel() for d in self._pending_connectors])
|
||||
|
||||
def stop_pending_connections(self):
|
||||
d = self._pending_connections.when_next_empty()
|
||||
[c.loseConnection() for c in self._pending_connections]
|
||||
return d
|
||||
|
||||
def break_cycles(self):
|
||||
# help GC by forgetting references to things that reference us
|
||||
self._listeners.clear()
|
||||
self._pending_connectors.clear()
|
||||
self._pending_connections.clear()
|
||||
self._winning_connection = None
|
||||
|
||||
connecting.upon(listener_ready, enter=connecting, outputs=[publish_hints])
|
||||
connecting.upon(add_relay, enter=connecting, outputs=[use_hints,
|
||||
publish_hints])
|
||||
connecting.upon(got_hints, enter=connecting, outputs=[use_hints])
|
||||
connecting.upon(add_candidate, enter=connecting, outputs=[consider])
|
||||
connecting.upon(accept, enter=connected, outputs=[
|
||||
select_and_stop_remaining])
|
||||
connecting.upon(stop, enter=stopped, outputs=[stop_everything])
|
||||
|
||||
# once connected, we ignore everything except stop
|
||||
connected.upon(listener_ready, enter=connected, outputs=[])
|
||||
connected.upon(add_relay, enter=connected, outputs=[])
|
||||
connected.upon(got_hints, enter=connected, outputs=[])
|
||||
# TODO: tell them to disconnect? will they hang out forever? I *think*
|
||||
# they'll drop this once they get a KCM on the winning connection.
|
||||
connected.upon(add_candidate, enter=connected, outputs=[])
|
||||
connected.upon(accept, enter=connected, outputs=[])
|
||||
connected.upon(stop, enter=stopped, outputs=[stop_everything])
|
||||
|
||||
# from Manager: start, got_hints, stop
|
||||
# maybe add_candidate, accept
|
||||
|
||||
def start(self):
|
||||
if not self._no_listen and not self._tor:
|
||||
addresses = self._get_listener_addresses()
|
||||
self._start_listener(addresses)
|
||||
if self._transit_relays:
|
||||
self._publish_hints(self._transit_relays)
|
||||
self._use_hints(self._transit_relays)
|
||||
|
||||
def _get_listener_addresses(self):
|
||||
addresses = ipaddrs.find_addresses()
|
||||
non_loopback_addresses = [a for a in addresses if a != "127.0.0.1"]
|
||||
if non_loopback_addresses:
|
||||
# some test hosts, including the appveyor VMs, *only* have
|
||||
# 127.0.0.1, and the tests will hang badly if we remove it.
|
||||
addresses = non_loopback_addresses
|
||||
return addresses
|
||||
|
||||
def _start_listener(self, addresses):
|
||||
# TODO: listen on a fixed port, if possible, for NAT/p2p benefits, also
|
||||
# to make firewall configs easier
|
||||
# TODO: retain listening port between connection generations?
|
||||
ep = serverFromString(self._reactor, "tcp:0")
|
||||
f = InboundConnectionFactory(self)
|
||||
d = ep.listen(f)
|
||||
|
||||
def _listening(lp):
|
||||
# lp is an IListeningPort
|
||||
self._listeners.add(lp) # for shutdown and tests
|
||||
portnum = lp.getHost().port
|
||||
direct_hints = [DirectTCPV1Hint(to_unicode(addr), portnum, 0.0)
|
||||
for addr in addresses]
|
||||
self.listener_ready(direct_hints)
|
||||
d.addCallback(_listening)
|
||||
d.addErrback(log.err)
|
||||
|
||||
def _schedule_connection(self, delay, h, is_relay):
|
||||
ep = endpoint_from_hint_obj(h, self._tor, self._reactor)
|
||||
desc = describe_hint_obj(h, is_relay, self._tor)
|
||||
d = deferLater(self._reactor, delay,
|
||||
self._connect, ep, desc, is_relay)
|
||||
d.addErrback(log.err)
|
||||
self._pending_connectors.add(d)
|
||||
|
||||
def _use_hints(self, hints):
|
||||
# first, pull out all the relays, we'll connect to them later
|
||||
relays = []
|
||||
direct = defaultdict(list)
|
||||
for h in hints:
|
||||
if isinstance(h, RelayV1Hint):
|
||||
relays.append(h)
|
||||
else:
|
||||
direct[h.priority].append(h)
|
||||
delay = 0.0
|
||||
made_direct = False
|
||||
priorities = sorted(set(direct.keys()), reverse=True)
|
||||
for p in priorities:
|
||||
for h in direct[p]:
|
||||
if isinstance(h, TorTCPV1Hint) and not self._tor:
|
||||
continue
|
||||
self._schedule_connection(delay, h, is_relay=False)
|
||||
made_direct = True
|
||||
# Make all direct connections immediately. Later, we'll change
|
||||
# the add_candidate() function to look at the priority when
|
||||
# deciding whether to accept a successful connection or not,
|
||||
# and it can wait for more options if it sees a higher-priority
|
||||
# one still running. But if we bail on that, we might consider
|
||||
# putting an inter-direct-hint delay here to influence the
|
||||
# process.
|
||||
# delay += 1.0
|
||||
|
||||
if made_direct and not self._no_listen:
|
||||
# Prefer direct connections by stalling relay connections by a
|
||||
# few seconds. We don't wait until direct connections have
|
||||
# failed, because many direct hints will be to unused
|
||||
# local-network IP address, which won't answer, and can take the
|
||||
# full 30s TCP timeout to fail.
|
||||
#
|
||||
# If we didn't make any direct connections, or we're using
|
||||
# --no-listen, then we're probably going to have to use the
|
||||
# relay, so don't delay it at all.
|
||||
delay += self.RELAY_DELAY
|
||||
|
||||
# It might be nice to wire this so that a failure in the direct hints
|
||||
# causes the relay hints to be used right away (fast failover). But
|
||||
# none of our current use cases would take advantage of that: if we
|
||||
# have any viable direct hints, then they're either going to succeed
|
||||
# quickly or hang for a long time.
|
||||
for r in relays:
|
||||
for h in r.hints:
|
||||
self._schedule_connection(delay, h, is_relay=True)
|
||||
# TODO:
|
||||
# if not contenders:
|
||||
# raise TransitError("No contenders for connection")
|
||||
|
||||
# TODO: add 2*TIMEOUT deadline for first generation, don't wait forever for
|
||||
# the initial connection
|
||||
|
||||
def _connect(self, ep, description, is_relay=False):
|
||||
relay_handshake = None
|
||||
if is_relay:
|
||||
relay_handshake = build_sided_relay_handshake(self._dilation_key,
|
||||
self._side)
|
||||
f = OutboundConnectionFactory(self, relay_handshake)
|
||||
d = ep.connect(f)
|
||||
# fires with protocol, or ConnectError
|
||||
|
||||
def _connected(p):
|
||||
self._pending_connections.add(p)
|
||||
# c might not be in _pending_connections, if it turned out to be a
|
||||
# winner, which is why we use discard() and not remove()
|
||||
p.when_disconnected().addCallback(self._pending_connections.discard)
|
||||
d.addCallback(_connected)
|
||||
return d
|
||||
|
||||
# Connection selection. All instances of DilatedConnectionProtocol which
|
||||
# look viable get passed into our add_contender() method.
|
||||
|
||||
# On the Leader side, "viable" means we've seen their KCM frame, which is
|
||||
# the first Noise-encrypted packet on any given connection, and it has an
|
||||
# empty body. We gather viable connections until we see one that we like,
|
||||
# or a timer expires. Then we "select" it, close the others, and tell our
|
||||
# Manager to use it.
|
||||
|
||||
# On the Follower side, we'll only see a KCM on the one connection selected
|
||||
# by the Leader, so the first viable connection wins.
|
||||
|
||||
# our Connection protocols call: add_candidate
|
||||
|
||||
|
||||
@attrs
|
||||
class OutboundConnectionFactory(ClientFactory, object):
|
||||
_connector = attrib(validator=provides(IDilationConnector))
|
||||
_relay_handshake = attrib(validator=optional(instance_of(bytes)))
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
p = self._connector.build_protocol(addr)
|
||||
p.factory = self
|
||||
if self._relay_handshake is not None:
|
||||
p.use_relay(self._relay_handshake)
|
||||
return p
|
||||
|
||||
|
||||
@attrs
|
||||
class InboundConnectionFactory(ServerFactory, object):
|
||||
_connector = attrib(validator=provides(IDilationConnector))
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
p = self._connector.build_protocol(addr)
|
||||
p.factory = self
|
||||
return p
|
19
src/wormhole/_dilation/encode.py
Normal file
19
src/wormhole/_dilation/encode.py
Normal file
|
@ -0,0 +1,19 @@
|
|||
from __future__ import print_function, unicode_literals
|
||||
import struct
|
||||
|
||||
assert len(struct.pack("<L", 0)) == 4
|
||||
assert len(struct.pack("<Q", 0)) == 8
|
||||
|
||||
|
||||
def to_be4(value):
|
||||
if not 0 <= value < 2**32:
|
||||
raise ValueError
|
||||
return struct.pack(">L", value)
|
||||
|
||||
|
||||
def from_be4(b):
|
||||
if not isinstance(b, bytes):
|
||||
raise TypeError(repr(b))
|
||||
if len(b) != 4:
|
||||
raise ValueError
|
||||
return struct.unpack(">L", b)[0]
|
134
src/wormhole/_dilation/inbound.py
Normal file
134
src/wormhole/_dilation/inbound.py
Normal file
|
@ -0,0 +1,134 @@
|
|||
from __future__ import print_function, unicode_literals
|
||||
from attr import attrs, attrib
|
||||
from attr.validators import provides
|
||||
from zope.interface import implementer
|
||||
from twisted.python import log
|
||||
from .._interfaces import IDilationManager, IInbound
|
||||
from .subchannel import (SubChannel, _SubchannelAddress)
|
||||
|
||||
|
||||
class DuplicateOpenError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DataForMissingSubchannelError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class CloseForMissingSubchannelError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@attrs
|
||||
@implementer(IInbound)
|
||||
class Inbound(object):
|
||||
# Inbound flow control: TCP delivers data to Connection.dataReceived,
|
||||
# Connection delivers to our handle_data, we deliver to
|
||||
# SubChannel.remote_data, subchannel delivers to proto.dataReceived
|
||||
_manager = attrib(validator=provides(IDilationManager))
|
||||
_host_addr = attrib()
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
# we route inbound Data records to Subchannels .dataReceived
|
||||
self._open_subchannels = {} # scid -> Subchannel
|
||||
self._paused_subchannels = set() # Subchannels that have paused us
|
||||
# the set is non-empty, we pause the transport
|
||||
self._highest_inbound_acked = -1
|
||||
self._connection = None
|
||||
|
||||
# from our Manager
|
||||
def set_listener_endpoint(self, listener_endpoint):
|
||||
self._listener_endpoint = listener_endpoint
|
||||
|
||||
def set_subchannel_zero(self, scid0, sc0):
|
||||
self._open_subchannels[scid0] = sc0
|
||||
|
||||
def use_connection(self, c):
|
||||
self._connection = c
|
||||
# We can pause the connection's reads when we receive too much data. If
|
||||
# this is a non-initial connection, then we might already have
|
||||
# subchannels that are paused from before, so we might need to pause
|
||||
# the new connection before it can send us any data
|
||||
if self._paused_subchannels:
|
||||
self._connection.pauseProducing()
|
||||
|
||||
# Inbound is responsible for tracking the high watermark and deciding
|
||||
# whether to ignore inbound messages or not
|
||||
|
||||
def is_record_old(self, r):
|
||||
if r.seqnum <= self._highest_inbound_acked:
|
||||
return True
|
||||
return False
|
||||
|
||||
def update_ack_watermark(self, r):
|
||||
self._highest_inbound_acked = max(self._highest_inbound_acked,
|
||||
r.seqnum)
|
||||
|
||||
def handle_open(self, scid):
|
||||
if scid in self._open_subchannels:
|
||||
log.err(DuplicateOpenError(
|
||||
"received duplicate OPEN for {}".format(scid)))
|
||||
return
|
||||
peer_addr = _SubchannelAddress(scid)
|
||||
sc = SubChannel(scid, self._manager, self._host_addr, peer_addr)
|
||||
self._open_subchannels[scid] = sc
|
||||
self._listener_endpoint._got_open(sc, peer_addr)
|
||||
|
||||
def handle_data(self, scid, data):
|
||||
sc = self._open_subchannels.get(scid)
|
||||
if sc is None:
|
||||
log.err(DataForMissingSubchannelError(
|
||||
"received DATA for non-existent subchannel {}".format(scid)))
|
||||
return
|
||||
sc.remote_data(data)
|
||||
|
||||
def handle_close(self, scid):
|
||||
sc = self._open_subchannels.get(scid)
|
||||
if sc is None:
|
||||
log.err(CloseForMissingSubchannelError(
|
||||
"received CLOSE for non-existent subchannel {}".format(scid)))
|
||||
return
|
||||
sc.remote_close()
|
||||
|
||||
def subchannel_closed(self, scid, sc):
|
||||
# connectionLost has just been signalled
|
||||
assert self._open_subchannels[scid] is sc
|
||||
del self._open_subchannels[scid]
|
||||
|
||||
def stop_using_connection(self):
|
||||
self._connection = None
|
||||
|
||||
# from our Subchannel, or rather from the Protocol above it and sent
|
||||
# through the subchannel
|
||||
|
||||
# The subchannel is an IProducer, and application protocols can always
|
||||
# thell them to pauseProducing if we're delivering inbound data too
|
||||
# quickly. They don't need to register anything.
|
||||
|
||||
def subchannel_pauseProducing(self, sc):
|
||||
was_paused = bool(self._paused_subchannels)
|
||||
self._paused_subchannels.add(sc)
|
||||
if self._connection and not was_paused:
|
||||
self._connection.pauseProducing()
|
||||
|
||||
def subchannel_resumeProducing(self, sc):
|
||||
was_paused = bool(self._paused_subchannels)
|
||||
self._paused_subchannels.discard(sc)
|
||||
if self._connection and was_paused and not self._paused_subchannels:
|
||||
self._connection.resumeProducing()
|
||||
|
||||
def subchannel_stopProducing(self, sc):
|
||||
# This protocol doesn't want any additional data. If we were a normal
|
||||
# (single-owner) Transport, we'd call .loseConnection now. But our
|
||||
# Connection is shared among many subchannels, so instead we just
|
||||
# stop letting them pause the connection.
|
||||
was_paused = bool(self._paused_subchannels)
|
||||
self._paused_subchannels.discard(sc)
|
||||
if self._connection and was_paused and not self._paused_subchannels:
|
||||
self._connection.resumeProducing()
|
||||
|
||||
# TODO: we might refactor these pause/resume/stop methods by building a
|
||||
# context manager that look at the paused/not-paused state first, then
|
||||
# lets the caller modify self._paused_subchannels, then looks at it a
|
||||
# second time, and calls c.pauseProducing/c.resumeProducing as
|
||||
# appropriate. I'm not sure it would be any cleaner, though.
|
580
src/wormhole/_dilation/manager.py
Normal file
580
src/wormhole/_dilation/manager.py
Normal file
|
@ -0,0 +1,580 @@
|
|||
from __future__ import print_function, unicode_literals
|
||||
import os
|
||||
from collections import deque
|
||||
from attr import attrs, attrib
|
||||
from attr.validators import provides, instance_of, optional
|
||||
from automat import MethodicalMachine
|
||||
from zope.interface import implementer
|
||||
from twisted.internet.defer import Deferred, inlineCallbacks, returnValue
|
||||
from twisted.python import log
|
||||
from .._interfaces import IDilator, IDilationManager, ISend
|
||||
from ..util import dict_to_bytes, bytes_to_dict, bytes_to_hexstr
|
||||
from ..observer import OneShotObserver
|
||||
from .._key import derive_key
|
||||
from .encode import to_be4
|
||||
from .subchannel import (SubChannel, _SubchannelAddress, _WormholeAddress,
|
||||
ControlEndpoint, SubchannelConnectorEndpoint,
|
||||
SubchannelListenerEndpoint)
|
||||
from .connector import Connector
|
||||
from .._hints import parse_hint
|
||||
from .roles import LEADER, FOLLOWER
|
||||
from .connection import KCM, Ping, Pong, Open, Data, Close, Ack
|
||||
from .inbound import Inbound
|
||||
from .outbound import Outbound
|
||||
|
||||
|
||||
# exported to Wormhole() for inclusion in versions message
|
||||
DILATION_VERSIONS = ["1"]
|
||||
|
||||
|
||||
class OldPeerCannotDilateError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class UnknownDilationMessageType(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ReceivedHintsTooEarly(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class UnexpectedKCM(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class UnknownMessageType(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def make_side():
|
||||
return bytes_to_hexstr(os.urandom(6))
|
||||
|
||||
|
||||
# new scheme:
|
||||
# * both sides send PLEASE as soon as they have an unverified key and
|
||||
# w.dilate has been called,
|
||||
# * PLEASE includes a dilation-specific "side" (independent of the "side"
|
||||
# used by mailbox messages)
|
||||
# * higher "side" is Leader, lower is Follower
|
||||
# * PLEASE includes can-dilate list of version integers, requires overlap
|
||||
# "1" is current
|
||||
|
||||
# * we start dilation after both w.dilate() and receiving VERSION, putting us
|
||||
# in WANTING, then we process all previously-queued inbound DILATE-n
|
||||
# messages. When PLEASE arrives, we move to CONNECTING
|
||||
# * HINTS sent after dilation starts
|
||||
# * only Leader sends RECONNECT, only Follower sends RECONNECTING. This
|
||||
# is the only difference between the two sides, and is not enforced
|
||||
# by the protocol (i.e. if the Follower sends RECONNECT to the Leader,
|
||||
# the Leader will obey, although TODO how confusing will this get?)
|
||||
# * upon receiving RECONNECT: drop Connector, start new Connector, send
|
||||
# RECONNECTING, start sending HINTS
|
||||
# * upon sending RECONNECT: go into FLUSHING state and ignore all HINTS until
|
||||
# RECONNECTING received. The new Connector can be spun up earlier, and it
|
||||
# can send HINTS, but it must not be given any HINTS that arrive before
|
||||
# RECONNECTING (since they're probably stale)
|
||||
|
||||
# * after VERSIONS(KCM) received, we might learn that they other side cannot
|
||||
# dilate. w.dilate errbacks at this point
|
||||
|
||||
# * maybe signal warning if we stay in a "want" state for too long
|
||||
# * nobody sends HINTS until they're ready to receive
|
||||
# * nobody sends HINTS unless they've called w.dilate() and received PLEASE
|
||||
# * nobody connects to inbound hints unless they've called w.dilate()
|
||||
# * if leader calls w.dilate() but not follower, leader waits forever in
|
||||
# "want" (doesn't send anything)
|
||||
# * if follower calls w.dilate() but not leader, follower waits forever
|
||||
# in "want", leader waits forever in "wanted"
|
||||
|
||||
@attrs
|
||||
@implementer(IDilationManager)
|
||||
class Manager(object):
|
||||
_S = attrib(validator=provides(ISend))
|
||||
_my_side = attrib(validator=instance_of(type(u"")))
|
||||
_transit_key = attrib(validator=instance_of(bytes))
|
||||
_transit_relay_location = attrib(validator=optional(instance_of(str)))
|
||||
_reactor = attrib()
|
||||
_eventual_queue = attrib()
|
||||
_cooperator = attrib()
|
||||
_no_listen = False # TODO
|
||||
_tor = None # TODO
|
||||
_timing = None # TODO
|
||||
_next_subchannel_id = None # initialized in choose_role
|
||||
|
||||
m = MethodicalMachine()
|
||||
set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
self._got_versions_d = Deferred()
|
||||
|
||||
self._my_role = None # determined upon rx_PLEASE
|
||||
|
||||
self._connection = None
|
||||
self._made_first_connection = False
|
||||
self._first_connected = OneShotObserver(self._eventual_queue)
|
||||
self._host_addr = _WormholeAddress()
|
||||
|
||||
self._next_dilation_phase = 0
|
||||
|
||||
# I kept getting confused about which methods were for inbound data
|
||||
# (and thus flow-control methods go "out") and which were for
|
||||
# outbound data (with flow-control going "in"), so I split them up
|
||||
# into separate pieces.
|
||||
self._inbound = Inbound(self, self._host_addr)
|
||||
self._outbound = Outbound(self, self._cooperator) # from us to peer
|
||||
|
||||
def set_listener_endpoint(self, listener_endpoint):
|
||||
self._inbound.set_listener_endpoint(listener_endpoint)
|
||||
|
||||
def set_subchannel_zero(self, scid0, sc0):
|
||||
self._inbound.set_subchannel_zero(scid0, sc0)
|
||||
|
||||
def when_first_connected(self):
|
||||
return self._first_connected.when_fired()
|
||||
|
||||
def send_dilation_phase(self, **fields):
|
||||
dilation_phase = self._next_dilation_phase
|
||||
self._next_dilation_phase += 1
|
||||
self._S.send("dilate-%d" % dilation_phase, dict_to_bytes(fields))
|
||||
|
||||
def send_hints(self, hints): # from Connector
|
||||
self.send_dilation_phase(type="connection-hints", hints=hints)
|
||||
|
||||
# forward inbound-ish things to _Inbound
|
||||
|
||||
def subchannel_pauseProducing(self, sc):
|
||||
self._inbound.subchannel_pauseProducing(sc)
|
||||
|
||||
def subchannel_resumeProducing(self, sc):
|
||||
self._inbound.subchannel_resumeProducing(sc)
|
||||
|
||||
def subchannel_stopProducing(self, sc):
|
||||
self._inbound.subchannel_stopProducing(sc)
|
||||
|
||||
# forward outbound-ish things to _Outbound
|
||||
def subchannel_registerProducer(self, sc, producer, streaming):
|
||||
self._outbound.subchannel_registerProducer(sc, producer, streaming)
|
||||
|
||||
def subchannel_unregisterProducer(self, sc):
|
||||
self._outbound.subchannel_unregisterProducer(sc)
|
||||
|
||||
def send_open(self, scid):
|
||||
self._queue_and_send(Open, scid)
|
||||
|
||||
def send_data(self, scid, data):
|
||||
self._queue_and_send(Data, scid, data)
|
||||
|
||||
def send_close(self, scid):
|
||||
self._queue_and_send(Close, scid)
|
||||
|
||||
def _queue_and_send(self, record_type, *args):
|
||||
r = self._outbound.build_record(record_type, *args)
|
||||
# Outbound owns the send_record() pipe, so that it can stall new
|
||||
# writes after a new connection is made until after all queued
|
||||
# messages are written (to preserve ordering).
|
||||
self._outbound.queue_and_send_record(r) # may trigger pauseProducing
|
||||
|
||||
def subchannel_closed(self, scid, sc):
|
||||
# let everyone clean up. This happens just after we delivered
|
||||
# connectionLost to the Protocol, except for the control channel,
|
||||
# which might get connectionLost later after they use ep.connect.
|
||||
# TODO: is this inversion a problem?
|
||||
self._inbound.subchannel_closed(scid, sc)
|
||||
self._outbound.subchannel_closed(scid, sc)
|
||||
|
||||
# our Connector calls these
|
||||
|
||||
def connector_connection_made(self, c):
|
||||
self.connection_made() # state machine update
|
||||
self._connection = c
|
||||
self._inbound.use_connection(c)
|
||||
self._outbound.use_connection(c) # does c.registerProducer
|
||||
if not self._made_first_connection:
|
||||
self._made_first_connection = True
|
||||
self._first_connected.fire(None)
|
||||
pass
|
||||
|
||||
def connector_connection_lost(self):
|
||||
self._stop_using_connection()
|
||||
if self._my_role is LEADER:
|
||||
self.connection_lost_leader() # state machine
|
||||
else:
|
||||
self.connection_lost_follower()
|
||||
|
||||
def _stop_using_connection(self):
|
||||
# the connection is already lost by this point
|
||||
self._connection = None
|
||||
self._inbound.stop_using_connection()
|
||||
self._outbound.stop_using_connection() # does c.unregisterProducer
|
||||
|
||||
# from our active Connection
|
||||
|
||||
def got_record(self, r):
|
||||
# records with sequence numbers: always ack, ignore old ones
|
||||
if isinstance(r, (Open, Data, Close)):
|
||||
self.send_ack(r.seqnum) # always ack, even for old ones
|
||||
if self._inbound.is_record_old(r):
|
||||
return
|
||||
self._inbound.update_ack_watermark(r.seqnum)
|
||||
if isinstance(r, Open):
|
||||
self._inbound.handle_open(r.scid)
|
||||
elif isinstance(r, Data):
|
||||
self._inbound.handle_data(r.scid, r.data)
|
||||
else: # isinstance(r, Close)
|
||||
self._inbound.handle_close(r.scid)
|
||||
return
|
||||
if isinstance(r, KCM):
|
||||
log.err(UnexpectedKCM())
|
||||
elif isinstance(r, Ping):
|
||||
self.handle_ping(r.ping_id)
|
||||
elif isinstance(r, Pong):
|
||||
self.handle_pong(r.ping_id)
|
||||
elif isinstance(r, Ack):
|
||||
self._outbound.handle_ack(r.resp_seqnum) # retire queued messages
|
||||
else:
|
||||
log.err(UnknownMessageType("{}".format(r)))
|
||||
|
||||
# pings, pongs, and acks are not queued
|
||||
def send_ping(self, ping_id):
|
||||
self._outbound.send_if_connected(Ping(ping_id))
|
||||
|
||||
def send_pong(self, ping_id):
|
||||
self._outbound.send_if_connected(Pong(ping_id))
|
||||
|
||||
def send_ack(self, resp_seqnum):
|
||||
self._outbound.send_if_connected(Ack(resp_seqnum))
|
||||
|
||||
def handle_ping(self, ping_id):
|
||||
self.send_pong(ping_id)
|
||||
|
||||
def handle_pong(self, ping_id):
|
||||
# TODO: update is-alive timer
|
||||
pass
|
||||
|
||||
# subchannel maintenance
|
||||
def allocate_subchannel_id(self):
|
||||
scid_num = self._next_subchannel_id
|
||||
self._next_subchannel_id += 2
|
||||
return to_be4(scid_num)
|
||||
|
||||
# state machine
|
||||
|
||||
# We are born WANTING after the local app calls w.dilate(). We start
|
||||
# CONNECTING when we receive PLEASE from the remote side
|
||||
|
||||
def start(self):
|
||||
self.send_please()
|
||||
|
||||
def send_please(self):
|
||||
self.send_dilation_phase(type="please", side=self._my_side)
|
||||
|
||||
@m.state(initial=True)
|
||||
def WANTING(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.state()
|
||||
def CONNECTING(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.state()
|
||||
def CONNECTED(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.state()
|
||||
def FLUSHING(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.state()
|
||||
def ABANDONING(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.state()
|
||||
def LONELY(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.state()
|
||||
def STOPPING(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.state(terminal=True)
|
||||
def STOPPED(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.input()
|
||||
def rx_PLEASE(self, message):
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.input() # only sent by Follower
|
||||
def rx_HINTS(self, hint_message):
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.input() # only Leader sends RECONNECT, so only Follower receives it
|
||||
def rx_RECONNECT(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.input() # only Follower sends RECONNECTING, so only Leader receives it
|
||||
def rx_RECONNECTING(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
# Connector gives us connection_made()
|
||||
@m.input()
|
||||
def connection_made(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
# our connection_lost() fires connection_lost_leader or
|
||||
# connection_lost_follower depending upon our role. If either side sees a
|
||||
# problem with the connection (timeouts, bad authentication) then they
|
||||
# just drop it and let connection_lost() handle the cleanup.
|
||||
@m.input()
|
||||
def connection_lost_leader(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.input()
|
||||
def connection_lost_follower(self):
|
||||
pass
|
||||
|
||||
@m.input()
|
||||
def stop(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.output()
|
||||
def choose_role(self, message):
|
||||
their_side = message["side"]
|
||||
if self._my_side > their_side:
|
||||
self._my_role = LEADER
|
||||
# scid 0 is reserved for the control channel. the leader uses odd
|
||||
# numbers starting with 1
|
||||
self._next_subchannel_id = 1
|
||||
elif their_side > self._my_side:
|
||||
self._my_role = FOLLOWER
|
||||
# the follower uses even numbers starting with 2
|
||||
self._next_subchannel_id = 2
|
||||
else:
|
||||
raise ValueError("their side shouldn't be equal: reflection?")
|
||||
|
||||
# these Outputs behave differently for the Leader vs the Follower
|
||||
|
||||
@m.output()
|
||||
def start_connecting_ignore_message(self, message):
|
||||
del message # ignored
|
||||
return self._start_connecting()
|
||||
|
||||
@m.output()
|
||||
def start_connecting(self):
|
||||
self._start_connecting()
|
||||
|
||||
def _start_connecting(self):
|
||||
assert self._my_role is not None
|
||||
self._connector = Connector(self._transit_key,
|
||||
self._transit_relay_location,
|
||||
self,
|
||||
self._reactor, self._eventual_queue,
|
||||
self._no_listen, self._tor,
|
||||
self._timing,
|
||||
self._my_side, # needed for relay handshake
|
||||
self._my_role)
|
||||
self._connector.start()
|
||||
|
||||
@m.output()
|
||||
def send_reconnect(self):
|
||||
self.send_dilation_phase(type="reconnect") # TODO: generation number?
|
||||
|
||||
@m.output()
|
||||
def send_reconnecting(self):
|
||||
self.send_dilation_phase(type="reconnecting") # TODO: generation?
|
||||
|
||||
@m.output()
|
||||
def use_hints(self, hint_message):
|
||||
hint_objs = filter(lambda h: h, # ignore None, unrecognizable
|
||||
[parse_hint(hs) for hs in hint_message["hints"]])
|
||||
hint_objs = list(hint_objs)
|
||||
self._connector.got_hints(hint_objs)
|
||||
|
||||
@m.output()
|
||||
def stop_connecting(self):
|
||||
self._connector.stop()
|
||||
|
||||
@m.output()
|
||||
def abandon_connection(self):
|
||||
# we think we're still connected, but the Leader disagrees. Or we've
|
||||
# been told to shut down.
|
||||
self._connection.disconnect() # let connection_lost do cleanup
|
||||
|
||||
# we start CONNECTING when we get rx_PLEASE
|
||||
WANTING.upon(rx_PLEASE, enter=CONNECTING,
|
||||
outputs=[choose_role, start_connecting_ignore_message])
|
||||
|
||||
CONNECTING.upon(connection_made, enter=CONNECTED, outputs=[])
|
||||
|
||||
# Leader
|
||||
CONNECTED.upon(connection_lost_leader, enter=FLUSHING,
|
||||
outputs=[send_reconnect])
|
||||
FLUSHING.upon(rx_RECONNECTING, enter=CONNECTING,
|
||||
outputs=[start_connecting])
|
||||
|
||||
# Follower
|
||||
# if we notice a lost connection, just wait for the Leader to notice too
|
||||
CONNECTED.upon(connection_lost_follower, enter=LONELY, outputs=[])
|
||||
LONELY.upon(rx_RECONNECT, enter=CONNECTING,
|
||||
outputs=[send_reconnecting, start_connecting])
|
||||
# but if they notice it first, abandon our (seemingly functional)
|
||||
# connection, then tell them that we're ready to try again
|
||||
CONNECTED.upon(rx_RECONNECT, enter=ABANDONING, outputs=[abandon_connection])
|
||||
ABANDONING.upon(connection_lost_follower, enter=CONNECTING,
|
||||
outputs=[send_reconnecting, start_connecting])
|
||||
# and if they notice a problem while we're still connecting, abandon our
|
||||
# incomplete attempt and try again. in this case we don't have to wait
|
||||
# for a connection to finish shutdown
|
||||
CONNECTING.upon(rx_RECONNECT, enter=CONNECTING,
|
||||
outputs=[stop_connecting,
|
||||
send_reconnecting,
|
||||
start_connecting])
|
||||
|
||||
# rx_HINTS never changes state, they're just accepted or ignored
|
||||
WANTING.upon(rx_HINTS, enter=WANTING, outputs=[]) # too early
|
||||
CONNECTING.upon(rx_HINTS, enter=CONNECTING, outputs=[use_hints])
|
||||
CONNECTED.upon(rx_HINTS, enter=CONNECTED, outputs=[]) # too late, ignore
|
||||
FLUSHING.upon(rx_HINTS, enter=FLUSHING, outputs=[]) # stale, ignore
|
||||
LONELY.upon(rx_HINTS, enter=LONELY, outputs=[]) # stale, ignore
|
||||
ABANDONING.upon(rx_HINTS, enter=ABANDONING, outputs=[]) # shouldn't happen
|
||||
STOPPING.upon(rx_HINTS, enter=STOPPING, outputs=[])
|
||||
|
||||
WANTING.upon(stop, enter=STOPPED, outputs=[])
|
||||
CONNECTING.upon(stop, enter=STOPPED, outputs=[stop_connecting])
|
||||
CONNECTED.upon(stop, enter=STOPPING, outputs=[abandon_connection])
|
||||
ABANDONING.upon(stop, enter=STOPPING, outputs=[])
|
||||
FLUSHING.upon(stop, enter=STOPPED, outputs=[])
|
||||
LONELY.upon(stop, enter=STOPPED, outputs=[])
|
||||
STOPPING.upon(connection_lost_leader, enter=STOPPED, outputs=[])
|
||||
STOPPING.upon(connection_lost_follower, enter=STOPPED, outputs=[])
|
||||
|
||||
|
||||
@attrs
|
||||
@implementer(IDilator)
|
||||
class Dilator(object):
|
||||
"""I launch the dilation process.
|
||||
|
||||
I am created with every Wormhole (regardless of whether .dilate()
|
||||
was called or not), and I handle the initial phase of dilation,
|
||||
before we know whether we'll be the Leader or the Follower. Once we
|
||||
hear the other side's VERSION message (which tells us that we have a
|
||||
connection, they are capable of dilating, and which side we're on),
|
||||
then we build a Manager and hand control to it.
|
||||
"""
|
||||
|
||||
_reactor = attrib()
|
||||
_eventual_queue = attrib()
|
||||
_cooperator = attrib()
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
self._got_versions_d = Deferred()
|
||||
self._started = False
|
||||
self._endpoints = OneShotObserver(self._eventual_queue)
|
||||
self._pending_inbound_dilate_messages = deque()
|
||||
self._manager = None
|
||||
|
||||
def wire(self, sender):
|
||||
self._S = ISend(sender)
|
||||
|
||||
# this is the primary entry point, called when w.dilate() is invoked
|
||||
def dilate(self, transit_relay_location=None):
|
||||
self._transit_relay_location = transit_relay_location
|
||||
if not self._started:
|
||||
self._started = True
|
||||
self._start().addBoth(self._endpoints.fire)
|
||||
return self._endpoints.when_fired()
|
||||
|
||||
@inlineCallbacks
|
||||
def _start(self):
|
||||
# first, we wait until we hear the VERSION message, which tells us 1:
|
||||
# the PAKE key works, so we can talk securely, 2: that they can do
|
||||
# dilation at all (if they can't then w.dilate() errbacks)
|
||||
|
||||
dilation_version = yield self._got_versions_d
|
||||
|
||||
# TODO: we could probably return the endpoints earlier, if we flunk
|
||||
# any connection/listen attempts upon OldPeerCannotDilateError, or
|
||||
# if/when we give up on the initial connection
|
||||
|
||||
if not dilation_version: # "1" or None
|
||||
# TODO: be more specific about the error. dilation_version==None
|
||||
# means we had no version in common with them, which could either
|
||||
# be because they're so old they don't dilate at all, or because
|
||||
# they're so new that they no longer accomodate our old version
|
||||
raise OldPeerCannotDilateError()
|
||||
|
||||
my_dilation_side = make_side()
|
||||
self._manager = Manager(self._S, my_dilation_side,
|
||||
self._transit_key,
|
||||
self._transit_relay_location,
|
||||
self._reactor, self._eventual_queue,
|
||||
self._cooperator)
|
||||
self._manager.start()
|
||||
|
||||
while self._pending_inbound_dilate_messages:
|
||||
plaintext = self._pending_inbound_dilate_messages.popleft()
|
||||
self.received_dilate(plaintext)
|
||||
|
||||
yield self._manager.when_first_connected()
|
||||
|
||||
# we can open subchannels as soon as we get our first connection
|
||||
scid0 = b"\x00\x00\x00\x00"
|
||||
self._host_addr = _WormholeAddress() # TODO: share with Manager
|
||||
peer_addr0 = _SubchannelAddress(scid0)
|
||||
control_ep = ControlEndpoint(peer_addr0)
|
||||
sc0 = SubChannel(scid0, self._manager, self._host_addr, peer_addr0)
|
||||
control_ep._subchannel_zero_opened(sc0)
|
||||
self._manager.set_subchannel_zero(scid0, sc0)
|
||||
|
||||
connect_ep = SubchannelConnectorEndpoint(self._manager, self._host_addr)
|
||||
|
||||
listen_ep = SubchannelListenerEndpoint(self._manager, self._host_addr)
|
||||
self._manager.set_listener_endpoint(listen_ep)
|
||||
|
||||
endpoints = (control_ep, connect_ep, listen_ep)
|
||||
returnValue(endpoints)
|
||||
|
||||
# from Boss
|
||||
|
||||
def got_key(self, key):
|
||||
# TODO: verify this happens before got_wormhole_versions, or add a gate
|
||||
# to tolerate either ordering
|
||||
purpose = b"dilation-v1"
|
||||
LENGTH = 32 # TODO: whatever Noise wants, I guess
|
||||
self._transit_key = derive_key(key, purpose, LENGTH)
|
||||
|
||||
def got_wormhole_versions(self, their_wormhole_versions):
|
||||
assert self._transit_key is not None
|
||||
# this always happens before received_dilate
|
||||
dilation_version = None
|
||||
their_dilation_versions = set(their_wormhole_versions.get("can-dilate", []))
|
||||
my_versions = set(DILATION_VERSIONS)
|
||||
shared_versions = my_versions.intersection(their_dilation_versions)
|
||||
if "1" in shared_versions:
|
||||
dilation_version = "1"
|
||||
self._got_versions_d.callback(dilation_version)
|
||||
|
||||
def received_dilate(self, plaintext):
|
||||
# this receives new in-order DILATE-n payloads, decrypted but not
|
||||
# de-JSONed.
|
||||
|
||||
# this can appear before our .dilate() method is called, in which case
|
||||
# we queue them for later
|
||||
if not self._manager:
|
||||
self._pending_inbound_dilate_messages.append(plaintext)
|
||||
return
|
||||
|
||||
message = bytes_to_dict(plaintext)
|
||||
type = message["type"]
|
||||
if type == "please":
|
||||
self._manager.rx_PLEASE(message)
|
||||
elif type == "connection-hints":
|
||||
self._manager.rx_HINTS(message)
|
||||
elif type == "reconnect":
|
||||
self._manager.rx_RECONNECT()
|
||||
elif type == "reconnecting":
|
||||
self._manager.rx_RECONNECTING()
|
||||
else:
|
||||
log.err(UnknownDilationMessageType(message))
|
||||
return
|
389
src/wormhole/_dilation/outbound.py
Normal file
389
src/wormhole/_dilation/outbound.py
Normal file
|
@ -0,0 +1,389 @@
|
|||
from __future__ import print_function, unicode_literals
|
||||
from collections import deque
|
||||
from attr import attrs, attrib
|
||||
from attr.validators import provides
|
||||
from zope.interface import implementer
|
||||
from twisted.internet.interfaces import IPushProducer, IPullProducer
|
||||
from twisted.python import log
|
||||
from twisted.python.reflect import safe_str
|
||||
from .._interfaces import IDilationManager, IOutbound
|
||||
from .connection import KCM, Ping, Pong, Ack
|
||||
|
||||
|
||||
# Outbound flow control: app writes to subchannel, we write to Connection
|
||||
|
||||
# The app can register an IProducer of their choice, to let us throttle their
|
||||
# outbound data. Not all subchannels will have producers registered, and the
|
||||
# producer probably won't be the IProtocol instance (it'll be something else
|
||||
# which feeds data out through the protocol, like a t.p.basic.FileSender). If
|
||||
# a producerless subchannel writes too much, we won't be able to stop them,
|
||||
# and we'll keep writing records into the Connection even though it's asked
|
||||
# us to pause. Likewise, when the connection is down (and we're busily trying
|
||||
# to reestablish a new one), registered subchannels will be paused, but
|
||||
# unregistered ones will just dump everything in _outbound_queue, and we'll
|
||||
# consume memory without bound until they stop.
|
||||
|
||||
# We need several things:
|
||||
#
|
||||
# * Add each registered IProducer to a list, whose order remains stable. We
|
||||
# want fairness under outbound throttling: each time the outbound
|
||||
# connection opens up (our resumeProducing method is called), we should let
|
||||
# just one producer have an opportunity to do transport.write, and then we
|
||||
# should pause them again, and not come back to them until everyone else
|
||||
# has gotten a turn. So we want an ordered list of producers to track this
|
||||
# rotation.
|
||||
#
|
||||
# * Remove the IProducer if/when the protocol uses unregisterProducer
|
||||
#
|
||||
# * Remove any registered IProducer when the associated Subchannel is closed.
|
||||
# This isn't a problem for normal transports, because usually there's a
|
||||
# one-to-one mapping from Protocol to Transport, so when the Transport you
|
||||
# forget the only reference to the Producer anyways. Our situation is
|
||||
# unusual because we have multiple Subchannels that get merged into the
|
||||
# same underlying Connection: each Subchannel's Protocol can register a
|
||||
# producer on the Subchannel (which is an ITransport), but that adds it to
|
||||
# a set of Producers for the Connection (which is also an ITransport). So
|
||||
# if the Subchannel is closed, we need to remove its Producer (if any) even
|
||||
# though the Connection remains open.
|
||||
#
|
||||
# * Register ourselves as an IPushProducer with each successive Connection
|
||||
# object. These connections will come and go, but there will never be more
|
||||
# than one. When the connection goes away, pause all our producers. When a
|
||||
# new one is established, write all our queued messages, then unpause our
|
||||
# producers as we would in resumeProducing.
|
||||
#
|
||||
# * Inside our resumeProducing call, we'll cycle through all producers,
|
||||
# calling their individual resumeProducing methods one at a time. If they
|
||||
# write so much data that the Connection pauses us again, we'll find out
|
||||
# because our pauseProducing will be called inside that loop. When that
|
||||
# happens, we need to stop looping. If we make it through the whole loop
|
||||
# without being paused, then all subchannel Producers are left unpaused,
|
||||
# and are free to write whenever they want. During this loop, some
|
||||
# Producers will be paused, and others will be resumed
|
||||
#
|
||||
# * If our pauseProducing is called, all Producers must be paused, and a flag
|
||||
# should be set to notify the resumeProducing loop to exit
|
||||
#
|
||||
# * In between calls to our resumeProducing method, we're in one of two
|
||||
# states.
|
||||
# * If we're writing data too fast, then we'll be left in the "paused"
|
||||
# state, in which all Subchannel producers are paused, and the aggregate
|
||||
# is paused too (our Connection told us to pauseProducing and hasn't yet
|
||||
# told us to resumeProducing). In this state, activity is driven by the
|
||||
# outbound TCP window opening up, which calls resumeProducing and allows
|
||||
# (probably just) one message to be sent. We receive pauseProducing in
|
||||
# the middle of their transport.write, so the loop exits early, and the
|
||||
# only state change is that some other Producer should get to go next
|
||||
# time.
|
||||
# * If we're writing too slowly, we'll be left in the "unpaused" state: all
|
||||
# Subchannel producers are unpaused, and the aggregate is unpaused too
|
||||
# (resumeProducing is the last thing we've been told). In this satte,
|
||||
# activity is driven by the Subchannels doing a transport.write, which
|
||||
# queues some data on the TCP connection (and then might call
|
||||
# pauseProducing if it's now full).
|
||||
#
|
||||
# * We want to guard against:
|
||||
#
|
||||
# * application protocol registering a Producer without first unregistering
|
||||
# the previous one
|
||||
#
|
||||
# * application protocols writing data despite being told to pause
|
||||
# (Subchannels without a registered Producer cannot be throttled, and we
|
||||
# can't do anything about that, but we must also handle the case where
|
||||
# they give us a pause switch and then proceed to ignore it)
|
||||
#
|
||||
# * our Connection calling resumeProducing or pauseProducing without an
|
||||
# intervening call of the other kind
|
||||
#
|
||||
# * application protocols that don't handle a resumeProducing or
|
||||
# pauseProducing call without an intervening call of the other kind (i.e.
|
||||
# we should keep track of the last thing we told them, and not repeat
|
||||
# ourselves)
|
||||
#
|
||||
# * If the Wormhole is closed, all Subchannels should close. This is not our
|
||||
# responsibility: it lives in (Manager? Inbound?)
|
||||
#
|
||||
# * If we're given an IPullProducer, we should keep calling its
|
||||
# resumeProducing until it runs out of data. We still want fairness, so we
|
||||
# won't call it a second time until everyone else has had a turn.
|
||||
|
||||
|
||||
# There are a couple of different ways to approach this. The one I've
|
||||
# selected is:
|
||||
#
|
||||
# * keep a dict that maps from Subchannel to Producer, which only contains
|
||||
# entries for Subchannels that have registered a producer. We use this to
|
||||
# remove Producers when Subchannels are closed
|
||||
#
|
||||
# * keep a Deque of Producers. This represents the fair-throttling rotation:
|
||||
# the left-most item gets the next upcoming turn, and then they'll be moved
|
||||
# to the end of the queue.
|
||||
#
|
||||
# * keep a set of IPushProducers which are paused, a second set of
|
||||
# IPushProducers which are unpaused, and a third set of IPullProducers
|
||||
# (which are always left paused) Enforce the invariant that these three
|
||||
# sets are disjoint, and that their union equals the contents of the deque.
|
||||
#
|
||||
# * keep a "paused" flag, which is cleared upon entry to resumeProducing, and
|
||||
# set upon entry to pauseProducing. The loop inside resumeProducing checks
|
||||
# this flag after each call to producer.resumeProducing, to sense whether
|
||||
# they used their turn to write data, and if that write was large enough to
|
||||
# fill the TCP window. If set, we break out of the loop. If not, we look
|
||||
# for the next producer to unpause. The loop finishes when all producers
|
||||
# are unpaused (evidenced by the two sets of paused producers being empty)
|
||||
#
|
||||
# * the "paused" flag also determines whether new IPushProducers are added to
|
||||
# the paused or unpaused set (IPullProducers are always added to the
|
||||
# pull+paused set). If we have any IPullProducers, we're always in the
|
||||
# "writing data too fast" state.
|
||||
|
||||
# other approaches that I didn't decide to do at this time (but might use in
|
||||
# the future):
|
||||
#
|
||||
# * use one set instead of two. pros: fewer moving parts. cons: harder to
|
||||
# spot decoherence bugs like adding a producer to the deque but forgetting
|
||||
# to add it to one of the
|
||||
#
|
||||
# * use zero sets, and keep the paused-vs-unpaused state in the Subchannel as
|
||||
# a visible boolean flag. This conflates Subchannels with their associated
|
||||
# Producer (so if we went this way, we should also let them track their own
|
||||
# Producer). Our resumeProducing loop ends when 'not any(sc.paused for sc
|
||||
# in self._subchannels_with_producers)'. Pros: fewer subchannel->producer
|
||||
# mappings lying around to disagree with one another. Cons: exposes a bit
|
||||
# too much of the Subchannel internals
|
||||
|
||||
|
||||
@attrs
|
||||
@implementer(IOutbound)
|
||||
class Outbound(object):
|
||||
# Manage outbound data: subchannel writes to us, we write to transport
|
||||
_manager = attrib(validator=provides(IDilationManager))
|
||||
_cooperator = attrib()
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
# _outbound_queue holds all messages we've ever sent but not retired
|
||||
self._outbound_queue = deque()
|
||||
self._next_outbound_seqnum = 0
|
||||
# _queued_unsent are messages to retry with our new connection
|
||||
self._queued_unsent = deque()
|
||||
|
||||
# outbound flow control: the Connection throttles our writes
|
||||
self._subchannel_producers = {} # Subchannel -> IProducer
|
||||
self._paused = True # our Connection called our pauseProducing
|
||||
self._all_producers = deque() # rotates, left-is-next
|
||||
self._paused_producers = set()
|
||||
self._unpaused_producers = set()
|
||||
self._check_invariants()
|
||||
|
||||
self._connection = None
|
||||
|
||||
def _check_invariants(self):
|
||||
assert self._unpaused_producers.isdisjoint(self._paused_producers)
|
||||
assert (self._paused_producers.union(self._unpaused_producers) ==
|
||||
set(self._all_producers))
|
||||
|
||||
def build_record(self, record_type, *args):
|
||||
seqnum = self._next_outbound_seqnum
|
||||
self._next_outbound_seqnum += 1
|
||||
r = record_type(seqnum, *args)
|
||||
assert hasattr(r, "seqnum"), r # only Open/Data/Close
|
||||
return r
|
||||
|
||||
def queue_and_send_record(self, r):
|
||||
# we always queue it, to resend on a subsequent connection if
|
||||
# necessary
|
||||
self._outbound_queue.append(r)
|
||||
|
||||
if self._connection:
|
||||
if self._queued_unsent:
|
||||
# to maintain correct ordering, queue this instead of sending it
|
||||
self._queued_unsent.append(r)
|
||||
else:
|
||||
# we're allowed to send it immediately
|
||||
self._connection.send_record(r)
|
||||
|
||||
def send_if_connected(self, r):
|
||||
assert isinstance(r, (KCM, Ping, Pong, Ack)), r # nothing with seqnum
|
||||
if self._connection:
|
||||
self._connection.send_record(r)
|
||||
|
||||
# our subchannels call these to register a producer
|
||||
|
||||
def subchannel_registerProducer(self, sc, producer, streaming):
|
||||
# streaming==True: IPushProducer (pause/resume)
|
||||
# streaming==False: IPullProducer (just resume)
|
||||
if sc in self._subchannel_producers:
|
||||
raise ValueError(
|
||||
"registering producer %s before previous one (%s) was "
|
||||
"unregistered" % (producer,
|
||||
self._subchannel_producers[sc]))
|
||||
# our underlying Connection uses streaming==True, so to make things
|
||||
# easier, use an adapter when the Subchannel asks for streaming=False
|
||||
if not streaming:
|
||||
def unregister():
|
||||
self.subchannel_unregisterProducer(sc)
|
||||
producer = PullToPush(producer, unregister, self._cooperator)
|
||||
|
||||
self._subchannel_producers[sc] = producer
|
||||
self._all_producers.append(producer)
|
||||
if self._paused:
|
||||
self._paused_producers.add(producer)
|
||||
else:
|
||||
self._unpaused_producers.add(producer)
|
||||
self._check_invariants()
|
||||
if streaming:
|
||||
if self._paused:
|
||||
# IPushProducers need to be paused immediately, before they
|
||||
# speak
|
||||
producer.pauseProducing() # you wake up sleeping
|
||||
else:
|
||||
# our PullToPush adapter must be started, but if we're paused then
|
||||
# we tell it to pause before it gets a chance to write anything
|
||||
producer.startStreaming(self._paused)
|
||||
|
||||
def subchannel_unregisterProducer(self, sc):
|
||||
# TODO: what if the subchannel closes, so we unregister their
|
||||
# producer for them, then the application reacts to connectionLost
|
||||
# with a duplicate unregisterProducer?
|
||||
p = self._subchannel_producers.pop(sc)
|
||||
if isinstance(p, PullToPush):
|
||||
p.stopStreaming()
|
||||
self._all_producers.remove(p)
|
||||
self._paused_producers.discard(p)
|
||||
self._unpaused_producers.discard(p)
|
||||
self._check_invariants()
|
||||
|
||||
def subchannel_closed(self, sc):
|
||||
self._check_invariants()
|
||||
if sc in self._subchannel_producers:
|
||||
self.subchannel_unregisterProducer(sc)
|
||||
|
||||
# our Manager tells us when we've got a new Connection to work with
|
||||
|
||||
def use_connection(self, c):
|
||||
self._connection = c
|
||||
assert not self._queued_unsent
|
||||
self._queued_unsent.extend(self._outbound_queue)
|
||||
# the connection can tell us to pause when we send too much data
|
||||
c.registerProducer(self, True) # IPushProducer: pause+resume
|
||||
# send our queued messages
|
||||
self.resumeProducing()
|
||||
|
||||
def stop_using_connection(self):
|
||||
self._connection.unregisterProducer()
|
||||
self._connection = None
|
||||
self._queued_unsent.clear()
|
||||
self.pauseProducing()
|
||||
# TODO: I expect this will call pauseProducing twice: the first time
|
||||
# when we get stopProducing (since we're registere with the
|
||||
# underlying connection as the producer), and again when the manager
|
||||
# notices the connectionLost and calls our _stop_using_connection
|
||||
|
||||
def handle_ack(self, resp_seqnum):
|
||||
# we've received an inbound ack, so retire something
|
||||
while (self._outbound_queue and
|
||||
self._outbound_queue[0].seqnum <= resp_seqnum):
|
||||
self._outbound_queue.popleft()
|
||||
while (self._queued_unsent and
|
||||
self._queued_unsent[0].seqnum <= resp_seqnum):
|
||||
self._queued_unsent.popleft()
|
||||
# Inbound is responsible for tracking the high watermark and deciding
|
||||
# whether to ignore inbound messages or not
|
||||
|
||||
# IProducer: the active connection calls these because we used
|
||||
# c.registerProducer to ask for them
|
||||
|
||||
def pauseProducing(self):
|
||||
if self._paused:
|
||||
return # someone is confused and called us twice
|
||||
self._paused = True
|
||||
for p in self._all_producers:
|
||||
if p in self._unpaused_producers:
|
||||
self._unpaused_producers.remove(p)
|
||||
self._paused_producers.add(p)
|
||||
p.pauseProducing()
|
||||
|
||||
def resumeProducing(self):
|
||||
if not self._paused:
|
||||
return # someone is confused and called us twice
|
||||
self._paused = False
|
||||
|
||||
while not self._paused:
|
||||
if self._queued_unsent:
|
||||
r = self._queued_unsent.popleft()
|
||||
self._connection.send_record(r)
|
||||
continue
|
||||
p = self._get_next_unpaused_producer()
|
||||
if not p:
|
||||
break
|
||||
self._paused_producers.remove(p)
|
||||
self._unpaused_producers.add(p)
|
||||
p.resumeProducing()
|
||||
|
||||
def _get_next_unpaused_producer(self):
|
||||
self._check_invariants()
|
||||
if not self._paused_producers:
|
||||
return None
|
||||
while True:
|
||||
p = self._all_producers[0]
|
||||
self._all_producers.rotate(-1) # p moves to the end of the line
|
||||
# the only unpaused Producers are at the end of the list
|
||||
assert p in self._paused_producers
|
||||
return p
|
||||
|
||||
def stopProducing(self):
|
||||
# we'll hopefully have a new connection to work with in the future,
|
||||
# so we don't shut anything down. We do pause everyone, though.
|
||||
self.pauseProducing()
|
||||
|
||||
|
||||
# modelled after twisted.internet._producer_helper._PullToPush , but with a
|
||||
# configurable Cooperator, a pause-immediately argument to startStreaming()
|
||||
@implementer(IPushProducer)
|
||||
@attrs(cmp=False)
|
||||
class PullToPush(object):
|
||||
_producer = attrib(validator=provides(IPullProducer))
|
||||
_unregister = attrib(validator=lambda _a, _b, v: callable(v))
|
||||
_cooperator = attrib()
|
||||
_finished = False
|
||||
|
||||
def _pull(self):
|
||||
while True:
|
||||
try:
|
||||
self._producer.resumeProducing()
|
||||
except Exception:
|
||||
log.err(None, "%s failed, producing will be stopped:" %
|
||||
(safe_str(self._producer),))
|
||||
try:
|
||||
self._unregister()
|
||||
# The consumer should now call stopStreaming() on us,
|
||||
# thus stopping the streaming.
|
||||
except Exception:
|
||||
# Since the consumer blew up, we may not have had
|
||||
# stopStreaming() called, so we just stop on our own:
|
||||
log.err(None, "%s failed to unregister producer:" %
|
||||
(safe_str(self._unregister),))
|
||||
self._finished = True
|
||||
return
|
||||
yield None
|
||||
|
||||
def startStreaming(self, paused):
|
||||
self._coopTask = self._cooperator.cooperate(self._pull())
|
||||
if paused:
|
||||
self.pauseProducing() # timer is scheduled, but task is removed
|
||||
|
||||
def stopStreaming(self):
|
||||
if self._finished:
|
||||
return
|
||||
self._finished = True
|
||||
self._coopTask.stop()
|
||||
|
||||
def pauseProducing(self):
|
||||
self._coopTask.pause()
|
||||
|
||||
def resumeProducing(self):
|
||||
self._coopTask.resume()
|
||||
|
||||
def stopProducing(self):
|
||||
self.stopStreaming()
|
||||
self._producer.stopProducing()
|
1
src/wormhole/_dilation/roles.py
Normal file
1
src/wormhole/_dilation/roles.py
Normal file
|
@ -0,0 +1 @@
|
|||
LEADER, FOLLOWER = object(), object()
|
300
src/wormhole/_dilation/subchannel.py
Normal file
300
src/wormhole/_dilation/subchannel.py
Normal file
|
@ -0,0 +1,300 @@
|
|||
from attr import attrs, attrib
|
||||
from attr.validators import instance_of, provides
|
||||
from zope.interface import implementer
|
||||
from twisted.internet.defer import (Deferred, inlineCallbacks, returnValue,
|
||||
succeed)
|
||||
from twisted.internet.interfaces import (ITransport, IProducer, IConsumer,
|
||||
IAddress, IListeningPort,
|
||||
IStreamClientEndpoint,
|
||||
IStreamServerEndpoint)
|
||||
from twisted.internet.error import ConnectionDone
|
||||
from automat import MethodicalMachine
|
||||
from .._interfaces import ISubChannel, IDilationManager
|
||||
|
||||
|
||||
@attrs
|
||||
class Once(object):
|
||||
_errtype = attrib()
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
self._called = False
|
||||
|
||||
def __call__(self):
|
||||
if self._called:
|
||||
raise self._errtype()
|
||||
self._called = True
|
||||
|
||||
|
||||
class SingleUseEndpointError(Exception):
|
||||
pass
|
||||
|
||||
# created in the (OPEN) state, by either:
|
||||
# * receipt of an OPEN message
|
||||
# * or local client_endpoint.connect()
|
||||
# then transitions are:
|
||||
# (OPEN) rx DATA: deliver .dataReceived(), -> (OPEN)
|
||||
# (OPEN) rx CLOSE: deliver .connectionLost(), send CLOSE, -> (CLOSED)
|
||||
# (OPEN) local .write(): send DATA, -> (OPEN)
|
||||
# (OPEN) local .loseConnection(): send CLOSE, -> (CLOSING)
|
||||
# (CLOSING) local .write(): error
|
||||
# (CLOSING) local .loseConnection(): error
|
||||
# (CLOSING) rx DATA: deliver .dataReceived(), -> (CLOSING)
|
||||
# (CLOSING) rx CLOSE: deliver .connectionLost(), -> (CLOSED)
|
||||
# object is deleted upon transition to (CLOSED)
|
||||
|
||||
|
||||
class AlreadyClosedError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@implementer(IAddress)
|
||||
class _WormholeAddress(object):
|
||||
pass
|
||||
|
||||
|
||||
@implementer(IAddress)
|
||||
@attrs
|
||||
class _SubchannelAddress(object):
|
||||
_scid = attrib()
|
||||
|
||||
|
||||
@attrs
|
||||
@implementer(ITransport)
|
||||
@implementer(IProducer)
|
||||
@implementer(IConsumer)
|
||||
@implementer(ISubChannel)
|
||||
class SubChannel(object):
|
||||
_id = attrib(validator=instance_of(bytes))
|
||||
_manager = attrib(validator=provides(IDilationManager))
|
||||
_host_addr = attrib(validator=instance_of(_WormholeAddress))
|
||||
_peer_addr = attrib(validator=instance_of(_SubchannelAddress))
|
||||
|
||||
m = MethodicalMachine()
|
||||
set_trace = getattr(m, "_setTrace", lambda self,
|
||||
f: None) # pragma: no cover
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
# self._mailbox = None
|
||||
# self._pending_outbound = {}
|
||||
# self._processed = set()
|
||||
self._protocol = None
|
||||
self._pending_dataReceived = []
|
||||
self._pending_connectionLost = (False, None)
|
||||
|
||||
@m.state(initial=True)
|
||||
def open(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.state()
|
||||
def closing():
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.state()
|
||||
def closed():
|
||||
pass # pragma: no cover
|
||||
|
||||
@m.input()
|
||||
def remote_data(self, data):
|
||||
pass
|
||||
|
||||
@m.input()
|
||||
def remote_close(self):
|
||||
pass
|
||||
|
||||
@m.input()
|
||||
def local_data(self, data):
|
||||
pass
|
||||
|
||||
@m.input()
|
||||
def local_close(self):
|
||||
pass
|
||||
|
||||
@m.output()
|
||||
def send_data(self, data):
|
||||
self._manager.send_data(self._id, data)
|
||||
|
||||
@m.output()
|
||||
def send_close(self):
|
||||
self._manager.send_close(self._id)
|
||||
|
||||
@m.output()
|
||||
def signal_dataReceived(self, data):
|
||||
if self._protocol:
|
||||
self._protocol.dataReceived(data)
|
||||
else:
|
||||
self._pending_dataReceived.append(data)
|
||||
|
||||
@m.output()
|
||||
def signal_connectionLost(self):
|
||||
if self._protocol:
|
||||
self._protocol.connectionLost(ConnectionDone())
|
||||
else:
|
||||
self._pending_connectionLost = (True, ConnectionDone())
|
||||
self._manager.subchannel_closed(self)
|
||||
# we're deleted momentarily
|
||||
|
||||
@m.output()
|
||||
def error_closed_write(self, data):
|
||||
raise AlreadyClosedError("write not allowed on closed subchannel")
|
||||
|
||||
@m.output()
|
||||
def error_closed_close(self):
|
||||
raise AlreadyClosedError(
|
||||
"loseConnection not allowed on closed subchannel")
|
||||
|
||||
# primary transitions
|
||||
open.upon(remote_data, enter=open, outputs=[signal_dataReceived])
|
||||
open.upon(local_data, enter=open, outputs=[send_data])
|
||||
open.upon(remote_close, enter=closed, outputs=[signal_connectionLost])
|
||||
open.upon(local_close, enter=closing, outputs=[send_close])
|
||||
closing.upon(remote_data, enter=closing, outputs=[signal_dataReceived])
|
||||
closing.upon(remote_close, enter=closed, outputs=[signal_connectionLost])
|
||||
|
||||
# error cases
|
||||
# we won't ever see an OPEN, since L4 will log+ignore those for us
|
||||
closing.upon(local_data, enter=closing, outputs=[error_closed_write])
|
||||
closing.upon(local_close, enter=closing, outputs=[error_closed_close])
|
||||
# the CLOSED state won't ever see messages, since we'll be deleted
|
||||
|
||||
# our endpoints use this
|
||||
|
||||
def _set_protocol(self, protocol):
|
||||
assert not self._protocol
|
||||
self._protocol = protocol
|
||||
if self._pending_dataReceived:
|
||||
for data in self._pending_dataReceived:
|
||||
self._protocol.dataReceived(data)
|
||||
self._pending_dataReceived = []
|
||||
cl, what = self._pending_connectionLost
|
||||
if cl:
|
||||
self._protocol.connectionLost(what)
|
||||
|
||||
# ITransport
|
||||
def write(self, data):
|
||||
assert isinstance(data, type(b""))
|
||||
self.local_data(data)
|
||||
|
||||
def writeSequence(self, iovec):
|
||||
self.write(b"".join(iovec))
|
||||
|
||||
def loseConnection(self):
|
||||
self.local_close()
|
||||
|
||||
def getHost(self):
|
||||
# we define "host addr" as the overall wormhole
|
||||
return self._host_addr
|
||||
|
||||
def getPeer(self):
|
||||
# and "peer addr" as the subchannel within that wormhole
|
||||
return self._peer_addr
|
||||
|
||||
# IProducer: throttle inbound data (wormhole "up" to local app's Protocol)
|
||||
def stopProducing(self):
|
||||
self._manager.subchannel_stopProducing(self)
|
||||
|
||||
def pauseProducing(self):
|
||||
self._manager.subchannel_pauseProducing(self)
|
||||
|
||||
def resumeProducing(self):
|
||||
self._manager.subchannel_resumeProducing(self)
|
||||
|
||||
# IConsumer: allow the wormhole to throttle outbound data (app->wormhole)
|
||||
def registerProducer(self, producer, streaming):
|
||||
self._manager.subchannel_registerProducer(self, producer, streaming)
|
||||
|
||||
def unregisterProducer(self):
|
||||
self._manager.subchannel_unregisterProducer(self)
|
||||
|
||||
|
||||
@implementer(IStreamClientEndpoint)
|
||||
class ControlEndpoint(object):
|
||||
_used = False
|
||||
|
||||
def __init__(self, peer_addr):
|
||||
self._subchannel_zero = Deferred()
|
||||
self._peer_addr = peer_addr
|
||||
self._once = Once(SingleUseEndpointError)
|
||||
|
||||
# from manager
|
||||
def _subchannel_zero_opened(self, subchannel):
|
||||
assert ISubChannel.providedBy(subchannel), subchannel
|
||||
self._subchannel_zero.callback(subchannel)
|
||||
|
||||
@inlineCallbacks
|
||||
def connect(self, protocolFactory):
|
||||
# return Deferred that fires with IProtocol or Failure(ConnectError)
|
||||
self._once()
|
||||
t = yield self._subchannel_zero
|
||||
p = protocolFactory.buildProtocol(self._peer_addr)
|
||||
t._set_protocol(p)
|
||||
p.makeConnection(t) # set p.transport = t and call connectionMade()
|
||||
returnValue(p)
|
||||
|
||||
|
||||
@implementer(IStreamClientEndpoint)
|
||||
@attrs
|
||||
class SubchannelConnectorEndpoint(object):
|
||||
_manager = attrib(validator=provides(IDilationManager))
|
||||
_host_addr = attrib(validator=instance_of(_WormholeAddress))
|
||||
|
||||
def connect(self, protocolFactory):
|
||||
# return Deferred that fires with IProtocol or Failure(ConnectError)
|
||||
scid = self._manager.allocate_subchannel_id()
|
||||
self._manager.send_open(scid)
|
||||
peer_addr = _SubchannelAddress(scid)
|
||||
# ? f.doStart()
|
||||
# ? f.startedConnecting(CONNECTOR) # ??
|
||||
t = SubChannel(scid, self._manager, self._host_addr, peer_addr)
|
||||
p = protocolFactory.buildProtocol(peer_addr)
|
||||
t._set_protocol(p)
|
||||
p.makeConnection(t) # set p.transport = t and call connectionMade()
|
||||
return succeed(p)
|
||||
|
||||
|
||||
@implementer(IStreamServerEndpoint)
|
||||
@attrs
|
||||
class SubchannelListenerEndpoint(object):
|
||||
_manager = attrib(validator=provides(IDilationManager))
|
||||
_host_addr = attrib(validator=provides(IAddress))
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
self._factory = None
|
||||
self._pending_opens = []
|
||||
|
||||
# from manager
|
||||
def _got_open(self, t, peer_addr):
|
||||
if self._factory:
|
||||
self._connect(t, peer_addr)
|
||||
else:
|
||||
self._pending_opens.append((t, peer_addr))
|
||||
|
||||
def _connect(self, t, peer_addr):
|
||||
p = self._factory.buildProtocol(peer_addr)
|
||||
t._set_protocol(p)
|
||||
p.makeConnection(t)
|
||||
|
||||
# IStreamServerEndpoint
|
||||
|
||||
def listen(self, protocolFactory):
|
||||
self._factory = protocolFactory
|
||||
for (t, peer_addr) in self._pending_opens:
|
||||
self._connect(t, peer_addr)
|
||||
self._pending_opens = []
|
||||
lp = SubchannelListeningPort(self._host_addr)
|
||||
return succeed(lp)
|
||||
|
||||
|
||||
@implementer(IListeningPort)
|
||||
@attrs
|
||||
class SubchannelListeningPort(object):
|
||||
_host_addr = attrib(validator=provides(IAddress))
|
||||
|
||||
def startListening(self):
|
||||
pass
|
||||
|
||||
def stopListening(self):
|
||||
# TODO
|
||||
pass
|
||||
|
||||
def getHost(self):
|
||||
return self._host_addr
|
138
src/wormhole/_hints.py
Normal file
138
src/wormhole/_hints.py
Normal file
|
@ -0,0 +1,138 @@
|
|||
from __future__ import print_function, unicode_literals
|
||||
import sys
|
||||
import re
|
||||
import six
|
||||
from collections import namedtuple
|
||||
from twisted.internet.endpoints import HostnameEndpoint
|
||||
from twisted.python import log
|
||||
|
||||
# These namedtuples are "hint objects". The JSON-serializable dictionaries
|
||||
# are "hint dicts".
|
||||
|
||||
# DirectTCPV1Hint and TorTCPV1Hint mean the following protocol:
|
||||
# * make a TCP connection (possibly via Tor)
|
||||
# * send the sender/receiver handshake bytes first
|
||||
# * expect to see the receiver/sender handshake bytes from the other side
|
||||
# * the sender writes "go\n", the receiver waits for "go\n"
|
||||
# * the rest of the connection contains transit data
|
||||
DirectTCPV1Hint = namedtuple("DirectTCPV1Hint",
|
||||
["hostname", "port", "priority"])
|
||||
TorTCPV1Hint = namedtuple("TorTCPV1Hint", ["hostname", "port", "priority"])
|
||||
# RelayV1Hint contains a tuple of DirectTCPV1Hint and TorTCPV1Hint hints (we
|
||||
# use a tuple rather than a list so they'll be hashable into a set). For each
|
||||
# one, make the TCP connection, send the relay handshake, then complete the
|
||||
# rest of the V1 protocol. Only one hint per relay is useful.
|
||||
RelayV1Hint = namedtuple("RelayV1Hint", ["hints"])
|
||||
|
||||
def describe_hint_obj(hint, relay, tor):
|
||||
prefix = "tor->" if tor else "->"
|
||||
if relay:
|
||||
prefix = prefix + "relay:"
|
||||
if isinstance(hint, DirectTCPV1Hint):
|
||||
return prefix + "tcp:%s:%d" % (hint.hostname, hint.port)
|
||||
elif isinstance(hint, TorTCPV1Hint):
|
||||
return prefix + "tor:%s:%d" % (hint.hostname, hint.port)
|
||||
else:
|
||||
return prefix + str(hint)
|
||||
|
||||
def parse_hint_argv(hint, stderr=sys.stderr):
|
||||
assert isinstance(hint, type(u""))
|
||||
# return tuple or None for an unparseable hint
|
||||
priority = 0.0
|
||||
mo = re.search(r'^([a-zA-Z0-9]+):(.*)$', hint)
|
||||
if not mo:
|
||||
print("unparseable hint '%s'" % (hint, ), file=stderr)
|
||||
return None
|
||||
hint_type = mo.group(1)
|
||||
if hint_type != "tcp":
|
||||
print("unknown hint type '%s' in '%s'" % (hint_type, hint),
|
||||
file=stderr)
|
||||
return None
|
||||
hint_value = mo.group(2)
|
||||
pieces = hint_value.split(":")
|
||||
if len(pieces) < 2:
|
||||
print("unparseable TCP hint (need more colons) '%s'" % (hint, ),
|
||||
file=stderr)
|
||||
return None
|
||||
mo = re.search(r'^(\d+)$', pieces[1])
|
||||
if not mo:
|
||||
print("non-numeric port in TCP hint '%s'" % (hint, ), file=stderr)
|
||||
return None
|
||||
hint_host = pieces[0]
|
||||
hint_port = int(pieces[1])
|
||||
for more in pieces[2:]:
|
||||
if more.startswith("priority="):
|
||||
more_pieces = more.split("=")
|
||||
try:
|
||||
priority = float(more_pieces[1])
|
||||
except ValueError:
|
||||
print("non-float priority= in TCP hint '%s'" % (hint, ),
|
||||
file=stderr)
|
||||
return None
|
||||
return DirectTCPV1Hint(hint_host, hint_port, priority)
|
||||
|
||||
def endpoint_from_hint_obj(hint, tor, reactor):
|
||||
if tor:
|
||||
if isinstance(hint, (DirectTCPV1Hint, TorTCPV1Hint)):
|
||||
# this Tor object will throw ValueError for non-public IPv4
|
||||
# addresses and any IPv6 address
|
||||
try:
|
||||
return tor.stream_via(hint.hostname, hint.port)
|
||||
except ValueError:
|
||||
return None
|
||||
return None
|
||||
if isinstance(hint, DirectTCPV1Hint):
|
||||
return HostnameEndpoint(reactor, hint.hostname, hint.port)
|
||||
return None
|
||||
|
||||
def parse_tcp_v1_hint(hint): # hint_struct -> hint_obj
|
||||
hint_type = hint.get("type", "")
|
||||
if hint_type not in ["direct-tcp-v1", "tor-tcp-v1"]:
|
||||
log.msg("unknown hint type: %r" % (hint, ))
|
||||
return None
|
||||
if not ("hostname" in hint and
|
||||
isinstance(hint["hostname"], type(""))):
|
||||
log.msg("invalid hostname in hint: %r" % (hint, ))
|
||||
return None
|
||||
if not ("port" in hint and
|
||||
isinstance(hint["port"], six.integer_types)):
|
||||
log.msg("invalid port in hint: %r" % (hint, ))
|
||||
return None
|
||||
priority = hint.get("priority", 0.0)
|
||||
if hint_type == "direct-tcp-v1":
|
||||
return DirectTCPV1Hint(hint["hostname"], hint["port"], priority)
|
||||
else:
|
||||
return TorTCPV1Hint(hint["hostname"], hint["port"], priority)
|
||||
|
||||
def parse_hint(hint_struct):
|
||||
hint_type = hint_struct.get("type", "")
|
||||
if hint_type == "relay-v1":
|
||||
# the struct can include multiple ways to reach the same relay
|
||||
rhints = filter(lambda h: h, # drop None (unrecognized)
|
||||
[parse_tcp_v1_hint(rh) for rh in hint_struct["hints"]])
|
||||
return RelayV1Hint(list(rhints))
|
||||
return parse_tcp_v1_hint(hint_struct)
|
||||
|
||||
|
||||
def encode_hint(h):
|
||||
if isinstance(h, DirectTCPV1Hint):
|
||||
return {"type": "direct-tcp-v1",
|
||||
"priority": h.priority,
|
||||
"hostname": h.hostname,
|
||||
"port": h.port, # integer
|
||||
}
|
||||
elif isinstance(h, RelayV1Hint):
|
||||
rhint = {"type": "relay-v1", "hints": []}
|
||||
for rh in h.hints:
|
||||
rhint["hints"].append({"type": "direct-tcp-v1",
|
||||
"priority": rh.priority,
|
||||
"hostname": rh.hostname,
|
||||
"port": rh.port})
|
||||
return rhint
|
||||
elif isinstance(h, TorTCPV1Hint):
|
||||
return {"type": "tor-tcp-v1",
|
||||
"priority": h.priority,
|
||||
"hostname": h.hostname,
|
||||
"port": h.port, # integer
|
||||
}
|
||||
raise ValueError("unknown hint type", h)
|
|
@ -433,3 +433,27 @@ class IInputHelper(Interface):
|
|||
|
||||
class IJournal(Interface): # TODO: this needs to be public
|
||||
pass
|
||||
|
||||
|
||||
class IDilator(Interface):
|
||||
pass
|
||||
|
||||
|
||||
class IDilationManager(Interface):
|
||||
pass
|
||||
|
||||
|
||||
class IDilationConnector(Interface):
|
||||
pass
|
||||
|
||||
|
||||
class ISubChannel(Interface):
|
||||
pass
|
||||
|
||||
|
||||
class IInbound(Interface):
|
||||
pass
|
||||
|
||||
|
||||
class IOutbound(Interface):
|
||||
pass
|
||||
|
|
|
@ -6,7 +6,6 @@ import six
|
|||
from attr import attrib, attrs
|
||||
from attr.validators import instance_of, provides
|
||||
from automat import MethodicalMachine
|
||||
from hkdf import Hkdf
|
||||
from nacl import utils
|
||||
from nacl.exceptions import CryptoError
|
||||
from nacl.secret import SecretBox
|
||||
|
@ -15,16 +14,12 @@ from zope.interface import implementer
|
|||
|
||||
from . import _interfaces
|
||||
from .util import (bytes_to_dict, bytes_to_hexstr, dict_to_bytes,
|
||||
hexstr_to_bytes, to_bytes)
|
||||
hexstr_to_bytes, to_bytes, HKDF)
|
||||
|
||||
CryptoError
|
||||
__all__ = ["derive_key", "derive_phase_key", "CryptoError", "Key"]
|
||||
|
||||
|
||||
def HKDF(skm, outlen, salt=None, CTXinfo=b""):
|
||||
return Hkdf(salt, skm).expand(CTXinfo, outlen)
|
||||
|
||||
|
||||
def derive_key(key, purpose, length=SecretBox.KEY_SIZE):
|
||||
if not isinstance(key, type(b"")):
|
||||
raise TypeError(type(key))
|
||||
|
|
|
@ -89,6 +89,7 @@ class RendezvousConnector(object):
|
|||
# if the initial connection fails, signal an error and shut down. do
|
||||
# this in a different reactor turn to avoid some hazards
|
||||
d.addBoth(lambda res: task.deferLater(self._reactor, 0.0, lambda: res))
|
||||
# TODO: use EventualQueue
|
||||
d.addErrback(self._initial_connection_failed)
|
||||
self._debug_record_inbound_f = None
|
||||
|
||||
|
|
|
@ -70,3 +70,24 @@ class SequenceObserver(object):
|
|||
if self._observers:
|
||||
d = self._observers.pop(0)
|
||||
self._eq.eventually(d.callback, self._results.pop(0))
|
||||
|
||||
|
||||
class EmptyableSet(set):
|
||||
# manage a set which grows and shrinks over time. Fire a Deferred the first
|
||||
# time it becomes empty after you start watching for it.
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._eq = kwargs.pop("_eventual_queue") # required
|
||||
super(EmptyableSet, self).__init__(*args, **kwargs)
|
||||
self._observer = None
|
||||
|
||||
def when_next_empty(self):
|
||||
if not self._observer:
|
||||
self._observer = OneShotObserver(self._eq)
|
||||
return self._observer.when_fired()
|
||||
|
||||
def discard(self, o):
|
||||
super(EmptyableSet, self).discard(o)
|
||||
if self._observer and not self:
|
||||
self._observer.fire(None)
|
||||
self._observer = None
|
||||
|
|
0
src/wormhole/test/dilate/__init__.py
Normal file
0
src/wormhole/test/dilate/__init__.py
Normal file
21
src/wormhole/test/dilate/common.py
Normal file
21
src/wormhole/test/dilate/common.py
Normal file
|
@ -0,0 +1,21 @@
|
|||
from __future__ import print_function, unicode_literals
|
||||
import mock
|
||||
from zope.interface import alsoProvides
|
||||
from ..._interfaces import IDilationManager, IWormhole
|
||||
|
||||
|
||||
def mock_manager():
|
||||
m = mock.Mock()
|
||||
alsoProvides(m, IDilationManager)
|
||||
return m
|
||||
|
||||
|
||||
def mock_wormhole():
|
||||
m = mock.Mock()
|
||||
alsoProvides(m, IWormhole)
|
||||
return m
|
||||
|
||||
|
||||
def clear_mock_calls(*args):
|
||||
for a in args:
|
||||
a.mock_calls[:] = []
|
217
src/wormhole/test/dilate/test_connection.py
Normal file
217
src/wormhole/test/dilate/test_connection.py
Normal file
|
@ -0,0 +1,217 @@
|
|||
from __future__ import print_function, unicode_literals
|
||||
import mock
|
||||
from zope.interface import alsoProvides
|
||||
from twisted.trial import unittest
|
||||
from twisted.internet.task import Clock
|
||||
from twisted.internet.interfaces import ITransport
|
||||
from ...eventual import EventualQueue
|
||||
from ..._interfaces import IDilationConnector
|
||||
from ..._dilation.roles import LEADER, FOLLOWER
|
||||
from ..._dilation.connection import (DilatedConnectionProtocol, encode_record,
|
||||
KCM, Open, Ack)
|
||||
from .common import clear_mock_calls
|
||||
|
||||
|
||||
def make_con(role, use_relay=False):
|
||||
clock = Clock()
|
||||
eq = EventualQueue(clock)
|
||||
connector = mock.Mock()
|
||||
alsoProvides(connector, IDilationConnector)
|
||||
n = mock.Mock() # pretends to be a Noise object
|
||||
n.write_message = mock.Mock(side_effect=[b"handshake"])
|
||||
c = DilatedConnectionProtocol(eq, role, connector, n,
|
||||
b"outbound_prologue\n", b"inbound_prologue\n")
|
||||
if use_relay:
|
||||
c.use_relay(b"relay_handshake\n")
|
||||
t = mock.Mock()
|
||||
alsoProvides(t, ITransport)
|
||||
return c, n, connector, t, eq
|
||||
|
||||
|
||||
class Connection(unittest.TestCase):
|
||||
def test_bad_prologue(self):
|
||||
c, n, connector, t, eq = make_con(LEADER)
|
||||
c.makeConnection(t)
|
||||
d = c.when_disconnected()
|
||||
self.assertEqual(n.mock_calls, [mock.call.start_handshake()])
|
||||
self.assertEqual(connector.mock_calls, [])
|
||||
self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")])
|
||||
clear_mock_calls(n, connector, t)
|
||||
|
||||
c.dataReceived(b"prologue\n")
|
||||
self.assertEqual(n.mock_calls, [])
|
||||
self.assertEqual(connector.mock_calls, [])
|
||||
self.assertEqual(t.mock_calls, [mock.call.loseConnection()])
|
||||
|
||||
eq.flush_sync()
|
||||
self.assertNoResult(d)
|
||||
c.connectionLost(b"why")
|
||||
eq.flush_sync()
|
||||
self.assertIdentical(self.successResultOf(d), c)
|
||||
|
||||
def _test_no_relay(self, role):
|
||||
c, n, connector, t, eq = make_con(role)
|
||||
t_kcm = KCM()
|
||||
t_open = Open(seqnum=1, scid=0x11223344)
|
||||
t_ack = Ack(resp_seqnum=2)
|
||||
n.decrypt = mock.Mock(side_effect=[
|
||||
encode_record(t_kcm),
|
||||
encode_record(t_open),
|
||||
])
|
||||
exp_kcm = b"\x00\x00\x00\x03kcm"
|
||||
n.encrypt = mock.Mock(side_effect=[b"kcm", b"ack1"])
|
||||
m = mock.Mock() # Manager
|
||||
|
||||
c.makeConnection(t)
|
||||
self.assertEqual(n.mock_calls, [mock.call.start_handshake()])
|
||||
self.assertEqual(connector.mock_calls, [])
|
||||
self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")])
|
||||
clear_mock_calls(n, connector, t, m)
|
||||
|
||||
c.dataReceived(b"inbound_prologue\n")
|
||||
self.assertEqual(n.mock_calls, [mock.call.write_message()])
|
||||
self.assertEqual(connector.mock_calls, [])
|
||||
exp_handshake = b"\x00\x00\x00\x09handshake"
|
||||
self.assertEqual(t.mock_calls, [mock.call.write(exp_handshake)])
|
||||
clear_mock_calls(n, connector, t, m)
|
||||
|
||||
c.dataReceived(b"\x00\x00\x00\x0Ahandshake2")
|
||||
if role is LEADER:
|
||||
# we're the leader, so we don't send the KCM right away
|
||||
self.assertEqual(n.mock_calls, [
|
||||
mock.call.read_message(b"handshake2")])
|
||||
self.assertEqual(connector.mock_calls, [])
|
||||
self.assertEqual(t.mock_calls, [])
|
||||
self.assertEqual(c._manager, None)
|
||||
else:
|
||||
# we're the follower, so we encrypt and send the KCM immediately
|
||||
self.assertEqual(n.mock_calls, [
|
||||
mock.call.read_message(b"handshake2"),
|
||||
mock.call.encrypt(encode_record(t_kcm)),
|
||||
])
|
||||
self.assertEqual(connector.mock_calls, [])
|
||||
self.assertEqual(t.mock_calls, [
|
||||
mock.call.write(exp_kcm)])
|
||||
self.assertEqual(c._manager, None)
|
||||
clear_mock_calls(n, connector, t, m)
|
||||
|
||||
c.dataReceived(b"\x00\x00\x00\x03KCM")
|
||||
# leader: inbound KCM means we add the candidate
|
||||
# follower: inbound KCM means we've been selected.
|
||||
# in both cases we notify Connector.add_candidate(), and the Connector
|
||||
# decides if/when to call .select()
|
||||
|
||||
self.assertEqual(n.mock_calls, [mock.call.decrypt(b"KCM")])
|
||||
self.assertEqual(connector.mock_calls, [mock.call.add_candidate(c)])
|
||||
self.assertEqual(t.mock_calls, [])
|
||||
clear_mock_calls(n, connector, t, m)
|
||||
|
||||
# now pretend this connection wins (either the Leader decides to use
|
||||
# this one among all the candiates, or we're the Follower and the
|
||||
# Connector is reacting to add_candidate() by recognizing we're the
|
||||
# only candidate there is)
|
||||
c.select(m)
|
||||
self.assertIdentical(c._manager, m)
|
||||
if role is LEADER:
|
||||
# TODO: currently Connector.select_and_stop_remaining() is
|
||||
# responsible for sending the KCM just before calling c.select()
|
||||
# iff we're the LEADER, therefore Connection.select won't send
|
||||
# anything. This should be moved to c.select().
|
||||
self.assertEqual(n.mock_calls, [])
|
||||
self.assertEqual(connector.mock_calls, [])
|
||||
self.assertEqual(t.mock_calls, [])
|
||||
self.assertEqual(m.mock_calls, [])
|
||||
|
||||
c.send_record(KCM())
|
||||
self.assertEqual(n.mock_calls, [
|
||||
mock.call.encrypt(encode_record(t_kcm)),
|
||||
])
|
||||
self.assertEqual(connector.mock_calls, [])
|
||||
self.assertEqual(t.mock_calls, [mock.call.write(exp_kcm)])
|
||||
self.assertEqual(m.mock_calls, [])
|
||||
else:
|
||||
# follower: we already sent the KCM, do nothing
|
||||
self.assertEqual(n.mock_calls, [])
|
||||
self.assertEqual(connector.mock_calls, [])
|
||||
self.assertEqual(t.mock_calls, [])
|
||||
self.assertEqual(m.mock_calls, [])
|
||||
clear_mock_calls(n, connector, t, m)
|
||||
|
||||
c.dataReceived(b"\x00\x00\x00\x04msg1")
|
||||
self.assertEqual(n.mock_calls, [mock.call.decrypt(b"msg1")])
|
||||
self.assertEqual(connector.mock_calls, [])
|
||||
self.assertEqual(t.mock_calls, [])
|
||||
self.assertEqual(m.mock_calls, [mock.call.got_record(t_open)])
|
||||
clear_mock_calls(n, connector, t, m)
|
||||
|
||||
c.send_record(t_ack)
|
||||
exp_ack = b"\x06\x00\x00\x00\x02"
|
||||
self.assertEqual(n.mock_calls, [mock.call.encrypt(exp_ack)])
|
||||
self.assertEqual(connector.mock_calls, [])
|
||||
self.assertEqual(t.mock_calls, [mock.call.write(b"\x00\x00\x00\x04ack1")])
|
||||
self.assertEqual(m.mock_calls, [])
|
||||
clear_mock_calls(n, connector, t, m)
|
||||
|
||||
c.disconnect()
|
||||
self.assertEqual(n.mock_calls, [])
|
||||
self.assertEqual(connector.mock_calls, [])
|
||||
self.assertEqual(t.mock_calls, [mock.call.loseConnection()])
|
||||
self.assertEqual(m.mock_calls, [])
|
||||
clear_mock_calls(n, connector, t, m)
|
||||
|
||||
def test_no_relay_leader(self):
|
||||
return self._test_no_relay(LEADER)
|
||||
|
||||
def test_no_relay_follower(self):
|
||||
return self._test_no_relay(FOLLOWER)
|
||||
|
||||
def test_relay(self):
|
||||
c, n, connector, t, eq = make_con(LEADER, use_relay=True)
|
||||
|
||||
c.makeConnection(t)
|
||||
self.assertEqual(n.mock_calls, [mock.call.start_handshake()])
|
||||
self.assertEqual(connector.mock_calls, [])
|
||||
self.assertEqual(t.mock_calls, [mock.call.write(b"relay_handshake\n")])
|
||||
clear_mock_calls(n, connector, t)
|
||||
|
||||
c.dataReceived(b"ok\n")
|
||||
self.assertEqual(n.mock_calls, [])
|
||||
self.assertEqual(connector.mock_calls, [])
|
||||
self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")])
|
||||
clear_mock_calls(n, connector, t)
|
||||
|
||||
c.dataReceived(b"inbound_prologue\n")
|
||||
self.assertEqual(n.mock_calls, [mock.call.write_message()])
|
||||
self.assertEqual(connector.mock_calls, [])
|
||||
exp_handshake = b"\x00\x00\x00\x09handshake"
|
||||
self.assertEqual(t.mock_calls, [mock.call.write(exp_handshake)])
|
||||
clear_mock_calls(n, connector, t)
|
||||
|
||||
def test_relay_jilted(self):
|
||||
c, n, connector, t, eq = make_con(LEADER, use_relay=True)
|
||||
d = c.when_disconnected()
|
||||
|
||||
c.makeConnection(t)
|
||||
self.assertEqual(n.mock_calls, [mock.call.start_handshake()])
|
||||
self.assertEqual(connector.mock_calls, [])
|
||||
self.assertEqual(t.mock_calls, [mock.call.write(b"relay_handshake\n")])
|
||||
clear_mock_calls(n, connector, t)
|
||||
|
||||
c.connectionLost(b"why")
|
||||
eq.flush_sync()
|
||||
self.assertIdentical(self.successResultOf(d), c)
|
||||
|
||||
def test_relay_bad_response(self):
|
||||
c, n, connector, t, eq = make_con(LEADER, use_relay=True)
|
||||
|
||||
c.makeConnection(t)
|
||||
self.assertEqual(n.mock_calls, [mock.call.start_handshake()])
|
||||
self.assertEqual(connector.mock_calls, [])
|
||||
self.assertEqual(t.mock_calls, [mock.call.write(b"relay_handshake\n")])
|
||||
clear_mock_calls(n, connector, t)
|
||||
|
||||
c.dataReceived(b"not ok\n")
|
||||
self.assertEqual(n.mock_calls, [])
|
||||
self.assertEqual(connector.mock_calls, [])
|
||||
self.assertEqual(t.mock_calls, [mock.call.loseConnection()])
|
||||
clear_mock_calls(n, connector, t)
|
463
src/wormhole/test/dilate/test_connector.py
Normal file
463
src/wormhole/test/dilate/test_connector.py
Normal file
|
@ -0,0 +1,463 @@
|
|||
from __future__ import print_function, unicode_literals
|
||||
|
||||
import mock
|
||||
from zope.interface import alsoProvides
|
||||
from twisted.trial import unittest
|
||||
from twisted.internet.task import Clock
|
||||
from twisted.internet.defer import Deferred
|
||||
from ...eventual import EventualQueue
|
||||
from ..._interfaces import IDilationManager, IDilationConnector
|
||||
from ..._hints import DirectTCPV1Hint, RelayV1Hint, TorTCPV1Hint
|
||||
from ..._dilation import roles
|
||||
from ..._dilation._noise import NoiseConnection
|
||||
from ..._dilation.connection import KCM
|
||||
from ..._dilation.connector import (Connector,
|
||||
build_sided_relay_handshake,
|
||||
build_noise,
|
||||
OutboundConnectionFactory,
|
||||
InboundConnectionFactory,
|
||||
PROLOGUE_LEADER, PROLOGUE_FOLLOWER,
|
||||
)
|
||||
from .common import clear_mock_calls
|
||||
|
||||
class Handshake(unittest.TestCase):
|
||||
def test_build(self):
|
||||
key = b"k"*32
|
||||
side = "12345678abcdabcd"
|
||||
self.assertEqual(build_sided_relay_handshake(key, side),
|
||||
b"please relay 3f4147851dbd2589d25b654ee9fb35ed0d3e5f19c5c5403e8e6a195c70f0577a for side 12345678abcdabcd\n")
|
||||
|
||||
class Outbound(unittest.TestCase):
|
||||
def test_no_relay(self):
|
||||
c = mock.Mock()
|
||||
alsoProvides(c, IDilationConnector)
|
||||
p0 = mock.Mock()
|
||||
c.build_protocol = mock.Mock(return_value=p0)
|
||||
relay_handshake = None
|
||||
f = OutboundConnectionFactory(c, relay_handshake)
|
||||
addr = object()
|
||||
p = f.buildProtocol(addr)
|
||||
self.assertIdentical(p, p0)
|
||||
self.assertEqual(c.mock_calls, [mock.call.build_protocol(addr)])
|
||||
self.assertEqual(p.mock_calls, [])
|
||||
self.assertIdentical(p.factory, f)
|
||||
|
||||
def test_with_relay(self):
|
||||
c = mock.Mock()
|
||||
alsoProvides(c, IDilationConnector)
|
||||
p0 = mock.Mock()
|
||||
c.build_protocol = mock.Mock(return_value=p0)
|
||||
relay_handshake = b"relay handshake"
|
||||
f = OutboundConnectionFactory(c, relay_handshake)
|
||||
addr = object()
|
||||
p = f.buildProtocol(addr)
|
||||
self.assertIdentical(p, p0)
|
||||
self.assertEqual(c.mock_calls, [mock.call.build_protocol(addr)])
|
||||
self.assertEqual(p.mock_calls, [mock.call.use_relay(relay_handshake)])
|
||||
self.assertIdentical(p.factory, f)
|
||||
|
||||
class Inbound(unittest.TestCase):
|
||||
def test_build(self):
|
||||
c = mock.Mock()
|
||||
alsoProvides(c, IDilationConnector)
|
||||
p0 = mock.Mock()
|
||||
c.build_protocol = mock.Mock(return_value=p0)
|
||||
f = InboundConnectionFactory(c)
|
||||
addr = object()
|
||||
p = f.buildProtocol(addr)
|
||||
self.assertIdentical(p, p0)
|
||||
self.assertEqual(c.mock_calls, [mock.call.build_protocol(addr)])
|
||||
self.assertIdentical(p.factory, f)
|
||||
|
||||
def make_connector(listen=True, tor=False, relay=None, role=roles.LEADER):
|
||||
class Holder:
|
||||
pass
|
||||
h = Holder()
|
||||
h.dilation_key = b"key"
|
||||
h.relay = relay
|
||||
h.manager = mock.Mock()
|
||||
alsoProvides(h.manager, IDilationManager)
|
||||
h.clock = Clock()
|
||||
h.reactor = h.clock
|
||||
h.eq = EventualQueue(h.clock)
|
||||
h.tor = None
|
||||
if tor:
|
||||
h.tor = mock.Mock()
|
||||
timing = None
|
||||
h.side = u"abcd1234abcd5678"
|
||||
h.role = role
|
||||
c = Connector(h.dilation_key, h.relay, h.manager, h.reactor, h.eq,
|
||||
not listen, h.tor, timing, h.side, h.role)
|
||||
return c, h
|
||||
|
||||
class TestConnector(unittest.TestCase):
|
||||
def test_build(self):
|
||||
c, h = make_connector()
|
||||
c, h = make_connector(relay="tcp:host:1234")
|
||||
|
||||
def test_connection_abilities(self):
|
||||
self.assertEqual(Connector.get_connection_abilities(),
|
||||
[{"type": "direct-tcp-v1"},
|
||||
{"type": "relay-v1"},
|
||||
])
|
||||
|
||||
def test_build_noise(self):
|
||||
if not NoiseConnection:
|
||||
raise unittest.SkipTest("noiseprotocol unavailable")
|
||||
build_noise()
|
||||
|
||||
def test_build_protocol_leader(self):
|
||||
c, h = make_connector(role=roles.LEADER)
|
||||
n0 = mock.Mock()
|
||||
p0 = mock.Mock()
|
||||
addr = object()
|
||||
with mock.patch("wormhole._dilation.connector.build_noise",
|
||||
return_value=n0) as bn:
|
||||
with mock.patch("wormhole._dilation.connector.DilatedConnectionProtocol",
|
||||
return_value=p0) as dcp:
|
||||
p = c.build_protocol(addr)
|
||||
self.assertEqual(bn.mock_calls, [mock.call()])
|
||||
self.assertEqual(n0.mock_calls, [mock.call.set_psks(h.dilation_key),
|
||||
mock.call.set_as_initiator()])
|
||||
self.assertIdentical(p, p0)
|
||||
self.assertEqual(dcp.mock_calls,
|
||||
[mock.call(h.eq, h.role, c, n0,
|
||||
PROLOGUE_LEADER, PROLOGUE_FOLLOWER)])
|
||||
|
||||
def test_build_protocol_follower(self):
|
||||
c, h = make_connector(role=roles.FOLLOWER)
|
||||
n0 = mock.Mock()
|
||||
p0 = mock.Mock()
|
||||
addr = object()
|
||||
with mock.patch("wormhole._dilation.connector.build_noise",
|
||||
return_value=n0) as bn:
|
||||
with mock.patch("wormhole._dilation.connector.DilatedConnectionProtocol",
|
||||
return_value=p0) as dcp:
|
||||
p = c.build_protocol(addr)
|
||||
self.assertEqual(bn.mock_calls, [mock.call()])
|
||||
self.assertEqual(n0.mock_calls, [mock.call.set_psks(h.dilation_key),
|
||||
mock.call.set_as_responder()])
|
||||
self.assertIdentical(p, p0)
|
||||
self.assertEqual(dcp.mock_calls,
|
||||
[mock.call(h.eq, h.role, c, n0,
|
||||
PROLOGUE_FOLLOWER, PROLOGUE_LEADER)])
|
||||
|
||||
def test_start_stop(self):
|
||||
c, h = make_connector(listen=False, relay=None, role=roles.LEADER)
|
||||
c.start()
|
||||
# no relays, so it publishes no hints
|
||||
self.assertEqual(h.manager.mock_calls, [])
|
||||
# and no listener, so nothing happens until we provide a hint
|
||||
c.stop()
|
||||
# we stop while we're connecting, so no connections must be stopped
|
||||
|
||||
def test_empty(self):
|
||||
c, h = make_connector(listen=False, relay=None, role=roles.LEADER)
|
||||
c._schedule_connection = mock.Mock()
|
||||
c.start()
|
||||
# no relays, so it publishes no hints
|
||||
self.assertEqual(h.manager.mock_calls, [])
|
||||
# and no listener, so nothing happens until we provide a hint
|
||||
self.assertEqual(c._schedule_connection.mock_calls, [])
|
||||
c.stop()
|
||||
|
||||
def test_basic(self):
|
||||
c, h = make_connector(listen=False, relay=None, role=roles.LEADER)
|
||||
c._schedule_connection = mock.Mock()
|
||||
c.start()
|
||||
# no relays, so it publishes no hints
|
||||
self.assertEqual(h.manager.mock_calls, [])
|
||||
# and no listener, so nothing happens until we provide a hint
|
||||
self.assertEqual(c._schedule_connection.mock_calls, [])
|
||||
|
||||
hint = DirectTCPV1Hint("foo", 55, 0.0)
|
||||
c.got_hints([hint])
|
||||
|
||||
# received hints don't get published
|
||||
self.assertEqual(h.manager.mock_calls, [])
|
||||
# they just schedule a connection
|
||||
self.assertEqual(c._schedule_connection.mock_calls,
|
||||
[mock.call(0.0, DirectTCPV1Hint("foo", 55, 0.0),
|
||||
is_relay=False)])
|
||||
|
||||
def test_listen_addresses(self):
|
||||
c, h = make_connector(listen=True, role=roles.LEADER)
|
||||
with mock.patch("wormhole.ipaddrs.find_addresses",
|
||||
return_value=["127.0.0.1", "1.2.3.4", "5.6.7.8"]):
|
||||
self.assertEqual(c._get_listener_addresses(),
|
||||
["1.2.3.4", "5.6.7.8"])
|
||||
with mock.patch("wormhole.ipaddrs.find_addresses",
|
||||
return_value=["127.0.0.1"]):
|
||||
# some test hosts, including the appveyor VMs, *only* have
|
||||
# 127.0.0.1, and the tests will hang badly if we remove it.
|
||||
self.assertEqual(c._get_listener_addresses(), ["127.0.0.1"])
|
||||
|
||||
def test_listen(self):
|
||||
c, h = make_connector(listen=True, role=roles.LEADER)
|
||||
c._start_listener = mock.Mock()
|
||||
with mock.patch("wormhole.ipaddrs.find_addresses",
|
||||
return_value=["127.0.0.1", "1.2.3.4", "5.6.7.8"]):
|
||||
c.start()
|
||||
self.assertEqual(c._start_listener.mock_calls,
|
||||
[mock.call(["1.2.3.4", "5.6.7.8"])])
|
||||
|
||||
def test_start_listen(self):
|
||||
c, h = make_connector(listen=True, role=roles.LEADER)
|
||||
ep = mock.Mock()
|
||||
d = Deferred()
|
||||
ep.listen = mock.Mock(return_value=d)
|
||||
with mock.patch("wormhole._dilation.connector.serverFromString",
|
||||
return_value=ep) as sfs:
|
||||
c._start_listener(["1.2.3.4", "5.6.7.8"])
|
||||
self.assertEqual(sfs.mock_calls, [mock.call(h.reactor, "tcp:0")])
|
||||
lp = mock.Mock()
|
||||
host = mock.Mock()
|
||||
host.port = 66
|
||||
lp.getHost = mock.Mock(return_value=host)
|
||||
d.callback(lp)
|
||||
self.assertEqual(h.manager.mock_calls,
|
||||
[mock.call.send_hints([{"type": "direct-tcp-v1",
|
||||
"hostname": "1.2.3.4",
|
||||
"port": 66,
|
||||
"priority": 0.0
|
||||
},
|
||||
{"type": "direct-tcp-v1",
|
||||
"hostname": "5.6.7.8",
|
||||
"port": 66,
|
||||
"priority": 0.0
|
||||
},
|
||||
])])
|
||||
|
||||
def test_schedule_connection_no_relay(self):
|
||||
c, h = make_connector(listen=True, role=roles.LEADER)
|
||||
hint = DirectTCPV1Hint("foo", 55, 0.0)
|
||||
ep = mock.Mock()
|
||||
with mock.patch("wormhole._dilation.connector.endpoint_from_hint_obj",
|
||||
side_effect=[ep]) as efho:
|
||||
c._schedule_connection(0.0, hint, False)
|
||||
self.assertEqual(efho.mock_calls, [mock.call(hint, h.tor, h.reactor)])
|
||||
self.assertEqual(ep.mock_calls, [])
|
||||
d = Deferred()
|
||||
ep.connect = mock.Mock(side_effect=[d])
|
||||
# direct hints are scheduled for T+0.0
|
||||
f = mock.Mock()
|
||||
with mock.patch("wormhole._dilation.connector.OutboundConnectionFactory",
|
||||
return_value=f) as ocf:
|
||||
h.clock.advance(1.0)
|
||||
self.assertEqual(ocf.mock_calls, [mock.call(c, None)])
|
||||
self.assertEqual(ep.connect.mock_calls, [mock.call(f)])
|
||||
p = mock.Mock()
|
||||
d.callback(p)
|
||||
self.assertEqual(p.mock_calls,
|
||||
[mock.call.when_disconnected(),
|
||||
mock.call.when_disconnected().addCallback(c._pending_connections.discard)])
|
||||
|
||||
def test_schedule_connection_relay(self):
|
||||
c, h = make_connector(listen=True, role=roles.LEADER)
|
||||
hint = DirectTCPV1Hint("foo", 55, 0.0)
|
||||
ep = mock.Mock()
|
||||
with mock.patch("wormhole._dilation.connector.endpoint_from_hint_obj",
|
||||
side_effect=[ep]) as efho:
|
||||
c._schedule_connection(0.0, hint, True)
|
||||
self.assertEqual(efho.mock_calls, [mock.call(hint, h.tor, h.reactor)])
|
||||
self.assertEqual(ep.mock_calls, [])
|
||||
d = Deferred()
|
||||
ep.connect = mock.Mock(side_effect=[d])
|
||||
# direct hints are scheduled for T+0.0
|
||||
f = mock.Mock()
|
||||
with mock.patch("wormhole._dilation.connector.OutboundConnectionFactory",
|
||||
return_value=f) as ocf:
|
||||
h.clock.advance(1.0)
|
||||
handshake = build_sided_relay_handshake(h.dilation_key, h.side)
|
||||
self.assertEqual(ocf.mock_calls, [mock.call(c, handshake)])
|
||||
|
||||
def test_listen_but_tor(self):
|
||||
c, h = make_connector(listen=True, tor=True, role=roles.LEADER)
|
||||
with mock.patch("wormhole.ipaddrs.find_addresses",
|
||||
return_value=["127.0.0.1", "1.2.3.4", "5.6.7.8"]) as fa:
|
||||
c.start()
|
||||
# don't even look up addresses
|
||||
self.assertEqual(fa.mock_calls, [])
|
||||
# no relays and the listener isn't ready yet, so no hints yet
|
||||
self.assertEqual(h.manager.mock_calls, [])
|
||||
|
||||
def test_no_listen(self):
|
||||
c, h = make_connector(listen=False, tor=False, role=roles.LEADER)
|
||||
with mock.patch("wormhole.ipaddrs.find_addresses",
|
||||
return_value=["127.0.0.1", "1.2.3.4", "5.6.7.8"]) as fa:
|
||||
c.start()
|
||||
# don't even look up addresses
|
||||
self.assertEqual(fa.mock_calls, [])
|
||||
self.assertEqual(h.manager.mock_calls, [])
|
||||
|
||||
def test_relay_delay(self):
|
||||
# given a direct connection and a relay, we should see the direct
|
||||
# connection initiated at T+0 seconds, and the relay at T+RELAY_DELAY
|
||||
c, h = make_connector(listen=True, relay=None, role=roles.LEADER)
|
||||
c._schedule_connection = mock.Mock()
|
||||
c._start_listener = mock.Mock()
|
||||
c.start()
|
||||
hint1 = DirectTCPV1Hint("foo", 55, 0.0)
|
||||
hint2 = DirectTCPV1Hint("bar", 55, 0.0)
|
||||
hint3 = RelayV1Hint([DirectTCPV1Hint("relay", 55, 0.0)])
|
||||
c.got_hints([hint1, hint2, hint3])
|
||||
self.assertEqual(c._schedule_connection.mock_calls,
|
||||
[mock.call(0.0, hint1, is_relay=False),
|
||||
mock.call(0.0, hint2, is_relay=False),
|
||||
mock.call(c.RELAY_DELAY, hint3.hints[0], is_relay=True),
|
||||
])
|
||||
|
||||
def test_initial_relay(self):
|
||||
c, h = make_connector(listen=False, relay="tcp:foo:55", role=roles.LEADER)
|
||||
c._schedule_connection = mock.Mock()
|
||||
c.start()
|
||||
self.assertEqual(h.manager.mock_calls,
|
||||
[mock.call.send_hints([{"type": "relay-v1",
|
||||
"hints": [
|
||||
{"type": "direct-tcp-v1",
|
||||
"hostname": "foo",
|
||||
"port": 55,
|
||||
"priority": 0.0
|
||||
},
|
||||
],
|
||||
}])])
|
||||
self.assertEqual(c._schedule_connection.mock_calls,
|
||||
[mock.call(0.0, DirectTCPV1Hint("foo", 55, 0.0),
|
||||
is_relay=True)])
|
||||
|
||||
def test_add_relay(self):
|
||||
c, h = make_connector(listen=False, relay=None, role=roles.LEADER)
|
||||
c._schedule_connection = mock.Mock()
|
||||
c.start()
|
||||
self.assertEqual(h.manager.mock_calls, [])
|
||||
self.assertEqual(c._schedule_connection.mock_calls, [])
|
||||
hint = RelayV1Hint([DirectTCPV1Hint("foo", 55, 0.0)])
|
||||
c.add_relay([hint])
|
||||
self.assertEqual(h.manager.mock_calls,
|
||||
[mock.call.send_hints([{"type": "relay-v1",
|
||||
"hints": [
|
||||
{"type": "direct-tcp-v1",
|
||||
"hostname": "foo",
|
||||
"port": 55,
|
||||
"priority": 0.0
|
||||
},
|
||||
],
|
||||
}])])
|
||||
self.assertEqual(c._schedule_connection.mock_calls,
|
||||
[mock.call(0.0, DirectTCPV1Hint("foo", 55, 0.0),
|
||||
is_relay=True)])
|
||||
|
||||
def test_tor_no_manager(self):
|
||||
# tor hints should be ignored if we don't have a Tor manager to use them
|
||||
c, h = make_connector(listen=False, role=roles.LEADER)
|
||||
c._schedule_connection = mock.Mock()
|
||||
c.start()
|
||||
hint = TorTCPV1Hint("foo", 55, 0.0)
|
||||
c.got_hints([hint])
|
||||
self.assertEqual(h.manager.mock_calls, [])
|
||||
self.assertEqual(c._schedule_connection.mock_calls, [])
|
||||
|
||||
def test_tor_with_manager(self):
|
||||
# tor hints should be processed if we do have a Tor manager
|
||||
c, h = make_connector(listen=False, tor=True, role=roles.LEADER)
|
||||
c._schedule_connection = mock.Mock()
|
||||
c.start()
|
||||
hint = TorTCPV1Hint("foo", 55, 0.0)
|
||||
c.got_hints([hint])
|
||||
self.assertEqual(c._schedule_connection.mock_calls,
|
||||
[mock.call(0.0, hint, is_relay=False)])
|
||||
|
||||
def test_priorities(self):
|
||||
# given two hints with different priorities, we should somehow prefer
|
||||
# one. This is a placeholder to fill in once we implement priorities.
|
||||
pass
|
||||
|
||||
|
||||
class Race(unittest.TestCase):
|
||||
def test_one_leader(self):
|
||||
c, h = make_connector(listen=True, role=roles.LEADER)
|
||||
lp = mock.Mock()
|
||||
def start_listener(addresses):
|
||||
c._listeners.add(lp)
|
||||
c._start_listener = start_listener
|
||||
c._schedule_connection = mock.Mock()
|
||||
c.start()
|
||||
self.assertEqual(c._listeners, set([lp]))
|
||||
|
||||
p1 = mock.Mock() # DilatedConnectionProtocol instance
|
||||
c.add_candidate(p1)
|
||||
self.assertEqual(h.manager.mock_calls, [])
|
||||
h.eq.flush_sync()
|
||||
self.assertEqual(h.manager.mock_calls, [mock.call.use_connection(p1)])
|
||||
self.assertEqual(p1.mock_calls,
|
||||
[mock.call.select(h.manager),
|
||||
mock.call.send_record(KCM())])
|
||||
self.assertEqual(lp.mock_calls[0], mock.call.stopListening())
|
||||
# stop_listeners() uses a DeferredList, so we ignore the second call
|
||||
|
||||
def test_one_follower(self):
|
||||
c, h = make_connector(listen=True, role=roles.FOLLOWER)
|
||||
lp = mock.Mock()
|
||||
def start_listener(addresses):
|
||||
c._listeners.add(lp)
|
||||
c._start_listener = start_listener
|
||||
c._schedule_connection = mock.Mock()
|
||||
c.start()
|
||||
self.assertEqual(c._listeners, set([lp]))
|
||||
|
||||
p1 = mock.Mock() # DilatedConnectionProtocol instance
|
||||
c.add_candidate(p1)
|
||||
self.assertEqual(h.manager.mock_calls, [])
|
||||
h.eq.flush_sync()
|
||||
self.assertEqual(h.manager.mock_calls, [mock.call.use_connection(p1)])
|
||||
# just like LEADER, but follower doesn't send KCM now (it sent one
|
||||
# earlier, to tell the leader that this connection looks viable)
|
||||
self.assertEqual(p1.mock_calls,
|
||||
[mock.call.select(h.manager)])
|
||||
self.assertEqual(lp.mock_calls[0], mock.call.stopListening())
|
||||
# stop_listeners() uses a DeferredList, so we ignore the second call
|
||||
|
||||
# TODO: make sure a pending connection is abandoned when the listener
|
||||
# answers successfully
|
||||
|
||||
# TODO: make sure a second pending connection is abandoned when the first
|
||||
# connection succeeds
|
||||
|
||||
def test_late(self):
|
||||
c, h = make_connector(listen=False, role=roles.LEADER)
|
||||
c._schedule_connection = mock.Mock()
|
||||
c.start()
|
||||
|
||||
p1 = mock.Mock() # DilatedConnectionProtocol instance
|
||||
c.add_candidate(p1)
|
||||
self.assertEqual(h.manager.mock_calls, [])
|
||||
h.eq.flush_sync()
|
||||
self.assertEqual(h.manager.mock_calls, [mock.call.use_connection(p1)])
|
||||
clear_mock_calls(h.manager)
|
||||
self.assertEqual(p1.mock_calls,
|
||||
[mock.call.select(h.manager),
|
||||
mock.call.send_record(KCM())])
|
||||
|
||||
# late connection is ignored
|
||||
p2 = mock.Mock()
|
||||
c.add_candidate(p2)
|
||||
self.assertEqual(h.manager.mock_calls, [])
|
||||
|
||||
|
||||
# make sure an established connection is dropped when stop() is called
|
||||
def test_stop(self):
|
||||
c, h = make_connector(listen=False, role=roles.LEADER)
|
||||
c._schedule_connection = mock.Mock()
|
||||
c.start()
|
||||
|
||||
p1 = mock.Mock() # DilatedConnectionProtocol instance
|
||||
c.add_candidate(p1)
|
||||
self.assertEqual(h.manager.mock_calls, [])
|
||||
h.eq.flush_sync()
|
||||
self.assertEqual(h.manager.mock_calls, [mock.call.use_connection(p1)])
|
||||
self.assertEqual(p1.mock_calls,
|
||||
[mock.call.select(h.manager),
|
||||
mock.call.send_record(KCM())])
|
||||
|
||||
c.stop()
|
||||
|
26
src/wormhole/test/dilate/test_encoding.py
Normal file
26
src/wormhole/test/dilate/test_encoding.py
Normal file
|
@ -0,0 +1,26 @@
|
|||
from __future__ import print_function, unicode_literals
|
||||
from twisted.trial import unittest
|
||||
from ..._dilation.encode import to_be4, from_be4
|
||||
|
||||
|
||||
class Encoding(unittest.TestCase):
|
||||
|
||||
def test_be4(self):
|
||||
self.assertEqual(to_be4(0), b"\x00\x00\x00\x00")
|
||||
self.assertEqual(to_be4(1), b"\x00\x00\x00\x01")
|
||||
self.assertEqual(to_be4(256), b"\x00\x00\x01\x00")
|
||||
self.assertEqual(to_be4(257), b"\x00\x00\x01\x01")
|
||||
with self.assertRaises(ValueError):
|
||||
to_be4(-1)
|
||||
with self.assertRaises(ValueError):
|
||||
to_be4(2**32)
|
||||
|
||||
self.assertEqual(from_be4(b"\x00\x00\x00\x00"), 0)
|
||||
self.assertEqual(from_be4(b"\x00\x00\x00\x01"), 1)
|
||||
self.assertEqual(from_be4(b"\x00\x00\x01\x00"), 256)
|
||||
self.assertEqual(from_be4(b"\x00\x00\x01\x01"), 257)
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
from_be4(0)
|
||||
with self.assertRaises(ValueError):
|
||||
from_be4(b"\x01\x00\x00\x00\x00")
|
98
src/wormhole/test/dilate/test_endpoints.py
Normal file
98
src/wormhole/test/dilate/test_endpoints.py
Normal file
|
@ -0,0 +1,98 @@
|
|||
from __future__ import print_function, unicode_literals
|
||||
import mock
|
||||
from zope.interface import alsoProvides
|
||||
from twisted.trial import unittest
|
||||
from ..._interfaces import ISubChannel
|
||||
from ..._dilation.subchannel import (ControlEndpoint,
|
||||
SubchannelConnectorEndpoint,
|
||||
SubchannelListenerEndpoint,
|
||||
SubchannelListeningPort,
|
||||
_WormholeAddress, _SubchannelAddress,
|
||||
SingleUseEndpointError)
|
||||
from .common import mock_manager
|
||||
|
||||
|
||||
class Endpoints(unittest.TestCase):
|
||||
def test_control(self):
|
||||
scid0 = b"scid0"
|
||||
peeraddr = _SubchannelAddress(scid0)
|
||||
ep = ControlEndpoint(peeraddr)
|
||||
|
||||
f = mock.Mock()
|
||||
p = mock.Mock()
|
||||
f.buildProtocol = mock.Mock(return_value=p)
|
||||
d = ep.connect(f)
|
||||
self.assertNoResult(d)
|
||||
|
||||
t = mock.Mock()
|
||||
alsoProvides(t, ISubChannel)
|
||||
ep._subchannel_zero_opened(t)
|
||||
self.assertIdentical(self.successResultOf(d), p)
|
||||
self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr)])
|
||||
self.assertEqual(t.mock_calls, [mock.call._set_protocol(p)])
|
||||
self.assertEqual(p.mock_calls, [mock.call.makeConnection(t)])
|
||||
|
||||
d = ep.connect(f)
|
||||
self.failureResultOf(d, SingleUseEndpointError)
|
||||
|
||||
def assert_makeConnection(self, mock_calls):
|
||||
self.assertEqual(len(mock_calls), 1)
|
||||
self.assertEqual(mock_calls[0][0], "makeConnection")
|
||||
self.assertEqual(len(mock_calls[0][1]), 1)
|
||||
return mock_calls[0][1][0]
|
||||
|
||||
def test_connector(self):
|
||||
m = mock_manager()
|
||||
m.allocate_subchannel_id = mock.Mock(return_value=b"scid")
|
||||
hostaddr = _WormholeAddress()
|
||||
peeraddr = _SubchannelAddress(b"scid")
|
||||
ep = SubchannelConnectorEndpoint(m, hostaddr)
|
||||
|
||||
f = mock.Mock()
|
||||
p = mock.Mock()
|
||||
t = mock.Mock()
|
||||
f.buildProtocol = mock.Mock(return_value=p)
|
||||
with mock.patch("wormhole._dilation.subchannel.SubChannel",
|
||||
return_value=t) as sc:
|
||||
d = ep.connect(f)
|
||||
self.assertIdentical(self.successResultOf(d), p)
|
||||
self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr)])
|
||||
self.assertEqual(sc.mock_calls, [mock.call(b"scid", m, hostaddr, peeraddr)])
|
||||
self.assertEqual(t.mock_calls, [mock.call._set_protocol(p)])
|
||||
self.assertEqual(p.mock_calls, [mock.call.makeConnection(t)])
|
||||
|
||||
def test_listener(self):
|
||||
m = mock_manager()
|
||||
m.allocate_subchannel_id = mock.Mock(return_value=b"scid")
|
||||
hostaddr = _WormholeAddress()
|
||||
ep = SubchannelListenerEndpoint(m, hostaddr)
|
||||
|
||||
f = mock.Mock()
|
||||
p1 = mock.Mock()
|
||||
p2 = mock.Mock()
|
||||
f.buildProtocol = mock.Mock(side_effect=[p1, p2])
|
||||
|
||||
# OPEN that arrives before we ep.listen() should be queued
|
||||
|
||||
t1 = mock.Mock()
|
||||
peeraddr1 = _SubchannelAddress(b"peer1")
|
||||
ep._got_open(t1, peeraddr1)
|
||||
|
||||
d = ep.listen(f)
|
||||
lp = self.successResultOf(d)
|
||||
self.assertIsInstance(lp, SubchannelListeningPort)
|
||||
|
||||
self.assertEqual(lp.getHost(), hostaddr)
|
||||
lp.startListening()
|
||||
|
||||
self.assertEqual(t1.mock_calls, [mock.call._set_protocol(p1)])
|
||||
self.assertEqual(p1.mock_calls, [mock.call.makeConnection(t1)])
|
||||
|
||||
t2 = mock.Mock()
|
||||
peeraddr2 = _SubchannelAddress(b"peer2")
|
||||
ep._got_open(t2, peeraddr2)
|
||||
|
||||
self.assertEqual(t2.mock_calls, [mock.call._set_protocol(p2)])
|
||||
self.assertEqual(p2.mock_calls, [mock.call.makeConnection(t2)])
|
||||
|
||||
lp.stopListening() # TODO: should this do more?
|
112
src/wormhole/test/dilate/test_framer.py
Normal file
112
src/wormhole/test/dilate/test_framer.py
Normal file
|
@ -0,0 +1,112 @@
|
|||
from __future__ import print_function, unicode_literals
|
||||
import mock
|
||||
from zope.interface import alsoProvides
|
||||
from twisted.trial import unittest
|
||||
from twisted.internet.interfaces import ITransport
|
||||
from ..._dilation.connection import _Framer, Frame, Prologue, Disconnect
|
||||
|
||||
|
||||
def make_framer():
|
||||
t = mock.Mock()
|
||||
alsoProvides(t, ITransport)
|
||||
f = _Framer(t, b"outbound_prologue\n", b"inbound_prologue\n")
|
||||
return f, t
|
||||
|
||||
|
||||
class Framer(unittest.TestCase):
|
||||
def test_bad_prologue_length(self):
|
||||
f, t = make_framer()
|
||||
self.assertEqual(t.mock_calls, [])
|
||||
|
||||
f.connectionMade()
|
||||
self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")])
|
||||
t.mock_calls[:] = []
|
||||
self.assertEqual([], list(f.add_and_parse(b"inbound_"))) # wait for it
|
||||
self.assertEqual(t.mock_calls, [])
|
||||
|
||||
with mock.patch("wormhole._dilation.connection.log.msg") as m:
|
||||
with self.assertRaises(Disconnect):
|
||||
list(f.add_and_parse(b"not the prologue after all"))
|
||||
self.assertEqual(m.mock_calls,
|
||||
[mock.call("bad prologue: {}".format(
|
||||
b"inbound_not the p"))])
|
||||
self.assertEqual(t.mock_calls, [])
|
||||
|
||||
def test_bad_prologue_newline(self):
|
||||
f, t = make_framer()
|
||||
self.assertEqual(t.mock_calls, [])
|
||||
|
||||
f.connectionMade()
|
||||
self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")])
|
||||
t.mock_calls[:] = []
|
||||
self.assertEqual([], list(f.add_and_parse(b"inbound_"))) # wait for it
|
||||
|
||||
self.assertEqual([], list(f.add_and_parse(b"not")))
|
||||
with mock.patch("wormhole._dilation.connection.log.msg") as m:
|
||||
with self.assertRaises(Disconnect):
|
||||
list(f.add_and_parse(b"\n"))
|
||||
self.assertEqual(m.mock_calls,
|
||||
[mock.call("bad prologue: {}".format(
|
||||
b"inbound_not\n"))])
|
||||
self.assertEqual(t.mock_calls, [])
|
||||
|
||||
def test_good_prologue(self):
|
||||
f, t = make_framer()
|
||||
self.assertEqual(t.mock_calls, [])
|
||||
|
||||
f.connectionMade()
|
||||
self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")])
|
||||
t.mock_calls[:] = []
|
||||
self.assertEqual([Prologue()],
|
||||
list(f.add_and_parse(b"inbound_prologue\n")))
|
||||
self.assertEqual(t.mock_calls, [])
|
||||
|
||||
# now send_frame should work
|
||||
f.send_frame(b"frame")
|
||||
self.assertEqual(t.mock_calls,
|
||||
[mock.call.write(b"\x00\x00\x00\x05frame")])
|
||||
|
||||
def test_bad_relay(self):
|
||||
f, t = make_framer()
|
||||
self.assertEqual(t.mock_calls, [])
|
||||
f.use_relay(b"relay handshake\n")
|
||||
|
||||
f.connectionMade()
|
||||
self.assertEqual(t.mock_calls, [mock.call.write(b"relay handshake\n")])
|
||||
t.mock_calls[:] = []
|
||||
with mock.patch("wormhole._dilation.connection.log.msg") as m:
|
||||
with self.assertRaises(Disconnect):
|
||||
list(f.add_and_parse(b"goodbye\n"))
|
||||
self.assertEqual(m.mock_calls,
|
||||
[mock.call("bad relay_ok: {}".format(b"goo"))])
|
||||
self.assertEqual(t.mock_calls, [])
|
||||
|
||||
def test_good_relay(self):
|
||||
f, t = make_framer()
|
||||
self.assertEqual(t.mock_calls, [])
|
||||
f.use_relay(b"relay handshake\n")
|
||||
self.assertEqual(t.mock_calls, [])
|
||||
|
||||
f.connectionMade()
|
||||
self.assertEqual(t.mock_calls, [mock.call.write(b"relay handshake\n")])
|
||||
t.mock_calls[:] = []
|
||||
|
||||
self.assertEqual([], list(f.add_and_parse(b"ok\n")))
|
||||
self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")])
|
||||
|
||||
def test_frame(self):
|
||||
f, t = make_framer()
|
||||
self.assertEqual(t.mock_calls, [])
|
||||
|
||||
f.connectionMade()
|
||||
self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")])
|
||||
t.mock_calls[:] = []
|
||||
self.assertEqual([Prologue()],
|
||||
list(f.add_and_parse(b"inbound_prologue\n")))
|
||||
self.assertEqual(t.mock_calls, [])
|
||||
|
||||
encoded_frame = b"\x00\x00\x00\x05frame"
|
||||
self.assertEqual([], list(f.add_and_parse(encoded_frame[:2])))
|
||||
self.assertEqual([], list(f.add_and_parse(encoded_frame[2:6])))
|
||||
self.assertEqual([Frame(frame=b"frame")],
|
||||
list(f.add_and_parse(encoded_frame[6:])))
|
174
src/wormhole/test/dilate/test_inbound.py
Normal file
174
src/wormhole/test/dilate/test_inbound.py
Normal file
|
@ -0,0 +1,174 @@
|
|||
from __future__ import print_function, unicode_literals
|
||||
import mock
|
||||
from zope.interface import alsoProvides
|
||||
from twisted.trial import unittest
|
||||
from ..._interfaces import IDilationManager
|
||||
from ..._dilation.connection import Open, Data, Close
|
||||
from ..._dilation.inbound import (Inbound, DuplicateOpenError,
|
||||
DataForMissingSubchannelError,
|
||||
CloseForMissingSubchannelError)
|
||||
|
||||
|
||||
def make_inbound():
|
||||
m = mock.Mock()
|
||||
alsoProvides(m, IDilationManager)
|
||||
host_addr = object()
|
||||
i = Inbound(m, host_addr)
|
||||
return i, m, host_addr
|
||||
|
||||
|
||||
class InboundTest(unittest.TestCase):
|
||||
def test_seqnum(self):
|
||||
i, m, host_addr = make_inbound()
|
||||
r1 = Open(scid=513, seqnum=1)
|
||||
r2 = Data(scid=513, seqnum=2, data=b"")
|
||||
r3 = Close(scid=513, seqnum=3)
|
||||
self.assertFalse(i.is_record_old(r1))
|
||||
self.assertFalse(i.is_record_old(r2))
|
||||
self.assertFalse(i.is_record_old(r3))
|
||||
|
||||
i.update_ack_watermark(r1)
|
||||
self.assertTrue(i.is_record_old(r1))
|
||||
self.assertFalse(i.is_record_old(r2))
|
||||
self.assertFalse(i.is_record_old(r3))
|
||||
|
||||
i.update_ack_watermark(r2)
|
||||
self.assertTrue(i.is_record_old(r1))
|
||||
self.assertTrue(i.is_record_old(r2))
|
||||
self.assertFalse(i.is_record_old(r3))
|
||||
|
||||
def test_open_data_close(self):
|
||||
i, m, host_addr = make_inbound()
|
||||
scid1 = b"scid"
|
||||
scid2 = b"scXX"
|
||||
c = mock.Mock()
|
||||
lep = mock.Mock()
|
||||
i.set_listener_endpoint(lep)
|
||||
i.use_connection(c)
|
||||
sc1 = mock.Mock()
|
||||
peer_addr = object()
|
||||
with mock.patch("wormhole._dilation.inbound.SubChannel",
|
||||
side_effect=[sc1]) as sc:
|
||||
with mock.patch("wormhole._dilation.inbound._SubchannelAddress",
|
||||
side_effect=[peer_addr]) as sca:
|
||||
i.handle_open(scid1)
|
||||
self.assertEqual(lep.mock_calls, [mock.call._got_open(sc1, peer_addr)])
|
||||
self.assertEqual(sc.mock_calls, [mock.call(scid1, m, host_addr, peer_addr)])
|
||||
self.assertEqual(sca.mock_calls, [mock.call(scid1)])
|
||||
lep.mock_calls[:] = []
|
||||
|
||||
# a subsequent duplicate OPEN should be ignored
|
||||
with mock.patch("wormhole._dilation.inbound.SubChannel",
|
||||
side_effect=[sc1]) as sc:
|
||||
with mock.patch("wormhole._dilation.inbound._SubchannelAddress",
|
||||
side_effect=[peer_addr]) as sca:
|
||||
i.handle_open(scid1)
|
||||
self.assertEqual(lep.mock_calls, [])
|
||||
self.assertEqual(sc.mock_calls, [])
|
||||
self.assertEqual(sca.mock_calls, [])
|
||||
self.flushLoggedErrors(DuplicateOpenError)
|
||||
|
||||
i.handle_data(scid1, b"data")
|
||||
self.assertEqual(sc1.mock_calls, [mock.call.remote_data(b"data")])
|
||||
sc1.mock_calls[:] = []
|
||||
|
||||
i.handle_data(scid2, b"for non-existent subchannel")
|
||||
self.assertEqual(sc1.mock_calls, [])
|
||||
self.flushLoggedErrors(DataForMissingSubchannelError)
|
||||
|
||||
i.handle_close(scid1)
|
||||
self.assertEqual(sc1.mock_calls, [mock.call.remote_close()])
|
||||
sc1.mock_calls[:] = []
|
||||
|
||||
i.handle_close(scid2)
|
||||
self.assertEqual(sc1.mock_calls, [])
|
||||
self.flushLoggedErrors(CloseForMissingSubchannelError)
|
||||
|
||||
# after the subchannel is closed, the Manager will notify Inbound
|
||||
i.subchannel_closed(scid1, sc1)
|
||||
|
||||
i.stop_using_connection()
|
||||
|
||||
def test_control_channel(self):
|
||||
i, m, host_addr = make_inbound()
|
||||
lep = mock.Mock()
|
||||
i.set_listener_endpoint(lep)
|
||||
|
||||
scid0 = b"scid"
|
||||
sc0 = mock.Mock()
|
||||
i.set_subchannel_zero(scid0, sc0)
|
||||
|
||||
# OPEN on the control channel identifier should be ignored as a
|
||||
# duplicate, since the control channel is already registered
|
||||
sc1 = mock.Mock()
|
||||
peer_addr = object()
|
||||
with mock.patch("wormhole._dilation.inbound.SubChannel",
|
||||
side_effect=[sc1]) as sc:
|
||||
with mock.patch("wormhole._dilation.inbound._SubchannelAddress",
|
||||
side_effect=[peer_addr]) as sca:
|
||||
i.handle_open(scid0)
|
||||
self.assertEqual(lep.mock_calls, [])
|
||||
self.assertEqual(sc.mock_calls, [])
|
||||
self.assertEqual(sca.mock_calls, [])
|
||||
self.flushLoggedErrors(DuplicateOpenError)
|
||||
|
||||
# and DATA to it should be delivered correctly
|
||||
i.handle_data(scid0, b"data")
|
||||
self.assertEqual(sc0.mock_calls, [mock.call.remote_data(b"data")])
|
||||
sc0.mock_calls[:] = []
|
||||
|
||||
def test_pause(self):
|
||||
i, m, host_addr = make_inbound()
|
||||
c = mock.Mock()
|
||||
lep = mock.Mock()
|
||||
i.set_listener_endpoint(lep)
|
||||
|
||||
# add two subchannels, pause one, then add a connection
|
||||
scid1 = b"sci1"
|
||||
scid2 = b"sci2"
|
||||
sc1 = mock.Mock()
|
||||
sc2 = mock.Mock()
|
||||
peer_addr = object()
|
||||
with mock.patch("wormhole._dilation.inbound.SubChannel",
|
||||
side_effect=[sc1, sc2]):
|
||||
with mock.patch("wormhole._dilation.inbound._SubchannelAddress",
|
||||
return_value=peer_addr):
|
||||
i.handle_open(scid1)
|
||||
i.handle_open(scid2)
|
||||
self.assertEqual(c.mock_calls, [])
|
||||
|
||||
i.subchannel_pauseProducing(sc1)
|
||||
self.assertEqual(c.mock_calls, [])
|
||||
i.subchannel_resumeProducing(sc1)
|
||||
self.assertEqual(c.mock_calls, [])
|
||||
i.subchannel_pauseProducing(sc1)
|
||||
self.assertEqual(c.mock_calls, [])
|
||||
|
||||
i.use_connection(c)
|
||||
self.assertEqual(c.mock_calls, [mock.call.pauseProducing()])
|
||||
c.mock_calls[:] = []
|
||||
|
||||
i.subchannel_resumeProducing(sc1)
|
||||
self.assertEqual(c.mock_calls, [mock.call.resumeProducing()])
|
||||
c.mock_calls[:] = []
|
||||
|
||||
# consumers aren't really supposed to do this, but tolerate it
|
||||
i.subchannel_resumeProducing(sc1)
|
||||
self.assertEqual(c.mock_calls, [])
|
||||
|
||||
i.subchannel_pauseProducing(sc1)
|
||||
self.assertEqual(c.mock_calls, [mock.call.pauseProducing()])
|
||||
c.mock_calls[:] = []
|
||||
i.subchannel_pauseProducing(sc2)
|
||||
self.assertEqual(c.mock_calls, []) # was already paused
|
||||
|
||||
# tolerate duplicate pauseProducing
|
||||
i.subchannel_pauseProducing(sc2)
|
||||
self.assertEqual(c.mock_calls, [])
|
||||
|
||||
# stopProducing is treated like a terminal resumeProducing
|
||||
i.subchannel_stopProducing(sc1)
|
||||
self.assertEqual(c.mock_calls, [])
|
||||
i.subchannel_stopProducing(sc2)
|
||||
self.assertEqual(c.mock_calls, [mock.call.resumeProducing()])
|
||||
c.mock_calls[:] = []
|
650
src/wormhole/test/dilate/test_manager.py
Normal file
650
src/wormhole/test/dilate/test_manager.py
Normal file
|
@ -0,0 +1,650 @@
|
|||
from __future__ import print_function, unicode_literals
|
||||
from zope.interface import alsoProvides
|
||||
from twisted.trial import unittest
|
||||
from twisted.internet.defer import Deferred
|
||||
from twisted.internet.task import Clock, Cooperator
|
||||
import mock
|
||||
from ...eventual import EventualQueue
|
||||
from ..._interfaces import ISend, IDilationManager
|
||||
from ...util import dict_to_bytes
|
||||
from ..._dilation import roles
|
||||
from ..._dilation.encode import to_be4
|
||||
from ..._dilation.manager import (Dilator, Manager, make_side,
|
||||
OldPeerCannotDilateError,
|
||||
UnknownDilationMessageType,
|
||||
UnexpectedKCM,
|
||||
UnknownMessageType)
|
||||
from ..._dilation.subchannel import _WormholeAddress
|
||||
from ..._dilation.connection import Open, Data, Close, Ack, KCM, Ping, Pong
|
||||
from .common import clear_mock_calls
|
||||
|
||||
|
||||
def make_dilator():
|
||||
reactor = object()
|
||||
clock = Clock()
|
||||
eq = EventualQueue(clock)
|
||||
term = mock.Mock(side_effect=lambda: True) # one write per Eventual tick
|
||||
|
||||
def term_factory():
|
||||
return term
|
||||
coop = Cooperator(terminationPredicateFactory=term_factory,
|
||||
scheduler=eq.eventually)
|
||||
send = mock.Mock()
|
||||
alsoProvides(send, ISend)
|
||||
dil = Dilator(reactor, eq, coop)
|
||||
dil.wire(send)
|
||||
return dil, send, reactor, eq, clock, coop
|
||||
|
||||
|
||||
class TestDilator(unittest.TestCase):
|
||||
def test_manager_and_endpoints(self):
|
||||
dil, send, reactor, eq, clock, coop = make_dilator()
|
||||
d1 = dil.dilate()
|
||||
d2 = dil.dilate()
|
||||
self.assertNoResult(d1)
|
||||
self.assertNoResult(d2)
|
||||
|
||||
key = b"key"
|
||||
transit_key = object()
|
||||
with mock.patch("wormhole._dilation.manager.derive_key",
|
||||
return_value=transit_key) as dk:
|
||||
dil.got_key(key)
|
||||
self.assertEqual(dk.mock_calls, [mock.call(key, b"dilation-v1", 32)])
|
||||
self.assertIdentical(dil._transit_key, transit_key)
|
||||
self.assertNoResult(d1)
|
||||
self.assertNoResult(d2)
|
||||
|
||||
m = mock.Mock()
|
||||
alsoProvides(m, IDilationManager)
|
||||
m.when_first_connected.return_value = wfc_d = Deferred()
|
||||
with mock.patch("wormhole._dilation.manager.Manager",
|
||||
return_value=m) as ml:
|
||||
with mock.patch("wormhole._dilation.manager.make_side",
|
||||
return_value="us"):
|
||||
dil.got_wormhole_versions({"can-dilate": ["1"]})
|
||||
# that should create the Manager
|
||||
self.assertEqual(ml.mock_calls, [mock.call(send, "us", transit_key,
|
||||
None, reactor, eq, coop)])
|
||||
# and tell it to start, and get wait-for-it-to-connect Deferred
|
||||
self.assertEqual(m.mock_calls, [mock.call.start(),
|
||||
mock.call.when_first_connected(),
|
||||
])
|
||||
clear_mock_calls(m)
|
||||
self.assertNoResult(d1)
|
||||
self.assertNoResult(d2)
|
||||
|
||||
host_addr = _WormholeAddress()
|
||||
m_wa = mock.patch("wormhole._dilation.manager._WormholeAddress",
|
||||
return_value=host_addr)
|
||||
peer_addr = object()
|
||||
m_sca = mock.patch("wormhole._dilation.manager._SubchannelAddress",
|
||||
return_value=peer_addr)
|
||||
ce = mock.Mock()
|
||||
m_ce = mock.patch("wormhole._dilation.manager.ControlEndpoint",
|
||||
return_value=ce)
|
||||
sc = mock.Mock()
|
||||
m_sc = mock.patch("wormhole._dilation.manager.SubChannel",
|
||||
return_value=sc)
|
||||
|
||||
lep = object()
|
||||
m_sle = mock.patch("wormhole._dilation.manager.SubchannelListenerEndpoint",
|
||||
return_value=lep)
|
||||
|
||||
with m_wa, m_sca, m_ce as m_ce_m, m_sc as m_sc_m, m_sle as m_sle_m:
|
||||
wfc_d.callback(None)
|
||||
eq.flush_sync()
|
||||
scid0 = b"\x00\x00\x00\x00"
|
||||
self.assertEqual(m_ce_m.mock_calls, [mock.call(peer_addr)])
|
||||
self.assertEqual(m_sc_m.mock_calls,
|
||||
[mock.call(scid0, m, host_addr, peer_addr)])
|
||||
self.assertEqual(ce.mock_calls, [mock.call._subchannel_zero_opened(sc)])
|
||||
self.assertEqual(m_sle_m.mock_calls, [mock.call(m, host_addr)])
|
||||
self.assertEqual(m.mock_calls,
|
||||
[mock.call.set_subchannel_zero(scid0, sc),
|
||||
mock.call.set_listener_endpoint(lep),
|
||||
])
|
||||
clear_mock_calls(m)
|
||||
|
||||
eps = self.successResultOf(d1)
|
||||
self.assertEqual(eps, self.successResultOf(d2))
|
||||
d3 = dil.dilate()
|
||||
eq.flush_sync()
|
||||
self.assertEqual(eps, self.successResultOf(d3))
|
||||
|
||||
# all subsequent DILATE-n messages should get passed to the manager
|
||||
self.assertEqual(m.mock_calls, [])
|
||||
pleasemsg = dict(type="please", side="them")
|
||||
dil.received_dilate(dict_to_bytes(pleasemsg))
|
||||
self.assertEqual(m.mock_calls, [mock.call.rx_PLEASE(pleasemsg)])
|
||||
clear_mock_calls(m)
|
||||
|
||||
hintmsg = dict(type="connection-hints")
|
||||
dil.received_dilate(dict_to_bytes(hintmsg))
|
||||
self.assertEqual(m.mock_calls, [mock.call.rx_HINTS(hintmsg)])
|
||||
clear_mock_calls(m)
|
||||
|
||||
# we're nominally the LEADER, and the leader would not normally be
|
||||
# receiving a RECONNECT, but since we've mocked out the Manager it
|
||||
# won't notice
|
||||
dil.received_dilate(dict_to_bytes(dict(type="reconnect")))
|
||||
self.assertEqual(m.mock_calls, [mock.call.rx_RECONNECT()])
|
||||
clear_mock_calls(m)
|
||||
|
||||
dil.received_dilate(dict_to_bytes(dict(type="reconnecting")))
|
||||
self.assertEqual(m.mock_calls, [mock.call.rx_RECONNECTING()])
|
||||
clear_mock_calls(m)
|
||||
|
||||
dil.received_dilate(dict_to_bytes(dict(type="unknown")))
|
||||
self.assertEqual(m.mock_calls, [])
|
||||
self.flushLoggedErrors(UnknownDilationMessageType)
|
||||
|
||||
def test_peer_cannot_dilate(self):
|
||||
dil, send, reactor, eq, clock, coop = make_dilator()
|
||||
d1 = dil.dilate()
|
||||
self.assertNoResult(d1)
|
||||
|
||||
dil._transit_key = b"\x01" * 32
|
||||
dil.got_wormhole_versions({}) # missing "can-dilate"
|
||||
eq.flush_sync()
|
||||
f = self.failureResultOf(d1)
|
||||
f.check(OldPeerCannotDilateError)
|
||||
|
||||
def test_disjoint_versions(self):
|
||||
dil, send, reactor, eq, clock, coop = make_dilator()
|
||||
d1 = dil.dilate()
|
||||
self.assertNoResult(d1)
|
||||
|
||||
dil._transit_key = b"key"
|
||||
dil.got_wormhole_versions({"can-dilate": [-1]})
|
||||
eq.flush_sync()
|
||||
f = self.failureResultOf(d1)
|
||||
f.check(OldPeerCannotDilateError)
|
||||
|
||||
def test_early_dilate_messages(self):
|
||||
dil, send, reactor, eq, clock, coop = make_dilator()
|
||||
dil._transit_key = b"key"
|
||||
d1 = dil.dilate()
|
||||
self.assertNoResult(d1)
|
||||
pleasemsg = dict(type="please", side="them")
|
||||
dil.received_dilate(dict_to_bytes(pleasemsg))
|
||||
hintmsg = dict(type="connection-hints")
|
||||
dil.received_dilate(dict_to_bytes(hintmsg))
|
||||
|
||||
m = mock.Mock()
|
||||
alsoProvides(m, IDilationManager)
|
||||
m.when_first_connected.return_value = Deferred()
|
||||
|
||||
with mock.patch("wormhole._dilation.manager.Manager",
|
||||
return_value=m) as ml:
|
||||
with mock.patch("wormhole._dilation.manager.make_side",
|
||||
return_value="us"):
|
||||
dil.got_wormhole_versions({"can-dilate": ["1"]})
|
||||
self.assertEqual(ml.mock_calls, [mock.call(send, "us", b"key",
|
||||
None, reactor, eq, coop)])
|
||||
self.assertEqual(m.mock_calls, [mock.call.start(),
|
||||
mock.call.rx_PLEASE(pleasemsg),
|
||||
mock.call.rx_HINTS(hintmsg),
|
||||
mock.call.when_first_connected()])
|
||||
|
||||
def test_transit_relay(self):
|
||||
dil, send, reactor, eq, clock, coop = make_dilator()
|
||||
dil._transit_key = b"key"
|
||||
relay = object()
|
||||
d1 = dil.dilate(transit_relay_location=relay)
|
||||
self.assertNoResult(d1)
|
||||
|
||||
with mock.patch("wormhole._dilation.manager.Manager") as ml:
|
||||
with mock.patch("wormhole._dilation.manager.make_side",
|
||||
return_value="us"):
|
||||
dil.got_wormhole_versions({"can-dilate": ["1"]})
|
||||
self.assertEqual(ml.mock_calls, [mock.call(send, "us", b"key",
|
||||
relay, reactor, eq, coop),
|
||||
mock.call().start(),
|
||||
mock.call().when_first_connected()])
|
||||
|
||||
|
||||
LEADER = "ff3456abcdef"
|
||||
FOLLOWER = "123456abcdef"
|
||||
|
||||
|
||||
def make_manager(leader=True):
|
||||
class Holder:
|
||||
pass
|
||||
h = Holder()
|
||||
h.send = mock.Mock()
|
||||
alsoProvides(h.send, ISend)
|
||||
if leader:
|
||||
side = LEADER
|
||||
else:
|
||||
side = FOLLOWER
|
||||
h.key = b"\x00" * 32
|
||||
h.relay = None
|
||||
h.reactor = object()
|
||||
h.clock = Clock()
|
||||
h.eq = EventualQueue(h.clock)
|
||||
term = mock.Mock(side_effect=lambda: True) # one write per Eventual tick
|
||||
|
||||
def term_factory():
|
||||
return term
|
||||
h.coop = Cooperator(terminationPredicateFactory=term_factory,
|
||||
scheduler=h.eq.eventually)
|
||||
h.inbound = mock.Mock()
|
||||
h.Inbound = mock.Mock(return_value=h.inbound)
|
||||
h.outbound = mock.Mock()
|
||||
h.Outbound = mock.Mock(return_value=h.outbound)
|
||||
h.hostaddr = object()
|
||||
with mock.patch("wormhole._dilation.manager.Inbound", h.Inbound):
|
||||
with mock.patch("wormhole._dilation.manager.Outbound", h.Outbound):
|
||||
with mock.patch("wormhole._dilation.manager._WormholeAddress",
|
||||
return_value=h.hostaddr):
|
||||
m = Manager(h.send, side, h.key, h.relay, h.reactor, h.eq, h.coop)
|
||||
return m, h
|
||||
|
||||
|
||||
class TestManager(unittest.TestCase):
|
||||
def test_make_side(self):
|
||||
side = make_side()
|
||||
self.assertEqual(type(side), type(u""))
|
||||
self.assertEqual(len(side), 2 * 6)
|
||||
|
||||
def test_create(self):
|
||||
m, h = make_manager()
|
||||
|
||||
def test_leader(self):
|
||||
m, h = make_manager(leader=True)
|
||||
self.assertEqual(h.send.mock_calls, [])
|
||||
self.assertEqual(h.Inbound.mock_calls, [mock.call(m, h.hostaddr)])
|
||||
self.assertEqual(h.Outbound.mock_calls, [mock.call(m, h.coop)])
|
||||
|
||||
m.start()
|
||||
self.assertEqual(h.send.mock_calls, [
|
||||
mock.call.send("dilate-0",
|
||||
dict_to_bytes({"type": "please", "side": LEADER}))
|
||||
])
|
||||
clear_mock_calls(h.send)
|
||||
|
||||
wfc_d = m.when_first_connected()
|
||||
self.assertNoResult(wfc_d)
|
||||
|
||||
# ignore early hints
|
||||
m.rx_HINTS({})
|
||||
self.assertEqual(h.send.mock_calls, [])
|
||||
|
||||
c = mock.Mock()
|
||||
connector = mock.Mock(return_value=c)
|
||||
with mock.patch("wormhole._dilation.manager.Connector", connector):
|
||||
# receiving this PLEASE triggers creation of the Connector
|
||||
m.rx_PLEASE({"side": FOLLOWER})
|
||||
self.assertEqual(h.send.mock_calls, [])
|
||||
self.assertEqual(connector.mock_calls, [
|
||||
mock.call(b"\x00" * 32, None, m, h.reactor, h.eq,
|
||||
False, # no_listen
|
||||
None, # tor
|
||||
None, # timing
|
||||
LEADER, roles.LEADER),
|
||||
])
|
||||
self.assertEqual(c.mock_calls, [mock.call.start()])
|
||||
clear_mock_calls(connector, c)
|
||||
|
||||
self.assertNoResult(wfc_d)
|
||||
|
||||
# now any inbound hints should get passed to our Connector
|
||||
with mock.patch("wormhole._dilation.manager.parse_hint",
|
||||
side_effect=["p1", None, "p3"]) as ph:
|
||||
m.rx_HINTS({"hints": [1, 2, 3]})
|
||||
self.assertEqual(ph.mock_calls, [mock.call(1), mock.call(2), mock.call(3)])
|
||||
self.assertEqual(c.mock_calls, [mock.call.got_hints(["p1", "p3"])])
|
||||
clear_mock_calls(ph, c)
|
||||
|
||||
# and we send out any (listening) hints from our Connector
|
||||
m.send_hints([1, 2])
|
||||
self.assertEqual(h.send.mock_calls, [
|
||||
mock.call.send("dilate-1",
|
||||
dict_to_bytes({"type": "connection-hints",
|
||||
"hints": [1, 2]}))
|
||||
])
|
||||
clear_mock_calls(h.send)
|
||||
|
||||
# the first successful connection fires when_first_connected(), so
|
||||
# the Dilator can create and return the endpoints
|
||||
c1 = mock.Mock()
|
||||
m.connector_connection_made(c1)
|
||||
|
||||
self.assertEqual(h.inbound.mock_calls, [mock.call.use_connection(c1)])
|
||||
self.assertEqual(h.outbound.mock_calls, [mock.call.use_connection(c1)])
|
||||
clear_mock_calls(h.inbound, h.outbound)
|
||||
|
||||
h.eq.flush_sync()
|
||||
self.successResultOf(wfc_d) # fires with None
|
||||
wfc_d2 = m.when_first_connected()
|
||||
h.eq.flush_sync()
|
||||
self.successResultOf(wfc_d2)
|
||||
|
||||
scid0 = b"\x00\x00\x00\x00"
|
||||
sc0 = mock.Mock()
|
||||
m.set_subchannel_zero(scid0, sc0)
|
||||
listen_ep = mock.Mock()
|
||||
m.set_listener_endpoint(listen_ep)
|
||||
self.assertEqual(h.inbound.mock_calls, [
|
||||
mock.call.set_subchannel_zero(scid0, sc0),
|
||||
mock.call.set_listener_endpoint(listen_ep),
|
||||
])
|
||||
clear_mock_calls(h.inbound)
|
||||
|
||||
# the Leader making a new outbound channel should get scid=1
|
||||
scid1 = to_be4(1)
|
||||
self.assertEqual(m.allocate_subchannel_id(), scid1)
|
||||
r1 = Open(10, scid1) # seqnum=10
|
||||
h.outbound.build_record = mock.Mock(return_value=r1)
|
||||
m.send_open(scid1)
|
||||
self.assertEqual(h.outbound.mock_calls, [
|
||||
mock.call.build_record(Open, scid1),
|
||||
mock.call.queue_and_send_record(r1),
|
||||
])
|
||||
clear_mock_calls(h.outbound)
|
||||
|
||||
r2 = Data(11, scid1, b"data")
|
||||
h.outbound.build_record = mock.Mock(return_value=r2)
|
||||
m.send_data(scid1, b"data")
|
||||
self.assertEqual(h.outbound.mock_calls, [
|
||||
mock.call.build_record(Data, scid1, b"data"),
|
||||
mock.call.queue_and_send_record(r2),
|
||||
])
|
||||
clear_mock_calls(h.outbound)
|
||||
|
||||
r3 = Close(12, scid1)
|
||||
h.outbound.build_record = mock.Mock(return_value=r3)
|
||||
m.send_close(scid1)
|
||||
self.assertEqual(h.outbound.mock_calls, [
|
||||
mock.call.build_record(Close, scid1),
|
||||
mock.call.queue_and_send_record(r3),
|
||||
])
|
||||
clear_mock_calls(h.outbound)
|
||||
|
||||
# ack the OPEN
|
||||
m.got_record(Ack(10))
|
||||
self.assertEqual(h.outbound.mock_calls, [
|
||||
mock.call.handle_ack(10)
|
||||
])
|
||||
clear_mock_calls(h.outbound)
|
||||
|
||||
# test that inbound records get acked and routed to Inbound
|
||||
h.inbound.is_record_old = mock.Mock(return_value=False)
|
||||
scid2 = to_be4(2)
|
||||
o200 = Open(200, scid2)
|
||||
m.got_record(o200)
|
||||
self.assertEqual(h.outbound.mock_calls, [
|
||||
mock.call.send_if_connected(Ack(200))
|
||||
])
|
||||
self.assertEqual(h.inbound.mock_calls, [
|
||||
mock.call.is_record_old(o200),
|
||||
mock.call.update_ack_watermark(200),
|
||||
mock.call.handle_open(scid2),
|
||||
])
|
||||
clear_mock_calls(h.outbound, h.inbound)
|
||||
|
||||
# old (duplicate) records should provoke new Acks, but not get
|
||||
# forwarded
|
||||
h.inbound.is_record_old = mock.Mock(return_value=True)
|
||||
m.got_record(o200)
|
||||
self.assertEqual(h.outbound.mock_calls, [
|
||||
mock.call.send_if_connected(Ack(200))
|
||||
])
|
||||
self.assertEqual(h.inbound.mock_calls, [
|
||||
mock.call.is_record_old(o200),
|
||||
])
|
||||
clear_mock_calls(h.outbound, h.inbound)
|
||||
|
||||
# check Data and Close too
|
||||
h.inbound.is_record_old = mock.Mock(return_value=False)
|
||||
d201 = Data(201, scid2, b"data")
|
||||
m.got_record(d201)
|
||||
self.assertEqual(h.outbound.mock_calls, [
|
||||
mock.call.send_if_connected(Ack(201))
|
||||
])
|
||||
self.assertEqual(h.inbound.mock_calls, [
|
||||
mock.call.is_record_old(d201),
|
||||
mock.call.update_ack_watermark(201),
|
||||
mock.call.handle_data(scid2, b"data"),
|
||||
])
|
||||
clear_mock_calls(h.outbound, h.inbound)
|
||||
|
||||
c202 = Close(202, scid2)
|
||||
m.got_record(c202)
|
||||
self.assertEqual(h.outbound.mock_calls, [
|
||||
mock.call.send_if_connected(Ack(202))
|
||||
])
|
||||
self.assertEqual(h.inbound.mock_calls, [
|
||||
mock.call.is_record_old(c202),
|
||||
mock.call.update_ack_watermark(202),
|
||||
mock.call.handle_close(scid2),
|
||||
])
|
||||
clear_mock_calls(h.outbound, h.inbound)
|
||||
|
||||
# Now we lose the connection. The Leader should tell the other side
|
||||
# that we're reconnecting.
|
||||
|
||||
m.connector_connection_lost()
|
||||
self.assertEqual(h.send.mock_calls, [
|
||||
mock.call.send("dilate-2",
|
||||
dict_to_bytes({"type": "reconnect"}))
|
||||
])
|
||||
self.assertEqual(h.inbound.mock_calls, [
|
||||
mock.call.stop_using_connection()
|
||||
])
|
||||
self.assertEqual(h.outbound.mock_calls, [
|
||||
mock.call.stop_using_connection()
|
||||
])
|
||||
clear_mock_calls(h.send, h.inbound, h.outbound)
|
||||
|
||||
# leader does nothing (stays in FLUSHING) until the follower acks by
|
||||
# sending RECONNECTING
|
||||
|
||||
# inbound hints should be ignored during FLUSHING
|
||||
with mock.patch("wormhole._dilation.manager.parse_hint",
|
||||
return_value=None) as ph:
|
||||
m.rx_HINTS({"hints": [1, 2, 3]})
|
||||
self.assertEqual(ph.mock_calls, []) # ignored
|
||||
|
||||
c2 = mock.Mock()
|
||||
connector2 = mock.Mock(return_value=c2)
|
||||
with mock.patch("wormhole._dilation.manager.Connector", connector2):
|
||||
# this triggers creation of a new Connector
|
||||
m.rx_RECONNECTING()
|
||||
self.assertEqual(h.send.mock_calls, [])
|
||||
self.assertEqual(connector2.mock_calls, [
|
||||
mock.call(b"\x00" * 32, None, m, h.reactor, h.eq,
|
||||
False, # no_listen
|
||||
None, # tor
|
||||
None, # timing
|
||||
LEADER, roles.LEADER),
|
||||
])
|
||||
self.assertEqual(c2.mock_calls, [mock.call.start()])
|
||||
clear_mock_calls(connector2, c2)
|
||||
|
||||
self.assertEqual(h.inbound.mock_calls, [])
|
||||
self.assertEqual(h.outbound.mock_calls, [])
|
||||
|
||||
# and a new connection should re-register with Inbound/Outbound,
|
||||
# which are responsible for re-sending unacked queued messages
|
||||
c3 = mock.Mock()
|
||||
m.connector_connection_made(c3)
|
||||
|
||||
self.assertEqual(h.inbound.mock_calls, [mock.call.use_connection(c3)])
|
||||
self.assertEqual(h.outbound.mock_calls, [mock.call.use_connection(c3)])
|
||||
clear_mock_calls(h.inbound, h.outbound)
|
||||
|
||||
def test_follower(self):
|
||||
m, h = make_manager(leader=False)
|
||||
|
||||
m.start()
|
||||
self.assertEqual(h.send.mock_calls, [
|
||||
mock.call.send("dilate-0",
|
||||
dict_to_bytes({"type": "please", "side": FOLLOWER}))
|
||||
])
|
||||
clear_mock_calls(h.send)
|
||||
|
||||
c = mock.Mock()
|
||||
connector = mock.Mock(return_value=c)
|
||||
with mock.patch("wormhole._dilation.manager.Connector", connector):
|
||||
# receiving this PLEASE triggers creation of the Connector
|
||||
m.rx_PLEASE({"side": LEADER})
|
||||
self.assertEqual(h.send.mock_calls, [])
|
||||
self.assertEqual(connector.mock_calls, [
|
||||
mock.call(b"\x00" * 32, None, m, h.reactor, h.eq,
|
||||
False, # no_listen
|
||||
None, # tor
|
||||
None, # timing
|
||||
FOLLOWER, roles.FOLLOWER),
|
||||
])
|
||||
self.assertEqual(c.mock_calls, [mock.call.start()])
|
||||
clear_mock_calls(connector, c)
|
||||
|
||||
# get connected, then lose the connection
|
||||
c1 = mock.Mock()
|
||||
m.connector_connection_made(c1)
|
||||
self.assertEqual(h.inbound.mock_calls, [mock.call.use_connection(c1)])
|
||||
self.assertEqual(h.outbound.mock_calls, [mock.call.use_connection(c1)])
|
||||
clear_mock_calls(h.inbound, h.outbound)
|
||||
|
||||
# now lose the connection. As the follower, we don't notify the
|
||||
# leader, we just wait for them to notice
|
||||
m.connector_connection_lost()
|
||||
self.assertEqual(h.send.mock_calls, [])
|
||||
self.assertEqual(h.inbound.mock_calls, [
|
||||
mock.call.stop_using_connection()
|
||||
])
|
||||
self.assertEqual(h.outbound.mock_calls, [
|
||||
mock.call.stop_using_connection()
|
||||
])
|
||||
clear_mock_calls(h.send, h.inbound, h.outbound)
|
||||
|
||||
# now we get a RECONNECT: we should send RECONNECTING
|
||||
c2 = mock.Mock()
|
||||
connector2 = mock.Mock(return_value=c2)
|
||||
with mock.patch("wormhole._dilation.manager.Connector", connector2):
|
||||
m.rx_RECONNECT()
|
||||
self.assertEqual(h.send.mock_calls, [
|
||||
mock.call.send("dilate-1",
|
||||
dict_to_bytes({"type": "reconnecting"}))
|
||||
])
|
||||
self.assertEqual(connector2.mock_calls, [
|
||||
mock.call(b"\x00" * 32, None, m, h.reactor, h.eq,
|
||||
False, # no_listen
|
||||
None, # tor
|
||||
None, # timing
|
||||
FOLLOWER, roles.FOLLOWER),
|
||||
])
|
||||
self.assertEqual(c2.mock_calls, [mock.call.start()])
|
||||
clear_mock_calls(connector2, c2)
|
||||
|
||||
# while we're trying to connect, we get told to stop again, so we
|
||||
# should abandon the connection attempt and start another
|
||||
c3 = mock.Mock()
|
||||
connector3 = mock.Mock(return_value=c3)
|
||||
with mock.patch("wormhole._dilation.manager.Connector", connector3):
|
||||
m.rx_RECONNECT()
|
||||
self.assertEqual(c2.mock_calls, [mock.call.stop()])
|
||||
self.assertEqual(connector3.mock_calls, [
|
||||
mock.call(b"\x00" * 32, None, m, h.reactor, h.eq,
|
||||
False, # no_listen
|
||||
None, # tor
|
||||
None, # timing
|
||||
FOLLOWER, roles.FOLLOWER),
|
||||
])
|
||||
self.assertEqual(c3.mock_calls, [mock.call.start()])
|
||||
clear_mock_calls(c2, connector3, c3)
|
||||
|
||||
m.connector_connection_made(c3)
|
||||
# finally if we're already connected, rx_RECONNECT means we should
|
||||
# abandon this connection (even though it still looks ok to us), then
|
||||
# when the attempt is finished stopping, we should start another
|
||||
|
||||
m.rx_RECONNECT()
|
||||
|
||||
c4 = mock.Mock()
|
||||
connector4 = mock.Mock(return_value=c4)
|
||||
with mock.patch("wormhole._dilation.manager.Connector", connector4):
|
||||
m.connector_connection_lost()
|
||||
self.assertEqual(c3.mock_calls, [mock.call.disconnect()])
|
||||
self.assertEqual(connector4.mock_calls, [
|
||||
mock.call(b"\x00" * 32, None, m, h.reactor, h.eq,
|
||||
False, # no_listen
|
||||
None, # tor
|
||||
None, # timing
|
||||
FOLLOWER, roles.FOLLOWER),
|
||||
])
|
||||
self.assertEqual(c4.mock_calls, [mock.call.start()])
|
||||
clear_mock_calls(c3, connector4, c4)
|
||||
|
||||
def test_mirror(self):
|
||||
# receive a PLEASE with the same side as us: shouldn't happen
|
||||
m, h = make_manager(leader=True)
|
||||
|
||||
m.start()
|
||||
clear_mock_calls(h.send)
|
||||
e = self.assertRaises(ValueError, m.rx_PLEASE, {"side": LEADER})
|
||||
self.assertEqual(str(e), "their side shouldn't be equal: reflection?")
|
||||
|
||||
def test_ping_pong(self):
|
||||
m, h = make_manager(leader=False)
|
||||
|
||||
m.got_record(KCM())
|
||||
self.flushLoggedErrors(UnexpectedKCM)
|
||||
|
||||
m.got_record(Ping(1))
|
||||
self.assertEqual(h.outbound.mock_calls,
|
||||
[mock.call.send_if_connected(Pong(1))])
|
||||
clear_mock_calls(h.outbound)
|
||||
|
||||
m.got_record(Pong(2))
|
||||
# currently ignored, will eventually update a timer
|
||||
|
||||
m.got_record("not recognized")
|
||||
e = self.flushLoggedErrors(UnknownMessageType)
|
||||
self.assertEqual(len(e), 1)
|
||||
self.assertEqual(str(e[0].value), "not recognized")
|
||||
|
||||
m.send_ping(3)
|
||||
self.assertEqual(h.outbound.mock_calls,
|
||||
[mock.call.send_if_connected(Pong(3))])
|
||||
clear_mock_calls(h.outbound)
|
||||
|
||||
def test_subchannel(self):
|
||||
m, h = make_manager(leader=True)
|
||||
sc = object()
|
||||
|
||||
m.subchannel_pauseProducing(sc)
|
||||
self.assertEqual(h.inbound.mock_calls, [
|
||||
mock.call.subchannel_pauseProducing(sc)])
|
||||
clear_mock_calls(h.inbound)
|
||||
|
||||
m.subchannel_resumeProducing(sc)
|
||||
self.assertEqual(h.inbound.mock_calls, [
|
||||
mock.call.subchannel_resumeProducing(sc)])
|
||||
clear_mock_calls(h.inbound)
|
||||
|
||||
m.subchannel_stopProducing(sc)
|
||||
self.assertEqual(h.inbound.mock_calls, [
|
||||
mock.call.subchannel_stopProducing(sc)])
|
||||
clear_mock_calls(h.inbound)
|
||||
|
||||
p = object()
|
||||
streaming = object()
|
||||
|
||||
m.subchannel_registerProducer(sc, p, streaming)
|
||||
self.assertEqual(h.outbound.mock_calls, [
|
||||
mock.call.subchannel_registerProducer(sc, p, streaming)])
|
||||
clear_mock_calls(h.outbound)
|
||||
|
||||
m.subchannel_unregisterProducer(sc)
|
||||
self.assertEqual(h.outbound.mock_calls, [
|
||||
mock.call.subchannel_unregisterProducer(sc)])
|
||||
clear_mock_calls(h.outbound)
|
||||
|
||||
m.subchannel_closed("scid", sc)
|
||||
self.assertEqual(h.inbound.mock_calls, [
|
||||
mock.call.subchannel_closed("scid", sc)])
|
||||
self.assertEqual(h.outbound.mock_calls, [
|
||||
mock.call.subchannel_closed("scid", sc)])
|
||||
clear_mock_calls(h.inbound, h.outbound)
|
655
src/wormhole/test/dilate/test_outbound.py
Normal file
655
src/wormhole/test/dilate/test_outbound.py
Normal file
|
@ -0,0 +1,655 @@
|
|||
from __future__ import print_function, unicode_literals
|
||||
from collections import namedtuple
|
||||
from itertools import cycle
|
||||
import mock
|
||||
from zope.interface import alsoProvides
|
||||
from twisted.trial import unittest
|
||||
from twisted.internet.task import Clock, Cooperator
|
||||
from twisted.internet.interfaces import IPullProducer
|
||||
from ...eventual import EventualQueue
|
||||
from ..._interfaces import IDilationManager
|
||||
from ..._dilation.connection import KCM, Open, Data, Close, Ack
|
||||
from ..._dilation.outbound import Outbound, PullToPush
|
||||
from .common import clear_mock_calls
|
||||
|
||||
Pauser = namedtuple("Pauser", ["seqnum"])
|
||||
NonPauser = namedtuple("NonPauser", ["seqnum"])
|
||||
Stopper = namedtuple("Stopper", ["sc"])
|
||||
|
||||
|
||||
def make_outbound():
|
||||
m = mock.Mock()
|
||||
alsoProvides(m, IDilationManager)
|
||||
clock = Clock()
|
||||
eq = EventualQueue(clock)
|
||||
term = mock.Mock(side_effect=lambda: True) # one write per Eventual tick
|
||||
|
||||
def term_factory():
|
||||
return term
|
||||
coop = Cooperator(terminationPredicateFactory=term_factory,
|
||||
scheduler=eq.eventually)
|
||||
o = Outbound(m, coop)
|
||||
c = mock.Mock() # Connection
|
||||
|
||||
def maybe_pause(r):
|
||||
if isinstance(r, Pauser):
|
||||
o.pauseProducing()
|
||||
elif isinstance(r, Stopper):
|
||||
o.subchannel_unregisterProducer(r.sc)
|
||||
c.send_record = mock.Mock(side_effect=maybe_pause)
|
||||
o._test_eq = eq
|
||||
o._test_term = term
|
||||
return o, m, c
|
||||
|
||||
|
||||
class OutboundTest(unittest.TestCase):
|
||||
def test_build_record(self):
|
||||
o, m, c = make_outbound()
|
||||
scid1 = b"scid"
|
||||
self.assertEqual(o.build_record(Open, scid1),
|
||||
Open(seqnum=0, scid=b"scid"))
|
||||
self.assertEqual(o.build_record(Data, scid1, b"dataaa"),
|
||||
Data(seqnum=1, scid=b"scid", data=b"dataaa"))
|
||||
self.assertEqual(o.build_record(Close, scid1),
|
||||
Close(seqnum=2, scid=b"scid"))
|
||||
self.assertEqual(o.build_record(Close, scid1),
|
||||
Close(seqnum=3, scid=b"scid"))
|
||||
|
||||
def test_outbound_queue(self):
|
||||
o, m, c = make_outbound()
|
||||
scid1 = b"scid"
|
||||
r1 = o.build_record(Open, scid1)
|
||||
r2 = o.build_record(Data, scid1, b"data1")
|
||||
r3 = o.build_record(Data, scid1, b"data2")
|
||||
o.queue_and_send_record(r1)
|
||||
o.queue_and_send_record(r2)
|
||||
o.queue_and_send_record(r3)
|
||||
self.assertEqual(list(o._outbound_queue), [r1, r2, r3])
|
||||
|
||||
# we would never normally receive an ACK without first getting a
|
||||
# connection
|
||||
o.handle_ack(r2.seqnum)
|
||||
self.assertEqual(list(o._outbound_queue), [r3])
|
||||
|
||||
o.handle_ack(r3.seqnum)
|
||||
self.assertEqual(list(o._outbound_queue), [])
|
||||
|
||||
o.handle_ack(r3.seqnum) # ignored
|
||||
self.assertEqual(list(o._outbound_queue), [])
|
||||
|
||||
o.handle_ack(r1.seqnum) # ignored
|
||||
self.assertEqual(list(o._outbound_queue), [])
|
||||
|
||||
def test_duplicate_registerProducer(self):
|
||||
o, m, c = make_outbound()
|
||||
sc1 = object()
|
||||
p1 = mock.Mock()
|
||||
o.subchannel_registerProducer(sc1, p1, True)
|
||||
with self.assertRaises(ValueError) as ar:
|
||||
o.subchannel_registerProducer(sc1, p1, True)
|
||||
s = str(ar.exception)
|
||||
self.assertIn("registering producer", s)
|
||||
self.assertIn("before previous one", s)
|
||||
self.assertIn("was unregistered", s)
|
||||
|
||||
def test_connection_send_queued_unpaused(self):
|
||||
o, m, c = make_outbound()
|
||||
scid1 = b"scid"
|
||||
r1 = o.build_record(Open, scid1)
|
||||
r2 = o.build_record(Data, scid1, b"data1")
|
||||
r3 = o.build_record(Data, scid1, b"data2")
|
||||
o.queue_and_send_record(r1)
|
||||
o.queue_and_send_record(r2)
|
||||
self.assertEqual(list(o._outbound_queue), [r1, r2])
|
||||
self.assertEqual(list(o._queued_unsent), [])
|
||||
|
||||
# as soon as the connection is established, everything is sent
|
||||
o.use_connection(c)
|
||||
self.assertEqual(c.mock_calls, [mock.call.registerProducer(o, True),
|
||||
mock.call.send_record(r1),
|
||||
mock.call.send_record(r2)])
|
||||
self.assertEqual(list(o._outbound_queue), [r1, r2])
|
||||
self.assertEqual(list(o._queued_unsent), [])
|
||||
clear_mock_calls(c)
|
||||
|
||||
o.queue_and_send_record(r3)
|
||||
self.assertEqual(list(o._outbound_queue), [r1, r2, r3])
|
||||
self.assertEqual(list(o._queued_unsent), [])
|
||||
self.assertEqual(c.mock_calls, [mock.call.send_record(r3)])
|
||||
|
||||
def test_connection_send_queued_paused(self):
|
||||
o, m, c = make_outbound()
|
||||
r1 = Pauser(seqnum=1)
|
||||
r2 = Pauser(seqnum=2)
|
||||
r3 = Pauser(seqnum=3)
|
||||
o.queue_and_send_record(r1)
|
||||
o.queue_and_send_record(r2)
|
||||
self.assertEqual(list(o._outbound_queue), [r1, r2])
|
||||
self.assertEqual(list(o._queued_unsent), [])
|
||||
|
||||
# pausing=True, so our mock Manager will pause the Outbound producer
|
||||
# after each write. So only r1 should have been sent before getting
|
||||
# paused
|
||||
o.use_connection(c)
|
||||
self.assertEqual(c.mock_calls, [mock.call.registerProducer(o, True),
|
||||
mock.call.send_record(r1)])
|
||||
self.assertEqual(list(o._outbound_queue), [r1, r2])
|
||||
self.assertEqual(list(o._queued_unsent), [r2])
|
||||
clear_mock_calls(c)
|
||||
|
||||
# Outbound is responsible for sending all records, so when Manager
|
||||
# wants to send a new one, and Outbound is still in the middle of
|
||||
# draining the beginning-of-connection queue, the new message gets
|
||||
# queued behind the rest (in addition to being queued in
|
||||
# _outbound_queue until an ACK retires it).
|
||||
o.queue_and_send_record(r3)
|
||||
self.assertEqual(list(o._outbound_queue), [r1, r2, r3])
|
||||
self.assertEqual(list(o._queued_unsent), [r2, r3])
|
||||
self.assertEqual(c.mock_calls, [])
|
||||
|
||||
o.handle_ack(r1.seqnum)
|
||||
self.assertEqual(list(o._outbound_queue), [r2, r3])
|
||||
self.assertEqual(list(o._queued_unsent), [r2, r3])
|
||||
self.assertEqual(c.mock_calls, [])
|
||||
|
||||
def test_premptive_ack(self):
|
||||
# one mode I have in mind is for each side to send an immediate ACK,
|
||||
# with everything they've ever seen, as the very first message on each
|
||||
# new connection. The idea is that you might preempt sending stuff from
|
||||
# the _queued_unsent list if it arrives fast enough (in practice this
|
||||
# is more likely to be delivered via the DILATE mailbox message, but
|
||||
# the effects might be vaguely similar, so it seems worth testing
|
||||
# here). A similar situation would be if each side sends ACKs with the
|
||||
# highest seqnum they've ever seen, instead of merely ACKing the
|
||||
# message which was just received.
|
||||
o, m, c = make_outbound()
|
||||
r1 = Pauser(seqnum=1)
|
||||
r2 = Pauser(seqnum=2)
|
||||
r3 = Pauser(seqnum=3)
|
||||
o.queue_and_send_record(r1)
|
||||
o.queue_and_send_record(r2)
|
||||
self.assertEqual(list(o._outbound_queue), [r1, r2])
|
||||
self.assertEqual(list(o._queued_unsent), [])
|
||||
|
||||
o.use_connection(c)
|
||||
self.assertEqual(c.mock_calls, [mock.call.registerProducer(o, True),
|
||||
mock.call.send_record(r1)])
|
||||
self.assertEqual(list(o._outbound_queue), [r1, r2])
|
||||
self.assertEqual(list(o._queued_unsent), [r2])
|
||||
clear_mock_calls(c)
|
||||
|
||||
o.queue_and_send_record(r3)
|
||||
self.assertEqual(list(o._outbound_queue), [r1, r2, r3])
|
||||
self.assertEqual(list(o._queued_unsent), [r2, r3])
|
||||
self.assertEqual(c.mock_calls, [])
|
||||
|
||||
o.handle_ack(r2.seqnum)
|
||||
self.assertEqual(list(o._outbound_queue), [r3])
|
||||
self.assertEqual(list(o._queued_unsent), [r3])
|
||||
self.assertEqual(c.mock_calls, [])
|
||||
|
||||
def test_pause(self):
|
||||
o, m, c = make_outbound()
|
||||
o.use_connection(c)
|
||||
self.assertEqual(c.mock_calls, [mock.call.registerProducer(o, True)])
|
||||
self.assertEqual(list(o._outbound_queue), [])
|
||||
self.assertEqual(list(o._queued_unsent), [])
|
||||
clear_mock_calls(c)
|
||||
|
||||
sc1, sc2, sc3 = object(), object(), object()
|
||||
p1, p2, p3 = mock.Mock(name="p1"), mock.Mock(
|
||||
name="p2"), mock.Mock(name="p3")
|
||||
|
||||
# we aren't paused yet, since we haven't sent any data
|
||||
o.subchannel_registerProducer(sc1, p1, True)
|
||||
self.assertEqual(p1.mock_calls, [])
|
||||
|
||||
r1 = Pauser(seqnum=1)
|
||||
o.queue_and_send_record(r1)
|
||||
# now we should be paused
|
||||
self.assertTrue(o._paused)
|
||||
self.assertEqual(c.mock_calls, [mock.call.send_record(r1)])
|
||||
self.assertEqual(p1.mock_calls, [mock.call.pauseProducing()])
|
||||
clear_mock_calls(p1, c)
|
||||
|
||||
# so an IPushProducer will be paused right away
|
||||
o.subchannel_registerProducer(sc2, p2, True)
|
||||
self.assertEqual(p2.mock_calls, [mock.call.pauseProducing()])
|
||||
clear_mock_calls(p2)
|
||||
|
||||
o.subchannel_registerProducer(sc3, p3, True)
|
||||
self.assertEqual(p3.mock_calls, [mock.call.pauseProducing()])
|
||||
self.assertEqual(o._paused_producers, set([p1, p2, p3]))
|
||||
self.assertEqual(list(o._all_producers), [p1, p2, p3])
|
||||
clear_mock_calls(p3)
|
||||
|
||||
# one resumeProducing should cause p1 to get a turn, since p2 was added
|
||||
# after we were paused and p1 was at the "end" of a one-element list.
|
||||
# If it writes anything, it will get paused again immediately.
|
||||
r2 = Pauser(seqnum=2)
|
||||
p1.resumeProducing.side_effect = lambda: c.send_record(r2)
|
||||
o.resumeProducing()
|
||||
self.assertEqual(p1.mock_calls, [mock.call.resumeProducing(),
|
||||
mock.call.pauseProducing(),
|
||||
])
|
||||
self.assertEqual(p2.mock_calls, [])
|
||||
self.assertEqual(p3.mock_calls, [])
|
||||
self.assertEqual(c.mock_calls, [mock.call.send_record(r2)])
|
||||
clear_mock_calls(p1, p2, p3, c)
|
||||
# p2 should now be at the head of the queue
|
||||
self.assertEqual(list(o._all_producers), [p2, p3, p1])
|
||||
|
||||
# next turn: p2 has nothing to send, but p3 does. we should see p3
|
||||
# called but not p1. The actual sequence of expected calls is:
|
||||
# p2.resume, p3.resume, pauseProducing, set(p2.pause, p3.pause)
|
||||
r3 = Pauser(seqnum=3)
|
||||
p2.resumeProducing.side_effect = lambda: None
|
||||
p3.resumeProducing.side_effect = lambda: c.send_record(r3)
|
||||
o.resumeProducing()
|
||||
self.assertEqual(p1.mock_calls, [])
|
||||
self.assertEqual(p2.mock_calls, [mock.call.resumeProducing(),
|
||||
mock.call.pauseProducing(),
|
||||
])
|
||||
self.assertEqual(p3.mock_calls, [mock.call.resumeProducing(),
|
||||
mock.call.pauseProducing(),
|
||||
])
|
||||
self.assertEqual(c.mock_calls, [mock.call.send_record(r3)])
|
||||
clear_mock_calls(p1, p2, p3, c)
|
||||
# p1 should now be at the head of the queue
|
||||
self.assertEqual(list(o._all_producers), [p1, p2, p3])
|
||||
|
||||
# next turn: p1 has data to send, but not enough to cause a pause. same
|
||||
# for p2. p3 causes a pause
|
||||
r4 = NonPauser(seqnum=4)
|
||||
r5 = NonPauser(seqnum=5)
|
||||
r6 = Pauser(seqnum=6)
|
||||
p1.resumeProducing.side_effect = lambda: c.send_record(r4)
|
||||
p2.resumeProducing.side_effect = lambda: c.send_record(r5)
|
||||
p3.resumeProducing.side_effect = lambda: c.send_record(r6)
|
||||
o.resumeProducing()
|
||||
self.assertEqual(p1.mock_calls, [mock.call.resumeProducing(),
|
||||
mock.call.pauseProducing(),
|
||||
])
|
||||
self.assertEqual(p2.mock_calls, [mock.call.resumeProducing(),
|
||||
mock.call.pauseProducing(),
|
||||
])
|
||||
self.assertEqual(p3.mock_calls, [mock.call.resumeProducing(),
|
||||
mock.call.pauseProducing(),
|
||||
])
|
||||
self.assertEqual(c.mock_calls, [mock.call.send_record(r4),
|
||||
mock.call.send_record(r5),
|
||||
mock.call.send_record(r6),
|
||||
])
|
||||
clear_mock_calls(p1, p2, p3, c)
|
||||
# p1 should now be at the head of the queue again
|
||||
self.assertEqual(list(o._all_producers), [p1, p2, p3])
|
||||
|
||||
# now we let it catch up. p1 and p2 send non-pausing data, p3 sends
|
||||
# nothing.
|
||||
r7 = NonPauser(seqnum=4)
|
||||
r8 = NonPauser(seqnum=5)
|
||||
p1.resumeProducing.side_effect = lambda: c.send_record(r7)
|
||||
p2.resumeProducing.side_effect = lambda: c.send_record(r8)
|
||||
p3.resumeProducing.side_effect = lambda: None
|
||||
|
||||
o.resumeProducing()
|
||||
self.assertEqual(p1.mock_calls, [mock.call.resumeProducing(),
|
||||
])
|
||||
self.assertEqual(p2.mock_calls, [mock.call.resumeProducing(),
|
||||
])
|
||||
self.assertEqual(p3.mock_calls, [mock.call.resumeProducing(),
|
||||
])
|
||||
self.assertEqual(c.mock_calls, [mock.call.send_record(r7),
|
||||
mock.call.send_record(r8),
|
||||
])
|
||||
clear_mock_calls(p1, p2, p3, c)
|
||||
# p1 should now be at the head of the queue again
|
||||
self.assertEqual(list(o._all_producers), [p1, p2, p3])
|
||||
self.assertFalse(o._paused)
|
||||
|
||||
# now a producer disconnects itself (spontaneously, not from inside a
|
||||
# resumeProducing)
|
||||
o.subchannel_unregisterProducer(sc1)
|
||||
self.assertEqual(list(o._all_producers), [p2, p3])
|
||||
self.assertEqual(p1.mock_calls, [])
|
||||
self.assertFalse(o._paused)
|
||||
|
||||
# and another disconnects itself when called
|
||||
p2.resumeProducing.side_effect = lambda: None
|
||||
p3.resumeProducing.side_effect = lambda: o.subchannel_unregisterProducer(
|
||||
sc3)
|
||||
o.pauseProducing()
|
||||
o.resumeProducing()
|
||||
self.assertEqual(p2.mock_calls, [mock.call.pauseProducing(),
|
||||
mock.call.resumeProducing()])
|
||||
self.assertEqual(p3.mock_calls, [mock.call.pauseProducing(),
|
||||
mock.call.resumeProducing()])
|
||||
clear_mock_calls(p2, p3)
|
||||
self.assertEqual(list(o._all_producers), [p2])
|
||||
self.assertFalse(o._paused)
|
||||
|
||||
def test_subchannel_closed(self):
|
||||
o, m, c = make_outbound()
|
||||
|
||||
sc1 = mock.Mock()
|
||||
p1 = mock.Mock(name="p1")
|
||||
o.subchannel_registerProducer(sc1, p1, True)
|
||||
self.assertEqual(p1.mock_calls, [mock.call.pauseProducing()])
|
||||
clear_mock_calls(p1)
|
||||
|
||||
o.subchannel_closed(sc1)
|
||||
self.assertEqual(p1.mock_calls, [])
|
||||
self.assertEqual(list(o._all_producers), [])
|
||||
|
||||
sc2 = mock.Mock()
|
||||
o.subchannel_closed(sc2)
|
||||
|
||||
def test_disconnect(self):
|
||||
o, m, c = make_outbound()
|
||||
o.use_connection(c)
|
||||
|
||||
sc1 = mock.Mock()
|
||||
p1 = mock.Mock(name="p1")
|
||||
o.subchannel_registerProducer(sc1, p1, True)
|
||||
self.assertEqual(p1.mock_calls, [])
|
||||
o.stop_using_connection()
|
||||
self.assertEqual(p1.mock_calls, [mock.call.pauseProducing()])
|
||||
|
||||
def OFF_test_push_pull(self):
|
||||
# use one IPushProducer and one IPullProducer. They should take turns
|
||||
o, m, c = make_outbound()
|
||||
o.use_connection(c)
|
||||
clear_mock_calls(c)
|
||||
|
||||
sc1, sc2 = object(), object()
|
||||
p1, p2 = mock.Mock(name="p1"), mock.Mock(name="p2")
|
||||
r1 = Pauser(seqnum=1)
|
||||
r2 = NonPauser(seqnum=2)
|
||||
|
||||
# we aren't paused yet, since we haven't sent any data
|
||||
o.subchannel_registerProducer(sc1, p1, True) # push
|
||||
o.queue_and_send_record(r1)
|
||||
# now we're paused
|
||||
self.assertTrue(o._paused)
|
||||
self.assertEqual(c.mock_calls, [mock.call.send_record(r1)])
|
||||
self.assertEqual(p1.mock_calls, [mock.call.pauseProducing()])
|
||||
self.assertEqual(p2.mock_calls, [])
|
||||
clear_mock_calls(p1, p2, c)
|
||||
|
||||
p1.resumeProducing.side_effect = lambda: c.send_record(r1)
|
||||
p2.resumeProducing.side_effect = lambda: c.send_record(r2)
|
||||
o.subchannel_registerProducer(sc2, p2, False) # pull: always ready
|
||||
|
||||
# p1 is still first, since p2 was just added (at the end)
|
||||
self.assertTrue(o._paused)
|
||||
self.assertEqual(c.mock_calls, [])
|
||||
self.assertEqual(p1.mock_calls, [])
|
||||
self.assertEqual(p2.mock_calls, [])
|
||||
self.assertEqual(list(o._all_producers), [p1, p2])
|
||||
clear_mock_calls(p1, p2, c)
|
||||
|
||||
# resume should send r1, which should pause everything
|
||||
o.resumeProducing()
|
||||
self.assertTrue(o._paused)
|
||||
self.assertEqual(c.mock_calls, [mock.call.send_record(r1),
|
||||
])
|
||||
self.assertEqual(p1.mock_calls, [mock.call.resumeProducing(),
|
||||
mock.call.pauseProducing(),
|
||||
])
|
||||
self.assertEqual(p2.mock_calls, [])
|
||||
self.assertEqual(list(o._all_producers), [p2, p1]) # now p2 is next
|
||||
clear_mock_calls(p1, p2, c)
|
||||
|
||||
# next should fire p2, then p1
|
||||
o.resumeProducing()
|
||||
self.assertTrue(o._paused)
|
||||
self.assertEqual(c.mock_calls, [mock.call.send_record(r2),
|
||||
mock.call.send_record(r1),
|
||||
])
|
||||
self.assertEqual(p1.mock_calls, [mock.call.resumeProducing(),
|
||||
mock.call.pauseProducing(),
|
||||
])
|
||||
self.assertEqual(p2.mock_calls, [mock.call.resumeProducing(),
|
||||
])
|
||||
self.assertEqual(list(o._all_producers), [p2, p1]) # p2 still at bat
|
||||
clear_mock_calls(p1, p2, c)
|
||||
|
||||
def test_pull_producer(self):
|
||||
# a single pull producer should write until it is paused, rate-limited
|
||||
# by the cooperator (so we'll see back-to-back resumeProducing calls
|
||||
# until the Connection is paused, or 10ms have passed, whichever comes
|
||||
# first, and if it's stopped by the timer, then the next EventualQueue
|
||||
# turn will start it off again)
|
||||
|
||||
o, m, c = make_outbound()
|
||||
eq = o._test_eq
|
||||
o.use_connection(c)
|
||||
clear_mock_calls(c)
|
||||
self.assertFalse(o._paused)
|
||||
|
||||
sc1 = mock.Mock()
|
||||
p1 = mock.Mock(name="p1")
|
||||
alsoProvides(p1, IPullProducer)
|
||||
|
||||
records = [NonPauser(seqnum=1)] * 10
|
||||
records.append(Pauser(seqnum=2))
|
||||
records.append(Stopper(sc1))
|
||||
it = iter(records)
|
||||
p1.resumeProducing.side_effect = lambda: c.send_record(next(it))
|
||||
o.subchannel_registerProducer(sc1, p1, False)
|
||||
eq.flush_sync() # fast forward into the glorious (paused) future
|
||||
|
||||
self.assertTrue(o._paused)
|
||||
self.assertEqual(c.mock_calls,
|
||||
[mock.call.send_record(r) for r in records[:-1]])
|
||||
self.assertEqual(p1.mock_calls,
|
||||
[mock.call.resumeProducing()] * (len(records) - 1))
|
||||
clear_mock_calls(c, p1)
|
||||
|
||||
# next resumeProducing should cause it to disconnect
|
||||
o.resumeProducing()
|
||||
eq.flush_sync()
|
||||
self.assertEqual(c.mock_calls, [mock.call.send_record(records[-1])])
|
||||
self.assertEqual(p1.mock_calls, [mock.call.resumeProducing()])
|
||||
self.assertEqual(len(o._all_producers), 0)
|
||||
self.assertFalse(o._paused)
|
||||
|
||||
def test_two_pull_producers(self):
|
||||
# we should alternate between them until paused
|
||||
p1_records = ([NonPauser(seqnum=i) for i in range(5)] +
|
||||
[Pauser(seqnum=5)] +
|
||||
[NonPauser(seqnum=i) for i in range(6, 10)])
|
||||
p2_records = ([NonPauser(seqnum=i) for i in range(10, 19)] +
|
||||
[Pauser(seqnum=19)])
|
||||
expected1 = [NonPauser(0), NonPauser(10),
|
||||
NonPauser(1), NonPauser(11),
|
||||
NonPauser(2), NonPauser(12),
|
||||
NonPauser(3), NonPauser(13),
|
||||
NonPauser(4), NonPauser(14),
|
||||
Pauser(5)]
|
||||
expected2 = [NonPauser(15),
|
||||
NonPauser(6), NonPauser(16),
|
||||
NonPauser(7), NonPauser(17),
|
||||
NonPauser(8), NonPauser(18),
|
||||
NonPauser(9), Pauser(19),
|
||||
]
|
||||
|
||||
o, m, c = make_outbound()
|
||||
eq = o._test_eq
|
||||
o.use_connection(c)
|
||||
clear_mock_calls(c)
|
||||
self.assertFalse(o._paused)
|
||||
|
||||
sc1 = mock.Mock()
|
||||
p1 = mock.Mock(name="p1")
|
||||
alsoProvides(p1, IPullProducer)
|
||||
it1 = iter(p1_records)
|
||||
p1.resumeProducing.side_effect = lambda: c.send_record(next(it1))
|
||||
o.subchannel_registerProducer(sc1, p1, False)
|
||||
|
||||
sc2 = mock.Mock()
|
||||
p2 = mock.Mock(name="p2")
|
||||
alsoProvides(p2, IPullProducer)
|
||||
it2 = iter(p2_records)
|
||||
p2.resumeProducing.side_effect = lambda: c.send_record(next(it2))
|
||||
o.subchannel_registerProducer(sc2, p2, False)
|
||||
|
||||
eq.flush_sync() # fast forward into the glorious (paused) future
|
||||
|
||||
sends = [mock.call.resumeProducing()]
|
||||
self.assertTrue(o._paused)
|
||||
self.assertEqual(c.mock_calls,
|
||||
[mock.call.send_record(r) for r in expected1])
|
||||
self.assertEqual(p1.mock_calls, 6 * sends)
|
||||
self.assertEqual(p2.mock_calls, 5 * sends)
|
||||
clear_mock_calls(c, p1, p2)
|
||||
|
||||
o.resumeProducing()
|
||||
eq.flush_sync()
|
||||
self.assertTrue(o._paused)
|
||||
self.assertEqual(c.mock_calls,
|
||||
[mock.call.send_record(r) for r in expected2])
|
||||
self.assertEqual(p1.mock_calls, 4 * sends)
|
||||
self.assertEqual(p2.mock_calls, 5 * sends)
|
||||
clear_mock_calls(c, p1, p2)
|
||||
|
||||
def test_send_if_connected(self):
|
||||
o, m, c = make_outbound()
|
||||
o.send_if_connected(Ack(1)) # not connected yet
|
||||
|
||||
o.use_connection(c)
|
||||
o.send_if_connected(KCM())
|
||||
self.assertEqual(c.mock_calls, [mock.call.registerProducer(o, True),
|
||||
mock.call.send_record(KCM())])
|
||||
|
||||
def test_tolerate_duplicate_pause_resume(self):
|
||||
o, m, c = make_outbound()
|
||||
self.assertTrue(o._paused) # no connection
|
||||
o.use_connection(c)
|
||||
self.assertFalse(o._paused)
|
||||
o.pauseProducing()
|
||||
self.assertTrue(o._paused)
|
||||
o.pauseProducing()
|
||||
self.assertTrue(o._paused)
|
||||
o.resumeProducing()
|
||||
self.assertFalse(o._paused)
|
||||
o.resumeProducing()
|
||||
self.assertFalse(o._paused)
|
||||
|
||||
def test_stopProducing(self):
|
||||
o, m, c = make_outbound()
|
||||
o.use_connection(c)
|
||||
self.assertFalse(o._paused)
|
||||
o.stopProducing() # connection does this before loss
|
||||
self.assertTrue(o._paused)
|
||||
o.stop_using_connection()
|
||||
self.assertTrue(o._paused)
|
||||
|
||||
def test_resume_error(self):
|
||||
o, m, c = make_outbound()
|
||||
o.use_connection(c)
|
||||
sc1 = mock.Mock()
|
||||
p1 = mock.Mock(name="p1")
|
||||
alsoProvides(p1, IPullProducer)
|
||||
p1.resumeProducing.side_effect = PretendResumptionError
|
||||
o.subchannel_registerProducer(sc1, p1, False)
|
||||
o._test_eq.flush_sync()
|
||||
# the error is supposed to automatically unregister the producer
|
||||
self.assertEqual(list(o._all_producers), [])
|
||||
self.flushLoggedErrors(PretendResumptionError)
|
||||
|
||||
|
||||
def make_pushpull(pauses):
|
||||
p = mock.Mock()
|
||||
alsoProvides(p, IPullProducer)
|
||||
unregister = mock.Mock()
|
||||
|
||||
clock = Clock()
|
||||
eq = EventualQueue(clock)
|
||||
term = mock.Mock(side_effect=lambda: True) # one write per Eventual tick
|
||||
|
||||
def term_factory():
|
||||
return term
|
||||
coop = Cooperator(terminationPredicateFactory=term_factory,
|
||||
scheduler=eq.eventually)
|
||||
pp = PullToPush(p, unregister, coop)
|
||||
|
||||
it = cycle(pauses)
|
||||
|
||||
def action(i):
|
||||
if isinstance(i, Exception):
|
||||
raise i
|
||||
elif i:
|
||||
pp.pauseProducing()
|
||||
p.resumeProducing.side_effect = lambda: action(next(it))
|
||||
return p, unregister, pp, eq
|
||||
|
||||
|
||||
class PretendResumptionError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class PretendUnregisterError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class PushPull(unittest.TestCase):
|
||||
# test our wrapper utility, which I copied from
|
||||
# twisted.internet._producer_helpers since it isn't publically exposed
|
||||
|
||||
def test_start_unpaused(self):
|
||||
p, unr, pp, eq = make_pushpull([True]) # pause on each resumeProducing
|
||||
# if it starts unpaused, it gets one write before being halted
|
||||
pp.startStreaming(False)
|
||||
eq.flush_sync()
|
||||
self.assertEqual(p.mock_calls, [mock.call.resumeProducing()] * 1)
|
||||
clear_mock_calls(p)
|
||||
|
||||
# now each time we call resumeProducing, we should see one delivered to
|
||||
# the underlying IPullProducer
|
||||
pp.resumeProducing()
|
||||
eq.flush_sync()
|
||||
self.assertEqual(p.mock_calls, [mock.call.resumeProducing()] * 1)
|
||||
|
||||
pp.stopStreaming()
|
||||
pp.stopStreaming() # should tolerate this
|
||||
|
||||
def test_start_unpaused_two_writes(self):
|
||||
p, unr, pp, eq = make_pushpull([False, True]) # pause every other time
|
||||
# it should get two writes, since the first didn't pause
|
||||
pp.startStreaming(False)
|
||||
eq.flush_sync()
|
||||
self.assertEqual(p.mock_calls, [mock.call.resumeProducing()] * 2)
|
||||
|
||||
def test_start_paused(self):
|
||||
p, unr, pp, eq = make_pushpull([True]) # pause on each resumeProducing
|
||||
pp.startStreaming(True)
|
||||
eq.flush_sync()
|
||||
self.assertEqual(p.mock_calls, [])
|
||||
pp.stopStreaming()
|
||||
|
||||
def test_stop(self):
|
||||
p, unr, pp, eq = make_pushpull([True])
|
||||
pp.startStreaming(True)
|
||||
pp.stopProducing()
|
||||
eq.flush_sync()
|
||||
self.assertEqual(p.mock_calls, [mock.call.stopProducing()])
|
||||
|
||||
def test_error(self):
|
||||
p, unr, pp, eq = make_pushpull([PretendResumptionError()])
|
||||
unr.side_effect = lambda: pp.stopStreaming()
|
||||
pp.startStreaming(False)
|
||||
eq.flush_sync()
|
||||
self.assertEqual(unr.mock_calls, [mock.call()])
|
||||
self.flushLoggedErrors(PretendResumptionError)
|
||||
|
||||
def test_error_during_unregister(self):
|
||||
p, unr, pp, eq = make_pushpull([PretendResumptionError()])
|
||||
unr.side_effect = PretendUnregisterError()
|
||||
pp.startStreaming(False)
|
||||
eq.flush_sync()
|
||||
self.assertEqual(unr.mock_calls, [mock.call()])
|
||||
self.flushLoggedErrors(PretendResumptionError, PretendUnregisterError)
|
||||
|
||||
# TODO: consider making p1/p2/p3 all elements of a shared Mock, maybe I
|
||||
# could capture the inter-call ordering that way
|
44
src/wormhole/test/dilate/test_parse.py
Normal file
44
src/wormhole/test/dilate/test_parse.py
Normal file
|
@ -0,0 +1,44 @@
|
|||
from __future__ import print_function, unicode_literals
|
||||
import mock
|
||||
from twisted.trial import unittest
|
||||
from ..._dilation.connection import (parse_record, encode_record,
|
||||
KCM, Ping, Pong, Open, Data, Close, Ack)
|
||||
|
||||
|
||||
class Parse(unittest.TestCase):
|
||||
def test_parse(self):
|
||||
self.assertEqual(parse_record(b"\x00"), KCM())
|
||||
self.assertEqual(parse_record(b"\x01\x55\x44\x33\x22"),
|
||||
Ping(ping_id=b"\x55\x44\x33\x22"))
|
||||
self.assertEqual(parse_record(b"\x02\x55\x44\x33\x22"),
|
||||
Pong(ping_id=b"\x55\x44\x33\x22"))
|
||||
self.assertEqual(parse_record(b"\x03\x00\x00\x02\x01\x00\x00\x01\x00"),
|
||||
Open(scid=513, seqnum=256))
|
||||
self.assertEqual(parse_record(b"\x04\x00\x00\x02\x02\x00\x00\x01\x01dataaa"),
|
||||
Data(scid=514, seqnum=257, data=b"dataaa"))
|
||||
self.assertEqual(parse_record(b"\x05\x00\x00\x02\x03\x00\x00\x01\x02"),
|
||||
Close(scid=515, seqnum=258))
|
||||
self.assertEqual(parse_record(b"\x06\x00\x00\x01\x03"),
|
||||
Ack(resp_seqnum=259))
|
||||
with mock.patch("wormhole._dilation.connection.log.err") as le:
|
||||
with self.assertRaises(ValueError):
|
||||
parse_record(b"\x07unknown")
|
||||
self.assertEqual(le.mock_calls,
|
||||
[mock.call("received unknown message type: {}".format(
|
||||
b"\x07unknown"))])
|
||||
|
||||
def test_encode(self):
|
||||
self.assertEqual(encode_record(KCM()), b"\x00")
|
||||
self.assertEqual(encode_record(Ping(ping_id=b"ping")), b"\x01ping")
|
||||
self.assertEqual(encode_record(Pong(ping_id=b"pong")), b"\x02pong")
|
||||
self.assertEqual(encode_record(Open(scid=65536, seqnum=16)),
|
||||
b"\x03\x00\x01\x00\x00\x00\x00\x00\x10")
|
||||
self.assertEqual(encode_record(Data(scid=65537, seqnum=17, data=b"dataaa")),
|
||||
b"\x04\x00\x01\x00\x01\x00\x00\x00\x11dataaa")
|
||||
self.assertEqual(encode_record(Close(scid=65538, seqnum=18)),
|
||||
b"\x05\x00\x01\x00\x02\x00\x00\x00\x12")
|
||||
self.assertEqual(encode_record(Ack(resp_seqnum=19)),
|
||||
b"\x06\x00\x00\x00\x13")
|
||||
with self.assertRaises(TypeError) as ar:
|
||||
encode_record("not a record")
|
||||
self.assertEqual(str(ar.exception), "not a record")
|
269
src/wormhole/test/dilate/test_record.py
Normal file
269
src/wormhole/test/dilate/test_record.py
Normal file
|
@ -0,0 +1,269 @@
|
|||
from __future__ import print_function, unicode_literals
|
||||
import mock
|
||||
from zope.interface import alsoProvides
|
||||
from twisted.trial import unittest
|
||||
from ..._dilation._noise import NoiseInvalidMessage
|
||||
from ..._dilation.connection import (IFramer, Frame, Prologue,
|
||||
_Record, Handshake,
|
||||
Disconnect, Ping)
|
||||
|
||||
|
||||
def make_record():
|
||||
f = mock.Mock()
|
||||
alsoProvides(f, IFramer)
|
||||
n = mock.Mock() # pretends to be a Noise object
|
||||
r = _Record(f, n)
|
||||
return r, f, n
|
||||
|
||||
|
||||
class Record(unittest.TestCase):
|
||||
def test_good2(self):
|
||||
f = mock.Mock()
|
||||
alsoProvides(f, IFramer)
|
||||
f.add_and_parse = mock.Mock(side_effect=[
|
||||
[],
|
||||
[Prologue()],
|
||||
[Frame(frame=b"rx-handshake")],
|
||||
[Frame(frame=b"frame1"), Frame(frame=b"frame2")],
|
||||
])
|
||||
n = mock.Mock()
|
||||
n.write_message = mock.Mock(return_value=b"tx-handshake")
|
||||
p1, p2 = object(), object()
|
||||
n.decrypt = mock.Mock(side_effect=[p1, p2])
|
||||
r = _Record(f, n)
|
||||
self.assertEqual(f.mock_calls, [])
|
||||
r.connectionMade()
|
||||
self.assertEqual(f.mock_calls, [mock.call.connectionMade()])
|
||||
f.mock_calls[:] = []
|
||||
self.assertEqual(n.mock_calls, [mock.call.start_handshake()])
|
||||
n.mock_calls[:] = []
|
||||
|
||||
# Pretend to deliver the prologue in two parts. The text we send in
|
||||
# doesn't matter: the side_effect= is what causes the prologue to be
|
||||
# recognized by the second call.
|
||||
self.assertEqual(list(r.add_and_unframe(b"pro")), [])
|
||||
self.assertEqual(f.mock_calls, [mock.call.add_and_parse(b"pro")])
|
||||
f.mock_calls[:] = []
|
||||
self.assertEqual(n.mock_calls, [])
|
||||
|
||||
# recognizing the prologue causes a handshake frame to be sent
|
||||
self.assertEqual(list(r.add_and_unframe(b"logue")), [])
|
||||
self.assertEqual(f.mock_calls, [mock.call.add_and_parse(b"logue"),
|
||||
mock.call.send_frame(b"tx-handshake")])
|
||||
f.mock_calls[:] = []
|
||||
self.assertEqual(n.mock_calls, [mock.call.write_message()])
|
||||
n.mock_calls[:] = []
|
||||
|
||||
# next add_and_unframe is recognized as the Handshake
|
||||
self.assertEqual(list(r.add_and_unframe(b"blah")), [Handshake()])
|
||||
self.assertEqual(f.mock_calls, [mock.call.add_and_parse(b"blah")])
|
||||
f.mock_calls[:] = []
|
||||
self.assertEqual(n.mock_calls, [mock.call.read_message(b"rx-handshake")])
|
||||
n.mock_calls[:] = []
|
||||
|
||||
# next is a pair of Records
|
||||
r1, r2 = object(), object()
|
||||
with mock.patch("wormhole._dilation.connection.parse_record",
|
||||
side_effect=[r1, r2]) as pr:
|
||||
self.assertEqual(list(r.add_and_unframe(b"blah2")), [r1, r2])
|
||||
self.assertEqual(n.mock_calls, [mock.call.decrypt(b"frame1"),
|
||||
mock.call.decrypt(b"frame2")])
|
||||
self.assertEqual(pr.mock_calls, [mock.call(p1), mock.call(p2)])
|
||||
|
||||
def test_bad_handshake(self):
|
||||
f = mock.Mock()
|
||||
alsoProvides(f, IFramer)
|
||||
f.add_and_parse = mock.Mock(return_value=[Prologue(),
|
||||
Frame(frame=b"rx-handshake")])
|
||||
n = mock.Mock()
|
||||
n.write_message = mock.Mock(return_value=b"tx-handshake")
|
||||
nvm = NoiseInvalidMessage()
|
||||
n.read_message = mock.Mock(side_effect=nvm)
|
||||
r = _Record(f, n)
|
||||
self.assertEqual(f.mock_calls, [])
|
||||
r.connectionMade()
|
||||
self.assertEqual(f.mock_calls, [mock.call.connectionMade()])
|
||||
f.mock_calls[:] = []
|
||||
self.assertEqual(n.mock_calls, [mock.call.start_handshake()])
|
||||
n.mock_calls[:] = []
|
||||
|
||||
with mock.patch("wormhole._dilation.connection.log.err") as le:
|
||||
with self.assertRaises(Disconnect):
|
||||
list(r.add_and_unframe(b"data"))
|
||||
self.assertEqual(le.mock_calls,
|
||||
[mock.call(nvm, "bad inbound noise handshake")])
|
||||
|
||||
def test_bad_message(self):
|
||||
f = mock.Mock()
|
||||
alsoProvides(f, IFramer)
|
||||
f.add_and_parse = mock.Mock(return_value=[Prologue(),
|
||||
Frame(frame=b"rx-handshake"),
|
||||
Frame(frame=b"bad-message")])
|
||||
n = mock.Mock()
|
||||
n.write_message = mock.Mock(return_value=b"tx-handshake")
|
||||
nvm = NoiseInvalidMessage()
|
||||
n.decrypt = mock.Mock(side_effect=nvm)
|
||||
r = _Record(f, n)
|
||||
self.assertEqual(f.mock_calls, [])
|
||||
r.connectionMade()
|
||||
self.assertEqual(f.mock_calls, [mock.call.connectionMade()])
|
||||
f.mock_calls[:] = []
|
||||
self.assertEqual(n.mock_calls, [mock.call.start_handshake()])
|
||||
n.mock_calls[:] = []
|
||||
|
||||
with mock.patch("wormhole._dilation.connection.log.err") as le:
|
||||
with self.assertRaises(Disconnect):
|
||||
list(r.add_and_unframe(b"data"))
|
||||
self.assertEqual(le.mock_calls,
|
||||
[mock.call(nvm, "bad inbound noise frame")])
|
||||
|
||||
def test_send_record(self):
|
||||
f = mock.Mock()
|
||||
alsoProvides(f, IFramer)
|
||||
n = mock.Mock()
|
||||
f1 = object()
|
||||
n.encrypt = mock.Mock(return_value=f1)
|
||||
r1 = Ping(b"pingid")
|
||||
r = _Record(f, n)
|
||||
self.assertEqual(f.mock_calls, [])
|
||||
m1 = object()
|
||||
with mock.patch("wormhole._dilation.connection.encode_record",
|
||||
return_value=m1) as er:
|
||||
r.send_record(r1)
|
||||
self.assertEqual(er.mock_calls, [mock.call(r1)])
|
||||
self.assertEqual(n.mock_calls, [mock.call.start_handshake(),
|
||||
mock.call.encrypt(m1)])
|
||||
self.assertEqual(f.mock_calls, [mock.call.send_frame(f1)])
|
||||
|
||||
def test_good(self):
|
||||
# Exercise the success path. The Record instance is given each chunk
|
||||
# of data as it arrives on Protocol.dataReceived, and is supposed to
|
||||
# return a series of Tokens (maybe none, if the chunk was incomplete,
|
||||
# or more than one, if the chunk was larger). Internally, it delivers
|
||||
# the chunks to the Framer for unframing (which returns 0 or more
|
||||
# frames), manages the Noise decryption object, and parses any
|
||||
# decrypted messages into tokens (some of which are consumed
|
||||
# internally, others for delivery upstairs).
|
||||
#
|
||||
# in the normal flow, we get:
|
||||
#
|
||||
# | | Inbound | NoiseAction | Outbound | ToUpstairs |
|
||||
# | | - | - | - | - |
|
||||
# | 1 | | | prologue | |
|
||||
# | 2 | prologue | | | |
|
||||
# | 3 | | write_message | handshake | |
|
||||
# | 4 | handshake | read_message | | Handshake |
|
||||
# | 5 | | encrypt | KCM | |
|
||||
# | 6 | KCM | decrypt | | KCM |
|
||||
# | 7 | msg1 | decrypt | | msg1 |
|
||||
|
||||
# 1: instantiating the Record instance causes the outbound prologue
|
||||
# to be sent
|
||||
|
||||
# 2+3: receipt of the inbound prologue triggers creation of the
|
||||
# ephemeral key (the "handshake") by calling noise.write_message()
|
||||
# and then writes the handshake to the outbound transport
|
||||
|
||||
# 4: when the peer's handshake is received, it is delivered to
|
||||
# noise.read_message(), which generates the shared key (enabling
|
||||
# noise.send() and noise.decrypt()). It also delivers the Handshake
|
||||
# token upstairs, which might (on the Follower) trigger immediate
|
||||
# transmission of the Key Confirmation Message (KCM)
|
||||
|
||||
# 5: the outbound KCM is framed and fed into noise.encrypt(), then
|
||||
# sent outbound
|
||||
|
||||
# 6: the peer's KCM is decrypted then delivered upstairs. The
|
||||
# Follower treats this as a signal that it should use this connection
|
||||
# (and drop all others).
|
||||
|
||||
# 7: the peer's first message is decrypted, parsed, and delivered
|
||||
# upstairs. This might be an Open or a Data, depending upon what
|
||||
# queued messages were left over from the previous connection
|
||||
|
||||
r, f, n = make_record()
|
||||
outbound_handshake = object()
|
||||
kcm, msg1 = object(), object()
|
||||
f_kcm, f_msg1 = object(), object()
|
||||
n.write_message = mock.Mock(return_value=outbound_handshake)
|
||||
n.decrypt = mock.Mock(side_effect=[kcm, msg1])
|
||||
n.encrypt = mock.Mock(side_effect=[f_kcm, f_msg1])
|
||||
f.add_and_parse = mock.Mock(side_effect=[[], # no tokens yet
|
||||
[Prologue()],
|
||||
[Frame("f_handshake")],
|
||||
[Frame("f_kcm"),
|
||||
Frame("f_msg1")],
|
||||
])
|
||||
|
||||
self.assertEqual(f.mock_calls, [])
|
||||
self.assertEqual(n.mock_calls, [mock.call.start_handshake()])
|
||||
n.mock_calls[:] = []
|
||||
|
||||
# 1. The Framer is responsible for sending the prologue, so we don't
|
||||
# have to check that here, we just check that the Framer was told
|
||||
# about connectionMade properly.
|
||||
r.connectionMade()
|
||||
self.assertEqual(f.mock_calls, [mock.call.connectionMade()])
|
||||
self.assertEqual(n.mock_calls, [])
|
||||
f.mock_calls[:] = []
|
||||
|
||||
# 2
|
||||
# we dribble the prologue in over two messages, to make sure we can
|
||||
# handle a dataReceived that doesn't complete the token
|
||||
|
||||
# remember, add_and_unframe is a generator
|
||||
self.assertEqual(list(r.add_and_unframe(b"pro")), [])
|
||||
self.assertEqual(f.mock_calls, [mock.call.add_and_parse(b"pro")])
|
||||
self.assertEqual(n.mock_calls, [])
|
||||
f.mock_calls[:] = []
|
||||
|
||||
self.assertEqual(list(r.add_and_unframe(b"logue")), [])
|
||||
# 3: write_message, send outbound handshake
|
||||
self.assertEqual(f.mock_calls, [mock.call.add_and_parse(b"logue"),
|
||||
mock.call.send_frame(outbound_handshake),
|
||||
])
|
||||
self.assertEqual(n.mock_calls, [mock.call.write_message()])
|
||||
f.mock_calls[:] = []
|
||||
n.mock_calls[:] = []
|
||||
|
||||
# 4
|
||||
# Now deliver the Noise "handshake", the ephemeral public key. This
|
||||
# is framed, but not a record, so it shouldn't decrypt or parse
|
||||
# anything, but the handshake is delivered to the Noise object, and
|
||||
# it does return a Handshake token so we can let the next layer up
|
||||
# react (by sending the KCM frame if we're a Follower, or not if
|
||||
# we're the Leader)
|
||||
|
||||
self.assertEqual(list(r.add_and_unframe(b"handshake")), [Handshake()])
|
||||
self.assertEqual(f.mock_calls, [mock.call.add_and_parse(b"handshake")])
|
||||
self.assertEqual(n.mock_calls, [mock.call.read_message("f_handshake")])
|
||||
f.mock_calls[:] = []
|
||||
n.mock_calls[:] = []
|
||||
|
||||
# 5: at this point we ought to be able to send a messge, the KCM
|
||||
with mock.patch("wormhole._dilation.connection.encode_record",
|
||||
side_effect=[b"r-kcm"]) as er:
|
||||
r.send_record(kcm)
|
||||
self.assertEqual(er.mock_calls, [mock.call(kcm)])
|
||||
self.assertEqual(n.mock_calls, [mock.call.encrypt(b"r-kcm")])
|
||||
self.assertEqual(f.mock_calls, [mock.call.send_frame(f_kcm)])
|
||||
n.mock_calls[:] = []
|
||||
f.mock_calls[:] = []
|
||||
|
||||
# 6: Now we deliver two messages stacked up: the KCM (Key
|
||||
# Confirmation Message) and the first real message. Concatenating
|
||||
# them tests that we can handle more than one token in a single
|
||||
# chunk. We need to mock parse_record() because everything past the
|
||||
# handshake is decrypted and parsed.
|
||||
|
||||
with mock.patch("wormhole._dilation.connection.parse_record",
|
||||
side_effect=[kcm, msg1]) as pr:
|
||||
self.assertEqual(list(r.add_and_unframe(b"kcm,msg1")),
|
||||
[kcm, msg1])
|
||||
self.assertEqual(f.mock_calls,
|
||||
[mock.call.add_and_parse(b"kcm,msg1")])
|
||||
self.assertEqual(n.mock_calls, [mock.call.decrypt("f_kcm"),
|
||||
mock.call.decrypt("f_msg1")])
|
||||
self.assertEqual(pr.mock_calls, [mock.call(kcm), mock.call(msg1)])
|
||||
n.mock_calls[:] = []
|
||||
f.mock_calls[:] = []
|
144
src/wormhole/test/dilate/test_subchannel.py
Normal file
144
src/wormhole/test/dilate/test_subchannel.py
Normal file
|
@ -0,0 +1,144 @@
|
|||
from __future__ import print_function, unicode_literals
|
||||
import mock
|
||||
from twisted.trial import unittest
|
||||
from twisted.internet.interfaces import ITransport
|
||||
from twisted.internet.error import ConnectionDone
|
||||
from ..._dilation.subchannel import (Once, SubChannel,
|
||||
_WormholeAddress, _SubchannelAddress,
|
||||
AlreadyClosedError)
|
||||
from .common import mock_manager
|
||||
|
||||
|
||||
def make_sc(set_protocol=True):
|
||||
scid = b"scid"
|
||||
hostaddr = _WormholeAddress()
|
||||
peeraddr = _SubchannelAddress(scid)
|
||||
m = mock_manager()
|
||||
sc = SubChannel(scid, m, hostaddr, peeraddr)
|
||||
p = mock.Mock()
|
||||
if set_protocol:
|
||||
sc._set_protocol(p)
|
||||
return sc, m, scid, hostaddr, peeraddr, p
|
||||
|
||||
|
||||
class SubChannelAPI(unittest.TestCase):
|
||||
def test_once(self):
|
||||
o = Once(ValueError)
|
||||
o()
|
||||
with self.assertRaises(ValueError):
|
||||
o()
|
||||
|
||||
def test_create(self):
|
||||
sc, m, scid, hostaddr, peeraddr, p = make_sc()
|
||||
self.assert_(ITransport.providedBy(sc))
|
||||
self.assertEqual(m.mock_calls, [])
|
||||
self.assertIdentical(sc.getHost(), hostaddr)
|
||||
self.assertIdentical(sc.getPeer(), peeraddr)
|
||||
|
||||
def test_write(self):
|
||||
sc, m, scid, hostaddr, peeraddr, p = make_sc()
|
||||
|
||||
sc.write(b"data")
|
||||
self.assertEqual(m.mock_calls, [mock.call.send_data(scid, b"data")])
|
||||
m.mock_calls[:] = []
|
||||
sc.writeSequence([b"more", b"data"])
|
||||
self.assertEqual(m.mock_calls, [mock.call.send_data(scid, b"moredata")])
|
||||
|
||||
def test_write_when_closing(self):
|
||||
sc, m, scid, hostaddr, peeraddr, p = make_sc()
|
||||
|
||||
sc.loseConnection()
|
||||
self.assertEqual(m.mock_calls, [mock.call.send_close(scid)])
|
||||
m.mock_calls[:] = []
|
||||
|
||||
with self.assertRaises(AlreadyClosedError) as e:
|
||||
sc.write(b"data")
|
||||
self.assertEqual(str(e.exception),
|
||||
"write not allowed on closed subchannel")
|
||||
|
||||
def test_local_close(self):
|
||||
sc, m, scid, hostaddr, peeraddr, p = make_sc()
|
||||
|
||||
sc.loseConnection()
|
||||
self.assertEqual(m.mock_calls, [mock.call.send_close(scid)])
|
||||
m.mock_calls[:] = []
|
||||
|
||||
# late arriving data is still delivered
|
||||
sc.remote_data(b"late")
|
||||
self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"late")])
|
||||
p.mock_calls[:] = []
|
||||
|
||||
sc.remote_close()
|
||||
self.assert_connectionDone(p.mock_calls)
|
||||
|
||||
def test_local_double_close(self):
|
||||
sc, m, scid, hostaddr, peeraddr, p = make_sc()
|
||||
|
||||
sc.loseConnection()
|
||||
self.assertEqual(m.mock_calls, [mock.call.send_close(scid)])
|
||||
m.mock_calls[:] = []
|
||||
|
||||
with self.assertRaises(AlreadyClosedError) as e:
|
||||
sc.loseConnection()
|
||||
self.assertEqual(str(e.exception),
|
||||
"loseConnection not allowed on closed subchannel")
|
||||
|
||||
def assert_connectionDone(self, mock_calls):
|
||||
self.assertEqual(len(mock_calls), 1)
|
||||
self.assertEqual(mock_calls[0][0], "connectionLost")
|
||||
self.assertEqual(len(mock_calls[0][1]), 1)
|
||||
self.assertIsInstance(mock_calls[0][1][0], ConnectionDone)
|
||||
|
||||
def test_remote_close(self):
|
||||
sc, m, scid, hostaddr, peeraddr, p = make_sc()
|
||||
sc.remote_close()
|
||||
self.assertEqual(m.mock_calls, [mock.call.subchannel_closed(sc)])
|
||||
self.assert_connectionDone(p.mock_calls)
|
||||
|
||||
def test_data(self):
|
||||
sc, m, scid, hostaddr, peeraddr, p = make_sc()
|
||||
sc.remote_data(b"data")
|
||||
self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"data")])
|
||||
p.mock_calls[:] = []
|
||||
sc.remote_data(b"not")
|
||||
sc.remote_data(b"coalesced")
|
||||
self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"not"),
|
||||
mock.call.dataReceived(b"coalesced"),
|
||||
])
|
||||
|
||||
def test_data_before_open(self):
|
||||
sc, m, scid, hostaddr, peeraddr, p = make_sc(set_protocol=False)
|
||||
sc.remote_data(b"data")
|
||||
self.assertEqual(p.mock_calls, [])
|
||||
sc._set_protocol(p)
|
||||
self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"data")])
|
||||
p.mock_calls[:] = []
|
||||
sc.remote_data(b"more")
|
||||
self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"more")])
|
||||
|
||||
def test_close_before_open(self):
|
||||
sc, m, scid, hostaddr, peeraddr, p = make_sc(set_protocol=False)
|
||||
sc.remote_close()
|
||||
self.assertEqual(p.mock_calls, [])
|
||||
sc._set_protocol(p)
|
||||
self.assert_connectionDone(p.mock_calls)
|
||||
|
||||
def test_producer(self):
|
||||
sc, m, scid, hostaddr, peeraddr, p = make_sc()
|
||||
|
||||
sc.pauseProducing()
|
||||
self.assertEqual(m.mock_calls, [mock.call.subchannel_pauseProducing(sc)])
|
||||
m.mock_calls[:] = []
|
||||
sc.resumeProducing()
|
||||
self.assertEqual(m.mock_calls, [mock.call.subchannel_resumeProducing(sc)])
|
||||
m.mock_calls[:] = []
|
||||
sc.stopProducing()
|
||||
self.assertEqual(m.mock_calls, [mock.call.subchannel_stopProducing(sc)])
|
||||
m.mock_calls[:] = []
|
||||
|
||||
def test_consumer(self):
|
||||
sc, m, scid, hostaddr, peeraddr, p = make_sc()
|
||||
|
||||
# TODO: more, once this is implemented
|
||||
sc.registerProducer(None, True)
|
||||
sc.unregisterProducer()
|
196
src/wormhole/test/test_hints.py
Normal file
196
src/wormhole/test/test_hints.py
Normal file
|
@ -0,0 +1,196 @@
|
|||
from __future__ import print_function, unicode_literals
|
||||
import io
|
||||
from collections import namedtuple
|
||||
import mock
|
||||
from twisted.internet import endpoints, reactor
|
||||
from twisted.trial import unittest
|
||||
from .._hints import (endpoint_from_hint_obj, parse_hint_argv, parse_tcp_v1_hint,
|
||||
describe_hint_obj, parse_hint, encode_hint,
|
||||
DirectTCPV1Hint, TorTCPV1Hint, RelayV1Hint)
|
||||
|
||||
UnknownHint = namedtuple("UnknownHint", ["stuff"])
|
||||
|
||||
|
||||
class Hints(unittest.TestCase):
|
||||
def test_endpoint_from_hint_obj(self):
|
||||
def efho(hint, tor=None):
|
||||
return endpoint_from_hint_obj(hint, tor, reactor)
|
||||
self.assertIsInstance(efho(DirectTCPV1Hint("host", 1234, 0.0)),
|
||||
endpoints.HostnameEndpoint)
|
||||
self.assertEqual(efho("unknown:stuff:yowza:pivlor"), None)
|
||||
|
||||
# tor=None
|
||||
self.assertEqual(efho(TorTCPV1Hint("host", "port", 0)), None)
|
||||
self.assertEqual(efho(UnknownHint("foo")), None)
|
||||
|
||||
tor = mock.Mock()
|
||||
def tor_ep(hostname, port):
|
||||
if hostname == "non-public":
|
||||
raise ValueError
|
||||
return ("tor_ep", hostname, port)
|
||||
tor.stream_via = mock.Mock(side_effect=tor_ep)
|
||||
|
||||
self.assertEqual(efho(DirectTCPV1Hint("host", 1234, 0.0), tor),
|
||||
("tor_ep", "host", 1234))
|
||||
self.assertEqual(efho(TorTCPV1Hint("host2.onion", 1234, 0.0), tor),
|
||||
("tor_ep", "host2.onion", 1234))
|
||||
self.assertEqual( efho(DirectTCPV1Hint("non-public", 1234, 0.0), tor), None)
|
||||
|
||||
self.assertEqual(efho(UnknownHint("foo"), tor), None)
|
||||
|
||||
def test_comparable(self):
|
||||
h1 = DirectTCPV1Hint("hostname", "port1", 0.0)
|
||||
h1b = DirectTCPV1Hint("hostname", "port1", 0.0)
|
||||
h2 = DirectTCPV1Hint("hostname", "port2", 0.0)
|
||||
r1 = RelayV1Hint(tuple(sorted([h1, h2])))
|
||||
r2 = RelayV1Hint(tuple(sorted([h2, h1])))
|
||||
r3 = RelayV1Hint(tuple(sorted([h1b, h2])))
|
||||
self.assertEqual(r1, r2)
|
||||
self.assertEqual(r2, r3)
|
||||
self.assertEqual(len(set([r1, r2, r3])), 1)
|
||||
|
||||
def test_parse_tcp_v1_hint(self):
|
||||
p = parse_tcp_v1_hint
|
||||
self.assertEqual(p({"type": "unknown"}), None)
|
||||
h = p({"type": "direct-tcp-v1", "hostname": "foo", "port": 1234})
|
||||
self.assertEqual(h, DirectTCPV1Hint("foo", 1234, 0.0))
|
||||
h = p({
|
||||
"type": "direct-tcp-v1",
|
||||
"hostname": "foo",
|
||||
"port": 1234,
|
||||
"priority": 2.5
|
||||
})
|
||||
self.assertEqual(h, DirectTCPV1Hint("foo", 1234, 2.5))
|
||||
h = p({"type": "tor-tcp-v1", "hostname": "foo", "port": 1234})
|
||||
self.assertEqual(h, TorTCPV1Hint("foo", 1234, 0.0))
|
||||
h = p({
|
||||
"type": "tor-tcp-v1",
|
||||
"hostname": "foo",
|
||||
"port": 1234,
|
||||
"priority": 2.5
|
||||
})
|
||||
self.assertEqual(h, TorTCPV1Hint("foo", 1234, 2.5))
|
||||
self.assertEqual(p({
|
||||
"type": "direct-tcp-v1"
|
||||
}), None) # missing hostname
|
||||
self.assertEqual(p({
|
||||
"type": "direct-tcp-v1",
|
||||
"hostname": 12
|
||||
}), None) # invalid hostname
|
||||
self.assertEqual(
|
||||
p({
|
||||
"type": "direct-tcp-v1",
|
||||
"hostname": "foo"
|
||||
}), None) # missing port
|
||||
self.assertEqual(
|
||||
p({
|
||||
"type": "direct-tcp-v1",
|
||||
"hostname": "foo",
|
||||
"port": "not a number"
|
||||
}), None) # invalid port
|
||||
|
||||
def test_parse_hint(self):
|
||||
p = parse_hint
|
||||
self.assertEqual(p({"type": "direct-tcp-v1",
|
||||
"hostname": "foo",
|
||||
"port": 12}),
|
||||
DirectTCPV1Hint("foo", 12, 0.0))
|
||||
self.assertEqual(p({"type": "relay-v1",
|
||||
"hints": [
|
||||
{"type": "direct-tcp-v1",
|
||||
"hostname": "foo",
|
||||
"port": 12},
|
||||
{"type": "unrecognized"},
|
||||
{"type": "direct-tcp-v1",
|
||||
"hostname": "bar",
|
||||
"port": 13}]}),
|
||||
RelayV1Hint([DirectTCPV1Hint("foo", 12, 0.0),
|
||||
DirectTCPV1Hint("bar", 13, 0.0)]))
|
||||
|
||||
def test_parse_hint_argv(self):
|
||||
def p(hint):
|
||||
stderr = io.StringIO()
|
||||
value = parse_hint_argv(hint, stderr=stderr)
|
||||
return value, stderr.getvalue()
|
||||
|
||||
h, stderr = p("tcp:host:1234")
|
||||
self.assertEqual(h, DirectTCPV1Hint("host", 1234, 0.0))
|
||||
self.assertEqual(stderr, "")
|
||||
|
||||
h, stderr = p("tcp:host:1234:priority=2.6")
|
||||
self.assertEqual(h, DirectTCPV1Hint("host", 1234, 2.6))
|
||||
self.assertEqual(stderr, "")
|
||||
|
||||
h, stderr = p("tcp:host:1234:unknown=stuff")
|
||||
self.assertEqual(h, DirectTCPV1Hint("host", 1234, 0.0))
|
||||
self.assertEqual(stderr, "")
|
||||
|
||||
h, stderr = p("$!@#^")
|
||||
self.assertEqual(h, None)
|
||||
self.assertEqual(stderr, "unparseable hint '$!@#^'\n")
|
||||
|
||||
h, stderr = p("unknown:stuff")
|
||||
self.assertEqual(h, None)
|
||||
self.assertEqual(stderr,
|
||||
"unknown hint type 'unknown' in 'unknown:stuff'\n")
|
||||
|
||||
h, stderr = p("tcp:just-a-hostname")
|
||||
self.assertEqual(h, None)
|
||||
self.assertEqual(
|
||||
stderr,
|
||||
"unparseable TCP hint (need more colons) 'tcp:just-a-hostname'\n")
|
||||
|
||||
h, stderr = p("tcp:host:number")
|
||||
self.assertEqual(h, None)
|
||||
self.assertEqual(stderr,
|
||||
"non-numeric port in TCP hint 'tcp:host:number'\n")
|
||||
|
||||
h, stderr = p("tcp:host:1234:priority=bad")
|
||||
self.assertEqual(h, None)
|
||||
self.assertEqual(
|
||||
stderr,
|
||||
"non-float priority= in TCP hint 'tcp:host:1234:priority=bad'\n")
|
||||
|
||||
def test_describe_hint_obj(self):
|
||||
d = describe_hint_obj
|
||||
self.assertEqual(d(DirectTCPV1Hint("host", 1234, 0.0), False, False),
|
||||
"->tcp:host:1234")
|
||||
self.assertEqual(d(DirectTCPV1Hint("host", 1234, 0.0), True, False),
|
||||
"->relay:tcp:host:1234")
|
||||
self.assertEqual(d(DirectTCPV1Hint("host", 1234, 0.0), False, True),
|
||||
"tor->tcp:host:1234")
|
||||
self.assertEqual(d(DirectTCPV1Hint("host", 1234, 0.0), True, True),
|
||||
"tor->relay:tcp:host:1234")
|
||||
self.assertEqual(d(TorTCPV1Hint("host", 1234, 0.0), False, False),
|
||||
"->tor:host:1234")
|
||||
self.assertEqual(d(UnknownHint("stuff"), False, False),
|
||||
"->%s" % str(UnknownHint("stuff")))
|
||||
|
||||
def test_encode_hint(self):
|
||||
e = encode_hint
|
||||
self.assertEqual(e(DirectTCPV1Hint("host", 1234, 1.0)),
|
||||
{"type": "direct-tcp-v1",
|
||||
"priority": 1.0,
|
||||
"hostname": "host",
|
||||
"port": 1234})
|
||||
self.assertEqual(e(RelayV1Hint([DirectTCPV1Hint("foo", 12, 0.0),
|
||||
DirectTCPV1Hint("bar", 13, 0.0)])),
|
||||
{"type": "relay-v1",
|
||||
"hints": [
|
||||
{"type": "direct-tcp-v1",
|
||||
"hostname": "foo",
|
||||
"port": 12,
|
||||
"priority": 0.0},
|
||||
{"type": "direct-tcp-v1",
|
||||
"hostname": "bar",
|
||||
"port": 13,
|
||||
"priority": 0.0},
|
||||
]})
|
||||
self.assertEqual(e(TorTCPV1Hint("host", 1234, 1.0)),
|
||||
{"type": "tor-tcp-v1",
|
||||
"priority": 1.0,
|
||||
"hostname": "host",
|
||||
"port": 1234})
|
||||
e = self.assertRaises(ValueError, e, "not a Hint")
|
||||
self.assertIn("unknown hint type", str(e))
|
||||
self.assertIn("not a Hint", str(e))
|
|
@ -12,8 +12,8 @@ import mock
|
|||
from .. import (__version__, _allocator, _boss, _code, _input, _key, _lister,
|
||||
_mailbox, _nameplate, _order, _receive, _rendezvous, _send,
|
||||
_terminator, errors, timing)
|
||||
from .._interfaces import (IAllocator, IBoss, ICode, IInput, IKey, ILister,
|
||||
IMailbox, INameplate, IOrder, IReceive,
|
||||
from .._interfaces import (IAllocator, IBoss, ICode, IDilator, IInput, IKey,
|
||||
ILister, IMailbox, INameplate, IOrder, IReceive,
|
||||
IRendezvousConnector, ISend, ITerminator, IWordlist)
|
||||
from .._key import derive_key, derive_phase_key, encrypt_data
|
||||
from ..journal import ImmediateJournal
|
||||
|
@ -1286,17 +1286,21 @@ class Boss(unittest.TestCase):
|
|||
"closed")
|
||||
versions = {"app": "version1"}
|
||||
reactor = None
|
||||
eq = None
|
||||
cooperator = None
|
||||
journal = ImmediateJournal()
|
||||
tor_manager = None
|
||||
client_version = ("python", __version__)
|
||||
b = MockBoss(wormhole, "side", "url", "appid", versions,
|
||||
client_version, reactor, journal, tor_manager,
|
||||
client_version, reactor, eq, cooperator, journal,
|
||||
tor_manager,
|
||||
timing.DebugTiming())
|
||||
b._T = Dummy("t", events, ITerminator, "close")
|
||||
b._S = Dummy("s", events, ISend, "send")
|
||||
b._RC = Dummy("rc", events, IRendezvousConnector, "start")
|
||||
b._C = Dummy("c", events, ICode, "allocate_code", "input_code",
|
||||
"set_code")
|
||||
b._D = Dummy("d", events, IDilator, "got_wormhole_versions", "got_key")
|
||||
return b, events
|
||||
|
||||
def test_basic(self):
|
||||
|
@ -1324,7 +1328,9 @@ class Boss(unittest.TestCase):
|
|||
b.got_message("0", b"msg1")
|
||||
self.assertEqual(events, [
|
||||
("w.got_key", b"key"),
|
||||
("d.got_key", b"key"),
|
||||
("w.got_verifier", b"verifier"),
|
||||
("d.got_wormhole_versions", {}),
|
||||
("w.got_versions", {}),
|
||||
("w.received", b"msg1"),
|
||||
])
|
||||
|
|
|
@ -3,7 +3,7 @@ from twisted.python.failure import Failure
|
|||
from twisted.trial import unittest
|
||||
|
||||
from ..eventual import EventualQueue
|
||||
from ..observer import OneShotObserver, SequenceObserver
|
||||
from ..observer import OneShotObserver, SequenceObserver, EmptyableSet
|
||||
|
||||
|
||||
class OneShot(unittest.TestCase):
|
||||
|
@ -121,3 +121,28 @@ class Sequence(unittest.TestCase):
|
|||
d2 = o.when_next_event()
|
||||
eq.flush_sync()
|
||||
self.assertIdentical(self.failureResultOf(d2), f)
|
||||
|
||||
|
||||
class Empty(unittest.TestCase):
|
||||
def test_set(self):
|
||||
eq = EventualQueue(Clock())
|
||||
s = EmptyableSet(_eventual_queue=eq)
|
||||
d1 = s.when_next_empty()
|
||||
eq.flush_sync()
|
||||
self.assertNoResult(d1)
|
||||
s.add(1)
|
||||
eq.flush_sync()
|
||||
self.assertNoResult(d1)
|
||||
s.add(2)
|
||||
s.discard(1)
|
||||
d2 = s.when_next_empty()
|
||||
eq.flush_sync()
|
||||
self.assertNoResult(d1)
|
||||
self.assertNoResult(d2)
|
||||
s.discard(2)
|
||||
eq.flush_sync()
|
||||
self.assertEqual(self.successResultOf(d1), None)
|
||||
self.assertEqual(self.successResultOf(d2), None)
|
||||
|
||||
s.add(3)
|
||||
s.discard(3)
|
||||
|
|
|
@ -3,7 +3,6 @@ from __future__ import print_function, unicode_literals
|
|||
import gc
|
||||
import io
|
||||
from binascii import hexlify, unhexlify
|
||||
from collections import namedtuple
|
||||
|
||||
import six
|
||||
from nacl.exceptions import CryptoError
|
||||
|
@ -18,7 +17,9 @@ import mock
|
|||
from wormhole_transit_relay import transit_server
|
||||
|
||||
from .. import transit
|
||||
from .._hints import DirectTCPV1Hint
|
||||
from ..errors import InternalError
|
||||
from ..util import HKDF
|
||||
from .common import ServerBase
|
||||
|
||||
|
||||
|
@ -140,141 +141,6 @@ class Misc(unittest.TestCase):
|
|||
self.assertIsInstance(portno, int)
|
||||
|
||||
|
||||
UnknownHint = namedtuple("UnknownHint", ["stuff"])
|
||||
|
||||
|
||||
class Hints(unittest.TestCase):
|
||||
def test_endpoint_from_hint_obj(self):
|
||||
c = transit.Common("")
|
||||
efho = c._endpoint_from_hint_obj
|
||||
self.assertIsInstance(
|
||||
efho(transit.DirectTCPV1Hint("host", 1234, 0.0)),
|
||||
endpoints.HostnameEndpoint)
|
||||
self.assertEqual(efho("unknown:stuff:yowza:pivlor"), None)
|
||||
# c._tor is currently None
|
||||
self.assertEqual(efho(transit.TorTCPV1Hint("host", "port", 0)), None)
|
||||
c._tor = mock.Mock()
|
||||
|
||||
def tor_ep(hostname, port):
|
||||
if hostname == "non-public":
|
||||
return None
|
||||
return ("tor_ep", hostname, port)
|
||||
|
||||
c._tor.stream_via = mock.Mock(side_effect=tor_ep)
|
||||
self.assertEqual(
|
||||
efho(transit.DirectTCPV1Hint("host", 1234, 0.0)),
|
||||
("tor_ep", "host", 1234))
|
||||
self.assertEqual(
|
||||
efho(transit.TorTCPV1Hint("host2.onion", 1234, 0.0)),
|
||||
("tor_ep", "host2.onion", 1234))
|
||||
self.assertEqual(
|
||||
efho(transit.DirectTCPV1Hint("non-public", 1234, 0.0)), None)
|
||||
self.assertEqual(efho(UnknownHint("foo")), None)
|
||||
|
||||
def test_comparable(self):
|
||||
h1 = transit.DirectTCPV1Hint("hostname", "port1", 0.0)
|
||||
h1b = transit.DirectTCPV1Hint("hostname", "port1", 0.0)
|
||||
h2 = transit.DirectTCPV1Hint("hostname", "port2", 0.0)
|
||||
r1 = transit.RelayV1Hint(tuple(sorted([h1, h2])))
|
||||
r2 = transit.RelayV1Hint(tuple(sorted([h2, h1])))
|
||||
r3 = transit.RelayV1Hint(tuple(sorted([h1b, h2])))
|
||||
self.assertEqual(r1, r2)
|
||||
self.assertEqual(r2, r3)
|
||||
self.assertEqual(len(set([r1, r2, r3])), 1)
|
||||
|
||||
def test_parse_tcp_v1_hint(self):
|
||||
c = transit.Common("")
|
||||
p = c._parse_tcp_v1_hint
|
||||
self.assertEqual(p({"type": "unknown"}), None)
|
||||
h = p({"type": "direct-tcp-v1", "hostname": "foo", "port": 1234})
|
||||
self.assertEqual(h, transit.DirectTCPV1Hint("foo", 1234, 0.0))
|
||||
h = p({
|
||||
"type": "direct-tcp-v1",
|
||||
"hostname": "foo",
|
||||
"port": 1234,
|
||||
"priority": 2.5
|
||||
})
|
||||
self.assertEqual(h, transit.DirectTCPV1Hint("foo", 1234, 2.5))
|
||||
h = p({"type": "tor-tcp-v1", "hostname": "foo", "port": 1234})
|
||||
self.assertEqual(h, transit.TorTCPV1Hint("foo", 1234, 0.0))
|
||||
h = p({
|
||||
"type": "tor-tcp-v1",
|
||||
"hostname": "foo",
|
||||
"port": 1234,
|
||||
"priority": 2.5
|
||||
})
|
||||
self.assertEqual(h, transit.TorTCPV1Hint("foo", 1234, 2.5))
|
||||
self.assertEqual(p({
|
||||
"type": "direct-tcp-v1"
|
||||
}), None) # missing hostname
|
||||
self.assertEqual(p({
|
||||
"type": "direct-tcp-v1",
|
||||
"hostname": 12
|
||||
}), None) # invalid hostname
|
||||
self.assertEqual(
|
||||
p({
|
||||
"type": "direct-tcp-v1",
|
||||
"hostname": "foo"
|
||||
}), None) # missing port
|
||||
self.assertEqual(
|
||||
p({
|
||||
"type": "direct-tcp-v1",
|
||||
"hostname": "foo",
|
||||
"port": "not a number"
|
||||
}), None) # invalid port
|
||||
|
||||
def test_parse_hint_argv(self):
|
||||
def p(hint):
|
||||
stderr = io.StringIO()
|
||||
value = transit.parse_hint_argv(hint, stderr=stderr)
|
||||
return value, stderr.getvalue()
|
||||
|
||||
h, stderr = p("tcp:host:1234")
|
||||
self.assertEqual(h, transit.DirectTCPV1Hint("host", 1234, 0.0))
|
||||
self.assertEqual(stderr, "")
|
||||
|
||||
h, stderr = p("tcp:host:1234:priority=2.6")
|
||||
self.assertEqual(h, transit.DirectTCPV1Hint("host", 1234, 2.6))
|
||||
self.assertEqual(stderr, "")
|
||||
|
||||
h, stderr = p("tcp:host:1234:unknown=stuff")
|
||||
self.assertEqual(h, transit.DirectTCPV1Hint("host", 1234, 0.0))
|
||||
self.assertEqual(stderr, "")
|
||||
|
||||
h, stderr = p("$!@#^")
|
||||
self.assertEqual(h, None)
|
||||
self.assertEqual(stderr, "unparseable hint '$!@#^'\n")
|
||||
|
||||
h, stderr = p("unknown:stuff")
|
||||
self.assertEqual(h, None)
|
||||
self.assertEqual(stderr,
|
||||
"unknown hint type 'unknown' in 'unknown:stuff'\n")
|
||||
|
||||
h, stderr = p("tcp:just-a-hostname")
|
||||
self.assertEqual(h, None)
|
||||
self.assertEqual(
|
||||
stderr,
|
||||
"unparseable TCP hint (need more colons) 'tcp:just-a-hostname'\n")
|
||||
|
||||
h, stderr = p("tcp:host:number")
|
||||
self.assertEqual(h, None)
|
||||
self.assertEqual(stderr,
|
||||
"non-numeric port in TCP hint 'tcp:host:number'\n")
|
||||
|
||||
h, stderr = p("tcp:host:1234:priority=bad")
|
||||
self.assertEqual(h, None)
|
||||
self.assertEqual(
|
||||
stderr,
|
||||
"non-float priority= in TCP hint 'tcp:host:1234:priority=bad'\n")
|
||||
|
||||
def test_describe_hint_obj(self):
|
||||
d = transit.describe_hint_obj
|
||||
self.assertEqual(
|
||||
d(transit.DirectTCPV1Hint("host", 1234, 0.0)), "tcp:host:1234")
|
||||
self.assertEqual(
|
||||
d(transit.TorTCPV1Hint("host", 1234, 0.0)), "tor:host:1234")
|
||||
self.assertEqual(d(UnknownHint("stuff")), str(UnknownHint("stuff")))
|
||||
|
||||
|
||||
# ipaddrs.py currently uses native strings: bytes on py2, unicode on
|
||||
# py3
|
||||
|
@ -437,7 +303,7 @@ class Listener(unittest.TestCase):
|
|||
hints, ep = c._build_listener()
|
||||
self.assertIsInstance(hints, (list, set))
|
||||
if hints:
|
||||
self.assertIsInstance(hints[0], transit.DirectTCPV1Hint)
|
||||
self.assertIsInstance(hints[0], DirectTCPV1Hint)
|
||||
self.assertIsInstance(ep, endpoints.TCP4ServerEndpoint)
|
||||
|
||||
def test_get_direct_hints(self):
|
||||
|
@ -1507,8 +1373,8 @@ class Transit(unittest.TestCase):
|
|||
self.assertEqual(self.successResultOf(d), "winner")
|
||||
self.assertEqual(self._descriptions, ["tor->relay:tcp:relay:1234"])
|
||||
|
||||
def _endpoint_from_hint_obj(self, hint):
|
||||
if isinstance(hint, transit.DirectTCPV1Hint):
|
||||
def _endpoint_from_hint_obj(self, hint, _tor, _reactor):
|
||||
if isinstance(hint, DirectTCPV1Hint):
|
||||
if hint.hostname == "unavailable":
|
||||
return None
|
||||
return hint.hostname
|
||||
|
@ -1523,20 +1389,21 @@ class Transit(unittest.TestCase):
|
|||
del hints
|
||||
s.add_connection_hints(
|
||||
[DIRECT_HINT_JSON, UNRECOGNIZED_HINT_JSON, RELAY_HINT_JSON])
|
||||
s._endpoint_from_hint_obj = self._endpoint_from_hint_obj
|
||||
s._start_connector = self._start_connector
|
||||
|
||||
d = s.connect()
|
||||
self.assertNoResult(d)
|
||||
# the direct connectors are tried right away, but the relay
|
||||
# connectors are stalled for a few seconds
|
||||
self.assertEqual(self._connectors, ["direct"])
|
||||
with mock.patch("wormhole.transit.endpoint_from_hint_obj",
|
||||
self._endpoint_from_hint_obj):
|
||||
d = s.connect()
|
||||
self.assertNoResult(d)
|
||||
# the direct connectors are tried right away, but the relay
|
||||
# connectors are stalled for a few seconds
|
||||
self.assertEqual(self._connectors, ["direct"])
|
||||
|
||||
clock.advance(s.RELAY_DELAY + 1.0)
|
||||
self.assertEqual(self._connectors, ["direct", "relay"])
|
||||
clock.advance(s.RELAY_DELAY + 1.0)
|
||||
self.assertEqual(self._connectors, ["direct", "relay"])
|
||||
|
||||
self._waiters[0].callback("winner")
|
||||
self.assertEqual(self.successResultOf(d), "winner")
|
||||
self._waiters[0].callback("winner")
|
||||
self.assertEqual(self.successResultOf(d), "winner")
|
||||
|
||||
@inlineCallbacks
|
||||
def test_priorities(self):
|
||||
|
@ -1586,31 +1453,32 @@ class Transit(unittest.TestCase):
|
|||
}]
|
||||
},
|
||||
])
|
||||
s._endpoint_from_hint_obj = self._endpoint_from_hint_obj
|
||||
s._start_connector = self._start_connector
|
||||
|
||||
d = s.connect()
|
||||
self.assertNoResult(d)
|
||||
# direct connector should be used first, then the priority=3.0 relay,
|
||||
# then the two 2.0 relays, then the (default) 0.0 relay
|
||||
with mock.patch("wormhole.transit.endpoint_from_hint_obj",
|
||||
self._endpoint_from_hint_obj):
|
||||
d = s.connect()
|
||||
self.assertNoResult(d)
|
||||
# direct connector should be used first, then the priority=3.0 relay,
|
||||
# then the two 2.0 relays, then the (default) 0.0 relay
|
||||
|
||||
self.assertEqual(self._connectors, ["direct"])
|
||||
self.assertEqual(self._connectors, ["direct"])
|
||||
|
||||
clock.advance(s.RELAY_DELAY + 1.0)
|
||||
self.assertEqual(self._connectors, ["direct", "relay3"])
|
||||
clock.advance(s.RELAY_DELAY + 1.0)
|
||||
self.assertEqual(self._connectors, ["direct", "relay3"])
|
||||
|
||||
clock.advance(s.RELAY_DELAY)
|
||||
self.assertIn(self._connectors,
|
||||
(["direct", "relay3", "relay2", "relay4"],
|
||||
["direct", "relay3", "relay4", "relay2"]))
|
||||
clock.advance(s.RELAY_DELAY)
|
||||
self.assertIn(self._connectors,
|
||||
(["direct", "relay3", "relay2", "relay4"],
|
||||
["direct", "relay3", "relay4", "relay2"]))
|
||||
|
||||
clock.advance(s.RELAY_DELAY)
|
||||
self.assertIn(self._connectors,
|
||||
(["direct", "relay3", "relay2", "relay4", "relay"],
|
||||
["direct", "relay3", "relay4", "relay2", "relay"]))
|
||||
clock.advance(s.RELAY_DELAY)
|
||||
self.assertIn(self._connectors,
|
||||
(["direct", "relay3", "relay2", "relay4", "relay"],
|
||||
["direct", "relay3", "relay4", "relay2", "relay"]))
|
||||
|
||||
self._waiters[0].callback("winner")
|
||||
self.assertEqual(self.successResultOf(d), "winner")
|
||||
self._waiters[0].callback("winner")
|
||||
self.assertEqual(self.successResultOf(d), "winner")
|
||||
|
||||
@inlineCallbacks
|
||||
def test_no_direct_hints(self):
|
||||
|
@ -1624,20 +1492,21 @@ class Transit(unittest.TestCase):
|
|||
UNRECOGNIZED_HINT_JSON, UNAVAILABLE_HINT_JSON, RELAY_HINT2_JSON,
|
||||
UNAVAILABLE_RELAY_HINT_JSON
|
||||
])
|
||||
s._endpoint_from_hint_obj = self._endpoint_from_hint_obj
|
||||
s._start_connector = self._start_connector
|
||||
|
||||
d = s.connect()
|
||||
self.assertNoResult(d)
|
||||
# since there are no usable direct hints, the relay connector will
|
||||
# only be stalled for 0 seconds
|
||||
self.assertEqual(self._connectors, [])
|
||||
with mock.patch("wormhole.transit.endpoint_from_hint_obj",
|
||||
self._endpoint_from_hint_obj):
|
||||
d = s.connect()
|
||||
self.assertNoResult(d)
|
||||
# since there are no usable direct hints, the relay connector will
|
||||
# only be stalled for 0 seconds
|
||||
self.assertEqual(self._connectors, [])
|
||||
|
||||
clock.advance(0)
|
||||
self.assertEqual(self._connectors, ["relay"])
|
||||
clock.advance(0)
|
||||
self.assertEqual(self._connectors, ["relay"])
|
||||
|
||||
self._waiters[0].callback("winner")
|
||||
self.assertEqual(self.successResultOf(d), "winner")
|
||||
self._waiters[0].callback("winner")
|
||||
self.assertEqual(self.successResultOf(d), "winner")
|
||||
|
||||
@inlineCallbacks
|
||||
def test_no_contenders(self):
|
||||
|
@ -1647,17 +1516,18 @@ class Transit(unittest.TestCase):
|
|||
hints = yield s.get_connection_hints() # start the listener
|
||||
del hints
|
||||
s.add_connection_hints([]) # no hints at all
|
||||
s._endpoint_from_hint_obj = self._endpoint_from_hint_obj
|
||||
s._start_connector = self._start_connector
|
||||
|
||||
d = s.connect()
|
||||
f = self.failureResultOf(d, transit.TransitError)
|
||||
self.assertEqual(str(f.value), "No contenders for connection")
|
||||
with mock.patch("wormhole.transit.endpoint_from_hint_obj",
|
||||
self._endpoint_from_hint_obj):
|
||||
d = s.connect()
|
||||
f = self.failureResultOf(d, transit.TransitError)
|
||||
self.assertEqual(str(f.value), "No contenders for connection")
|
||||
|
||||
|
||||
class RelayHandshake(unittest.TestCase):
|
||||
def old_build_relay_handshake(self, key):
|
||||
token = transit.HKDF(key, 32, CTXinfo=b"transit_relay_token")
|
||||
token = HKDF(key, 32, CTXinfo=b"transit_relay_token")
|
||||
return (token, b"please relay " + hexlify(token) + b"\n")
|
||||
|
||||
def test_old(self):
|
||||
|
|
|
@ -2,15 +2,13 @@
|
|||
from __future__ import absolute_import, print_function
|
||||
|
||||
import os
|
||||
import re
|
||||
import socket
|
||||
import sys
|
||||
import time
|
||||
from binascii import hexlify, unhexlify
|
||||
from collections import deque, namedtuple
|
||||
from collections import deque
|
||||
|
||||
import six
|
||||
from hkdf import Hkdf
|
||||
from nacl.secret import SecretBox
|
||||
from twisted.internet import (address, defer, endpoints, error, interfaces,
|
||||
protocol, reactor, task)
|
||||
|
@ -23,11 +21,10 @@ from zope.interface import implementer
|
|||
from . import ipaddrs
|
||||
from .errors import InternalError
|
||||
from .timing import DebugTiming
|
||||
from .util import bytes_to_hexstr
|
||||
|
||||
|
||||
def HKDF(skm, outlen, salt=None, CTXinfo=b""):
|
||||
return Hkdf(salt, skm).expand(CTXinfo, outlen)
|
||||
from .util import bytes_to_hexstr, HKDF
|
||||
from ._hints import (DirectTCPV1Hint, RelayV1Hint,
|
||||
parse_hint_argv, describe_hint_obj, endpoint_from_hint_obj,
|
||||
parse_tcp_v1_hint)
|
||||
|
||||
|
||||
class TransitError(Exception):
|
||||
|
@ -95,72 +92,6 @@ def build_sided_relay_handshake(key, side):
|
|||
"ascii") + b"\n"
|
||||
|
||||
|
||||
# These namedtuples are "hint objects". The JSON-serializable dictionaries
|
||||
# are "hint dicts".
|
||||
|
||||
# DirectTCPV1Hint and TorTCPV1Hint mean the following protocol:
|
||||
# * make a TCP connection (possibly via Tor)
|
||||
# * send the sender/receiver handshake bytes first
|
||||
# * expect to see the receiver/sender handshake bytes from the other side
|
||||
# * the sender writes "go\n", the receiver waits for "go\n"
|
||||
# * the rest of the connection contains transit data
|
||||
DirectTCPV1Hint = namedtuple("DirectTCPV1Hint",
|
||||
["hostname", "port", "priority"])
|
||||
TorTCPV1Hint = namedtuple("TorTCPV1Hint", ["hostname", "port", "priority"])
|
||||
# RelayV1Hint contains a tuple of DirectTCPV1Hint and TorTCPV1Hint hints (we
|
||||
# use a tuple rather than a list so they'll be hashable into a set). For each
|
||||
# one, make the TCP connection, send the relay handshake, then complete the
|
||||
# rest of the V1 protocol. Only one hint per relay is useful.
|
||||
RelayV1Hint = namedtuple("RelayV1Hint", ["hints"])
|
||||
|
||||
|
||||
def describe_hint_obj(hint):
|
||||
if isinstance(hint, DirectTCPV1Hint):
|
||||
return u"tcp:%s:%d" % (hint.hostname, hint.port)
|
||||
elif isinstance(hint, TorTCPV1Hint):
|
||||
return u"tor:%s:%d" % (hint.hostname, hint.port)
|
||||
else:
|
||||
return str(hint)
|
||||
|
||||
|
||||
def parse_hint_argv(hint, stderr=sys.stderr):
|
||||
assert isinstance(hint, type(u""))
|
||||
# return tuple or None for an unparseable hint
|
||||
priority = 0.0
|
||||
mo = re.search(r'^([a-zA-Z0-9]+):(.*)$', hint)
|
||||
if not mo:
|
||||
print("unparseable hint '%s'" % (hint, ), file=stderr)
|
||||
return None
|
||||
hint_type = mo.group(1)
|
||||
if hint_type != "tcp":
|
||||
print(
|
||||
"unknown hint type '%s' in '%s'" % (hint_type, hint), file=stderr)
|
||||
return None
|
||||
hint_value = mo.group(2)
|
||||
pieces = hint_value.split(":")
|
||||
if len(pieces) < 2:
|
||||
print(
|
||||
"unparseable TCP hint (need more colons) '%s'" % (hint, ),
|
||||
file=stderr)
|
||||
return None
|
||||
mo = re.search(r'^(\d+)$', pieces[1])
|
||||
if not mo:
|
||||
print("non-numeric port in TCP hint '%s'" % (hint, ), file=stderr)
|
||||
return None
|
||||
hint_host = pieces[0]
|
||||
hint_port = int(pieces[1])
|
||||
for more in pieces[2:]:
|
||||
if more.startswith("priority="):
|
||||
more_pieces = more.split("=")
|
||||
try:
|
||||
priority = float(more_pieces[1])
|
||||
except ValueError:
|
||||
print(
|
||||
"non-float priority= in TCP hint '%s'" % (hint, ),
|
||||
file=stderr)
|
||||
return None
|
||||
return DirectTCPV1Hint(hint_host, hint_port, priority)
|
||||
|
||||
|
||||
TIMEOUT = 60 # seconds
|
||||
|
||||
|
@ -746,30 +677,11 @@ class Common:
|
|||
self._listener_d.addErrback(lambda f: None)
|
||||
self._listener_d.cancel()
|
||||
|
||||
def _parse_tcp_v1_hint(self, hint): # hint_struct -> hint_obj
|
||||
hint_type = hint.get(u"type", u"")
|
||||
if hint_type not in [u"direct-tcp-v1", u"tor-tcp-v1"]:
|
||||
log.msg("unknown hint type: %r" % (hint, ))
|
||||
return None
|
||||
if not (u"hostname" in hint and
|
||||
isinstance(hint[u"hostname"], type(u""))):
|
||||
log.msg("invalid hostname in hint: %r" % (hint, ))
|
||||
return None
|
||||
if not (u"port" in hint and
|
||||
isinstance(hint[u"port"], six.integer_types)):
|
||||
log.msg("invalid port in hint: %r" % (hint, ))
|
||||
return None
|
||||
priority = hint.get(u"priority", 0.0)
|
||||
if hint_type == u"direct-tcp-v1":
|
||||
return DirectTCPV1Hint(hint[u"hostname"], hint[u"port"], priority)
|
||||
else:
|
||||
return TorTCPV1Hint(hint[u"hostname"], hint[u"port"], priority)
|
||||
|
||||
def add_connection_hints(self, hints):
|
||||
for h in hints: # hint structs
|
||||
hint_type = h.get(u"type", u"")
|
||||
if hint_type in [u"direct-tcp-v1", u"tor-tcp-v1"]:
|
||||
dh = self._parse_tcp_v1_hint(h)
|
||||
dh = parse_tcp_v1_hint(h)
|
||||
if dh:
|
||||
self._their_direct_hints.append(dh) # hint_obj
|
||||
elif hint_type == u"relay-v1":
|
||||
|
@ -779,7 +691,7 @@ class Common:
|
|||
# together like this.
|
||||
relay_hints = []
|
||||
for rhs in h.get(u"hints", []):
|
||||
h = self._parse_tcp_v1_hint(rhs)
|
||||
h = parse_tcp_v1_hint(rhs)
|
||||
if h:
|
||||
relay_hints.append(h)
|
||||
if relay_hints:
|
||||
|
@ -875,13 +787,11 @@ class Common:
|
|||
# Check the hint type to see if we can support it (e.g. skip
|
||||
# onion hints on a non-Tor client). Do not increase relay_delay
|
||||
# unless we have at least one viable hint.
|
||||
ep = self._endpoint_from_hint_obj(hint_obj)
|
||||
ep = endpoint_from_hint_obj(hint_obj, self._tor, self._reactor)
|
||||
if not ep:
|
||||
continue
|
||||
description = "->%s" % describe_hint_obj(hint_obj)
|
||||
if self._tor:
|
||||
description = "tor" + description
|
||||
d = self._start_connector(ep, description)
|
||||
d = self._start_connector(ep,
|
||||
describe_hint_obj(hint_obj, False, self._tor))
|
||||
contenders.append(d)
|
||||
relay_delay = self.RELAY_DELAY
|
||||
|
||||
|
@ -902,18 +812,15 @@ class Common:
|
|||
|
||||
for priority in sorted(prioritized_relays, reverse=True):
|
||||
for hint_obj in prioritized_relays[priority]:
|
||||
ep = self._endpoint_from_hint_obj(hint_obj)
|
||||
ep = endpoint_from_hint_obj(hint_obj, self._tor, self._reactor)
|
||||
if not ep:
|
||||
continue
|
||||
description = "->relay:%s" % describe_hint_obj(hint_obj)
|
||||
if self._tor:
|
||||
description = "tor" + description
|
||||
d = task.deferLater(
|
||||
self._reactor,
|
||||
relay_delay,
|
||||
self._start_connector,
|
||||
ep,
|
||||
description,
|
||||
describe_hint_obj(hint_obj, True, self._tor),
|
||||
is_relay=True)
|
||||
contenders.append(d)
|
||||
relay_delay += self.RELAY_DELAY
|
||||
|
@ -951,21 +858,6 @@ class Common:
|
|||
d.addCallback(lambda p: p.startNegotiation())
|
||||
return d
|
||||
|
||||
def _endpoint_from_hint_obj(self, hint):
|
||||
if self._tor:
|
||||
if isinstance(hint, (DirectTCPV1Hint, TorTCPV1Hint)):
|
||||
# this Tor object will throw ValueError for non-public IPv4
|
||||
# addresses and any IPv6 address
|
||||
try:
|
||||
return self._tor.stream_via(hint.hostname, hint.port)
|
||||
except ValueError:
|
||||
return None
|
||||
return None
|
||||
if isinstance(hint, DirectTCPV1Hint):
|
||||
return endpoints.HostnameEndpoint(self._reactor, hint.hostname,
|
||||
hint.port)
|
||||
return None
|
||||
|
||||
def connection_ready(self, p):
|
||||
# inbound/outbound Connection protocols call this when they finish
|
||||
# negotiation. The first one wins and gets a "go". Any subsequent
|
||||
|
|
|
@ -3,11 +3,19 @@ import json
|
|||
import os
|
||||
import unicodedata
|
||||
from binascii import hexlify, unhexlify
|
||||
from hkdf import Hkdf
|
||||
|
||||
|
||||
def HKDF(skm, outlen, salt=None, CTXinfo=b""):
|
||||
return Hkdf(salt, skm).expand(CTXinfo, outlen)
|
||||
|
||||
def to_bytes(u):
|
||||
return unicodedata.normalize("NFC", u).encode("utf-8")
|
||||
|
||||
def to_unicode(any):
|
||||
if isinstance(any, type(u"")):
|
||||
return any
|
||||
return any.decode("ascii")
|
||||
|
||||
def bytes_to_hexstr(b):
|
||||
assert isinstance(b, type(b""))
|
||||
|
|
|
@ -5,9 +5,12 @@ import sys
|
|||
|
||||
from attr import attrib, attrs
|
||||
from twisted.python import failure
|
||||
from twisted.internet.task import Cooperator
|
||||
from zope.interface import implementer
|
||||
|
||||
from ._boss import Boss
|
||||
from ._dilation.manager import DILATION_VERSIONS
|
||||
from ._dilation.connector import Connector
|
||||
from ._interfaces import IDeferredWormhole, IWormhole
|
||||
from ._key import derive_key
|
||||
from .errors import NoKeyError, WormholeClosed
|
||||
|
@ -122,7 +125,8 @@ class _DelegatedWormhole(object):
|
|||
|
||||
@implementer(IWormhole, IDeferredWormhole)
|
||||
class _DeferredWormhole(object):
|
||||
def __init__(self, eq):
|
||||
def __init__(self, reactor, eq):
|
||||
self._reactor = reactor
|
||||
self._welcome_observer = OneShotObserver(eq)
|
||||
self._code_observer = OneShotObserver(eq)
|
||||
self._key = None
|
||||
|
@ -187,6 +191,10 @@ class _DeferredWormhole(object):
|
|||
raise NoKeyError()
|
||||
return derive_key(self._key, to_bytes(purpose), length)
|
||||
|
||||
def dilate(self):
|
||||
raise NotImplementedError
|
||||
return self._boss.dilate() # fires with (endpoints)
|
||||
|
||||
def close(self):
|
||||
# fails with WormholeError unless we established a connection
|
||||
# (state=="happy"). Fails with WrongPasswordError (a subclass of
|
||||
|
@ -258,18 +266,24 @@ def create(
|
|||
side = bytes_to_hexstr(os.urandom(5))
|
||||
journal = journal or ImmediateJournal()
|
||||
eq = _eventual_queue or EventualQueue(reactor)
|
||||
cooperator = Cooperator(scheduler=eq.eventually)
|
||||
if delegate:
|
||||
w = _DelegatedWormhole(delegate)
|
||||
else:
|
||||
w = _DeferredWormhole(eq)
|
||||
wormhole_versions = {} # will be used to indicate Wormhole capabilities
|
||||
w = _DeferredWormhole(reactor, eq)
|
||||
# this indicates Wormhole capabilities
|
||||
wormhole_versions = {
|
||||
"can-dilate": DILATION_VERSIONS,
|
||||
"dilation-abilities": Connector.get_connection_abilities(),
|
||||
}
|
||||
wormhole_versions = {} # don't advertise Dilation yet: not ready
|
||||
wormhole_versions["app_versions"] = versions # app-specific capabilities
|
||||
v = __version__
|
||||
if isinstance(v, type(b"")):
|
||||
v = v.decode("utf-8", errors="replace")
|
||||
client_version = ("python", v)
|
||||
b = Boss(w, side, relay_url, appid, wormhole_versions, client_version,
|
||||
reactor, journal, tor, timing)
|
||||
reactor, eq, cooperator, journal, tor, timing)
|
||||
w._set_boss(b)
|
||||
b.start()
|
||||
return w
|
||||
|
|
4
tox.ini
4
tox.ini
|
@ -10,7 +10,7 @@ minversion = 2.4.0
|
|||
|
||||
[testenv]
|
||||
usedevelop = True
|
||||
extras = dev
|
||||
extras = dev,dilate
|
||||
deps =
|
||||
pyflakes >= 1.2.3
|
||||
commands =
|
||||
|
@ -18,6 +18,8 @@ commands =
|
|||
wormhole --version
|
||||
python -m wormhole.test.run_trial {posargs:wormhole}
|
||||
|
||||
[testenv:no-dilate]
|
||||
extras = dev
|
||||
|
||||
# on windows, trial is installed as venv/bin/trial.py, not .exe, but (at
|
||||
# least appveyor) adds .PY to $PATHEXT. So "trial wormhole" might work on
|
||||
|
|
Loading…
Reference in New Issue
Block a user