Merge branch 'senddir'

This commit is contained in:
Brian Warner 2015-11-29 01:37:12 -06:00
commit 143d6dbc74
4 changed files with 311 additions and 97 deletions

View File

@ -1,88 +1,46 @@
from __future__ import print_function from __future__ import print_function
import os, sys, json, binascii, six import os, sys, json, binascii, six, tempfile, zipfile
from ..errors import handle_server_error from ..errors import handle_server_error
APPID = u"lothar.com/wormhole/text-or-file-xfer" APPID = u"lothar.com/wormhole/text-or-file-xfer"
@handle_server_error def accept_file(args, them_d, w):
def receive(args):
# we're receiving text, or a file
from ..blocking.transcribe import Wormhole, WrongPasswordError
from ..blocking.transit import TransitReceiver, TransitError from ..blocking.transit import TransitReceiver, TransitError
from .progress import start_progress, update_progress, finish_progress from .progress import start_progress, update_progress, finish_progress
assert isinstance(args.relay_url, type(u""))
with Wormhole(APPID, args.relay_url) as w: file_data = them_d["file"]
if args.zeromode: # the basename() is intended to protect us against
assert not args.code # "~/.ssh/authorized_keys" and other attacks
args.code = u"0-" filename = os.path.basename(file_data["filename"]) # unicode
code = args.code filesize = file_data["filesize"]
if not code:
code = w.input_code("Enter receive wormhole code: ", args.code_length)
w.set_code(code)
if args.verify: # get confirmation from the user before writing to the local directory
verifier = binascii.hexlify(w.get_verifier()).decode("ascii") if os.path.exists(filename):
print("Verifier %s." % verifier) print("Error: refusing to overwrite existing file %s" % (filename,))
data = json.dumps({"error": "file already exists"}).encode("utf-8")
try:
them_bytes = w.get_data()
except WrongPasswordError as e:
print("ERROR: " + e.explain(), file=sys.stderr)
return 1
them_d = json.loads(them_bytes.decode("utf-8"))
if "error" in them_d:
print("ERROR: " + them_d["error"], file=sys.stderr)
return 1
if "message" in them_d:
# we're receiving a text message
print(them_d["message"])
data = json.dumps({"message_ack": "ok"}).encode("utf-8")
w.send_data(data)
return 0
if not "file" in them_d:
print("I don't know what they're offering\n")
print(them_d)
return 1
if "error" in them_d:
print("ERROR: " + data["error"], file=sys.stderr)
return 1
file_data = them_d["file"]
# the basename() is intended to protect us against
# "~/.ssh/authorized_keys" and other attacks
filename = os.path.basename(file_data["filename"]) # unicode
filesize = file_data["filesize"]
# get confirmation from the user before writing to the local directory
if os.path.exists(filename):
print("Error: refusing to overwrite existing file %s" % (filename,))
data = json.dumps({"error": "file already exists"}).encode("utf-8")
w.send_data(data)
return 1
print("Receiving file (%d bytes) into: %s" % (filesize, filename))
while True and not args.accept_file:
ok = six.moves.input("ok? (y/n): ")
if ok.lower().startswith("y"):
break
print("transfer rejected", file=sys.stderr)
data = json.dumps({"error": "transfer rejected"}).encode("utf-8")
w.send_data(data)
return 1
transit_receiver = TransitReceiver(args.transit_helper)
data = json.dumps({
"file_ack": "ok",
"transit": {
"direct_connection_hints": transit_receiver.get_direct_hints(),
"relay_connection_hints": transit_receiver.get_relay_hints(),
},
}).encode("utf-8")
w.send_data(data) w.send_data(data)
return 1
print("Receiving file (%d bytes) into: %s" % (filesize, filename))
while True and not args.accept_file:
ok = six.moves.input("ok? (y/n): ")
if ok.lower().startswith("y"):
break
print("transfer rejected", file=sys.stderr)
data = json.dumps({"error": "transfer rejected"}).encode("utf-8")
w.send_data(data)
return 1
transit_receiver = TransitReceiver(args.transit_helper)
data = json.dumps({
"file_ack": "ok",
"transit": {
"direct_connection_hints": transit_receiver.get_direct_hints(),
"relay_connection_hints": transit_receiver.get_relay_hints(),
},
}).encode("utf-8")
w.send_data(data)
# now done with the Wormhole object
# now receive the rest of the owl # now receive the rest of the owl
tdata = them_d["transit"] tdata = them_d["transit"]
@ -118,3 +76,144 @@ def receive(args):
record_pipe.send_record(b"ok\n") record_pipe.send_record(b"ok\n")
record_pipe.close() record_pipe.close()
return 0 return 0
def accept_directory(args, them_d, w):
from ..blocking.transit import TransitReceiver, TransitError
from .progress import start_progress, update_progress, finish_progress
file_data = them_d["directory"]
mode = file_data["mode"]
if mode != "zipfile/deflated":
print("Error: unknown directory-transfer mode '%s'" % (mode,))
data = json.dumps({"error": "unknown mode"}).encode("utf-8")
w.send_data(data)
return 1
if args.output_file:
dirname = args.output_file
else:
# the basename() is intended to protect us against
# "~/.ssh/authorized_keys" and other attacks
dirname = os.path.basename(file_data["dirname"]) # unicode
filesize = file_data["zipsize"]
num_files = file_data["numfiles"]
num_bytes = file_data["numbytes"]
if os.path.exists(dirname):
print("Error: refusing to overwrite existing directory %s" % (dirname,))
data = json.dumps({"error": "directory already exists"}).encode("utf-8")
w.send_data(data)
return 1
print("Receiving directory into: %s/" % (dirname,))
print("%d files, %d bytes (%d compressed)" % (num_files, num_bytes,
filesize))
while True and not args.accept_file:
ok = six.moves.input("ok? (y/n): ")
if ok.lower().startswith("y"):
break
print("transfer rejected", file=sys.stderr)
data = json.dumps({"error": "transfer rejected"}).encode("utf-8")
w.send_data(data)
return 1
transit_receiver = TransitReceiver(args.transit_helper)
data = json.dumps({
"file_ack": "ok",
"transit": {
"direct_connection_hints": transit_receiver.get_direct_hints(),
"relay_connection_hints": transit_receiver.get_relay_hints(),
},
}).encode("utf-8")
w.send_data(data)
# now done with the Wormhole object
# now receive the rest of the owl
tdata = them_d["transit"]
transit_key = w.derive_key(APPID+u"/transit-key")
transit_receiver.set_transit_key(transit_key)
transit_receiver.add_their_direct_hints(tdata["direct_connection_hints"])
transit_receiver.add_their_relay_hints(tdata["relay_connection_hints"])
record_pipe = transit_receiver.connect()
print("Receiving %d bytes for '%s' (%s).." % (filesize, dirname,
transit_receiver.describe()))
f = tempfile.SpooledTemporaryFile()
received = 0
next_update = start_progress(filesize)
while received < filesize:
try:
plaintext = record_pipe.receive_record()
except TransitError:
print()
print("Connection dropped before full file received")
print("got %d bytes, wanted %d" % (received, filesize))
return 1
f.write(plaintext)
received += len(plaintext)
next_update = update_progress(next_update, received, filesize)
finish_progress(filesize)
assert received == filesize
print("Unpacking zipfile..")
with zipfile.ZipFile(f, "r", zipfile.ZIP_DEFLATED) as zf:
zf.extractall(path=dirname)
# extractall() appears to offer some protection against malicious
# pathnames. For example, "/tmp/oops" and "../tmp/oops" both do the
# same thing as the (safe) "tmp/oops".
print("Received files written to %s/" % dirname)
record_pipe.send_record(b"ok\n")
record_pipe.close()
return 0
@handle_server_error
def receive(args):
# we're receiving text, or a file
from ..blocking.transcribe import Wormhole, WrongPasswordError
assert isinstance(args.relay_url, type(u""))
with Wormhole(APPID, args.relay_url) as w:
if args.zeromode:
assert not args.code
args.code = u"0-"
code = args.code
if not code:
code = w.input_code("Enter receive wormhole code: ", args.code_length)
w.set_code(code)
if args.verify:
verifier = binascii.hexlify(w.get_verifier()).decode("ascii")
print("Verifier %s." % verifier)
try:
them_bytes = w.get_data()
except WrongPasswordError as e:
print("ERROR: " + e.explain(), file=sys.stderr)
return 1
them_d = json.loads(them_bytes.decode("utf-8"))
if "error" in them_d:
print("ERROR: " + them_d["error"], file=sys.stderr)
return 1
if "message" in them_d:
# we're receiving a text message
print(them_d["message"])
data = json.dumps({"message_ack": "ok"}).encode("utf-8")
w.send_data(data)
return 0
if "error" in them_d:
print("ERROR: " + data["error"], file=sys.stderr)
return 1
if "file" in them_d:
return accept_file(args, them_d, w)
if "directory" in them_d:
return accept_directory(args, them_d, w)
print("I don't know what they're offering\n")
print("Offer details:", them_d)
data = json.dumps({"error": "unknown offer type"}).encode("utf-8")
w.send_data(data)
return 1

