transit: provide encrypted record-pipe, use it for file-xfer

This commit is contained in:
Brian Warner 2015-03-12 18:14:42 -07:00
parent 8b3e5836ee
commit fcd2678dfd
3 changed files with 121 additions and 52 deletions

View File

@ -1,6 +1,7 @@
from __future__ import print_function
import re, time, threading, socket, SocketServer
from binascii import hexlify
from binascii import hexlify, unhexlify
from nacl.secret import SecretBox
from ..util import ipaddrs
from ..util.hkdf import HKDF
from ..const import TRANSIT_RELAY
@ -212,6 +213,62 @@ class MyTCPServer(SocketServer.TCPServer):
t.start()
class TransitClosed(TransitError):
pass
class BadNonce(TransitError):
pass
class ReceiveBuffer:
def __init__(self, skt):
self.skt = skt
self.buf = b""
def read(self, count):
while len(self.buf) < count:
more = self.skt.recv(4096)
if not more:
raise TransitClosed
self.buf += more
rc = self.buf[:count]
self.buf = self.buf[count:]
return rc
class RecordPipe:
def __init__(self, skt, send_key, receive_key):
self.skt = skt
self.send_box = SecretBox(send_key)
self.send_nonce = 0
self.receive_buf = ReceiveBuffer(self.skt)
self.receive_box = SecretBox(receive_key)
self.next_receive_nonce = 0
def send_record(self, record):
assert SecretBox.NONCE_SIZE == 24
assert self.send_nonce < 2**(8*24)
assert len(record) < 2**(8*4)
nonce = unhexlify("%048x" % self.send_nonce) # big-endian
self.send_nonce += 1
encrypted = self.send_box.encrypt(record, nonce)
length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long
send_to(self.skt, length)
send_to(self.skt, encrypted)
def receive_record(self):
length_buf = self.receive_buf.read(4)
length = int(hexlify(length_buf), 16)
encrypted = self.receive_buf.read(length)
nonce_buf = encrypted[:SecretBox.NONCE_SIZE] # assume it's prepended
nonce = int(hexlify(nonce_buf), 16)
if nonce != self.next_receive_nonce:
raise BadNonce("received out-of-order record")
self.next_receive_nonce += 1
record = self.receive_box.decrypt(encrypted)
return record
def close(self):
self.skt.close()
class Common:
def __init__(self):
self.winning = threading.Event()
@ -253,6 +310,22 @@ class Common:
else:
return build_sender_handshake(self._transit_key) + "go\n"
def _sender_record_key(self):
if self.is_sender:
return HKDF(self._transit_key, SecretBox.KEY_SIZE,
CTXinfo=b"transit_record_sender_key")
else:
return HKDF(self._transit_key, SecretBox.KEY_SIZE,
CTXinfo=b"transit_record_receiver_key")
def _receiver_record_key(self):
if self.is_sender:
return HKDF(self._transit_key, SecretBox.KEY_SIZE,
CTXinfo=b"transit_record_receiver_key")
else:
return HKDF(self._transit_key, SecretBox.KEY_SIZE,
CTXinfo=b"transit_record_sender_key")
def set_transit_key(self, key):
# This _have_transit_key condition/lock protects us against the race
# where the sender knows the hints and the key, and connects to the
@ -288,7 +361,7 @@ class Common:
for hint in self._their_relay_hints:
self._start_connector(hint, is_relay=True)
def establish_connection(self):
def establish_socket(self):
start = time.time()
self.winning_skt = None
self.winning_skt_description = None
@ -340,6 +413,11 @@ class Common:
send_to(skt, "nevermind\n")
skt.close()
def connect(self):
skt = self.establish_socket()
return RecordPipe(skt, self._sender_record_key(),
self._receiver_record_key())
class TransitSender(Common):
is_sender = True

View File

