Merge branch 'py3'

This adds python3 compatibility for blocking.transcribe and
blocking.transit, enough to allow the four
"wormhole (send|receive)-(text|file)" commands to work. These are all
tested by travis (via "trial wormhole").

"wormhole server" runs under py3, but only with --no-daemon (until
twisted.python.logfile is ported).

twisted.transcribe doesn't work yet (it needs twisted.web.client.Agent,
plus more local porting work).
This commit is contained in:
Brian Warner 2015-09-28 00:53:35 -07:00
commit 0838f9bc43
21 changed files with 283 additions and 137 deletions

View File

@ -53,7 +53,7 @@ The synchronous+blocking flow looks like this:
from wormhole.blocking.transcribe import Wormhole from wormhole.blocking.transcribe import Wormhole
from wormhole.public_relay import RENDEZVOUS_RELAY from wormhole.public_relay import RENDEZVOUS_RELAY
mydata = b"initiator's data" mydata = b"initiator's data"
i = Wormhole("appid", RENDEZVOUS_RELAY) i = Wormhole(b"appid", RENDEZVOUS_RELAY)
code = i.get_code() code = i.get_code()
print("Invitation Code: %s" % code) print("Invitation Code: %s" % code)
theirdata = i.get_data(mydata) theirdata = i.get_data(mydata)
@ -66,7 +66,7 @@ from wormhole.blocking.transcribe import Wormhole
from wormhole.public_relay import RENDEZVOUS_RELAY from wormhole.public_relay import RENDEZVOUS_RELAY
mydata = b"receiver's data" mydata = b"receiver's data"
code = sys.argv[1] code = sys.argv[1]
r = Wormhole("appid", RENDEZVOUS_RELAY) r = Wormhole(b"appid", RENDEZVOUS_RELAY)
r.set_code(code) r.set_code(code)
theirdata = r.get_data(mydata) theirdata = r.get_data(mydata)
print("Their data: %s" % theirdata.decode("ascii")) print("Their data: %s" % theirdata.decode("ascii"))
@ -81,7 +81,7 @@ from twisted.internet import reactor
from wormhole.public_relay import RENDEZVOUS_RELAY from wormhole.public_relay import RENDEZVOUS_RELAY
from wormhole.twisted.transcribe import Wormhole from wormhole.twisted.transcribe import Wormhole
outbound_message = b"outbound data" outbound_message = b"outbound data"
w1 = Wormhole("appid", RENDEZVOUS_RELAY) w1 = Wormhole(b"appid", RENDEZVOUS_RELAY)
d = w1.get_code() d = w1.get_code()
def _got_code(code): def _got_code(code):
print "Invitation Code:", code print "Invitation Code:", code
@ -97,7 +97,7 @@ reactor.run()
On the other side, you call `set_code()` instead of waiting for `get_code()`: On the other side, you call `set_code()` instead of waiting for `get_code()`:
```python ```python
w2 = Wormhole("appid", RENDEZVOUS_RELAY) w2 = Wormhole(b"appid", RENDEZVOUS_RELAY)
w2.set_code(code) w2.set_code(code)
d = w2.get_data(my_message) d = w2.get_data(my_message)
... ...
@ -132,6 +132,8 @@ include randomly-selected words or characters. Dice, coin flips, shuffled
cards, or repeated sampling of a high-resolution stopwatch are all useful cards, or repeated sampling of a high-resolution stopwatch are all useful
techniques. techniques.
Note that the code is a human-readable string (the python "str" type: so
unicode in python3, plain bytes in python2).
## Application Identifier ## Application Identifier
@ -140,7 +142,8 @@ simple bytestring that distinguishes one application from another. To ensure
uniqueness, use a domain name. To use multiple apps for a single domain, just uniqueness, use a domain name. To use multiple apps for a single domain, just
use a string like `example.com/app1`. This string must be the same on both use a string like `example.com/app1`. This string must be the same on both
clients, otherwise they will not see each other. The invitation codes are clients, otherwise they will not see each other. The invitation codes are
scoped to the app-id. scoped to the app-id. Note that the app-id must be a bytestring, not unicode,
so on python3 use `b"appid"`.
Distinct app-ids reduce the size of the connection-id numbers. If fewer than Distinct app-ids reduce the size of the connection-id numbers. If fewer than
ten initiators are active for a given app-id, the connection-id will only ten initiators are active for a given app-id, the connection-id will only
@ -209,6 +212,24 @@ To properly checkpoint the process, you should store the first message
(returned by `start()`) next to the serialized wormhole instance, so you can (returned by `start()`) next to the serialized wormhole instance, so you can
re-send it if necessary. re-send it if necessary.
## Bytes, Strings, Unicode, and Python 3
All cryptographically-sensitive parameters are passed as bytes ("str" in
python2, "bytes" in python3):
* application identifier
* verifier string
* data in/out
* derived-key "purpose" string
* transit records in/out
Some human-readable parameters are passed as strings: "str" in python2, "str"
(i.e. unicode) in python3:
* wormhole code
* relay/transit URLs
* transit connection hints (e.g. "host:port")
## Detailed Example ## Detailed Example
```python ```python

View File

@ -20,7 +20,8 @@ setup(name="magic-wormhole",
package_data={"wormhole": ["db-schemas/*.sql"]}, package_data={"wormhole": ["db-schemas/*.sql"]},
entry_points={"console_scripts": entry_points={"console_scripts":
["wormhole = wormhole.scripts.runner:entry"]}, ["wormhole = wormhole.scripts.runner:entry"]},
install_requires=["spake2==0.3", "pynacl", "requests", "argparse"], install_requires=["spake2==0.3", "pynacl", "requests", "argparse",
"six"],
test_suite="wormhole.test", test_suite="wormhole.test",
cmdclass=commands, cmdclass=commands,
) )

View File

@ -1,3 +1,4 @@
import six
import requests import requests
class EventSourceFollower: class EventSourceFollower:
@ -13,11 +14,12 @@ class EventSourceFollower:
def _get_fields(self, lines): def _get_fields(self, lines):
while True: while True:
first_line = lines.next() # raises StopIteration when closed first_line = next(lines) # raises StopIteration when closed
assert isinstance(first_line, type(six.u(""))), type(first_line)
fieldname, data = first_line.split(": ", 1) fieldname, data = first_line.split(": ", 1)
data_lines = [data] data_lines = [data]
while True: while True:
next_line = lines.next() next_line = next(lines)
if not next_line: # empty string, original was "\n" if not next_line: # empty string, original was "\n"
yield (fieldname, "\n".join(data_lines)) yield (fieldname, "\n".join(data_lines))
break break
@ -30,12 +32,16 @@ class EventSourceFollower:
# for a long time. I'd prefer that chunk_size behaved like # for a long time. I'd prefer that chunk_size behaved like
# read(size), and gave you 1<=x<=size bytes in response. # read(size), and gave you 1<=x<=size bytes in response.
eventtype = "message" eventtype = "message"
lines_iter = self.resp.iter_lines(chunk_size=1) lines_iter = self.resp.iter_lines(chunk_size=1, decode_unicode=True)
for (fieldname, data) in self._get_fields(lines_iter): for (fieldname, data) in self._get_fields(lines_iter):
# fieldname/data are unicode on both py2 and py3. On py2, where
# ("abc"==u"abc" is True), this compares unicode against str,
# which matches. On py3, where (b"abc"=="abc" is False), this
# compares unicode against unicode, which matches.
if fieldname == "data": if fieldname == "data":
yield (eventtype, data) yield (eventtype, data)
eventtype = "message" eventtype = "message"
elif fieldname == "event": elif fieldname == "event":
eventtype = data eventtype = data
else: else:
print("weird fieldname", fieldname, data) print("weird fieldname", fieldname, type(fieldname), data)

