transit: provide encrypted record-pipe, use it for file-xfer

This commit is contained in:
Brian Warner 2015-03-12 18:14:42 -07:00
parent 8b3e5836ee
commit fcd2678dfd
3 changed files with 121 additions and 52 deletions

View File

@ -1,6 +1,7 @@
from __future__ import print_function from __future__ import print_function
import re, time, threading, socket, SocketServer import re, time, threading, socket, SocketServer
from binascii import hexlify from binascii import hexlify, unhexlify
from nacl.secret import SecretBox
from ..util import ipaddrs from ..util import ipaddrs
from ..util.hkdf import HKDF from ..util.hkdf import HKDF
from ..const import TRANSIT_RELAY from ..const import TRANSIT_RELAY
@ -212,6 +213,62 @@ class MyTCPServer(SocketServer.TCPServer):
t.start() t.start()
class TransitClosed(TransitError):
pass
class BadNonce(TransitError):
pass
class ReceiveBuffer:
def __init__(self, skt):
self.skt = skt
self.buf = b""
def read(self, count):
while len(self.buf) < count:
more = self.skt.recv(4096)
if not more:
raise TransitClosed
self.buf += more
rc = self.buf[:count]
self.buf = self.buf[count:]
return rc
class RecordPipe:
def __init__(self, skt, send_key, receive_key):
self.skt = skt
self.send_box = SecretBox(send_key)
self.send_nonce = 0
self.receive_buf = ReceiveBuffer(self.skt)
self.receive_box = SecretBox(receive_key)
self.next_receive_nonce = 0
def send_record(self, record):
assert SecretBox.NONCE_SIZE == 24
assert self.send_nonce < 2**(8*24)
assert len(record) < 2**(8*4)
nonce = unhexlify("%048x" % self.send_nonce) # big-endian
self.send_nonce += 1
encrypted = self.send_box.encrypt(record, nonce)
length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long
send_to(self.skt, length)
send_to(self.skt, encrypted)
def receive_record(self):
length_buf = self.receive_buf.read(4)
length = int(hexlify(length_buf), 16)
encrypted = self.receive_buf.read(length)
nonce_buf = encrypted[:SecretBox.NONCE_SIZE] # assume it's prepended
nonce = int(hexlify(nonce_buf), 16)
if nonce != self.next_receive_nonce:
raise BadNonce("received out-of-order record")
self.next_receive_nonce += 1
record = self.receive_box.decrypt(encrypted)
return record
def close(self):
self.skt.close()
class Common: class Common:
def __init__(self): def __init__(self):
self.winning = threading.Event() self.winning = threading.Event()
@ -253,6 +310,22 @@ class Common:
else: else:
return build_sender_handshake(self._transit_key) + "go\n" return build_sender_handshake(self._transit_key) + "go\n"
def _sender_record_key(self):
if self.is_sender:
return HKDF(self._transit_key, SecretBox.KEY_SIZE,
CTXinfo=b"transit_record_sender_key")
else:
return HKDF(self._transit_key, SecretBox.KEY_SIZE,
CTXinfo=b"transit_record_receiver_key")
def _receiver_record_key(self):
if self.is_sender:
return HKDF(self._transit_key, SecretBox.KEY_SIZE,
CTXinfo=b"transit_record_receiver_key")
else:
return HKDF(self._transit_key, SecretBox.KEY_SIZE,
CTXinfo=b"transit_record_sender_key")
def set_transit_key(self, key): def set_transit_key(self, key):
# This _have_transit_key condition/lock protects us against the race # This _have_transit_key condition/lock protects us against the race
# where the sender knows the hints and the key, and connects to the # where the sender knows the hints and the key, and connects to the
@ -288,7 +361,7 @@ class Common:
for hint in self._their_relay_hints: for hint in self._their_relay_hints:
self._start_connector(hint, is_relay=True) self._start_connector(hint, is_relay=True)
def establish_connection(self): def establish_socket(self):
start = time.time() start = time.time()
self.winning_skt = None self.winning_skt = None
self.winning_skt_description = None self.winning_skt_description = None
@ -340,6 +413,11 @@ class Common:
send_to(skt, "nevermind\n") send_to(skt, "nevermind\n")
skt.close() skt.close()
def connect(self):
skt = self.establish_socket()
return RecordPipe(skt, self._sender_record_key(),
self._receiver_record_key())
class TransitSender(Common): class TransitSender(Common):
is_sender = True is_sender = True

View File

@ -1,8 +1,7 @@
from __future__ import print_function from __future__ import print_function
import sys, os, json import sys, os, json
from nacl.secret import SecretBox
from wormhole.blocking.transcribe import Receiver, WrongPasswordError from wormhole.blocking.transcribe import Receiver, WrongPasswordError
from wormhole.blocking.transit import TransitReceiver from wormhole.blocking.transit import TransitReceiver, TransitError
from .progress import start_progress, update_progress, finish_progress from .progress import start_progress, update_progress, finish_progress
APPID = "lothar.com/wormhole/file-xfer" APPID = "lothar.com/wormhole/file-xfer"
@ -28,10 +27,8 @@ def receive_file(so):
#print("their data: %r" % (data,)) #print("their data: %r" % (data,))
file_data = data["file"] file_data = data["file"]
xfer_key = r.derive_key(APPID+"/xfer-key", SecretBox.KEY_SIZE)
filename = os.path.basename(file_data["filename"]) # unicode filename = os.path.basename(file_data["filename"]) # unicode
filesize = file_data["filesize"] filesize = file_data["filesize"]
encrypted_filesize = filesize + SecretBox.NONCE_SIZE+16
# now receive the rest of the owl # now receive the rest of the owl
tdata = data["transit"] tdata = data["transit"]
@ -39,46 +36,47 @@ def receive_file(so):
transit_receiver.set_transit_key(transit_key) transit_receiver.set_transit_key(transit_key)
transit_receiver.add_their_direct_hints(tdata["direct_connection_hints"]) transit_receiver.add_their_direct_hints(tdata["direct_connection_hints"])
transit_receiver.add_their_relay_hints(tdata["relay_connection_hints"]) transit_receiver.add_their_relay_hints(tdata["relay_connection_hints"])
skt = transit_receiver.establish_connection() record_pipe = transit_receiver.connect()
print("Receiving %d bytes for '%s' (%s).." % (filesize, filename, print("Receiving %d bytes for '%s' (%s).." % (filesize, filename,
transit_receiver.describe())) transit_receiver.describe()))
encrypted = b""
next_update = start_progress(encrypted_filesize)
while len(encrypted) < encrypted_filesize:
more = skt.recv(encrypted_filesize - len(encrypted))
if not more:
print()
print("Connection dropped before full file received")
print("got %d bytes, wanted %d" % (len(encrypted), encrypted_filesize))
return 1
encrypted += more
next_update = update_progress(next_update, len(encrypted),
encrypted_filesize)
finish_progress(encrypted_filesize)
assert len(encrypted) == encrypted_filesize
print("Decrypting..")
decrypted = SecretBox(xfer_key).decrypt(encrypted)
# only write to the current directory, and never overwrite anything # only write to the current directory, and never overwrite anything
here = os.path.abspath(os.getcwd()) here = os.path.abspath(os.getcwd())
target = os.path.abspath(os.path.join(here, filename)) target = os.path.abspath(os.path.join(here, filename))
if os.path.dirname(target) != here: if os.path.dirname(target) != here:
print("Error: suggested filename (%s) would be outside current directory" print("Error: suggested filename (%s) would be outside current directory"
% (filename,)) % (filename,))
skt.send("bad filename\n") record_pipe.send_record("bad filename\n")
skt.close() record_pipe.close()
return 1 return 1
if os.path.exists(target): if os.path.exists(target):
print("Error: refusing to overwrite existing file %s" % (filename,)) print("Error: refusing to overwrite existing file %s" % (filename,))
skt.send("file already exists\n") record_pipe.send_record("file already exists\n")
skt.close() record_pipe.close()
return 1 return 1
with open(target, "wb") as f: tmp = target + ".tmp"
f.write(decrypted)
with open(tmp, "wb") as f:
received = 0
next_update = start_progress(filesize)
while received < filesize:
try:
plaintext = record_pipe.receive_record()
except TransitError:
print()
print("Connection dropped before full file received")
print("got %d bytes, wanted %d" % (received, filesize))
return 1
f.write(plaintext)
received += len(plaintext)
next_update = update_progress(next_update, received, filesize)
finish_progress(filesize)
assert received == filesize
os.rename(tmp, target)
print("Received file written to %s" % filename) print("Received file written to %s" % filename)
skt.send("ok\n") record_pipe.send_record("ok\n")
skt.close() record_pipe.close()
return 0 return 0

