diff --git a/bin/receive_file.py b/bin/receive_file.py index f30ce2f..9a0e50f 100644 --- a/bin/receive_file.py +++ b/bin/receive_file.py @@ -1,5 +1,5 @@ from __future__ import print_function -import os, json +import sys, os, json from binascii import unhexlify from nacl.secret import SecretBox from wormhole.blocking.transcribe import Receiver @@ -26,6 +26,7 @@ file_data = data["file"] xfer_key = unhexlify(file_data["key"].encode("ascii")) filename = os.path.basename(file_data["filename"]) # unicode filesize = file_data["filesize"] +encrypted_filesize = filesize + SecretBox.NONCE_SIZE+16 # now receive the rest of the owl tdata = data["transit"] @@ -34,7 +35,10 @@ transit_receiver.set_transit_key(tdata["key"]) transit_receiver.add_sender_direct_hints(tdata["direct_connection_hints"]) transit_receiver.add_sender_relay_hints(tdata["relay_connection_hints"]) skt = transit_receiver.establish_connection() -encrypted = skt.recv(filesize) +encrypted = skt.recv(encrypted_filesize) +if len(encrypted) != encrypted_filesize: + print("Connection dropped before file received") + sys.exit(1) decrypted = SecretBox(xfer_key).decrypt(encrypted) diff --git a/bin/send_file.py b/bin/send_file.py index e5c1951..4fdb430 100644 --- a/bin/send_file.py +++ b/bin/send_file.py @@ -47,7 +47,7 @@ encrypted = box.encrypt(plaintext, nonce) tdata = them_d["transit"] transit_sender.add_receiver_hints(tdata["direct_connection_hints"]) skt = transit_sender.establish_connection() -skt.write(encrypted) +skt.send(encrypted) skt.close() print("file sent") diff --git a/src/wormhole/blocking/transit.py b/src/wormhole/blocking/transit.py index d950b1d..ab88abc 100644 --- a/src/wormhole/blocking/transit.py +++ b/src/wormhole/blocking/transit.py @@ -100,17 +100,17 @@ class TransitSender: with self._negotiation_check_lock: if self.winning_skt: - winner = False + is_winner = False else: - winner = True + is_winner = True self.winning_skt = skt - if winner: - self.winning_skt.send("go\n") + if is_winner: + skt.send("go\n") self.winning.set() else: - winner.send("nevermind\n") - winner.close() + skt.send("nevermind\n") + skt.close() class BadHandshake(Exception): pass @@ -168,9 +168,19 @@ def handle(skt, client_address, owner, send_handshake, expected_handshake): class MyTCPServer(SocketServer.TCPServer): allow_reuse_address = True + def process_request(self, request, client_address): - if not self.owner.key: - raise BadHandshake("connection received before set_transit_key()") + kc = self.owner._have_transit_key + kc.acquire() + while not self.owner._transit_key: + kc.wait() + # owner._transit_key is either None or set to a value. We don't + # modify it from here, so we can release the condition lock before + # grabbing the key. + kc.release() + + # Once it is set, we can get handler_(send|receive)_handshake, which + # is what we actually care about. t = threading.Thread(target=handle, args=(request, client_address, self.owner, @@ -183,7 +193,8 @@ class TransitReceiver: def __init__(self): self.winning = threading.Event() self._negotiation_check_lock = threading.Lock() - self.key = None + self._have_transit_key = threading.Condition() + self._transit_key = None server = MyTCPServer(("",9999), None) _, port = server.server_address self.my_direct_hints = ["%s,%d" % (addr, port) @@ -196,12 +207,18 @@ class TransitReceiver: def get_direct_hints(self): return self.my_direct_hints + def set_transit_key(self, key): - # TODO consider race: sender knows the hints and the key, connects to - # transit before receiver gets relay message (with key) - self.key = key + # This _have_transit_key condition/lock protects us against the race + # where the sender knows the hints and the key, and connects to the + # receiver's transit socket before the receiver gets relay message + # (and thus the key). + self._have_transit_key.acquire() + self._transit_key = key self.handler_send_handshake = build_receiver_handshake(key) self.handler_expected_handshake = build_sender_handshake(key) + "go\n" + self._have_transit_key.notify_all() + self._have_transit_key.release() def add_sender_direct_hints(self, hints): self.sender_direct_hints = hints # TODO ignored