Merge branch 'remove-transit'

This removes the Transit Relay server from "wormhole-server", since it's been
split out into it's own package:

* https://github.com/warner/magic-wormhole-transit-relay
* https://pypi.python.org/pypi/magic-wormhole-transit-relay/0.1.0

The magic-wormhole tests now import that external
magic-wormhole-transit-relay package to exercise the client-side
functionality, as well as for the integration tests that do end-to-end
transfers. A normal "pip install magic-wormhole" will no longer include
transit-relay functionality.

The next step will be to remove the Rendezvous Server functionality too,
following the same path (create a new package, copy the server code into it,
get it working, remove that code from magic-wormhole, rewrite the tests to
import the external package).
This commit is contained in:
Brian Warner 2017-11-13 14:15:03 -08:00
commit 9d4c6aad3b
15 changed files with 20 additions and 749 deletions

View File

@ -11,5 +11,3 @@ include misc/munin/wormhole_errors
include misc/munin/wormhole_event_rate
include misc/munin/wormhole_events
include misc/munin/wormhole_events_alltime
include misc/munin/wormhole_transit
include misc/munin/wormhole_transit_alltime

View File

@ -184,6 +184,8 @@ addresses of each client with the other (inside the encrypted message), and
both clients first attempt to connect directly. If this fails, they fall back
to using the transit relay. As before, the host/port of a public server is
baked into the library, and should be sufficient to handle moderate traffic.
Code for the Transit Relay is provided a separate package named
`magic-wormhole-transit-relay`.
The protocol includes provisions to deliver notices and error messages to
clients: if either relay must be shut down, these channels will be used to

View File

