Use twisted implementation all the time.

Merge commit '1a455c0'
This commit is contained in:
Brian Warner 2016-04-18 16:17:26 -07:00
commit 7e1405576e
19 changed files with 382 additions and 1280 deletions

View File

@ -21,10 +21,8 @@ setup(name="magic-wormhole",
entry_points={"console_scripts":
["wormhole = wormhole.scripts.runner:entry"]},
install_requires=["spake2==0.3", "pynacl", "requests", "argparse",
"six"],
"six", "twisted >= 16.1.0"],
extras_require={"tor": ["txtorcon", "ipaddr"]},
# for Twisted support, we want Twisted>=15.5.0. Older Twisteds don't
# provide sufficient python3 compatibility.
test_suite="wormhole.test",
cmdclass=commands,
)

View File

@ -1,400 +0,0 @@
from __future__ import print_function
import time, threading, socket
from six.moves import socketserver
from binascii import hexlify, unhexlify
from nacl.secret import SecretBox
from ..util import ipaddrs
from ..util.hkdf import HKDF
from ..errors import UsageError
from ..timing import DebugTiming
from ..transit_common import (TransitError, BadHandshake, TransitClosed,
BadNonce,
build_receiver_handshake,
build_sender_handshake,
build_relay_handshake,
parse_hint_tcp)
TIMEOUT=15
# 1: sender only transmits, receiver only accepts, both wait forever
# 2: sender also accepts, receiver also transmits
# 3: timeouts / stop when no more progress can be made
# 4: add relay
# 5: accelerate shutdown of losing sockets
def send_to(skt, data):
sent = 0
while sent < len(data):
sent += skt.send(data[sent:])
def wait_for_line(skt, max_length, description):
got = b""
while len(got) < max_length:
got += skt.recv(1)
if got.endswith(b"\n"):
return got[:-1]
raise BadHandshake("exceeded max_length, got %r on %s" %
(got, description))
def wait_for(skt, expected, description):
assert isinstance(expected, type(b""))
got = b""
while len(got) < len(expected):
got += skt.recv(1)
if expected[:len(got)] != got:
raise BadHandshake("got %r want %r on %s" %
(got, expected, description))
def debug(msg):
if False:
print(msg)
def since(start):
return time.time() - start
def connector(owner, hint, description,
send_handshake, expected_handshake, relay_handshake=None):
start = time.time()
parsed_hint = parse_hint_tcp(hint)
if not parsed_hint:
return # unparseable
addr,port = parsed_hint
skt = None
debug("+ connector(%s)" % hint)
try:
skt = socket.create_connection((addr,port),
TIMEOUT) # timeout or ECONNREFUSED
skt.settimeout(TIMEOUT)
debug(" - socket(%s) connected CT+%.1f" % (description, since(start)))
if relay_handshake:
debug(" - sending relay_handshake")
send_to(skt, relay_handshake)
relay_msg = wait_for_line(skt, 10000, description)
if relay_msg != b"ok":
raise BadHandshake(relay_msg)
debug(" - relay ready CT+%.1f" % (since(start),))
send_to(skt, send_handshake)
wait_for(skt, expected_handshake, description)
debug(" + connector(%s) ready CT+%.1f" % (hint, since(start)))
except Exception as e:
debug(" - error(%s)(%r) CT+%.1f" % (hint, e, since(start)))
try:
if skt:
skt.shutdown(socket.SHUT_WR)
except socket.error:
pass
if skt:
skt.close()
# ignore socket errors, warn about coding errors
if not isinstance(e, (socket.error, socket.timeout, BadHandshake)):
raise
debug(" - notifying owner._connector_failed(%s) CT+%.1f" % (hint, since(start)))
owner._connector_failed(hint)
return
# owner is now responsible for the socket
owner._negotiation_finished(skt, description) # note thread
def handle(skt, client_address, owner, description,
send_handshake, expected_handshake):
try:
debug("handle %r" % (skt,))
skt.settimeout(TIMEOUT)
send_to(skt, send_handshake)
got = b""
# for the receiver, this includes the "go\n"
while len(got) < len(expected_handshake):
more = skt.recv(1)
if not more:
raise BadHandshake("disconnect after merely '%r'" % got)
got += more
if expected_handshake[:len(got)] != got:
raise BadHandshake("got '%r' want '%r'" %
(got, expected_handshake))
debug("handler negotiation finished %r" % (client_address,))
except Exception as e:
debug("handler failed %r" % (client_address,))
try:
# this raises socket.err(EBADF) if the socket was already closed
skt.shutdown(socket.SHUT_WR)
except socket.error:
pass
skt.close() # this appears to be idempotent
# ignore socket errors, warn about coding errors
if not isinstance(e, (socket.error, socket.timeout, BadHandshake)):
raise
return
# owner is now responsible for the socket
owner._negotiation_finished(skt, description) # note thread
class MyTCPServer(socketserver.TCPServer):
allow_reuse_address = True
def process_request(self, request, client_address):
description = "<-tcp:%s:%d" % (client_address[0], client_address[1])
ready_lock = self.owner._ready_for_connections_lock
ready_lock.acquire()
while not (self.owner._ready_for_connections
and self.owner._transit_key):
ready_lock.wait()
# owner._transit_key is either None or set to a value. We don't
# modify it from here, so we can release the condition lock before
# grabbing the key.
ready_lock.release()
# Once it is set, we can get handler_(send|receive)_handshake, which
# is what we actually care about.
t = threading.Thread(target=handle,
args=(request, client_address,
self.owner, description,
self.owner.handler_send_handshake,
self.owner.handler_expected_handshake))
t.daemon = True
t.start()
class ReceiveBuffer:
def __init__(self, skt):
self.skt = skt
self.buf = b""
def read(self, count):
while len(self.buf) < count:
more = self.skt.recv(4096)
if not more:
raise TransitClosed
self.buf += more
rc = self.buf[:count]
self.buf = self.buf[count:]
return rc
class RecordPipe:
def __init__(self, skt, send_key, receive_key, description):
self.skt = skt
self.send_box = SecretBox(send_key)
self.send_nonce = 0
self.receive_buf = ReceiveBuffer(self.skt)
self.receive_box = SecretBox(receive_key)
self.next_receive_nonce = 0
self._description = description
def describe(self):
return self._description
def send_record(self, record):
if not isinstance(record, type(b"")): raise UsageError
assert SecretBox.NONCE_SIZE == 24
assert self.send_nonce < 2**(8*24)
assert len(record) < 2**(8*4)
nonce = unhexlify("%048x" % self.send_nonce) # big-endian
self.send_nonce += 1
encrypted = self.send_box.encrypt(record, nonce)
length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long
send_to(self.skt, length)
send_to(self.skt, encrypted)
def receive_record(self):
length_buf = self.receive_buf.read(4)
length = int(hexlify(length_buf), 16)
encrypted = self.receive_buf.read(length)
nonce_buf = encrypted[:SecretBox.NONCE_SIZE] # assume it's prepended
nonce = int(hexlify(nonce_buf), 16)
if nonce != self.next_receive_nonce:
raise BadNonce("received out-of-order record")
self.next_receive_nonce += 1
record = self.receive_box.decrypt(encrypted)
return record
def close(self):
self.skt.close()
class Common:
def __init__(self, transit_relay, no_listen=False, timing=None):
if transit_relay:
if not isinstance(transit_relay, type(u"")):
raise UsageError
self._transit_relays = [transit_relay]
else:
self._transit_relays = []
self._no_listen = no_listen
self._timing = timing or DebugTiming()
self._timing_started = self._timing.add_event("transit")
self.winning = threading.Event()
self._negotiation_check_lock = threading.Lock()
self._ready_for_connections_lock = threading.Condition()
self._ready_for_connections = False
self._transit_key = None
self._start_server()
def _start_server(self):
if self._no_listen:
self.my_direct_hints = []
self.listener = None
return
server = MyTCPServer(("", 0), None)
_, port = server.server_address
self.my_direct_hints = [u"tcp:%s:%d" % (addr, port)
for addr in ipaddrs.find_addresses()]
server.owner = self
server_thread = threading.Thread(target=server.serve_forever)
server_thread.daemon = True
server_thread.start()
self.listener = server
def get_direct_hints(self):
return self.my_direct_hints
def get_relay_hints(self):
return self._transit_relays
def add_their_direct_hints(self, hints):
for h in hints:
if not isinstance(h, type(u"")):
raise TypeError("hint '%r' should be unicode, not %s"
% (h, type(h)))
self._their_direct_hints = list(hints)
def add_their_relay_hints(self, hints):
for h in hints:
if not isinstance(h, type(u"")):
raise TypeError("hint '%r' should be unicode, not %s"
% (h, type(h)))
self._their_relay_hints = list(hints)
def _send_this(self):
if self.is_sender:
return build_sender_handshake(self._transit_key)
else:
return build_receiver_handshake(self._transit_key)
def _expect_this(self):
if self.is_sender:
return build_receiver_handshake(self._transit_key)
else:
return build_sender_handshake(self._transit_key) + b"go\n"
def _sender_record_key(self):
if self.is_sender:
return HKDF(self._transit_key, SecretBox.KEY_SIZE,
CTXinfo=b"transit_record_sender_key")
else:
return HKDF(self._transit_key, SecretBox.KEY_SIZE,
CTXinfo=b"transit_record_receiver_key")
def _receiver_record_key(self):
if self.is_sender:
return HKDF(self._transit_key, SecretBox.KEY_SIZE,
CTXinfo=b"transit_record_receiver_key")
else:
return HKDF(self._transit_key, SecretBox.KEY_SIZE,
CTXinfo=b"transit_record_sender_key")
def set_transit_key(self, key):
# This _ready_for_connections condition/lock protects us against the
# race where the sender knows the hints and the key, and connects to
# the receiver's transit socket before the receiver gets relay
# message (and thus the key).
self._ready_for_connections_lock.acquire()
self._transit_key = key
self.handler_send_handshake = self._send_this() # no "go"
self.handler_expected_handshake = self._expect_this()
self._ready_for_connections_lock.notify_all()
self._ready_for_connections_lock.release()
def _start_outbound(self):
self._active_connectors = set(self._their_direct_hints)
self._attempted_connectors = set()
for hint in self._their_direct_hints:
self._start_connector(hint)
if not self._their_direct_hints:
self._start_relay_connectors()
def _start_connector(self, hint, is_relay=False):
# Don't try any hint more than once. If all hints fail, we'll
# eventually timeout. We make no attempt to fail any faster.
if hint in self._attempted_connectors:
return
self._attempted_connectors.add(hint)
description = "->%s" % (hint,)
if is_relay:
description = "->relay:%s" % (hint,)
args = (self, hint, description,
self._send_this(), self._expect_this())
if is_relay:
args = args + (build_relay_handshake(self._transit_key),)
t = threading.Thread(target=connector, args=args)
t.daemon = True
t.start()
def _start_relay_connectors(self):
self._active_connectors.update(self._their_direct_hints)
for hint in self._their_relay_hints:
self._start_connector(hint, is_relay=True)
def establish_socket(self):
start = time.time()
self.winning_skt = None
self.winning_skt_description = None
self._ready_for_connections_lock.acquire()
self._ready_for_connections = True
self._ready_for_connections_lock.notify_all()
self._ready_for_connections_lock.release()
self._start_outbound()
# we sit here until one of our inbound or outbound sockets succeeds
flag = self.winning.wait(2*TIMEOUT)
debug("wait returned at %.1f" % (since(start),))
if not flag:
# timeout: self.winning_skt will not be set. ish. race.
pass
if self.listener:
self.listener.shutdown() # TODO: waits up to 0.5s. push to thread
if self.winning_skt:
return self.winning_skt
raise TransitError("timeout")
def _connector_failed(self, hint):
debug("- failed connector %s" % hint)
# XXX this was .remove, and occasionally got KeyError
self._active_connectors.discard(hint)
if not self._active_connectors:
self._start_relay_connectors()
def _negotiation_finished(self, skt, description):
# inbound/outbound sockets call this when they finish negotiation.
# The first one wins and gets a "go". Any subsequent ones lose and
# get a "nevermind" before being closed.
with self._negotiation_check_lock:
if self.winning_skt:
is_winner = False
else:
is_winner = True
self.winning_skt = skt
self.winning_skt_description = description
if is_winner:
if self.is_sender:
send_to(skt, b"go\n")
self.winning.set()
else:
if self.is_sender:
try:
send_to(skt, b"nevermind\n")
except socket.error:
# They realized this connection is not going to win, and
# closed it so fast we didn't get a chance to tell them
# it lost. This happens in unit tests.
pass
skt.close()
def connect(self):
_start = self._timing.add_event("transit connect")
skt = self.establish_socket()
self._timing.finish_event(_start)
return RecordPipe(skt, self._sender_record_key(),
self._receiver_record_key(),
self.winning_skt_description)
class TransitSender(Common):
is_sender = True
class TransitReceiver(Common):
is_sender = False

