Merge branch 'py3'

This adds python3 compatibility for blocking.transcribe and
blocking.transit, enough to allow the four
"wormhole (send|receive)-(text|file)" commands to work. These are all
tested by travis (via "trial wormhole").

"wormhole server" runs under py3, but only with --no-daemon (until
twisted.python.logfile is ported).

twisted.transcribe doesn't work yet (it needs twisted.web.client.Agent,
plus more local porting work).
This commit is contained in:
Brian Warner 2015-09-28 00:53:35 -07:00
commit 0838f9bc43
21 changed files with 283 additions and 137 deletions

View File

@ -53,7 +53,7 @@ The synchronous+blocking flow looks like this:
from wormhole.blocking.transcribe import Wormhole
from wormhole.public_relay import RENDEZVOUS_RELAY
mydata = b"initiator's data"
i = Wormhole("appid", RENDEZVOUS_RELAY)
i = Wormhole(b"appid", RENDEZVOUS_RELAY)
code = i.get_code()
print("Invitation Code: %s" % code)
theirdata = i.get_data(mydata)
@ -66,7 +66,7 @@ from wormhole.blocking.transcribe import Wormhole
from wormhole.public_relay import RENDEZVOUS_RELAY
mydata = b"receiver's data"
code = sys.argv[1]
r = Wormhole("appid", RENDEZVOUS_RELAY)
r = Wormhole(b"appid", RENDEZVOUS_RELAY)
r.set_code(code)
theirdata = r.get_data(mydata)
print("Their data: %s" % theirdata.decode("ascii"))
@ -81,7 +81,7 @@ from twisted.internet import reactor
from wormhole.public_relay import RENDEZVOUS_RELAY
from wormhole.twisted.transcribe import Wormhole
outbound_message = b"outbound data"
w1 = Wormhole("appid", RENDEZVOUS_RELAY)
w1 = Wormhole(b"appid", RENDEZVOUS_RELAY)
d = w1.get_code()
def _got_code(code):
print "Invitation Code:", code
@ -97,7 +97,7 @@ reactor.run()
On the other side, you call `set_code()` instead of waiting for `get_code()`:
```python
w2 = Wormhole("appid", RENDEZVOUS_RELAY)
w2 = Wormhole(b"appid", RENDEZVOUS_RELAY)
w2.set_code(code)
d = w2.get_data(my_message)
...
@ -132,6 +132,8 @@ include randomly-selected words or characters. Dice, coin flips, shuffled
cards, or repeated sampling of a high-resolution stopwatch are all useful
techniques.
Note that the code is a human-readable string (the python "str" type: so
unicode in python3, plain bytes in python2).
## Application Identifier
@ -140,7 +142,8 @@ simple bytestring that distinguishes one application from another. To ensure
uniqueness, use a domain name. To use multiple apps for a single domain, just
use a string like `example.com/app1`. This string must be the same on both
clients, otherwise they will not see each other. The invitation codes are
scoped to the app-id.
scoped to the app-id. Note that the app-id must be a bytestring, not unicode,
so on python3 use `b"appid"`.
Distinct app-ids reduce the size of the connection-id numbers. If fewer than
ten initiators are active for a given app-id, the connection-id will only
@ -209,6 +212,24 @@ To properly checkpoint the process, you should store the first message
(returned by `start()`) next to the serialized wormhole instance, so you can
re-send it if necessary.
## Bytes, Strings, Unicode, and Python 3
All cryptographically-sensitive parameters are passed as bytes ("str" in
python2, "bytes" in python3):
* application identifier
* verifier string
* data in/out
* derived-key "purpose" string
* transit records in/out
Some human-readable parameters are passed as strings: "str" in python2, "str"
(i.e. unicode) in python3:
* wormhole code
* relay/transit URLs
* transit connection hints (e.g. "host:port")
## Detailed Example
```python

View File

@ -20,7 +20,8 @@ setup(name="magic-wormhole",
package_data={"wormhole": ["db-schemas/*.sql"]},
entry_points={"console_scripts":
["wormhole = wormhole.scripts.runner:entry"]},
install_requires=["spake2==0.3", "pynacl", "requests", "argparse"],
install_requires=["spake2==0.3", "pynacl", "requests", "argparse",
"six"],
test_suite="wormhole.test",
cmdclass=commands,
)

View File

@ -1,3 +1,4 @@
import six
import requests
class EventSourceFollower:
@ -13,11 +14,12 @@ class EventSourceFollower:
def _get_fields(self, lines):
while True:
first_line = lines.next() # raises StopIteration when closed
first_line = next(lines) # raises StopIteration when closed
assert isinstance(first_line, type(six.u(""))), type(first_line)
fieldname, data = first_line.split(": ", 1)
data_lines = [data]
while True:
next_line = lines.next()
next_line = next(lines)
if not next_line: # empty string, original was "\n"
yield (fieldname, "\n".join(data_lines))
break
@ -30,12 +32,16 @@ class EventSourceFollower:
# for a long time. I'd prefer that chunk_size behaved like
# read(size), and gave you 1<=x<=size bytes in response.
eventtype = "message"
lines_iter = self.resp.iter_lines(chunk_size=1)
lines_iter = self.resp.iter_lines(chunk_size=1, decode_unicode=True)
for (fieldname, data) in self._get_fields(lines_iter):
# fieldname/data are unicode on both py2 and py3. On py2, where
# ("abc"==u"abc" is True), this compares unicode against str,
# which matches. On py3, where (b"abc"=="abc" is False), this
# compares unicode against unicode, which matches.
if fieldname == "data":
yield (eventtype, data)
eventtype = "message"
elif fieldname == "event":
eventtype = data
else:
print("weird fieldname", fieldname, data)
print("weird fieldname", fieldname, type(fieldname), data)

