magic-wormhole/src/wormhole/blocking/transit.py
Brian Warner 35630661a5 increase establish_connection() timeout to let relay work
If all the direct hints resulted in timeouts (e.g. they were to bad IP
addresses where connections just hang), the relay connection would fail.
The establish_connection() function had the same TIMEOUT as the
direct-hint connector, so it would give up just before the relay
connection was initiated.
2015-03-12 15:52:11 -07:00

334 lines
12 KiB
Python

from __future__ import print_function
import re, threading, socket, SocketServer
from binascii import hexlify
from ..util import ipaddrs
from ..util.hkdf import HKDF
from ..const import TRANSIT_RELAY
class TransitError(Exception):
pass
# The beginning of each TCP connection consists of the following handshake
# messages. The sender transmits the same text regardless of whether it is on
# the initiating/connecting end of the TCP connection, or on the
# listening/accepting side. Same for the receiver.
#
# sender -> receiver: transit sender TXID_HEX ready\n\n
# receiver -> sender: transit receiver RXID_HEX ready\n\n
#
# Any deviations from this result in the socket being closed. The handshake
# messages are designed to provoke an invalid response from other sorts of
# servers (HTTP, SMTP, echo).
#
# If the sender is satisfied with the handshake, and this is the first socket
# to complete negotiation, the sender does:
#
# sender -> receiver: go\n
#
# and the next byte on the wire will be from the application.
#
# If this is not the first socket, the sender does:
#
# sender -> receiver: nevermind\n
#
# and closes the socket.
# So the receiver looks for "transit sender TXID_HEX ready\n\ngo\n" and hangs
# up upon the first wrong byte. The sender lookgs for "transit receiver
# RXID_HEX ready\n\n" and then makes a first/not-first decision about sending
# "go\n" or "nevermind\n"+close().
def build_receiver_handshake(key):
hexid = HKDF(key, 32, CTXinfo=b"transit_receiver")
return "transit receiver %s ready\n\n" % hexlify(hexid)
def build_sender_handshake(key):
hexid = HKDF(key, 32, CTXinfo=b"transit_sender")
return "transit sender %s ready\n\n" % hexlify(hexid)
def build_relay_handshake(key):
token = HKDF(key, 32, CTXinfo=b"transit_relay_token")
return "please relay %s\n" % hexlify(token)
TIMEOUT=15
# 1: sender only transmits, receiver only accepts, both wait forever
# 2: sender also accepts, receiver also transmits
# 3: timeouts / stop when no more progress can be made
# 4: add relay
# 5: accelerate shutdown of losing sockets
class BadHandshake(Exception):
pass
def force_ascii(s):
if isinstance(s, type(u"")):
return s.encode("ascii")
return s
def send_to(skt, data):
sent = 0
while sent < len(data):
sent += skt.send(data[sent:])
def wait_for(skt, expected, description):
got = b""
while len(got) < len(expected):
got += skt.recv(1)
if expected[:len(got)] != got:
raise BadHandshake("got '%r' want '%r' on %s" %
(got, expected, description))
# The hint format is: TYPE,VALUE= /^([a-zA-Z0-9]+):(.*)$/ . VALUE depends
# upon TYPE, and it can have more colons in it. For TYPE=tcp (the only one
# currently defined), ADDR,PORT = /^(.*):(\d+)$/ , so ADDR can have colons.
# ADDR can be a hostname, ipv4 dotted-quad, or ipv6 colon-hex. If the hint
# publisher wants anonymity, their only hint's ADDR will end in .onion .
def parse_hint_tcp(hint):
# return tuple or None for an unparseable hint
mo = re.search(r'^([a-zA-Z0-9]+):(.*)$', hint)
if not mo:
print("unparseable hint '%s'" % (hint,))
return None
hint_type = mo.group(1)
if hint_type != "tcp":
print("unknown hint type '%s' in '%s'" % (hint_type, hint))
return None
hint_value = mo.group(2)
mo = re.search(r'^(.*):(\d+)$', hint_value)
if not mo:
print("unparseable TCP hint '%s'" % (hint,))
return None
hint_host = mo.group(1)
try:
hint_port = int(mo.group(2))
except ValueError:
print("non-numeric port in TCP hint '%s'" % (hint,))
return None
return hint_host, hint_port
def connector(owner, hint, description,
send_handshake, expected_handshake, relay_handshake=None):
parsed_hint = parse_hint_tcp(hint)
if not parsed_hint:
return # unparseable
addr,port = parsed_hint
skt = None
try:
skt = socket.create_connection((addr,port),
TIMEOUT) # timeout or ECONNREFUSED
skt.settimeout(TIMEOUT)
#print("socket(%s) connected" % (description,))
if relay_handshake:
send_to(skt, relay_handshake)
wait_for(skt, "ok\n", description)
#print("relay ready %r" % (description,))
send_to(skt, send_handshake)
wait_for(skt, expected_handshake, description)
#print("connector ready %r" % (description,))
except Exception as e:
try:
if skt:
skt.shutdown(socket.SHUT_WR)
except socket.error:
pass
if skt:
skt.close()
# ignore socket errors, warn about coding errors
if not isinstance(e, (socket.error, socket.timeout, BadHandshake)):
raise
owner._connector_failed(hint)
return
# owner is now responsible for the socket
owner._negotiation_finished(skt, description) # note thread
def handle(skt, client_address, owner, description,
send_handshake, expected_handshake):
try:
#print("handle %r" % (skt,))
skt.settimeout(TIMEOUT)
send_to(skt, send_handshake)
got = b""
# for the receiver, this includes the "go\n"
while len(got) < len(expected_handshake):
more = skt.recv(1)
if not more:
raise BadHandshake("disconnect after merely '%r'" % got)
got += more
if expected_handshake[:len(got)] != got:
raise BadHandshake("got '%r' want '%r'" %
(got, expected_handshake))
#print("handler negotiation finished %r" % (client_address,))
except Exception as e:
#print("handler failed %r" % (client_address,))
try:
# this raises socket.err(EBADF) if the socket was already closed
skt.shutdown(socket.SHUT_WR)
except socket.error:
pass
skt.close() # this appears to be idempotent
# ignore socket errors, warn about coding errors
if not isinstance(e, (socket.error, socket.timeout, BadHandshake)):
raise
return
# owner is now responsible for the socket
owner._negotiation_finished(skt, description) # note thread
class MyTCPServer(SocketServer.TCPServer):
allow_reuse_address = True
def process_request(self, request, client_address):
description = "<-tcp:%s:%d" % (client_address[0], client_address[1])
kc = self.owner._have_transit_key
kc.acquire()
while not self.owner._transit_key:
kc.wait()
# owner._transit_key is either None or set to a value. We don't
# modify it from here, so we can release the condition lock before
# grabbing the key.
kc.release()
# Once it is set, we can get handler_(send|receive)_handshake, which
# is what we actually care about.
t = threading.Thread(target=handle,
args=(request, client_address,
self.owner, description,
self.owner.handler_send_handshake,
self.owner.handler_expected_handshake))
t.daemon = True
t.start()
class Common:
def __init__(self):
self.winning = threading.Event()
self._negotiation_check_lock = threading.Lock()
self._have_transit_key = threading.Condition()
self._transit_key = None
self._start_server()
def _start_server(self):
server = MyTCPServer(("", 0), None)
_, port = server.server_address
self.my_direct_hints = ["tcp:%s:%d" % (addr, port)
for addr in ipaddrs.find_addresses()]
server.owner = self
server_thread = threading.Thread(target=server.serve_forever)
server_thread.daemon = True
server_thread.start()
self.listener = server
def get_direct_hints(self):
return self.my_direct_hints
def get_relay_hints(self):
return [TRANSIT_RELAY]
def add_their_direct_hints(self, hints):
self._their_direct_hints = [force_ascii(h) for h in hints]
def add_their_relay_hints(self, hints):
self._their_relay_hints = [force_ascii(h) for h in hints]
def _send_this(self):
if self.is_sender:
return build_sender_handshake(self._transit_key)
else:
return build_receiver_handshake(self._transit_key)
def _expect_this(self):
if self.is_sender:
return build_receiver_handshake(self._transit_key)
else:
return build_sender_handshake(self._transit_key) + "go\n"
def set_transit_key(self, key):
# This _have_transit_key condition/lock protects us against the race
# where the sender knows the hints and the key, and connects to the
# receiver's transit socket before the receiver gets relay message
# (and thus the key).
self._have_transit_key.acquire()
self._transit_key = key
self.handler_send_handshake = self._send_this() # no "go"
self.handler_expected_handshake = self._expect_this()
self._have_transit_key.notify_all()
self._have_transit_key.release()
def _start_outbound(self):
self._active_connectors = set(self._their_direct_hints)
for hint in self._their_direct_hints:
self._start_connector(hint)
if not self._their_direct_hints:
self._start_relay_connectors()
def _start_connector(self, hint, is_relay=False):
description = "->%s" % (hint,)
if is_relay:
description = "->relay:%s" % (hint,)
args = (self, hint, description,
self._send_this(), self._expect_this())
if is_relay:
args = args + (build_relay_handshake(self._transit_key),)
t = threading.Thread(target=connector, args=args)
t.daemon = True
t.start()
def _start_relay_connectors(self):
for hint in self._their_relay_hints:
self._start_connector(hint, is_relay=True)
def establish_connection(self):
self.winning_skt = None
self.winning_skt_description = None
self._start_outbound()
# we sit here until one of our inbound or outbound sockets succeeds
flag = self.winning.wait(2*TIMEOUT)
if not flag:
# timeout: self.winning_skt will not be set. ish. race.
pass
if self.listener:
self.listener.shutdown() # TODO: waits up to 0.5s. push to thread
if self.winning_skt:
return self.winning_skt
raise TransitError
def describe(self):
if not self.winning_skt_description:
return "not yet established"
return self.winning_skt_description
def _connector_failed(self, hint):
self._active_connectors.remove(hint)
if not self._active_connectors:
self._start_relay_connectors()
def _negotiation_finished(self, skt, description):
# inbound/outbound sockets call this when they finish negotiation.
# The first one wins and gets a "go". Any subsequent ones lose and
# get a "nevermind" before being closed.
with self._negotiation_check_lock:
if self.winning_skt:
is_winner = False
else:
is_winner = True
self.winning_skt = skt
self.winning_skt_description = description
if is_winner:
if self.is_sender:
send_to(skt, "go\n")
self.winning.set()
else:
if self.is_sender:
send_to(skt, "nevermind\n")
skt.close()
class TransitSender(Common):
is_sender = True
class TransitReceiver(Common):
is_sender = False