diff --git a/src/wormhole/blocking/transit.py b/src/wormhole/blocking/transit.py index ac6bbab..1f9135c 100644 --- a/src/wormhole/blocking/transit.py +++ b/src/wormhole/blocking/transit.py @@ -1,6 +1,7 @@ from __future__ import print_function import re, time, threading, socket, SocketServer -from binascii import hexlify +from binascii import hexlify, unhexlify +from nacl.secret import SecretBox from ..util import ipaddrs from ..util.hkdf import HKDF from ..const import TRANSIT_RELAY @@ -212,6 +213,62 @@ class MyTCPServer(SocketServer.TCPServer): t.start() +class TransitClosed(TransitError): + pass + +class BadNonce(TransitError): + pass + +class ReceiveBuffer: + def __init__(self, skt): + self.skt = skt + self.buf = b"" + + def read(self, count): + while len(self.buf) < count: + more = self.skt.recv(4096) + if not more: + raise TransitClosed + self.buf += more + rc = self.buf[:count] + self.buf = self.buf[count:] + return rc + +class RecordPipe: + def __init__(self, skt, send_key, receive_key): + self.skt = skt + self.send_box = SecretBox(send_key) + self.send_nonce = 0 + self.receive_buf = ReceiveBuffer(self.skt) + self.receive_box = SecretBox(receive_key) + self.next_receive_nonce = 0 + + def send_record(self, record): + assert SecretBox.NONCE_SIZE == 24 + assert self.send_nonce < 2**(8*24) + assert len(record) < 2**(8*4) + nonce = unhexlify("%048x" % self.send_nonce) # big-endian + self.send_nonce += 1 + encrypted = self.send_box.encrypt(record, nonce) + length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long + send_to(self.skt, length) + send_to(self.skt, encrypted) + + def receive_record(self): + length_buf = self.receive_buf.read(4) + length = int(hexlify(length_buf), 16) + encrypted = self.receive_buf.read(length) + nonce_buf = encrypted[:SecretBox.NONCE_SIZE] # assume it's prepended + nonce = int(hexlify(nonce_buf), 16) + if nonce != self.next_receive_nonce: + raise BadNonce("received out-of-order record") + self.next_receive_nonce += 1 + record = self.receive_box.decrypt(encrypted) + return record + + def close(self): + self.skt.close() + class Common: def __init__(self): self.winning = threading.Event() @@ -253,6 +310,22 @@ class Common: else: return build_sender_handshake(self._transit_key) + "go\n" + def _sender_record_key(self): + if self.is_sender: + return HKDF(self._transit_key, SecretBox.KEY_SIZE, + CTXinfo=b"transit_record_sender_key") + else: + return HKDF(self._transit_key, SecretBox.KEY_SIZE, + CTXinfo=b"transit_record_receiver_key") + + def _receiver_record_key(self): + if self.is_sender: + return HKDF(self._transit_key, SecretBox.KEY_SIZE, + CTXinfo=b"transit_record_receiver_key") + else: + return HKDF(self._transit_key, SecretBox.KEY_SIZE, + CTXinfo=b"transit_record_sender_key") + def set_transit_key(self, key): # This _have_transit_key condition/lock protects us against the race # where the sender knows the hints and the key, and connects to the @@ -288,7 +361,7 @@ class Common: for hint in self._their_relay_hints: self._start_connector(hint, is_relay=True) - def establish_connection(self): + def establish_socket(self): start = time.time() self.winning_skt = None self.winning_skt_description = None @@ -340,6 +413,11 @@ class Common: send_to(skt, "nevermind\n") skt.close() + def connect(self): + skt = self.establish_socket() + return RecordPipe(skt, self._sender_record_key(), + self._receiver_record_key()) + class TransitSender(Common): is_sender = True diff --git a/src/wormhole/scripts/cmd_receive_file.py b/src/wormhole/scripts/cmd_receive_file.py index 8745514..e578d31 100644 --- a/src/wormhole/scripts/cmd_receive_file.py +++ b/src/wormhole/scripts/cmd_receive_file.py @@ -1,8 +1,7 @@ from __future__ import print_function import sys, os, json -from nacl.secret import SecretBox from wormhole.blocking.transcribe import Receiver, WrongPasswordError -from wormhole.blocking.transit import TransitReceiver +from wormhole.blocking.transit import TransitReceiver, TransitError from .progress import start_progress, update_progress, finish_progress APPID = "lothar.com/wormhole/file-xfer" @@ -28,10 +27,8 @@ def receive_file(so): #print("their data: %r" % (data,)) file_data = data["file"] - xfer_key = r.derive_key(APPID+"/xfer-key", SecretBox.KEY_SIZE) filename = os.path.basename(file_data["filename"]) # unicode filesize = file_data["filesize"] - encrypted_filesize = filesize + SecretBox.NONCE_SIZE+16 # now receive the rest of the owl tdata = data["transit"] @@ -39,46 +36,47 @@ def receive_file(so): transit_receiver.set_transit_key(transit_key) transit_receiver.add_their_direct_hints(tdata["direct_connection_hints"]) transit_receiver.add_their_relay_hints(tdata["relay_connection_hints"]) - skt = transit_receiver.establish_connection() + record_pipe = transit_receiver.connect() print("Receiving %d bytes for '%s' (%s).." % (filesize, filename, transit_receiver.describe())) - encrypted = b"" - next_update = start_progress(encrypted_filesize) - while len(encrypted) < encrypted_filesize: - more = skt.recv(encrypted_filesize - len(encrypted)) - if not more: - print() - print("Connection dropped before full file received") - print("got %d bytes, wanted %d" % (len(encrypted), encrypted_filesize)) - return 1 - encrypted += more - next_update = update_progress(next_update, len(encrypted), - encrypted_filesize) - finish_progress(encrypted_filesize) - assert len(encrypted) == encrypted_filesize - - print("Decrypting..") - decrypted = SecretBox(xfer_key).decrypt(encrypted) - # only write to the current directory, and never overwrite anything here = os.path.abspath(os.getcwd()) target = os.path.abspath(os.path.join(here, filename)) if os.path.dirname(target) != here: print("Error: suggested filename (%s) would be outside current directory" % (filename,)) - skt.send("bad filename\n") - skt.close() + record_pipe.send_record("bad filename\n") + record_pipe.close() return 1 if os.path.exists(target): print("Error: refusing to overwrite existing file %s" % (filename,)) - skt.send("file already exists\n") - skt.close() + record_pipe.send_record("file already exists\n") + record_pipe.close() return 1 - with open(target, "wb") as f: - f.write(decrypted) + tmp = target + ".tmp" + + with open(tmp, "wb") as f: + received = 0 + next_update = start_progress(filesize) + while received < filesize: + try: + plaintext = record_pipe.receive_record() + except TransitError: + print() + print("Connection dropped before full file received") + print("got %d bytes, wanted %d" % (received, filesize)) + return 1 + f.write(plaintext) + received += len(plaintext) + next_update = update_progress(next_update, received, filesize) + finish_progress(filesize) + assert received == filesize + + os.rename(tmp, target) + print("Received file written to %s" % filename) - skt.send("ok\n") - skt.close() + record_pipe.send_record("ok\n") + record_pipe.close() return 0 diff --git a/src/wormhole/scripts/cmd_send_file.py b/src/wormhole/scripts/cmd_send_file.py index 566ae5d..09c4996 100644 --- a/src/wormhole/scripts/cmd_send_file.py +++ b/src/wormhole/scripts/cmd_send_file.py @@ -1,6 +1,5 @@ from __future__ import print_function import os, sys, json -from nacl.secret import SecretBox from wormhole.blocking.transcribe import Initiator, WrongPasswordError from wormhole.blocking.transit import TransitSender from .progress import start_progress, update_progress, finish_progress @@ -37,36 +36,30 @@ def send_file(so): return 1 them_d = json.loads(them_bytes.decode("utf-8")) #print("them: %r" % (them_d,)) - xfer_key = i.derive_key(APPID+"/xfer-key", SecretBox.KEY_SIZE) - print("Encrypting %d bytes.." % filesize) - - box = SecretBox(xfer_key) - with open(filename, "rb") as f: - plaintext = f.read() - nonce = os.urandom(SecretBox.NONCE_SIZE) - encrypted = box.encrypt(plaintext, nonce) tdata = them_d["transit"] transit_key = i.derive_key(APPID+"/transit-key") transit_sender.set_transit_key(transit_key) transit_sender.add_their_direct_hints(tdata["direct_connection_hints"]) transit_sender.add_their_relay_hints(tdata["relay_connection_hints"]) - skt = transit_sender.establish_connection() + record_pipe = transit_sender.connect() print("Sending (%s).." % transit_sender.describe()) - sent = 0 - next_update = start_progress(len(encrypted)) - while sent < len(encrypted): - sent += skt.send(encrypted[sent:]) - next_update = update_progress(next_update, sent, len(encrypted)) - finish_progress(len(encrypted)) + CHUNKSIZE = 64*1024 + with open(filename, "rb") as f: + sent = 0 + next_update = start_progress(filesize) + while sent < filesize: + plaintext = f.read(CHUNKSIZE) + record_pipe.send_record(plaintext) + sent += len(plaintext) + next_update = update_progress(next_update, sent, filesize) + finish_progress(filesize) print("File sent.. waiting for confirmation") - # ack is a short newline-terminated string, followed by socket close. A long - # read is probably good enough. - ack = skt.recv(300) + ack = record_pipe.receive_record() if ack == "ok\n": print("Confirmation received. Transfer complete.") return 0