INCOMPATIBLE CHANGE: Merge branch 'new-proto'

This is a very large branch that replaces many aspects of the wormhole
protocol. Clients that use code before this change (including the 0.7.6
release) will not be able to talk to clients after this change. They
won't even be able to talk to the relay.

Things that have changed:

* The server protocol has changed. A new public relay has been set up,
  which listens on a different port.
* The blocking (non-Twisted) implementation has been removed. It will
  return, built on top of the Twisted falvor, using the Crochet library.
* Persistence (state = wormhole.serialize()) has been removed. It will
  return, in a form that works better for constantly-evolving Wormholes.
* API changes:
  * 'from wormhole.wormhole import wormhole', rather than from
    wormhole.twisted.transcribe (this is likely to change further)
  * Create the Wormhole with the wormhole() function, rather than the
    Wormhole() class constructor. You *must* pass reactor= to get a
    Twisted-flavor wormhole (omitting reactor= will, in the future, give
    you a blocking-flavor wormhole).
  * w.get() and w.send(data), instead of w.get_data(phase) and
    w.send_data(data, phase). Wormhole is now a sequential record pipe,
    rather than a named-record channel. Internally, these APIs produce
    numbered phases.
  * verifier = yield w.verify(), instead of get_verifier(). The new
    verify() defers until the connection has received the
    key-confirmation message, and will errback with WrongPasswordError
    if that message doesn't match.
  * w.derive_key(purpose, length) now requires a length, instead of
    defaulting to the NaCl SecretBox key size.
  * w.close() now always defers until all outbound messages have been
    delivered to the relay server, and the connection has closed. It
    always returns a Deferred. Application code should close() before
    calling os.exit(), to make sure your ACKs have been delivered.
  * Any errors (WrongPasswordError, websocket dropped early) will cause
    all pending Deferreds to errback, the nameplate and mailbox will be
    released, and the websocket connection will be closed. w.close() is
    still the right thing to call after an error, as it will defer until
    the connection is finally dropped.
* The Wormhole object starts working as soon as wormhole() is called,
  rather than waiting until an API method is invoked.
* There are more opportunities for parallelism, which should avoid a few
  roundtrips and make things faster.
* We now use SPAKE2-0.7, which changes the key-derivation function to
  one that hopefully matches a proposed SJCL implementation, enabling
  future interoperability between python and javascript clients.
* We derive per-message keys differently, to prevent a particular kind
  of reflection attack that was mitigated differently before.
* The server now manages "nameplates" and "mailboxes" separately (the
  old server/protocol didn't make a distinction). A "nameplate" is a
  channel with a short name (the number from the wormhole code) and
  which only contains one value (a pointer to a mailbox). A "mailbox"
  has a long random name and contains the usual queue of messages being
  sent from one client to the other. This lets us drop the nameplate as
  soon as the second side has connected (and switches to the mailbox),
  so long file transfers don't hog the short wormhole codes for longer
  than necessary.
* There is room for "nameplate attributes", which will (in the future)
  be used to indicate the wordlist being used for the wormhole code,
  allowing tab-completion for alternate wordlists, including languages
  other than english.
* The new expectation is that nameplates and mailboxes will be deleted
  if nobody is connected to them for a while (although this is not yet
  implemented in the server). Applications which need extended offline
  persistent channels will be able to ask for them when claiming the
  nameplate.
This commit is contained in:
Brian Warner 2016-05-24 15:26:26 -07:00
commit 88696dd0ed
32 changed files with 3032 additions and 3282 deletions

View File

@ -10,6 +10,15 @@ short string that is transcribed from one machine to the other by the users
at the keyboard. This works in conjunction with a baked-in "rendezvous
server" that relays information from one machine to the other.
The "Wormhole" object provides a secure record pipe between any two programs
that use the same wormhole code (and are configured with the same application
ID and rendezvous server). Each side can send multiple messages to the other,
but the encrypted data for all messages must pass through (and be temporarily
stored on) the rendezvous server, which is a shared resource. For this
reason, larger data (including bulk file transfers) should use the Transit
class instead. The Wormhole object has a method to create a Transit object
for this purpose.
## Modes
This library will eventually offer multiple modes. For now, only "transcribe
@ -39,26 +48,36 @@ string.
The two machines participating in the wormhole setup are not distinguished:
it doesn't matter which one goes first, and both use the same Wormhole class.
In the first variant, one side calls `get_code()` while the other calls
`set_code()`. In the second variant, both sides call `set_code()`. Note that
`set_code()`. In the second variant, both sides call `set_code()`. (Note that
this is not true for the "Transit" protocol used for bulk data-transfer: the
Transit class currently distinguishes "Sender" from "Receiver", so the
programs on each side must have some way to decide (ahead of time) which is
which.
programs on each side must have some way to decide ahead of time which is
which).
Each side gets to do one `send_data()` call and one `get_data()` call per
phase (see below). `get_data` will wait until the other side has done
`send_data`, so the application developer must be careful to avoid deadlocks
(don't get before you send on both sides in the same protocol). When both
sides are done, they must call `close()`, to let the library know that the
connection is complete and it can deallocate the channel. If you forget to
call `close()`, the server will not free the channel, and other users will
suffer longer invitation codes as a result. To encourage `close()`, the
library will log an error if a Wormhole object is destroyed before being
closed.
Each side can then do an arbitrary number of `send()` and `get()` calls.
`send()` writes a message into the channel. `get()` waits for a new message
to be available, then returns it. The Wormhole is not meant as a long-term
communication channel, but some protocols work better if they can exchange an
initial pair of messages (perhaps offering some set of negotiable
capabilities), and then follow up with a second pair (to reveal the results
of the negotiation). Another use case is for an ACK that gets sent at the end
of a file transfer: the Wormhole is held open until the Transit object
reports completion, and the last message is a hash of the file contents to
prove it was received correctly.
Note: the application developer must be careful to avoid deadlocks (if both
sides want to `get()`, somebody has to `send()` first).
When both sides are done, they must call `close()`, to let the library know
that the connection is complete and it can deallocate the channel. If you
forget to call `close()`, the server will not free the channel, and other
users will suffer longer invitation codes as a result. To encourage
`close()`, the library will log an error if a Wormhole object is destroyed
before being closed.
To make it easier to call `close()`, the blocking Wormhole objects can be
used as a context manager. Just put your code in the body of a `with
Wormhole(ARGS) as w:` statement, and `close()` will automatically be called
wormhole(ARGS) as w:` statement, and `close()` will automatically be called
when the block exits (either successfully or due to an exception).
## Examples
@ -66,27 +85,27 @@ when the block exits (either successfully or due to an exception).
The synchronous+blocking flow looks like this:
```python
from wormhole.blocking.transcribe import Wormhole
from wormhole.blocking.transcribe import wormhole
from wormhole.public_relay import RENDEZVOUS_RELAY
mydata = b"initiator's data"
with Wormhole(u"appid", RENDEZVOUS_RELAY) as i:
with wormhole(u"appid", RENDEZVOUS_RELAY) as i:
code = i.get_code()
print("Invitation Code: %s" % code)
i.send_data(mydata)
theirdata = i.get_data()
i.send(mydata)
theirdata = i.get()
print("Their data: %s" % theirdata.decode("ascii"))
```
```python
import sys
from wormhole.blocking.transcribe import Wormhole
from wormhole.blocking.transcribe import wormhole
from wormhole.public_relay import RENDEZVOUS_RELAY
mydata = b"receiver's data"
code = sys.argv[1]
with Wormhole(u"appid", RENDEZVOUS_RELAY) as r:
with wormhole(u"appid", RENDEZVOUS_RELAY) as r:
r.set_code(code)
r.send_data(mydata)
theirdata = r.get_data()
r.send(mydata)
theirdata = r.get()
print("Their data: %s" % theirdata.decode("ascii"))
```
@ -97,18 +116,18 @@ The Twisted-friendly flow looks like this:
```python
from twisted.internet import reactor
from wormhole.public_relay import RENDEZVOUS_RELAY
from wormhole.twisted.transcribe import Wormhole
from wormhole.twisted.transcribe import wormhole
outbound_message = b"outbound data"
w1 = Wormhole(u"appid", RENDEZVOUS_RELAY)
w1 = wormhole(u"appid", RENDEZVOUS_RELAY, reactor)
d = w1.get_code()
def _got_code(code):
print "Invitation Code:", code
return w1.send_data(outbound_message)
return w1.send(outbound_message)
d.addCallback(_got_code)
d.addCallback(lambda _: w1.get_data())
def _got_data(inbound_message):
d.addCallback(lambda _: w1.get())
def _got(inbound_message):
print "Inbound message:", inbound_message
d.addCallback(_got_data)
d.addCallback(_got)
d.addCallback(w1.close)
d.addBoth(lambda _: reactor.stop())
reactor.run()
@ -117,9 +136,9 @@ reactor.run()
On the other side, you call `set_code()` instead of waiting for `get_code()`:
```python
w2 = Wormhole(u"appid", RENDEZVOUS_RELAY)
w2 = wormhole(u"appid", RENDEZVOUS_RELAY, reactor)
w2.set_code(code)
d = w2.send_data(my_message)
d = w2.send(my_message)
...
```
@ -127,56 +146,45 @@ Note that the Twisted-form `close()` accepts (and returns) an optional
argument, so you can use `d.addCallback(w.close)` instead of
`d.addCallback(lambda _: w.close())`.
## Phases
If necessary, more than one message can be exchanged through the relay
server. It is not meant as a long-term communication channel, but some
protocols work better if they can exchange an initial pair of messages
(perhaps offering some set of negotiable capabilities), and then follow up
with a second pair (to reveal the results of the negotiation).
To support this, `send_data()/get_data()` accept a "phase" argument: an
arbitrary (unicode) string. It must match the other side: calling
`send_data(data, phase=u"offer")` on one side will deliver that data to
`get_data(phase=u"offer")` on the other.
It is a UsageError to call `send_data()` or `get_data()` twice with the same
phase name. The relay server may limit the number of phases that may be
exchanged, however it will always allow at least two.
## Verifier
You can call `w.get_verifier()` before `send_data()/get_data()`: this will
perform the first half of the PAKE negotiation, then return a verifier object
(bytes) which can be converted into a printable representation and manually
compared. When the users are convinced that `get_verifier()` from both sides
are the same, call `send_data()/get_data()` to continue the transfer. If you
call `send_data()/get_data()` before `get_verifier()`, it will perform the
complete transfer without pausing.
For extra protection against guessing attacks, Wormhole can provide a
"Verifier". This is a moderate-length series of bytes (a SHA256 hash) that is
derived from the supposedly-shared session key. If desired, both sides can
display this value, and the humans can manually compare them before allowing
the rest of the protocol to proceed. If they do not match, then the two
programs are not talking to each other (they may both be talking to a
man-in-the-middle attacker), and the protocol should be abandoned.
To retrieve the verifier, you call `w.get_verifier()` before any calls to
`send()/get()`. Turn this into hex or Base64 to print it, or render it as
ASCII-art, etc. Once the users are convinced that `get_verifier()` from both
sides are the same, call `send()/get()` to continue the protocol. If you call
`send()/get()` before `get_verifier()`, it will perform the complete protocol
without pausing.
The Twisted form of `get_verifier()` returns a Deferred that fires with the
verifier bytes.
## Generating the Invitation Code
In most situations, the "sending" or "initiating" side will call
`i.get_code()` to generate the invitation code. This returns a string in the
form `NNN-code-words`. The numeric "NNN" prefix is the "channel id", and is a
In most situations, the "sending" or "initiating" side will call `get_code()`
to generate the invitation code. This returns a string in the form
`NNN-code-words`. The numeric "NNN" prefix is the "channel id", and is a
short integer allocated by talking to the rendezvous server. The rest is a
randomly-generated selection from the PGP wordlist, providing a default of 16
bits of entropy. The initiating program should display this code to the user,
who should transcribe it to the receiving user, who gives it to the Receiver
object by calling `r.set_code()`. The receiving program can also use
object by calling `set_code()`. The receiving program can also use
`input_code_with_completion()` to use a readline-based input function: this
offers tab completion of allocated channel-ids and known codewords.
Alternatively, the human users can agree upon an invitation code themselves,
and provide it to both programs later (with `i.set_code()` and
`r.set_code()`). They should choose a channel-id that is unlikely to already
be in use (3 or more digits are recommended), append a hyphen, and then
include randomly-selected words or characters. Dice, coin flips, shuffled
cards, or repeated sampling of a high-resolution stopwatch are all useful
techniques.
and provide it to both programs later (both sides call `set_code()`). They
should choose a channel-id that is unlikely to already be in use (3 or more
digits are recommended), append a hyphen, and then include randomly-selected
words or characters. Dice, coin flips, shuffled cards, or repeated sampling
of a high-resolution stopwatch are all useful techniques.
Note that the code is a human-readable string (the python "unicode" type in
python2, "str" in python3).
@ -192,8 +200,8 @@ invitation codes are scoped to the app-id. Note that the app-id must be
unicode, not bytes, so on python2 use `u"appid"`.
Distinct app-ids reduce the size of the connection-id numbers. If fewer than
ten initiators are active for a given app-id, the connection-id will only
need to contain a single digit, even if some other app-id is currently using
ten Wormholes are active for a given app-id, the connection-id will only need
to contain a single digit, even if some other app-id is currently using
thousands of concurrent sessions.
## Rendezvous Relays
@ -245,16 +253,14 @@ You may not be able to hold the Wormhole object in memory for the whole sync
process: maybe you allow it to wait for several days, but the program will be
restarted during that time. To support this, you can persist the state of the
object by calling `data = w.serialize()`, which will return a printable
bytestring (the JSON-encoding of a small dictionary). To restore, use the
`from_serialized(data)` classmethod (e.g. `w =
Wormhole.from_serialized(data)`).
bytestring (the JSON-encoding of a small dictionary). To restore, use `w =
wormhole_from_serialized(data, reactor)`.
There is exactly one point at which you can serialize the wormhole: *after*
establishing the invitation code, but before waiting for `get_verifier()` or
`get_data()`, or calling `send_data()`. If you are creating a new invitation
code, the correct time is during the callback fired by `get_code()`. If you
are accepting a pre-generated code, the time is just after calling
`set_code()`.
`get()`, or calling `send()`. If you are creating a new invitation code, the
correct time is during the callback fired by `get_code()`. If you are
accepting a pre-generated code, the time is just after calling `set_code()`.
To properly checkpoint the process, you should store the first message
(returned by `start()`) next to the serialized wormhole instance, so you can
@ -278,9 +284,3 @@ in python3):
* transit connection hints (e.g. "host:port")
* application identifier
* derived-key "purpose" string: `w.derive_key(PURPOSE)`
## Detailed Example
```python
```

93
events.dot Normal file
View File

@ -0,0 +1,93 @@
digraph {
api_get_code [label="get_code" shape="hexagon" color="red"]
api_input_code [label="input_code" shape="hexagon" color="red"]
api_set_code [label="set_code" shape="hexagon" color="red"]
send [label="API\nsend" shape="hexagon" color="red"]
get [label="API\nget" shape="hexagon" color="red"]
close [label="API\nclose" shape="hexagon" color="red"]
event_connected [label="connected" shape="box"]
event_learned_code [label="learned\ncode" shape="box"]
event_learned_nameplate [label="learned\nnameplate" shape="box"]
event_received_mailbox [label="received\nmailbox" shape="box"]
event_opened_mailbox [label="opened\nmailbox" shape="box"]
event_built_msg1 [label="built\nmsg1" shape="box"]
event_mailbox_used [label="mailbox\nused" shape="box"]
event_learned_PAKE [label="learned\nmsg2" shape="box"]
event_established_key [label="established\nkey" shape="box"]
event_computed_verifier [label="computed\nverifier" shape="box"]
event_received_confirm [label="received\nconfirm" shape="box"]
event_received_message [label="received\nmessage" shape="box"]
event_received_released [label="ack\nreleased" shape="box"]
event_received_closed [label="ack\nclosed" shape="box"]
event_connected -> api_get_code
event_connected -> api_input_code
api_get_code -> event_learned_code
api_input_code -> event_learned_code
api_set_code -> event_learned_code
maybe_build_msg1 [label="build\nmsg1"]
maybe_claim_nameplate [label="claim\nnameplate"]
maybe_send_pake [label="send\npake"]
maybe_send_phase_messages [label="send\nphase\nmessages"]
event_connected -> maybe_claim_nameplate
event_connected -> maybe_send_pake
event_built_msg1 -> maybe_send_pake
event_learned_code -> maybe_build_msg1
event_learned_code -> event_learned_nameplate
maybe_build_msg1 -> event_built_msg1
event_learned_nameplate -> maybe_claim_nameplate
maybe_claim_nameplate -> event_received_mailbox [style="dashed"]
event_received_mailbox -> event_opened_mailbox
maybe_claim_nameplate -> event_learned_PAKE [style="dashed"]
maybe_claim_nameplate -> event_received_confirm [style="dashed"]
event_opened_mailbox -> event_learned_PAKE [style="dashed"]
event_learned_PAKE -> event_mailbox_used [style="dashed"]
event_learned_PAKE -> event_received_confirm [style="dashed"]
event_received_confirm -> event_received_message [style="dashed"]
send -> maybe_send_phase_messages
release_nameplate [label="release\nnameplate"]
event_mailbox_used -> release_nameplate
event_opened_mailbox -> maybe_send_pake
event_opened_mailbox -> maybe_send_phase_messages
event_learned_PAKE -> event_established_key
event_established_key -> event_computed_verifier
event_established_key -> maybe_send_phase_messages
check_verifier [label="check\nverifier"]
event_computed_verifier -> check_verifier
event_received_confirm -> check_verifier
check_verifier -> error
event_received_message -> error
event_received_message -> get
event_established_key -> get
close -> close_mailbox
close -> release_nameplate
error [label="signal\nerror"]
error -> close_mailbox
error -> release_nameplate
release_nameplate -> event_received_released [style="dashed"]
close_mailbox [label="close\nmailbox"]
close_mailbox -> event_received_closed [style="dashed"]
maybe_close_websocket [label="close\nwebsocket"]
event_received_released -> maybe_close_websocket
event_received_closed -> maybe_close_websocket
maybe_close_websocket -> event_websocket_closed [style="dashed"]
event_websocket_closed [label="websocket\nclosed"]
}

View File

@ -14,7 +14,6 @@ setup(name="magic-wormhole",
url="https://github.com/warner/magic-wormhole",
package_dir={"": "src"},
packages=["wormhole",
"wormhole.blocking",
"wormhole.cli",
"wormhole.server",
"wormhole.test",
@ -25,11 +24,11 @@ setup(name="magic-wormhole",
["wormhole = wormhole.cli.runner:entry",
"wormhole-server = wormhole.server.runner:entry",
]},
install_requires=["spake2==0.3", "pynacl", "requests", "argparse",
"six", "twisted >= 16.1.0", "hkdf", "tqdm",
"autobahn[twisted]", "pytrie",
# autobahn seems to have a bug, and one plugin throws
# errors unless pytrie is installed
install_requires=["spake2==0.7", "pynacl", "argparse",
"six",
"twisted==16.1.1", # since autobahn pins it
"autobahn[twisted]",
"hkdf", "tqdm",
],
extras_require={"tor": ["txtorcon", "ipaddr"]},
test_suite="wormhole.test",

View File

@ -1,49 +0,0 @@
from __future__ import print_function, unicode_literals
import requests
class EventSourceFollower:
def __init__(self, url, timeout):
self._resp = requests.get(url,
headers={"accept": "text/event-stream"},
stream=True,
timeout=timeout)
self._resp.raise_for_status()
self._lines_iter = self._resp.iter_lines(chunk_size=1,
decode_unicode=True)
def close(self):
self._resp.close()
def iter_events(self):
# I think Request.iter_lines and .iter_content use chunk_size= in a
# funny way, and nothing happens until at least that much data has
# arrived. So unless we set chunk_size=1, we won't hear about lines
# for a long time. I'd prefer that chunk_size behaved like
# read(size), and gave you 1<=x<=size bytes in response.
eventtype = "message"
current_lines = []
for line in self._lines_iter:
assert isinstance(line, type(u"")), type(line)
if not line:
# blank line ends the field: deliver event, reset for next
yield (eventtype, "\n".join(current_lines))
eventtype = "message"
current_lines[:] = []
continue
if ":" in line:
fieldname, data = line.split(":", 1)
if data.startswith(" "):
data = data[1:]
else:
fieldname = line
data = ""
if fieldname == "event":
eventtype = data
elif fieldname == "data":
current_lines.append(data)
elif fieldname in ("id", "retry"):
# documented but unhandled
pass
else:
#log.msg("weird fieldname", fieldname, data)
pass

View File

@ -1,413 +0,0 @@
from __future__ import print_function
import os, sys, time, re, requests, json, unicodedata
from six.moves.urllib_parse import urlencode
from binascii import hexlify, unhexlify
from spake2 import SPAKE2_Symmetric
from nacl.secret import SecretBox
from nacl.exceptions import CryptoError
from nacl import utils
from .eventsource import EventSourceFollower
from .. import __version__
from .. import codes
from ..errors import ServerError, Timeout, WrongPasswordError, UsageError
from ..timing import DebugTiming
from hkdf import Hkdf
from ..channel_monitor import monitor
def HKDF(skm, outlen, salt=None, CTXinfo=b""):
return Hkdf(salt, skm).expand(CTXinfo, outlen)
SECOND = 1
MINUTE = 60*SECOND
CONFMSG_NONCE_LENGTH = 128//8
CONFMSG_MAC_LENGTH = 256//8
def make_confmsg(confkey, nonce):
return nonce+HKDF(confkey, CONFMSG_MAC_LENGTH, nonce)
def to_bytes(u):
return unicodedata.normalize("NFC", u).encode("utf-8")
class Channel:
def __init__(self, relay_url, appid, channelid, side, handle_welcome,
wait, timeout, timing):
self._relay_url = relay_url
self._appid = appid
self._channelid = channelid
self._side = side
self._handle_welcome = handle_welcome
self._messages = set() # (phase,body) , body is bytes
self._sent_messages = set() # (phase,body)
self._started = time.time()
self._wait = wait
self._timeout = timeout
self._timing = timing
def _add_inbound_messages(self, messages):
for msg in messages:
phase = msg["phase"]
body = unhexlify(msg["body"].encode("ascii"))
self._messages.add( (phase, body) )
def _find_inbound_message(self, phases):
their_messages = self._messages - self._sent_messages
for phase in phases:
for (their_phase,body) in their_messages:
if their_phase == phase:
return (phase, body)
return None
def send(self, phase, msg):
# TODO: retry on failure, with exponential backoff. We're guarding
# against the rendezvous server being temporarily offline.
if not isinstance(phase, type(u"")): raise TypeError(type(phase))
if not isinstance(msg, type(b"")): raise TypeError(type(msg))
self._sent_messages.add( (phase,msg) )
payload = {"appid": self._appid,
"channelid": self._channelid,
"side": self._side,
"phase": phase,
"body": hexlify(msg).decode("ascii")}
data = json.dumps(payload).encode("utf-8")
with self._timing.add("send %s" % phase):
r = requests.post(self._relay_url+"add", data=data,
timeout=self._timeout)
r.raise_for_status()
resp = r.json()
if "welcome" in resp:
self._handle_welcome(resp["welcome"])
self._add_inbound_messages(resp["messages"])
def get_first_of(self, phases):
if not isinstance(phases, (list, set)): raise TypeError(type(phases))
for phase in phases:
if not isinstance(phase, type(u"")): raise TypeError(type(phase))
# For now, server errors cause the client to fail. TODO: don't. This
# will require changing the client to re-post messages when the
# server comes back up.
# fire with a bytestring of the first message for any 'phase' that
# wasn't one of our own messages. It will either come from
# previously-received messages, or from an EventSource that we attach
# to the corresponding URL
with self._timing.add("get %s" % "/".join(sorted(phases))):
phase_and_body = self._find_inbound_message(phases)
while phase_and_body is None:
remaining = self._started + self._timeout - time.time()
if remaining < 0:
raise Timeout
queryargs = urlencode([("appid", self._appid),
("channelid", self._channelid)])
f = EventSourceFollower(self._relay_url+"watch?%s" % queryargs,
remaining)
# we loop here until the connection is lost, or we see the
# message we want
for (eventtype, line) in f.iter_events():
if eventtype == "welcome":
self._handle_welcome(json.loads(line))
if eventtype == "message":
data = json.loads(line)
self._add_inbound_messages([data])
phase_and_body = self._find_inbound_message(phases)
if phase_and_body:
f.close()
break
if not phase_and_body:
time.sleep(self._wait)
return phase_and_body
def get(self, phase):
(got_phase, body) = self.get_first_of([phase])
assert got_phase == phase
return body
def deallocate(self, mood=None):
# only try once, no retries
data = json.dumps({"appid": self._appid,
"channelid": self._channelid,
"side": self._side,
"mood": mood}).encode("utf-8")
try:
# ignore POST failure, don't call r.raise_for_status(), set a
# short timeout and ignore failures
with self._timing.add("close"):
r = requests.post(self._relay_url+"deallocate", data=data,
timeout=5)
r.json()
except requests.exceptions.RequestException:
pass
class ChannelManager:
def __init__(self, relay_url, appid, side, handle_welcome, timing=None,
wait=0.5*SECOND, timeout=3*MINUTE):
self._relay_url = relay_url
self._appid = appid
self._side = side
self._handle_welcome = handle_welcome
self._timing = timing or DebugTiming()
self._wait = wait
self._timeout = timeout
def list_channels(self):
queryargs = urlencode([("appid", self._appid)])
with self._timing.add("list"):
r = requests.get(self._relay_url+"list?%s" % queryargs,
timeout=self._timeout)
r.raise_for_status()
data = r.json()
if "welcome" in data:
self._handle_welcome(data["welcome"])
channelids = data["channelids"]
return channelids
def allocate(self):
data = json.dumps({"appid": self._appid,
"side": self._side}).encode("utf-8")
with self._timing.add("allocate"):
r = requests.post(self._relay_url+"allocate", data=data,
timeout=self._timeout)
r.raise_for_status()
data = r.json()
if "welcome" in data:
self._handle_welcome(data["welcome"])
channelid = data["channelid"]
return channelid
def connect(self, channelid):
return Channel(self._relay_url, self._appid, channelid, self._side,
self._handle_welcome, self._wait, self._timeout,
self._timing)
def close_on_error(f): # method decorator
# Clients report certain errors as "moods", so the server can make a
# rough count failed connections (due to mismatched passwords, attacks,
# or timeouts). We don't report precondition failures, as those are the
# responsibility/fault of the local application code. We count
# non-precondition errors in case they represent server-side problems.
def _f(self, *args, **kwargs):
try:
return f(self, *args, **kwargs)
except Timeout:
self.close(u"lonely")
raise
except WrongPasswordError:
self.close(u"scary")
raise
except (TypeError, UsageError):
# preconditions don't warrant _close_with_error()
raise
except:
self.close(u"errory")
raise
return _f
class Wormhole:
motd_displayed = False
version_warning_displayed = False
_send_confirm = True
def __init__(self, appid, relay_url, wait=0.5*SECOND, timeout=3*MINUTE,
timing=None):
if not isinstance(appid, type(u"")): raise TypeError(type(appid))
if not isinstance(relay_url, type(u"")):
raise TypeError(type(relay_url))
if not relay_url.endswith(u"/"): raise UsageError
self._appid = appid
self._relay_url = relay_url
self._wait = wait
self._timeout = timeout
self._timing = timing or DebugTiming()
side = hexlify(os.urandom(5)).decode("ascii")
self._channel_manager = ChannelManager(relay_url, appid, side,
self.handle_welcome,
self._timing,
self._wait, self._timeout)
self._channel = None
self.code = None
self.key = None
self.verifier = None
self._sent_data = set() # phases
self._got_data = set()
self._got_confirmation = False
self._closed = False
self._timing_started = self._timing.add("wormhole")
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
return False
def handle_welcome(self, welcome):
if ("motd" in welcome and
not self.motd_displayed):
motd_lines = welcome["motd"].splitlines()
motd_formatted = "\n ".join(motd_lines)
print("Server (at %s) says:\n %s" % (self._relay_url, motd_formatted),
file=sys.stderr)
self.motd_displayed = True
# Only warn if we're running a release version (e.g. 0.0.6, not
# 0.0.6-DISTANCE-gHASH). Only warn once.
if ("-" not in __version__ and
not self.version_warning_displayed and
welcome["current_version"] != __version__):
print("Warning: errors may occur unless both sides are running the same version", file=sys.stderr)
print("Server claims %s is current, but ours is %s"
% (welcome["current_version"], __version__), file=sys.stderr)
self.version_warning_displayed = True
if "error" in welcome:
raise ServerError(welcome["error"], self._relay_url)
def get_code(self, code_length=2):
if self.code is not None: raise UsageError
channelid = self._channel_manager.allocate()
code = codes.make_code(channelid, code_length)
assert isinstance(code, type(u"")), type(code)
self._set_code_and_channelid(code)
self._start()
return code
def input_code(self, prompt="Enter wormhole code: ", code_length=2):
lister = self._channel_manager.list_channels
# fetch the list of channels ahead of time, to give us a chance to
# discover the welcome message (and warn the user about an obsolete
# client)
initial_channelids = lister()
with self._timing.add("input code", waiting="user"):
code = codes.input_code_with_completion(prompt,
initial_channelids, lister,
code_length)
return code
def set_code(self, code): # used for human-made pre-generated codes
if not isinstance(code, type(u"")): raise TypeError(type(code))
if self.code is not None: raise UsageError
self._set_code_and_channelid(code)
self._start()
def _set_code_and_channelid(self, code):
if self.code is not None: raise UsageError
self._timing.add("code established")
mo = re.search(r'^(\d+)-', code)
if not mo:
raise ValueError("code (%s) must start with NN-" % code)
self.code = code
channelid = int(mo.group(1))
self._channel = self._channel_manager.connect(channelid)
monitor.add(self._channel)
def _start(self):
# allocate the rest now too, so it can be serialized
self.sp = SPAKE2_Symmetric(to_bytes(self.code),
idSymmetric=to_bytes(self._appid))
self.msg1 = self.sp.start()
def derive_key(self, purpose, length=SecretBox.KEY_SIZE):
if not isinstance(purpose, type(u"")): raise TypeError(type(purpose))
return HKDF(self.key, length, CTXinfo=to_bytes(purpose))
def _encrypt_data(self, key, data):
assert isinstance(key, type(b"")), type(key)
assert isinstance(data, type(b"")), type(data)
assert len(key) == SecretBox.KEY_SIZE, len(key)
box = SecretBox(key)
nonce = utils.random(SecretBox.NONCE_SIZE)
return box.encrypt(data, nonce)
def _decrypt_data(self, key, encrypted):
assert isinstance(key, type(b"")), type(key)
assert isinstance(encrypted, type(b"")), type(encrypted)
assert len(key) == SecretBox.KEY_SIZE, len(key)
box = SecretBox(key)
data = box.decrypt(encrypted)
return data
def _get_key(self):
if not self.key:
self._channel.send(u"pake", self.msg1)
pake_msg = self._channel.get(u"pake")
self.key = self.sp.finish(pake_msg)
self.verifier = self.derive_key(u"wormhole:verifier")
self._timing.add("key established")
if not self._send_confirm:
return
confkey = self.derive_key(u"wormhole:confirmation")
nonce = os.urandom(CONFMSG_NONCE_LENGTH)
confmsg = make_confmsg(confkey, nonce)
self._channel.send(u"_confirm", confmsg)
@close_on_error
def get_verifier(self):
if self._closed: raise UsageError
if self.code is None: raise UsageError
if self._channel is None: raise UsageError
self._get_key()
return self.verifier
@close_on_error
def send_data(self, outbound_data, phase=u"data"):
if not isinstance(outbound_data, type(b"")):
raise TypeError(type(outbound_data))
if not isinstance(phase, type(u"")): raise TypeError(type(phase))
if self._closed: raise UsageError
if phase in self._sent_data: raise UsageError # only call this once
if phase.startswith(u"_"): raise UsageError # reserved for internals
if self.code is None: raise UsageError
if self._channel is None: raise UsageError
with self._timing.add("API send data", phase=phase):
# Without predefined roles, we can't derive predictably unique
# keys for each side, so we use the same key for both. We use
# random nonces to keep the messages distinct, and the Channel
# automatically ignores reflections.
self._sent_data.add(phase)
self._get_key()
data_key = self.derive_key(u"wormhole:phase:%s" % phase)
outbound_encrypted = self._encrypt_data(data_key, outbound_data)
self._channel.send(phase, outbound_encrypted)
@close_on_error
def get_data(self, phase=u"data"):
if not isinstance(phase, type(u"")): raise TypeError(type(phase))
if phase in self._got_data: raise UsageError # only call this once
if phase.startswith(u"_"): raise UsageError # reserved for internals
if self._closed: raise UsageError
if self.code is None: raise UsageError
if self._channel is None: raise UsageError
with self._timing.add("API get data", phase=phase):
self._got_data.add(phase)
self._get_key()
phases = []
if not self._got_confirmation:
phases.append(u"_confirm")
phases.append(phase)
(got_phase, body) = self._channel.get_first_of(phases)
if got_phase == u"_confirm":
confkey = self.derive_key(u"wormhole:confirmation")
nonce = body[:CONFMSG_NONCE_LENGTH]
if body != make_confmsg(confkey, nonce):
raise WrongPasswordError
self._got_confirmation = True
(got_phase, body) = self._channel.get_first_of([phase])
assert got_phase == phase
try:
data_key = self.derive_key(u"wormhole:phase:%s" % phase)
inbound_data = self._decrypt_data(data_key, body)
return inbound_data
except CryptoError:
raise WrongPasswordError
def close(self, mood=u"happy"):
if not isinstance(mood, (type(None), type(u""))):
raise TypeError(type(mood))
self._closed = True
if self._channel:
self._timing_started.finish(mood=mood)
c, self._channel = self._channel, None
monitor.close(c)
c.deallocate(mood)

View File

@ -3,7 +3,7 @@ import os, sys, json, binascii, six, tempfile, zipfile
from tqdm import tqdm
from twisted.internet import reactor
from twisted.internet.defer import inlineCallbacks, returnValue
from ..twisted.transcribe import Wormhole
from ..wormhole import wormhole
from ..twisted.transit import TransitReceiver
from ..errors import TransferError
@ -45,9 +45,8 @@ class TwistedReceiver:
# can lazy-provide an endpoint, and overlap the startup process
# with the user handing off the wormhole code
yield tor_manager.start()
w = Wormhole(APPID, self.args.relay_url, tor_manager,
timing=self.args.timing,
reactor=self._reactor)
w = wormhole(APPID, self.args.relay_url, self._reactor,
tor_manager, timing=self.args.timing)
# I wanted to do this instead:
#
# try:
@ -65,12 +64,12 @@ class TwistedReceiver:
@inlineCallbacks
def _go(self, w, tor_manager):
yield self.handle_code(w)
verifier = yield w.get_verifier()
verifier = yield w.verify()
self.show_verifier(verifier)
them_d = yield self.get_data(w)
try:
if "message" in them_d:
yield self.handle_text(them_d, w)
self.handle_text(them_d, w)
returnValue(None)
if "file" in them_d:
f = self.handle_file(them_d)
@ -90,7 +89,7 @@ class TwistedReceiver:
raise RespondError("unknown offer type")
except RespondError as r:
data = json.dumps({"error": r.response}).encode("utf-8")
yield w.send_data(data)
w.send(data)
raise TransferError(r.response)
returnValue(None)
@ -100,10 +99,11 @@ class TwistedReceiver:
if self.args.zeromode:
assert not code
code = u"0-"
if not code:
code = yield w.input_code("Enter receive wormhole code: ",
self.args.code_length)
yield w.set_code(code)
if code:
w.set_code(code)
else:
yield w.input_code("Enter receive wormhole code: ",
self.args.code_length)
def show_verifier(self, verifier):
verifier_hex = binascii.hexlify(verifier).decode("ascii")
@ -113,18 +113,17 @@ class TwistedReceiver:
@inlineCallbacks
def get_data(self, w):
# this may raise WrongPasswordError
them_bytes = yield w.get_data()
them_bytes = yield w.get()
them_d = json.loads(them_bytes.decode("utf-8"))
if "error" in them_d:
raise TransferError(them_d["error"])
returnValue(them_d)
@inlineCallbacks
def handle_text(self, them_d, w):
# we're receiving a text message
self.msg(them_d["message"])
data = json.dumps({"message_ack": "ok"}).encode("utf-8")
yield w.send_data(data, wait=True)
w.send(data)
def handle_file(self, them_d):
file_data = them_d["file"]
@ -183,12 +182,13 @@ class TwistedReceiver:
@inlineCallbacks
def establish_transit(self, w, them_d, tor_manager):
transit_key = w.derive_key(APPID+u"/transit-key")
transit_receiver = TransitReceiver(self.args.transit_helper,
no_listen=self.args.no_listen,
tor_manager=tor_manager,
reactor=self._reactor,
timing=self.args.timing)
transit_key = w.derive_key(APPID+u"/transit-key",
transit_receiver.TRANSIT_KEY_LENGTH)
transit_receiver.set_transit_key(transit_key)
direct_hints = yield transit_receiver.get_direct_hints()
relay_hints = yield transit_receiver.get_relay_hints()
@ -199,7 +199,7 @@ class TwistedReceiver:
"relay_connection_hints": relay_hints,
},
}).encode("utf-8")
yield w.send_data(data)
w.send(data)
# now receive the rest of the owl
tdata = them_d["transit"]

View File

@ -5,7 +5,7 @@ from twisted.protocols import basic
from twisted.internet import reactor
from twisted.internet.defer import inlineCallbacks, returnValue
from ..errors import TransferError
from ..twisted.transcribe import Wormhole
from ..wormhole import wormhole
from ..twisted.transit import TransitSender
APPID = u"lothar.com/wormhole/text-or-file-xfer"
@ -49,8 +49,8 @@ def send(args, reactor=reactor):
# user handing off the wormhole code
yield tor_manager.start()
w = Wormhole(APPID, args.relay_url, tor_manager, timing=args.timing,
reactor=reactor)
w = wormhole(APPID, args.relay_url, reactor, tor_manager,
timing=args.timing)
d = _send(reactor, w, args, phase1, fd_to_send, tor_manager)
d.addBoth(w.close)
@ -83,7 +83,7 @@ def _send(reactor, w, args, phase1, fd_to_send, tor_manager):
# get the verifier, because that also lets us derive the transit key,
# which we want to set before revealing the connection hints to the far
# side, so we'll be ready for them when they connect
verifier_bytes = yield w.get_verifier()
verifier_bytes = yield w.verify()
verifier = binascii.hexlify(verifier_bytes).decode("ascii")
if args.verify:
@ -94,17 +94,18 @@ def _send(reactor, w, args, phase1, fd_to_send, tor_manager):
if ok.lower() == "no":
err = "sender rejected verification check, abandoned transfer"
reject_data = json.dumps({"error": err}).encode("utf-8")
yield w.send_data(reject_data)
w.send(reject_data)
raise TransferError(err)
if fd_to_send is not None:
transit_key = w.derive_key(APPID+"/transit-key")
transit_key = w.derive_key(APPID+"/transit-key",
transit_sender.TRANSIT_KEY_LENGTH)
transit_sender.set_transit_key(transit_key)
my_phase1_bytes = json.dumps(phase1).encode("utf-8")
yield w.send_data(my_phase1_bytes)
w.send(my_phase1_bytes)
# this may raise WrongPasswordError
them_phase1_bytes = yield w.get_data()
them_phase1_bytes = yield w.get()
them_phase1 = json.loads(them_phase1_bytes.decode("utf-8"))

View File

@ -1,5 +1,5 @@
# This is a relay I run on a personal server. If it gets too expensive to
# run, I'll shut it down.
RENDEZVOUS_RELAY = u"http://wormhole-relay.petmail.org:3000/wormhole-relay/"
TRANSIT_RELAY = u"tcp:wormhole-transit-relay.petmail.org:3001"
RENDEZVOUS_RELAY = u"ws://wormhole-relay.petmail.org:4000/"
TRANSIT_RELAY = u"tcp:wormhole-transit-relay.petmail.org:4001"

View File

@ -4,6 +4,7 @@ from .wordlist import (byte_to_even_word, byte_to_odd_word,
even_words_lowercase, odd_words_lowercase)
def make_code(channel_id, code_length):
assert isinstance(channel_id, type(u"")), type(channel_id)
words = []
for i in range(code_length):
# we start with an "odd word"
@ -11,7 +12,7 @@ def make_code(channel_id, code_length):
words.append(byte_to_odd_word[os.urandom(1)].lower())
else:
words.append(byte_to_even_word[os.urandom(1)].lower())
return u"%d-%s" % (channel_id, u"-".join(words))
return u"%s-%s" % (channel_id, u"-".join(words))
def extract_channel_id(code):
channel_id = int(code.split("-")[0])

View File

@ -20,6 +20,10 @@ def handle_server_error(func):
class Timeout(Exception):
pass
class WelcomeError(Exception):
"""The server told us to signal an error, probably because our version is
too old to possibly work."""
class WrongPasswordError(Exception):
"""
Key confirmation failed. Either you or your correspondent typed the code
@ -37,5 +41,8 @@ class ReflectionAttack(Exception):
class UsageError(Exception):
"""The programmer did something wrong."""
class WormholeClosedError(UsageError):
"""API calls may not be made after close() is called."""
class TransferError(Exception):
"""Something bad happened and the transfer failed."""

View File

@ -18,9 +18,9 @@ s = parser.add_subparsers(title="subcommands", dest="subcommand")
# CLI: run-server
sp_start = s.add_parser("start", description="Start a relay server",
usage="wormhole server start [opts] [TWISTD-ARGS..]")
sp_start.add_argument("--rendezvous", default="tcp:3000", metavar="tcp:PORT",
sp_start.add_argument("--rendezvous", default="tcp:4000", metavar="tcp:PORT",
help="endpoint specification for the rendezvous port")
sp_start.add_argument("--transit", default="tcp:3001", metavar="tcp:PORT",
sp_start.add_argument("--transit", default="tcp:4001", metavar="tcp:PORT",
help="endpoint specification for the transit-relay port")
sp_start.add_argument("--advertise-version", metavar="VERSION",
help="version to recommend to clients")

View File

@ -22,7 +22,7 @@ def get_db(dbfile, stderr=sys.stderr):
raise DBError("Unable to create/open db file %s: %s" % (dbfile, e))
db.row_factory = sqlite3.Row
VERSION = 1
VERSION = 2
if must_create:
schema = get_schema(VERSION)
db.executescript(schema)

View File

@ -1,43 +0,0 @@
-- note: anything which isn't an boolean, integer, or human-readable unicode
-- string, (i.e. binary strings) will be stored as hex
CREATE TABLE `version`
(
`version` INTEGER -- contains one row, set to 1
);
CREATE TABLE `messages`
(
`appid` VARCHAR,
`channelid` INTEGER,
`side` VARCHAR,
`phase` VARCHAR, -- not numeric, more of a PAKE-phase indicator string
-- phase="_allocate" and "_deallocate" are used internally
`body` VARCHAR,
`server_rx` INTEGER,
`msgid` VARCHAR
);
CREATE INDEX `messages_idx` ON `messages` (`appid`, `channelid`);
CREATE TABLE `usage`
(
`type` VARCHAR, -- "rendezvous" or "transit"
`started` INTEGER, -- seconds since epoch, rounded to one day
`result` VARCHAR, -- happy, scary, lonely, errory, pruney
-- rendezvous moods:
-- "happy": both sides close with mood=happy
-- "scary": any side closes with mood=scary (bad MAC, probably wrong pw)
-- "lonely": any side closes with mood=lonely (no response from 2nd side)
-- "errory": any side closes with mood=errory (other errors)
-- "pruney": channels which get pruned for inactivity
-- "crowded": three or more sides were involved
-- transit moods:
-- "errory": this side have the wrong handshake
-- "lonely": good handshake, but the other side never showed up
-- "happy": both sides gave correct handshake
`total_bytes` INTEGER, -- for transit, total bytes relayed (both directions)
`total_time` INTEGER, -- seconds from start to closed, or None
`waiting_time` INTEGER -- seconds from start to 2nd side appearing, or None
);
CREATE INDEX `usage_idx` ON `usage` (`started`);

View File

@ -0,0 +1,103 @@
-- note: anything which isn't an boolean, integer, or human-readable unicode
-- string, (i.e. binary strings) will be stored as hex
CREATE TABLE `version`
(
`version` INTEGER -- contains one row, set to 2
);
-- Wormhole codes use a "nameplate": a short identifier which is only used to
-- reference a specific (long-named) mailbox. The codes only use numeric
-- nameplates, but the protocol and server allow can use arbitrary strings.
CREATE TABLE `nameplates`
(
`app_id` VARCHAR,
`id` VARCHAR,
`mailbox_id` VARCHAR, -- really a foreign key
`side1` VARCHAR, -- side name, or NULL
`side2` VARCHAR, -- side name, or NULL
`crowded` BOOLEAN, -- at some point, three or more sides were involved
`updated` INTEGER, -- time of last activity, used for pruning
-- timing data
`started` INTEGER, -- time when nameplace was opened
`second` INTEGER -- time when second side opened
);
CREATE INDEX `nameplates_idx` ON `nameplates` (`app_id`, `id`);
CREATE INDEX `nameplates_updated_idx` ON `nameplates` (`app_id`, `updated`);
CREATE INDEX `nameplates_mailbox_idx` ON `nameplates` (`app_id`, `mailbox_id`);
-- Clients exchange messages through a "mailbox", which has a long (randomly
-- unique) identifier and a queue of messages.
CREATE TABLE `mailboxes`
(
`app_id` VARCHAR,
`id` VARCHAR,
`side1` VARCHAR, -- side name, or NULL
`side2` VARCHAR, -- side name, or NULL
`crowded` BOOLEAN, -- at some point, three or more sides were involved
`first_mood` VARCHAR,
-- timing data for the mailbox itself
`started` INTEGER, -- time when opened
`second` INTEGER -- time when second side opened
);
CREATE INDEX `mailboxes_idx` ON `mailboxes` (`app_id`, `id`);
CREATE TABLE `messages`
(
`app_id` VARCHAR,
`mailbox_id` VARCHAR,
`side` VARCHAR,
`phase` VARCHAR, -- numeric or string
`body` VARCHAR,
`server_rx` INTEGER,
`msg_id` VARCHAR
);
CREATE INDEX `messages_idx` ON `messages` (`app_id`, `mailbox_id`);
CREATE TABLE `nameplate_usage`
(
`app_id` VARCHAR,
`started` INTEGER, -- seconds since epoch, rounded to "blur time"
`waiting_time` INTEGER, -- seconds from start to 2nd side appearing, or None
`total_time` INTEGER, -- seconds from open to last close/prune
`result` VARCHAR -- happy, lonely, pruney, crowded
-- nameplate moods:
-- "happy": two sides open and close
-- "lonely": one side opens and closes (no response from 2nd side)
-- "pruney": channels which get pruned for inactivity
-- "crowded": three or more sides were involved
);
CREATE INDEX `nameplate_usage_idx` ON `nameplate_usage` (`app_id`, `started`);
CREATE TABLE `mailbox_usage`
(
`app_id` VARCHAR,
`started` INTEGER, -- seconds since epoch, rounded to "blur time"
`total_time` INTEGER, -- seconds from open to last close
`waiting_time` INTEGER, -- seconds from start to 2nd side appearing, or None
`result` VARCHAR -- happy, scary, lonely, errory, pruney
-- rendezvous moods:
-- "happy": both sides close with mood=happy
-- "scary": any side closes with mood=scary (bad MAC, probably wrong pw)
-- "lonely": any side closes with mood=lonely (no response from 2nd side)
-- "errory": any side closes with mood=errory (other errors)
-- "pruney": channels which get pruned for inactivity
-- "crowded": three or more sides were involved
);
CREATE INDEX `mailbox_usage_idx` ON `mailbox_usage` (`app_id`, `started`);
CREATE TABLE `transit_usage`
(
`started` INTEGER, -- seconds since epoch, rounded to "blur time"
`total_time` INTEGER, -- seconds from open to last close
`waiting_time` INTEGER, -- seconds from start to 2nd side appearing, or None
`total_bytes` INTEGER, -- total bytes relayed (both directions)
`result` VARCHAR -- happy, scary, lonely, errory, pruney
-- transit moods:
-- "errory": one side gave the wrong handshake
-- "lonely": good handshake, but the other side never showed up
-- "happy": both sides gave correct handshake
);
CREATE INDEX `transit_usage_idx` ON `transit_usage` (`started`);

View File

@ -1,5 +1,6 @@
from __future__ import print_function
import time, random
import os, time, random, base64
from collections import namedtuple
from twisted.python import log
from twisted.application import service, internet
@ -12,92 +13,209 @@ MB = 1000*1000
CHANNEL_EXPIRATION_TIME = 3*DAY
EXPIRATION_CHECK_PERIOD = 2*HOUR
ALLOCATE = u"_allocate"
DEALLOCATE = u"_deallocate"
def get_sides(row):
return set([s for s in [row["side1"], row["side2"]] if s])
def make_sides(sides):
return list(sides) + [None] * (2 - len(sides))
def generate_mailbox_id():
return base64.b32encode(os.urandom(8)).lower().strip(b"=").decode("ascii")
class Channel:
def __init__(self, app, db, welcome, blur_usage, log_requests,
appid, channelid):
SideResult = namedtuple("SideResult", ["changed", "empty", "side1", "side2"])
Unchanged = SideResult(changed=False, empty=False, side1=None, side2=None)
class CrowdedError(Exception):
pass
def add_side(row, new_side):
old_sides = [s for s in [row["side1"], row["side2"]] if s]
assert old_sides
if new_side in old_sides:
return Unchanged
if len(old_sides) == 2:
raise CrowdedError("too many sides for this thing")
return SideResult(changed=True, empty=False,
side1=old_sides[0], side2=new_side)
def remove_side(row, side):
old_sides = [s for s in [row["side1"], row["side2"]] if s]
if side not in old_sides:
return Unchanged
remaining_sides = old_sides[:]
remaining_sides.remove(side)
if remaining_sides:
return SideResult(changed=True, empty=False, side1=remaining_sides[0],
side2=None)
return SideResult(changed=True, empty=True, side1=None, side2=None)
Usage = namedtuple("Usage", ["started", "waiting_time", "total_time", "result"])
TransitUsage = namedtuple("TransitUsage",
["started", "waiting_time", "total_time",
"total_bytes", "result"])
SidedMessage = namedtuple("SidedMessage", ["side", "phase", "body",
"server_rx", "msg_id"])
class Mailbox:
def __init__(self, app, db, blur_usage, log_requests, app_id, mailbox_id):
self._app = app
self._db = db
self._blur_usage = blur_usage
self._log_requests = log_requests
self._appid = appid
self._channelid = channelid
self._listeners = set() # instances with .send_rendezvous_event (that
# takes a JSONable object) and
# .stop_rendezvous_watcher()
self._app_id = app_id
self._mailbox_id = mailbox_id
self._listeners = {} # handle -> (send_f, stop_f)
# "handle" is a hashable object, for deregistration
# send_f() takes a JSONable object, stop_f() has no args
def get_channelid(self):
return self._channelid
def open(self, side, when):
# requires caller to db.commit()
assert isinstance(side, type(u"")), type(side)
db = self._db
row = db.execute("SELECT * FROM `mailboxes`"
" WHERE `app_id`=? AND `id`=?",
(self._app_id, self._mailbox_id)).fetchone()
try:
sr = add_side(row, side)
except CrowdedError:
db.execute("UPDATE `mailboxes` SET `crowded`=?"
" WHERE `app_id`=? AND `id`=?",
(True, self._app_id, self._mailbox_id))
db.commit()
raise
if sr.changed:
db.execute("UPDATE `mailboxes` SET"
" `side1`=?, `side2`=?, `second`=?"
" WHERE `app_id`=? AND `id`=?",
(sr.side1, sr.side2, when,
self._app_id, self._mailbox_id))
def get_messages(self):
messages = []
db = self._db
for row in db.execute("SELECT * FROM `messages`"
" WHERE `appid`=? AND `channelid`=?"
" WHERE `app_id`=? AND `mailbox_id`=?"
" ORDER BY `server_rx` ASC",
(self._appid, self._channelid)).fetchall():
if row["phase"] in (u"_allocate", u"_deallocate"):
continue
messages.append({"phase": row["phase"], "body": row["body"],
"server_rx": row["server_rx"], "id": row["msgid"]})
(self._app_id, self._mailbox_id)).fetchall():
sm = SidedMessage(side=row["side"], phase=row["phase"],
body=row["body"], server_rx=row["server_rx"],
msg_id=row["msg_id"])
messages.append(sm)
return messages
def add_listener(self, ep):
self._listeners.add(ep)
def add_listener(self, handle, send_f, stop_f):
self._listeners[handle] = (send_f, stop_f)
return self.get_messages()
def remove_listener(self, ep):
self._listeners.discard(ep)
def remove_listener(self, handle):
self._listeners.pop(handle)
def broadcast_message(self, phase, body, server_rx, msgid):
for ep in self._listeners:
ep.send_rendezvous_event({"phase": phase, "body": body,
"server_rx": server_rx, "id": msgid})
def broadcast_message(self, sm):
for (send_f, stop_f) in self._listeners.values():
send_f(sm)
def _add_message(self, side, phase, body, server_rx, msgid):
def _add_message(self, sm):
self._db.execute("INSERT INTO `messages`"
" (`app_id`, `mailbox_id`, `side`, `phase`, `body`,"
" `server_rx`, `msg_id`)"
" VALUES (?,?,?,?,?, ?,?)",
(self._app_id, self._mailbox_id, sm.side,
sm.phase, sm.body, sm.server_rx, sm.msg_id))
self._db.commit()
def add_message(self, sm):
assert isinstance(sm, SidedMessage)
self._add_message(sm)
self.broadcast_message(sm)
def close(self, side, mood, when):
assert isinstance(side, type(u"")), type(side)
db = self._db
db.execute("INSERT INTO `messages`"
" (`appid`, `channelid`, `side`, `phase`, `body`,"
" `server_rx`, `msgid`)"
" VALUES (?,?,?,?,?, ?,?)",
(self._appid, self._channelid, side, phase, body,
server_rx, msgid))
db.commit()
row = db.execute("SELECT * FROM `mailboxes`"
" WHERE `app_id`=? AND `id`=?",
(self._app_id, self._mailbox_id)).fetchone()
if not row:
return
sr = remove_side(row, side)
if sr.empty:
rows = db.execute("SELECT DISTINCT(`side`) FROM `messages`"
" WHERE `app_id`=? AND `mailbox_id`=?",
(self._app_id, self._mailbox_id)).fetchall()
num_sides = len(rows)
self._summarize_and_store(row, num_sides, mood, when, pruned=False)
self._delete()
db.commit()
elif sr.changed:
db.execute("UPDATE `mailboxes`"
" SET `side1`=?, `side2`=?, `first_mood`=?"
" WHERE `app_id`=? AND `id`=?",
(sr.side1, sr.side2, mood,
self._app_id, self._mailbox_id))
db.commit()
def allocate(self, side):
self._add_message(side, ALLOCATE, None, time.time(), None)
def _delete(self):
# requires caller to db.commit()
self._db.execute("DELETE FROM `mailboxes`"
" WHERE `app_id`=? AND `id`=?",
(self._app_id, self._mailbox_id))
self._db.execute("DELETE FROM `messages`"
" WHERE `app_id`=? AND `mailbox_id`=?",
(self._app_id, self._mailbox_id))
def add_message(self, side, phase, body, server_rx, msgid):
self._add_message(side, phase, body, server_rx, msgid)
self.broadcast_message(phase, body, server_rx, msgid)
return self.get_messages() # for rendezvous_web.py POST /add
# Shut down any listeners, just in case they're still lingering
# around.
for (send_f, stop_f) in self._listeners.values():
stop_f()
def deallocate(self, side, mood):
self._add_message(side, DEALLOCATE, mood, time.time(), None)
db = self._db
seen = set([row["side"] for row in
db.execute("SELECT `side` FROM `messages`"
" WHERE `appid`=? AND `channelid`=?",
(self._appid, self._channelid))])
freed = set([row["side"] for row in
db.execute("SELECT `side` FROM `messages`"
" WHERE `appid`=? AND `channelid`=?"
" AND `phase`=?",
(self._appid, self._channelid, DEALLOCATE))])
if seen - freed:
return False
self.delete_and_summarize()
return True
self._app.free_mailbox(self._mailbox_id)
def _summarize_and_store(self, row, num_sides, second_mood, delete_time,
pruned):
u = self._summarize(row, num_sides, second_mood, delete_time, pruned)
self._db.execute("INSERT INTO `mailbox_usage`"
" (`app_id`, "
" `started`, `total_time`, `waiting_time`, `result`)"
" VALUES (?, ?,?,?,?)",
(self._app_id,
u.started, u.total_time, u.waiting_time, u.result))
def _summarize(self, row, num_sides, second_mood, delete_time, pruned):
started = row["started"]
if self._blur_usage:
started = self._blur_usage * (started // self._blur_usage)
waiting_time = None
if row["second"]:
waiting_time = row["second"] - row["started"]
total_time = delete_time - row["started"]
if num_sides == 0:
result = u"quiet"
elif num_sides == 1:
result = u"lonely"
else:
result = u"happy"
moods = set([row["first_mood"], second_mood])
if u"lonely" in moods:
result = u"lonely"
if u"errory" in moods:
result = u"errory"
if u"scary" in moods:
result = u"scary"
if pruned:
result = u"pruney"
if row["crowded"]:
result = u"crowded"
return Usage(started=started, waiting_time=waiting_time,
total_time=total_time, result=result)
def is_idle(self):
if self._listeners:
return False
c = self._db.execute("SELECT `server_rx` FROM `messages`"
" WHERE `appid`=? AND `channelid`=?"
" WHERE `app_id`=? AND `mailbox_id`=?"
" ORDER BY `server_rx` DESC LIMIT 1",
(self._appid, self._channelid))
(self._app_id, self._mailbox_id))
rows = c.fetchall()
if not rows:
return True
@ -106,172 +224,224 @@ class Channel:
return True
return False
def _store_summary(self, summary):
(started, result, total_time, waiting_time) = summary
if self._blur_usage:
started = self._blur_usage * (started // self._blur_usage)
self._db.execute("INSERT INTO `usage`"
" (`type`, `started`, `result`,"
" `total_time`, `waiting_time`)"
" VALUES (?,?,?, ?,?)",
(u"rendezvous", started, result,
total_time, waiting_time))
self._db.commit()
def _summarize(self, messages, delete_time):
all_sides = set([m["side"] for m in messages])
if len(all_sides) == 0:
log.msg("_summarize was given zero messages") # shouldn't happen
return
started = min([m["server_rx"] for m in messages])
# 'total_time' is how long the channel was occupied. That ends now,
# both for channels that got pruned for inactivity, and for channels
# that got pruned because of two DEALLOCATE messages
total_time = delete_time - started
if len(all_sides) == 1:
return (started, "lonely", total_time, None)
if len(all_sides) > 2:
# TODO: it'll be useful to have more detail here
return (started, "crowded", total_time, None)
# exactly two sides were involved
A_side = sorted(messages, key=lambda m: m["server_rx"])[0]["side"]
B_side = list(all_sides - set([A_side]))[0]
# How long did the first side wait until the second side showed up?
first_A = min([m["server_rx"] for m in messages if m["side"] == A_side])
first_B = min([m["server_rx"] for m in messages if m["side"] == B_side])
waiting_time = first_B - first_A
# now, were all sides closed? If not, this is "pruney"
A_deallocs = [m for m in messages
if m["phase"] == DEALLOCATE and m["side"] == A_side]
B_deallocs = [m for m in messages
if m["phase"] == DEALLOCATE and m["side"] == B_side]
if not A_deallocs or not B_deallocs:
return (started, "pruney", total_time, None)
# ok, both sides closed. figure out the mood
A_mood = A_deallocs[0]["body"] # maybe None
B_mood = B_deallocs[0]["body"] # maybe None
mood = "quiet"
if A_mood == u"happy" and B_mood == u"happy":
mood = "happy"
if A_mood == u"lonely" or B_mood == u"lonely":
mood = "lonely"
if A_mood == u"errory" or B_mood == u"errory":
mood = "errory"
if A_mood == u"scary" or B_mood == u"scary":
mood = "scary"
return (started, mood, total_time, waiting_time)
def delete_and_summarize(self):
db = self._db
c = self._db.execute("SELECT * FROM `messages`"
" WHERE `appid`=? AND `channelid`=?"
" ORDER BY `server_rx`",
(self._appid, self._channelid))
messages = c.fetchall()
summary = self._summarize(messages, time.time())
self._store_summary(summary)
db.execute("DELETE FROM `messages`"
" WHERE `appid`=? AND `channelid`=?",
(self._appid, self._channelid))
db.commit()
# Shut down any listeners, just in case they're still lingering
# around.
for ep in self._listeners:
ep.stop_rendezvous_watcher()
self._app.free_channel(self._channelid)
def _shutdown(self):
# used at test shutdown to accelerate client disconnects
for ep in self._listeners:
ep.stop_rendezvous_watcher()
for (send_f, stop_f) in self._listeners.values():
stop_f()
class AppNamespace:
def __init__(self, db, welcome, blur_usage, log_requests, appid):
def __init__(self, db, welcome, blur_usage, log_requests, app_id):
self._db = db
self._welcome = welcome
self._blur_usage = blur_usage
self._log_requests = log_requests
self._appid = appid
self._channels = {}
self._app_id = app_id
self._mailboxes = {}
def get_allocated(self):
def get_nameplate_ids(self):
db = self._db
c = db.execute("SELECT DISTINCT `channelid` FROM `messages`"
" WHERE `appid`=?", (self._appid,))
return set([row["channelid"] for row in c.fetchall()])
# TODO: filter this to numeric ids?
c = db.execute("SELECT DISTINCT `id` FROM `nameplates`"
" WHERE `app_id`=?", (self._app_id,))
return set([row["id"] for row in c.fetchall()])
def find_available_channelid(self):
allocated = self.get_allocated()
def _find_available_nameplate_id(self):
claimed = self.get_nameplate_ids()
for size in range(1,4): # stick to 1-999 for now
available = set()
for cid in range(10**(size-1), 10**size):
if cid not in allocated:
available.add(cid)
for id_int in range(10**(size-1), 10**size):
id = u"%d" % id_int
if id not in claimed:
available.add(id)
if available:
return random.choice(list(available))
# ouch, 999 currently allocated. Try random ones for a while.
# ouch, 999 currently claimed. Try random ones for a while.
for tries in range(1000):
cid = random.randrange(1000, 1000*1000)
if cid not in allocated:
return cid
raise ValueError("unable to find a free channel-id")
id_int = random.randrange(1000, 1000*1000)
id = u"%d" % id_int
if id not in claimed:
return id
raise ValueError("unable to find a free nameplate-id")
def allocate_channel(self, channelid, side):
channel = self.get_channel(channelid)
channel.allocate(side)
return channel
def allocate_nameplate(self, side, when):
nameplate_id = self._find_available_nameplate_id()
mailbox_id = self.claim_nameplate(nameplate_id, side, when)
del mailbox_id # ignored, they'll learn it from claim()
return nameplate_id
def get_channel(self, channelid):
assert isinstance(channelid, int)
if not channelid in self._channels:
def claim_nameplate(self, nameplate_id, side, when):
# when we're done:
# * there will be one row for the nameplate
# * side1 or side2 will be populated
# * started or second will be populated
# * a mailbox id will be created, but not a mailbox row
# (ids are randomly unique, so we can defer creation until 'open')
assert isinstance(nameplate_id, type(u"")), type(nameplate_id)
assert isinstance(side, type(u"")), type(side)
db = self._db
row = db.execute("SELECT * FROM `nameplates`"
" WHERE `app_id`=? AND `id`=?",
(self._app_id, nameplate_id)).fetchone()
if row:
mailbox_id = row["mailbox_id"]
try:
sr = add_side(row, side)
except CrowdedError:
db.execute("UPDATE `nameplates` SET `crowded`=?"
" WHERE `app_id`=? AND `id`=?",
(True, self._app_id, nameplate_id))
db.commit()
raise
if sr.changed:
db.execute("UPDATE `nameplates` SET"
" `side1`=?, `side2`=?, `updated`=?, `second`=?"
" WHERE `app_id`=? AND `id`=?",
(sr.side1, sr.side2, when, when,
self._app_id, nameplate_id))
else:
if self._log_requests:
log.msg("spawning #%d for appid %s" % (channelid, self._appid))
self._channels[channelid] = Channel(self, self._db, self._welcome,
self._blur_usage,
self._log_requests,
self._appid, channelid)
return self._channels[channelid]
log.msg("creating nameplate#%s for app_id %s" %
(nameplate_id, self._app_id))
mailbox_id = generate_mailbox_id()
db.execute("INSERT INTO `nameplates`"
" (`app_id`, `id`, `mailbox_id`, `side1`, `crowded`,"
" `updated`, `started`)"
" VALUES(?,?,?,?,?, ?,?)",
(self._app_id, nameplate_id, mailbox_id, side, False,
when, when))
db.commit()
return mailbox_id
def free_channel(self, channelid):
# called from Channel.delete_and_summarize(), which deletes any
def release_nameplate(self, nameplate_id, side, when):
# when we're done:
# * in the nameplate row, side1 or side2 will be removed
# * if the nameplate is now unused:
# * mailbox.nameplate_closed will be populated
# * the nameplate row will be removed
assert isinstance(nameplate_id, type(u"")), type(nameplate_id)
assert isinstance(side, type(u"")), type(side)
db = self._db
row = db.execute("SELECT * FROM `nameplates`"
" WHERE `app_id`=? AND `id`=?",
(self._app_id, nameplate_id)).fetchone()
if not row:
return
sr = remove_side(row, side)
if sr.empty:
db.execute("DELETE FROM `nameplates`"
" WHERE `app_id`=? AND `id`=?",
(self._app_id, nameplate_id))
self._summarize_nameplate_and_store(row, when, pruned=False)
db.commit()
elif sr.changed:
db.execute("UPDATE `nameplates`"
" SET `side1`=?, `side2`=?, `updated`=?"
" WHERE `app_id`=? AND `id`=?",
(sr.side1, sr.side2, when,
self._app_id, nameplate_id))
db.commit()
def _summarize_nameplate_and_store(self, row, delete_time, pruned):
# requires caller to db.commit()
u = self._summarize_nameplate_usage(row, delete_time, pruned)
self._db.execute("INSERT INTO `nameplate_usage`"
" (`app_id`,"
" `started`, `total_time`, `waiting_time`, `result`)"
" VALUES (?, ?,?,?,?)",
(self._app_id,
u.started, u.total_time, u.waiting_time, u.result))
def _summarize_nameplate_usage(self, row, delete_time, pruned):
started = row["started"]
if self._blur_usage:
started = self._blur_usage * (started // self._blur_usage)
waiting_time = None
if row["second"]:
waiting_time = row["second"] - row["started"]
total_time = delete_time - row["started"]
result = u"lonely"
if row["second"]:
result = u"happy"
if pruned:
result = u"pruney"
if row["crowded"]:
result = u"crowded"
return Usage(started=started, waiting_time=waiting_time,
total_time=total_time, result=result)
def _prune_nameplate(self, row, delete_time):
# requires caller to db.commit()
db = self._db
db.execute("DELETE FROM `nameplates` WHERE `app_id`=? AND `id`=?",
(self._app_id, row["id"]))
self._summarize_nameplate_and_store(row, delete_time, pruned=True)
# TODO: make a Nameplate object, keep track of when there's a
# websocket that's watching it, don't prune a nameplate that someone
# is watching, even if they started watching a long time ago
def prune_nameplates(self, old):
db = self._db
for row in db.execute("SELECT * FROM `nameplates`"
" WHERE `updated` < ?",
(old,)).fetchall():
self._prune_nameplate(row)
count = db.execute("SELECT COUNT(*) FROM `nameplates`").fetchone()[0]
return count
def open_mailbox(self, mailbox_id, side, when):
assert isinstance(mailbox_id, type(u"")), type(mailbox_id)
db = self._db
if not mailbox_id in self._mailboxes:
if self._log_requests:
log.msg("spawning #%s for app_id %s" % (mailbox_id,
self._app_id))
db.execute("INSERT INTO `mailboxes`"
" (`app_id`, `id`, `side1`, `crowded`, `started`)"
" VALUES(?,?,?,?,?)",
(self._app_id, mailbox_id, side, False, when))
db.commit() # XXX
# mailbox.open() does a SELECT to find the old sides
self._mailboxes[mailbox_id] = Mailbox(self, self._db,
self._blur_usage,
self._log_requests,
self._app_id, mailbox_id)
mailbox = self._mailboxes[mailbox_id]
mailbox.open(side, when)
db.commit()
return mailbox
def free_mailbox(self, mailbox_id):
# called from Mailbox.delete_and_summarize(), which deletes any
# messages
if channelid in self._channels:
self._channels.pop(channelid)
if self._log_requests:
log.msg("freed+killed #%d, now have %d DB channels, %d live" %
(channelid, len(self.get_allocated()), len(self._channels)))
if mailbox_id in self._mailboxes:
self._mailboxes.pop(mailbox_id)
#if self._log_requests:
# log.msg("freed+killed #%s, now have %d DB mailboxes, %d live" %
# (mailbox_id, len(self.get_claimed()), len(self._mailboxes)))
def prune_old_channels(self):
def prune_mailboxes(self, old):
# For now, pruning is logged even if log_requests is False, to debug
# the pruning process, and since pruning is triggered by a timer
# instead of by user action. It does reveal which channels were
# instead of by user action. It does reveal which mailboxes were
# present when the pruning process began, though, so in the log run
# it should do less logging.
log.msg(" channel prune begins")
# a channel is deleted when there are no listeners and there have
# been no messages added in CHANNEL_EXPIRATION_TIME seconds
channels = set(self.get_allocated()) # these have messages
channels.update(self._channels) # these might have listeners
for channelid in channels:
log.msg(" channel prune checking %d" % channelid)
channel = self.get_channel(channelid)
mailboxes = set(self.get_claimed()) # these have messages
mailboxes.update(self._mailboxes) # these might have listeners
for mailbox_id in mailboxes:
log.msg(" channel prune checking %d" % mailbox_id)
channel = self.get_channel(mailbox_id)
if channel.is_idle():
log.msg(" channel prune expiring %d" % channelid)
log.msg(" channel prune expiring %d" % mailbox_id)
channel.delete_and_summarize() # calls self.free_channel
log.msg(" channel prune done, %r left" % (self._channels.keys(),))
return bool(self._channels)
log.msg(" channel prune done, %r left" % (self._mailboxes.keys(),))
return bool(self._mailboxes)
def _shutdown(self):
for channel in self._channels.values():
for channel in self._mailboxes.values():
channel._shutdown()
class Rendezvous(service.MultiService):
@ -279,7 +449,7 @@ class Rendezvous(service.MultiService):
service.MultiService.__init__(self)
self._db = db
self._welcome = welcome
self._blur_usage = blur_usage
self._blur_usage = None
log_requests = blur_usage is None
self._log_requests = log_requests
self._apps = {}
@ -291,28 +461,31 @@ class Rendezvous(service.MultiService):
def get_log_requests(self):
return self._log_requests
def get_app(self, appid):
assert isinstance(appid, type(u""))
if not appid in self._apps:
def get_app(self, app_id):
assert isinstance(app_id, type(u""))
if not app_id in self._apps:
if self._log_requests:
log.msg("spawning appid %s" % (appid,))
self._apps[appid] = AppNamespace(self._db, self._welcome,
log.msg("spawning app_id %s" % (app_id,))
self._apps[app_id] = AppNamespace(self._db, self._welcome,
self._blur_usage,
self._log_requests, appid)
return self._apps[appid]
self._log_requests, app_id)
return self._apps[app_id]
def prune(self):
# As with AppNamespace.prune_old_channels, we log for now.
def prune(self, old=None):
# As with AppNamespace.prune_old_mailboxes, we log for now.
log.msg("beginning app prune")
c = self._db.execute("SELECT DISTINCT `appid` FROM `messages`")
apps = set([row["appid"] for row in c.fetchall()]) # these have messages
if old is None:
old = time.time() - CHANNEL_EXPIRATION_TIME
c = self._db.execute("SELECT DISTINCT `app_id` FROM `messages`")
apps = set([row["app_id"] for row in c.fetchall()]) # these have messages
apps.update(self._apps) # these might have listeners
for appid in apps:
log.msg(" app prune checking %r" % (appid,))
still_active = self.get_app(appid).prune_old_channels()
for app_id in apps:
log.msg(" app prune checking %r" % (app_id,))
app = self.get_app(app_id)
still_active = app.prune_nameplates(old) + app.prune_mailboxes(old)
if not still_active:
log.msg("prune pops app %r" % (appid,))
self._apps.pop(appid)
log.msg("prune pops app %r" % (app_id,))
self._apps.pop(app_id)
log.msg("app prune ends, %d remaining apps" % len(self._apps))
def stopService(self):

View File

@ -1,223 +0,0 @@
import json, time
from twisted.web import server, resource
from twisted.python import log
def json_response(request, data):
request.setHeader(b"content-type", b"application/json; charset=utf-8")
return (json.dumps(data)+"\n").encode("utf-8")
class EventsProtocol:
def __init__(self, request):
self.request = request
def sendComment(self, comment):
# this is ignored by clients, but can keep the connection open in the
# face of firewall/NAT timeouts. It also helps unit tests, since
# apparently twisted.web.client.Agent doesn't consider the connection
# to be established until it sees the first byte of the reponse body.
self.request.write(b": " + comment + b"\n\n")
def sendEvent(self, data, name=None, id=None, retry=None):
if name:
self.request.write(b"event: " + name.encode("utf-8") + b"\n")
# e.g. if name=foo, then the client web page should do:
# (new EventSource(url)).addEventListener("foo", handlerfunc)
# Note that this basically defaults to "message".
if id:
self.request.write(b"id: " + id.encode("utf-8") + b"\n")
if retry:
self.request.write(b"retry: " + retry + b"\n") # milliseconds
for line in data.splitlines():
self.request.write(b"data: " + line.encode("utf-8") + b"\n")
self.request.write(b"\n")
def stop(self):
self.request.finish()
def send_rendezvous_event(self, data):
data = data.copy()
data["sent"] = time.time()
self.sendEvent(json.dumps(data))
def stop_rendezvous_watcher(self):
self.stop()
# note: no versions of IE (including the current IE11) support EventSource
# relay URLs are as follows: (MESSAGES=[{phase:,body:}..])
# ("-" indicates a deprecated URL)
# GET /list?appid= -> {channelids: [INT..]}
# POST /allocate {appid:,side:} -> {channelid: INT}
# these return all messages (base64) for appid=/channelid= :
# POST /add {appid:,channelid:,side:,phase:,body:} -> {messages: MESSAGES}
# GET /get?appid=&channelid= (no-eventsource) -> {messages: MESSAGES}
#- GET /get?appid=&channelid= (eventsource) -> {phase:, body:}..
# GET /watch?appid=&channelid= (eventsource) -> {phase:, body:}..
# POST /deallocate {appid:,channelid:,side:} -> {status: waiting | deleted}
# all JSON responses include a "welcome:{..}" key
class RelayResource(resource.Resource):
def __init__(self, rendezvous):
resource.Resource.__init__(self)
self._rendezvous = rendezvous
self._welcome = rendezvous.get_welcome()
class ChannelLister(RelayResource):
def render_GET(self, request):
if b"appid" not in request.args:
e = NeedToUpgradeErrorResource(self._welcome)
return e.get_message()
appid = request.args[b"appid"][0].decode("utf-8")
#print("LIST", appid)
app = self._rendezvous.get_app(appid)
allocated = app.get_allocated()
data = {"welcome": self._welcome, "channelids": sorted(allocated),
"sent": time.time()}
return json_response(request, data)
class Allocator(RelayResource):
def render_POST(self, request):
content = request.content.read()
data = json.loads(content.decode("utf-8"))
appid = data["appid"]
side = data["side"]
if not isinstance(side, type(u"")):
raise TypeError("side must be string, not '%s'" % type(side))
#print("ALLOCATE", appid, side)
app = self._rendezvous.get_app(appid)
channelid = app.find_available_channelid()
app.allocate_channel(channelid, side)
if self._rendezvous.get_log_requests():
log.msg("allocated #%d, now have %d DB channels" %
(channelid, len(app.get_allocated())))
response = {"welcome": self._welcome, "channelid": channelid,
"sent": time.time()}
return json_response(request, response)
def getChild(self, path, req):
# wormhole-0.4.0 "send" started with "POST /allocate/SIDE".
# wormhole-0.5.0 changed that to "POST /allocate". We catch the old
# URL here to deliver a nicer error message (with upgrade
# instructions) than an ugly 404.
return NeedToUpgradeErrorResource(self._welcome)
class NeedToUpgradeErrorResource(resource.Resource):
def __init__(self, welcome):
resource.Resource.__init__(self)
w = welcome.copy()
w["error"] = "Sorry, you must upgrade your client to use this server."
message = {"welcome": w}
self._message = (json.dumps(message)+"\n").encode("utf-8")
def get_message(self):
return self._message
def render_POST(self, request):
return self._message
def render_GET(self, request):
return self._message
def getChild(self, path, req):
return self
class Adder(RelayResource):
def render_POST(self, request):
#content = json.load(request.content, encoding="utf-8")
content = request.content.read()
data = json.loads(content.decode("utf-8"))
appid = data["appid"]
channelid = int(data["channelid"])
side = data["side"]
phase = data["phase"]
if not isinstance(phase, type(u"")):
raise TypeError("phase must be string, not %s" % type(phase))
body = data["body"]
#print("ADD", appid, channelid, side, phase, body)
app = self._rendezvous.get_app(appid)
channel = app.get_channel(channelid)
messages = channel.add_message(side, phase, body, time.time(), None)
response = {"welcome": self._welcome, "messages": messages,
"sent": time.time()}
return json_response(request, response)
class GetterOrWatcher(RelayResource):
def render_GET(self, request):
appid = request.args[b"appid"][0].decode("utf-8")
channelid = int(request.args[b"channelid"][0])
#print("GET", appid, channelid)
app = self._rendezvous.get_app(appid)
channel = app.get_channel(channelid)
if b"text/event-stream" not in (request.getHeader(b"accept") or b""):
messages = channel.get_messages()
response = {"welcome": self._welcome, "messages": messages,
"sent": time.time()}
return json_response(request, response)
request.setHeader(b"content-type", b"text/event-stream; charset=utf-8")
ep = EventsProtocol(request)
ep.sendEvent(json.dumps(self._welcome), name="welcome")
old_events = channel.add_listener(ep)
request.notifyFinish().addErrback(lambda f:
channel.remove_listener(ep))
for old_event in old_events:
ep.send_rendezvous_event(old_event)
return server.NOT_DONE_YET
class Watcher(RelayResource):
def render_GET(self, request):
appid = request.args[b"appid"][0].decode("utf-8")
channelid = int(request.args[b"channelid"][0])
app = self._rendezvous.get_app(appid)
channel = app.get_channel(channelid)
if b"text/event-stream" not in (request.getHeader(b"accept") or b""):
raise TypeError("/watch is for EventSource only")
request.setHeader(b"content-type", b"text/event-stream; charset=utf-8")
ep = EventsProtocol(request)
ep.sendEvent(json.dumps(self._welcome), name="welcome")
old_events = channel.add_listener(ep)
request.notifyFinish().addErrback(lambda f:
channel.remove_listener(ep))
for old_event in old_events:
ep.send_rendezvous_event(old_event)
return server.NOT_DONE_YET
class Deallocator(RelayResource):
def render_POST(self, request):
content = request.content.read()
data = json.loads(content.decode("utf-8"))
appid = data["appid"]
channelid = int(data["channelid"])
side = data["side"]
if not isinstance(side, type(u"")):
raise TypeError("side must be string, not '%s'" % type(side))
mood = data.get("mood")
#print("DEALLOCATE", appid, channelid, side)
app = self._rendezvous.get_app(appid)
channel = app.get_channel(channelid)
deleted = channel.deallocate(side, mood)
response = {"status": "waiting", "sent": time.time()}
if deleted:
response = {"status": "deleted", "sent": time.time()}
return json_response(request, response)
class WebRendezvous(resource.Resource):
def __init__(self, rendezvous):
resource.Resource.__init__(self)
self._rendezvous = rendezvous
self.putChild(b"list", ChannelLister(rendezvous))
self.putChild(b"allocate", Allocator(rendezvous))
self.putChild(b"add", Adder(rendezvous))
self.putChild(b"get", GetterOrWatcher(rendezvous))
self.putChild(b"watch", Watcher(rendezvous))
self.putChild(b"deallocate", Deallocator(rendezvous))
def getChild(self, path, req):
# 0.4.0 used "POST /CID/SIDE/post/MSGNUM"
# 0.5.0 replaced it with "POST /add (json body)"
# give a nicer error message to old clients
if (len(req.postpath) >= 2
and req.postpath[1] in (b"post", b"poll", b"deallocate")):
welcome = self._rendezvous.get_welcome()
return NeedToUpgradeErrorResource(welcome)
return resource.NoResource("No such child resource.")

View File

@ -2,23 +2,48 @@ import json, time
from twisted.internet import reactor
from twisted.python import log
from autobahn.twisted import websocket
from .rendezvous import CrowdedError, SidedMessage
# Each WebSocket connection is bound to one "appid", one "side", and one
# "channelid". The connection's appid and side are set by the "bind" message
# (which must be the first message on the connection). The channelid is set
# by either a "allocate" message (where the server picks the channelid), or
# by a "claim" message (where the client picks it). All three values must be
# set before any other message (watch, add, deallocate) can be sent.
# The WebSocket allows the client to send "commands" to the server, and the
# server to send "responses" to the client. Note that commands and responses
# are not necessarily one-to-one. All commands provoke an "ack" response
# (with a copy of the original message) for timing, testing, and
# synchronization purposes. All commands and responses are JSON-encoded.
# All websocket messages are JSON-encoded. The client can send us "inbound"
# messages (marked as "->" below), which may (or may not) provoke immediate
# (or delayed) "outbound" messages (marked as "<-"). There is no guaranteed
# correlation between requests and responses. In this list, "A -> B" means
# that some time after A is received, at least one message of type B will be
# sent out.
# Each WebSocket connection is bound to one "appid" and one "side", which are
# set by the "bind" command (which must be the first command on the
# connection), and must be set before any other command will be accepted.
# All outbound messages include a "sent" key, which is a float (seconds since
# epoch) with the server clock just before the outbound message was written
# Each connection can be bound to a single "mailbox" (a two-sided
# store-and-forward queue, identified by the "mailbox id": a long, randomly
# unique string identifier) by using the "open" command. This protects the
# mailbox from idle closure, enables the "add" command (to put new messages
# in the queue), and triggers delivery of past and future messages via the
# "message" response. The "close" command removes the binding (but note that
# it does not enable the subsequent binding of a second mailbox). When the
# last side closes a mailbox, its contents are deleted.
# Additionally, the connection can be bound a single "nameplate", which is
# short identifier that makes up the first component of a wormhole code. Each
# nameplate points to a single long-id "mailbox". The "allocate" message
# determines the shortest available numeric nameplate, reserves it, and
# returns the nameplate id. "list" returns a list of all numeric nameplates
# which currently have only one side active (i.e. they are waiting for a
# partner). The "claim" message reserves an arbitrary nameplate id (perhaps
# the receiver of a wormhole connection typed in a code they got from the
# sender, or perhaps the two sides agreed upon a code offline and are both
# typing it in), and the "release" message releases it. When every side that
# has claimed the nameplate has also released it, the nameplate is
# deallocated (but they will probably keep the underlying mailbox open).
# Inbound (client to server) commands are marked as "->" below. Unrecognized
# inbound keys will be ignored. Outbound (server to client) responses use
# "<-". There is no guaranteed correlation between requests and responses. In
# this list, "A -> B" means that some time after A is received, at least one
# message of type B will be sent out (probably).
# All responses include a "server_tx" key, which is a float (seconds since
# epoch) with the server clock just before the outbound response was written
# to the socket.
# connection -> welcome
@ -27,16 +52,24 @@ from autobahn.twisted import websocket
# motd: all clients display message, then continue normally
# error: all clients display mesage, then terminate with error
# -> {type: "bind", appid:, side:}
# -> {type: "list"} -> channelids
# <- {type: "channelids", channelids: [int..]}
# -> {type: "allocate"} -> allocated
# <- {type: "allocated", channelid: int}
# -> {type: "claim", channelid: int}
# -> {type: "watch"} -> message # sends old messages and more in future
# <- {type: "message", message: {phase:, body:}} # body is hex
# -> {type: "add", phase: str, body: hex} # may send echo
# -> {type: "deallocate", mood: str} -> deallocated
# <- {type: "deallocated", status: waiting|deleted}
#
# -> {type: "list"} -> nameplates
# <- {type: "nameplates", nameplates: [{id: str,..},..]}
# -> {type: "allocate"} -> nameplate, mailbox
# <- {type: "allocated", nameplate: str}
# -> {type: "claim", nameplate: str} -> mailbox
# <- {type: "claimed", mailbox: str}
# -> {type: "release"}
# <- {type: "released"}
#
# -> {type: "open", mailbox: str} -> message
# sends old messages now, and subscribes to deliver future messages
# <- {type: "message", side:, phase:, body:, msg_id:}} # body is hex
# -> {type: "add", phase: str, body: hex} # will send echo in a "message"
#
# -> {type: "close", mood: str} -> closed
# <- {type: "closed"}
#
# <- {type: "error", error: str, orig: {}} # in response to malformed msgs
# for tests that need to know when a message has been processed:
@ -52,8 +85,9 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol):
websocket.WebSocketServerProtocol.__init__(self)
self._app = None
self._side = None
self._channel = None
self._watching = False
self._did_allocate = False # only one allocate() per websocket
self._nameplate_id = None
self._mailbox = None
def onConnect(self, request):
rv = self.factory.rendezvous
@ -71,10 +105,7 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol):
try:
if "type" not in msg:
raise Error("missing 'type'")
if "id" in msg:
# Only ack clients modern enough to include [id]. Older ones
# won't recognize the message, then they'll abort.
self.send("ack", id=msg["id"])
self.send("ack", id=msg.get("id"))
mtype = msg["type"]
if mtype == "ping":
@ -83,33 +114,27 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol):
return self.handle_bind(msg)
if not self._app:
raise Error("Must bind first")
raise Error("must bind first")
if mtype == "list":
return self.handle_list()
if mtype == "allocate":
return self.handle_allocate()
return self.handle_allocate(server_rx)
if mtype == "claim":
return self.handle_claim(msg)
return self.handle_claim(msg, server_rx)
if mtype == "release":
return self.handle_release(server_rx)
if not self._channel:
raise Error("Must set channel first")
if mtype == "watch":
return self.handle_watch(self._channel, msg)
if mtype == "open":
return self.handle_open(msg, server_rx)
if mtype == "add":
return self.handle_add(self._channel, msg, server_rx)
if mtype == "deallocate":
return self.handle_deallocate(self._channel, msg)
return self.handle_add(msg, server_rx)
if mtype == "close":
return self.handle_close(msg, server_rx)
raise Error("Unknown type")
raise Error("unknown type")
except Error as e:
self.send("error", error=e._explain, orig=msg)
def send_rendezvous_event(self, event):
self.send("message", message=event)
def stop_rendezvous_watcher(self):
self._reactor.callLater(0, self.transport.loseConnection)
def handle_ping(self, msg):
if "ping" not in msg:
raise Error("ping requires 'ping'")
@ -125,46 +150,79 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol):
self._app = self.factory.rendezvous.get_app(msg["appid"])
self._side = msg["side"]
def handle_list(self):
channelids = sorted(self._app.get_allocated())
self.send("channelids", channelids=channelids)
nameplate_ids = sorted(self._app.get_nameplate_ids())
# provide room to add nameplate attributes later (like which wordlist
# is used for each, maybe how many words)
nameplates = [{"id": nid} for nid in nameplate_ids]
self.send("nameplates", nameplates=nameplates)
def handle_allocate(self):
if self._channel:
raise Error("Already bound to a channelid")
channelid = self._app.find_available_channelid()
self._channel = self._app.allocate_channel(channelid, self._side)
self.send("allocated", channelid=channelid)
def handle_allocate(self, server_rx):
if self._did_allocate:
raise Error("you already allocated one, don't be greedy")
nameplate_id = self._app.allocate_nameplate(self._side, server_rx)
assert isinstance(nameplate_id, type(u""))
self._did_allocate = True
self.send("allocated", nameplate=nameplate_id)
def handle_claim(self, msg):
if "channelid" not in msg:
raise Error("claim requires 'channelid'")
# we allow allocate+claim as long as they match
if self._channel is not None:
old_cid = self._channel.get_channelid()
if msg["channelid"] != old_cid:
raise Error("Already bound to channelid %d" % old_cid)
self._channel = self._app.allocate_channel(msg["channelid"], self._side)
def handle_claim(self, msg, server_rx):
if "nameplate" not in msg:
raise Error("claim requires 'nameplate'")
nameplate_id = msg["nameplate"]
assert isinstance(nameplate_id, type(u"")), type(nameplate_id)
self._nameplate_id = nameplate_id
try:
mailbox_id = self._app.claim_nameplate(nameplate_id, self._side,
server_rx)
except CrowdedError:
raise Error("crowded")
self.send("claimed", mailbox=mailbox_id)
def handle_watch(self, channel, msg):
if self._watching:
raise Error("already watching")
self._watching = True
for old_message in channel.add_listener(self):
self.send_rendezvous_event(old_message)
def handle_release(self, server_rx):
if not self._nameplate_id:
raise Error("must claim a nameplate before releasing it")
self._app.release_nameplate(self._nameplate_id, self._side, server_rx)
self._nameplate_id = None
self.send("released")
def handle_add(self, channel, msg, server_rx):
def handle_open(self, msg, server_rx):
if self._mailbox:
raise Error("you already have a mailbox open")
if "mailbox" not in msg:
raise Error("open requires 'mailbox'")
mailbox_id = msg["mailbox"]
assert isinstance(mailbox_id, type(u""))
self._mailbox = self._app.open_mailbox(mailbox_id, self._side,
server_rx)
def _send(sm):
self.send("message", side=sm.side, phase=sm.phase,
body=sm.body, server_rx=sm.server_rx, id=sm.msg_id)
def _stop():
pass
for old_sm in self._mailbox.add_listener(self, _send, _stop):
_send(old_sm)
def handle_add(self, msg, server_rx):
if not self._mailbox:
raise Error("must open mailbox before adding")
if "phase" not in msg:
raise Error("missing 'phase'")
if "body" not in msg:
raise Error("missing 'body'")
msgid = msg.get("id") # optional
channel.add_message(self._side, msg["phase"], msg["body"],
server_rx, msgid)
sm = SidedMessage(side=self._side, phase=msg["phase"],
body=msg["body"], server_rx=server_rx,
msg_id=msgid)
self._mailbox.add_message(sm)
def handle_deallocate(self, channel, msg):
deleted = channel.deallocate(self._side, msg.get("mood"))
self.send("deallocated", status="deleted" if deleted else "waiting")
def handle_close(self, msg, server_rx):
if not self._mailbox:
raise Error("must open mailbox before closing")
self._mailbox.close(self._side, msg.get("mood"), server_rx)
self._mailbox = None
self.send("closed")
def send(self, mtype, **kwargs):
kwargs["type"] = mtype

View File

@ -8,7 +8,6 @@ from .endpoint_service import ServerEndpointService
from .. import __version__
from .database import get_db
from .rendezvous import Rendezvous
from .rendezvous_web import WebRendezvous
from .rendezvous_websocket import WebSocketRendezvousFactory
from .transit_server import Transit
@ -49,12 +48,8 @@ class RelayServer(service.MultiService):
rendezvous = Rendezvous(db, welcome, blur_usage)
rendezvous.setServiceParent(self) # for the pruning timer
root = Root()
wr = WebRendezvous(rendezvous)
root.putChild(b"wormhole-relay", wr)
wsrf = WebSocketRendezvousFactory(None, rendezvous)
wr.putChild(b"ws", WebSocketResource(wsrf))
root = WebSocketResource(wsrf)
site = PrivacyEnhancedSite(root)
if blur_usage:
@ -75,7 +70,6 @@ class RelayServer(service.MultiService):
self._db = db
self._rendezvous = rendezvous
self._root = root
self._rendezvous_web = wr
self._rendezvous_web_service = rendezvous_web_service
self._rendezvous_websocket = wsrf
if transit_port:

View File

@ -186,12 +186,12 @@ class Transit(protocol.ServerFactory, service.MultiService):
if self._blur_usage:
started = self._blur_usage * (started // self._blur_usage)
total_bytes = blur_size(total_bytes)
self._db.execute("INSERT INTO `usage`"
" (`type`, `started`, `result`, `total_bytes`,"
" `total_time`, `waiting_time`)"
" VALUES (?,?,?,?, ?,?)",
(u"transit", started, result, total_bytes,
total_time, waiting_time))
self._db.execute("INSERT INTO `transit_usage`"
" (`started`, `total_time`, `waiting_time`,"
" `total_bytes`, `result`)"
" VALUES (?,?,?, ?,?)",
(started, total_time, waiting_time,
total_bytes, result))
self._db.commit()
def transitFinished(self, p, token, description):

View File

@ -17,8 +17,7 @@ class ServerBase:
s.setServiceParent(self.sp)
self._rendezvous = s._rendezvous
self._transit_server = s._transit
self.relayurl = u"http://127.0.0.1:%d/wormhole-relay/" % relayport
self.rdv_ws_url = self.relayurl.replace("http:", "ws:") + "ws"
self.relayurl = u"ws://127.0.0.1:%d/" % relayport
self.rdv_ws_port = relayport
# ws://127.0.0.1:%d/wormhole-relay/ws
self.transit = u"tcp:127.0.0.1:%d" % transitport

View File

@ -1,446 +0,0 @@
from __future__ import print_function
import json
from twisted.trial import unittest
from twisted.internet.defer import gatherResults, succeed
from twisted.internet.threads import deferToThread
from ..blocking.transcribe import (Wormhole, UsageError, ChannelManager,
WrongPasswordError)
from ..blocking.eventsource import EventSourceFollower
from .common import ServerBase
APPID = u"appid"
class Channel(ServerBase, unittest.TestCase):
def ignore(self, welcome):
pass
def test_allocate(self):
cm = ChannelManager(self.relayurl, APPID, u"side", self.ignore)
d = deferToThread(cm.list_channels)
def _got_channels(channels):
self.failUnlessEqual(channels, [])
d.addCallback(_got_channels)
d.addCallback(lambda _: deferToThread(cm.allocate))
def _allocated(channelid):
self.failUnlessEqual(type(channelid), int)
self._channelid = channelid
d.addCallback(_allocated)
d.addCallback(lambda _: deferToThread(cm.connect, self._channelid))
def _connected(c):
self._channel = c
d.addCallback(_connected)
d.addCallback(lambda _: deferToThread(self._channel.deallocate,
u"happy"))
return d
def test_messages(self):
cm1 = ChannelManager(self.relayurl, APPID, u"side1", self.ignore)
cm2 = ChannelManager(self.relayurl, APPID, u"side2", self.ignore)
c1 = cm1.connect(1)
c2 = cm2.connect(1)
d = succeed(None)
d.addCallback(lambda _: deferToThread(c1.send, u"phase1", b"msg1"))
d.addCallback(lambda _: deferToThread(c2.get, u"phase1"))
d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg1"))
d.addCallback(lambda _: deferToThread(c2.send, u"phase1", b"msg2"))
d.addCallback(lambda _: deferToThread(c1.get, u"phase1"))
d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg2"))
# it's legal to fetch a phase multiple times, should be idempotent
d.addCallback(lambda _: deferToThread(c1.get, u"phase1"))
d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg2"))
# deallocating one side is not enough to destroy the channel
d.addCallback(lambda _: deferToThread(c2.deallocate))
def _not_yet(_):
self._rendezvous.prune()
self.failUnlessEqual(len(self._rendezvous._apps), 1)
d.addCallback(_not_yet)
# but deallocating both will make the messages go away
d.addCallback(lambda _: deferToThread(c1.deallocate, u"sad"))
def _gone(_):
self._rendezvous.prune()
self.failUnlessEqual(len(self._rendezvous._apps), 0)
d.addCallback(_gone)
return d
def test_get_multiple_phases(self):
cm1 = ChannelManager(self.relayurl, APPID, u"side1", self.ignore)
cm2 = ChannelManager(self.relayurl, APPID, u"side2", self.ignore)
c1 = cm1.connect(1)
c2 = cm2.connect(1)
self.failUnlessRaises(TypeError, c2.get_first_of, u"phase1")
self.failUnlessRaises(TypeError, c2.get_first_of, [u"phase1", 7])
d = succeed(None)
d.addCallback(lambda _: deferToThread(c1.send, u"phase1", b"msg1"))
d.addCallback(lambda _: deferToThread(c2.get_first_of, [u"phase1",
u"phase2"]))
d.addCallback(lambda phase_and_body:
self.failUnlessEqual(phase_and_body,
(u"phase1", b"msg1")))
d.addCallback(lambda _: deferToThread(c2.get_first_of, [u"phase2",
u"phase1"]))
d.addCallback(lambda phase_and_body:
self.failUnlessEqual(phase_and_body,
(u"phase1", b"msg1")))
d.addCallback(lambda _: deferToThread(c1.send, u"phase2", b"msg2"))
d.addCallback(lambda _: deferToThread(c2.get, u"phase2"))
# if both are present, it should prefer the first one we asked for
d.addCallback(lambda _: deferToThread(c2.get_first_of, [u"phase1",
u"phase2"]))
d.addCallback(lambda phase_and_body:
self.failUnlessEqual(phase_and_body,
(u"phase1", b"msg1")))
d.addCallback(lambda _: deferToThread(c2.get_first_of, [u"phase2",
u"phase1"]))
d.addCallback(lambda phase_and_body:
self.failUnlessEqual(phase_and_body,
(u"phase2", b"msg2")))
return d
def test_appid_independence(self):
APPID_A = u"appid_A"
APPID_B = u"appid_B"
cm1a = ChannelManager(self.relayurl, APPID_A, u"side1", self.ignore)
cm2a = ChannelManager(self.relayurl, APPID_A, u"side2", self.ignore)
c1a = cm1a.connect(1)
c2a = cm2a.connect(1)
cm1b = ChannelManager(self.relayurl, APPID_B, u"side1", self.ignore)
cm2b = ChannelManager(self.relayurl, APPID_B, u"side2", self.ignore)
c1b = cm1b.connect(1)
c2b = cm2b.connect(1)
d = succeed(None)
d.addCallback(lambda _: deferToThread(c1a.send, u"phase1", b"msg1a"))
d.addCallback(lambda _: deferToThread(c1b.send, u"phase1", b"msg1b"))
d.addCallback(lambda _: deferToThread(c2a.get, u"phase1"))
d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg1a"))
d.addCallback(lambda _: deferToThread(c2b.get, u"phase1"))
d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg1b"))
return d
class _DoBothMixin:
def doBoth(self, call1, call2):
f1 = call1[0]
f1args = call1[1:]
f2 = call2[0]
f2args = call2[1:]
return gatherResults([deferToThread(f1, *f1args),
deferToThread(f2, *f2args)], True)
class Blocking(_DoBothMixin, ServerBase, unittest.TestCase):
# we need Twisted to run the server, but we run the sender and receiver
# with deferToThread()
def test_basic(self):
w1 = Wormhole(APPID, self.relayurl)
w2 = Wormhole(APPID, self.relayurl)
d = deferToThread(w1.get_code)
def _got_code(code):
w2.set_code(code)
return self.doBoth([w1.send_data, b"data1"],
[w2.send_data, b"data2"])
d.addCallback(_got_code)
def _sent(res):
return self.doBoth([w1.get_data], [w2.get_data])
d.addCallback(_sent)
def _done(dl):
(dataX, dataY) = dl
self.assertEqual(dataX, b"data2")
self.assertEqual(dataY, b"data1")
return self.doBoth([w1.close], [w2.close])
d.addCallback(_done)
return d
def test_same_message(self):
# the two sides use random nonces for their messages, so it's ok for
# both to try and send the same body: they'll result in distinct
# encrypted messages
w1 = Wormhole(APPID, self.relayurl)
w2 = Wormhole(APPID, self.relayurl)
d = deferToThread(w1.get_code)
def _got_code(code):
w2.set_code(code)
return self.doBoth([w1.send_data, b"data"],
[w2.send_data, b"data"])
d.addCallback(_got_code)
def _sent(res):
return self.doBoth([w1.get_data], [w2.get_data])
d.addCallback(_sent)
def _done(dl):
(dataX, dataY) = dl
self.assertEqual(dataX, b"data")
self.assertEqual(dataY, b"data")
return self.doBoth([w1.close], [w2.close])
d.addCallback(_done)
return d
def test_interleaved(self):
w1 = Wormhole(APPID, self.relayurl)
w2 = Wormhole(APPID, self.relayurl)
d = deferToThread(w1.get_code)
def _got_code(code):
w2.set_code(code)
return self.doBoth([w1.send_data, b"data1"],
[w2.get_data])
d.addCallback(_got_code)
def _sent(res):
(_, dataY) = res
self.assertEqual(dataY, b"data1")
return self.doBoth([w1.get_data], [w2.send_data, b"data2"])
d.addCallback(_sent)
def _done(dl):
(dataX, _) = dl
self.assertEqual(dataX, b"data2")
return self.doBoth([w1.close], [w2.close])
d.addCallback(_done)
return d
def test_fixed_code(self):
w1 = Wormhole(APPID, self.relayurl)
w2 = Wormhole(APPID, self.relayurl)
w1.set_code(u"123-purple-elephant")
w2.set_code(u"123-purple-elephant")
d = self.doBoth([w1.send_data, b"data1"], [w2.send_data, b"data2"])
def _sent(res):
return self.doBoth([w1.get_data], [w2.get_data])
d.addCallback(_sent)
def _done(dl):
(dataX, dataY) = dl
self.assertEqual(dataX, b"data2")
self.assertEqual(dataY, b"data1")
return self.doBoth([w1.close], [w2.close])
d.addCallback(_done)
return d
def test_phases(self):
w1 = Wormhole(APPID, self.relayurl)
w2 = Wormhole(APPID, self.relayurl)
w1.set_code(u"123-purple-elephant")
w2.set_code(u"123-purple-elephant")
d = self.doBoth([w1.send_data, b"data1", u"p1"],
[w2.send_data, b"data2", u"p1"])
d.addCallback(lambda _:
self.doBoth([w1.send_data, b"data3", u"p2"],
[w2.send_data, b"data4", u"p2"]))
d.addCallback(lambda _:
self.doBoth([w1.get_data, u"p2"],
[w2.get_data, u"p1"]))
def _got_1(dl):
(dataX, dataY) = dl
self.assertEqual(dataX, b"data4")
self.assertEqual(dataY, b"data1")
return self.doBoth([w1.get_data, u"p1"],
[w2.get_data, u"p2"])
d.addCallback(_got_1)
def _got_2(dl):
(dataX, dataY) = dl
self.assertEqual(dataX, b"data2")
self.assertEqual(dataY, b"data3")
return self.doBoth([w1.close], [w2.close])
d.addCallback(_got_2)
return d
def test_wrong_password(self):
w1 = Wormhole(APPID, self.relayurl)
w2 = Wormhole(APPID, self.relayurl)
# make sure we can detect WrongPasswordError even if one side only
# does get_data() and not send_data(), like "wormhole receive" does
d = deferToThread(w1.get_code)
d.addCallback(lambda code: w2.set_code(code+"not"))
# w2 can't throw WrongPasswordError until it sees a CONFIRM message,
# and w1 won't send CONFIRM until it sees a PAKE message, which w2
# won't send until we call get_data. So we need both sides to be
# running at the same time for this test.
def _w1_sends():
w1.send_data(b"data1")
def _w2_gets():
self.assertRaises(WrongPasswordError, w2.get_data)
d.addCallback(lambda _: self.doBoth([_w1_sends], [_w2_gets]))
# and now w1 should have enough information to throw too
d.addCallback(lambda _: deferToThread(self.assertRaises,
WrongPasswordError, w1.get_data))
def _done(_):
# both sides are closed automatically upon error, but it's still
# legal to call .close(), and should be idempotent
return self.doBoth([w1.close], [w2.close])
d.addCallback(_done)
return d
def test_no_confirm(self):
# newer versions (which check confirmations) should will work with
# older versions (that don't send confirmations)
w1 = Wormhole(APPID, self.relayurl)
w1._send_confirm = False
w2 = Wormhole(APPID, self.relayurl)
d = deferToThread(w1.get_code)
d.addCallback(lambda code: w2.set_code(code))
d.addCallback(lambda _: self.doBoth([w1.send_data, b"data1"],
[w2.get_data]))
d.addCallback(lambda dl: self.assertEqual(dl[1], b"data1"))
d.addCallback(lambda _: self.doBoth([w1.get_data],
[w2.send_data, b"data2"]))
d.addCallback(lambda dl: self.assertEqual(dl[0], b"data2"))
d.addCallback(lambda _: self.doBoth([w1.close], [w2.close]))
return d
def test_verifier(self):
w1 = Wormhole(APPID, self.relayurl)
w2 = Wormhole(APPID, self.relayurl)
d = deferToThread(w1.get_code)
def _got_code(code):
w2.set_code(code)
return self.doBoth([w1.get_verifier], [w2.get_verifier])
d.addCallback(_got_code)
def _check_verifier(res):
v1, v2 = res
self.failUnlessEqual(type(v1), type(b""))
self.failUnlessEqual(v1, v2)
return self.doBoth([w1.send_data, b"data1"],
[w2.send_data, b"data2"])
d.addCallback(_check_verifier)
def _sent(res):
return self.doBoth([w1.get_data], [w2.get_data])
d.addCallback(_sent)
def _done(dl):
(dataX, dataY) = dl
self.assertEqual(dataX, b"data2")
self.assertEqual(dataY, b"data1")
return self.doBoth([w1.close], [w2.close])
d.addCallback(_done)
return d
def test_verifier_mismatch(self):
w1 = Wormhole(APPID, self.relayurl)
w2 = Wormhole(APPID, self.relayurl)
d = deferToThread(w1.get_code)
def _got_code(code):
w2.set_code(code+"not")
return self.doBoth([w1.get_verifier], [w2.get_verifier])
d.addCallback(_got_code)
def _check_verifier(res):
v1, v2 = res
self.failUnlessEqual(type(v1), type(b""))
self.failIfEqual(v1, v2)
return self.doBoth([w1.close], [w2.close])
d.addCallback(_check_verifier)
return d
def test_errors(self):
w1 = Wormhole(APPID, self.relayurl)
self.assertRaises(UsageError, w1.get_verifier)
self.assertRaises(UsageError, w1.get_data)
self.assertRaises(UsageError, w1.send_data, b"data")
w1.set_code(u"123-purple-elephant")
self.assertRaises(UsageError, w1.set_code, u"123-nope")
self.assertRaises(UsageError, w1.get_code)
w2 = Wormhole(APPID, self.relayurl)
d = deferToThread(w2.get_code)
def _done(code):
self.assertRaises(UsageError, w2.get_code)
return self.doBoth([w1.close], [w2.close])
d.addCallback(_done)
return d
def test_repeat_phases(self):
w1 = Wormhole(APPID, self.relayurl)
w1.set_code(u"123-purple-elephant")
w2 = Wormhole(APPID, self.relayurl)
w2.set_code(u"123-purple-elephant")
# we must let them establish a key before we can send data
d = self.doBoth([w1.get_verifier], [w2.get_verifier])
d.addCallback(lambda _:
deferToThread(w1.send_data, b"data1", phase=u"1"))
def _sent(res):
# underscore-prefixed phases are reserved
self.assertRaises(UsageError, w1.send_data, b"data1", phase=u"_1")
self.assertRaises(UsageError, w1.get_data, phase=u"_1")
# you can't send twice to the same phase
self.assertRaises(UsageError, w1.send_data, b"data1", phase=u"1")
# but you can send to a different one
return deferToThread(w1.send_data, b"data2", phase=u"2")
d.addCallback(_sent)
d.addCallback(lambda _: deferToThread(w2.get_data, phase=u"1"))
def _got1(res):
self.failUnlessEqual(res, b"data1")
# and you can't read twice from the same phase
self.assertRaises(UsageError, w2.get_data, phase=u"1")
# but you can read from a different one
return deferToThread(w2.get_data, phase=u"2")
d.addCallback(_got1)
def _got2(res):
self.failUnlessEqual(res, b"data2")
return self.doBoth([w1.close], [w2.close])
d.addCallback(_got2)
return d
def test_serialize(self):
w1 = Wormhole(APPID, self.relayurl)
self.assertRaises(UsageError, w1.serialize) # too early
w2 = Wormhole(APPID, self.relayurl)
d = deferToThread(w1.get_code)
def _got_code(code):
self.assertRaises(UsageError, w2.serialize) # too early
w2.set_code(code)
w2.serialize() # ok
s = w1.serialize()
self.assertEqual(type(s), type(""))
unpacked = json.loads(s) # this is supposed to be JSON
self.assertEqual(type(unpacked), dict)
self.new_w1 = Wormhole.from_serialized(s)
return self.doBoth([self.new_w1.send_data, b"data1"],
[w2.send_data, b"data2"])
d.addCallback(_got_code)
def _sent(res):
return self.doBoth(self.new_w1.get_data(), w2.get_data())
d.addCallback(_sent)
def _done(dl):
(dataX, dataY) = dl
self.assertEqual(dataX, b"data2")
self.assertEqual(dataY, b"data1")
self.assertRaises(UsageError, w2.serialize) # too late
return self.doBoth([w1.close], [w2.close])
d.addCallback(_done)
return d
test_serialize.skip = "not yet implemented for the blocking flavor"
data1 = u"""\
event: welcome
data: one and a
data: two
data:.
data: three
: this line is ignored
event: e2
: this line is ignored too
i am a dataless field name
data: four
"""
class NoNetworkESF(EventSourceFollower):
def __init__(self, text):
self._lines_iter = iter(text.splitlines())
class EventSourceClient(unittest.TestCase):
def test_parser(self):
events = []
f = NoNetworkESF(data1)
events = list(f.iter_events())
self.failUnlessEqual(events,
[(u"welcome", u"one and a\ntwo\n."),
(u"message", u"three"),
(u"e2", u"four"),
])

View File

@ -1,57 +0,0 @@
from __future__ import print_function
from twisted.trial import unittest
from twisted.internet.defer import gatherResults
from twisted.internet.threads import deferToThread
from ..twisted.transcribe import Wormhole as twisted_Wormhole
from ..blocking.transcribe import Wormhole as blocking_Wormhole
from .common import ServerBase
# make sure the two implementations (Twisted-style and blocking-style) can
# interoperate
APPID = u"appid"
class Basic(ServerBase, unittest.TestCase):
def doBoth(self, call1, d2):
f1 = call1[0]
f1args = call1[1:]
return gatherResults([deferToThread(f1, *f1args), d2], True)
def test_twisted_to_blocking(self):
tw = twisted_Wormhole(APPID, self.relayurl)
bw = blocking_Wormhole(APPID, self.relayurl)
d = tw.get_code()
def _got_code(code):
bw.set_code(code)
return self.doBoth([bw.send_data, b"data2"], tw.send_data(b"data1"))
d.addCallback(_got_code)
def _sent(res):
return self.doBoth([bw.get_data], tw.get_data())
d.addCallback(_sent)
def _done(dl):
(dataX, dataY) = dl
self.assertEqual(dataX, b"data1")
self.assertEqual(dataY, b"data2")
return self.doBoth([bw.close], tw.close())
d.addCallback(_done)
return d
def test_blocking_to_twisted(self):
bw = blocking_Wormhole(APPID, self.relayurl)
tw = twisted_Wormhole(APPID, self.relayurl)
d = deferToThread(bw.get_code)
def _got_code(code):
tw.set_code(code)
return self.doBoth([bw.send_data, b"data1"], tw.send_data(b"data2"))
d.addCallback(_got_code)
def _sent(res):
return self.doBoth([bw.get_data], tw.get_data())
d.addCallback(_sent)
def _done(dl):
(dataX, dataY) = dl
self.assertEqual(dataX, b"data2")
self.assertEqual(dataY, b"data1")
return self.doBoth([bw.close], tw.close())
d.addCallback(_done)
return d

View File

@ -453,7 +453,7 @@ class Cleanup(ServerBase, unittest.TestCase):
yield send_d
yield receive_d
cids = self._rendezvous.get_app(cmd_send.APPID).get_allocated()
cids = self._rendezvous.get_app(cmd_send.APPID).get_nameplate_ids()
self.assertEqual(len(cids), 0)
@inlineCallbacks
@ -482,6 +482,7 @@ class Cleanup(ServerBase, unittest.TestCase):
yield self.assertFailure(send_d, WrongPasswordError)
yield self.assertFailure(receive_d, WrongPasswordError)
cids = self._rendezvous.get_app(cmd_send.APPID).get_allocated()
cids = self._rendezvous.get_app(cmd_send.APPID).get_nameplate_ids()
self.assertEqual(len(cids), 0)
self.flushLoggedErrors(WrongPasswordError)

File diff suppressed because it is too large Load Diff

View File

@ -1,243 +0,0 @@
from __future__ import print_function
import json
from twisted.trial import unittest
from twisted.internet.defer import gatherResults, inlineCallbacks
from ..twisted.transcribe import Wormhole, UsageError, WrongPasswordError
from .common import ServerBase
APPID = u"appid"
class Basic(ServerBase, unittest.TestCase):
def doBoth(self, d1, d2):
return gatherResults([d1, d2], True)
@inlineCallbacks
def test_basic(self):
w1 = Wormhole(APPID, self.relayurl)
w2 = Wormhole(APPID, self.relayurl)
code = yield w1.get_code()
w2.set_code(code)
yield self.doBoth(w1.send_data(b"data1"), w2.send_data(b"data2"))
dl = yield self.doBoth(w1.get_data(), w2.get_data())
(dataX, dataY) = dl
self.assertEqual(dataX, b"data2")
self.assertEqual(dataY, b"data1")
yield self.doBoth(w1.close(), w2.close())
@inlineCallbacks
def test_same_message(self):
# the two sides use random nonces for their messages, so it's ok for
# both to try and send the same body: they'll result in distinct
# encrypted messages
w1 = Wormhole(APPID, self.relayurl)
w2 = Wormhole(APPID, self.relayurl)
code = yield w1.get_code()
w2.set_code(code)
yield self.doBoth(w1.send_data(b"data"), w2.send_data(b"data"))
dl = yield self.doBoth(w1.get_data(), w2.get_data())
(dataX, dataY) = dl
self.assertEqual(dataX, b"data")
self.assertEqual(dataY, b"data")
yield self.doBoth(w1.close(), w2.close())
@inlineCallbacks
def test_interleaved(self):
w1 = Wormhole(APPID, self.relayurl)
w2 = Wormhole(APPID, self.relayurl)
code = yield w1.get_code()
w2.set_code(code)
res = yield self.doBoth(w1.send_data(b"data1"), w2.get_data())
(_, dataY) = res
self.assertEqual(dataY, b"data1")
dl = yield self.doBoth(w1.get_data(), w2.send_data(b"data2"))
(dataX, _) = dl
self.assertEqual(dataX, b"data2")
yield self.doBoth(w1.close(), w2.close())
@inlineCallbacks
def test_fixed_code(self):
w1 = Wormhole(APPID, self.relayurl)
w2 = Wormhole(APPID, self.relayurl)
w1.set_code(u"123-purple-elephant")
w2.set_code(u"123-purple-elephant")
yield self.doBoth(w1.send_data(b"data1"), w2.send_data(b"data2"))
dl = yield self.doBoth(w1.get_data(), w2.get_data())
(dataX, dataY) = dl
self.assertEqual(dataX, b"data2")
self.assertEqual(dataY, b"data1")
yield self.doBoth(w1.close(), w2.close())
@inlineCallbacks
def test_phases(self):
w1 = Wormhole(APPID, self.relayurl)
w2 = Wormhole(APPID, self.relayurl)
w1.set_code(u"123-purple-elephant")
w2.set_code(u"123-purple-elephant")
yield self.doBoth(w1.send_data(b"data1", u"p1"),
w2.send_data(b"data2", u"p1"))
yield self.doBoth(w1.send_data(b"data3", u"p2"),
w2.send_data(b"data4", u"p2"))
dl = yield self.doBoth(w1.get_data(u"p2"),
w2.get_data(u"p1"))
(dataX, dataY) = dl
self.assertEqual(dataX, b"data4")
self.assertEqual(dataY, b"data1")
dl = yield self.doBoth(w1.get_data(u"p1"),
w2.get_data(u"p2"))
(dataX, dataY) = dl
self.assertEqual(dataX, b"data2")
self.assertEqual(dataY, b"data3")
yield self.doBoth(w1.close(), w2.close())
@inlineCallbacks
def test_wrong_password(self):
w1 = Wormhole(APPID, self.relayurl)
w2 = Wormhole(APPID, self.relayurl)
code = yield w1.get_code()
w2.set_code(code+"not")
# w2 can't throw WrongPasswordError until it sees a CONFIRM message,
# and w1 won't send CONFIRM until it sees a PAKE message, which w2
# won't send until we call get_data. So we need both sides to be
# running at the same time for this test.
d1 = w1.send_data(b"data1")
# at this point, w1 should be waiting for w2.PAKE
yield self.assertFailure(w2.get_data(), WrongPasswordError)
# * w2 will send w2.PAKE, wait for (and get) w1.PAKE, compute a key,
# send w2.CONFIRM, then wait for w1.DATA.
# * w1 will get w2.PAKE, compute a key, send w1.CONFIRM.
# * w1 might also get w2.CONFIRM, and may notice the error before it
# sends w1.CONFIRM, in which case the wait=True will signal an
# error inside _get_master_key() (inside send_data), and d1 will
# errback.
# * but w1 might not see w2.CONFIRM yet, in which case it won't
# errback until we do w1.get_data()
# * w2 gets w1.CONFIRM, notices the error, records it.
# * w2 (waiting for w1.DATA) wakes up, sees the error, throws
# * meanwhile w1 finishes sending its data. w2.CONFIRM may or may not
# have arrived by then
try:
yield d1
except WrongPasswordError:
pass
# When we ask w1 to get_data(), one of two things might happen:
# * if w2.CONFIRM arrived already, it will have recorded the error.
# When w1.get_data() sleeps (waiting for w2.DATA), we'll notice the
# error before sleeping, and throw WrongPasswordError
# * if w2.CONFIRM hasn't arrived yet, we'll sleep. When w2.CONFIRM
# arrives, we notice and record the error, and wake up, and throw
# Note that we didn't do w2.send_data(), so we're hoping that w1 will
# have enough information to detect the error before it sleeps
# (waiting for w2.DATA). Checking for the error both before sleeping
# and after waking up makes this happen.
# so now w1 should have enough information to throw too
yield self.assertFailure(w1.get_data(), WrongPasswordError)
# both sides are closed automatically upon error, but it's still
# legal to call .close(), and should be idempotent
yield self.doBoth(w1.close(), w2.close())
@inlineCallbacks
def test_no_confirm(self):
# newer versions (which check confirmations) should will work with
# older versions (that don't send confirmations)
w1 = Wormhole(APPID, self.relayurl)
w1._send_confirm = False
w2 = Wormhole(APPID, self.relayurl)
code = yield w1.get_code()
w2.set_code(code)
dl = yield self.doBoth(w1.send_data(b"data1"), w2.get_data())
self.assertEqual(dl[1], b"data1")
dl = yield self.doBoth(w1.get_data(), w2.send_data(b"data2"))
self.assertEqual(dl[0], b"data2")
yield self.doBoth(w1.close(), w2.close())
@inlineCallbacks
def test_verifier(self):
w1 = Wormhole(APPID, self.relayurl)
w2 = Wormhole(APPID, self.relayurl)
code = yield w1.get_code()
w2.set_code(code)
res = yield self.doBoth(w1.get_verifier(), w2.get_verifier())
v1, v2 = res
self.failUnlessEqual(type(v1), type(b""))
self.failUnlessEqual(v1, v2)
yield self.doBoth(w1.send_data(b"data1"), w2.send_data(b"data2"))
dl = yield self.doBoth(w1.get_data(), w2.get_data())
(dataX, dataY) = dl
self.assertEqual(dataX, b"data2")
self.assertEqual(dataY, b"data1")
yield self.doBoth(w1.close(), w2.close())
@inlineCallbacks
def test_errors(self):
w1 = Wormhole(APPID, self.relayurl)
yield self.assertFailure(w1.get_verifier(), UsageError)
yield self.assertFailure(w1.send_data(b"data"), UsageError)
yield self.assertFailure(w1.get_data(), UsageError)
w1.set_code(u"123-purple-elephant")
yield self.assertRaises(UsageError, w1.set_code, u"123-nope")
yield self.assertFailure(w1.get_code(), UsageError)
w2 = Wormhole(APPID, self.relayurl)
yield w2.get_code()
yield self.assertFailure(w2.get_code(), UsageError)
yield self.doBoth(w1.close(), w2.close())
@inlineCallbacks
def test_repeat_phases(self):
w1 = Wormhole(APPID, self.relayurl)
w1.set_code(u"123-purple-elephant")
w2 = Wormhole(APPID, self.relayurl)
w2.set_code(u"123-purple-elephant")
# we must let them establish a key before we can send data
yield self.doBoth(w1.get_verifier(), w2.get_verifier())
yield w1.send_data(b"data1", phase=u"1")
# underscore-prefixed phases are reserved
yield self.assertFailure(w1.send_data(b"data1", phase=u"_1"),
UsageError)
yield self.assertFailure(w1.get_data(phase=u"_1"), UsageError)
# you can't send twice to the same phase
yield self.assertFailure(w1.send_data(b"data1", phase=u"1"),
UsageError)
# but you can send to a different one
yield w1.send_data(b"data2", phase=u"2")
res = yield w2.get_data(phase=u"1")
self.failUnlessEqual(res, b"data1")
# and you can't read twice from the same phase
yield self.assertFailure(w2.get_data(phase=u"1"), UsageError)
# but you can read from a different one
res = yield w2.get_data(phase=u"2")
self.failUnlessEqual(res, b"data2")
yield self.doBoth(w1.close(), w2.close())
@inlineCallbacks
def test_serialize(self):
w1 = Wormhole(APPID, self.relayurl)
self.assertRaises(UsageError, w1.serialize) # too early
w2 = Wormhole(APPID, self.relayurl)
code = yield w1.get_code()
self.assertRaises(UsageError, w2.serialize) # too early
w2.set_code(code)
w2.serialize() # ok
s = w1.serialize()
self.assertEqual(type(s), type(""))
unpacked = json.loads(s) # this is supposed to be JSON
self.assertEqual(type(unpacked), dict)
self.new_w1 = Wormhole.from_serialized(s)
yield self.doBoth(self.new_w1.send_data(b"data1"),
w2.send_data(b"data2"))
dl = yield self.doBoth(self.new_w1.get_data(), w2.get_data())
(dataX, dataY) = dl
self.assertEqual((dataX, dataY), (b"data2", b"data1"))
self.assertRaises(UsageError, w2.serialize) # too late
yield gatherResults([w1.close(), w2.close(), self.new_w1.close()],
True)

View File

@ -0,0 +1,789 @@
from __future__ import print_function
import os, json, re, gc
from binascii import hexlify, unhexlify
import mock
from twisted.trial import unittest
from twisted.internet import reactor
from twisted.internet.defer import Deferred, gatherResults, inlineCallbacks
from .common import ServerBase
from .. import wormhole
from ..errors import WrongPasswordError, WelcomeError, UsageError
from spake2 import SPAKE2_Symmetric
from ..timing import DebugTiming
from nacl.secret import SecretBox
APPID = u"appid"
class MockWebSocket:
def __init__(self):
self._payloads = []
def sendMessage(self, payload, is_binary):
assert not is_binary
self._payloads.append(payload)
def outbound(self):
out = []
while self._payloads:
p = self._payloads.pop(0)
out.append(json.loads(p.decode("utf-8")))
return out
def response(w, **kwargs):
payload = json.dumps(kwargs).encode("utf-8")
w._ws_dispatch_response(payload)
class Welcome(unittest.TestCase):
def test_tolerate_no_current_version(self):
w = wormhole._WelcomeHandler(u"relay_url", u"current_version", None)
w.handle_welcome({})
def test_print_motd(self):
w = wormhole._WelcomeHandler(u"relay_url", u"current_version", None)
with mock.patch("sys.stderr") as stderr:
w.handle_welcome({u"motd": u"message of\nthe day"})
self.assertEqual(stderr.method_calls,
[mock.call.write(u"Server (at relay_url) says:\n"
" message of\n the day"),
mock.call.write(u"\n")])
# motd is only displayed once
with mock.patch("sys.stderr") as stderr2:
w.handle_welcome({u"motd": u"second message"})
self.assertEqual(stderr2.method_calls, [])
def test_current_version(self):
w = wormhole._WelcomeHandler(u"relay_url", u"2.0", None)
with mock.patch("sys.stderr") as stderr:
w.handle_welcome({u"current_version": u"2.0"})
self.assertEqual(stderr.method_calls, [])
with mock.patch("sys.stderr") as stderr:
w.handle_welcome({u"current_version": u"3.0"})
exp1 = (u"Warning: errors may occur unless both sides are"
" running the same version")
exp2 = (u"Server claims 3.0 is current, but ours is 2.0")
self.assertEqual(stderr.method_calls,
[mock.call.write(exp1),
mock.call.write(u"\n"),
mock.call.write(exp2),
mock.call.write(u"\n"),
])
# warning is only displayed once
with mock.patch("sys.stderr") as stderr:
w.handle_welcome({u"current_version": u"3.0"})
self.assertEqual(stderr.method_calls, [])
def test_non_release_version(self):
w = wormhole._WelcomeHandler(u"relay_url", u"2.0-dirty", None)
with mock.patch("sys.stderr") as stderr:
w.handle_welcome({u"current_version": u"3.0"})
self.assertEqual(stderr.method_calls, [])
def test_signal_error(self):
se = mock.Mock()
w = wormhole._WelcomeHandler(u"relay_url", u"2.0", se)
w.handle_welcome({})
self.assertEqual(se.mock_calls, [])
w.handle_welcome({u"error": u"oops"})
self.assertEqual(len(se.mock_calls), 1)
self.assertEqual(len(se.mock_calls[0][1]), 1) # posargs
we = se.mock_calls[0][1][0]
self.assertIsInstance(we, WelcomeError)
self.assertEqual(we.args, (u"oops",))
# alas WelcomeError instances don't compare against each other
#self.assertEqual(se.mock_calls, [mock.call(WelcomeError(u"oops"))])
class InputCode(unittest.TestCase):
def test_list(self):
send_command = mock.Mock()
ic = wormhole._InputCode(None, u"prompt", 2, send_command,
DebugTiming())
d = ic._list()
self.assertNoResult(d)
self.assertEqual(send_command.mock_calls, [mock.call(u"list")])
ic._response_handle_nameplates({u"type": u"nameplates",
u"nameplates": [{u"id": u"123"}]})
res = self.successResultOf(d)
self.assertEqual(res, [u"123"])
class GetCode(unittest.TestCase):
def test_get(self):
send_command = mock.Mock()
gc = wormhole._GetCode(2, send_command, DebugTiming())
d = gc.go()
self.assertNoResult(d)
self.assertEqual(send_command.mock_calls, [mock.call(u"allocate")])
# TODO: nameplate attributes get added and checked here
gc._response_handle_allocated({u"type": u"allocated",
u"nameplate": u"123"})
code = self.successResultOf(d)
self.assertIsInstance(code, type(u""))
self.assert_(code.startswith(u"123-"))
pieces = code.split(u"-")
self.assertEqual(len(pieces), 3) # nameplate plus two words
self.assert_(re.search(r'^\d+-\w+-\w+$', code), code)
class Basic(unittest.TestCase):
def tearDown(self):
# flush out any errorful Deferreds left dangling in cycles
gc.collect()
def check_out(self, out, **kwargs):
# Assert that each kwarg is present in the 'out' dict. Ignore other
# keys ('msgid' in particular)
for key, value in kwargs.items():
self.assertIn(key, out)
self.assertEqual(out[key], value, (out, key, value))
def check_outbound(self, ws, types):
out = ws.outbound()
self.assertEqual(len(out), len(types), (out, types))
for i,t in enumerate(types):
self.assertEqual(out[i][u"type"], t, (i,t,out))
return out
def make_pake(self, code, side, msg1):
sp2 = SPAKE2_Symmetric(wormhole.to_bytes(code),
idSymmetric=wormhole.to_bytes(APPID))
msg2 = sp2.start()
msg2_hex = hexlify(msg2).decode("ascii")
key = sp2.finish(msg1)
return key, msg2_hex
def test_create(self):
wormhole._Wormhole(APPID, u"relay_url", reactor, None, None)
def test_basic(self):
# We don't call w._start(), so this doesn't create a WebSocket
# connection. We provide a mock connection instead. If we wanted to
# exercise _connect, we'd mock out WSFactory.
# w._connect = lambda self: None
# w._event_connected(mock_ws)
# w._event_ws_opened()
# w._ws_dispatch_response(payload)
timing = DebugTiming()
with mock.patch("wormhole.wormhole._WelcomeHandler") as wh_c:
w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing)
wh = wh_c.return_value
self.assertEqual(w._ws_url, u"relay_url")
self.assertTrue(w._flag_need_nameplate)
self.assertTrue(w._flag_need_to_build_msg1)
self.assertTrue(w._flag_need_to_send_PAKE)
v = w.verify()
w._drop_connection = mock.Mock()
ws = MockWebSocket()
w._event_connected(ws)
out = ws.outbound()
self.assertEqual(len(out), 0)
w._event_ws_opened(None)
out = ws.outbound()
self.assertEqual(len(out), 1)
self.check_out(out[0], type=u"bind", appid=APPID, side=w._side)
self.assertIn(u"id", out[0])
# WelcomeHandler should get called upon 'welcome' response. Its full
# behavior is exercised in 'Welcome' above.
WELCOME = {u"foo": u"bar"}
response(w, type="welcome", welcome=WELCOME)
self.assertEqual(wh.mock_calls, [mock.call.handle_welcome(WELCOME)])
# because we're connected, setting the code also claims the mailbox
CODE = u"123-foo-bar"
w.set_code(CODE)
self.assertFalse(w._flag_need_to_build_msg1)
out = ws.outbound()
self.assertEqual(len(out), 1)
self.check_out(out[0], type=u"claim", nameplate=u"123")
# the server reveals the linked mailbox
response(w, type=u"claimed", mailbox=u"mb456")
# that triggers event_learned_mailbox, which should send open() and
# PAKE
self.assertEqual(w._mailbox_state, wormhole.OPEN)
out = ws.outbound()
self.assertEqual(len(out), 2)
self.check_out(out[0], type=u"open", mailbox=u"mb456")
self.check_out(out[1], type=u"add", phase=u"pake")
self.assertNoResult(v)
# server echoes back all "add" messages
response(w, type=u"message", phase=u"pake", body=out[1][u"body"],
side=w._side)
self.assertNoResult(v)
# next we build the simulated peer's PAKE operation
side2 = w._side + u"other"
msg1 = unhexlify(out[1][u"body"].encode("ascii"))
key, msg2_hex = self.make_pake(CODE, side2, msg1)
response(w, type=u"message", phase=u"pake", body=msg2_hex, side=side2)
# hearing the peer's PAKE (msg2) makes us release the nameplate, send
# the confirmation message, delivered the verifier, and sends any
# queued phase messages
self.assertFalse(w._flag_need_to_see_mailbox_used)
self.assertEqual(w._key, key)
out = ws.outbound()
self.assertEqual(len(out), 2, out)
self.check_out(out[0], type=u"release")
self.check_out(out[1], type=u"add", phase=u"confirm")
verifier = self.successResultOf(v)
self.assertEqual(verifier,
w.derive_key(u"wormhole:verifier", SecretBox.KEY_SIZE))
# hearing a valid confirmation message doesn't throw an error
confkey = w.derive_key(u"wormhole:confirmation", SecretBox.KEY_SIZE)
nonce = os.urandom(wormhole.CONFMSG_NONCE_LENGTH)
confirm2 = wormhole.make_confmsg(confkey, nonce)
confirm2_hex = hexlify(confirm2).decode("ascii")
response(w, type=u"message", phase=u"confirm", body=confirm2_hex,
side=side2)
# an outbound message can now be sent immediately
w.send(b"phase0-outbound")
out = ws.outbound()
self.assertEqual(len(out), 1)
self.check_out(out[0], type=u"add", phase=u"0")
# decrypt+check the outbound message
p0_outbound = unhexlify(out[0][u"body"].encode("ascii"))
msgkey0 = w._derive_phase_key(w._side, u"0")
p0_plaintext = w._decrypt_data(msgkey0, p0_outbound)
self.assertEqual(p0_plaintext, b"phase0-outbound")
# get() waits for the inbound message to arrive
md = w.get()
self.assertNoResult(md)
self.assertIn(u"0", w._receive_waiters)
self.assertNotIn(u"0", w._received_messages)
msgkey1 = w._derive_phase_key(side2, u"0")
p0_inbound = w._encrypt_data(msgkey1, b"phase0-inbound")
p0_inbound_hex = hexlify(p0_inbound).decode("ascii")
response(w, type=u"message", phase=u"0", body=p0_inbound_hex,
side=side2)
p0_in = self.successResultOf(md)
self.assertEqual(p0_in, b"phase0-inbound")
self.assertNotIn(u"0", w._receive_waiters)
self.assertIn(u"0", w._received_messages)
# receiving an inbound message will queue it until get() is called
msgkey2 = w._derive_phase_key(side2, u"1")
p1_inbound = w._encrypt_data(msgkey2, b"phase1-inbound")
p1_inbound_hex = hexlify(p1_inbound).decode("ascii")
response(w, type=u"message", phase=u"1", body=p1_inbound_hex,
side=side2)
self.assertIn(u"1", w._received_messages)
self.assertNotIn(u"1", w._receive_waiters)
p1_in = self.successResultOf(w.get())
self.assertEqual(p1_in, b"phase1-inbound")
self.assertIn(u"1", w._received_messages)
self.assertNotIn(u"1", w._receive_waiters)
d = w.close()
self.assertNoResult(d)
out = ws.outbound()
self.assertEqual(len(out), 1)
self.check_out(out[0], type=u"close", mood=u"happy")
self.assertEqual(w._drop_connection.mock_calls, [])
response(w, type=u"released")
self.assertEqual(w._drop_connection.mock_calls, [])
response(w, type=u"closed")
self.assertEqual(w._drop_connection.mock_calls, [mock.call()])
w._ws_closed(True, None, None)
self.assertEqual(self.successResultOf(d), None)
def test_close_wait_0(self):
# Close before the connection is established. The connection still
# gets established, but it is then torn down before sending anything.
timing = DebugTiming()
w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing)
w._drop_connection = mock.Mock()
d = w.close()
self.assertNoResult(d)
ws = MockWebSocket()
w._event_connected(ws)
w._event_ws_opened(None)
self.assertEqual(w._drop_connection.mock_calls, [mock.call()])
self.assertNoResult(d)
w._ws_closed(True, None, None)
self.successResultOf(d)
def test_close_wait_1(self):
# close before even claiming the nameplate
timing = DebugTiming()
w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing)
w._drop_connection = mock.Mock()
ws = MockWebSocket()
w._event_connected(ws)
w._event_ws_opened(None)
d = w.close()
self.check_outbound(ws, [u"bind"])
self.assertNoResult(d)
self.assertEqual(w._drop_connection.mock_calls, [mock.call()])
self.assertNoResult(d)
w._ws_closed(True, None, None)
self.successResultOf(d)
def test_close_wait_2(self):
# Close after claiming the nameplate, but before opening the mailbox.
# The 'claimed' response arrives before we close.
timing = DebugTiming()
w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing)
w._drop_connection = mock.Mock()
ws = MockWebSocket()
w._event_connected(ws)
w._event_ws_opened(None)
CODE = u"123-foo-bar"
w.set_code(CODE)
self.check_outbound(ws, [u"bind", u"claim"])
response(w, type=u"claimed", mailbox=u"mb123")
d = w.close()
self.check_outbound(ws, [u"open", u"add", u"release", u"close"])
self.assertNoResult(d)
self.assertEqual(w._drop_connection.mock_calls, [])
response(w, type=u"released")
self.assertNoResult(d)
self.assertEqual(w._drop_connection.mock_calls, [])
response(w, type=u"closed")
self.assertEqual(w._drop_connection.mock_calls, [mock.call()])
self.assertNoResult(d)
w._ws_closed(True, None, None)
self.successResultOf(d)
def test_close_wait_3(self):
# close after claiming the nameplate, but before opening the mailbox
# The 'claimed' response arrives after we start to close.
timing = DebugTiming()
w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing)
w._drop_connection = mock.Mock()
ws = MockWebSocket()
w._event_connected(ws)
w._event_ws_opened(None)
CODE = u"123-foo-bar"
w.set_code(CODE)
self.check_outbound(ws, [u"bind", u"claim"])
d = w.close()
response(w, type=u"claimed", mailbox=u"mb123")
self.check_outbound(ws, [u"release"])
self.assertNoResult(d)
self.assertEqual(w._drop_connection.mock_calls, [])
response(w, type=u"released")
self.assertEqual(w._drop_connection.mock_calls, [mock.call()])
self.assertNoResult(d)
w._ws_closed(True, None, None)
self.successResultOf(d)
def test_close_wait_4(self):
# close after both claiming the nameplate and opening the mailbox
timing = DebugTiming()
w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing)
w._drop_connection = mock.Mock()
ws = MockWebSocket()
w._event_connected(ws)
w._event_ws_opened(None)
CODE = u"123-foo-bar"
w.set_code(CODE)
response(w, type=u"claimed", mailbox=u"mb456")
self.check_outbound(ws, [u"bind", u"claim", u"open", u"add"])
d = w.close()
self.check_outbound(ws, [u"release", u"close"])
self.assertNoResult(d)
self.assertEqual(w._drop_connection.mock_calls, [])
response(w, type=u"released")
self.assertNoResult(d)
self.assertEqual(w._drop_connection.mock_calls, [])
response(w, type=u"closed")
self.assertNoResult(d)
self.assertEqual(w._drop_connection.mock_calls, [mock.call()])
w._ws_closed(True, None, None)
self.successResultOf(d)
def test_close_wait_5(self):
# close after claiming the nameplate, opening the mailbox, then
# releasing the nameplate
timing = DebugTiming()
w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing)
w._drop_connection = mock.Mock()
ws = MockWebSocket()
w._event_connected(ws)
w._event_ws_opened(None)
CODE = u"123-foo-bar"
w.set_code(CODE)
response(w, type=u"claimed", mailbox=u"mb456")
w._key = b""
msgkey = w._derive_phase_key(u"side2", u"misc")
p1_inbound = w._encrypt_data(msgkey, b"")
p1_inbound_hex = hexlify(p1_inbound).decode("ascii")
response(w, type=u"message", phase=u"misc", side=u"side2",
body=p1_inbound_hex)
self.check_outbound(ws, [u"bind", u"claim", u"open", u"add",
u"release"])
d = w.close()
self.check_outbound(ws, [u"close"])
self.assertNoResult(d)
self.assertEqual(w._drop_connection.mock_calls, [])
response(w, type=u"released")
self.assertNoResult(d)
self.assertEqual(w._drop_connection.mock_calls, [])
response(w, type=u"closed")
self.assertNoResult(d)
self.assertEqual(w._drop_connection.mock_calls, [mock.call()])
w._ws_closed(True, None, None)
self.successResultOf(d)
def test_close_errbacks(self):
# make sure the Deferreds returned by verify() and get() are properly
# errbacked upon close
pass
def test_get_code_mock(self):
timing = DebugTiming()
w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing)
ws = MockWebSocket() # TODO: mock w._ws_send_command instead
w._event_connected(ws)
w._event_ws_opened(None)
self.check_outbound(ws, [u"bind"])
gc_c = mock.Mock()
gc = gc_c.return_value = mock.Mock()
gc_d = gc.go.return_value = Deferred()
with mock.patch("wormhole.wormhole._GetCode", gc_c):
d = w.get_code()
self.assertNoResult(d)
gc_d.callback(u"123-foo-bar")
code = self.successResultOf(d)
self.assertEqual(code, u"123-foo-bar")
def test_get_code_real(self):
timing = DebugTiming()
w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing)
ws = MockWebSocket()
w._event_connected(ws)
w._event_ws_opened(None)
self.check_outbound(ws, [u"bind"])
d = w.get_code()
out = ws.outbound()
self.assertEqual(len(out), 1)
self.check_out(out[0], type=u"allocate")
# TODO: nameplate attributes go here
self.assertNoResult(d)
response(w, type=u"allocated", nameplate=u"123")
code = self.successResultOf(d)
self.assertIsInstance(code, type(u""))
self.assert_(code.startswith(u"123-"))
pieces = code.split(u"-")
self.assertEqual(len(pieces), 3) # nameplate plus two words
self.assert_(re.search(r'^\d+-\w+-\w+$', code), code)
def test_verifier(self):
# make sure verify() can be called both before and after the verifier
# is computed
pass
def test_api_errors(self):
# doing things you're not supposed to do
pass
def test_welcome_error(self):
# A welcome message could arrive at any time, with an [error] key
# that should make us halt. In practice, though, this gets sent as
# soon as the connection is established, which limits the possible
# states in which we might see it.
timing = DebugTiming()
w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing)
w._drop_connection = mock.Mock()
ws = MockWebSocket()
w._event_connected(ws)
w._event_ws_opened(None)
self.check_outbound(ws, [u"bind"])
d1 = w.get()
d2 = w.verify()
d3 = w.get_code()
# TODO (tricky): test w.input_code
self.assertNoResult(d1)
self.assertNoResult(d2)
self.assertNoResult(d3)
w._signal_error(WelcomeError(u"you are not actually welcome"), u"pouty")
self.failureResultOf(d1, WelcomeError)
self.failureResultOf(d2, WelcomeError)
self.failureResultOf(d3, WelcomeError)
# once the error is signalled, all API calls should fail
self.assertRaises(WelcomeError, w.send, u"foo")
self.assertRaises(WelcomeError,
w.derive_key, u"foo", SecretBox.KEY_SIZE)
self.failureResultOf(w.get(), WelcomeError)
self.failureResultOf(w.verify(), WelcomeError)
def test_confirm_error(self):
# we should only receive the "confirm" message after we receive the
# PAKE message, by which point we should know the key. If the
# confirmation message doesn't decrypt, we signal an error.
timing = DebugTiming()
w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing)
w._drop_connection = mock.Mock()
ws = MockWebSocket()
w._event_connected(ws)
w._event_ws_opened(None)
w.set_code(u"123-foo-bar")
response(w, type=u"claimed", mailbox=u"mb456")
d1 = w.get()
d2 = w.verify()
self.assertNoResult(d1)
self.assertNoResult(d2)
out = ws.outbound()
# [u"bind", u"claim", u"open", u"add"]
self.assertEqual(len(out), 4)
self.assertEqual(out[3][u"type"], u"add")
sp2 = SPAKE2_Symmetric(b"", idSymmetric=wormhole.to_bytes(APPID))
msg2 = sp2.start()
msg2_hex = hexlify(msg2).decode("ascii")
response(w, type=u"message", phase=u"pake", body=msg2_hex, side=u"s2")
self.assertNoResult(d1)
self.successResultOf(d2) # early verify is unaffected
# TODO: change verify() to wait for "confirm"
# sending a random confirm message will cause a confirmation error
confkey = w.derive_key(u"WRONG", SecretBox.KEY_SIZE)
nonce = os.urandom(wormhole.CONFMSG_NONCE_LENGTH)
badconfirm = wormhole.make_confmsg(confkey, nonce)
badconfirm_hex = hexlify(badconfirm).decode("ascii")
response(w, type=u"message", phase=u"confirm", body=badconfirm_hex,
side=u"s2")
self.failureResultOf(d1, WrongPasswordError)
# once the error is signalled, all API calls should fail
self.assertRaises(WrongPasswordError, w.send, u"foo")
self.assertRaises(WrongPasswordError,
w.derive_key, u"foo", SecretBox.KEY_SIZE)
self.failureResultOf(w.get(), WrongPasswordError)
self.failureResultOf(w.verify(), WrongPasswordError)
# event orderings to exercise:
#
# * normal sender: set_code, send_phase1, connected, claimed, learn_msg2,
# learn_phase1
# * normal receiver (argv[2]=code): set_code, connected, learn_msg1,
# learn_phase1, send_phase1,
# * normal receiver (readline): connected, input_code
# *
# * set_code, then connected
# * connected, receive_pake, send_phase, set_code
class Wormholes(ServerBase, unittest.TestCase):
# integration test, with a real server
def doBoth(self, d1, d2):
return gatherResults([d1, d2], True)
@inlineCallbacks
def test_basic(self):
w1 = wormhole.wormhole(APPID, self.relayurl, reactor)
w2 = wormhole.wormhole(APPID, self.relayurl, reactor)
code = yield w1.get_code()
w2.set_code(code)
w1.send(b"data1")
w2.send(b"data2")
dataX = yield w1.get()
dataY = yield w2.get()
self.assertEqual(dataX, b"data2")
self.assertEqual(dataY, b"data1")
yield w1.close()
yield w2.close()
@inlineCallbacks
def test_same_message(self):
# the two sides use random nonces for their messages, so it's ok for
# both to try and send the same body: they'll result in distinct
# encrypted messages
w1 = wormhole.wormhole(APPID, self.relayurl, reactor)
w2 = wormhole.wormhole(APPID, self.relayurl, reactor)
code = yield w1.get_code()
w2.set_code(code)
w1.send(b"data")
w2.send(b"data")
dataX = yield w1.get()
dataY = yield w2.get()
self.assertEqual(dataX, b"data")
self.assertEqual(dataY, b"data")
yield w1.close()
yield w2.close()
@inlineCallbacks
def test_interleaved(self):
w1 = wormhole.wormhole(APPID, self.relayurl, reactor)
w2 = wormhole.wormhole(APPID, self.relayurl, reactor)
code = yield w1.get_code()
w2.set_code(code)
w1.send(b"data1")
dataY = yield w2.get()
self.assertEqual(dataY, b"data1")
d = w1.get()
w2.send(b"data2")
dataX = yield d
self.assertEqual(dataX, b"data2")
yield w1.close()
yield w2.close()
@inlineCallbacks
def test_unidirectional(self):
w1 = wormhole.wormhole(APPID, self.relayurl, reactor)
w2 = wormhole.wormhole(APPID, self.relayurl, reactor)
code = yield w1.get_code()
w2.set_code(code)
w1.send(b"data1")
dataY = yield w2.get()
self.assertEqual(dataY, b"data1")
yield w1.close()
yield w2.close()
@inlineCallbacks
def test_early(self):
w1 = wormhole.wormhole(APPID, self.relayurl, reactor)
w1.send(b"data1")
w2 = wormhole.wormhole(APPID, self.relayurl, reactor)
d = w2.get()
w1.set_code(u"123-abc-def")
w2.set_code(u"123-abc-def")
dataY = yield d
self.assertEqual(dataY, b"data1")
yield w1.close()
yield w2.close()
@inlineCallbacks
def test_fixed_code(self):
w1 = wormhole.wormhole(APPID, self.relayurl, reactor)
w2 = wormhole.wormhole(APPID, self.relayurl, reactor)
w1.set_code(u"123-purple-elephant")
w2.set_code(u"123-purple-elephant")
w1.send(b"data1"), w2.send(b"data2")
dl = yield self.doBoth(w1.get(), w2.get())
(dataX, dataY) = dl
self.assertEqual(dataX, b"data2")
self.assertEqual(dataY, b"data1")
yield w1.close()
yield w2.close()
@inlineCallbacks
def test_multiple_messages(self):
w1 = wormhole.wormhole(APPID, self.relayurl, reactor)
w2 = wormhole.wormhole(APPID, self.relayurl, reactor)
w1.set_code(u"123-purple-elephant")
w2.set_code(u"123-purple-elephant")
w1.send(b"data1"), w2.send(b"data2")
w1.send(b"data3"), w2.send(b"data4")
dl = yield self.doBoth(w1.get(), w2.get())
(dataX, dataY) = dl
self.assertEqual(dataX, b"data2")
self.assertEqual(dataY, b"data1")
dl = yield self.doBoth(w1.get(), w2.get())
(dataX, dataY) = dl
self.assertEqual(dataX, b"data4")
self.assertEqual(dataY, b"data3")
yield w1.close()
yield w2.close()
@inlineCallbacks
def test_wrong_password(self):
w1 = wormhole.wormhole(APPID, self.relayurl, reactor)
w2 = wormhole.wormhole(APPID, self.relayurl, reactor)
code = yield w1.get_code()
w2.set_code(code+"not")
# That's enough to allow both sides to discover the mismatch, but
# only after the confirmation message gets through. API calls that
# don't wait will appear to work until the mismatched confirmation
# message arrives.
w1.send(b"should still work")
w2.send(b"should still work")
# API calls that wait (i.e. get) will errback
yield self.assertFailure(w2.get(), WrongPasswordError)
yield self.assertFailure(w1.get(), WrongPasswordError)
yield w1.close()
yield w2.close()
self.flushLoggedErrors(WrongPasswordError)
@inlineCallbacks
def test_verifier(self):
w1 = wormhole.wormhole(APPID, self.relayurl, reactor)
w2 = wormhole.wormhole(APPID, self.relayurl, reactor)
code = yield w1.get_code()
w2.set_code(code)
v1 = yield w1.verify()
v2 = yield w2.verify()
self.failUnlessEqual(type(v1), type(b""))
self.failUnlessEqual(v1, v2)
w1.send(b"data1")
w2.send(b"data2")
dataX = yield w1.get()
dataY = yield w2.get()
self.assertEqual(dataX, b"data2")
self.assertEqual(dataY, b"data1")
yield w1.close()
yield w2.close()
class Errors(ServerBase, unittest.TestCase):
@inlineCallbacks
def test_codes_1(self):
w = wormhole.wormhole(APPID, self.relayurl, reactor)
# definitely too early
self.assertRaises(UsageError, w.derive_key, u"purpose", 12)
w.set_code(u"123-purple-elephant")
# code can only be set once
self.assertRaises(UsageError, w.set_code, u"123-nope")
yield self.assertFailure(w.get_code(), UsageError)
yield self.assertFailure(w.input_code(), UsageError)
yield w.close()
@inlineCallbacks
def test_codes_2(self):
w = wormhole.wormhole(APPID, self.relayurl, reactor)
yield w.get_code()
self.assertRaises(UsageError, w.set_code, u"123-nope")
yield self.assertFailure(w.get_code(), UsageError)
yield self.assertFailure(w.input_code(), UsageError)
yield w.close()

View File

@ -1,238 +0,0 @@
#import sys
from twisted.python import log, failure
from twisted.internet import reactor, defer, protocol
from twisted.application import service
from twisted.protocols import basic
from twisted.web.client import Agent, ResponseDone
from twisted.web.http_headers import Headers
from cgi import parse_header
from .eventual import eventually
#if sys.version_info[0] == 2:
# to_unicode = unicode
#else:
# to_unicode = str
class EventSourceParser(basic.LineOnlyReceiver):
# http://www.w3.org/TR/eventsource/
delimiter = b"\n"
def __init__(self, handler):
self.current_field = None
self.current_lines = []
self.handler = handler
self.done_deferred = defer.Deferred()
self.eventtype = u"message"
self.encoding = "utf-8"
def set_encoding(self, encoding):
self.encoding = encoding
def connectionLost(self, why):
if why.check(ResponseDone):
why = None
self.done_deferred.callback(why)
def dataReceived(self, data):
# exceptions here aren't being logged properly, and tests will hang
# rather than halt. I suspect twisted.web._newclient's
# HTTP11ClientProtocol.dataReceived(), which catches everything and
# responds with self._giveUp() but doesn't log.err.
try:
basic.LineOnlyReceiver.dataReceived(self, data)
except:
log.err()
raise
def lineReceived(self, line):
#line = to_unicode(line, self.encoding)
line = line.decode(self.encoding)
if not line:
# blank line ends the field: deliver event, reset for next
self.eventReceived(self.eventtype, "\n".join(self.current_lines))
self.eventtype = u"message"
self.current_lines[:] = []
return
if u":" in line:
fieldname, data = line.split(u":", 1)
if data.startswith(u" "):
data = data[1:]
else:
fieldname = line
data = u""
if fieldname == u"event":
self.eventtype = data
elif fieldname == u"data":
self.current_lines.append(data)
elif fieldname in (u"id", u"retry"):
# documented but unhandled
pass
else:
log.msg("weird fieldname", fieldname, data)
def eventReceived(self, eventtype, data):
self.handler(eventtype, data)
class EventSourceError(Exception):
pass
# es = EventSource(url, handler)
# d = es.start()
# es.cancel()
class EventSource: # TODO: service.Service
def __init__(self, url, handler, when_connected=None, agent=None):
assert isinstance(url, type(u""))
self.url = url
self.handler = handler
self.when_connected = when_connected
self.started = False
self.cancelled = False
self.proto = EventSourceParser(self.handler)
if not agent:
agent = Agent(reactor)
self.agent = agent
def start(self):
assert not self.started, "single-use"
self.started = True
assert self.url
d = self.agent.request(b"GET", self.url.encode("utf-8"),
Headers({b"accept": [b"text/event-stream"]}))
d.addCallback(self._connected)
return d
def _connected(self, resp):
if resp.code != 200:
raise EventSourceError("%d: %s" % (resp.code, resp.phrase))
if self.when_connected:
self.when_connected()
default_ct = "text/event-stream; charset=utf-8"
ct_headers = resp.headers.getRawHeaders("content-type", [default_ct])
ct, ct_params = parse_header(ct_headers[0])
assert ct == "text/event-stream", ct
self.proto.set_encoding(ct_params.get("charset", "utf-8"))
resp.deliverBody(self.proto)
if self.cancelled:
self.kill_connection()
return self.proto.done_deferred
def cancel(self):
self.cancelled = True
if not self.proto.transport:
# _connected hasn't been called yet, but that self.cancelled
# should take care of it when the connection is established
def kill(data):
# this should kill it as soon as any data is delivered
raise ValueError("dead")
self.proto.dataReceived = kill # just in case
return
self.kill_connection()
def kill_connection(self):
if (hasattr(self.proto.transport, "_producer")
and self.proto.transport._producer):
# This is gross and fragile. We need a clean way to stop the
# client connection. p.transport is a
# twisted.web._newclient.TransportProxyProducer , and its
# ._producer is the tcp.Port.
self.proto.transport._producer.loseConnection()
else:
log.err("get_events: unable to stop connection")
# oh well
#err = EventSourceError("unable to cancel")
try:
self.proto.done_deferred.callback(None)
except defer.AlreadyCalledError:
pass
class Connector:
# behave enough like an IConnector to appease ReconnectingClientFactory
def __init__(self, res):
self.res = res
def connect(self):
self.res._maybeStart()
def stopConnecting(self):
self.res._stop_eventsource()
class ReconnectingEventSource(service.MultiService,
protocol.ReconnectingClientFactory):
def __init__(self, url, handler, agent=None):
service.MultiService.__init__(self)
# we don't use any of the basic Factory/ClientFactory methods of
# this, just the ReconnectingClientFactory.retry, stopTrying, and
# resetDelay methods.
self.url = url
self.handler = handler
self.agent = agent
# IService provides self.running, toggled by {start,stop}Service.
# self.active is toggled by {,de}activate. If both .running and
# .active are True, then we want to have an outstanding EventSource
# and will start one if necessary. If either is False, then we don't
# want one to be outstanding, and will initiate shutdown.
self.active = False
self.connector = Connector(self)
self.es = None # set we have an outstanding EventSource
self.when_stopped = [] # list of Deferreds
def isStopped(self):
return not self.es
def startService(self):
service.MultiService.startService(self) # sets self.running
self._maybeStart()
def stopService(self):
# clears self.running
d = defer.maybeDeferred(service.MultiService.stopService, self)
d.addCallback(self._maybeStop)
return d
def activate(self):
assert not self.active
self.active = True
self._maybeStart()
def deactivate(self):
assert self.active # XXX
self.active = False
return self._maybeStop()
def _maybeStart(self):
if not (self.active and self.running):
return
self.continueTrying = True
self.es = EventSource(self.url, self.handler, self.resetDelay,
agent=self.agent)
d = self.es.start()
d.addBoth(self._stopped)
def _stopped(self, res):
self.es = None
# we might have stopped because of a connection error, or because of
# an intentional shutdown.
if self.active and self.running:
# we still want to be connected, so schedule a reconnection
if isinstance(res, failure.Failure):
log.err(res)
self.retry() # will eventually call _maybeStart
return
# intentional shutdown
self.stopTrying()
for d in self.when_stopped:
eventually(d.callback, None)
self.when_stopped = []
def _stop_eventsource(self):
if self.es:
eventually(self.es.cancel)
def _maybeStop(self, _=None):
self.stopTrying() # cancels timer, calls _stop_eventsource()
if not self.es:
return defer.succeed(None)
d = defer.Deferred()
self.when_stopped.append(d)
return d

View File

@ -1,560 +0,0 @@
from __future__ import print_function
import os, sys, json, re, unicodedata
from six.moves.urllib_parse import urlparse
from binascii import hexlify, unhexlify
from twisted.internet import reactor, defer, endpoints, error
from twisted.internet.threads import deferToThread, blockingCallFromThread
from twisted.internet.defer import inlineCallbacks, returnValue
from twisted.python import log
from autobahn.twisted import websocket
from nacl.secret import SecretBox
from nacl.exceptions import CryptoError
from nacl import utils
from spake2 import SPAKE2_Symmetric
from .. import __version__
from .. import codes
from ..errors import ServerError, Timeout, WrongPasswordError, UsageError
from ..timing import DebugTiming
from hkdf import Hkdf
def HKDF(skm, outlen, salt=None, CTXinfo=b""):
return Hkdf(salt, skm).expand(CTXinfo, outlen)
CONFMSG_NONCE_LENGTH = 128//8
CONFMSG_MAC_LENGTH = 256//8
def make_confmsg(confkey, nonce):
return nonce+HKDF(confkey, CONFMSG_MAC_LENGTH, nonce)
def to_bytes(u):
return unicodedata.normalize("NFC", u).encode("utf-8")
class WSClient(websocket.WebSocketClientProtocol):
def onOpen(self):
self.wormhole_open = True
self.factory.d.callback(self)
def onMessage(self, payload, isBinary):
assert not isBinary
self.wormhole._ws_dispatch_msg(payload)
def onClose(self, wasClean, code, reason):
if self.wormhole_open:
self.wormhole._ws_closed(wasClean, code, reason)
else:
# we closed before establishing a connection (onConnect) or
# finishing WebSocket negotiation (onOpen): errback
self.factory.d.errback(error.ConnectError(reason))
class WSFactory(websocket.WebSocketClientFactory):
protocol = WSClient
def buildProtocol(self, addr):
proto = websocket.WebSocketClientFactory.buildProtocol(self, addr)
proto.wormhole = self.wormhole
proto.wormhole_open = False
return proto
class Wormhole:
motd_displayed = False
version_warning_displayed = False
_send_confirm = True
def __init__(self, appid, relay_url, tor_manager=None, timing=None,
reactor=reactor):
if not isinstance(appid, type(u"")): raise TypeError(type(appid))
if not isinstance(relay_url, type(u"")):
raise TypeError(type(relay_url))
if not relay_url.endswith(u"/"): raise UsageError
self._appid = appid
self._relay_url = relay_url
self._ws_url = relay_url.replace("http:", "ws:") + "ws"
self._tor_manager = tor_manager
self._timing = timing or DebugTiming()
self._reactor = reactor
self._side = hexlify(os.urandom(5)).decode("ascii")
self._code = None
self._channelid = None
self._key = None
self._started_get_code = False
self._sent_messages = set() # (phase, body_bytes)
self._delivered_messages = set() # (phase, body_bytes)
self._received_messages = {} # phase -> body_bytes
self._sent_phases = set() # phases, to prohibit double-send
self._got_phases = set() # phases, to prohibit double-read
self._sleepers = []
self._confirmation_failed = False
self._closed = False
self._deallocated_status = None
self._timing_started = self._timing.add("wormhole")
self._ws = None
self._ws_t = None # timing Event
self._ws_channel_claimed = False
self._error = None
def _make_endpoint(self, hostname, port):
if self._tor_manager:
return self._tor_manager.get_endpoint_for(hostname, port)
# note: HostnameEndpoints have a default 30s timeout
return endpoints.HostnameEndpoint(self._reactor, hostname, port)
@inlineCallbacks
def _get_websocket(self):
if not self._ws:
# TODO: if we lose the connection, make a new one
#from twisted.python import log
#log.startLogging(sys.stderr)
assert self._side
assert not self._ws_channel_claimed
p = urlparse(self._ws_url)
f = WSFactory(self._ws_url)
f.wormhole = self
f.d = defer.Deferred()
# TODO: if hostname="localhost", I get three factories starting
# and stopping (maybe 127.0.0.1, ::1, and something else?), and
# an error in the factory is masked.
ep = self._make_endpoint(p.hostname, p.port or 80)
# .connect errbacks if the TCP connection fails
self._ws = yield ep.connect(f)
self._ws_t = self._timing.add("websocket")
# f.d is errbacked if WebSocket negotiation fails
yield f.d # WebSocket drops data sent before onOpen() fires
self._ws_send(u"bind", appid=self._appid, side=self._side)
# the socket is connected, and bound, but no channel has been claimed
returnValue(self._ws)
@inlineCallbacks
def _ws_send(self, mtype, **kwargs):
ws = yield self._get_websocket()
# msgid is used by misc/dump-timing.py to correlate our sends with
# their receives, and vice versa. They are also correlated with the
# ACKs we get back from the server (which we otherwise ignore). There
# are so few messages, 16 bits is enough to be mostly-unique.
kwargs["id"] = hexlify(os.urandom(2)).decode("ascii")
kwargs["type"] = mtype
payload = json.dumps(kwargs).encode("utf-8")
self._timing.add("ws_send", _side=self._side, **kwargs)
ws.sendMessage(payload, False)
def _ws_dispatch_msg(self, payload):
msg = json.loads(payload.decode("utf-8"))
self._timing.add("ws_receive", _side=self._side, message=msg)
mtype = msg["type"]
meth = getattr(self, "_ws_handle_"+mtype, None)
if not meth:
# make tests fail, but real application will ignore it
log.err(ValueError("Unknown inbound message type %r" % (msg,)))
return
return meth(msg)
def _ws_handle_ack(self, msg):
pass
def _ws_handle_welcome(self, msg):
welcome = msg["welcome"]
if ("motd" in welcome and
not self.motd_displayed):
motd_lines = welcome["motd"].splitlines()
motd_formatted = "\n ".join(motd_lines)
print("Server (at %s) says:\n %s" %
(self._ws_url, motd_formatted), file=sys.stderr)
self.motd_displayed = True
# Only warn if we're running a release version (e.g. 0.0.6, not
# 0.0.6-DISTANCE-gHASH). Only warn once.
if ("-" not in __version__ and
not self.version_warning_displayed and
welcome["current_version"] != __version__):
print("Warning: errors may occur unless both sides are running the same version", file=sys.stderr)
print("Server claims %s is current, but ours is %s"
% (welcome["current_version"], __version__), file=sys.stderr)
self.version_warning_displayed = True
if "error" in welcome:
return self._signal_error(welcome["error"])
@inlineCallbacks
def _sleep(self, wake_on_error=True):
if wake_on_error and self._error:
# don't sleep if the bed's already on fire, unless we're waiting
# for the fire department to respond, in which case sure, keep on
# sleeping
raise self._error
d = defer.Deferred()
self._sleepers.append(d)
yield d
if wake_on_error and self._error:
raise self._error
def _wakeup(self):
sleepers = self._sleepers
self._sleepers = []
for d in sleepers:
d.callback(None)
# NOTE: callers should avoid reentrancy themselves. An
# eventual-send would be safer here, but it makes synchronizing
# unit tests annoying.
def _signal_error(self, error):
assert isinstance(error, Exception)
self._error = error
self._wakeup()
def _ws_handle_error(self, msg):
err = ServerError("%s: %s" % (msg["error"], msg["orig"]),
self._ws_url)
return self._signal_error(err)
@inlineCallbacks
def _claim_channel_and_watch(self):
assert self._channelid is not None
yield self._get_websocket()
if not self._ws_channel_claimed:
yield self._ws_send(u"claim", channelid=self._channelid)
self._ws_channel_claimed = True
yield self._ws_send(u"watch")
# entry point 1: generate a new code
@inlineCallbacks
def get_code(self, code_length=2): # rename to allocate_code()? create_?
if self._code is not None: raise UsageError
if self._started_get_code: raise UsageError
self._started_get_code = True
with self._timing.add("API get_code"):
with self._timing.add("allocate"):
yield self._ws_send(u"allocate")
while self._channelid is None:
yield self._sleep()
code = codes.make_code(self._channelid, code_length)
assert isinstance(code, type(u"")), type(code)
self._set_code(code)
self._start()
returnValue(code)
def _ws_handle_allocated(self, msg):
if self._channelid is not None:
return self._signal_error("got duplicate channelid")
self._channelid = msg["channelid"]
self._wakeup()
def _start(self):
# allocate the rest now too, so it can be serialized
with self._timing.add("pake1", waiting="crypto"):
self._sp = SPAKE2_Symmetric(to_bytes(self._code),
idSymmetric=to_bytes(self._appid))
self._msg1 = self._sp.start()
# entry point 2a: interactively type in a code, with completion
@inlineCallbacks
def input_code(self, prompt="Enter wormhole code: ", code_length=2):
def _lister():
return blockingCallFromThread(self._reactor, self._list_channels)
# fetch the list of channels ahead of time, to give us a chance to
# discover the welcome message (and warn the user about an obsolete
# client)
#
# TODO: send the request early, show the prompt right away, hide the
# latency in the user's indecision and slow typing. If we're lucky
# the answer will come back before they hit TAB.
with self._timing.add("API input_code"):
initial_channelids = yield self._list_channels()
with self._timing.add("input code", waiting="user"):
t = self._reactor.addSystemEventTrigger("before", "shutdown",
self._warn_readline)
code = yield deferToThread(codes.input_code_with_completion,
prompt,
initial_channelids, _lister,
code_length)
self._reactor.removeSystemEventTrigger(t)
returnValue(code) # application will give this to set_code()
def _warn_readline(self):
# When our process receives a SIGINT, Twisted's SIGINT handler will
# stop the reactor and wait for all threads to terminate before the
# process exits. However, if we were waiting for
# input_code_with_completion() when SIGINT happened, the readline
# thread will be blocked waiting for something on stdin. Trick the
# user into satisfying the blocking read so we can exit.
print("\nCommand interrupted: please press Return to quit",
file=sys.stderr)
# Other potential approaches to this problem:
# * hard-terminate our process with os._exit(1), but make sure the
# tty gets reset to a normal mode ("cooked"?) first, so that the
# next shell command the user types is echoed correctly
# * track down the thread (t.p.threadable.getThreadID from inside the
# thread), get a cffi binding to pthread_kill, deliver SIGINT to it
# * allocate a pty pair (pty.openpty), replace sys.stdin with the
# slave, build a pty bridge that copies bytes (and other PTY
# things) from the real stdin to the master, then close the slave
# at shutdown, so readline sees EOF
# * write tab-completion and basic editing (TTY raw mode,
# backspace-is-erase) without readline, probably with curses or
# twisted.conch.insults
# * write a separate program to get codes (maybe just "wormhole
# --internal-get-code"), run it as a subprocess, let it inherit
# stdin/stdout, send it SIGINT when we receive SIGINT ourselves. It
# needs an RPC mechanism (over some extra file descriptors) to ask
# us to fetch the current channelid list.
#
# Note that hard-terminating our process with os.kill(os.getpid(),
# signal.SIGKILL), or SIGTERM, doesn't seem to work: the thread
# doesn't see the signal, and we must still wait for stdin to make
# readline finish.
@inlineCallbacks
def _list_channels(self):
with self._timing.add("list"):
self._latest_channelids = None
yield self._ws_send(u"list")
while self._latest_channelids is None:
yield self._sleep()
returnValue(self._latest_channelids)
def _ws_handle_channelids(self, msg):
self._latest_channelids = msg["channelids"]
self._wakeup()
# entry point 2b: paste in a fully-formed code
def set_code(self, code):
if not isinstance(code, type(u"")): raise TypeError(type(code))
if self._code is not None: raise UsageError
mo = re.search(r'^(\d+)-', code)
if not mo:
raise ValueError("code (%s) must start with NN-" % code)
with self._timing.add("API set_code"):
self._channelid = int(mo.group(1))
self._set_code(code)
self._start()
def _set_code(self, code):
if self._code is not None: raise UsageError
self._timing.add("code established")
self._code = code
def serialize(self):
# I can only be serialized after get_code/set_code and before
# get_verifier/get_data
if self._code is None: raise UsageError
if self._key is not None: raise UsageError
if self._sent_phases: raise UsageError
if self._got_phases: raise UsageError
data = {
"appid": self._appid,
"relay_url": self._relay_url,
"code": self._code,
"channelid": self._channelid,
"side": self._side,
"spake2": json.loads(self._sp.serialize().decode("ascii")),
"msg1": hexlify(self._msg1).decode("ascii"),
}
return json.dumps(data)
# entry point 3: resume a previously-serialized session
@classmethod
def from_serialized(klass, data):
d = json.loads(data)
self = klass(d["appid"], d["relay_url"])
self._side = d["side"]
self._channelid = d["channelid"]
self._set_code(d["code"])
sp_data = json.dumps(d["spake2"]).encode("ascii")
self._sp = SPAKE2_Symmetric.from_serialized(sp_data)
self._msg1 = unhexlify(d["msg1"].encode("ascii"))
return self
@inlineCallbacks
def get_verifier(self):
if self._closed: raise UsageError
if self._code is None: raise UsageError
with self._timing.add("API get_verifier"):
yield self._get_master_key()
# If the caller cares about the verifier, then they'll probably
# also willing to wait a moment to see the _confirm message. Each
# side sends this as soon as it sees the other's PAKE message. So
# the sender should see this hot on the heels of the inbound PAKE
# message (a moment after _get_master_key() returns). The
# receiver will see this a round-trip after they send their PAKE
# (because the sender is using wait=True inside _get_master_key,
# below: otherwise the sender might go do some blocking call).
yield self._msg_get(u"_confirm")
returnValue(self._verifier)
@inlineCallbacks
def _get_master_key(self):
# TODO: prevent multiple invocation
if not self._key:
yield self._claim_channel_and_watch()
yield self._msg_send(u"pake", self._msg1)
pake_msg = yield self._msg_get(u"pake")
with self._timing.add("pake2", waiting="crypto"):
self._key = self._sp.finish(pake_msg)
self._verifier = self.derive_key(u"wormhole:verifier")
self._timing.add("key established")
if self._send_confirm:
# both sides send different (random) confirmation messages
confkey = self.derive_key(u"wormhole:confirmation")
nonce = os.urandom(CONFMSG_NONCE_LENGTH)
confmsg = make_confmsg(confkey, nonce)
yield self._msg_send(u"_confirm", confmsg, wait=True)
@inlineCallbacks
def _msg_send(self, phase, body, wait=False):
self._sent_messages.add( (phase, body) )
# TODO: retry on failure, with exponential backoff. We're guarding
# against the rendezvous server being temporarily offline.
t = self._timing.add("add", phase=phase, wait=wait)
yield self._ws_send(u"add", phase=phase,
body=hexlify(body).decode("ascii"))
if wait:
while (phase, body) not in self._delivered_messages:
yield self._sleep()
t.finish()
def _ws_handle_message(self, msg):
m = msg["message"]
phase = m["phase"]
body = unhexlify(m["body"].encode("ascii"))
if (phase, body) in self._sent_messages:
self._delivered_messages.add( (phase, body) ) # ack by server
self._wakeup()
return # ignore echoes of our outbound messages
if phase in self._received_messages:
# a channel collision would cause this
err = ServerError("got duplicate phase %s" % phase, self._ws_url)
return self._signal_error(err)
self._received_messages[phase] = body
if phase == u"_confirm":
# TODO: we might not have a master key yet, if the caller wasn't
# waiting in _get_master_key() when a back-to-back pake+_confirm
# message pair arrived.
confkey = self.derive_key(u"wormhole:confirmation")
nonce = body[:CONFMSG_NONCE_LENGTH]
if body != make_confmsg(confkey, nonce):
# this makes all API calls fail
return self._signal_error(WrongPasswordError())
# now notify anyone waiting on it
self._wakeup()
@inlineCallbacks
def _msg_get(self, phase):
with self._timing.add("get", phase=phase):
while phase not in self._received_messages:
yield self._sleep() # we can wait a long time here
# that will throw an error if something goes wrong
msg = self._received_messages[phase]
returnValue(msg)
def derive_key(self, purpose, length=SecretBox.KEY_SIZE):
if not isinstance(purpose, type(u"")): raise TypeError(type(purpose))
if self._key is None:
# call after get_verifier() or get_data()
raise UsageError
return HKDF(self._key, length, CTXinfo=to_bytes(purpose))
def _encrypt_data(self, key, data):
assert isinstance(key, type(b"")), type(key)
assert isinstance(data, type(b"")), type(data)
assert len(key) == SecretBox.KEY_SIZE, len(key)
box = SecretBox(key)
nonce = utils.random(SecretBox.NONCE_SIZE)
return box.encrypt(data, nonce)
def _decrypt_data(self, key, encrypted):
assert isinstance(key, type(b"")), type(key)
assert isinstance(encrypted, type(b"")), type(encrypted)
assert len(key) == SecretBox.KEY_SIZE, len(key)
box = SecretBox(key)
data = box.decrypt(encrypted)
return data
@inlineCallbacks
def send_data(self, outbound_data, phase=u"data", wait=False):
if not isinstance(outbound_data, type(b"")):
raise TypeError(type(outbound_data))
if not isinstance(phase, type(u"")): raise TypeError(type(phase))
if self._closed: raise UsageError
if self._code is None:
raise UsageError("You must set_code() before send_data()")
if phase.startswith(u"_"): raise UsageError # reserved for internals
if phase in self._sent_phases: raise UsageError # only call this once
self._sent_phases.add(phase)
with self._timing.add("API send_data", phase=phase, wait=wait):
# Without predefined roles, we can't derive predictably unique
# keys for each side, so we use the same key for both. We use
# random nonces to keep the messages distinct, and we
# automatically ignore reflections.
yield self._get_master_key()
data_key = self.derive_key(u"wormhole:phase:%s" % phase)
outbound_encrypted = self._encrypt_data(data_key, outbound_data)
yield self._msg_send(phase, outbound_encrypted, wait)
@inlineCallbacks
def get_data(self, phase=u"data"):
if not isinstance(phase, type(u"")): raise TypeError(type(phase))
if self._closed: raise UsageError
if self._code is None: raise UsageError
if phase.startswith(u"_"): raise UsageError # reserved for internals
if phase in self._got_phases: raise UsageError # only call this once
self._got_phases.add(phase)
with self._timing.add("API get_data", phase=phase):
yield self._get_master_key()
body = yield self._msg_get(phase) # we can wait a long time here
try:
data_key = self.derive_key(u"wormhole:phase:%s" % phase)
inbound_data = self._decrypt_data(data_key, body)
returnValue(inbound_data)
except CryptoError:
raise WrongPasswordError
def _ws_closed(self, wasClean, code, reason):
self._ws = None
self._ws_t.finish()
# TODO: schedule reconnect, unless we're done
@inlineCallbacks
def close(self, f=None, mood=None):
"""Do d.addBoth(w.close) at the end of your chain."""
if self._closed:
returnValue(None)
self._closed = True
if not self._ws:
returnValue(None)
if mood is None:
mood = u"happy"
if f:
if f.check(Timeout):
mood = u"lonely"
elif f.check(WrongPasswordError):
mood = u"scary"
elif f.check(TypeError, UsageError):
# preconditions don't warrant reporting mood
pass
else:
mood = u"errory" # other errors do
if not isinstance(mood, (type(None), type(u""))):
raise TypeError(type(mood))
with self._timing.add("API close"):
yield self._deallocate(mood)
# TODO: mark WebSocket as don't-reconnect
self._ws.transport.loseConnection() # probably flushes
del self._ws
self._ws_t.finish()
self._timing_started.finish(mood=mood)
returnValue(f)
@inlineCallbacks
def _deallocate(self, mood):
with self._timing.add("deallocate"):
yield self._ws_send(u"deallocate", mood=mood)
while self._deallocated_status is None:
yield self._sleep(wake_on_error=False)
# TODO: set a timeout, don't wait forever for an ack
# TODO: if the connection is lost, let it go
returnValue(self._deallocated_status)
def _ws_handle_deallocated(self, msg):
self._deallocated_status = msg["status"]
self._wakeup()

View File

@ -548,6 +548,7 @@ def there_can_be_only_one(contenders):
class Common:
RELAY_DELAY = 2.0
TRANSIT_KEY_LENGTH = SecretBox.KEY_SIZE
def __init__(self, transit_relay, no_listen=False, tor_manager=None,
reactor=reactor, timing=None):

845
src/wormhole/wormhole.py Normal file
View File

@ -0,0 +1,845 @@
from __future__ import print_function, absolute_import
import os, sys, json, re, unicodedata
from six.moves.urllib_parse import urlparse
from binascii import hexlify, unhexlify
from twisted.internet import defer, endpoints, error
from twisted.internet.threads import deferToThread, blockingCallFromThread
from twisted.internet.defer import inlineCallbacks, returnValue
from twisted.python import log
from autobahn.twisted import websocket
from nacl.secret import SecretBox
from nacl.exceptions import CryptoError
from nacl import utils
from spake2 import SPAKE2_Symmetric
from hashlib import sha256
from . import __version__
from . import codes
#from .errors import ServerError, Timeout
from .errors import (WrongPasswordError, UsageError, WelcomeError,
WormholeClosedError)
from .timing import DebugTiming
from hkdf import Hkdf
def HKDF(skm, outlen, salt=None, CTXinfo=b""):
return Hkdf(salt, skm).expand(CTXinfo, outlen)
CONFMSG_NONCE_LENGTH = 128//8
CONFMSG_MAC_LENGTH = 256//8
def make_confmsg(confkey, nonce):
return nonce+HKDF(confkey, CONFMSG_MAC_LENGTH, nonce)
def to_bytes(u):
return unicodedata.normalize("NFC", u).encode("utf-8")
# We send the following messages through the relay server to the far side (by
# sending "add" commands to the server, and getting "message" responses):
#
# phase=setup:
# * unauthenticated version strings (but why?)
# * early warmup for connection hints ("I can do tor, spin up HS")
# * wordlist l10n identifier
# phase=pake: just the SPAKE2 'start' message (binary)
# phase=confirm: key verification (HKDF(key, nonce)+nonce)
# phase=1,2,3,..: application messages
class WSClient(websocket.WebSocketClientProtocol):
def onOpen(self):
self.wormhole_open = True
self.factory.d.callback(self)
def onMessage(self, payload, isBinary):
assert not isBinary
self.wormhole._ws_dispatch_response(payload)
def onClose(self, wasClean, code, reason):
if self.wormhole_open:
self.wormhole._ws_closed(wasClean, code, reason)
else:
# we closed before establishing a connection (onConnect) or
# finishing WebSocket negotiation (onOpen): errback
self.factory.d.errback(error.ConnectError(reason))
class WSFactory(websocket.WebSocketClientFactory):
protocol = WSClient
def buildProtocol(self, addr):
proto = websocket.WebSocketClientFactory.buildProtocol(self, addr)
proto.wormhole = self.wormhole
proto.wormhole_open = False
return proto
class _GetCode:
def __init__(self, code_length, send_command, timing):
self._code_length = code_length
self._send_command = send_command
self._timing = timing
self._allocated_d = defer.Deferred()
@inlineCallbacks
def go(self):
with self._timing.add("allocate"):
self._send_command(u"allocate")
nameplate_id = yield self._allocated_d
code = codes.make_code(nameplate_id, self._code_length)
assert isinstance(code, type(u"")), type(code)
returnValue(code)
def _response_handle_allocated(self, msg):
nid = msg["nameplate"]
assert isinstance(nid, type(u"")), type(nid)
self._allocated_d.callback(nid)
class _InputCode:
def __init__(self, reactor, prompt, code_length, send_command, timing):
self._reactor = reactor
self._prompt = prompt
self._code_length = code_length
self._send_command = send_command
self._timing = timing
@inlineCallbacks
def _list(self):
self._lister_d = defer.Deferred()
self._send_command(u"list")
nameplates = yield self._lister_d
self._lister_d = None
returnValue(nameplates)
def _list_blocking(self):
return blockingCallFromThread(self._reactor, self._list)
@inlineCallbacks
def go(self):
# fetch the list of nameplates ahead of time, to give us a chance to
# discover the welcome message (and warn the user about an obsolete
# client)
#
# TODO: send the request early, show the prompt right away, hide the
# latency in the user's indecision and slow typing. If we're lucky
# the answer will come back before they hit TAB.
initial_nameplate_ids = yield self._list()
with self._timing.add("input code", waiting="user"):
t = self._reactor.addSystemEventTrigger("before", "shutdown",
self._warn_readline)
code = yield deferToThread(codes.input_code_with_completion,
self._prompt,
initial_nameplate_ids,
self._list_blocking,
self._code_length)
self._reactor.removeSystemEventTrigger(t)
returnValue(code)
def _response_handle_nameplates(self, msg):
nameplates = msg["nameplates"]
assert isinstance(nameplates, list), type(nameplates)
nids = []
for n in nameplates:
assert isinstance(n, dict), type(n)
nameplate_id = n[u"id"]
assert isinstance(nameplate_id, type(u"")), type(nameplate_id)
nids.append(nameplate_id)
self._lister_d.callback(nids)
def _warn_readline(self):
# When our process receives a SIGINT, Twisted's SIGINT handler will
# stop the reactor and wait for all threads to terminate before the
# process exits. However, if we were waiting for
# input_code_with_completion() when SIGINT happened, the readline
# thread will be blocked waiting for something on stdin. Trick the
# user into satisfying the blocking read so we can exit.
print("\nCommand interrupted: please press Return to quit",
file=sys.stderr)
# Other potential approaches to this problem:
# * hard-terminate our process with os._exit(1), but make sure the
# tty gets reset to a normal mode ("cooked"?) first, so that the
# next shell command the user types is echoed correctly
# * track down the thread (t.p.threadable.getThreadID from inside the
# thread), get a cffi binding to pthread_kill, deliver SIGINT to it
# * allocate a pty pair (pty.openpty), replace sys.stdin with the
# slave, build a pty bridge that copies bytes (and other PTY
# things) from the real stdin to the master, then close the slave
# at shutdown, so readline sees EOF
# * write tab-completion and basic editing (TTY raw mode,
# backspace-is-erase) without readline, probably with curses or
# twisted.conch.insults
# * write a separate program to get codes (maybe just "wormhole
# --internal-get-code"), run it as a subprocess, let it inherit
# stdin/stdout, send it SIGINT when we receive SIGINT ourselves. It
# needs an RPC mechanism (over some extra file descriptors) to ask
# us to fetch the current nameplate_id list.
#
# Note that hard-terminating our process with os.kill(os.getpid(),
# signal.SIGKILL), or SIGTERM, doesn't seem to work: the thread
# doesn't see the signal, and we must still wait for stdin to make
# readline finish.
class _WelcomeHandler:
def __init__(self, url, current_version, signal_error):
self._ws_url = url
self._version_warning_displayed = False
self._motd_displayed = False
self._current_version = current_version
self._signal_error = signal_error
def handle_welcome(self, welcome):
if ("motd" in welcome and
not self._motd_displayed):
motd_lines = welcome["motd"].splitlines()
motd_formatted = "\n ".join(motd_lines)
print("Server (at %s) says:\n %s" %
(self._ws_url, motd_formatted), file=sys.stderr)
self._motd_displayed = True
# Only warn if we're running a release version (e.g. 0.0.6, not
# 0.0.6-DISTANCE-gHASH). Only warn once.
if ("current_version" in welcome
and "-" not in self._current_version
and not self._version_warning_displayed
and welcome["current_version"] != self._current_version):
print("Warning: errors may occur unless both sides are running the same version", file=sys.stderr)
print("Server claims %s is current, but ours is %s"
% (welcome["current_version"], self._current_version),
file=sys.stderr)
self._version_warning_displayed = True
if "error" in welcome:
return self._signal_error(WelcomeError(welcome["error"]))
# states for nameplates, mailboxes, and the websocket connection
(CLOSED, OPENING, OPEN, CLOSING) = ("closed", "opening", "open", "closing")
class _Wormhole:
def __init__(self, appid, relay_url, reactor, tor_manager, timing):
self._appid = appid
self._ws_url = relay_url
self._reactor = reactor
self._tor_manager = tor_manager
self._timing = timing
self._welcomer = _WelcomeHandler(self._ws_url, __version__,
self._signal_error)
self._side = hexlify(os.urandom(5)).decode("ascii")
self._connection_state = CLOSED
self._connection_waiters = []
self._started_get_code = False
self._get_code = None
self._started_input_code = False
self._code = None
self._nameplate_id = None
self._nameplate_state = CLOSED
self._mailbox_id = None
self._mailbox_state = CLOSED
self._flag_need_nameplate = True
self._flag_need_to_see_mailbox_used = True
self._flag_need_to_build_msg1 = True
self._flag_need_to_send_PAKE = True
self._key = None
self._close_called = False # the close() API has been called
self._closing = False # we've started shutdown
self._disconnect_waiter = defer.Deferred()
self._error = None
self._get_verifier_called = False
self._verifier = None
self._verifier_waiter = None
self._next_send_phase = 0
# send() queues plaintext here, waiting for a connection and the key
self._plaintext_to_send = [] # (phase, plaintext)
self._sent_phases = set() # to detect double-send
self._next_receive_phase = 0
self._receive_waiters = {} # phase -> Deferred
self._received_messages = {} # phase -> plaintext
# API METHODS for applications to call
# You must use at least one of these entry points, to establish the
# wormhole code. Other APIs will stall or be queued until we have one.
# entry point 1: generate a new code. returns a Deferred
def get_code(self, code_length=2): # XX rename to allocate_code()? create_?
return self._API_get_code(code_length)
# entry point 2: interactively type in a code, with completion. returns
# Deferred
def input_code(self, prompt="Enter wormhole code: ", code_length=2):
return self._API_input_code(prompt, code_length)
# entry point 3: paste in a fully-formed code. No return value.
def set_code(self, code):
self._API_set_code(code)
# todo: restore-saved-state entry points
def verify(self):
"""Returns a Deferred that fires when we've heard back from the other
side, and have confirmed that they used the right wormhole code. When
successful, the Deferred fires with a "verifier" (a bytestring) which
can be compared out-of-band before making additional API calls. If
they used the wrong wormhole code, the Deferred errbacks with
WrongPasswordError.
"""
return self._API_verify()
def send(self, outbound_data):
return self._API_send(outbound_data)
def get(self):
return self._API_get()
def derive_key(self, purpose, length):
"""Derive a new key from the established wormhole channel for some
other purpose. This is a deterministic randomized function of the
session key and the 'purpose' string (unicode/py3-string). This
cannot be called until verify() or get() has fired.
"""
return self._API_derive_key(purpose, length)
def close(self, res=None):
"""Collapse the wormhole, freeing up server resources and flushing
all pending messages. Returns a Deferred that fires when everything
is done. It fires with any argument close() was given, to enable use
as a d.addBoth() handler:
w = wormhole(...)
d = w.get()
..
d.addBoth(w.close)
return d
Another reasonable approach is to use inlineCallbacks:
@inlineCallbacks
def pair(self, code):
w = wormhole(...)
try:
them = yield w.get()
finally:
yield w.close()
"""
return self._API_close(res)
# INTERNAL METHODS beyond here
def _start(self):
d = self._connect() # causes stuff to happen
d.addErrback(log.err)
return d # fires when connection is established, if you care
def _make_endpoint(self, hostname, port):
if self._tor_manager:
return self._tor_manager.get_endpoint_for(hostname, port)
# note: HostnameEndpoints have a default 30s timeout
return endpoints.HostnameEndpoint(self._reactor, hostname, port)
def _connect(self):
# TODO: if we lose the connection, make a new one, re-establish the
# state
assert self._side
self._connection_state = OPENING
p = urlparse(self._ws_url)
f = WSFactory(self._ws_url)
f.wormhole = self
f.d = defer.Deferred()
# TODO: if hostname="localhost", I get three factories starting
# and stopping (maybe 127.0.0.1, ::1, and something else?), and
# an error in the factory is masked.
ep = self._make_endpoint(p.hostname, p.port or 80)
# .connect errbacks if the TCP connection fails
d = ep.connect(f)
d.addCallback(self._event_connected)
# f.d is errbacked if WebSocket negotiation fails, and the WebSocket
# drops any data sent before onOpen() fires, so we must wait for it
d.addCallback(lambda _: f.d)
d.addCallback(self._event_ws_opened)
return d
def _event_connected(self, ws):
self._ws = ws
self._ws_t = self._timing.add("websocket")
def _event_ws_opened(self, _):
self._connection_state = OPEN
if self._closing:
return self._maybe_finished_closing()
self._ws_send_command(u"bind", appid=self._appid, side=self._side)
self._maybe_claim_nameplate()
self._maybe_send_pake()
waiters, self._connection_waiters = self._connection_waiters, []
for d in waiters:
d.callback(None)
def _when_connected(self):
if self._connection_state == OPEN:
return defer.succeed(None)
d = defer.Deferred()
self._connection_waiters.append(d)
return d
def _ws_send_command(self, mtype, **kwargs):
# msgid is used by misc/dump-timing.py to correlate our sends with
# their receives, and vice versa. They are also correlated with the
# ACKs we get back from the server (which we otherwise ignore). There
# are so few messages, 16 bits is enough to be mostly-unique.
if self.DEBUG: print("SEND", mtype)
kwargs["id"] = hexlify(os.urandom(2)).decode("ascii")
kwargs["type"] = mtype
payload = json.dumps(kwargs).encode("utf-8")
self._timing.add("ws_send", _side=self._side, **kwargs)
self._ws.sendMessage(payload, False)
DEBUG=False
def _ws_dispatch_response(self, payload):
msg = json.loads(payload.decode("utf-8"))
if self.DEBUG and msg["type"]!="ack": print("DIS", msg["type"], msg)
self._timing.add("ws_receive", _side=self._side, message=msg)
mtype = msg["type"]
meth = getattr(self, "_response_handle_"+mtype, None)
if not meth:
# make tests fail, but real application will ignore it
log.err(ValueError("Unknown inbound message type %r" % (msg,)))
return
return meth(msg)
def _response_handle_ack(self, msg):
pass
def _response_handle_welcome(self, msg):
self._welcomer.handle_welcome(msg["welcome"])
# entry point 1: generate a new code
@inlineCallbacks
def _API_get_code(self, code_length):
if self._code is not None: raise UsageError
if self._started_get_code: raise UsageError
self._started_get_code = True
with self._timing.add("API get_code"):
yield self._when_connected()
gc = _GetCode(code_length, self._ws_send_command, self._timing)
self._get_code = gc
self._response_handle_allocated = gc._response_handle_allocated
# TODO: signal_error
code = yield gc.go()
self._get_code = None
self._nameplate_state = OPEN
self._event_learned_code(code)
returnValue(code)
# entry point 2: interactively type in a code, with completion
@inlineCallbacks
def _API_input_code(self, prompt, code_length):
if self._code is not None: raise UsageError
if self._started_input_code: raise UsageError
self._started_input_code = True
with self._timing.add("API input_code"):
yield self._when_connected()
ic = _InputCode(self._reactor, prompt, code_length,
self._ws_send_command, self._timing)
self._response_handle_nameplates = ic._response_handle_nameplates
# TODO: signal_error
code = yield ic.go()
self._event_learned_code(code)
returnValue(None)
# entry point 3: paste in a fully-formed code
def _API_set_code(self, code):
self._timing.add("API set_code")
if not isinstance(code, type(u"")): raise TypeError(type(code))
if self._code is not None: raise UsageError
self._event_learned_code(code)
# TODO: entry point 4: restore pre-contact saved state (we haven't heard
# from the peer yet, so we still need the nameplate)
# TODO: entry point 5: restore post-contact saved state (so we don't need
# or use the nameplate, only the mailbox)
def _restore_post_contact_state(self, state):
# ...
self._flag_need_nameplate = False
#self._mailbox_id = X(state)
self._event_learned_mailbox()
def _event_learned_code(self, code):
self._timing.add("code established")
self._code = code
mo = re.search(r'^(\d+)-', code)
if not mo:
raise ValueError("code (%s) must start with NN-" % code)
nid = mo.group(1)
assert isinstance(nid, type(u"")), type(nid)
self._nameplate_id = nid
# fire more events
self._maybe_build_msg1()
self._event_learned_nameplate()
def _maybe_build_msg1(self):
if not (self._code and self._flag_need_to_build_msg1):
return
with self._timing.add("pake1", waiting="crypto"):
self._sp = SPAKE2_Symmetric(to_bytes(self._code),
idSymmetric=to_bytes(self._appid))
self._msg1 = self._sp.start()
self._flag_need_to_build_msg1 = False
self._event_built_msg1()
def _event_built_msg1(self):
self._maybe_send_pake()
# every _maybe_X starts with a set of conditions
# for each such condition Y, every _event_Y must call _maybe_X
def _event_learned_nameplate(self):
self._maybe_claim_nameplate()
def _maybe_claim_nameplate(self):
if not (self._nameplate_id and self._connection_state == OPEN):
return
self._ws_send_command(u"claim", nameplate=self._nameplate_id)
self._nameplate_state = OPEN
def _response_handle_claimed(self, msg):
mailbox_id = msg["mailbox"]
assert isinstance(mailbox_id, type(u"")), type(mailbox_id)
self._mailbox_id = mailbox_id
self._event_learned_mailbox()
def _event_learned_mailbox(self):
if not self._mailbox_id: raise UsageError
assert self._mailbox_state == CLOSED, self._mailbox_state
if self._closing:
return
self._ws_send_command(u"open", mailbox=self._mailbox_id)
self._mailbox_state = OPEN
# causes old messages to be sent now, and subscribes to new messages
self._maybe_send_pake()
self._maybe_send_phase_messages()
def _maybe_send_pake(self):
# TODO: deal with reentrant call
if not (self._connection_state == OPEN
and self._mailbox_state == OPEN
and self._flag_need_to_send_PAKE):
return
self._msg_send(u"pake", self._msg1)
self._flag_need_to_send_PAKE = False
def _event_received_pake(self, pake_msg):
with self._timing.add("pake2", waiting="crypto"):
self._key = self._sp.finish(pake_msg)
self._event_established_key()
def _derive_confirmation_key(self):
return self._derive_key(b"wormhole:confirmation")
def _event_established_key(self):
self._timing.add("key established")
# both sides send different (random) confirmation messages
confkey = self._derive_confirmation_key()
nonce = os.urandom(CONFMSG_NONCE_LENGTH)
confmsg = make_confmsg(confkey, nonce)
self._msg_send(u"confirm", confmsg)
verifier = self._derive_key(b"wormhole:verifier")
self._event_computed_verifier(verifier)
self._maybe_send_phase_messages()
def _API_verify(self):
# TODO: rename "verify()", make it stall until confirm received. If
# you want to discover WrongPasswordError before doing send(), call
# verify() first. If you also want to deny a successful MitM (and
# have some other way to check a long verifier), use the return value
# of verify().
if self._error: return defer.fail(self._error)
if self._get_verifier_called: raise UsageError
self._get_verifier_called = True
if self._verifier:
return defer.succeed(self._verifier)
# TODO: maybe have this wait on _event_received_confirm too
self._verifier_waiter = defer.Deferred()
return self._verifier_waiter
def _event_computed_verifier(self, verifier):
self._verifier = verifier
if self._verifier_waiter:
self._verifier_waiter.callback(verifier)
def _event_received_confirm(self, body):
# TODO: we might not have a master key yet, if the caller wasn't
# waiting in _get_master_key() when a back-to-back pake+_confirm
# message pair arrived.
confkey = self._derive_confirmation_key()
nonce = body[:CONFMSG_NONCE_LENGTH]
if body != make_confmsg(confkey, nonce):
# this makes all API calls fail
if self.DEBUG: print("CONFIRM FAILED")
return self._signal_error(WrongPasswordError(), u"scary")
def _API_send(self, outbound_data):
if self._error: raise self._error
if not isinstance(outbound_data, type(b"")):
raise TypeError(type(outbound_data))
phase = self._next_send_phase
self._next_send_phase += 1
self._plaintext_to_send.append( (phase, outbound_data) )
with self._timing.add("API send", phase=phase):
self._maybe_send_phase_messages()
def _derive_phase_key(self, side, phase):
assert isinstance(side, type(u"")), type(side)
assert isinstance(phase, type(u"")), type(phase)
side_bytes = side.encode("ascii")
phase_bytes = phase.encode("ascii")
purpose = (b"wormhole:phase:"
+ sha256(side_bytes).digest()
+ sha256(phase_bytes).digest())
return self._derive_key(purpose)
def _maybe_send_phase_messages(self):
# TODO: deal with reentrant call
if not (self._connection_state == OPEN
and self._mailbox_state == OPEN
and self._key):
return
plaintexts = self._plaintext_to_send
self._plaintext_to_send = []
for pm in plaintexts:
(phase_int, plaintext) = pm
assert isinstance(phase_int, int), type(phase_int)
phase = u"%d" % phase_int
data_key = self._derive_phase_key(self._side, phase)
encrypted = self._encrypt_data(data_key, plaintext)
self._msg_send(phase, encrypted)
def _encrypt_data(self, key, data):
# Without predefined roles, we can't derive predictably unique keys
# for each side, so we use the same key for both. We use random
# nonces to keep the messages distinct, and we automatically ignore
# reflections.
# TODO: HKDF(side, nonce, key) ?? include 'side' to prevent
# reflections, since we no longer compare messages
assert isinstance(key, type(b"")), type(key)
assert isinstance(data, type(b"")), type(data)
assert len(key) == SecretBox.KEY_SIZE, len(key)
box = SecretBox(key)
nonce = utils.random(SecretBox.NONCE_SIZE)
return box.encrypt(data, nonce)
def _msg_send(self, phase, body):
if phase in self._sent_phases: raise UsageError
assert self._mailbox_state == OPEN, self._mailbox_state
self._sent_phases.add(phase)
# TODO: retry on failure, with exponential backoff. We're guarding
# against the rendezvous server being temporarily offline.
self._timing.add("add", phase=phase)
self._ws_send_command(u"add", phase=phase,
body=hexlify(body).decode("ascii"))
def _event_mailbox_used(self):
if self.DEBUG: print("_event_mailbox_used")
if self._flag_need_to_see_mailbox_used:
self._maybe_release_nameplate()
self._flag_need_to_see_mailbox_used = False
def _API_derive_key(self, purpose, length):
if self._error: raise self._error
if not isinstance(purpose, type(u"")): raise TypeError(type(purpose))
return self._derive_key(to_bytes(purpose), length)
def _derive_key(self, purpose, length=SecretBox.KEY_SIZE):
if not isinstance(purpose, type(b"")): raise TypeError(type(purpose))
if self._key is None:
raise UsageError # call derive_key after get_verifier() or get()
return HKDF(self._key, length, CTXinfo=purpose)
def _response_handle_message(self, msg):
side = msg["side"]
phase = msg["phase"]
assert isinstance(phase, type(u"")), type(phase)
body = unhexlify(msg["body"].encode("ascii"))
if side == self._side:
return
self._event_received_peer_message(side, phase, body)
def _event_received_peer_message(self, side, phase, body):
# any message in the mailbox means we no longer need the nameplate
self._event_mailbox_used()
#if phase in self._received_messages:
# # a nameplate collision would cause this
# err = ServerError("got duplicate phase %s" % phase, self._ws_url)
# return self._signal_error(err)
#self._received_messages[phase] = body
if phase == u"pake":
self._event_received_pake(body)
return
if phase == u"confirm":
self._event_received_confirm(body)
return
# It's a phase message, aimed at the application above us. Decrypt
# and deliver upstairs, notifying anyone waiting on it
try:
data_key = self._derive_phase_key(side, phase)
plaintext = self._decrypt_data(data_key, body)
except CryptoError:
e = WrongPasswordError()
self._signal_error(e, u"scary") # flunk all other API calls
# make tests fail, if they aren't explicitly catching it
if self.DEBUG: print("CryptoError in msg received")
log.err(e)
if self.DEBUG: print(" did log.err", e)
return # ignore this message
self._received_messages[phase] = plaintext
if phase in self._receive_waiters:
d = self._receive_waiters.pop(phase)
d.callback(plaintext)
def _decrypt_data(self, key, encrypted):
assert isinstance(key, type(b"")), type(key)
assert isinstance(encrypted, type(b"")), type(encrypted)
assert len(key) == SecretBox.KEY_SIZE, len(key)
box = SecretBox(key)
data = box.decrypt(encrypted)
return data
def _API_get(self):
if self._error: return defer.fail(self._error)
phase = u"%d" % self._next_receive_phase
self._next_receive_phase += 1
with self._timing.add("API get", phase=phase):
if phase in self._received_messages:
return defer.succeed(self._received_messages[phase])
d = self._receive_waiters[phase] = defer.Deferred()
return d
def _signal_error(self, error, mood):
if self.DEBUG: print("_signal_error", error, mood)
if self._error:
return
self._maybe_close(error, mood)
if self.DEBUG: print("_signal_error done")
@inlineCallbacks
def _API_close(self, res, mood=u"happy"):
if self.DEBUG: print("close")
if self._close_called: raise UsageError
self._close_called = True
self._maybe_close(WormholeClosedError(), mood)
if self.DEBUG: print("waiting for disconnect")
yield self._disconnect_waiter
returnValue(res)
def _maybe_close(self, error, mood):
if self._closing:
return
# ordering constraints:
# * must wait for nameplate/mailbox acks before closing the websocket
# * must mark APIs for failure before errbacking Deferreds
# * since we give up control
# * must mark self._closing before errbacking Deferreds
# * since caller may call close() when we give up control
# * and close() will reenter _maybe_close
self._error = error # causes new API calls to fail
# since we're about to give up control by errbacking any API
# Deferreds, set self._closing, to make sure that a new call to
# close() isn't going to confuse anything
self._closing = True
# now errback all API deferreds except close(): get_code,
# input_code, verify, get
for d in self._connection_waiters: # input_code, get_code (early)
if self.DEBUG: print("EB cw")
d.errback(error)
if self._get_code: # get_code (late)
if self.DEBUG: print("EB gc")
self._get_code._allocated_d.errback(error)
if self._verifier_waiter and not self._verifier_waiter.called:
if self.DEBUG: print("EB VW")
self._verifier_waiter.errback(error)
for d in self._receive_waiters.values():
if self.DEBUG: print("EB RW")
d.errback(error)
# Release nameplate and close mailbox, if either was claimed/open.
# Since _closing is True when both ACKs come back, the handlers will
# close the websocket. When *that* finishes, _disconnect_waiter()
# will fire.
self._maybe_release_nameplate()
self._maybe_close_mailbox(mood)
# In the off chance we got closed before we even claimed the
# nameplate, give _maybe_finished_closing a chance to run now.
self._maybe_finished_closing()
def _maybe_release_nameplate(self):
if self.DEBUG: print("_maybe_release_nameplate", self._nameplate_state)
if self._nameplate_state == OPEN:
if self.DEBUG: print(" sending release")
self._ws_send_command(u"release")
self._nameplate_state = CLOSING
def _response_handle_released(self, msg):
self._nameplate_state = CLOSED
self._maybe_finished_closing()
def _maybe_close_mailbox(self, mood):
if self.DEBUG: print("_maybe_close_mailbox", self._mailbox_state)
if self._mailbox_state == OPEN:
if self.DEBUG: print(" sending close")
self._ws_send_command(u"close", mood=mood)
self._mailbox_state = CLOSING
def _response_handle_closed(self, msg):
self._mailbox_state = CLOSED
self._maybe_finished_closing()
def _maybe_finished_closing(self):
if self.DEBUG: print("_maybe_finished_closing", self._closing, self._nameplate_state, self._mailbox_state, self._connection_state)
if not self._closing:
return
if (self._nameplate_state == CLOSED
and self._mailbox_state == CLOSED
and self._connection_state == OPEN):
self._connection_state = CLOSING
self._drop_connection()
def _drop_connection(self):
# separate method so it can be overridden by tests
self._ws.transport.loseConnection() # probably flushes output
# calls _ws_closed() when done
def _ws_closed(self, wasClean, code, reason):
# For now (until we add reconnection), losing the websocket means
# losing everything. Make all API callers fail. Help someone waiting
# in close() to finish
self._connection_state = CLOSED
self._disconnect_waiter.callback(None)
self._maybe_finished_closing()
# what needs to happen when _ws_closed() happens unexpectedly
# * errback all API deferreds
# * maybe: cause new API calls to fail
# * obviously can't release nameplate or close mailbox
# * can't re-close websocket
# * close(wait=True) callers should fire right away
def wormhole(appid, relay_url, reactor, tor_manager=None, timing=None):
timing = timing or DebugTiming()
w = _Wormhole(appid, relay_url, reactor, tor_manager, timing)
w._start()
return w
def wormhole_from_serialized(data, reactor, timing=None):
timing = timing or DebugTiming()
w = _Wormhole.from_serialized(data, reactor, timing)
return w

View File

@ -19,6 +19,7 @@ skip_missing_interpreters = True
[testenv]
deps =
pyflakes >= 1.2.3
mock
{env:EXTRA_DEPENDENCY:}
commands =
pyflakes setup.py src