remove blocking.transit

This commit is contained in:
Brian Warner 2016-04-15 16:39:05 -07:00
parent 4e937c2100
commit db137c26e5
2 changed files with 1 additions and 502 deletions

View File

@ -1,400 +0,0 @@
from __future__ import print_function
import time, threading, socket
from six.moves import socketserver
from binascii import hexlify, unhexlify
from nacl.secret import SecretBox
from ..util import ipaddrs
from ..util.hkdf import HKDF
from ..errors import UsageError
from ..timing import DebugTiming
from ..transit_common import (TransitError, BadHandshake, TransitClosed,
BadNonce,
build_receiver_handshake,
build_sender_handshake,
build_relay_handshake,
parse_hint_tcp)
TIMEOUT=15
# 1: sender only transmits, receiver only accepts, both wait forever
# 2: sender also accepts, receiver also transmits
# 3: timeouts / stop when no more progress can be made
# 4: add relay
# 5: accelerate shutdown of losing sockets
def send_to(skt, data):
sent = 0
while sent < len(data):
sent += skt.send(data[sent:])
def wait_for_line(skt, max_length, description):
got = b""
while len(got) < max_length:
got += skt.recv(1)
if got.endswith(b"\n"):
return got[:-1]
raise BadHandshake("exceeded max_length, got %r on %s" %
(got, description))
def wait_for(skt, expected, description):
assert isinstance(expected, type(b""))
got = b""
while len(got) < len(expected):
got += skt.recv(1)
if expected[:len(got)] != got:
raise BadHandshake("got %r want %r on %s" %
(got, expected, description))
def debug(msg):
if False:
print(msg)
def since(start):
return time.time() - start
def connector(owner, hint, description,
send_handshake, expected_handshake, relay_handshake=None):
start = time.time()
parsed_hint = parse_hint_tcp(hint)
if not parsed_hint:
return # unparseable
addr,port = parsed_hint
skt = None
debug("+ connector(%s)" % hint)
try:
skt = socket.create_connection((addr,port),
TIMEOUT) # timeout or ECONNREFUSED
skt.settimeout(TIMEOUT)
debug(" - socket(%s) connected CT+%.1f" % (description, since(start)))
if relay_handshake:
debug(" - sending relay_handshake")
send_to(skt, relay_handshake)
relay_msg = wait_for_line(skt, 10000, description)
if relay_msg != b"ok":
raise BadHandshake(relay_msg)
debug(" - relay ready CT+%.1f" % (since(start),))
send_to(skt, send_handshake)
wait_for(skt, expected_handshake, description)
debug(" + connector(%s) ready CT+%.1f" % (hint, since(start)))
except Exception as e:
debug(" - error(%s)(%r) CT+%.1f" % (hint, e, since(start)))
try:
if skt:
skt.shutdown(socket.SHUT_WR)
except socket.error:
pass
if skt:
skt.close()
# ignore socket errors, warn about coding errors
if not isinstance(e, (socket.error, socket.timeout, BadHandshake)):
raise
debug(" - notifying owner._connector_failed(%s) CT+%.1f" % (hint, since(start)))
owner._connector_failed(hint)
return
# owner is now responsible for the socket
owner._negotiation_finished(skt, description) # note thread
def handle(skt, client_address, owner, description,
send_handshake, expected_handshake):
try:
debug("handle %r" % (skt,))
skt.settimeout(TIMEOUT)
send_to(skt, send_handshake)
got = b""
# for the receiver, this includes the "go\n"
while len(got) < len(expected_handshake):
more = skt.recv(1)
if not more:
raise BadHandshake("disconnect after merely '%r'" % got)
got += more
if expected_handshake[:len(got)] != got:
raise BadHandshake("got '%r' want '%r'" %
(got, expected_handshake))
debug("handler negotiation finished %r" % (client_address,))
except Exception as e:
debug("handler failed %r" % (client_address,))
try:
# this raises socket.err(EBADF) if the socket was already closed
skt.shutdown(socket.SHUT_WR)
except socket.error:
pass
skt.close() # this appears to be idempotent
# ignore socket errors, warn about coding errors
if not isinstance(e, (socket.error, socket.timeout, BadHandshake)):
raise
return
# owner is now responsible for the socket
owner._negotiation_finished(skt, description) # note thread
class MyTCPServer(socketserver.TCPServer):
allow_reuse_address = True
def process_request(self, request, client_address):
description = "<-tcp:%s:%d" % (client_address[0], client_address[1])
ready_lock = self.owner._ready_for_connections_lock
ready_lock.acquire()
while not (self.owner._ready_for_connections
and self.owner._transit_key):
ready_lock.wait()
# owner._transit_key is either None or set to a value. We don't
# modify it from here, so we can release the condition lock before
# grabbing the key.
ready_lock.release()
# Once it is set, we can get handler_(send|receive)_handshake, which
# is what we actually care about.
t = threading.Thread(target=handle,
args=(request, client_address,
self.owner, description,
self.owner.handler_send_handshake,
self.owner.handler_expected_handshake))
t.daemon = True
t.start()
class ReceiveBuffer:
def __init__(self, skt):
self.skt = skt
self.buf = b""
def read(self, count):
while len(self.buf) < count:
more = self.skt.recv(4096)
if not more:
raise TransitClosed
self.buf += more
rc = self.buf[:count]
self.buf = self.buf[count:]
return rc
class RecordPipe:
def __init__(self, skt, send_key, receive_key, description):
self.skt = skt
self.send_box = SecretBox(send_key)
self.send_nonce = 0
self.receive_buf = ReceiveBuffer(self.skt)
self.receive_box = SecretBox(receive_key)
self.next_receive_nonce = 0
self._description = description
def describe(self):
return self._description
def send_record(self, record):
if not isinstance(record, type(b"")): raise UsageError
assert SecretBox.NONCE_SIZE == 24
assert self.send_nonce < 2**(8*24)
assert len(record) < 2**(8*4)
nonce = unhexlify("%048x" % self.send_nonce) # big-endian
self.send_nonce += 1
encrypted = self.send_box.encrypt(record, nonce)
length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long
send_to(self.skt, length)
send_to(self.skt, encrypted)
def receive_record(self):
length_buf = self.receive_buf.read(4)
length = int(hexlify(length_buf), 16)
encrypted = self.receive_buf.read(length)
nonce_buf = encrypted[:SecretBox.NONCE_SIZE] # assume it's prepended
nonce = int(hexlify(nonce_buf), 16)
if nonce != self.next_receive_nonce:
raise BadNonce("received out-of-order record")
self.next_receive_nonce += 1
record = self.receive_box.decrypt(encrypted)
return record
def close(self):
self.skt.close()
class Common:
def __init__(self, transit_relay, no_listen=False, timing=None):
if transit_relay:
if not isinstance(transit_relay, type(u"")):
raise UsageError
self._transit_relays = [transit_relay]
else:
self._transit_relays = []
self._no_listen = no_listen
self._timing = timing or DebugTiming()
self._timing_started = self._timing.add_event("transit")
self.winning = threading.Event()
self._negotiation_check_lock = threading.Lock()
self._ready_for_connections_lock = threading.Condition()
self._ready_for_connections = False
self._transit_key = None
self._start_server()
def _start_server(self):
if self._no_listen:
self.my_direct_hints = []
self.listener = None
return
server = MyTCPServer(("", 0), None)
_, port = server.server_address
self.my_direct_hints = [u"tcp:%s:%d" % (addr, port)
for addr in ipaddrs.find_addresses()]
server.owner = self
server_thread = threading.Thread(target=server.serve_forever)
server_thread.daemon = True
server_thread.start()
self.listener = server
def get_direct_hints(self):
return self.my_direct_hints
def get_relay_hints(self):
return self._transit_relays
def add_their_direct_hints(self, hints):
for h in hints:
if not isinstance(h, type(u"")):
raise TypeError("hint '%r' should be unicode, not %s"
% (h, type(h)))
self._their_direct_hints = list(hints)
def add_their_relay_hints(self, hints):
for h in hints:
if not isinstance(h, type(u"")):
raise TypeError("hint '%r' should be unicode, not %s"
% (h, type(h)))
self._their_relay_hints = list(hints)
def _send_this(self):
if self.is_sender:
return build_sender_handshake(self._transit_key)
else:
return build_receiver_handshake(self._transit_key)
def _expect_this(self):
if self.is_sender:
return build_receiver_handshake(self._transit_key)
else:
return build_sender_handshake(self._transit_key) + b"go\n"
def _sender_record_key(self):
if self.is_sender:
return HKDF(self._transit_key, SecretBox.KEY_SIZE,
CTXinfo=b"transit_record_sender_key")
else:
return HKDF(self._transit_key, SecretBox.KEY_SIZE,
CTXinfo=b"transit_record_receiver_key")
def _receiver_record_key(self):
if self.is_sender:
return HKDF(self._transit_key, SecretBox.KEY_SIZE,
CTXinfo=b"transit_record_receiver_key")
else:
return HKDF(self._transit_key, SecretBox.KEY_SIZE,
CTXinfo=b"transit_record_sender_key")
def set_transit_key(self, key):
# This _ready_for_connections condition/lock protects us against the
# race where the sender knows the hints and the key, and connects to
# the receiver's transit socket before the receiver gets relay
# message (and thus the key).
self._ready_for_connections_lock.acquire()
self._transit_key = key
self.handler_send_handshake = self._send_this() # no "go"
self.handler_expected_handshake = self._expect_this()
self._ready_for_connections_lock.notify_all()
self._ready_for_connections_lock.release()
def _start_outbound(self):
self._active_connectors = set(self._their_direct_hints)
self._attempted_connectors = set()
for hint in self._their_direct_hints:
self._start_connector(hint)
if not self._their_direct_hints:
self._start_relay_connectors()
def _start_connector(self, hint, is_relay=False):
# Don't try any hint more than once. If all hints fail, we'll
# eventually timeout. We make no attempt to fail any faster.
if hint in self._attempted_connectors:
return
self._attempted_connectors.add(hint)
description = "->%s" % (hint,)
if is_relay:
description = "->relay:%s" % (hint,)
args = (self, hint, description,
self._send_this(), self._expect_this())
if is_relay:
args = args + (build_relay_handshake(self._transit_key),)
t = threading.Thread(target=connector, args=args)
t.daemon = True
t.start()
def _start_relay_connectors(self):
self._active_connectors.update(self._their_direct_hints)
for hint in self._their_relay_hints:
self._start_connector(hint, is_relay=True)
def establish_socket(self):
start = time.time()
self.winning_skt = None
self.winning_skt_description = None
self._ready_for_connections_lock.acquire()
self._ready_for_connections = True
self._ready_for_connections_lock.notify_all()
self._ready_for_connections_lock.release()
self._start_outbound()
# we sit here until one of our inbound or outbound sockets succeeds
flag = self.winning.wait(2*TIMEOUT)
debug("wait returned at %.1f" % (since(start),))
if not flag:
# timeout: self.winning_skt will not be set. ish. race.
pass
if self.listener:
self.listener.shutdown() # TODO: waits up to 0.5s. push to thread
if self.winning_skt:
return self.winning_skt
raise TransitError("timeout")
def _connector_failed(self, hint):
debug("- failed connector %s" % hint)
# XXX this was .remove, and occasionally got KeyError
self._active_connectors.discard(hint)
if not self._active_connectors:
self._start_relay_connectors()
def _negotiation_finished(self, skt, description):
# inbound/outbound sockets call this when they finish negotiation.
# The first one wins and gets a "go". Any subsequent ones lose and
# get a "nevermind" before being closed.
with self._negotiation_check_lock:
if self.winning_skt:
is_winner = False
else:
is_winner = True
self.winning_skt = skt
self.winning_skt_description = description
if is_winner:
if self.is_sender:
send_to(skt, b"go\n")
self.winning.set()
else:
if self.is_sender:
try:
send_to(skt, b"nevermind\n")
except socket.error:
# They realized this connection is not going to win, and
# closed it so fast we didn't get a chance to tell them
# it lost. This happens in unit tests.
pass
skt.close()
def connect(self):
_start = self._timing.add_event("transit connect")
skt = self.establish_socket()
self._timing.finish_event(_start)
return RecordPipe(skt, self._sender_record_key(),
self._receiver_record_key(),
self.winning_skt_description)
class TransitSender(Common):
is_sender = True
class TransitReceiver(Common):
is_sender = False