View File

@ -27,8 +27,6 @@ g.add_argument("--hide-progress", action="store_true",
help="supress progress-bar display")
g.add_argument("--dump-timing", type=type(u""), # TODO: hide from --help output
metavar="FILE", help="(debug) write timing data to file")
g.add_argument("--twisted", action="store_true",
help="use Twisted-based implementations, for testing")
g.add_argument("--no-listen", action="store_true",
help="(debug) don't open a listening socket for Transit")
g.add_argument("--tor", action="store_true",

View File

@ -1,13 +1,19 @@
from __future__ import print_function
import io, json
import io, os, sys, json, binascii, six, tempfile, zipfile
from twisted.internet import reactor, defer
from twisted.internet.defer import inlineCallbacks, returnValue
from ..twisted.transcribe import Wormhole, WrongPasswordError
from ..twisted.transit import TransitReceiver
from .cmd_receive_blocking import BlockingReceiver, RespondError, APPID
from ..errors import TransferError
from .progress import ProgressPrinter
APPID = u"lothar.com/wormhole/text-or-file-xfer"
class RespondError(Exception):
def __init__(self, response):
self.response = response
def receive_twisted_sync(args):
# try to use twisted.internet.task.react(f) here (but it calls sys.exit
# directly)
@ -34,7 +40,14 @@ def receive_twisted_sync(args):
def receive_twisted(args):
return TwistedReceiver(args).go()
class TwistedReceiver(BlockingReceiver):
class TwistedReceiver:
def __init__(self, args):
assert isinstance(args.relay_url, type(u""))
self.args = args
def msg(self, *args, **kwargs):
print(*args, file=self.args.stdout, **kwargs)
# TODO: @handle_server_error
@inlineCallbacks
@ -101,6 +114,11 @@ class TwistedReceiver(BlockingReceiver):
self.args.code_length)
yield w.set_code(code)
def show_verifier(self, verifier):
verifier_hex = binascii.hexlify(verifier).decode("ascii")
if self.args.verify:
self.msg(u"Verifier %s." % verifier_hex)
@inlineCallbacks
def get_data(self, w):
try:
@ -119,6 +137,61 @@ class TwistedReceiver(BlockingReceiver):
data = json.dumps({"message_ack": "ok"}).encode("utf-8")
yield w.send_data(data)
def handle_file(self, them_d):
file_data = them_d["file"]
self.abs_destname = self.decide_destname("file",
file_data["filename"])
self.xfersize = file_data["filesize"]
self.msg(u"Receiving file (%d bytes) into: %s" %
(self.xfersize, os.path.basename(self.abs_destname)))
self.ask_permission()
tmp_destname = self.abs_destname + ".tmp"
return open(tmp_destname, "wb")
def handle_directory(self, them_d):
file_data = them_d["directory"]
zipmode = file_data["mode"]
if zipmode != "zipfile/deflated":
self.msg(u"Error: unknown directory-transfer mode '%s'" % (zipmode,))
raise RespondError({"error": "unknown mode"})
self.abs_destname = self.decide_destname("directory",
file_data["dirname"])
self.xfersize = file_data["zipsize"]
self.msg(u"Receiving directory (%d bytes) into: %s/" %
(self.xfersize, os.path.basename(self.abs_destname)))
self.msg(u"%d files, %d bytes (uncompressed)" %
(file_data["numfiles"], file_data["numbytes"]))
self.ask_permission()
return tempfile.SpooledTemporaryFile()
def decide_destname(self, mode, destname):
# the basename() is intended to protect us against
# "~/.ssh/authorized_keys" and other attacks
destname = os.path.basename(destname)
if self.args.output_file:
destname = self.args.output_file # override
abs_destname = os.path.join(self.args.cwd, destname)
# get confirmation from the user before writing to the local directory
if os.path.exists(abs_destname):
self.msg(u"Error: refusing to overwrite existing %s %s" %
(mode, destname))
raise RespondError({"error": "%s already exists" % mode})
return abs_destname
def ask_permission(self):
_start = self.args.timing.add_event("permission", waiting="user")
while True and not self.args.accept_file:
ok = six.moves.input("ok? (y/n): ")
if ok.lower().startswith("y"):
break
print(u"transfer rejected", file=sys.stderr)
self.args.timing.finish_event(_start, answer="no")
raise RespondError({"error": "transfer rejected"})
self.args.timing.finish_event(_start, answer="yes")
@inlineCallbacks
def establish_transit(self, w, them_d, tor_manager):
transit_key = w.derive_key(APPID+u"/transit-key")
@ -169,6 +242,27 @@ class TwistedReceiver(BlockingReceiver):
returnValue(1) # TODO: exit properly
assert received == self.xfersize
def write_file(self, f):
tmp_name = f.name
f.close()
os.rename(tmp_name, self.abs_destname)
self.msg(u"Received file written to %s" %
os.path.basename(self.abs_destname))
def write_directory(self, f):
self.msg(u"Unpacking zipfile..")
_start = self.args.timing.add_event("unpack zip")
with zipfile.ZipFile(f, "r", zipfile.ZIP_DEFLATED) as zf:
zf.extractall(path=self.abs_destname)
# extractall() appears to offer some protection against
# malicious pathnames. For example, "/tmp/oops" and
# "../tmp/oops" both do the same thing as the (safe)
# "tmp/oops".
self.msg(u"Received files written to %s/" %
os.path.basename(self.abs_destname))
f.close()
self.args.timing.finish_event(_start)
@inlineCallbacks
def close_transit(self, record_pipe):
_start = self.args.timing.add_event("ack")

