magic-wormhole/src/wormhole/blocking/transit.py
Brian Warner 95d0e68cf2 transit: avoid near-infinite loop upon connector error
Now we will never try any hint more than once. Previously we'd hit the
relay hint over and over until the timeout fired.
2015-12-03 16:22:03 -06:00

442 lines
16 KiB
Python

from __future__ import print_function
import re, time, threading, socket
from six.moves import socketserver
from binascii import hexlify, unhexlify
from nacl.secret import SecretBox
from ..util import ipaddrs
from ..util.hkdf import HKDF
from ..errors import UsageError
class TransitError(Exception):
pass
# The beginning of each TCP connection consists of the following handshake
# messages. The sender transmits the same text regardless of whether it is on
# the initiating/connecting end of the TCP connection, or on the
# listening/accepting side. Same for the receiver.
#
# sender -> receiver: transit sender TXID_HEX ready\n\n
# receiver -> sender: transit receiver RXID_HEX ready\n\n
#
# Any deviations from this result in the socket being closed. The handshake
# messages are designed to provoke an invalid response from other sorts of
# servers (HTTP, SMTP, echo).
#
# If the sender is satisfied with the handshake, and this is the first socket
# to complete negotiation, the sender does:
#
# sender -> receiver: go\n
#
# and the next byte on the wire will be from the application.
#
# If this is not the first socket, the sender does:
#
# sender -> receiver: nevermind\n
#
# and closes the socket.
# So the receiver looks for "transit sender TXID_HEX ready\n\ngo\n" and hangs
# up upon the first wrong byte. The sender lookgs for "transit receiver
# RXID_HEX ready\n\n" and then makes a first/not-first decision about sending
# "go\n" or "nevermind\n"+close().
def build_receiver_handshake(key):
hexid = HKDF(key, 32, CTXinfo=b"transit_receiver")
return b"transit receiver "+hexlify(hexid)+b" ready\n\n"
def build_sender_handshake(key):
hexid = HKDF(key, 32, CTXinfo=b"transit_sender")
return b"transit sender "+hexlify(hexid)+b" ready\n\n"
def build_relay_handshake(key):
token = HKDF(key, 32, CTXinfo=b"transit_relay_token")
return b"please relay "+hexlify(token)+b"\n"
TIMEOUT=15
# 1: sender only transmits, receiver only accepts, both wait forever
# 2: sender also accepts, receiver also transmits
# 3: timeouts / stop when no more progress can be made
# 4: add relay
# 5: accelerate shutdown of losing sockets
class BadHandshake(Exception):
pass
def send_to(skt, data):
sent = 0
while sent < len(data):
sent += skt.send(data[sent:])
def wait_for(skt, expected, description):
got = b""
while len(got) < len(expected):
got += skt.recv(1)
if expected[:len(got)] != got:
raise BadHandshake("got '%r' want '%r' on %s" %
(got, expected, description))
# The hint format is: TYPE,VALUE= /^([a-zA-Z0-9]+):(.*)$/ . VALUE depends
# upon TYPE, and it can have more colons in it. For TYPE=tcp (the only one
# currently defined), ADDR,PORT = /^(.*):(\d+)$/ , so ADDR can have colons.
# ADDR can be a hostname, ipv4 dotted-quad, or ipv6 colon-hex. If the hint
# publisher wants anonymity, their only hint's ADDR will end in .onion .
def parse_hint_tcp(hint):
assert isinstance(hint, type(u""))
# return tuple or None for an unparseable hint
mo = re.search(r'^([a-zA-Z0-9]+):(.*)$', hint)
if not mo:
print("unparseable hint '%s'" % (hint,))
return None
hint_type = mo.group(1)
if hint_type != "tcp":
print("unknown hint type '%s' in '%s'" % (hint_type, hint))
return None
hint_value = mo.group(2)
mo = re.search(r'^(.*):(\d+)$', hint_value)
if not mo:
print("unparseable TCP hint '%s'" % (hint,))
return None
hint_host = mo.group(1)
try:
hint_port = int(mo.group(2))
except ValueError:
print("non-numeric port in TCP hint '%s'" % (hint,))
return None
return hint_host, hint_port
def debug(msg):
if False:
print(msg)
def since(start):
return time.time() - start
def connector(owner, hint, description,
send_handshake, expected_handshake, relay_handshake=None):
start = time.time()
parsed_hint = parse_hint_tcp(hint)
if not parsed_hint:
return # unparseable
addr,port = parsed_hint
skt = None
debug("+ connector(%s)" % hint)
try:
skt = socket.create_connection((addr,port),
TIMEOUT) # timeout or ECONNREFUSED
skt.settimeout(TIMEOUT)
debug(" - socket(%s) connected CT+%.1f" % (description, since(start)))
if relay_handshake:
debug(" - sending relay_handshake")
send_to(skt, relay_handshake)
wait_for(skt, "ok\n", description)
debug(" - relay ready CT+%.1f" % (since(start),))
send_to(skt, send_handshake)
wait_for(skt, expected_handshake, description)
debug(" + connector(%s) ready CT+%.1f" % (hint, since(start)))
except Exception as e:
debug(" - timeout(%s) CT+%.1f" % (hint, since(start)))
try:
if skt:
skt.shutdown(socket.SHUT_WR)
except socket.error:
pass
if skt:
skt.close()
# ignore socket errors, warn about coding errors
if not isinstance(e, (socket.error, socket.timeout, BadHandshake)):
raise
debug(" - notifying owner._connector_failed(%s) CT+%.1f" % (hint, since(start)))
owner._connector_failed(hint)
return
# owner is now responsible for the socket
owner._negotiation_finished(skt, description) # note thread
def handle(skt, client_address, owner, description,
send_handshake, expected_handshake):
try:
debug("handle %r" % (skt,))
skt.settimeout(TIMEOUT)
send_to(skt, send_handshake)
got = b""
# for the receiver, this includes the "go\n"
while len(got) < len(expected_handshake):
more = skt.recv(1)
if not more:
raise BadHandshake("disconnect after merely '%r'" % got)
got += more
if expected_handshake[:len(got)] != got:
raise BadHandshake("got '%r' want '%r'" %
(got, expected_handshake))
debug("handler negotiation finished %r" % (client_address,))
except Exception as e:
debug("handler failed %r" % (client_address,))
try:
# this raises socket.err(EBADF) if the socket was already closed
skt.shutdown(socket.SHUT_WR)
except socket.error:
pass
skt.close() # this appears to be idempotent
# ignore socket errors, warn about coding errors
if not isinstance(e, (socket.error, socket.timeout, BadHandshake)):
raise
return
# owner is now responsible for the socket
owner._negotiation_finished(skt, description) # note thread
class MyTCPServer(socketserver.TCPServer):
allow_reuse_address = True
def process_request(self, request, client_address):
description = "<-tcp:%s:%d" % (client_address[0], client_address[1])
kc = self.owner._have_transit_key
kc.acquire()
while not self.owner._transit_key:
kc.wait()
# owner._transit_key is either None or set to a value. We don't
# modify it from here, so we can release the condition lock before
# grabbing the key.
kc.release()
# Once it is set, we can get handler_(send|receive)_handshake, which
# is what we actually care about.
t = threading.Thread(target=handle,
args=(request, client_address,
self.owner, description,
self.owner.handler_send_handshake,
self.owner.handler_expected_handshake))
t.daemon = True
t.start()
class TransitClosed(TransitError):
pass
class BadNonce(TransitError):
pass
class ReceiveBuffer:
def __init__(self, skt):
self.skt = skt
self.buf = b""
def read(self, count):
while len(self.buf) < count:
more = self.skt.recv(4096)
if not more:
raise TransitClosed
self.buf += more
rc = self.buf[:count]
self.buf = self.buf[count:]
return rc
class RecordPipe:
def __init__(self, skt, send_key, receive_key):
self.skt = skt
self.send_box = SecretBox(send_key)
self.send_nonce = 0
self.receive_buf = ReceiveBuffer(self.skt)
self.receive_box = SecretBox(receive_key)
self.next_receive_nonce = 0
def send_record(self, record):
if not isinstance(record, type(b"")): raise UsageError
assert SecretBox.NONCE_SIZE == 24
assert self.send_nonce < 2**(8*24)
assert len(record) < 2**(8*4)
nonce = unhexlify("%048x" % self.send_nonce) # big-endian
self.send_nonce += 1
encrypted = self.send_box.encrypt(record, nonce)
length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long
send_to(self.skt, length)
send_to(self.skt, encrypted)
def receive_record(self):
length_buf = self.receive_buf.read(4)
length = int(hexlify(length_buf), 16)
encrypted = self.receive_buf.read(length)
nonce_buf = encrypted[:SecretBox.NONCE_SIZE] # assume it's prepended
nonce = int(hexlify(nonce_buf), 16)
if nonce != self.next_receive_nonce:
raise BadNonce("received out-of-order record")
self.next_receive_nonce += 1
record = self.receive_box.decrypt(encrypted)
return record
def close(self):
self.skt.close()
class Common:
def __init__(self, transit_relay):
if not isinstance(transit_relay, type(u"")): raise UsageError
self._transit_relay = transit_relay
self.winning = threading.Event()
self._negotiation_check_lock = threading.Lock()
self._have_transit_key = threading.Condition()
self._transit_key = None
self._start_server()
def _start_server(self):
server = MyTCPServer(("", 0), None)
_, port = server.server_address
self.my_direct_hints = [u"tcp:%s:%d" % (addr, port)
for addr in ipaddrs.find_addresses()]
server.owner = self
server_thread = threading.Thread(target=server.serve_forever)
server_thread.daemon = True
server_thread.start()
self.listener = server
def get_direct_hints(self):
return self.my_direct_hints
def get_relay_hints(self):
return [self._transit_relay]
def add_their_direct_hints(self, hints):
for h in hints:
if not isinstance(h, type(u"")):
raise TypeError("hint '%r' should be unicode, not %s"
% (h, type(h)))
self._their_direct_hints = list(hints)
def add_their_relay_hints(self, hints):
for h in hints:
if not isinstance(h, type(u"")):
raise TypeError("hint '%r' should be unicode, not %s"
% (h, type(h)))
self._their_relay_hints = list(hints)
def _send_this(self):
if self.is_sender:
return build_sender_handshake(self._transit_key)
else:
return build_receiver_handshake(self._transit_key)
def _expect_this(self):
if self.is_sender:
return build_receiver_handshake(self._transit_key)
else:
return build_sender_handshake(self._transit_key) + b"go\n"
def _sender_record_key(self):
if self.is_sender:
return HKDF(self._transit_key, SecretBox.KEY_SIZE,
CTXinfo=b"transit_record_sender_key")
else:
return HKDF(self._transit_key, SecretBox.KEY_SIZE,
CTXinfo=b"transit_record_receiver_key")
def _receiver_record_key(self):
if self.is_sender:
return HKDF(self._transit_key, SecretBox.KEY_SIZE,
CTXinfo=b"transit_record_receiver_key")
else:
return HKDF(self._transit_key, SecretBox.KEY_SIZE,
CTXinfo=b"transit_record_sender_key")
def set_transit_key(self, key):
# This _have_transit_key condition/lock protects us against the race
# where the sender knows the hints and the key, and connects to the
# receiver's transit socket before the receiver gets relay message
# (and thus the key).
self._have_transit_key.acquire()
self._transit_key = key
self.handler_send_handshake = self._send_this() # no "go"
self.handler_expected_handshake = self._expect_this()
self._have_transit_key.notify_all()
self._have_transit_key.release()
def _start_outbound(self):
self._active_connectors = set(self._their_direct_hints)
self._attempted_connectors = set()
for hint in self._their_direct_hints:
self._start_connector(hint)
if not self._their_direct_hints:
self._start_relay_connectors()
def _start_connector(self, hint, is_relay=False):
# Don't try any hint more than once. If all hints fail, we'll
# eventually timeout. We make no attempt to fail any faster.
if hint in self._attempted_connectors:
return
self._attempted_connectors.add(hint)
description = "->%s" % (hint,)
if is_relay:
description = "->relay:%s" % (hint,)
args = (self, hint, description,
self._send_this(), self._expect_this())
if is_relay:
args = args + (build_relay_handshake(self._transit_key),)
t = threading.Thread(target=connector, args=args)
t.daemon = True
t.start()
def _start_relay_connectors(self):
self._active_connectors.update(self._their_direct_hints)
for hint in self._their_relay_hints:
self._start_connector(hint, is_relay=True)
def establish_socket(self):
start = time.time()
self.winning_skt = None
self.winning_skt_description = None
self._start_outbound()
# we sit here until one of our inbound or outbound sockets succeeds
flag = self.winning.wait(2*TIMEOUT)
debug("wait returned at %.1f" % (since(start),))
if not flag:
# timeout: self.winning_skt will not be set. ish. race.
pass
if self.listener:
self.listener.shutdown() # TODO: waits up to 0.5s. push to thread
if self.winning_skt:
return self.winning_skt
raise TransitError
def describe(self):
if not self.winning_skt_description:
return "not yet established"
return self.winning_skt_description
def _connector_failed(self, hint):
debug("- failed connector %s" % hint)
# XXX this was .remove, and occasionally got KeyError
self._active_connectors.discard(hint)
if not self._active_connectors:
self._start_relay_connectors()
def _negotiation_finished(self, skt, description):
# inbound/outbound sockets call this when they finish negotiation.
# The first one wins and gets a "go". Any subsequent ones lose and
# get a "nevermind" before being closed.
with self._negotiation_check_lock:
if self.winning_skt:
is_winner = False
else:
is_winner = True
self.winning_skt = skt
self.winning_skt_description = description
if is_winner:
if self.is_sender:
send_to(skt, b"go\n")
self.winning.set()
else:
if self.is_sender:
send_to(skt, b"nevermind\n")
skt.close()
def connect(self):
skt = self.establish_socket()
return RecordPipe(skt, self._sender_record_key(),
self._receiver_record_key())
class TransitSender(Common):
is_sender = True
class TransitReceiver(Common):
is_sender = False