diff --git a/src/wormhole/blocking/transit.py b/src/wormhole/blocking/transit.py index 6b5f1d1..cd35b23 100644 --- a/src/wormhole/blocking/transit.py +++ b/src/wormhole/blocking/transit.py @@ -6,9 +6,6 @@ from ..util.hkdf import HKDF class TransitError(Exception): pass -class ThreadedTCPServer(SocketServer.ThreadingMixIn, SocketServer.TCPServer): - pass - # The beginning of each TCP connection consists of the following handshake # messages. The sender transmits the same text regardless of whether it is on # the initiating/connecting end of the TCP connection, or on the @@ -49,6 +46,8 @@ def build_sender_handshake(key): hexid = HKDF(key, 32, CTXinfo=b"transit_sender") return "transit sender %s ready\n\n" % hexlify(hexid) +TIMEOUT=10000 + # 1: sender only transmits, receiver only accepts, both wait forever # 2: sender also accepts, receiver also transmits # 3: timeouts / stop when no more progress can be made @@ -79,11 +78,11 @@ class TransitSender: t = threading.Thread(target=connector, args=(self, hint, sender_handshake, receiver_handshake)) + t.daemon = True t.start() # we sit here until one of our inbound or outbound sockets succeeds - timeout = 10.0 - flag = self.winning.wait(timeout) + flag = self.winning.wait(TIMEOUT) if not flag: # timeout: self.winning_skt will not be set. ish. race. @@ -107,7 +106,7 @@ class TransitSender: self.winning_skt = skt if winner: - winner.send("go\n") + self.winning_skt.send("go\n") self.winning.set() else: winner.send("nevermind\n") @@ -119,7 +118,7 @@ class BadHandshake(Exception): def connector(owner, hint, send_handshake, expected_handshake): addr,port = hint.split(",") skt = socket.create_connection((addr,port)) # timeout here - skt.settimeout(5.0) + skt.settimeout(TIMEOUT) print "socket(%s) connected" % hint try: skt.send(send_handshake) @@ -145,10 +144,10 @@ def connector(owner, hint, send_handshake, expected_handshake): def handle(skt, client_address, owner, send_handshake, expected_handshake): try: print "handle", skt - skt.settimeout(5.0) + skt.settimeout(TIMEOUT) skt.send(send_handshake) got = b"" - # for the receiver, this includes the "ok\n" + # for the receiver, this includes the "go\n" while len(got) < len(expected_handshake): got += skt.recv(1) if expected_handshake[:len(got)] != got: @@ -175,11 +174,13 @@ class MyTCPServer(SocketServer.TCPServer): self.owner, self.owner.handler_send_handshake, self.owner.handler_expected_handshake)) - t.daemon = False + t.daemon = True t.start() class TransitReceiver: def __init__(self): + self.winning = threading.Event() + self._negotiation_check_lock = threading.Lock() server = MyTCPServer(("",9999), None) _, port = server.server_address self.my_direct_hints = ["%s,%d" % (addr, port) @@ -197,7 +198,7 @@ class TransitReceiver: # transit before receiver gets relay message (with key) self.key = key self.handler_send_handshake = build_receiver_handshake(key) - self.handler_expected_handshake = build_sender_handshake(key) + "ok\n" + self.handler_expected_handshake = build_sender_handshake(key) + "go\n" def add_sender_direct_hints(self, hints): self.sender_direct_hints = hints # TODO ignored @@ -208,8 +209,7 @@ class TransitReceiver: self.winning_skt = None # we sit here until one of our inbound or outbound sockets succeeds - timeout = 10.0 - flag = self.winning.wait(timeout) + flag = self.winning.wait(TIMEOUT) if not flag: # timeout: self.winning_skt will not be set. ish. race.