add twisted form of sender
Currently this is only invokable from tests.
This commit is contained in:
parent
7ceffd783a
commit
aa27bfd32c
145
src/wormhole/scripts/cmd_send_twisted.py
Normal file
145
src/wormhole/scripts/cmd_send_twisted.py
Normal file
|
@ -0,0 +1,145 @@
|
|||
from __future__ import print_function
|
||||
import io, json, binascii, six
|
||||
from twisted.protocols import basic
|
||||
from twisted.internet import reactor
|
||||
from twisted.internet.defer import inlineCallbacks, returnValue
|
||||
from ..errors import TransferError
|
||||
from .progress import ProgressPrinter
|
||||
from ..twisted.transcribe import Wormhole, WrongPasswordError
|
||||
from ..twisted.transit import TransitSender
|
||||
from .send_common import (APPID, handle_zero, build_other_command,
|
||||
build_phase1_data)
|
||||
|
||||
def send_twisted_sync(args):
|
||||
d = send_twisted(args)
|
||||
# try to use twisted.internet.task.react(f) here (but it calls sys.exit
|
||||
# directly)
|
||||
rc = []
|
||||
def _done(res):
|
||||
rc.extend([True, res])
|
||||
reactor.stop()
|
||||
def _err(f):
|
||||
rc.extend([False, f.value])
|
||||
reactor.stop()
|
||||
d.addCallbacks(_done, _err)
|
||||
reactor.run()
|
||||
if rc[0]:
|
||||
return rc[1]
|
||||
raise rc[1]
|
||||
|
||||
@inlineCallbacks
|
||||
def send_twisted(args):
|
||||
assert isinstance(args.relay_url, type(u""))
|
||||
handle_zero(args)
|
||||
phase1, fd_to_send = build_phase1_data(args)
|
||||
other_cmd = build_other_command(args)
|
||||
print(u"On the other computer, please run: %s" % other_cmd,
|
||||
file=args.stdout)
|
||||
|
||||
w = Wormhole(APPID, args.relay_url)
|
||||
|
||||
if fd_to_send:
|
||||
transit_sender = TransitSender(args.transit_helper)
|
||||
phase1["transit"] = transit_data = {}
|
||||
transit_data["relay_connection_hints"] = transit_sender.get_relay_hints()
|
||||
direct_hints = yield transit_sender.get_direct_hints()
|
||||
transit_data["direct_connection_hints"] = direct_hints
|
||||
|
||||
if args.code:
|
||||
w.set_code(args.code)
|
||||
code = args.code
|
||||
else:
|
||||
code = yield w.get_code(args.code_length)
|
||||
|
||||
if not args.zeromode:
|
||||
print(u"Wormhole code is: %s" % code, file=args.stdout)
|
||||
print(u"", file=args.stdout)
|
||||
|
||||
if args.verify:
|
||||
verifier_bytes = yield w.get_verifier()
|
||||
verifier = binascii.hexlify(verifier_bytes).decode("ascii")
|
||||
while True:
|
||||
ok = six.moves.input("Verifier %s. ok? (yes/no): " % verifier)
|
||||
if ok.lower() == "yes":
|
||||
break
|
||||
if ok.lower() == "no":
|
||||
reject_data = json.dumps({"error": "verification rejected",
|
||||
}).encode("utf-8")
|
||||
yield w.send_data(reject_data)
|
||||
raise TransferError("verification rejected, abandoning transfer")
|
||||
|
||||
my_phase1_bytes = json.dumps(phase1).encode("utf-8")
|
||||
yield w.send_data(my_phase1_bytes)
|
||||
|
||||
try:
|
||||
them_phase1_bytes = yield w.get_data()
|
||||
except WrongPasswordError as e:
|
||||
raise TransferError(e.explain())
|
||||
|
||||
them_phase1 = json.loads(them_phase1_bytes.decode("utf-8"))
|
||||
|
||||
if fd_to_send is None:
|
||||
if them_phase1["message_ack"] == "ok":
|
||||
print(u"text message sent", file=args.stdout)
|
||||
yield w.close()
|
||||
returnValue(0) # terminates this function
|
||||
raise TransferError("error sending text: %r" % (them_phase1,))
|
||||
|
||||
if "error" in them_phase1:
|
||||
raise TransferError("remote error, transfer abandoned: %s"
|
||||
% them_phase1["error"])
|
||||
if them_phase1.get("file_ack") != "ok":
|
||||
raise TransferError("ambiguous response from remote, "
|
||||
"transfer abandoned: %s" % (them_phase1,))
|
||||
tdata = them_phase1["transit"]
|
||||
# this is happening too late: the other side already connects to our
|
||||
# server
|
||||
transit_key = w.derive_key(APPID+"/transit-key")
|
||||
transit_sender.set_transit_key(transit_key)
|
||||
yield w.close()
|
||||
yield _send_file_twisted(tdata, transit_sender, fd_to_send,
|
||||
args.stdout, args.hide_progress)
|
||||
returnValue(0)
|
||||
|
||||
class ProgressingFileSender(basic.FileSender):
|
||||
def __init__(self, filesize, stdout):
|
||||
self._sent = 0
|
||||
self._progress = ProgressPrinter(filesize, stdout)
|
||||
self._progress.start()
|
||||
def progress(self, data):
|
||||
self._sent += len(data)
|
||||
self._progress.update(self._sent)
|
||||
return data
|
||||
def beginFileTransfer(self, file, consumer):
|
||||
d = basic.FileSender.beginFileTransfer(self, file, consumer,
|
||||
self.progress)
|
||||
d.addCallback(self.done)
|
||||
return d
|
||||
def done(self, res):
|
||||
self._progress.finish()
|
||||
return res
|
||||
|
||||
@inlineCallbacks
|
||||
def _send_file_twisted(tdata, transit_sender, fd_to_send,
|
||||
stdout, hide_progress):
|
||||
transit_sender.add_their_direct_hints(tdata["direct_connection_hints"])
|
||||
transit_sender.add_their_relay_hints(tdata["relay_connection_hints"])
|
||||
|
||||
fd_to_send.seek(0,2)
|
||||
filesize = fd_to_send.tell()
|
||||
fd_to_send.seek(0,0)
|
||||
progress_stdout = stdout
|
||||
if hide_progress:
|
||||
progress_stdout = io.StringIO()
|
||||
pfs = ProgressingFileSender(filesize, progress_stdout)
|
||||
|
||||
record_pipe = yield transit_sender.connect()
|
||||
# record_pipe should implement IConsumer, chunks are just records
|
||||
print(u"Sending (%s).." % transit_sender.describe(), file=stdout)
|
||||
yield pfs.beginFileTransfer(fd_to_send, record_pipe)
|
||||
print(u"File sent.. waiting for confirmation", file=stdout)
|
||||
ack = yield record_pipe.receive_record()
|
||||
record_pipe.close()
|
||||
if ack != b"ok\n":
|
||||
raise TransferError("Transfer failed (remote says: %r)" % ack)
|
||||
print(u"Confirmation received. Transfer complete.", file=stdout)
|
|
@ -1,3 +1,4 @@
|
|||
from __future__ import print_function
|
||||
import os, sys, re, io, zipfile
|
||||
from twisted.trial import unittest
|
||||
from twisted.python import procutils, log
|
||||
|
@ -6,7 +7,7 @@ from twisted.internet.defer import inlineCallbacks
|
|||
from twisted.internet.threads import deferToThread
|
||||
from .. import __version__
|
||||
from .common import ServerBase
|
||||
from ..scripts import runner, cmd_send_blocking, cmd_receive
|
||||
from ..scripts import runner, cmd_send_blocking, cmd_send_twisted, cmd_receive
|
||||
from ..scripts.send_common import build_phase1_data
|
||||
from ..errors import TransferError
|
||||
|
||||
|
@ -207,7 +208,8 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
|
|||
|
||||
@inlineCallbacks
|
||||
def _do_test(self, as_subprocess=False,
|
||||
mode="text", override_filename=False):
|
||||
mode="text", override_filename=False,
|
||||
sender_twisted=False, receiver_twisted=False):
|
||||
assert mode in ("text", "file", "directory")
|
||||
common_args = ["--hide-progress",
|
||||
"--relay-url", self.relayurl,
|
||||
|
@ -295,7 +297,11 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
|
|||
rargs.cwd = receive_dir
|
||||
rargs.stdout = io.StringIO()
|
||||
rargs.stderr = io.StringIO()
|
||||
if sender_twisted:
|
||||
send_d = cmd_send_twisted.send_twisted(sargs)
|
||||
else:
|
||||
send_d = deferToThread(cmd_send_blocking.send_blocking, sargs)
|
||||
assert not receiver_twisted # not importable yet
|
||||
receive_d = deferToThread(cmd_receive.receive, rargs)
|
||||
|
||||
send_rc = yield send_d
|
||||
|
@ -308,8 +314,10 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
|
|||
|
||||
self.maxDiff = None # show full output for assertion failures
|
||||
|
||||
# check sender
|
||||
self.failUnlessEqual(send_stderr, "")
|
||||
self.failUnlessEqual(receive_stderr, "")
|
||||
|
||||
# check sender
|
||||
if mode == "text":
|
||||
expected = ("Sending text message (%d bytes)\n"
|
||||
"On the other computer, please run: "
|
||||
|
@ -337,10 +345,8 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
|
|||
self.failUnlessIn("File sent.. waiting for confirmation\n"
|
||||
"Confirmation received. Transfer complete.\n",
|
||||
send_stdout)
|
||||
self.failUnlessEqual(send_rc, 0)
|
||||
|
||||
# check receiver
|
||||
self.failUnlessEqual(receive_stderr, "")
|
||||
if mode == "text":
|
||||
self.failUnlessEqual(receive_stdout, message+"\n")
|
||||
elif mode == "file":
|
||||
|
@ -362,19 +368,27 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
|
|||
fn = os.path.join(receive_dir, receive_dirname, str(i))
|
||||
with open(fn, "r") as f:
|
||||
self.failUnlessEqual(f.read(), message(i))
|
||||
|
||||
self.failUnlessEqual(send_rc, 0)
|
||||
self.failUnlessEqual(receive_rc, 0)
|
||||
|
||||
def test_text(self):
|
||||
return self._do_test()
|
||||
def test_text_subprocess(self):
|
||||
return self._do_test(as_subprocess=True)
|
||||
def test_text_twisted_to_blocking(self):
|
||||
return self._do_test(sender_twisted=True)
|
||||
|
||||
def test_file(self):
|
||||
return self._do_test(mode="file")
|
||||
def test_file_override(self):
|
||||
return self._do_test(mode="file", override_filename=True)
|
||||
def test_file_twisted_to_blocking(self):
|
||||
return self._do_test(mode="file", sender_twisted=True)
|
||||
|
||||
def test_directory(self):
|
||||
return self._do_test(mode="directory")
|
||||
def test_directory_override(self):
|
||||
return self._do_test(mode="directory", override_filename=True)
|
||||
def test_directory_twisted_to_blocking(self):
|
||||
return self._do_test(mode="directory", sender_twisted=True)
|
||||
|
|
Loading…
Reference in New Issue
Block a user