@ -1,8 +1,7 @@
from __future__ import print_function
import sys, os, json
from nacl.secret import SecretBox
from wormhole.blocking.transcribe import Receiver, WrongPasswordError
from wormhole.blocking.transit import TransitReceiver
from wormhole.blocking.transit import TransitReceiver, TransitError
from .progress import start_progress, update_progress, finish_progress
APPID = "lothar.com/wormhole/file-xfer"
@ -28,10 +27,8 @@ def receive_file(so):
#print("their data: %r" % (data,))
file_data = data["file"]
xfer_key = r.derive_key(APPID+"/xfer-key", SecretBox.KEY_SIZE)
filename = os.path.basename(file_data["filename"]) # unicode
filesize = file_data["filesize"]
encrypted_filesize = filesize + SecretBox.NONCE_SIZE+16
# now receive the rest of the owl
tdata = data["transit"]
@ -39,46 +36,47 @@ def receive_file(so):
transit_receiver.set_transit_key(transit_key)
transit_receiver.add_their_direct_hints(tdata["direct_connection_hints"])
transit_receiver.add_their_relay_hints(tdata["relay_connection_hints"])
skt = transit_receiver.establish_connection()
record_pipe = transit_receiver.connect()
print("Receiving %d bytes for '%s' (%s).." % (filesize, filename,
transit_receiver.describe()))
encrypted = b""
next_update = start_progress(encrypted_filesize)
while len(encrypted) < encrypted_filesize:
more = skt.recv(encrypted_filesize - len(encrypted))
if not more:
print()
print("Connection dropped before full file received")
print("got %d bytes, wanted %d" % (len(encrypted), encrypted_filesize))
return 1
encrypted += more
next_update = update_progress(next_update, len(encrypted),
encrypted_filesize)
finish_progress(encrypted_filesize)
assert len(encrypted) == encrypted_filesize
print("Decrypting..")
decrypted = SecretBox(xfer_key).decrypt(encrypted)
# only write to the current directory, and never overwrite anything
here = os.path.abspath(os.getcwd())
target = os.path.abspath(os.path.join(here, filename))
if os.path.dirname(target) != here:
print("Error: suggested filename (%s) would be outside current directory"
% (filename,))
skt.send("bad filename\n")
skt.close()
record_pipe.send_record("bad filename\n")
record_pipe.close()
return 1
if os.path.exists(target):
print("Error: refusing to overwrite existing file %s" % (filename,))
skt.send("file already exists\n")
skt.close()
record_pipe.send_record("file already exists\n")
record_pipe.close()
return 1
with open(target, "wb") as f:
f.write(decrypted)
tmp = target + ".tmp"
with open(tmp, "wb") as f:
received = 0
next_update = start_progress(filesize)
while received < filesize:
try:
plaintext = record_pipe.receive_record()
except TransitError:
print()
print("Connection dropped before full file received")
print("got %d bytes, wanted %d" % (received, filesize))
return 1
f.write(plaintext)
received += len(plaintext)
next_update = update_progress(next_update, received, filesize)
finish_progress(filesize)
assert received == filesize
os.rename(tmp, target)
print("Received file written to %s" % filename)
skt.send("ok\n")
skt.close()
record_pipe.send_record("ok\n")
record_pipe.close()
return 0

View File

@ -1,6 +1,5 @@
from __future__ import print_function
import os, sys, json
from nacl.secret import SecretBox
from wormhole.blocking.transcribe import Initiator, WrongPasswordError
from wormhole.blocking.transit import TransitSender
from .progress import start_progress, update_progress, finish_progress
@ -37,36 +36,30 @@ def send_file(so):
return 1
them_d = json.loads(them_bytes.decode("utf-8"))
#print("them: %r" % (them_d,))
xfer_key = i.derive_key(APPID+"/xfer-key", SecretBox.KEY_SIZE)
print("Encrypting %d bytes.." % filesize)
box = SecretBox(xfer_key)
with open(filename, "rb") as f:
plaintext = f.read()
nonce = os.urandom(SecretBox.NONCE_SIZE)
encrypted = box.encrypt(plaintext, nonce)
tdata = them_d["transit"]
transit_key = i.derive_key(APPID+"/transit-key")
transit_sender.set_transit_key(transit_key)
transit_sender.add_their_direct_hints(tdata["direct_connection_hints"])
transit_sender.add_their_relay_hints(tdata["relay_connection_hints"])
skt = transit_sender.establish_connection()
record_pipe = transit_sender.connect()
print("Sending (%s).." % transit_sender.describe())
sent = 0
next_update = start_progress(len(encrypted))
while sent < len(encrypted):
sent += skt.send(encrypted[sent:])
next_update = update_progress(next_update, sent, len(encrypted))
finish_progress(len(encrypted))
CHUNKSIZE = 64*1024
with open(filename, "rb") as f:
sent = 0
next_update = start_progress(filesize)
while sent < filesize:
plaintext = f.read(CHUNKSIZE)
record_pipe.send_record(plaintext)
sent += len(plaintext)
next_update = update_progress(next_update, sent, filesize)
finish_progress(filesize)
print("File sent.. waiting for confirmation")
# ack is a short newline-terminated string, followed by socket close. A long
# read is probably good enough.
ack = skt.recv(300)
ack = record_pipe.receive_record()
if ack == "ok\n":
print("Confirmation received. Transfer complete.")
return 0