View File

@ -29,6 +29,7 @@ class Wormhole:
version_warning_displayed = False version_warning_displayed = False
def __init__(self, appid, relay): def __init__(self, appid, relay):
if not isinstance(appid, type(b"")): raise UsageError
self.appid = appid self.appid = appid
self.relay = relay self.relay = relay
if not self.relay.endswith("/"): raise UsageError if not self.relay.endswith("/"): raise UsageError
@ -89,9 +90,10 @@ class Wormhole:
def get_code(self, code_length=2): def get_code(self, code_length=2):
if self.code is not None: raise UsageError if self.code is not None: raise UsageError
self.side = hexlify(os.urandom(5)) self.side = hexlify(os.urandom(5)).decode("ascii")
channel_id = self._allocate_channel() # allocate channel channel_id = self._allocate_channel() # allocate channel
code = codes.make_code(channel_id, code_length) code = codes.make_code(channel_id, code_length)
assert isinstance(code, str), type(code)
self._set_code_and_channel_id(code) self._set_code_and_channel_id(code)
self._start() self._start()
return code return code
@ -108,10 +110,11 @@ class Wormhole:
return code return code
def set_code(self, code): # used for human-made pre-generated codes def set_code(self, code): # used for human-made pre-generated codes
if not isinstance(code, str): raise UsageError
if self.code is not None: raise UsageError if self.code is not None: raise UsageError
if self.side is not None: raise UsageError if self.side is not None: raise UsageError
self._set_code_and_channel_id(code) self._set_code_and_channel_id(code)
self.side = hexlify(os.urandom(5)) self.side = hexlify(os.urandom(5)).decode("ascii")
self._start() self._start()
def _set_code_and_channel_id(self, code): def _set_code_and_channel_id(self, code):
@ -164,12 +167,16 @@ class Wormhole:
return HKDF(self.key, length, CTXinfo=purpose) return HKDF(self.key, length, CTXinfo=purpose)
def _encrypt_data(self, key, data): def _encrypt_data(self, key, data):
assert isinstance(key, type(b"")), type(key)
assert isinstance(data, type(b"")), type(data)
if len(key) != SecretBox.KEY_SIZE: raise UsageError if len(key) != SecretBox.KEY_SIZE: raise UsageError
box = SecretBox(key) box = SecretBox(key)
nonce = utils.random(SecretBox.NONCE_SIZE) nonce = utils.random(SecretBox.NONCE_SIZE)
return box.encrypt(data, nonce) return box.encrypt(data, nonce)
def _decrypt_data(self, key, encrypted): def _decrypt_data(self, key, encrypted):
assert isinstance(key, type(b"")), type(key)
assert isinstance(encrypted, type(b"")), type(encrypted)
if len(key) != SecretBox.KEY_SIZE: raise UsageError if len(key) != SecretBox.KEY_SIZE: raise UsageError
box = SecretBox(key) box = SecretBox(key)
data = box.decrypt(encrypted) data = box.decrypt(encrypted)
@ -191,6 +198,7 @@ class Wormhole:
def get_data(self, outbound_data): def get_data(self, outbound_data):
# only call this once # only call this once
if not isinstance(outbound_data, type(b"")): raise UsageError
if self.code is None: raise UsageError if self.code is None: raise UsageError
if self.channel_id is None: raise UsageError if self.channel_id is None: raise UsageError
try: try:

View File

