move transit-relevant files out from magic-wormhole

These files are copied (with roughly-appropriate changes to the top-level
setup.py, NEWS.md, etc) from magic-wormhole 0.10.3, commit
be166b483c5796ab3a9ad588ccf671b7eabdd96c).
tests
Brian Warner 7 years ago
parent 646ee3e5be
commit 46abd75fda

@ -0,0 +1,24 @@
# -*- mode: conf -*-
[run]
# only record trace data for wormhole_transit_relay.*
source =
wormhole_transit_relay
# and don't trace the test files themselves, or Versioneer's stuff
omit =
src/wormhole_transit_relay/test/*
src/wormhole_transit_relay/_version.py
# This allows 'coverage combine' to correlate the tracing data built while
# running tests in multiple tox virtualenvs. To take advantage of this
# properly, use "coverage erase" before tox, "coverage run --parallel-mode"
# inside tox to avoid overwriting the output data (by writing it into
# .coverage-XYZ instead of just .coverage), and run "coverage combine"
# afterwards.
[paths]
source =
src/
.tox/*/lib/python*/site-packages/
.tox/pypy*/site-packages/

1
.gitattributes vendored

@ -0,0 +1 @@
src/wormhole_transit_relay/_version.py export-subst

@ -0,0 +1,30 @@
sudo: false
language: python
cache: pip
before_cache:
- rm -f $HOME/.cache/pip/log/debug.log
branches:
except:
- /^WIP-.*$/
python:
- "2.7"
- "3.3"
- "3.4"
- "3.5"
- "3.6"
- "nightly"
install:
- pip install -U pip tox virtualenv codecov
before_script:
- if [[ $TRAVIS_PYTHON_VERSION == 3.6 ]]; then
pip install -U flake8 ;
flake8 *.py src --count --select=E901,E999,F821,F822,F823 --statistics ;
fi
script:
- tox -e coverage
after_success:
- codecov
matrix:
allow_failures:
- python: "3.3"
- python: "nightly"

@ -0,0 +1,7 @@
include versioneer.py
include src/wormhole_transit_relay/_version.py
include LICENSE README.md NEWS.md
recursive-include docs *.md *.rst *.dot
include .coveragerc tox.ini
include misc/munin/wormhole_transit
include misc/munin/wormhole_transit_alltime

@ -0,0 +1,5 @@
User-visible changes in "magic-wormhole-transit-relay":
## forked from magic-wormhole-0.10.3 (12-Sep-2017)

@ -0,0 +1,232 @@
= Transit Protocol =
The Transit protocol is responsible for establishing an encrypted
bidirectional record stream between two programs. It must be given a "transit
key" and a set of "hints" which help locate the other end (which are both
delivered by Wormhole).
The protocol tries hard to create a **direct** connection between the two
ends, but if that fails, it uses a centralized relay server to ferry data
between two separate TCP streams (one to each client).
The current implementation starts with the following:
* detect all of the host's IP addresses
* listen on a random TCP port
* offers the (address,port) pairs as hints
The other side will attempt to connect to each of those ports, as well as
listening on its own socket. After a few seconds without success, they will
both connect to a relay server.
== Roles ==
The Transit protocol has pre-defined "Sender" and "Receiver" roles (unlike
Wormhole, which is symmetric/nobody-goes-first). Each connection must have
exactly one Sender and exactly one Receiver.
The connection itself is bidirectional: either side can send or receive
records. However the connection establishment mechanism needs to know who is
in charge, and the encryption layer needs a way to produce separate keys for
each side..
This may be relaxed in the future, much as Wormhole was.
== Records ==
Transit establishes a **record-pipe**, so the two sides can send and receive
whole records, rather than unframed bytes. This is a side-effect of the
encryption (which uses the NaCl "secretbox" function). The encryption adds 44
bytes of overhead to each record (4-byte length, 24-byte nonce, 32-byte MAC),
so you might want to use slightly larger records for efficiency. The maximum
record size is 2^32 bytes (4GiB). The whole record must be held in memory at
the same time, plus its ciphertext, so very large ciphertexts are not
recommended.
Transit provides **confidentiality**, **integrity**, and **ordering** of
records. Passive attackers can only do the following:
* learn the size and transmission time of each record
* learn the sending and destination IP addresses
In addition, an active attacker is able to:
* delay delivery of individual records, while maintaining ordering (if they
delay record #4, they must delay #5 and later as well)
* terminate the connection at any time
If either side receives a corrupted or out-of-order record, they drop the
connection. Attackers cannot modify the contents of a record, or change the
order of the records, without being detected and the connection being
dropped. If a record is lost (e.g. the receiver observers records #1,#2,#4,
but not #3), the connection is dropped when the unexpected sequence number is
received.
== Handshake ==
The transit key is used to derive several secondary keys. Two of them are
used as a "handshake", to distinguish correct Transit connections from other
programs that happen to connect to the Transit sockets by mistake or malice.
The handshake is also responsible for choosing exactly one TCP connection to
use, even though multiple outbound and inbound connections are being
attempted.
The SENDER-HANDSHAKE is the string `transit sender %s ready\n\n`, with the
`%s` replaced by a hex-encoded 32-byte HKDF derivative of the transit key,
using a "context string" of `transit_sender`. The RECEIVER-HANDSHAKE is the
same but with `receiver` instead of `sender` (both for the string and the
HKDF context).
The handshake protocol is like this:
* immediately upon socket connection being made, the Sender writes
SENDER-HANDSHAKE to the socket (regardless of whether the Sender initiated
the TCP connection, or was listening on a socket and just accepted the
connection)
* likewise the Receiver immediately writes RECEIVER-HANDSHAKE to either kind
of socket
* if the Sender sees anything other than RECEIVER-HANDSHAKE as the first
bytes on the wire, it hangs up
* likewise with the Receiver and SENDER-HANDSHAKE
* if the Sender sees that this is the first connection to get
RECEIVER-HANDSHAKE, it sends `go\n`. If some other connection got there
first, it hangs up (or sends `nevermind\n` and then hangs up, but this is
mostly for debugging, and implementations should not depend upon it). After
sending `go`, it switches to encrypted-record mode.
* if the Receiver sees `go\n`, it switches to encrypted-record mode. If the
receiver sees anything else, or a disconnected socket, it disconnects.
To tolerate the inevitable race conditions created by multiple contending
sockets, only the Sender gets to decide which one wins: the first one to make
it past negotiation. Hopefully this is correlated with the fastest connection
pathway. The protocol ignores any socket that is not somewhat affiliated with
the matching Transit instance.
Hints will frequently point to local IP addresses (local to the other end)
which might be in use by unrelated nearby computers. The handshake helps to
ignore these spurious connections. It is still possible for an attacker to
cause the connection to fail, by intercepting both connections (to learn the
two handshakes), then making new connections to play back the recorded
handshakes, but this level of attacker could simply drop the user's packets
directly.
== Relay ==
The **Transit Relay** is a host which offers TURN-like services for
magic-wormhole instances. It uses a TCP-based protocol with a handshake to
determine which connection wants to be connected to which.
When connecting to a relay, the Transit client first writes RELAY-HANDSHAKE
to the socket, which is `please relay %s\n`, where `%s` is the hex-encoded
32-byte HKDF derivative of the transit key, using `transit_relay_token` as
the context. The client then waits for `ok\n`.
The relay waits for a second connection that uses the same token. When this
happens, the relay sends `ok\n` to both, then wires the connections together,
so that everything received after the token on one is written out (after the
ok) on the other. When either connection is lost, the other will be closed
(the relay does not support "half-close").
When clients use a relay connection, they perform the usual sender/receiver
handshake just after the `ok\n` is received: until that point they pretend
the connection doesn't even exist.
Direct connections are better, since they are faster and less expensive for
the relay operator. If there are any potentially-viable direct connection
hints available, the Transit instance will wait a few seconds before
attempting to use the relay. If it has no viable direct hints, it will start
using the relay right away. This prefers direct connections, but doesn't
introduce completely unnecessary stalls.
== API ==
First, create a Transit instance, giving it the connection information of the
transit relay. The application must know whether it should use a Sender or a
Receiver:
```python
from wormhole.blocking.transit import TransitSender
s = TransitSender("tcp:relayhost.example.org:12345")
```
Next, ask the Transit for its direct and relay hints. This should be
delivered to the other side via a Wormhole message (i.e. add them to a dict,
serialize it with JSON, send the result as a message with `wormhole.send()`).
```python
direct_hints = s.get_direct_hints()
relay_hints = s.get_relay_hints()
```
Then, perform the Wormhole exchange, which ought to give you the direct and
relay hints of the other side. Tell your Transit instance about their hints.
```python
s.add_their_direct_hints(their_direct_hints)
s.add_their_relay_hints(their_relay_hints)
```
Then use `wormhole.derive_key()` to obtain a shared key for Transit purposes,
and tell your Transit about it. Both sides must use the same derivation
string, and this string must not be used for any other purpose, but beyond
that it doesn't much matter what the exact string is.
```python
key = w.derive_key(application_id + "/transit-key")
s.set_transit_key(key)
```
Finally, tell the Transit instance to connect. This will yield a "record
pipe" object, on which records can be sent and received. If no connection can
be established within a timeout (defaults to 30 seconds), `connect()` will
throw an exception instead. The pipe can be closed with `close()`.
```python
rp = s.connect()
rp.send_record(b"my first record")
their_record = rp.receive_record()
rp.send_record(b"Greatest Hits)
other = rp.receive_record()
rp.close()
```
Records can be sent and received arbitrarily (you are not limited to taking
turns). However the blocking API does not provide a way to send records while
waiting for an inbound record. This *might* work with threads, but it has not
been tested.
== Twisted API ==
The same facilities are available in the asynchronous Twisted environment.
The difference is that some functions return Deferreds instead of immediate
values. The final record-pipe object is a Protocol (TBD: maybe this is a job
for Tubes?), which exposes `receive_record()` as a Deferred-returning
function that internally holds a queue of inbound records.
```python
from twisted.internet.defer import inlineCallbacks
from wormhole.twisted.transit import TransitSender
@inlineCallbacks
def do_transit():
s = TransitSender(relay)
my_relay_hints = s.get_relay_hints()
my_direct_hints = yield s.get_direct_hints()
# (send hints via wormhole)
s.add_their_relay_hints(their_relay_hints)
s.add_their_direct_hints(their_direct_hints)
s.set_transit_key(key)
rp = yield s.connect()
rp.send_record(b"eponymous")
them = yield rp.receive_record()
yield rp.close()
```
This object also implements the `IConsumer`/`IProducer` protocols for
**bytes**, which means you can transfer a file by wiring up a file reader as
a Producer. Each chunk of bytes that the Producer generates will be put into
a single record. The Consumer interface works the same way. This enables
backpressure and flow-control: if the far end (or the network) cannot keep up
with the stream of data, the sender will wait for them to catch up before
filling buffers without bound.

@ -0,0 +1,33 @@
#! /usr/bin/env python
"""
Use the following in /etc/munin/plugin-conf.d/wormhole :
[wormhole_*]
env.serverdir /path/to/your/wormhole/server
"""
import os, sys, time, json
CONFIG = """\
graph_title Magic-Wormhole Transit Usage (since reboot)
graph_vlabel Bytes Since Reboot
graph_category network
bytes.label Transit Bytes
bytes.draw LINE1
bytes.type GAUGE
"""
if len(sys.argv) > 1 and sys.argv[1] == "config":
print CONFIG.rstrip()
sys.exit(0)
serverdir = os.environ["serverdir"]
fn = os.path.join(serverdir, "stats.json")
with open(fn) as f:
data = json.load(f)
if time.time() > data["valid_until"]:
sys.exit(1) # expired
t = data["transit"]["since_reboot"]
print "bytes.value", t["bytes"]

@ -0,0 +1,33 @@
#! /usr/bin/env python
"""
Use the following in /etc/munin/plugin-conf.d/wormhole :
[wormhole_*]
env.serverdir /path/to/your/wormhole/server
"""
import os, sys, time, json
CONFIG = """\
graph_title Magic-Wormhole Transit Usage (all time)
graph_vlabel Bytes Since DB Creation
graph_category network
bytes.label Transit Bytes
bytes.draw LINE1
bytes.type GAUGE
"""
if len(sys.argv) > 1 and sys.argv[1] == "config":
print CONFIG.rstrip()
sys.exit(0)
serverdir = os.environ["serverdir"]
fn = os.path.join(serverdir, "stats.json")
with open(fn) as f:
data = json.load(f)
if time.time() > data["valid_until"]:
sys.exit(1) # expired
t = data["transit"]["all_time"]
print "bytes.value", t["bytes"]

@ -0,0 +1,9 @@
[wheel]
universal = 1
[versioneer]
VCS = git
versionfile_source = src/wormhole_transit_relay/_version.py
versionfile_build = wormhole_transit_relay/_version.py
tag_prefix =
parentdir_prefix = magic-wormhole-transit-relay

@ -0,0 +1,28 @@
from setuptools import setup
import versioneer
commands = versioneer.get_cmdclass()
setup(name="magic-wormhole-transit-relay",
version=versioneer.get_version(),
description="Transit Relay server for Magic-Wormhole",
author="Brian Warner",
author_email="warner-magic-wormhole@lothar.com",
license="MIT",
url="https://github.com/warner/magic-wormhole-transit-relay",
package_dir={"": "src"},
packages=["wormhole_transit_relay",
"wormhole_transit_relay.test",
],
package_data={"wormhole_transit_relay": ["db-schemas/*.sql"]},
install_requires=[
"twisted >= 17.5.0",
],
extras_require={
':sys_platform=="win32"': ["pypiwin32"],
"dev": ["mock", "tox", "pyflakes"],
},
test_suite="wormhole_transit_relay.test",
cmdclass=commands,
)

@ -0,0 +1,4 @@
from ._version import get_versions
__version__ = get_versions()['version']
del get_versions

@ -0,0 +1,520 @@
# This file helps to compute a version number in source trees obtained from
# git-archive tarball (such as those provided by githubs download-from-tag
# feature). Distribution tarballs (built by setup.py sdist) and build
# directories (produced by setup.py build) will contain a much shorter file
# that just contains the computed version number.
# This file is released into the public domain. Generated by
# versioneer-0.18 (https://github.com/warner/python-versioneer)
"""Git implementation of _version.py."""
import errno
import os
import re
import subprocess
import sys
def get_keywords():
"""Get the keywords needed to look up the version information."""
# these strings will be replaced by git during git-archive.
# setup.py/versioneer.py will grep for the variable names, so they must
# each be defined on a line of their own. _version.py will just call
# get_keywords().
git_refnames = "$Format:%d$"
git_full = "$Format:%H$"
git_date = "$Format:%ci$"
keywords = {"refnames": git_refnames, "full": git_full, "date": git_date}
return keywords
class VersioneerConfig:
"""Container for Versioneer configuration parameters."""
def get_config():
"""Create, populate and return the VersioneerConfig() object."""
# these strings are filled in when 'setup.py versioneer' creates
# _version.py
cfg = VersioneerConfig()
cfg.VCS = "git"
cfg.style = ""
cfg.tag_prefix = ""
cfg.parentdir_prefix = "magic-wormhole-transit-relay"
cfg.versionfile_source = "src/wormhole_transit_relay/_version.py"
cfg.verbose = False
return cfg
class NotThisMethod(Exception):
"""Exception raised if a method is not valid for the current scenario."""
LONG_VERSION_PY = {}
HANDLERS = {}
def register_vcs_handler(vcs, method): # decorator
"""Decorator to mark a method as the handler for a particular VCS."""
def decorate(f):
"""Store f in HANDLERS[vcs][method]."""
if vcs not in HANDLERS:
HANDLERS[vcs] = {}
HANDLERS[vcs][method] = f
return f
return decorate
def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False,
env=None):
"""Call the given command(s)."""
assert isinstance(commands, list)
p = None
for c in commands:
try:
dispcmd = str([c] + args)
# remember shell=False, so use git.cmd on windows, not just git
p = subprocess.Popen([c] + args, cwd=cwd, env=env,
stdout=subprocess.PIPE,
stderr=(subprocess.PIPE if hide_stderr
else None))
break
except EnvironmentError:
e = sys.exc_info()[1]
if e.errno == errno.ENOENT:
continue
if verbose:
print("unable to run %s" % dispcmd)
print(e)
return None, None
else:
if verbose:
print("unable to find command, tried %s" % (commands,))
return None, None
stdout = p.communicate()[0].strip()
if sys.version_info[0] >= 3:
stdout = stdout.decode()
if p.returncode != 0:
if verbose:
print("unable to run %s (error)" % dispcmd)
print("stdout was %s" % stdout)
return None, p.returncode
return stdout, p.returncode
def versions_from_parentdir(parentdir_prefix, root, verbose):
"""Try to determine the version from the parent directory name.
Source tarballs conventionally unpack into a directory that includes both
the project name and a version string. We will also support searching up
two directory levels for an appropriately named parent directory
"""
rootdirs = []
for i in range(3):
dirname = os.path.basename(root)
if dirname.startswith(parentdir_prefix):
return {"version": dirname[len(parentdir_prefix):],
"full-revisionid": None,
"dirty": False, "error": None, "date": None}
else:
rootdirs.append(root)
root = os.path.dirname(root) # up a level
if verbose:
print("Tried directories %s but none started with prefix %s" %
(str(rootdirs), parentdir_prefix))
raise NotThisMethod("rootdir doesn't start with parentdir_prefix")
@register_vcs_handler("git", "get_keywords")
def git_get_keywords(versionfile_abs):
"""Extract version information from the given file."""
# the code embedded in _version.py can just fetch the value of these
# keywords. When used from setup.py, we don't want to import _version.py,
# so we do it with a regexp instead. This function is not used from
# _version.py.
keywords = {}
try:
f = open(versionfile_abs, "r")
for line in f.readlines():
if line.strip().startswith("git_refnames ="):
mo = re.search(r'=\s*"(.*)"', line)
if mo:
keywords["refnames"] = mo.group(1)
if line.strip().startswith("git_full ="):
mo = re.search(r'=\s*"(.*)"', line)
if mo:
keywords["full"] = mo.group(1)
if line.strip().startswith("git_date ="):
mo = re.search(r'=\s*"(.*)"', line)
if mo:
keywords["date"] = mo.group(1)
f.close()
except EnvironmentError:
pass
return keywords
@register_vcs_handler("git", "keywords")
def git_versions_from_keywords(keywords, tag_prefix, verbose):
"""Get version information from git keywords."""
if not keywords:
raise NotThisMethod("no keywords at all, weird")
date = keywords.get("date")
if date is not None:
# git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant
# datestamp. However we prefer "%ci" (which expands to an "ISO-8601
# -like" string, which we must then edit to make compliant), because
# it's been around since git-1.5.3, and it's too difficult to
# discover which version we're using, or to work around using an
# older one.
date = date.strip().replace(" ", "T", 1).replace(" ", "", 1)
refnames = keywords["refnames"].strip()
if refnames.startswith("$Format"):
if verbose:
print("keywords are unexpanded, not using")
raise NotThisMethod("unexpanded keywords, not a git-archive tarball")
refs = set([r.strip() for r in refnames.strip("()").split(",")])
# starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of
# just "foo-1.0". If we see a "tag: " prefix, prefer those.
TAG = "tag: "
tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)])
if not tags:
# Either we're using git < 1.8.3, or there really are no tags. We use
# a heuristic: assume all version tags have a digit. The old git %d
# expansion behaves like git log --decorate=short and strips out the
# refs/heads/ and refs/tags/ prefixes that would let us distinguish
# between branches and tags. By ignoring refnames without digits, we
# filter out many common branch names like "release" and
# "stabilization", as well as "HEAD" and "master".
tags = set([r for r in refs if re.search(r'\d', r)])
if verbose:
print("discarding '%s', no digits" % ",".join(refs - tags))
if verbose:
print("likely tags: %s" % ",".join(sorted(tags)))
for ref in sorted(tags):
# sorting will prefer e.g. "2.0" over "2.0rc1"
if ref.startswith(tag_prefix):
r = ref[len(tag_prefix):]
if verbose:
print("picking %s" % r)
return {"version": r,
"full-revisionid": keywords["full"].strip(),
"dirty": False, "error": None,
"date": date}
# no suitable tags, so version is "0+unknown", but full hex is still there
if verbose:
print("no suitable tags, using unknown + full revision id")
return {"version": "0+unknown",
"full-revisionid": keywords["full"].strip(),
"dirty": False, "error": "no suitable tags", "date": None}
@register_vcs_handler("git", "pieces_from_vcs")
def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
"""Get version from 'git describe' in the root of the source tree.
This only gets called if the git-archive 'subst' keywords were *not*
expanded, and _version.py hasn't already been rewritten with a short
version string, meaning we're inside a checked out source tree.
"""
GITS = ["git"]
if sys.platform == "win32":
GITS = ["git.cmd", "git.exe"]
out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root,
hide_stderr=True)
if rc != 0:
if verbose:
print("Directory %s not under git control" % root)
raise NotThisMethod("'git rev-parse --git-dir' returned error")
# if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty]
# if there isn't one, this yields HEX[-dirty] (no NUM)
describe_out, rc = run_command(GITS, ["describe", "--tags", "--dirty",
"--always", "--long",
"--match", "%s*" % tag_prefix],
cwd=root)
# --long was added in git-1.5.5
if describe_out is None:
raise NotThisMethod("'git describe' failed")
describe_out = describe_out.strip()
full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root)
if full_out is None:
raise NotThisMethod("'git rev-parse' failed")
full_out = full_out.strip()
pieces = {}
pieces["long"] = full_out
pieces["short"] = full_out[:7] # maybe improved later
pieces["error"] = None
# parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty]
# TAG might have hyphens.
git_describe = describe_out
# look for -dirty suffix
dirty = git_describe.endswith("-dirty")
pieces["dirty"] = dirty
if dirty:
git_describe = git_describe[:git_describe.rindex("-dirty")]
# now we have TAG-NUM-gHEX or HEX
if "-" in git_describe:
# TAG-NUM-gHEX
mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe)
if not mo:
# unparseable. Maybe git-describe is misbehaving?
pieces["error"] = ("unable to parse git-describe output: '%s'"
% describe_out)
return pieces
# tag
full_tag = mo.group(1)
if not full_tag.startswith(tag_prefix):
if verbose:
fmt = "tag '%s' doesn't start with prefix '%s'"
print(fmt % (full_tag, tag_prefix))
pieces["error"] = ("tag '%s' doesn't start with prefix '%s'"
% (full_tag, tag_prefix))
return pieces
pieces["closest-tag"] = full_tag[len(tag_prefix):]
# distance: number of commits since tag
pieces["distance"] = int(mo.group(2))
# commit: short hex revision ID
pieces["short"] = mo.group(3)
else:
# HEX: no tags
pieces["closest-tag"] = None
count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"],
cwd=root)
pieces["distance"] = int(count_out) # total number of commits
# commit date: see ISO-8601 comment in git_versions_from_keywords()
date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"],
cwd=root)[0].strip()
pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1)
return pieces
def plus_or_dot(pieces):
"""Return a + if we don't already have one, else return a ."""
if "+" in pieces.get("closest-tag", ""):
return "."
return "+"
def render_pep440(pieces):
"""Build up version string, with post-release "local version identifier".
Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you
get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty
Exceptions:
1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty]
"""
if pieces["closest-tag"]:
rendered = pieces["closest-tag"]
if pieces["distance"] or pieces["dirty"]:
rendered += plus_or_dot(pieces)
rendered += "%d.g%s" % (pieces["distance"], pieces["short"])
if pieces["dirty"]:
rendered += ".dirty"
else:
# exception #1
rendered = "0+untagged.%d.g%s" % (pieces["distance"],
pieces["short"])
if pieces["dirty"]:
rendered += ".dirty"
return rendered
def render_pep440_pre(pieces):
"""TAG[.post.devDISTANCE] -- No -dirty.
Exceptions:
1: no tags. 0.post.devDISTANCE
"""
if pieces["closest-tag"]:
rendered = pieces["closest-tag"]
if pieces["distance"]:
rendered += ".post.dev%d" % pieces["distance"]
else:
# exception #1
rendered = "0.post.dev%d" % pieces["distance"]
return rendered
def render_pep440_post(pieces):
"""TAG[.postDISTANCE[.dev0]+gHEX] .
The ".dev0" means dirty. Note that .dev0 sorts backwards
(a dirty tree will appear "older" than the corresponding clean one),
but you shouldn't be releasing software with -dirty anyways.
Exceptions:
1: no tags. 0.postDISTANCE[.dev0]
"""
if pieces["closest-tag"]:
rendered = pieces["closest-tag"]
if pieces["distance"] or pieces["dirty"]:
rendered += ".post%d" % pieces["distance"]
if pieces["dirty"]:
rendered += ".dev0"
rendered += plus_or_dot(pieces)
rendered += "g%s" % pieces["short"]
else:
# exception #1
rendered = "0.post%d" % pieces["distance"]
if pieces["dirty"]:
rendered += ".dev0"
rendered += "+g%s" % pieces["short"]
return rendered
def render_pep440_old(pieces):
"""TAG[.postDISTANCE[.dev0]] .
The ".dev0" means dirty.
Eexceptions:
1: no tags. 0.postDISTANCE[.dev0]
"""
if pieces["closest-tag"]:
rendered = pieces["closest-tag"]
if pieces["distance"] or pieces["dirty"]:
rendered += ".post%d" % pieces["distance"]
if pieces["dirty"]:
rendered += ".dev0"
else:
# exception #1
rendered = "0.post%d" % pieces["distance"]
if pieces["dirty"]:
rendered += ".dev0"
return rendered
def render_git_describe(pieces):
"""TAG[-DISTANCE-gHEX][-dirty].
Like 'git describe --tags --dirty --always'.
Exceptions:
1: no tags. HEX[-dirty] (note: no 'g' prefix)
"""
if pieces["closest-tag"]:
rendered = pieces["closest-tag"]
if pieces["distance"]:
rendered += "-%d-g%s" % (pieces["distance"], pieces["short"])
else:
# exception #1
rendered = pieces["short"]
if pieces["dirty"]:
rendered += "-dirty"
return rendered
def render_git_describe_long(pieces):
"""TAG-DISTANCE-gHEX[-dirty].
Like 'git describe --tags --dirty --always -long'.
The distance/hash is unconditional.
Exceptions:
1: no tags. HEX[-dirty] (note: no 'g' prefix)
"""
if pieces["closest-tag"]:
rendered = pieces["closest-tag"]
rendered += "-%d-g%s" % (pieces["distance"], pieces["short"])
else:
# exception #1
rendered = pieces["short"]
if pieces["dirty"]:
rendered += "-dirty"
return rendered
def render(pieces, style):
"""Render the given version pieces into the requested style."""
if pieces["error"]:
return {"version": "unknown",
"full-revisionid": pieces.get("long"),
"dirty": None,
"error": pieces["error"],
"date": None}
if not style or style == "default":
style = "pep440" # the default
if style == "pep440":
rendered = render_pep440(pieces)
elif style == "pep440-pre":
rendered = render_pep440_pre(pieces)
elif style == "pep440-post":
rendered = render_pep440_post(pieces)
elif style == "pep440-old":
rendered = render_pep440_old(pieces)
elif style == "git-describe":
rendered = render_git_describe(pieces)
elif style == "git-describe-long":
rendered = render_git_describe_long(pieces)
else:
raise ValueError("unknown style '%s'" % style)
return {"version": rendered, "full-revisionid": pieces["long"],
"dirty": pieces["dirty"], "error": None,
"date": pieces.get("date")}
def get_versions():
"""Get version information or return default if unable to do so."""
# I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have
# __file__, we can work backwards from there to the root. Some
# py2exe/bbfreeze/non-CPython implementations don't do __file__, in which
# case we can only use expanded keywords.
cfg = get_config()
verbose = cfg.verbose
try:
return git_versions_from_keywords(get_keywords(), cfg.tag_prefix,
verbose)
except NotThisMethod:
pass
try:
root = os.path.realpath(__file__)
# versionfile_source is the relative path from the top of the source
# tree (where the .git directory might live) to this file. Invert
# this to find the root from __file__.
for i in cfg.versionfile_source.split('/'):
root = os.path.dirname(root)
except NameError:
return {"version": "0+unknown", "full-revisionid": None,
"dirty": None,
"error": "unable to find root of source tree",
"date": None}
try:
pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose)
return render(pieces, cfg.style)
except NotThisMethod:
pass
try:
if cfg.parentdir_prefix:
return versions_from_parentdir(cfg.parentdir_prefix, root, verbose)
except NotThisMethod:
pass
return {"version": "0+unknown", "full-revisionid": None,
"dirty": None,
"error": "unable to compute version", "date": None}

