diff --git a/src/wormhole/blocking/transit.py b/src/wormhole/blocking/transit.py index 0e25843..92fa03c 100644 --- a/src/wormhole/blocking/transit.py +++ b/src/wormhole/blocking/transit.py @@ -45,6 +45,9 @@ def build_sender_handshake(key): hexid = HKDF(key, 32, CTXinfo=b"transit_sender") return "transit sender %s ready\n\n" % hexlify(hexid) +def build_relay_token(key): + return "PLEASE RELAY\n" + TIMEOUT=15 # 1: sender only transmits, receiver only accepts, both wait forever @@ -62,7 +65,16 @@ def force_ascii(s): return s.encode("ascii") return s -def connector(owner, hint, send_handshake, expected_handshake): +def wait_for(skt, expected, hint): + got = b"" + while len(got) < len(expected): + got += skt.recv(1) + if expected[:len(got)] != got: + raise BadHandshake("got '%r' want '%r' on %s" % + (got, expected, hint)) + +def connector(owner, hint, send_handshake, expected_handshake, + relay_token=None): addr,port = hint.split(",") skt = None try: @@ -70,13 +82,11 @@ def connector(owner, hint, send_handshake, expected_handshake): TIMEOUT) # timeout or ECONNREFUSED skt.settimeout(TIMEOUT) #print("socket(%s) connected" % (hint,)) + if relay_token: + skt.send(relay_token) + wait_for(skt, "ok\n", hint) skt.send(send_handshake) - got = b"" - while len(got) < len(expected_handshake): - got += skt.recv(1) - if expected_handshake[:len(got)] != got: - raise BadHandshake("got '%r' want '%r' on %s" % - (got, expected_handshake, hint)) + wait_for(skt, expected_handshake, hint) #print("connector ready %r" % (hint,)) except Exception as e: try: @@ -208,17 +218,17 @@ class Common: if not self._their_direct_hints: self._start_relay_connectors() - def _start_connector(self, hint): - t = threading.Thread(target=connector, - args=(self, hint, - self._send_this(), - self._expect_this())) + def _start_connector(self, hint, is_relay=False): + args = (self, hint, self._send_this(), self._expect_this()) + if is_relay: + args = args + (build_relay_token(self._transit_key),) + t = threading.Thread(target=connector, args=args) t.daemon = True t.start() def _start_relay_connectors(self): for hint in self._their_relay_hints: - self._start_connector(hint) + self._start_connector(hint, is_relay=True) def establish_connection(self): self.winning_skt = None