diff --git a/src/wormhole/blocking/transit.py b/src/wormhole/blocking/transit.py index 515061a..0e25843 100644 --- a/src/wormhole/blocking/transit.py +++ b/src/wormhole/blocking/transit.py @@ -57,13 +57,17 @@ TIMEOUT=15 class BadHandshake(Exception): pass +def force_ascii(s): + if isinstance(s, type(u"")): + return s.encode("ascii") + return s + def connector(owner, hint, send_handshake, expected_handshake): - if isinstance(hint, type(u"")): - hint = hint.encode("ascii") addr,port = hint.split(",") skt = None try: - skt = socket.create_connection((addr,port)) # timeout here + skt = socket.create_connection((addr,port), + TIMEOUT) # timeout or ECONNREFUSED skt.settimeout(TIMEOUT) #print("socket(%s) connected" % (hint,)) skt.send(send_handshake) @@ -85,7 +89,7 @@ def connector(owner, hint, send_handshake, expected_handshake): # ignore socket errors, warn about coding errors if not isinstance(e, (socket.error, socket.timeout, BadHandshake)): raise - return + owner._connector_failed(hint) # owner is now responsible for the socket owner._negotiation_finished(skt) # note thread @@ -169,9 +173,9 @@ class Common: return [] def add_their_direct_hints(self, hints): - self._their_direct_hints = hints + self._their_direct_hints = [force_ascii(h) for h in hints] def add_their_relay_hints(self, hints): - self._their_relay_hints = hints # ignored + self._their_relay_hints = [force_ascii(h) for h in hints] def _send_this(self): if self.is_sender: @@ -198,14 +202,23 @@ class Common: self._have_transit_key.release() def _start_outbound(self): - self.connectors = [] + self._active_connectors = set(self._their_direct_hints) for hint in self._their_direct_hints: - t = threading.Thread(target=connector, - args=(self, hint, - self._send_this(), - self._expect_this())) - t.daemon = True - t.start() + self._start_connector(hint) + if not self._their_direct_hints: + self._start_relay_connectors() + + def _start_connector(self, hint): + t = threading.Thread(target=connector, + args=(self, hint, + self._send_this(), + self._expect_this())) + t.daemon = True + t.start() + + def _start_relay_connectors(self): + for hint in self._their_relay_hints: + self._start_connector(hint) def establish_connection(self): self.winning_skt = None @@ -223,6 +236,11 @@ class Common: return self.winning_skt raise TransitError + def _connector_failed(self, hint): + self._active_connectors.remove(hint) + if not self._active_connectors: + self._start_relay_connectors() + def _negotiation_finished(self, skt): # inbound/outbound sockets call this when they finish negotiation. # The first one wins and gets a "go". Any subsequent ones lose and