View File

@ -29,6 +29,7 @@ class Wormhole:
version_warning_displayed = False
def __init__(self, appid, relay):
if not isinstance(appid, type(b"")): raise UsageError
self.appid = appid
self.relay = relay
if not self.relay.endswith("/"): raise UsageError
@ -89,9 +90,10 @@ class Wormhole:
def get_code(self, code_length=2):
if self.code is not None: raise UsageError
self.side = hexlify(os.urandom(5))
self.side = hexlify(os.urandom(5)).decode("ascii")
channel_id = self._allocate_channel() # allocate channel
code = codes.make_code(channel_id, code_length)
assert isinstance(code, str), type(code)
self._set_code_and_channel_id(code)
self._start()
return code
@ -108,10 +110,11 @@ class Wormhole:
return code
def set_code(self, code): # used for human-made pre-generated codes
if not isinstance(code, str): raise UsageError
if self.code is not None: raise UsageError
if self.side is not None: raise UsageError
self._set_code_and_channel_id(code)
self.side = hexlify(os.urandom(5))
self.side = hexlify(os.urandom(5)).decode("ascii")
self._start()
def _set_code_and_channel_id(self, code):
@ -164,12 +167,16 @@ class Wormhole:
return HKDF(self.key, length, CTXinfo=purpose)
def _encrypt_data(self, key, data):
assert isinstance(key, type(b"")), type(key)
assert isinstance(data, type(b"")), type(data)
if len(key) != SecretBox.KEY_SIZE: raise UsageError
box = SecretBox(key)
nonce = utils.random(SecretBox.NONCE_SIZE)
return box.encrypt(data, nonce)
def _decrypt_data(self, key, encrypted):
assert isinstance(key, type(b"")), type(key)
assert isinstance(encrypted, type(b"")), type(encrypted)
if len(key) != SecretBox.KEY_SIZE: raise UsageError
box = SecretBox(key)
data = box.decrypt(encrypted)
@ -191,6 +198,7 @@ class Wormhole:
def get_data(self, outbound_data):
# only call this once
if not isinstance(outbound_data, type(b"")): raise UsageError
if self.code is None: raise UsageError
if self.channel_id is None: raise UsageError
try:

View File