View File

@ -1,12 +1,12 @@
from __future__ import print_function from __future__ import print_function
import os, sys, json, binascii, six import os, sys, json, binascii, six, tempfile, zipfile
from ..errors import handle_server_error from ..errors import handle_server_error
APPID = u"lothar.com/wormhole/text-or-file-xfer" APPID = u"lothar.com/wormhole/text-or-file-xfer"
@handle_server_error @handle_server_error
def send(args): def send(args):
# we're sending text, or a file # we're sending text, or a file/directory
from ..blocking.transcribe import Wormhole, WrongPasswordError from ..blocking.transcribe import Wormhole, WrongPasswordError
from ..blocking.transit import TransitSender from ..blocking.transit import TransitSender
from .progress import start_progress, update_progress, finish_progress from .progress import start_progress, update_progress, finish_progress
@ -26,25 +26,60 @@ def send(args):
"message": text, "message": text,
} }
else: else:
if not os.path.isfile(args.what): if not os.path.exists(args.what):
print("Cannot send: no file named '%s'" % args.what) print("Cannot send: no file/directory named '%s'" % args.what)
return 1 return 1
# we're sending a file
sending_message = False sending_message = False
filesize = os.stat(args.what).st_size
basename = os.path.basename(args.what)
print("Sending %d byte file named '%s'" % (filesize, basename))
transit_sender = TransitSender(args.transit_helper) transit_sender = TransitSender(args.transit_helper)
phase1 = { phase1 = {
"file": {
"filename": basename,
"filesize": filesize,
},
"transit": { "transit": {
"direct_connection_hints": transit_sender.get_direct_hints(), "direct_connection_hints": transit_sender.get_direct_hints(),
"relay_connection_hints": transit_sender.get_relay_hints(), "relay_connection_hints": transit_sender.get_relay_hints(),
}, },
} }
basename = os.path.basename(args.what)
if os.path.isfile(args.what):
# we're sending a file
filesize = os.stat(args.what).st_size
phase1["file"] = {
"filename": basename,
"filesize": filesize,
}
print("Sending %d byte file named '%s'" % (filesize, basename))
fd_to_send = open(args.what, "rb")
elif os.path.isdir(args.what):
print("Building zipfile..")
# We're sending a directory. Create a zipfile in a tempdir and
# send that.
fd_to_send = tempfile.SpooledTemporaryFile()
# TODO: I think ZIP_DEFLATED means compressed.. check it
num_files = 0
num_bytes = 0
tostrip = len(args.what.split(os.sep))
with zipfile.ZipFile(fd_to_send, "w", zipfile.ZIP_DEFLATED) as zf:
for path,dirs,files in os.walk(args.what):
# path always starts with args.what, then sometimes might
# have "/subdir" appended. We want the zipfile to contain
# "" or "subdir"
localpath = list(path.split(os.sep)[tostrip:])
for fn in files:
archivename = os.path.join(*tuple(localpath+[fn]))
localfilename = os.path.join(path, fn)
zf.write(localfilename, archivename)
num_bytes += os.stat(localfilename).st_size
num_files += 1
fd_to_send.seek(0,2)
filesize = fd_to_send.tell()
fd_to_send.seek(0,0)
phase1["directory"] = {
"mode": "zipfile/deflated",
"dirname": basename,
"zipsize": filesize,
"numbytes": num_bytes,
"numfiles": num_files,
}
print("Sending directory (%d bytes compressed) named '%s'"
% (filesize, basename))
with Wormhole(APPID, args.relay_url) as w: with Wormhole(APPID, args.relay_url) as w:
if args.zeromode: if args.zeromode:
@ -114,7 +149,7 @@ def send(args):
print("Sending (%s).." % transit_sender.describe()) print("Sending (%s).." % transit_sender.describe())
CHUNKSIZE = 64*1024 CHUNKSIZE = 64*1024
with open(args.what, "rb") as f: with fd_to_send as f:
sent = 0 sent = 0
next_update = start_progress(filesize) next_update = start_progress(filesize)
while sent < filesize: while sent < filesize:

View File

@ -76,21 +76,21 @@ sp_tail_usage.set_defaults(func=cmd_usage.tail_usage)
# CLI: send # CLI: send
p = subparsers.add_parser("send", p = subparsers.add_parser("send",
description="Send text message or file", description="Send text message, file, or directory",
usage="wormhole send [FILENAME]") usage="wormhole send [FILENAME|DIRNAME]")
p.add_argument("--text", metavar="MESSAGE", p.add_argument("--text", metavar="MESSAGE",
help="text message to send, instead of a file. Use '-' to read from stdin.") help="text message to send, instead of a file. Use '-' to read from stdin.")
p.add_argument("--code", metavar="CODE", help="human-generated code phrase", p.add_argument("--code", metavar="CODE", help="human-generated code phrase",
type=type(u"")) type=type(u""))
p.add_argument("-0", dest="zeromode", action="store_true", p.add_argument("-0", dest="zeromode", action="store_true",
help="enable no-code anything-goes mode") help="enable no-code anything-goes mode")
p.add_argument("what", nargs="?", default=None, metavar="[FILENAME]", p.add_argument("what", nargs="?", default=None, metavar="[FILENAME|DIRNAME]",
help="the file to send") help="the file/directory to send")
p.set_defaults(func=cmd_send.send) p.set_defaults(func=cmd_send.send)
# CLI: receive # CLI: receive
p = subparsers.add_parser("receive", p = subparsers.add_parser("receive",
description="Receive a text message or file", description="Receive a text message, file, or directory",
usage="wormhole receive [CODE]") usage="wormhole receive [CODE]")
p.add_argument("-0", dest="zeromode", action="store_true", p.add_argument("-0", dest="zeromode", action="store_true",
help="enable no-code anything-goes mode") help="enable no-code anything-goes mode")
@ -98,10 +98,10 @@ p.add_argument("-t", "--only-text", dest="only_text", action="store_true",
help="refuse file transfers, only accept text transfers") help="refuse file transfers, only accept text transfers")
p.add_argument("--accept-file", dest="accept_file", action="store_true", p.add_argument("--accept-file", dest="accept_file", action="store_true",
help="accept file transfer with asking for confirmation") help="accept file transfer with asking for confirmation")
p.add_argument("-o", "--output-file", default=None, metavar="FILENAME", p.add_argument("-o", "--output-file", default=None, metavar="FILENAME|DIRNAME",
help=dedent("""\ help=dedent("""\
The file to create, overriding the filename suggested by the The file or directory to create, overriding the name suggested
sender."""), by the sender."""),
) )
p.add_argument("code", nargs="?", default=None, metavar="[CODE]", p.add_argument("code", nargs="?", default=None, metavar="[CODE]",
help=dedent("""\ help=dedent("""\

View File

@ -1,4 +1,4 @@
import os, sys import os, sys, re
from twisted.trial import unittest from twisted.trial import unittest
from twisted.python import procutils, log from twisted.python import procutils, log
from twisted.internet.utils import getProcessOutputAndValue from twisted.internet.utils import getProcessOutputAndValue
@ -177,3 +177,83 @@ class Scripts(ServerBase, ScriptsBase, unittest.TestCase):
self.failUnlessEqual(f.read(), message) self.failUnlessEqual(f.read(), message)
d1.addCallback(_check_receiver) d1.addCallback(_check_receiver)
return d1 return d1
def test_send_directory_pre_generated_code(self):
return self._do_test_send_directory_pre_generated_code(False)
def test_send_directory_pre_generated_code_override(self):
return self._do_test_send_directory_pre_generated_code(True)
def _do_test_send_directory_pre_generated_code(self, override_dirname):
self.maxDiff=None
code = u"1-abc"
dirname = "testdir"
def message(i):
return "test message %d\n" % i
source_parent_dir = self.mktemp()
os.mkdir(source_parent_dir)
os.mkdir(os.path.join(source_parent_dir, "middle"))
source_dir = os.path.join(source_parent_dir, "middle", dirname)
os.mkdir(source_dir)
for i in range(5):
with open(os.path.join(source_dir, str(i)), "w") as f:
f.write(message(i))
target_parent_dir = self.mktemp()
os.mkdir(target_parent_dir)
wormhole = self.find_executable()
server_args = ["--relay-url", self.relayurl]
send_args = server_args + [
"send",
"--code", code,
os.path.join("middle", dirname),
]
receive_args = server_args + [
"receive", "--accept-file",
]
if override_dirname:
receive_args.extend(["-o", "outdir"])
dirname = "outdir"
receive_args.append(code)
d1 = getProcessOutputAndValue(wormhole, send_args,
path=source_parent_dir)
d2 = getProcessOutputAndValue(wormhole, receive_args,
path=target_parent_dir)
def _check_sender(res):
out, err, rc = res
out = out.decode("utf-8")
err = err.decode("utf-8")
self.failUnlessEqual(err, "")
self.failUnlessIn("Sending directory", out)
self.failUnlessIn("named 'testdir'", out)
self.failUnlessIn("On the other computer, please run: "
"wormhole receive\n"
"Wormhole code is: %s\n\n" % code,
out)
self.failUnlessIn("File sent.. waiting for confirmation\n"
"Confirmation received. Transfer complete.\n",
out)
self.failUnlessEqual(rc, 0)
return d2
d1.addCallback(_check_sender)
def _check_receiver(res):
out, err, rc = res
out = out.decode("utf-8")
err = err.decode("utf-8")
self.failUnless(re.search(r"Receiving \d+ bytes for '%s'" %
dirname, out))
self.failUnlessIn("Received files written to %s" % dirname, out)
self.failUnlessEqual(err, "")
self.failUnlessEqual(rc, 0)
fn = os.path.join(target_parent_dir, dirname)
self.failUnless(os.path.exists(fn))
for i in range(5):
fn = os.path.join(target_parent_dir, dirname, str(i))
with open(fn, "r") as f:
self.failUnlessEqual(f.read(), message(i))
d1.addCallback(_check_receiver)
return d1