diff --git a/src/wormhole/blocking/transit.py b/src/wormhole/blocking/transit.py index 6cf4098..6cd72c0 100644 --- a/src/wormhole/blocking/transit.py +++ b/src/wormhole/blocking/transit.py @@ -129,14 +129,15 @@ class MyTCPServer(socketserver.TCPServer): def process_request(self, request, client_address): description = "<-tcp:%s:%d" % (client_address[0], client_address[1]) - kc = self.owner._have_transit_key - kc.acquire() - while not self.owner._transit_key: - kc.wait() + ready_lock = self.owner._ready_for_connections_lock + ready_lock.acquire() + while not (self.owner._ready_for_connections + and self.owner._transit_key): + ready_lock.wait() # owner._transit_key is either None or set to a value. We don't # modify it from here, so we can release the condition lock before # grabbing the key. - kc.release() + ready_lock.release() # Once it is set, we can get handler_(send|receive)_handshake, which # is what we actually care about. @@ -210,7 +211,8 @@ class Common: self._transit_relays = [] self.winning = threading.Event() self._negotiation_check_lock = threading.Lock() - self._have_transit_key = threading.Condition() + self._ready_for_connections_lock = threading.Condition() + self._ready_for_connections = False self._transit_key = None self._start_server() @@ -272,16 +274,16 @@ class Common: CTXinfo=b"transit_record_sender_key") def set_transit_key(self, key): - # This _have_transit_key condition/lock protects us against the race - # where the sender knows the hints and the key, and connects to the - # receiver's transit socket before the receiver gets relay message - # (and thus the key). - self._have_transit_key.acquire() + # This _ready_for_connections condition/lock protects us against the + # race where the sender knows the hints and the key, and connects to + # the receiver's transit socket before the receiver gets relay + # message (and thus the key). + self._ready_for_connections_lock.acquire() self._transit_key = key self.handler_send_handshake = self._send_this() # no "go" self.handler_expected_handshake = self._expect_this() - self._have_transit_key.notify_all() - self._have_transit_key.release() + self._ready_for_connections_lock.notify_all() + self._ready_for_connections_lock.release() def _start_outbound(self): self._active_connectors = set(self._their_direct_hints) @@ -317,6 +319,10 @@ class Common: start = time.time() self.winning_skt = None self.winning_skt_description = None + self._ready_for_connections_lock.acquire() + self._ready_for_connections = True + self._ready_for_connections_lock.notify_all() + self._ready_for_connections_lock.release() self._start_outbound() # we sit here until one of our inbound or outbound sockets succeeds