View File

@ -1,216 +0,0 @@
from __future__ import print_function
import io, os, sys, json, binascii, six, tempfile, zipfile
from ..blocking.transcribe import Wormhole, WrongPasswordError
from ..blocking.transit import TransitReceiver, TransitError
from ..errors import handle_server_error, TransferError
from .progress import ProgressPrinter
APPID = u"lothar.com/wormhole/text-or-file-xfer"
def receive_blocking(args):
return BlockingReceiver(args).go()
class RespondError(Exception):
def __init__(self, response):
self.response = response
class BlockingReceiver:
def __init__(self, args):
assert isinstance(args.relay_url, type(u""))
self.args = args
def msg(self, *args, **kwargs):
print(*args, file=self.args.stdout, **kwargs)
@handle_server_error
def go(self):
with Wormhole(APPID, self.args.relay_url, timing=self.args.timing) as w:
self.handle_code(w)
verifier = w.get_verifier()
self.show_verifier(verifier)
them_d = self.get_data(w)
try:
if "message" in them_d:
self.handle_text(them_d, w)
return 0
if "file" in them_d:
f = self.handle_file(them_d)
rp = self.establish_transit(w, them_d)
self.transfer_data(rp, f)
self.write_file(f)
self.close_transit(rp)
elif "directory" in them_d:
f = self.handle_directory(them_d)
rp = self.establish_transit(w, them_d)
self.transfer_data(rp, f)
self.write_directory(f)
self.close_transit(rp)
else:
self.msg(u"I don't know what they're offering\n")
self.msg(u"Offer details:", them_d)
raise RespondError({"error": "unknown offer type"})
except RespondError as r:
data = json.dumps(r.response).encode("utf-8")
w.send_data(data)
return 1
return 0
def handle_code(self, w):
code = self.args.code
if self.args.zeromode:
assert not code
code = u"0-"
if not code:
code = w.input_code("Enter receive wormhole code: ",
self.args.code_length)
w.set_code(code)
def show_verifier(self, verifier):
verifier_hex = binascii.hexlify(verifier).decode("ascii")
if self.args.verify:
self.msg(u"Verifier %s." % verifier_hex)
def get_data(self, w):
try:
them_bytes = w.get_data()
except WrongPasswordError as e:
raise TransferError(u"ERROR: " + e.explain())
them_d = json.loads(them_bytes.decode("utf-8"))
if "error" in them_d:
raise TransferError(u"ERROR: " + them_d["error"])
return them_d
def handle_text(self, them_d, w):
# we're receiving a text message
self.msg(them_d["message"])
data = json.dumps({"message_ack": "ok"}).encode("utf-8")
w.send_data(data)
def handle_file(self, them_d):
file_data = them_d["file"]
self.abs_destname = self.decide_destname("file",
file_data["filename"])
self.xfersize = file_data["filesize"]
self.msg(u"Receiving file (%d bytes) into: %s" %
(self.xfersize, os.path.basename(self.abs_destname)))
self.ask_permission()
tmp_destname = self.abs_destname + ".tmp"
return open(tmp_destname, "wb")
def handle_directory(self, them_d):
file_data = them_d["directory"]
zipmode = file_data["mode"]
if zipmode != "zipfile/deflated":
self.msg(u"Error: unknown directory-transfer mode '%s'" % (zipmode,))
raise RespondError({"error": "unknown mode"})
self.abs_destname = self.decide_destname("directory",
file_data["dirname"])
self.xfersize = file_data["zipsize"]
self.msg(u"Receiving directory (%d bytes) into: %s/" %
(self.xfersize, os.path.basename(self.abs_destname)))
self.msg(u"%d files, %d bytes (uncompressed)" %
(file_data["numfiles"], file_data["numbytes"]))
self.ask_permission()
return tempfile.SpooledTemporaryFile()
def decide_destname(self, mode, destname):
# the basename() is intended to protect us against
# "~/.ssh/authorized_keys" and other attacks
destname = os.path.basename(destname)
if self.args.output_file:
destname = self.args.output_file # override
abs_destname = os.path.join(self.args.cwd, destname)
# get confirmation from the user before writing to the local directory
if os.path.exists(abs_destname):
self.msg(u"Error: refusing to overwrite existing %s %s" %
(mode, destname))
raise RespondError({"error": "%s already exists" % mode})
return abs_destname
def ask_permission(self):
_start = self.args.timing.add_event("permission", waiting="user")
while True and not self.args.accept_file:
ok = six.moves.input("ok? (y/n): ")
if ok.lower().startswith("y"):
break
print(u"transfer rejected", file=sys.stderr)
self.args.timing.finish_event(_start, answer="no")
raise RespondError({"error": "transfer rejected"})
self.args.timing.finish_event(_start, answer="yes")
def establish_transit(self, w, them_d):
transit_key = w.derive_key(APPID+u"/transit-key")
transit_receiver = TransitReceiver(self.args.transit_helper,
no_listen=self.args.no_listen,
timing=self.args.timing)
transit_receiver.set_transit_key(transit_key)
data = json.dumps({
"file_ack": "ok",
"transit": {
"direct_connection_hints": transit_receiver.get_direct_hints(),
"relay_connection_hints": transit_receiver.get_relay_hints(),
},
}).encode("utf-8")
w.send_data(data)
# now receive the rest of the owl
tdata = them_d["transit"]
transit_receiver.add_their_direct_hints(tdata["direct_connection_hints"])
transit_receiver.add_their_relay_hints(tdata["relay_connection_hints"])
record_pipe = transit_receiver.connect()
return record_pipe
def transfer_data(self, record_pipe, f):
self.msg(u"Receiving (%s).." % record_pipe.describe())
_start = self.args.timing.add_event("rx file")
progress_stdout = self.args.stdout
if self.args.hide_progress:
progress_stdout = io.StringIO()
received = 0
p = ProgressPrinter(self.xfersize, progress_stdout)
p.start()
while received < self.xfersize:
try:
plaintext = record_pipe.receive_record()
except TransitError:
self.msg()
self.msg(u"Connection dropped before full file received")
self.msg(u"got %d bytes, wanted %d" % (received, self.xfersize))
return 1
f.write(plaintext)
received += len(plaintext)
p.update(received)
p.finish()
self.args.timing.finish_event(_start)
assert received == self.xfersize
def write_file(self, f):
tmp_name = f.name
f.close()
os.rename(tmp_name, self.abs_destname)
self.msg(u"Received file written to %s" %
os.path.basename(self.abs_destname))
def write_directory(self, f):
self.msg(u"Unpacking zipfile..")
_start = self.args.timing.add_event("unpack zip")
with zipfile.ZipFile(f, "r", zipfile.ZIP_DEFLATED) as zf:
zf.extractall(path=self.abs_destname)
# extractall() appears to offer some protection against
# malicious pathnames. For example, "/tmp/oops" and
# "../tmp/oops" both do the same thing as the (safe)
# "tmp/oops".
self.msg(u"Received files written to %s/" %
os.path.basename(self.abs_destname))
f.close()
self.args.timing.finish_event(_start)
def close_transit(self, record_pipe):
_start = self.args.timing.add_event("ack")
record_pipe.send_record(b"ok\n")
record_pipe.close()
self.args.timing.finish_event(_start)

