diff --git a/src/wormhole/scripts/cmd_receive_blocking.py b/src/wormhole/scripts/cmd_receive_blocking.py index 562a35d..e55a2ec 100644 --- a/src/wormhole/scripts/cmd_receive_blocking.py +++ b/src/wormhole/scripts/cmd_receive_blocking.py @@ -2,107 +2,146 @@ from __future__ import print_function import io, os, sys, json, binascii, six, tempfile, zipfile from ..blocking.transcribe import Wormhole, WrongPasswordError from ..blocking.transit import TransitReceiver, TransitError -from ..errors import handle_server_error +from ..errors import handle_server_error, TransferError from .progress import ProgressPrinter APPID = u"lothar.com/wormhole/text-or-file-xfer" -@handle_server_error def receive_blocking(args): - # we're receiving text, or a file - assert isinstance(args.relay_url, type(u"")) + return BlockingReceiver(args).go() - with Wormhole(APPID, args.relay_url) as w: - if args.zeromode: - assert not args.code - args.code = u"0-" - code = args.code +class RespondError(Exception): + def __init__(self, response): + self.response = response + +class BlockingReceiver: + def __init__(self, args): + assert isinstance(args.relay_url, type(u"")) + self.args = args + + def msg(self, *args, **kwargs): + print(*args, file=self.args.stdout, **kwargs) + + @handle_server_error + def go(self): + with Wormhole(APPID, self.args.relay_url) as w: + self.handle_code(w) + verifier = w.get_verifier() + self.show_verifier(verifier) + them_d = self.get_data(w) + try: + if "message" in them_d: + self.handle_text(them_d, w) + return 0 + if "file" in them_d: + f = self.handle_file(them_d) + rp = self.establish_transit(w, them_d) + self.transfer_data(rp, f) + self.write_file(f) + self.close_transit(rp) + elif "directory" in them_d: + f = self.handle_directory(them_d) + rp = self.establish_transit(w, them_d) + self.transfer_data(rp, f) + self.write_directory(f) + self.close_transit(rp) + else: + self.msg(u"I don't know what they're offering\n") + self.msg(u"Offer details:", them_d) + raise RespondError({"error": "unknown offer type"}) + except RespondError as r: + data = json.dumps(r.response).encode("utf-8") + w.send_data(data) + return 1 + return 0 + + def handle_code(self, w): + code = self.args.code + if self.args.zeromode: + assert not code + code = u"0-" if not code: code = w.input_code("Enter receive wormhole code: ", - args.code_length) + self.args.code_length) w.set_code(code) - verifier = binascii.hexlify(w.get_verifier()).decode("ascii") - if args.verify: - print(u"Verifier %s." % verifier, file=args.stdout) + def show_verifier(self, verifier): + verifier_hex = binascii.hexlify(verifier).decode("ascii") + if self.args.verify: + self.msg(u"Verifier %s." % verifier_hex) + def get_data(self, w): try: them_bytes = w.get_data() except WrongPasswordError as e: - print(u"ERROR: " + e.explain(), file=sys.stderr) - return 1 + raise TransferError(u"ERROR: " + e.explain()) them_d = json.loads(them_bytes.decode("utf-8")) if "error" in them_d: - print(u"ERROR: " + them_d["error"], file=sys.stderr) - return 1 + raise TransferError(u"ERROR: " + them_d["error"]) + return them_d - if "message" in them_d: - # we're receiving a text message - print(them_d["message"], file=args.stdout) - data = json.dumps({"message_ack": "ok"}).encode("utf-8") - w.send_data(data) - return 0 + def handle_text(self, them_d, w): + # we're receiving a text message + self.msg(them_d["message"]) + data = json.dumps({"message_ack": "ok"}).encode("utf-8") + w.send_data(data) - if "error" in them_d: - print(u"ERROR: " + data["error"], file=sys.stderr) - return 1 + def handle_file(self, them_d): + file_data = them_d["file"] + self.abs_destname = self.decide_destname("file", + file_data["filename"]) + self.xfersize = file_data["filesize"] - if "file" in them_d: - mode = "file" - file_data = them_d["file"] - # the basename() is intended to protect us against - # "~/.ssh/authorized_keys" and other attacks - destname = os.path.basename(file_data["filename"]) # unicode - xfersize = file_data["filesize"] - elif "directory" in them_d: - mode = "directory" - file_data = them_d["directory"] - zipmode = file_data["mode"] - if zipmode != "zipfile/deflated": - print(u"Error: unknown directory-transfer mode '%s'" % - (zipmode,), file=args.stdout) - data = json.dumps({"error": "unknown mode"}).encode("utf-8") - w.send_data(data) - return 1 - destname = os.path.basename(file_data["dirname"]) # unicode - xfersize = file_data["zipsize"] - num_files = file_data["numfiles"] - num_bytes = file_data["numbytes"] - else: - print(u"I don't know what they're offering\n", file=args.stdout) - print(u"Offer details:", them_d, file=args.stdout) - data = json.dumps({"error": "unknown offer type"}).encode("utf-8") - w.send_data(data) - return 1 + self.msg(u"Receiving file (%d bytes) into: %s" % + (self.xfersize, os.path.basename(self.abs_destname))) + self.ask_permission() + tmp_destname = self.abs_destname + ".tmp" + return open(tmp_destname, "wb") - if args.output_file: - destname = args.output_file # override - abs_destname = os.path.join(args.cwd, destname) + def handle_directory(self, them_d): + file_data = them_d["directory"] + zipmode = file_data["mode"] + if zipmode != "zipfile/deflated": + self.msg(u"Error: unknown directory-transfer mode '%s'" % (zipmode,)) + raise RespondError({"error": "unknown mode"}) + self.abs_destname = self.decide_destname("directory", + file_data["dirname"]) + self.xfersize = file_data["zipsize"] + + self.msg(u"Receiving directory (%d bytes) into: %s/" % + (self.xfersize, os.path.basename(self.abs_destname))) + self.msg(u"%d files, %d bytes (uncompressed)" % + (file_data["numfiles"], file_data["numbytes"])) + self.ask_permission() + return tempfile.SpooledTemporaryFile() + + def decide_destname(self, mode, destname): + # the basename() is intended to protect us against + # "~/.ssh/authorized_keys" and other attacks + destname = os.path.basename(destname) + if self.args.output_file: + destname = self.args.output_file # override + abs_destname = os.path.join(self.args.cwd, destname) # get confirmation from the user before writing to the local directory if os.path.exists(abs_destname): - print(u"Error: refusing to overwrite existing %s %s" % - (mode, destname), file=args.stdout) - data = json.dumps({"error": "%s already exists" % mode}).encode("utf-8") - w.send_data(data) - return 1 - # TODO: add / to destname - print(u"Receiving %s (%d bytes) into: %s" % (mode, xfersize, destname), - file=args.stdout) - if mode == "directory": - print(u"%d files, %d bytes (uncompressed)" % - (num_files, num_bytes), file=args.stdout) + self.msg(u"Error: refusing to overwrite existing %s %s" % + (mode, destname)) + raise RespondError({"error": "%s already exists" % mode}) + return abs_destname - while True and not args.accept_file: + def ask_permission(self): + while True and not self.args.accept_file: ok = six.moves.input("ok? (y/n): ") if ok.lower().startswith("y"): break print(u"transfer rejected", file=sys.stderr) - data = json.dumps({"error": "transfer rejected"}).encode("utf-8") - w.send_data(data) - return 1 + raise RespondError({"error": "transfer rejected"}) - transit_receiver = TransitReceiver(args.transit_helper) + def establish_transit(self, w, them_d): + transit_key = w.derive_key(APPID+u"/transit-key") + transit_receiver = TransitReceiver(self.args.transit_helper) + transit_receiver.set_transit_key(transit_key) data = json.dumps({ "file_ack": "ok", "transit": { @@ -111,63 +150,56 @@ def receive_blocking(args): }, }).encode("utf-8") w.send_data(data) - # now done with the Wormhole object # now receive the rest of the owl tdata = them_d["transit"] - transit_key = w.derive_key(APPID+u"/transit-key") - transit_receiver.set_transit_key(transit_key) transit_receiver.add_their_direct_hints(tdata["direct_connection_hints"]) transit_receiver.add_their_relay_hints(tdata["relay_connection_hints"]) record_pipe = transit_receiver.connect() + return record_pipe - print(u"Receiving %d bytes for '%s' (%s).." % - (xfersize, destname, record_pipe.describe()), - file=args.stdout) - if mode == "file": - tmp_destname = abs_destname + ".tmp" - f = open(tmp_destname, "wb") - else: - f = tempfile.SpooledTemporaryFile() + def transfer_data(self, record_pipe, f): + self.msg(u"Receiving (%s).." % record_pipe.describe()) - progress_stdout = args.stdout - if args.hide_progress: + progress_stdout = self.args.stdout + if self.args.hide_progress: progress_stdout = io.StringIO() received = 0 - p = ProgressPrinter(xfersize, progress_stdout) + p = ProgressPrinter(self.xfersize, progress_stdout) p.start() - while received < xfersize: + while received < self.xfersize: try: plaintext = record_pipe.receive_record() except TransitError: - print(u"", file=args.stdout) - print(u"Connection dropped before full file received", - file=args.stdout) - print(u"got %d bytes, wanted %d" % (received, xfersize), - file=args.stdout) + self.msg() + self.msg(u"Connection dropped before full file received") + self.msg(u"got %d bytes, wanted %d" % (received, self.xfersize)) return 1 f.write(plaintext) received += len(plaintext) p.update(received) p.finish() - assert received == xfersize + assert received == self.xfersize - if mode == "file": - f.close() - os.rename(tmp_destname, abs_destname) - print(u"Received file written to %s" % destname, file=args.stdout) - else: - print(u"Unpacking zipfile..", file=args.stdout) - with zipfile.ZipFile(f, "r", zipfile.ZIP_DEFLATED) as zf: - zf.extractall(path=abs_destname) - # extractall() appears to offer some protection against - # malicious pathnames. For example, "/tmp/oops" and - # "../tmp/oops" both do the same thing as the (safe) - # "tmp/oops". - print(u"Received files written to %s/" % destname, file=args.stdout) - f.close() + def write_file(self, f): + tmp_name = f.name + f.close() + os.rename(tmp_name, self.abs_destname) + self.msg(u"Received file written to %s" % + os.path.basename(self.abs_destname)) + def write_directory(self, f): + self.msg(u"Unpacking zipfile..") + with zipfile.ZipFile(f, "r", zipfile.ZIP_DEFLATED) as zf: + zf.extractall(path=self.abs_destname) + # extractall() appears to offer some protection against + # malicious pathnames. For example, "/tmp/oops" and + # "../tmp/oops" both do the same thing as the (safe) + # "tmp/oops". + self.msg(u"Received files written to %s/" % + os.path.basename(self.abs_destname)) + f.close() + + def close_transit(self, record_pipe): record_pipe.send_record(b"ok\n") record_pipe.close() - return 0 - diff --git a/src/wormhole/test/test_scripts.py b/src/wormhole/test/test_scripts.py index 0d22737..fe63b33 100644 --- a/src/wormhole/test/test_scripts.py +++ b/src/wormhole/test/test_scripts.py @@ -360,7 +360,7 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): if mode == "text": self.failUnlessEqual(receive_stdout, message+NL) elif mode == "file": - self.failUnlessIn("Receiving {bytes:d} bytes for '{name}'" + self.failUnlessIn("Receiving file ({bytes:d} bytes) into: {name}" .format(bytes=len(message), name=receive_filename), receive_stdout) self.failUnlessIn("Received file written to ", receive_stdout) @@ -369,9 +369,10 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): with open(fn, "r") as f: self.failUnlessEqual(f.read(), message) elif mode == "directory": - self.failUnless(re.search(r"Receiving \d+ bytes for '{name}'" - .format(name=receive_dirname), - receive_stdout)) + want = (r"Receiving directory \(\d+ bytes\) into: {name}/" + .format(name=receive_dirname)) + self.failUnless(re.search(want, receive_stdout), + (want, receive_stdout)) self.failUnlessIn("Received files written to {name}" .format(name=receive_dirname), receive_stdout) fn = os.path.join(receive_dir, receive_dirname)