@ -0,0 +1,156 @@
from __future__ import print_function
import json
import click
from ..cli.cli import Config, _compose
# can put this back in to get this command as "wormhole server"
# instead
#from ..cli.cli import wormhole
#@wormhole.group()
@click.group()
@click.pass_context
def server(ctx): # this is the setuptools entrypoint for bin/wormhole-server
"""
Control a relay server (most users shouldn't need to worry
about this and can use the default server).
"""
# just leaving this pointing to wormhole.cli.cli.Config for now,
# but if we want to keep wormhole-server as a separate command
# should probably have our own Config without all the options the
# server commands don't use
ctx.obj = Config()
def _validate_websocket_protocol_options(ctx, param, value):
return list(_validate_websocket_protocol_option(option) for option in value)
def _validate_websocket_protocol_option(option):
try:
key, value = option.split("=", 1)
except ValueError:
raise click.BadParameter("format options as OPTION=VALUE")
try:
value = json.loads(value)
except:
raise click.BadParameter("could not parse JSON value for {}".format(key))
return (key, value)
LaunchArgs = _compose(
click.option(
"--rendezvous", default="tcp:4000", metavar="tcp:PORT",
help="endpoint specification for the rendezvous port",
),
click.option(
"--transit", default="tcp:4001", metavar="tcp:PORT",
help="endpoint specification for the transit-relay port",
),
click.option(
"--advertise-version", metavar="VERSION",
help="version to recommend to clients",
),
click.option(
"--blur-usage", default=None, type=int,
metavar="SECONDS",
help="round logged access times to improve privacy",
),
click.option(
"--no-daemon", "-n", is_flag=True,
help="Run in the foreground",
),
click.option(
"--signal-error", is_flag=True,
help="force all clients to fail with a message",
),
click.option(
"--allow-list/--disallow-list", default=True,
help="always/never send list of allocated nameplates",
),
click.option(
"--relay-database-path", default="relay.sqlite", metavar="PATH",
help="location for the relay server state database",
),
click.option(
"--stats-json-path", default="stats.json", metavar="PATH",
help="location to write the relay stats file",
),
click.option(
"--websocket-protocol-option", multiple=True, metavar="OPTION=VALUE",
callback=_validate_websocket_protocol_options,
help="a websocket server protocol option to configure",
),
)
@server.command()
@LaunchArgs
@click.pass_obj
def start(cfg, **kwargs):
"""
Start a relay server
"""
for name, value in kwargs.items():
setattr(cfg, name, value)
from wormhole.server.cmd_server import start_server
start_server(cfg)
@server.command()
@LaunchArgs
@click.pass_obj
def restart(cfg, **kwargs):
"""
Re-start a relay server
"""
for name, value in kwargs.items():
setattr(cfg, name, value)
from wormhole.server.cmd_server import restart_server
restart_server(cfg)
@server.command()
@click.pass_obj
def stop(cfg):
"""
Stop a relay server
"""
from wormhole.server.cmd_server import stop_server
stop_server(cfg)
@server.command(name="tail-usage")
@click.pass_obj
def tail_usage(cfg):
"""
Follow the latest usage
"""
from wormhole.server.cmd_usage import tail_usage
tail_usage(cfg)
@server.command(name='count-channels')
@click.option(
"--json", is_flag=True,
)
@click.pass_obj
def count_channels(cfg, json):
"""
Count active channels
"""
from wormhole.server.cmd_usage import count_channels
cfg.json = json
count_channels(cfg)
@server.command(name='count-events')
@click.option(
"--json", is_flag=True,
)
@click.pass_obj
def count_events(cfg, json):
"""
Count events
"""
from wormhole.server.cmd_usage import count_events
cfg.json = json
count_events(cfg)