View File

@ -1,5 +1,5 @@
from __future__ import print_function
import io, json, binascii, six
import os, sys, io, json, binascii, six, tempfile, zipfile
from twisted.protocols import basic
from twisted.internet import reactor, defer
from twisted.internet.defer import inlineCallbacks, returnValue
@ -7,8 +7,94 @@ from ..errors import TransferError
from .progress import ProgressPrinter
from ..twisted.transcribe import Wormhole, WrongPasswordError
from ..twisted.transit import TransitSender
from .send_common import (APPID, handle_zero, build_other_command,
build_phase1_data)
APPID = u"lothar.com/wormhole/text-or-file-xfer"
def handle_zero(args):
if args.zeromode:
assert not args.code
args.code = u"0-"
def build_other_command(args):
other_cmd = "wormhole receive"
if args.verify:
other_cmd = "wormhole --verify receive"
if args.zeromode:
other_cmd += " -0"
return other_cmd
def build_phase1_data(args):
phase1 = {}
text = args.text
if text == "-":
print(u"Reading text message from stdin..", file=args.stdout)
text = sys.stdin.read()
if not text and not args.what:
text = six.moves.input("Text to send: ")
if text is not None:
print(u"Sending text message (%d bytes)" % len(text), file=args.stdout)
phase1 = { "message": text }
fd_to_send = None
return phase1, fd_to_send
what = os.path.join(args.cwd, args.what)
what = what.rstrip(os.sep)
if not os.path.exists(what):
raise TransferError("Cannot send: no file/directory named '%s'" %
args.what)
basename = os.path.basename(what)
if os.path.isfile(what):
# we're sending a file
filesize = os.stat(what).st_size
phase1["file"] = {
"filename": basename,
"filesize": filesize,
}
print(u"Sending %d byte file named '%s'" % (filesize, basename),
file=args.stdout)
fd_to_send = open(what, "rb")
return phase1, fd_to_send
if os.path.isdir(what):
print(u"Building zipfile..", file=args.stdout)
# We're sending a directory. Create a zipfile in a tempdir and
# send that.
fd_to_send = tempfile.SpooledTemporaryFile()
# TODO: I think ZIP_DEFLATED means compressed.. check it
num_files = 0
num_bytes = 0
tostrip = len(what.split(os.sep))
with zipfile.ZipFile(fd_to_send, "w", zipfile.ZIP_DEFLATED) as zf:
for path,dirs,files in os.walk(what):
# path always starts with args.what, then sometimes might
# have "/subdir" appended. We want the zipfile to contain
# "" or "subdir"
localpath = list(path.split(os.sep)[tostrip:])
for fn in files:
archivename = os.path.join(*tuple(localpath+[fn]))
localfilename = os.path.join(path, fn)
zf.write(localfilename, archivename)
num_bytes += os.stat(localfilename).st_size
num_files += 1
fd_to_send.seek(0,2)
filesize = fd_to_send.tell()
fd_to_send.seek(0,0)
phase1["directory"] = {
"mode": "zipfile/deflated",
"dirname": basename,
"zipsize": filesize,
"numbytes": num_bytes,
"numfiles": num_files,
}
print(u"Sending directory (%d bytes compressed) named '%s'"
% (filesize, basename), file=args.stdout)
return phase1, fd_to_send
raise TypeError("'%s' is neither file nor directory" % args.what)
def send_twisted_sync(args):
# try to use twisted.internet.task.react(f) here (but it calls sys.exit

View File

