diff --git a/src/wormhole/scripts/cmd_send_twisted.py b/src/wormhole/scripts/cmd_send_twisted.py new file mode 100644 index 0000000..2de6785 --- /dev/null +++ b/src/wormhole/scripts/cmd_send_twisted.py @@ -0,0 +1,145 @@ +from __future__ import print_function +import io, json, binascii, six +from twisted.protocols import basic +from twisted.internet import reactor +from twisted.internet.defer import inlineCallbacks, returnValue +from ..errors import TransferError +from .progress import ProgressPrinter +from ..twisted.transcribe import Wormhole, WrongPasswordError +from ..twisted.transit import TransitSender +from .send_common import (APPID, handle_zero, build_other_command, + build_phase1_data) + +def send_twisted_sync(args): + d = send_twisted(args) + # try to use twisted.internet.task.react(f) here (but it calls sys.exit + # directly) + rc = [] + def _done(res): + rc.extend([True, res]) + reactor.stop() + def _err(f): + rc.extend([False, f.value]) + reactor.stop() + d.addCallbacks(_done, _err) + reactor.run() + if rc[0]: + return rc[1] + raise rc[1] + +@inlineCallbacks +def send_twisted(args): + assert isinstance(args.relay_url, type(u"")) + handle_zero(args) + phase1, fd_to_send = build_phase1_data(args) + other_cmd = build_other_command(args) + print(u"On the other computer, please run: %s" % other_cmd, + file=args.stdout) + + w = Wormhole(APPID, args.relay_url) + + if fd_to_send: + transit_sender = TransitSender(args.transit_helper) + phase1["transit"] = transit_data = {} + transit_data["relay_connection_hints"] = transit_sender.get_relay_hints() + direct_hints = yield transit_sender.get_direct_hints() + transit_data["direct_connection_hints"] = direct_hints + + if args.code: + w.set_code(args.code) + code = args.code + else: + code = yield w.get_code(args.code_length) + + if not args.zeromode: + print(u"Wormhole code is: %s" % code, file=args.stdout) + print(u"", file=args.stdout) + + if args.verify: + verifier_bytes = yield w.get_verifier() + verifier = binascii.hexlify(verifier_bytes).decode("ascii") + while True: + ok = six.moves.input("Verifier %s. ok? (yes/no): " % verifier) + if ok.lower() == "yes": + break + if ok.lower() == "no": + reject_data = json.dumps({"error": "verification rejected", + }).encode("utf-8") + yield w.send_data(reject_data) + raise TransferError("verification rejected, abandoning transfer") + + my_phase1_bytes = json.dumps(phase1).encode("utf-8") + yield w.send_data(my_phase1_bytes) + + try: + them_phase1_bytes = yield w.get_data() + except WrongPasswordError as e: + raise TransferError(e.explain()) + + them_phase1 = json.loads(them_phase1_bytes.decode("utf-8")) + + if fd_to_send is None: + if them_phase1["message_ack"] == "ok": + print(u"text message sent", file=args.stdout) + yield w.close() + returnValue(0) # terminates this function + raise TransferError("error sending text: %r" % (them_phase1,)) + + if "error" in them_phase1: + raise TransferError("remote error, transfer abandoned: %s" + % them_phase1["error"]) + if them_phase1.get("file_ack") != "ok": + raise TransferError("ambiguous response from remote, " + "transfer abandoned: %s" % (them_phase1,)) + tdata = them_phase1["transit"] + # this is happening too late: the other side already connects to our + # server + transit_key = w.derive_key(APPID+"/transit-key") + transit_sender.set_transit_key(transit_key) + yield w.close() + yield _send_file_twisted(tdata, transit_sender, fd_to_send, + args.stdout, args.hide_progress) + returnValue(0) + +class ProgressingFileSender(basic.FileSender): + def __init__(self, filesize, stdout): + self._sent = 0 + self._progress = ProgressPrinter(filesize, stdout) + self._progress.start() + def progress(self, data): + self._sent += len(data) + self._progress.update(self._sent) + return data + def beginFileTransfer(self, file, consumer): + d = basic.FileSender.beginFileTransfer(self, file, consumer, + self.progress) + d.addCallback(self.done) + return d + def done(self, res): + self._progress.finish() + return res + +@inlineCallbacks +def _send_file_twisted(tdata, transit_sender, fd_to_send, + stdout, hide_progress): + transit_sender.add_their_direct_hints(tdata["direct_connection_hints"]) + transit_sender.add_their_relay_hints(tdata["relay_connection_hints"]) + + fd_to_send.seek(0,2) + filesize = fd_to_send.tell() + fd_to_send.seek(0,0) + progress_stdout = stdout + if hide_progress: + progress_stdout = io.StringIO() + pfs = ProgressingFileSender(filesize, progress_stdout) + + record_pipe = yield transit_sender.connect() + # record_pipe should implement IConsumer, chunks are just records + print(u"Sending (%s).." % transit_sender.describe(), file=stdout) + yield pfs.beginFileTransfer(fd_to_send, record_pipe) + print(u"File sent.. waiting for confirmation", file=stdout) + ack = yield record_pipe.receive_record() + record_pipe.close() + if ack != b"ok\n": + raise TransferError("Transfer failed (remote says: %r)" % ack) + print(u"Confirmation received. Transfer complete.", file=stdout) diff --git a/src/wormhole/test/test_scripts.py b/src/wormhole/test/test_scripts.py index 1e9767c..da7d8b1 100644 --- a/src/wormhole/test/test_scripts.py +++ b/src/wormhole/test/test_scripts.py @@ -1,3 +1,4 @@ +from __future__ import print_function import os, sys, re, io, zipfile from twisted.trial import unittest from twisted.python import procutils, log @@ -6,7 +7,7 @@ from twisted.internet.defer import inlineCallbacks from twisted.internet.threads import deferToThread from .. import __version__ from .common import ServerBase -from ..scripts import runner, cmd_send_blocking, cmd_receive +from ..scripts import runner, cmd_send_blocking, cmd_send_twisted, cmd_receive from ..scripts.send_common import build_phase1_data from ..errors import TransferError @@ -207,7 +208,8 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): @inlineCallbacks def _do_test(self, as_subprocess=False, - mode="text", override_filename=False): + mode="text", override_filename=False, + sender_twisted=False, receiver_twisted=False): assert mode in ("text", "file", "directory") common_args = ["--hide-progress", "--relay-url", self.relayurl, @@ -295,7 +297,11 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): rargs.cwd = receive_dir rargs.stdout = io.StringIO() rargs.stderr = io.StringIO() - send_d = deferToThread(cmd_send_blocking.send_blocking, sargs) + if sender_twisted: + send_d = cmd_send_twisted.send_twisted(sargs) + else: + send_d = deferToThread(cmd_send_blocking.send_blocking, sargs) + assert not receiver_twisted # not importable yet receive_d = deferToThread(cmd_receive.receive, rargs) send_rc = yield send_d @@ -308,8 +314,10 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): self.maxDiff = None # show full output for assertion failures - # check sender self.failUnlessEqual(send_stderr, "") + self.failUnlessEqual(receive_stderr, "") + + # check sender if mode == "text": expected = ("Sending text message (%d bytes)\n" "On the other computer, please run: " @@ -337,10 +345,8 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): self.failUnlessIn("File sent.. waiting for confirmation\n" "Confirmation received. Transfer complete.\n", send_stdout) - self.failUnlessEqual(send_rc, 0) # check receiver - self.failUnlessEqual(receive_stderr, "") if mode == "text": self.failUnlessEqual(receive_stdout, message+"\n") elif mode == "file": @@ -362,19 +368,27 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): fn = os.path.join(receive_dir, receive_dirname, str(i)) with open(fn, "r") as f: self.failUnlessEqual(f.read(), message(i)) + + self.failUnlessEqual(send_rc, 0) self.failUnlessEqual(receive_rc, 0) def test_text(self): return self._do_test() def test_text_subprocess(self): return self._do_test(as_subprocess=True) + def test_text_twisted_to_blocking(self): + return self._do_test(sender_twisted=True) def test_file(self): return self._do_test(mode="file") def test_file_override(self): return self._do_test(mode="file", override_filename=True) + def test_file_twisted_to_blocking(self): + return self._do_test(mode="file", sender_twisted=True) def test_directory(self): return self._do_test(mode="directory") def test_directory_override(self): return self._do_test(mode="directory", override_filename=True) + def test_directory_twisted_to_blocking(self): + return self._do_test(mode="directory", sender_twisted=True)