@ -1,9 +1,11 @@
from __future__ import print_function from __future__ import print_function
import re, time, threading, socket, SocketServer import re, time, threading, socket
from six.moves import socketserver
from binascii import hexlify, unhexlify from binascii import hexlify, unhexlify
from nacl.secret import SecretBox from nacl.secret import SecretBox
from ..util import ipaddrs from ..util import ipaddrs
from ..util.hkdf import HKDF from ..util.hkdf import HKDF
from ..errors import UsageError
class TransitError(Exception): class TransitError(Exception):
pass pass
@ -40,15 +42,15 @@ class TransitError(Exception):
def build_receiver_handshake(key): def build_receiver_handshake(key):
hexid = HKDF(key, 32, CTXinfo=b"transit_receiver") hexid = HKDF(key, 32, CTXinfo=b"transit_receiver")
return "transit receiver %s ready\n\n" % hexlify(hexid) return b"transit receiver %s ready\n\n" % hexlify(hexid)
def build_sender_handshake(key): def build_sender_handshake(key):
hexid = HKDF(key, 32, CTXinfo=b"transit_sender") hexid = HKDF(key, 32, CTXinfo=b"transit_sender")
return "transit sender %s ready\n\n" % hexlify(hexid) return b"transit sender %s ready\n\n" % hexlify(hexid)
def build_relay_handshake(key): def build_relay_handshake(key):
token = HKDF(key, 32, CTXinfo=b"transit_relay_token") token = HKDF(key, 32, CTXinfo=b"transit_relay_token")
return "please relay %s\n" % hexlify(token) return b"please relay %s\n" % hexlify(token)
TIMEOUT=15 TIMEOUT=15
@ -62,11 +64,6 @@ TIMEOUT=15
class BadHandshake(Exception): class BadHandshake(Exception):
pass pass
def force_ascii(s):
if isinstance(s, type(u"")):
return s.encode("ascii")
return s
def send_to(skt, data): def send_to(skt, data):
sent = 0 sent = 0
while sent < len(data): while sent < len(data):
@ -87,6 +84,7 @@ def wait_for(skt, expected, description):
# publisher wants anonymity, their only hint's ADDR will end in .onion . # publisher wants anonymity, their only hint's ADDR will end in .onion .
def parse_hint_tcp(hint): def parse_hint_tcp(hint):
assert isinstance(hint, str)
# return tuple or None for an unparseable hint # return tuple or None for an unparseable hint
mo = re.search(r'^([a-zA-Z0-9]+):(.*)$', hint) mo = re.search(r'^([a-zA-Z0-9]+):(.*)$', hint)
if not mo: if not mo:
@ -187,7 +185,7 @@ def handle(skt, client_address, owner, description,
# owner is now responsible for the socket # owner is now responsible for the socket
owner._negotiation_finished(skt, description) # note thread owner._negotiation_finished(skt, description) # note thread
class MyTCPServer(SocketServer.TCPServer): class MyTCPServer(socketserver.TCPServer):
allow_reuse_address = True allow_reuse_address = True
def process_request(self, request, client_address): def process_request(self, request, client_address):
@ -243,6 +241,7 @@ class RecordPipe:
self.next_receive_nonce = 0 self.next_receive_nonce = 0
def send_record(self, record): def send_record(self, record):
if not isinstance(record, type(b"")): raise UsageError
assert SecretBox.NONCE_SIZE == 24 assert SecretBox.NONCE_SIZE == 24
assert self.send_nonce < 2**(8*24) assert self.send_nonce < 2**(8*24)
assert len(record) < 2**(8*4) assert len(record) < 2**(8*4)
@ -294,9 +293,9 @@ class Common:
return [self._transit_relay] return [self._transit_relay]
def add_their_direct_hints(self, hints): def add_their_direct_hints(self, hints):
self._their_direct_hints = [force_ascii(h) for h in hints] self._their_direct_hints = [str(h) for h in hints]
def add_their_relay_hints(self, hints): def add_their_relay_hints(self, hints):
self._their_relay_hints = [force_ascii(h) for h in hints] self._their_relay_hints = [str(h) for h in hints]
def _send_this(self): def _send_this(self):
if self.is_sender: if self.is_sender:
@ -308,7 +307,7 @@ class Common:
if self.is_sender: if self.is_sender:
return build_receiver_handshake(self._transit_key) return build_receiver_handshake(self._transit_key)
else: else:
return build_sender_handshake(self._transit_key) + "go\n" return build_sender_handshake(self._transit_key) + b"go\n"
def _sender_record_key(self): def _sender_record_key(self):
if self.is_sender: if self.is_sender:
@ -407,11 +406,11 @@ class Common:
if is_winner: if is_winner:
if self.is_sender: if self.is_sender:
send_to(skt, "go\n") send_to(skt, b"go\n")
self.winning.set() self.winning.set()
else: else:
if self.is_sender: if self.is_sender:
send_to(skt, "nevermind\n") send_to(skt, b"nevermind\n")
skt.close() skt.close()
def connect(self): def connect(self):

View File

@ -1,5 +1,5 @@
from __future__ import print_function from __future__ import print_function
import os import os, six
from .wordlist import (byte_to_even_word, byte_to_odd_word, from .wordlist import (byte_to_even_word, byte_to_odd_word,
even_words_lowercase, odd_words_lowercase) even_words_lowercase, odd_words_lowercase)
@ -81,7 +81,7 @@ def input_code_with_completion(prompt, get_channel_ids, code_length):
readline.parse_and_bind("tab: complete") readline.parse_and_bind("tab: complete")
readline.set_completer(c.wrap_completer) readline.set_completer(c.wrap_completer)
readline.set_completer_delims("") readline.set_completer_delims("")
code = raw_input(prompt) code = six.moves.input(prompt)
return code return code
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -2,7 +2,7 @@ from __future__ import print_function
import sys, os, json, binascii import sys, os, json, binascii
from ..errors import handle_server_error from ..errors import handle_server_error
APPID = "lothar.com/wormhole/file-xfer" APPID = b"lothar.com/wormhole/file-xfer"
@handle_server_error @handle_server_error
def receive_file(args): def receive_file(args):
@ -50,7 +50,7 @@ def receive_file(args):
# now receive the rest of the owl # now receive the rest of the owl
tdata = data["transit"] tdata = data["transit"]
transit_key = w.derive_key(APPID+"/transit-key") transit_key = w.derive_key(APPID+b"/transit-key")
transit_receiver.set_transit_key(transit_key) transit_receiver.set_transit_key(transit_key)
transit_receiver.add_their_direct_hints(tdata["direct_connection_hints"]) transit_receiver.add_their_direct_hints(tdata["direct_connection_hints"])
transit_receiver.add_their_relay_hints(tdata["relay_connection_hints"]) transit_receiver.add_their_relay_hints(tdata["relay_connection_hints"])
@ -68,12 +68,12 @@ def receive_file(args):
if os.path.dirname(target) != here: if os.path.dirname(target) != here:
print("Error: suggested filename (%s) would be outside current directory" print("Error: suggested filename (%s) would be outside current directory"
% (filename,)) % (filename,))
record_pipe.send_record("bad filename\n") record_pipe.send_record(b"bad filename\n")
record_pipe.close() record_pipe.close()
return 1 return 1
if os.path.exists(target) and not args.overwrite: if os.path.exists(target) and not args.overwrite:
print("Error: refusing to overwrite existing file %s" % (filename,)) print("Error: refusing to overwrite existing file %s" % (filename,))
record_pipe.send_record("file already exists\n") record_pipe.send_record(b"file already exists\n")
record_pipe.close() record_pipe.close()
return 1 return 1
tmp = target + ".tmp" tmp = target + ".tmp"
@ -98,6 +98,6 @@ def receive_file(args):
os.rename(tmp, target) os.rename(tmp, target)
print("Received file written to %s" % target) print("Received file written to %s" % target)
record_pipe.send_record("ok\n") record_pipe.send_record(b"ok\n")
record_pipe.close() record_pipe.close()
return 0 return 0

View File

@ -2,7 +2,7 @@ from __future__ import print_function
import sys, json, binascii import sys, json, binascii
from ..errors import handle_server_error from ..errors import handle_server_error
APPID = "lothar.com/wormhole/text-xfer" APPID = b"lothar.com/wormhole/text-xfer"
@handle_server_error @handle_server_error
def receive_text(args): def receive_text(args):

View File

@ -1,8 +1,8 @@
from __future__ import print_function from __future__ import print_function
import os, sys, json, binascii import os, sys, json, binascii, six
from ..errors import handle_server_error from ..errors import handle_server_error
APPID = "lothar.com/wormhole/file-xfer" APPID = b"lothar.com/wormhole/file-xfer"
@handle_server_error @handle_server_error
def send_file(args): def send_file(args):
@ -37,7 +37,7 @@ def send_file(args):
if args.verify: if args.verify:
verifier = binascii.hexlify(w.get_verifier()) verifier = binascii.hexlify(w.get_verifier())
while True: while True:
ok = raw_input("Verifier %s. ok? (yes/no): " % verifier) ok = six.moves.input("Verifier %s. ok? (yes/no): " % verifier)
if ok.lower() == "yes": if ok.lower() == "yes":
break break
if ok.lower() == "no": if ok.lower() == "no":
@ -70,7 +70,7 @@ def send_file(args):
tdata = them_d["transit"] tdata = them_d["transit"]
transit_key = w.derive_key(APPID+"/transit-key") transit_key = w.derive_key(APPID+b"/transit-key")
transit_sender.set_transit_key(transit_key) transit_sender.set_transit_key(transit_key)
transit_sender.add_their_direct_hints(tdata["direct_connection_hints"]) transit_sender.add_their_direct_hints(tdata["direct_connection_hints"])
transit_sender.add_their_relay_hints(tdata["relay_connection_hints"]) transit_sender.add_their_relay_hints(tdata["relay_connection_hints"])
@ -91,7 +91,7 @@ def send_file(args):
print("File sent.. waiting for confirmation") print("File sent.. waiting for confirmation")
ack = record_pipe.receive_record() ack = record_pipe.receive_record()
if ack == "ok\n": if ack == b"ok\n":
print("Confirmation received. Transfer complete.") print("Confirmation received. Transfer complete.")
return 0 return 0
else: else:

View File

@ -1,8 +1,8 @@
from __future__ import print_function from __future__ import print_function
import sys, json, binascii import sys, json, binascii, six
from ..errors import handle_server_error from ..errors import handle_server_error
APPID = "lothar.com/wormhole/text-xfer" APPID = b"lothar.com/wormhole/text-xfer"
@handle_server_error @handle_server_error
def send_text(args): def send_text(args):
@ -31,7 +31,7 @@ def send_text(args):
if args.verify: if args.verify:
verifier = binascii.hexlify(w.get_verifier()) verifier = binascii.hexlify(w.get_verifier())
while True: while True:
ok = raw_input("Verifier %s. ok? (yes/no): " % verifier) ok = six.moves.input("Verifier %s. ok? (yes/no): " % verifier)
if ok.lower() == "yes": if ok.lower() == "yes":
break break
if ok.lower() == "no": if ok.lower() == "no":

View File

@ -1,3 +1,4 @@
from __future__ import print_function
import sys, argparse import sys, argparse
from textwrap import dedent from textwrap import dedent
from .. import public_relay from .. import public_relay
@ -39,6 +40,7 @@ sp_start.add_argument("--transit", default="tcp:3001", metavar="tcp:PORT",
help="endpoint specification for the transit-relay port") help="endpoint specification for the transit-relay port")
sp_start.add_argument("--advertise-version", metavar="VERSION", sp_start.add_argument("--advertise-version", metavar="VERSION",
help="version to recommend to clients") help="version to recommend to clients")
sp_start.add_argument("-n", "--no-daemon", action="store_true")
#sp_start.add_argument("twistd_args", nargs="*", default=None, #sp_start.add_argument("twistd_args", nargs="*", default=None,
# metavar="[TWISTD-ARGS..]", # metavar="[TWISTD-ARGS..]",
# help=dedent("""\ # help=dedent("""\
@ -120,14 +122,19 @@ def run(args, stdout, stderr, executable=None):
also invoked by entry() below.""" also invoked by entry() below."""
args = parser.parse_args() args = parser.parse_args()
if not getattr(args, "func", None):
# So far this only works on py3. py2 exits with a really terse
# "error: too few arguments" during parse_args().
parser.print_help()
sys.exit(0)
try: try:
#rc = command.func(args, stdout, stderr) #rc = command.func(args, stdout, stderr)
rc = args.func(args) rc = args.func(args)
return rc return rc
except ImportError as e: except ImportError as e:
print >>stderr, "--- ImportError ---" print("--- ImportError ---", file=stderr)
print >>stderr, e print(e, file=stderr)
print >>stderr, "Please run 'python setup.py build'" print("Please run 'python setup.py build'", file=stderr)
raise raise
return 1 return 1
@ -138,4 +145,4 @@ def entry():
if __name__ == "__main__": if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
print args print(args)