@ -22,12 +22,6 @@ mailboxes.type GAUGE
messages.label Messages
messages.draw LINE1
messages.type GAUGE
transit_waiting.label Transit Waiting
transit_waiting.draw LINE1
transit_waiting.type GAUGE
transit_connected.label Transit Connected
transit_connected.draw LINE1
transit_connected.type GAUGE
"""
if len(sys.argv) > 1 and sys.argv[1] == "config":
@ -45,6 +39,3 @@ ra = data["rendezvous"]["active"]
print "nameplates.value", ra["nameplates_total"]
print "mailboxes.value", ra["mailboxes_total"]
print "messages.value", ra["messages_total"]
ta = data["transit"]["active"]
print "transit_waiting.value", ta["waiting"]
print "transit_connected.value", ta["connected"]

View File

@ -22,9 +22,6 @@ mailboxes.type GAUGE
mailboxes_scary.label Mailboxes (scary)
mailboxes_scary.draw LINE1
mailboxes_scary.type GAUGE
transit.label Transit
transit.draw LINE1
transit.type GAUGE
"""
if len(sys.argv) > 1 and sys.argv[1] == "config":
@ -44,5 +41,3 @@ print "nameplates.value", (r["nameplates_total"]
print "mailboxes.value", (r["mailboxes_total"]
- r["mailbox_moods"].get("happy", 0))
print "mailboxes_scary.value", r["mailbox_moods"].get("scary", 0)
t = data["transit"]["since_reboot"]
print "transit.value", (t["total"] - t["moods"].get("happy", 0))

View File

@ -1,33 +0,0 @@
#! /usr/bin/env python
"""
Use the following in /etc/munin/plugin-conf.d/wormhole :
[wormhole_*]
env.serverdir /path/to/your/wormhole/server
"""
import os, sys, time, json
CONFIG = """\
graph_title Magic-Wormhole Transit Usage (since reboot)
graph_vlabel Bytes Since Reboot
graph_category network
bytes.label Transit Bytes
bytes.draw LINE1
bytes.type GAUGE
"""
if len(sys.argv) > 1 and sys.argv[1] == "config":
print CONFIG.rstrip()
sys.exit(0)
serverdir = os.environ["serverdir"]
fn = os.path.join(serverdir, "stats.json")
with open(fn) as f:
data = json.load(f)
if time.time() > data["valid_until"]:
sys.exit(1) # expired
t = data["transit"]["since_reboot"]
print "bytes.value", t["bytes"]

View File

@ -1,33 +0,0 @@
#! /usr/bin/env python
"""
Use the following in /etc/munin/plugin-conf.d/wormhole :
[wormhole_*]
env.serverdir /path/to/your/wormhole/server
"""
import os, sys, time, json
CONFIG = """\
graph_title Magic-Wormhole Transit Usage (all time)
graph_vlabel Bytes Since DB Creation
graph_category network
bytes.label Transit Bytes
bytes.draw LINE1
bytes.type GAUGE
"""
if len(sys.argv) > 1 and sys.argv[1] == "config":
print CONFIG.rstrip()
sys.exit(0)
serverdir = os.environ["serverdir"]
fn = os.path.join(serverdir, "stats.json")
with open(fn) as f:
data = json.load(f)
if time.time() > data["valid_until"]:
sys.exit(1) # expired
t = data["transit"]["all_time"]
print "bytes.value", t["bytes"]

View File

@ -41,7 +41,8 @@ setup(name="magic-wormhole",
],
extras_require={
':sys_platform=="win32"': ["pypiwin32"],
"dev": ["mock", "tox", "pyflakes"],
"dev": ["mock", "tox", "pyflakes",
"magic-wormhole-transit-relay==0.1.0"],
},
test_suite="wormhole.test",
cmdclass=commands,

View File

@ -41,10 +41,6 @@ LaunchArgs = _compose(
"--rendezvous", default="tcp:4000", metavar="tcp:PORT",
help="endpoint specification for the rendezvous port",
),
click.option(
"--transit", default="tcp:4001", metavar="tcp:PORT",
help="endpoint specification for the transit-relay port",
),
click.option(
"--advertise-version", metavar="VERSION",
help="version to recommend to clients",

View File

@ -15,7 +15,6 @@ class MyPlugin(object):
from .server import RelayServer
return RelayServer(
str(self.args.rendezvous),
str(self.args.transit),
self.args.advertise_version,
self.args.relay_database_path,
self.args.blur_usage,

View File

@ -16,7 +16,6 @@ from autobahn.twisted.resource import WebSocketResource
from .database import get_db
from .rendezvous import Rendezvous
from .rendezvous_websocket import WebSocketRendezvousFactory
from .transit_server import Transit
SECONDS = 1.0
MINUTE = 60*SECONDS
@ -38,7 +37,7 @@ class PrivacyEnhancedSite(server.Site):
class RelayServer(service.MultiService):
def __init__(self, rendezvous_web_port, transit_port,
def __init__(self, rendezvous_web_port,
advertise_version, db_url=":memory:", blur_usage=None,
signal_error=None, stats_file=None, allow_list=True,
websocket_protocol_options=()):
@ -82,13 +81,6 @@ class RelayServer(service.MultiService):
rendezvous_web_service = internet.StreamServerEndpointService(r, site)
rendezvous_web_service.setServiceParent(self)
if transit_port:
transit = Transit(db, blur_usage)
transit.setServiceParent(self) # for the timer
t = endpoints.serverFromString(reactor, transit_port)
transit_service = internet.StreamServerEndpointService(t, transit)
transit_service.setServiceParent(self)
self._stats_file = stats_file
if self._stats_file and os.path.exists(self._stats_file):
os.unlink(self._stats_file)
@ -103,10 +95,6 @@ class RelayServer(service.MultiService):
self._root = root
self._rendezvous_web_service = rendezvous_web_service
self._rendezvous_websocket = wsrf
self._transit = None
if transit_port:
self._transit = transit
self._transit_service = transit_service
def increase_rlimits(self):
if getrlimit is None:
@ -141,10 +129,10 @@ class RelayServer(service.MultiService):
service.MultiService.startService(self)
self.increase_rlimits()
log.msg("websocket listening on /wormhole-relay/ws")
log.msg("Wormhole relay server (Rendezvous and Transit) running")
log.msg("Wormhole relay server (Rendezvous) running")
if self._blur_usage:
log.msg("blurring access times to %d seconds" % self._blur_usage)
log.msg("not logging HTTP requests or Transit connections")
log.msg("not logging HTTP requests")
else:
log.msg("not blurring access times")
if not self._allow_list:
@ -167,8 +155,6 @@ class RelayServer(service.MultiService):
start = time.time()
data["rendezvous"] = self._rendezvous.get_stats()
if self._transit:
data["transit"] = self._transit.get_stats()
log.msg("get_stats took:", time.time() - start)
with open(tmpfn, "wb") as f:

View File

@ -1,328 +0,0 @@
from __future__ import print_function, unicode_literals
import re, time, collections
from twisted.python import log
from twisted.internet import protocol
from twisted.application import service
SECONDS = 1.0
MINUTE = 60*SECONDS
HOUR = 60*MINUTE
DAY = 24*HOUR
MB = 1000*1000
def round_to(size, coarseness):
return int(coarseness*(1+int((size-1)/coarseness)))
def blur_size(size):
if size == 0:
return 0
if size < 1e6:
return round_to(size, 10e3)
if size < 1e9:
return round_to(size, 1e6)
return round_to(size, 100e6)
class TransitConnection(protocol.Protocol):
def __init__(self):
self._got_token = False
self._got_side = False
self._token_buffer = b""
self._sent_ok = False
self._buddy = None
self._had_buddy = False
self._total_sent = 0
def describeToken(self):
d = "-"
if self._got_token:
d = self._got_token[:16].decode("ascii")
if self._got_side:
d += "-" + self._got_side.decode("ascii")
else:
d += "-<unsided>"
return d
def connectionMade(self):
self._started = time.time()
self._log_requests = self.factory._log_requests
def dataReceived(self, data):
if self._sent_ok:
# We are an IPushProducer to our buddy's IConsumer, so they'll
# throttle us (by calling pauseProducing()) when their outbound
# buffer is full (e.g. when their downstream pipe is full). In
# practice, this buffers about 10MB per connection, after which
# point the sender will only transmit data as fast as the
# receiver can handle it.
self._total_sent += len(data)
self._buddy.transport.write(data)
return
if self._got_token: # but not yet sent_ok
self.transport.write(b"impatient\n")
if self._log_requests:
log.msg("transit impatience failure")
return self.disconnect() # impatience yields failure
# else this should be (part of) the token
self._token_buffer += data
buf = self._token_buffer
# old: "please relay {64}\n"
# new: "please relay {64} for side {16}\n"
(old, handshake_len, token) = self._check_old_handshake(buf)
assert old in ("yes", "waiting", "no")
if old == "yes":
# remember they aren't supposed to send anything past their
# handshake until we've said go
if len(buf) > handshake_len:
self.transport.write(b"impatient\n")
if self._log_requests:
log.msg("transit impatience failure")
return self.disconnect() # impatience yields failure
return self._got_handshake(token, None)
(new, handshake_len, token, side) = self._check_new_handshake(buf)
assert new in ("yes", "waiting", "no")
if new == "yes":
if len(buf) > handshake_len:
self.transport.write(b"impatient\n")
if self._log_requests:
log.msg("transit impatience failure")
return self.disconnect() # impatience yields failure
return self._got_handshake(token, side)
if (old == "no" and new == "no"):
self.transport.write(b"bad handshake\n")
if self._log_requests:
log.msg("transit handshake failure")
return self.disconnect() # incorrectness yields failure
# else we'll keep waiting
def _check_old_handshake(self, buf):
# old: "please relay {64}\n"
# return ("yes", handshake, token) if buf contains an old-style handshake
# return ("waiting", None, None) if it might eventually contain one
# return ("no", None, None) if it could never contain one
wanted = len("please relay \n")+32*2
if len(buf) < wanted-1 and b"\n" in buf:
return ("no", None, None)
if len(buf) < wanted:
return ("waiting", None, None)
mo = re.search(br"^please relay (\w{64})\n", buf, re.M)
if mo:
token = mo.group(1)
return ("yes", wanted, token)
return ("no", None, None)
def _check_new_handshake(self, buf):
# new: "please relay {64} for side {16}\n"
wanted = len("please relay for side \n")+32*2+8*2
if len(buf) < wanted-1 and b"\n" in buf:
return ("no", None, None, None)
if len(buf) < wanted:
return ("waiting", None, None, None)
mo = re.search(br"^please relay (\w{64}) for side (\w{16})\n", buf, re.M)
if mo:
token = mo.group(1)
side = mo.group(2)
return ("yes", wanted, token, side)
return ("no", None, None, None)
def _got_handshake(self, token, side):
self._got_token = token
self._got_side = side
self.factory.connection_got_token(token, side, self)
def buddy_connected(self, them):
self._buddy = them
self._had_buddy = True
self.transport.write(b"ok\n")
self._sent_ok = True
# Connect the two as a producer/consumer pair. We use streaming=True,
# so this expects the IPushProducer interface, and uses
# pauseProducing() to throttle, and resumeProducing() to unthrottle.
self._buddy.transport.registerProducer(self.transport, True)
# The Transit object calls buddy_connected() on both protocols, so
# there will be two producer/consumer pairs.
def buddy_disconnected(self):
if self._log_requests:
log.msg("buddy_disconnected %s" % self.describeToken())
self._buddy = None
self.transport.loseConnection()
def connectionLost(self, reason):
if self._buddy:
self._buddy.buddy_disconnected()
self.factory.transitFinished(self, self._got_token, self._got_side,
self.describeToken())
# Record usage. There are four cases:
# * 1: we connected, never had a buddy
# * 2: we connected first, we disconnect before the buddy
# * 3: we connected first, buddy disconnects first
# * 4: buddy connected first, we disconnect before buddy
# * 5: buddy connected first, buddy disconnects first
# whoever disconnects first gets to write the usage record (1,2,4)
finished = time.time()
if not self._had_buddy: # 1
total_time = finished - self._started
self.factory.recordUsage(self._started, "lonely", 0,
total_time, None)
if self._had_buddy and self._buddy: # 2,4
total_bytes = self._total_sent + self._buddy._total_sent
starts = [self._started, self._buddy._started]
total_time = finished - min(starts)
waiting_time = max(starts) - min(starts)
self.factory.recordUsage(self._started, "happy", total_bytes,
total_time, waiting_time)
def disconnect(self):
self.transport.loseConnection()
self.factory.transitFailed(self)
finished = time.time()
total_time = finished - self._started
self.factory.recordUsage(self._started, "errory", 0,
total_time, None)
class Transit(protocol.ServerFactory, service.MultiService):
# I manage pairs of simultaneous connections to a secondary TCP port,
# both forwarded to the other. Clients must begin each connection with
# "please relay TOKEN for SIDE\n" (or a legacy form without the "for
# SIDE"). Two connections match if they use the same TOKEN and have
# different SIDEs (the redundant connections are dropped when a match is
# made). Legacy connections match any with the same TOKEN, ignoring SIDE
# (so two legacy connections will match each other).
# I will send "ok\n" when the matching connection is established, or
# disconnect if no matching connection is made within MAX_WAIT_TIME
# seconds. I will disconnect if you send data before the "ok\n". All data
# you get after the "ok\n" will be from the other side. You will not
# receive "ok\n" until the other side has also connected and submitted a
# matching token (and differing SIDE).
# In addition, the connections will be dropped after MAXLENGTH bytes have
# been sent by either side, or MAXTIME seconds have elapsed after the
# matching connections were established. A future API will reveal these
# limits to clients instead of causing mysterious spontaneous failures.
# These relay connections are not half-closeable (unlike full TCP
# connections, applications will not receive any data after half-closing
# their outgoing side). Applications must negotiate shutdown with their
# peer and not close the connection until all data has finished
# transferring in both directions. Applications which only need to send
# data in one direction can use close() as usual.
MAX_WAIT_TIME = 30*SECONDS
MAXLENGTH = 10*MB
MAXTIME = 60*SECONDS
protocol = TransitConnection
def __init__(self, db, blur_usage):
service.MultiService.__init__(self)
self._db = db
self._blur_usage = blur_usage
self._log_requests = blur_usage is None
self._pending_requests = {} # token -> set((side, TransitConnection))
self._active_connections = set() # TransitConnection
self._counts = collections.defaultdict(int)
self._count_bytes = 0
def connection_got_token(self, token, new_side, new_tc):
if token not in self._pending_requests:
self._pending_requests[token] = set()
potentials = self._pending_requests[token]
for old in potentials:
(old_side, old_tc) = old
if ((old_side is None)
or (new_side is None)
or (old_side != new_side)):
# we found a match
if self._log_requests:
log.msg("transit relay 2: %s" % new_tc.describeToken())
# drop and stop tracking the rest
potentials.remove(old)
for (_, leftover_tc) in potentials:
leftover_tc.disconnect() # TODO: not "errory"?
self._pending_requests.pop(token)
# glue the two ends together
self._active_connections.add(new_tc)
self._active_connections.add(old_tc)
new_tc.buddy_connected(old_tc)
old_tc.buddy_connected(new_tc)
return
if self._log_requests:
log.msg("transit relay 1: %s" % new_tc.describeToken())
potentials.add((new_side, new_tc))
# TODO: timer
def recordUsage(self, started, result, total_bytes,
total_time, waiting_time):
if self._log_requests:
log.msg("Transit.recordUsage (%dB)" % total_bytes)
if self._blur_usage:
started = self._blur_usage * (started // self._blur_usage)
total_bytes = blur_size(total_bytes)
self._db.execute("INSERT INTO `transit_usage`"
" (`started`, `total_time`, `waiting_time`,"
" `total_bytes`, `result`)"
" VALUES (?,?,?, ?,?)",
(started, total_time, waiting_time,
total_bytes, result))
self._db.commit()
self._counts[result] += 1
self._count_bytes += total_bytes
def transitFinished(self, tc, token, side, description):
if token in self._pending_requests:
side_tc = (side, tc)
if side_tc in self._pending_requests[token]:
self._pending_requests[token].remove(side_tc)
if not self._pending_requests[token]: # set is now empty
del self._pending_requests[token]
if self._log_requests:
log.msg("transitFinished %s" % (description,))
self._active_connections.discard(tc)
def transitFailed(self, p):
if self._log_requests:
log.msg("transitFailed %r" % p)
pass
def get_stats(self):
stats = {}
def q(query, values=()):
row = self._db.execute(query, values).fetchone()
return list(row.values())[0]
# current status: expected to be zero most of the time
c = stats["active"] = {}
c["connected"] = len(self._active_connections) / 2
c["waiting"] = len(self._pending_requests)
# usage since last reboot
rb = stats["since_reboot"] = {}
rb["bytes"] = self._count_bytes
rb["total"] = sum(self._counts.values(), 0)
rbm = rb["moods"] = {}
for result, count in self._counts.items():
rbm[result] = count
# historical usage (all-time)
u = stats["all_time"] = {}
u["total"] = q("SELECT COUNT() FROM `transit_usage`")
u["bytes"] = q("SELECT SUM(`total_bytes`) FROM `transit_usage`") or 0
um = u["moods"] = {}
um["happy"] = q("SELECT COUNT() FROM `transit_usage`"
" WHERE `result`='happy'")
um["lonely"] = q("SELECT COUNT() FROM `transit_usage`"
" WHERE `result`='lonely'")
um["errory"] = q("SELECT COUNT() FROM `transit_usage`"
" WHERE `result`='errory'")
return stats

View File

@ -1,12 +1,13 @@
# no unicode_literals untill twisted update
from twisted.application import service
from twisted.internet import defer, task, reactor
from twisted.application import service, internet
from twisted.internet import defer, task, reactor, endpoints
from twisted.python import log
from click.testing import CliRunner
import mock
from ..cli import cli
from ..transit import allocate_tcp_port
from ..server.server import RelayServer
from wormhole_transit_relay.transit_server import Transit
class ServerBase:
def setUp(self):
@ -16,20 +17,25 @@ class ServerBase:
self.sp = service.MultiService()
self.sp.startService()
self.relayport = allocate_tcp_port()
self.transitport = allocate_tcp_port()
# need to talk to twisted team about only using unicode in
# endpoints.serverFromString
s = RelayServer("tcp:%d:interface=127.0.0.1" % self.relayport,
"tcp:%s:interface=127.0.0.1" % self.transitport,
advertise_version=advertise_version,
signal_error=error)
s.setServiceParent(self.sp)
self._relay_server = s
self._rendezvous = s._rendezvous
self._transit_server = s._transit
self.relayurl = u"ws://127.0.0.1:%d/v1" % self.relayport
self.rdv_ws_port = self.relayport
# ws://127.0.0.1:%d/wormhole-relay/ws
self.transitport = allocate_tcp_port()
ep = endpoints.serverFromString(reactor,
"tcp:%d:interface=127.0.0.1" %
self.transitport)
self._transit_server = f = Transit(blur_usage=None, log_file=None,
usage_db=None)
internet.StreamServerEndpointService(ep, f).setServiceParent(self.sp)
self.transit = u"tcp:127.0.0.1:%d" % self.transitport
def tearDown(self):

View File

@ -13,13 +13,11 @@ from ..server.database import get_db
def easy_relay(
rendezvous_web_port=str("tcp:0"),
transit_port=str("tcp:0"),
advertise_version=None,
**kwargs
):
return server.RelayServer(
rendezvous_web_port,
transit_port,
advertise_version,
**kwargs
)
@ -33,7 +31,7 @@ class RLimits(unittest.TestCase):
# is easier than just passing "tcp:0"
ep = endpoints.TCP4ServerEndpoint(None, 0)
with patch_s("endpoints.serverFromString", return_value=ep):
s = server.RelayServer("fake", None, None)
s = server.RelayServer("fake", None)
fakelog = []
def checklog(*expected):
self.assertEqual(fakelog, list(expected))
@ -1403,7 +1401,6 @@ class DumpStats(unittest.TestCase):
self.assertEqual(data["created"], now)
self.assertEqual(data["valid_until"], now+validity)
self.assertEqual(data["rendezvous"]["all_time"]["mailboxes_total"], 0)
self.assertEqual(data["transit"]["all_time"]["total"], 0)
class Startup(unittest.TestCase):

View File

@ -10,9 +10,9 @@ from twisted.internet import defer, task, endpoints, protocol, address, error
from twisted.internet.defer import gatherResults, inlineCallbacks
from twisted.python import log, failure
from twisted.test import proto_helpers
from wormhole_transit_relay import transit_server
from ..errors import InternalError
from .. import transit
from ..server import transit_server
from .common import ServerBase
from nacl.secret import SecretBox
from nacl.exceptions import CryptoError

View File

@ -1,306 +0,0 @@
from __future__ import print_function, unicode_literals
from binascii import hexlify
from twisted.trial import unittest
from twisted.internet import protocol, reactor, defer
from twisted.internet.endpoints import clientFromString, connectProtocol
from twisted.web import client
from .common import ServerBase
from ..server import transit_server
class Accumulator(protocol.Protocol):
def __init__(self):
self.data = b""
self.count = 0
self._wait = None
self._disconnect = defer.Deferred()
def waitForBytes(self, more):
assert self._wait is None
self.count = more
self._wait = defer.Deferred()
self._check_done()
return self._wait
def dataReceived(self, data):
self.data = self.data + data
self._check_done()
def _check_done(self):
if self._wait and len(self.data) >= self.count:
d = self._wait
self._wait = None
d.callback(self)
def connectionLost(self, why):
if self._wait:
self._wait.errback(RuntimeError("closed"))
self._disconnect.callback(None)
class Transit(ServerBase, unittest.TestCase):
def test_blur_size(self):
blur = transit_server.blur_size
self.failUnlessEqual(blur(0), 0)
self.failUnlessEqual(blur(1), 10e3)
self.failUnlessEqual(blur(10e3), 10e3)
self.failUnlessEqual(blur(10e3+1), 20e3)
self.failUnlessEqual(blur(15e3), 20e3)
self.failUnlessEqual(blur(20e3), 20e3)
self.failUnlessEqual(blur(1e6), 1e6)
self.failUnlessEqual(blur(1e6+1), 2e6)
self.failUnlessEqual(blur(1.5e6), 2e6)
self.failUnlessEqual(blur(2e6), 2e6)
self.failUnlessEqual(blur(900e6), 900e6)
self.failUnlessEqual(blur(1000e6), 1000e6)
self.failUnlessEqual(blur(1050e6), 1100e6)
self.failUnlessEqual(blur(1100e6), 1100e6)
self.failUnlessEqual(blur(1150e6), 1200e6)
@defer.inlineCallbacks
def test_web_request(self):
resp = yield client.getPage('http://127.0.0.1:{}/'.format(self.relayport).encode('ascii'))
self.assertEqual('Wormhole Relay'.encode('ascii'), resp.strip())
@defer.inlineCallbacks
def test_register(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
token1 = b"\x00"*32
side1 = b"\x01"*8
a1.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\n")
# let that arrive
while self.count() == 0:
yield self.wait()
self.assertEqual(self.count(), 1)
a1.transport.loseConnection()
# let that get removed
while self.count() > 0:
yield self.wait()
self.assertEqual(self.count(), 0)
# the token should be removed too
self.assertEqual(len(self._transit_server._pending_requests), 0)
@defer.inlineCallbacks
def test_both_unsided(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
a2 = yield connectProtocol(ep, Accumulator())
token1 = b"\x00"*32
a1.transport.write(b"please relay " + hexlify(token1) + b"\n")
a2.transport.write(b"please relay " + hexlify(token1) + b"\n")
# a correct handshake yields an ack, after which we can send
exp = b"ok\n"
yield a1.waitForBytes(len(exp))
self.assertEqual(a1.data, exp)
s1 = b"data1"
a1.transport.write(s1)
exp = b"ok\n"
yield a2.waitForBytes(len(exp))
self.assertEqual(a2.data, exp)
# all data they sent after the handshake should be given to us
exp = b"ok\n"+s1
yield a2.waitForBytes(len(exp))
self.assertEqual(a2.data, exp)
a1.transport.loseConnection()
a2.transport.loseConnection()
@defer.inlineCallbacks
def test_sided_unsided(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
a2 = yield connectProtocol(ep, Accumulator())
token1 = b"\x00"*32
side1 = b"\x01"*8
a1.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\n")
a2.transport.write(b"please relay " + hexlify(token1) + b"\n")
# a correct handshake yields an ack, after which we can send
exp = b"ok\n"
yield a1.waitForBytes(len(exp))
self.assertEqual(a1.data, exp)
s1 = b"data1"
a1.transport.write(s1)
exp = b"ok\n"
yield a2.waitForBytes(len(exp))
self.assertEqual(a2.data, exp)
# all data they sent after the handshake should be given to us
exp = b"ok\n"+s1
yield a2.waitForBytes(len(exp))
self.assertEqual(a2.data, exp)
a1.transport.loseConnection()
a2.transport.loseConnection()
@defer.inlineCallbacks
def test_unsided_sided(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
a2 = yield connectProtocol(ep, Accumulator())
token1 = b"\x00"*32
side1 = b"\x01"*8
a1.transport.write(b"please relay " + hexlify(token1) + b"\n")
a2.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\n")
# a correct handshake yields an ack, after which we can send
exp = b"ok\n"
yield a1.waitForBytes(len(exp))
self.assertEqual(a1.data, exp)
s1 = b"data1"
a1.transport.write(s1)
exp = b"ok\n"
yield a2.waitForBytes(len(exp))
self.assertEqual(a2.data, exp)
# all data they sent after the handshake should be given to us
exp = b"ok\n"+s1
yield a2.waitForBytes(len(exp))
self.assertEqual(a2.data, exp)
a1.transport.loseConnection()
a2.transport.loseConnection()
@defer.inlineCallbacks
def test_both_sided(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
a2 = yield connectProtocol(ep, Accumulator())
token1 = b"\x00"*32
side1 = b"\x01"*8
side2 = b"\x02"*8
a1.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\n")
a2.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side2) + b"\n")
# a correct handshake yields an ack, after which we can send
exp = b"ok\n"
yield a1.waitForBytes(len(exp))
self.assertEqual(a1.data, exp)
s1 = b"data1"
a1.transport.write(s1)
exp = b"ok\n"
yield a2.waitForBytes(len(exp))
self.assertEqual(a2.data, exp)
# all data they sent after the handshake should be given to us
exp = b"ok\n"+s1
yield a2.waitForBytes(len(exp))
self.assertEqual(a2.data, exp)
a1.transport.loseConnection()
a2.transport.loseConnection()
def count(self):
return sum([len(potentials)
for potentials
in self._transit_server._pending_requests.values()])
def wait(self):
d = defer.Deferred()
reactor.callLater(0.001, d.callback, None)
return d
@defer.inlineCallbacks
def test_ignore_same_side(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
a2 = yield connectProtocol(ep, Accumulator())
token1 = b"\x00"*32
side1 = b"\x01"*8
a1.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\n")
# let that arrive
while self.count() == 0:
yield self.wait()
a2.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\n")
# let that arrive
while self.count() == 1:
yield self.wait()
self.assertEqual(self.count(), 2) # same-side connections don't match
a1.transport.loseConnection()
a2.transport.loseConnection()
@defer.inlineCallbacks
def test_bad_handshake(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
token1 = b"\x00"*32
# the server waits for the exact number of bytes in the expected
# handshake message. to trigger "bad handshake", we must match.
a1.transport.write(b"please DELAY " + hexlify(token1) + b"\n")
exp = b"bad handshake\n"
yield a1.waitForBytes(len(exp))
self.assertEqual(a1.data, exp)
a1.transport.loseConnection()
@defer.inlineCallbacks
def test_binary_handshake(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
binary_bad_handshake = b"\x00\x01\xe0\x0f\n\xff"
# the embedded \n makes the server trigger early, before the full
# expected handshake length has arrived. A non-wormhole client
# writing non-ascii junk to the transit port used to trigger a
# UnicodeDecodeError when it tried to coerce the incoming handshake
# to unicode, due to the ("\n" in buf) check. This was fixed to use
# (b"\n" in buf). This exercises the old failure.
a1.transport.write(binary_bad_handshake)
exp = b"bad handshake\n"
yield a1.waitForBytes(len(exp))
self.assertEqual(a1.data, exp)
a1.transport.loseConnection()
@defer.inlineCallbacks
def test_impatience_old(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
token1 = b"\x00"*32
# sending too many bytes is impatience.
a1.transport.write(b"please relay " + hexlify(token1) + b"\nNOWNOWNOW")
exp = b"impatient\n"
yield a1.waitForBytes(len(exp))
self.assertEqual(a1.data, exp)
a1.transport.loseConnection()
@defer.inlineCallbacks
def test_impatience_new(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
token1 = b"\x00"*32
side1 = b"\x01"*8
# sending too many bytes is impatience.
a1.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\nNOWNOWNOW")
exp = b"impatient\n"
yield a1.waitForBytes(len(exp))
self.assertEqual(a1.data, exp)
a1.transport.loseConnection()