diff --git a/src/wormhole/blocking/transit.py b/src/wormhole/blocking/transit.py index cd928a3..1421ff3 100644 --- a/src/wormhole/blocking/transit.py +++ b/src/wormhole/blocking/transit.py @@ -175,6 +175,40 @@ class Common: def add_their_relay_hints(self, hints): self._their_relay_hints = hints # ignored + def _send_this(self): + if self.is_sender: + return build_sender_handshake(self._transit_key) + else: + return build_receiver_handshake(self._transit_key) + + def _expect_this(self): + if self.is_sender: + return build_receiver_handshake(self._transit_key) + else: + return build_sender_handshake(self._transit_key) + "go\n" + + def set_transit_key(self, key): + # This _have_transit_key condition/lock protects us against the race + # where the sender knows the hints and the key, and connects to the + # receiver's transit socket before the receiver gets relay message + # (and thus the key). + self._have_transit_key.acquire() + self._transit_key = key + self.handler_send_handshake = self._send_this() # no "go" + self.handler_expected_handshake = self._expect_this() + self._have_transit_key.notify_all() + self._have_transit_key.release() + + def _start_outbound(self): + self.connectors = [] + for hint in self._their_direct_hints: + t = threading.Thread(target=connector, + args=(self, hint, + self._send_this(), + self._expect_this())) + t.daemon = True + t.start() + def establish_connection(self): self.winning_skt = None self._start_outbound() @@ -191,32 +225,6 @@ class Common: return self.winning_skt raise TransitError -class TransitSender(Common): - server_port = 9999 - - def set_transit_key(self, key): - # This _have_transit_key condition/lock protects us against the race - # where the sender knows the hints and the key, and connects to the - # receiver's transit socket before the receiver gets relay message - # (and thus the key). - self._have_transit_key.acquire() - self._transit_key = key - self.handler_send_handshake = build_sender_handshake(key) # no "go" - self.handler_expected_handshake = build_receiver_handshake(key) - self._have_transit_key.notify_all() - self._have_transit_key.release() - - def _start_outbound(self): - sender_handshake = build_sender_handshake(self._transit_key) - receiver_handshake = build_receiver_handshake(self._transit_key) - self.connectors = [] - for hint in self._their_direct_hints: - t = threading.Thread(target=connector, - args=(self, hint, - sender_handshake, receiver_handshake)) - t.daemon = True - t.start() - def _negotiation_finished(self, skt): # inbound/outbound sockets call this when they finish negotiation. # The first one wins and gets a "go". Any subsequent ones lose and @@ -230,51 +238,18 @@ class TransitSender(Common): self.winning_skt = skt if is_winner: - skt.send("go\n") + if self.is_sender: + skt.send("go\n") self.winning.set() else: - skt.send("nevermind\n") + if self.is_sender: + skt.send("nevermind\n") skt.close() - +class TransitSender(Common): + server_port = 9999 + is_sender = True class TransitReceiver(Common): server_port = 9998 - - def set_transit_key(self, key): - # This _have_transit_key condition/lock protects us against the race - # where the sender knows the hints and the key, and connects to the - # receiver's transit socket before the receiver gets relay message - # (and thus the key). - self._have_transit_key.acquire() - self._transit_key = key - self.handler_send_handshake = build_receiver_handshake(key) - self.handler_expected_handshake = build_sender_handshake(key) + "go\n" - self._have_transit_key.notify_all() - self._have_transit_key.release() - - def _start_outbound(self): - sender_handshake = build_sender_handshake(self._transit_key) + "go\n" - receiver_handshake = build_receiver_handshake(self._transit_key) - self.connectors = [] - for hint in self._their_direct_hints: - t = threading.Thread(target=connector, - args=(self, hint, # SWAPPED - receiver_handshake, sender_handshake)) - t.daemon = True - t.start() - - def _negotiation_finished(self, skt): - with self._negotiation_check_lock: - if self.winning_skt: - winner = False - else: - winner = True - self.winning_skt = skt - - if winner: - self.winning.set() - else: - winner.close() - print("weird, receiver was given duplicate winner") - + is_sender = False