add twisted form of sender
Currently this is only invokable from tests.
This commit is contained in:
parent
7ceffd783a
commit
aa27bfd32c
145
src/wormhole/scripts/cmd_send_twisted.py
Normal file
145
src/wormhole/scripts/cmd_send_twisted.py
Normal file
|
@ -0,0 +1,145 @@
|
||||||
|
from __future__ import print_function
|
||||||
|
import io, json, binascii, six
|
||||||
|
from twisted.protocols import basic
|
||||||
|
from twisted.internet import reactor
|
||||||
|
from twisted.internet.defer import inlineCallbacks, returnValue
|
||||||
|
from ..errors import TransferError
|
||||||
|
from .progress import ProgressPrinter
|
||||||
|
from ..twisted.transcribe import Wormhole, WrongPasswordError
|
||||||
|
from ..twisted.transit import TransitSender
|
||||||
|
from .send_common import (APPID, handle_zero, build_other_command,
|
||||||
|
build_phase1_data)
|
||||||
|
|
||||||
|
def send_twisted_sync(args):
|
||||||
|
d = send_twisted(args)
|
||||||
|
# try to use twisted.internet.task.react(f) here (but it calls sys.exit
|
||||||
|
# directly)
|
||||||
|
rc = []
|
||||||
|
def _done(res):
|
||||||
|
rc.extend([True, res])
|
||||||
|
reactor.stop()
|
||||||
|
def _err(f):
|
||||||
|
rc.extend([False, f.value])
|
||||||
|
reactor.stop()
|
||||||
|
d.addCallbacks(_done, _err)
|
||||||
|
reactor.run()
|
||||||
|
if rc[0]:
|
||||||
|
return rc[1]
|
||||||
|
raise rc[1]
|
||||||
|
|
||||||
|
@inlineCallbacks
|
||||||
|
def send_twisted(args):
|
||||||
|
assert isinstance(args.relay_url, type(u""))
|
||||||
|
handle_zero(args)
|
||||||
|
phase1, fd_to_send = build_phase1_data(args)
|
||||||
|
other_cmd = build_other_command(args)
|
||||||
|
print(u"On the other computer, please run: %s" % other_cmd,
|
||||||
|
file=args.stdout)
|
||||||
|
|
||||||
|
w = Wormhole(APPID, args.relay_url)
|
||||||
|
|
||||||
|
if fd_to_send:
|
||||||
|
transit_sender = TransitSender(args.transit_helper)
|
||||||
|
phase1["transit"] = transit_data = {}
|
||||||
|
transit_data["relay_connection_hints"] = transit_sender.get_relay_hints()
|
||||||
|
direct_hints = yield transit_sender.get_direct_hints()
|
||||||
|
transit_data["direct_connection_hints"] = direct_hints
|
||||||
|
|
||||||
|
if args.code:
|
||||||
|
w.set_code(args.code)
|
||||||
|
code = args.code
|
||||||
|
else:
|
||||||
|
code = yield w.get_code(args.code_length)
|
||||||
|
|
||||||
|
if not args.zeromode:
|
||||||
|
print(u"Wormhole code is: %s" % code, file=args.stdout)
|
||||||
|
print(u"", file=args.stdout)
|
||||||
|
|
||||||
|
if args.verify:
|
||||||
|
verifier_bytes = yield w.get_verifier()
|
||||||
|
verifier = binascii.hexlify(verifier_bytes).decode("ascii")
|
||||||
|
while True:
|
||||||
|
ok = six.moves.input("Verifier %s. ok? (yes/no): " % verifier)
|
||||||
|
if ok.lower() == "yes":
|
||||||
|
break
|
||||||
|
if ok.lower() == "no":
|
||||||
|
reject_data = json.dumps({"error": "verification rejected",
|
||||||
|
}).encode("utf-8")
|
||||||
|
yield w.send_data(reject_data)
|
||||||
|
raise TransferError("verification rejected, abandoning transfer")
|
||||||
|
|
||||||
|
my_phase1_bytes = json.dumps(phase1).encode("utf-8")
|
||||||
|
yield w.send_data(my_phase1_bytes)
|
||||||
|
|
||||||
|
try:
|
||||||
|
them_phase1_bytes = yield w.get_data()
|
||||||
|
except WrongPasswordError as e:
|
||||||
|
raise TransferError(e.explain())
|
||||||
|
|
||||||
|
them_phase1 = json.loads(them_phase1_bytes.decode("utf-8"))
|
||||||
|
|
||||||
|
if fd_to_send is None:
|
||||||
|
if them_phase1["message_ack"] == "ok":
|
||||||
|
print(u"text message sent", file=args.stdout)
|
||||||
|
yield w.close()
|
||||||
|
returnValue(0) # terminates this function
|
||||||
|
raise TransferError("error sending text: %r" % (them_phase1,))
|
||||||
|
|
||||||
|
if "error" in them_phase1:
|
||||||
|
raise TransferError("remote error, transfer abandoned: %s"
|
||||||
|
% them_phase1["error"])
|
||||||
|
if them_phase1.get("file_ack") != "ok":
|
||||||
|
raise TransferError("ambiguous response from remote, "
|
||||||
|
"transfer abandoned: %s" % (them_phase1,))
|
||||||
|
tdata = them_phase1["transit"]
|
||||||
|
# this is happening too late: the other side already connects to our
|
||||||
|
# server
|
||||||
|
transit_key = w.derive_key(APPID+"/transit-key")
|
||||||
|
transit_sender.set_transit_key(transit_key)
|
||||||
|
yield w.close()
|
||||||
|
yield _send_file_twisted(tdata, transit_sender, fd_to_send,
|
||||||
|
args.stdout, args.hide_progress)
|
||||||
|
returnValue(0)
|
||||||
|
|
||||||
|
class ProgressingFileSender(basic.FileSender):
|
||||||
|
def __init__(self, filesize, stdout):
|
||||||
|
self._sent = 0
|
||||||
|
self._progress = ProgressPrinter(filesize, stdout)
|
||||||
|
self._progress.start()
|
||||||
|
def progress(self, data):
|
||||||
|
self._sent += len(data)
|
||||||
|
self._progress.update(self._sent)
|
||||||
|
return data
|
||||||
|
def beginFileTransfer(self, file, consumer):
|
||||||
|
d = basic.FileSender.beginFileTransfer(self, file, consumer,
|
||||||
|
self.progress)
|
||||||
|
d.addCallback(self.done)
|
||||||
|
return d
|
||||||
|
def done(self, res):
|
||||||
|
self._progress.finish()
|
||||||
|
return res
|
||||||
|
|
||||||
|
@inlineCallbacks
|
||||||
|
def _send_file_twisted(tdata, transit_sender, fd_to_send,
|
||||||
|
stdout, hide_progress):
|
||||||
|
transit_sender.add_their_direct_hints(tdata["direct_connection_hints"])
|
||||||
|
transit_sender.add_their_relay_hints(tdata["relay_connection_hints"])
|
||||||
|
|
||||||
|
fd_to_send.seek(0,2)
|
||||||
|
filesize = fd_to_send.tell()
|
||||||
|
fd_to_send.seek(0,0)
|
||||||
|
progress_stdout = stdout
|
||||||
|
if hide_progress:
|
||||||
|
progress_stdout = io.StringIO()
|
||||||
|
pfs = ProgressingFileSender(filesize, progress_stdout)
|
||||||
|
|
||||||
|
record_pipe = yield transit_sender.connect()
|
||||||
|
# record_pipe should implement IConsumer, chunks are just records
|
||||||
|
print(u"Sending (%s).." % transit_sender.describe(), file=stdout)
|
||||||
|
yield pfs.beginFileTransfer(fd_to_send, record_pipe)
|
||||||
|
print(u"File sent.. waiting for confirmation", file=stdout)
|
||||||
|
ack = yield record_pipe.receive_record()
|
||||||
|
record_pipe.close()
|
||||||
|
if ack != b"ok\n":
|
||||||
|
raise TransferError("Transfer failed (remote says: %r)" % ack)
|
||||||
|
print(u"Confirmation received. Transfer complete.", file=stdout)
|
|
@ -1,3 +1,4 @@
|
||||||
|
from __future__ import print_function
|
||||||
import os, sys, re, io, zipfile
|
import os, sys, re, io, zipfile
|
||||||
from twisted.trial import unittest
|
from twisted.trial import unittest
|
||||||
from twisted.python import procutils, log
|
from twisted.python import procutils, log
|
||||||
|
@ -6,7 +7,7 @@ from twisted.internet.defer import inlineCallbacks
|
||||||
from twisted.internet.threads import deferToThread
|
from twisted.internet.threads import deferToThread
|
||||||
from .. import __version__
|
from .. import __version__
|
||||||
from .common import ServerBase
|
from .common import ServerBase
|
||||||
from ..scripts import runner, cmd_send_blocking, cmd_receive
|
from ..scripts import runner, cmd_send_blocking, cmd_send_twisted, cmd_receive
|
||||||
from ..scripts.send_common import build_phase1_data
|
from ..scripts.send_common import build_phase1_data
|
||||||
from ..errors import TransferError
|
from ..errors import TransferError
|
||||||
|
|
||||||
|
@ -207,7 +208,8 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
|
||||||
|
|
||||||
@inlineCallbacks
|
@inlineCallbacks
|
||||||
def _do_test(self, as_subprocess=False,
|
def _do_test(self, as_subprocess=False,
|
||||||
mode="text", override_filename=False):
|
mode="text", override_filename=False,
|
||||||
|
sender_twisted=False, receiver_twisted=False):
|
||||||
assert mode in ("text", "file", "directory")
|
assert mode in ("text", "file", "directory")
|
||||||
common_args = ["--hide-progress",
|
common_args = ["--hide-progress",
|
||||||
"--relay-url", self.relayurl,
|
"--relay-url", self.relayurl,
|
||||||
|
@ -295,7 +297,11 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
|
||||||
rargs.cwd = receive_dir
|
rargs.cwd = receive_dir
|
||||||
rargs.stdout = io.StringIO()
|
rargs.stdout = io.StringIO()
|
||||||
rargs.stderr = io.StringIO()
|
rargs.stderr = io.StringIO()
|
||||||
send_d = deferToThread(cmd_send_blocking.send_blocking, sargs)
|
if sender_twisted:
|
||||||
|
send_d = cmd_send_twisted.send_twisted(sargs)
|
||||||
|
else:
|
||||||
|
send_d = deferToThread(cmd_send_blocking.send_blocking, sargs)
|
||||||
|
assert not receiver_twisted # not importable yet
|
||||||
receive_d = deferToThread(cmd_receive.receive, rargs)
|
receive_d = deferToThread(cmd_receive.receive, rargs)
|
||||||
|
|
||||||
send_rc = yield send_d
|
send_rc = yield send_d
|
||||||
|
@ -308,8 +314,10 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
|
||||||
|
|
||||||
self.maxDiff = None # show full output for assertion failures
|
self.maxDiff = None # show full output for assertion failures
|
||||||
|
|
||||||
# check sender
|
|
||||||
self.failUnlessEqual(send_stderr, "")
|
self.failUnlessEqual(send_stderr, "")
|
||||||
|
self.failUnlessEqual(receive_stderr, "")
|
||||||
|
|
||||||
|
# check sender
|
||||||
if mode == "text":
|
if mode == "text":
|
||||||
expected = ("Sending text message (%d bytes)\n"
|
expected = ("Sending text message (%d bytes)\n"
|
||||||
"On the other computer, please run: "
|
"On the other computer, please run: "
|
||||||
|
@ -337,10 +345,8 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
|
||||||
self.failUnlessIn("File sent.. waiting for confirmation\n"
|
self.failUnlessIn("File sent.. waiting for confirmation\n"
|
||||||
"Confirmation received. Transfer complete.\n",
|
"Confirmation received. Transfer complete.\n",
|
||||||
send_stdout)
|
send_stdout)
|
||||||
self.failUnlessEqual(send_rc, 0)
|
|
||||||
|
|
||||||
# check receiver
|
# check receiver
|
||||||
self.failUnlessEqual(receive_stderr, "")
|
|
||||||
if mode == "text":
|
if mode == "text":
|
||||||
self.failUnlessEqual(receive_stdout, message+"\n")
|
self.failUnlessEqual(receive_stdout, message+"\n")
|
||||||
elif mode == "file":
|
elif mode == "file":
|
||||||
|
@ -362,19 +368,27 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
|
||||||
fn = os.path.join(receive_dir, receive_dirname, str(i))
|
fn = os.path.join(receive_dir, receive_dirname, str(i))
|
||||||
with open(fn, "r") as f:
|
with open(fn, "r") as f:
|
||||||
self.failUnlessEqual(f.read(), message(i))
|
self.failUnlessEqual(f.read(), message(i))
|
||||||
|
|
||||||
|
self.failUnlessEqual(send_rc, 0)
|
||||||
self.failUnlessEqual(receive_rc, 0)
|
self.failUnlessEqual(receive_rc, 0)
|
||||||
|
|
||||||
def test_text(self):
|
def test_text(self):
|
||||||
return self._do_test()
|
return self._do_test()
|
||||||
def test_text_subprocess(self):
|
def test_text_subprocess(self):
|
||||||
return self._do_test(as_subprocess=True)
|
return self._do_test(as_subprocess=True)
|
||||||
|
def test_text_twisted_to_blocking(self):
|
||||||
|
return self._do_test(sender_twisted=True)
|
||||||
|
|
||||||
def test_file(self):
|
def test_file(self):
|
||||||
return self._do_test(mode="file")
|
return self._do_test(mode="file")
|
||||||
def test_file_override(self):
|
def test_file_override(self):
|
||||||
return self._do_test(mode="file", override_filename=True)
|
return self._do_test(mode="file", override_filename=True)
|
||||||
|
def test_file_twisted_to_blocking(self):
|
||||||
|
return self._do_test(mode="file", sender_twisted=True)
|
||||||
|
|
||||||
def test_directory(self):
|
def test_directory(self):
|
||||||
return self._do_test(mode="directory")
|
return self._do_test(mode="directory")
|
||||||
def test_directory_override(self):
|
def test_directory_override(self):
|
||||||
return self._do_test(mode="directory", override_filename=True)
|
return self._do_test(mode="directory", override_filename=True)
|
||||||
|
def test_directory_twisted_to_blocking(self):
|
||||||
|
return self._do_test(mode="directory", sender_twisted=True)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user