make blocking/send-file work on py3

* declare transit records and handshake keys are bytes, not str
* declare transit connection hints to be str
* use six.moves.socketserver, six.moves.input for Verifier query
* argparse "--version" writes to stderr on py2, stdout on py3
* avoid xrange(), use subprocess.Popen(universal_newlines=True)
This commit is contained in:
Brian Warner 2015-09-27 23:40:00 -07:00
parent 8fe41e135d
commit b5d470fcda
6 changed files with 46 additions and 28 deletions

View File

@ -219,15 +219,16 @@ python2, "bytes" in python3):
* application identifier * application identifier
* verifier string * verifier string
* data in * data in/out
* data out
* derived-key "purpose" string * derived-key "purpose" string
* transit records in/out
Some human-readable parameters are passed as strings: "str" in python2, "str" Some human-readable parameters are passed as strings: "str" in python2, "str"
(i.e. unicode) in python3: (i.e. unicode) in python3:
* wormhole code * wormhole code
* relay/transit URLs * relay/transit URLs
* transit connection hints (e.g. "host:port")
## Detailed Example ## Detailed Example

View File

@ -1,9 +1,11 @@
from __future__ import print_function from __future__ import print_function
import re, time, threading, socket, SocketServer import re, time, threading, socket
from six.moves import socketserver
from binascii import hexlify, unhexlify from binascii import hexlify, unhexlify
from nacl.secret import SecretBox from nacl.secret import SecretBox
from ..util import ipaddrs from ..util import ipaddrs
from ..util.hkdf import HKDF from ..util.hkdf import HKDF
from ..errors import UsageError
class TransitError(Exception): class TransitError(Exception):
pass pass
@ -40,15 +42,15 @@ class TransitError(Exception):
def build_receiver_handshake(key): def build_receiver_handshake(key):
hexid = HKDF(key, 32, CTXinfo=b"transit_receiver") hexid = HKDF(key, 32, CTXinfo=b"transit_receiver")
return "transit receiver %s ready\n\n" % hexlify(hexid) return b"transit receiver %s ready\n\n" % hexlify(hexid)
def build_sender_handshake(key): def build_sender_handshake(key):
hexid = HKDF(key, 32, CTXinfo=b"transit_sender") hexid = HKDF(key, 32, CTXinfo=b"transit_sender")
return "transit sender %s ready\n\n" % hexlify(hexid) return b"transit sender %s ready\n\n" % hexlify(hexid)
def build_relay_handshake(key): def build_relay_handshake(key):
token = HKDF(key, 32, CTXinfo=b"transit_relay_token") token = HKDF(key, 32, CTXinfo=b"transit_relay_token")
return "please relay %s\n" % hexlify(token) return b"please relay %s\n" % hexlify(token)
TIMEOUT=15 TIMEOUT=15
@ -62,11 +64,6 @@ TIMEOUT=15
class BadHandshake(Exception): class BadHandshake(Exception):
pass pass
def force_ascii(s):
if isinstance(s, type(u"")):
return s.encode("ascii")
return s
def send_to(skt, data): def send_to(skt, data):
sent = 0 sent = 0
while sent < len(data): while sent < len(data):
@ -87,6 +84,7 @@ def wait_for(skt, expected, description):
# publisher wants anonymity, their only hint's ADDR will end in .onion . # publisher wants anonymity, their only hint's ADDR will end in .onion .
def parse_hint_tcp(hint): def parse_hint_tcp(hint):
assert isinstance(hint, str)
# return tuple or None for an unparseable hint # return tuple or None for an unparseable hint
mo = re.search(r'^([a-zA-Z0-9]+):(.*)$', hint) mo = re.search(r'^([a-zA-Z0-9]+):(.*)$', hint)
if not mo: if not mo:
@ -187,7 +185,7 @@ def handle(skt, client_address, owner, description,
# owner is now responsible for the socket # owner is now responsible for the socket
owner._negotiation_finished(skt, description) # note thread owner._negotiation_finished(skt, description) # note thread
class MyTCPServer(SocketServer.TCPServer): class MyTCPServer(socketserver.TCPServer):
allow_reuse_address = True allow_reuse_address = True
def process_request(self, request, client_address): def process_request(self, request, client_address):
@ -243,6 +241,7 @@ class RecordPipe:
self.next_receive_nonce = 0 self.next_receive_nonce = 0
def send_record(self, record): def send_record(self, record):
if not isinstance(record, type(b"")): raise UsageError
assert SecretBox.NONCE_SIZE == 24 assert SecretBox.NONCE_SIZE == 24
assert self.send_nonce < 2**(8*24) assert self.send_nonce < 2**(8*24)
assert len(record) < 2**(8*4) assert len(record) < 2**(8*4)
@ -294,9 +293,9 @@ class Common:
return [self._transit_relay] return [self._transit_relay]
def add_their_direct_hints(self, hints): def add_their_direct_hints(self, hints):
self._their_direct_hints = [force_ascii(h) for h in hints] self._their_direct_hints = [str(h) for h in hints]
def add_their_relay_hints(self, hints): def add_their_relay_hints(self, hints):
self._their_relay_hints = [force_ascii(h) for h in hints] self._their_relay_hints = [str(h) for h in hints]
def _send_this(self): def _send_this(self):
if self.is_sender: if self.is_sender:
@ -308,7 +307,7 @@ class Common:
if self.is_sender: if self.is_sender:
return build_receiver_handshake(self._transit_key) return build_receiver_handshake(self._transit_key)
else: else:
return build_sender_handshake(self._transit_key) + "go\n" return build_sender_handshake(self._transit_key) + b"go\n"
def _sender_record_key(self): def _sender_record_key(self):
if self.is_sender: if self.is_sender:
@ -407,11 +406,11 @@ class Common:
if is_winner: if is_winner:
if self.is_sender: if self.is_sender:
send_to(skt, "go\n") send_to(skt, b"go\n")
self.winning.set() self.winning.set()
else: else:
if self.is_sender: if self.is_sender:
send_to(skt, "nevermind\n") send_to(skt, b"nevermind\n")
skt.close() skt.close()
def connect(self): def connect(self):

View File

@ -68,12 +68,12 @@ def receive_file(args):
if os.path.dirname(target) != here: if os.path.dirname(target) != here:
print("Error: suggested filename (%s) would be outside current directory" print("Error: suggested filename (%s) would be outside current directory"
% (filename,)) % (filename,))
record_pipe.send_record("bad filename\n") record_pipe.send_record(b"bad filename\n")
record_pipe.close() record_pipe.close()
return 1 return 1
if os.path.exists(target) and not args.overwrite: if os.path.exists(target) and not args.overwrite:
print("Error: refusing to overwrite existing file %s" % (filename,)) print("Error: refusing to overwrite existing file %s" % (filename,))
record_pipe.send_record("file already exists\n") record_pipe.send_record(b"file already exists\n")
record_pipe.close() record_pipe.close()
return 1 return 1
tmp = target + ".tmp" tmp = target + ".tmp"
@ -98,6 +98,6 @@ def receive_file(args):
os.rename(tmp, target) os.rename(tmp, target)
print("Received file written to %s" % target) print("Received file written to %s" % target)
record_pipe.send_record("ok\n") record_pipe.send_record(b"ok\n")
record_pipe.close() record_pipe.close()
return 0 return 0

View File

@ -1,5 +1,5 @@
from __future__ import print_function from __future__ import print_function
import os, sys, json, binascii import os, sys, json, binascii, six
from ..errors import handle_server_error from ..errors import handle_server_error
APPID = b"lothar.com/wormhole/file-xfer" APPID = b"lothar.com/wormhole/file-xfer"
@ -37,7 +37,7 @@ def send_file(args):
if args.verify: if args.verify:
verifier = binascii.hexlify(w.get_verifier()) verifier = binascii.hexlify(w.get_verifier())
while True: while True:
ok = raw_input("Verifier %s. ok? (yes/no): " % verifier) ok = six.moves.input("Verifier %s. ok? (yes/no): " % verifier)
if ok.lower() == "yes": if ok.lower() == "yes":
break break
if ok.lower() == "no": if ok.lower() == "no":
@ -91,7 +91,7 @@ def send_file(args):
print("File sent.. waiting for confirmation") print("File sent.. waiting for confirmation")
ack = record_pipe.receive_record() ack = record_pipe.receive_record()
if ack == "ok\n": if ack == b"ok\n":
print("Confirmation received. Transfer complete.") print("Confirmation received. Transfer complete.")
return 0 return 0
else: else:

View File

@ -55,12 +55,21 @@ class ScriptVersion(ServerBase, ScriptsBase, unittest.TestCase):
d = getProcessOutputAndValue(wormhole, ["--version"]) d = getProcessOutputAndValue(wormhole, ["--version"])
def _check(res): def _check(res):
out, err, rc = res out, err, rc = res
self.failUnlessEqual(out, "") # argparse on py2 sends --version to stderr
# argparse on py3 sends --version to stdout
# aargh
out = out.decode("utf-8")
err = err.decode("utf-8")
if "DistributionNotFound" in err: if "DistributionNotFound" in err:
log.msg("stderr was %s" % err) log.msg("stderr was %s" % err)
last = err.strip().split("\n")[-1] last = err.strip().split("\n")[-1]
self.fail("wormhole not runnable: %s" % last) self.fail("wormhole not runnable: %s" % last)
self.failUnlessEqual(err, "magic-wormhole %s\n" % __version__) if sys.version_info[0] == 2:
self.failUnlessEqual(out, "")
self.failUnlessEqual(err, "magic-wormhole %s\n" % __version__)
else:
self.failUnlessEqual(err, "")
self.failUnlessEqual(out, "magic-wormhole %s\n" % __version__)
self.failUnlessEqual(rc, 0) self.failUnlessEqual(rc, 0)
d.addCallback(_check) d.addCallback(_check)
return d return d
@ -92,6 +101,8 @@ class Scripts(ServerBase, ScriptsBase, unittest.TestCase):
d2 = getProcessOutputAndValue(wormhole, receive_args) d2 = getProcessOutputAndValue(wormhole, receive_args)
def _check_sender(res): def _check_sender(res):
out, err, rc = res out, err, rc = res
out = out.decode("utf-8")
err = err.decode("utf-8")
self.failUnlessEqual(out, self.failUnlessEqual(out,
"On the other computer, please run: " "On the other computer, please run: "
"wormhole receive-text\n" "wormhole receive-text\n"
@ -104,6 +115,8 @@ class Scripts(ServerBase, ScriptsBase, unittest.TestCase):
d1.addCallback(_check_sender) d1.addCallback(_check_sender)
def _check_receiver(res): def _check_receiver(res):
out, err, rc = res out, err, rc = res
out = out.decode("utf-8")
err = err.decode("utf-8")
self.failUnlessEqual(out, message+"\n") self.failUnlessEqual(out, message+"\n")
self.failUnlessEqual(err, "") self.failUnlessEqual(err, "")
self.failUnlessEqual(rc, 0) self.failUnlessEqual(rc, 0)
@ -137,6 +150,8 @@ class Scripts(ServerBase, ScriptsBase, unittest.TestCase):
d2 = getProcessOutputAndValue(wormhole, receive_args, path=receive_dir) d2 = getProcessOutputAndValue(wormhole, receive_args, path=receive_dir)
def _check_sender(res): def _check_sender(res):
out, err, rc = res out, err, rc = res
out = out.decode("utf-8")
err = err.decode("utf-8")
self.failUnlessIn("On the other computer, please run: " self.failUnlessIn("On the other computer, please run: "
"wormhole receive-file\n" "wormhole receive-file\n"
"Wormhole code is '%s'\n\n" % code, "Wormhole code is '%s'\n\n" % code,
@ -150,6 +165,8 @@ class Scripts(ServerBase, ScriptsBase, unittest.TestCase):
d1.addCallback(_check_sender) d1.addCallback(_check_sender)
def _check_receiver(res): def _check_receiver(res):
out, err, rc = res out, err, rc = res
out = out.decode("utf-8")
err = err.decode("utf-8")
self.failUnlessIn("Receiving %d bytes for 'testfile'" % len(message), self.failUnlessIn("Receiving %d bytes for 'testfile'" % len(message),
out) out)
self.failUnlessIn("Received file written to ", out) self.failUnlessIn("Received file written to ", out)

View File

@ -32,7 +32,7 @@ def find_addresses():
commands = _unix_commands commands = _unix_commands
for (pathtotool, args, regex) in commands: for (pathtotool, args, regex) in commands:
assert os.path.isabs(pathtotool) assert os.path.isabs(pathtotool), pathtotool
if not os.path.isfile(pathtotool): if not os.path.isfile(pathtotool):
continue continue
try: try:
@ -46,12 +46,13 @@ def find_addresses():
def _query(path, args, regex): def _query(path, args, regex):
env = {'LANG': 'en_US.UTF-8'} env = {'LANG': 'en_US.UTF-8'}
TRIES = 5 TRIES = 5
for trial in xrange(TRIES): for trial in range(TRIES):
try: try:
p = subprocess.Popen([path] + list(args), p = subprocess.Popen([path] + list(args),
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, stderr=subprocess.PIPE,
env=env) env=env,
universal_newlines=True)
(output, err) = p.communicate() (output, err) = p.communicate()
break break
except OSError as e: except OSError as e: