remove blocking.transit
This commit is contained in:
parent
4e937c2100
commit
db137c26e5
|
@ -1,400 +0,0 @@
|
||||||
from __future__ import print_function
|
|
||||||
import time, threading, socket
|
|
||||||
from six.moves import socketserver
|
|
||||||
from binascii import hexlify, unhexlify
|
|
||||||
from nacl.secret import SecretBox
|
|
||||||
from ..util import ipaddrs
|
|
||||||
from ..util.hkdf import HKDF
|
|
||||||
from ..errors import UsageError
|
|
||||||
from ..timing import DebugTiming
|
|
||||||
from ..transit_common import (TransitError, BadHandshake, TransitClosed,
|
|
||||||
BadNonce,
|
|
||||||
build_receiver_handshake,
|
|
||||||
build_sender_handshake,
|
|
||||||
build_relay_handshake,
|
|
||||||
parse_hint_tcp)
|
|
||||||
|
|
||||||
TIMEOUT=15
|
|
||||||
|
|
||||||
# 1: sender only transmits, receiver only accepts, both wait forever
|
|
||||||
# 2: sender also accepts, receiver also transmits
|
|
||||||
# 3: timeouts / stop when no more progress can be made
|
|
||||||
# 4: add relay
|
|
||||||
# 5: accelerate shutdown of losing sockets
|
|
||||||
|
|
||||||
def send_to(skt, data):
|
|
||||||
sent = 0
|
|
||||||
while sent < len(data):
|
|
||||||
sent += skt.send(data[sent:])
|
|
||||||
|
|
||||||
def wait_for_line(skt, max_length, description):
|
|
||||||
got = b""
|
|
||||||
while len(got) < max_length:
|
|
||||||
got += skt.recv(1)
|
|
||||||
if got.endswith(b"\n"):
|
|
||||||
return got[:-1]
|
|
||||||
raise BadHandshake("exceeded max_length, got %r on %s" %
|
|
||||||
(got, description))
|
|
||||||
|
|
||||||
def wait_for(skt, expected, description):
|
|
||||||
assert isinstance(expected, type(b""))
|
|
||||||
got = b""
|
|
||||||
while len(got) < len(expected):
|
|
||||||
got += skt.recv(1)
|
|
||||||
if expected[:len(got)] != got:
|
|
||||||
raise BadHandshake("got %r want %r on %s" %
|
|
||||||
(got, expected, description))
|
|
||||||
|
|
||||||
def debug(msg):
|
|
||||||
if False:
|
|
||||||
print(msg)
|
|
||||||
def since(start):
|
|
||||||
return time.time() - start
|
|
||||||
|
|
||||||
def connector(owner, hint, description,
|
|
||||||
send_handshake, expected_handshake, relay_handshake=None):
|
|
||||||
start = time.time()
|
|
||||||
parsed_hint = parse_hint_tcp(hint)
|
|
||||||
if not parsed_hint:
|
|
||||||
return # unparseable
|
|
||||||
addr,port = parsed_hint
|
|
||||||
skt = None
|
|
||||||
debug("+ connector(%s)" % hint)
|
|
||||||
try:
|
|
||||||
skt = socket.create_connection((addr,port),
|
|
||||||
TIMEOUT) # timeout or ECONNREFUSED
|
|
||||||
skt.settimeout(TIMEOUT)
|
|
||||||
debug(" - socket(%s) connected CT+%.1f" % (description, since(start)))
|
|
||||||
if relay_handshake:
|
|
||||||
debug(" - sending relay_handshake")
|
|
||||||
send_to(skt, relay_handshake)
|
|
||||||
relay_msg = wait_for_line(skt, 10000, description)
|
|
||||||
if relay_msg != b"ok":
|
|
||||||
raise BadHandshake(relay_msg)
|
|
||||||
debug(" - relay ready CT+%.1f" % (since(start),))
|
|
||||||
send_to(skt, send_handshake)
|
|
||||||
wait_for(skt, expected_handshake, description)
|
|
||||||
debug(" + connector(%s) ready CT+%.1f" % (hint, since(start)))
|
|
||||||
except Exception as e:
|
|
||||||
debug(" - error(%s)(%r) CT+%.1f" % (hint, e, since(start)))
|
|
||||||
try:
|
|
||||||
if skt:
|
|
||||||
skt.shutdown(socket.SHUT_WR)
|
|
||||||
except socket.error:
|
|
||||||
pass
|
|
||||||
if skt:
|
|
||||||
skt.close()
|
|
||||||
# ignore socket errors, warn about coding errors
|
|
||||||
if not isinstance(e, (socket.error, socket.timeout, BadHandshake)):
|
|
||||||
raise
|
|
||||||
debug(" - notifying owner._connector_failed(%s) CT+%.1f" % (hint, since(start)))
|
|
||||||
owner._connector_failed(hint)
|
|
||||||
return
|
|
||||||
# owner is now responsible for the socket
|
|
||||||
owner._negotiation_finished(skt, description) # note thread
|
|
||||||
|
|
||||||
def handle(skt, client_address, owner, description,
|
|
||||||
send_handshake, expected_handshake):
|
|
||||||
try:
|
|
||||||
debug("handle %r" % (skt,))
|
|
||||||
skt.settimeout(TIMEOUT)
|
|
||||||
send_to(skt, send_handshake)
|
|
||||||
got = b""
|
|
||||||
# for the receiver, this includes the "go\n"
|
|
||||||
while len(got) < len(expected_handshake):
|
|
||||||
more = skt.recv(1)
|
|
||||||
if not more:
|
|
||||||
raise BadHandshake("disconnect after merely '%r'" % got)
|
|
||||||
got += more
|
|
||||||
if expected_handshake[:len(got)] != got:
|
|
||||||
raise BadHandshake("got '%r' want '%r'" %
|
|
||||||
(got, expected_handshake))
|
|
||||||
debug("handler negotiation finished %r" % (client_address,))
|
|
||||||
except Exception as e:
|
|
||||||
debug("handler failed %r" % (client_address,))
|
|
||||||
try:
|
|
||||||
# this raises socket.err(EBADF) if the socket was already closed
|
|
||||||
skt.shutdown(socket.SHUT_WR)
|
|
||||||
except socket.error:
|
|
||||||
pass
|
|
||||||
skt.close() # this appears to be idempotent
|
|
||||||
# ignore socket errors, warn about coding errors
|
|
||||||
if not isinstance(e, (socket.error, socket.timeout, BadHandshake)):
|
|
||||||
raise
|
|
||||||
return
|
|
||||||
# owner is now responsible for the socket
|
|
||||||
owner._negotiation_finished(skt, description) # note thread
|
|
||||||
|
|
||||||
class MyTCPServer(socketserver.TCPServer):
|
|
||||||
allow_reuse_address = True
|
|
||||||
|
|
||||||
def process_request(self, request, client_address):
|
|
||||||
description = "<-tcp:%s:%d" % (client_address[0], client_address[1])
|
|
||||||
ready_lock = self.owner._ready_for_connections_lock
|
|
||||||
ready_lock.acquire()
|
|
||||||
while not (self.owner._ready_for_connections
|
|
||||||
and self.owner._transit_key):
|
|
||||||
ready_lock.wait()
|
|
||||||
# owner._transit_key is either None or set to a value. We don't
|
|
||||||
# modify it from here, so we can release the condition lock before
|
|
||||||
# grabbing the key.
|
|
||||||
ready_lock.release()
|
|
||||||
|
|
||||||
# Once it is set, we can get handler_(send|receive)_handshake, which
|
|
||||||
# is what we actually care about.
|
|
||||||
t = threading.Thread(target=handle,
|
|
||||||
args=(request, client_address,
|
|
||||||
self.owner, description,
|
|
||||||
self.owner.handler_send_handshake,
|
|
||||||
self.owner.handler_expected_handshake))
|
|
||||||
t.daemon = True
|
|
||||||
t.start()
|
|
||||||
|
|
||||||
|
|
||||||
class ReceiveBuffer:
|
|
||||||
def __init__(self, skt):
|
|
||||||
self.skt = skt
|
|
||||||
self.buf = b""
|
|
||||||
|
|
||||||
def read(self, count):
|
|
||||||
while len(self.buf) < count:
|
|
||||||
more = self.skt.recv(4096)
|
|
||||||
if not more:
|
|
||||||
raise TransitClosed
|
|
||||||
self.buf += more
|
|
||||||
rc = self.buf[:count]
|
|
||||||
self.buf = self.buf[count:]
|
|
||||||
return rc
|
|
||||||
|
|
||||||
class RecordPipe:
|
|
||||||
def __init__(self, skt, send_key, receive_key, description):
|
|
||||||
self.skt = skt
|
|
||||||
self.send_box = SecretBox(send_key)
|
|
||||||
self.send_nonce = 0
|
|
||||||
self.receive_buf = ReceiveBuffer(self.skt)
|
|
||||||
self.receive_box = SecretBox(receive_key)
|
|
||||||
self.next_receive_nonce = 0
|
|
||||||
self._description = description
|
|
||||||
|
|
||||||
def describe(self):
|
|
||||||
return self._description
|
|
||||||
|
|
||||||
def send_record(self, record):
|
|
||||||
if not isinstance(record, type(b"")): raise UsageError
|
|
||||||
assert SecretBox.NONCE_SIZE == 24
|
|
||||||
assert self.send_nonce < 2**(8*24)
|
|
||||||
assert len(record) < 2**(8*4)
|
|
||||||
nonce = unhexlify("%048x" % self.send_nonce) # big-endian
|
|
||||||
self.send_nonce += 1
|
|
||||||
encrypted = self.send_box.encrypt(record, nonce)
|
|
||||||
length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long
|
|
||||||
send_to(self.skt, length)
|
|
||||||
send_to(self.skt, encrypted)
|
|
||||||
|
|
||||||
def receive_record(self):
|
|
||||||
length_buf = self.receive_buf.read(4)
|
|
||||||
length = int(hexlify(length_buf), 16)
|
|
||||||
encrypted = self.receive_buf.read(length)
|
|
||||||
nonce_buf = encrypted[:SecretBox.NONCE_SIZE] # assume it's prepended
|
|
||||||
nonce = int(hexlify(nonce_buf), 16)
|
|
||||||
if nonce != self.next_receive_nonce:
|
|
||||||
raise BadNonce("received out-of-order record")
|
|
||||||
self.next_receive_nonce += 1
|
|
||||||
record = self.receive_box.decrypt(encrypted)
|
|
||||||
return record
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
self.skt.close()
|
|
||||||
|
|
||||||
class Common:
|
|
||||||
def __init__(self, transit_relay, no_listen=False, timing=None):
|
|
||||||
if transit_relay:
|
|
||||||
if not isinstance(transit_relay, type(u"")):
|
|
||||||
raise UsageError
|
|
||||||
self._transit_relays = [transit_relay]
|
|
||||||
else:
|
|
||||||
self._transit_relays = []
|
|
||||||
self._no_listen = no_listen
|
|
||||||
self._timing = timing or DebugTiming()
|
|
||||||
self._timing_started = self._timing.add_event("transit")
|
|
||||||
self.winning = threading.Event()
|
|
||||||
self._negotiation_check_lock = threading.Lock()
|
|
||||||
self._ready_for_connections_lock = threading.Condition()
|
|
||||||
self._ready_for_connections = False
|
|
||||||
self._transit_key = None
|
|
||||||
self._start_server()
|
|
||||||
|
|
||||||
def _start_server(self):
|
|
||||||
if self._no_listen:
|
|
||||||
self.my_direct_hints = []
|
|
||||||
self.listener = None
|
|
||||||
return
|
|
||||||
server = MyTCPServer(("", 0), None)
|
|
||||||
_, port = server.server_address
|
|
||||||
self.my_direct_hints = [u"tcp:%s:%d" % (addr, port)
|
|
||||||
for addr in ipaddrs.find_addresses()]
|
|
||||||
server.owner = self
|
|
||||||
server_thread = threading.Thread(target=server.serve_forever)
|
|
||||||
server_thread.daemon = True
|
|
||||||
server_thread.start()
|
|
||||||
self.listener = server
|
|
||||||
|
|
||||||
def get_direct_hints(self):
|
|
||||||
return self.my_direct_hints
|
|
||||||
def get_relay_hints(self):
|
|
||||||
return self._transit_relays
|
|
||||||
|
|
||||||
def add_their_direct_hints(self, hints):
|
|
||||||
for h in hints:
|
|
||||||
if not isinstance(h, type(u"")):
|
|
||||||
raise TypeError("hint '%r' should be unicode, not %s"
|
|
||||||
% (h, type(h)))
|
|
||||||
self._their_direct_hints = list(hints)
|
|
||||||
def add_their_relay_hints(self, hints):
|
|
||||||
for h in hints:
|
|
||||||
if not isinstance(h, type(u"")):
|
|
||||||
raise TypeError("hint '%r' should be unicode, not %s"
|
|
||||||
% (h, type(h)))
|
|
||||||
self._their_relay_hints = list(hints)
|
|
||||||
|
|
||||||
def _send_this(self):
|
|
||||||
if self.is_sender:
|
|
||||||
return build_sender_handshake(self._transit_key)
|
|
||||||
else:
|
|
||||||
return build_receiver_handshake(self._transit_key)
|
|
||||||
|
|
||||||
def _expect_this(self):
|
|
||||||
if self.is_sender:
|
|
||||||
return build_receiver_handshake(self._transit_key)
|
|
||||||
else:
|
|
||||||
return build_sender_handshake(self._transit_key) + b"go\n"
|
|
||||||
|
|
||||||
def _sender_record_key(self):
|
|
||||||
if self.is_sender:
|
|
||||||
return HKDF(self._transit_key, SecretBox.KEY_SIZE,
|
|
||||||
CTXinfo=b"transit_record_sender_key")
|
|
||||||
else:
|
|
||||||
return HKDF(self._transit_key, SecretBox.KEY_SIZE,
|
|
||||||
CTXinfo=b"transit_record_receiver_key")
|
|
||||||
|
|
||||||
def _receiver_record_key(self):
|
|
||||||
if self.is_sender:
|
|
||||||
return HKDF(self._transit_key, SecretBox.KEY_SIZE,
|
|
||||||
CTXinfo=b"transit_record_receiver_key")
|
|
||||||
else:
|
|
||||||
return HKDF(self._transit_key, SecretBox.KEY_SIZE,
|
|
||||||
CTXinfo=b"transit_record_sender_key")
|
|
||||||
|
|
||||||
def set_transit_key(self, key):
|
|
||||||
# This _ready_for_connections condition/lock protects us against the
|
|
||||||
# race where the sender knows the hints and the key, and connects to
|
|
||||||
# the receiver's transit socket before the receiver gets relay
|
|
||||||
# message (and thus the key).
|
|
||||||
self._ready_for_connections_lock.acquire()
|
|
||||||
self._transit_key = key
|
|
||||||
self.handler_send_handshake = self._send_this() # no "go"
|
|
||||||
self.handler_expected_handshake = self._expect_this()
|
|
||||||
self._ready_for_connections_lock.notify_all()
|
|
||||||
self._ready_for_connections_lock.release()
|
|
||||||
|
|
||||||
def _start_outbound(self):
|
|
||||||
self._active_connectors = set(self._their_direct_hints)
|
|
||||||
self._attempted_connectors = set()
|
|
||||||
for hint in self._their_direct_hints:
|
|
||||||
self._start_connector(hint)
|
|
||||||
if not self._their_direct_hints:
|
|
||||||
self._start_relay_connectors()
|
|
||||||
|
|
||||||
def _start_connector(self, hint, is_relay=False):
|
|
||||||
# Don't try any hint more than once. If all hints fail, we'll
|
|
||||||
# eventually timeout. We make no attempt to fail any faster.
|
|
||||||
if hint in self._attempted_connectors:
|
|
||||||
return
|
|
||||||
self._attempted_connectors.add(hint)
|
|
||||||
description = "->%s" % (hint,)
|
|
||||||
if is_relay:
|
|
||||||
description = "->relay:%s" % (hint,)
|
|
||||||
args = (self, hint, description,
|
|
||||||
self._send_this(), self._expect_this())
|
|
||||||
if is_relay:
|
|
||||||
args = args + (build_relay_handshake(self._transit_key),)
|
|
||||||
t = threading.Thread(target=connector, args=args)
|
|
||||||
t.daemon = True
|
|
||||||
t.start()
|
|
||||||
|
|
||||||
def _start_relay_connectors(self):
|
|
||||||
self._active_connectors.update(self._their_direct_hints)
|
|
||||||
for hint in self._their_relay_hints:
|
|
||||||
self._start_connector(hint, is_relay=True)
|
|
||||||
|
|
||||||
def establish_socket(self):
|
|
||||||
start = time.time()
|
|
||||||
self.winning_skt = None
|
|
||||||
self.winning_skt_description = None
|
|
||||||
self._ready_for_connections_lock.acquire()
|
|
||||||
self._ready_for_connections = True
|
|
||||||
self._ready_for_connections_lock.notify_all()
|
|
||||||
self._ready_for_connections_lock.release()
|
|
||||||
self._start_outbound()
|
|
||||||
|
|
||||||
# we sit here until one of our inbound or outbound sockets succeeds
|
|
||||||
flag = self.winning.wait(2*TIMEOUT)
|
|
||||||
debug("wait returned at %.1f" % (since(start),))
|
|
||||||
|
|
||||||
if not flag:
|
|
||||||
# timeout: self.winning_skt will not be set. ish. race.
|
|
||||||
pass
|
|
||||||
if self.listener:
|
|
||||||
self.listener.shutdown() # TODO: waits up to 0.5s. push to thread
|
|
||||||
if self.winning_skt:
|
|
||||||
return self.winning_skt
|
|
||||||
raise TransitError("timeout")
|
|
||||||
|
|
||||||
def _connector_failed(self, hint):
|
|
||||||
debug("- failed connector %s" % hint)
|
|
||||||
# XXX this was .remove, and occasionally got KeyError
|
|
||||||
self._active_connectors.discard(hint)
|
|
||||||
if not self._active_connectors:
|
|
||||||
self._start_relay_connectors()
|
|
||||||
|
|
||||||
def _negotiation_finished(self, skt, description):
|
|
||||||
# inbound/outbound sockets call this when they finish negotiation.
|
|
||||||
# The first one wins and gets a "go". Any subsequent ones lose and
|
|
||||||
# get a "nevermind" before being closed.
|
|
||||||
|
|
||||||
with self._negotiation_check_lock:
|
|
||||||
if self.winning_skt:
|
|
||||||
is_winner = False
|
|
||||||
else:
|
|
||||||
is_winner = True
|
|
||||||
self.winning_skt = skt
|
|
||||||
self.winning_skt_description = description
|
|
||||||
|
|
||||||
if is_winner:
|
|
||||||
if self.is_sender:
|
|
||||||
send_to(skt, b"go\n")
|
|
||||||
self.winning.set()
|
|
||||||
else:
|
|
||||||
if self.is_sender:
|
|
||||||
try:
|
|
||||||
send_to(skt, b"nevermind\n")
|
|
||||||
except socket.error:
|
|
||||||
# They realized this connection is not going to win, and
|
|
||||||
# closed it so fast we didn't get a chance to tell them
|
|
||||||
# it lost. This happens in unit tests.
|
|
||||||
pass
|
|
||||||
skt.close()
|
|
||||||
|
|
||||||
def connect(self):
|
|
||||||
_start = self._timing.add_event("transit connect")
|
|
||||||
skt = self.establish_socket()
|
|
||||||
self._timing.finish_event(_start)
|
|
||||||
return RecordPipe(skt, self._sender_record_key(),
|
|
||||||
self._receiver_record_key(),
|
|
||||||
self.winning_skt_description)
|
|
||||||
|
|
||||||
class TransitSender(Common):
|
|
||||||
is_sender = True
|
|
||||||
|
|
||||||
class TransitReceiver(Common):
|
|
||||||
is_sender = False
|
|
|
@ -1,14 +1,11 @@
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
import json
|
import json
|
||||||
from twisted.trial import unittest
|
from twisted.trial import unittest
|
||||||
from twisted.internet.defer import gatherResults, succeed, inlineCallbacks
|
from twisted.internet.defer import gatherResults, succeed
|
||||||
from twisted.internet.threads import deferToThread
|
from twisted.internet.threads import deferToThread
|
||||||
from ..blocking.transcribe import (Wormhole, UsageError, ChannelManager,
|
from ..blocking.transcribe import (Wormhole, UsageError, ChannelManager,
|
||||||
WrongPasswordError)
|
WrongPasswordError)
|
||||||
from ..blocking.eventsource import EventSourceFollower
|
from ..blocking.eventsource import EventSourceFollower
|
||||||
from ..blocking.transit import (TransitSender, TransitReceiver,
|
|
||||||
build_sender_handshake,
|
|
||||||
build_receiver_handshake)
|
|
||||||
from .common import ServerBase
|
from .common import ServerBase
|
||||||
|
|
||||||
APPID = u"appid"
|
APPID = u"appid"
|
||||||
|
@ -447,101 +444,3 @@ class EventSourceClient(unittest.TestCase):
|
||||||
(u"message", u"three"),
|
(u"message", u"three"),
|
||||||
(u"e2", u"four"),
|
(u"e2", u"four"),
|
||||||
])
|
])
|
||||||
|
|
||||||
class Transit(_DoBothMixin, ServerBase, unittest.TestCase):
|
|
||||||
def test_hints(self):
|
|
||||||
r = TransitReceiver(self.transit)
|
|
||||||
hints = r.get_direct_hints()
|
|
||||||
self.assertTrue(len(hints), hints)
|
|
||||||
|
|
||||||
@inlineCallbacks
|
|
||||||
def test_direct_to_receiver(self):
|
|
||||||
s = TransitSender(self.transit)
|
|
||||||
r = TransitReceiver(self.transit)
|
|
||||||
key = b"\x00"*32
|
|
||||||
|
|
||||||
# force the connection to be sender->receiver
|
|
||||||
s.set_transit_key(key)
|
|
||||||
# only use 127.0.0.1
|
|
||||||
hint = u"tcp:127.0.0.1:%d" % r.listener.server_address[1]
|
|
||||||
s.add_their_direct_hints([hint])
|
|
||||||
s.add_their_relay_hints([])
|
|
||||||
r.set_transit_key(key)
|
|
||||||
r.add_their_direct_hints([])
|
|
||||||
r.add_their_relay_hints([])
|
|
||||||
|
|
||||||
# it'd be nice to factor this chunk out with 'yield from', but that
|
|
||||||
# didn't appear until python-3.3, and isn't in py2 at all.
|
|
||||||
(sp, rp) = yield self.doBoth([s.connect], [r.connect])
|
|
||||||
yield deferToThread(sp.send_record, b"01234")
|
|
||||||
rec = yield deferToThread(rp.receive_record)
|
|
||||||
self.assertEqual(rec, b"01234")
|
|
||||||
yield deferToThread(sp.close)
|
|
||||||
yield deferToThread(rp.close)
|
|
||||||
|
|
||||||
@inlineCallbacks
|
|
||||||
def test_direct_to_sender(self):
|
|
||||||
s = TransitSender(self.transit)
|
|
||||||
r = TransitReceiver(self.transit)
|
|
||||||
key = b"\x00"*32
|
|
||||||
|
|
||||||
# force the connection to be receiver->sender
|
|
||||||
s.set_transit_key(key)
|
|
||||||
s.add_their_direct_hints([])
|
|
||||||
s.add_their_relay_hints([])
|
|
||||||
r.set_transit_key(key)
|
|
||||||
hint = u"tcp:127.0.0.1:%d" % s.listener.server_address[1]
|
|
||||||
r.add_their_direct_hints([hint])
|
|
||||||
r.add_their_relay_hints([])
|
|
||||||
|
|
||||||
(sp, rp) = yield self.doBoth([s.connect], [r.connect])
|
|
||||||
yield deferToThread(sp.send_record, b"01234")
|
|
||||||
rec = yield deferToThread(rp.receive_record)
|
|
||||||
self.assertEqual(rec, b"01234")
|
|
||||||
yield deferToThread(sp.close)
|
|
||||||
yield deferToThread(rp.close)
|
|
||||||
|
|
||||||
@inlineCallbacks
|
|
||||||
def test_relay(self):
|
|
||||||
s = TransitSender(self.transit)
|
|
||||||
r = TransitReceiver(self.transit)
|
|
||||||
key = b"\x00"*32
|
|
||||||
# force the connection to use the relay by not revealing direct hints
|
|
||||||
s.set_transit_key(key)
|
|
||||||
s.add_their_direct_hints([])
|
|
||||||
s.add_their_relay_hints(r.get_relay_hints())
|
|
||||||
r.set_transit_key(key)
|
|
||||||
r.add_their_direct_hints([])
|
|
||||||
r.add_their_relay_hints(s.get_relay_hints())
|
|
||||||
|
|
||||||
(sp, rp) = yield self.doBoth([s.connect], [r.connect])
|
|
||||||
yield deferToThread(sp.send_record, b"01234")
|
|
||||||
rec = yield deferToThread(rp.receive_record)
|
|
||||||
self.assertEqual(rec, b"01234")
|
|
||||||
yield deferToThread(sp.close)
|
|
||||||
yield deferToThread(rp.close)
|
|
||||||
# TODO: this may be racy if we don't poll the server to make sure
|
|
||||||
# it's witnessed the first connection closing before querying the DB
|
|
||||||
#import time
|
|
||||||
#yield deferToThread(time.sleep, 1)
|
|
||||||
|
|
||||||
# check the transit relay's DB, make sure it counted the bytes
|
|
||||||
db = self._transit_server._db
|
|
||||||
c = db.execute("SELECT * FROM `usage` WHERE `type`=?", (u"transit",))
|
|
||||||
rows = c.fetchall()
|
|
||||||
self.assertEqual(len(rows), 1)
|
|
||||||
row = rows[0]
|
|
||||||
self.assertEqual(row["result"], u"happy")
|
|
||||||
# Sender first writes relay_handshake and waits for OK, but that's
|
|
||||||
# not counted by the transit server. Then sender writes
|
|
||||||
# sender_handshake and waits for receiver_handshake. Then sender
|
|
||||||
# writes GO and the body. Body is length-prefixed SecretBox, so
|
|
||||||
# includes 4-byte length, 24-byte nonce, and 16-byte MAC.
|
|
||||||
sender_count = (len(build_sender_handshake(b""))+
|
|
||||||
len(b"go\n")+
|
|
||||||
4+24+len(b"01234")+16)
|
|
||||||
# Receiver first writes relay_handshake and waits for OK, but that's
|
|
||||||
# not counted. Then receiver writes receiver_handshake and waits for
|
|
||||||
# sender_handshake+GO.
|
|
||||||
receiver_count = len(build_receiver_handshake(b""))
|
|
||||||
self.assertEqual(row["total_bytes"], sender_count+receiver_count)
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user