cmd_receive: refactor (slight message changes)

This prepares the way for a twisted-based implementation.
This commit is contained in:
Brian Warner 2016-02-28 00:52:49 -08:00
parent 6654efb429
commit 01064325a2
2 changed files with 146 additions and 113 deletions

View File

@ -2,107 +2,146 @@ from __future__ import print_function
import io, os, sys, json, binascii, six, tempfile, zipfile import io, os, sys, json, binascii, six, tempfile, zipfile
from ..blocking.transcribe import Wormhole, WrongPasswordError from ..blocking.transcribe import Wormhole, WrongPasswordError
from ..blocking.transit import TransitReceiver, TransitError from ..blocking.transit import TransitReceiver, TransitError
from ..errors import handle_server_error from ..errors import handle_server_error, TransferError
from .progress import ProgressPrinter from .progress import ProgressPrinter
APPID = u"lothar.com/wormhole/text-or-file-xfer" APPID = u"lothar.com/wormhole/text-or-file-xfer"
@handle_server_error
def receive_blocking(args): def receive_blocking(args):
# we're receiving text, or a file return BlockingReceiver(args).go()
assert isinstance(args.relay_url, type(u""))
with Wormhole(APPID, args.relay_url) as w: class RespondError(Exception):
if args.zeromode: def __init__(self, response):
assert not args.code self.response = response
args.code = u"0-"
code = args.code class BlockingReceiver:
def __init__(self, args):
assert isinstance(args.relay_url, type(u""))
self.args = args
def msg(self, *args, **kwargs):
print(*args, file=self.args.stdout, **kwargs)
@handle_server_error
def go(self):
with Wormhole(APPID, self.args.relay_url) as w:
self.handle_code(w)
verifier = w.get_verifier()
self.show_verifier(verifier)
them_d = self.get_data(w)
try:
if "message" in them_d:
self.handle_text(them_d, w)
return 0
if "file" in them_d:
f = self.handle_file(them_d)
rp = self.establish_transit(w, them_d)
self.transfer_data(rp, f)
self.write_file(f)
self.close_transit(rp)
elif "directory" in them_d:
f = self.handle_directory(them_d)
rp = self.establish_transit(w, them_d)
self.transfer_data(rp, f)
self.write_directory(f)
self.close_transit(rp)
else:
self.msg(u"I don't know what they're offering\n")
self.msg(u"Offer details:", them_d)
raise RespondError({"error": "unknown offer type"})
except RespondError as r:
data = json.dumps(r.response).encode("utf-8")
w.send_data(data)
return 1
return 0
def handle_code(self, w):
code = self.args.code
if self.args.zeromode:
assert not code
code = u"0-"
if not code: if not code:
code = w.input_code("Enter receive wormhole code: ", code = w.input_code("Enter receive wormhole code: ",
args.code_length) self.args.code_length)
w.set_code(code) w.set_code(code)
verifier = binascii.hexlify(w.get_verifier()).decode("ascii") def show_verifier(self, verifier):
if args.verify: verifier_hex = binascii.hexlify(verifier).decode("ascii")
print(u"Verifier %s." % verifier, file=args.stdout) if self.args.verify:
self.msg(u"Verifier %s." % verifier_hex)
def get_data(self, w):
try: try:
them_bytes = w.get_data() them_bytes = w.get_data()
except WrongPasswordError as e: except WrongPasswordError as e:
print(u"ERROR: " + e.explain(), file=sys.stderr) raise TransferError(u"ERROR: " + e.explain())
return 1
them_d = json.loads(them_bytes.decode("utf-8")) them_d = json.loads(them_bytes.decode("utf-8"))
if "error" in them_d: if "error" in them_d:
print(u"ERROR: " + them_d["error"], file=sys.stderr) raise TransferError(u"ERROR: " + them_d["error"])
return 1 return them_d
if "message" in them_d: def handle_text(self, them_d, w):
# we're receiving a text message # we're receiving a text message
print(them_d["message"], file=args.stdout) self.msg(them_d["message"])
data = json.dumps({"message_ack": "ok"}).encode("utf-8") data = json.dumps({"message_ack": "ok"}).encode("utf-8")
w.send_data(data) w.send_data(data)
return 0
if "error" in them_d: def handle_file(self, them_d):
print(u"ERROR: " + data["error"], file=sys.stderr)
return 1
if "file" in them_d:
mode = "file"
file_data = them_d["file"] file_data = them_d["file"]
# the basename() is intended to protect us against self.abs_destname = self.decide_destname("file",
# "~/.ssh/authorized_keys" and other attacks file_data["filename"])
destname = os.path.basename(file_data["filename"]) # unicode self.xfersize = file_data["filesize"]
xfersize = file_data["filesize"]
elif "directory" in them_d: self.msg(u"Receiving file (%d bytes) into: %s" %
mode = "directory" (self.xfersize, os.path.basename(self.abs_destname)))
self.ask_permission()
tmp_destname = self.abs_destname + ".tmp"
return open(tmp_destname, "wb")
def handle_directory(self, them_d):
file_data = them_d["directory"] file_data = them_d["directory"]
zipmode = file_data["mode"] zipmode = file_data["mode"]
if zipmode != "zipfile/deflated": if zipmode != "zipfile/deflated":
print(u"Error: unknown directory-transfer mode '%s'" % self.msg(u"Error: unknown directory-transfer mode '%s'" % (zipmode,))
(zipmode,), file=args.stdout) raise RespondError({"error": "unknown mode"})
data = json.dumps({"error": "unknown mode"}).encode("utf-8") self.abs_destname = self.decide_destname("directory",
w.send_data(data) file_data["dirname"])
return 1 self.xfersize = file_data["zipsize"]
destname = os.path.basename(file_data["dirname"]) # unicode
xfersize = file_data["zipsize"]
num_files = file_data["numfiles"]
num_bytes = file_data["numbytes"]
else:
print(u"I don't know what they're offering\n", file=args.stdout)
print(u"Offer details:", them_d, file=args.stdout)
data = json.dumps({"error": "unknown offer type"}).encode("utf-8")
w.send_data(data)
return 1
if args.output_file: self.msg(u"Receiving directory (%d bytes) into: %s/" %
destname = args.output_file # override (self.xfersize, os.path.basename(self.abs_destname)))
abs_destname = os.path.join(args.cwd, destname) self.msg(u"%d files, %d bytes (uncompressed)" %
(file_data["numfiles"], file_data["numbytes"]))
self.ask_permission()
return tempfile.SpooledTemporaryFile()
def decide_destname(self, mode, destname):
# the basename() is intended to protect us against
# "~/.ssh/authorized_keys" and other attacks
destname = os.path.basename(destname)
if self.args.output_file:
destname = self.args.output_file # override
abs_destname = os.path.join(self.args.cwd, destname)
# get confirmation from the user before writing to the local directory # get confirmation from the user before writing to the local directory
if os.path.exists(abs_destname): if os.path.exists(abs_destname):
print(u"Error: refusing to overwrite existing %s %s" % self.msg(u"Error: refusing to overwrite existing %s %s" %
(mode, destname), file=args.stdout) (mode, destname))
data = json.dumps({"error": "%s already exists" % mode}).encode("utf-8") raise RespondError({"error": "%s already exists" % mode})
w.send_data(data) return abs_destname
return 1
# TODO: add / to destname
print(u"Receiving %s (%d bytes) into: %s" % (mode, xfersize, destname),
file=args.stdout)
if mode == "directory":
print(u"%d files, %d bytes (uncompressed)" %
(num_files, num_bytes), file=args.stdout)
while True and not args.accept_file: def ask_permission(self):
while True and not self.args.accept_file:
ok = six.moves.input("ok? (y/n): ") ok = six.moves.input("ok? (y/n): ")
if ok.lower().startswith("y"): if ok.lower().startswith("y"):
break break
print(u"transfer rejected", file=sys.stderr) print(u"transfer rejected", file=sys.stderr)
data = json.dumps({"error": "transfer rejected"}).encode("utf-8") raise RespondError({"error": "transfer rejected"})
w.send_data(data)
return 1
transit_receiver = TransitReceiver(args.transit_helper) def establish_transit(self, w, them_d):
transit_key = w.derive_key(APPID+u"/transit-key")
transit_receiver = TransitReceiver(self.args.transit_helper)
transit_receiver.set_transit_key(transit_key)
data = json.dumps({ data = json.dumps({
"file_ack": "ok", "file_ack": "ok",
"transit": { "transit": {
@ -111,63 +150,56 @@ def receive_blocking(args):
}, },
}).encode("utf-8") }).encode("utf-8")
w.send_data(data) w.send_data(data)
# now done with the Wormhole object
# now receive the rest of the owl # now receive the rest of the owl
tdata = them_d["transit"] tdata = them_d["transit"]
transit_key = w.derive_key(APPID+u"/transit-key")
transit_receiver.set_transit_key(transit_key)
transit_receiver.add_their_direct_hints(tdata["direct_connection_hints"]) transit_receiver.add_their_direct_hints(tdata["direct_connection_hints"])
transit_receiver.add_their_relay_hints(tdata["relay_connection_hints"]) transit_receiver.add_their_relay_hints(tdata["relay_connection_hints"])
record_pipe = transit_receiver.connect() record_pipe = transit_receiver.connect()
return record_pipe
print(u"Receiving %d bytes for '%s' (%s).." % def transfer_data(self, record_pipe, f):
(xfersize, destname, record_pipe.describe()), self.msg(u"Receiving (%s).." % record_pipe.describe())
file=args.stdout)
if mode == "file":
tmp_destname = abs_destname + ".tmp"
f = open(tmp_destname, "wb")
else:
f = tempfile.SpooledTemporaryFile()
progress_stdout = args.stdout progress_stdout = self.args.stdout
if args.hide_progress: if self.args.hide_progress:
progress_stdout = io.StringIO() progress_stdout = io.StringIO()
received = 0 received = 0
p = ProgressPrinter(xfersize, progress_stdout) p = ProgressPrinter(self.xfersize, progress_stdout)
p.start() p.start()
while received < xfersize: while received < self.xfersize:
try: try:
plaintext = record_pipe.receive_record() plaintext = record_pipe.receive_record()
except TransitError: except TransitError:
print(u"", file=args.stdout) self.msg()
print(u"Connection dropped before full file received", self.msg(u"Connection dropped before full file received")
file=args.stdout) self.msg(u"got %d bytes, wanted %d" % (received, self.xfersize))
print(u"got %d bytes, wanted %d" % (received, xfersize),
file=args.stdout)
return 1 return 1
f.write(plaintext) f.write(plaintext)
received += len(plaintext) received += len(plaintext)
p.update(received) p.update(received)
p.finish() p.finish()
assert received == xfersize assert received == self.xfersize
if mode == "file": def write_file(self, f):
tmp_name = f.name
f.close() f.close()
os.rename(tmp_destname, abs_destname) os.rename(tmp_name, self.abs_destname)
print(u"Received file written to %s" % destname, file=args.stdout) self.msg(u"Received file written to %s" %
else: os.path.basename(self.abs_destname))
print(u"Unpacking zipfile..", file=args.stdout)
def write_directory(self, f):
self.msg(u"Unpacking zipfile..")
with zipfile.ZipFile(f, "r", zipfile.ZIP_DEFLATED) as zf: with zipfile.ZipFile(f, "r", zipfile.ZIP_DEFLATED) as zf:
zf.extractall(path=abs_destname) zf.extractall(path=self.abs_destname)
# extractall() appears to offer some protection against # extractall() appears to offer some protection against
# malicious pathnames. For example, "/tmp/oops" and # malicious pathnames. For example, "/tmp/oops" and
# "../tmp/oops" both do the same thing as the (safe) # "../tmp/oops" both do the same thing as the (safe)
# "tmp/oops". # "tmp/oops".
print(u"Received files written to %s/" % destname, file=args.stdout) self.msg(u"Received files written to %s/" %
os.path.basename(self.abs_destname))
f.close() f.close()
def close_transit(self, record_pipe):
record_pipe.send_record(b"ok\n") record_pipe.send_record(b"ok\n")
record_pipe.close() record_pipe.close()
return 0

View File

@ -360,7 +360,7 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
if mode == "text": if mode == "text":
self.failUnlessEqual(receive_stdout, message+NL) self.failUnlessEqual(receive_stdout, message+NL)
elif mode == "file": elif mode == "file":
self.failUnlessIn("Receiving {bytes:d} bytes for '{name}'" self.failUnlessIn("Receiving file ({bytes:d} bytes) into: {name}"
.format(bytes=len(message), .format(bytes=len(message),
name=receive_filename), receive_stdout) name=receive_filename), receive_stdout)
self.failUnlessIn("Received file written to ", receive_stdout) self.failUnlessIn("Received file written to ", receive_stdout)
@ -369,9 +369,10 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
with open(fn, "r") as f: with open(fn, "r") as f:
self.failUnlessEqual(f.read(), message) self.failUnlessEqual(f.read(), message)
elif mode == "directory": elif mode == "directory":
self.failUnless(re.search(r"Receiving \d+ bytes for '{name}'" want = (r"Receiving directory \(\d+ bytes\) into: {name}/"
.format(name=receive_dirname), .format(name=receive_dirname))
receive_stdout)) self.failUnless(re.search(want, receive_stdout),
(want, receive_stdout))
self.failUnlessIn("Received files written to {name}" self.failUnlessIn("Received files written to {name}"
.format(name=receive_dirname), receive_stdout) .format(name=receive_dirname), receive_stdout)
fn = os.path.join(receive_dir, receive_dirname) fn = os.path.join(receive_dir, receive_dirname)