@ -1,129 +0,0 @@
from __future__ import print_function
import json, binascii, six
from ..errors import TransferError
from .progress import ProgressPrinter
from ..blocking.transcribe import Wormhole, WrongPasswordError
from ..blocking.transit import TransitSender
from ..errors import handle_server_error
from .send_common import (APPID, handle_zero, build_other_command,
build_phase1_data)
@handle_server_error
def send_blocking(args):
assert isinstance(args.relay_url, type(u""))
handle_zero(args)
phase1, fd_to_send = build_phase1_data(args)
other_cmd = build_other_command(args)
print(u"On the other computer, please run: %s" % other_cmd,
file=args.stdout)
if fd_to_send is not None:
transit_sender = TransitSender(args.transit_helper,
no_listen=args.no_listen,
timing=args.timing)
transit_data = {
"direct_connection_hints": transit_sender.get_direct_hints(),
"relay_connection_hints": transit_sender.get_relay_hints(),
}
phase1["transit"] = transit_data
with Wormhole(APPID, args.relay_url, timing=args.timing) as w:
if args.code:
w.set_code(args.code)
code = args.code
else:
code = w.get_code(args.code_length)
if not args.zeromode:
print(u"Wormhole code is: %s" % code, file=args.stdout)
print(u"", file=args.stdout)
# get the verifier, because that also lets us derive the transit key,
# which we want to set before revealing the connection hints to the
# far side, so we'll be ready for them when they connect
verifier = binascii.hexlify(w.get_verifier()).decode("ascii")
if args.verify:
_do_verify(verifier, w)
if fd_to_send is not None:
transit_key = w.derive_key(APPID+"/transit-key")
transit_sender.set_transit_key(transit_key)
my_phase1_bytes = json.dumps(phase1).encode("utf-8")
w.send_data(my_phase1_bytes)
try:
them_phase1_bytes = w.get_data()
except WrongPasswordError as e:
raise TransferError(e.explain())
them_phase1 = json.loads(them_phase1_bytes.decode("utf-8"))
if fd_to_send is None:
if them_phase1["message_ack"] == "ok":
print(u"text message sent", file=args.stdout)
return 0
raise TransferError("error sending text: %r" % (them_phase1,))
return _send_file_blocking(them_phase1, fd_to_send,
transit_sender, args.stdout, args.hide_progress,
args.timing)
def _do_verify(verifier, w):
while True:
ok = six.moves.input("Verifier %s. ok? (yes/no): " % verifier)
if ok.lower() == "yes":
break
if ok.lower() == "no":
reject_data = json.dumps({"error": "verification rejected",
}).encode("utf-8")
w.send_data(reject_data)
raise TransferError("verification rejected, abandoning transfer")
def _send_file_blocking(them_phase1, fd_to_send, transit_sender,
stdout, hide_progress, timing):
# we're sending a file, if they accept it
if "error" in them_phase1:
raise TransferError("remote error, transfer abandoned: %s"
% them_phase1["error"])
if them_phase1.get("file_ack") != "ok":
raise TransferError("ambiguous response from remote, "
"transfer abandoned: %s" % (them_phase1,))
tdata = them_phase1["transit"]
transit_sender.add_their_direct_hints(tdata["direct_connection_hints"])
transit_sender.add_their_relay_hints(tdata["relay_connection_hints"])
record_pipe = transit_sender.connect()
print(u"Sending (%s).." % record_pipe.describe(), file=stdout)
_start = timing.add_event("tx file")
CHUNKSIZE = 64*1024
fd_to_send.seek(0,2)
filesize = fd_to_send.tell()
fd_to_send.seek(0,0)
p = ProgressPrinter(filesize, stdout)
with fd_to_send as f:
sent = 0
if not hide_progress:
p.start()
while sent < filesize:
plaintext = f.read(CHUNKSIZE)
record_pipe.send_record(plaintext)
sent += len(plaintext)
if not hide_progress:
p.update(sent)
if not hide_progress:
p.finish()
timing.finish_event(_start)
_start = timing.add_event("get ack")
print(u"File sent.. waiting for confirmation", file=stdout)
ack = record_pipe.receive_record()
record_pipe.close()
if ack == b"ok\n":
print(u"Confirmation received. Transfer complete.", file=stdout)
timing.finish_event(_start, ack="ok")
return 0
timing.finish_event(_start, ack="failed")
raise TransferError("Transfer failed (remote says: %r)" % ack)

View File

@ -21,22 +21,14 @@ def dispatch(args):
from ..servers import cmd_usage
return cmd_usage.tail_usage(args)
if args.tor:
args.twisted = True
if args.func == "send/send":
if args.twisted:
from . import cmd_send_twisted
return cmd_send_twisted.send_twisted_sync(args)
from . import cmd_send_blocking
return cmd_send_blocking.send_blocking(args)
from . import cmd_send
return cmd_send.send_twisted_sync(args)
if args.func == "receive/receive":
if args.twisted:
_start = args.timing.add_event("import c_r_t")
from . import cmd_receive_twisted
args.timing.finish_event(_start)
return cmd_receive_twisted.receive_twisted_sync(args)
from . import cmd_receive_blocking
return cmd_receive_blocking.receive_blocking(args)
_start = args.timing.add_event("import c_r_t")
from . import cmd_receive
args.timing.finish_event(_start)
return cmd_receive.receive_twisted_sync(args)
raise ValueError("unknown args.func %s" % args.func)

View File

@ -1,90 +0,0 @@
from __future__ import print_function
import os, sys, six, tempfile, zipfile
from ..errors import TransferError
APPID = u"lothar.com/wormhole/text-or-file-xfer"
def handle_zero(args):
if args.zeromode:
assert not args.code
args.code = u"0-"
def build_other_command(args):
other_cmd = "wormhole receive"
if args.verify:
other_cmd = "wormhole --verify receive"
if args.zeromode:
other_cmd += " -0"
return other_cmd
def build_phase1_data(args):
phase1 = {}
text = args.text
if text == "-":
print(u"Reading text message from stdin..", file=args.stdout)
text = sys.stdin.read()
if not text and not args.what:
text = six.moves.input("Text to send: ")
if text is not None:
print(u"Sending text message (%d bytes)" % len(text), file=args.stdout)
phase1 = { "message": text }
fd_to_send = None
return phase1, fd_to_send
what = os.path.join(args.cwd, args.what)
what = what.rstrip(os.sep)
if not os.path.exists(what):
raise TransferError("Cannot send: no file/directory named '%s'" %
args.what)
basename = os.path.basename(what)
if os.path.isfile(what):
# we're sending a file
filesize = os.stat(what).st_size
phase1["file"] = {
"filename": basename,
"filesize": filesize,
}
print(u"Sending %d byte file named '%s'" % (filesize, basename),
file=args.stdout)
fd_to_send = open(what, "rb")
return phase1, fd_to_send
if os.path.isdir(what):
print(u"Building zipfile..", file=args.stdout)
# We're sending a directory. Create a zipfile in a tempdir and
# send that.
fd_to_send = tempfile.SpooledTemporaryFile()
# TODO: I think ZIP_DEFLATED means compressed.. check it
num_files = 0
num_bytes = 0
tostrip = len(what.split(os.sep))
with zipfile.ZipFile(fd_to_send, "w", zipfile.ZIP_DEFLATED) as zf:
for path,dirs,files in os.walk(what):
# path always starts with args.what, then sometimes might
# have "/subdir" appended. We want the zipfile to contain
# "" or "subdir"
localpath = list(path.split(os.sep)[tostrip:])
for fn in files:
archivename = os.path.join(*tuple(localpath+[fn]))
localfilename = os.path.join(path, fn)
zf.write(localfilename, archivename)
num_bytes += os.stat(localfilename).st_size
num_files += 1
fd_to_send.seek(0,2)
filesize = fd_to_send.tell()
fd_to_send.seek(0,0)
phase1["directory"] = {
"mode": "zipfile/deflated",
"dirname": basename,
"zipsize": filesize,
"numbytes": num_bytes,
"numfiles": num_files,
}
print(u"Sending directory (%d bytes compressed) named '%s'"
% (filesize, basename), file=args.stdout)
return phase1, fd_to_send
raise TypeError("'%s' is neither file nor directory" % args.what)

View File

@ -1,7 +1,7 @@
from twisted.application import service
from twisted.internet import reactor, defer
from twisted.python import log
from ..twisted.util import allocate_ports
from ..twisted.transit import allocate_tcp_port
from ..servers.server import RelayServer
from .. import __version__
@ -9,19 +9,16 @@ class ServerBase:
def setUp(self):
self.sp = service.MultiService()
self.sp.startService()
d = allocate_ports()
def _got_ports(ports):
relayport, transitport = ports
s = RelayServer("tcp:%d:interface=127.0.0.1" % relayport,
"tcp:%s:interface=127.0.0.1" % transitport,
__version__)
s.setServiceParent(self.sp)
self._relay_server = s.relay
self._transit_server = s.transit
self.relayurl = u"http://127.0.0.1:%d/wormhole-relay/" % relayport
self.transit = u"tcp:127.0.0.1:%d" % transitport
d.addCallback(_got_ports)
return d
relayport = allocate_tcp_port()
transitport = allocate_tcp_port()
s = RelayServer("tcp:%d:interface=127.0.0.1" % relayport,
"tcp:%s:interface=127.0.0.1" % transitport,
__version__)
s.setServiceParent(self.sp)
self._relay_server = s.relay
self._transit_server = s.transit
self.relayurl = u"http://127.0.0.1:%d/wormhole-relay/" % relayport
self.transit = u"tcp:127.0.0.1:%d" % transitport
def tearDown(self):
# Unit tests that spawn a (blocking) client in a thread might still

View File

@ -1,14 +1,11 @@
from __future__ import print_function
import json
from twisted.trial import unittest
from twisted.internet.defer import gatherResults, succeed, inlineCallbacks
from twisted.internet.defer import gatherResults, succeed
from twisted.internet.threads import deferToThread
from ..blocking.transcribe import (Wormhole, UsageError, ChannelManager,
WrongPasswordError)
from ..blocking.eventsource import EventSourceFollower
from ..blocking.transit import (TransitSender, TransitReceiver,
build_sender_handshake,
build_receiver_handshake)
from .common import ServerBase
APPID = u"appid"
@ -447,101 +444,3 @@ class EventSourceClient(unittest.TestCase):
(u"message", u"three"),
(u"e2", u"four"),
])
class Transit(_DoBothMixin, ServerBase, unittest.TestCase):
def test_hints(self):
r = TransitReceiver(self.transit)
hints = r.get_direct_hints()
self.assertTrue(len(hints), hints)
@inlineCallbacks
def test_direct_to_receiver(self):
s = TransitSender(self.transit)
r = TransitReceiver(self.transit)
key = b"\x00"*32
# force the connection to be sender->receiver
s.set_transit_key(key)
# only use 127.0.0.1
hint = u"tcp:127.0.0.1:%d" % r.listener.server_address[1]
s.add_their_direct_hints([hint])
s.add_their_relay_hints([])
r.set_transit_key(key)
r.add_their_direct_hints([])
r.add_their_relay_hints([])
# it'd be nice to factor this chunk out with 'yield from', but that
# didn't appear until python-3.3, and isn't in py2 at all.
(sp, rp) = yield self.doBoth([s.connect], [r.connect])
yield deferToThread(sp.send_record, b"01234")
rec = yield deferToThread(rp.receive_record)
self.assertEqual(rec, b"01234")
yield deferToThread(sp.close)
yield deferToThread(rp.close)
@inlineCallbacks
def test_direct_to_sender(self):
s = TransitSender(self.transit)
r = TransitReceiver(self.transit)
key = b"\x00"*32
# force the connection to be receiver->sender
s.set_transit_key(key)
s.add_their_direct_hints([])
s.add_their_relay_hints([])
r.set_transit_key(key)
hint = u"tcp:127.0.0.1:%d" % s.listener.server_address[1]
r.add_their_direct_hints([hint])
r.add_their_relay_hints([])
(sp, rp) = yield self.doBoth([s.connect], [r.connect])
yield deferToThread(sp.send_record, b"01234")
rec = yield deferToThread(rp.receive_record)
self.assertEqual(rec, b"01234")
yield deferToThread(sp.close)
yield deferToThread(rp.close)
@inlineCallbacks
def test_relay(self):
s = TransitSender(self.transit)
r = TransitReceiver(self.transit)
key = b"\x00"*32
# force the connection to use the relay by not revealing direct hints
s.set_transit_key(key)
s.add_their_direct_hints([])
s.add_their_relay_hints(r.get_relay_hints())
r.set_transit_key(key)
r.add_their_direct_hints([])
r.add_their_relay_hints(s.get_relay_hints())
(sp, rp) = yield self.doBoth([s.connect], [r.connect])
yield deferToThread(sp.send_record, b"01234")
rec = yield deferToThread(rp.receive_record)
self.assertEqual(rec, b"01234")
yield deferToThread(sp.close)
yield deferToThread(rp.close)
# TODO: this may be racy if we don't poll the server to make sure
# it's witnessed the first connection closing before querying the DB
#import time
#yield deferToThread(time.sleep, 1)
# check the transit relay's DB, make sure it counted the bytes
db = self._transit_server._db
c = db.execute("SELECT * FROM `usage` WHERE `type`=?", (u"transit",))
rows = c.fetchall()
self.assertEqual(len(rows), 1)
row = rows[0]
self.assertEqual(row["result"], u"happy")
# Sender first writes relay_handshake and waits for OK, but that's
# not counted by the transit server. Then sender writes
# sender_handshake and waits for receiver_handshake. Then sender
# writes GO and the body. Body is length-prefixed SecretBox, so
# includes 4-byte length, 24-byte nonce, and 16-byte MAC.
sender_count = (len(build_sender_handshake(b""))+
len(b"go\n")+
4+24+len(b"01234")+16)
# Receiver first writes relay_handshake and waits for OK, but that's
# not counted. Then receiver writes receiver_handshake and waits for
# sender_handshake+GO.
receiver_count = len(build_receiver_handshake(b""))
self.assertEqual(row["total_bytes"], sender_count+receiver_count)

View File