@ -1,9 +1,11 @@
from __future__ import print_function
import re, time, threading, socket, SocketServer
import re, time, threading, socket
from six.moves import socketserver
from binascii import hexlify, unhexlify
from nacl.secret import SecretBox
from ..util import ipaddrs
from ..util.hkdf import HKDF
from ..errors import UsageError
class TransitError(Exception):
pass
@ -40,15 +42,15 @@ class TransitError(Exception):
def build_receiver_handshake(key):
hexid = HKDF(key, 32, CTXinfo=b"transit_receiver")
return "transit receiver %s ready\n\n" % hexlify(hexid)
return b"transit receiver %s ready\n\n" % hexlify(hexid)
def build_sender_handshake(key):
hexid = HKDF(key, 32, CTXinfo=b"transit_sender")
return "transit sender %s ready\n\n" % hexlify(hexid)
return b"transit sender %s ready\n\n" % hexlify(hexid)
def build_relay_handshake(key):
token = HKDF(key, 32, CTXinfo=b"transit_relay_token")
return "please relay %s\n" % hexlify(token)
return b"please relay %s\n" % hexlify(token)
TIMEOUT=15
@ -62,11 +64,6 @@ TIMEOUT=15
class BadHandshake(Exception):
pass
def force_ascii(s):
if isinstance(s, type(u"")):
return s.encode("ascii")
return s
def send_to(skt, data):
sent = 0
while sent < len(data):
@ -87,6 +84,7 @@ def wait_for(skt, expected, description):
# publisher wants anonymity, their only hint's ADDR will end in .onion .
def parse_hint_tcp(hint):
assert isinstance(hint, str)
# return tuple or None for an unparseable hint
mo = re.search(r'^([a-zA-Z0-9]+):(.*)$', hint)
if not mo:
@ -187,7 +185,7 @@ def handle(skt, client_address, owner, description,
# owner is now responsible for the socket
owner._negotiation_finished(skt, description) # note thread
class MyTCPServer(SocketServer.TCPServer):
class MyTCPServer(socketserver.TCPServer):
allow_reuse_address = True
def process_request(self, request, client_address):
@ -243,6 +241,7 @@ class RecordPipe:
self.next_receive_nonce = 0
def send_record(self, record):
if not isinstance(record, type(b"")): raise UsageError
assert SecretBox.NONCE_SIZE == 24
assert self.send_nonce < 2**(8*24)
assert len(record) < 2**(8*4)
@ -294,9 +293,9 @@ class Common:
return [self._transit_relay]
def add_their_direct_hints(self, hints):
self._their_direct_hints = [force_ascii(h) for h in hints]
self._their_direct_hints = [str(h) for h in hints]
def add_their_relay_hints(self, hints):
self._their_relay_hints = [force_ascii(h) for h in hints]
self._their_relay_hints = [str(h) for h in hints]
def _send_this(self):
if self.is_sender:
@ -308,7 +307,7 @@ class Common:
if self.is_sender:
return build_receiver_handshake(self._transit_key)
else:
return build_sender_handshake(self._transit_key) + "go\n"
return build_sender_handshake(self._transit_key) + b"go\n"
def _sender_record_key(self):
if self.is_sender:
@ -407,11 +406,11 @@ class Common:
if is_winner:
if self.is_sender:
send_to(skt, "go\n")
send_to(skt, b"go\n")
self.winning.set()
else:
if self.is_sender:
send_to(skt, "nevermind\n")
send_to(skt, b"nevermind\n")
skt.close()
def connect(self):

View File

@ -1,5 +1,5 @@
from __future__ import print_function
import os
import os, six
from .wordlist import (byte_to_even_word, byte_to_odd_word,
even_words_lowercase, odd_words_lowercase)
@ -81,7 +81,7 @@ def input_code_with_completion(prompt, get_channel_ids, code_length):
readline.parse_and_bind("tab: complete")
readline.set_completer(c.wrap_completer)
readline.set_completer_delims("")
code = raw_input(prompt)
code = six.moves.input(prompt)
return code
if __name__ == "__main__":

View File

@ -2,7 +2,7 @@ from __future__ import print_function
import sys, os, json, binascii
from ..errors import handle_server_error
APPID = "lothar.com/wormhole/file-xfer"
APPID = b"lothar.com/wormhole/file-xfer"
@handle_server_error
def receive_file(args):
@ -50,7 +50,7 @@ def receive_file(args):
# now receive the rest of the owl
tdata = data["transit"]
transit_key = w.derive_key(APPID+"/transit-key")
transit_key = w.derive_key(APPID+b"/transit-key")
transit_receiver.set_transit_key(transit_key)
transit_receiver.add_their_direct_hints(tdata["direct_connection_hints"])
transit_receiver.add_their_relay_hints(tdata["relay_connection_hints"])
@ -68,12 +68,12 @@ def receive_file(args):
if os.path.dirname(target) != here:
print("Error: suggested filename (%s) would be outside current directory"
% (filename,))
record_pipe.send_record("bad filename\n")
record_pipe.send_record(b"bad filename\n")
record_pipe.close()
return 1
if os.path.exists(target) and not args.overwrite:
print("Error: refusing to overwrite existing file %s" % (filename,))
record_pipe.send_record("file already exists\n")
record_pipe.send_record(b"file already exists\n")
record_pipe.close()
return 1
tmp = target + ".tmp"
@ -98,6 +98,6 @@ def receive_file(args):
os.rename(tmp, target)
print("Received file written to %s" % target)
record_pipe.send_record("ok\n")
record_pipe.send_record(b"ok\n")
record_pipe.close()
return 0

View File

@ -2,7 +2,7 @@ from __future__ import print_function
import sys, json, binascii
from ..errors import handle_server_error
APPID = "lothar.com/wormhole/text-xfer"
APPID = b"lothar.com/wormhole/text-xfer"
@handle_server_error
def receive_text(args):

View File

@ -1,8 +1,8 @@
from __future__ import print_function
import os, sys, json, binascii
import os, sys, json, binascii, six
from ..errors import handle_server_error
APPID = "lothar.com/wormhole/file-xfer"
APPID = b"lothar.com/wormhole/file-xfer"
@handle_server_error
def send_file(args):
@ -37,7 +37,7 @@ def send_file(args):
if args.verify:
verifier = binascii.hexlify(w.get_verifier())
while True:
ok = raw_input("Verifier %s. ok? (yes/no): " % verifier)
ok = six.moves.input("Verifier %s. ok? (yes/no): " % verifier)
if ok.lower() == "yes":
break
if ok.lower() == "no":
@ -70,7 +70,7 @@ def send_file(args):
tdata = them_d["transit"]
transit_key = w.derive_key(APPID+"/transit-key")
transit_key = w.derive_key(APPID+b"/transit-key")
transit_sender.set_transit_key(transit_key)
transit_sender.add_their_direct_hints(tdata["direct_connection_hints"])
transit_sender.add_their_relay_hints(tdata["relay_connection_hints"])
@ -91,7 +91,7 @@ def send_file(args):
print("File sent.. waiting for confirmation")
ack = record_pipe.receive_record()
if ack == "ok\n":
if ack == b"ok\n":
print("Confirmation received. Transfer complete.")
return 0
else:

View File

@ -1,8 +1,8 @@
from __future__ import print_function
import sys, json, binascii
import sys, json, binascii, six
from ..errors import handle_server_error
APPID = "lothar.com/wormhole/text-xfer"
APPID = b"lothar.com/wormhole/text-xfer"
@handle_server_error
def send_text(args):
@ -31,7 +31,7 @@ def send_text(args):
if args.verify:
verifier = binascii.hexlify(w.get_verifier())
while True:
ok = raw_input("Verifier %s. ok? (yes/no): " % verifier)
ok = six.moves.input("Verifier %s. ok? (yes/no): " % verifier)
if ok.lower() == "yes":
break
if ok.lower() == "no":

View File

@ -1,3 +1,4 @@
from __future__ import print_function
import sys, argparse
from textwrap import dedent
from .. import public_relay
@ -39,6 +40,7 @@ sp_start.add_argument("--transit", default="tcp:3001", metavar="tcp:PORT",
help="endpoint specification for the transit-relay port")
sp_start.add_argument("--advertise-version", metavar="VERSION",
help="version to recommend to clients")
sp_start.add_argument("-n", "--no-daemon", action="store_true")
#sp_start.add_argument("twistd_args", nargs="*", default=None,
# metavar="[TWISTD-ARGS..]",
# help=dedent("""\
@ -120,14 +122,19 @@ def run(args, stdout, stderr, executable=None):
also invoked by entry() below."""
args = parser.parse_args()
if not getattr(args, "func", None):
# So far this only works on py3. py2 exits with a really terse
# "error: too few arguments" during parse_args().
parser.print_help()
sys.exit(0)
try:
#rc = command.func(args, stdout, stderr)
rc = args.func(args)
return rc
except ImportError as e:
print >>stderr, "--- ImportError ---"
print >>stderr, e
print >>stderr, "Please run 'python setup.py build'"
print("--- ImportError ---", file=stderr)
print(e, file=stderr)
print("Please run 'python setup.py build'", file=stderr)
raise
return 1
@ -138,4 +145,4 @@ def entry():
if __name__ == "__main__":
args = parser.parse_args()
print args
print(args)

View File

@ -22,8 +22,11 @@ def start_server(args):
c = MyTwistdConfig()
#twistd_args = tuple(args.twistd_args) + ("XYZ",)
twistd_args = ("XYZ",) # TODO: allow user to add twistd-specific args
c.parseOptions(twistd_args)
base_args = []
if args.no_daemon:
base_args.append("--nodaemon")
twistd_args = base_args + ["XYZ"]
c.parseOptions(tuple(twistd_args))
c.loadedPlugins = {"XYZ": MyPlugin(args)}
print("starting wormhole relay server")

View File

@ -26,24 +26,24 @@ class EventsProtocol:
# face of firewall/NAT timeouts. It also helps unit tests, since
# apparently twisted.web.client.Agent doesn't consider the connection
# to be established until it sees the first byte of the reponse body.
self.request.write(": %s\n\n" % comment)
self.request.write(b": %s\n\n" % comment)
def sendEvent(self, data, name=None, id=None, retry=None):
if name:
self.request.write("event: %s\n" % name.encode("utf-8"))
self.request.write(b"event: %s\n" % name.encode("utf-8"))
# e.g. if name=foo, then the client web page should do:
# (new EventSource(url)).addEventListener("foo", handlerfunc)
# Note that this basically defaults to "message".
self.request.write("\n")
self.request.write(b"\n")
if id:
self.request.write("id: %s\n" % id.encode("utf-8"))
self.request.write("\n")
self.request.write(b"id: %s\n" % id.encode("utf-8"))
self.request.write(b"\n")
if retry:
self.request.write("retry: %d\n" % retry) # milliseconds
self.request.write("\n")
self.request.write(b"retry: %d\n" % retry) # milliseconds
self.request.write(b"\n")
for line in data.splitlines():
self.request.write("data: %s\n" % line.encode("utf-8"))
self.request.write("\n")
self.request.write(b"data: %s\n" % line.encode("utf-8"))
self.request.write(b"\n")
def stop(self):
self.request.finish()
@ -72,15 +72,15 @@ class Channel(resource.Resource):
def render_GET(self, request):
# rest of URL is: SIDE/poll/MSGNUM
their_side = request.postpath[0]
if request.postpath[1] != "poll":
request.setResponseCode(http.BAD_REQUEST, "GET to wrong URL")
return "GET is only for /SIDE/poll/MSGNUM"
their_msgnum = request.postpath[2]
if "text/event-stream" not in (request.getHeader("accept") or ""):
request.setResponseCode(http.BAD_REQUEST, "Must use EventSource")
return "Must use EventSource (Content-Type: text/event-stream)"
request.setHeader("content-type", "text/event-stream")
their_side = request.postpath[0].decode("utf-8")
if request.postpath[1] != b"poll":
request.setResponseCode(http.BAD_REQUEST, b"GET to wrong URL")
return b"GET is only for /SIDE/poll/MSGNUM"
their_msgnum = request.postpath[2].decode("utf-8")
if b"text/event-stream" not in (request.getHeader(b"accept") or b""):
request.setResponseCode(http.BAD_REQUEST, b"Must use EventSource")
return b"Must use EventSource (Content-Type: text/event-stream)"
request.setHeader(b"content-type", b"text/event-stream; charset=utf-8")
ep = EventsProtocol(request)
ep.sendEvent(json.dumps(self.welcome), name="welcome")
handle = (their_side, their_msgnum, ep)
@ -107,20 +107,20 @@ class Channel(resource.Resource):
def render_POST(self, request):
# rest of URL is: SIDE/(MSGNUM|deallocate)/(post|poll)
side = request.postpath[0]
verb = request.postpath[1]
side = request.postpath[0].decode("utf-8")
verb = request.postpath[1].decode("utf-8")
if verb == "deallocate":
deleted = self.relay.maybe_free_child(self.channel_id, side)
if deleted:
return "deleted\n"
return "waiting\n"
return b"deleted\n"
return b"waiting\n"
if verb not in ("post", "poll"):
request.setResponseCode(http.BAD_REQUEST)
return "bad verb, want 'post' or 'poll'\n"
return b"bad verb, want 'post' or 'poll'\n"
msgnum = request.postpath[2]
msgnum = request.postpath[2].decode("utf-8")
other_messages = []
for row in self.db.execute("SELECT `message` FROM `messages`"
@ -131,7 +131,9 @@ class Channel(resource.Resource):
other_messages.append(row["message"])
if verb == "post":
data = json.load(request.content)
#data = json.load(request.content, encoding="utf-8")
content = request.content.read()
data = json.loads(content.decode("utf-8"))
self.db.execute("INSERT INTO `messages`"
" (`channel_id`, `side`, `msgnum`, `message`, `when`)"
" VALUES (?,?,?,?,?)",
@ -144,9 +146,10 @@ class Channel(resource.Resource):
self.db.commit()
self.message_added(side, msgnum, data["message"])
request.setHeader("content-type", "application/json; charset=utf-8")
return json.dumps({"welcome": self.welcome,
"messages": other_messages})+"\n"
request.setHeader(b"content-type", b"application/json; charset=utf-8")
data = {"welcome": self.welcome,
"messages": other_messages}
return (json.dumps(data)+"\n").encode("utf-8")
def get_allocated(db):
c = db.execute("SELECT DISTINCT `channel_id` FROM `allocations`")
@ -183,9 +186,10 @@ class Allocator(resource.Resource):
self.db.commit()
log.msg("allocated #%d, now have %d DB channels" %
(channel_id, len(get_allocated(self.db))))
request.setHeader("content-type", "application/json; charset=utf-8")
return json.dumps({"welcome": self.welcome,
"channel-id": channel_id})+"\n"
request.setHeader(b"content-type", b"application/json; charset=utf-8")
data = {"welcome": self.welcome,
"channel-id": channel_id}
return (json.dumps(data)+"\n").encode("utf-8")
class ChannelList(resource.Resource):
def __init__(self, db, welcome):
@ -195,9 +199,10 @@ class ChannelList(resource.Resource):
def render_GET(self, request):
c = self.db.execute("SELECT DISTINCT `channel_id` FROM `allocations`")
allocated = sorted(set([row["channel_id"] for row in c.fetchall()]))
request.setHeader("content-type", "application/json; charset=utf-8")
return json.dumps({"welcome": self.welcome,
"channel-ids": allocated})+"\n"
request.setHeader(b"content-type", b"application/json; charset=utf-8")
data = {"welcome": self.welcome,
"channel-ids": allocated}
return (json.dumps(data)+"\n").encode("utf-8")
class Relay(resource.Resource):
def __init__(self, db, welcome):
@ -207,11 +212,11 @@ class Relay(resource.Resource):
self.channels = {}
def getChild(self, path, request):
if path == "allocate":
if path == b"allocate":
return Allocator(self.db, self.welcome)
if path == "list":
if path == b"list":
return ChannelList(self.db, self.welcome)
if not re.search(r'^\d+$', path):
if not re.search(br'^\d+$', path):
return resource.ErrorPage(http.BAD_REQUEST,
"invalid channel id",
"invalid channel id")
@ -381,7 +386,7 @@ class Root(resource.Resource):
# child_FOO is a nevow thing, not a twisted.web.resource thing
def __init__(self):
resource.Resource.__init__(self)
self.putChild("", static.Data("Wormhole Relay\n", "text/plain"))
self.putChild(b"", static.Data(b"Wormhole Relay\n", "text/plain"))
class RelayServer(service.MultiService):
def __init__(self, relayport, transitport, advertise_version,
@ -405,7 +410,7 @@ class RelayServer(service.MultiService):
self.relayport_service = EndpointServerService(r, site)
self.relayport_service.setServiceParent(self)
self.relay = Relay(self.db, welcome) # accessible from tests
self.root.putChild("wormhole-relay", self.relay)
self.root.putChild(b"wormhole-relay", self.relay)
t = internet.TimerService(EXPIRATION_CHECK_PERIOD,
self.relay.prune_old_channels)
t.setServiceParent(self)

View File

@ -10,14 +10,14 @@ class Blocking(ServerBase, unittest.TestCase):
# with deferToThread()
def test_basic(self):
appid = "appid"
appid = b"appid"
w1 = BlockingWormhole(appid, self.relayurl)
w2 = BlockingWormhole(appid, self.relayurl)
d = deferToThread(w1.get_code)
def _got_code(code):
w2.set_code(code)
d1 = deferToThread(w1.get_data, "data1")
d2 = deferToThread(w2.get_data, "data2")
d1 = deferToThread(w1.get_data, b"data1")
d2 = deferToThread(w2.get_data, b"data2")
return defer.DeferredList([d1,d2], fireOnOneErrback=False)
d.addCallback(_got_code)
def _done(dl):
@ -25,35 +25,35 @@ class Blocking(ServerBase, unittest.TestCase):
r1,r2 = dl
self.assertTrue(success1, dataX)
self.assertTrue(success2, dataY)
self.assertEqual(dataX, "data2")
self.assertEqual(dataY, "data1")
self.assertEqual(dataX, b"data2")
self.assertEqual(dataY, b"data1")
d.addCallback(_done)
return d
def test_fixed_code(self):
appid = "appid"
appid = b"appid"
w1 = BlockingWormhole(appid, self.relayurl)
w2 = BlockingWormhole(appid, self.relayurl)
w1.set_code("123-purple-elephant")
w2.set_code("123-purple-elephant")
d1 = deferToThread(w1.get_data, "data1")
d2 = deferToThread(w2.get_data, "data2")
d1 = deferToThread(w1.get_data, b"data1")
d2 = deferToThread(w2.get_data, b"data2")
d = defer.DeferredList([d1,d2], fireOnOneErrback=False)
def _done(dl):
((success1, dataX), (success2, dataY)) = dl
r1,r2 = dl
self.assertTrue(success1, dataX)
self.assertTrue(success2, dataY)
self.assertEqual(dataX, "data2")
self.assertEqual(dataY, "data1")
self.assertEqual(dataX, b"data2")
self.assertEqual(dataY, b"data1")
d.addCallback(_done)
return d
def test_errors(self):
appid = "appid"
appid = b"appid"
w1 = BlockingWormhole(appid, self.relayurl)
self.assertRaises(UsageError, w1.get_verifier)
self.assertRaises(UsageError, w1.get_data, "data")
self.assertRaises(UsageError, w1.get_data, b"data")
w1.set_code("123-purple-elephant")
self.assertRaises(UsageError, w1.set_code, "123-nope")
self.assertRaises(UsageError, w1.get_code)
@ -65,7 +65,7 @@ class Blocking(ServerBase, unittest.TestCase):
return d
def test_serialize(self):
appid = "appid"
appid = b"appid"
w1 = BlockingWormhole(appid, self.relayurl)
self.assertRaises(UsageError, w1.serialize) # too early
w2 = BlockingWormhole(appid, self.relayurl)
@ -79,8 +79,8 @@ class Blocking(ServerBase, unittest.TestCase):
unpacked = json.loads(s) # this is supposed to be JSON
self.assertEqual(type(unpacked), dict)
new_w1 = BlockingWormhole.from_serialized(s)
d1 = deferToThread(new_w1.get_data, "data1")
d2 = deferToThread(w2.get_data, "data2")
d1 = deferToThread(new_w1.get_data, b"data1")
d2 = deferToThread(w2.get_data, b"data2")
return defer.DeferredList([d1,d2], fireOnOneErrback=False)
d.addCallback(_got_code)
def _done(dl):
@ -88,8 +88,8 @@ class Blocking(ServerBase, unittest.TestCase):
r1,r2 = dl
self.assertTrue(success1, dataX)
self.assertTrue(success2, dataY)
self.assertEqual(dataX, "data2")
self.assertEqual(dataY, "data1")
self.assertEqual(dataX, b"data2")
self.assertEqual(dataY, b"data1")
self.assertRaises(UsageError, w2.serialize) # too late
d.addCallback(_done)
return d

View File

@ -55,12 +55,21 @@ class ScriptVersion(ServerBase, ScriptsBase, unittest.TestCase):
d = getProcessOutputAndValue(wormhole, ["--version"])
def _check(res):
out, err, rc = res
self.failUnlessEqual(out, "")
# argparse on py2 sends --version to stderr
# argparse on py3 sends --version to stdout
# aargh
out = out.decode("utf-8")
err = err.decode("utf-8")
if "DistributionNotFound" in err:
log.msg("stderr was %s" % err)
last = err.strip().split("\n")[-1]
self.fail("wormhole not runnable: %s" % last)
self.failUnlessEqual(err, "magic-wormhole %s\n" % __version__)
if sys.version_info[0] == 2:
self.failUnlessEqual(out, "")
self.failUnlessEqual(err, "magic-wormhole %s\n" % __version__)
else:
self.failUnlessEqual(err, "")
self.failUnlessEqual(out, "magic-wormhole %s\n" % __version__)
self.failUnlessEqual(rc, 0)
d.addCallback(_check)
return d
@ -92,6 +101,8 @@ class Scripts(ServerBase, ScriptsBase, unittest.TestCase):
d2 = getProcessOutputAndValue(wormhole, receive_args)
def _check_sender(res):
out, err, rc = res
out = out.decode("utf-8")
err = err.decode("utf-8")
self.failUnlessEqual(out,
"On the other computer, please run: "
"wormhole receive-text\n"
@ -104,6 +115,8 @@ class Scripts(ServerBase, ScriptsBase, unittest.TestCase):
d1.addCallback(_check_sender)
def _check_receiver(res):
out, err, rc = res
out = out.decode("utf-8")
err = err.decode("utf-8")
self.failUnlessEqual(out, message+"\n")
self.failUnlessEqual(err, "")
self.failUnlessEqual(rc, 0)
@ -137,6 +150,8 @@ class Scripts(ServerBase, ScriptsBase, unittest.TestCase):
d2 = getProcessOutputAndValue(wormhole, receive_args, path=receive_dir)
def _check_sender(res):
out, err, rc = res
out = out.decode("utf-8")
err = err.decode("utf-8")
self.failUnlessIn("On the other computer, please run: "
"wormhole receive-file\n"
"Wormhole code is '%s'\n\n" % code,
@ -150,6 +165,8 @@ class Scripts(ServerBase, ScriptsBase, unittest.TestCase):
d1.addCallback(_check_sender)
def _check_receiver(res):
out, err, rc = res
out = out.decode("utf-8")
err = err.decode("utf-8")
self.failUnlessIn("Receiving %d bytes for 'testfile'" % len(message),
out)
self.failUnlessIn("Received file written to ", out)

View File

@ -0,0 +1,49 @@
import sys
import requests
from twisted.trial import unittest
from twisted.internet import reactor
from twisted.internet.threads import deferToThread
from twisted.web.client import getPage, Agent, readBody
from .common import ServerBase
class Reachable(ServerBase, unittest.TestCase):
def test_getPage(self):
# client.getPage requires str/unicode URL, returns bytes
url = self.relayurl.replace("wormhole-relay/", "").encode("ascii")
d = getPage(url)
def _got(res):
self.failUnlessEqual(res, b"Wormhole Relay\n")
d.addCallback(_got)
return d
def test_agent(self):
# client.Agent is not yet ported: it wants URLs to be both unicode
# and bytes at the same time.
# https://twistedmatrix.com/trac/ticket/7407
if sys.version_info[0] > 2:
raise unittest.SkipTest("twisted.web.client.Agent does not yet support py3")
url = self.relayurl.replace("wormhole-relay/", "").encode("ascii")
agent = Agent(reactor)
d = agent.request("GET", url)
def _check(resp):
self.failUnlessEqual(resp.code, 200)
return readBody(resp)
d.addCallback(_check)
def _got(res):
self.failUnlessEqual(res, b"Wormhole Relay\n")
d.addCallback(_got)
return d
def test_requests(self):
# requests requires bytes URL, returns unicode
url = self.relayurl.replace("wormhole-relay/", "")
def _get(url):
r = requests.get(url)
r.raise_for_status()
return r.text
d = deferToThread(_get, url)
def _got(res):
self.failUnlessEqual(res, "Wormhole Relay\n")
d.addCallback(_got)
return d

View File

@ -1,22 +1,19 @@
import json
import sys, json
from twisted.trial import unittest
from twisted.internet import defer
from ..twisted.transcribe import Wormhole, UsageError
from .common import ServerBase
#from twisted.python import log
#import sys
#log.startLogging(sys.stdout)
class Basic(ServerBase, unittest.TestCase):
def test_basic(self):
appid = "appid"
appid = b"appid"
w1 = Wormhole(appid, self.relayurl)
w2 = Wormhole(appid, self.relayurl)
d = w1.get_code()
def _got_code(code):
w2.set_code(code)
d1 = w1.get_data("data1")
d2 = w2.get_data("data2")
d1 = w1.get_data(b"data1")
d2 = w2.get_data(b"data2")
return defer.DeferredList([d1,d2], fireOnOneErrback=False)
d.addCallback(_got_code)
def _done(dl):
@ -24,35 +21,35 @@ class Basic(ServerBase, unittest.TestCase):
r1,r2 = dl
self.assertTrue(success1, dataX)
self.assertTrue(success2, dataY)
self.assertEqual(dataX, "data2")
self.assertEqual(dataY, "data1")
self.assertEqual(dataX, b"data2")
self.assertEqual(dataY, b"data1")
d.addCallback(_done)
return d
def test_fixed_code(self):
appid = "appid"
appid = b"appid"
w1 = Wormhole(appid, self.relayurl)
w2 = Wormhole(appid, self.relayurl)
w1.set_code("123-purple-elephant")
w2.set_code("123-purple-elephant")
d1 = w1.get_data("data1")
d2 = w2.get_data("data2")
d1 = w1.get_data(b"data1")
d2 = w2.get_data(b"data2")
d = defer.DeferredList([d1,d2], fireOnOneErrback=False)
def _done(dl):
((success1, dataX), (success2, dataY)) = dl
r1,r2 = dl
self.assertTrue(success1, dataX)
self.assertTrue(success2, dataY)
self.assertEqual(dataX, "data2")
self.assertEqual(dataY, "data1")
self.assertEqual(dataX, b"data2")
self.assertEqual(dataY, b"data1")
d.addCallback(_done)
return d
def test_errors(self):
appid = "appid"
appid = b"appid"
w1 = Wormhole(appid, self.relayurl)
self.assertRaises(UsageError, w1.get_verifier)
self.assertRaises(UsageError, w1.get_data, "data")
self.assertRaises(UsageError, w1.get_data, b"data")
w1.set_code("123-purple-elephant")
self.assertRaises(UsageError, w1.set_code, "123-nope")
self.assertRaises(UsageError, w1.get_code)
@ -62,7 +59,7 @@ class Basic(ServerBase, unittest.TestCase):
return d
def test_serialize(self):
appid = "appid"
appid = b"appid"
w1 = Wormhole(appid, self.relayurl)
self.assertRaises(UsageError, w1.serialize) # too early
w2 = Wormhole(appid, self.relayurl)
@ -76,8 +73,8 @@ class Basic(ServerBase, unittest.TestCase):
unpacked = json.loads(s) # this is supposed to be JSON
self.assertEqual(type(unpacked), dict)
new_w1 = Wormhole.from_serialized(s)
d1 = new_w1.get_data("data1")
d2 = w2.get_data("data2")
d1 = new_w1.get_data(b"data1")
d2 = w2.get_data(b"data2")
return defer.DeferredList([d1,d2], fireOnOneErrback=False)
d.addCallback(_got_code)
def _done(dl):
@ -85,8 +82,14 @@ class Basic(ServerBase, unittest.TestCase):
r1,r2 = dl
self.assertTrue(success1, dataX)
self.assertTrue(success2, dataY)
self.assertEqual(dataX, "data2")
self.assertEqual(dataY, "data1")
self.assertEqual(dataX, b"data2")
self.assertEqual(dataY, b"data1")
self.assertRaises(UsageError, w2.serialize) # too late
d.addCallback(_done)
return d
if sys.version_info[0] >= 3:
Basic.skip = "twisted is not yet sufficiently ported to py3"
# as of 15.4.0, Twisted is still missing:
# * web.client.Agent (for all non-EventSource POSTs in transcribe.py)
# * python.logfile (to allow daemonization of 'wormhole server')

View File

@ -1,3 +1,4 @@
from __future__ import print_function
import sys, json
from twisted.internet import reactor
from .transcribe import Wormhole
@ -12,15 +13,15 @@ if sys.argv[1] == "send-text":
data = json.dumps({"message": message}).encode("utf-8")
d = w.get_code()
def _got_code(code):
print "code is:", code
print("code is:", code)
return w.get_data(data)
d.addCallback(_got_code)
def _got_data(them_bytes):
them_d = json.loads(them_bytes.decode("utf-8"))
if them_d["message"] == "ok":
print "text sent"
print("text sent")
else:
print "error sending text: %r" % (them_d,)
print("error sending text: %r" % (them_d,))
d.addCallback(_got_data)
elif sys.argv[1] == "receive-text":
code = sys.argv[2]
@ -30,9 +31,9 @@ elif sys.argv[1] == "receive-text":
def _got_data(them_bytes):
them_d = json.loads(them_bytes.decode("utf-8"))
if "error" in them_d:
print >>sys.stderr, "ERROR: " + them_d["error"]
print("ERROR: " + them_d["error"], file=sys.stderr)
return 1
print them_d["message"]
print(them_d["message"])
d.addCallback(_got_data)
else:
raise ValueError("bad command")

View File

@ -1,11 +1,18 @@
#import sys
from twisted.python import log, failure
from twisted.internet import reactor, defer, protocol
from twisted.application import service
from twisted.protocols import basic
from twisted.web.client import Agent, ResponseDone
from twisted.web.http_headers import Headers
from cgi import parse_header
from ..util.eventual import eventually
#if sys.version_info[0] == 2:
# to_unicode = unicode
#else:
# to_unicode = str
class EventSourceParser(basic.LineOnlyReceiver):
delimiter = "\n"
@ -15,6 +22,10 @@ class EventSourceParser(basic.LineOnlyReceiver):
self.handler = handler
self.done_deferred = defer.Deferred()
self.eventtype = "message"
self.encoding = "utf-8"
def set_encoding(self, encoding):
self.encoding = encoding
def connectionLost(self, why):
if why.check(ResponseDone):
@ -40,6 +51,8 @@ class EventSourceParser(basic.LineOnlyReceiver):
self.current_field = None
self.current_lines[:] = []
return
line = line.decode(self.encoding)
#line = to_unicode(line, self.encoding)
if self.current_field is None:
self.current_field, data = line.split(": ", 1)
self.current_lines.append(data)
@ -90,7 +103,11 @@ class EventSource: # TODO: service.Service
raise EventSourceError("%d: %s" % (resp.code, resp.phrase))
if self.when_connected:
self.when_connected()
#if resp.headers.getRawHeaders("content-type") == ["text/event-stream"]:
default_ct = "text/event-stream; charset=utf-8"
ct_headers = resp.headers.getRawHeaders("content-type", [default_ct])
ct, ct_params = parse_header(ct_headers[0])
assert ct == "text/event-stream", ct
self.proto.set_encoding(ct_params.get("charset", "utf-8"))
resp.deliverBody(self.proto)
if self.cancelled:
self.kill_connection()

View File

@ -38,6 +38,7 @@ class Wormhole:
version_warning_displayed = False
def __init__(self, appid, relay):
if not isinstance(appid, type(b"")): raise UsageError
self.appid = appid
self.relay = relay
self.agent = web_client.Agent(reactor)
@ -109,6 +110,7 @@ class Wormhole:
d = self._allocate_channel()
def _got_channel_id(channel_id):
code = codes.make_code(channel_id, code_length)
assert isinstance(code, str), type(code)
self._set_code_and_channel_id(code)
self._start()
return code
@ -116,6 +118,7 @@ class Wormhole:
return d
def set_code(self, code):
if not isinstance(code, str): raise UsageError
if self.code is not None: raise UsageError
if self.side is not None: raise UsageError
self._set_code_and_channel_id(code)
@ -201,12 +204,16 @@ class Wormhole:
return HKDF(self.key, length, CTXinfo=purpose)
def _encrypt_data(self, key, data):
assert isinstance(key, type(b"")), type(key)
assert isinstance(data, type(b"")), type(data)
if len(key) != SecretBox.KEY_SIZE: raise UsageError
box = SecretBox(key)
nonce = utils.random(SecretBox.NONCE_SIZE)
return box.encrypt(data, nonce)
def _decrypt_data(self, key, encrypted):
assert isinstance(key, type(b"")), type(key)
assert isinstance(encrypted, type(b"")), type(encrypted)
if len(key) != SecretBox.KEY_SIZE: raise UsageError
box = SecretBox(key)
data = box.decrypt(encrypted)
@ -235,6 +242,7 @@ class Wormhole:
def get_data(self, outbound_data):
# only call this once
if not isinstance(outbound_data, type(b"")): raise UsageError
if self.code is None: raise UsageError
d = self._get_key()
d.addCallback(self._get_data2, outbound_data)

View File

@ -32,7 +32,7 @@ def find_addresses():
commands = _unix_commands
for (pathtotool, args, regex) in commands:
assert os.path.isabs(pathtotool)
assert os.path.isabs(pathtotool), pathtotool
if not os.path.isfile(pathtotool):
continue
try:
@ -46,12 +46,13 @@ def find_addresses():
def _query(path, args, regex):
env = {'LANG': 'en_US.UTF-8'}
TRIES = 5
for trial in xrange(TRIES):
for trial in range(TRIES):
try:
p = subprocess.Popen([path] + list(args),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=env)
env=env,
universal_newlines=True)
(output, err) = p.communicate()
break
except OSError as e: