merge cmd_receive_blocking into cmd_receive_twisted

This commit is contained in:
Brian Warner 2016-04-15 17:36:52 -07:00
parent 8c67a98259
commit 9b53bb96c6
3 changed files with 97 additions and 348 deletions

View File

@ -1,216 +0,0 @@
from __future__ import print_function
import io, os, sys, json, binascii, six, tempfile, zipfile
from ..blocking.transcribe import Wormhole, WrongPasswordError
from ..blocking.transit import TransitReceiver, TransitError
from ..errors import handle_server_error, TransferError
from .progress import ProgressPrinter
APPID = u"lothar.com/wormhole/text-or-file-xfer"
def receive_blocking(args):
return BlockingReceiver(args).go()
class RespondError(Exception):
def __init__(self, response):
self.response = response
class BlockingReceiver:
def __init__(self, args):
assert isinstance(args.relay_url, type(u""))
self.args = args
def msg(self, *args, **kwargs):
print(*args, file=self.args.stdout, **kwargs)
@handle_server_error
def go(self):
with Wormhole(APPID, self.args.relay_url, timing=self.args.timing) as w:
self.handle_code(w)
verifier = w.get_verifier()
self.show_verifier(verifier)
them_d = self.get_data(w)
try:
if "message" in them_d:
self.handle_text(them_d, w)
return 0
if "file" in them_d:
f = self.handle_file(them_d)
rp = self.establish_transit(w, them_d)
self.transfer_data(rp, f)
self.write_file(f)
self.close_transit(rp)
elif "directory" in them_d:
f = self.handle_directory(them_d)
rp = self.establish_transit(w, them_d)
self.transfer_data(rp, f)
self.write_directory(f)
self.close_transit(rp)
else:
self.msg(u"I don't know what they're offering\n")
self.msg(u"Offer details:", them_d)
raise RespondError({"error": "unknown offer type"})
except RespondError as r:
data = json.dumps(r.response).encode("utf-8")
w.send_data(data)
return 1
return 0
def handle_code(self, w):
code = self.args.code
if self.args.zeromode:
assert not code
code = u"0-"
if not code:
code = w.input_code("Enter receive wormhole code: ",
self.args.code_length)
w.set_code(code)
def show_verifier(self, verifier):
verifier_hex = binascii.hexlify(verifier).decode("ascii")
if self.args.verify:
self.msg(u"Verifier %s." % verifier_hex)
def get_data(self, w):
try:
them_bytes = w.get_data()
except WrongPasswordError as e:
raise TransferError(u"ERROR: " + e.explain())
them_d = json.loads(them_bytes.decode("utf-8"))
if "error" in them_d:
raise TransferError(u"ERROR: " + them_d["error"])
return them_d
def handle_text(self, them_d, w):
# we're receiving a text message
self.msg(them_d["message"])
data = json.dumps({"message_ack": "ok"}).encode("utf-8")
w.send_data(data)
def handle_file(self, them_d):
file_data = them_d["file"]
self.abs_destname = self.decide_destname("file",
file_data["filename"])
self.xfersize = file_data["filesize"]
self.msg(u"Receiving file (%d bytes) into: %s" %
(self.xfersize, os.path.basename(self.abs_destname)))
self.ask_permission()
tmp_destname = self.abs_destname + ".tmp"
return open(tmp_destname, "wb")
def handle_directory(self, them_d):
file_data = them_d["directory"]
zipmode = file_data["mode"]
if zipmode != "zipfile/deflated":
self.msg(u"Error: unknown directory-transfer mode '%s'" % (zipmode,))
raise RespondError({"error": "unknown mode"})
self.abs_destname = self.decide_destname("directory",
file_data["dirname"])
self.xfersize = file_data["zipsize"]
self.msg(u"Receiving directory (%d bytes) into: %s/" %
(self.xfersize, os.path.basename(self.abs_destname)))
self.msg(u"%d files, %d bytes (uncompressed)" %
(file_data["numfiles"], file_data["numbytes"]))
self.ask_permission()
return tempfile.SpooledTemporaryFile()
def decide_destname(self, mode, destname):
# the basename() is intended to protect us against
# "~/.ssh/authorized_keys" and other attacks
destname = os.path.basename(destname)
if self.args.output_file:
destname = self.args.output_file # override
abs_destname = os.path.join(self.args.cwd, destname)
# get confirmation from the user before writing to the local directory
if os.path.exists(abs_destname):
self.msg(u"Error: refusing to overwrite existing %s %s" %
(mode, destname))
raise RespondError({"error": "%s already exists" % mode})
return abs_destname
def ask_permission(self):
_start = self.args.timing.add_event("permission", waiting="user")
while True and not self.args.accept_file:
ok = six.moves.input("ok? (y/n): ")
if ok.lower().startswith("y"):
break
print(u"transfer rejected", file=sys.stderr)
self.args.timing.finish_event(_start, answer="no")
raise RespondError({"error": "transfer rejected"})
self.args.timing.finish_event(_start, answer="yes")
def establish_transit(self, w, them_d):
transit_key = w.derive_key(APPID+u"/transit-key")
transit_receiver = TransitReceiver(self.args.transit_helper,
no_listen=self.args.no_listen,
timing=self.args.timing)
transit_receiver.set_transit_key(transit_key)
data = json.dumps({
"file_ack": "ok",
"transit": {
"direct_connection_hints": transit_receiver.get_direct_hints(),
"relay_connection_hints": transit_receiver.get_relay_hints(),
},
}).encode("utf-8")
w.send_data(data)
# now receive the rest of the owl
tdata = them_d["transit"]
transit_receiver.add_their_direct_hints(tdata["direct_connection_hints"])
transit_receiver.add_their_relay_hints(tdata["relay_connection_hints"])
record_pipe = transit_receiver.connect()
return record_pipe
def transfer_data(self, record_pipe, f):
self.msg(u"Receiving (%s).." % record_pipe.describe())
_start = self.args.timing.add_event("rx file")
progress_stdout = self.args.stdout
if self.args.hide_progress:
progress_stdout = io.StringIO()
received = 0
p = ProgressPrinter(self.xfersize, progress_stdout)
p.start()
while received < self.xfersize:
try:
plaintext = record_pipe.receive_record()
except TransitError:
self.msg()
self.msg(u"Connection dropped before full file received")
self.msg(u"got %d bytes, wanted %d" % (received, self.xfersize))
return 1
f.write(plaintext)
received += len(plaintext)
p.update(received)
p.finish()
self.args.timing.finish_event(_start)
assert received == self.xfersize
def write_file(self, f):
tmp_name = f.name
f.close()
os.rename(tmp_name, self.abs_destname)
self.msg(u"Received file written to %s" %
os.path.basename(self.abs_destname))
def write_directory(self, f):
self.msg(u"Unpacking zipfile..")
_start = self.args.timing.add_event("unpack zip")
with zipfile.ZipFile(f, "r", zipfile.ZIP_DEFLATED) as zf:
zf.extractall(path=self.abs_destname)
# extractall() appears to offer some protection against
# malicious pathnames. For example, "/tmp/oops" and
# "../tmp/oops" both do the same thing as the (safe)
# "tmp/oops".
self.msg(u"Received files written to %s/" %
os.path.basename(self.abs_destname))
f.close()
self.args.timing.finish_event(_start)
def close_transit(self, record_pipe):
_start = self.args.timing.add_event("ack")
record_pipe.send_record(b"ok\n")
record_pipe.close()
self.args.timing.finish_event(_start)

View File

@ -1,13 +1,19 @@
from __future__ import print_function
import io, json
import io, os, sys, json, binascii, six, tempfile, zipfile
from twisted.internet import reactor, defer
from twisted.internet.defer import inlineCallbacks, returnValue
from ..twisted.transcribe import Wormhole, WrongPasswordError
from ..twisted.transit import TransitReceiver
from .cmd_receive_blocking import BlockingReceiver, RespondError, APPID
from ..errors import TransferError
from .progress import ProgressPrinter
APPID = u"lothar.com/wormhole/text-or-file-xfer"
class RespondError(Exception):
def __init__(self, response):
self.response = response
def receive_twisted_sync(args):
# try to use twisted.internet.task.react(f) here (but it calls sys.exit
# directly)
@ -34,7 +40,14 @@ def receive_twisted_sync(args):
def receive_twisted(args):
return TwistedReceiver(args).go()
class TwistedReceiver(BlockingReceiver):
class TwistedReceiver:
def __init__(self, args):
assert isinstance(args.relay_url, type(u""))
self.args = args
def msg(self, *args, **kwargs):
print(*args, file=self.args.stdout, **kwargs)
# TODO: @handle_server_error
@inlineCallbacks
@ -101,6 +114,11 @@ class TwistedReceiver(BlockingReceiver):
self.args.code_length)
yield w.set_code(code)
def show_verifier(self, verifier):
verifier_hex = binascii.hexlify(verifier).decode("ascii")
if self.args.verify:
self.msg(u"Verifier %s." % verifier_hex)
@inlineCallbacks
def get_data(self, w):
try:
@ -119,6 +137,61 @@ class TwistedReceiver(BlockingReceiver):
data = json.dumps({"message_ack": "ok"}).encode("utf-8")
yield w.send_data(data)
def handle_file(self, them_d):
file_data = them_d["file"]
self.abs_destname = self.decide_destname("file",
file_data["filename"])
self.xfersize = file_data["filesize"]
self.msg(u"Receiving file (%d bytes) into: %s" %
(self.xfersize, os.path.basename(self.abs_destname)))
self.ask_permission()
tmp_destname = self.abs_destname + ".tmp"
return open(tmp_destname, "wb")
def handle_directory(self, them_d):
file_data = them_d["directory"]
zipmode = file_data["mode"]
if zipmode != "zipfile/deflated":
self.msg(u"Error: unknown directory-transfer mode '%s'" % (zipmode,))
raise RespondError({"error": "unknown mode"})
self.abs_destname = self.decide_destname("directory",
file_data["dirname"])
self.xfersize = file_data["zipsize"]
self.msg(u"Receiving directory (%d bytes) into: %s/" %
(self.xfersize, os.path.basename(self.abs_destname)))
self.msg(u"%d files, %d bytes (uncompressed)" %
(file_data["numfiles"], file_data["numbytes"]))
self.ask_permission()
return tempfile.SpooledTemporaryFile()
def decide_destname(self, mode, destname):
# the basename() is intended to protect us against
# "~/.ssh/authorized_keys" and other attacks
destname = os.path.basename(destname)
if self.args.output_file:
destname = self.args.output_file # override
abs_destname = os.path.join(self.args.cwd, destname)
# get confirmation from the user before writing to the local directory
if os.path.exists(abs_destname):
self.msg(u"Error: refusing to overwrite existing %s %s" %
(mode, destname))
raise RespondError({"error": "%s already exists" % mode})
return abs_destname
def ask_permission(self):
_start = self.args.timing.add_event("permission", waiting="user")
while True and not self.args.accept_file:
ok = six.moves.input("ok? (y/n): ")
if ok.lower().startswith("y"):
break
print(u"transfer rejected", file=sys.stderr)
self.args.timing.finish_event(_start, answer="no")
raise RespondError({"error": "transfer rejected"})
self.args.timing.finish_event(_start, answer="yes")
@inlineCallbacks
def establish_transit(self, w, them_d, tor_manager):
transit_key = w.derive_key(APPID+u"/transit-key")
@ -169,6 +242,27 @@ class TwistedReceiver(BlockingReceiver):
returnValue(1) # TODO: exit properly
assert received == self.xfersize
def write_file(self, f):
tmp_name = f.name
f.close()
os.rename(tmp_name, self.abs_destname)
self.msg(u"Received file written to %s" %
os.path.basename(self.abs_destname))
def write_directory(self, f):
self.msg(u"Unpacking zipfile..")
_start = self.args.timing.add_event("unpack zip")
with zipfile.ZipFile(f, "r", zipfile.ZIP_DEFLATED) as zf:
zf.extractall(path=self.abs_destname)
# extractall() appears to offer some protection against
# malicious pathnames. For example, "/tmp/oops" and
# "../tmp/oops" both do the same thing as the (safe)
# "tmp/oops".
self.msg(u"Received files written to %s/" %
os.path.basename(self.abs_destname))
f.close()
self.args.timing.finish_event(_start)
@inlineCallbacks
def close_transit(self, record_pipe):
_start = self.args.timing.add_event("ack")