@ -4,12 +4,10 @@ from twisted.trial import unittest
from twisted.python import procutils, log
from twisted.internet.utils import getProcessOutputAndValue
from twisted.internet.defer import inlineCallbacks
from twisted.internet.threads import deferToThread
from .. import __version__
from .common import ServerBase
from ..scripts import (runner, cmd_send_blocking, cmd_send_twisted,
cmd_receive_blocking, cmd_receive_twisted)
from ..scripts.send_common import build_phase1_data
from ..scripts import runner, cmd_send, cmd_receive
from ..scripts.cmd_send import build_phase1_data
from ..errors import TransferError
from ..timing import DebugTiming
@ -219,8 +217,7 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
@inlineCallbacks
def _do_test(self, as_subprocess=False,
mode="text", addslash=False, override_filename=False,
sender_twisted=False, receiver_twisted=False):
mode="text", addslash=False, override_filename=False):
assert mode in ("text", "file", "directory")
common_args = ["--hide-progress",
"--relay-url", self.relayurl,
@ -314,14 +311,8 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
rargs.stdout = io.StringIO()
rargs.stderr = io.StringIO()
rargs.timing = DebugTiming()
if sender_twisted:
send_d = cmd_send_twisted.send_twisted(sargs)
else:
send_d = deferToThread(cmd_send_blocking.send_blocking, sargs)
if receiver_twisted:
receive_d = cmd_receive_twisted.receive_twisted(rargs)
else:
receive_d = deferToThread(cmd_receive_blocking.receive_blocking, rargs)
send_d = cmd_send.send_twisted(sargs)
receive_d = cmd_receive.receive_twisted(rargs)
# The sender might fail, leaving the receiver hanging, or vice
# versa. If either side fails, cancel the other, so it won't
@ -418,24 +409,11 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
return self._do_test()
def test_text_subprocess(self):
return self._do_test(as_subprocess=True)
def test_text_twisted_to_blocking(self):
return self._do_test(sender_twisted=True)
def test_text_blocking_to_twisted(self):
return self._do_test(receiver_twisted=True)
def test_text_twisted_to_twisted(self):
return self._do_test(sender_twisted=True, receiver_twisted=True)
def test_file(self):
return self._do_test(mode="file")
def test_file_override(self):
return self._do_test(mode="file", override_filename=True)
def test_file_twisted_to_blocking(self):
return self._do_test(mode="file", sender_twisted=True)
def test_file_blocking_to_twisted(self):
return self._do_test(mode="file", receiver_twisted=True)
def test_file_twisted_to_twisted(self):
return self._do_test(mode="file",
sender_twisted=True, receiver_twisted=True)
def test_directory(self):
return self._do_test(mode="directory")
@ -443,16 +421,3 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
return self._do_test(mode="directory", addslash=True)
def test_directory_override(self):
return self._do_test(mode="directory", override_filename=True)
def test_directory_twisted_to_blocking(self):
return self._do_test(mode="directory", sender_twisted=True)
def test_directory_twisted_to_blocking_addslash(self):
return self._do_test(mode="directory", addslash=True,
sender_twisted=True)
def test_directory_blocking_to_twisted(self):
return self._do_test(mode="directory", receiver_twisted=True)
def test_directory_twisted_to_twisted(self):
return self._do_test(mode="directory",
sender_twisted=True, receiver_twisted=True)
def test_directory_twisted_to_twisted_addslash(self):
return self._do_test(mode="directory", addslash=True,
sender_twisted=True, receiver_twisted=True)

View File

@ -1,10 +1,12 @@
from __future__ import print_function
import json
import requests
from binascii import hexlify
from six.moves.urllib_parse import urlencode
from twisted.trial import unittest
from twisted.internet import reactor, defer
from twisted.internet import protocol, reactor, defer
from twisted.internet.threads import deferToThread
from twisted.internet.endpoints import clientFromString, connectProtocol
from twisted.web.client import getPage, Agent, readBody
from .. import __version__
from .common import ServerBase
@ -434,7 +436,30 @@ class Summary(unittest.TestCase):
self.failUnlessEqual(c._summarize(make_moods(None, "scary"), 41),
(1, "scary", 40, 9))
class Transit(unittest.TestCase):
class Accumulator(protocol.Protocol):
def __init__(self):
self.data = b""
self.count = 0
self._wait = None
def waitForBytes(self, more):
assert self._wait is None
self.count = more
self._wait = defer.Deferred()
self._check_done()
return self._wait
def dataReceived(self, data):
self.data = self.data + data
self._check_done()
def _check_done(self):
if self._wait and len(self.data) >= self.count:
d = self._wait
self._wait = None
d.callback(self)
def connectionLost(self, why):
if self._wait:
self._wait.errback(RuntimeError("closed"))
class Transit(ServerBase, unittest.TestCase):
def test_blur_size(self):
blur = transit_server.blur_size
self.failUnlessEqual(blur(0), 0)
@ -453,3 +478,62 @@ class Transit(unittest.TestCase):
self.failUnlessEqual(blur(1100e6), 1100e6)
self.failUnlessEqual(blur(1150e6), 1200e6)
@defer.inlineCallbacks
def test_basic(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
a2 = yield connectProtocol(ep, Accumulator())
token1 = b"\x00"*32
a1.transport.write(b"please relay " + hexlify(token1) + b"\n")
a2.transport.write(b"please relay " + hexlify(token1) + b"\n")
# a correct handshake yields an ack, after which we can send
exp = b"ok\n"
yield a1.waitForBytes(len(exp))
self.assertEqual(a1.data, exp)
s1 = b"data1"
a1.transport.write(s1)
exp = b"ok\n"
yield a2.waitForBytes(len(exp))
self.assertEqual(a2.data, exp)
# all data they sent after the handshake should be given to us
exp = b"ok\n"+s1
yield a2.waitForBytes(len(exp))
self.assertEqual(a2.data, exp)
a1.transport.loseConnection()
a2.transport.loseConnection()
@defer.inlineCallbacks
def test_bad_handshake(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
token1 = b"\x00"*32
# the server waits for the exact number of bytes in the expected
# handshake message. to trigger "bad handshake", we must match.
a1.transport.write(b"please DELAY " + hexlify(token1) + b"\n")
exp = b"bad handshake\n"
yield a1.waitForBytes(len(exp))
self.assertEqual(a1.data, exp)
a1.transport.loseConnection()
@defer.inlineCallbacks
def test_impatience(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
token1 = b"\x00"*32
# sending too many bytes is impatience.
a1.transport.write(b"please RELAY NOWNOW " + hexlify(token1) + b"\n")
exp = b"impatient\n"
yield a1.waitForBytes(len(exp))
self.assertEqual(a1.data, exp)
a1.transport.loseConnection()

View File

@ -1,90 +0,0 @@
from __future__ import print_function
from binascii import hexlify
from twisted.trial import unittest
from twisted.internet import protocol, defer, reactor
from twisted.internet.endpoints import clientFromString, connectProtocol
from .common import ServerBase
class Accumulator(protocol.Protocol):
def __init__(self):
self.data = b""
self.count = 0
self._wait = None
def waitForBytes(self, more):
assert self._wait is None
self.count = more
self._wait = defer.Deferred()
self._check_done()
return self._wait
def dataReceived(self, data):
self.data = self.data + data
self._check_done()
def _check_done(self):
if self._wait and len(self.data) >= self.count:
d = self._wait
self._wait = None
d.callback(self)
def connectionLost(self, why):
if self._wait:
self._wait.errback(RuntimeError("closed"))
class Transit(ServerBase, unittest.TestCase):
@defer.inlineCallbacks
def test_basic(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
a2 = yield connectProtocol(ep, Accumulator())
token1 = b"\x00"*32
a1.transport.write(b"please relay " + hexlify(token1) + b"\n")
a2.transport.write(b"please relay " + hexlify(token1) + b"\n")
# a correct handshake yields an ack, after which we can send
exp = b"ok\n"
yield a1.waitForBytes(len(exp))
self.assertEqual(a1.data, exp)
s1 = b"data1"
a1.transport.write(s1)
exp = b"ok\n"
yield a2.waitForBytes(len(exp))
self.assertEqual(a2.data, exp)
# all data they sent after the handshake should be given to us
exp = b"ok\n"+s1
yield a2.waitForBytes(len(exp))
self.assertEqual(a2.data, exp)
a1.transport.loseConnection()
a2.transport.loseConnection()
@defer.inlineCallbacks
def test_bad_handshake(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
token1 = b"\x00"*32
# the server waits for the exact number of bytes in the expected
# handshake message. to trigger "bad handshake", we must match.
a1.transport.write(b"please DELAY " + hexlify(token1) + b"\n")
exp = b"bad handshake\n"
yield a1.waitForBytes(len(exp))
self.assertEqual(a1.data, exp)
a1.transport.loseConnection()
@defer.inlineCallbacks
def test_impatience(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
token1 = b"\x00"*32
# sending too many bytes is impatience.
a1.transport.write(b"please RELAY NOWNOW " + hexlify(token1) + b"\n")
exp = b"impatient\n"
yield a1.waitForBytes(len(exp))
self.assertEqual(a1.data, exp)
a1.transport.loseConnection()

View File

@ -1,87 +0,0 @@
import re
from binascii import hexlify
from .util.hkdf import HKDF
class TransitError(Exception):
pass
class BadHandshake(Exception):
pass
class TransitClosed(TransitError):
pass
class BadNonce(TransitError):
pass
# The beginning of each TCP connection consists of the following handshake
# messages. The sender transmits the same text regardless of whether it is on
# the initiating/connecting end of the TCP connection, or on the
# listening/accepting side. Same for the receiver.
#
# sender -> receiver: transit sender TXID_HEX ready\n\n
# receiver -> sender: transit receiver RXID_HEX ready\n\n
#
# Any deviations from this result in the socket being closed. The handshake
# messages are designed to provoke an invalid response from other sorts of
# servers (HTTP, SMTP, echo).
#
# If the sender is satisfied with the handshake, and this is the first socket
# to complete negotiation, the sender does:
#
# sender -> receiver: go\n
#
# and the next byte on the wire will be from the application.
#
# If this is not the first socket, the sender does:
#
# sender -> receiver: nevermind\n
#
# and closes the socket.
# So the receiver looks for "transit sender TXID_HEX ready\n\ngo\n" and hangs
# up upon the first wrong byte. The sender lookgs for "transit receiver
# RXID_HEX ready\n\n" and then makes a first/not-first decision about sending
# "go\n" or "nevermind\n"+close().
def build_receiver_handshake(key):
hexid = HKDF(key, 32, CTXinfo=b"transit_receiver")
return b"transit receiver "+hexlify(hexid)+b" ready\n\n"
def build_sender_handshake(key):
hexid = HKDF(key, 32, CTXinfo=b"transit_sender")
return b"transit sender "+hexlify(hexid)+b" ready\n\n"
def build_relay_handshake(key):
token = HKDF(key, 32, CTXinfo=b"transit_relay_token")
return b"please relay "+hexlify(token)+b"\n"
# The hint format is: TYPE,VALUE= /^([a-zA-Z0-9]+):(.*)$/ . VALUE depends
# upon TYPE, and it can have more colons in it. For TYPE=tcp (the only one
# currently defined), ADDR,PORT = /^(.*):(\d+)$/ , so ADDR can have colons.
# ADDR can be a hostname, ipv4 dotted-quad, or ipv6 colon-hex. If the hint
# publisher wants anonymity, their only hint's ADDR will end in .onion .
def parse_hint_tcp(hint):
assert isinstance(hint, type(u""))
# return tuple or None for an unparseable hint
mo = re.search(r'^([a-zA-Z0-9]+):(.*)$', hint)
if not mo:
print("unparseable hint '%s'" % (hint,))
return None
hint_type = mo.group(1)
if hint_type != "tcp":
print("unknown hint type '%s' in '%s'" % (hint_type, hint))
return None
hint_value = mo.group(2)
mo = re.search(r'^(.*):(\d+)$', hint_value)
if not mo:
print("unparseable TCP hint '%s'" % (hint,))
return None
hint_host = mo.group(1)
try:
hint_port = int(mo.group(2))
except ValueError:
print("non-numeric port in TCP hint '%s'" % (hint,))
return None
return hint_host, hint_port

View File

@ -1,5 +1,5 @@
from __future__ import print_function
import sys, time, socket, collections
import re, sys, time, socket, collections
from binascii import hexlify, unhexlify
from zope.interface import implementer
from twisted.python.runtime import platformType
@ -12,11 +12,6 @@ from ..util import ipaddrs
from ..util.hkdf import HKDF
from ..errors import UsageError
from ..timing import DebugTiming
from ..transit_common import (BadHandshake,
BadNonce,
build_receiver_handshake,
build_sender_handshake,
build_relay_handshake)
def debug(msg):
if False:
@ -24,6 +19,90 @@ def debug(msg):
def since(start):
return time.time() - start
class TransitError(Exception):
pass
class BadHandshake(Exception):
pass
class TransitClosed(TransitError):
pass
class BadNonce(TransitError):
pass
# The beginning of each TCP connection consists of the following handshake
# messages. The sender transmits the same text regardless of whether it is on
# the initiating/connecting end of the TCP connection, or on the
# listening/accepting side. Same for the receiver.
#
# sender -> receiver: transit sender TXID_HEX ready\n\n
# receiver -> sender: transit receiver RXID_HEX ready\n\n
#
# Any deviations from this result in the socket being closed. The handshake
# messages are designed to provoke an invalid response from other sorts of
# servers (HTTP, SMTP, echo).
#
# If the sender is satisfied with the handshake, and this is the first socket
# to complete negotiation, the sender does:
#
# sender -> receiver: go\n
#
# and the next byte on the wire will be from the application.
#
# If this is not the first socket, the sender does:
#
# sender -> receiver: nevermind\n
#
# and closes the socket.
# So the receiver looks for "transit sender TXID_HEX ready\n\ngo\n" and hangs
# up upon the first wrong byte. The sender lookgs for "transit receiver
# RXID_HEX ready\n\n" and then makes a first/not-first decision about sending
# "go\n" or "nevermind\n"+close().
def build_receiver_handshake(key):
hexid = HKDF(key, 32, CTXinfo=b"transit_receiver")
return b"transit receiver "+hexlify(hexid)+b" ready\n\n"
def build_sender_handshake(key):
hexid = HKDF(key, 32, CTXinfo=b"transit_sender")
return b"transit sender "+hexlify(hexid)+b" ready\n\n"
def build_relay_handshake(key):
token = HKDF(key, 32, CTXinfo=b"transit_relay_token")
return b"please relay "+hexlify(token)+b"\n"
# The hint format is: TYPE,VALUE= /^([a-zA-Z0-9]+):(.*)$/ . VALUE depends
# upon TYPE, and it can have more colons in it. For TYPE=tcp (the only one
# currently defined), ADDR,PORT = /^(.*):(\d+)$/ , so ADDR can have colons.
# ADDR can be a hostname, ipv4 dotted-quad, or ipv6 colon-hex. If the hint
# publisher wants anonymity, their only hint's ADDR will end in .onion .
def parse_hint_tcp(hint):
assert isinstance(hint, type(u""))
# return tuple or None for an unparseable hint
mo = re.search(r'^([a-zA-Z0-9]+):(.*)$', hint)
if not mo:
print("unparseable hint '%s'" % (hint,))
return None
hint_type = mo.group(1)
if hint_type != "tcp":
print("unknown hint type '%s' in '%s'" % (hint_type, hint))
return None
hint_value = mo.group(2)
mo = re.search(r'^(.*):(\d+)$', hint_value)
if not mo:
print("unparseable TCP hint '%s'" % (hint,))
return None
hint_host = mo.group(1)
try:
hint_port = int(mo.group(2))
except ValueError:
print("non-numeric port in TCP hint '%s'" % (hint,))
return None
return hint_host, hint_port
TIMEOUT=15
@implementer(interfaces.IProducer, interfaces.IConsumer)
@ -684,7 +763,7 @@ class Common:
return d
def _endpoint_from_hint(self, hint):
# TODO: use transit_common.parse_hint_tcp
# TODO: use parse_hint_tcp
if ":" not in hint:
return None
pieces = hint.split(":")

View File

@ -1,21 +0,0 @@
from twisted.internet import defer, protocol, endpoints, reactor
def allocate_port():
ep = endpoints.serverFromString(reactor, "tcp:0:interface=127.0.0.1")
d = ep.listen(protocol.Factory())
def _listening(lp):
port = lp.getHost().port
d2 = lp.stopListening()
d2.addCallback(lambda _: port)
return d2
d.addCallback(_listening)
return d
def allocate_ports():
d = defer.DeferredList([allocate_port(), allocate_port()])
def _done(results):
port1 = results[0][1]
port2 = results[1][1]
return (port1, port2)
d.addCallback(_done)
return d

View File

@ -1,50 +0,0 @@
# -*- test-case-name: foolscap.test_observer -*-
# many thanks to AllMyData for contributing the initial version of this code
from twisted.internet import defer
from foolscap import eventual
class OneShotObserverList(object):
"""A one-shot event distributor.
Subscribers can get a Deferred that will fire with the results of the
event once it finally occurs. The caller does not need to know whether
the event has happened yet or not: they get a Deferred in either case.
The Deferreds returned to subscribers are guaranteed to not fire in the
current reactor turn; instead, eventually() is used to fire them in a
later turn. Look at Mark Miller's 'Concurrency Among Strangers' paper on
erights.org for a description of why this property is useful.
I can only be fired once."""
def __init__(self):
self._fired = False
self._result = None
self._watchers = []
self.__repr__ = self._unfired_repr
def _unfired_repr(self):
return "<OneShotObserverList [%s]>" % (self._watchers, )
def _fired_repr(self):
return "<OneShotObserverList -> %s>" % (self._result, )
def whenFired(self):
if self._fired:
return eventual.fireEventually(self._result)
d = defer.Deferred()
self._watchers.append(d)
return d
def fire(self, result):
assert not self._fired
self._fired = True
self._result = result
for w in self._watchers:
eventual.eventually(w.callback, result)
del self._watchers
self.__repr__ = self._fired_repr

View File

@ -7,12 +7,6 @@
envlist = py27,py33,py34,py35
skip_missing_interpreters = True
# There's a race-condition bug in Twisted-15.5.0 (#8014, fixed in trunk) that
# manifests as various intermittent magic-wormhole test failures that always
# include twisted.internet.endpoints "iterateEndpoint" or "checkDone". So run
# all builds with a copy of Twisted from git 'trunk' until Twisted-16.0.0 is
# released and we can just depend on that.
# On windows we need "pypiwin32" installed. It's supposedly possible to make
# Twisted do this by depending upon "twisted[windows]" instead of just
# "twisted", but when I try this via Appveyor, the extra is ignored.
@ -24,7 +18,6 @@ skip_missing_interpreters = True
[testenv]
deps =
twisted >= 16.1.0
pyflakes
{env:EXTRA_DEPENDENCY:}
commands =