add twisted form of sender

Currently this is only invokable from tests.
This commit is contained in:
Brian Warner 2016-02-17 17:22:54 -08:00
parent 7ceffd783a
commit aa27bfd32c
2 changed files with 165 additions and 6 deletions

View File

@ -0,0 +1,145 @@
from __future__ import print_function
import io, json, binascii, six
from twisted.protocols import basic
from twisted.internet import reactor
from twisted.internet.defer import inlineCallbacks, returnValue
from ..errors import TransferError
from .progress import ProgressPrinter
from ..twisted.transcribe import Wormhole, WrongPasswordError
from ..twisted.transit import TransitSender
from .send_common import (APPID, handle_zero, build_other_command,
build_phase1_data)
def send_twisted_sync(args):
d = send_twisted(args)
# try to use twisted.internet.task.react(f) here (but it calls sys.exit
# directly)
rc = []
def _done(res):
rc.extend([True, res])
reactor.stop()
def _err(f):
rc.extend([False, f.value])
reactor.stop()
d.addCallbacks(_done, _err)
reactor.run()
if rc[0]:
return rc[1]
raise rc[1]
@inlineCallbacks
def send_twisted(args):
assert isinstance(args.relay_url, type(u""))
handle_zero(args)
phase1, fd_to_send = build_phase1_data(args)
other_cmd = build_other_command(args)
print(u"On the other computer, please run: %s" % other_cmd,
file=args.stdout)
w = Wormhole(APPID, args.relay_url)
if fd_to_send:
transit_sender = TransitSender(args.transit_helper)
phase1["transit"] = transit_data = {}
transit_data["relay_connection_hints"] = transit_sender.get_relay_hints()
direct_hints = yield transit_sender.get_direct_hints()
transit_data["direct_connection_hints"] = direct_hints
if args.code:
w.set_code(args.code)
code = args.code
else:
code = yield w.get_code(args.code_length)
if not args.zeromode:
print(u"Wormhole code is: %s" % code, file=args.stdout)
print(u"", file=args.stdout)
if args.verify:
verifier_bytes = yield w.get_verifier()
verifier = binascii.hexlify(verifier_bytes).decode("ascii")
while True:
ok = six.moves.input("Verifier %s. ok? (yes/no): " % verifier)
if ok.lower() == "yes":
break
if ok.lower() == "no":
reject_data = json.dumps({"error": "verification rejected",
}).encode("utf-8")
yield w.send_data(reject_data)
raise TransferError("verification rejected, abandoning transfer")
my_phase1_bytes = json.dumps(phase1).encode("utf-8")
yield w.send_data(my_phase1_bytes)
try:
them_phase1_bytes = yield w.get_data()
except WrongPasswordError as e:
raise TransferError(e.explain())
them_phase1 = json.loads(them_phase1_bytes.decode("utf-8"))
if fd_to_send is None:
if them_phase1["message_ack"] == "ok":
print(u"text message sent", file=args.stdout)
yield w.close()
returnValue(0) # terminates this function
raise TransferError("error sending text: %r" % (them_phase1,))
if "error" in them_phase1:
raise TransferError("remote error, transfer abandoned: %s"
% them_phase1["error"])
if them_phase1.get("file_ack") != "ok":
raise TransferError("ambiguous response from remote, "
"transfer abandoned: %s" % (them_phase1,))
tdata = them_phase1["transit"]
# this is happening too late: the other side already connects to our
# server
transit_key = w.derive_key(APPID+"/transit-key")
transit_sender.set_transit_key(transit_key)
yield w.close()
yield _send_file_twisted(tdata, transit_sender, fd_to_send,
args.stdout, args.hide_progress)
returnValue(0)
class ProgressingFileSender(basic.FileSender):
def __init__(self, filesize, stdout):
self._sent = 0
self._progress = ProgressPrinter(filesize, stdout)
self._progress.start()
def progress(self, data):
self._sent += len(data)
self._progress.update(self._sent)
return data
def beginFileTransfer(self, file, consumer):
d = basic.FileSender.beginFileTransfer(self, file, consumer,
self.progress)
d.addCallback(self.done)
return d
def done(self, res):
self._progress.finish()
return res
@inlineCallbacks
def _send_file_twisted(tdata, transit_sender, fd_to_send,
stdout, hide_progress):
transit_sender.add_their_direct_hints(tdata["direct_connection_hints"])
transit_sender.add_their_relay_hints(tdata["relay_connection_hints"])
fd_to_send.seek(0,2)
filesize = fd_to_send.tell()
fd_to_send.seek(0,0)
progress_stdout = stdout
if hide_progress:
progress_stdout = io.StringIO()
pfs = ProgressingFileSender(filesize, progress_stdout)
record_pipe = yield transit_sender.connect()
# record_pipe should implement IConsumer, chunks are just records
print(u"Sending (%s).." % transit_sender.describe(), file=stdout)
yield pfs.beginFileTransfer(fd_to_send, record_pipe)
print(u"File sent.. waiting for confirmation", file=stdout)
ack = yield record_pipe.receive_record()
record_pipe.close()
if ack != b"ok\n":
raise TransferError("Transfer failed (remote says: %r)" % ack)
print(u"Confirmation received. Transfer complete.", file=stdout)

View File

@ -1,3 +1,4 @@
from __future__ import print_function
import os, sys, re, io, zipfile
from twisted.trial import unittest
from twisted.python import procutils, log
@ -6,7 +7,7 @@ from twisted.internet.defer import inlineCallbacks
from twisted.internet.threads import deferToThread
from .. import __version__
from .common import ServerBase
from ..scripts import runner, cmd_send_blocking, cmd_receive
from ..scripts import runner, cmd_send_blocking, cmd_send_twisted, cmd_receive
from ..scripts.send_common import build_phase1_data
from ..errors import TransferError
@ -207,7 +208,8 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
@inlineCallbacks
def _do_test(self, as_subprocess=False,
mode="text", override_filename=False):
mode="text", override_filename=False,
sender_twisted=False, receiver_twisted=False):
assert mode in ("text", "file", "directory")
common_args = ["--hide-progress",
"--relay-url", self.relayurl,
@ -295,7 +297,11 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
rargs.cwd = receive_dir
rargs.stdout = io.StringIO()
rargs.stderr = io.StringIO()
if sender_twisted:
send_d = cmd_send_twisted.send_twisted(sargs)
else:
send_d = deferToThread(cmd_send_blocking.send_blocking, sargs)
assert not receiver_twisted # not importable yet
receive_d = deferToThread(cmd_receive.receive, rargs)
send_rc = yield send_d
@ -308,8 +314,10 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
self.maxDiff = None # show full output for assertion failures
# check sender
self.failUnlessEqual(send_stderr, "")
self.failUnlessEqual(receive_stderr, "")
# check sender
if mode == "text":
expected = ("Sending text message (%d bytes)\n"
"On the other computer, please run: "
@ -337,10 +345,8 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
self.failUnlessIn("File sent.. waiting for confirmation\n"
"Confirmation received. Transfer complete.\n",
send_stdout)
self.failUnlessEqual(send_rc, 0)
# check receiver
self.failUnlessEqual(receive_stderr, "")
if mode == "text":
self.failUnlessEqual(receive_stdout, message+"\n")
elif mode == "file":
@ -362,19 +368,27 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
fn = os.path.join(receive_dir, receive_dirname, str(i))
with open(fn, "r") as f:
self.failUnlessEqual(f.read(), message(i))
self.failUnlessEqual(send_rc, 0)
self.failUnlessEqual(receive_rc, 0)
def test_text(self):
return self._do_test()
def test_text_subprocess(self):
return self._do_test(as_subprocess=True)
def test_text_twisted_to_blocking(self):
return self._do_test(sender_twisted=True)
def test_file(self):
return self._do_test(mode="file")
def test_file_override(self):
return self._do_test(mode="file", override_filename=True)
def test_file_twisted_to_blocking(self):
return self._do_test(mode="file", sender_twisted=True)
def test_directory(self):
return self._do_test(mode="directory")
def test_directory_override(self):
return self._do_test(mode="directory", override_filename=True)
def test_directory_twisted_to_blocking(self):
return self._do_test(mode="directory", sender_twisted=True)