transit: fix handshake

Also make all threads daemonic, so they won't keep the process alive.
Also crank up the timeouts for manual testing.
This commit is contained in:
Brian Warner 2015-02-18 16:20:35 -08:00
parent 18ff9f9fd6
commit f459d59b48

View File

@ -6,9 +6,6 @@ from ..util.hkdf import HKDF
class TransitError(Exception): class TransitError(Exception):
pass pass
class ThreadedTCPServer(SocketServer.ThreadingMixIn, SocketServer.TCPServer):
pass
# The beginning of each TCP connection consists of the following handshake # The beginning of each TCP connection consists of the following handshake
# messages. The sender transmits the same text regardless of whether it is on # messages. The sender transmits the same text regardless of whether it is on
# the initiating/connecting end of the TCP connection, or on the # the initiating/connecting end of the TCP connection, or on the
@ -49,6 +46,8 @@ def build_sender_handshake(key):
hexid = HKDF(key, 32, CTXinfo=b"transit_sender") hexid = HKDF(key, 32, CTXinfo=b"transit_sender")
return "transit sender %s ready\n\n" % hexlify(hexid) return "transit sender %s ready\n\n" % hexlify(hexid)
TIMEOUT=10000
# 1: sender only transmits, receiver only accepts, both wait forever # 1: sender only transmits, receiver only accepts, both wait forever
# 2: sender also accepts, receiver also transmits # 2: sender also accepts, receiver also transmits
# 3: timeouts / stop when no more progress can be made # 3: timeouts / stop when no more progress can be made
@ -79,11 +78,11 @@ class TransitSender:
t = threading.Thread(target=connector, t = threading.Thread(target=connector,
args=(self, hint, args=(self, hint,
sender_handshake, receiver_handshake)) sender_handshake, receiver_handshake))
t.daemon = True
t.start() t.start()
# we sit here until one of our inbound or outbound sockets succeeds # we sit here until one of our inbound or outbound sockets succeeds
timeout = 10.0 flag = self.winning.wait(TIMEOUT)
flag = self.winning.wait(timeout)
if not flag: if not flag:
# timeout: self.winning_skt will not be set. ish. race. # timeout: self.winning_skt will not be set. ish. race.
@ -107,7 +106,7 @@ class TransitSender:
self.winning_skt = skt self.winning_skt = skt
if winner: if winner:
winner.send("go\n") self.winning_skt.send("go\n")
self.winning.set() self.winning.set()
else: else:
winner.send("nevermind\n") winner.send("nevermind\n")
@ -119,7 +118,7 @@ class BadHandshake(Exception):
def connector(owner, hint, send_handshake, expected_handshake): def connector(owner, hint, send_handshake, expected_handshake):
addr,port = hint.split(",") addr,port = hint.split(",")
skt = socket.create_connection((addr,port)) # timeout here skt = socket.create_connection((addr,port)) # timeout here
skt.settimeout(5.0) skt.settimeout(TIMEOUT)
print "socket(%s) connected" % hint print "socket(%s) connected" % hint
try: try:
skt.send(send_handshake) skt.send(send_handshake)
@ -145,10 +144,10 @@ def connector(owner, hint, send_handshake, expected_handshake):
def handle(skt, client_address, owner, send_handshake, expected_handshake): def handle(skt, client_address, owner, send_handshake, expected_handshake):
try: try:
print "handle", skt print "handle", skt
skt.settimeout(5.0) skt.settimeout(TIMEOUT)
skt.send(send_handshake) skt.send(send_handshake)
got = b"" got = b""
# for the receiver, this includes the "ok\n" # for the receiver, this includes the "go\n"
while len(got) < len(expected_handshake): while len(got) < len(expected_handshake):
got += skt.recv(1) got += skt.recv(1)
if expected_handshake[:len(got)] != got: if expected_handshake[:len(got)] != got:
@ -175,11 +174,13 @@ class MyTCPServer(SocketServer.TCPServer):
self.owner, self.owner,
self.owner.handler_send_handshake, self.owner.handler_send_handshake,
self.owner.handler_expected_handshake)) self.owner.handler_expected_handshake))
t.daemon = False t.daemon = True
t.start() t.start()
class TransitReceiver: class TransitReceiver:
def __init__(self): def __init__(self):
self.winning = threading.Event()
self._negotiation_check_lock = threading.Lock()
server = MyTCPServer(("",9999), None) server = MyTCPServer(("",9999), None)
_, port = server.server_address _, port = server.server_address
self.my_direct_hints = ["%s,%d" % (addr, port) self.my_direct_hints = ["%s,%d" % (addr, port)
@ -197,7 +198,7 @@ class TransitReceiver:
# transit before receiver gets relay message (with key) # transit before receiver gets relay message (with key)
self.key = key self.key = key
self.handler_send_handshake = build_receiver_handshake(key) self.handler_send_handshake = build_receiver_handshake(key)
self.handler_expected_handshake = build_sender_handshake(key) + "ok\n" self.handler_expected_handshake = build_sender_handshake(key) + "go\n"
def add_sender_direct_hints(self, hints): def add_sender_direct_hints(self, hints):
self.sender_direct_hints = hints # TODO ignored self.sender_direct_hints = hints # TODO ignored
@ -208,8 +209,7 @@ class TransitReceiver:
self.winning_skt = None self.winning_skt = None
# we sit here until one of our inbound or outbound sockets succeeds # we sit here until one of our inbound or outbound sockets succeeds
timeout = 10.0 flag = self.winning.wait(TIMEOUT)
flag = self.winning.wait(timeout)
if not flag: if not flag:
# timeout: self.winning_skt will not be set. ish. race. # timeout: self.winning_skt will not be set. ish. race.