From 99e08c2e37d6f3667f7fcaabee5072455691add5 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Thu, 19 Feb 2015 19:09:08 -0800 Subject: [PATCH] transit: use bidirectional connections --- bin/receive_file.py | 8 ++--- bin/send_file.py | 9 +++-- src/wormhole/blocking/transit.py | 62 ++++++++++++++++++++++++-------- 3 files changed, 56 insertions(+), 23 deletions(-) diff --git a/bin/receive_file.py b/bin/receive_file.py index 1809f83..16dcfa8 100644 --- a/bin/receive_file.py +++ b/bin/receive_file.py @@ -8,11 +8,11 @@ APPID = "lothar.com/wormhole/file-xfer" # we're receiving transit_receiver = TransitReceiver() -direct_hints = transit_receiver.get_direct_hints() mydata = json.dumps({ "transit": { - "direct_connection_hints": direct_hints, + "direct_connection_hints": transit_receiver.get_direct_hints(), + "relay_connection_hints": transit_receiver.get_relay_hints(), }, }).encode("utf-8") r = Receiver(APPID, mydata) @@ -31,8 +31,8 @@ encrypted_filesize = filesize + SecretBox.NONCE_SIZE+16 tdata = data["transit"] transit_key = r.derive_key(APPID+"/transit-key") transit_receiver.set_transit_key(transit_key) -transit_receiver.add_sender_direct_hints(tdata["direct_connection_hints"]) -transit_receiver.add_sender_relay_hints(tdata["relay_connection_hints"]) +transit_receiver.add_their_direct_hints(tdata["direct_connection_hints"]) +transit_receiver.add_their_relay_hints(tdata["relay_connection_hints"]) skt = transit_receiver.establish_connection() print("Receiving %d bytes.." % filesize) encrypted = b"" diff --git a/bin/send_file.py b/bin/send_file.py index e8c75df..c21ea3e 100644 --- a/bin/send_file.py +++ b/bin/send_file.py @@ -10,8 +10,6 @@ APPID = "lothar.com/wormhole/file-xfer" filename = sys.argv[1] assert os.path.isfile(filename) transit_sender = TransitSender() -direct_hints = transit_sender.get_direct_hints() -relay_hints = transit_sender.get_relay_hints() filesize = os.stat(filename).st_size data = json.dumps({ @@ -20,8 +18,8 @@ data = json.dumps({ "filesize": filesize, }, "transit": { - "direct_connection_hints": direct_hints, - "relay_connection_hints": relay_hints, + "direct_connection_hints": transit_sender.get_direct_hints(), + "relay_connection_hints": transit_sender.get_relay_hints(), }, }).encode("utf-8") @@ -44,7 +42,8 @@ encrypted = box.encrypt(plaintext, nonce) tdata = them_d["transit"] transit_key = i.derive_key(APPID+"/transit-key") transit_sender.set_transit_key(transit_key) -transit_sender.add_receiver_hints(tdata["direct_connection_hints"]) +transit_sender.add_their_direct_hints(tdata["direct_connection_hints"]) +transit_sender.add_their_relay_hints(tdata["relay_connection_hints"]) skt = transit_sender.establish_connection() print("Sending %d bytes.." % filesize) diff --git a/src/wormhole/blocking/transit.py b/src/wormhole/blocking/transit.py index df90b27..f46be0f 100644 --- a/src/wormhole/blocking/transit.py +++ b/src/wormhole/blocking/transit.py @@ -153,13 +153,28 @@ class TransitSender: self._negotiation_check_lock = threading.Lock() self._have_transit_key = threading.Condition() self._transit_key = None + self._start_server() + + def _start_server(self): + server = MyTCPServer(("",9999), None) + _, port = server.server_address + self.my_direct_hints = ["%s,%d" % (addr, port) + for addr in ipaddrs.find_addresses()] + server.owner = self + server_thread = threading.Thread(target=server.serve_forever) + server_thread.daemon = True + server_thread.start() + self.listener = server def get_direct_hints(self): - return [] + return self.my_direct_hints def get_relay_hints(self): return [] - def add_receiver_hints(self, hints): - self.receiver_hints = hints + + def add_their_direct_hints(self, hints): + self._their_direct_hints = hints + def add_their_relay_hints(self, hints): + self._their_relay_hints = hints # ignored def set_transit_key(self, key): # This _have_transit_key condition/lock protects us against the race @@ -168,24 +183,26 @@ class TransitSender: # (and thus the key). self._have_transit_key.acquire() self._transit_key = key - #self.handler_send_handshake = build_receiver_handshake(key) - #self.handler_expected_handshake = build_sender_handshake(key) + "go\n" + self.handler_send_handshake = build_sender_handshake(key) # no "go" + self.handler_expected_handshake = build_receiver_handshake(key) self._have_transit_key.notify_all() self._have_transit_key.release() - def establish_connection(self): + def _start_outbound(self): sender_handshake = build_sender_handshake(self._transit_key) receiver_handshake = build_receiver_handshake(self._transit_key) - self.listener = None self.connectors = [] - self.winning_skt = None - for hint in self.receiver_hints: + for hint in self._their_direct_hints: t = threading.Thread(target=connector, args=(self, hint, sender_handshake, receiver_handshake)) t.daemon = True t.start() + def establish_connection(self): + self.winning_skt = None + self._start_outbound() + # we sit here until one of our inbound or outbound sockets succeeds flag = self.winning.wait(TIMEOUT) @@ -225,7 +242,10 @@ class TransitReceiver: self._negotiation_check_lock = threading.Lock() self._have_transit_key = threading.Condition() self._transit_key = None - server = MyTCPServer(("",9999), None) + self._start_server() + + def _start_server(self): + server = MyTCPServer(("",9998), None) _, port = server.server_address self.my_direct_hints = ["%s,%d" % (addr, port) for addr in ipaddrs.find_addresses()] @@ -237,6 +257,13 @@ class TransitReceiver: def get_direct_hints(self): return self.my_direct_hints + def get_relay_hints(self): + return [] + + def add_their_direct_hints(self, hints): + self._their_direct_hints = hints + def add_their_relay_hints(self, hints): + self._their_relay_hints = hints # ignored def set_transit_key(self, key): # This _have_transit_key condition/lock protects us against the race @@ -250,13 +277,20 @@ class TransitReceiver: self._have_transit_key.notify_all() self._have_transit_key.release() - def add_sender_direct_hints(self, hints): - self.sender_direct_hints = hints # TODO ignored - def add_sender_relay_hints(self, hints): - self.sender_relay_hints = hints # TODO ignored + def _start_outbound(self): + sender_handshake = build_sender_handshake(self._transit_key) + "go\n" + receiver_handshake = build_receiver_handshake(self._transit_key) + self.connectors = [] + for hint in self._their_direct_hints: + t = threading.Thread(target=connector, + args=(self, hint, # SWAPPED + receiver_handshake, sender_handshake)) + t.daemon = True + t.start() def establish_connection(self): self.winning_skt = None + self._start_outbound() # we sit here until one of our inbound or outbound sockets succeeds flag = self.winning.wait(TIMEOUT)