transit: finish refactoring, combine mostly into a single class

This commit is contained in:
Brian Warner 2015-02-19 23:55:05 -08:00
parent af5f2053b8
commit 6f64b6d326

View File

@ -175,6 +175,40 @@ class Common:
def add_their_relay_hints(self, hints):
self._their_relay_hints = hints # ignored
def _send_this(self):
if self.is_sender:
return build_sender_handshake(self._transit_key)
else:
return build_receiver_handshake(self._transit_key)
def _expect_this(self):
if self.is_sender:
return build_receiver_handshake(self._transit_key)
else:
return build_sender_handshake(self._transit_key) + "go\n"
def set_transit_key(self, key):
# This _have_transit_key condition/lock protects us against the race
# where the sender knows the hints and the key, and connects to the
# receiver's transit socket before the receiver gets relay message
# (and thus the key).
self._have_transit_key.acquire()
self._transit_key = key
self.handler_send_handshake = self._send_this() # no "go"
self.handler_expected_handshake = self._expect_this()
self._have_transit_key.notify_all()
self._have_transit_key.release()
def _start_outbound(self):
self.connectors = []
for hint in self._their_direct_hints:
t = threading.Thread(target=connector,
args=(self, hint,
self._send_this(),
self._expect_this()))
t.daemon = True
t.start()
def establish_connection(self):
self.winning_skt = None
self._start_outbound()
@ -191,32 +225,6 @@ class Common:
return self.winning_skt
raise TransitError
class TransitSender(Common):
server_port = 9999
def set_transit_key(self, key):
# This _have_transit_key condition/lock protects us against the race
# where the sender knows the hints and the key, and connects to the
# receiver's transit socket before the receiver gets relay message
# (and thus the key).
self._have_transit_key.acquire()
self._transit_key = key
self.handler_send_handshake = build_sender_handshake(key) # no "go"
self.handler_expected_handshake = build_receiver_handshake(key)
self._have_transit_key.notify_all()
self._have_transit_key.release()
def _start_outbound(self):
sender_handshake = build_sender_handshake(self._transit_key)
receiver_handshake = build_receiver_handshake(self._transit_key)
self.connectors = []
for hint in self._their_direct_hints:
t = threading.Thread(target=connector,
args=(self, hint,
sender_handshake, receiver_handshake))
t.daemon = True
t.start()
def _negotiation_finished(self, skt):
# inbound/outbound sockets call this when they finish negotiation.
# The first one wins and gets a "go". Any subsequent ones lose and
@ -230,51 +238,18 @@ class TransitSender(Common):
self.winning_skt = skt
if is_winner:
skt.send("go\n")
if self.is_sender:
skt.send("go\n")
self.winning.set()
else:
skt.send("nevermind\n")
if self.is_sender:
skt.send("nevermind\n")
skt.close()
class TransitSender(Common):
server_port = 9999
is_sender = True
class TransitReceiver(Common):
server_port = 9998
def set_transit_key(self, key):
# This _have_transit_key condition/lock protects us against the race
# where the sender knows the hints and the key, and connects to the
# receiver's transit socket before the receiver gets relay message
# (and thus the key).
self._have_transit_key.acquire()
self._transit_key = key
self.handler_send_handshake = build_receiver_handshake(key)
self.handler_expected_handshake = build_sender_handshake(key) + "go\n"
self._have_transit_key.notify_all()
self._have_transit_key.release()
def _start_outbound(self):
sender_handshake = build_sender_handshake(self._transit_key) + "go\n"
receiver_handshake = build_receiver_handshake(self._transit_key)
self.connectors = []
for hint in self._their_direct_hints:
t = threading.Thread(target=connector,
args=(self, hint, # SWAPPED
receiver_handshake, sender_handshake))
t.daemon = True
t.start()
def _negotiation_finished(self, skt):
with self._negotiation_check_lock:
if self.winning_skt:
winner = False
else:
winner = True
self.winning_skt = skt
if winner:
self.winning.set()
else:
winner.close()
print("weird, receiver was given duplicate winner")
is_sender = False