View File

@ -1,6 +1,5 @@
from __future__ import print_function from __future__ import print_function
import os, sys, json import os, sys, json
from nacl.secret import SecretBox
from wormhole.blocking.transcribe import Initiator, WrongPasswordError from wormhole.blocking.transcribe import Initiator, WrongPasswordError
from wormhole.blocking.transit import TransitSender from wormhole.blocking.transit import TransitSender
from .progress import start_progress, update_progress, finish_progress from .progress import start_progress, update_progress, finish_progress
@ -37,36 +36,30 @@ def send_file(so):
return 1 return 1
them_d = json.loads(them_bytes.decode("utf-8")) them_d = json.loads(them_bytes.decode("utf-8"))
#print("them: %r" % (them_d,)) #print("them: %r" % (them_d,))
xfer_key = i.derive_key(APPID+"/xfer-key", SecretBox.KEY_SIZE)
print("Encrypting %d bytes.." % filesize)
box = SecretBox(xfer_key)
with open(filename, "rb") as f:
plaintext = f.read()
nonce = os.urandom(SecretBox.NONCE_SIZE)
encrypted = box.encrypt(plaintext, nonce)
tdata = them_d["transit"] tdata = them_d["transit"]
transit_key = i.derive_key(APPID+"/transit-key") transit_key = i.derive_key(APPID+"/transit-key")
transit_sender.set_transit_key(transit_key) transit_sender.set_transit_key(transit_key)
transit_sender.add_their_direct_hints(tdata["direct_connection_hints"]) transit_sender.add_their_direct_hints(tdata["direct_connection_hints"])
transit_sender.add_their_relay_hints(tdata["relay_connection_hints"]) transit_sender.add_their_relay_hints(tdata["relay_connection_hints"])
skt = transit_sender.establish_connection() record_pipe = transit_sender.connect()
print("Sending (%s).." % transit_sender.describe()) print("Sending (%s).." % transit_sender.describe())
CHUNKSIZE = 64*1024
with open(filename, "rb") as f:
sent = 0 sent = 0
next_update = start_progress(len(encrypted)) next_update = start_progress(filesize)
while sent < len(encrypted): while sent < filesize:
sent += skt.send(encrypted[sent:]) plaintext = f.read(CHUNKSIZE)
next_update = update_progress(next_update, sent, len(encrypted)) record_pipe.send_record(plaintext)
finish_progress(len(encrypted)) sent += len(plaintext)
next_update = update_progress(next_update, sent, filesize)
finish_progress(filesize)
print("File sent.. waiting for confirmation") print("File sent.. waiting for confirmation")
# ack is a short newline-terminated string, followed by socket close. A long ack = record_pipe.receive_record()
# read is probably good enough.
ack = skt.recv(300)
if ack == "ok\n": if ack == "ok\n":
print("Confirmation received. Transfer complete.") print("Confirmation received. Transfer complete.")
return 0 return 0