@ -0,0 +1,73 @@
from __future__ import print_function, unicode_literals
import os, time
from twisted.python import usage
from twisted.scripts import twistd
class MyPlugin(object):
tapname = "xyznode"
def __init__(self, args):
self.args = args
def makeService(self, so):
# delay this import as late as possible, to allow twistd's code to
# accept --reactor= selection
from .server import RelayServer
return RelayServer(
str(self.args.rendezvous),
str(self.args.transit),
self.args.advertise_version,
self.args.relay_database_path,
self.args.blur_usage,
signal_error=self.args.signal_error,
stats_file=self.args.stats_json_path,
allow_list=self.args.allow_list,
)
class MyTwistdConfig(twistd.ServerOptions):
subCommands = [("XYZ", None, usage.Options, "node")]
def start_server(args):
c = MyTwistdConfig()
#twistd_args = tuple(args.twistd_args) + ("XYZ",)
base_args = []
if args.no_daemon:
base_args.append("--nodaemon")
twistd_args = base_args + ["XYZ"]
c.parseOptions(tuple(twistd_args))
c.loadedPlugins = {"XYZ": MyPlugin(args)}
print("starting wormhole relay server")
# this forks and never comes back. The parent calls os._exit(0)
twistd.runApp(c)
def kill_server():
try:
f = open("twistd.pid", "r")
except EnvironmentError:
print("Unable to find twistd.pid: is this really a server directory?")
print("oh well, ignoring 'stop'")
return
pid = int(f.read().strip())
f.close()
os.kill(pid, 15)
print("server process %d sent SIGTERM" % pid)
return
def stop_server(args):
kill_server()
def restart_server(args):
kill_server()
time.sleep(0.1)
timeout = 0
while os.path.exists("twistd.pid") and timeout < 10:
if timeout == 0:
print(" waiting for shutdown..")
timeout += 1
time.sleep(1)
if os.path.exists("twistd.pid"):
print("error: unable to shut down old server")
return 1
print(" old server shut down")
start_server(args)

@ -0,0 +1,226 @@
from __future__ import print_function, unicode_literals
import os, time, json
from collections import defaultdict
import click
from humanize import naturalsize
from .database import get_db
def abbrev(t):
if t is None:
return "-"
if t > 1.0:
return "%.3fs" % t
if t > 1e-3:
return "%.1fms" % (t*1e3)
return "%.1fus" % (t*1e6)
def print_event(event):
event_type, started, result, total_bytes, waiting_time, total_time = event
followthrough = None
if waiting_time and total_time:
followthrough = total_time - waiting_time
print("%17s: total=%7s wait=%7s ft=%7s size=%s (%s)" %
("%s-%s" % (event_type, result),
abbrev(total_time),
abbrev(waiting_time),
abbrev(followthrough),
naturalsize(total_bytes),
time.ctime(started),
))
def show_usage(args):
print("closed for renovation")
return 0
if not os.path.exists("relay.sqlite"):
raise click.UsageError(
"cannot find relay.sqlite, please run from the server directory"
)
oldest = None
newest = None
rendezvous_counters = defaultdict(int)
transit_counters = defaultdict(int)
total_transit_bytes = 0
db = get_db("relay.sqlite")
c = db.execute("SELECT * FROM `usage`"
" ORDER BY `started` ASC LIMIT ?",
(args.n,))
for row in c.fetchall():
if row["type"] == "rendezvous":
counters = rendezvous_counters
elif row["type"] == "transit":
counters = transit_counters
total_transit_bytes += row["total_bytes"]
else:
continue
counters["total"] += 1
counters[row["result"]] += 1
if oldest is None or row["started"] < oldest:
oldest = row["started"]
if newest is None or row["started"] > newest:
newest = row["started"]
event = (row["type"], row["started"], row["result"],
row["total_bytes"], row["waiting_time"], row["total_time"])
print_event(event)
if rendezvous_counters["total"] or transit_counters["total"]:
print("---")
print("(most recent started %s ago)" % abbrev(time.time() - newest))
if rendezvous_counters["total"]:
print("rendezvous events:")
counters = rendezvous_counters
elapsed = time.time() - oldest
total = counters["total"]
print(" %d events in %s (%.2f per hour)" % (total, abbrev(elapsed),
(3600 * total / elapsed)))
print("", ", ".join(["%s=%d (%d%%)" %
(k, counters[k], (100.0 * counters[k] / total))
for k in sorted(counters)
if k != "total"]))
if transit_counters["total"]:
print("transit events:")
counters = transit_counters
elapsed = time.time() - oldest
total = counters["total"]
print(" %d events in %s (%.2f per hour)" % (total, abbrev(elapsed),
(3600 * total / elapsed)))
rate = total_transit_bytes / elapsed
print(" %s total bytes, %sps" % (naturalsize(total_transit_bytes),
naturalsize(rate)))
print("", ", ".join(["%s=%d (%d%%)" %
(k, counters[k], (100.0 * counters[k] / total))
for k in sorted(counters)
if k != "total"]))
return 0
def tail_usage(args):
if not os.path.exists("relay.sqlite"):
raise click.UsageError(
"cannot find relay.sqlite, please run from the server directory"
)
db = get_db("relay.sqlite")
# we don't seem to have unique row IDs, so this is an inaccurate and
# inefficient hack
seen = set()
try:
while True:
old = time.time() - 2*60*60
c = db.execute("SELECT * FROM `usage`"
" WHERE `started` > ?"
" ORDER BY `started` ASC", (old,))
for row in c.fetchall():
event = (row["type"], row["started"], row["result"],
row["total_bytes"], row["waiting_time"],
row["total_time"])
if event not in seen:
print_event(event)
seen.add(event)
time.sleep(2)
except KeyboardInterrupt:
return 0
return 0
def count_channels(args):
if not os.path.exists("relay.sqlite"):
raise click.UsageError(
"cannot find relay.sqlite, please run from the server directory"
)
db = get_db("relay.sqlite")
c_list = []
c_dict = {}
def add(key, value):
c_list.append((key, value))
c_dict[key] = value
OLD = time.time() - 10*60
def q(query, values=()):
return list(db.execute(query, values).fetchone().values())[0]
add("apps", q("SELECT COUNT(DISTINCT(`app_id`)) FROM `nameplates`"))
add("total nameplates", q("SELECT COUNT() FROM `nameplates`"))
add("waiting nameplates", q("SELECT COUNT() FROM `nameplates`"
" WHERE `second` is null"))
add("connected nameplates", q("SELECT COUNT() FROM `nameplates`"
" WHERE `second` is not null"))
add("stale nameplates", q("SELECT COUNT() FROM `nameplates`"
" where `updated` < ?", (OLD,)))
add("total mailboxes", q("SELECT COUNT() FROM `mailboxes`"))
add("waiting mailboxes", q("SELECT COUNT() FROM `mailboxes`"
" WHERE `second` is null"))
add("connected mailboxes", q("SELECT COUNT() FROM `mailboxes`"
" WHERE `second` is not null"))
stale_mailboxes = 0
for mbox_row in db.execute("SELECT * FROM `mailboxes`").fetchall():
newest = db.execute("SELECT `server_rx` FROM `messages`"
" WHERE `app_id`=? AND `mailbox_id`=?"
" ORDER BY `server_rx` DESC LIMIT 1",
(mbox_row["app_id"], mbox_row["id"])).fetchone()
if newest and newest[0] < OLD:
stale_mailboxes += 1
add("stale mailboxes", stale_mailboxes)
add("messages", q("SELECT COUNT() FROM `messages`"))
if args.json:
print(json.dumps(c_dict))
else:
for (key, value) in c_list:
print(key, value)
return 0
def count_events(args):
if not os.path.exists("relay.sqlite"):
raise click.UsageError(
"cannot find relay.sqlite, please run from the server directory"
)
db = get_db("relay.sqlite")
c_list = []
c_dict = {}
def add(key, value):
c_list.append((key, value))
c_dict[key] = value
def q(query, values=()):
return list(db.execute(query, values).fetchone().values())[0]
add("apps", q("SELECT COUNT(DISTINCT(`app_id`)) FROM `nameplate_usage`"))
add("total nameplates", q("SELECT COUNT() FROM `nameplate_usage`"))
add("happy nameplates", q("SELECT COUNT() FROM `nameplate_usage`"
" WHERE `result`='happy'"))
add("lonely nameplates", q("SELECT COUNT() FROM `nameplate_usage`"
" WHERE `result`='lonely'"))
add("pruney nameplates", q("SELECT COUNT() FROM `nameplate_usage`"
" WHERE `result`='pruney'"))
add("crowded nameplates", q("SELECT COUNT() FROM `nameplate_usage`"
" WHERE `result`='crowded'"))
add("total mailboxes", q("SELECT COUNT() FROM `mailbox_usage`"))
add("happy mailboxes", q("SELECT COUNT() FROM `mailbox_usage`"
" WHERE `result`='happy'"))
add("scary mailboxes", q("SELECT COUNT() FROM `mailbox_usage`"
" WHERE `result`='scary'"))
add("lonely mailboxes", q("SELECT COUNT() FROM `mailbox_usage`"
" WHERE `result`='lonely'"))
add("errory mailboxes", q("SELECT COUNT() FROM `mailbox_usage`"
" WHERE `result`='errory'"))
add("pruney mailboxes", q("SELECT COUNT() FROM `mailbox_usage`"
" WHERE `result`='pruney'"))
add("crowded mailboxes", q("SELECT COUNT() FROM `mailbox_usage`"
" WHERE `result`='crowded'"))
add("total transit", q("SELECT COUNT() FROM `transit_usage`"))
add("happy transit", q("SELECT COUNT() FROM `transit_usage`"
" WHERE `result`='happy'"))
add("lonely transit", q("SELECT COUNT() FROM `transit_usage`"
" WHERE `result`='lonely'"))
add("errory transit", q("SELECT COUNT() FROM `transit_usage`"
" WHERE `result`='errory'"))
add("transit bytes", q("SELECT SUM(`total_bytes`) FROM `transit_usage`"))
if args.json:
print(json.dumps(c_dict))
else:
for (key, value) in c_list:
print(key, value)
return 0

@ -0,0 +1,126 @@
from __future__ import unicode_literals
import os
import sqlite3
import tempfile
from pkg_resources import resource_string
from twisted.python import log
class DBError(Exception):
pass
def get_schema(version):
schema_bytes = resource_string("wormhole.server",
"db-schemas/v%d.sql" % version)
return schema_bytes.decode("utf-8")
def get_upgrader(new_version):
schema_bytes = resource_string("wormhole.server",
"db-schemas/upgrade-to-v%d.sql" % new_version)
return schema_bytes.decode("utf-8")
TARGET_VERSION = 3
def dict_factory(cursor, row):
d = {}
for idx, col in enumerate(cursor.description):
d[col[0]] = row[idx]
return d
def _initialize_db_schema(db, target_version):
"""Creates the application schema in the given database.
"""
log.msg("populating new database with schema v%s" % target_version)
schema = get_schema(target_version)
db.executescript(schema)
db.execute("INSERT INTO version (version) VALUES (?)",
(target_version,))
db.commit()
def _initialize_db_connection(db):
"""Sets up the db connection object with a row factory and with necessary
foreign key settings.
"""
db.row_factory = dict_factory
db.execute("PRAGMA foreign_keys = ON")
problems = db.execute("PRAGMA foreign_key_check").fetchall()
if problems:
raise DBError("failed foreign key check: %s" % (problems,))
def _open_db_connection(dbfile):
"""Open a new connection to the SQLite3 database at the given path.
"""
try:
db = sqlite3.connect(dbfile)
except (EnvironmentError, sqlite3.OperationalError) as e:
raise DBError("Unable to create/open db file %s: %s" % (dbfile, e))
_initialize_db_connection(db)
return db
def _get_temporary_dbfile(dbfile):
"""Get a temporary filename near the given path.
"""
fd, name = tempfile.mkstemp(
prefix=os.path.basename(dbfile) + ".",
dir=os.path.dirname(dbfile)
)
os.close(fd)
return name
def _atomic_create_and_initialize_db(dbfile, target_version):
"""Create and return a new database, initialized with the application
schema.
If anything goes wrong, nothing is left at the ``dbfile`` path.
"""
temp_dbfile = _get_temporary_dbfile(dbfile)
db = _open_db_connection(temp_dbfile)
_initialize_db_schema(db, target_version)
db.close()
os.rename(temp_dbfile, dbfile)
return _open_db_connection(dbfile)
def get_db(dbfile, target_version=TARGET_VERSION):
"""Open or create the given db file. The parent directory must exist.
Returns the db connection object, or raises DBError.
"""
if dbfile == ":memory:":
db = _open_db_connection(dbfile)
_initialize_db_schema(db, target_version)
elif os.path.exists(dbfile):
db = _open_db_connection(dbfile)
else:
db = _atomic_create_and_initialize_db(dbfile, target_version)
try:
version = db.execute("SELECT version FROM version").fetchone()["version"]
except sqlite3.DatabaseError as e:
# this indicates that the file is not a compatible database format.
# Perhaps it was created with an old version, or it might be junk.
raise DBError("db file is unusable: %s" % e)
while version < target_version:
log.msg(" need to upgrade from %s to %s" % (version, target_version))
try:
upgrader = get_upgrader(version+1)
except ValueError: # ResourceError??
log.msg(" unable to upgrade %s to %s" % (version, version+1))
raise DBError("Unable to upgrade %s to version %s, left at %s"
% (dbfile, version+1, version))
log.msg(" executing upgrader v%s->v%s" % (version, version+1))
db.executescript(upgrader)
db.commit()
version = version+1
if version != target_version:
raise DBError("Unable to handle db version %s" % version)
return db
def dump_db(db):
# to let _iterdump work, we need to restore the original row factory
orig = db.row_factory
try:
db.row_factory = sqlite3.Row
return "".join(db.iterdump())
finally:
db.row_factory = orig

@ -0,0 +1,68 @@
DROP TABLE `nameplates`;
DROP TABLE `messages`;
DROP TABLE `mailboxes`;
-- Wormhole codes use a "nameplate": a short name which is only used to
-- reference a specific (long-named) mailbox. The codes only use numeric
-- nameplates, but the protocol and server allow can use arbitrary strings.
CREATE TABLE `nameplates`
(
`id` INTEGER PRIMARY KEY AUTOINCREMENT,
`app_id` VARCHAR,
`name` VARCHAR,
`mailbox_id` VARCHAR REFERENCES `mailboxes`(`id`),
`request_id` VARCHAR -- from 'allocate' message, for future deduplication
);
CREATE INDEX `nameplates_idx` ON `nameplates` (`app_id`, `name`);
CREATE INDEX `nameplates_mailbox_idx` ON `nameplates` (`app_id`, `mailbox_id`);
CREATE INDEX `nameplates_request_idx` ON `nameplates` (`app_id`, `request_id`);
CREATE TABLE `nameplate_sides`
(
`nameplates_id` REFERENCES `nameplates`(`id`),
`claimed` BOOLEAN, -- True after claim(), False after release()
`side` VARCHAR,
`added` INTEGER -- time when this side first claimed the nameplate
);
-- Clients exchange messages through a "mailbox", which has a long (randomly
-- unique) identifier and a queue of messages.
-- `id` is randomly-generated and unique across all apps.
CREATE TABLE `mailboxes`
(
`app_id` VARCHAR,
`id` VARCHAR PRIMARY KEY,
`updated` INTEGER, -- time of last activity, used for pruning
`for_nameplate` BOOLEAN -- allocated for a nameplate, not standalone
);
CREATE INDEX `mailboxes_idx` ON `mailboxes` (`app_id`, `id`);
CREATE TABLE `mailbox_sides`
(
`mailbox_id` REFERENCES `mailboxes`(`id`),
`opened` BOOLEAN, -- True after open(), False after close()
`side` VARCHAR,
`added` INTEGER, -- time when this side first opened the mailbox
`mood` VARCHAR
);
CREATE TABLE `messages`
(
`app_id` VARCHAR,
`mailbox_id` VARCHAR,
`side` VARCHAR,
`phase` VARCHAR, -- numeric or string
`body` VARCHAR,
`server_rx` INTEGER,
`msg_id` VARCHAR
);
CREATE INDEX `messages_idx` ON `messages` (`app_id`, `mailbox_id`);
ALTER TABLE `mailbox_usage` ADD COLUMN `for_nameplate` BOOLEAN;
CREATE INDEX `mailbox_usage_result_idx` ON `mailbox_usage` (`result`);
CREATE INDEX `transit_usage_result_idx` ON `transit_usage` (`result`);
DELETE FROM `version`;
INSERT INTO `version` (`version`) VALUES (3);

@ -0,0 +1,105 @@
-- note: anything which isn't an boolean, integer, or human-readable unicode
-- string, (i.e. binary strings) will be stored as hex
CREATE TABLE `version`
(
`version` INTEGER -- contains one row, set to 2
);
-- Wormhole codes use a "nameplate": a short identifier which is only used to
-- reference a specific (long-named) mailbox. The codes only use numeric
-- nameplates, but the protocol and server allow can use arbitrary strings.
CREATE TABLE `nameplates`
(
`app_id` VARCHAR,
`id` VARCHAR,
`mailbox_id` VARCHAR, -- really a foreign key
`side1` VARCHAR, -- side name, or NULL
`side2` VARCHAR, -- side name, or NULL
`request_id` VARCHAR, -- from 'allocate' message, for future deduplication
`crowded` BOOLEAN, -- at some point, three or more sides were involved
`updated` INTEGER, -- time of last activity, used for pruning
-- timing data
`started` INTEGER, -- time when nameplace was opened
`second` INTEGER -- time when second side opened
);
CREATE INDEX `nameplates_idx` ON `nameplates` (`app_id`, `id`);
CREATE INDEX `nameplates_updated_idx` ON `nameplates` (`app_id`, `updated`);
CREATE INDEX `nameplates_mailbox_idx` ON `nameplates` (`app_id`, `mailbox_id`);
CREATE INDEX `nameplates_request_idx` ON `nameplates` (`app_id`, `request_id`);
-- Clients exchange messages through a "mailbox", which has a long (randomly
-- unique) identifier and a queue of messages.
CREATE TABLE `mailboxes`
(
`app_id` VARCHAR,
`id` VARCHAR,
`side1` VARCHAR, -- side name, or NULL
`side2` VARCHAR, -- side name, or NULL
`crowded` BOOLEAN, -- at some point, three or more sides were involved
`first_mood` VARCHAR,
-- timing data for the mailbox itself
`started` INTEGER, -- time when opened
`second` INTEGER -- time when second side opened
);
CREATE INDEX `mailboxes_idx` ON `mailboxes` (`app_id`, `id`);
CREATE TABLE `messages`
(
`app_id` VARCHAR,
`mailbox_id` VARCHAR,
`side` VARCHAR,
`phase` VARCHAR, -- numeric or string
`body` VARCHAR,
`server_rx` INTEGER,
`msg_id` VARCHAR
);
CREATE INDEX `messages_idx` ON `messages` (`app_id`, `mailbox_id`);
CREATE TABLE `nameplate_usage`
(
`app_id` VARCHAR,
`started` INTEGER, -- seconds since epoch, rounded to "blur time"
`waiting_time` INTEGER, -- seconds from start to 2nd side appearing, or None
`total_time` INTEGER, -- seconds from open to last close/prune
`result` VARCHAR -- happy, lonely, pruney, crowded
-- nameplate moods:
-- "happy": two sides open and close
-- "lonely": one side opens and closes (no response from 2nd side)
-- "pruney": channels which get pruned for inactivity
-- "crowded": three or more sides were involved
);
CREATE INDEX `nameplate_usage_idx` ON `nameplate_usage` (`app_id`, `started`);
CREATE TABLE `mailbox_usage`
(
`app_id` VARCHAR,
`started` INTEGER, -- seconds since epoch, rounded to "blur time"
`total_time` INTEGER, -- seconds from open to last close
`waiting_time` INTEGER, -- seconds from start to 2nd side appearing, or None
`result` VARCHAR -- happy, scary, lonely, errory, pruney
-- rendezvous moods:
-- "happy": both sides close with mood=happy
-- "scary": any side closes with mood=scary (bad MAC, probably wrong pw)
-- "lonely": any side closes with mood=lonely (no response from 2nd side)
-- "errory": any side closes with mood=errory (other errors)
-- "pruney": channels which get pruned for inactivity
-- "crowded": three or more sides were involved
);
CREATE INDEX `mailbox_usage_idx` ON `mailbox_usage` (`app_id`, `started`);
CREATE TABLE `transit_usage`
(
`started` INTEGER, -- seconds since epoch, rounded to "blur time"
`total_time` INTEGER, -- seconds from open to last close
`waiting_time` INTEGER, -- seconds from start to 2nd side appearing, or None
`total_bytes` INTEGER, -- total bytes relayed (both directions)
`result` VARCHAR -- happy, scary, lonely, errory, pruney
-- transit moods:
-- "errory": one side gave the wrong handshake
-- "lonely": good handshake, but the other side never showed up
-- "happy": both sides gave correct handshake
);
CREATE INDEX `transit_usage_idx` ON `transit_usage` (`started`);

