INCOMPATIBILITY: send+expect hash of data after xfer

This enhances the ACK that wormhole-receive returns when it finishes
receiving all the data to be a dictionary. The dict includes the SHA256
hash of everything it received, and the sender checks this for a match
before declaring the transfer to be a success. This guards against data
being shuffled somehow during transit.
This commit is contained in:
Brian Warner 2016-05-25 19:36:56 -07:00
parent d8f6126916
commit 708bcf36d4
2 changed files with 32 additions and 15 deletions

View File

@ -1,5 +1,5 @@
from __future__ import print_function from __future__ import print_function
import os, sys, json, binascii, six, tempfile, zipfile import os, sys, json, binascii, six, tempfile, zipfile, hashlib
from tqdm import tqdm from tqdm import tqdm
from twisted.internet import reactor from twisted.internet import reactor
from twisted.internet.defer import inlineCallbacks, returnValue from twisted.internet.defer import inlineCallbacks, returnValue
@ -164,16 +164,16 @@ class TwistedReceiver:
f = self._handle_file(them_d) f = self._handle_file(them_d)
self._send_permission(w) self._send_permission(w)
rp = yield self._establish_transit() rp = yield self._establish_transit()
yield self._transfer_data(rp, f) datahash = yield self._transfer_data(rp, f)
self._write_file(f) self._write_file(f)
yield self._close_transit(rp) yield self._close_transit(rp, datahash)
elif "directory" in them_d: elif "directory" in them_d:
f = self._handle_directory(them_d) f = self._handle_directory(them_d)
self._send_permission(w) self._send_permission(w)
rp = yield self._establish_transit() rp = yield self._establish_transit()
yield self._transfer_data(rp, f) datahash = yield self._transfer_data(rp, f)
self._write_directory(f) self._write_directory(f)
yield self._close_transit(rp) yield self._close_transit(rp, datahash)
else: else:
self._msg(u"I don't know what they're offering\n") self._msg(u"I don't know what they're offering\n")
self._msg(u"Offer details: %r" % (them_d,)) self._msg(u"Offer details: %r" % (them_d,))
@ -257,9 +257,12 @@ class TwistedReceiver:
progress = tqdm(file=self.args.stdout, progress = tqdm(file=self.args.stdout,
disable=self.args.hide_progress, disable=self.args.hide_progress,
unit="B", unit_scale=True, total=self.xfersize) unit="B", unit_scale=True, total=self.xfersize)
hasher = hashlib.sha256()
with progress: with progress:
received = yield record_pipe.writeToFile(f, self.xfersize, received = yield record_pipe.writeToFile(f, self.xfersize,
progress.update) progress.update,
hasher.update)
datahash = hasher.digest()
# except TransitError # except TransitError
if received < self.xfersize: if received < self.xfersize:
@ -268,6 +271,7 @@ class TwistedReceiver:
self._msg(u"got %d bytes, wanted %d" % (received, self.xfersize)) self._msg(u"got %d bytes, wanted %d" % (received, self.xfersize))
raise TransferError("Connection dropped before full file received") raise TransferError("Connection dropped before full file received")
assert received == self.xfersize assert received == self.xfersize
returnValue(datahash)
def _write_file(self, f): def _write_file(self, f):
tmp_name = f.name tmp_name = f.name
@ -290,7 +294,10 @@ class TwistedReceiver:
f.close() f.close()
@inlineCallbacks @inlineCallbacks
def _close_transit(self, record_pipe): def _close_transit(self, record_pipe, datahash):
datahash_hex = binascii.hexlify(datahash).decode("ascii")
ack = {u"ack": u"ok", u"sha256": datahash_hex}
ack_bytes = json.dumps(ack).encode("utf-8")
with self.args.timing.add("send ack"): with self.args.timing.add("send ack"):
yield record_pipe.send_record(b"ok\n") yield record_pipe.send_record(ack_bytes)
yield record_pipe.close() yield record_pipe.close()

View File

@ -1,5 +1,5 @@
from __future__ import print_function from __future__ import print_function
import os, sys, json, binascii, six, tempfile, zipfile import os, sys, json, binascii, six, tempfile, zipfile, hashlib
from tqdm import tqdm from tqdm import tqdm
from twisted.python import log from twisted.python import log
from twisted.protocols import basic from twisted.protocols import basic
@ -236,11 +236,11 @@ class Sender:
raise TransferError("ambiguous response from remote, " raise TransferError("ambiguous response from remote, "
"transfer abandoned: %s" % (them_answer,)) "transfer abandoned: %s" % (them_answer,))
yield self._send_file_twisted() yield self._send_file()
@inlineCallbacks @inlineCallbacks
def _send_file_twisted(self): def _send_file(self):
ts = self._transit_sender ts = self._transit_sender
self._fd_to_send.seek(0,2) self._fd_to_send.seek(0,2)
@ -253,10 +253,12 @@ class Sender:
stdout = self._args.stdout stdout = self._args.stdout
print(u"Sending (%s).." % record_pipe.describe(), file=stdout) print(u"Sending (%s).." % record_pipe.describe(), file=stdout)
hasher = hashlib.sha256()
progress = tqdm(file=stdout, disable=self._args.hide_progress, progress = tqdm(file=stdout, disable=self._args.hide_progress,
unit="B", unit_scale=True, unit="B", unit_scale=True,
total=filesize) total=filesize)
def _count(data): def _count_and_hash(data):
hasher.update(data)
progress.update(len(data)) progress.update(len(data))
return data return data
fs = basic.FileSender() fs = basic.FileSender()
@ -264,14 +266,22 @@ class Sender:
with self._timing.add("tx file"): with self._timing.add("tx file"):
with progress: with progress:
yield fs.beginFileTransfer(self._fd_to_send, record_pipe, yield fs.beginFileTransfer(self._fd_to_send, record_pipe,
transform=_count) transform=_count_and_hash)
expected_hash = hasher.digest()
expected_hex = binascii.hexlify(expected_hash).decode("ascii")
print(u"File sent.. waiting for confirmation", file=stdout) print(u"File sent.. waiting for confirmation", file=stdout)
with self._timing.add("get ack") as t: with self._timing.add("get ack") as t:
ack = yield record_pipe.receive_record() ack_bytes = yield record_pipe.receive_record()
record_pipe.close() record_pipe.close()
if ack != b"ok\n": ack = json.loads(ack_bytes.decode("utf-8"))
ok = ack.get(u"ack", u"")
if ok != u"ok":
t.detail(ack="failed") t.detail(ack="failed")
raise TransferError("Transfer failed (remote says: %r)" % ack) raise TransferError("Transfer failed (remote says: %r)" % ack)
if u"sha256" in ack:
if ack[u"sha256"] != expected_hex:
t.detail(datahash="failed")
raise TransferError("Transfer failed (bad remote hash)")
print(u"Confirmation received. Transfer complete.", file=stdout) print(u"Confirmation received. Transfer complete.", file=stdout)
t.detail(ack="ok") t.detail(ack="ok")