diff --git a/setup.py b/setup.py index b0ca181..364a84a 100644 --- a/setup.py +++ b/setup.py @@ -21,10 +21,8 @@ setup(name="magic-wormhole", entry_points={"console_scripts": ["wormhole = wormhole.scripts.runner:entry"]}, install_requires=["spake2==0.3", "pynacl", "requests", "argparse", - "six"], + "six", "twisted >= 16.1.0"], extras_require={"tor": ["txtorcon", "ipaddr"]}, - # for Twisted support, we want Twisted>=15.5.0. Older Twisteds don't - # provide sufficient python3 compatibility. test_suite="wormhole.test", cmdclass=commands, ) diff --git a/src/wormhole/blocking/transit.py b/src/wormhole/blocking/transit.py deleted file mode 100644 index c520437..0000000 --- a/src/wormhole/blocking/transit.py +++ /dev/null @@ -1,400 +0,0 @@ -from __future__ import print_function -import time, threading, socket -from six.moves import socketserver -from binascii import hexlify, unhexlify -from nacl.secret import SecretBox -from ..util import ipaddrs -from ..util.hkdf import HKDF -from ..errors import UsageError -from ..timing import DebugTiming -from ..transit_common import (TransitError, BadHandshake, TransitClosed, - BadNonce, - build_receiver_handshake, - build_sender_handshake, - build_relay_handshake, - parse_hint_tcp) - -TIMEOUT=15 - -# 1: sender only transmits, receiver only accepts, both wait forever -# 2: sender also accepts, receiver also transmits -# 3: timeouts / stop when no more progress can be made -# 4: add relay -# 5: accelerate shutdown of losing sockets - -def send_to(skt, data): - sent = 0 - while sent < len(data): - sent += skt.send(data[sent:]) - -def wait_for_line(skt, max_length, description): - got = b"" - while len(got) < max_length: - got += skt.recv(1) - if got.endswith(b"\n"): - return got[:-1] - raise BadHandshake("exceeded max_length, got %r on %s" % - (got, description)) - -def wait_for(skt, expected, description): - assert isinstance(expected, type(b"")) - got = b"" - while len(got) < len(expected): - got += skt.recv(1) - if expected[:len(got)] != got: - raise BadHandshake("got %r want %r on %s" % - (got, expected, description)) - -def debug(msg): - if False: - print(msg) -def since(start): - return time.time() - start - -def connector(owner, hint, description, - send_handshake, expected_handshake, relay_handshake=None): - start = time.time() - parsed_hint = parse_hint_tcp(hint) - if not parsed_hint: - return # unparseable - addr,port = parsed_hint - skt = None - debug("+ connector(%s)" % hint) - try: - skt = socket.create_connection((addr,port), - TIMEOUT) # timeout or ECONNREFUSED - skt.settimeout(TIMEOUT) - debug(" - socket(%s) connected CT+%.1f" % (description, since(start))) - if relay_handshake: - debug(" - sending relay_handshake") - send_to(skt, relay_handshake) - relay_msg = wait_for_line(skt, 10000, description) - if relay_msg != b"ok": - raise BadHandshake(relay_msg) - debug(" - relay ready CT+%.1f" % (since(start),)) - send_to(skt, send_handshake) - wait_for(skt, expected_handshake, description) - debug(" + connector(%s) ready CT+%.1f" % (hint, since(start))) - except Exception as e: - debug(" - error(%s)(%r) CT+%.1f" % (hint, e, since(start))) - try: - if skt: - skt.shutdown(socket.SHUT_WR) - except socket.error: - pass - if skt: - skt.close() - # ignore socket errors, warn about coding errors - if not isinstance(e, (socket.error, socket.timeout, BadHandshake)): - raise - debug(" - notifying owner._connector_failed(%s) CT+%.1f" % (hint, since(start))) - owner._connector_failed(hint) - return - # owner is now responsible for the socket - owner._negotiation_finished(skt, description) # note thread - -def handle(skt, client_address, owner, description, - send_handshake, expected_handshake): - try: - debug("handle %r" % (skt,)) - skt.settimeout(TIMEOUT) - send_to(skt, send_handshake) - got = b"" - # for the receiver, this includes the "go\n" - while len(got) < len(expected_handshake): - more = skt.recv(1) - if not more: - raise BadHandshake("disconnect after merely '%r'" % got) - got += more - if expected_handshake[:len(got)] != got: - raise BadHandshake("got '%r' want '%r'" % - (got, expected_handshake)) - debug("handler negotiation finished %r" % (client_address,)) - except Exception as e: - debug("handler failed %r" % (client_address,)) - try: - # this raises socket.err(EBADF) if the socket was already closed - skt.shutdown(socket.SHUT_WR) - except socket.error: - pass - skt.close() # this appears to be idempotent - # ignore socket errors, warn about coding errors - if not isinstance(e, (socket.error, socket.timeout, BadHandshake)): - raise - return - # owner is now responsible for the socket - owner._negotiation_finished(skt, description) # note thread - -class MyTCPServer(socketserver.TCPServer): - allow_reuse_address = True - - def process_request(self, request, client_address): - description = "<-tcp:%s:%d" % (client_address[0], client_address[1]) - ready_lock = self.owner._ready_for_connections_lock - ready_lock.acquire() - while not (self.owner._ready_for_connections - and self.owner._transit_key): - ready_lock.wait() - # owner._transit_key is either None or set to a value. We don't - # modify it from here, so we can release the condition lock before - # grabbing the key. - ready_lock.release() - - # Once it is set, we can get handler_(send|receive)_handshake, which - # is what we actually care about. - t = threading.Thread(target=handle, - args=(request, client_address, - self.owner, description, - self.owner.handler_send_handshake, - self.owner.handler_expected_handshake)) - t.daemon = True - t.start() - - -class ReceiveBuffer: - def __init__(self, skt): - self.skt = skt - self.buf = b"" - - def read(self, count): - while len(self.buf) < count: - more = self.skt.recv(4096) - if not more: - raise TransitClosed - self.buf += more - rc = self.buf[:count] - self.buf = self.buf[count:] - return rc - -class RecordPipe: - def __init__(self, skt, send_key, receive_key, description): - self.skt = skt - self.send_box = SecretBox(send_key) - self.send_nonce = 0 - self.receive_buf = ReceiveBuffer(self.skt) - self.receive_box = SecretBox(receive_key) - self.next_receive_nonce = 0 - self._description = description - - def describe(self): - return self._description - - def send_record(self, record): - if not isinstance(record, type(b"")): raise UsageError - assert SecretBox.NONCE_SIZE == 24 - assert self.send_nonce < 2**(8*24) - assert len(record) < 2**(8*4) - nonce = unhexlify("%048x" % self.send_nonce) # big-endian - self.send_nonce += 1 - encrypted = self.send_box.encrypt(record, nonce) - length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long - send_to(self.skt, length) - send_to(self.skt, encrypted) - - def receive_record(self): - length_buf = self.receive_buf.read(4) - length = int(hexlify(length_buf), 16) - encrypted = self.receive_buf.read(length) - nonce_buf = encrypted[:SecretBox.NONCE_SIZE] # assume it's prepended - nonce = int(hexlify(nonce_buf), 16) - if nonce != self.next_receive_nonce: - raise BadNonce("received out-of-order record") - self.next_receive_nonce += 1 - record = self.receive_box.decrypt(encrypted) - return record - - def close(self): - self.skt.close() - -class Common: - def __init__(self, transit_relay, no_listen=False, timing=None): - if transit_relay: - if not isinstance(transit_relay, type(u"")): - raise UsageError - self._transit_relays = [transit_relay] - else: - self._transit_relays = [] - self._no_listen = no_listen - self._timing = timing or DebugTiming() - self._timing_started = self._timing.add_event("transit") - self.winning = threading.Event() - self._negotiation_check_lock = threading.Lock() - self._ready_for_connections_lock = threading.Condition() - self._ready_for_connections = False - self._transit_key = None - self._start_server() - - def _start_server(self): - if self._no_listen: - self.my_direct_hints = [] - self.listener = None - return - server = MyTCPServer(("", 0), None) - _, port = server.server_address - self.my_direct_hints = [u"tcp:%s:%d" % (addr, port) - for addr in ipaddrs.find_addresses()] - server.owner = self - server_thread = threading.Thread(target=server.serve_forever) - server_thread.daemon = True - server_thread.start() - self.listener = server - - def get_direct_hints(self): - return self.my_direct_hints - def get_relay_hints(self): - return self._transit_relays - - def add_their_direct_hints(self, hints): - for h in hints: - if not isinstance(h, type(u"")): - raise TypeError("hint '%r' should be unicode, not %s" - % (h, type(h))) - self._their_direct_hints = list(hints) - def add_their_relay_hints(self, hints): - for h in hints: - if not isinstance(h, type(u"")): - raise TypeError("hint '%r' should be unicode, not %s" - % (h, type(h))) - self._their_relay_hints = list(hints) - - def _send_this(self): - if self.is_sender: - return build_sender_handshake(self._transit_key) - else: - return build_receiver_handshake(self._transit_key) - - def _expect_this(self): - if self.is_sender: - return build_receiver_handshake(self._transit_key) - else: - return build_sender_handshake(self._transit_key) + b"go\n" - - def _sender_record_key(self): - if self.is_sender: - return HKDF(self._transit_key, SecretBox.KEY_SIZE, - CTXinfo=b"transit_record_sender_key") - else: - return HKDF(self._transit_key, SecretBox.KEY_SIZE, - CTXinfo=b"transit_record_receiver_key") - - def _receiver_record_key(self): - if self.is_sender: - return HKDF(self._transit_key, SecretBox.KEY_SIZE, - CTXinfo=b"transit_record_receiver_key") - else: - return HKDF(self._transit_key, SecretBox.KEY_SIZE, - CTXinfo=b"transit_record_sender_key") - - def set_transit_key(self, key): - # This _ready_for_connections condition/lock protects us against the - # race where the sender knows the hints and the key, and connects to - # the receiver's transit socket before the receiver gets relay - # message (and thus the key). - self._ready_for_connections_lock.acquire() - self._transit_key = key - self.handler_send_handshake = self._send_this() # no "go" - self.handler_expected_handshake = self._expect_this() - self._ready_for_connections_lock.notify_all() - self._ready_for_connections_lock.release() - - def _start_outbound(self): - self._active_connectors = set(self._their_direct_hints) - self._attempted_connectors = set() - for hint in self._their_direct_hints: - self._start_connector(hint) - if not self._their_direct_hints: - self._start_relay_connectors() - - def _start_connector(self, hint, is_relay=False): - # Don't try any hint more than once. If all hints fail, we'll - # eventually timeout. We make no attempt to fail any faster. - if hint in self._attempted_connectors: - return - self._attempted_connectors.add(hint) - description = "->%s" % (hint,) - if is_relay: - description = "->relay:%s" % (hint,) - args = (self, hint, description, - self._send_this(), self._expect_this()) - if is_relay: - args = args + (build_relay_handshake(self._transit_key),) - t = threading.Thread(target=connector, args=args) - t.daemon = True - t.start() - - def _start_relay_connectors(self): - self._active_connectors.update(self._their_direct_hints) - for hint in self._their_relay_hints: - self._start_connector(hint, is_relay=True) - - def establish_socket(self): - start = time.time() - self.winning_skt = None - self.winning_skt_description = None - self._ready_for_connections_lock.acquire() - self._ready_for_connections = True - self._ready_for_connections_lock.notify_all() - self._ready_for_connections_lock.release() - self._start_outbound() - - # we sit here until one of our inbound or outbound sockets succeeds - flag = self.winning.wait(2*TIMEOUT) - debug("wait returned at %.1f" % (since(start),)) - - if not flag: - # timeout: self.winning_skt will not be set. ish. race. - pass - if self.listener: - self.listener.shutdown() # TODO: waits up to 0.5s. push to thread - if self.winning_skt: - return self.winning_skt - raise TransitError("timeout") - - def _connector_failed(self, hint): - debug("- failed connector %s" % hint) - # XXX this was .remove, and occasionally got KeyError - self._active_connectors.discard(hint) - if not self._active_connectors: - self._start_relay_connectors() - - def _negotiation_finished(self, skt, description): - # inbound/outbound sockets call this when they finish negotiation. - # The first one wins and gets a "go". Any subsequent ones lose and - # get a "nevermind" before being closed. - - with self._negotiation_check_lock: - if self.winning_skt: - is_winner = False - else: - is_winner = True - self.winning_skt = skt - self.winning_skt_description = description - - if is_winner: - if self.is_sender: - send_to(skt, b"go\n") - self.winning.set() - else: - if self.is_sender: - try: - send_to(skt, b"nevermind\n") - except socket.error: - # They realized this connection is not going to win, and - # closed it so fast we didn't get a chance to tell them - # it lost. This happens in unit tests. - pass - skt.close() - - def connect(self): - _start = self._timing.add_event("transit connect") - skt = self.establish_socket() - self._timing.finish_event(_start) - return RecordPipe(skt, self._sender_record_key(), - self._receiver_record_key(), - self.winning_skt_description) - -class TransitSender(Common): - is_sender = True - -class TransitReceiver(Common): - is_sender = False diff --git a/src/wormhole/scripts/cli_args.py b/src/wormhole/scripts/cli_args.py index 90169d4..c650768 100644 --- a/src/wormhole/scripts/cli_args.py +++ b/src/wormhole/scripts/cli_args.py @@ -27,8 +27,6 @@ g.add_argument("--hide-progress", action="store_true", help="supress progress-bar display") g.add_argument("--dump-timing", type=type(u""), # TODO: hide from --help output metavar="FILE", help="(debug) write timing data to file") -g.add_argument("--twisted", action="store_true", - help="use Twisted-based implementations, for testing") g.add_argument("--no-listen", action="store_true", help="(debug) don't open a listening socket for Transit") g.add_argument("--tor", action="store_true", diff --git a/src/wormhole/scripts/cmd_receive_twisted.py b/src/wormhole/scripts/cmd_receive.py similarity index 62% rename from src/wormhole/scripts/cmd_receive_twisted.py rename to src/wormhole/scripts/cmd_receive.py index fb7dd3d..a48ade7 100644 --- a/src/wormhole/scripts/cmd_receive_twisted.py +++ b/src/wormhole/scripts/cmd_receive.py @@ -1,13 +1,19 @@ from __future__ import print_function -import io, json +import io, os, sys, json, binascii, six, tempfile, zipfile from twisted.internet import reactor, defer from twisted.internet.defer import inlineCallbacks, returnValue from ..twisted.transcribe import Wormhole, WrongPasswordError from ..twisted.transit import TransitReceiver -from .cmd_receive_blocking import BlockingReceiver, RespondError, APPID from ..errors import TransferError from .progress import ProgressPrinter + +APPID = u"lothar.com/wormhole/text-or-file-xfer" + +class RespondError(Exception): + def __init__(self, response): + self.response = response + def receive_twisted_sync(args): # try to use twisted.internet.task.react(f) here (but it calls sys.exit # directly) @@ -34,7 +40,14 @@ def receive_twisted_sync(args): def receive_twisted(args): return TwistedReceiver(args).go() -class TwistedReceiver(BlockingReceiver): + +class TwistedReceiver: + def __init__(self, args): + assert isinstance(args.relay_url, type(u"")) + self.args = args + + def msg(self, *args, **kwargs): + print(*args, file=self.args.stdout, **kwargs) # TODO: @handle_server_error @inlineCallbacks @@ -101,6 +114,11 @@ class TwistedReceiver(BlockingReceiver): self.args.code_length) yield w.set_code(code) + def show_verifier(self, verifier): + verifier_hex = binascii.hexlify(verifier).decode("ascii") + if self.args.verify: + self.msg(u"Verifier %s." % verifier_hex) + @inlineCallbacks def get_data(self, w): try: @@ -119,6 +137,61 @@ class TwistedReceiver(BlockingReceiver): data = json.dumps({"message_ack": "ok"}).encode("utf-8") yield w.send_data(data) + def handle_file(self, them_d): + file_data = them_d["file"] + self.abs_destname = self.decide_destname("file", + file_data["filename"]) + self.xfersize = file_data["filesize"] + + self.msg(u"Receiving file (%d bytes) into: %s" % + (self.xfersize, os.path.basename(self.abs_destname))) + self.ask_permission() + tmp_destname = self.abs_destname + ".tmp" + return open(tmp_destname, "wb") + + def handle_directory(self, them_d): + file_data = them_d["directory"] + zipmode = file_data["mode"] + if zipmode != "zipfile/deflated": + self.msg(u"Error: unknown directory-transfer mode '%s'" % (zipmode,)) + raise RespondError({"error": "unknown mode"}) + self.abs_destname = self.decide_destname("directory", + file_data["dirname"]) + self.xfersize = file_data["zipsize"] + + self.msg(u"Receiving directory (%d bytes) into: %s/" % + (self.xfersize, os.path.basename(self.abs_destname))) + self.msg(u"%d files, %d bytes (uncompressed)" % + (file_data["numfiles"], file_data["numbytes"])) + self.ask_permission() + return tempfile.SpooledTemporaryFile() + + def decide_destname(self, mode, destname): + # the basename() is intended to protect us against + # "~/.ssh/authorized_keys" and other attacks + destname = os.path.basename(destname) + if self.args.output_file: + destname = self.args.output_file # override + abs_destname = os.path.join(self.args.cwd, destname) + + # get confirmation from the user before writing to the local directory + if os.path.exists(abs_destname): + self.msg(u"Error: refusing to overwrite existing %s %s" % + (mode, destname)) + raise RespondError({"error": "%s already exists" % mode}) + return abs_destname + + def ask_permission(self): + _start = self.args.timing.add_event("permission", waiting="user") + while True and not self.args.accept_file: + ok = six.moves.input("ok? (y/n): ") + if ok.lower().startswith("y"): + break + print(u"transfer rejected", file=sys.stderr) + self.args.timing.finish_event(_start, answer="no") + raise RespondError({"error": "transfer rejected"}) + self.args.timing.finish_event(_start, answer="yes") + @inlineCallbacks def establish_transit(self, w, them_d, tor_manager): transit_key = w.derive_key(APPID+u"/transit-key") @@ -169,6 +242,27 @@ class TwistedReceiver(BlockingReceiver): returnValue(1) # TODO: exit properly assert received == self.xfersize + def write_file(self, f): + tmp_name = f.name + f.close() + os.rename(tmp_name, self.abs_destname) + self.msg(u"Received file written to %s" % + os.path.basename(self.abs_destname)) + + def write_directory(self, f): + self.msg(u"Unpacking zipfile..") + _start = self.args.timing.add_event("unpack zip") + with zipfile.ZipFile(f, "r", zipfile.ZIP_DEFLATED) as zf: + zf.extractall(path=self.abs_destname) + # extractall() appears to offer some protection against + # malicious pathnames. For example, "/tmp/oops" and + # "../tmp/oops" both do the same thing as the (safe) + # "tmp/oops". + self.msg(u"Received files written to %s/" % + os.path.basename(self.abs_destname)) + f.close() + self.args.timing.finish_event(_start) + @inlineCallbacks def close_transit(self, record_pipe): _start = self.args.timing.add_event("ack") diff --git a/src/wormhole/scripts/cmd_receive_blocking.py b/src/wormhole/scripts/cmd_receive_blocking.py deleted file mode 100644 index afecd06..0000000 --- a/src/wormhole/scripts/cmd_receive_blocking.py +++ /dev/null @@ -1,216 +0,0 @@ -from __future__ import print_function -import io, os, sys, json, binascii, six, tempfile, zipfile -from ..blocking.transcribe import Wormhole, WrongPasswordError -from ..blocking.transit import TransitReceiver, TransitError -from ..errors import handle_server_error, TransferError -from .progress import ProgressPrinter - -APPID = u"lothar.com/wormhole/text-or-file-xfer" - -def receive_blocking(args): - return BlockingReceiver(args).go() - -class RespondError(Exception): - def __init__(self, response): - self.response = response - -class BlockingReceiver: - def __init__(self, args): - assert isinstance(args.relay_url, type(u"")) - self.args = args - - def msg(self, *args, **kwargs): - print(*args, file=self.args.stdout, **kwargs) - - @handle_server_error - def go(self): - with Wormhole(APPID, self.args.relay_url, timing=self.args.timing) as w: - self.handle_code(w) - verifier = w.get_verifier() - self.show_verifier(verifier) - them_d = self.get_data(w) - try: - if "message" in them_d: - self.handle_text(them_d, w) - return 0 - if "file" in them_d: - f = self.handle_file(them_d) - rp = self.establish_transit(w, them_d) - self.transfer_data(rp, f) - self.write_file(f) - self.close_transit(rp) - elif "directory" in them_d: - f = self.handle_directory(them_d) - rp = self.establish_transit(w, them_d) - self.transfer_data(rp, f) - self.write_directory(f) - self.close_transit(rp) - else: - self.msg(u"I don't know what they're offering\n") - self.msg(u"Offer details:", them_d) - raise RespondError({"error": "unknown offer type"}) - except RespondError as r: - data = json.dumps(r.response).encode("utf-8") - w.send_data(data) - return 1 - return 0 - - def handle_code(self, w): - code = self.args.code - if self.args.zeromode: - assert not code - code = u"0-" - if not code: - code = w.input_code("Enter receive wormhole code: ", - self.args.code_length) - w.set_code(code) - - def show_verifier(self, verifier): - verifier_hex = binascii.hexlify(verifier).decode("ascii") - if self.args.verify: - self.msg(u"Verifier %s." % verifier_hex) - - def get_data(self, w): - try: - them_bytes = w.get_data() - except WrongPasswordError as e: - raise TransferError(u"ERROR: " + e.explain()) - them_d = json.loads(them_bytes.decode("utf-8")) - if "error" in them_d: - raise TransferError(u"ERROR: " + them_d["error"]) - return them_d - - def handle_text(self, them_d, w): - # we're receiving a text message - self.msg(them_d["message"]) - data = json.dumps({"message_ack": "ok"}).encode("utf-8") - w.send_data(data) - - def handle_file(self, them_d): - file_data = them_d["file"] - self.abs_destname = self.decide_destname("file", - file_data["filename"]) - self.xfersize = file_data["filesize"] - - self.msg(u"Receiving file (%d bytes) into: %s" % - (self.xfersize, os.path.basename(self.abs_destname))) - self.ask_permission() - tmp_destname = self.abs_destname + ".tmp" - return open(tmp_destname, "wb") - - def handle_directory(self, them_d): - file_data = them_d["directory"] - zipmode = file_data["mode"] - if zipmode != "zipfile/deflated": - self.msg(u"Error: unknown directory-transfer mode '%s'" % (zipmode,)) - raise RespondError({"error": "unknown mode"}) - self.abs_destname = self.decide_destname("directory", - file_data["dirname"]) - self.xfersize = file_data["zipsize"] - - self.msg(u"Receiving directory (%d bytes) into: %s/" % - (self.xfersize, os.path.basename(self.abs_destname))) - self.msg(u"%d files, %d bytes (uncompressed)" % - (file_data["numfiles"], file_data["numbytes"])) - self.ask_permission() - return tempfile.SpooledTemporaryFile() - - def decide_destname(self, mode, destname): - # the basename() is intended to protect us against - # "~/.ssh/authorized_keys" and other attacks - destname = os.path.basename(destname) - if self.args.output_file: - destname = self.args.output_file # override - abs_destname = os.path.join(self.args.cwd, destname) - - # get confirmation from the user before writing to the local directory - if os.path.exists(abs_destname): - self.msg(u"Error: refusing to overwrite existing %s %s" % - (mode, destname)) - raise RespondError({"error": "%s already exists" % mode}) - return abs_destname - - def ask_permission(self): - _start = self.args.timing.add_event("permission", waiting="user") - while True and not self.args.accept_file: - ok = six.moves.input("ok? (y/n): ") - if ok.lower().startswith("y"): - break - print(u"transfer rejected", file=sys.stderr) - self.args.timing.finish_event(_start, answer="no") - raise RespondError({"error": "transfer rejected"}) - self.args.timing.finish_event(_start, answer="yes") - - def establish_transit(self, w, them_d): - transit_key = w.derive_key(APPID+u"/transit-key") - transit_receiver = TransitReceiver(self.args.transit_helper, - no_listen=self.args.no_listen, - timing=self.args.timing) - transit_receiver.set_transit_key(transit_key) - data = json.dumps({ - "file_ack": "ok", - "transit": { - "direct_connection_hints": transit_receiver.get_direct_hints(), - "relay_connection_hints": transit_receiver.get_relay_hints(), - }, - }).encode("utf-8") - w.send_data(data) - - # now receive the rest of the owl - tdata = them_d["transit"] - transit_receiver.add_their_direct_hints(tdata["direct_connection_hints"]) - transit_receiver.add_their_relay_hints(tdata["relay_connection_hints"]) - record_pipe = transit_receiver.connect() - return record_pipe - - def transfer_data(self, record_pipe, f): - self.msg(u"Receiving (%s).." % record_pipe.describe()) - - _start = self.args.timing.add_event("rx file") - progress_stdout = self.args.stdout - if self.args.hide_progress: - progress_stdout = io.StringIO() - received = 0 - p = ProgressPrinter(self.xfersize, progress_stdout) - p.start() - while received < self.xfersize: - try: - plaintext = record_pipe.receive_record() - except TransitError: - self.msg() - self.msg(u"Connection dropped before full file received") - self.msg(u"got %d bytes, wanted %d" % (received, self.xfersize)) - return 1 - f.write(plaintext) - received += len(plaintext) - p.update(received) - p.finish() - self.args.timing.finish_event(_start) - assert received == self.xfersize - - def write_file(self, f): - tmp_name = f.name - f.close() - os.rename(tmp_name, self.abs_destname) - self.msg(u"Received file written to %s" % - os.path.basename(self.abs_destname)) - - def write_directory(self, f): - self.msg(u"Unpacking zipfile..") - _start = self.args.timing.add_event("unpack zip") - with zipfile.ZipFile(f, "r", zipfile.ZIP_DEFLATED) as zf: - zf.extractall(path=self.abs_destname) - # extractall() appears to offer some protection against - # malicious pathnames. For example, "/tmp/oops" and - # "../tmp/oops" both do the same thing as the (safe) - # "tmp/oops". - self.msg(u"Received files written to %s/" % - os.path.basename(self.abs_destname)) - f.close() - self.args.timing.finish_event(_start) - - def close_transit(self, record_pipe): - _start = self.args.timing.add_event("ack") - record_pipe.send_record(b"ok\n") - record_pipe.close() - self.args.timing.finish_event(_start) diff --git a/src/wormhole/scripts/cmd_send_twisted.py b/src/wormhole/scripts/cmd_send.py similarity index 67% rename from src/wormhole/scripts/cmd_send_twisted.py rename to src/wormhole/scripts/cmd_send.py index 092c8cb..442a591 100644 --- a/src/wormhole/scripts/cmd_send_twisted.py +++ b/src/wormhole/scripts/cmd_send.py @@ -1,5 +1,5 @@ from __future__ import print_function -import io, json, binascii, six +import os, sys, io, json, binascii, six, tempfile, zipfile from twisted.protocols import basic from twisted.internet import reactor, defer from twisted.internet.defer import inlineCallbacks, returnValue @@ -7,8 +7,94 @@ from ..errors import TransferError from .progress import ProgressPrinter from ..twisted.transcribe import Wormhole, WrongPasswordError from ..twisted.transit import TransitSender -from .send_common import (APPID, handle_zero, build_other_command, - build_phase1_data) + +APPID = u"lothar.com/wormhole/text-or-file-xfer" + +def handle_zero(args): + if args.zeromode: + assert not args.code + args.code = u"0-" + +def build_other_command(args): + other_cmd = "wormhole receive" + if args.verify: + other_cmd = "wormhole --verify receive" + if args.zeromode: + other_cmd += " -0" + return other_cmd + +def build_phase1_data(args): + phase1 = {} + + text = args.text + if text == "-": + print(u"Reading text message from stdin..", file=args.stdout) + text = sys.stdin.read() + if not text and not args.what: + text = six.moves.input("Text to send: ") + + if text is not None: + print(u"Sending text message (%d bytes)" % len(text), file=args.stdout) + phase1 = { "message": text } + fd_to_send = None + return phase1, fd_to_send + + what = os.path.join(args.cwd, args.what) + what = what.rstrip(os.sep) + if not os.path.exists(what): + raise TransferError("Cannot send: no file/directory named '%s'" % + args.what) + basename = os.path.basename(what) + + if os.path.isfile(what): + # we're sending a file + filesize = os.stat(what).st_size + phase1["file"] = { + "filename": basename, + "filesize": filesize, + } + print(u"Sending %d byte file named '%s'" % (filesize, basename), + file=args.stdout) + fd_to_send = open(what, "rb") + return phase1, fd_to_send + + if os.path.isdir(what): + print(u"Building zipfile..", file=args.stdout) + # We're sending a directory. Create a zipfile in a tempdir and + # send that. + fd_to_send = tempfile.SpooledTemporaryFile() + # TODO: I think ZIP_DEFLATED means compressed.. check it + num_files = 0 + num_bytes = 0 + tostrip = len(what.split(os.sep)) + with zipfile.ZipFile(fd_to_send, "w", zipfile.ZIP_DEFLATED) as zf: + for path,dirs,files in os.walk(what): + # path always starts with args.what, then sometimes might + # have "/subdir" appended. We want the zipfile to contain + # "" or "subdir" + localpath = list(path.split(os.sep)[tostrip:]) + for fn in files: + archivename = os.path.join(*tuple(localpath+[fn])) + localfilename = os.path.join(path, fn) + zf.write(localfilename, archivename) + num_bytes += os.stat(localfilename).st_size + num_files += 1 + fd_to_send.seek(0,2) + filesize = fd_to_send.tell() + fd_to_send.seek(0,0) + phase1["directory"] = { + "mode": "zipfile/deflated", + "dirname": basename, + "zipsize": filesize, + "numbytes": num_bytes, + "numfiles": num_files, + } + print(u"Sending directory (%d bytes compressed) named '%s'" + % (filesize, basename), file=args.stdout) + return phase1, fd_to_send + + raise TypeError("'%s' is neither file nor directory" % args.what) + def send_twisted_sync(args): # try to use twisted.internet.task.react(f) here (but it calls sys.exit diff --git a/src/wormhole/scripts/cmd_send_blocking.py b/src/wormhole/scripts/cmd_send_blocking.py deleted file mode 100644 index 7259185..0000000 --- a/src/wormhole/scripts/cmd_send_blocking.py +++ /dev/null @@ -1,129 +0,0 @@ -from __future__ import print_function -import json, binascii, six -from ..errors import TransferError -from .progress import ProgressPrinter -from ..blocking.transcribe import Wormhole, WrongPasswordError -from ..blocking.transit import TransitSender -from ..errors import handle_server_error -from .send_common import (APPID, handle_zero, build_other_command, - build_phase1_data) - -@handle_server_error -def send_blocking(args): - assert isinstance(args.relay_url, type(u"")) - handle_zero(args) - phase1, fd_to_send = build_phase1_data(args) - other_cmd = build_other_command(args) - print(u"On the other computer, please run: %s" % other_cmd, - file=args.stdout) - - if fd_to_send is not None: - transit_sender = TransitSender(args.transit_helper, - no_listen=args.no_listen, - timing=args.timing) - transit_data = { - "direct_connection_hints": transit_sender.get_direct_hints(), - "relay_connection_hints": transit_sender.get_relay_hints(), - } - phase1["transit"] = transit_data - - with Wormhole(APPID, args.relay_url, timing=args.timing) as w: - if args.code: - w.set_code(args.code) - code = args.code - else: - code = w.get_code(args.code_length) - if not args.zeromode: - print(u"Wormhole code is: %s" % code, file=args.stdout) - print(u"", file=args.stdout) - - # get the verifier, because that also lets us derive the transit key, - # which we want to set before revealing the connection hints to the - # far side, so we'll be ready for them when they connect - verifier = binascii.hexlify(w.get_verifier()).decode("ascii") - if args.verify: - _do_verify(verifier, w) - - if fd_to_send is not None: - transit_key = w.derive_key(APPID+"/transit-key") - transit_sender.set_transit_key(transit_key) - - my_phase1_bytes = json.dumps(phase1).encode("utf-8") - w.send_data(my_phase1_bytes) - try: - them_phase1_bytes = w.get_data() - except WrongPasswordError as e: - raise TransferError(e.explain()) - - them_phase1 = json.loads(them_phase1_bytes.decode("utf-8")) - - if fd_to_send is None: - if them_phase1["message_ack"] == "ok": - print(u"text message sent", file=args.stdout) - return 0 - raise TransferError("error sending text: %r" % (them_phase1,)) - - return _send_file_blocking(them_phase1, fd_to_send, - transit_sender, args.stdout, args.hide_progress, - args.timing) - -def _do_verify(verifier, w): - while True: - ok = six.moves.input("Verifier %s. ok? (yes/no): " % verifier) - if ok.lower() == "yes": - break - if ok.lower() == "no": - reject_data = json.dumps({"error": "verification rejected", - }).encode("utf-8") - w.send_data(reject_data) - raise TransferError("verification rejected, abandoning transfer") - -def _send_file_blocking(them_phase1, fd_to_send, transit_sender, - stdout, hide_progress, timing): - - # we're sending a file, if they accept it - - if "error" in them_phase1: - raise TransferError("remote error, transfer abandoned: %s" - % them_phase1["error"]) - if them_phase1.get("file_ack") != "ok": - raise TransferError("ambiguous response from remote, " - "transfer abandoned: %s" % (them_phase1,)) - - tdata = them_phase1["transit"] - transit_sender.add_their_direct_hints(tdata["direct_connection_hints"]) - transit_sender.add_their_relay_hints(tdata["relay_connection_hints"]) - record_pipe = transit_sender.connect() - - print(u"Sending (%s).." % record_pipe.describe(), file=stdout) - - _start = timing.add_event("tx file") - CHUNKSIZE = 64*1024 - fd_to_send.seek(0,2) - filesize = fd_to_send.tell() - fd_to_send.seek(0,0) - p = ProgressPrinter(filesize, stdout) - with fd_to_send as f: - sent = 0 - if not hide_progress: - p.start() - while sent < filesize: - plaintext = f.read(CHUNKSIZE) - record_pipe.send_record(plaintext) - sent += len(plaintext) - if not hide_progress: - p.update(sent) - if not hide_progress: - p.finish() - timing.finish_event(_start) - - _start = timing.add_event("get ack") - print(u"File sent.. waiting for confirmation", file=stdout) - ack = record_pipe.receive_record() - record_pipe.close() - if ack == b"ok\n": - print(u"Confirmation received. Transfer complete.", file=stdout) - timing.finish_event(_start, ack="ok") - return 0 - timing.finish_event(_start, ack="failed") - raise TransferError("Transfer failed (remote says: %r)" % ack) diff --git a/src/wormhole/scripts/runner.py b/src/wormhole/scripts/runner.py index 32e9c88..8cff151 100644 --- a/src/wormhole/scripts/runner.py +++ b/src/wormhole/scripts/runner.py @@ -21,22 +21,14 @@ def dispatch(args): from ..servers import cmd_usage return cmd_usage.tail_usage(args) - if args.tor: - args.twisted = True if args.func == "send/send": - if args.twisted: - from . import cmd_send_twisted - return cmd_send_twisted.send_twisted_sync(args) - from . import cmd_send_blocking - return cmd_send_blocking.send_blocking(args) + from . import cmd_send + return cmd_send.send_twisted_sync(args) if args.func == "receive/receive": - if args.twisted: - _start = args.timing.add_event("import c_r_t") - from . import cmd_receive_twisted - args.timing.finish_event(_start) - return cmd_receive_twisted.receive_twisted_sync(args) - from . import cmd_receive_blocking - return cmd_receive_blocking.receive_blocking(args) + _start = args.timing.add_event("import c_r_t") + from . import cmd_receive + args.timing.finish_event(_start) + return cmd_receive.receive_twisted_sync(args) raise ValueError("unknown args.func %s" % args.func) diff --git a/src/wormhole/scripts/send_common.py b/src/wormhole/scripts/send_common.py deleted file mode 100644 index a8f3080..0000000 --- a/src/wormhole/scripts/send_common.py +++ /dev/null @@ -1,90 +0,0 @@ -from __future__ import print_function -import os, sys, six, tempfile, zipfile -from ..errors import TransferError - -APPID = u"lothar.com/wormhole/text-or-file-xfer" - -def handle_zero(args): - if args.zeromode: - assert not args.code - args.code = u"0-" - -def build_other_command(args): - other_cmd = "wormhole receive" - if args.verify: - other_cmd = "wormhole --verify receive" - if args.zeromode: - other_cmd += " -0" - return other_cmd - -def build_phase1_data(args): - phase1 = {} - - text = args.text - if text == "-": - print(u"Reading text message from stdin..", file=args.stdout) - text = sys.stdin.read() - if not text and not args.what: - text = six.moves.input("Text to send: ") - - if text is not None: - print(u"Sending text message (%d bytes)" % len(text), file=args.stdout) - phase1 = { "message": text } - fd_to_send = None - return phase1, fd_to_send - - what = os.path.join(args.cwd, args.what) - what = what.rstrip(os.sep) - if not os.path.exists(what): - raise TransferError("Cannot send: no file/directory named '%s'" % - args.what) - basename = os.path.basename(what) - - if os.path.isfile(what): - # we're sending a file - filesize = os.stat(what).st_size - phase1["file"] = { - "filename": basename, - "filesize": filesize, - } - print(u"Sending %d byte file named '%s'" % (filesize, basename), - file=args.stdout) - fd_to_send = open(what, "rb") - return phase1, fd_to_send - - if os.path.isdir(what): - print(u"Building zipfile..", file=args.stdout) - # We're sending a directory. Create a zipfile in a tempdir and - # send that. - fd_to_send = tempfile.SpooledTemporaryFile() - # TODO: I think ZIP_DEFLATED means compressed.. check it - num_files = 0 - num_bytes = 0 - tostrip = len(what.split(os.sep)) - with zipfile.ZipFile(fd_to_send, "w", zipfile.ZIP_DEFLATED) as zf: - for path,dirs,files in os.walk(what): - # path always starts with args.what, then sometimes might - # have "/subdir" appended. We want the zipfile to contain - # "" or "subdir" - localpath = list(path.split(os.sep)[tostrip:]) - for fn in files: - archivename = os.path.join(*tuple(localpath+[fn])) - localfilename = os.path.join(path, fn) - zf.write(localfilename, archivename) - num_bytes += os.stat(localfilename).st_size - num_files += 1 - fd_to_send.seek(0,2) - filesize = fd_to_send.tell() - fd_to_send.seek(0,0) - phase1["directory"] = { - "mode": "zipfile/deflated", - "dirname": basename, - "zipsize": filesize, - "numbytes": num_bytes, - "numfiles": num_files, - } - print(u"Sending directory (%d bytes compressed) named '%s'" - % (filesize, basename), file=args.stdout) - return phase1, fd_to_send - - raise TypeError("'%s' is neither file nor directory" % args.what) diff --git a/src/wormhole/test/common.py b/src/wormhole/test/common.py index 56aa38e..ee2de77 100644 --- a/src/wormhole/test/common.py +++ b/src/wormhole/test/common.py @@ -1,7 +1,7 @@ from twisted.application import service from twisted.internet import reactor, defer from twisted.python import log -from ..twisted.util import allocate_ports +from ..twisted.transit import allocate_tcp_port from ..servers.server import RelayServer from .. import __version__ @@ -9,19 +9,16 @@ class ServerBase: def setUp(self): self.sp = service.MultiService() self.sp.startService() - d = allocate_ports() - def _got_ports(ports): - relayport, transitport = ports - s = RelayServer("tcp:%d:interface=127.0.0.1" % relayport, - "tcp:%s:interface=127.0.0.1" % transitport, - __version__) - s.setServiceParent(self.sp) - self._relay_server = s.relay - self._transit_server = s.transit - self.relayurl = u"http://127.0.0.1:%d/wormhole-relay/" % relayport - self.transit = u"tcp:127.0.0.1:%d" % transitport - d.addCallback(_got_ports) - return d + relayport = allocate_tcp_port() + transitport = allocate_tcp_port() + s = RelayServer("tcp:%d:interface=127.0.0.1" % relayport, + "tcp:%s:interface=127.0.0.1" % transitport, + __version__) + s.setServiceParent(self.sp) + self._relay_server = s.relay + self._transit_server = s.transit + self.relayurl = u"http://127.0.0.1:%d/wormhole-relay/" % relayport + self.transit = u"tcp:127.0.0.1:%d" % transitport def tearDown(self): # Unit tests that spawn a (blocking) client in a thread might still diff --git a/src/wormhole/test/test_blocking.py b/src/wormhole/test/test_blocking.py index 1fbfaff..e930604 100644 --- a/src/wormhole/test/test_blocking.py +++ b/src/wormhole/test/test_blocking.py @@ -1,14 +1,11 @@ from __future__ import print_function import json from twisted.trial import unittest -from twisted.internet.defer import gatherResults, succeed, inlineCallbacks +from twisted.internet.defer import gatherResults, succeed from twisted.internet.threads import deferToThread from ..blocking.transcribe import (Wormhole, UsageError, ChannelManager, WrongPasswordError) from ..blocking.eventsource import EventSourceFollower -from ..blocking.transit import (TransitSender, TransitReceiver, - build_sender_handshake, - build_receiver_handshake) from .common import ServerBase APPID = u"appid" @@ -447,101 +444,3 @@ class EventSourceClient(unittest.TestCase): (u"message", u"three"), (u"e2", u"four"), ]) - -class Transit(_DoBothMixin, ServerBase, unittest.TestCase): - def test_hints(self): - r = TransitReceiver(self.transit) - hints = r.get_direct_hints() - self.assertTrue(len(hints), hints) - - @inlineCallbacks - def test_direct_to_receiver(self): - s = TransitSender(self.transit) - r = TransitReceiver(self.transit) - key = b"\x00"*32 - - # force the connection to be sender->receiver - s.set_transit_key(key) - # only use 127.0.0.1 - hint = u"tcp:127.0.0.1:%d" % r.listener.server_address[1] - s.add_their_direct_hints([hint]) - s.add_their_relay_hints([]) - r.set_transit_key(key) - r.add_their_direct_hints([]) - r.add_their_relay_hints([]) - - # it'd be nice to factor this chunk out with 'yield from', but that - # didn't appear until python-3.3, and isn't in py2 at all. - (sp, rp) = yield self.doBoth([s.connect], [r.connect]) - yield deferToThread(sp.send_record, b"01234") - rec = yield deferToThread(rp.receive_record) - self.assertEqual(rec, b"01234") - yield deferToThread(sp.close) - yield deferToThread(rp.close) - - @inlineCallbacks - def test_direct_to_sender(self): - s = TransitSender(self.transit) - r = TransitReceiver(self.transit) - key = b"\x00"*32 - - # force the connection to be receiver->sender - s.set_transit_key(key) - s.add_their_direct_hints([]) - s.add_their_relay_hints([]) - r.set_transit_key(key) - hint = u"tcp:127.0.0.1:%d" % s.listener.server_address[1] - r.add_their_direct_hints([hint]) - r.add_their_relay_hints([]) - - (sp, rp) = yield self.doBoth([s.connect], [r.connect]) - yield deferToThread(sp.send_record, b"01234") - rec = yield deferToThread(rp.receive_record) - self.assertEqual(rec, b"01234") - yield deferToThread(sp.close) - yield deferToThread(rp.close) - - @inlineCallbacks - def test_relay(self): - s = TransitSender(self.transit) - r = TransitReceiver(self.transit) - key = b"\x00"*32 - # force the connection to use the relay by not revealing direct hints - s.set_transit_key(key) - s.add_their_direct_hints([]) - s.add_their_relay_hints(r.get_relay_hints()) - r.set_transit_key(key) - r.add_their_direct_hints([]) - r.add_their_relay_hints(s.get_relay_hints()) - - (sp, rp) = yield self.doBoth([s.connect], [r.connect]) - yield deferToThread(sp.send_record, b"01234") - rec = yield deferToThread(rp.receive_record) - self.assertEqual(rec, b"01234") - yield deferToThread(sp.close) - yield deferToThread(rp.close) - # TODO: this may be racy if we don't poll the server to make sure - # it's witnessed the first connection closing before querying the DB - #import time - #yield deferToThread(time.sleep, 1) - - # check the transit relay's DB, make sure it counted the bytes - db = self._transit_server._db - c = db.execute("SELECT * FROM `usage` WHERE `type`=?", (u"transit",)) - rows = c.fetchall() - self.assertEqual(len(rows), 1) - row = rows[0] - self.assertEqual(row["result"], u"happy") - # Sender first writes relay_handshake and waits for OK, but that's - # not counted by the transit server. Then sender writes - # sender_handshake and waits for receiver_handshake. Then sender - # writes GO and the body. Body is length-prefixed SecretBox, so - # includes 4-byte length, 24-byte nonce, and 16-byte MAC. - sender_count = (len(build_sender_handshake(b""))+ - len(b"go\n")+ - 4+24+len(b"01234")+16) - # Receiver first writes relay_handshake and waits for OK, but that's - # not counted. Then receiver writes receiver_handshake and waits for - # sender_handshake+GO. - receiver_count = len(build_receiver_handshake(b"")) - self.assertEqual(row["total_bytes"], sender_count+receiver_count) diff --git a/src/wormhole/test/test_scripts.py b/src/wormhole/test/test_scripts.py index 2939f67..e5c0847 100644 --- a/src/wormhole/test/test_scripts.py +++ b/src/wormhole/test/test_scripts.py @@ -4,12 +4,10 @@ from twisted.trial import unittest from twisted.python import procutils, log from twisted.internet.utils import getProcessOutputAndValue from twisted.internet.defer import inlineCallbacks -from twisted.internet.threads import deferToThread from .. import __version__ from .common import ServerBase -from ..scripts import (runner, cmd_send_blocking, cmd_send_twisted, - cmd_receive_blocking, cmd_receive_twisted) -from ..scripts.send_common import build_phase1_data +from ..scripts import runner, cmd_send, cmd_receive +from ..scripts.cmd_send import build_phase1_data from ..errors import TransferError from ..timing import DebugTiming @@ -219,8 +217,7 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): @inlineCallbacks def _do_test(self, as_subprocess=False, - mode="text", addslash=False, override_filename=False, - sender_twisted=False, receiver_twisted=False): + mode="text", addslash=False, override_filename=False): assert mode in ("text", "file", "directory") common_args = ["--hide-progress", "--relay-url", self.relayurl, @@ -314,14 +311,8 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): rargs.stdout = io.StringIO() rargs.stderr = io.StringIO() rargs.timing = DebugTiming() - if sender_twisted: - send_d = cmd_send_twisted.send_twisted(sargs) - else: - send_d = deferToThread(cmd_send_blocking.send_blocking, sargs) - if receiver_twisted: - receive_d = cmd_receive_twisted.receive_twisted(rargs) - else: - receive_d = deferToThread(cmd_receive_blocking.receive_blocking, rargs) + send_d = cmd_send.send_twisted(sargs) + receive_d = cmd_receive.receive_twisted(rargs) # The sender might fail, leaving the receiver hanging, or vice # versa. If either side fails, cancel the other, so it won't @@ -418,24 +409,11 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): return self._do_test() def test_text_subprocess(self): return self._do_test(as_subprocess=True) - def test_text_twisted_to_blocking(self): - return self._do_test(sender_twisted=True) - def test_text_blocking_to_twisted(self): - return self._do_test(receiver_twisted=True) - def test_text_twisted_to_twisted(self): - return self._do_test(sender_twisted=True, receiver_twisted=True) def test_file(self): return self._do_test(mode="file") def test_file_override(self): return self._do_test(mode="file", override_filename=True) - def test_file_twisted_to_blocking(self): - return self._do_test(mode="file", sender_twisted=True) - def test_file_blocking_to_twisted(self): - return self._do_test(mode="file", receiver_twisted=True) - def test_file_twisted_to_twisted(self): - return self._do_test(mode="file", - sender_twisted=True, receiver_twisted=True) def test_directory(self): return self._do_test(mode="directory") @@ -443,16 +421,3 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): return self._do_test(mode="directory", addslash=True) def test_directory_override(self): return self._do_test(mode="directory", override_filename=True) - def test_directory_twisted_to_blocking(self): - return self._do_test(mode="directory", sender_twisted=True) - def test_directory_twisted_to_blocking_addslash(self): - return self._do_test(mode="directory", addslash=True, - sender_twisted=True) - def test_directory_blocking_to_twisted(self): - return self._do_test(mode="directory", receiver_twisted=True) - def test_directory_twisted_to_twisted(self): - return self._do_test(mode="directory", - sender_twisted=True, receiver_twisted=True) - def test_directory_twisted_to_twisted_addslash(self): - return self._do_test(mode="directory", addslash=True, - sender_twisted=True, receiver_twisted=True) diff --git a/src/wormhole/test/test_server.py b/src/wormhole/test/test_server.py index 63dd643..09bea1e 100644 --- a/src/wormhole/test/test_server.py +++ b/src/wormhole/test/test_server.py @@ -1,10 +1,12 @@ from __future__ import print_function import json import requests +from binascii import hexlify from six.moves.urllib_parse import urlencode from twisted.trial import unittest -from twisted.internet import reactor, defer +from twisted.internet import protocol, reactor, defer from twisted.internet.threads import deferToThread +from twisted.internet.endpoints import clientFromString, connectProtocol from twisted.web.client import getPage, Agent, readBody from .. import __version__ from .common import ServerBase @@ -434,7 +436,30 @@ class Summary(unittest.TestCase): self.failUnlessEqual(c._summarize(make_moods(None, "scary"), 41), (1, "scary", 40, 9)) -class Transit(unittest.TestCase): +class Accumulator(protocol.Protocol): + def __init__(self): + self.data = b"" + self.count = 0 + self._wait = None + def waitForBytes(self, more): + assert self._wait is None + self.count = more + self._wait = defer.Deferred() + self._check_done() + return self._wait + def dataReceived(self, data): + self.data = self.data + data + self._check_done() + def _check_done(self): + if self._wait and len(self.data) >= self.count: + d = self._wait + self._wait = None + d.callback(self) + def connectionLost(self, why): + if self._wait: + self._wait.errback(RuntimeError("closed")) + +class Transit(ServerBase, unittest.TestCase): def test_blur_size(self): blur = transit_server.blur_size self.failUnlessEqual(blur(0), 0) @@ -453,3 +478,62 @@ class Transit(unittest.TestCase): self.failUnlessEqual(blur(1100e6), 1100e6) self.failUnlessEqual(blur(1150e6), 1200e6) + @defer.inlineCallbacks + def test_basic(self): + ep = clientFromString(reactor, self.transit) + a1 = yield connectProtocol(ep, Accumulator()) + a2 = yield connectProtocol(ep, Accumulator()) + + token1 = b"\x00"*32 + a1.transport.write(b"please relay " + hexlify(token1) + b"\n") + a2.transport.write(b"please relay " + hexlify(token1) + b"\n") + + # a correct handshake yields an ack, after which we can send + exp = b"ok\n" + yield a1.waitForBytes(len(exp)) + self.assertEqual(a1.data, exp) + s1 = b"data1" + a1.transport.write(s1) + + exp = b"ok\n" + yield a2.waitForBytes(len(exp)) + self.assertEqual(a2.data, exp) + + # all data they sent after the handshake should be given to us + exp = b"ok\n"+s1 + yield a2.waitForBytes(len(exp)) + self.assertEqual(a2.data, exp) + + a1.transport.loseConnection() + a2.transport.loseConnection() + + @defer.inlineCallbacks + def test_bad_handshake(self): + ep = clientFromString(reactor, self.transit) + a1 = yield connectProtocol(ep, Accumulator()) + + token1 = b"\x00"*32 + # the server waits for the exact number of bytes in the expected + # handshake message. to trigger "bad handshake", we must match. + a1.transport.write(b"please DELAY " + hexlify(token1) + b"\n") + + exp = b"bad handshake\n" + yield a1.waitForBytes(len(exp)) + self.assertEqual(a1.data, exp) + + a1.transport.loseConnection() + + @defer.inlineCallbacks + def test_impatience(self): + ep = clientFromString(reactor, self.transit) + a1 = yield connectProtocol(ep, Accumulator()) + + token1 = b"\x00"*32 + # sending too many bytes is impatience. + a1.transport.write(b"please RELAY NOWNOW " + hexlify(token1) + b"\n") + + exp = b"impatient\n" + yield a1.waitForBytes(len(exp)) + self.assertEqual(a1.data, exp) + + a1.transport.loseConnection() diff --git a/src/wormhole/test/test_transit.py b/src/wormhole/test/test_transit.py deleted file mode 100644 index 0ebb750..0000000 --- a/src/wormhole/test/test_transit.py +++ /dev/null @@ -1,90 +0,0 @@ -from __future__ import print_function -from binascii import hexlify -from twisted.trial import unittest -from twisted.internet import protocol, defer, reactor -from twisted.internet.endpoints import clientFromString, connectProtocol -from .common import ServerBase - -class Accumulator(protocol.Protocol): - def __init__(self): - self.data = b"" - self.count = 0 - self._wait = None - def waitForBytes(self, more): - assert self._wait is None - self.count = more - self._wait = defer.Deferred() - self._check_done() - return self._wait - def dataReceived(self, data): - self.data = self.data + data - self._check_done() - def _check_done(self): - if self._wait and len(self.data) >= self.count: - d = self._wait - self._wait = None - d.callback(self) - def connectionLost(self, why): - if self._wait: - self._wait.errback(RuntimeError("closed")) - -class Transit(ServerBase, unittest.TestCase): - @defer.inlineCallbacks - def test_basic(self): - ep = clientFromString(reactor, self.transit) - a1 = yield connectProtocol(ep, Accumulator()) - a2 = yield connectProtocol(ep, Accumulator()) - - token1 = b"\x00"*32 - a1.transport.write(b"please relay " + hexlify(token1) + b"\n") - a2.transport.write(b"please relay " + hexlify(token1) + b"\n") - - # a correct handshake yields an ack, after which we can send - exp = b"ok\n" - yield a1.waitForBytes(len(exp)) - self.assertEqual(a1.data, exp) - s1 = b"data1" - a1.transport.write(s1) - - exp = b"ok\n" - yield a2.waitForBytes(len(exp)) - self.assertEqual(a2.data, exp) - - # all data they sent after the handshake should be given to us - exp = b"ok\n"+s1 - yield a2.waitForBytes(len(exp)) - self.assertEqual(a2.data, exp) - - a1.transport.loseConnection() - a2.transport.loseConnection() - - @defer.inlineCallbacks - def test_bad_handshake(self): - ep = clientFromString(reactor, self.transit) - a1 = yield connectProtocol(ep, Accumulator()) - - token1 = b"\x00"*32 - # the server waits for the exact number of bytes in the expected - # handshake message. to trigger "bad handshake", we must match. - a1.transport.write(b"please DELAY " + hexlify(token1) + b"\n") - - exp = b"bad handshake\n" - yield a1.waitForBytes(len(exp)) - self.assertEqual(a1.data, exp) - - a1.transport.loseConnection() - - @defer.inlineCallbacks - def test_impatience(self): - ep = clientFromString(reactor, self.transit) - a1 = yield connectProtocol(ep, Accumulator()) - - token1 = b"\x00"*32 - # sending too many bytes is impatience. - a1.transport.write(b"please RELAY NOWNOW " + hexlify(token1) + b"\n") - - exp = b"impatient\n" - yield a1.waitForBytes(len(exp)) - self.assertEqual(a1.data, exp) - - a1.transport.loseConnection() diff --git a/src/wormhole/transit_common.py b/src/wormhole/transit_common.py deleted file mode 100644 index eb18c92..0000000 --- a/src/wormhole/transit_common.py +++ /dev/null @@ -1,87 +0,0 @@ -import re -from binascii import hexlify -from .util.hkdf import HKDF - -class TransitError(Exception): - pass - -class BadHandshake(Exception): - pass - -class TransitClosed(TransitError): - pass - -class BadNonce(TransitError): - pass - -# The beginning of each TCP connection consists of the following handshake -# messages. The sender transmits the same text regardless of whether it is on -# the initiating/connecting end of the TCP connection, or on the -# listening/accepting side. Same for the receiver. -# -# sender -> receiver: transit sender TXID_HEX ready\n\n -# receiver -> sender: transit receiver RXID_HEX ready\n\n -# -# Any deviations from this result in the socket being closed. The handshake -# messages are designed to provoke an invalid response from other sorts of -# servers (HTTP, SMTP, echo). -# -# If the sender is satisfied with the handshake, and this is the first socket -# to complete negotiation, the sender does: -# -# sender -> receiver: go\n -# -# and the next byte on the wire will be from the application. -# -# If this is not the first socket, the sender does: -# -# sender -> receiver: nevermind\n -# -# and closes the socket. - -# So the receiver looks for "transit sender TXID_HEX ready\n\ngo\n" and hangs -# up upon the first wrong byte. The sender lookgs for "transit receiver -# RXID_HEX ready\n\n" and then makes a first/not-first decision about sending -# "go\n" or "nevermind\n"+close(). - -def build_receiver_handshake(key): - hexid = HKDF(key, 32, CTXinfo=b"transit_receiver") - return b"transit receiver "+hexlify(hexid)+b" ready\n\n" - -def build_sender_handshake(key): - hexid = HKDF(key, 32, CTXinfo=b"transit_sender") - return b"transit sender "+hexlify(hexid)+b" ready\n\n" - -def build_relay_handshake(key): - token = HKDF(key, 32, CTXinfo=b"transit_relay_token") - return b"please relay "+hexlify(token)+b"\n" - -# The hint format is: TYPE,VALUE= /^([a-zA-Z0-9]+):(.*)$/ . VALUE depends -# upon TYPE, and it can have more colons in it. For TYPE=tcp (the only one -# currently defined), ADDR,PORT = /^(.*):(\d+)$/ , so ADDR can have colons. -# ADDR can be a hostname, ipv4 dotted-quad, or ipv6 colon-hex. If the hint -# publisher wants anonymity, their only hint's ADDR will end in .onion . - -def parse_hint_tcp(hint): - assert isinstance(hint, type(u"")) - # return tuple or None for an unparseable hint - mo = re.search(r'^([a-zA-Z0-9]+):(.*)$', hint) - if not mo: - print("unparseable hint '%s'" % (hint,)) - return None - hint_type = mo.group(1) - if hint_type != "tcp": - print("unknown hint type '%s' in '%s'" % (hint_type, hint)) - return None - hint_value = mo.group(2) - mo = re.search(r'^(.*):(\d+)$', hint_value) - if not mo: - print("unparseable TCP hint '%s'" % (hint,)) - return None - hint_host = mo.group(1) - try: - hint_port = int(mo.group(2)) - except ValueError: - print("non-numeric port in TCP hint '%s'" % (hint,)) - return None - return hint_host, hint_port diff --git a/src/wormhole/twisted/transit.py b/src/wormhole/twisted/transit.py index cc8bed8..199b361 100644 --- a/src/wormhole/twisted/transit.py +++ b/src/wormhole/twisted/transit.py @@ -1,5 +1,5 @@ from __future__ import print_function -import sys, time, socket, collections +import re, sys, time, socket, collections from binascii import hexlify, unhexlify from zope.interface import implementer from twisted.python.runtime import platformType @@ -12,11 +12,6 @@ from ..util import ipaddrs from ..util.hkdf import HKDF from ..errors import UsageError from ..timing import DebugTiming -from ..transit_common import (BadHandshake, - BadNonce, - build_receiver_handshake, - build_sender_handshake, - build_relay_handshake) def debug(msg): if False: @@ -24,6 +19,90 @@ def debug(msg): def since(start): return time.time() - start +class TransitError(Exception): + pass + +class BadHandshake(Exception): + pass + +class TransitClosed(TransitError): + pass + +class BadNonce(TransitError): + pass + +# The beginning of each TCP connection consists of the following handshake +# messages. The sender transmits the same text regardless of whether it is on +# the initiating/connecting end of the TCP connection, or on the +# listening/accepting side. Same for the receiver. +# +# sender -> receiver: transit sender TXID_HEX ready\n\n +# receiver -> sender: transit receiver RXID_HEX ready\n\n +# +# Any deviations from this result in the socket being closed. The handshake +# messages are designed to provoke an invalid response from other sorts of +# servers (HTTP, SMTP, echo). +# +# If the sender is satisfied with the handshake, and this is the first socket +# to complete negotiation, the sender does: +# +# sender -> receiver: go\n +# +# and the next byte on the wire will be from the application. +# +# If this is not the first socket, the sender does: +# +# sender -> receiver: nevermind\n +# +# and closes the socket. + +# So the receiver looks for "transit sender TXID_HEX ready\n\ngo\n" and hangs +# up upon the first wrong byte. The sender lookgs for "transit receiver +# RXID_HEX ready\n\n" and then makes a first/not-first decision about sending +# "go\n" or "nevermind\n"+close(). + +def build_receiver_handshake(key): + hexid = HKDF(key, 32, CTXinfo=b"transit_receiver") + return b"transit receiver "+hexlify(hexid)+b" ready\n\n" + +def build_sender_handshake(key): + hexid = HKDF(key, 32, CTXinfo=b"transit_sender") + return b"transit sender "+hexlify(hexid)+b" ready\n\n" + +def build_relay_handshake(key): + token = HKDF(key, 32, CTXinfo=b"transit_relay_token") + return b"please relay "+hexlify(token)+b"\n" + +# The hint format is: TYPE,VALUE= /^([a-zA-Z0-9]+):(.*)$/ . VALUE depends +# upon TYPE, and it can have more colons in it. For TYPE=tcp (the only one +# currently defined), ADDR,PORT = /^(.*):(\d+)$/ , so ADDR can have colons. +# ADDR can be a hostname, ipv4 dotted-quad, or ipv6 colon-hex. If the hint +# publisher wants anonymity, their only hint's ADDR will end in .onion . + +def parse_hint_tcp(hint): + assert isinstance(hint, type(u"")) + # return tuple or None for an unparseable hint + mo = re.search(r'^([a-zA-Z0-9]+):(.*)$', hint) + if not mo: + print("unparseable hint '%s'" % (hint,)) + return None + hint_type = mo.group(1) + if hint_type != "tcp": + print("unknown hint type '%s' in '%s'" % (hint_type, hint)) + return None + hint_value = mo.group(2) + mo = re.search(r'^(.*):(\d+)$', hint_value) + if not mo: + print("unparseable TCP hint '%s'" % (hint,)) + return None + hint_host = mo.group(1) + try: + hint_port = int(mo.group(2)) + except ValueError: + print("non-numeric port in TCP hint '%s'" % (hint,)) + return None + return hint_host, hint_port + TIMEOUT=15 @implementer(interfaces.IProducer, interfaces.IConsumer) @@ -684,7 +763,7 @@ class Common: return d def _endpoint_from_hint(self, hint): - # TODO: use transit_common.parse_hint_tcp + # TODO: use parse_hint_tcp if ":" not in hint: return None pieces = hint.split(":") diff --git a/src/wormhole/twisted/util.py b/src/wormhole/twisted/util.py deleted file mode 100644 index 182fc34..0000000 --- a/src/wormhole/twisted/util.py +++ /dev/null @@ -1,21 +0,0 @@ -from twisted.internet import defer, protocol, endpoints, reactor - -def allocate_port(): - ep = endpoints.serverFromString(reactor, "tcp:0:interface=127.0.0.1") - d = ep.listen(protocol.Factory()) - def _listening(lp): - port = lp.getHost().port - d2 = lp.stopListening() - d2.addCallback(lambda _: port) - return d2 - d.addCallback(_listening) - return d - -def allocate_ports(): - d = defer.DeferredList([allocate_port(), allocate_port()]) - def _done(results): - port1 = results[0][1] - port2 = results[1][1] - return (port1, port2) - d.addCallback(_done) - return d diff --git a/src/wormhole/util/observer.py b/src/wormhole/util/observer.py deleted file mode 100644 index 25ce716..0000000 --- a/src/wormhole/util/observer.py +++ /dev/null @@ -1,50 +0,0 @@ -# -*- test-case-name: foolscap.test_observer -*- - -# many thanks to AllMyData for contributing the initial version of this code - -from twisted.internet import defer -from foolscap import eventual - -class OneShotObserverList(object): - """A one-shot event distributor. - - Subscribers can get a Deferred that will fire with the results of the - event once it finally occurs. The caller does not need to know whether - the event has happened yet or not: they get a Deferred in either case. - - The Deferreds returned to subscribers are guaranteed to not fire in the - current reactor turn; instead, eventually() is used to fire them in a - later turn. Look at Mark Miller's 'Concurrency Among Strangers' paper on - erights.org for a description of why this property is useful. - - I can only be fired once.""" - - def __init__(self): - self._fired = False - self._result = None - self._watchers = [] - self.__repr__ = self._unfired_repr - - def _unfired_repr(self): - return "" % (self._watchers, ) - - def _fired_repr(self): - return " %s>" % (self._result, ) - - def whenFired(self): - if self._fired: - return eventual.fireEventually(self._result) - d = defer.Deferred() - self._watchers.append(d) - return d - - def fire(self, result): - assert not self._fired - self._fired = True - self._result = result - - for w in self._watchers: - eventual.eventually(w.callback, result) - del self._watchers - self.__repr__ = self._fired_repr - diff --git a/tox.ini b/tox.ini index 6caebfc..ebef7a7 100644 --- a/tox.ini +++ b/tox.ini @@ -7,12 +7,6 @@ envlist = py27,py33,py34,py35 skip_missing_interpreters = True -# There's a race-condition bug in Twisted-15.5.0 (#8014, fixed in trunk) that -# manifests as various intermittent magic-wormhole test failures that always -# include twisted.internet.endpoints "iterateEndpoint" or "checkDone". So run -# all builds with a copy of Twisted from git 'trunk' until Twisted-16.0.0 is -# released and we can just depend on that. - # On windows we need "pypiwin32" installed. It's supposedly possible to make # Twisted do this by depending upon "twisted[windows]" instead of just # "twisted", but when I try this via Appveyor, the extra is ignored. @@ -24,7 +18,6 @@ skip_missing_interpreters = True [testenv] deps = - twisted >= 16.1.0 pyflakes {env:EXTRA_DEPENDENCY:} commands =