View File

@ -22,8 +22,11 @@ def start_server(args):
c = MyTwistdConfig() c = MyTwistdConfig()
#twistd_args = tuple(args.twistd_args) + ("XYZ",) #twistd_args = tuple(args.twistd_args) + ("XYZ",)
twistd_args = ("XYZ",) # TODO: allow user to add twistd-specific args base_args = []
c.parseOptions(twistd_args) if args.no_daemon:
base_args.append("--nodaemon")
twistd_args = base_args + ["XYZ"]
c.parseOptions(tuple(twistd_args))
c.loadedPlugins = {"XYZ": MyPlugin(args)} c.loadedPlugins = {"XYZ": MyPlugin(args)}
print("starting wormhole relay server") print("starting wormhole relay server")

View File

@ -26,24 +26,24 @@ class EventsProtocol:
# face of firewall/NAT timeouts. It also helps unit tests, since # face of firewall/NAT timeouts. It also helps unit tests, since
# apparently twisted.web.client.Agent doesn't consider the connection # apparently twisted.web.client.Agent doesn't consider the connection
# to be established until it sees the first byte of the reponse body. # to be established until it sees the first byte of the reponse body.
self.request.write(": %s\n\n" % comment) self.request.write(b": %s\n\n" % comment)
def sendEvent(self, data, name=None, id=None, retry=None): def sendEvent(self, data, name=None, id=None, retry=None):
if name: if name:
self.request.write("event: %s\n" % name.encode("utf-8")) self.request.write(b"event: %s\n" % name.encode("utf-8"))
# e.g. if name=foo, then the client web page should do: # e.g. if name=foo, then the client web page should do:
# (new EventSource(url)).addEventListener("foo", handlerfunc) # (new EventSource(url)).addEventListener("foo", handlerfunc)
# Note that this basically defaults to "message". # Note that this basically defaults to "message".
self.request.write("\n") self.request.write(b"\n")
if id: if id:
self.request.write("id: %s\n" % id.encode("utf-8")) self.request.write(b"id: %s\n" % id.encode("utf-8"))
self.request.write("\n") self.request.write(b"\n")
if retry: if retry:
self.request.write("retry: %d\n" % retry) # milliseconds self.request.write(b"retry: %d\n" % retry) # milliseconds
self.request.write("\n") self.request.write(b"\n")
for line in data.splitlines(): for line in data.splitlines():
self.request.write("data: %s\n" % line.encode("utf-8")) self.request.write(b"data: %s\n" % line.encode("utf-8"))
self.request.write("\n") self.request.write(b"\n")
def stop(self): def stop(self):
self.request.finish() self.request.finish()
@ -72,15 +72,15 @@ class Channel(resource.Resource):
def render_GET(self, request): def render_GET(self, request):
# rest of URL is: SIDE/poll/MSGNUM # rest of URL is: SIDE/poll/MSGNUM
their_side = request.postpath[0] their_side = request.postpath[0].decode("utf-8")
if request.postpath[1] != "poll": if request.postpath[1] != b"poll":
request.setResponseCode(http.BAD_REQUEST, "GET to wrong URL") request.setResponseCode(http.BAD_REQUEST, b"GET to wrong URL")
return "GET is only for /SIDE/poll/MSGNUM" return b"GET is only for /SIDE/poll/MSGNUM"
their_msgnum = request.postpath[2] their_msgnum = request.postpath[2].decode("utf-8")
if "text/event-stream" not in (request.getHeader("accept") or ""): if b"text/event-stream" not in (request.getHeader(b"accept") or b""):
request.setResponseCode(http.BAD_REQUEST, "Must use EventSource") request.setResponseCode(http.BAD_REQUEST, b"Must use EventSource")
return "Must use EventSource (Content-Type: text/event-stream)" return b"Must use EventSource (Content-Type: text/event-stream)"
request.setHeader("content-type", "text/event-stream") request.setHeader(b"content-type", b"text/event-stream; charset=utf-8")
ep = EventsProtocol(request) ep = EventsProtocol(request)
ep.sendEvent(json.dumps(self.welcome), name="welcome") ep.sendEvent(json.dumps(self.welcome), name="welcome")
handle = (their_side, their_msgnum, ep) handle = (their_side, their_msgnum, ep)
@ -107,20 +107,20 @@ class Channel(resource.Resource):
def render_POST(self, request): def render_POST(self, request):
# rest of URL is: SIDE/(MSGNUM|deallocate)/(post|poll) # rest of URL is: SIDE/(MSGNUM|deallocate)/(post|poll)
side = request.postpath[0] side = request.postpath[0].decode("utf-8")
verb = request.postpath[1] verb = request.postpath[1].decode("utf-8")
if verb == "deallocate": if verb == "deallocate":
deleted = self.relay.maybe_free_child(self.channel_id, side) deleted = self.relay.maybe_free_child(self.channel_id, side)
if deleted: if deleted:
return "deleted\n" return b"deleted\n"
return "waiting\n" return b"waiting\n"
if verb not in ("post", "poll"): if verb not in ("post", "poll"):
request.setResponseCode(http.BAD_REQUEST) request.setResponseCode(http.BAD_REQUEST)
return "bad verb, want 'post' or 'poll'\n" return b"bad verb, want 'post' or 'poll'\n"
msgnum = request.postpath[2] msgnum = request.postpath[2].decode("utf-8")
other_messages = [] other_messages = []
for row in self.db.execute("SELECT `message` FROM `messages`" for row in self.db.execute("SELECT `message` FROM `messages`"
@ -131,7 +131,9 @@ class Channel(resource.Resource):
other_messages.append(row["message"]) other_messages.append(row["message"])
if verb == "post": if verb == "post":
data = json.load(request.content) #data = json.load(request.content, encoding="utf-8")
content = request.content.read()
data = json.loads(content.decode("utf-8"))
self.db.execute("INSERT INTO `messages`" self.db.execute("INSERT INTO `messages`"
" (`channel_id`, `side`, `msgnum`, `message`, `when`)" " (`channel_id`, `side`, `msgnum`, `message`, `when`)"
" VALUES (?,?,?,?,?)", " VALUES (?,?,?,?,?)",
@ -144,9 +146,10 @@ class Channel(resource.Resource):
self.db.commit() self.db.commit()
self.message_added(side, msgnum, data["message"]) self.message_added(side, msgnum, data["message"])
request.setHeader("content-type", "application/json; charset=utf-8") request.setHeader(b"content-type", b"application/json; charset=utf-8")
return json.dumps({"welcome": self.welcome, data = {"welcome": self.welcome,
"messages": other_messages})+"\n" "messages": other_messages}
return (json.dumps(data)+"\n").encode("utf-8")
def get_allocated(db): def get_allocated(db):
c = db.execute("SELECT DISTINCT `channel_id` FROM `allocations`") c = db.execute("SELECT DISTINCT `channel_id` FROM `allocations`")
@ -183,9 +186,10 @@ class Allocator(resource.Resource):
self.db.commit() self.db.commit()
log.msg("allocated #%d, now have %d DB channels" % log.msg("allocated #%d, now have %d DB channels" %
(channel_id, len(get_allocated(self.db)))) (channel_id, len(get_allocated(self.db))))
request.setHeader("content-type", "application/json; charset=utf-8") request.setHeader(b"content-type", b"application/json; charset=utf-8")
return json.dumps({"welcome": self.welcome, data = {"welcome": self.welcome,
"channel-id": channel_id})+"\n" "channel-id": channel_id}
return (json.dumps(data)+"\n").encode("utf-8")
class ChannelList(resource.Resource): class ChannelList(resource.Resource):
def __init__(self, db, welcome): def __init__(self, db, welcome):
@ -195,9 +199,10 @@ class ChannelList(resource.Resource):
def render_GET(self, request): def render_GET(self, request):
c = self.db.execute("SELECT DISTINCT `channel_id` FROM `allocations`") c = self.db.execute("SELECT DISTINCT `channel_id` FROM `allocations`")
allocated = sorted(set([row["channel_id"] for row in c.fetchall()])) allocated = sorted(set([row["channel_id"] for row in c.fetchall()]))
request.setHeader("content-type", "application/json; charset=utf-8") request.setHeader(b"content-type", b"application/json; charset=utf-8")
return json.dumps({"welcome": self.welcome, data = {"welcome": self.welcome,
"channel-ids": allocated})+"\n" "channel-ids": allocated}
return (json.dumps(data)+"\n").encode("utf-8")
class Relay(resource.Resource): class Relay(resource.Resource):
def __init__(self, db, welcome): def __init__(self, db, welcome):
@ -207,11 +212,11 @@ class Relay(resource.Resource):
self.channels = {} self.channels = {}
def getChild(self, path, request): def getChild(self, path, request):
if path == "allocate": if path == b"allocate":
return Allocator(self.db, self.welcome) return Allocator(self.db, self.welcome)
if path == "list": if path == b"list":
return ChannelList(self.db, self.welcome) return ChannelList(self.db, self.welcome)
if not re.search(r'^\d+$', path): if not re.search(br'^\d+$', path):
return resource.ErrorPage(http.BAD_REQUEST, return resource.ErrorPage(http.BAD_REQUEST,
"invalid channel id", "invalid channel id",
"invalid channel id") "invalid channel id")
@ -381,7 +386,7 @@ class Root(resource.Resource):
# child_FOO is a nevow thing, not a twisted.web.resource thing # child_FOO is a nevow thing, not a twisted.web.resource thing
def __init__(self): def __init__(self):
resource.Resource.__init__(self) resource.Resource.__init__(self)
self.putChild("", static.Data("Wormhole Relay\n", "text/plain")) self.putChild(b"", static.Data(b"Wormhole Relay\n", "text/plain"))
class RelayServer(service.MultiService): class RelayServer(service.MultiService):
def __init__(self, relayport, transitport, advertise_version, def __init__(self, relayport, transitport, advertise_version,
@ -405,7 +410,7 @@ class RelayServer(service.MultiService):
self.relayport_service = EndpointServerService(r, site) self.relayport_service = EndpointServerService(r, site)
self.relayport_service.setServiceParent(self) self.relayport_service.setServiceParent(self)
self.relay = Relay(self.db, welcome) # accessible from tests self.relay = Relay(self.db, welcome) # accessible from tests
self.root.putChild("wormhole-relay", self.relay) self.root.putChild(b"wormhole-relay", self.relay)
t = internet.TimerService(EXPIRATION_CHECK_PERIOD, t = internet.TimerService(EXPIRATION_CHECK_PERIOD,
self.relay.prune_old_channels) self.relay.prune_old_channels)
t.setServiceParent(self) t.setServiceParent(self)

View File

@ -10,14 +10,14 @@ class Blocking(ServerBase, unittest.TestCase):
# with deferToThread() # with deferToThread()
def test_basic(self): def test_basic(self):
appid = "appid" appid = b"appid"
w1 = BlockingWormhole(appid, self.relayurl) w1 = BlockingWormhole(appid, self.relayurl)
w2 = BlockingWormhole(appid, self.relayurl) w2 = BlockingWormhole(appid, self.relayurl)
d = deferToThread(w1.get_code) d = deferToThread(w1.get_code)
def _got_code(code): def _got_code(code):
w2.set_code(code) w2.set_code(code)
d1 = deferToThread(w1.get_data, "data1") d1 = deferToThread(w1.get_data, b"data1")
d2 = deferToThread(w2.get_data, "data2") d2 = deferToThread(w2.get_data, b"data2")
return defer.DeferredList([d1,d2], fireOnOneErrback=False) return defer.DeferredList([d1,d2], fireOnOneErrback=False)
d.addCallback(_got_code) d.addCallback(_got_code)
def _done(dl): def _done(dl):
@ -25,35 +25,35 @@ class Blocking(ServerBase, unittest.TestCase):
r1,r2 = dl r1,r2 = dl
self.assertTrue(success1, dataX) self.assertTrue(success1, dataX)
self.assertTrue(success2, dataY) self.assertTrue(success2, dataY)
self.assertEqual(dataX, "data2") self.assertEqual(dataX, b"data2")
self.assertEqual(dataY, "data1") self.assertEqual(dataY, b"data1")
d.addCallback(_done) d.addCallback(_done)
return d return d
def test_fixed_code(self): def test_fixed_code(self):
appid = "appid" appid = b"appid"
w1 = BlockingWormhole(appid, self.relayurl) w1 = BlockingWormhole(appid, self.relayurl)
w2 = BlockingWormhole(appid, self.relayurl) w2 = BlockingWormhole(appid, self.relayurl)
w1.set_code("123-purple-elephant") w1.set_code("123-purple-elephant")
w2.set_code("123-purple-elephant") w2.set_code("123-purple-elephant")
d1 = deferToThread(w1.get_data, "data1") d1 = deferToThread(w1.get_data, b"data1")
d2 = deferToThread(w2.get_data, "data2") d2 = deferToThread(w2.get_data, b"data2")
d = defer.DeferredList([d1,d2], fireOnOneErrback=False) d = defer.DeferredList([d1,d2], fireOnOneErrback=False)
def _done(dl): def _done(dl):
((success1, dataX), (success2, dataY)) = dl ((success1, dataX), (success2, dataY)) = dl
r1,r2 = dl r1,r2 = dl
self.assertTrue(success1, dataX) self.assertTrue(success1, dataX)
self.assertTrue(success2, dataY) self.assertTrue(success2, dataY)
self.assertEqual(dataX, "data2") self.assertEqual(dataX, b"data2")
self.assertEqual(dataY, "data1") self.assertEqual(dataY, b"data1")
d.addCallback(_done) d.addCallback(_done)
return d return d
def test_errors(self): def test_errors(self):
appid = "appid" appid = b"appid"
w1 = BlockingWormhole(appid, self.relayurl) w1 = BlockingWormhole(appid, self.relayurl)
self.assertRaises(UsageError, w1.get_verifier) self.assertRaises(UsageError, w1.get_verifier)
self.assertRaises(UsageError, w1.get_data, "data") self.assertRaises(UsageError, w1.get_data, b"data")
w1.set_code("123-purple-elephant") w1.set_code("123-purple-elephant")
self.assertRaises(UsageError, w1.set_code, "123-nope") self.assertRaises(UsageError, w1.set_code, "123-nope")
self.assertRaises(UsageError, w1.get_code) self.assertRaises(UsageError, w1.get_code)
@ -65,7 +65,7 @@ class Blocking(ServerBase, unittest.TestCase):
return d return d
def test_serialize(self): def test_serialize(self):
appid = "appid" appid = b"appid"
w1 = BlockingWormhole(appid, self.relayurl) w1 = BlockingWormhole(appid, self.relayurl)
self.assertRaises(UsageError, w1.serialize) # too early self.assertRaises(UsageError, w1.serialize) # too early
w2 = BlockingWormhole(appid, self.relayurl) w2 = BlockingWormhole(appid, self.relayurl)
@ -79,8 +79,8 @@ class Blocking(ServerBase, unittest.TestCase):
unpacked = json.loads(s) # this is supposed to be JSON unpacked = json.loads(s) # this is supposed to be JSON
self.assertEqual(type(unpacked), dict) self.assertEqual(type(unpacked), dict)
new_w1 = BlockingWormhole.from_serialized(s) new_w1 = BlockingWormhole.from_serialized(s)
d1 = deferToThread(new_w1.get_data, "data1") d1 = deferToThread(new_w1.get_data, b"data1")
d2 = deferToThread(w2.get_data, "data2") d2 = deferToThread(w2.get_data, b"data2")
return defer.DeferredList([d1,d2], fireOnOneErrback=False) return defer.DeferredList([d1,d2], fireOnOneErrback=False)
d.addCallback(_got_code) d.addCallback(_got_code)
def _done(dl): def _done(dl):
@ -88,8 +88,8 @@ class Blocking(ServerBase, unittest.TestCase):
r1,r2 = dl r1,r2 = dl
self.assertTrue(success1, dataX) self.assertTrue(success1, dataX)
self.assertTrue(success2, dataY) self.assertTrue(success2, dataY)
self.assertEqual(dataX, "data2") self.assertEqual(dataX, b"data2")
self.assertEqual(dataY, "data1") self.assertEqual(dataY, b"data1")
self.assertRaises(UsageError, w2.serialize) # too late self.assertRaises(UsageError, w2.serialize) # too late
d.addCallback(_done) d.addCallback(_done)
return d return d

View File

@ -55,12 +55,21 @@ class ScriptVersion(ServerBase, ScriptsBase, unittest.TestCase):
d = getProcessOutputAndValue(wormhole, ["--version"]) d = getProcessOutputAndValue(wormhole, ["--version"])
def _check(res): def _check(res):
out, err, rc = res out, err, rc = res
self.failUnlessEqual(out, "") # argparse on py2 sends --version to stderr
# argparse on py3 sends --version to stdout
# aargh
out = out.decode("utf-8")
err = err.decode("utf-8")
if "DistributionNotFound" in err: if "DistributionNotFound" in err:
log.msg("stderr was %s" % err) log.msg("stderr was %s" % err)
last = err.strip().split("\n")[-1] last = err.strip().split("\n")[-1]
self.fail("wormhole not runnable: %s" % last) self.fail("wormhole not runnable: %s" % last)
if sys.version_info[0] == 2:
self.failUnlessEqual(out, "")
self.failUnlessEqual(err, "magic-wormhole %s\n" % __version__) self.failUnlessEqual(err, "magic-wormhole %s\n" % __version__)
else:
self.failUnlessEqual(err, "")
self.failUnlessEqual(out, "magic-wormhole %s\n" % __version__)
self.failUnlessEqual(rc, 0) self.failUnlessEqual(rc, 0)
d.addCallback(_check) d.addCallback(_check)
return d return d
@ -92,6 +101,8 @@ class Scripts(ServerBase, ScriptsBase, unittest.TestCase):
d2 = getProcessOutputAndValue(wormhole, receive_args) d2 = getProcessOutputAndValue(wormhole, receive_args)
def _check_sender(res): def _check_sender(res):
out, err, rc = res out, err, rc = res
out = out.decode("utf-8")
err = err.decode("utf-8")
self.failUnlessEqual(out, self.failUnlessEqual(out,
"On the other computer, please run: " "On the other computer, please run: "
"wormhole receive-text\n" "wormhole receive-text\n"
@ -104,6 +115,8 @@ class Scripts(ServerBase, ScriptsBase, unittest.TestCase):
d1.addCallback(_check_sender) d1.addCallback(_check_sender)
def _check_receiver(res): def _check_receiver(res):
out, err, rc = res out, err, rc = res
out = out.decode("utf-8")
err = err.decode("utf-8")
self.failUnlessEqual(out, message+"\n") self.failUnlessEqual(out, message+"\n")
self.failUnlessEqual(err, "") self.failUnlessEqual(err, "")
self.failUnlessEqual(rc, 0) self.failUnlessEqual(rc, 0)
@ -137,6 +150,8 @@ class Scripts(ServerBase, ScriptsBase, unittest.TestCase):
d2 = getProcessOutputAndValue(wormhole, receive_args, path=receive_dir) d2 = getProcessOutputAndValue(wormhole, receive_args, path=receive_dir)
def _check_sender(res): def _check_sender(res):
out, err, rc = res out, err, rc = res
out = out.decode("utf-8")
err = err.decode("utf-8")
self.failUnlessIn("On the other computer, please run: " self.failUnlessIn("On the other computer, please run: "
"wormhole receive-file\n" "wormhole receive-file\n"
"Wormhole code is '%s'\n\n" % code, "Wormhole code is '%s'\n\n" % code,
@ -150,6 +165,8 @@ class Scripts(ServerBase, ScriptsBase, unittest.TestCase):
d1.addCallback(_check_sender) d1.addCallback(_check_sender)
def _check_receiver(res): def _check_receiver(res):
out, err, rc = res out, err, rc = res
out = out.decode("utf-8")
err = err.decode("utf-8")
self.failUnlessIn("Receiving %d bytes for 'testfile'" % len(message), self.failUnlessIn("Receiving %d bytes for 'testfile'" % len(message),
out) out)
self.failUnlessIn("Received file written to ", out) self.failUnlessIn("Received file written to ", out)

View File

@ -0,0 +1,49 @@
import sys
import requests
from twisted.trial import unittest
from twisted.internet import reactor
from twisted.internet.threads import deferToThread
from twisted.web.client import getPage, Agent, readBody
from .common import ServerBase
class Reachable(ServerBase, unittest.TestCase):
def test_getPage(self):
# client.getPage requires str/unicode URL, returns bytes
url = self.relayurl.replace("wormhole-relay/", "").encode("ascii")
d = getPage(url)
def _got(res):
self.failUnlessEqual(res, b"Wormhole Relay\n")
d.addCallback(_got)
return d
def test_agent(self):
# client.Agent is not yet ported: it wants URLs to be both unicode
# and bytes at the same time.
# https://twistedmatrix.com/trac/ticket/7407
if sys.version_info[0] > 2:
raise unittest.SkipTest("twisted.web.client.Agent does not yet support py3")
url = self.relayurl.replace("wormhole-relay/", "").encode("ascii")
agent = Agent(reactor)
d = agent.request("GET", url)
def _check(resp):
self.failUnlessEqual(resp.code, 200)
return readBody(resp)
d.addCallback(_check)
def _got(res):
self.failUnlessEqual(res, b"Wormhole Relay\n")
d.addCallback(_got)
return d
def test_requests(self):
# requests requires bytes URL, returns unicode
url = self.relayurl.replace("wormhole-relay/", "")
def _get(url):
r = requests.get(url)
r.raise_for_status()
return r.text
d = deferToThread(_get, url)
def _got(res):
self.failUnlessEqual(res, "Wormhole Relay\n")
d.addCallback(_got)
return d

View File

@ -1,22 +1,19 @@
import json import sys, json
from twisted.trial import unittest from twisted.trial import unittest
from twisted.internet import defer from twisted.internet import defer
from ..twisted.transcribe import Wormhole, UsageError from ..twisted.transcribe import Wormhole, UsageError
from .common import ServerBase from .common import ServerBase
#from twisted.python import log
#import sys
#log.startLogging(sys.stdout)
class Basic(ServerBase, unittest.TestCase): class Basic(ServerBase, unittest.TestCase):
def test_basic(self): def test_basic(self):
appid = "appid" appid = b"appid"
w1 = Wormhole(appid, self.relayurl) w1 = Wormhole(appid, self.relayurl)
w2 = Wormhole(appid, self.relayurl) w2 = Wormhole(appid, self.relayurl)
d = w1.get_code() d = w1.get_code()
def _got_code(code): def _got_code(code):
w2.set_code(code) w2.set_code(code)
d1 = w1.get_data("data1") d1 = w1.get_data(b"data1")
d2 = w2.get_data("data2") d2 = w2.get_data(b"data2")
return defer.DeferredList([d1,d2], fireOnOneErrback=False) return defer.DeferredList([d1,d2], fireOnOneErrback=False)
d.addCallback(_got_code) d.addCallback(_got_code)
def _done(dl): def _done(dl):
@ -24,35 +21,35 @@ class Basic(ServerBase, unittest.TestCase):
r1,r2 = dl r1,r2 = dl
self.assertTrue(success1, dataX) self.assertTrue(success1, dataX)
self.assertTrue(success2, dataY) self.assertTrue(success2, dataY)
self.assertEqual(dataX, "data2") self.assertEqual(dataX, b"data2")
self.assertEqual(dataY, "data1") self.assertEqual(dataY, b"data1")
d.addCallback(_done) d.addCallback(_done)
return d return d
def test_fixed_code(self): def test_fixed_code(self):
appid = "appid" appid = b"appid"
w1 = Wormhole(appid, self.relayurl) w1 = Wormhole(appid, self.relayurl)
w2 = Wormhole(appid, self.relayurl) w2 = Wormhole(appid, self.relayurl)
w1.set_code("123-purple-elephant") w1.set_code("123-purple-elephant")
w2.set_code("123-purple-elephant") w2.set_code("123-purple-elephant")
d1 = w1.get_data("data1") d1 = w1.get_data(b"data1")
d2 = w2.get_data("data2") d2 = w2.get_data(b"data2")
d = defer.DeferredList([d1,d2], fireOnOneErrback=False) d = defer.DeferredList([d1,d2], fireOnOneErrback=False)
def _done(dl): def _done(dl):
((success1, dataX), (success2, dataY)) = dl ((success1, dataX), (success2, dataY)) = dl
r1,r2 = dl r1,r2 = dl
self.assertTrue(success1, dataX) self.assertTrue(success1, dataX)
self.assertTrue(success2, dataY) self.assertTrue(success2, dataY)
self.assertEqual(dataX, "data2") self.assertEqual(dataX, b"data2")
self.assertEqual(dataY, "data1") self.assertEqual(dataY, b"data1")
d.addCallback(_done) d.addCallback(_done)
return d return d
def test_errors(self): def test_errors(self):
appid = "appid" appid = b"appid"
w1 = Wormhole(appid, self.relayurl) w1 = Wormhole(appid, self.relayurl)
self.assertRaises(UsageError, w1.get_verifier) self.assertRaises(UsageError, w1.get_verifier)
self.assertRaises(UsageError, w1.get_data, "data") self.assertRaises(UsageError, w1.get_data, b"data")
w1.set_code("123-purple-elephant") w1.set_code("123-purple-elephant")
self.assertRaises(UsageError, w1.set_code, "123-nope") self.assertRaises(UsageError, w1.set_code, "123-nope")
self.assertRaises(UsageError, w1.get_code) self.assertRaises(UsageError, w1.get_code)
@ -62,7 +59,7 @@ class Basic(ServerBase, unittest.TestCase):
return d return d
def test_serialize(self): def test_serialize(self):
appid = "appid" appid = b"appid"
w1 = Wormhole(appid, self.relayurl) w1 = Wormhole(appid, self.relayurl)
self.assertRaises(UsageError, w1.serialize) # too early self.assertRaises(UsageError, w1.serialize) # too early
w2 = Wormhole(appid, self.relayurl) w2 = Wormhole(appid, self.relayurl)
@ -76,8 +73,8 @@ class Basic(ServerBase, unittest.TestCase):
unpacked = json.loads(s) # this is supposed to be JSON unpacked = json.loads(s) # this is supposed to be JSON
self.assertEqual(type(unpacked), dict) self.assertEqual(type(unpacked), dict)
new_w1 = Wormhole.from_serialized(s) new_w1 = Wormhole.from_serialized(s)
d1 = new_w1.get_data("data1") d1 = new_w1.get_data(b"data1")
d2 = w2.get_data("data2") d2 = w2.get_data(b"data2")
return defer.DeferredList([d1,d2], fireOnOneErrback=False) return defer.DeferredList([d1,d2], fireOnOneErrback=False)
d.addCallback(_got_code) d.addCallback(_got_code)
def _done(dl): def _done(dl):
@ -85,8 +82,14 @@ class Basic(ServerBase, unittest.TestCase):
r1,r2 = dl r1,r2 = dl
self.assertTrue(success1, dataX) self.assertTrue(success1, dataX)
self.assertTrue(success2, dataY) self.assertTrue(success2, dataY)
self.assertEqual(dataX, "data2") self.assertEqual(dataX, b"data2")
self.assertEqual(dataY, "data1") self.assertEqual(dataY, b"data1")
self.assertRaises(UsageError, w2.serialize) # too late self.assertRaises(UsageError, w2.serialize) # too late
d.addCallback(_done) d.addCallback(_done)
return d return d
if sys.version_info[0] >= 3:
Basic.skip = "twisted is not yet sufficiently ported to py3"
# as of 15.4.0, Twisted is still missing:
# * web.client.Agent (for all non-EventSource POSTs in transcribe.py)
# * python.logfile (to allow daemonization of 'wormhole server')

View File

@ -1,3 +1,4 @@
from __future__ import print_function
import sys, json import sys, json
from twisted.internet import reactor from twisted.internet import reactor
from .transcribe import Wormhole from .transcribe import Wormhole
@ -12,15 +13,15 @@ if sys.argv[1] == "send-text":
data = json.dumps({"message": message}).encode("utf-8") data = json.dumps({"message": message}).encode("utf-8")
d = w.get_code() d = w.get_code()
def _got_code(code): def _got_code(code):
print "code is:", code print("code is:", code)
return w.get_data(data) return w.get_data(data)
d.addCallback(_got_code) d.addCallback(_got_code)
def _got_data(them_bytes): def _got_data(them_bytes):
them_d = json.loads(them_bytes.decode("utf-8")) them_d = json.loads(them_bytes.decode("utf-8"))
if them_d["message"] == "ok": if them_d["message"] == "ok":
print "text sent" print("text sent")
else: else:
print "error sending text: %r" % (them_d,) print("error sending text: %r" % (them_d,))
d.addCallback(_got_data) d.addCallback(_got_data)
elif sys.argv[1] == "receive-text": elif sys.argv[1] == "receive-text":
code = sys.argv[2] code = sys.argv[2]
@ -30,9 +31,9 @@ elif sys.argv[1] == "receive-text":
def _got_data(them_bytes): def _got_data(them_bytes):
them_d = json.loads(them_bytes.decode("utf-8")) them_d = json.loads(them_bytes.decode("utf-8"))
if "error" in them_d: if "error" in them_d:
print >>sys.stderr, "ERROR: " + them_d["error"] print("ERROR: " + them_d["error"], file=sys.stderr)
return 1 return 1
print them_d["message"] print(them_d["message"])
d.addCallback(_got_data) d.addCallback(_got_data)
else: else:
raise ValueError("bad command") raise ValueError("bad command")

View File

@ -1,11 +1,18 @@
#import sys
from twisted.python import log, failure from twisted.python import log, failure
from twisted.internet import reactor, defer, protocol from twisted.internet import reactor, defer, protocol
from twisted.application import service from twisted.application import service
from twisted.protocols import basic from twisted.protocols import basic
from twisted.web.client import Agent, ResponseDone from twisted.web.client import Agent, ResponseDone
from twisted.web.http_headers import Headers from twisted.web.http_headers import Headers
from cgi import parse_header
from ..util.eventual import eventually from ..util.eventual import eventually
#if sys.version_info[0] == 2:
# to_unicode = unicode
#else:
# to_unicode = str
class EventSourceParser(basic.LineOnlyReceiver): class EventSourceParser(basic.LineOnlyReceiver):
delimiter = "\n" delimiter = "\n"
@ -15,6 +22,10 @@ class EventSourceParser(basic.LineOnlyReceiver):
self.handler = handler self.handler = handler
self.done_deferred = defer.Deferred() self.done_deferred = defer.Deferred()
self.eventtype = "message" self.eventtype = "message"
self.encoding = "utf-8"
def set_encoding(self, encoding):
self.encoding = encoding
def connectionLost(self, why): def connectionLost(self, why):
if why.check(ResponseDone): if why.check(ResponseDone):
@ -40,6 +51,8 @@ class EventSourceParser(basic.LineOnlyReceiver):
self.current_field = None self.current_field = None
self.current_lines[:] = [] self.current_lines[:] = []
return return
line = line.decode(self.encoding)
#line = to_unicode(line, self.encoding)
if self.current_field is None: if self.current_field is None:
self.current_field, data = line.split(": ", 1) self.current_field, data = line.split(": ", 1)
self.current_lines.append(data) self.current_lines.append(data)
@ -90,7 +103,11 @@ class EventSource: # TODO: service.Service
raise EventSourceError("%d: %s" % (resp.code, resp.phrase)) raise EventSourceError("%d: %s" % (resp.code, resp.phrase))
if self.when_connected: if self.when_connected:
self.when_connected() self.when_connected()
#if resp.headers.getRawHeaders("content-type") == ["text/event-stream"]: default_ct = "text/event-stream; charset=utf-8"
ct_headers = resp.headers.getRawHeaders("content-type", [default_ct])
ct, ct_params = parse_header(ct_headers[0])
assert ct == "text/event-stream", ct
self.proto.set_encoding(ct_params.get("charset", "utf-8"))
resp.deliverBody(self.proto) resp.deliverBody(self.proto)
if self.cancelled: if self.cancelled:
self.kill_connection() self.kill_connection()

View File

@ -38,6 +38,7 @@ class Wormhole:
version_warning_displayed = False version_warning_displayed = False
def __init__(self, appid, relay): def __init__(self, appid, relay):
if not isinstance(appid, type(b"")): raise UsageError
self.appid = appid self.appid = appid
self.relay = relay self.relay = relay
self.agent = web_client.Agent(reactor) self.agent = web_client.Agent(reactor)
@ -109,6 +110,7 @@ class Wormhole:
d = self._allocate_channel() d = self._allocate_channel()
def _got_channel_id(channel_id): def _got_channel_id(channel_id):
code = codes.make_code(channel_id, code_length) code = codes.make_code(channel_id, code_length)
assert isinstance(code, str), type(code)
self._set_code_and_channel_id(code) self._set_code_and_channel_id(code)
self._start() self._start()
return code return code
@ -116,6 +118,7 @@ class Wormhole:
return d return d
def set_code(self, code): def set_code(self, code):
if not isinstance(code, str): raise UsageError
if self.code is not None: raise UsageError if self.code is not None: raise UsageError
if self.side is not None: raise UsageError if self.side is not None: raise UsageError
self._set_code_and_channel_id(code) self._set_code_and_channel_id(code)
@ -201,12 +204,16 @@ class Wormhole:
return HKDF(self.key, length, CTXinfo=purpose) return HKDF(self.key, length, CTXinfo=purpose)
def _encrypt_data(self, key, data): def _encrypt_data(self, key, data):
assert isinstance(key, type(b"")), type(key)
assert isinstance(data, type(b"")), type(data)
if len(key) != SecretBox.KEY_SIZE: raise UsageError if len(key) != SecretBox.KEY_SIZE: raise UsageError
box = SecretBox(key) box = SecretBox(key)
nonce = utils.random(SecretBox.NONCE_SIZE) nonce = utils.random(SecretBox.NONCE_SIZE)
return box.encrypt(data, nonce) return box.encrypt(data, nonce)
def _decrypt_data(self, key, encrypted): def _decrypt_data(self, key, encrypted):
assert isinstance(key, type(b"")), type(key)
assert isinstance(encrypted, type(b"")), type(encrypted)
if len(key) != SecretBox.KEY_SIZE: raise UsageError if len(key) != SecretBox.KEY_SIZE: raise UsageError
box = SecretBox(key) box = SecretBox(key)
data = box.decrypt(encrypted) data = box.decrypt(encrypted)
@ -235,6 +242,7 @@ class Wormhole:
def get_data(self, outbound_data): def get_data(self, outbound_data):
# only call this once # only call this once
if not isinstance(outbound_data, type(b"")): raise UsageError
if self.code is None: raise UsageError if self.code is None: raise UsageError
d = self._get_key() d = self._get_key()
d.addCallback(self._get_data2, outbound_data) d.addCallback(self._get_data2, outbound_data)

View File

@ -32,7 +32,7 @@ def find_addresses():
commands = _unix_commands commands = _unix_commands
for (pathtotool, args, regex) in commands: for (pathtotool, args, regex) in commands:
assert os.path.isabs(pathtotool) assert os.path.isabs(pathtotool), pathtotool
if not os.path.isfile(pathtotool): if not os.path.isfile(pathtotool):
continue continue
try: try:
@ -46,12 +46,13 @@ def find_addresses():
def _query(path, args, regex): def _query(path, args, regex):
env = {'LANG': 'en_US.UTF-8'} env = {'LANG': 'en_US.UTF-8'}
TRIES = 5 TRIES = 5
for trial in xrange(TRIES): for trial in range(TRIES):
try: try:
p = subprocess.Popen([path] + list(args), p = subprocess.Popen([path] + list(args),
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, stderr=subprocess.PIPE,
env=env) env=env,
universal_newlines=True)
(output, err) = p.communicate() (output, err) = p.communicate()
break break
except OSError as e: except OSError as e: