add dilation code
(this compresses several months of false starts and rearchitecting)
This commit is contained in:
parent
cd6ae6390f
commit
34686a346a
136
docs/api.md
136
docs/api.md
|
@ -524,25 +524,31 @@ object twice.
|
|||
|
||||
## Dilation
|
||||
|
||||
(NOTE: this section is speculative: this code has not yet been written)
|
||||
|
||||
In the longer term, the Wormhole object will incorporate the "Transit"
|
||||
functionality (see transit.md) directly, removing the need to instantiate a
|
||||
second object. A Wormhole can be "dilated" into a form that is suitable for
|
||||
bulk data transfer.
|
||||
To send bulk data, or anything more than a handful of messages, a Wormhole
|
||||
can be "dilated" into a form that uses a direct TCP connection between the
|
||||
two endpoints.
|
||||
|
||||
All wormholes start out "undilated". In this state, all messages are queued
|
||||
on the Rendezvous Server for the lifetime of the wormhole, and server-imposed
|
||||
number/size/rate limits apply. Calling `w.dilate()` initiates the dilation
|
||||
process, and success is signalled via either `d=w.when_dilated()` firing, or
|
||||
`dg.wormhole_dilated()` being called. Once dilated, the Wormhole can be used
|
||||
as an IConsumer/IProducer, and messages will be sent on a direct connection
|
||||
(if possible) or through the transit relay (if not).
|
||||
process, and eventually yields a set of Endpoints. Once dilated, the usual
|
||||
`.send_message()`/`.get_message()` APIs are disabled (TODO: really?), and
|
||||
these endpoints can be used to establish multiple (encrypted) "subchannel"
|
||||
connections to the other side.
|
||||
|
||||
Each subchannel behaves like a regular Twisted `ITransport`, so they can be
|
||||
glued to the Protocol instance of your choice. They also implement the
|
||||
IConsumer/IProducer interfaces.
|
||||
|
||||
These subchannels are *durable*: as long as the processes on both sides keep
|
||||
running, the subchannel will survive the network connection being dropped.
|
||||
For example, a file transfer can be started from a laptop, then while it is
|
||||
running, the laptop can be closed, moved to a new wifi network, opened back
|
||||
up, and the transfer will resume from the new IP address.
|
||||
|
||||
What's good about a non-dilated wormhole?:
|
||||
|
||||
* setup is faster: no delay while it tries to make a direct connection
|
||||
* survives temporary network outages, since messages are queued
|
||||
* works with "journaled mode", allowing progress to be made even when both
|
||||
sides are never online at the same time, by serializing the wormhole
|
||||
|
||||
|
@ -556,21 +562,103 @@ Use non-dilated wormholes when your application only needs to exchange a
|
|||
couple of messages, for example to set up public keys or provision access
|
||||
tokens. Use a dilated wormhole to move files.
|
||||
|
||||
Dilated wormholes can provide multiple "channels": these are multiplexed
|
||||
through the single (encrypted) TCP connection. Each channel is a separate
|
||||
stream (offering IProducer/IConsumer)
|
||||
Dilated wormholes can provide multiple "subchannels": these are multiplexed
|
||||
through the single (encrypted) TCP connection. Each subchannel is a separate
|
||||
stream (offering IProducer/IConsumer for flow control), and is opened and
|
||||
closed independently. A special "control channel" is available to both sides
|
||||
so they can coordinate how they use the subchannels.
|
||||
|
||||
To create a channel, call `c = w.create_channel()` on a dilated wormhole. The
|
||||
"channel ID" can be obtained with `c.get_id()`. This ID will be a short
|
||||
(unicode) string, which can be sent to the other side via a normal
|
||||
`w.send()`, or any other means. On the other side, use `c =
|
||||
w.open_channel(channel_id)` to get a matching channel object.
|
||||
The `d = w.dilate()` Deferred fires with a triple of Endpoints:
|
||||
|
||||
Then use `c.send(data)` and `d=c.when_received()` to exchange data, or wire
|
||||
them up with `c.registerProducer()`. Note that channels do not close until
|
||||
the wormhole connection is closed, so they do not have separate `close()`
|
||||
methods or events. Therefore if you plan to send files through them, you'll
|
||||
need to inform the recipient ahead of time about how many bytes to expect.
|
||||
```python
|
||||
d = w.dilate()
|
||||
def _dilated(res):
|
||||
(control_channel_ep, subchannel_client_ep, subchannel_server_ep) = res
|
||||
d.addCallback(_dilated)
|
||||
```
|
||||
|
||||
The `control_channel_ep` endpoint is a client-style endpoint, so both sides
|
||||
will connect to it with `ep.connect(factory)`. This endpoint is single-use:
|
||||
calling `.connect()` a second time will fail. The control channel is
|
||||
symmetric: it doesn't matter which side is the application-level
|
||||
client/server or initiator/responder, if the application even has such
|
||||
concepts. The two applications can use the control channel to negotiate who
|
||||
goes first, if necessary.
|
||||
|
||||
The subchannel endpoints are *not* symmetric: for each subchannel, one side
|
||||
must listen as a server, and the other must connect as a client. Subchannels
|
||||
can be established by either side at any time. This supports e.g.
|
||||
bidirectional file transfer, where either user of a GUI app can drop files
|
||||
into the "wormhole" whenever they like.
|
||||
|
||||
The `subchannel_client_ep` on one side is used to connect to the other side's
|
||||
`subchannel_server_ep`, and vice versa. The client endpoint is reusable. The
|
||||
server endpoint is single-use: `.listen(factory)` may only be called once.
|
||||
|
||||
Applications are under no obligation to use subchannels: for many use cases,
|
||||
the control channel is enough.
|
||||
|
||||
To use subchannels, once the wormhole is dilated and the endpoints are
|
||||
available, the listening-side application should attach a listener to the
|
||||
`subchannel_server_ep` endpoint:
|
||||
|
||||
```python
|
||||
def _dilated(res):
|
||||
(control_channel_ep, subchannel_client_ep, subchannel_server_ep) = res
|
||||
f = Factory(MyListeningProtocol)
|
||||
subchannel_server_ep.listen(f)
|
||||
```
|
||||
|
||||
When the connecting-side application wants to connect to that listening
|
||||
protocol, it should use `.connect()` with a suitable connecting protocol
|
||||
factory:
|
||||
|
||||
```python
|
||||
def _connect():
|
||||
f = Factory(MyConnectingProtocol)
|
||||
subchannel_client_ep.connect(f)
|
||||
```
|
||||
|
||||
For a bidirectional file-transfer application, both sides will establish a
|
||||
listening protocol. Later, if/when the user drops a file on the application
|
||||
window, that side will initiate a connection, use the resulting subchannel to
|
||||
transfer the single file, and then close the subchannel.
|
||||
|
||||
```python
|
||||
def FileSendingProtocol(internet.Protocol):
|
||||
def __init__(self, metadata, filename):
|
||||
self.file_metadata = metadata
|
||||
self.file_name = filename
|
||||
def connectionMade(self):
|
||||
self.transport.write(self.file_metadata)
|
||||
sender = protocols.basic.FileSender()
|
||||
f = open(self.file_name,"rb")
|
||||
d = sender.beginFileTransfer(f, self.transport)
|
||||
d.addBoth(self._done, f)
|
||||
def _done(res, f):
|
||||
self.transport.loseConnection()
|
||||
f.close()
|
||||
def _send(metadata, filename):
|
||||
f = protocol.ClientCreator(reactor,
|
||||
FileSendingProtocol, metadata, filename)
|
||||
subchannel_client_ep.connect(f)
|
||||
def FileReceivingProtocol(internet.Protocol):
|
||||
state = INITIAL
|
||||
def dataReceived(self, data):
|
||||
if state == INITIAL:
|
||||
self.state = DATA
|
||||
metadata = parse(data)
|
||||
self.f = open(metadata.filename, "wb")
|
||||
else:
|
||||
# local file writes are blocking, so don't bother with IConsumer
|
||||
self.f.write(data)
|
||||
def connectionLost(self, reason):
|
||||
self.f.close()
|
||||
def _dilated(res):
|
||||
(control_channel_ep, subchannel_client_ep, subchannel_server_ep) = res
|
||||
f = Factory(FileReceivingProtocol)
|
||||
subchannel_server_ep.listen(f)
|
||||
```
|
||||
|
||||
## Bytes, Strings, Unicode, and Python 3
|
||||
|
||||
|
|
500
docs/dilation-protocol.md
Normal file
500
docs/dilation-protocol.md
Normal file
|
@ -0,0 +1,500 @@
|
|||
# Dilation Internals
|
||||
|
||||
Wormhole dilation involves several moving parts. Both sides exchange messages
|
||||
through the Mailbox server to coordinate the establishment of a more direct
|
||||
connection. This connection might flow in either direction, so they trade
|
||||
"connection hints" to point at potential listening ports. This process might
|
||||
succeed in making multiple connections at about the same time, so one side
|
||||
must select the best one to use, and cleanly shut down the others. To make
|
||||
the dilated connection *durable*, this side must also decide when the
|
||||
connection has been lost, and then coordinate the construction of a
|
||||
replacement. Within this connection, a series of queued-and-acked subchannel
|
||||
messages are used to open/use/close the application-visible subchannels.
|
||||
|
||||
## Leaders and Followers
|
||||
|
||||
Each side of a Wormhole has a randomly-generated "side" string. When the
|
||||
wormhole is dilated, the side with the lexicographically-higher "side" value
|
||||
is named the "Leader", and the other side is named the "Follower". The general
|
||||
wormhole protocol treats both sides identically, but the distinction matters
|
||||
for the dilation protocol.
|
||||
|
||||
Either side can trigger dilation, but the Follower does so by asking the
|
||||
Leader to start the process, whereas the Leader just starts the process
|
||||
unilaterally. The Leader has exclusive control over whether a given
|
||||
connection is considered established or not: if there are multiple potential
|
||||
connections to use, the Leader decides which one to use, and the Leader gets
|
||||
to decide when the connection is no longer viable (and triggers the
|
||||
establishment of a new one).
|
||||
|
||||
## Connection Layers
|
||||
|
||||
We describe the protocol as a series of layers. Messages sent on one layer
|
||||
may be encoded or transformed before being delivered on some other layer.
|
||||
|
||||
L1 is the mailbox channel (queued store-and-forward messages that always go
|
||||
to the mailbox server, and then are forwarded to other clients subscribed to
|
||||
the same mailbox). Both clients remain connected to the mailbox server until
|
||||
the Wormhole is closed. They send DILATE-n messages to each other to manage
|
||||
the dilation process, including records like `please-dilate`,
|
||||
`start-dilation`, `ok-dilation`, and `connection-hints`
|
||||
|
||||
L2 is the set of competing connection attempts for a given generation of
|
||||
connection. Each time the Leader decides to establish a new connection, a new
|
||||
generation number is used. Hopefully these are direct TCP connections between
|
||||
the two peers, but they may also include connections through the transit
|
||||
relay. Each connection must go through an encrypted handshake process before
|
||||
it is considered viable. Viable connections are then submitted to a selection
|
||||
process (on the Leader side), which chooses exactly one to use, and drops the
|
||||
others. It may wait an extra few seconds in the hopes of getting a "better"
|
||||
connection (faster, cheaper, etc), but eventually it will select one.
|
||||
|
||||
L3 is the current selected connection. There is one L3 for each generation.
|
||||
At all times, the wormhole will have exactly zero or one L3 connection. L3 is
|
||||
responsible for the selection process, connection monitoring/keepalives, and
|
||||
serialization/deserialization of the plaintext frames. L3 delivers decoded
|
||||
frames and connection-establishment events up to L4.
|
||||
|
||||
L4 is the persistent higher-level channel. It is created as soon as the first
|
||||
L3 connection is selected, and lasts until wormhole is closed entirely. L4
|
||||
contains OPEN/DATA/CLOSE/ACK messages: OPEN/DATA/CLOSE have a sequence number
|
||||
(scoped to the L4 connection and the direction of travel), and the ACK
|
||||
messages reference those sequence numbers. When a message is given to the L4
|
||||
channel for delivery to the remote side, it is always queued, then
|
||||
transmitted if there is an L3 connection available. This message remains in
|
||||
the queue until an ACK is received to release it. If a new L3 connection is
|
||||
made, all queued messages will be re-sent (in seqnum order).
|
||||
|
||||
L5 are subchannels. There is one pre-established subchannel 0 known as the
|
||||
"control channel", which does not require an OPEN message. All other
|
||||
subchannels are created by the receipt of an OPEN message with the subchannel
|
||||
number. DATA frames are delivered to a specific subchannel. When the
|
||||
subchannel is no longer needed, one side will invoke the ``close()`` API
|
||||
(``loseConnection()`` in Twisted), which will cause a CLOSE message to be
|
||||
sent, and the local L5 object will be put into the "closing "state. When the
|
||||
other side receives the CLOSE, it will send its own CLOSE for the same
|
||||
subchannel, and fully close its local object (``connectionLost()``). When the
|
||||
first side receives CLOSE in the "closing" state, it will fully close its
|
||||
local object too.
|
||||
|
||||
All L5 subchannels will be paused (``pauseProducing()``) when the L3
|
||||
connection is paused or lost. They are resumed when the L3 connection is
|
||||
resumed or reestablished.
|
||||
|
||||
## Initiating Dilation
|
||||
|
||||
Dilation is triggered by calling the `w.dilate()` API. This returns a
|
||||
Deferred that will fire once the first L3 connection is established. It fires
|
||||
with a 3-tuple of endpoints that can be used to establish subchannels.
|
||||
|
||||
For dilation to succeed, both sides must call `w.dilate()`, since the
|
||||
resulting endpoints are the only way to access the subchannels. If the other
|
||||
side never calls `w.dilate()`, the Deferred will never fire.
|
||||
|
||||
The L1 (mailbox) path is used to deliver dilation requests and connection
|
||||
hints. The current mailbox protocol uses named "phases" to distinguish
|
||||
messages (rather than behaving like a regular ordered channel of arbitrary
|
||||
frames or bytes), and all-number phase names are reserved for application
|
||||
data (sent via `w.send_message()`). Therefore the dilation control messages
|
||||
use phases named `DILATE-0`, `DILATE-1`, etc. Each side maintains its own
|
||||
counter, so one side might be up to e.g. `DILATE-5` while the other has only
|
||||
gotten as far as `DILATE-2`. This effectively creates a unidirectional stream
|
||||
of `DILATE-n` messages, each containing one or more dilation record, of
|
||||
various types described below. Note that all phases beyond the initial
|
||||
VERSION and PAKE phases are encrypted by the shared session key.
|
||||
|
||||
A future mailbox protocol might provide a simple ordered stream of messages,
|
||||
with application records and dilation records mixed together.
|
||||
|
||||
Each `DILATE-n` message is a JSON-encoded dictionary with a `type` field that
|
||||
has a string value. The dictionary will have other keys that depend upon the
|
||||
type.
|
||||
|
||||
`w.dilate()` triggers a `please-dilate` record with a set of versions that
|
||||
can be accepted. Both Leader and Follower emit this record, although the
|
||||
Leader is responsible for version decisions. Versions use strings, rather
|
||||
than integers, to support experimental protocols, however there is still a
|
||||
total ordering of preferability.
|
||||
|
||||
```
|
||||
{ "type": "please-dilate",
|
||||
"accepted-versions": ["1"]
|
||||
}
|
||||
```
|
||||
|
||||
The Leader then sends a `start-dilation` message with a `version` field (the
|
||||
"best" mutually-supported value) and the new "L2 generation" number in the
|
||||
`generation` field. Generation numbers are integers, monotonically increasing
|
||||
by 1 each time.
|
||||
|
||||
```
|
||||
{ "type": start-dilation,
|
||||
"version": "1",
|
||||
"generation": 1,
|
||||
}
|
||||
```
|
||||
|
||||
The Follower responds with a `ok-dilation` message with matching `version`
|
||||
and `generation` fields.
|
||||
|
||||
The Leader decides when a new dilation connection is necessary, both for the
|
||||
initial connection and any subsequent reconnects. Therefore the Leader has
|
||||
the exclusive right to send the `start-dilation` record. It won't send this
|
||||
until after it has sent its own `please-dilate`, and after it has received
|
||||
the Follower's `please-dilate`. As a result, local preparations may begin as
|
||||
soon as `w.dilate()` is called, but L2 connections do not begin until the
|
||||
Leader declares the start of a new L2 generation with the `start-dilation`
|
||||
message.
|
||||
|
||||
Generations are non-overlapping. The Leader will drop all connections from
|
||||
generation 1 before sending the `start-dilation` for generation 2, and will
|
||||
not initiate any gen-2 connections until it receives the matching
|
||||
`ok-dilation` from the Follower. The Follower must drop all gen-1 connections
|
||||
before it sends the `ok-dilation` response (even if it thinks they are still
|
||||
functioning: if the Leader thought the gen-1 connection still worked, it
|
||||
wouldn't have started gen-2). Listening sockets can be retained, but any
|
||||
previous connection made through them must be dropped. This should avoid a
|
||||
race.
|
||||
|
||||
(TODO: what about a follower->leader connection that was started before
|
||||
start-dilation is received, and gets established on the Leader side after
|
||||
start-dilation is sent? the follower will drop it after it receives
|
||||
start-dilation, but meanwhile the leader may accept it as gen2)
|
||||
|
||||
(probably need to include the generation number in the handshake, or in the
|
||||
derived key)
|
||||
|
||||
(TODO: reduce the number of round-trip stalls here, I've added too many)
|
||||
|
||||
"Connection hints" are type/address/port records that tell the other side of
|
||||
likely targets for L2 connections. Both sides will try to determine their
|
||||
external IP addresses, listen on a TCP port, and advertise `(tcp,
|
||||
external-IP, port)` as a connection hint. The Transit Relay is also used as a
|
||||
(lower-priority) hint. These are sent in `connection-hint` records, which can
|
||||
be sent by the Leader any time after the `start-dilation` record, and by the
|
||||
Follower after the `ok-dilation` record. Each side will initiate connections
|
||||
upon receipt of the hints.
|
||||
|
||||
```
|
||||
{ "type": "connection-hints",
|
||||
"hints": [ ... ]
|
||||
}
|
||||
```
|
||||
|
||||
Hints can arrive at any time. One side might immediately send hints that can
|
||||
be computed quickly, then send additional hints later as they become
|
||||
available. For example, it might enumerate the local network interfaces and
|
||||
send hints for all of the LAN addresses first, then send port-forwarding
|
||||
(UPnP) requests to the local router. When the forwarding is established
|
||||
(providing an externally-visible IP address and port), it can send additional
|
||||
hints for that new endpoint. If the other peer happens to be on the same LAN,
|
||||
the local connection can be established without waiting for the router's
|
||||
response.
|
||||
|
||||
|
||||
### Connection Hint Format
|
||||
|
||||
Each member of the `hints` field describes a potential L2 connection target
|
||||
endpoint, with an associated priority and a set of hints.
|
||||
|
||||
The priority is a number (positive or negative float), where larger numbers
|
||||
indicate that the client supplying that hint would prefer to use this
|
||||
connection over others of lower number. This indicates a sense of cost or
|
||||
performance. For example, the Transit Relay is lower priority than a direct
|
||||
TCP connection, because it incurs a bandwidth cost (on the relay operator),
|
||||
as well as adding latency.
|
||||
|
||||
Each endpoint has a set of hints, because the same target might be reachable
|
||||
by multiple hints. Once one hint succeeds, there is no point in using the
|
||||
other hints.
|
||||
|
||||
TODO: think this through some more. What's the example of a single endpoint
|
||||
reachable by multiple hints? Should each hint have its own priority, or just
|
||||
each endpoint?
|
||||
|
||||
## L2 protocol
|
||||
|
||||
Upon ``connectionMade()``, both sides send their handshake message. The
|
||||
Leader sends "Magic-Wormhole Dilation Handshake v1 Leader\n\n". The Follower
|
||||
sends "Magic-Wormhole Dilation Handshake v1 Follower\n\n". This should
|
||||
trigger an immediate error for most non-magic-wormhole listeners (e.g. HTTP
|
||||
servers that were contacted by accident). If the wrong handshake is received,
|
||||
the connection will be dropped. For debugging purposes, the node might want
|
||||
to keep looking at data beyond the first incorrect character and log
|
||||
everything until the first newline.
|
||||
|
||||
Everything beyond that point is a Noise protocol message, which consists of a
|
||||
4-byte big-endian length field, followed by the indicated number of bytes.
|
||||
This ises the `NNpsk0` pattern with the Leader as the first party ("-> psk,
|
||||
e" in the Noise spec), and the Follower as the second ("<- e, ee"). The
|
||||
pre-shared-key is the "dilation key", which is statically derived from the
|
||||
master PAKE key using HKDF. Each L2 connection uses the same dilation key,
|
||||
but different ephemeral keys, so each gets a different session key.
|
||||
|
||||
The Leader sends the first message, which is a psk-encrypted ephemeral key.
|
||||
The Follower sends the next message, its own psk-encrypted ephemeral key. The
|
||||
Follower then sends an empty packet as the "key confirmation message", which
|
||||
will be encrypted by the shared key.
|
||||
|
||||
The Leader sees the KCM and knows the connection is viable. It delivers the
|
||||
protocol object to the L3 manager, which will decide which connection to
|
||||
select. When the L2 connection is selected to be the new L3, it will send an
|
||||
empty KCM of its own, to let the Follower know the connection being selected.
|
||||
All other L2 connections (either viable or still in handshake) are dropped,
|
||||
all other connection attempts are cancelled, and all listening sockets are
|
||||
shut down.
|
||||
|
||||
The Follower will wait for either an empty KCM (at which point the L2
|
||||
connection is delivered to the Dilation manager as the new L3), a
|
||||
disconnection, or an invalid message (which causes the connection to be
|
||||
dropped). Other connections and/or listening sockets are stopped.
|
||||
|
||||
Internally, the L2Protocol object manages the Noise session itself. It knows
|
||||
(via a constructor argument) whether it is on the Leader or Follower side,
|
||||
which affects both the role is plays in the Noise pattern, and the reaction
|
||||
to receiving the ephemeral key (for which only the Follower sends an empty
|
||||
KCM message). After that, the L2Protocol notifies the L3 object in three
|
||||
situations:
|
||||
|
||||
* the Noise session produces a valid decrypted frame (for Leader, this
|
||||
includes the Follower's KCM, and thus indicates a viable candidate for
|
||||
connection selection)
|
||||
* the Noise session reports a failed decryption
|
||||
* the TCP session is lost
|
||||
|
||||
All notifications include a reference to the L2Protocol object (`self`). The
|
||||
L3 object uses this reference to either close the connection (for errors or
|
||||
when the selection process chooses someone else), to send the KCM message
|
||||
(after selection, only for the Leader), or to send other L4 messages. The L3
|
||||
object will retain a reference to the winning L2 object.
|
||||
|
||||
## L3 protocol
|
||||
|
||||
The L3 layer is responsible for connection selection, monitoring/keepalives,
|
||||
and message (de)serialization. Framing is handled by L2, so the inbound L3
|
||||
codepath receives single-message bytestrings, and delivers the same down to
|
||||
L2 for encryption, framing, and transmission.
|
||||
|
||||
Connection selection takes place exclusively on the Leader side, and includes
|
||||
the following:
|
||||
|
||||
* receipt of viable L2 connections from below (indicated by the first valid
|
||||
decrypted frame received for any given connection)
|
||||
* expiration of a timer
|
||||
* comparison of TBD quality/desirability/cost metrics of viable connections
|
||||
* selection of winner
|
||||
* instructions to losing connections to disconnect
|
||||
* delivery of KCM message through winning connection
|
||||
* retain reference to winning connection
|
||||
|
||||
On the Follower side, the L3 manager just waits for the first connection to
|
||||
receive the Leader's KCM, at which point it is retained and all others are
|
||||
dropped.
|
||||
|
||||
The L3 manager knows which "generation" of connection is being established.
|
||||
Each generation uses a different dilation key (?), and is triggered by a new
|
||||
set of L1 messages. Connections from one generation should not be confused
|
||||
with those of a different generation.
|
||||
|
||||
Each time a new L3 connection is established, the L4 protocol is notified. It
|
||||
will will immediately send all the L4 messages waiting in its outbound queue.
|
||||
The L3 protocol simply wraps these in Noise frames and sends them to the
|
||||
other side.
|
||||
|
||||
The L3 manager monitors the viability of the current connection, and declares
|
||||
it as lost when bidirectional traffic cannot be maintained. It uses PING and
|
||||
PONG messages to detect this. These also serve to keep NAT entries alive,
|
||||
since many firewalls will stop forwarding packets if they don't observe any
|
||||
traffic for e.g. 5 minutes.
|
||||
|
||||
Our goals are:
|
||||
|
||||
* don't allow more than 30? seconds to pass without at least *some* data
|
||||
being sent along each side of the connection
|
||||
* allow the Leader to detect silent connection loss within 60? seconds
|
||||
* minimize overhead
|
||||
|
||||
We need both sides to:
|
||||
|
||||
* maintain a 30-second repeating timer
|
||||
* set a flag each time we write to the connection
|
||||
* each time the timer fires, if the flag was clear then send a PONG,
|
||||
otherwise clear the flag
|
||||
|
||||
In addition, the Leader must:
|
||||
|
||||
* run a 60-second repeating timer (ideally somewhat offset from the other)
|
||||
* set a flag each time we receive data from the connection
|
||||
* each time the timer fires, if the flag was clear then drop the connection,
|
||||
otherwise clear the flag
|
||||
|
||||
In the future, we might have L2 links that are less connection-oriented,
|
||||
which might have a unidirectional failure mode, at which point we'll need to
|
||||
monitor full roundtrips. To accomplish this, the Leader will send periodic
|
||||
unconditional PINGs, and the Follower will respond with PONGs. If the
|
||||
Leader->Follower connection is down, the PINGs won't arrive and no PONGs will
|
||||
be produced. If the Follower->Leader direction has failed, the PONGs won't
|
||||
arrive. The delivery of both will be delayed by actual data, so the timeouts
|
||||
should be adjusted if we see regular data arriving.
|
||||
|
||||
If the connection is dropped before the wormhole is closed (either the other
|
||||
end explicitly dropped it, we noticed a problem and told TCP to drop it, or
|
||||
TCP noticed a problem itself), the Leader-side L3 manager will initiate a
|
||||
reconnection attempt. This uses L1 to send a new DILATE message through the
|
||||
mailbox server, along with new connection hints. Eventually this will result
|
||||
in a new L3 connection being established.
|
||||
|
||||
Finally, L3 is responsible for message serialization and deserialization. L2
|
||||
performs decryption and delivers plaintext frames to L3. Each frame starts
|
||||
with a one-byte type indicator. The rest of the message depends upon the
|
||||
type:
|
||||
|
||||
* 0x00 PING, 4-byte ping-id
|
||||
* 0x01 PONG, 4-byte ping-id
|
||||
* 0x02 OPEN, 4-byte subchannel-id, 4-byte seqnum
|
||||
* 0x03 DATA, 4-byte subchannel-id, 4-byte seqnum, variable-length payload
|
||||
* 0x04 CLOSE, 4-byte subchannel-id, 4-byte seqnum
|
||||
* 0x05 ACK, 4-byte response-seqnum
|
||||
|
||||
All seqnums are big-endian, and are provided by the L4 protocol. The other
|
||||
fields are arbitrary and not interpreted as integers. The subchannel-ids must
|
||||
be allocated by both sides without collision, but otherwise they are only
|
||||
used to look up L5 objects for dispatch. The response-seqnum is always copied
|
||||
from the OPEN/DATA/CLOSE packet being acknowledged.
|
||||
|
||||
L3 consumes the PING and PONG messages. Receiving any PING will provoke a
|
||||
PONG in response, with a copy of the ping-id field. The 30-second timer will
|
||||
produce unprovoked PONGs with a ping-id of all zeros. A future viability
|
||||
protocol will use PINGs to test for roundtrip functionality.
|
||||
|
||||
All other messages (OPEN/DATA/CLOSE/ACK) are deserialized and delivered
|
||||
"upstairs" to the L4 protocol handler.
|
||||
|
||||
The current L3 connection's `IProducer`/`IConsumer` interface is made
|
||||
available to the L4 flow-control manager.
|
||||
|
||||
## L4 protocol
|
||||
|
||||
The L4 protocol manages a durable stream of OPEN/DATA/CLOSE/ACK messages.
|
||||
Since each will be enclosed in a Noise frame before they pass to L3, they do
|
||||
not need length fields or other framing.
|
||||
|
||||
Each OPEN/DATA/CLOSE has a sequence number, starting at 0, and monotonically
|
||||
increasing by 1 for each message. Each direction has a separate number space.
|
||||
|
||||
The L4 manager maintains a double-ended queue of unacknowledged outbound
|
||||
messages. Subchannel activity (opening, closing, sending data) cause messages
|
||||
to be added to this queue. If an L3 connection is available, these messages
|
||||
are also sent over that connection, but they remain in the queue in case the
|
||||
connection is lost and they must be retransmitted on some future replacement
|
||||
connection. Messages stay in the queue until they can be retired by the
|
||||
receipt of an ACK with a matching response-sequence-number. This provides
|
||||
reliable message delivery that survives the L3 connection being replaced.
|
||||
|
||||
ACKs are not acked, nor do they have seqnums of their own. Each inbound side
|
||||
remembers the highest ACK it has sent, and ignores incoming OPEN/DATA/CLOSE
|
||||
messages with that sequence number or higher. This ensures in-order
|
||||
at-most-once processing of OPEN/DATA/CLOSE messages.
|
||||
|
||||
Each inbound OPEN message causes a new L5 subchannel object to be created.
|
||||
Subsequent DATA/CLOSE messages for the same subchannel-id are delivered to
|
||||
that object.
|
||||
|
||||
Each time an L3 connection is established, the side will immediately send all
|
||||
L4 messages waiting in the outbound queue. A future protocol might reduce
|
||||
this duplication by including the highest received sequence number in the L1
|
||||
PLEASE-DILATE message, which would effectively retire queued messages before
|
||||
initiating the L2 connection process. On any given L3 connection, all
|
||||
messages are sent in-order. The receipt of an ACK for seqnum `N` allows all
|
||||
messages with `seqnum <= N` to be retired.
|
||||
|
||||
The L4 layer is also responsible for managing flow control among the L3
|
||||
connection and the various L5 subchannels.
|
||||
|
||||
## L5 subchannels
|
||||
|
||||
The L5 layer consists of a collection of "subchannel" objects, a dispatcher,
|
||||
and the endpoints that provide the Twisted-flavored API.
|
||||
|
||||
Other than the "control channel", all subchannels are created by a client
|
||||
endpoint connection API. The side that calls this API is named the Initiator,
|
||||
and the other side is named the Acceptor. Subchannels can be initiated in
|
||||
either direction, independent of the Leader/Follower distinction. For a
|
||||
typical file-transfer application, the subchannel would be initiated by the
|
||||
side seeking to send a file.
|
||||
|
||||
Each subchannel uses a distinct subchannel-id, which is a four-byte
|
||||
identifier. Both directions share a number space (unlike L4 seqnums), so the
|
||||
rule is that the Leader side sets the last bit of the last byte to a 0, while
|
||||
the Follower sets it to a 1. These are not generally treated as integers,
|
||||
however for the sake of debugging, the implementation generates them with a
|
||||
simple big-endian-encoded counter (`next(counter)*2` for the Leader,
|
||||
`next(counter)*2+1` for the Follower).
|
||||
|
||||
When the `client_ep.connect()` API is called, the Initiator allocates a
|
||||
subchannel-id and sends an OPEN. It can then immediately send DATA messages
|
||||
with the outbound data (there is no special response to an OPEN, so there is
|
||||
no need to wait). The Acceptor will trigger their `.connectionMade` handler
|
||||
upon receipt of the OPEN.
|
||||
|
||||
Subchannels are durable: they do not close until one side calls
|
||||
`.loseConnection` on the subchannel object (or the enclosing Wormhole is
|
||||
closed). Either the Initiator or the Acceptor can call `.loseConnection`.
|
||||
This causes a CLOSE message to be sent (with the subchannel-id). The other
|
||||
side will send its own CLOSE message in response. Each side will signal the
|
||||
`.connectionLost()` event upon receipt of a CLOSE.
|
||||
|
||||
There is no equivalent to TCP's "half-closed" state, however if only one side
|
||||
calls `close()`, then all data written before that call will be delivered
|
||||
before the other side observes `.connectionLost()`. Any inbound data that was
|
||||
queued for delivery before the other side sees the CLOSE will still be
|
||||
delivered to the side that called `close()` before it sees
|
||||
`.connectionLost()`. Internally, the side which called `.loseConnection` will
|
||||
remain in a special "closing" state until the CLOSE response arrives, during
|
||||
which time DATA payloads are still delivered. After calling `close()` (or
|
||||
receiving CLOSE), any outbound `.write()` calls will trigger an error.
|
||||
|
||||
DATA payloads that arrive for a non-open subchannel are logged and discarded.
|
||||
|
||||
This protocol calls for one OPEN and two CLOSE messages for each subchannel,
|
||||
with some arbitrary number of DATA messages in between. Subchannel-ids should
|
||||
not be reused (it would probably work, the protocol hasn't been analyzed
|
||||
enough to be sure).
|
||||
|
||||
The "control channel" is special. It uses a subchannel-id of all zeros, and
|
||||
is opened implicitly by both sides as soon as the first L3 connection is
|
||||
selected. It is routed to a special client-on-both-sides endpoint, rather
|
||||
than causing the listening endpoint to accept a new connection. This avoids
|
||||
the need for application-level code to negotiate who should be the one to
|
||||
open it (the Leader/Follower distinction is private to the Wormhole
|
||||
internals: applications are not obligated to pick a side).
|
||||
|
||||
OPEN and CLOSE messages for the control channel are logged and discarded. The
|
||||
control-channel client endpoints can only be used once, and does not close
|
||||
until the Wormhole itself is closed.
|
||||
|
||||
Each OPEN/DATA/CLOSE message is delivered to the L4 object for queueing,
|
||||
delivery, and eventual retirement. The L5 layer does not keep track of old
|
||||
messages.
|
||||
|
||||
### Flow Control
|
||||
|
||||
Subchannels are flow-controlled by pausing their writes when the L3
|
||||
connection is paused, and pausing the L3 connection when the subchannel
|
||||
signals a pause. When the outbound L3 connection is full, *all* subchannels
|
||||
are paused. Likewise the inbound connection is paused if *any* of the
|
||||
subchannels asks for a pause. This is much easier to implement and improves
|
||||
our utilization factor (we can use TCP's window-filling algorithm, instead of
|
||||
rolling our own), but will block all subchannels even if only one of them
|
||||
gets full. This shouldn't matter for many applications, but might be
|
||||
noticeable when combining very different kinds of traffic (e.g. a chat
|
||||
conversation sharing a wormhole with file-transfer might prefer the IM text
|
||||
to take priority).
|
||||
|
||||
Each subchannel implements Twisted's `ITransport`, `IProducer`, and
|
||||
`IConsumer` interfaces. The Endpoint API causes a new `IProtocol` object to
|
||||
be created (by the caller's factory) and glued to the subchannel object in
|
||||
the `.transport` property, as is standard in Twisted-based applications.
|
||||
|
||||
All subchannels are also paused when the L3 connection is lost, and are
|
||||
unpaused when a new replacement connection is selected.
|
2000
docs/new-protocol.svg
Normal file
2000
docs/new-protocol.svg
Normal file
File diff suppressed because it is too large
Load Diff
After Width: | Height: | Size: 91 KiB |
|
@ -23,12 +23,13 @@ digraph {
|
|||
Terminator [shape="box" color="blue" fontcolor="blue"]
|
||||
InputHelperAPI [shape="oval" label="input\nhelper\nAPI"
|
||||
color="blue" fontcolor="blue"]
|
||||
Dilator [shape="box" label="Dilator" color="blue" fontcolor="blue"]
|
||||
|
||||
#Connection -> websocket [color="blue"]
|
||||
#Connection -> Order [color="blue"]
|
||||
|
||||
Wormhole -> Boss [style="dashed"
|
||||
label="allocate_code\ninput_code\nset_code\nsend\nclose\n(once)"
|
||||
label="allocate_code\ninput_code\nset_code\ndilate\nsend\nclose\n(once)"
|
||||
color="red" fontcolor="red"]
|
||||
#Wormhole -> Boss [color="blue"]
|
||||
Boss -> Wormhole [style="dashed" label="got_code\ngot_key\ngot_verifier\ngot_version\nreceived (seq)\nclosed\n(once)"]
|
||||
|
@ -112,4 +113,7 @@ digraph {
|
|||
Terminator -> Boss [style="dashed" label="closed\n(once)"]
|
||||
Boss -> Terminator [style="dashed" color="red" fontcolor="red"
|
||||
label="close"]
|
||||
|
||||
Boss -> Dilator [style="dashed" label="dilate\nreceived_dilate\ngot_wormhole_versions"]
|
||||
Dilator -> Send [style="dashed" label="send(dilate-N)"]
|
||||
}
|
||||
|
|
1
setup.py
1
setup.py
|
@ -48,6 +48,7 @@ setup(name="magic-wormhole",
|
|||
"click",
|
||||
"humanize",
|
||||
"txtorcon >= 18.0.2", # 18.0.2 fixes py3.4 support
|
||||
"noiseprotocol",
|
||||
],
|
||||
extras_require={
|
||||
':sys_platform=="win32"': ["pywin32"],
|
||||
|
|
|
@ -12,6 +12,7 @@ from zope.interface import implementer
|
|||
from . import _interfaces
|
||||
from ._allocator import Allocator
|
||||
from ._code import Code, validate_code
|
||||
from ._dilation.manager import Dilator
|
||||
from ._input import Input
|
||||
from ._key import Key
|
||||
from ._lister import Lister
|
||||
|
@ -66,6 +67,7 @@ class Boss(object):
|
|||
self._I = Input(self._timing)
|
||||
self._C = Code(self._timing)
|
||||
self._T = Terminator()
|
||||
self._D = Dilator(self._reactor, self._eventual_queue, self._cooperator)
|
||||
|
||||
self._N.wire(self._M, self._I, self._RC, self._T)
|
||||
self._M.wire(self._N, self._RC, self._O, self._T)
|
||||
|
@ -79,6 +81,7 @@ class Boss(object):
|
|||
self._I.wire(self._C, self._L)
|
||||
self._C.wire(self, self._A, self._N, self._K, self._I)
|
||||
self._T.wire(self, self._RC, self._N, self._M)
|
||||
self._D.wire(self._S)
|
||||
|
||||
def _init_other_state(self):
|
||||
self._did_start_code = False
|
||||
|
@ -86,6 +89,9 @@ class Boss(object):
|
|||
self._next_rx_phase = 0
|
||||
self._rx_phases = {} # phase -> plaintext
|
||||
|
||||
self._next_rx_dilate_seqnum = 0
|
||||
self._rx_dilate_seqnums = {} # seqnum -> plaintext
|
||||
|
||||
self._result = "empty"
|
||||
|
||||
# these methods are called from outside
|
||||
|
@ -198,6 +204,9 @@ class Boss(object):
|
|||
self._did_start_code = True
|
||||
self._C.set_code(code)
|
||||
|
||||
def dilate(self):
|
||||
return self._D.dilate() # fires with endpoints
|
||||
|
||||
@m.input()
|
||||
def send(self, plaintext):
|
||||
pass
|
||||
|
@ -258,8 +267,11 @@ class Boss(object):
|
|||
# this is only called for side != ours
|
||||
assert isinstance(phase, type("")), type(phase)
|
||||
assert isinstance(plaintext, type(b"")), type(plaintext)
|
||||
d_mo = re.search(r'^dilate-(\d+)$', phase)
|
||||
if phase == "version":
|
||||
self._got_version(side, plaintext)
|
||||
elif d_mo:
|
||||
self._got_dilate(int(d_mo.group(1)), plaintext)
|
||||
elif re.search(r'^\d+$', phase):
|
||||
self._got_phase(int(phase), plaintext)
|
||||
else:
|
||||
|
@ -275,6 +287,10 @@ class Boss(object):
|
|||
def _got_phase(self, phase, plaintext):
|
||||
pass
|
||||
|
||||
@m.input()
|
||||
def _got_dilate(self, seqnum, plaintext):
|
||||
pass
|
||||
|
||||
@m.input()
|
||||
def got_key(self, key):
|
||||
pass
|
||||
|
@ -298,6 +314,8 @@ class Boss(object):
|
|||
# in the future, this is how Dilation is signalled
|
||||
self._their_side = side
|
||||
self._their_versions = bytes_to_dict(plaintext)
|
||||
self._D.got_wormhole_versions(self._side, self._their_side,
|
||||
self._their_versions)
|
||||
# but this part is app-to-app
|
||||
app_versions = self._their_versions.get("app_versions", {})
|
||||
self._W.got_versions(app_versions)
|
||||
|
@ -339,6 +357,10 @@ class Boss(object):
|
|||
def W_got_key(self, key):
|
||||
self._W.got_key(key)
|
||||
|
||||
@m.output()
|
||||
def D_got_key(self, key):
|
||||
self._D.got_key(key)
|
||||
|
||||
@m.output()
|
||||
def W_got_verifier(self, verifier):
|
||||
self._W.got_verifier(verifier)
|
||||
|
@ -352,6 +374,16 @@ class Boss(object):
|
|||
self._W.received(self._rx_phases.pop(self._next_rx_phase))
|
||||
self._next_rx_phase += 1
|
||||
|
||||
@m.output()
|
||||
def D_received_dilate(self, seqnum, plaintext):
|
||||
assert isinstance(seqnum, six.integer_types), type(seqnum)
|
||||
# strict phase order, no gaps
|
||||
self._rx_dilate_seqnums[seqnum] = plaintext
|
||||
while self._next_rx_dilate_seqnum in self._rx_dilate_seqnums:
|
||||
m = self._rx_dilate_seqnums.pop(self._next_rx_dilate_seqnum)
|
||||
self._D.received_dilate(m)
|
||||
self._next_rx_dilate_seqnum += 1
|
||||
|
||||
@m.output()
|
||||
def W_close_with_error(self, err):
|
||||
self._result = err # exception
|
||||
|
@ -374,7 +406,7 @@ class Boss(object):
|
|||
S1_lonely.upon(scared, enter=S3_closing, outputs=[close_scared])
|
||||
S1_lonely.upon(close, enter=S3_closing, outputs=[close_lonely])
|
||||
S1_lonely.upon(send, enter=S1_lonely, outputs=[S_send])
|
||||
S1_lonely.upon(got_key, enter=S1_lonely, outputs=[W_got_key])
|
||||
S1_lonely.upon(got_key, enter=S1_lonely, outputs=[W_got_key, D_got_key])
|
||||
S1_lonely.upon(rx_error, enter=S3_closing, outputs=[close_error])
|
||||
S1_lonely.upon(error, enter=S4_closed, outputs=[W_close_with_error])
|
||||
|
||||
|
@ -382,6 +414,7 @@ class Boss(object):
|
|||
S2_happy.upon(got_verifier, enter=S2_happy, outputs=[W_got_verifier])
|
||||
S2_happy.upon(_got_phase, enter=S2_happy, outputs=[W_received])
|
||||
S2_happy.upon(_got_version, enter=S2_happy, outputs=[process_version])
|
||||
S2_happy.upon(_got_dilate, enter=S2_happy, outputs=[D_received_dilate])
|
||||
S2_happy.upon(scared, enter=S3_closing, outputs=[close_scared])
|
||||
S2_happy.upon(close, enter=S3_closing, outputs=[close_happy])
|
||||
S2_happy.upon(send, enter=S2_happy, outputs=[S_send])
|
||||
|
@ -393,6 +426,7 @@ class Boss(object):
|
|||
S3_closing.upon(got_verifier, enter=S3_closing, outputs=[])
|
||||
S3_closing.upon(_got_phase, enter=S3_closing, outputs=[])
|
||||
S3_closing.upon(_got_version, enter=S3_closing, outputs=[])
|
||||
S3_closing.upon(_got_dilate, enter=S3_closing, outputs=[])
|
||||
S3_closing.upon(happy, enter=S3_closing, outputs=[])
|
||||
S3_closing.upon(scared, enter=S3_closing, outputs=[])
|
||||
S3_closing.upon(close, enter=S3_closing, outputs=[])
|
||||
|
@ -404,6 +438,7 @@ class Boss(object):
|
|||
S4_closed.upon(got_verifier, enter=S4_closed, outputs=[])
|
||||
S4_closed.upon(_got_phase, enter=S4_closed, outputs=[])
|
||||
S4_closed.upon(_got_version, enter=S4_closed, outputs=[])
|
||||
S4_closed.upon(_got_dilate, enter=S4_closed, outputs=[])
|
||||
S4_closed.upon(happy, enter=S4_closed, outputs=[])
|
||||
S4_closed.upon(scared, enter=S4_closed, outputs=[])
|
||||
S4_closed.upon(close, enter=S4_closed, outputs=[])
|
||||
|
|
0
src/wormhole/_dilation/__init__.py
Normal file
0
src/wormhole/_dilation/__init__.py
Normal file
482
src/wormhole/_dilation/connection.py
Normal file
482
src/wormhole/_dilation/connection.py
Normal file
|
@ -0,0 +1,482 @@
|
|||
from __future__ import print_function, unicode_literals
|
||||
from collections import namedtuple
|
||||
import six
|
||||
from attr import attrs, attrib
|
||||
from attr.validators import instance_of, provides
|
||||
from automat import MethodicalMachine
|
||||
from zope.interface import Interface, implementer
|
||||
from twisted.python import log
|
||||
from twisted.internet.protocol import Protocol
|
||||
from twisted.internet.interfaces import ITransport
|
||||
from .._interfaces import IDilationConnector
|
||||
from ..observer import OneShotObserver
|
||||
from .encode import to_be4, from_be4
|
||||
from .roles import FOLLOWER
|
||||
|
||||
# InboundFraming is given data and returns Frames (Noise wire-side
|
||||
# bytestrings). It handles the relay handshake and the prologue. The Frames it
|
||||
# returns are either the ephemeral key (the Noise "handshake") or ciphertext
|
||||
# messages.
|
||||
|
||||
# The next object up knows whether it's expecting a Handshake or a message. It
|
||||
# feeds the first into Noise as a handshake, it feeds the rest into Noise as a
|
||||
# message (which produces a plaintext stream). It emits tokens that are either
|
||||
# "i've finished with the handshake (so you can send the KCM if you want)", or
|
||||
# "here is a decrypted message (which might be the KCM)".
|
||||
|
||||
# the transmit direction goes directly to transport.write, and doesn't touch
|
||||
# the state machine. we can do this because the way we encode/encrypt/frame
|
||||
# things doesn't depend upon the receiver state. It would be more safe to e.g.
|
||||
# prohibit sending ciphertext frames unless we're in the received-handshake
|
||||
# state, but then we'll be in the middle of an inbound state transition ("we
|
||||
# just received the handshake, so you can send a KCM now") when we perform an
|
||||
# operation that depends upon the state (send_plaintext(kcm)), which is not a
|
||||
# coherent/safe place to touch the state machine.
|
||||
|
||||
# we could set a flag and test it from inside send_plaintext, which kind of
|
||||
# violates the state machine owning the state (ideally all "if" statements
|
||||
# would be translated into same-input transitions from different starting
|
||||
# states). For the specific question of sending plaintext frames, Noise will
|
||||
# refuse us unless it's ready anyways, so the question is probably moot.
|
||||
|
||||
class IFramer(Interface):
|
||||
pass
|
||||
class IRecord(Interface):
|
||||
pass
|
||||
|
||||
def first(l):
|
||||
return l[0]
|
||||
|
||||
class Disconnect(Exception):
|
||||
pass
|
||||
RelayOK = namedtuple("RelayOk", [])
|
||||
Prologue = namedtuple("Prologue", [])
|
||||
Frame = namedtuple("Frame", ["frame"])
|
||||
|
||||
@attrs
|
||||
@implementer(IFramer)
|
||||
class _Framer(object):
|
||||
_transport = attrib(validator=provides(ITransport))
|
||||
_outbound_prologue = attrib(validator=instance_of(bytes))
|
||||
_inbound_prologue = attrib(validator=instance_of(bytes))
|
||||
_buffer = b""
|
||||
_can_send_frames = False
|
||||
|
||||
# in: use_relay
|
||||
# in: connectionMade, dataReceived
|
||||
# out: prologue_received, frame_received
|
||||
# out (shared): transport.loseConnection
|
||||
# out (shared): transport.write (relay handshake, prologue)
|
||||
# states: want_relay, want_prologue, want_frame
|
||||
m = MethodicalMachine()
|
||||
set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover
|
||||
|
||||
@m.state()
|
||||
def want_relay(self): pass # pragma: no cover
|
||||
@m.state(initial=True)
|
||||
def want_prologue(self): pass # pragma: no cover
|
||||
@m.state()
|
||||
def want_frame(self): pass # pragma: no cover
|
||||
|
||||
@m.input()
|
||||
def use_relay(self, relay_handshake): pass
|
||||
@m.input()
|
||||
def connectionMade(self): pass
|
||||
@m.input()
|
||||
def parse(self): pass
|
||||
@m.input()
|
||||
def got_relay_ok(self): pass
|
||||
@m.input()
|
||||
def got_prologue(self): pass
|
||||
|
||||
@m.output()
|
||||
def store_relay_handshake(self, relay_handshake):
|
||||
self._outbound_relay_handshake = relay_handshake
|
||||
self._expected_relay_handshake = b"ok\n" # TODO: make this configurable
|
||||
@m.output()
|
||||
def send_relay_handshake(self):
|
||||
self._transport.write(self._outbound_relay_handshake)
|
||||
|
||||
@m.output()
|
||||
def send_prologue(self):
|
||||
self._transport.write(self._outbound_prologue)
|
||||
|
||||
@m.output()
|
||||
def parse_relay_ok(self):
|
||||
if self._get_expected("relay_ok", self._expected_relay_handshake):
|
||||
return RelayOK()
|
||||
|
||||
@m.output()
|
||||
def parse_prologue(self):
|
||||
if self._get_expected("prologue", self._inbound_prologue):
|
||||
return Prologue()
|
||||
|
||||
@m.output()
|
||||
def can_send_frames(self):
|
||||
self._can_send_frames = True # for assertion in send_frame()
|
||||
|
||||
@m.output()
|
||||
def parse_frame(self):
|
||||
if len(self._buffer) < 4:
|
||||
return None
|
||||
frame_length = from_be4(self._buffer[0:4])
|
||||
if len(self._buffer) < 4+frame_length:
|
||||
return None
|
||||
frame = self._buffer[4:4+frame_length]
|
||||
self._buffer = self._buffer[4+frame_length:] # TODO: avoid copy
|
||||
return Frame(frame=frame)
|
||||
|
||||
want_prologue.upon(use_relay, outputs=[store_relay_handshake],
|
||||
enter=want_relay)
|
||||
|
||||
want_relay.upon(connectionMade, outputs=[send_relay_handshake],
|
||||
enter=want_relay)
|
||||
want_relay.upon(parse, outputs=[parse_relay_ok], enter=want_relay,
|
||||
collector=first)
|
||||
want_relay.upon(got_relay_ok, outputs=[send_prologue], enter=want_prologue)
|
||||
|
||||
want_prologue.upon(connectionMade, outputs=[send_prologue],
|
||||
enter=want_prologue)
|
||||
want_prologue.upon(parse, outputs=[parse_prologue], enter=want_prologue,
|
||||
collector=first)
|
||||
want_prologue.upon(got_prologue, outputs=[can_send_frames], enter=want_frame)
|
||||
|
||||
want_frame.upon(parse, outputs=[parse_frame], enter=want_frame,
|
||||
collector=first)
|
||||
|
||||
|
||||
def _get_expected(self, name, expected):
|
||||
lb = len(self._buffer)
|
||||
le = len(expected)
|
||||
if self._buffer.startswith(expected):
|
||||
# if the buffer starts with the expected string, consume it and
|
||||
# return True
|
||||
self._buffer = self._buffer[le:]
|
||||
return True
|
||||
if not expected.startswith(self._buffer):
|
||||
# we're not on track: the data we've received so far does not
|
||||
# match the expected value, so this can't possibly be right.
|
||||
# Don't complain until we see the expected length, or a newline,
|
||||
# so we can capture the weird input in the log for debugging.
|
||||
if (b"\n" in self._buffer or lb >= le):
|
||||
log.msg("bad {}: {}".format(name, self._buffer[:le]))
|
||||
raise Disconnect()
|
||||
return False # wait a bit longer
|
||||
# good so far, just waiting for the rest
|
||||
return False
|
||||
|
||||
# external API is: connectionMade, add_and_parse, and send_frame
|
||||
|
||||
def add_and_parse(self, data):
|
||||
# we can't make this an @m.input because we can't change the state
|
||||
# from within an input. Instead, let the state choose the parser to
|
||||
# use, and use the parsed token drive a state transition.
|
||||
self._buffer += data
|
||||
while True:
|
||||
# it'd be nice to use an iterator here, but since self.parse()
|
||||
# dispatches to a different parser (depending upon the current
|
||||
# state), we'd be using multiple iterators
|
||||
token = self.parse()
|
||||
if isinstance(token, RelayOK):
|
||||
self.got_relay_ok()
|
||||
elif isinstance(token, Prologue):
|
||||
self.got_prologue()
|
||||
yield token # triggers send_handshake
|
||||
elif isinstance(token, Frame):
|
||||
yield token
|
||||
else:
|
||||
break
|
||||
|
||||
def send_frame(self, frame):
|
||||
assert self._can_send_frames
|
||||
self._transport.write(to_be4(len(frame)) + frame)
|
||||
|
||||
# RelayOK: Newline-terminated buddy-is-connected response from Relay.
|
||||
# First data received from relay.
|
||||
# Prologue: double-newline-terminated this-is-really-wormhole response
|
||||
# from peer. First data received from peer.
|
||||
# Frame: Either handshake or encrypted message. Length-prefixed on wire.
|
||||
# Handshake: the Noise ephemeral key, first framed message
|
||||
# Message: plaintext: encoded KCM/PING/PONG/OPEN/DATA/CLOSE/ACK
|
||||
# KCM: Key Confirmation Message (encrypted b"\x00"). First frame
|
||||
# from peer. Sent immediately by Follower, after Selection by Leader.
|
||||
# Record: namedtuple of KCM/Open/Data/Close/Ack/Ping/Pong
|
||||
|
||||
Handshake = namedtuple("Handshake", [])
|
||||
# decrypted frames: produces KCM, Ping, Pong, Open, Data, Close, Ack
|
||||
KCM = namedtuple("KCM", [])
|
||||
Ping = namedtuple("Ping", ["ping_id"]) # ping_id is arbitrary 4-byte value
|
||||
Pong = namedtuple("Pong", ["ping_id"])
|
||||
Open = namedtuple("Open", ["seqnum", "scid"]) # seqnum is integer
|
||||
Data = namedtuple("Data", ["seqnum", "scid", "data"])
|
||||
Close = namedtuple("Close", ["seqnum", "scid"]) # scid is integer
|
||||
Ack = namedtuple("Ack", ["resp_seqnum"]) # resp_seqnum is integer
|
||||
Records = (KCM, Ping, Pong, Open, Data, Close, Ack)
|
||||
Handshake_or_Records = (Handshake,) + Records
|
||||
|
||||
T_KCM = b"\x00"
|
||||
T_PING = b"\x01"
|
||||
T_PONG = b"\x02"
|
||||
T_OPEN = b"\x03"
|
||||
T_DATA = b"\x04"
|
||||
T_CLOSE = b"\x05"
|
||||
T_ACK = b"\x06"
|
||||
|
||||
def parse_record(plaintext):
|
||||
msgtype = plaintext[0:1]
|
||||
if msgtype == T_KCM:
|
||||
return KCM()
|
||||
if msgtype == T_PING:
|
||||
ping_id = plaintext[1:5]
|
||||
return Ping(ping_id)
|
||||
if msgtype == T_PONG:
|
||||
ping_id = plaintext[1:5]
|
||||
return Pong(ping_id)
|
||||
if msgtype == T_OPEN:
|
||||
scid = from_be4(plaintext[1:5])
|
||||
seqnum = from_be4(plaintext[5:9])
|
||||
return Open(seqnum, scid)
|
||||
if msgtype == T_DATA:
|
||||
scid = from_be4(plaintext[1:5])
|
||||
seqnum = from_be4(plaintext[5:9])
|
||||
data = plaintext[9:]
|
||||
return Data(seqnum, scid, data)
|
||||
if msgtype == T_CLOSE:
|
||||
scid = from_be4(plaintext[1:5])
|
||||
seqnum = from_be4(plaintext[5:9])
|
||||
return Close(seqnum, scid)
|
||||
if msgtype == T_ACK:
|
||||
resp_seqnum = from_be4(plaintext[1:5])
|
||||
return Ack(resp_seqnum)
|
||||
log.err("received unknown message type: {}".format(plaintext))
|
||||
raise ValueError()
|
||||
|
||||
def encode_record(r):
|
||||
if isinstance(r, KCM):
|
||||
return b"\x00"
|
||||
if isinstance(r, Ping):
|
||||
return b"\x01" + r.ping_id
|
||||
if isinstance(r, Pong):
|
||||
return b"\x02" + r.ping_id
|
||||
if isinstance(r, Open):
|
||||
assert isinstance(r.scid, six.integer_types)
|
||||
assert isinstance(r.seqnum, six.integer_types)
|
||||
return b"\x03" + to_be4(r.scid) + to_be4(r.seqnum)
|
||||
if isinstance(r, Data):
|
||||
assert isinstance(r.scid, six.integer_types)
|
||||
assert isinstance(r.seqnum, six.integer_types)
|
||||
return b"\x04" + to_be4(r.scid) + to_be4(r.seqnum) + r.data
|
||||
if isinstance(r, Close):
|
||||
assert isinstance(r.scid, six.integer_types)
|
||||
assert isinstance(r.seqnum, six.integer_types)
|
||||
return b"\x05" + to_be4(r.scid) + to_be4(r.seqnum)
|
||||
if isinstance(r, Ack):
|
||||
assert isinstance(r.resp_seqnum, six.integer_types)
|
||||
return b"\x06" + to_be4(r.resp_seqnum)
|
||||
raise TypeError(r)
|
||||
|
||||
@attrs
|
||||
@implementer(IRecord)
|
||||
class _Record(object):
|
||||
_framer = attrib(validator=provides(IFramer))
|
||||
_noise = attrib()
|
||||
|
||||
n = MethodicalMachine()
|
||||
# TODO: set_trace
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
self._noise.start_handshake()
|
||||
|
||||
# in: role=
|
||||
# in: prologue_received, frame_received
|
||||
# out: handshake_received, record_received
|
||||
# out: transport.write (noise handshake, encrypted records)
|
||||
# states: want_prologue, want_handshake, want_record
|
||||
|
||||
@n.state(initial=True)
|
||||
def want_prologue(self): pass # pragma: no cover
|
||||
@n.state()
|
||||
def want_handshake(self): pass # pragma: no cover
|
||||
@n.state()
|
||||
def want_message(self): pass # pragma: no cover
|
||||
|
||||
@n.input()
|
||||
def got_prologue(self):
|
||||
pass
|
||||
@n.input()
|
||||
def got_frame(self, frame):
|
||||
pass
|
||||
|
||||
@n.output()
|
||||
def send_handshake(self):
|
||||
handshake = self._noise.write_message() # generate the ephemeral key
|
||||
self._framer.send_frame(handshake)
|
||||
|
||||
@n.output()
|
||||
def process_handshake(self, frame):
|
||||
from noise.exceptions import NoiseInvalidMessage
|
||||
try:
|
||||
payload = self._noise.read_message(frame)
|
||||
# Noise can include unencrypted data in the handshake, but we don't
|
||||
# use it
|
||||
del payload
|
||||
except NoiseInvalidMessage as e:
|
||||
log.err(e, "bad inbound noise handshake")
|
||||
raise Disconnect()
|
||||
return Handshake()
|
||||
|
||||
@n.output()
|
||||
def decrypt_message(self, frame):
|
||||
from noise.exceptions import NoiseInvalidMessage
|
||||
try:
|
||||
message = self._noise.decrypt(frame)
|
||||
except NoiseInvalidMessage as e:
|
||||
# if this happens during tests, flunk the test
|
||||
log.err(e, "bad inbound noise frame")
|
||||
raise Disconnect()
|
||||
return parse_record(message)
|
||||
|
||||
want_prologue.upon(got_prologue, outputs=[send_handshake],
|
||||
enter=want_handshake)
|
||||
want_handshake.upon(got_frame, outputs=[process_handshake],
|
||||
collector=first, enter=want_message)
|
||||
want_message.upon(got_frame, outputs=[decrypt_message],
|
||||
collector=first, enter=want_message)
|
||||
|
||||
# external API is: connectionMade, dataReceived, send_record
|
||||
|
||||
def connectionMade(self):
|
||||
self._framer.connectionMade()
|
||||
|
||||
def add_and_unframe(self, data):
|
||||
for token in self._framer.add_and_parse(data):
|
||||
if isinstance(token, Prologue):
|
||||
self.got_prologue() # triggers send_handshake
|
||||
else:
|
||||
assert isinstance(token, Frame)
|
||||
yield self.got_frame(token.frame) # Handshake or a Record type
|
||||
|
||||
def send_record(self, r):
|
||||
message = encode_record(r)
|
||||
frame = self._noise.encrypt(message)
|
||||
self._framer.send_frame(frame)
|
||||
|
||||
|
||||
@attrs
|
||||
class DilatedConnectionProtocol(Protocol, object):
|
||||
"""I manage an L2 connection.
|
||||
|
||||
When a new L2 connection is needed (as determined by the Leader),
|
||||
both Leader and Follower will initiate many simultaneous connections
|
||||
(probably TCP, but conceivably others). A subset will actually
|
||||
connect. A subset of those will successfully pass negotiation by
|
||||
exchanging handshakes to demonstrate knowledge of the session key.
|
||||
One of the negotiated connections will be selected by the Leader for
|
||||
active use, and the others will be dropped.
|
||||
|
||||
At any given time, there is at most one active L2 connection.
|
||||
"""
|
||||
|
||||
_eventual_queue = attrib()
|
||||
_role = attrib()
|
||||
_connector = attrib(validator=provides(IDilationConnector))
|
||||
_noise = attrib()
|
||||
_outbound_prologue = attrib(validator=instance_of(bytes))
|
||||
_inbound_prologue = attrib(validator=instance_of(bytes))
|
||||
|
||||
_use_relay = False
|
||||
_relay_handshake = None
|
||||
|
||||
m = MethodicalMachine()
|
||||
set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
self._manager = None # set if/when we are selected
|
||||
self._disconnected = OneShotObserver(self._eventual_queue)
|
||||
self._can_send_records = False
|
||||
|
||||
@m.state(initial=True)
|
||||
def unselected(self): pass # pragma: no cover
|
||||
@m.state()
|
||||
def selecting(self): pass # pragma: no cover
|
||||
@m.state()
|
||||
def selected(self): pass # pragma: no cover
|
||||
|
||||
@m.input()
|
||||
def got_kcm(self):
|
||||
pass
|
||||
@m.input()
|
||||
def select(self, manager):
|
||||
pass # fires set_manager()
|
||||
@m.input()
|
||||
def got_record(self, record):
|
||||
pass
|
||||
|
||||
@m.output()
|
||||
def add_candidate(self):
|
||||
self._connector.add_candidate(self)
|
||||
|
||||
@m.output()
|
||||
def set_manager(self, manager):
|
||||
self._manager = manager
|
||||
|
||||
@m.output()
|
||||
def can_send_records(self, manager):
|
||||
self._can_send_records = True
|
||||
|
||||
@m.output()
|
||||
def deliver_record(self, record):
|
||||
self._manager.got_record(record)
|
||||
|
||||
unselected.upon(got_kcm, outputs=[add_candidate], enter=selecting)
|
||||
selecting.upon(select, outputs=[set_manager, can_send_records], enter=selected)
|
||||
selected.upon(got_record, outputs=[deliver_record], enter=selected)
|
||||
|
||||
# called by Connector
|
||||
|
||||
def use_relay(self, relay_handshake):
|
||||
assert isinstance(relay_handshake, bytes)
|
||||
self._use_relay = True
|
||||
self._relay_handshake = relay_handshake
|
||||
|
||||
def when_disconnected(self):
|
||||
return self._disconnected.when_fired()
|
||||
|
||||
def disconnect(self):
|
||||
self.transport.loseConnection()
|
||||
|
||||
# select() called by Connector
|
||||
|
||||
# called by Manager
|
||||
def send_record(self, record):
|
||||
assert self._can_send_records
|
||||
self._record.send_record(record)
|
||||
|
||||
# IProtocol methods
|
||||
|
||||
def connectionMade(self):
|
||||
framer = _Framer(self.transport,
|
||||
self._outbound_prologue, self._inbound_prologue)
|
||||
if self._use_relay:
|
||||
framer.use_relay(self._relay_handshake)
|
||||
self._record = _Record(framer, self._noise)
|
||||
self._record.connectionMade()
|
||||
|
||||
def dataReceived(self, data):
|
||||
try:
|
||||
for token in self._record.add_and_unframe(data):
|
||||
assert isinstance(token, Handshake_or_Records)
|
||||
if isinstance(token, Handshake):
|
||||
if self._role is FOLLOWER:
|
||||
self._record.send_record(KCM())
|
||||
elif isinstance(token, KCM):
|
||||
# if we're the leader, add this connection as a candiate.
|
||||
# if we're the follower, accept this connection.
|
||||
self.got_kcm() # connector.add_candidate()
|
||||
else:
|
||||
self.got_record(token) # manager.got_record()
|
||||
except Disconnect:
|
||||
self.transport.loseConnection()
|
||||
|
||||
def connectionLost(self, why=None):
|
||||
self._disconnected.fire(self)
|
482
src/wormhole/_dilation/connector.py
Normal file
482
src/wormhole/_dilation/connector.py
Normal file
|
@ -0,0 +1,482 @@
|
|||
from __future__ import print_function, unicode_literals
|
||||
import sys, re
|
||||
from collections import defaultdict, namedtuple
|
||||
from binascii import hexlify
|
||||
import six
|
||||
from attr import attrs, attrib
|
||||
from attr.validators import instance_of, provides, optional
|
||||
from automat import MethodicalMachine
|
||||
from zope.interface import implementer
|
||||
from twisted.internet.task import deferLater
|
||||
from twisted.internet.defer import DeferredList
|
||||
from twisted.internet.endpoints import HostnameEndpoint, serverFromString
|
||||
from twisted.internet.protocol import ClientFactory, ServerFactory
|
||||
from twisted.python import log
|
||||
from hkdf import Hkdf
|
||||
from .. import ipaddrs # TODO: move into _dilation/
|
||||
from .._interfaces import IDilationConnector, IDilationManager
|
||||
from ..timing import DebugTiming
|
||||
from ..observer import EmptyableSet
|
||||
from .connection import DilatedConnectionProtocol, KCM
|
||||
from .roles import LEADER
|
||||
|
||||
|
||||
# These namedtuples are "hint objects". The JSON-serializable dictionaries
|
||||
# are "hint dicts".
|
||||
|
||||
# DirectTCPV1Hint and TorTCPV1Hint mean the following protocol:
|
||||
# * make a TCP connection (possibly via Tor)
|
||||
# * send the sender/receiver handshake bytes first
|
||||
# * expect to see the receiver/sender handshake bytes from the other side
|
||||
# * the sender writes "go\n", the receiver waits for "go\n"
|
||||
# * the rest of the connection contains transit data
|
||||
DirectTCPV1Hint = namedtuple("DirectTCPV1Hint", ["hostname", "port", "priority"])
|
||||
TorTCPV1Hint = namedtuple("TorTCPV1Hint", ["hostname", "port", "priority"])
|
||||
# RelayV1Hint contains a tuple of DirectTCPV1Hint and TorTCPV1Hint hints (we
|
||||
# use a tuple rather than a list so they'll be hashable into a set). For each
|
||||
# one, make the TCP connection, send the relay handshake, then complete the
|
||||
# rest of the V1 protocol. Only one hint per relay is useful.
|
||||
RelayV1Hint = namedtuple("RelayV1Hint", ["hints"])
|
||||
|
||||
def describe_hint_obj(hint, relay, tor):
|
||||
prefix = "tor->" if tor else "->"
|
||||
if relay:
|
||||
prefix = prefix + "relay:"
|
||||
if isinstance(hint, DirectTCPV1Hint):
|
||||
return prefix + "tcp:%s:%d" % (hint.hostname, hint.port)
|
||||
elif isinstance(hint, TorTCPV1Hint):
|
||||
return prefix+"tor:%s:%d" % (hint.hostname, hint.port)
|
||||
else:
|
||||
return prefix+str(hint)
|
||||
|
||||
def parse_hint_argv(hint, stderr=sys.stderr):
|
||||
assert isinstance(hint, type(""))
|
||||
# return tuple or None for an unparseable hint
|
||||
priority = 0.0
|
||||
mo = re.search(r'^([a-zA-Z0-9]+):(.*)$', hint)
|
||||
if not mo:
|
||||
print("unparseable hint '%s'" % (hint,), file=stderr)
|
||||
return None
|
||||
hint_type = mo.group(1)
|
||||
if hint_type != "tcp":
|
||||
print("unknown hint type '%s' in '%s'" % (hint_type, hint), file=stderr)
|
||||
return None
|
||||
hint_value = mo.group(2)
|
||||
pieces = hint_value.split(":")
|
||||
if len(pieces) < 2:
|
||||
print("unparseable TCP hint (need more colons) '%s'" % (hint,),
|
||||
file=stderr)
|
||||
return None
|
||||
mo = re.search(r'^(\d+)$', pieces[1])
|
||||
if not mo:
|
||||
print("non-numeric port in TCP hint '%s'" % (hint,), file=stderr)
|
||||
return None
|
||||
hint_host = pieces[0]
|
||||
hint_port = int(pieces[1])
|
||||
for more in pieces[2:]:
|
||||
if more.startswith("priority="):
|
||||
more_pieces = more.split("=")
|
||||
try:
|
||||
priority = float(more_pieces[1])
|
||||
except ValueError:
|
||||
print("non-float priority= in TCP hint '%s'" % (hint,),
|
||||
file=stderr)
|
||||
return None
|
||||
return DirectTCPV1Hint(hint_host, hint_port, priority)
|
||||
|
||||
def parse_tcp_v1_hint(hint): # hint_struct -> hint_obj
|
||||
hint_type = hint.get("type", "")
|
||||
if hint_type not in ["direct-tcp-v1", "tor-tcp-v1"]:
|
||||
log.msg("unknown hint type: %r" % (hint,))
|
||||
return None
|
||||
if not("hostname" in hint
|
||||
and isinstance(hint["hostname"], type(""))):
|
||||
log.msg("invalid hostname in hint: %r" % (hint,))
|
||||
return None
|
||||
if not("port" in hint
|
||||
and isinstance(hint["port"], six.integer_types)):
|
||||
log.msg("invalid port in hint: %r" % (hint,))
|
||||
return None
|
||||
priority = hint.get("priority", 0.0)
|
||||
if hint_type == "direct-tcp-v1":
|
||||
return DirectTCPV1Hint(hint["hostname"], hint["port"], priority)
|
||||
else:
|
||||
return TorTCPV1Hint(hint["hostname"], hint["port"], priority)
|
||||
|
||||
def parse_hint(hint_struct):
|
||||
hint_type = hint_struct.get("type", "")
|
||||
if hint_type == "relay-v1":
|
||||
# the struct can include multiple ways to reach the same relay
|
||||
rhints = filter(lambda h: h, # drop None (unrecognized)
|
||||
[parse_tcp_v1_hint(rh) for rh in hint_struct["hints"]])
|
||||
return RelayV1Hint(rhints)
|
||||
return parse_tcp_v1_hint(hint_struct)
|
||||
|
||||
def encode_hint(h):
|
||||
if isinstance(h, DirectTCPV1Hint):
|
||||
return {"type": "direct-tcp-v1",
|
||||
"priority": h.priority,
|
||||
"hostname": h.hostname,
|
||||
"port": h.port, # integer
|
||||
}
|
||||
elif isinstance(h, RelayV1Hint):
|
||||
rhint = {"type": "relay-v1", "hints": []}
|
||||
for rh in h.hints:
|
||||
rhint["hints"].append({"type": "direct-tcp-v1",
|
||||
"priority": rh.priority,
|
||||
"hostname": rh.hostname,
|
||||
"port": rh.port})
|
||||
return rhint
|
||||
elif isinstance(h, TorTCPV1Hint):
|
||||
return {"type": "tor-tcp-v1",
|
||||
"priority": h.priority,
|
||||
"hostname": h.hostname,
|
||||
"port": h.port, # integer
|
||||
}
|
||||
raise ValueError("unknown hint type", h)
|
||||
|
||||
def HKDF(skm, outlen, salt=None, CTXinfo=b""):
|
||||
return Hkdf(salt, skm).expand(CTXinfo, outlen)
|
||||
|
||||
def build_sided_relay_handshake(key, side):
|
||||
assert isinstance(side, type(u""))
|
||||
assert len(side) == 8*2
|
||||
token = HKDF(key, 32, CTXinfo=b"transit_relay_token")
|
||||
return b"please relay "+hexlify(token)+b" for side "+side.encode("ascii")+b"\n"
|
||||
|
||||
PROLOGUE_LEADER = b"Magic-Wormhole Dilation Handshake v1 Leader\n\n"
|
||||
PROLOGUE_FOLLOWER = b"Magic-Wormhole Dilation Handshake v1 Follower\n\n"
|
||||
NOISEPROTO = "Noise_NNpsk0_25519_ChaChaPoly_BLAKE2s"
|
||||
|
||||
@attrs
|
||||
@implementer(IDilationConnector)
|
||||
class Connector(object):
|
||||
_dilation_key = attrib(validator=instance_of(type(b"")))
|
||||
_transit_relay_location = attrib(validator=optional(instance_of(str)))
|
||||
_manager = attrib(validator=provides(IDilationManager))
|
||||
_reactor = attrib()
|
||||
_eventual_queue = attrib()
|
||||
_no_listen = attrib(validator=instance_of(bool))
|
||||
_tor = attrib()
|
||||
_timing = attrib()
|
||||
_side = attrib(validator=instance_of(type(u"")))
|
||||
# was self._side = bytes_to_hexstr(os.urandom(8)) # unicode
|
||||
_role = attrib()
|
||||
|
||||
m = MethodicalMachine()
|
||||
set_trace = getattr(m, "_setTrace", lambda self, f: None)
|
||||
|
||||
RELAY_DELAY = 2.0
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
if self._transit_relay_location:
|
||||
# TODO: allow multiple hints for a single relay
|
||||
relay_hint = parse_hint_argv(self._transit_relay_location)
|
||||
relay = RelayV1Hint(hints=(relay_hint,))
|
||||
self._transit_relays = [relay]
|
||||
else:
|
||||
self._transit_relays = []
|
||||
self._listeners = set() # IListeningPorts that can be stopped
|
||||
self._pending_connectors = set() # Deferreds that can be cancelled
|
||||
self._pending_connections = EmptyableSet(_eventual_queue=self._eventual_queue) # Protocols to be stopped
|
||||
self._contenders = set() # viable connections
|
||||
self._winning_connection = None
|
||||
self._timing = self._timing or DebugTiming()
|
||||
self._timing.add("transit")
|
||||
|
||||
# this describes what our Connector can do, for the initial advertisement
|
||||
@classmethod
|
||||
def get_connection_abilities(klass):
|
||||
return [{"type": "direct-tcp-v1"},
|
||||
{"type": "relay-v1"},
|
||||
]
|
||||
|
||||
def build_protocol(self, addr):
|
||||
# encryption: let's use Noise NNpsk0 (or maybe NNpsk2). That uses
|
||||
# ephemeral keys plus a pre-shared symmetric key (the Transit key), a
|
||||
# different one for each potential connection.
|
||||
from noise.connection import NoiseConnection
|
||||
noise = NoiseConnection.from_name(NOISEPROTO)
|
||||
noise.set_psks(self._dilation_key)
|
||||
if self._role is LEADER:
|
||||
noise.set_as_initiator()
|
||||
outbound_prologue = PROLOGUE_LEADER
|
||||
inbound_prologue = PROLOGUE_FOLLOWER
|
||||
else:
|
||||
noise.set_as_responder()
|
||||
outbound_prologue = PROLOGUE_FOLLOWER
|
||||
inbound_prologue = PROLOGUE_LEADER
|
||||
p = DilatedConnectionProtocol(self._eventual_queue, self._role,
|
||||
self, noise,
|
||||
outbound_prologue, inbound_prologue)
|
||||
return p
|
||||
|
||||
@m.state(initial=True)
|
||||
def connecting(self): pass # pragma: no cover
|
||||
@m.state()
|
||||
def connected(self): pass # pragma: no cover
|
||||
@m.state(terminal=True)
|
||||
def stopped(self): pass # pragma: no cover
|
||||
|
||||
# TODO: unify the tense of these method-name verbs
|
||||
@m.input()
|
||||
def listener_ready(self, hint_objs): pass
|
||||
@m.input()
|
||||
def add_relay(self, hint_objs): pass
|
||||
@m.input()
|
||||
def got_hints(self, hint_objs): pass
|
||||
@m.input()
|
||||
def add_candidate(self, c): # called by DilatedConnectionProtocol
|
||||
pass
|
||||
@m.input()
|
||||
def accept(self, c): pass
|
||||
@m.input()
|
||||
def stop(self): pass
|
||||
|
||||
@m.output()
|
||||
def use_hints(self, hint_objs):
|
||||
self._use_hints(hint_objs)
|
||||
|
||||
@m.output()
|
||||
def publish_hints(self, hint_objs):
|
||||
self._manager.send_hints([encode_hint(h) for h in hint_objs])
|
||||
|
||||
@m.output()
|
||||
def consider(self, c):
|
||||
self._contenders.add(c)
|
||||
if self._role is LEADER:
|
||||
# for now, just accept the first one. TODO: be clever.
|
||||
self._eventual_queue.eventually(self.accept, c)
|
||||
else:
|
||||
# the follower always uses the first contender, since that's the
|
||||
# only one the leader picked
|
||||
self._eventual_queue.eventually(self.accept, c)
|
||||
|
||||
@m.output()
|
||||
def select_and_stop_remaining(self, c):
|
||||
self._winning_connection = c
|
||||
self._contenders.clear() # we no longer care who else came close
|
||||
# remove this winner from the losers, so we don't shut it down
|
||||
self._pending_connections.discard(c)
|
||||
# shut down losing connections
|
||||
self.stop_listeners() # TODO: maybe keep it open? NAT/p2p assist
|
||||
self.stop_pending_connectors()
|
||||
self.stop_pending_connections()
|
||||
|
||||
c.select(self._manager) # subsequent frames go directly to the manager
|
||||
if self._role is LEADER:
|
||||
# TODO: this should live in Connection
|
||||
c.send_record(KCM()) # leader sends KCM now
|
||||
self._manager.use_connection(c) # manager sends frames to Connection
|
||||
|
||||
@m.output()
|
||||
def stop_everything(self):
|
||||
self.stop_listeners()
|
||||
self.stop_pending_connectors()
|
||||
self.stop_pending_connections()
|
||||
self.break_cycles()
|
||||
|
||||
def stop_listeners(self):
|
||||
d = DeferredList([l.stopListening() for l in self._listeners])
|
||||
self._listeners.clear()
|
||||
return d # synchronization for tests
|
||||
|
||||
def stop_pending_connectors(self):
|
||||
return DeferredList([d.cancel() for d in self._pending_connectors])
|
||||
|
||||
def stop_pending_connections(self):
|
||||
d = self._pending_connections.when_next_empty()
|
||||
[c.loseConnection() for c in self._pending_connections]
|
||||
return d
|
||||
|
||||
def stop_winner(self):
|
||||
d = self._winner.when_disconnected()
|
||||
self._winner.disconnect()
|
||||
return d
|
||||
|
||||
def break_cycles(self):
|
||||
# help GC by forgetting references to things that reference us
|
||||
self._listeners.clear()
|
||||
self._pending_connectors.clear()
|
||||
self._pending_connections.clear()
|
||||
self._winner = None
|
||||
|
||||
connecting.upon(listener_ready, enter=connecting, outputs=[publish_hints])
|
||||
connecting.upon(add_relay, enter=connecting, outputs=[use_hints,
|
||||
publish_hints])
|
||||
connecting.upon(got_hints, enter=connecting, outputs=[use_hints])
|
||||
connecting.upon(add_candidate, enter=connecting, outputs=[consider])
|
||||
connecting.upon(accept, enter=connected, outputs=[select_and_stop_remaining])
|
||||
connecting.upon(stop, enter=stopped, outputs=[stop_everything])
|
||||
|
||||
# once connected, we ignore everything except stop
|
||||
connected.upon(listener_ready, enter=connected, outputs=[])
|
||||
connected.upon(add_relay, enter=connected, outputs=[])
|
||||
connected.upon(got_hints, enter=connected, outputs=[])
|
||||
connected.upon(add_candidate, enter=connected, outputs=[])
|
||||
connected.upon(accept, enter=connected, outputs=[])
|
||||
connected.upon(stop, enter=stopped, outputs=[stop_everything])
|
||||
|
||||
|
||||
# from Manager: start, got_hints, stop
|
||||
# maybe add_candidate, accept
|
||||
def start(self):
|
||||
self._start_listener()
|
||||
if self._transit_relays:
|
||||
self.publish_hints(self._transit_relays)
|
||||
self._use_hints(self._transit_relays)
|
||||
|
||||
def _start_listener(self):
|
||||
if self._no_listen or self._tor:
|
||||
return
|
||||
addresses = ipaddrs.find_addresses()
|
||||
non_loopback_addresses = [a for a in addresses if a != "127.0.0.1"]
|
||||
if non_loopback_addresses:
|
||||
# some test hosts, including the appveyor VMs, *only* have
|
||||
# 127.0.0.1, and the tests will hang badly if we remove it.
|
||||
addresses = non_loopback_addresses
|
||||
# TODO: listen on a fixed port, if possible, for NAT/p2p benefits, also
|
||||
# to make firewall configs easier
|
||||
# TODO: retain listening port between connection generations?
|
||||
ep = serverFromString(self._reactor, "tcp:0")
|
||||
f = InboundConnectionFactory(self)
|
||||
d = ep.listen(f)
|
||||
def _listening(lp):
|
||||
# lp is an IListeningPort
|
||||
self._listeners.add(lp) # for shutdown and tests
|
||||
portnum = lp.getHost().port
|
||||
direct_hints = [DirectTCPV1Hint(six.u(addr), portnum, 0.0)
|
||||
for addr in addresses]
|
||||
self.listener_ready(direct_hints)
|
||||
d.addCallback(_listening)
|
||||
d.addErrback(log.err)
|
||||
|
||||
def _use_hints(self, hints):
|
||||
# first, pull out all the relays, we'll connect to them later
|
||||
relays = defaultdict(list)
|
||||
direct = defaultdict(list)
|
||||
for h in hints:
|
||||
if isinstance(h, RelayV1Hint):
|
||||
relays[h.priority].append(h)
|
||||
else:
|
||||
direct[h.priority].append(h)
|
||||
delay = 0.0
|
||||
priorities = sorted(set(direct.keys()), reverse=True)
|
||||
for p in priorities:
|
||||
for h in direct[p]:
|
||||
if isinstance(h, TorTCPV1Hint) and not self._tor:
|
||||
continue
|
||||
ep = self._endpoint_from_hint_obj(h)
|
||||
desc = describe_hint_obj(h, False, self._tor)
|
||||
d = deferLater(self._reactor, delay,
|
||||
self._connect, ep, desc, is_relay=False)
|
||||
self._pending_connectors.add(d)
|
||||
# Make all direct connections immediately. Later, we'll change
|
||||
# the add_candidate() function to look at the priority when
|
||||
# deciding whether to accept a successful connection or not,
|
||||
# and it can wait for more options if it sees a higher-priority
|
||||
# one still running. But if we bail on that, we might consider
|
||||
# putting an inter-direct-hint delay here to influence the
|
||||
# process.
|
||||
#delay += 1.0
|
||||
if delay > 0.0:
|
||||
# Start trying the relays a few seconds after we start to try the
|
||||
# direct hints. The idea is to prefer direct connections, but not
|
||||
# be afraid of using a relay when we have direct hints that don't
|
||||
# resolve quickly. Many direct hints will be to unused
|
||||
# local-network IP addresses, which won't answer, and would take
|
||||
# the full TCP timeout (30s or more) to fail. If there were no
|
||||
# direct hints, don't delay at all.
|
||||
delay += self.RELAY_DELAY
|
||||
|
||||
# prefer direct connections by stalling relay connections by a few
|
||||
# seconds, unless we're using --no-listen in which case we're probably
|
||||
# going to have to use the relay
|
||||
delay = self.RELAY_DELAY if self._no_listen else 0.0
|
||||
|
||||
# It might be nice to wire this so that a failure in the direct hints
|
||||
# causes the relay hints to be used right away (fast failover). But
|
||||
# none of our current use cases would take advantage of that: if we
|
||||
# have any viable direct hints, then they're either going to succeed
|
||||
# quickly or hang for a long time.
|
||||
for p in priorities:
|
||||
for r in relays[p]:
|
||||
for h in r.hints:
|
||||
ep = self._endpoint_from_hint_obj(h)
|
||||
desc = describe_hint_obj(h, True, self._tor)
|
||||
d = deferLater(self._reactor, delay,
|
||||
self._connect, ep, desc, is_relay=True)
|
||||
self._pending_connectors.add(d)
|
||||
# TODO:
|
||||
#if not contenders:
|
||||
# raise TransitError("No contenders for connection")
|
||||
|
||||
# TODO: add 2*TIMEOUT deadline for first generation, don't wait forever for
|
||||
# the initial connection
|
||||
|
||||
def _connect(self, h, ep, description, is_relay=False):
|
||||
relay_handshake = None
|
||||
if is_relay:
|
||||
relay_handshake = build_sided_relay_handshake(self._dilation_key,
|
||||
self._side)
|
||||
f = OutboundConnectionFactory(self, relay_handshake)
|
||||
d = ep.connect(f)
|
||||
# fires with protocol, or ConnectError
|
||||
def _connected(p):
|
||||
self._pending_connections.add(p)
|
||||
# c might not be in _pending_connections, if it turned out to be a
|
||||
# winner, which is why we use discard() and not remove()
|
||||
p.when_disconnected().addCallback(self._pending_connections.discard)
|
||||
d.addCallback(_connected)
|
||||
return d
|
||||
|
||||
def _endpoint_from_hint_obj(self, hint):
|
||||
if self._tor:
|
||||
if isinstance(hint, (DirectTCPV1Hint, TorTCPV1Hint)):
|
||||
# this Tor object will throw ValueError for non-public IPv4
|
||||
# addresses and any IPv6 address
|
||||
try:
|
||||
return self._tor.stream_via(hint.hostname, hint.port)
|
||||
except ValueError:
|
||||
return None
|
||||
return None
|
||||
if isinstance(hint, DirectTCPV1Hint):
|
||||
return HostnameEndpoint(self._reactor, hint.hostname, hint.port)
|
||||
return None
|
||||
|
||||
|
||||
# Connection selection. All instances of DilatedConnectionProtocol which
|
||||
# look viable get passed into our add_contender() method.
|
||||
|
||||
# On the Leader side, "viable" means we've seen their KCM frame, which is
|
||||
# the first Noise-encrypted packet on any given connection, and it has an
|
||||
# empty body. We gather viable connections until we see one that we like,
|
||||
# or a timer expires. Then we "select" it, close the others, and tell our
|
||||
# Manager to use it.
|
||||
|
||||
# On the Follower side, we'll only see a KCM on the one connection selected
|
||||
# by the Leader, so the first viable connection wins.
|
||||
|
||||
# our Connection protocols call: add_candidate
|
||||
|
||||
@attrs
|
||||
class OutboundConnectionFactory(ClientFactory, object):
|
||||
_connector = attrib(validator=provides(IDilationConnector))
|
||||
_relay_handshake = attrib(validator=optional(instance_of(bytes)))
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
p = self._connector.build_protocol(addr)
|
||||
p.factory = self
|
||||
if self._relay_handshake is not None:
|
||||
p.use_relay(self._relay_handshake)
|
||||
return p
|
||||
|
||||
@attrs
|
||||
class InboundConnectionFactory(ServerFactory, object):
|
||||
_connector = attrib(validator=provides(IDilationConnector))
|
||||
protocol = DilatedConnectionProtocol
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
p = self._connector.build_protocol(addr)
|
||||
p.factory = self
|
||||
return p
|
16
src/wormhole/_dilation/encode.py
Normal file
16
src/wormhole/_dilation/encode.py
Normal file
|
@ -0,0 +1,16 @@
|
|||
from __future__ import print_function, unicode_literals
|
||||
import struct
|
||||
|
||||
assert len(struct.pack("<L", 0)) == 4
|
||||
assert len(struct.pack("<Q", 0)) == 8
|
||||
|
||||
def to_be4(value):
|
||||
if not 0 <= value < 2**32:
|
||||
raise ValueError
|
||||
return struct.pack(">L", value)
|
||||
def from_be4(b):
|
||||
if not isinstance(b, bytes):
|
||||
raise TypeError(repr(b))
|
||||
if len(b) != 4:
|
||||
raise ValueError
|
||||
return struct.unpack(">L", b)[0]
|
127
src/wormhole/_dilation/inbound.py
Normal file
127
src/wormhole/_dilation/inbound.py
Normal file
|
@ -0,0 +1,127 @@
|
|||
from __future__ import print_function, unicode_literals
|
||||
from attr import attrs, attrib
|
||||
from attr.validators import provides
|
||||
from zope.interface import implementer
|
||||
from twisted.python import log
|
||||
from .._interfaces import IDilationManager, IInbound
|
||||
from .subchannel import (SubChannel, _SubchannelAddress)
|
||||
|
||||
class DuplicateOpenError(Exception):
|
||||
pass
|
||||
class DataForMissingSubchannelError(Exception):
|
||||
pass
|
||||
class CloseForMissingSubchannelError(Exception):
|
||||
pass
|
||||
|
||||
@attrs
|
||||
@implementer(IInbound)
|
||||
class Inbound(object):
|
||||
# Inbound flow control: TCP delivers data to Connection.dataReceived,
|
||||
# Connection delivers to our handle_data, we deliver to
|
||||
# SubChannel.remote_data, subchannel delivers to proto.dataReceived
|
||||
_manager = attrib(validator=provides(IDilationManager))
|
||||
_host_addr = attrib()
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
# we route inbound Data records to Subchannels .dataReceived
|
||||
self._open_subchannels = {} # scid -> Subchannel
|
||||
self._paused_subchannels = set() # Subchannels that have paused us
|
||||
# the set is non-empty, we pause the transport
|
||||
self._highest_inbound_acked = -1
|
||||
self._connection = None
|
||||
|
||||
# from our Manager
|
||||
def set_listener_endpoint(self, listener_endpoint):
|
||||
self._listener_endpoint = listener_endpoint
|
||||
|
||||
def set_subchannel_zero(self, scid0, sc0):
|
||||
self._open_subchannels[scid0] = sc0
|
||||
|
||||
|
||||
def use_connection(self, c):
|
||||
self._connection = c
|
||||
# We can pause the connection's reads when we receive too much data. If
|
||||
# this is a non-initial connection, then we might already have
|
||||
# subchannels that are paused from before, so we might need to pause
|
||||
# the new connection before it can send us any data
|
||||
if self._paused_subchannels:
|
||||
self._connection.pauseProducing()
|
||||
|
||||
# Inbound is responsible for tracking the high watermark and deciding
|
||||
# whether to ignore inbound messages or not
|
||||
|
||||
def is_record_old(self, r):
|
||||
if r.seqnum <= self._highest_inbound_acked:
|
||||
return True
|
||||
return False
|
||||
|
||||
def update_ack_watermark(self, r):
|
||||
self._highest_inbound_acked = max(self._highest_inbound_acked,
|
||||
r.seqnum)
|
||||
|
||||
def handle_open(self, scid):
|
||||
if scid in self._open_subchannels:
|
||||
log.err(DuplicateOpenError("received duplicate OPEN for {}".format(scid)))
|
||||
return
|
||||
peer_addr = _SubchannelAddress(scid)
|
||||
sc = SubChannel(scid, self._manager, self._host_addr, peer_addr)
|
||||
self._open_subchannels[scid] = sc
|
||||
self._listener_endpoint._got_open(sc, peer_addr)
|
||||
|
||||
def handle_data(self, scid, data):
|
||||
sc = self._open_subchannels.get(scid)
|
||||
if sc is None:
|
||||
log.err(DataForMissingSubchannelError("received DATA for non-existent subchannel {}".format(scid)))
|
||||
return
|
||||
sc.remote_data(data)
|
||||
|
||||
def handle_close(self, scid):
|
||||
sc = self._open_subchannels.get(scid)
|
||||
if sc is None:
|
||||
log.err(CloseForMissingSubchannelError("received CLOSE for non-existent subchannel {}".format(scid)))
|
||||
return
|
||||
sc.remote_close()
|
||||
|
||||
def subchannel_closed(self, scid, sc):
|
||||
# connectionLost has just been signalled
|
||||
assert self._open_subchannels[scid] is sc
|
||||
del self._open_subchannels[scid]
|
||||
|
||||
def stop_using_connection(self):
|
||||
self._connection = None
|
||||
|
||||
|
||||
# from our Subchannel, or rather from the Protocol above it and sent
|
||||
# through the subchannel
|
||||
|
||||
# The subchannel is an IProducer, and application protocols can always
|
||||
# thell them to pauseProducing if we're delivering inbound data too
|
||||
# quickly. They don't need to register anything.
|
||||
|
||||
def subchannel_pauseProducing(self, sc):
|
||||
was_paused = bool(self._paused_subchannels)
|
||||
self._paused_subchannels.add(sc)
|
||||
if self._connection and not was_paused:
|
||||
self._connection.pauseProducing()
|
||||
|
||||
def subchannel_resumeProducing(self, sc):
|
||||
was_paused = bool(self._paused_subchannels)
|
||||
self._paused_subchannels.discard(sc)
|
||||
if self._connection and was_paused and not self._paused_subchannels:
|
||||
self._connection.resumeProducing()
|
||||
|
||||
def subchannel_stopProducing(self, sc):
|
||||
# This protocol doesn't want any additional data. If we were a normal
|
||||
# (single-owner) Transport, we'd call .loseConnection now. But our
|
||||
# Connection is shared among many subchannels, so instead we just
|
||||
# stop letting them pause the connection.
|
||||
was_paused = bool(self._paused_subchannels)
|
||||
self._paused_subchannels.discard(sc)
|
||||
if self._connection and was_paused and not self._paused_subchannels:
|
||||
self._connection.resumeProducing()
|
||||
|
||||
# TODO: we might refactor these pause/resume/stop methods by building a
|
||||
# context manager that look at the paused/not-paused state first, then
|
||||
# lets the caller modify self._paused_subchannels, then looks at it a
|
||||
# second time, and calls c.pauseProducing/c.resumeProducing as
|
||||
# appropriate. I'm not sure it would be any cleaner, though.
|
514
src/wormhole/_dilation/manager.py
Normal file
514
src/wormhole/_dilation/manager.py
Normal file
|
@ -0,0 +1,514 @@
|
|||
from __future__ import print_function, unicode_literals
|
||||
from collections import deque
|
||||
from attr import attrs, attrib
|
||||
from attr.validators import provides, instance_of, optional
|
||||
from automat import MethodicalMachine
|
||||
from zope.interface import implementer
|
||||
from twisted.internet.defer import Deferred, inlineCallbacks, returnValue
|
||||
from twisted.python import log
|
||||
from .._interfaces import IDilator, IDilationManager, ISend
|
||||
from ..util import dict_to_bytes, bytes_to_dict
|
||||
from ..observer import OneShotObserver
|
||||
from .._key import derive_key
|
||||
from .encode import to_be4
|
||||
from .subchannel import (SubChannel, _SubchannelAddress, _WormholeAddress,
|
||||
ControlEndpoint, SubchannelConnectorEndpoint,
|
||||
SubchannelListenerEndpoint)
|
||||
from .connector import Connector, parse_hint
|
||||
from .roles import LEADER, FOLLOWER
|
||||
from .connection import KCM, Ping, Pong, Open, Data, Close, Ack
|
||||
from .inbound import Inbound
|
||||
from .outbound import Outbound
|
||||
|
||||
class OldPeerCannotDilateError(Exception):
|
||||
pass
|
||||
class UnknownDilationMessageType(Exception):
|
||||
pass
|
||||
class ReceivedHintsTooEarly(Exception):
|
||||
pass
|
||||
|
||||
@attrs
|
||||
@implementer(IDilationManager)
|
||||
class _ManagerBase(object):
|
||||
_S = attrib(validator=provides(ISend))
|
||||
_my_side = attrib(validator=instance_of(type(u"")))
|
||||
_transit_key = attrib(validator=instance_of(bytes))
|
||||
_transit_relay_location = attrib(validator=optional(instance_of(str)))
|
||||
_reactor = attrib()
|
||||
_eventual_queue = attrib()
|
||||
_cooperator = attrib()
|
||||
_no_listen = False # TODO
|
||||
_tor = None # TODO
|
||||
_timing = None # TODO
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
self._got_versions_d = Deferred()
|
||||
|
||||
self._my_role = None # determined upon rx_PLEASE
|
||||
|
||||
self._connection = None
|
||||
self._made_first_connection = False
|
||||
self._first_connected = OneShotObserver(self._eventual_queue)
|
||||
self._host_addr = _WormholeAddress()
|
||||
|
||||
self._next_dilation_phase = 0
|
||||
|
||||
self._next_subchannel_id = 0 # increments by 2
|
||||
|
||||
# I kept getting confused about which methods were for inbound data
|
||||
# (and thus flow-control methods go "out") and which were for
|
||||
# outbound data (with flow-control going "in"), so I split them up
|
||||
# into separate pieces.
|
||||
self._inbound = Inbound(self, self._host_addr)
|
||||
self._outbound = Outbound(self, self._cooperator) # from us to peer
|
||||
|
||||
def set_listener_endpoint(self, listener_endpoint):
|
||||
self._inbound.set_listener_endpoint(listener_endpoint)
|
||||
def set_subchannel_zero(self, scid0, sc0):
|
||||
self._inbound.set_subchannel_zero(scid0, sc0)
|
||||
|
||||
def when_first_connected(self):
|
||||
return self._first_connected.when_fired()
|
||||
|
||||
|
||||
def send_dilation_phase(self, **fields):
|
||||
dilation_phase = self._next_dilation_phase
|
||||
self._next_dilation_phase += 1
|
||||
self._S.send("dilate-%d" % dilation_phase, dict_to_bytes(fields))
|
||||
|
||||
def send_hints(self, hints): # from Connector
|
||||
self.send_dilation_phase(type="connection-hints", hints=hints)
|
||||
|
||||
|
||||
# forward inbound-ish things to _Inbound
|
||||
def subchannel_pauseProducing(self, sc):
|
||||
self._inbound.subchannel_pauseProducing(sc)
|
||||
def subchannel_resumeProducing(self, sc):
|
||||
self._inbound.subchannel_resumeProducing(sc)
|
||||
def subchannel_stopProducing(self, sc):
|
||||
self._inbound.subchannel_stopProducing(sc)
|
||||
|
||||
# forward outbound-ish things to _Outbound
|
||||
def subchannel_registerProducer(self, sc, producer, streaming):
|
||||
self._outbound.subchannel_registerProducer(sc, producer, streaming)
|
||||
def subchannel_unregisterProducer(self, sc):
|
||||
self._outbound.subchannel_unregisterProducer(sc)
|
||||
|
||||
def send_open(self, scid):
|
||||
self._queue_and_send(Open, scid)
|
||||
def send_data(self, scid, data):
|
||||
self._queue_and_send(Data, scid, data)
|
||||
def send_close(self, scid):
|
||||
self._queue_and_send(Close, scid)
|
||||
|
||||
def _queue_and_send(self, record_type, *args):
|
||||
r = self._outbound.build_record(record_type, *args)
|
||||
# Outbound owns the send_record() pipe, so that it can stall new
|
||||
# writes after a new connection is made until after all queued
|
||||
# messages are written (to preserve ordering).
|
||||
self._outbound.queue_and_send_record(r) # may trigger pauseProducing
|
||||
|
||||
def subchannel_closed(self, scid, sc):
|
||||
# let everyone clean up. This happens just after we delivered
|
||||
# connectionLost to the Protocol, except for the control channel,
|
||||
# which might get connectionLost later after they use ep.connect.
|
||||
# TODO: is this inversion a problem?
|
||||
self._inbound.subchannel_closed(scid, sc)
|
||||
self._outbound.subchannel_closed(scid, sc)
|
||||
|
||||
|
||||
def _start_connecting(self, role):
|
||||
assert self._my_role is not None
|
||||
self._connector = Connector(self._transit_key,
|
||||
self._transit_relay_location,
|
||||
self,
|
||||
self._reactor, self._eventual_queue,
|
||||
self._no_listen, self._tor,
|
||||
self._timing,
|
||||
self._side, # needed for relay handshake
|
||||
self._my_role)
|
||||
self._connector.start()
|
||||
|
||||
# our Connector calls these
|
||||
|
||||
def connector_connection_made(self, c):
|
||||
self.connection_made() # state machine update
|
||||
self._connection = c
|
||||
self._inbound.use_connection(c)
|
||||
self._outbound.use_connection(c) # does c.registerProducer
|
||||
if not self._made_first_connection:
|
||||
self._made_first_connection = True
|
||||
self._first_connected.fire(None)
|
||||
pass
|
||||
def connector_connection_lost(self):
|
||||
self._stop_using_connection()
|
||||
if self.role is LEADER:
|
||||
self.connection_lost_leader() # state machine
|
||||
else:
|
||||
self.connection_lost_follower()
|
||||
|
||||
|
||||
def _stop_using_connection(self):
|
||||
# the connection is already lost by this point
|
||||
self._connection = None
|
||||
self._inbound.stop_using_connection()
|
||||
self._outbound.stop_using_connection() # does c.unregisterProducer
|
||||
|
||||
# from our active Connection
|
||||
|
||||
def got_record(self, r):
|
||||
# records with sequence numbers: always ack, ignore old ones
|
||||
if isinstance(r, (Open, Data, Close)):
|
||||
self.send_ack(r.seqnum) # always ack, even for old ones
|
||||
if self._inbound.is_record_old(r):
|
||||
return
|
||||
self._inbound.update_ack_watermark(r.seqnum)
|
||||
if isinstance(r, Open):
|
||||
self._inbound.handle_open(r.scid)
|
||||
elif isinstance(r, Data):
|
||||
self._inbound.handle_data(r.scid, r.data)
|
||||
else: # isinstance(r, Close)
|
||||
self._inbound.handle_close(r.scid)
|
||||
if isinstance(r, KCM):
|
||||
log.err("got unexpected KCM")
|
||||
elif isinstance(r, Ping):
|
||||
self.handle_ping(r.ping_id)
|
||||
elif isinstance(r, Pong):
|
||||
self.handle_pong(r.ping_id)
|
||||
elif isinstance(r, Ack):
|
||||
self._outbound.handle_ack(r.resp_seqnum) # retire queued messages
|
||||
else:
|
||||
log.err("received unknown message type {}".format(r))
|
||||
|
||||
# pings, pongs, and acks are not queued
|
||||
def send_ping(self, ping_id):
|
||||
self._outbound.send_if_connected(Ping(ping_id))
|
||||
|
||||
def send_pong(self, ping_id):
|
||||
self._outbound.send_if_connected(Pong(ping_id))
|
||||
|
||||
def send_ack(self, resp_seqnum):
|
||||
self._outbound.send_if_connected(Ack(resp_seqnum))
|
||||
|
||||
|
||||
def handle_ping(self, ping_id):
|
||||
self.send_pong(ping_id)
|
||||
|
||||
def handle_pong(self, ping_id):
|
||||
# TODO: update is-alive timer
|
||||
pass
|
||||
|
||||
# subchannel maintenance
|
||||
def allocate_subchannel_id(self):
|
||||
raise NotImplemented # subclass knows if we're leader or follower
|
||||
|
||||
# new scheme:
|
||||
# * both sides send PLEASE as soon as they have an unverified key and
|
||||
# w.dilate has been called,
|
||||
# * PLEASE includes a dilation-specific "side" (independent of the "side"
|
||||
# used by mailbox messages)
|
||||
# * higher "side" is Leader, lower is Follower
|
||||
# * PLEASE includes can-dilate list of version integers, requires overlap
|
||||
# "1" is current
|
||||
# * dilation starts as soon as we've sent PLEASE and received PLEASE
|
||||
# (four-state two-variable IDLE/WANTING/WANTED/STARTED diamond FSM)
|
||||
# * HINTS sent after dilation starts
|
||||
# * only Leader sends RECONNECT, only Follower sends RECONNECTING. This
|
||||
# is the only difference between the two sides, and is not enforced
|
||||
# by the protocol (i.e. if the Follower sends RECONNECT to the Leader,
|
||||
# the Leader will obey, although TODO how confusing will this get?)
|
||||
# * upon receiving RECONNECT: drop Connector, start new Connector, send
|
||||
# RECONNECTING, start sending HINTS
|
||||
# * upon sending CONNECT: go into FLUSHING state and ignore all HINTS until
|
||||
# RECONNECTING received. The new Connector can be spun up earlier, and it
|
||||
# can send HINTS, but it must not be given any HINTS that arrive before
|
||||
# RECONNECTING (since they're probably stale)
|
||||
|
||||
# * after VERSIONS(KCM) received, we might learn that they other side cannot
|
||||
# dilate. w.dilate errbacks at this point
|
||||
|
||||
# * maybe signal warning if we stay in a "want" state for too long
|
||||
# * nobody sends HINTS until they're ready to receive
|
||||
# * nobody sends HINTS unless they've called w.dilate() and received PLEASE
|
||||
# * nobody connects to inbound hints unless they've called w.dilate()
|
||||
# * if leader calls w.dilate() but not follower, leader waits forever in
|
||||
# "want" (doesn't send anything)
|
||||
# * if follower calls w.dilate() but not leader, follower waits forever
|
||||
# in "want", leader waits forever in "wanted"
|
||||
|
||||
class ManagerShared(_ManagerBase):
|
||||
m = MethodicalMachine()
|
||||
set_trace = getattr(m, "_setTrace", lambda self, f: None)
|
||||
|
||||
@m.state(initial=True)
|
||||
def IDLE(self): pass # pragma: no cover
|
||||
|
||||
@m.state()
|
||||
def WANTING(self): pass # pragma: no cover
|
||||
@m.state()
|
||||
def WANTED(self): pass # pragma: no cover
|
||||
@m.state()
|
||||
def CONNECTING(self): pass # pragma: no cover
|
||||
@m.state()
|
||||
def CONNECTED(self): pass # pragma: no cover
|
||||
@m.state()
|
||||
def FLUSHING(self): pass # pragma: no cover
|
||||
@m.state()
|
||||
def ABANDONING(self): pass # pragma: no cover
|
||||
@m.state()
|
||||
def LONELY(self): pass # pragme: no cover
|
||||
@m.state()
|
||||
def STOPPING(self): pass # pragma: no cover
|
||||
@m.state(terminal=True)
|
||||
def STOPPED(self): pass # pragma: no cover
|
||||
|
||||
@m.input()
|
||||
def start(self): pass # pragma: no cover
|
||||
@m.input()
|
||||
def rx_PLEASE(self, message): pass # pragma: no cover
|
||||
@m.input() # only sent by Follower
|
||||
def rx_HINTS(self, hint_message): pass # pragma: no cover
|
||||
@m.input() # only Leader sends RECONNECT, so only Follower receives it
|
||||
def rx_RECONNECT(self): pass # pragma: no cover
|
||||
@m.input() # only Follower sends RECONNECTING, so only Leader receives it
|
||||
def rx_RECONNECTING(self): pass # pragma: no cover
|
||||
|
||||
# Connector gives us connection_made()
|
||||
@m.input()
|
||||
def connection_made(self, c): pass # pragma: no cover
|
||||
|
||||
# our connection_lost() fires connection_lost_leader or
|
||||
# connection_lost_follower depending upon our role. If either side sees a
|
||||
# problem with the connection (timeouts, bad authentication) then they
|
||||
# just drop it and let connection_lost() handle the cleanup.
|
||||
@m.input()
|
||||
def connection_lost_leader(self): pass # pragma: no cover
|
||||
@m.input()
|
||||
def connection_lost_follower(self): pass
|
||||
|
||||
@m.input()
|
||||
def stop(self): pass # pragma: no cover
|
||||
|
||||
@m.output()
|
||||
def stash_side(self, message):
|
||||
their_side = message["side"]
|
||||
self.my_role = LEADER if self._my_side > their_side else FOLLOWER
|
||||
|
||||
# these Outputs behave differently for the Leader vs the Follower
|
||||
@m.output()
|
||||
def send_please(self):
|
||||
self.send_dilation_phase(type="please", side=self._my_side)
|
||||
|
||||
@m.output()
|
||||
def start_connecting(self):
|
||||
self._start_connecting() # TODO: merge
|
||||
@m.output()
|
||||
def ignore_message_start_connecting(self, message):
|
||||
self.start_connecting()
|
||||
|
||||
@m.output()
|
||||
def send_reconnect(self):
|
||||
self.send_dilation_phase(type="reconnect") # TODO: generation number?
|
||||
@m.output()
|
||||
def send_reconnecting(self):
|
||||
self.send_dilation_phase(type="reconnecting") # TODO: generation?
|
||||
|
||||
@m.output()
|
||||
def use_hints(self, hint_message):
|
||||
hint_objs = filter(lambda h: h, # ignore None, unrecognizable
|
||||
[parse_hint(hs) for hs in hint_message["hints"]])
|
||||
hint_objs = list(hint_objs)
|
||||
self._connector.got_hints(hint_objs)
|
||||
@m.output()
|
||||
def stop_connecting(self):
|
||||
self._connector.stop()
|
||||
@m.output()
|
||||
def abandon_connection(self):
|
||||
# we think we're still connected, but the Leader disagrees. Or we've
|
||||
# been told to shut down.
|
||||
self._connection.disconnect() # let connection_lost do cleanup
|
||||
|
||||
|
||||
# we don't start CONNECTING until a local start() plus rx_PLEASE
|
||||
IDLE.upon(rx_PLEASE, enter=WANTED, outputs=[stash_side])
|
||||
IDLE.upon(start, enter=WANTING, outputs=[send_please])
|
||||
WANTED.upon(start, enter=CONNECTING, outputs=[send_please, start_connecting])
|
||||
WANTING.upon(rx_PLEASE, enter=CONNECTING,
|
||||
outputs=[stash_side,
|
||||
ignore_message_start_connecting])
|
||||
|
||||
CONNECTING.upon(connection_made, enter=CONNECTED, outputs=[])
|
||||
|
||||
# Leader
|
||||
CONNECTED.upon(connection_lost_leader, enter=FLUSHING,
|
||||
outputs=[send_reconnect])
|
||||
FLUSHING.upon(rx_RECONNECTING, enter=CONNECTING, outputs=[start_connecting])
|
||||
|
||||
# Follower
|
||||
# if we notice a lost connection, just wait for the Leader to notice too
|
||||
CONNECTED.upon(connection_lost_follower, enter=LONELY, outputs=[])
|
||||
LONELY.upon(rx_RECONNECT, enter=CONNECTING, outputs=[start_connecting])
|
||||
# but if they notice it first, abandon our (seemingly functional)
|
||||
# connection, then tell them that we're ready to try again
|
||||
CONNECTED.upon(rx_RECONNECT, enter=ABANDONING, # they noticed loss
|
||||
outputs=[abandon_connection])
|
||||
ABANDONING.upon(connection_lost_follower, enter=CONNECTING,
|
||||
outputs=[send_reconnecting, start_connecting])
|
||||
# and if they notice a problem while we're still connecting, abandon our
|
||||
# incomplete attempt and try again. in this case we don't have to wait
|
||||
# for a connection to finish shutdown
|
||||
CONNECTING.upon(rx_RECONNECT, enter=CONNECTING,
|
||||
outputs=[stop_connecting,
|
||||
send_reconnecting,
|
||||
start_connecting])
|
||||
|
||||
|
||||
# rx_HINTS never changes state, they're just accepted or ignored
|
||||
IDLE.upon(rx_HINTS, enter=IDLE, outputs=[]) # too early
|
||||
WANTED.upon(rx_HINTS, enter=WANTED, outputs=[]) # too early
|
||||
WANTING.upon(rx_HINTS, enter=WANTING, outputs=[]) # too early
|
||||
CONNECTING.upon(rx_HINTS, enter=CONNECTING, outputs=[use_hints])
|
||||
CONNECTED.upon(rx_HINTS, enter=CONNECTED, outputs=[]) # too late, ignore
|
||||
FLUSHING.upon(rx_HINTS, enter=FLUSHING, outputs=[]) # stale, ignore
|
||||
LONELY.upon(rx_HINTS, enter=LONELY, outputs=[]) # stale, ignore
|
||||
ABANDONING.upon(rx_HINTS, enter=ABANDONING, outputs=[]) # shouldn't happen
|
||||
STOPPING.upon(rx_HINTS, enter=STOPPING, outputs=[])
|
||||
|
||||
IDLE.upon(stop, enter=STOPPED, outputs=[])
|
||||
WANTED.upon(stop, enter=STOPPED, outputs=[])
|
||||
WANTING.upon(stop, enter=STOPPED, outputs=[])
|
||||
CONNECTING.upon(stop, enter=STOPPED, outputs=[stop_connecting])
|
||||
CONNECTED.upon(stop, enter=STOPPING, outputs=[abandon_connection])
|
||||
ABANDONING.upon(stop, enter=STOPPING, outputs=[])
|
||||
FLUSHING.upon(stop, enter=STOPPED, outputs=[stop_connecting])
|
||||
LONELY.upon(stop, enter=STOPPED, outputs=[])
|
||||
STOPPING.upon(connection_lost_leader, enter=STOPPED, outputs=[])
|
||||
STOPPING.upon(connection_lost_follower, enter=STOPPED, outputs=[])
|
||||
|
||||
|
||||
def allocate_subchannel_id(self):
|
||||
# scid 0 is reserved for the control channel. the leader uses odd
|
||||
# numbers starting with 1
|
||||
scid_num = self._next_outbound_seqnum + 1
|
||||
self._next_outbound_seqnum += 2
|
||||
return to_be4(scid_num)
|
||||
|
||||
@attrs
|
||||
@implementer(IDilator)
|
||||
class Dilator(object):
|
||||
"""I launch the dilation process.
|
||||
|
||||
I am created with every Wormhole (regardless of whether .dilate()
|
||||
was called or not), and I handle the initial phase of dilation,
|
||||
before we know whether we'll be the Leader or the Follower. Once we
|
||||
hear the other side's VERSION message (which tells us that we have a
|
||||
connection, they are capable of dilating, and which side we're on),
|
||||
then we build a DilationManager and hand control to it.
|
||||
"""
|
||||
|
||||
_reactor = attrib()
|
||||
_eventual_queue = attrib()
|
||||
_cooperator = attrib()
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
self._got_versions_d = Deferred()
|
||||
self._started = False
|
||||
self._endpoints = OneShotObserver(self._eventual_queue)
|
||||
self._pending_inbound_dilate_messages = deque()
|
||||
self._manager = None
|
||||
|
||||
def wire(self, sender):
|
||||
self._S = ISend(sender)
|
||||
|
||||
# this is the primary entry point, called when w.dilate() is invoked
|
||||
def dilate(self, transit_relay_location=None):
|
||||
self._transit_relay_location = transit_relay_location
|
||||
if not self._started:
|
||||
self._started = True
|
||||
self._start().addBoth(self._endpoints.fire)
|
||||
return self._endpoints.when_fired()
|
||||
|
||||
@inlineCallbacks
|
||||
def _start(self):
|
||||
# first, we wait until we hear the VERSION message, which tells us 1:
|
||||
# the PAKE key works, so we can talk securely, 2: their side, so we
|
||||
# know who will lead, and 3: that they can do dilation at all
|
||||
|
||||
dilation_version = yield self._got_versions_d
|
||||
|
||||
if not dilation_version: # 1 or None
|
||||
raise OldPeerCannotDilateError()
|
||||
|
||||
my_dilation_side = TODO # random
|
||||
self._manager = Manager(self._S, my_dilation_side,
|
||||
self._transit_key,
|
||||
self._transit_relay_location,
|
||||
self._reactor, self._eventual_queue,
|
||||
self._cooperator)
|
||||
self._manager.start()
|
||||
|
||||
while self._pending_inbound_dilate_messages:
|
||||
plaintext = self._pending_inbound_dilate_messages.popleft()
|
||||
self.received_dilate(plaintext)
|
||||
|
||||
# we could probably return the endpoints earlier
|
||||
yield self._manager.when_first_connected()
|
||||
# we can open subchannels as soon as we get our first connection
|
||||
scid0 = b"\x00\x00\x00\x00"
|
||||
self._host_addr = _WormholeAddress() # TODO: share with Manager
|
||||
peer_addr0 = _SubchannelAddress(scid0)
|
||||
control_ep = ControlEndpoint(peer_addr0)
|
||||
sc0 = SubChannel(scid0, self._manager, self._host_addr, peer_addr0)
|
||||
control_ep._subchannel_zero_opened(sc0)
|
||||
self._manager.set_subchannel_zero(scid0, sc0)
|
||||
|
||||
connect_ep = SubchannelConnectorEndpoint(self._manager, self._host_addr)
|
||||
|
||||
listen_ep = SubchannelListenerEndpoint(self._manager, self._host_addr)
|
||||
self._manager.set_listener_endpoint(listen_ep)
|
||||
|
||||
endpoints = (control_ep, connect_ep, listen_ep)
|
||||
returnValue(endpoints)
|
||||
|
||||
# from Boss
|
||||
|
||||
def got_key(self, key):
|
||||
# TODO: verify this happens before got_wormhole_versions, or add a gate
|
||||
# to tolerate either ordering
|
||||
purpose = b"dilation-v1"
|
||||
LENGTH =32 # TODO: whatever Noise wants, I guess
|
||||
self._transit_key = derive_key(key, purpose, LENGTH)
|
||||
|
||||
def got_wormhole_versions(self, our_side, their_side,
|
||||
their_wormhole_versions):
|
||||
# TODO: remove our_side, their_side
|
||||
assert isinstance(our_side, str), str
|
||||
assert isinstance(their_side, str), str
|
||||
# this always happens before received_dilate
|
||||
dilation_version = None
|
||||
their_dilation_versions = their_wormhole_versions.get("can-dilate", [])
|
||||
if 1 in their_dilation_versions:
|
||||
dilation_version = 1
|
||||
self._got_versions_d.callback(dilation_version)
|
||||
|
||||
def received_dilate(self, plaintext):
|
||||
# this receives new in-order DILATE-n payloads, decrypted but not
|
||||
# de-JSONed.
|
||||
|
||||
# this can appear before our .dilate() method is called, in which case
|
||||
# we queue them for later
|
||||
if not self._manager:
|
||||
self._pending_inbound_dilate_messages.append(plaintext)
|
||||
return
|
||||
|
||||
message = bytes_to_dict(plaintext)
|
||||
type = message["type"]
|
||||
if type == "please":
|
||||
self._manager.rx_PLEASE() # message)
|
||||
elif type == "dilate":
|
||||
self._manager.rx_DILATE() #message)
|
||||
elif type == "connection-hints":
|
||||
self._manager.rx_HINTS(message)
|
||||
else:
|
||||
log.err(UnknownDilationMessageType(message))
|
||||
return
|
106
src/wormhole/_dilation/old-follower.py
Normal file
106
src/wormhole/_dilation/old-follower.py
Normal file
|
@ -0,0 +1,106 @@
|
|||
|
||||
class ManagerFollower(_ManagerBase):
|
||||
m = MethodicalMachine()
|
||||
set_trace = getattr(m, "_setTrace", lambda self, f: None)
|
||||
|
||||
@m.state(initial=True)
|
||||
def IDLE(self): pass # pragma: no cover
|
||||
|
||||
@m.state()
|
||||
def WANTING(self): pass # pragma: no cover
|
||||
@m.state()
|
||||
def CONNECTING(self): pass # pragma: no cover
|
||||
@m.state()
|
||||
def CONNECTED(self): pass # pragma: no cover
|
||||
@m.state(terminal=True)
|
||||
def STOPPED(self): pass # pragma: no cover
|
||||
|
||||
@m.input()
|
||||
def start(self): pass # pragma: no cover
|
||||
@m.input()
|
||||
def rx_PLEASE(self): pass # pragma: no cover
|
||||
@m.input()
|
||||
def rx_DILATE(self): pass # pragma: no cover
|
||||
@m.input()
|
||||
def rx_HINTS(self, hint_message): pass # pragma: no cover
|
||||
|
||||
@m.input()
|
||||
def connection_made(self): pass # pragma: no cover
|
||||
@m.input()
|
||||
def connection_lost(self): pass # pragma: no cover
|
||||
# follower doesn't react to connection_lost, but waits for a new LETS_DILATE
|
||||
|
||||
@m.input()
|
||||
def stop(self): pass # pragma: no cover
|
||||
|
||||
# these Outputs behave differently for the Leader vs the Follower
|
||||
@m.output()
|
||||
def send_please(self):
|
||||
self.send_dilation_phase(type="please")
|
||||
|
||||
@m.output()
|
||||
def start_connecting(self):
|
||||
self._start_connecting(FOLLOWER)
|
||||
|
||||
# these Outputs delegate to the same code in both the Leader and the
|
||||
# Follower, but they must be replicated here because the Automat instance
|
||||
# is on the subclass, not the shared superclass
|
||||
|
||||
@m.output()
|
||||
def use_hints(self, hint_message):
|
||||
hint_objs = filter(lambda h: h, # ignore None, unrecognizable
|
||||
[parse_hint(hs) for hs in hint_message["hints"]])
|
||||
self._connector.got_hints(hint_objs)
|
||||
@m.output()
|
||||
def stop_connecting(self):
|
||||
self._connector.stop()
|
||||
@m.output()
|
||||
def use_connection(self, c):
|
||||
self._use_connection(c)
|
||||
@m.output()
|
||||
def stop_using_connection(self):
|
||||
self._stop_using_connection()
|
||||
@m.output()
|
||||
def signal_error(self):
|
||||
pass # TODO
|
||||
@m.output()
|
||||
def signal_error_hints(self, hint_message):
|
||||
pass # TODO
|
||||
|
||||
IDLE.upon(rx_HINTS, enter=STOPPED, outputs=[signal_error_hints]) # too early
|
||||
IDLE.upon(rx_DILATE, enter=STOPPED, outputs=[signal_error]) # too early
|
||||
# leader shouldn't send us DILATE before receiving our PLEASE
|
||||
IDLE.upon(stop, enter=STOPPED, outputs=[])
|
||||
IDLE.upon(start, enter=WANTING, outputs=[send_please])
|
||||
WANTING.upon(rx_DILATE, enter=CONNECTING, outputs=[start_connecting])
|
||||
WANTING.upon(stop, enter=STOPPED, outputs=[])
|
||||
|
||||
CONNECTING.upon(rx_HINTS, enter=CONNECTING, outputs=[use_hints])
|
||||
CONNECTING.upon(connection_made, enter=CONNECTED, outputs=[use_connection])
|
||||
# shouldn't happen: connection_lost
|
||||
#CONNECTING.upon(connection_lost, enter=CONNECTING, outputs=[?])
|
||||
CONNECTING.upon(rx_DILATE, enter=CONNECTING, outputs=[stop_connecting,
|
||||
start_connecting])
|
||||
# receiving rx_DILATE while we're still working on the last one means the
|
||||
# leader thought we'd connected, then thought we'd been disconnected, all
|
||||
# before we heard about that connection
|
||||
CONNECTING.upon(stop, enter=STOPPED, outputs=[stop_connecting])
|
||||
|
||||
CONNECTED.upon(connection_lost, enter=WANTING, outputs=[stop_using_connection])
|
||||
CONNECTED.upon(rx_DILATE, enter=CONNECTING, outputs=[stop_using_connection,
|
||||
start_connecting])
|
||||
CONNECTED.upon(rx_HINTS, enter=CONNECTED, outputs=[]) # too late, ignore
|
||||
CONNECTED.upon(stop, enter=STOPPED, outputs=[stop_using_connection])
|
||||
# shouldn't happen: connection_made
|
||||
|
||||
# we should never receive PLEASE, we're the follower
|
||||
IDLE.upon(rx_PLEASE, enter=STOPPED, outputs=[signal_error])
|
||||
WANTING.upon(rx_PLEASE, enter=STOPPED, outputs=[signal_error])
|
||||
CONNECTING.upon(rx_PLEASE, enter=STOPPED, outputs=[signal_error])
|
||||
CONNECTED.upon(rx_PLEASE, enter=STOPPED, outputs=[signal_error])
|
||||
|
||||
def allocate_subchannel_id(self):
|
||||
# the follower uses even numbers starting with 2
|
||||
scid_num = self._next_outbound_seqnum + 2
|
||||
self._next_outbound_seqnum += 2
|
||||
return to_be4(scid_num)
|
392
src/wormhole/_dilation/outbound.py
Normal file
392
src/wormhole/_dilation/outbound.py
Normal file
|
@ -0,0 +1,392 @@
|
|||
from __future__ import print_function, unicode_literals
|
||||
from collections import deque
|
||||
from attr import attrs, attrib
|
||||
from attr.validators import provides
|
||||
from zope.interface import implementer
|
||||
from twisted.internet.interfaces import IPushProducer, IPullProducer
|
||||
from twisted.python import log
|
||||
from twisted.python.reflect import safe_str
|
||||
from .._interfaces import IDilationManager, IOutbound
|
||||
from .connection import KCM, Ping, Pong, Ack
|
||||
|
||||
|
||||
# Outbound flow control: app writes to subchannel, we write to Connection
|
||||
|
||||
# The app can register an IProducer of their choice, to let us throttle their
|
||||
# outbound data. Not all subchannels will have producers registered, and the
|
||||
# producer probably won't be the IProtocol instance (it'll be something else
|
||||
# which feeds data out through the protocol, like a t.p.basic.FileSender). If
|
||||
# a producerless subchannel writes too much, we won't be able to stop them,
|
||||
# and we'll keep writing records into the Connection even though it's asked
|
||||
# us to pause. Likewise, when the connection is down (and we're busily trying
|
||||
# to reestablish a new one), registered subchannels will be paused, but
|
||||
# unregistered ones will just dump everything in _outbound_queue, and we'll
|
||||
# consume memory without bound until they stop.
|
||||
|
||||
# We need several things:
|
||||
#
|
||||
# * Add each registered IProducer to a list, whose order remains stable. We
|
||||
# want fairness under outbound throttling: each time the outbound
|
||||
# connection opens up (our resumeProducing method is called), we should let
|
||||
# just one producer have an opportunity to do transport.write, and then we
|
||||
# should pause them again, and not come back to them until everyone else
|
||||
# has gotten a turn. So we want an ordered list of producers to track this
|
||||
# rotation.
|
||||
#
|
||||
# * Remove the IProducer if/when the protocol uses unregisterProducer
|
||||
#
|
||||
# * Remove any registered IProducer when the associated Subchannel is closed.
|
||||
# This isn't a problem for normal transports, because usually there's a
|
||||
# one-to-one mapping from Protocol to Transport, so when the Transport you
|
||||
# forget the only reference to the Producer anyways. Our situation is
|
||||
# unusual because we have multiple Subchannels that get merged into the
|
||||
# same underlying Connection: each Subchannel's Protocol can register a
|
||||
# producer on the Subchannel (which is an ITransport), but that adds it to
|
||||
# a set of Producers for the Connection (which is also an ITransport). So
|
||||
# if the Subchannel is closed, we need to remove its Producer (if any) even
|
||||
# though the Connection remains open.
|
||||
#
|
||||
# * Register ourselves as an IPushProducer with each successive Connection
|
||||
# object. These connections will come and go, but there will never be more
|
||||
# than one. When the connection goes away, pause all our producers. When a
|
||||
# new one is established, write all our queued messages, then unpause our
|
||||
# producers as we would in resumeProducing.
|
||||
#
|
||||
# * Inside our resumeProducing call, we'll cycle through all producers,
|
||||
# calling their individual resumeProducing methods one at a time. If they
|
||||
# write so much data that the Connection pauses us again, we'll find out
|
||||
# because our pauseProducing will be called inside that loop. When that
|
||||
# happens, we need to stop looping. If we make it through the whole loop
|
||||
# without being paused, then all subchannel Producers are left unpaused,
|
||||
# and are free to write whenever they want. During this loop, some
|
||||
# Producers will be paused, and others will be resumed
|
||||
#
|
||||
# * If our pauseProducing is called, all Producers must be paused, and a flag
|
||||
# should be set to notify the resumeProducing loop to exit
|
||||
#
|
||||
# * In between calls to our resumeProducing method, we're in one of two
|
||||
# states.
|
||||
# * If we're writing data too fast, then we'll be left in the "paused"
|
||||
# state, in which all Subchannel producers are paused, and the aggregate
|
||||
# is paused too (our Connection told us to pauseProducing and hasn't yet
|
||||
# told us to resumeProducing). In this state, activity is driven by the
|
||||
# outbound TCP window opening up, which calls resumeProducing and allows
|
||||
# (probably just) one message to be sent. We receive pauseProducing in
|
||||
# the middle of their transport.write, so the loop exits early, and the
|
||||
# only state change is that some other Producer should get to go next
|
||||
# time.
|
||||
# * If we're writing too slowly, we'll be left in the "unpaused" state: all
|
||||
# Subchannel producers are unpaused, and the aggregate is unpaused too
|
||||
# (resumeProducing is the last thing we've been told). In this satte,
|
||||
# activity is driven by the Subchannels doing a transport.write, which
|
||||
# queues some data on the TCP connection (and then might call
|
||||
# pauseProducing if it's now full).
|
||||
#
|
||||
# * We want to guard against:
|
||||
#
|
||||
# * application protocol registering a Producer without first unregistering
|
||||
# the previous one
|
||||
#
|
||||
# * application protocols writing data despite being told to pause
|
||||
# (Subchannels without a registered Producer cannot be throttled, and we
|
||||
# can't do anything about that, but we must also handle the case where
|
||||
# they give us a pause switch and then proceed to ignore it)
|
||||
#
|
||||
# * our Connection calling resumeProducing or pauseProducing without an
|
||||
# intervening call of the other kind
|
||||
#
|
||||
# * application protocols that don't handle a resumeProducing or
|
||||
# pauseProducing call without an intervening call of the other kind (i.e.
|
||||
# we should keep track of the last thing we told them, and not repeat
|
||||
# ourselves)
|
||||
#
|
||||
# * If the Wormhole is closed, all Subchannels should close. This is not our
|
||||
# responsibility: it lives in (Manager? Inbound?)
|
||||
#
|
||||
# * If we're given an IPullProducer, we should keep calling its
|
||||
# resumeProducing until it runs out of data. We still want fairness, so we
|
||||
# won't call it a second time until everyone else has had a turn.
|
||||
|
||||
|
||||
# There are a couple of different ways to approach this. The one I've
|
||||
# selected is:
|
||||
#
|
||||
# * keep a dict that maps from Subchannel to Producer, which only contains
|
||||
# entries for Subchannels that have registered a producer. We use this to
|
||||
# remove Producers when Subchannels are closed
|
||||
#
|
||||
# * keep a Deque of Producers. This represents the fair-throttling rotation:
|
||||
# the left-most item gets the next upcoming turn, and then they'll be moved
|
||||
# to the end of the queue.
|
||||
#
|
||||
# * keep a set of IPushProducers which are paused, a second set of
|
||||
# IPushProducers which are unpaused, and a third set of IPullProducers
|
||||
# (which are always left paused) Enforce the invariant that these three
|
||||
# sets are disjoint, and that their union equals the contents of the deque.
|
||||
#
|
||||
# * keep a "paused" flag, which is cleared upon entry to resumeProducing, and
|
||||
# set upon entry to pauseProducing. The loop inside resumeProducing checks
|
||||
# this flag after each call to producer.resumeProducing, to sense whether
|
||||
# they used their turn to write data, and if that write was large enough to
|
||||
# fill the TCP window. If set, we break out of the loop. If not, we look
|
||||
# for the next producer to unpause. The loop finishes when all producers
|
||||
# are unpaused (evidenced by the two sets of paused producers being empty)
|
||||
#
|
||||
# * the "paused" flag also determines whether new IPushProducers are added to
|
||||
# the paused or unpaused set (IPullProducers are always added to the
|
||||
# pull+paused set). If we have any IPullProducers, we're always in the
|
||||
# "writing data too fast" state.
|
||||
|
||||
# other approaches that I didn't decide to do at this time (but might use in
|
||||
# the future):
|
||||
#
|
||||
# * use one set instead of two. pros: fewer moving parts. cons: harder to
|
||||
# spot decoherence bugs like adding a producer to the deque but forgetting
|
||||
# to add it to one of the
|
||||
#
|
||||
# * use zero sets, and keep the paused-vs-unpaused state in the Subchannel as
|
||||
# a visible boolean flag. This conflates Subchannels with their associated
|
||||
# Producer (so if we went this way, we should also let them track their own
|
||||
# Producer). Our resumeProducing loop ends when 'not any(sc.paused for sc
|
||||
# in self._subchannels_with_producers)'. Pros: fewer subchannel->producer
|
||||
# mappings lying around to disagree with one another. Cons: exposes a bit
|
||||
# too much of the Subchannel internals
|
||||
|
||||
|
||||
@attrs
|
||||
@implementer(IOutbound)
|
||||
class Outbound(object):
|
||||
# Manage outbound data: subchannel writes to us, we write to transport
|
||||
_manager = attrib(validator=provides(IDilationManager))
|
||||
_cooperator = attrib()
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
# _outbound_queue holds all messages we've ever sent but not retired
|
||||
self._outbound_queue = deque()
|
||||
self._next_outbound_seqnum = 0
|
||||
# _queued_unsent are messages to retry with our new connection
|
||||
self._queued_unsent = deque()
|
||||
|
||||
# outbound flow control: the Connection throttles our writes
|
||||
self._subchannel_producers = {} # Subchannel -> IProducer
|
||||
self._paused = True # our Connection called our pauseProducing
|
||||
self._all_producers = deque() # rotates, left-is-next
|
||||
self._paused_producers = set()
|
||||
self._unpaused_producers = set()
|
||||
self._check_invariants()
|
||||
|
||||
self._connection = None
|
||||
|
||||
def _check_invariants(self):
|
||||
assert self._unpaused_producers.isdisjoint(self._paused_producers)
|
||||
assert (self._paused_producers.union(self._unpaused_producers) ==
|
||||
set(self._all_producers))
|
||||
|
||||
def build_record(self, record_type, *args):
|
||||
seqnum = self._next_outbound_seqnum
|
||||
self._next_outbound_seqnum += 1
|
||||
r = record_type(seqnum, *args)
|
||||
assert hasattr(r, "seqnum"), r # only Open/Data/Close
|
||||
return r
|
||||
|
||||
def queue_and_send_record(self, r):
|
||||
# we always queue it, to resend on a subsequent connection if
|
||||
# necessary
|
||||
self._outbound_queue.append(r)
|
||||
|
||||
if self._connection:
|
||||
if self._queued_unsent:
|
||||
# to maintain correct ordering, queue this instead of sending it
|
||||
self._queued_unsent.append(r)
|
||||
else:
|
||||
# we're allowed to send it immediately
|
||||
self._connection.send_record(r)
|
||||
|
||||
def send_if_connected(self, r):
|
||||
assert isinstance(r, (KCM, Ping, Pong, Ack)), r # nothing with seqnum
|
||||
if self._connection:
|
||||
self._connection.send_record(r)
|
||||
|
||||
# our subchannels call these to register a producer
|
||||
|
||||
def subchannel_registerProducer(self, sc, producer, streaming):
|
||||
# streaming==True: IPushProducer (pause/resume)
|
||||
# streaming==False: IPullProducer (just resume)
|
||||
if sc in self._subchannel_producers:
|
||||
raise ValueError(
|
||||
"registering producer %s before previous one (%s) was "
|
||||
"unregistered" % (producer,
|
||||
self._subchannel_producers[sc]))
|
||||
# our underlying Connection uses streaming==True, so to make things
|
||||
# easier, use an adapter when the Subchannel asks for streaming=False
|
||||
if not streaming:
|
||||
def unregister():
|
||||
self.subchannel_unregisterProducer(sc)
|
||||
producer = PullToPush(producer, unregister, self._cooperator)
|
||||
|
||||
self._subchannel_producers[sc] = producer
|
||||
self._all_producers.append(producer)
|
||||
if self._paused:
|
||||
self._paused_producers.add(producer)
|
||||
else:
|
||||
self._unpaused_producers.add(producer)
|
||||
self._check_invariants()
|
||||
if streaming:
|
||||
if self._paused:
|
||||
# IPushProducers need to be paused immediately, before they
|
||||
# speak
|
||||
producer.pauseProducing() # you wake up sleeping
|
||||
else:
|
||||
# our PullToPush adapter must be started, but if we're paused then
|
||||
# we tell it to pause before it gets a chance to write anything
|
||||
producer.startStreaming(self._paused)
|
||||
|
||||
def subchannel_unregisterProducer(self, sc):
|
||||
# TODO: what if the subchannel closes, so we unregister their
|
||||
# producer for them, then the application reacts to connectionLost
|
||||
# with a duplicate unregisterProducer?
|
||||
p = self._subchannel_producers.pop(sc)
|
||||
if isinstance(p, PullToPush):
|
||||
p.stopStreaming()
|
||||
self._all_producers.remove(p)
|
||||
self._paused_producers.discard(p)
|
||||
self._unpaused_producers.discard(p)
|
||||
self._check_invariants()
|
||||
|
||||
def subchannel_closed(self, sc):
|
||||
self._check_invariants()
|
||||
if sc in self._subchannel_producers:
|
||||
self.subchannel_unregisterProducer(sc)
|
||||
|
||||
# our Manager tells us when we've got a new Connection to work with
|
||||
|
||||
def use_connection(self, c):
|
||||
self._connection = c
|
||||
assert not self._queued_unsent
|
||||
self._queued_unsent.extend(self._outbound_queue)
|
||||
# the connection can tell us to pause when we send too much data
|
||||
c.registerProducer(self, True) # IPushProducer: pause+resume
|
||||
# send our queued messages
|
||||
self.resumeProducing()
|
||||
|
||||
def stop_using_connection(self):
|
||||
self._connection.unregisterProducer()
|
||||
self._connection = None
|
||||
self._queued_unsent.clear()
|
||||
self.pauseProducing()
|
||||
# TODO: I expect this will call pauseProducing twice: the first time
|
||||
# when we get stopProducing (since we're registere with the
|
||||
# underlying connection as the producer), and again when the manager
|
||||
# notices the connectionLost and calls our _stop_using_connection
|
||||
|
||||
def handle_ack(self, resp_seqnum):
|
||||
# we've received an inbound ack, so retire something
|
||||
while (self._outbound_queue and
|
||||
self._outbound_queue[0].seqnum <= resp_seqnum):
|
||||
self._outbound_queue.popleft()
|
||||
while (self._queued_unsent and
|
||||
self._queued_unsent[0].seqnum <= resp_seqnum):
|
||||
self._queued_unsent.popleft()
|
||||
# Inbound is responsible for tracking the high watermark and deciding
|
||||
# whether to ignore inbound messages or not
|
||||
|
||||
|
||||
# IProducer: the active connection calls these because we used
|
||||
# c.registerProducer to ask for them
|
||||
def pauseProducing(self):
|
||||
if self._paused:
|
||||
return # someone is confused and called us twice
|
||||
self._paused = True
|
||||
for p in self._all_producers:
|
||||
if p in self._unpaused_producers:
|
||||
self._unpaused_producers.remove(p)
|
||||
self._paused_producers.add(p)
|
||||
p.pauseProducing()
|
||||
|
||||
def resumeProducing(self):
|
||||
if not self._paused:
|
||||
return # someone is confused and called us twice
|
||||
self._paused = False
|
||||
|
||||
while not self._paused:
|
||||
if self._queued_unsent:
|
||||
r = self._queued_unsent.popleft()
|
||||
self._connection.send_record(r)
|
||||
continue
|
||||
p = self._get_next_unpaused_producer()
|
||||
if not p:
|
||||
break
|
||||
self._paused_producers.remove(p)
|
||||
self._unpaused_producers.add(p)
|
||||
p.resumeProducing()
|
||||
|
||||
def _get_next_unpaused_producer(self):
|
||||
self._check_invariants()
|
||||
if not self._paused_producers:
|
||||
return None
|
||||
while True:
|
||||
p = self._all_producers[0]
|
||||
self._all_producers.rotate(-1) # p moves to the end of the line
|
||||
# the only unpaused Producers are at the end of the list
|
||||
assert p in self._paused_producers
|
||||
return p
|
||||
|
||||
def stopProducing(self):
|
||||
# we'll hopefully have a new connection to work with in the future,
|
||||
# so we don't shut anything down. We do pause everyone, though.
|
||||
self.pauseProducing()
|
||||
|
||||
|
||||
# modelled after twisted.internet._producer_helper._PullToPush , but with a
|
||||
# configurable Cooperator, a pause-immediately argument to startStreaming()
|
||||
@implementer(IPushProducer)
|
||||
@attrs(cmp=False)
|
||||
class PullToPush(object):
|
||||
_producer = attrib(validator=provides(IPullProducer))
|
||||
_unregister = attrib(validator=lambda _a,_b,v: callable(v))
|
||||
_cooperator = attrib()
|
||||
_finished = False
|
||||
|
||||
def _pull(self):
|
||||
while True:
|
||||
try:
|
||||
self._producer.resumeProducing()
|
||||
except:
|
||||
log.err(None, "%s failed, producing will be stopped:" %
|
||||
(safe_str(self._producer),))
|
||||
try:
|
||||
self._unregister()
|
||||
# The consumer should now call stopStreaming() on us,
|
||||
# thus stopping the streaming.
|
||||
except:
|
||||
# Since the consumer blew up, we may not have had
|
||||
# stopStreaming() called, so we just stop on our own:
|
||||
log.err(None, "%s failed to unregister producer:" %
|
||||
(safe_str(self._unregister),))
|
||||
self._finished = True
|
||||
return
|
||||
yield None
|
||||
|
||||
def startStreaming(self, paused):
|
||||
self._coopTask = self._cooperator.cooperate(self._pull())
|
||||
if paused:
|
||||
self.pauseProducing() # timer is scheduled, but task is removed
|
||||
|
||||
def stopStreaming(self):
|
||||
if self._finished:
|
||||
return
|
||||
self._finished = True
|
||||
self._coopTask.stop()
|
||||
|
||||
|
||||
def pauseProducing(self):
|
||||
self._coopTask.pause()
|
||||
|
||||
|
||||
def resumeProducing(self):
|
||||
self._coopTask.resume()
|
||||
|
||||
|
||||
def stopProducing(self):
|
||||
self.stopStreaming()
|
||||
self._producer.stopProducing()
|
1
src/wormhole/_dilation/roles.py
Normal file
1
src/wormhole/_dilation/roles.py
Normal file
|
@ -0,0 +1 @@
|
|||
LEADER, FOLLOWER = object(), object()
|
269
src/wormhole/_dilation/subchannel.py
Normal file
269
src/wormhole/_dilation/subchannel.py
Normal file
|
@ -0,0 +1,269 @@
|
|||
from attr import attrs, attrib
|
||||
from attr.validators import instance_of, provides
|
||||
from zope.interface import implementer
|
||||
from twisted.internet.defer import Deferred, inlineCallbacks, returnValue, succeed
|
||||
from twisted.internet.interfaces import (ITransport, IProducer, IConsumer,
|
||||
IAddress, IListeningPort,
|
||||
IStreamClientEndpoint,
|
||||
IStreamServerEndpoint)
|
||||
from twisted.internet.error import ConnectionDone
|
||||
from automat import MethodicalMachine
|
||||
from .._interfaces import ISubChannel, IDilationManager
|
||||
|
||||
@attrs
|
||||
class Once(object):
|
||||
_errtype = attrib()
|
||||
def __attrs_post_init__(self):
|
||||
self._called = False
|
||||
|
||||
def __call__(self):
|
||||
if self._called:
|
||||
raise self._errtype()
|
||||
self._called = True
|
||||
|
||||
class SingleUseEndpointError(Exception):
|
||||
pass
|
||||
|
||||
# created in the (OPEN) state, by either:
|
||||
# * receipt of an OPEN message
|
||||
# * or local client_endpoint.connect()
|
||||
# then transitions are:
|
||||
# (OPEN) rx DATA: deliver .dataReceived(), -> (OPEN)
|
||||
# (OPEN) rx CLOSE: deliver .connectionLost(), send CLOSE, -> (CLOSED)
|
||||
# (OPEN) local .write(): send DATA, -> (OPEN)
|
||||
# (OPEN) local .loseConnection(): send CLOSE, -> (CLOSING)
|
||||
# (CLOSING) local .write(): error
|
||||
# (CLOSING) local .loseConnection(): error
|
||||
# (CLOSING) rx DATA: deliver .dataReceived(), -> (CLOSING)
|
||||
# (CLOSING) rx CLOSE: deliver .connectionLost(), -> (CLOSED)
|
||||
# object is deleted upon transition to (CLOSED)
|
||||
|
||||
class AlreadyClosedError(Exception):
|
||||
pass
|
||||
|
||||
@implementer(IAddress)
|
||||
class _WormholeAddress(object):
|
||||
pass
|
||||
|
||||
@implementer(IAddress)
|
||||
@attrs
|
||||
class _SubchannelAddress(object):
|
||||
_scid = attrib()
|
||||
|
||||
|
||||
@attrs
|
||||
@implementer(ITransport)
|
||||
@implementer(IProducer)
|
||||
@implementer(IConsumer)
|
||||
@implementer(ISubChannel)
|
||||
class SubChannel(object):
|
||||
_id = attrib(validator=instance_of(bytes))
|
||||
_manager = attrib(validator=provides(IDilationManager))
|
||||
_host_addr = attrib(validator=instance_of(_WormholeAddress))
|
||||
_peer_addr = attrib(validator=instance_of(_SubchannelAddress))
|
||||
|
||||
m = MethodicalMachine()
|
||||
set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
#self._mailbox = None
|
||||
#self._pending_outbound = {}
|
||||
#self._processed = set()
|
||||
self._protocol = None
|
||||
self._pending_dataReceived = []
|
||||
self._pending_connectionLost = (False, None)
|
||||
|
||||
@m.state(initial=True)
|
||||
def open(self): pass # pragma: no cover
|
||||
|
||||
@m.state()
|
||||
def closing(): pass # pragma: no cover
|
||||
|
||||
@m.state()
|
||||
def closed(): pass # pragma: no cover
|
||||
|
||||
@m.input()
|
||||
def remote_data(self, data): pass
|
||||
@m.input()
|
||||
def remote_close(self): pass
|
||||
|
||||
@m.input()
|
||||
def local_data(self, data): pass
|
||||
@m.input()
|
||||
def local_close(self): pass
|
||||
|
||||
|
||||
@m.output()
|
||||
def send_data(self, data):
|
||||
self._manager.send_data(self._id, data)
|
||||
|
||||
@m.output()
|
||||
def send_close(self):
|
||||
self._manager.send_close(self._id)
|
||||
|
||||
@m.output()
|
||||
def signal_dataReceived(self, data):
|
||||
if self._protocol:
|
||||
self._protocol.dataReceived(data)
|
||||
else:
|
||||
self._pending_dataReceived.append(data)
|
||||
|
||||
@m.output()
|
||||
def signal_connectionLost(self):
|
||||
if self._protocol:
|
||||
self._protocol.connectionLost(ConnectionDone())
|
||||
else:
|
||||
self._pending_connectionLost = (True, ConnectionDone())
|
||||
self._manager.subchannel_closed(self)
|
||||
# we're deleted momentarily
|
||||
|
||||
@m.output()
|
||||
def error_closed_write(self, data):
|
||||
raise AlreadyClosedError("write not allowed on closed subchannel")
|
||||
@m.output()
|
||||
def error_closed_close(self):
|
||||
raise AlreadyClosedError("loseConnection not allowed on closed subchannel")
|
||||
|
||||
# primary transitions
|
||||
open.upon(remote_data, enter=open, outputs=[signal_dataReceived])
|
||||
open.upon(local_data, enter=open, outputs=[send_data])
|
||||
open.upon(remote_close, enter=closed, outputs=[signal_connectionLost])
|
||||
open.upon(local_close, enter=closing, outputs=[send_close])
|
||||
closing.upon(remote_data, enter=closing, outputs=[signal_dataReceived])
|
||||
closing.upon(remote_close, enter=closed, outputs=[signal_connectionLost])
|
||||
|
||||
# error cases
|
||||
# we won't ever see an OPEN, since L4 will log+ignore those for us
|
||||
closing.upon(local_data, enter=closing, outputs=[error_closed_write])
|
||||
closing.upon(local_close, enter=closing, outputs=[error_closed_close])
|
||||
# the CLOSED state won't ever see messages, since we'll be deleted
|
||||
|
||||
# our endpoints use this
|
||||
|
||||
def _set_protocol(self, protocol):
|
||||
assert not self._protocol
|
||||
self._protocol = protocol
|
||||
if self._pending_dataReceived:
|
||||
for data in self._pending_dataReceived:
|
||||
self._protocol.dataReceived(data)
|
||||
self._pending_dataReceived = []
|
||||
cl, what = self._pending_connectionLost
|
||||
if cl:
|
||||
self._protocol.connectionLost(what)
|
||||
|
||||
# ITransport
|
||||
def write(self, data):
|
||||
assert isinstance(data, type(b""))
|
||||
self.local_data(data)
|
||||
def writeSequence(self, iovec):
|
||||
self.write(b"".join(iovec))
|
||||
def loseConnection(self):
|
||||
self.local_close()
|
||||
def getHost(self):
|
||||
# we define "host addr" as the overall wormhole
|
||||
return self._host_addr
|
||||
def getPeer(self):
|
||||
# and "peer addr" as the subchannel within that wormhole
|
||||
return self._peer_addr
|
||||
|
||||
# IProducer: throttle inbound data (wormhole "up" to local app's Protocol)
|
||||
def stopProducing(self):
|
||||
self._manager.subchannel_stopProducing(self)
|
||||
def pauseProducing(self):
|
||||
self._manager.subchannel_pauseProducing(self)
|
||||
def resumeProducing(self):
|
||||
self._manager.subchannel_resumeProducing(self)
|
||||
|
||||
# IConsumer: allow the wormhole to throttle outbound data (app->wormhole)
|
||||
def registerProducer(self, producer, streaming):
|
||||
self._manager.subchannel_registerProducer(self, producer, streaming)
|
||||
def unregisterProducer(self):
|
||||
self._manager.subchannel_unregisterProducer(self)
|
||||
|
||||
|
||||
@implementer(IStreamClientEndpoint)
|
||||
class ControlEndpoint(object):
|
||||
_used = False
|
||||
def __init__(self, peer_addr):
|
||||
self._subchannel_zero = Deferred()
|
||||
self._peer_addr = peer_addr
|
||||
self._once = Once(SingleUseEndpointError)
|
||||
|
||||
# from manager
|
||||
def _subchannel_zero_opened(self, subchannel):
|
||||
assert ISubChannel.providedBy(subchannel), subchannel
|
||||
self._subchannel_zero.callback(subchannel)
|
||||
|
||||
@inlineCallbacks
|
||||
def connect(self, protocolFactory):
|
||||
# return Deferred that fires with IProtocol or Failure(ConnectError)
|
||||
self._once()
|
||||
t = yield self._subchannel_zero
|
||||
p = protocolFactory.buildProtocol(self._peer_addr)
|
||||
t._set_protocol(p)
|
||||
p.makeConnection(t) # set p.transport = t and call connectionMade()
|
||||
returnValue(p)
|
||||
|
||||
@implementer(IStreamClientEndpoint)
|
||||
@attrs
|
||||
class SubchannelConnectorEndpoint(object):
|
||||
_manager = attrib(validator=provides(IDilationManager))
|
||||
_host_addr = attrib(validator=instance_of(_WormholeAddress))
|
||||
|
||||
def connect(self, protocolFactory):
|
||||
# return Deferred that fires with IProtocol or Failure(ConnectError)
|
||||
scid = self._manager.allocate_subchannel_id()
|
||||
self._manager.send_open(scid)
|
||||
peer_addr = _SubchannelAddress(scid)
|
||||
# ? f.doStart()
|
||||
# ? f.startedConnecting(CONNECTOR) # ??
|
||||
t = SubChannel(scid, self._manager, self._host_addr, peer_addr)
|
||||
p = protocolFactory.buildProtocol(peer_addr)
|
||||
t._set_protocol(p)
|
||||
p.makeConnection(t) # set p.transport = t and call connectionMade()
|
||||
return succeed(p)
|
||||
|
||||
@implementer(IStreamServerEndpoint)
|
||||
@attrs
|
||||
class SubchannelListenerEndpoint(object):
|
||||
_manager = attrib(validator=provides(IDilationManager))
|
||||
_host_addr = attrib(validator=provides(IAddress))
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
self._factory = None
|
||||
self._pending_opens = []
|
||||
|
||||
# from manager
|
||||
def _got_open(self, t, peer_addr):
|
||||
if self._factory:
|
||||
self._connect(t, peer_addr)
|
||||
else:
|
||||
self._pending_opens.append( (t, peer_addr) )
|
||||
|
||||
def _connect(self, t, peer_addr):
|
||||
p = self._factory.buildProtocol(peer_addr)
|
||||
t._set_protocol(p)
|
||||
p.makeConnection(t)
|
||||
|
||||
# IStreamServerEndpoint
|
||||
|
||||
def listen(self, protocolFactory):
|
||||
self._factory = protocolFactory
|
||||
for (t, peer_addr) in self._pending_opens:
|
||||
self._connect(t, peer_addr)
|
||||
self._pending_opens = []
|
||||
lp = SubchannelListeningPort(self._host_addr)
|
||||
return succeed(lp)
|
||||
|
||||
@implementer(IListeningPort)
|
||||
@attrs
|
||||
class SubchannelListeningPort(object):
|
||||
_host_addr = attrib(validator=provides(IAddress))
|
||||
|
||||
def startListening(self):
|
||||
pass
|
||||
def stopListening(self):
|
||||
# TODO
|
||||
pass
|
||||
def getHost(self):
|
||||
return self._host_addr
|
|
@ -433,3 +433,16 @@ class IInputHelper(Interface):
|
|||
|
||||
class IJournal(Interface): # TODO: this needs to be public
|
||||
pass
|
||||
|
||||
class IDilator(Interface):
|
||||
pass
|
||||
class IDilationManager(Interface):
|
||||
pass
|
||||
class IDilationConnector(Interface):
|
||||
pass
|
||||
class ISubChannel(Interface):
|
||||
pass
|
||||
class IInbound(Interface):
|
||||
pass
|
||||
class IOutbound(Interface):
|
||||
pass
|
||||
|
|
0
src/wormhole/test/dilate/__init__.py
Normal file
0
src/wormhole/test/dilate/__init__.py
Normal file
18
src/wormhole/test/dilate/common.py
Normal file
18
src/wormhole/test/dilate/common.py
Normal file
|
@ -0,0 +1,18 @@
|
|||
from __future__ import print_function, unicode_literals
|
||||
import mock
|
||||
from zope.interface import alsoProvides
|
||||
from ..._interfaces import IDilationManager, IWormhole
|
||||
|
||||
def mock_manager():
|
||||
m = mock.Mock()
|
||||
alsoProvides(m, IDilationManager)
|
||||
return m
|
||||
|
||||
def mock_wormhole():
|
||||
m = mock.Mock()
|
||||
alsoProvides(m, IWormhole)
|
||||
return m
|
||||
|
||||
def clear_mock_calls(*args):
|
||||
for a in args:
|
||||
a.mock_calls[:] = []
|
216
src/wormhole/test/dilate/test_connection.py
Normal file
216
src/wormhole/test/dilate/test_connection.py
Normal file
|
@ -0,0 +1,216 @@
|
|||
from __future__ import print_function, unicode_literals
|
||||
import mock
|
||||
from zope.interface import alsoProvides
|
||||
from twisted.trial import unittest
|
||||
from twisted.internet.task import Clock
|
||||
from twisted.internet.interfaces import ITransport
|
||||
from ...eventual import EventualQueue
|
||||
from ..._interfaces import IDilationConnector
|
||||
from ..._dilation.roles import LEADER, FOLLOWER
|
||||
from ..._dilation.connection import (DilatedConnectionProtocol, encode_record,
|
||||
KCM, Open, Ack)
|
||||
from .common import clear_mock_calls
|
||||
|
||||
def make_con(role, use_relay=False):
|
||||
clock = Clock()
|
||||
eq = EventualQueue(clock)
|
||||
connector = mock.Mock()
|
||||
alsoProvides(connector, IDilationConnector)
|
||||
n = mock.Mock() # pretends to be a Noise object
|
||||
n.write_message = mock.Mock(side_effect=[b"handshake"])
|
||||
c = DilatedConnectionProtocol(eq, role, connector, n,
|
||||
b"outbound_prologue\n", b"inbound_prologue\n")
|
||||
if use_relay:
|
||||
c.use_relay(b"relay_handshake\n")
|
||||
t = mock.Mock()
|
||||
alsoProvides(t, ITransport)
|
||||
return c, n, connector, t, eq
|
||||
|
||||
class Connection(unittest.TestCase):
|
||||
def test_bad_prologue(self):
|
||||
c, n, connector, t, eq = make_con(LEADER)
|
||||
c.makeConnection(t)
|
||||
d = c.when_disconnected()
|
||||
self.assertEqual(n.mock_calls, [mock.call.start_handshake()])
|
||||
self.assertEqual(connector.mock_calls, [])
|
||||
self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")])
|
||||
clear_mock_calls(n, connector, t)
|
||||
|
||||
c.dataReceived(b"prologue\n")
|
||||
self.assertEqual(n.mock_calls, [])
|
||||
self.assertEqual(connector.mock_calls, [])
|
||||
self.assertEqual(t.mock_calls, [mock.call.loseConnection()])
|
||||
|
||||
eq.flush_sync()
|
||||
self.assertNoResult(d)
|
||||
c.connectionLost(b"why")
|
||||
eq.flush_sync()
|
||||
self.assertIdentical(self.successResultOf(d), c)
|
||||
|
||||
def _test_no_relay(self, role):
|
||||
c, n, connector, t, eq = make_con(role)
|
||||
t_kcm = KCM()
|
||||
t_open = Open(seqnum=1, scid=0x11223344)
|
||||
t_ack = Ack(resp_seqnum=2)
|
||||
n.decrypt = mock.Mock(side_effect=[
|
||||
encode_record(t_kcm),
|
||||
encode_record(t_open),
|
||||
])
|
||||
exp_kcm = b"\x00\x00\x00\x03kcm"
|
||||
n.encrypt = mock.Mock(side_effect=[b"kcm", b"ack1"])
|
||||
m = mock.Mock() # Manager
|
||||
|
||||
c.makeConnection(t)
|
||||
self.assertEqual(n.mock_calls, [mock.call.start_handshake()])
|
||||
self.assertEqual(connector.mock_calls, [])
|
||||
self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")])
|
||||
clear_mock_calls(n, connector, t, m)
|
||||
|
||||
c.dataReceived(b"inbound_prologue\n")
|
||||
self.assertEqual(n.mock_calls, [mock.call.write_message()])
|
||||
self.assertEqual(connector.mock_calls, [])
|
||||
exp_handshake = b"\x00\x00\x00\x09handshake"
|
||||
self.assertEqual(t.mock_calls, [mock.call.write(exp_handshake)])
|
||||
clear_mock_calls(n, connector, t, m)
|
||||
|
||||
c.dataReceived(b"\x00\x00\x00\x0Ahandshake2")
|
||||
if role is LEADER:
|
||||
# we're the leader, so we don't send the KCM right away
|
||||
self.assertEqual(n.mock_calls, [
|
||||
mock.call.read_message(b"handshake2")])
|
||||
self.assertEqual(connector.mock_calls, [])
|
||||
self.assertEqual(t.mock_calls, [])
|
||||
self.assertEqual(c._manager, None)
|
||||
else:
|
||||
# we're the follower, so we encrypt and send the KCM immediately
|
||||
self.assertEqual(n.mock_calls, [
|
||||
mock.call.read_message(b"handshake2"),
|
||||
mock.call.encrypt(encode_record(t_kcm)),
|
||||
])
|
||||
self.assertEqual(connector.mock_calls, [])
|
||||
self.assertEqual(t.mock_calls, [
|
||||
mock.call.write(exp_kcm)])
|
||||
self.assertEqual(c._manager, None)
|
||||
clear_mock_calls(n, connector, t, m)
|
||||
|
||||
c.dataReceived(b"\x00\x00\x00\x03KCM")
|
||||
# leader: inbound KCM means we add the candidate
|
||||
# follower: inbound KCM means we've been selected.
|
||||
# in both cases we notify Connector.add_candidate(), and the Connector
|
||||
# decides if/when to call .select()
|
||||
|
||||
self.assertEqual(n.mock_calls, [mock.call.decrypt(b"KCM")])
|
||||
self.assertEqual(connector.mock_calls, [mock.call.add_candidate(c)])
|
||||
self.assertEqual(t.mock_calls, [])
|
||||
clear_mock_calls(n, connector, t, m)
|
||||
|
||||
# now pretend this connection wins (either the Leader decides to use
|
||||
# this one among all the candiates, or we're the Follower and the
|
||||
# Connector is reacting to add_candidate() by recognizing we're the
|
||||
# only candidate there is)
|
||||
c.select(m)
|
||||
self.assertIdentical(c._manager, m)
|
||||
if role is LEADER:
|
||||
# TODO: currently Connector.select_and_stop_remaining() is
|
||||
# responsible for sending the KCM just before calling c.select()
|
||||
# iff we're the LEADER, therefore Connection.select won't send
|
||||
# anything. This should be moved to c.select().
|
||||
self.assertEqual(n.mock_calls, [])
|
||||
self.assertEqual(connector.mock_calls, [])
|
||||
self.assertEqual(t.mock_calls, [])
|
||||
self.assertEqual(m.mock_calls, [])
|
||||
|
||||
c.send_record(KCM())
|
||||
self.assertEqual(n.mock_calls, [
|
||||
mock.call.encrypt(encode_record(t_kcm)),
|
||||
])
|
||||
self.assertEqual(connector.mock_calls, [])
|
||||
self.assertEqual(t.mock_calls, [mock.call.write(exp_kcm)])
|
||||
self.assertEqual(m.mock_calls, [])
|
||||
else:
|
||||
# follower: we already sent the KCM, do nothing
|
||||
self.assertEqual(n.mock_calls, [])
|
||||
self.assertEqual(connector.mock_calls, [])
|
||||
self.assertEqual(t.mock_calls, [])
|
||||
self.assertEqual(m.mock_calls, [])
|
||||
clear_mock_calls(n, connector, t, m)
|
||||
|
||||
c.dataReceived(b"\x00\x00\x00\x04msg1")
|
||||
self.assertEqual(n.mock_calls, [mock.call.decrypt(b"msg1")])
|
||||
self.assertEqual(connector.mock_calls, [])
|
||||
self.assertEqual(t.mock_calls, [])
|
||||
self.assertEqual(m.mock_calls, [mock.call.got_record(t_open)])
|
||||
clear_mock_calls(n, connector, t, m)
|
||||
|
||||
c.send_record(t_ack)
|
||||
exp_ack = b"\x06\x00\x00\x00\x02"
|
||||
self.assertEqual(n.mock_calls, [mock.call.encrypt(exp_ack)])
|
||||
self.assertEqual(connector.mock_calls, [])
|
||||
self.assertEqual(t.mock_calls, [mock.call.write(b"\x00\x00\x00\x04ack1")])
|
||||
self.assertEqual(m.mock_calls, [])
|
||||
clear_mock_calls(n, connector, t, m)
|
||||
|
||||
c.disconnect()
|
||||
self.assertEqual(n.mock_calls, [])
|
||||
self.assertEqual(connector.mock_calls, [])
|
||||
self.assertEqual(t.mock_calls, [mock.call.loseConnection()])
|
||||
self.assertEqual(m.mock_calls, [])
|
||||
clear_mock_calls(n, connector, t, m)
|
||||
|
||||
def test_no_relay_leader(self):
|
||||
return self._test_no_relay(LEADER)
|
||||
|
||||
def test_no_relay_follower(self):
|
||||
return self._test_no_relay(FOLLOWER)
|
||||
|
||||
|
||||
def test_relay(self):
|
||||
c, n, connector, t, eq = make_con(LEADER, use_relay=True)
|
||||
|
||||
c.makeConnection(t)
|
||||
self.assertEqual(n.mock_calls, [mock.call.start_handshake()])
|
||||
self.assertEqual(connector.mock_calls, [])
|
||||
self.assertEqual(t.mock_calls, [mock.call.write(b"relay_handshake\n")])
|
||||
clear_mock_calls(n, connector, t)
|
||||
|
||||
c.dataReceived(b"ok\n")
|
||||
self.assertEqual(n.mock_calls, [])
|
||||
self.assertEqual(connector.mock_calls, [])
|
||||
self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")])
|
||||
clear_mock_calls(n, connector, t)
|
||||
|
||||
c.dataReceived(b"inbound_prologue\n")
|
||||
self.assertEqual(n.mock_calls, [mock.call.write_message()])
|
||||
self.assertEqual(connector.mock_calls, [])
|
||||
exp_handshake = b"\x00\x00\x00\x09handshake"
|
||||
self.assertEqual(t.mock_calls, [mock.call.write(exp_handshake)])
|
||||
clear_mock_calls(n, connector, t)
|
||||
|
||||
def test_relay_jilted(self):
|
||||
c, n, connector, t, eq = make_con(LEADER, use_relay=True)
|
||||
d = c.when_disconnected()
|
||||
|
||||
c.makeConnection(t)
|
||||
self.assertEqual(n.mock_calls, [mock.call.start_handshake()])
|
||||
self.assertEqual(connector.mock_calls, [])
|
||||
self.assertEqual(t.mock_calls, [mock.call.write(b"relay_handshake\n")])
|
||||
clear_mock_calls(n, connector, t)
|
||||
|
||||
c.connectionLost(b"why")
|
||||
eq.flush_sync()
|
||||
self.assertIdentical(self.successResultOf(d), c)
|
||||
|
||||
def test_relay_bad_response(self):
|
||||
c, n, connector, t, eq = make_con(LEADER, use_relay=True)
|
||||
|
||||
c.makeConnection(t)
|
||||
self.assertEqual(n.mock_calls, [mock.call.start_handshake()])
|
||||
self.assertEqual(connector.mock_calls, [])
|
||||
self.assertEqual(t.mock_calls, [mock.call.write(b"relay_handshake\n")])
|
||||
clear_mock_calls(n, connector, t)
|
||||
|
||||
c.dataReceived(b"not ok\n")
|
||||
self.assertEqual(n.mock_calls, [])
|
||||
self.assertEqual(connector.mock_calls, [])
|
||||
self.assertEqual(t.mock_calls, [mock.call.loseConnection()])
|
||||
clear_mock_calls(n, connector, t)
|
25
src/wormhole/test/dilate/test_encoding.py
Normal file
25
src/wormhole/test/dilate/test_encoding.py
Normal file
|
@ -0,0 +1,25 @@
|
|||
from __future__ import print_function, unicode_literals
|
||||
from twisted.trial import unittest
|
||||
from ..._dilation.encode import to_be4, from_be4
|
||||
|
||||
class Encoding(unittest.TestCase):
|
||||
|
||||
def test_be4(self):
|
||||
self.assertEqual(to_be4(0), b"\x00\x00\x00\x00")
|
||||
self.assertEqual(to_be4(1), b"\x00\x00\x00\x01")
|
||||
self.assertEqual(to_be4(256), b"\x00\x00\x01\x00")
|
||||
self.assertEqual(to_be4(257), b"\x00\x00\x01\x01")
|
||||
with self.assertRaises(ValueError):
|
||||
to_be4(-1)
|
||||
with self.assertRaises(ValueError):
|
||||
to_be4(2**32)
|
||||
|
||||
self.assertEqual(from_be4(b"\x00\x00\x00\x00"), 0)
|
||||
self.assertEqual(from_be4(b"\x00\x00\x00\x01"), 1)
|
||||
self.assertEqual(from_be4(b"\x00\x00\x01\x00"), 256)
|
||||
self.assertEqual(from_be4(b"\x00\x00\x01\x01"), 257)
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
from_be4(0)
|
||||
with self.assertRaises(ValueError):
|
||||
from_be4(b"\x01\x00\x00\x00\x00")
|
97
src/wormhole/test/dilate/test_endpoints.py
Normal file
97
src/wormhole/test/dilate/test_endpoints.py
Normal file
|
@ -0,0 +1,97 @@
|
|||
from __future__ import print_function, unicode_literals
|
||||
import mock
|
||||
from zope.interface import alsoProvides
|
||||
from twisted.trial import unittest
|
||||
from ..._interfaces import ISubChannel
|
||||
from ..._dilation.subchannel import (ControlEndpoint,
|
||||
SubchannelConnectorEndpoint,
|
||||
SubchannelListenerEndpoint,
|
||||
SubchannelListeningPort,
|
||||
_WormholeAddress, _SubchannelAddress,
|
||||
SingleUseEndpointError)
|
||||
from .common import mock_manager
|
||||
|
||||
class Endpoints(unittest.TestCase):
|
||||
def test_control(self):
|
||||
scid0 = b"scid0"
|
||||
peeraddr = _SubchannelAddress(scid0)
|
||||
ep = ControlEndpoint(peeraddr)
|
||||
|
||||
f = mock.Mock()
|
||||
p = mock.Mock()
|
||||
f.buildProtocol = mock.Mock(return_value=p)
|
||||
d = ep.connect(f)
|
||||
self.assertNoResult(d)
|
||||
|
||||
t = mock.Mock()
|
||||
alsoProvides(t, ISubChannel)
|
||||
ep._subchannel_zero_opened(t)
|
||||
self.assertIdentical(self.successResultOf(d), p)
|
||||
self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr)])
|
||||
self.assertEqual(t.mock_calls, [mock.call._set_protocol(p)])
|
||||
self.assertEqual(p.mock_calls, [mock.call.makeConnection(t)])
|
||||
|
||||
d = ep.connect(f)
|
||||
self.failureResultOf(d, SingleUseEndpointError)
|
||||
|
||||
def assert_makeConnection(self, mock_calls):
|
||||
self.assertEqual(len(mock_calls), 1)
|
||||
self.assertEqual(mock_calls[0][0], "makeConnection")
|
||||
self.assertEqual(len(mock_calls[0][1]), 1)
|
||||
return mock_calls[0][1][0]
|
||||
|
||||
def test_connector(self):
|
||||
m = mock_manager()
|
||||
m.allocate_subchannel_id = mock.Mock(return_value=b"scid")
|
||||
hostaddr = _WormholeAddress()
|
||||
peeraddr = _SubchannelAddress(b"scid")
|
||||
ep = SubchannelConnectorEndpoint(m, hostaddr)
|
||||
|
||||
f = mock.Mock()
|
||||
p = mock.Mock()
|
||||
t = mock.Mock()
|
||||
f.buildProtocol = mock.Mock(return_value=p)
|
||||
with mock.patch("wormhole._dilation.subchannel.SubChannel",
|
||||
return_value=t) as sc:
|
||||
d = ep.connect(f)
|
||||
self.assertIdentical(self.successResultOf(d), p)
|
||||
self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr)])
|
||||
self.assertEqual(sc.mock_calls, [mock.call(b"scid", m, hostaddr, peeraddr)])
|
||||
self.assertEqual(t.mock_calls, [mock.call._set_protocol(p)])
|
||||
self.assertEqual(p.mock_calls, [mock.call.makeConnection(t)])
|
||||
|
||||
def test_listener(self):
|
||||
m = mock_manager()
|
||||
m.allocate_subchannel_id = mock.Mock(return_value=b"scid")
|
||||
hostaddr = _WormholeAddress()
|
||||
ep = SubchannelListenerEndpoint(m, hostaddr)
|
||||
|
||||
f = mock.Mock()
|
||||
p1 = mock.Mock()
|
||||
p2 = mock.Mock()
|
||||
f.buildProtocol = mock.Mock(side_effect=[p1, p2])
|
||||
|
||||
# OPEN that arrives before we ep.listen() should be queued
|
||||
|
||||
t1 = mock.Mock()
|
||||
peeraddr1 = _SubchannelAddress(b"peer1")
|
||||
ep._got_open(t1, peeraddr1)
|
||||
|
||||
d = ep.listen(f)
|
||||
lp = self.successResultOf(d)
|
||||
self.assertIsInstance(lp, SubchannelListeningPort)
|
||||
|
||||
self.assertEqual(lp.getHost(), hostaddr)
|
||||
lp.startListening()
|
||||
|
||||
self.assertEqual(t1.mock_calls, [mock.call._set_protocol(p1)])
|
||||
self.assertEqual(p1.mock_calls, [mock.call.makeConnection(t1)])
|
||||
|
||||
t2 = mock.Mock()
|
||||
peeraddr2 = _SubchannelAddress(b"peer2")
|
||||
ep._got_open(t2, peeraddr2)
|
||||
|
||||
self.assertEqual(t2.mock_calls, [mock.call._set_protocol(p2)])
|
||||
self.assertEqual(p2.mock_calls, [mock.call.makeConnection(t2)])
|
||||
|
||||
lp.stopListening() # TODO: should this do more?
|
110
src/wormhole/test/dilate/test_framer.py
Normal file
110
src/wormhole/test/dilate/test_framer.py
Normal file
|
@ -0,0 +1,110 @@
|
|||
from __future__ import print_function, unicode_literals
|
||||
import mock
|
||||
from zope.interface import alsoProvides
|
||||
from twisted.trial import unittest
|
||||
from twisted.internet.interfaces import ITransport
|
||||
from ..._dilation.connection import _Framer, Frame, Prologue, Disconnect
|
||||
|
||||
def make_framer():
|
||||
t = mock.Mock()
|
||||
alsoProvides(t, ITransport)
|
||||
f = _Framer(t, b"outbound_prologue\n", b"inbound_prologue\n")
|
||||
return f, t
|
||||
|
||||
class Framer(unittest.TestCase):
|
||||
def test_bad_prologue_length(self):
|
||||
f, t = make_framer()
|
||||
self.assertEqual(t.mock_calls, [])
|
||||
|
||||
f.connectionMade()
|
||||
self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")])
|
||||
t.mock_calls[:] = []
|
||||
self.assertEqual([], list(f.add_and_parse(b"inbound_"))) # wait for it
|
||||
self.assertEqual(t.mock_calls, [])
|
||||
|
||||
with mock.patch("wormhole._dilation.connection.log.msg") as m:
|
||||
with self.assertRaises(Disconnect):
|
||||
list(f.add_and_parse(b"not the prologue after all"))
|
||||
self.assertEqual(m.mock_calls,
|
||||
[mock.call("bad prologue: {}".format(
|
||||
b"inbound_not the p"))])
|
||||
self.assertEqual(t.mock_calls, [])
|
||||
|
||||
def test_bad_prologue_newline(self):
|
||||
f, t = make_framer()
|
||||
self.assertEqual(t.mock_calls, [])
|
||||
|
||||
f.connectionMade()
|
||||
self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")])
|
||||
t.mock_calls[:] = []
|
||||
self.assertEqual([], list(f.add_and_parse(b"inbound_"))) # wait for it
|
||||
|
||||
self.assertEqual([], list(f.add_and_parse(b"not")))
|
||||
with mock.patch("wormhole._dilation.connection.log.msg") as m:
|
||||
with self.assertRaises(Disconnect):
|
||||
list(f.add_and_parse(b"\n"))
|
||||
self.assertEqual(m.mock_calls,
|
||||
[mock.call("bad prologue: {}".format(
|
||||
b"inbound_not\n"))])
|
||||
self.assertEqual(t.mock_calls, [])
|
||||
|
||||
def test_good_prologue(self):
|
||||
f, t = make_framer()
|
||||
self.assertEqual(t.mock_calls, [])
|
||||
|
||||
f.connectionMade()
|
||||
self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")])
|
||||
t.mock_calls[:] = []
|
||||
self.assertEqual([Prologue()],
|
||||
list(f.add_and_parse(b"inbound_prologue\n")))
|
||||
self.assertEqual(t.mock_calls, [])
|
||||
|
||||
# now send_frame should work
|
||||
f.send_frame(b"frame")
|
||||
self.assertEqual(t.mock_calls,
|
||||
[mock.call.write(b"\x00\x00\x00\x05frame")])
|
||||
|
||||
def test_bad_relay(self):
|
||||
f, t = make_framer()
|
||||
self.assertEqual(t.mock_calls, [])
|
||||
f.use_relay(b"relay handshake\n")
|
||||
|
||||
f.connectionMade()
|
||||
self.assertEqual(t.mock_calls, [mock.call.write(b"relay handshake\n")])
|
||||
t.mock_calls[:] = []
|
||||
with mock.patch("wormhole._dilation.connection.log.msg") as m:
|
||||
with self.assertRaises(Disconnect):
|
||||
list(f.add_and_parse(b"goodbye\n"))
|
||||
self.assertEqual(m.mock_calls,
|
||||
[mock.call("bad relay_ok: {}".format(b"goo"))])
|
||||
self.assertEqual(t.mock_calls, [])
|
||||
|
||||
def test_good_relay(self):
|
||||
f, t = make_framer()
|
||||
self.assertEqual(t.mock_calls, [])
|
||||
f.use_relay(b"relay handshake\n")
|
||||
self.assertEqual(t.mock_calls, [])
|
||||
|
||||
f.connectionMade()
|
||||
self.assertEqual(t.mock_calls, [mock.call.write(b"relay handshake\n")])
|
||||
t.mock_calls[:] = []
|
||||
|
||||
self.assertEqual([], list(f.add_and_parse(b"ok\n")))
|
||||
self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")])
|
||||
|
||||
def test_frame(self):
|
||||
f, t = make_framer()
|
||||
self.assertEqual(t.mock_calls, [])
|
||||
|
||||
f.connectionMade()
|
||||
self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")])
|
||||
t.mock_calls[:] = []
|
||||
self.assertEqual([Prologue()],
|
||||
list(f.add_and_parse(b"inbound_prologue\n")))
|
||||
self.assertEqual(t.mock_calls, [])
|
||||
|
||||
encoded_frame = b"\x00\x00\x00\x05frame"
|
||||
self.assertEqual([], list(f.add_and_parse(encoded_frame[:2])))
|
||||
self.assertEqual([], list(f.add_and_parse(encoded_frame[2:6])))
|
||||
self.assertEqual([Frame(frame=b"frame")],
|
||||
list(f.add_and_parse(encoded_frame[6:])))
|
172
src/wormhole/test/dilate/test_inbound.py
Normal file
172
src/wormhole/test/dilate/test_inbound.py
Normal file
|
@ -0,0 +1,172 @@
|
|||
from __future__ import print_function, unicode_literals
|
||||
import mock
|
||||
from zope.interface import alsoProvides
|
||||
from twisted.trial import unittest
|
||||
from ..._interfaces import IDilationManager
|
||||
from ..._dilation.connection import Open, Data, Close
|
||||
from ..._dilation.inbound import (Inbound, DuplicateOpenError,
|
||||
DataForMissingSubchannelError,
|
||||
CloseForMissingSubchannelError)
|
||||
|
||||
def make_inbound():
|
||||
m = mock.Mock()
|
||||
alsoProvides(m, IDilationManager)
|
||||
host_addr = object()
|
||||
i = Inbound(m, host_addr)
|
||||
return i, m, host_addr
|
||||
|
||||
class InboundTest(unittest.TestCase):
|
||||
def test_seqnum(self):
|
||||
i, m, host_addr = make_inbound()
|
||||
r1 = Open(scid=513, seqnum=1)
|
||||
r2 = Data(scid=513, seqnum=2, data=b"")
|
||||
r3 = Close(scid=513, seqnum=3)
|
||||
self.assertFalse(i.is_record_old(r1))
|
||||
self.assertFalse(i.is_record_old(r2))
|
||||
self.assertFalse(i.is_record_old(r3))
|
||||
|
||||
i.update_ack_watermark(r1)
|
||||
self.assertTrue(i.is_record_old(r1))
|
||||
self.assertFalse(i.is_record_old(r2))
|
||||
self.assertFalse(i.is_record_old(r3))
|
||||
|
||||
i.update_ack_watermark(r2)
|
||||
self.assertTrue(i.is_record_old(r1))
|
||||
self.assertTrue(i.is_record_old(r2))
|
||||
self.assertFalse(i.is_record_old(r3))
|
||||
|
||||
def test_open_data_close(self):
|
||||
i, m, host_addr = make_inbound()
|
||||
scid1 = b"scid"
|
||||
scid2 = b"scXX"
|
||||
c = mock.Mock()
|
||||
lep = mock.Mock()
|
||||
i.set_listener_endpoint(lep)
|
||||
i.use_connection(c)
|
||||
sc1 = mock.Mock()
|
||||
peer_addr = object()
|
||||
with mock.patch("wormhole._dilation.inbound.SubChannel",
|
||||
side_effect=[sc1]) as sc:
|
||||
with mock.patch("wormhole._dilation.inbound._SubchannelAddress",
|
||||
side_effect=[peer_addr]) as sca:
|
||||
i.handle_open(scid1)
|
||||
self.assertEqual(lep.mock_calls, [mock.call._got_open(sc1, peer_addr)])
|
||||
self.assertEqual(sc.mock_calls, [mock.call(scid1, m, host_addr, peer_addr)])
|
||||
self.assertEqual(sca.mock_calls, [mock.call(scid1)])
|
||||
lep.mock_calls[:] = []
|
||||
|
||||
# a subsequent duplicate OPEN should be ignored
|
||||
with mock.patch("wormhole._dilation.inbound.SubChannel",
|
||||
side_effect=[sc1]) as sc:
|
||||
with mock.patch("wormhole._dilation.inbound._SubchannelAddress",
|
||||
side_effect=[peer_addr]) as sca:
|
||||
i.handle_open(scid1)
|
||||
self.assertEqual(lep.mock_calls, [])
|
||||
self.assertEqual(sc.mock_calls, [])
|
||||
self.assertEqual(sca.mock_calls, [])
|
||||
self.flushLoggedErrors(DuplicateOpenError)
|
||||
|
||||
i.handle_data(scid1, b"data")
|
||||
self.assertEqual(sc1.mock_calls, [mock.call.remote_data(b"data")])
|
||||
sc1.mock_calls[:] = []
|
||||
|
||||
i.handle_data(scid2, b"for non-existent subchannel")
|
||||
self.assertEqual(sc1.mock_calls, [])
|
||||
self.flushLoggedErrors(DataForMissingSubchannelError)
|
||||
|
||||
i.handle_close(scid1)
|
||||
self.assertEqual(sc1.mock_calls, [mock.call.remote_close()])
|
||||
sc1.mock_calls[:] = []
|
||||
|
||||
i.handle_close(scid2)
|
||||
self.assertEqual(sc1.mock_calls, [])
|
||||
self.flushLoggedErrors(CloseForMissingSubchannelError)
|
||||
|
||||
# after the subchannel is closed, the Manager will notify Inbound
|
||||
i.subchannel_closed(scid1, sc1)
|
||||
|
||||
i.stop_using_connection()
|
||||
|
||||
def test_control_channel(self):
|
||||
i, m, host_addr = make_inbound()
|
||||
lep = mock.Mock()
|
||||
i.set_listener_endpoint(lep)
|
||||
|
||||
scid0 = b"scid"
|
||||
sc0 = mock.Mock()
|
||||
i.set_subchannel_zero(scid0, sc0)
|
||||
|
||||
# OPEN on the control channel identifier should be ignored as a
|
||||
# duplicate, since the control channel is already registered
|
||||
sc1 = mock.Mock()
|
||||
peer_addr = object()
|
||||
with mock.patch("wormhole._dilation.inbound.SubChannel",
|
||||
side_effect=[sc1]) as sc:
|
||||
with mock.patch("wormhole._dilation.inbound._SubchannelAddress",
|
||||
side_effect=[peer_addr]) as sca:
|
||||
i.handle_open(scid0)
|
||||
self.assertEqual(lep.mock_calls, [])
|
||||
self.assertEqual(sc.mock_calls, [])
|
||||
self.assertEqual(sca.mock_calls, [])
|
||||
self.flushLoggedErrors(DuplicateOpenError)
|
||||
|
||||
# and DATA to it should be delivered correctly
|
||||
i.handle_data(scid0, b"data")
|
||||
self.assertEqual(sc0.mock_calls, [mock.call.remote_data(b"data")])
|
||||
sc0.mock_calls[:] = []
|
||||
|
||||
def test_pause(self):
|
||||
i, m, host_addr = make_inbound()
|
||||
c = mock.Mock()
|
||||
lep = mock.Mock()
|
||||
i.set_listener_endpoint(lep)
|
||||
|
||||
# add two subchannels, pause one, then add a connection
|
||||
scid1 = b"sci1"
|
||||
scid2 = b"sci2"
|
||||
sc1 = mock.Mock()
|
||||
sc2 = mock.Mock()
|
||||
peer_addr = object()
|
||||
with mock.patch("wormhole._dilation.inbound.SubChannel",
|
||||
side_effect=[sc1, sc2]):
|
||||
with mock.patch("wormhole._dilation.inbound._SubchannelAddress",
|
||||
return_value=peer_addr):
|
||||
i.handle_open(scid1)
|
||||
i.handle_open(scid2)
|
||||
self.assertEqual(c.mock_calls, [])
|
||||
|
||||
i.subchannel_pauseProducing(sc1)
|
||||
self.assertEqual(c.mock_calls, [])
|
||||
i.subchannel_resumeProducing(sc1)
|
||||
self.assertEqual(c.mock_calls, [])
|
||||
i.subchannel_pauseProducing(sc1)
|
||||
self.assertEqual(c.mock_calls, [])
|
||||
|
||||
i.use_connection(c)
|
||||
self.assertEqual(c.mock_calls, [mock.call.pauseProducing()])
|
||||
c.mock_calls[:] = []
|
||||
|
||||
i.subchannel_resumeProducing(sc1)
|
||||
self.assertEqual(c.mock_calls, [mock.call.resumeProducing()])
|
||||
c.mock_calls[:] = []
|
||||
|
||||
# consumers aren't really supposed to do this, but tolerate it
|
||||
i.subchannel_resumeProducing(sc1)
|
||||
self.assertEqual(c.mock_calls, [])
|
||||
|
||||
i.subchannel_pauseProducing(sc1)
|
||||
self.assertEqual(c.mock_calls, [mock.call.pauseProducing()])
|
||||
c.mock_calls[:] = []
|
||||
i.subchannel_pauseProducing(sc2)
|
||||
self.assertEqual(c.mock_calls, []) # was already paused
|
||||
|
||||
# tolerate duplicate pauseProducing
|
||||
i.subchannel_pauseProducing(sc2)
|
||||
self.assertEqual(c.mock_calls, [])
|
||||
|
||||
# stopProducing is treated like a terminal resumeProducing
|
||||
i.subchannel_stopProducing(sc1)
|
||||
self.assertEqual(c.mock_calls, [])
|
||||
i.subchannel_stopProducing(sc2)
|
||||
self.assertEqual(c.mock_calls, [mock.call.resumeProducing()])
|
||||
c.mock_calls[:] = []
|
205
src/wormhole/test/dilate/test_manager.py
Normal file
205
src/wormhole/test/dilate/test_manager.py
Normal file
|
@ -0,0 +1,205 @@
|
|||
from __future__ import print_function, unicode_literals
|
||||
from zope.interface import alsoProvides
|
||||
from twisted.trial import unittest
|
||||
from twisted.internet.defer import Deferred
|
||||
from twisted.internet.task import Clock, Cooperator
|
||||
import mock
|
||||
from ...eventual import EventualQueue
|
||||
from ..._interfaces import ISend, IDilationManager
|
||||
from ...util import dict_to_bytes
|
||||
from ..._dilation.manager import (Dilator,
|
||||
OldPeerCannotDilateError,
|
||||
UnknownDilationMessageType)
|
||||
from ..._dilation.subchannel import _WormholeAddress
|
||||
from .common import clear_mock_calls
|
||||
|
||||
def make_dilator():
|
||||
reactor = object()
|
||||
clock = Clock()
|
||||
eq = EventualQueue(clock)
|
||||
term = mock.Mock(side_effect=lambda: True) # one write per Eventual tick
|
||||
term_factory = lambda: term
|
||||
coop = Cooperator(terminationPredicateFactory=term_factory,
|
||||
scheduler=eq.eventually)
|
||||
send = mock.Mock()
|
||||
alsoProvides(send, ISend)
|
||||
dil = Dilator(reactor, eq, coop)
|
||||
dil.wire(send)
|
||||
return dil, send, reactor, eq, clock, coop
|
||||
|
||||
class TestDilator(unittest.TestCase):
|
||||
def test_leader(self):
|
||||
dil, send, reactor, eq, clock, coop = make_dilator()
|
||||
d1 = dil.dilate()
|
||||
d2 = dil.dilate()
|
||||
self.assertNoResult(d1)
|
||||
self.assertNoResult(d2)
|
||||
|
||||
key = b"key"
|
||||
transit_key = object()
|
||||
with mock.patch("wormhole._dilation.manager.derive_key",
|
||||
return_value=transit_key) as dk:
|
||||
dil.got_key(key)
|
||||
self.assertEqual(dk.mock_calls, [mock.call(key, b"dilation-v1", 32)])
|
||||
self.assertIdentical(dil._transit_key, transit_key)
|
||||
self.assertNoResult(d1)
|
||||
self.assertNoResult(d2)
|
||||
|
||||
m = mock.Mock()
|
||||
alsoProvides(m, IDilationManager)
|
||||
m.when_first_connected.return_value = wfc_d = Deferred()
|
||||
# TODO: test missing can-dilate, and no-overlap
|
||||
with mock.patch("wormhole._dilation.manager.ManagerLeader",
|
||||
return_value=m) as ml:
|
||||
dil.got_wormhole_versions("us", "them", {"can-dilate": [1]})
|
||||
# that should create the Manager. Because "us" > "them", we're
|
||||
# the leader
|
||||
self.assertEqual(ml.mock_calls, [mock.call(send, "us", transit_key,
|
||||
None, reactor, eq, coop)])
|
||||
self.assertEqual(m.mock_calls, [mock.call.start(),
|
||||
mock.call.when_first_connected(),
|
||||
])
|
||||
clear_mock_calls(m)
|
||||
self.assertNoResult(d1)
|
||||
self.assertNoResult(d2)
|
||||
|
||||
host_addr = _WormholeAddress()
|
||||
m_wa = mock.patch("wormhole._dilation.manager._WormholeAddress",
|
||||
return_value=host_addr)
|
||||
peer_addr = object()
|
||||
m_sca = mock.patch("wormhole._dilation.manager._SubchannelAddress",
|
||||
return_value=peer_addr)
|
||||
ce = mock.Mock()
|
||||
m_ce = mock.patch("wormhole._dilation.manager.ControlEndpoint",
|
||||
return_value=ce)
|
||||
sc = mock.Mock()
|
||||
m_sc = mock.patch("wormhole._dilation.manager.SubChannel",
|
||||
return_value=sc)
|
||||
|
||||
lep = object()
|
||||
m_sle = mock.patch("wormhole._dilation.manager.SubchannelListenerEndpoint",
|
||||
return_value=lep)
|
||||
|
||||
with m_wa, m_sca, m_ce as m_ce_m, m_sc as m_sc_m, m_sle as m_sle_m:
|
||||
wfc_d.callback(None)
|
||||
eq.flush_sync()
|
||||
scid0 = b"\x00\x00\x00\x00"
|
||||
self.assertEqual(m_ce_m.mock_calls, [mock.call(peer_addr)])
|
||||
self.assertEqual(m_sc_m.mock_calls,
|
||||
[mock.call(scid0, m, host_addr, peer_addr)])
|
||||
self.assertEqual(ce.mock_calls, [mock.call._subchannel_zero_opened(sc)])
|
||||
self.assertEqual(m_sle_m.mock_calls, [mock.call(m, host_addr)])
|
||||
self.assertEqual(m.mock_calls,
|
||||
[mock.call.set_subchannel_zero(scid0, sc),
|
||||
mock.call.set_listener_endpoint(lep),
|
||||
])
|
||||
clear_mock_calls(m)
|
||||
|
||||
eps = self.successResultOf(d1)
|
||||
self.assertEqual(eps, self.successResultOf(d2))
|
||||
d3 = dil.dilate()
|
||||
eq.flush_sync()
|
||||
self.assertEqual(eps, self.successResultOf(d3))
|
||||
|
||||
self.assertEqual(m.mock_calls, [])
|
||||
dil.received_dilate(dict_to_bytes(dict(type="please")))
|
||||
self.assertEqual(m.mock_calls, [mock.call.rx_PLEASE()])
|
||||
clear_mock_calls(m)
|
||||
|
||||
hintmsg = dict(type="connection-hints")
|
||||
dil.received_dilate(dict_to_bytes(hintmsg))
|
||||
self.assertEqual(m.mock_calls, [mock.call.rx_HINTS(hintmsg)])
|
||||
clear_mock_calls(m)
|
||||
|
||||
dil.received_dilate(dict_to_bytes(dict(type="dilate")))
|
||||
self.assertEqual(m.mock_calls, [mock.call.rx_DILATE()])
|
||||
clear_mock_calls(m)
|
||||
|
||||
dil.received_dilate(dict_to_bytes(dict(type="unknown")))
|
||||
self.assertEqual(m.mock_calls, [])
|
||||
self.flushLoggedErrors(UnknownDilationMessageType)
|
||||
|
||||
def test_follower(self):
|
||||
dil, send, reactor, eq, clock, coop = make_dilator()
|
||||
d1 = dil.dilate()
|
||||
self.assertNoResult(d1)
|
||||
|
||||
key = b"key"
|
||||
transit_key = object()
|
||||
with mock.patch("wormhole._dilation.manager.derive_key",
|
||||
return_value=transit_key):
|
||||
dil.got_key(key)
|
||||
|
||||
m = mock.Mock()
|
||||
alsoProvides(m, IDilationManager)
|
||||
m.when_first_connected.return_value = Deferred()
|
||||
with mock.patch("wormhole._dilation.manager.ManagerFollower",
|
||||
return_value=m) as mf:
|
||||
dil.got_wormhole_versions("me", "you", {"can-dilate": [1]})
|
||||
# "me" < "you", so we're the follower
|
||||
self.assertEqual(mf.mock_calls, [mock.call(send, "me", transit_key,
|
||||
None, reactor, eq, coop)])
|
||||
self.assertEqual(m.mock_calls, [mock.call.start(),
|
||||
mock.call.when_first_connected(),
|
||||
])
|
||||
|
||||
def test_peer_cannot_dilate(self):
|
||||
dil, send, reactor, eq, clock, coop = make_dilator()
|
||||
d1 = dil.dilate()
|
||||
self.assertNoResult(d1)
|
||||
|
||||
dil.got_wormhole_versions("me", "you", {}) # missing "can-dilate"
|
||||
eq.flush_sync()
|
||||
f = self.failureResultOf(d1)
|
||||
f.check(OldPeerCannotDilateError)
|
||||
|
||||
|
||||
def test_disjoint_versions(self):
|
||||
dil, send, reactor, eq, clock, coop = make_dilator()
|
||||
d1 = dil.dilate()
|
||||
self.assertNoResult(d1)
|
||||
|
||||
dil.got_wormhole_versions("me", "you", {"can-dilate": [-1]})
|
||||
eq.flush_sync()
|
||||
f = self.failureResultOf(d1)
|
||||
f.check(OldPeerCannotDilateError)
|
||||
|
||||
|
||||
def test_early_dilate_messages(self):
|
||||
dil, send, reactor, eq, clock, coop = make_dilator()
|
||||
dil._transit_key = b"key"
|
||||
d1 = dil.dilate()
|
||||
self.assertNoResult(d1)
|
||||
dil.received_dilate(dict_to_bytes(dict(type="please")))
|
||||
hintmsg = dict(type="connection-hints")
|
||||
dil.received_dilate(dict_to_bytes(hintmsg))
|
||||
|
||||
m = mock.Mock()
|
||||
alsoProvides(m, IDilationManager)
|
||||
m.when_first_connected.return_value = Deferred()
|
||||
|
||||
with mock.patch("wormhole._dilation.manager.ManagerLeader",
|
||||
return_value=m) as ml:
|
||||
dil.got_wormhole_versions("us", "them", {"can-dilate": [1]})
|
||||
self.assertEqual(ml.mock_calls, [mock.call(send, "us", b"key",
|
||||
None, reactor, eq, coop)])
|
||||
self.assertEqual(m.mock_calls, [mock.call.start(),
|
||||
mock.call.rx_PLEASE(),
|
||||
mock.call.rx_HINTS(hintmsg),
|
||||
mock.call.when_first_connected()])
|
||||
|
||||
|
||||
|
||||
def test_transit_relay(self):
|
||||
dil, send, reactor, eq, clock, coop = make_dilator()
|
||||
dil._transit_key = b"key"
|
||||
relay = object()
|
||||
d1 = dil.dilate(transit_relay_location=relay)
|
||||
self.assertNoResult(d1)
|
||||
|
||||
with mock.patch("wormhole._dilation.manager.ManagerLeader") as ml:
|
||||
dil.got_wormhole_versions("us", "them", {"can-dilate": [1]})
|
||||
self.assertEqual(ml.mock_calls, [mock.call(send, "us", b"key",
|
||||
relay, reactor, eq, coop),
|
||||
mock.call().start(),
|
||||
mock.call().when_first_connected()])
|
645
src/wormhole/test/dilate/test_outbound.py
Normal file
645
src/wormhole/test/dilate/test_outbound.py
Normal file
|
@ -0,0 +1,645 @@
|
|||
from __future__ import print_function, unicode_literals
|
||||
from collections import namedtuple
|
||||
from itertools import cycle
|
||||
import mock
|
||||
from zope.interface import alsoProvides
|
||||
from twisted.trial import unittest
|
||||
from twisted.internet.task import Clock, Cooperator
|
||||
from twisted.internet.interfaces import IPullProducer
|
||||
from ...eventual import EventualQueue
|
||||
from ..._interfaces import IDilationManager
|
||||
from ..._dilation.connection import KCM, Open, Data, Close, Ack
|
||||
from ..._dilation.outbound import Outbound, PullToPush
|
||||
from .common import clear_mock_calls
|
||||
|
||||
Pauser = namedtuple("Pauser", ["seqnum"])
|
||||
NonPauser = namedtuple("NonPauser", ["seqnum"])
|
||||
Stopper = namedtuple("Stopper", ["sc"])
|
||||
|
||||
def make_outbound():
|
||||
m = mock.Mock()
|
||||
alsoProvides(m, IDilationManager)
|
||||
clock = Clock()
|
||||
eq = EventualQueue(clock)
|
||||
term = mock.Mock(side_effect=lambda: True) # one write per Eventual tick
|
||||
term_factory = lambda: term
|
||||
coop = Cooperator(terminationPredicateFactory=term_factory,
|
||||
scheduler=eq.eventually)
|
||||
o = Outbound(m, coop)
|
||||
c = mock.Mock() # Connection
|
||||
def maybe_pause(r):
|
||||
if isinstance(r, Pauser):
|
||||
o.pauseProducing()
|
||||
elif isinstance(r, Stopper):
|
||||
o.subchannel_unregisterProducer(r.sc)
|
||||
c.send_record = mock.Mock(side_effect=maybe_pause)
|
||||
o._test_eq = eq
|
||||
o._test_term = term
|
||||
return o, m, c
|
||||
|
||||
class OutboundTest(unittest.TestCase):
|
||||
def test_build_record(self):
|
||||
o, m, c = make_outbound()
|
||||
scid1 = b"scid"
|
||||
self.assertEqual(o.build_record(Open, scid1),
|
||||
Open(seqnum=0, scid=b"scid"))
|
||||
self.assertEqual(o.build_record(Data, scid1, b"dataaa"),
|
||||
Data(seqnum=1, scid=b"scid", data=b"dataaa"))
|
||||
self.assertEqual(o.build_record(Close, scid1),
|
||||
Close(seqnum=2, scid=b"scid"))
|
||||
self.assertEqual(o.build_record(Close, scid1),
|
||||
Close(seqnum=3, scid=b"scid"))
|
||||
|
||||
def test_outbound_queue(self):
|
||||
o, m, c = make_outbound()
|
||||
scid1 = b"scid"
|
||||
r1 = o.build_record(Open, scid1)
|
||||
r2 = o.build_record(Data, scid1, b"data1")
|
||||
r3 = o.build_record(Data, scid1, b"data2")
|
||||
o.queue_and_send_record(r1)
|
||||
o.queue_and_send_record(r2)
|
||||
o.queue_and_send_record(r3)
|
||||
self.assertEqual(list(o._outbound_queue), [r1, r2, r3])
|
||||
|
||||
# we would never normally receive an ACK without first getting a
|
||||
# connection
|
||||
o.handle_ack(r2.seqnum)
|
||||
self.assertEqual(list(o._outbound_queue), [r3])
|
||||
|
||||
o.handle_ack(r3.seqnum)
|
||||
self.assertEqual(list(o._outbound_queue), [])
|
||||
|
||||
o.handle_ack(r3.seqnum) # ignored
|
||||
self.assertEqual(list(o._outbound_queue), [])
|
||||
|
||||
o.handle_ack(r1.seqnum) # ignored
|
||||
self.assertEqual(list(o._outbound_queue), [])
|
||||
|
||||
def test_duplicate_registerProducer(self):
|
||||
o, m, c = make_outbound()
|
||||
sc1 = object()
|
||||
p1 = mock.Mock()
|
||||
o.subchannel_registerProducer(sc1, p1, True)
|
||||
with self.assertRaises(ValueError) as ar:
|
||||
o.subchannel_registerProducer(sc1, p1, True)
|
||||
s = str(ar.exception)
|
||||
self.assertIn("registering producer", s)
|
||||
self.assertIn("before previous one", s)
|
||||
self.assertIn("was unregistered", s)
|
||||
|
||||
def test_connection_send_queued_unpaused(self):
|
||||
o, m, c = make_outbound()
|
||||
scid1 = b"scid"
|
||||
r1 = o.build_record(Open, scid1)
|
||||
r2 = o.build_record(Data, scid1, b"data1")
|
||||
r3 = o.build_record(Data, scid1, b"data2")
|
||||
o.queue_and_send_record(r1)
|
||||
o.queue_and_send_record(r2)
|
||||
self.assertEqual(list(o._outbound_queue), [r1, r2])
|
||||
self.assertEqual(list(o._queued_unsent), [])
|
||||
|
||||
# as soon as the connection is established, everything is sent
|
||||
o.use_connection(c)
|
||||
self.assertEqual(c.mock_calls, [mock.call.registerProducer(o, True),
|
||||
mock.call.send_record(r1),
|
||||
mock.call.send_record(r2)])
|
||||
self.assertEqual(list(o._outbound_queue), [r1, r2])
|
||||
self.assertEqual(list(o._queued_unsent), [])
|
||||
clear_mock_calls(c)
|
||||
|
||||
o.queue_and_send_record(r3)
|
||||
self.assertEqual(list(o._outbound_queue), [r1, r2, r3])
|
||||
self.assertEqual(list(o._queued_unsent), [])
|
||||
self.assertEqual(c.mock_calls, [mock.call.send_record(r3)])
|
||||
|
||||
def test_connection_send_queued_paused(self):
|
||||
o, m, c = make_outbound()
|
||||
r1 = Pauser(seqnum=1)
|
||||
r2 = Pauser(seqnum=2)
|
||||
r3 = Pauser(seqnum=3)
|
||||
o.queue_and_send_record(r1)
|
||||
o.queue_and_send_record(r2)
|
||||
self.assertEqual(list(o._outbound_queue), [r1, r2])
|
||||
self.assertEqual(list(o._queued_unsent), [])
|
||||
|
||||
# pausing=True, so our mock Manager will pause the Outbound producer
|
||||
# after each write. So only r1 should have been sent before getting
|
||||
# paused
|
||||
o.use_connection(c)
|
||||
self.assertEqual(c.mock_calls, [mock.call.registerProducer(o, True),
|
||||
mock.call.send_record(r1)])
|
||||
self.assertEqual(list(o._outbound_queue), [r1, r2])
|
||||
self.assertEqual(list(o._queued_unsent), [r2])
|
||||
clear_mock_calls(c)
|
||||
|
||||
# Outbound is responsible for sending all records, so when Manager
|
||||
# wants to send a new one, and Outbound is still in the middle of
|
||||
# draining the beginning-of-connection queue, the new message gets
|
||||
# queued behind the rest (in addition to being queued in
|
||||
# _outbound_queue until an ACK retires it).
|
||||
o.queue_and_send_record(r3)
|
||||
self.assertEqual(list(o._outbound_queue), [r1, r2, r3])
|
||||
self.assertEqual(list(o._queued_unsent), [r2, r3])
|
||||
self.assertEqual(c.mock_calls, [])
|
||||
|
||||
o.handle_ack(r1.seqnum)
|
||||
self.assertEqual(list(o._outbound_queue), [r2, r3])
|
||||
self.assertEqual(list(o._queued_unsent), [r2, r3])
|
||||
self.assertEqual(c.mock_calls, [])
|
||||
|
||||
def test_premptive_ack(self):
|
||||
# one mode I have in mind is for each side to send an immediate ACK,
|
||||
# with everything they've ever seen, as the very first message on each
|
||||
# new connection. The idea is that you might preempt sending stuff from
|
||||
# the _queued_unsent list if it arrives fast enough (in practice this
|
||||
# is more likely to be delivered via the DILATE mailbox message, but
|
||||
# the effects might be vaguely similar, so it seems worth testing
|
||||
# here). A similar situation would be if each side sends ACKs with the
|
||||
# highest seqnum they've ever seen, instead of merely ACKing the
|
||||
# message which was just received.
|
||||
o, m, c = make_outbound()
|
||||
r1 = Pauser(seqnum=1)
|
||||
r2 = Pauser(seqnum=2)
|
||||
r3 = Pauser(seqnum=3)
|
||||
o.queue_and_send_record(r1)
|
||||
o.queue_and_send_record(r2)
|
||||
self.assertEqual(list(o._outbound_queue), [r1, r2])
|
||||
self.assertEqual(list(o._queued_unsent), [])
|
||||
|
||||
o.use_connection(c)
|
||||
self.assertEqual(c.mock_calls, [mock.call.registerProducer(o, True),
|
||||
mock.call.send_record(r1)])
|
||||
self.assertEqual(list(o._outbound_queue), [r1, r2])
|
||||
self.assertEqual(list(o._queued_unsent), [r2])
|
||||
clear_mock_calls(c)
|
||||
|
||||
o.queue_and_send_record(r3)
|
||||
self.assertEqual(list(o._outbound_queue), [r1, r2, r3])
|
||||
self.assertEqual(list(o._queued_unsent), [r2, r3])
|
||||
self.assertEqual(c.mock_calls, [])
|
||||
|
||||
o.handle_ack(r2.seqnum)
|
||||
self.assertEqual(list(o._outbound_queue), [r3])
|
||||
self.assertEqual(list(o._queued_unsent), [r3])
|
||||
self.assertEqual(c.mock_calls, [])
|
||||
|
||||
def test_pause(self):
|
||||
o, m, c = make_outbound()
|
||||
o.use_connection(c)
|
||||
self.assertEqual(c.mock_calls, [mock.call.registerProducer(o, True)])
|
||||
self.assertEqual(list(o._outbound_queue), [])
|
||||
self.assertEqual(list(o._queued_unsent), [])
|
||||
clear_mock_calls(c)
|
||||
|
||||
sc1, sc2, sc3 = object(), object(), object()
|
||||
p1, p2, p3 = mock.Mock(name="p1"), mock.Mock(name="p2"), mock.Mock(name="p3")
|
||||
|
||||
# we aren't paused yet, since we haven't sent any data
|
||||
o.subchannel_registerProducer(sc1, p1, True)
|
||||
self.assertEqual(p1.mock_calls, [])
|
||||
|
||||
r1 = Pauser(seqnum=1)
|
||||
o.queue_and_send_record(r1)
|
||||
# now we should be paused
|
||||
self.assertTrue(o._paused)
|
||||
self.assertEqual(c.mock_calls, [mock.call.send_record(r1)])
|
||||
self.assertEqual(p1.mock_calls, [mock.call.pauseProducing()])
|
||||
clear_mock_calls(p1, c)
|
||||
|
||||
# so an IPushProducer will be paused right away
|
||||
o.subchannel_registerProducer(sc2, p2, True)
|
||||
self.assertEqual(p2.mock_calls, [mock.call.pauseProducing()])
|
||||
clear_mock_calls(p2)
|
||||
|
||||
o.subchannel_registerProducer(sc3, p3, True)
|
||||
self.assertEqual(p3.mock_calls, [mock.call.pauseProducing()])
|
||||
self.assertEqual(o._paused_producers, set([p1, p2, p3]))
|
||||
self.assertEqual(list(o._all_producers), [p1, p2, p3])
|
||||
clear_mock_calls(p3)
|
||||
|
||||
# one resumeProducing should cause p1 to get a turn, since p2 was added
|
||||
# after we were paused and p1 was at the "end" of a one-element list.
|
||||
# If it writes anything, it will get paused again immediately.
|
||||
r2 = Pauser(seqnum=2)
|
||||
p1.resumeProducing.side_effect = lambda: c.send_record(r2)
|
||||
o.resumeProducing()
|
||||
self.assertEqual(p1.mock_calls, [mock.call.resumeProducing(),
|
||||
mock.call.pauseProducing(),
|
||||
])
|
||||
self.assertEqual(p2.mock_calls, [])
|
||||
self.assertEqual(p3.mock_calls, [])
|
||||
self.assertEqual(c.mock_calls, [mock.call.send_record(r2)])
|
||||
clear_mock_calls(p1, p2, p3, c)
|
||||
# p2 should now be at the head of the queue
|
||||
self.assertEqual(list(o._all_producers), [p2, p3, p1])
|
||||
|
||||
# next turn: p2 has nothing to send, but p3 does. we should see p3
|
||||
# called but not p1. The actual sequence of expected calls is:
|
||||
# p2.resume, p3.resume, pauseProducing, set(p2.pause, p3.pause)
|
||||
r3 = Pauser(seqnum=3)
|
||||
p2.resumeProducing.side_effect = lambda: None
|
||||
p3.resumeProducing.side_effect = lambda: c.send_record(r3)
|
||||
o.resumeProducing()
|
||||
self.assertEqual(p1.mock_calls, [])
|
||||
self.assertEqual(p2.mock_calls, [mock.call.resumeProducing(),
|
||||
mock.call.pauseProducing(),
|
||||
])
|
||||
self.assertEqual(p3.mock_calls, [mock.call.resumeProducing(),
|
||||
mock.call.pauseProducing(),
|
||||
])
|
||||
self.assertEqual(c.mock_calls, [mock.call.send_record(r3)])
|
||||
clear_mock_calls(p1, p2, p3, c)
|
||||
# p1 should now be at the head of the queue
|
||||
self.assertEqual(list(o._all_producers), [p1, p2, p3])
|
||||
|
||||
# next turn: p1 has data to send, but not enough to cause a pause. same
|
||||
# for p2. p3 causes a pause
|
||||
r4 = NonPauser(seqnum=4)
|
||||
r5 = NonPauser(seqnum=5)
|
||||
r6 = Pauser(seqnum=6)
|
||||
p1.resumeProducing.side_effect = lambda: c.send_record(r4)
|
||||
p2.resumeProducing.side_effect = lambda: c.send_record(r5)
|
||||
p3.resumeProducing.side_effect = lambda: c.send_record(r6)
|
||||
o.resumeProducing()
|
||||
self.assertEqual(p1.mock_calls, [mock.call.resumeProducing(),
|
||||
mock.call.pauseProducing(),
|
||||
])
|
||||
self.assertEqual(p2.mock_calls, [mock.call.resumeProducing(),
|
||||
mock.call.pauseProducing(),
|
||||
])
|
||||
self.assertEqual(p3.mock_calls, [mock.call.resumeProducing(),
|
||||
mock.call.pauseProducing(),
|
||||
])
|
||||
self.assertEqual(c.mock_calls, [mock.call.send_record(r4),
|
||||
mock.call.send_record(r5),
|
||||
mock.call.send_record(r6),
|
||||
])
|
||||
clear_mock_calls(p1, p2, p3, c)
|
||||
# p1 should now be at the head of the queue again
|
||||
self.assertEqual(list(o._all_producers), [p1, p2, p3])
|
||||
|
||||
# now we let it catch up. p1 and p2 send non-pausing data, p3 sends
|
||||
# nothing.
|
||||
r7 = NonPauser(seqnum=4)
|
||||
r8 = NonPauser(seqnum=5)
|
||||
p1.resumeProducing.side_effect = lambda: c.send_record(r7)
|
||||
p2.resumeProducing.side_effect = lambda: c.send_record(r8)
|
||||
p3.resumeProducing.side_effect = lambda: None
|
||||
|
||||
o.resumeProducing()
|
||||
self.assertEqual(p1.mock_calls, [mock.call.resumeProducing(),
|
||||
])
|
||||
self.assertEqual(p2.mock_calls, [mock.call.resumeProducing(),
|
||||
])
|
||||
self.assertEqual(p3.mock_calls, [mock.call.resumeProducing(),
|
||||
])
|
||||
self.assertEqual(c.mock_calls, [mock.call.send_record(r7),
|
||||
mock.call.send_record(r8),
|
||||
])
|
||||
clear_mock_calls(p1, p2, p3, c)
|
||||
# p1 should now be at the head of the queue again
|
||||
self.assertEqual(list(o._all_producers), [p1, p2, p3])
|
||||
self.assertFalse(o._paused)
|
||||
|
||||
# now a producer disconnects itself (spontaneously, not from inside a
|
||||
# resumeProducing)
|
||||
o.subchannel_unregisterProducer(sc1)
|
||||
self.assertEqual(list(o._all_producers), [p2, p3])
|
||||
self.assertEqual(p1.mock_calls, [])
|
||||
self.assertFalse(o._paused)
|
||||
|
||||
# and another disconnects itself when called
|
||||
p2.resumeProducing.side_effect = lambda: None
|
||||
p3.resumeProducing.side_effect = lambda: o.subchannel_unregisterProducer(sc3)
|
||||
o.pauseProducing()
|
||||
o.resumeProducing()
|
||||
self.assertEqual(p2.mock_calls, [mock.call.pauseProducing(),
|
||||
mock.call.resumeProducing()])
|
||||
self.assertEqual(p3.mock_calls, [mock.call.pauseProducing(),
|
||||
mock.call.resumeProducing()])
|
||||
clear_mock_calls(p2, p3)
|
||||
self.assertEqual(list(o._all_producers), [p2])
|
||||
self.assertFalse(o._paused)
|
||||
|
||||
def test_subchannel_closed(self):
|
||||
o, m, c = make_outbound()
|
||||
|
||||
sc1 = mock.Mock()
|
||||
p1 = mock.Mock(name="p1")
|
||||
o.subchannel_registerProducer(sc1, p1, True)
|
||||
self.assertEqual(p1.mock_calls, [mock.call.pauseProducing()])
|
||||
clear_mock_calls(p1)
|
||||
|
||||
o.subchannel_closed(sc1)
|
||||
self.assertEqual(p1.mock_calls, [])
|
||||
self.assertEqual(list(o._all_producers), [])
|
||||
|
||||
sc2 = mock.Mock()
|
||||
o.subchannel_closed(sc2)
|
||||
|
||||
def test_disconnect(self):
|
||||
o, m, c = make_outbound()
|
||||
o.use_connection(c)
|
||||
|
||||
sc1 = mock.Mock()
|
||||
p1 = mock.Mock(name="p1")
|
||||
o.subchannel_registerProducer(sc1, p1, True)
|
||||
self.assertEqual(p1.mock_calls, [])
|
||||
o.stop_using_connection()
|
||||
self.assertEqual(p1.mock_calls, [mock.call.pauseProducing()])
|
||||
|
||||
def OFF_test_push_pull(self):
|
||||
# use one IPushProducer and one IPullProducer. They should take turns
|
||||
o, m, c = make_outbound()
|
||||
o.use_connection(c)
|
||||
clear_mock_calls(c)
|
||||
|
||||
sc1, sc2 = object(), object()
|
||||
p1, p2 = mock.Mock(name="p1"), mock.Mock(name="p2")
|
||||
r1 = Pauser(seqnum=1)
|
||||
r2 = NonPauser(seqnum=2)
|
||||
|
||||
# we aren't paused yet, since we haven't sent any data
|
||||
o.subchannel_registerProducer(sc1, p1, True) # push
|
||||
o.queue_and_send_record(r1)
|
||||
# now we're paused
|
||||
self.assertTrue(o._paused)
|
||||
self.assertEqual(c.mock_calls, [mock.call.send_record(r1)])
|
||||
self.assertEqual(p1.mock_calls, [mock.call.pauseProducing()])
|
||||
self.assertEqual(p2.mock_calls, [])
|
||||
clear_mock_calls(p1, p2, c)
|
||||
|
||||
p1.resumeProducing.side_effect = lambda: c.send_record(r1)
|
||||
p2.resumeProducing.side_effect = lambda: c.send_record(r2)
|
||||
o.subchannel_registerProducer(sc2, p2, False) # pull: always ready
|
||||
|
||||
# p1 is still first, since p2 was just added (at the end)
|
||||
self.assertTrue(o._paused)
|
||||
self.assertEqual(c.mock_calls, [])
|
||||
self.assertEqual(p1.mock_calls, [])
|
||||
self.assertEqual(p2.mock_calls, [])
|
||||
self.assertEqual(list(o._all_producers), [p1, p2])
|
||||
clear_mock_calls(p1, p2, c)
|
||||
|
||||
# resume should send r1, which should pause everything
|
||||
o.resumeProducing()
|
||||
self.assertTrue(o._paused)
|
||||
self.assertEqual(c.mock_calls, [mock.call.send_record(r1),
|
||||
])
|
||||
self.assertEqual(p1.mock_calls, [mock.call.resumeProducing(),
|
||||
mock.call.pauseProducing(),
|
||||
])
|
||||
self.assertEqual(p2.mock_calls, [])
|
||||
self.assertEqual(list(o._all_producers), [p2, p1]) # now p2 is next
|
||||
clear_mock_calls(p1, p2, c)
|
||||
|
||||
# next should fire p2, then p1
|
||||
o.resumeProducing()
|
||||
self.assertTrue(o._paused)
|
||||
self.assertEqual(c.mock_calls, [mock.call.send_record(r2),
|
||||
mock.call.send_record(r1),
|
||||
])
|
||||
self.assertEqual(p1.mock_calls, [mock.call.resumeProducing(),
|
||||
mock.call.pauseProducing(),
|
||||
])
|
||||
self.assertEqual(p2.mock_calls, [mock.call.resumeProducing(),
|
||||
])
|
||||
self.assertEqual(list(o._all_producers), [p2, p1]) # p2 still at bat
|
||||
clear_mock_calls(p1, p2, c)
|
||||
|
||||
def test_pull_producer(self):
|
||||
# a single pull producer should write until it is paused, rate-limited
|
||||
# by the cooperator (so we'll see back-to-back resumeProducing calls
|
||||
# until the Connection is paused, or 10ms have passed, whichever comes
|
||||
# first, and if it's stopped by the timer, then the next EventualQueue
|
||||
# turn will start it off again)
|
||||
|
||||
o, m, c = make_outbound()
|
||||
eq = o._test_eq
|
||||
o.use_connection(c)
|
||||
clear_mock_calls(c)
|
||||
self.assertFalse(o._paused)
|
||||
|
||||
sc1 = mock.Mock()
|
||||
p1 = mock.Mock(name="p1")
|
||||
alsoProvides(p1, IPullProducer)
|
||||
|
||||
records = [NonPauser(seqnum=1)] * 10
|
||||
records.append(Pauser(seqnum=2))
|
||||
records.append(Stopper(sc1))
|
||||
it = iter(records)
|
||||
p1.resumeProducing.side_effect = lambda: c.send_record(next(it))
|
||||
o.subchannel_registerProducer(sc1, p1, False)
|
||||
eq.flush_sync() # fast forward into the glorious (paused) future
|
||||
|
||||
self.assertTrue(o._paused)
|
||||
self.assertEqual(c.mock_calls,
|
||||
[mock.call.send_record(r) for r in records[:-1]])
|
||||
self.assertEqual(p1.mock_calls,
|
||||
[mock.call.resumeProducing()]*(len(records)-1))
|
||||
clear_mock_calls(c, p1)
|
||||
|
||||
# next resumeProducing should cause it to disconnect
|
||||
o.resumeProducing()
|
||||
eq.flush_sync()
|
||||
self.assertEqual(c.mock_calls, [mock.call.send_record(records[-1])])
|
||||
self.assertEqual(p1.mock_calls, [mock.call.resumeProducing()])
|
||||
self.assertEqual(len(o._all_producers), 0)
|
||||
self.assertFalse(o._paused)
|
||||
|
||||
def test_two_pull_producers(self):
|
||||
# we should alternate between them until paused
|
||||
p1_records = ([NonPauser(seqnum=i) for i in range(5)] +
|
||||
[Pauser(seqnum=5)] +
|
||||
[NonPauser(seqnum=i) for i in range(6, 10)])
|
||||
p2_records = ([NonPauser(seqnum=i) for i in range(10, 19)] +
|
||||
[Pauser(seqnum=19)])
|
||||
expected1 = [NonPauser(0), NonPauser(10),
|
||||
NonPauser(1), NonPauser(11),
|
||||
NonPauser(2), NonPauser(12),
|
||||
NonPauser(3), NonPauser(13),
|
||||
NonPauser(4), NonPauser(14),
|
||||
Pauser(5)]
|
||||
expected2 = [ NonPauser(15),
|
||||
NonPauser(6), NonPauser(16),
|
||||
NonPauser(7), NonPauser(17),
|
||||
NonPauser(8), NonPauser(18),
|
||||
NonPauser(9), Pauser(19),
|
||||
]
|
||||
|
||||
o, m, c = make_outbound()
|
||||
eq = o._test_eq
|
||||
o.use_connection(c)
|
||||
clear_mock_calls(c)
|
||||
self.assertFalse(o._paused)
|
||||
|
||||
sc1 = mock.Mock()
|
||||
p1 = mock.Mock(name="p1")
|
||||
alsoProvides(p1, IPullProducer)
|
||||
it1 = iter(p1_records)
|
||||
p1.resumeProducing.side_effect = lambda: c.send_record(next(it1))
|
||||
o.subchannel_registerProducer(sc1, p1, False)
|
||||
|
||||
sc2 = mock.Mock()
|
||||
p2 = mock.Mock(name="p2")
|
||||
alsoProvides(p2, IPullProducer)
|
||||
it2 = iter(p2_records)
|
||||
p2.resumeProducing.side_effect = lambda: c.send_record(next(it2))
|
||||
o.subchannel_registerProducer(sc2, p2, False)
|
||||
|
||||
eq.flush_sync() # fast forward into the glorious (paused) future
|
||||
|
||||
sends = [mock.call.resumeProducing()]
|
||||
self.assertTrue(o._paused)
|
||||
self.assertEqual(c.mock_calls,
|
||||
[mock.call.send_record(r) for r in expected1])
|
||||
self.assertEqual(p1.mock_calls, 6*sends)
|
||||
self.assertEqual(p2.mock_calls, 5*sends)
|
||||
clear_mock_calls(c, p1, p2)
|
||||
|
||||
o.resumeProducing()
|
||||
eq.flush_sync()
|
||||
self.assertTrue(o._paused)
|
||||
self.assertEqual(c.mock_calls,
|
||||
[mock.call.send_record(r) for r in expected2])
|
||||
self.assertEqual(p1.mock_calls, 4*sends)
|
||||
self.assertEqual(p2.mock_calls, 5*sends)
|
||||
clear_mock_calls(c, p1, p2)
|
||||
|
||||
def test_send_if_connected(self):
|
||||
o, m, c = make_outbound()
|
||||
o.send_if_connected(Ack(1)) # not connected yet
|
||||
|
||||
o.use_connection(c)
|
||||
o.send_if_connected(KCM())
|
||||
self.assertEqual(c.mock_calls, [mock.call.registerProducer(o, True),
|
||||
mock.call.send_record(KCM())])
|
||||
|
||||
def test_tolerate_duplicate_pause_resume(self):
|
||||
o, m, c = make_outbound()
|
||||
self.assertTrue(o._paused) # no connection
|
||||
o.use_connection(c)
|
||||
self.assertFalse(o._paused)
|
||||
o.pauseProducing()
|
||||
self.assertTrue(o._paused)
|
||||
o.pauseProducing()
|
||||
self.assertTrue(o._paused)
|
||||
o.resumeProducing()
|
||||
self.assertFalse(o._paused)
|
||||
o.resumeProducing()
|
||||
self.assertFalse(o._paused)
|
||||
|
||||
def test_stopProducing(self):
|
||||
o, m, c = make_outbound()
|
||||
o.use_connection(c)
|
||||
self.assertFalse(o._paused)
|
||||
o.stopProducing() # connection does this before loss
|
||||
self.assertTrue(o._paused)
|
||||
o.stop_using_connection()
|
||||
self.assertTrue(o._paused)
|
||||
|
||||
def test_resume_error(self):
|
||||
o, m, c = make_outbound()
|
||||
o.use_connection(c)
|
||||
sc1 = mock.Mock()
|
||||
p1 = mock.Mock(name="p1")
|
||||
alsoProvides(p1, IPullProducer)
|
||||
p1.resumeProducing.side_effect = PretendResumptionError
|
||||
o.subchannel_registerProducer(sc1, p1, False)
|
||||
o._test_eq.flush_sync()
|
||||
# the error is supposed to automatically unregister the producer
|
||||
self.assertEqual(list(o._all_producers), [])
|
||||
self.flushLoggedErrors(PretendResumptionError)
|
||||
|
||||
|
||||
def make_pushpull(pauses):
|
||||
p = mock.Mock()
|
||||
alsoProvides(p, IPullProducer)
|
||||
unregister = mock.Mock()
|
||||
|
||||
clock = Clock()
|
||||
eq = EventualQueue(clock)
|
||||
term = mock.Mock(side_effect=lambda: True) # one write per Eventual tick
|
||||
term_factory = lambda: term
|
||||
coop = Cooperator(terminationPredicateFactory=term_factory,
|
||||
scheduler=eq.eventually)
|
||||
pp = PullToPush(p, unregister, coop)
|
||||
|
||||
it = cycle(pauses)
|
||||
def action(i):
|
||||
if isinstance(i, Exception):
|
||||
raise i
|
||||
elif i:
|
||||
pp.pauseProducing()
|
||||
p.resumeProducing.side_effect = lambda: action(next(it))
|
||||
return p, unregister, pp, eq
|
||||
|
||||
class PretendResumptionError(Exception):
|
||||
pass
|
||||
class PretendUnregisterError(Exception):
|
||||
pass
|
||||
|
||||
class PushPull(unittest.TestCase):
|
||||
# test our wrapper utility, which I copied from
|
||||
# twisted.internet._producer_helpers since it isn't publically exposed
|
||||
|
||||
def test_start_unpaused(self):
|
||||
p, unr, pp, eq = make_pushpull([True]) # pause on each resumeProducing
|
||||
# if it starts unpaused, it gets one write before being halted
|
||||
pp.startStreaming(False)
|
||||
eq.flush_sync()
|
||||
self.assertEqual(p.mock_calls, [mock.call.resumeProducing()]*1)
|
||||
clear_mock_calls(p)
|
||||
|
||||
# now each time we call resumeProducing, we should see one delivered to
|
||||
# the underlying IPullProducer
|
||||
pp.resumeProducing()
|
||||
eq.flush_sync()
|
||||
self.assertEqual(p.mock_calls, [mock.call.resumeProducing()]*1)
|
||||
|
||||
pp.stopStreaming()
|
||||
pp.stopStreaming() # should tolerate this
|
||||
|
||||
def test_start_unpaused_two_writes(self):
|
||||
p, unr, pp, eq = make_pushpull([False, True]) # pause every other time
|
||||
# it should get two writes, since the first didn't pause
|
||||
pp.startStreaming(False)
|
||||
eq.flush_sync()
|
||||
self.assertEqual(p.mock_calls, [mock.call.resumeProducing()]*2)
|
||||
|
||||
def test_start_paused(self):
|
||||
p, unr, pp, eq = make_pushpull([True]) # pause on each resumeProducing
|
||||
pp.startStreaming(True)
|
||||
eq.flush_sync()
|
||||
self.assertEqual(p.mock_calls, [])
|
||||
pp.stopStreaming()
|
||||
|
||||
def test_stop(self):
|
||||
p, unr, pp, eq = make_pushpull([True])
|
||||
pp.startStreaming(True)
|
||||
pp.stopProducing()
|
||||
eq.flush_sync()
|
||||
self.assertEqual(p.mock_calls, [mock.call.stopProducing()])
|
||||
|
||||
def test_error(self):
|
||||
p, unr, pp, eq = make_pushpull([PretendResumptionError()])
|
||||
unr.side_effect = lambda: pp.stopStreaming()
|
||||
pp.startStreaming(False)
|
||||
eq.flush_sync()
|
||||
self.assertEqual(unr.mock_calls, [mock.call()])
|
||||
self.flushLoggedErrors(PretendResumptionError)
|
||||
|
||||
def test_error_during_unregister(self):
|
||||
p, unr, pp, eq = make_pushpull([PretendResumptionError()])
|
||||
unr.side_effect = PretendUnregisterError()
|
||||
pp.startStreaming(False)
|
||||
eq.flush_sync()
|
||||
self.assertEqual(unr.mock_calls, [mock.call()])
|
||||
self.flushLoggedErrors(PretendResumptionError, PretendUnregisterError)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# TODO: consider making p1/p2/p3 all elements of a shared Mock, maybe I
|
||||
# could capture the inter-call ordering that way
|
43
src/wormhole/test/dilate/test_parse.py
Normal file
43
src/wormhole/test/dilate/test_parse.py
Normal file
|
@ -0,0 +1,43 @@
|
|||
from __future__ import print_function, unicode_literals
|
||||
import mock
|
||||
from twisted.trial import unittest
|
||||
from ..._dilation.connection import (parse_record, encode_record,
|
||||
KCM, Ping, Pong, Open, Data, Close, Ack)
|
||||
|
||||
class Parse(unittest.TestCase):
|
||||
def test_parse(self):
|
||||
self.assertEqual(parse_record(b"\x00"), KCM())
|
||||
self.assertEqual(parse_record(b"\x01\x55\x44\x33\x22"),
|
||||
Ping(ping_id=b"\x55\x44\x33\x22"))
|
||||
self.assertEqual(parse_record(b"\x02\x55\x44\x33\x22"),
|
||||
Pong(ping_id=b"\x55\x44\x33\x22"))
|
||||
self.assertEqual(parse_record(b"\x03\x00\x00\x02\x01\x00\x00\x01\x00"),
|
||||
Open(scid=513, seqnum=256))
|
||||
self.assertEqual(parse_record(b"\x04\x00\x00\x02\x02\x00\x00\x01\x01dataaa"),
|
||||
Data(scid=514, seqnum=257, data=b"dataaa"))
|
||||
self.assertEqual(parse_record(b"\x05\x00\x00\x02\x03\x00\x00\x01\x02"),
|
||||
Close(scid=515, seqnum=258))
|
||||
self.assertEqual(parse_record(b"\x06\x00\x00\x01\x03"),
|
||||
Ack(resp_seqnum=259))
|
||||
with mock.patch("wormhole._dilation.connection.log.err") as le:
|
||||
with self.assertRaises(ValueError):
|
||||
parse_record(b"\x07unknown")
|
||||
self.assertEqual(le.mock_calls,
|
||||
[mock.call("received unknown message type: {}".format(
|
||||
b"\x07unknown"))])
|
||||
|
||||
def test_encode(self):
|
||||
self.assertEqual(encode_record(KCM()), b"\x00")
|
||||
self.assertEqual(encode_record(Ping(ping_id=b"ping")), b"\x01ping")
|
||||
self.assertEqual(encode_record(Pong(ping_id=b"pong")), b"\x02pong")
|
||||
self.assertEqual(encode_record(Open(scid=65536, seqnum=16)),
|
||||
b"\x03\x00\x01\x00\x00\x00\x00\x00\x10")
|
||||
self.assertEqual(encode_record(Data(scid=65537, seqnum=17, data=b"dataaa")),
|
||||
b"\x04\x00\x01\x00\x01\x00\x00\x00\x11dataaa")
|
||||
self.assertEqual(encode_record(Close(scid=65538, seqnum=18)),
|
||||
b"\x05\x00\x01\x00\x02\x00\x00\x00\x12")
|
||||
self.assertEqual(encode_record(Ack(resp_seqnum=19)),
|
||||
b"\x06\x00\x00\x00\x13")
|
||||
with self.assertRaises(TypeError) as ar:
|
||||
encode_record("not a record")
|
||||
self.assertEqual(str(ar.exception), "not a record")
|
268
src/wormhole/test/dilate/test_record.py
Normal file
268
src/wormhole/test/dilate/test_record.py
Normal file
|
@ -0,0 +1,268 @@
|
|||
from __future__ import print_function, unicode_literals
|
||||
import mock
|
||||
from zope.interface import alsoProvides
|
||||
from twisted.trial import unittest
|
||||
from noise.exceptions import NoiseInvalidMessage
|
||||
from ..._dilation.connection import (IFramer, Frame, Prologue,
|
||||
_Record, Handshake,
|
||||
Disconnect, Ping)
|
||||
|
||||
def make_record():
|
||||
f = mock.Mock()
|
||||
alsoProvides(f, IFramer)
|
||||
n = mock.Mock() # pretends to be a Noise object
|
||||
r = _Record(f, n)
|
||||
return r, f, n
|
||||
|
||||
class Record(unittest.TestCase):
|
||||
def test_good2(self):
|
||||
f = mock.Mock()
|
||||
alsoProvides(f, IFramer)
|
||||
f.add_and_parse = mock.Mock(side_effect=[
|
||||
[],
|
||||
[Prologue()],
|
||||
[Frame(frame=b"rx-handshake")],
|
||||
[Frame(frame=b"frame1"), Frame(frame=b"frame2")],
|
||||
])
|
||||
n = mock.Mock()
|
||||
n.write_message = mock.Mock(return_value=b"tx-handshake")
|
||||
p1, p2 = object(), object()
|
||||
n.decrypt = mock.Mock(side_effect=[p1, p2])
|
||||
r = _Record(f, n)
|
||||
self.assertEqual(f.mock_calls, [])
|
||||
r.connectionMade()
|
||||
self.assertEqual(f.mock_calls, [mock.call.connectionMade()])
|
||||
f.mock_calls[:] = []
|
||||
self.assertEqual(n.mock_calls, [mock.call.start_handshake()])
|
||||
n.mock_calls[:] = []
|
||||
|
||||
# Pretend to deliver the prologue in two parts. The text we send in
|
||||
# doesn't matter: the side_effect= is what causes the prologue to be
|
||||
# recognized by the second call.
|
||||
self.assertEqual(list(r.add_and_unframe(b"pro")), [])
|
||||
self.assertEqual(f.mock_calls, [mock.call.add_and_parse(b"pro")])
|
||||
f.mock_calls[:] = []
|
||||
self.assertEqual(n.mock_calls, [])
|
||||
|
||||
# recognizing the prologue causes a handshake frame to be sent
|
||||
self.assertEqual(list(r.add_and_unframe(b"logue")), [])
|
||||
self.assertEqual(f.mock_calls, [mock.call.add_and_parse(b"logue"),
|
||||
mock.call.send_frame(b"tx-handshake")])
|
||||
f.mock_calls[:] = []
|
||||
self.assertEqual(n.mock_calls, [mock.call.write_message()])
|
||||
n.mock_calls[:] = []
|
||||
|
||||
# next add_and_unframe is recognized as the Handshake
|
||||
self.assertEqual(list(r.add_and_unframe(b"blah")), [Handshake()])
|
||||
self.assertEqual(f.mock_calls, [mock.call.add_and_parse(b"blah")])
|
||||
f.mock_calls[:] = []
|
||||
self.assertEqual(n.mock_calls, [mock.call.read_message(b"rx-handshake")])
|
||||
n.mock_calls[:] = []
|
||||
|
||||
# next is a pair of Records
|
||||
r1, r2 = object() , object()
|
||||
with mock.patch("wormhole._dilation.connection.parse_record",
|
||||
side_effect=[r1,r2]) as pr:
|
||||
self.assertEqual(list(r.add_and_unframe(b"blah2")), [r1, r2])
|
||||
self.assertEqual(n.mock_calls, [mock.call.decrypt(b"frame1"),
|
||||
mock.call.decrypt(b"frame2")])
|
||||
self.assertEqual(pr.mock_calls, [mock.call(p1), mock.call(p2)])
|
||||
|
||||
def test_bad_handshake(self):
|
||||
f = mock.Mock()
|
||||
alsoProvides(f, IFramer)
|
||||
f.add_and_parse = mock.Mock(return_value=[Prologue(),
|
||||
Frame(frame=b"rx-handshake")])
|
||||
n = mock.Mock()
|
||||
n.write_message = mock.Mock(return_value=b"tx-handshake")
|
||||
nvm = NoiseInvalidMessage()
|
||||
n.read_message = mock.Mock(side_effect=nvm)
|
||||
r = _Record(f, n)
|
||||
self.assertEqual(f.mock_calls, [])
|
||||
r.connectionMade()
|
||||
self.assertEqual(f.mock_calls, [mock.call.connectionMade()])
|
||||
f.mock_calls[:] = []
|
||||
self.assertEqual(n.mock_calls, [mock.call.start_handshake()])
|
||||
n.mock_calls[:] = []
|
||||
|
||||
with mock.patch("wormhole._dilation.connection.log.err") as le:
|
||||
with self.assertRaises(Disconnect):
|
||||
list(r.add_and_unframe(b"data"))
|
||||
self.assertEqual(le.mock_calls,
|
||||
[mock.call(nvm, "bad inbound noise handshake")])
|
||||
|
||||
def test_bad_message(self):
|
||||
f = mock.Mock()
|
||||
alsoProvides(f, IFramer)
|
||||
f.add_and_parse = mock.Mock(return_value=[Prologue(),
|
||||
Frame(frame=b"rx-handshake"),
|
||||
Frame(frame=b"bad-message")])
|
||||
n = mock.Mock()
|
||||
n.write_message = mock.Mock(return_value=b"tx-handshake")
|
||||
nvm = NoiseInvalidMessage()
|
||||
n.decrypt = mock.Mock(side_effect=nvm)
|
||||
r = _Record(f, n)
|
||||
self.assertEqual(f.mock_calls, [])
|
||||
r.connectionMade()
|
||||
self.assertEqual(f.mock_calls, [mock.call.connectionMade()])
|
||||
f.mock_calls[:] = []
|
||||
self.assertEqual(n.mock_calls, [mock.call.start_handshake()])
|
||||
n.mock_calls[:] = []
|
||||
|
||||
with mock.patch("wormhole._dilation.connection.log.err") as le:
|
||||
with self.assertRaises(Disconnect):
|
||||
list(r.add_and_unframe(b"data"))
|
||||
self.assertEqual(le.mock_calls,
|
||||
[mock.call(nvm, "bad inbound noise frame")])
|
||||
|
||||
def test_send_record(self):
|
||||
f = mock.Mock()
|
||||
alsoProvides(f, IFramer)
|
||||
n = mock.Mock()
|
||||
f1 = object()
|
||||
n.encrypt = mock.Mock(return_value=f1)
|
||||
r1 = Ping(b"pingid")
|
||||
r = _Record(f, n)
|
||||
self.assertEqual(f.mock_calls, [])
|
||||
m1 = object()
|
||||
with mock.patch("wormhole._dilation.connection.encode_record",
|
||||
return_value=m1) as er:
|
||||
r.send_record(r1)
|
||||
self.assertEqual(er.mock_calls, [mock.call(r1)])
|
||||
self.assertEqual(n.mock_calls, [mock.call.start_handshake(),
|
||||
mock.call.encrypt(m1)])
|
||||
self.assertEqual(f.mock_calls, [mock.call.send_frame(f1)])
|
||||
|
||||
def test_good(self):
|
||||
# Exercise the success path. The Record instance is given each chunk
|
||||
# of data as it arrives on Protocol.dataReceived, and is supposed to
|
||||
# return a series of Tokens (maybe none, if the chunk was incomplete,
|
||||
# or more than one, if the chunk was larger). Internally, it delivers
|
||||
# the chunks to the Framer for unframing (which returns 0 or more
|
||||
# frames), manages the Noise decryption object, and parses any
|
||||
# decrypted messages into tokens (some of which are consumed
|
||||
# internally, others for delivery upstairs).
|
||||
#
|
||||
# in the normal flow, we get:
|
||||
#
|
||||
# | | Inbound | NoiseAction | Outbound | ToUpstairs |
|
||||
# | | - | - | - | - |
|
||||
# | 1 | | | prologue | |
|
||||
# | 2 | prologue | | | |
|
||||
# | 3 | | write_message | handshake | |
|
||||
# | 4 | handshake | read_message | | Handshake |
|
||||
# | 5 | | encrypt | KCM | |
|
||||
# | 6 | KCM | decrypt | | KCM |
|
||||
# | 7 | msg1 | decrypt | | msg1 |
|
||||
|
||||
# 1: instantiating the Record instance causes the outbound prologue
|
||||
# to be sent
|
||||
|
||||
# 2+3: receipt of the inbound prologue triggers creation of the
|
||||
# ephemeral key (the "handshake") by calling noise.write_message()
|
||||
# and then writes the handshake to the outbound transport
|
||||
|
||||
# 4: when the peer's handshake is received, it is delivered to
|
||||
# noise.read_message(), which generates the shared key (enabling
|
||||
# noise.send() and noise.decrypt()). It also delivers the Handshake
|
||||
# token upstairs, which might (on the Follower) trigger immediate
|
||||
# transmission of the Key Confirmation Message (KCM)
|
||||
|
||||
# 5: the outbound KCM is framed and fed into noise.encrypt(), then
|
||||
# sent outbound
|
||||
|
||||
# 6: the peer's KCM is decrypted then delivered upstairs. The
|
||||
# Follower treats this as a signal that it should use this connection
|
||||
# (and drop all others).
|
||||
|
||||
# 7: the peer's first message is decrypted, parsed, and delivered
|
||||
# upstairs. This might be an Open or a Data, depending upon what
|
||||
# queued messages were left over from the previous connection
|
||||
|
||||
r, f, n = make_record()
|
||||
outbound_handshake = object()
|
||||
kcm, msg1 = object(), object()
|
||||
f_kcm, f_msg1 = object(), object()
|
||||
n.write_message = mock.Mock(return_value=outbound_handshake)
|
||||
n.decrypt = mock.Mock(side_effect=[kcm, msg1])
|
||||
n.encrypt = mock.Mock(side_effect=[f_kcm, f_msg1])
|
||||
f.add_and_parse = mock.Mock(side_effect=[[], # no tokens yet
|
||||
[Prologue()],
|
||||
[Frame("f_handshake")],
|
||||
[Frame("f_kcm"),
|
||||
Frame("f_msg1")],
|
||||
])
|
||||
|
||||
self.assertEqual(f.mock_calls, [])
|
||||
self.assertEqual(n.mock_calls, [mock.call.start_handshake()])
|
||||
n.mock_calls[:] = []
|
||||
|
||||
# 1. The Framer is responsible for sending the prologue, so we don't
|
||||
# have to check that here, we just check that the Framer was told
|
||||
# about connectionMade properly.
|
||||
r.connectionMade()
|
||||
self.assertEqual(f.mock_calls, [mock.call.connectionMade()])
|
||||
self.assertEqual(n.mock_calls, [])
|
||||
f.mock_calls[:] = []
|
||||
|
||||
# 2
|
||||
# we dribble the prologue in over two messages, to make sure we can
|
||||
# handle a dataReceived that doesn't complete the token
|
||||
|
||||
# remember, add_and_unframe is a generator
|
||||
self.assertEqual(list(r.add_and_unframe(b"pro")), [])
|
||||
self.assertEqual(f.mock_calls, [mock.call.add_and_parse(b"pro")])
|
||||
self.assertEqual(n.mock_calls, [])
|
||||
f.mock_calls[:] = []
|
||||
|
||||
self.assertEqual(list(r.add_and_unframe(b"logue")), [])
|
||||
# 3: write_message, send outbound handshake
|
||||
self.assertEqual(f.mock_calls, [mock.call.add_and_parse(b"logue"),
|
||||
mock.call.send_frame(outbound_handshake),
|
||||
])
|
||||
self.assertEqual(n.mock_calls, [mock.call.write_message()])
|
||||
f.mock_calls[:] = []
|
||||
n.mock_calls[:] = []
|
||||
|
||||
# 4
|
||||
# Now deliver the Noise "handshake", the ephemeral public key. This
|
||||
# is framed, but not a record, so it shouldn't decrypt or parse
|
||||
# anything, but the handshake is delivered to the Noise object, and
|
||||
# it does return a Handshake token so we can let the next layer up
|
||||
# react (by sending the KCM frame if we're a Follower, or not if
|
||||
# we're the Leader)
|
||||
|
||||
self.assertEqual(list(r.add_and_unframe(b"handshake")), [Handshake()])
|
||||
self.assertEqual(f.mock_calls, [mock.call.add_and_parse(b"handshake")])
|
||||
self.assertEqual(n.mock_calls, [mock.call.read_message("f_handshake")])
|
||||
f.mock_calls[:] = []
|
||||
n.mock_calls[:] = []
|
||||
|
||||
|
||||
# 5: at this point we ought to be able to send a messge, the KCM
|
||||
with mock.patch("wormhole._dilation.connection.encode_record",
|
||||
side_effect=[b"r-kcm"]) as er:
|
||||
r.send_record(kcm)
|
||||
self.assertEqual(er.mock_calls, [mock.call(kcm)])
|
||||
self.assertEqual(n.mock_calls, [mock.call.encrypt(b"r-kcm")])
|
||||
self.assertEqual(f.mock_calls, [mock.call.send_frame(f_kcm)])
|
||||
n.mock_calls[:] = []
|
||||
f.mock_calls[:] = []
|
||||
|
||||
# 6: Now we deliver two messages stacked up: the KCM (Key
|
||||
# Confirmation Message) and the first real message. Concatenating
|
||||
# them tests that we can handle more than one token in a single
|
||||
# chunk. We need to mock parse_record() because everything past the
|
||||
# handshake is decrypted and parsed.
|
||||
|
||||
with mock.patch("wormhole._dilation.connection.parse_record",
|
||||
side_effect=[kcm, msg1]) as pr:
|
||||
self.assertEqual(list(r.add_and_unframe(b"kcm,msg1")),
|
||||
[kcm, msg1])
|
||||
self.assertEqual(f.mock_calls,
|
||||
[mock.call.add_and_parse(b"kcm,msg1")])
|
||||
self.assertEqual(n.mock_calls, [mock.call.decrypt("f_kcm"),
|
||||
mock.call.decrypt("f_msg1")])
|
||||
self.assertEqual(pr.mock_calls, [mock.call(kcm), mock.call(msg1)])
|
||||
n.mock_calls[:] = []
|
||||
f.mock_calls[:] = []
|
142
src/wormhole/test/dilate/test_subchannel.py
Normal file
142
src/wormhole/test/dilate/test_subchannel.py
Normal file
|
@ -0,0 +1,142 @@
|
|||
from __future__ import print_function, unicode_literals
|
||||
import mock
|
||||
from twisted.trial import unittest
|
||||
from twisted.internet.interfaces import ITransport
|
||||
from twisted.internet.error import ConnectionDone
|
||||
from ..._dilation.subchannel import (Once, SubChannel,
|
||||
_WormholeAddress, _SubchannelAddress,
|
||||
AlreadyClosedError)
|
||||
from .common import mock_manager
|
||||
|
||||
def make_sc(set_protocol=True):
|
||||
scid = b"scid"
|
||||
hostaddr = _WormholeAddress()
|
||||
peeraddr = _SubchannelAddress(scid)
|
||||
m = mock_manager()
|
||||
sc = SubChannel(scid, m, hostaddr, peeraddr)
|
||||
p = mock.Mock()
|
||||
if set_protocol:
|
||||
sc._set_protocol(p)
|
||||
return sc, m, scid, hostaddr, peeraddr, p
|
||||
|
||||
class SubChannelAPI(unittest.TestCase):
|
||||
def test_once(self):
|
||||
o = Once(ValueError)
|
||||
o()
|
||||
with self.assertRaises(ValueError):
|
||||
o()
|
||||
|
||||
def test_create(self):
|
||||
sc, m, scid, hostaddr, peeraddr, p = make_sc()
|
||||
self.assert_(ITransport.providedBy(sc))
|
||||
self.assertEqual(m.mock_calls, [])
|
||||
self.assertIdentical(sc.getHost(), hostaddr)
|
||||
self.assertIdentical(sc.getPeer(), peeraddr)
|
||||
|
||||
def test_write(self):
|
||||
sc, m, scid, hostaddr, peeraddr, p = make_sc()
|
||||
|
||||
sc.write(b"data")
|
||||
self.assertEqual(m.mock_calls, [mock.call.send_data(scid, b"data")])
|
||||
m.mock_calls[:] = []
|
||||
sc.writeSequence([b"more", b"data"])
|
||||
self.assertEqual(m.mock_calls, [mock.call.send_data(scid, b"moredata")])
|
||||
|
||||
def test_write_when_closing(self):
|
||||
sc, m, scid, hostaddr, peeraddr, p = make_sc()
|
||||
|
||||
sc.loseConnection()
|
||||
self.assertEqual(m.mock_calls, [mock.call.send_close(scid)])
|
||||
m.mock_calls[:] = []
|
||||
|
||||
with self.assertRaises(AlreadyClosedError) as e:
|
||||
sc.write(b"data")
|
||||
self.assertEqual(str(e.exception),
|
||||
"write not allowed on closed subchannel")
|
||||
|
||||
def test_local_close(self):
|
||||
sc, m, scid, hostaddr, peeraddr, p = make_sc()
|
||||
|
||||
sc.loseConnection()
|
||||
self.assertEqual(m.mock_calls, [mock.call.send_close(scid)])
|
||||
m.mock_calls[:] = []
|
||||
|
||||
# late arriving data is still delivered
|
||||
sc.remote_data(b"late")
|
||||
self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"late")])
|
||||
p.mock_calls[:] = []
|
||||
|
||||
sc.remote_close()
|
||||
self.assert_connectionDone(p.mock_calls)
|
||||
|
||||
def test_local_double_close(self):
|
||||
sc, m, scid, hostaddr, peeraddr, p = make_sc()
|
||||
|
||||
sc.loseConnection()
|
||||
self.assertEqual(m.mock_calls, [mock.call.send_close(scid)])
|
||||
m.mock_calls[:] = []
|
||||
|
||||
with self.assertRaises(AlreadyClosedError) as e:
|
||||
sc.loseConnection()
|
||||
self.assertEqual(str(e.exception),
|
||||
"loseConnection not allowed on closed subchannel")
|
||||
|
||||
def assert_connectionDone(self, mock_calls):
|
||||
self.assertEqual(len(mock_calls), 1)
|
||||
self.assertEqual(mock_calls[0][0], "connectionLost")
|
||||
self.assertEqual(len(mock_calls[0][1]), 1)
|
||||
self.assertIsInstance(mock_calls[0][1][0], ConnectionDone)
|
||||
|
||||
def test_remote_close(self):
|
||||
sc, m, scid, hostaddr, peeraddr, p = make_sc()
|
||||
sc.remote_close()
|
||||
self.assertEqual(m.mock_calls, [mock.call.subchannel_closed(sc)])
|
||||
self.assert_connectionDone(p.mock_calls)
|
||||
|
||||
def test_data(self):
|
||||
sc, m, scid, hostaddr, peeraddr, p = make_sc()
|
||||
sc.remote_data(b"data")
|
||||
self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"data")])
|
||||
p.mock_calls[:] = []
|
||||
sc.remote_data(b"not")
|
||||
sc.remote_data(b"coalesced")
|
||||
self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"not"),
|
||||
mock.call.dataReceived(b"coalesced"),
|
||||
])
|
||||
|
||||
def test_data_before_open(self):
|
||||
sc, m, scid, hostaddr, peeraddr, p = make_sc(set_protocol=False)
|
||||
sc.remote_data(b"data")
|
||||
self.assertEqual(p.mock_calls, [])
|
||||
sc._set_protocol(p)
|
||||
self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"data")])
|
||||
p.mock_calls[:] = []
|
||||
sc.remote_data(b"more")
|
||||
self.assertEqual(p.mock_calls, [mock.call.dataReceived(b"more")])
|
||||
|
||||
def test_close_before_open(self):
|
||||
sc, m, scid, hostaddr, peeraddr, p = make_sc(set_protocol=False)
|
||||
sc.remote_close()
|
||||
self.assertEqual(p.mock_calls, [])
|
||||
sc._set_protocol(p)
|
||||
self.assert_connectionDone(p.mock_calls)
|
||||
|
||||
def test_producer(self):
|
||||
sc, m, scid, hostaddr, peeraddr, p = make_sc()
|
||||
|
||||
sc.pauseProducing()
|
||||
self.assertEqual(m.mock_calls, [mock.call.subchannel_pauseProducing(sc)])
|
||||
m.mock_calls[:] = []
|
||||
sc.resumeProducing()
|
||||
self.assertEqual(m.mock_calls, [mock.call.subchannel_resumeProducing(sc)])
|
||||
m.mock_calls[:] = []
|
||||
sc.stopProducing()
|
||||
self.assertEqual(m.mock_calls, [mock.call.subchannel_stopProducing(sc)])
|
||||
m.mock_calls[:] = []
|
||||
|
||||
def test_consumer(self):
|
||||
sc, m, scid, hostaddr, peeraddr, p = make_sc()
|
||||
|
||||
# TODO: more, once this is implemented
|
||||
sc.registerProducer(None, True)
|
||||
sc.unregisterProducer()
|
|
@ -12,7 +12,7 @@ import mock
|
|||
from .. import (__version__, _allocator, _boss, _code, _input, _key, _lister,
|
||||
_mailbox, _nameplate, _order, _receive, _rendezvous, _send,
|
||||
_terminator, errors, timing)
|
||||
from .._interfaces import (IAllocator, IBoss, ICode, IInput, IKey, ILister,
|
||||
from .._interfaces import (IAllocator, IBoss, ICode, IDilator, IInput, IKey, ILister,
|
||||
IMailbox, INameplate, IOrder, IReceive,
|
||||
IRendezvousConnector, ISend, ITerminator, IWordlist)
|
||||
from .._key import derive_key, derive_phase_key, encrypt_data
|
||||
|
@ -1300,6 +1300,7 @@ class Boss(unittest.TestCase):
|
|||
b._RC = Dummy("rc", events, IRendezvousConnector, "start")
|
||||
b._C = Dummy("c", events, ICode, "allocate_code", "input_code",
|
||||
"set_code")
|
||||
b._D = Dummy("d", events, IDilator, "got_wormhole_versions", "got_key")
|
||||
return b, events
|
||||
|
||||
def test_basic(self):
|
||||
|
@ -1327,7 +1328,9 @@ class Boss(unittest.TestCase):
|
|||
b.got_message("side", "0", b"msg1")
|
||||
self.assertEqual(events, [
|
||||
("w.got_key", b"key"),
|
||||
("d.got_key", b"key"),
|
||||
("w.got_verifier", b"verifier"),
|
||||
("d.got_wormhole_versions", "side", "side", {}),
|
||||
("w.got_versions", {}),
|
||||
("w.received", b"msg1"),
|
||||
])
|
||||
|
|
|
@ -9,6 +9,7 @@ from twisted.internet.task import Cooperator
|
|||
from zope.interface import implementer
|
||||
|
||||
from ._boss import Boss
|
||||
from ._dilation.connector import Connector
|
||||
from ._interfaces import IDeferredWormhole, IWormhole
|
||||
from ._key import derive_key
|
||||
from .errors import NoKeyError, WormholeClosed
|
||||
|
@ -189,6 +190,9 @@ class _DeferredWormhole(object):
|
|||
raise NoKeyError()
|
||||
return derive_key(self._key, to_bytes(purpose), length)
|
||||
|
||||
def dilate(self):
|
||||
return self._boss.dilate() # fires with (endpoints)
|
||||
|
||||
def close(self):
|
||||
# fails with WormholeError unless we established a connection
|
||||
# (state=="happy"). Fails with WrongPasswordError (a subclass of
|
||||
|
@ -265,8 +269,12 @@ def create(
|
|||
w = _DelegatedWormhole(delegate)
|
||||
else:
|
||||
w = _DeferredWormhole(reactor, eq)
|
||||
wormhole_versions = {} # will be used to indicate Wormhole capabilities
|
||||
wormhole_versions["app_versions"] = versions # app-specific capabilities
|
||||
# this indicates Wormhole capabilities
|
||||
wormhole_versions = {
|
||||
"can-dilate": [1],
|
||||
"dilation-abilities": Connector.get_connection_abilities(),
|
||||
}
|
||||
wormhole_versions["app_versions"] = versions # app-specific capabilities
|
||||
v = __version__
|
||||
if isinstance(v, type(b"")):
|
||||
v = v.decode("utf-8", errors="replace")
|
||||
|
|
Loading…
Reference in New Issue
Block a user