transit: fix race, file-xfer basically works, but noisy

The failed connections are throwing exceptions that should be caught and
ignored.
This commit is contained in:
Brian Warner 2015-02-19 15:30:16 -08:00
parent ae68dad441
commit 9f998221da
3 changed files with 36 additions and 15 deletions

View File

@ -1,5 +1,5 @@
from __future__ import print_function from __future__ import print_function
import os, json import sys, os, json
from binascii import unhexlify from binascii import unhexlify
from nacl.secret import SecretBox from nacl.secret import SecretBox
from wormhole.blocking.transcribe import Receiver from wormhole.blocking.transcribe import Receiver
@ -26,6 +26,7 @@ file_data = data["file"]
xfer_key = unhexlify(file_data["key"].encode("ascii")) xfer_key = unhexlify(file_data["key"].encode("ascii"))
filename = os.path.basename(file_data["filename"]) # unicode filename = os.path.basename(file_data["filename"]) # unicode
filesize = file_data["filesize"] filesize = file_data["filesize"]
encrypted_filesize = filesize + SecretBox.NONCE_SIZE+16
# now receive the rest of the owl # now receive the rest of the owl
tdata = data["transit"] tdata = data["transit"]
@ -34,7 +35,10 @@ transit_receiver.set_transit_key(tdata["key"])
transit_receiver.add_sender_direct_hints(tdata["direct_connection_hints"]) transit_receiver.add_sender_direct_hints(tdata["direct_connection_hints"])
transit_receiver.add_sender_relay_hints(tdata["relay_connection_hints"]) transit_receiver.add_sender_relay_hints(tdata["relay_connection_hints"])
skt = transit_receiver.establish_connection() skt = transit_receiver.establish_connection()
encrypted = skt.recv(filesize) encrypted = skt.recv(encrypted_filesize)
if len(encrypted) != encrypted_filesize:
print("Connection dropped before file received")
sys.exit(1)
decrypted = SecretBox(xfer_key).decrypt(encrypted) decrypted = SecretBox(xfer_key).decrypt(encrypted)

View File

@ -47,7 +47,7 @@ encrypted = box.encrypt(plaintext, nonce)
tdata = them_d["transit"] tdata = them_d["transit"]
transit_sender.add_receiver_hints(tdata["direct_connection_hints"]) transit_sender.add_receiver_hints(tdata["direct_connection_hints"])
skt = transit_sender.establish_connection() skt = transit_sender.establish_connection()
skt.write(encrypted) skt.send(encrypted)
skt.close() skt.close()
print("file sent") print("file sent")

View File

@ -100,17 +100,17 @@ class TransitSender:
with self._negotiation_check_lock: with self._negotiation_check_lock:
if self.winning_skt: if self.winning_skt:
winner = False is_winner = False
else: else:
winner = True is_winner = True
self.winning_skt = skt self.winning_skt = skt
if winner: if is_winner:
self.winning_skt.send("go\n") skt.send("go\n")
self.winning.set() self.winning.set()
else: else:
winner.send("nevermind\n") skt.send("nevermind\n")
winner.close() skt.close()
class BadHandshake(Exception): class BadHandshake(Exception):
pass pass
@ -168,9 +168,19 @@ def handle(skt, client_address, owner, send_handshake, expected_handshake):
class MyTCPServer(SocketServer.TCPServer): class MyTCPServer(SocketServer.TCPServer):
allow_reuse_address = True allow_reuse_address = True
def process_request(self, request, client_address): def process_request(self, request, client_address):
if not self.owner.key: kc = self.owner._have_transit_key
raise BadHandshake("connection received before set_transit_key()") kc.acquire()
while not self.owner._transit_key:
kc.wait()
# owner._transit_key is either None or set to a value. We don't
# modify it from here, so we can release the condition lock before
# grabbing the key.
kc.release()
# Once it is set, we can get handler_(send|receive)_handshake, which
# is what we actually care about.
t = threading.Thread(target=handle, t = threading.Thread(target=handle,
args=(request, client_address, args=(request, client_address,
self.owner, self.owner,
@ -183,7 +193,8 @@ class TransitReceiver:
def __init__(self): def __init__(self):
self.winning = threading.Event() self.winning = threading.Event()
self._negotiation_check_lock = threading.Lock() self._negotiation_check_lock = threading.Lock()
self.key = None self._have_transit_key = threading.Condition()
self._transit_key = None
server = MyTCPServer(("",9999), None) server = MyTCPServer(("",9999), None)
_, port = server.server_address _, port = server.server_address
self.my_direct_hints = ["%s,%d" % (addr, port) self.my_direct_hints = ["%s,%d" % (addr, port)
@ -196,12 +207,18 @@ class TransitReceiver:
def get_direct_hints(self): def get_direct_hints(self):
return self.my_direct_hints return self.my_direct_hints
def set_transit_key(self, key): def set_transit_key(self, key):
# TODO consider race: sender knows the hints and the key, connects to # This _have_transit_key condition/lock protects us against the race
# transit before receiver gets relay message (with key) # where the sender knows the hints and the key, and connects to the
self.key = key # receiver's transit socket before the receiver gets relay message
# (and thus the key).
self._have_transit_key.acquire()
self._transit_key = key
self.handler_send_handshake = build_receiver_handshake(key) self.handler_send_handshake = build_receiver_handshake(key)
self.handler_expected_handshake = build_sender_handshake(key) + "go\n" self.handler_expected_handshake = build_sender_handshake(key) + "go\n"
self._have_transit_key.notify_all()
self._have_transit_key.release()
def add_sender_direct_hints(self, hints): def add_sender_direct_hints(self, hints):
self.sender_direct_hints = hints # TODO ignored self.sender_direct_hints = hints # TODO ignored