From 14285079091b566a50c37bcf6653e759e36e4fa6 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Wed, 25 Nov 2015 01:24:07 -0600 Subject: [PATCH] refactor cmd_receive.py, split accept_file() to a separate function --- src/wormhole/scripts/cmd_receive.py | 155 +++++++++++++++------------- 1 file changed, 81 insertions(+), 74 deletions(-) diff --git a/src/wormhole/scripts/cmd_receive.py b/src/wormhole/scripts/cmd_receive.py index 2936638..10eab3d 100644 --- a/src/wormhole/scripts/cmd_receive.py +++ b/src/wormhole/scripts/cmd_receive.py @@ -4,85 +4,43 @@ from ..errors import handle_server_error APPID = u"lothar.com/wormhole/text-or-file-xfer" -@handle_server_error -def receive(args): - # we're receiving text, or a file - from ..blocking.transcribe import Wormhole, WrongPasswordError +def accept_file(args, them_d, w): from ..blocking.transit import TransitReceiver, TransitError from .progress import start_progress, update_progress, finish_progress - assert isinstance(args.relay_url, type(u"")) - with Wormhole(APPID, args.relay_url) as w: - if args.zeromode: - assert not args.code - args.code = u"0-" - code = args.code - if not code: - code = w.input_code("Enter receive wormhole code: ", args.code_length) - w.set_code(code) + file_data = them_d["file"] + # the basename() is intended to protect us against + # "~/.ssh/authorized_keys" and other attacks + filename = os.path.basename(file_data["filename"]) # unicode + filesize = file_data["filesize"] - if args.verify: - verifier = binascii.hexlify(w.get_verifier()).decode("ascii") - print("Verifier %s." % verifier) - - try: - them_bytes = w.get_data() - except WrongPasswordError as e: - print("ERROR: " + e.explain(), file=sys.stderr) - return 1 - them_d = json.loads(them_bytes.decode("utf-8")) - if "error" in them_d: - print("ERROR: " + them_d["error"], file=sys.stderr) - return 1 - - if "message" in them_d: - # we're receiving a text message - print(them_d["message"]) - data = json.dumps({"message_ack": "ok"}).encode("utf-8") - w.send_data(data) - return 0 - - if not "file" in them_d: - print("I don't know what they're offering\n") - print(them_d) - return 1 - - if "error" in them_d: - print("ERROR: " + data["error"], file=sys.stderr) - return 1 - - file_data = them_d["file"] - # the basename() is intended to protect us against - # "~/.ssh/authorized_keys" and other attacks - filename = os.path.basename(file_data["filename"]) # unicode - filesize = file_data["filesize"] - - # get confirmation from the user before writing to the local directory - if os.path.exists(filename): - print("Error: refusing to overwrite existing file %s" % (filename,)) - data = json.dumps({"error": "file already exists"}).encode("utf-8") - w.send_data(data) - return 1 - - print("Receiving file (%d bytes) into: %s" % (filesize, filename)) - while True and not args.accept_file: - ok = six.moves.input("ok? (y/n): ") - if ok.lower().startswith("y"): - break - print("transfer rejected", file=sys.stderr) - data = json.dumps({"error": "transfer rejected"}).encode("utf-8") - w.send_data(data) - return 1 - - transit_receiver = TransitReceiver(args.transit_helper) - data = json.dumps({ - "file_ack": "ok", - "transit": { - "direct_connection_hints": transit_receiver.get_direct_hints(), - "relay_connection_hints": transit_receiver.get_relay_hints(), - }, - }).encode("utf-8") + # get confirmation from the user before writing to the local directory + if os.path.exists(filename): + print("Error: refusing to overwrite existing file %s" % (filename,)) + data = json.dumps({"error": "file already exists"}).encode("utf-8") w.send_data(data) + return 1 + + print("Receiving file (%d bytes) into: %s" % (filesize, filename)) + while True and not args.accept_file: + ok = six.moves.input("ok? (y/n): ") + if ok.lower().startswith("y"): + break + print("transfer rejected", file=sys.stderr) + data = json.dumps({"error": "transfer rejected"}).encode("utf-8") + w.send_data(data) + return 1 + + transit_receiver = TransitReceiver(args.transit_helper) + data = json.dumps({ + "file_ack": "ok", + "transit": { + "direct_connection_hints": transit_receiver.get_direct_hints(), + "relay_connection_hints": transit_receiver.get_relay_hints(), + }, + }).encode("utf-8") + w.send_data(data) + # now done with the Wormhole object # now receive the rest of the owl tdata = them_d["transit"] @@ -118,3 +76,52 @@ def receive(args): record_pipe.send_record(b"ok\n") record_pipe.close() return 0 + +@handle_server_error +def receive(args): + # we're receiving text, or a file + from ..blocking.transcribe import Wormhole, WrongPasswordError + assert isinstance(args.relay_url, type(u"")) + + with Wormhole(APPID, args.relay_url) as w: + if args.zeromode: + assert not args.code + args.code = u"0-" + code = args.code + if not code: + code = w.input_code("Enter receive wormhole code: ", args.code_length) + w.set_code(code) + + if args.verify: + verifier = binascii.hexlify(w.get_verifier()).decode("ascii") + print("Verifier %s." % verifier) + + try: + them_bytes = w.get_data() + except WrongPasswordError as e: + print("ERROR: " + e.explain(), file=sys.stderr) + return 1 + them_d = json.loads(them_bytes.decode("utf-8")) + if "error" in them_d: + print("ERROR: " + them_d["error"], file=sys.stderr) + return 1 + + if "message" in them_d: + # we're receiving a text message + print(them_d["message"]) + data = json.dumps({"message_ack": "ok"}).encode("utf-8") + w.send_data(data) + return 0 + + if "error" in them_d: + print("ERROR: " + data["error"], file=sys.stderr) + return 1 + + if "file" in them_d: + return accept_file(args, them_d, w) + + print("I don't know what they're offering\n") + print("Offer details:", them_d) + data = json.dumps({"error": "unknown offer type"}).encode("utf-8") + w.send_data(data) + return 1