View File

@ -1,129 +0,0 @@
from __future__ import print_function
import json, binascii, six
from ..errors import TransferError
from .progress import ProgressPrinter
from ..blocking.transcribe import Wormhole, WrongPasswordError
from ..blocking.transit import TransitSender
from ..errors import handle_server_error
from .send_common import (APPID, handle_zero, build_other_command,
build_phase1_data)
@handle_server_error
def send_blocking(args):
assert isinstance(args.relay_url, type(u""))
handle_zero(args)
phase1, fd_to_send = build_phase1_data(args)
other_cmd = build_other_command(args)
print(u"On the other computer, please run: %s" % other_cmd,
file=args.stdout)
if fd_to_send is not None:
transit_sender = TransitSender(args.transit_helper,
no_listen=args.no_listen,
timing=args.timing)
transit_data = {
"direct_connection_hints": transit_sender.get_direct_hints(),
"relay_connection_hints": transit_sender.get_relay_hints(),
}
phase1["transit"] = transit_data
with Wormhole(APPID, args.relay_url, timing=args.timing) as w:
if args.code:
w.set_code(args.code)
code = args.code
else:
code = w.get_code(args.code_length)
if not args.zeromode:
print(u"Wormhole code is: %s" % code, file=args.stdout)
print(u"", file=args.stdout)
# get the verifier, because that also lets us derive the transit key,
# which we want to set before revealing the connection hints to the
# far side, so we'll be ready for them when they connect
verifier = binascii.hexlify(w.get_verifier()).decode("ascii")
if args.verify:
_do_verify(verifier, w)
if fd_to_send is not None:
transit_key = w.derive_key(APPID+"/transit-key")
transit_sender.set_transit_key(transit_key)
my_phase1_bytes = json.dumps(phase1).encode("utf-8")
w.send_data(my_phase1_bytes)
try:
them_phase1_bytes = w.get_data()
except WrongPasswordError as e:
raise TransferError(e.explain())
them_phase1 = json.loads(them_phase1_bytes.decode("utf-8"))
if fd_to_send is None:
if them_phase1["message_ack"] == "ok":
print(u"text message sent", file=args.stdout)
return 0
raise TransferError("error sending text: %r" % (them_phase1,))
return _send_file_blocking(them_phase1, fd_to_send,
transit_sender, args.stdout, args.hide_progress,
args.timing)
def _do_verify(verifier, w):
while True:
ok = six.moves.input("Verifier %s. ok? (yes/no): " % verifier)
if ok.lower() == "yes":
break
if ok.lower() == "no":
reject_data = json.dumps({"error": "verification rejected",
}).encode("utf-8")
w.send_data(reject_data)
raise TransferError("verification rejected, abandoning transfer")
def _send_file_blocking(them_phase1, fd_to_send, transit_sender,
stdout, hide_progress, timing):
# we're sending a file, if they accept it
if "error" in them_phase1:
raise TransferError("remote error, transfer abandoned: %s"
% them_phase1["error"])
if them_phase1.get("file_ack") != "ok":
raise TransferError("ambiguous response from remote, "
"transfer abandoned: %s" % (them_phase1,))
tdata = them_phase1["transit"]
transit_sender.add_their_direct_hints(tdata["direct_connection_hints"])
transit_sender.add_their_relay_hints(tdata["relay_connection_hints"])
record_pipe = transit_sender.connect()
print(u"Sending (%s).." % record_pipe.describe(), file=stdout)
_start = timing.add_event("tx file")
CHUNKSIZE = 64*1024
fd_to_send.seek(0,2)
filesize = fd_to_send.tell()
fd_to_send.seek(0,0)
p = ProgressPrinter(filesize, stdout)
with fd_to_send as f:
sent = 0
if not hide_progress:
p.start()
while sent < filesize:
plaintext = f.read(CHUNKSIZE)
record_pipe.send_record(plaintext)
sent += len(plaintext)
if not hide_progress:
p.update(sent)
if not hide_progress:
p.finish()
timing.finish_event(_start)
_start = timing.add_event("get ack")
print(u"File sent.. waiting for confirmation", file=stdout)
ack = record_pipe.receive_record()
record_pipe.close()
if ack == b"ok\n":
print(u"Confirmation received. Transfer complete.", file=stdout)
timing.finish_event(_start, ack="ok")
return 0
timing.finish_event(_start, ack="failed")
raise TransferError("Transfer failed (remote says: %r)" % ack)