transit: provide encrypted record-pipe, use it for file-xfer
This commit is contained in:
parent
8b3e5836ee
commit
fcd2678dfd
|
@ -1,6 +1,7 @@
|
|||
from __future__ import print_function
|
||||
import re, time, threading, socket, SocketServer
|
||||
from binascii import hexlify
|
||||
from binascii import hexlify, unhexlify
|
||||
from nacl.secret import SecretBox
|
||||
from ..util import ipaddrs
|
||||
from ..util.hkdf import HKDF
|
||||
from ..const import TRANSIT_RELAY
|
||||
|
@ -212,6 +213,62 @@ class MyTCPServer(SocketServer.TCPServer):
|
|||
t.start()
|
||||
|
||||
|
||||
class TransitClosed(TransitError):
|
||||
pass
|
||||
|
||||
class BadNonce(TransitError):
|
||||
pass
|
||||
|
||||
class ReceiveBuffer:
|
||||
def __init__(self, skt):
|
||||
self.skt = skt
|
||||
self.buf = b""
|
||||
|
||||
def read(self, count):
|
||||
while len(self.buf) < count:
|
||||
more = self.skt.recv(4096)
|
||||
if not more:
|
||||
raise TransitClosed
|
||||
self.buf += more
|
||||
rc = self.buf[:count]
|
||||
self.buf = self.buf[count:]
|
||||
return rc
|
||||
|
||||
class RecordPipe:
|
||||
def __init__(self, skt, send_key, receive_key):
|
||||
self.skt = skt
|
||||
self.send_box = SecretBox(send_key)
|
||||
self.send_nonce = 0
|
||||
self.receive_buf = ReceiveBuffer(self.skt)
|
||||
self.receive_box = SecretBox(receive_key)
|
||||
self.next_receive_nonce = 0
|
||||
|
||||
def send_record(self, record):
|
||||
assert SecretBox.NONCE_SIZE == 24
|
||||
assert self.send_nonce < 2**(8*24)
|
||||
assert len(record) < 2**(8*4)
|
||||
nonce = unhexlify("%048x" % self.send_nonce) # big-endian
|
||||
self.send_nonce += 1
|
||||
encrypted = self.send_box.encrypt(record, nonce)
|
||||
length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long
|
||||
send_to(self.skt, length)
|
||||
send_to(self.skt, encrypted)
|
||||
|
||||
def receive_record(self):
|
||||
length_buf = self.receive_buf.read(4)
|
||||
length = int(hexlify(length_buf), 16)
|
||||
encrypted = self.receive_buf.read(length)
|
||||
nonce_buf = encrypted[:SecretBox.NONCE_SIZE] # assume it's prepended
|
||||
nonce = int(hexlify(nonce_buf), 16)
|
||||
if nonce != self.next_receive_nonce:
|
||||
raise BadNonce("received out-of-order record")
|
||||
self.next_receive_nonce += 1
|
||||
record = self.receive_box.decrypt(encrypted)
|
||||
return record
|
||||
|
||||
def close(self):
|
||||
self.skt.close()
|
||||
|
||||
class Common:
|
||||
def __init__(self):
|
||||
self.winning = threading.Event()
|
||||
|
@ -253,6 +310,22 @@ class Common:
|
|||
else:
|
||||
return build_sender_handshake(self._transit_key) + "go\n"
|
||||
|
||||
def _sender_record_key(self):
|
||||
if self.is_sender:
|
||||
return HKDF(self._transit_key, SecretBox.KEY_SIZE,
|
||||
CTXinfo=b"transit_record_sender_key")
|
||||
else:
|
||||
return HKDF(self._transit_key, SecretBox.KEY_SIZE,
|
||||
CTXinfo=b"transit_record_receiver_key")
|
||||
|
||||
def _receiver_record_key(self):
|
||||
if self.is_sender:
|
||||
return HKDF(self._transit_key, SecretBox.KEY_SIZE,
|
||||
CTXinfo=b"transit_record_receiver_key")
|
||||
else:
|
||||
return HKDF(self._transit_key, SecretBox.KEY_SIZE,
|
||||
CTXinfo=b"transit_record_sender_key")
|
||||
|
||||
def set_transit_key(self, key):
|
||||
# This _have_transit_key condition/lock protects us against the race
|
||||
# where the sender knows the hints and the key, and connects to the
|
||||
|
@ -288,7 +361,7 @@ class Common:
|
|||
for hint in self._their_relay_hints:
|
||||
self._start_connector(hint, is_relay=True)
|
||||
|
||||
def establish_connection(self):
|
||||
def establish_socket(self):
|
||||
start = time.time()
|
||||
self.winning_skt = None
|
||||
self.winning_skt_description = None
|
||||
|
@ -340,6 +413,11 @@ class Common:
|
|||
send_to(skt, "nevermind\n")
|
||||
skt.close()
|
||||
|
||||
def connect(self):
|
||||
skt = self.establish_socket()
|
||||
return RecordPipe(skt, self._sender_record_key(),
|
||||
self._receiver_record_key())
|
||||
|
||||
class TransitSender(Common):
|
||||
is_sender = True
|
||||
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
from __future__ import print_function
|
||||
import sys, os, json
|
||||
from nacl.secret import SecretBox
|
||||
from wormhole.blocking.transcribe import Receiver, WrongPasswordError
|
||||
from wormhole.blocking.transit import TransitReceiver
|
||||
from wormhole.blocking.transit import TransitReceiver, TransitError
|
||||
from .progress import start_progress, update_progress, finish_progress
|
||||
|
||||
APPID = "lothar.com/wormhole/file-xfer"
|
||||
|
@ -28,10 +27,8 @@ def receive_file(so):
|
|||
#print("their data: %r" % (data,))
|
||||
|
||||
file_data = data["file"]
|
||||
xfer_key = r.derive_key(APPID+"/xfer-key", SecretBox.KEY_SIZE)
|
||||
filename = os.path.basename(file_data["filename"]) # unicode
|
||||
filesize = file_data["filesize"]
|
||||
encrypted_filesize = filesize + SecretBox.NONCE_SIZE+16
|
||||
|
||||
# now receive the rest of the owl
|
||||
tdata = data["transit"]
|
||||
|
@ -39,46 +36,47 @@ def receive_file(so):
|
|||
transit_receiver.set_transit_key(transit_key)
|
||||
transit_receiver.add_their_direct_hints(tdata["direct_connection_hints"])
|
||||
transit_receiver.add_their_relay_hints(tdata["relay_connection_hints"])
|
||||
skt = transit_receiver.establish_connection()
|
||||
record_pipe = transit_receiver.connect()
|
||||
|
||||
print("Receiving %d bytes for '%s' (%s).." % (filesize, filename,
|
||||
transit_receiver.describe()))
|
||||
|
||||
encrypted = b""
|
||||
next_update = start_progress(encrypted_filesize)
|
||||
while len(encrypted) < encrypted_filesize:
|
||||
more = skt.recv(encrypted_filesize - len(encrypted))
|
||||
if not more:
|
||||
print()
|
||||
print("Connection dropped before full file received")
|
||||
print("got %d bytes, wanted %d" % (len(encrypted), encrypted_filesize))
|
||||
return 1
|
||||
encrypted += more
|
||||
next_update = update_progress(next_update, len(encrypted),
|
||||
encrypted_filesize)
|
||||
finish_progress(encrypted_filesize)
|
||||
assert len(encrypted) == encrypted_filesize
|
||||
|
||||
print("Decrypting..")
|
||||
decrypted = SecretBox(xfer_key).decrypt(encrypted)
|
||||
|
||||
# only write to the current directory, and never overwrite anything
|
||||
here = os.path.abspath(os.getcwd())
|
||||
target = os.path.abspath(os.path.join(here, filename))
|
||||
if os.path.dirname(target) != here:
|
||||
print("Error: suggested filename (%s) would be outside current directory"
|
||||
% (filename,))
|
||||
skt.send("bad filename\n")
|
||||
skt.close()
|
||||
record_pipe.send_record("bad filename\n")
|
||||
record_pipe.close()
|
||||
return 1
|
||||
if os.path.exists(target):
|
||||
print("Error: refusing to overwrite existing file %s" % (filename,))
|
||||
skt.send("file already exists\n")
|
||||
skt.close()
|
||||
record_pipe.send_record("file already exists\n")
|
||||
record_pipe.close()
|
||||
return 1
|
||||
with open(target, "wb") as f:
|
||||
f.write(decrypted)
|
||||
tmp = target + ".tmp"
|
||||
|
||||
with open(tmp, "wb") as f:
|
||||
received = 0
|
||||
next_update = start_progress(filesize)
|
||||
while received < filesize:
|
||||
try:
|
||||
plaintext = record_pipe.receive_record()
|
||||
except TransitError:
|
||||
print()
|
||||
print("Connection dropped before full file received")
|
||||
print("got %d bytes, wanted %d" % (received, filesize))
|
||||
return 1
|
||||
f.write(plaintext)
|
||||
received += len(plaintext)
|
||||
next_update = update_progress(next_update, received, filesize)
|
||||
finish_progress(filesize)
|
||||
assert received == filesize
|
||||
|
||||
os.rename(tmp, target)
|
||||
|
||||
print("Received file written to %s" % filename)
|
||||
skt.send("ok\n")
|
||||
skt.close()
|
||||
record_pipe.send_record("ok\n")
|
||||
record_pipe.close()
|
||||
return 0
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
from __future__ import print_function
|
||||
import os, sys, json
|
||||
from nacl.secret import SecretBox
|
||||
from wormhole.blocking.transcribe import Initiator, WrongPasswordError
|
||||
from wormhole.blocking.transit import TransitSender
|
||||
from .progress import start_progress, update_progress, finish_progress
|
||||
|
@ -37,36 +36,30 @@ def send_file(so):
|
|||
return 1
|
||||
them_d = json.loads(them_bytes.decode("utf-8"))
|
||||
#print("them: %r" % (them_d,))
|
||||
xfer_key = i.derive_key(APPID+"/xfer-key", SecretBox.KEY_SIZE)
|
||||
|
||||
print("Encrypting %d bytes.." % filesize)
|
||||
|
||||
box = SecretBox(xfer_key)
|
||||
with open(filename, "rb") as f:
|
||||
plaintext = f.read()
|
||||
nonce = os.urandom(SecretBox.NONCE_SIZE)
|
||||
encrypted = box.encrypt(plaintext, nonce)
|
||||
|
||||
tdata = them_d["transit"]
|
||||
transit_key = i.derive_key(APPID+"/transit-key")
|
||||
transit_sender.set_transit_key(transit_key)
|
||||
transit_sender.add_their_direct_hints(tdata["direct_connection_hints"])
|
||||
transit_sender.add_their_relay_hints(tdata["relay_connection_hints"])
|
||||
skt = transit_sender.establish_connection()
|
||||
record_pipe = transit_sender.connect()
|
||||
|
||||
print("Sending (%s).." % transit_sender.describe())
|
||||
|
||||
sent = 0
|
||||
next_update = start_progress(len(encrypted))
|
||||
while sent < len(encrypted):
|
||||
sent += skt.send(encrypted[sent:])
|
||||
next_update = update_progress(next_update, sent, len(encrypted))
|
||||
finish_progress(len(encrypted))
|
||||
CHUNKSIZE = 64*1024
|
||||
with open(filename, "rb") as f:
|
||||
sent = 0
|
||||
next_update = start_progress(filesize)
|
||||
while sent < filesize:
|
||||
plaintext = f.read(CHUNKSIZE)
|
||||
record_pipe.send_record(plaintext)
|
||||
sent += len(plaintext)
|
||||
next_update = update_progress(next_update, sent, filesize)
|
||||
finish_progress(filesize)
|
||||
|
||||
print("File sent.. waiting for confirmation")
|
||||
# ack is a short newline-terminated string, followed by socket close. A long
|
||||
# read is probably good enough.
|
||||
ack = skt.recv(300)
|
||||
ack = record_pipe.receive_record()
|
||||
if ack == "ok\n":
|
||||
print("Confirmation received. Transfer complete.")
|
||||
return 0
|
||||
|
|
Loading…
Reference in New Issue
Block a user