@ -0,0 +1,115 @@
-- note: anything which isn't an boolean, integer, or human-readable unicode
-- string, (i.e. binary strings) will be stored as hex
CREATE TABLE `version`
(
`version` INTEGER -- contains one row, set to 3
);
-- Wormhole codes use a "nameplate": a short name which is only used to
-- reference a specific (long-named) mailbox. The codes only use numeric
-- nameplates, but the protocol and server allow can use arbitrary strings.
CREATE TABLE `nameplates`
(
`id` INTEGER PRIMARY KEY AUTOINCREMENT,
`app_id` VARCHAR,
`name` VARCHAR,
`mailbox_id` VARCHAR REFERENCES `mailboxes`(`id`),
`request_id` VARCHAR -- from 'allocate' message, for future deduplication
);
CREATE INDEX `nameplates_idx` ON `nameplates` (`app_id`, `name`);
CREATE INDEX `nameplates_mailbox_idx` ON `nameplates` (`app_id`, `mailbox_id`);
CREATE INDEX `nameplates_request_idx` ON `nameplates` (`app_id`, `request_id`);
CREATE TABLE `nameplate_sides`
(
`nameplates_id` REFERENCES `nameplates`(`id`),
`claimed` BOOLEAN, -- True after claim(), False after release()
`side` VARCHAR,
`added` INTEGER -- time when this side first claimed the nameplate
);
-- Clients exchange messages through a "mailbox", which has a long (randomly
-- unique) identifier and a queue of messages.
-- `id` is randomly-generated and unique across all apps.
CREATE TABLE `mailboxes`
(
`app_id` VARCHAR,
`id` VARCHAR PRIMARY KEY,
`updated` INTEGER, -- time of last activity, used for pruning
`for_nameplate` BOOLEAN -- allocated for a nameplate, not standalone
);
CREATE INDEX `mailboxes_idx` ON `mailboxes` (`app_id`, `id`);
CREATE TABLE `mailbox_sides`
(
`mailbox_id` REFERENCES `mailboxes`(`id`),
`opened` BOOLEAN, -- True after open(), False after close()
`side` VARCHAR,
`added` INTEGER, -- time when this side first opened the mailbox
`mood` VARCHAR
);
CREATE TABLE `messages`
(
`app_id` VARCHAR,
`mailbox_id` VARCHAR,
`side` VARCHAR,
`phase` VARCHAR, -- numeric or string
`body` VARCHAR,
`server_rx` INTEGER,
`msg_id` VARCHAR
);
CREATE INDEX `messages_idx` ON `messages` (`app_id`, `mailbox_id`);
CREATE TABLE `nameplate_usage`
(
`app_id` VARCHAR,
`started` INTEGER, -- seconds since epoch, rounded to "blur time"
`waiting_time` INTEGER, -- seconds from start to 2nd side appearing, or None
`total_time` INTEGER, -- seconds from open to last close/prune
`result` VARCHAR -- happy, lonely, pruney, crowded
-- nameplate moods:
-- "happy": two sides open and close
-- "lonely": one side opens and closes (no response from 2nd side)
-- "pruney": channels which get pruned for inactivity
-- "crowded": three or more sides were involved
);
CREATE INDEX `nameplate_usage_idx` ON `nameplate_usage` (`app_id`, `started`);
CREATE TABLE `mailbox_usage`
(
`app_id` VARCHAR,
`for_nameplate` BOOLEAN, -- allocated for a nameplate, not standalone
`started` INTEGER, -- seconds since epoch, rounded to "blur time"
`total_time` INTEGER, -- seconds from open to last close
`waiting_time` INTEGER, -- seconds from start to 2nd side appearing, or None
`result` VARCHAR -- happy, scary, lonely, errory, pruney
-- rendezvous moods:
-- "happy": both sides close with mood=happy
-- "scary": any side closes with mood=scary (bad MAC, probably wrong pw)
-- "lonely": any side closes with mood=lonely (no response from 2nd side)
-- "errory": any side closes with mood=errory (other errors)
-- "pruney": channels which get pruned for inactivity
-- "crowded": three or more sides were involved
);
CREATE INDEX `mailbox_usage_idx` ON `mailbox_usage` (`app_id`, `started`);
CREATE INDEX `mailbox_usage_result_idx` ON `mailbox_usage` (`result`);
CREATE TABLE `transit_usage`
(
`started` INTEGER, -- seconds since epoch, rounded to "blur time"
`total_time` INTEGER, -- seconds from open to last close
`waiting_time` INTEGER, -- seconds from start to 2nd side appearing, or None
`total_bytes` INTEGER, -- total bytes relayed (both directions)
`result` VARCHAR -- happy, scary, lonely, errory, pruney
-- transit moods:
-- "errory": one side gave the wrong handshake
-- "lonely": good handshake, but the other side never showed up
-- "happy": both sides gave correct handshake
);
CREATE INDEX `transit_usage_idx` ON `transit_usage` (`started`);
CREATE INDEX `transit_usage_result_idx` ON `transit_usage` (`result`);

@ -0,0 +1,181 @@
# NO unicode_literals or static.Data() will break, because it demands
# a str on Python 2
from __future__ import print_function
import os, time, json
try:
# 'resource' is unix-only
from resource import getrlimit, setrlimit, RLIMIT_NOFILE
except ImportError: # pragma: nocover
getrlimit, setrlimit, RLIMIT_NOFILE = None, None, None # pragma: nocover
from twisted.python import log
from twisted.internet import reactor, endpoints
from twisted.application import service, internet
from twisted.web import server, static
from twisted.web.resource import Resource
from autobahn.twisted.resource import WebSocketResource
from .database import get_db
from .rendezvous import Rendezvous
from .rendezvous_websocket import WebSocketRendezvousFactory
from .transit_server import Transit
SECONDS = 1.0
MINUTE = 60*SECONDS
CHANNEL_EXPIRATION_TIME = 11*MINUTE
EXPIRATION_CHECK_PERIOD = 10*MINUTE
class Root(Resource):
# child_FOO is a nevow thing, not a twisted.web.resource thing
def __init__(self):
Resource.__init__(self)
self.putChild(b"", static.Data(b"Wormhole Relay\n", "text/plain"))
class PrivacyEnhancedSite(server.Site):
logRequests = True
def log(self, request):
if self.logRequests:
return server.Site.log(self, request)
class RelayServer(service.MultiService):
def __init__(self, rendezvous_web_port, transit_port,
advertise_version, db_url=":memory:", blur_usage=None,
signal_error=None, stats_file=None, allow_list=True,
websocket_protocol_options=()):
service.MultiService.__init__(self)
self._blur_usage = blur_usage
self._allow_list = allow_list
self._db_url = db_url
db = get_db(db_url)
welcome = {
# adding .motd will cause all clients to display the message,
# then keep running normally
#"motd": "Welcome to the public relay.\nPlease enjoy this service.",
# adding .error will cause all clients to fail, with this message
#"error": "This server has been disabled, see URL for details.",
}
if advertise_version:
# The primary (python CLI) implementation will emit a message if
# its version does not match this key. If/when we have
# distributions which include older version, but we still expect
# them to be compatible, stop sending this key.
welcome["current_cli_version"] = advertise_version
if signal_error:
welcome["error"] = signal_error
self._rendezvous = Rendezvous(db, welcome, blur_usage, self._allow_list)
self._rendezvous.setServiceParent(self) # for the pruning timer
root = Root()
wsrf = WebSocketRendezvousFactory(None, self._rendezvous)
_set_options(websocket_protocol_options, wsrf)
root.putChild(b"v1", WebSocketResource(wsrf))
site = PrivacyEnhancedSite(root)
if blur_usage:
site.logRequests = False
r = endpoints.serverFromString(reactor, rendezvous_web_port)
rendezvous_web_service = internet.StreamServerEndpointService(r, site)
rendezvous_web_service.setServiceParent(self)
if transit_port:
transit = Transit(db, blur_usage)
transit.setServiceParent(self) # for the timer
t = endpoints.serverFromString(reactor, transit_port)
transit_service = internet.StreamServerEndpointService(t, transit)
transit_service.setServiceParent(self)
self._stats_file = stats_file
if self._stats_file and os.path.exists(self._stats_file):
os.unlink(self._stats_file)
# this will be regenerated immediately, but if something goes
# wrong in dump_stats(), it's better to have a missing file than
# a stale one
t = internet.TimerService(EXPIRATION_CHECK_PERIOD, self.timer)
t.setServiceParent(self)
# make some things accessible for tests
self._db = db
self._root = root
self._rendezvous_web_service = rendezvous_web_service
self._rendezvous_websocket = wsrf
self._transit = None
if transit_port:
self._transit = transit
self._transit_service = transit_service
def increase_rlimits(self):
if getrlimit is None:
log.msg("unable to import 'resource', leaving rlimit alone")
return
soft, hard = getrlimit(RLIMIT_NOFILE)
if soft >= 10000:
log.msg("RLIMIT_NOFILE.soft was %d, leaving it alone" % soft)
return
# OS-X defaults to soft=7168, and reports a huge number for 'hard',
# but won't accept anything more than soft=10240, so we can't just
# set soft=hard. Linux returns (1024, 1048576) and is fine with
# soft=hard. Cygwin is reported to return (256,-1) and accepts up to
# soft=3200. So we try multiple values until something works.
for newlimit in [hard, 10000, 3200, 1024]:
log.msg("changing RLIMIT_NOFILE from (%s,%s) to (%s,%s)" %
(soft, hard, newlimit, hard))
try:
setrlimit(RLIMIT_NOFILE, (newlimit, hard))
log.msg("setrlimit successful")
return
except ValueError as e:
log.msg("error during setrlimit: %s" % e)
continue
except:
log.msg("other error during setrlimit, leaving it alone")
log.err()
return
log.msg("unable to change rlimit, leaving it alone")
def startService(self):
service.MultiService.startService(self)
self.increase_rlimits()
log.msg("websocket listening on /wormhole-relay/ws")
log.msg("Wormhole relay server (Rendezvous and Transit) running")
if self._blur_usage:
log.msg("blurring access times to %d seconds" % self._blur_usage)
log.msg("not logging HTTP requests or Transit connections")
else:
log.msg("not blurring access times")
if not self._allow_list:
log.msg("listing of allocated nameplates disallowed")
def timer(self):
now = time.time()
old = now - CHANNEL_EXPIRATION_TIME
self._rendezvous.prune_all_apps(now, old)
self.dump_stats(now, validity=EXPIRATION_CHECK_PERIOD+60)
def dump_stats(self, now, validity):
if not self._stats_file:
return
tmpfn = self._stats_file + ".tmp"
data = {}
data["created"] = now
data["valid_until"] = now + validity
start = time.time()
data["rendezvous"] = self._rendezvous.get_stats()
data["transit"] = self._transit.get_stats()
log.msg("get_stats took:", time.time() - start)
with open(tmpfn, "wb") as f:
# json.dump(f) has str-vs-unicode issues on py2-vs-py3
f.write(json.dumps(data, indent=1).encode("utf-8"))
f.write(b"\n")
os.rename(tmpfn, self._stats_file)
def _set_options(options, factory):
factory.setProtocolOptions(**dict(options))

@ -0,0 +1,61 @@
from __future__ import print_function, unicode_literals
import os
from twisted.python import filepath
from twisted.trial import unittest
from ..server import database
from ..server.database import get_db, TARGET_VERSION, dump_db
class DB(unittest.TestCase):
def test_create_default(self):
db_url = ":memory:"
db = get_db(db_url)
rows = db.execute("SELECT * FROM version").fetchall()
self.assertEqual(len(rows), 1)
self.assertEqual(rows[0]["version"], TARGET_VERSION)
def test_failed_create_allows_subsequent_create(self):
patch = self.patch(database, "get_schema", lambda version: b"this is a broken schema")
dbfile = filepath.FilePath(self.mktemp())
self.assertRaises(Exception, lambda: get_db(dbfile.path))
patch.restore()
get_db(dbfile.path)
def test_upgrade(self):
basedir = self.mktemp()
os.mkdir(basedir)
fn = os.path.join(basedir, "upgrade.db")
self.assertNotEqual(TARGET_VERSION, 2)
# create an old-version DB in a file
db = get_db(fn, 2)
rows = db.execute("SELECT * FROM version").fetchall()
self.assertEqual(len(rows), 1)
self.assertEqual(rows[0]["version"], 2)
del db
# then upgrade the file to the latest version
dbA = get_db(fn, TARGET_VERSION)
rows = dbA.execute("SELECT * FROM version").fetchall()
self.assertEqual(len(rows), 1)
self.assertEqual(rows[0]["version"], TARGET_VERSION)
dbA_text = dump_db(dbA)
del dbA
# make sure the upgrades got committed to disk
dbB = get_db(fn, TARGET_VERSION)
dbB_text = dump_db(dbB)
del dbB
self.assertEqual(dbA_text, dbB_text)
# The upgraded schema should be equivalent to that of a new DB.
# However a text dump will differ because ALTER TABLE always appends
# the new column to the end of a table, whereas our schema puts it
# somewhere in the middle (wherever it fits naturally). Also ALTER
# TABLE doesn't include comments.
if False:
latest_db = get_db(":memory:", TARGET_VERSION)
latest_text = dump_db(latest_db)
with open("up.sql","w") as f: f.write(dbA_text)
with open("new.sql","w") as f: f.write(latest_text)
# check with "diff -u _trial_temp/up.sql _trial_temp/new.sql"
self.assertEqual(dbA_text, latest_text)

@ -0,0 +1,306 @@
from __future__ import print_function, unicode_literals
from binascii import hexlify
from twisted.trial import unittest
from twisted.internet import protocol, reactor, defer
from twisted.internet.endpoints import clientFromString, connectProtocol
from twisted.web import client
from .common import ServerBase
from ..server import transit_server
class Accumulator(protocol.Protocol):
def __init__(self):
self.data = b""
self.count = 0
self._wait = None
self._disconnect = defer.Deferred()
def waitForBytes(self, more):
assert self._wait is None
self.count = more
self._wait = defer.Deferred()
self._check_done()
return self._wait
def dataReceived(self, data):
self.data = self.data + data
self._check_done()
def _check_done(self):
if self._wait and len(self.data) >= self.count:
d = self._wait
self._wait = None
d.callback(self)
def connectionLost(self, why):
if self._wait:
self._wait.errback(RuntimeError("closed"))
self._disconnect.callback(None)
class Transit(ServerBase, unittest.TestCase):
def test_blur_size(self):
blur = transit_server.blur_size
self.failUnlessEqual(blur(0), 0)
self.failUnlessEqual(blur(1), 10e3)
self.failUnlessEqual(blur(10e3), 10e3)
self.failUnlessEqual(blur(10e3+1), 20e3)
self.failUnlessEqual(blur(15e3), 20e3)
self.failUnlessEqual(blur(20e3), 20e3)
self.failUnlessEqual(blur(1e6), 1e6)
self.failUnlessEqual(blur(1e6+1), 2e6)
self.failUnlessEqual(blur(1.5e6), 2e6)
self.failUnlessEqual(blur(2e6), 2e6)
self.failUnlessEqual(blur(900e6), 900e6)
self.failUnlessEqual(blur(1000e6), 1000e6)
self.failUnlessEqual(blur(1050e6), 1100e6)
self.failUnlessEqual(blur(1100e6), 1100e6)
self.failUnlessEqual(blur(1150e6), 1200e6)
@defer.inlineCallbacks
def test_web_request(self):
resp = yield client.getPage('http://127.0.0.1:{}/'.format(self.relayport).encode('ascii'))
self.assertEqual('Wormhole Relay'.encode('ascii'), resp.strip())
@defer.inlineCallbacks
def test_register(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
token1 = b"\x00"*32
side1 = b"\x01"*8
a1.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\n")
# let that arrive
while self.count() == 0:
yield self.wait()
self.assertEqual(self.count(), 1)
a1.transport.loseConnection()
# let that get removed
while self.count() > 0:
yield self.wait()
self.assertEqual(self.count(), 0)
# the token should be removed too
self.assertEqual(len(self._transit_server._pending_requests), 0)
@defer.inlineCallbacks
def test_both_unsided(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
a2 = yield connectProtocol(ep, Accumulator())
token1 = b"\x00"*32
a1.transport.write(b"please relay " + hexlify(token1) + b"\n")
a2.transport.write(b"please relay " + hexlify(token1) + b"\n")
# a correct handshake yields an ack, after which we can send
exp = b"ok\n"
yield a1.waitForBytes(len(exp))
self.assertEqual(a1.data, exp)
s1 = b"data1"
a1.transport.write(s1)
exp = b"ok\n"
yield a2.waitForBytes(len(exp))
self.assertEqual(a2.data, exp)
# all data they sent after the handshake should be given to us
exp = b"ok\n"+s1
yield a2.waitForBytes(len(exp))
self.assertEqual(a2.data, exp)
a1.transport.loseConnection()
a2.transport.loseConnection()
@defer.inlineCallbacks
def test_sided_unsided(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
a2 = yield connectProtocol(ep, Accumulator())
token1 = b"\x00"*32
side1 = b"\x01"*8
a1.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\n")
a2.transport.write(b"please relay " + hexlify(token1) + b"\n")
# a correct handshake yields an ack, after which we can send
exp = b"ok\n"
yield a1.waitForBytes(len(exp))
self.assertEqual(a1.data, exp)
s1 = b"data1"
a1.transport.write(s1)
exp = b"ok\n"
yield a2.waitForBytes(len(exp))
self.assertEqual(a2.data, exp)
# all data they sent after the handshake should be given to us
exp = b"ok\n"+s1
yield a2.waitForBytes(len(exp))
self.assertEqual(a2.data, exp)
a1.transport.loseConnection()
a2.transport.loseConnection()
@defer.inlineCallbacks
def test_unsided_sided(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
a2 = yield connectProtocol(ep, Accumulator())
token1 = b"\x00"*32
side1 = b"\x01"*8
a1.transport.write(b"please relay " + hexlify(token1) + b"\n")
a2.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\n")
# a correct handshake yields an ack, after which we can send
exp = b"ok\n"
yield a1.waitForBytes(len(exp))
self.assertEqual(a1.data, exp)
s1 = b"data1"
a1.transport.write(s1)
exp = b"ok\n"
yield a2.waitForBytes(len(exp))
self.assertEqual(a2.data, exp)
# all data they sent after the handshake should be given to us
exp = b"ok\n"+s1
yield a2.waitForBytes(len(exp))
self.assertEqual(a2.data, exp)
a1.transport.loseConnection()
a2.transport.loseConnection()
@defer.inlineCallbacks
def test_both_sided(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
a2 = yield connectProtocol(ep, Accumulator())
token1 = b"\x00"*32
side1 = b"\x01"*8
side2 = b"\x02"*8
a1.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\n")
a2.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side2) + b"\n")
# a correct handshake yields an ack, after which we can send
exp = b"ok\n"
yield a1.waitForBytes(len(exp))
self.assertEqual(a1.data, exp)
s1 = b"data1"
a1.transport.write(s1)
exp = b"ok\n"
yield a2.waitForBytes(len(exp))
self.assertEqual(a2.data, exp)
# all data they sent after the handshake should be given to us
exp = b"ok\n"+s1
yield a2.waitForBytes(len(exp))
self.assertEqual(a2.data, exp)
a1.transport.loseConnection()
a2.transport.loseConnection()
def count(self):
return sum([len(potentials)
for potentials
in self._transit_server._pending_requests.values()])
def wait(self):
d = defer.Deferred()
reactor.callLater(0.001, d.callback, None)
return d
@defer.inlineCallbacks
def test_ignore_same_side(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
a2 = yield connectProtocol(ep, Accumulator())
token1 = b"\x00"*32
side1 = b"\x01"*8
a1.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\n")
# let that arrive
while self.count() == 0:
yield self.wait()
a2.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\n")
# let that arrive
while self.count() == 1:
yield self.wait()
self.assertEqual(self.count(), 2) # same-side connections don't match
a1.transport.loseConnection()
a2.transport.loseConnection()
@defer.inlineCallbacks
def test_bad_handshake(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
token1 = b"\x00"*32
# the server waits for the exact number of bytes in the expected
# handshake message. to trigger "bad handshake", we must match.
a1.transport.write(b"please DELAY " + hexlify(token1) + b"\n")
exp = b"bad handshake\n"
yield a1.waitForBytes(len(exp))
self.assertEqual(a1.data, exp)
a1.transport.loseConnection()
@defer.inlineCallbacks
def test_binary_handshake(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
binary_bad_handshake = b"\x00\x01\xe0\x0f\n\xff"
# the embedded \n makes the server trigger early, before the full
# expected handshake length has arrived. A non-wormhole client
# writing non-ascii junk to the transit port used to trigger a
# UnicodeDecodeError when it tried to coerce the incoming handshake
# to unicode, due to the ("\n" in buf) check. This was fixed to use
# (b"\n" in buf). This exercises the old failure.
a1.transport.write(binary_bad_handshake)
exp = b"bad handshake\n"
yield a1.waitForBytes(len(exp))
self.assertEqual(a1.data, exp)
a1.transport.loseConnection()
@defer.inlineCallbacks
def test_impatience_old(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
token1 = b"\x00"*32
# sending too many bytes is impatience.
a1.transport.write(b"please relay " + hexlify(token1) + b"\nNOWNOWNOW")
exp = b"impatient\n"
yield a1.waitForBytes(len(exp))
self.assertEqual(a1.data, exp)
a1.transport.loseConnection()
@defer.inlineCallbacks
def test_impatience_new(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
token1 = b"\x00"*32
side1 = b"\x01"*8
# sending too many bytes is impatience.
a1.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\nNOWNOWNOW")
exp = b"impatient\n"
yield a1.waitForBytes(len(exp))
self.assertEqual(a1.data, exp)
a1.transport.loseConnection()

@ -0,0 +1,328 @@
from __future__ import print_function, unicode_literals
import re, time, collections
from twisted.python import log
from twisted.internet import protocol
from twisted.application import service
SECONDS = 1.0
MINUTE = 60*SECONDS
HOUR = 60*MINUTE
DAY = 24*HOUR
MB = 1000*1000
def round_to(size, coarseness):
return int(coarseness*(1+int((size-1)/coarseness)))
def blur_size(size):
if size == 0:
return 0
if size < 1e6:
return round_to(size, 10e3)
if size < 1e9:
return round_to(size, 1e6)
return round_to(size, 100e6)
class TransitConnection(protocol.Protocol):
def __init__(self):
self._got_token = False
self._got_side = False
self._token_buffer = b""
self._sent_ok = False
self._buddy = None
self._had_buddy = False
self._total_sent = 0
def describeToken(self):
d = "-"
if self._got_token:
d = self._got_token[:16].decode("ascii")
if self._got_side:
d += "-" + self._got_side.decode("ascii")
else:
d += "-<unsided>"
return d
def connectionMade(self):
self._started = time.time()
self._log_requests = self.factory._log_requests
def dataReceived(self, data):
if self._sent_ok:
# We are an IPushProducer to our buddy's IConsumer, so they'll
# throttle us (by calling pauseProducing()) when their outbound
# buffer is full (e.g. when their downstream pipe is full). In
# practice, this buffers about 10MB per connection, after which
# point the sender will only transmit data as fast as the
# receiver can handle it.
self._total_sent += len(data)
self._buddy.transport.write(data)
return
if self._got_token: # but not yet sent_ok
self.transport.write(b"impatient\n")
if self._log_requests:
log.msg("transit impatience failure")
return self.disconnect() # impatience yields failure
# else this should be (part of) the token
self._token_buffer += data
buf = self._token_buffer
# old: "please relay {64}\n"
# new: "please relay {64} for side {16}\n"
(old, handshake_len, token) = self._check_old_handshake(buf)
assert old in ("yes", "waiting", "no")
if old == "yes":
# remember they aren't supposed to send anything past their
# handshake until we've said go
if len(buf) > handshake_len:
self.transport.write(b"impatient\n")
if self._log_requests:
log.msg("transit impatience failure")
return self.disconnect() # impatience yields failure
return self._got_handshake(token, None)
(new, handshake_len, token, side) = self._check_new_handshake(buf)
assert new in ("yes", "waiting", "no")
if new == "yes":
if len(buf) > handshake_len:
self.transport.write(b"impatient\n")
if self._log_requests:
log.msg("transit impatience failure")
return self.disconnect() # impatience yields failure
return self._got_handshake(token, side)
if (old == "no" and new == "no"):
self.transport.write(b"bad handshake\n")
if self._log_requests:
log.msg("transit handshake failure")
return self.disconnect() # incorrectness yields failure
# else we'll keep waiting
def _check_old_handshake(self, buf):
# old: "please relay {64}\n"
# return ("yes", handshake, token) if buf contains an old-style handshake
# return ("waiting", None, None) if it might eventually contain one
# return ("no", None, None) if it could never contain one
wanted = len("please relay \n")+32*2
if len(buf) < wanted-1 and b"\n" in buf:
return ("no", None, None)
if len(buf) < wanted:
return ("waiting", None, None)
mo = re.search(br"^please relay (\w{64})\n", buf, re.M)
if mo:
token = mo.group(1)
return ("yes", wanted, token)
return ("no", None, None)
def _check_new_handshake(self, buf):
# new: "please relay {64} for side {16}\n"
wanted = len("please relay for side \n")+32*2+8*2
if len(buf) < wanted-1 and b"\n" in buf:
return ("no", None, None, None)
if len(buf) < wanted:
return ("waiting", None, None, None)
mo = re.search(br"^please relay (\w{64}) for side (\w{16})\n", buf, re.M)
if mo:
token = mo.group(1)
side = mo.group(2)
return ("yes", wanted, token, side)
return ("no", None, None, None)
def _got_handshake(self, token, side):
self._got_token = token
self._got_side = side
self.factory.connection_got_token(token, side, self)
def buddy_connected(self, them):
self._buddy = them
self._had_buddy = True
self.transport.write(b"ok\n")
self._sent_ok = True
# Connect the two as a producer/consumer pair. We use streaming=True,
# so this expects the IPushProducer interface, and uses
# pauseProducing() to throttle, and resumeProducing() to unthrottle.
self._buddy.transport.registerProducer(self.transport, True)
# The Transit object calls buddy_connected() on both protocols, so
# there will be two producer/consumer pairs.
def buddy_disconnected(self):
if self._log_requests:
log.msg("buddy_disconnected %s" % self.describeToken())
self._buddy = None
self.transport.loseConnection()
def connectionLost(self, reason):
if self._buddy:
self._buddy.buddy_disconnected()
self.factory.transitFinished(self, self._got_token, self._got_side,
self.describeToken())
# Record usage. There are four cases:
# * 1: we connected, never had a buddy
# * 2: we connected first, we disconnect before the buddy
# * 3: we connected first, buddy disconnects first
# * 4: buddy connected first, we disconnect before buddy
# * 5: buddy connected first, buddy disconnects first
# whoever disconnects first gets to write the usage record (1,2,4)
finished = time.time()
if not self._had_buddy: # 1
total_time = finished - self._started
self.factory.recordUsage(self._started, "lonely", 0,
total_time, None)
if self._had_buddy and self._buddy: # 2,4
total_bytes = self._total_sent + self._buddy._total_sent
starts = [self._started, self._buddy._started]
total_time = finished - min(starts)
waiting_time = max(starts) - min(starts)
self.factory.recordUsage(self._started, "happy", total_bytes,
total_time, waiting_time)
def disconnect(self):
self.transport.loseConnection()
self.factory.transitFailed(self)
finished = time.time()
total_time = finished - self._started
self.factory.recordUsage(self._started, "errory", 0,
total_time, None)
class Transit(protocol.ServerFactory, service.MultiService):
# I manage pairs of simultaneous connections to a secondary TCP port,
# both forwarded to the other. Clients must begin each connection with
# "please relay TOKEN for SIDE\n" (or a legacy form without the "for
# SIDE"). Two connections match if they use the same TOKEN and have
# different SIDEs (the redundant connections are dropped when a match is
# made). Legacy connections match any with the same TOKEN, ignoring SIDE
# (so two legacy connections will match each other).
# I will send "ok\n" when the matching connection is established, or
# disconnect if no matching connection is made within MAX_WAIT_TIME
# seconds. I will disconnect if you send data before the "ok\n". All data
# you get after the "ok\n" will be from the other side. You will not
# receive "ok\n" until the other side has also connected and submitted a
# matching token (and differing SIDE).
# In addition, the connections will be dropped after MAXLENGTH bytes have
# been sent by either side, or MAXTIME seconds have elapsed after the
# matching connections were established. A future API will reveal these
# limits to clients instead of causing mysterious spontaneous failures.
# These relay connections are not half-closeable (unlike full TCP
# connections, applications will not receive any data after half-closing
# their outgoing side). Applications must negotiate shutdown with their
# peer and not close the connection until all data has finished
# transferring in both directions. Applications which only need to send
# data in one direction can use close() as usual.
MAX_WAIT_TIME = 30*SECONDS
MAXLENGTH = 10*MB
MAXTIME = 60*SECONDS
protocol = TransitConnection
def __init__(self, db, blur_usage):
service.MultiService.__init__(self)
self._db = db
self._blur_usage = blur_usage
self._log_requests = blur_usage is None
self._pending_requests = {} # token -> set((side, TransitConnection))
self._active_connections = set() # TransitConnection
self._counts = collections.defaultdict(int)
self._count_bytes = 0
def connection_got_token(self, token, new_side, new_tc):
if token not in self._pending_requests:
self._pending_requests[token] = set()
potentials = self._pending_requests[token]
for old in potentials:
(old_side, old_tc) = old
if ((old_side is None)
or (new_side is None)
or (old_side != new_side)):
# we found a match
if self._log_requests:
log.msg("transit relay 2: %s" % new_tc.describeToken())
# drop and stop tracking the rest
potentials.remove(old)
for (_, leftover_tc) in potentials:
leftover_tc.disconnect() # TODO: not "errory"?
self._pending_requests.pop(token)
# glue the two ends together
self._active_connections.add(new_tc)
self._active_connections.add(old_tc)
new_tc.buddy_connected(old_tc)
old_tc.buddy_connected(new_tc)
return
if self._log_requests:
log.msg("transit relay 1: %s" % new_tc.describeToken())
potentials.add((new_side, new_tc))
# TODO: timer
def recordUsage(self, started, result, total_bytes,
total_time, waiting_time):
if self._log_requests:
log.msg("Transit.recordUsage (%dB)" % total_bytes)
if self._blur_usage:
started = self._blur_usage * (started // self._blur_usage)
total_bytes = blur_size(total_bytes)
self._db.execute("INSERT INTO `transit_usage`"
" (`started`, `total_time`, `waiting_time`,"
" `total_bytes`, `result`)"
" VALUES (?,?,?, ?,?)",
(started, total_time, waiting_time,
total_bytes, result))
self._db.commit()
self._counts[result] += 1
self._count_bytes += total_bytes
def transitFinished(self, tc, token, side, description):
if token in self._pending_requests:
side_tc = (side, tc)
if side_tc in self._pending_requests[token]:
self._pending_requests[token].remove(side_tc)
if not self._pending_requests[token]: # set is now empty
del self._pending_requests[token]
if self._log_requests:
log.msg("transitFinished %s" % (description,))
self._active_connections.discard(tc)
def transitFailed(self, p):
if self._log_requests:
log.msg("transitFailed %r" % p)
pass
def get_stats(self):
stats = {}
def q(query, values=()):
row = self._db.execute(query, values).fetchone()
return list(row.values())[0]
# current status: expected to be zero most of the time
c = stats["active"] = {}
c["connected"] = len(self._active_connections) / 2
c["waiting"] = len(self._pending_requests)
# usage since last reboot
rb = stats["since_reboot"] = {}
rb["bytes"] = self._count_bytes
rb["total"] = sum(self._counts.values(), 0)
rbm = rb["moods"] = {}
for result, count in self._counts.items():
rbm[result] = count
# historical usage (all-time)
u = stats["all_time"] = {}
u["total"] = q("SELECT COUNT() FROM `transit_usage`")
u["bytes"] = q("SELECT SUM(`total_bytes`) FROM `transit_usage`") or 0
um = u["moods"] = {}
um["happy"] = q("SELECT COUNT() FROM `transit_usage`"
" WHERE `result`='happy'")
um["lonely"] = q("SELECT COUNT() FROM `transit_usage`"
" WHERE `result`='lonely'")
um["errory"] = q("SELECT COUNT() FROM `transit_usage`"
" WHERE `result`='errory'")
return stats

@ -0,0 +1,29 @@
# Tox (http://tox.testrun.org/) is a tool for running tests
# in multiple virtualenvs. This configuration file will run the
# test suite on all supported python versions. To use it, "pip install tox"
# and then run "tox" from this directory.
[tox]
envlist = {py27,py34,py35,py36,pypy}
skip_missing_interpreters = True
minversion = 2.4.0
[testenv]
usedevelop = True
extras = dev
deps =
pyflakes >= 1.2.3
commands =
pyflakes setup.py src
wormhole --version
python -m twisted.trial {posargs:wormhole_transit_relay}
[testenv:coverage]
deps =
pyflakes >= 1.2.3
coverage
commands =
pyflakes setup.py src
wormhole --version
coverage run --branch -m twisted.trial {posargs:wormhole_transit_relay}
coverage xml

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save