View File

@ -1,14 +1,11 @@
from __future__ import print_function from __future__ import print_function
import json import json
from twisted.trial import unittest from twisted.trial import unittest
from twisted.internet.defer import gatherResults, succeed, inlineCallbacks from twisted.internet.defer import gatherResults, succeed
from twisted.internet.threads import deferToThread from twisted.internet.threads import deferToThread
from ..blocking.transcribe import (Wormhole, UsageError, ChannelManager, from ..blocking.transcribe import (Wormhole, UsageError, ChannelManager,
WrongPasswordError) WrongPasswordError)
from ..blocking.eventsource import EventSourceFollower from ..blocking.eventsource import EventSourceFollower
from ..blocking.transit import (TransitSender, TransitReceiver,
build_sender_handshake,
build_receiver_handshake)
from .common import ServerBase from .common import ServerBase
APPID = u"appid" APPID = u"appid"
@ -447,101 +444,3 @@ class EventSourceClient(unittest.TestCase):
(u"message", u"three"), (u"message", u"three"),
(u"e2", u"four"), (u"e2", u"four"),
]) ])
class Transit(_DoBothMixin, ServerBase, unittest.TestCase):
def test_hints(self):
r = TransitReceiver(self.transit)
hints = r.get_direct_hints()
self.assertTrue(len(hints), hints)
@inlineCallbacks
def test_direct_to_receiver(self):
s = TransitSender(self.transit)
r = TransitReceiver(self.transit)
key = b"\x00"*32
# force the connection to be sender->receiver
s.set_transit_key(key)
# only use 127.0.0.1
hint = u"tcp:127.0.0.1:%d" % r.listener.server_address[1]
s.add_their_direct_hints([hint])
s.add_their_relay_hints([])
r.set_transit_key(key)
r.add_their_direct_hints([])
r.add_their_relay_hints([])
# it'd be nice to factor this chunk out with 'yield from', but that
# didn't appear until python-3.3, and isn't in py2 at all.
(sp, rp) = yield self.doBoth([s.connect], [r.connect])
yield deferToThread(sp.send_record, b"01234")
rec = yield deferToThread(rp.receive_record)
self.assertEqual(rec, b"01234")
yield deferToThread(sp.close)
yield deferToThread(rp.close)
@inlineCallbacks
def test_direct_to_sender(self):
s = TransitSender(self.transit)
r = TransitReceiver(self.transit)
key = b"\x00"*32
# force the connection to be receiver->sender
s.set_transit_key(key)
s.add_their_direct_hints([])
s.add_their_relay_hints([])
r.set_transit_key(key)
hint = u"tcp:127.0.0.1:%d" % s.listener.server_address[1]
r.add_their_direct_hints([hint])
r.add_their_relay_hints([])
(sp, rp) = yield self.doBoth([s.connect], [r.connect])
yield deferToThread(sp.send_record, b"01234")
rec = yield deferToThread(rp.receive_record)
self.assertEqual(rec, b"01234")
yield deferToThread(sp.close)
yield deferToThread(rp.close)
@inlineCallbacks
def test_relay(self):
s = TransitSender(self.transit)
r = TransitReceiver(self.transit)
key = b"\x00"*32
# force the connection to use the relay by not revealing direct hints
s.set_transit_key(key)
s.add_their_direct_hints([])
s.add_their_relay_hints(r.get_relay_hints())
r.set_transit_key(key)
r.add_their_direct_hints([])
r.add_their_relay_hints(s.get_relay_hints())
(sp, rp) = yield self.doBoth([s.connect], [r.connect])
yield deferToThread(sp.send_record, b"01234")
rec = yield deferToThread(rp.receive_record)
self.assertEqual(rec, b"01234")
yield deferToThread(sp.close)
yield deferToThread(rp.close)
# TODO: this may be racy if we don't poll the server to make sure
# it's witnessed the first connection closing before querying the DB
#import time
#yield deferToThread(time.sleep, 1)
# check the transit relay's DB, make sure it counted the bytes
db = self._transit_server._db
c = db.execute("SELECT * FROM `usage` WHERE `type`=?", (u"transit",))
rows = c.fetchall()
self.assertEqual(len(rows), 1)
row = rows[0]
self.assertEqual(row["result"], u"happy")
# Sender first writes relay_handshake and waits for OK, but that's
# not counted by the transit server. Then sender writes
# sender_handshake and waits for receiver_handshake. Then sender
# writes GO and the body. Body is length-prefixed SecretBox, so
# includes 4-byte length, 24-byte nonce, and 16-byte MAC.
sender_count = (len(build_sender_handshake(b""))+
len(b"go\n")+
4+24+len(b"01234")+16)
# Receiver first writes relay_handshake and waits for OK, but that's
# not counted. Then receiver writes receiver_handshake and waits for
# sender_handshake+GO.
receiver_count = len(build_receiver_handshake(b""))
self.assertEqual(row["total_bytes"], sender_count+receiver_count)