Use twisted implementation all the time.
Merge commit '1a455c0'
This commit is contained in:
commit
7e1405576e
4
setup.py
4
setup.py
|
@ -21,10 +21,8 @@ setup(name="magic-wormhole",
|
|||
entry_points={"console_scripts":
|
||||
["wormhole = wormhole.scripts.runner:entry"]},
|
||||
install_requires=["spake2==0.3", "pynacl", "requests", "argparse",
|
||||
"six"],
|
||||
"six", "twisted >= 16.1.0"],
|
||||
extras_require={"tor": ["txtorcon", "ipaddr"]},
|
||||
# for Twisted support, we want Twisted>=15.5.0. Older Twisteds don't
|
||||
# provide sufficient python3 compatibility.
|
||||
test_suite="wormhole.test",
|
||||
cmdclass=commands,
|
||||
)
|
||||
|
|
|
@ -1,400 +0,0 @@
|
|||
from __future__ import print_function
|
||||
import time, threading, socket
|
||||
from six.moves import socketserver
|
||||
from binascii import hexlify, unhexlify
|
||||
from nacl.secret import SecretBox
|
||||
from ..util import ipaddrs
|
||||
from ..util.hkdf import HKDF
|
||||
from ..errors import UsageError
|
||||
from ..timing import DebugTiming
|
||||
from ..transit_common import (TransitError, BadHandshake, TransitClosed,
|
||||
BadNonce,
|
||||
build_receiver_handshake,
|
||||
build_sender_handshake,
|
||||
build_relay_handshake,
|
||||
parse_hint_tcp)
|
||||
|
||||
TIMEOUT=15
|
||||
|
||||
# 1: sender only transmits, receiver only accepts, both wait forever
|
||||
# 2: sender also accepts, receiver also transmits
|
||||
# 3: timeouts / stop when no more progress can be made
|
||||
# 4: add relay
|
||||
# 5: accelerate shutdown of losing sockets
|
||||
|
||||
def send_to(skt, data):
|
||||
sent = 0
|
||||
while sent < len(data):
|
||||
sent += skt.send(data[sent:])
|
||||
|
||||
def wait_for_line(skt, max_length, description):
|
||||
got = b""
|
||||
while len(got) < max_length:
|
||||
got += skt.recv(1)
|
||||
if got.endswith(b"\n"):
|
||||
return got[:-1]
|
||||
raise BadHandshake("exceeded max_length, got %r on %s" %
|
||||
(got, description))
|
||||
|
||||
def wait_for(skt, expected, description):
|
||||
assert isinstance(expected, type(b""))
|
||||
got = b""
|
||||
while len(got) < len(expected):
|
||||
got += skt.recv(1)
|
||||
if expected[:len(got)] != got:
|
||||
raise BadHandshake("got %r want %r on %s" %
|
||||
(got, expected, description))
|
||||
|
||||
def debug(msg):
|
||||
if False:
|
||||
print(msg)
|
||||
def since(start):
|
||||
return time.time() - start
|
||||
|
||||
def connector(owner, hint, description,
|
||||
send_handshake, expected_handshake, relay_handshake=None):
|
||||
start = time.time()
|
||||
parsed_hint = parse_hint_tcp(hint)
|
||||
if not parsed_hint:
|
||||
return # unparseable
|
||||
addr,port = parsed_hint
|
||||
skt = None
|
||||
debug("+ connector(%s)" % hint)
|
||||
try:
|
||||
skt = socket.create_connection((addr,port),
|
||||
TIMEOUT) # timeout or ECONNREFUSED
|
||||
skt.settimeout(TIMEOUT)
|
||||
debug(" - socket(%s) connected CT+%.1f" % (description, since(start)))
|
||||
if relay_handshake:
|
||||
debug(" - sending relay_handshake")
|
||||
send_to(skt, relay_handshake)
|
||||
relay_msg = wait_for_line(skt, 10000, description)
|
||||
if relay_msg != b"ok":
|
||||
raise BadHandshake(relay_msg)
|
||||
debug(" - relay ready CT+%.1f" % (since(start),))
|
||||
send_to(skt, send_handshake)
|
||||
wait_for(skt, expected_handshake, description)
|
||||
debug(" + connector(%s) ready CT+%.1f" % (hint, since(start)))
|
||||
except Exception as e:
|
||||
debug(" - error(%s)(%r) CT+%.1f" % (hint, e, since(start)))
|
||||
try:
|
||||
if skt:
|
||||
skt.shutdown(socket.SHUT_WR)
|
||||
except socket.error:
|
||||
pass
|
||||
if skt:
|
||||
skt.close()
|
||||
# ignore socket errors, warn about coding errors
|
||||
if not isinstance(e, (socket.error, socket.timeout, BadHandshake)):
|
||||
raise
|
||||
debug(" - notifying owner._connector_failed(%s) CT+%.1f" % (hint, since(start)))
|
||||
owner._connector_failed(hint)
|
||||
return
|
||||
# owner is now responsible for the socket
|
||||
owner._negotiation_finished(skt, description) # note thread
|
||||
|
||||
def handle(skt, client_address, owner, description,
|
||||
send_handshake, expected_handshake):
|
||||
try:
|
||||
debug("handle %r" % (skt,))
|
||||
skt.settimeout(TIMEOUT)
|
||||
send_to(skt, send_handshake)
|
||||
got = b""
|
||||
# for the receiver, this includes the "go\n"
|
||||
while len(got) < len(expected_handshake):
|
||||
more = skt.recv(1)
|
||||
if not more:
|
||||
raise BadHandshake("disconnect after merely '%r'" % got)
|
||||
got += more
|
||||
if expected_handshake[:len(got)] != got:
|
||||
raise BadHandshake("got '%r' want '%r'" %
|
||||
(got, expected_handshake))
|
||||
debug("handler negotiation finished %r" % (client_address,))
|
||||
except Exception as e:
|
||||
debug("handler failed %r" % (client_address,))
|
||||
try:
|
||||
# this raises socket.err(EBADF) if the socket was already closed
|
||||
skt.shutdown(socket.SHUT_WR)
|
||||
except socket.error:
|
||||
pass
|
||||
skt.close() # this appears to be idempotent
|
||||
# ignore socket errors, warn about coding errors
|
||||
if not isinstance(e, (socket.error, socket.timeout, BadHandshake)):
|
||||
raise
|
||||
return
|
||||
# owner is now responsible for the socket
|
||||
owner._negotiation_finished(skt, description) # note thread
|
||||
|
||||
class MyTCPServer(socketserver.TCPServer):
|
||||
allow_reuse_address = True
|
||||
|
||||
def process_request(self, request, client_address):
|
||||
description = "<-tcp:%s:%d" % (client_address[0], client_address[1])
|
||||
ready_lock = self.owner._ready_for_connections_lock
|
||||
ready_lock.acquire()
|
||||
while not (self.owner._ready_for_connections
|
||||
and self.owner._transit_key):
|
||||
ready_lock.wait()
|
||||
# owner._transit_key is either None or set to a value. We don't
|
||||
# modify it from here, so we can release the condition lock before
|
||||
# grabbing the key.
|
||||
ready_lock.release()
|
||||
|
||||
# Once it is set, we can get handler_(send|receive)_handshake, which
|
||||
# is what we actually care about.
|
||||
t = threading.Thread(target=handle,
|
||||
args=(request, client_address,
|
||||
self.owner, description,
|
||||
self.owner.handler_send_handshake,
|
||||
self.owner.handler_expected_handshake))
|
||||
t.daemon = True
|
||||
t.start()
|
||||
|
||||
|
||||
class ReceiveBuffer:
|
||||
def __init__(self, skt):
|
||||
self.skt = skt
|
||||
self.buf = b""
|
||||
|
||||
def read(self, count):
|
||||
while len(self.buf) < count:
|
||||
more = self.skt.recv(4096)
|
||||
if not more:
|
||||
raise TransitClosed
|
||||
self.buf += more
|
||||
rc = self.buf[:count]
|
||||
self.buf = self.buf[count:]
|
||||
return rc
|
||||
|
||||
class RecordPipe:
|
||||
def __init__(self, skt, send_key, receive_key, description):
|
||||
self.skt = skt
|
||||
self.send_box = SecretBox(send_key)
|
||||
self.send_nonce = 0
|
||||
self.receive_buf = ReceiveBuffer(self.skt)
|
||||
self.receive_box = SecretBox(receive_key)
|
||||
self.next_receive_nonce = 0
|
||||
self._description = description
|
||||
|
||||
def describe(self):
|
||||
return self._description
|
||||
|
||||
def send_record(self, record):
|
||||
if not isinstance(record, type(b"")): raise UsageError
|
||||
assert SecretBox.NONCE_SIZE == 24
|
||||
assert self.send_nonce < 2**(8*24)
|
||||
assert len(record) < 2**(8*4)
|
||||
nonce = unhexlify("%048x" % self.send_nonce) # big-endian
|
||||
self.send_nonce += 1
|
||||
encrypted = self.send_box.encrypt(record, nonce)
|
||||
length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long
|
||||
send_to(self.skt, length)
|
||||
send_to(self.skt, encrypted)
|
||||
|
||||
def receive_record(self):
|
||||
length_buf = self.receive_buf.read(4)
|
||||
length = int(hexlify(length_buf), 16)
|
||||
encrypted = self.receive_buf.read(length)
|
||||
nonce_buf = encrypted[:SecretBox.NONCE_SIZE] # assume it's prepended
|
||||
nonce = int(hexlify(nonce_buf), 16)
|
||||
if nonce != self.next_receive_nonce:
|
||||
raise BadNonce("received out-of-order record")
|
||||
self.next_receive_nonce += 1
|
||||
record = self.receive_box.decrypt(encrypted)
|
||||
return record
|
||||
|
||||
def close(self):
|
||||
self.skt.close()
|
||||
|
||||
class Common:
|
||||
def __init__(self, transit_relay, no_listen=False, timing=None):
|
||||
if transit_relay:
|
||||
if not isinstance(transit_relay, type(u"")):
|
||||
raise UsageError
|
||||
self._transit_relays = [transit_relay]
|
||||
else:
|
||||
self._transit_relays = []
|
||||
self._no_listen = no_listen
|
||||
self._timing = timing or DebugTiming()
|
||||
self._timing_started = self._timing.add_event("transit")
|
||||
self.winning = threading.Event()
|
||||
self._negotiation_check_lock = threading.Lock()
|
||||
self._ready_for_connections_lock = threading.Condition()
|
||||
self._ready_for_connections = False
|
||||
self._transit_key = None
|
||||
self._start_server()
|
||||
|
||||
def _start_server(self):
|
||||
if self._no_listen:
|
||||
self.my_direct_hints = []
|
||||
self.listener = None
|
||||
return
|
||||
server = MyTCPServer(("", 0), None)
|
||||
_, port = server.server_address
|
||||
self.my_direct_hints = [u"tcp:%s:%d" % (addr, port)
|
||||
for addr in ipaddrs.find_addresses()]
|
||||
server.owner = self
|
||||
server_thread = threading.Thread(target=server.serve_forever)
|
||||
server_thread.daemon = True
|
||||
server_thread.start()
|
||||
self.listener = server
|
||||
|
||||
def get_direct_hints(self):
|
||||
return self.my_direct_hints
|
||||
def get_relay_hints(self):
|
||||
return self._transit_relays
|
||||
|
||||
def add_their_direct_hints(self, hints):
|
||||
for h in hints:
|
||||
if not isinstance(h, type(u"")):
|
||||
raise TypeError("hint '%r' should be unicode, not %s"
|
||||
% (h, type(h)))
|
||||
self._their_direct_hints = list(hints)
|
||||
def add_their_relay_hints(self, hints):
|
||||
for h in hints:
|
||||
if not isinstance(h, type(u"")):
|
||||
raise TypeError("hint '%r' should be unicode, not %s"
|
||||
% (h, type(h)))
|
||||
self._their_relay_hints = list(hints)
|
||||
|
||||
def _send_this(self):
|
||||
if self.is_sender:
|
||||
return build_sender_handshake(self._transit_key)
|
||||
else:
|
||||
return build_receiver_handshake(self._transit_key)
|
||||
|
||||
def _expect_this(self):
|
||||
if self.is_sender:
|
||||
return build_receiver_handshake(self._transit_key)
|
||||
else:
|
||||
return build_sender_handshake(self._transit_key) + b"go\n"
|
||||
|
||||
def _sender_record_key(self):
|
||||
if self.is_sender:
|
||||
return HKDF(self._transit_key, SecretBox.KEY_SIZE,
|
||||
CTXinfo=b"transit_record_sender_key")
|
||||
else:
|
||||
return HKDF(self._transit_key, SecretBox.KEY_SIZE,
|
||||
CTXinfo=b"transit_record_receiver_key")
|
||||
|
||||
def _receiver_record_key(self):
|
||||
if self.is_sender:
|
||||
return HKDF(self._transit_key, SecretBox.KEY_SIZE,
|
||||
CTXinfo=b"transit_record_receiver_key")
|
||||
else:
|
||||
return HKDF(self._transit_key, SecretBox.KEY_SIZE,
|
||||
CTXinfo=b"transit_record_sender_key")
|
||||
|
||||
def set_transit_key(self, key):
|
||||
# This _ready_for_connections condition/lock protects us against the
|
||||
# race where the sender knows the hints and the key, and connects to
|
||||
# the receiver's transit socket before the receiver gets relay
|
||||
# message (and thus the key).
|
||||
self._ready_for_connections_lock.acquire()
|
||||
self._transit_key = key
|
||||
self.handler_send_handshake = self._send_this() # no "go"
|
||||
self.handler_expected_handshake = self._expect_this()
|
||||
self._ready_for_connections_lock.notify_all()
|
||||
self._ready_for_connections_lock.release()
|
||||
|
||||
def _start_outbound(self):
|
||||
self._active_connectors = set(self._their_direct_hints)
|
||||
self._attempted_connectors = set()
|
||||
for hint in self._their_direct_hints:
|
||||
self._start_connector(hint)
|
||||
if not self._their_direct_hints:
|
||||
self._start_relay_connectors()
|
||||
|
||||
def _start_connector(self, hint, is_relay=False):
|
||||
# Don't try any hint more than once. If all hints fail, we'll
|
||||
# eventually timeout. We make no attempt to fail any faster.
|
||||
if hint in self._attempted_connectors:
|
||||
return
|
||||
self._attempted_connectors.add(hint)
|
||||
description = "->%s" % (hint,)
|
||||
if is_relay:
|
||||
description = "->relay:%s" % (hint,)
|
||||
args = (self, hint, description,
|
||||
self._send_this(), self._expect_this())
|
||||
if is_relay:
|
||||
args = args + (build_relay_handshake(self._transit_key),)
|
||||
t = threading.Thread(target=connector, args=args)
|
||||
t.daemon = True
|
||||
t.start()
|
||||
|
||||
def _start_relay_connectors(self):
|
||||
self._active_connectors.update(self._their_direct_hints)
|
||||
for hint in self._their_relay_hints:
|
||||
self._start_connector(hint, is_relay=True)
|
||||
|
||||
def establish_socket(self):
|
||||
start = time.time()
|
||||
self.winning_skt = None
|
||||
self.winning_skt_description = None
|
||||
self._ready_for_connections_lock.acquire()
|
||||
self._ready_for_connections = True
|
||||
self._ready_for_connections_lock.notify_all()
|
||||
self._ready_for_connections_lock.release()
|
||||
self._start_outbound()
|
||||
|
||||
# we sit here until one of our inbound or outbound sockets succeeds
|
||||
flag = self.winning.wait(2*TIMEOUT)
|
||||
debug("wait returned at %.1f" % (since(start),))
|
||||
|
||||
if not flag:
|
||||
# timeout: self.winning_skt will not be set. ish. race.
|
||||
pass
|
||||
if self.listener:
|
||||
self.listener.shutdown() # TODO: waits up to 0.5s. push to thread
|
||||
if self.winning_skt:
|
||||
return self.winning_skt
|
||||
raise TransitError("timeout")
|
||||
|
||||
def _connector_failed(self, hint):
|
||||
debug("- failed connector %s" % hint)
|
||||
# XXX this was .remove, and occasionally got KeyError
|
||||
self._active_connectors.discard(hint)
|
||||
if not self._active_connectors:
|
||||
self._start_relay_connectors()
|
||||
|
||||
def _negotiation_finished(self, skt, description):
|
||||
# inbound/outbound sockets call this when they finish negotiation.
|
||||
# The first one wins and gets a "go". Any subsequent ones lose and
|
||||
# get a "nevermind" before being closed.
|
||||
|
||||
with self._negotiation_check_lock:
|
||||
if self.winning_skt:
|
||||
is_winner = False
|
||||
else:
|
||||
is_winner = True
|
||||
self.winning_skt = skt
|
||||
self.winning_skt_description = description
|
||||
|
||||
if is_winner:
|
||||
if self.is_sender:
|
||||
send_to(skt, b"go\n")
|
||||
self.winning.set()
|
||||
else:
|
||||
if self.is_sender:
|
||||
try:
|
||||
send_to(skt, b"nevermind\n")
|
||||
except socket.error:
|
||||
# They realized this connection is not going to win, and
|
||||
# closed it so fast we didn't get a chance to tell them
|
||||
# it lost. This happens in unit tests.
|
||||
pass
|
||||
skt.close()
|
||||
|
||||
def connect(self):
|
||||
_start = self._timing.add_event("transit connect")
|
||||
skt = self.establish_socket()
|
||||
self._timing.finish_event(_start)
|
||||
return RecordPipe(skt, self._sender_record_key(),
|
||||
self._receiver_record_key(),
|
||||
self.winning_skt_description)
|
||||
|
||||
class TransitSender(Common):
|
||||
is_sender = True
|
||||
|
||||
class TransitReceiver(Common):
|
||||
is_sender = False
|
|
@ -27,8 +27,6 @@ g.add_argument("--hide-progress", action="store_true",
|
|||
help="supress progress-bar display")
|
||||
g.add_argument("--dump-timing", type=type(u""), # TODO: hide from --help output
|
||||
metavar="FILE", help="(debug) write timing data to file")
|
||||
g.add_argument("--twisted", action="store_true",
|
||||
help="use Twisted-based implementations, for testing")
|
||||
g.add_argument("--no-listen", action="store_true",
|
||||
help="(debug) don't open a listening socket for Transit")
|
||||
g.add_argument("--tor", action="store_true",
|
||||
|
|
|
@ -1,13 +1,19 @@
|
|||
from __future__ import print_function
|
||||
import io, json
|
||||
import io, os, sys, json, binascii, six, tempfile, zipfile
|
||||
from twisted.internet import reactor, defer
|
||||
from twisted.internet.defer import inlineCallbacks, returnValue
|
||||
from ..twisted.transcribe import Wormhole, WrongPasswordError
|
||||
from ..twisted.transit import TransitReceiver
|
||||
from .cmd_receive_blocking import BlockingReceiver, RespondError, APPID
|
||||
from ..errors import TransferError
|
||||
from .progress import ProgressPrinter
|
||||
|
||||
|
||||
APPID = u"lothar.com/wormhole/text-or-file-xfer"
|
||||
|
||||
class RespondError(Exception):
|
||||
def __init__(self, response):
|
||||
self.response = response
|
||||
|
||||
def receive_twisted_sync(args):
|
||||
# try to use twisted.internet.task.react(f) here (but it calls sys.exit
|
||||
# directly)
|
||||
|
@ -34,7 +40,14 @@ def receive_twisted_sync(args):
|
|||
def receive_twisted(args):
|
||||
return TwistedReceiver(args).go()
|
||||
|
||||
class TwistedReceiver(BlockingReceiver):
|
||||
|
||||
class TwistedReceiver:
|
||||
def __init__(self, args):
|
||||
assert isinstance(args.relay_url, type(u""))
|
||||
self.args = args
|
||||
|
||||
def msg(self, *args, **kwargs):
|
||||
print(*args, file=self.args.stdout, **kwargs)
|
||||
|
||||
# TODO: @handle_server_error
|
||||
@inlineCallbacks
|
||||
|
@ -101,6 +114,11 @@ class TwistedReceiver(BlockingReceiver):
|
|||
self.args.code_length)
|
||||
yield w.set_code(code)
|
||||
|
||||
def show_verifier(self, verifier):
|
||||
verifier_hex = binascii.hexlify(verifier).decode("ascii")
|
||||
if self.args.verify:
|
||||
self.msg(u"Verifier %s." % verifier_hex)
|
||||
|
||||
@inlineCallbacks
|
||||
def get_data(self, w):
|
||||
try:
|
||||
|
@ -119,6 +137,61 @@ class TwistedReceiver(BlockingReceiver):
|
|||
data = json.dumps({"message_ack": "ok"}).encode("utf-8")
|
||||
yield w.send_data(data)
|
||||
|
||||
def handle_file(self, them_d):
|
||||
file_data = them_d["file"]
|
||||
self.abs_destname = self.decide_destname("file",
|
||||
file_data["filename"])
|
||||
self.xfersize = file_data["filesize"]
|
||||
|
||||
self.msg(u"Receiving file (%d bytes) into: %s" %
|
||||
(self.xfersize, os.path.basename(self.abs_destname)))
|
||||
self.ask_permission()
|
||||
tmp_destname = self.abs_destname + ".tmp"
|
||||
return open(tmp_destname, "wb")
|
||||
|
||||
def handle_directory(self, them_d):
|
||||
file_data = them_d["directory"]
|
||||
zipmode = file_data["mode"]
|
||||
if zipmode != "zipfile/deflated":
|
||||
self.msg(u"Error: unknown directory-transfer mode '%s'" % (zipmode,))
|
||||
raise RespondError({"error": "unknown mode"})
|
||||
self.abs_destname = self.decide_destname("directory",
|
||||
file_data["dirname"])
|
||||
self.xfersize = file_data["zipsize"]
|
||||
|
||||
self.msg(u"Receiving directory (%d bytes) into: %s/" %
|
||||
(self.xfersize, os.path.basename(self.abs_destname)))
|
||||
self.msg(u"%d files, %d bytes (uncompressed)" %
|
||||
(file_data["numfiles"], file_data["numbytes"]))
|
||||
self.ask_permission()
|
||||
return tempfile.SpooledTemporaryFile()
|
||||
|
||||
def decide_destname(self, mode, destname):
|
||||
# the basename() is intended to protect us against
|
||||
# "~/.ssh/authorized_keys" and other attacks
|
||||
destname = os.path.basename(destname)
|
||||
if self.args.output_file:
|
||||
destname = self.args.output_file # override
|
||||
abs_destname = os.path.join(self.args.cwd, destname)
|
||||
|
||||
# get confirmation from the user before writing to the local directory
|
||||
if os.path.exists(abs_destname):
|
||||
self.msg(u"Error: refusing to overwrite existing %s %s" %
|
||||
(mode, destname))
|
||||
raise RespondError({"error": "%s already exists" % mode})
|
||||
return abs_destname
|
||||
|
||||
def ask_permission(self):
|
||||
_start = self.args.timing.add_event("permission", waiting="user")
|
||||
while True and not self.args.accept_file:
|
||||
ok = six.moves.input("ok? (y/n): ")
|
||||
if ok.lower().startswith("y"):
|
||||
break
|
||||
print(u"transfer rejected", file=sys.stderr)
|
||||
self.args.timing.finish_event(_start, answer="no")
|
||||
raise RespondError({"error": "transfer rejected"})
|
||||
self.args.timing.finish_event(_start, answer="yes")
|
||||
|
||||
@inlineCallbacks
|
||||
def establish_transit(self, w, them_d, tor_manager):
|
||||
transit_key = w.derive_key(APPID+u"/transit-key")
|
||||
|
@ -169,6 +242,27 @@ class TwistedReceiver(BlockingReceiver):
|
|||
returnValue(1) # TODO: exit properly
|
||||
assert received == self.xfersize
|
||||
|
||||
def write_file(self, f):
|
||||
tmp_name = f.name
|
||||
f.close()
|
||||
os.rename(tmp_name, self.abs_destname)
|
||||
self.msg(u"Received file written to %s" %
|
||||
os.path.basename(self.abs_destname))
|
||||
|
||||
def write_directory(self, f):
|
||||
self.msg(u"Unpacking zipfile..")
|
||||
_start = self.args.timing.add_event("unpack zip")
|
||||
with zipfile.ZipFile(f, "r", zipfile.ZIP_DEFLATED) as zf:
|
||||
zf.extractall(path=self.abs_destname)
|
||||
# extractall() appears to offer some protection against
|
||||
# malicious pathnames. For example, "/tmp/oops" and
|
||||
# "../tmp/oops" both do the same thing as the (safe)
|
||||
# "tmp/oops".
|
||||
self.msg(u"Received files written to %s/" %
|
||||
os.path.basename(self.abs_destname))
|
||||
f.close()
|
||||
self.args.timing.finish_event(_start)
|
||||
|
||||
@inlineCallbacks
|
||||
def close_transit(self, record_pipe):
|
||||
_start = self.args.timing.add_event("ack")
|
|
@ -1,216 +0,0 @@
|
|||
from __future__ import print_function
|
||||
import io, os, sys, json, binascii, six, tempfile, zipfile
|
||||
from ..blocking.transcribe import Wormhole, WrongPasswordError
|
||||
from ..blocking.transit import TransitReceiver, TransitError
|
||||
from ..errors import handle_server_error, TransferError
|
||||
from .progress import ProgressPrinter
|
||||
|
||||
APPID = u"lothar.com/wormhole/text-or-file-xfer"
|
||||
|
||||
def receive_blocking(args):
|
||||
return BlockingReceiver(args).go()
|
||||
|
||||
class RespondError(Exception):
|
||||
def __init__(self, response):
|
||||
self.response = response
|
||||
|
||||
class BlockingReceiver:
|
||||
def __init__(self, args):
|
||||
assert isinstance(args.relay_url, type(u""))
|
||||
self.args = args
|
||||
|
||||
def msg(self, *args, **kwargs):
|
||||
print(*args, file=self.args.stdout, **kwargs)
|
||||
|
||||
@handle_server_error
|
||||
def go(self):
|
||||
with Wormhole(APPID, self.args.relay_url, timing=self.args.timing) as w:
|
||||
self.handle_code(w)
|
||||
verifier = w.get_verifier()
|
||||
self.show_verifier(verifier)
|
||||
them_d = self.get_data(w)
|
||||
try:
|
||||
if "message" in them_d:
|
||||
self.handle_text(them_d, w)
|
||||
return 0
|
||||
if "file" in them_d:
|
||||
f = self.handle_file(them_d)
|
||||
rp = self.establish_transit(w, them_d)
|
||||
self.transfer_data(rp, f)
|
||||
self.write_file(f)
|
||||
self.close_transit(rp)
|
||||
elif "directory" in them_d:
|
||||
f = self.handle_directory(them_d)
|
||||
rp = self.establish_transit(w, them_d)
|
||||
self.transfer_data(rp, f)
|
||||
self.write_directory(f)
|
||||
self.close_transit(rp)
|
||||
else:
|
||||
self.msg(u"I don't know what they're offering\n")
|
||||
self.msg(u"Offer details:", them_d)
|
||||
raise RespondError({"error": "unknown offer type"})
|
||||
except RespondError as r:
|
||||
data = json.dumps(r.response).encode("utf-8")
|
||||
w.send_data(data)
|
||||
return 1
|
||||
return 0
|
||||
|
||||
def handle_code(self, w):
|
||||
code = self.args.code
|
||||
if self.args.zeromode:
|
||||
assert not code
|
||||
code = u"0-"
|
||||
if not code:
|
||||
code = w.input_code("Enter receive wormhole code: ",
|
||||
self.args.code_length)
|
||||
w.set_code(code)
|
||||
|
||||
def show_verifier(self, verifier):
|
||||
verifier_hex = binascii.hexlify(verifier).decode("ascii")
|
||||
if self.args.verify:
|
||||
self.msg(u"Verifier %s." % verifier_hex)
|
||||
|
||||
def get_data(self, w):
|
||||
try:
|
||||
them_bytes = w.get_data()
|
||||
except WrongPasswordError as e:
|
||||
raise TransferError(u"ERROR: " + e.explain())
|
||||
them_d = json.loads(them_bytes.decode("utf-8"))
|
||||
if "error" in them_d:
|
||||
raise TransferError(u"ERROR: " + them_d["error"])
|
||||
return them_d
|
||||
|
||||
def handle_text(self, them_d, w):
|
||||
# we're receiving a text message
|
||||
self.msg(them_d["message"])
|
||||
data = json.dumps({"message_ack": "ok"}).encode("utf-8")
|
||||
w.send_data(data)
|
||||
|
||||
def handle_file(self, them_d):
|
||||
file_data = them_d["file"]
|
||||
self.abs_destname = self.decide_destname("file",
|
||||
file_data["filename"])
|
||||
self.xfersize = file_data["filesize"]
|
||||
|
||||
self.msg(u"Receiving file (%d bytes) into: %s" %
|
||||
(self.xfersize, os.path.basename(self.abs_destname)))
|
||||
self.ask_permission()
|
||||
tmp_destname = self.abs_destname + ".tmp"
|
||||
return open(tmp_destname, "wb")
|
||||
|
||||
def handle_directory(self, them_d):
|
||||
file_data = them_d["directory"]
|
||||
zipmode = file_data["mode"]
|
||||
if zipmode != "zipfile/deflated":
|
||||
self.msg(u"Error: unknown directory-transfer mode '%s'" % (zipmode,))
|
||||
raise RespondError({"error": "unknown mode"})
|
||||
self.abs_destname = self.decide_destname("directory",
|
||||
file_data["dirname"])
|
||||
self.xfersize = file_data["zipsize"]
|
||||
|
||||
self.msg(u"Receiving directory (%d bytes) into: %s/" %
|
||||
(self.xfersize, os.path.basename(self.abs_destname)))
|
||||
self.msg(u"%d files, %d bytes (uncompressed)" %
|
||||
(file_data["numfiles"], file_data["numbytes"]))
|
||||
self.ask_permission()
|
||||
return tempfile.SpooledTemporaryFile()
|
||||
|
||||
def decide_destname(self, mode, destname):
|
||||
# the basename() is intended to protect us against
|
||||
# "~/.ssh/authorized_keys" and other attacks
|
||||
destname = os.path.basename(destname)
|
||||
if self.args.output_file:
|
||||
destname = self.args.output_file # override
|
||||
abs_destname = os.path.join(self.args.cwd, destname)
|
||||
|
||||
# get confirmation from the user before writing to the local directory
|
||||
if os.path.exists(abs_destname):
|
||||
self.msg(u"Error: refusing to overwrite existing %s %s" %
|
||||
(mode, destname))
|
||||
raise RespondError({"error": "%s already exists" % mode})
|
||||
return abs_destname
|
||||
|
||||
def ask_permission(self):
|
||||
_start = self.args.timing.add_event("permission", waiting="user")
|
||||
while True and not self.args.accept_file:
|
||||
ok = six.moves.input("ok? (y/n): ")
|
||||
if ok.lower().startswith("y"):
|
||||
break
|
||||
print(u"transfer rejected", file=sys.stderr)
|
||||
self.args.timing.finish_event(_start, answer="no")
|
||||
raise RespondError({"error": "transfer rejected"})
|
||||
self.args.timing.finish_event(_start, answer="yes")
|
||||
|
||||
def establish_transit(self, w, them_d):
|
||||
transit_key = w.derive_key(APPID+u"/transit-key")
|
||||
transit_receiver = TransitReceiver(self.args.transit_helper,
|
||||
no_listen=self.args.no_listen,
|
||||
timing=self.args.timing)
|
||||
transit_receiver.set_transit_key(transit_key)
|
||||
data = json.dumps({
|
||||
"file_ack": "ok",
|
||||
"transit": {
|
||||
"direct_connection_hints": transit_receiver.get_direct_hints(),
|
||||
"relay_connection_hints": transit_receiver.get_relay_hints(),
|
||||
},
|
||||
}).encode("utf-8")
|
||||
w.send_data(data)
|
||||
|
||||
# now receive the rest of the owl
|
||||
tdata = them_d["transit"]
|
||||
transit_receiver.add_their_direct_hints(tdata["direct_connection_hints"])
|
||||
transit_receiver.add_their_relay_hints(tdata["relay_connection_hints"])
|
||||
record_pipe = transit_receiver.connect()
|
||||
return record_pipe
|
||||
|
||||
def transfer_data(self, record_pipe, f):
|
||||
self.msg(u"Receiving (%s).." % record_pipe.describe())
|
||||
|
||||
_start = self.args.timing.add_event("rx file")
|
||||
progress_stdout = self.args.stdout
|
||||
if self.args.hide_progress:
|
||||
progress_stdout = io.StringIO()
|
||||
received = 0
|
||||
p = ProgressPrinter(self.xfersize, progress_stdout)
|
||||
p.start()
|
||||
while received < self.xfersize:
|
||||
try:
|
||||
plaintext = record_pipe.receive_record()
|
||||
except TransitError:
|
||||
self.msg()
|
||||
self.msg(u"Connection dropped before full file received")
|
||||
self.msg(u"got %d bytes, wanted %d" % (received, self.xfersize))
|
||||
return 1
|
||||
f.write(plaintext)
|
||||
received += len(plaintext)
|
||||
p.update(received)
|
||||
p.finish()
|
||||
self.args.timing.finish_event(_start)
|
||||
assert received == self.xfersize
|
||||
|
||||
def write_file(self, f):
|
||||
tmp_name = f.name
|
||||
f.close()
|
||||
os.rename(tmp_name, self.abs_destname)
|
||||
self.msg(u"Received file written to %s" %
|
||||
os.path.basename(self.abs_destname))
|
||||
|
||||
def write_directory(self, f):
|
||||
self.msg(u"Unpacking zipfile..")
|
||||
_start = self.args.timing.add_event("unpack zip")
|
||||
with zipfile.ZipFile(f, "r", zipfile.ZIP_DEFLATED) as zf:
|
||||
zf.extractall(path=self.abs_destname)
|
||||
# extractall() appears to offer some protection against
|
||||
# malicious pathnames. For example, "/tmp/oops" and
|
||||
# "../tmp/oops" both do the same thing as the (safe)
|
||||
# "tmp/oops".
|
||||
self.msg(u"Received files written to %s/" %
|
||||
os.path.basename(self.abs_destname))
|
||||
f.close()
|
||||
self.args.timing.finish_event(_start)
|
||||
|
||||
def close_transit(self, record_pipe):
|
||||
_start = self.args.timing.add_event("ack")
|
||||
record_pipe.send_record(b"ok\n")
|
||||
record_pipe.close()
|
||||
self.args.timing.finish_event(_start)
|
|
@ -1,5 +1,5 @@
|
|||
from __future__ import print_function
|
||||
import io, json, binascii, six
|
||||
import os, sys, io, json, binascii, six, tempfile, zipfile
|
||||
from twisted.protocols import basic
|
||||
from twisted.internet import reactor, defer
|
||||
from twisted.internet.defer import inlineCallbacks, returnValue
|
||||
|
@ -7,8 +7,94 @@ from ..errors import TransferError
|
|||
from .progress import ProgressPrinter
|
||||
from ..twisted.transcribe import Wormhole, WrongPasswordError
|
||||
from ..twisted.transit import TransitSender
|
||||
from .send_common import (APPID, handle_zero, build_other_command,
|
||||
build_phase1_data)
|
||||
|
||||
APPID = u"lothar.com/wormhole/text-or-file-xfer"
|
||||
|
||||
def handle_zero(args):
|
||||
if args.zeromode:
|
||||
assert not args.code
|
||||
args.code = u"0-"
|
||||
|
||||
def build_other_command(args):
|
||||
other_cmd = "wormhole receive"
|
||||
if args.verify:
|
||||
other_cmd = "wormhole --verify receive"
|
||||
if args.zeromode:
|
||||
other_cmd += " -0"
|
||||
return other_cmd
|
||||
|
||||
def build_phase1_data(args):
|
||||
phase1 = {}
|
||||
|
||||
text = args.text
|
||||
if text == "-":
|
||||
print(u"Reading text message from stdin..", file=args.stdout)
|
||||
text = sys.stdin.read()
|
||||
if not text and not args.what:
|
||||
text = six.moves.input("Text to send: ")
|
||||
|
||||
if text is not None:
|
||||
print(u"Sending text message (%d bytes)" % len(text), file=args.stdout)
|
||||
phase1 = { "message": text }
|
||||
fd_to_send = None
|
||||
return phase1, fd_to_send
|
||||
|
||||
what = os.path.join(args.cwd, args.what)
|
||||
what = what.rstrip(os.sep)
|
||||
if not os.path.exists(what):
|
||||
raise TransferError("Cannot send: no file/directory named '%s'" %
|
||||
args.what)
|
||||
basename = os.path.basename(what)
|
||||
|
||||
if os.path.isfile(what):
|
||||
# we're sending a file
|
||||
filesize = os.stat(what).st_size
|
||||
phase1["file"] = {
|
||||
"filename": basename,
|
||||
"filesize": filesize,
|
||||
}
|
||||
print(u"Sending %d byte file named '%s'" % (filesize, basename),
|
||||
file=args.stdout)
|
||||
fd_to_send = open(what, "rb")
|
||||
return phase1, fd_to_send
|
||||
|
||||
if os.path.isdir(what):
|
||||
print(u"Building zipfile..", file=args.stdout)
|
||||
# We're sending a directory. Create a zipfile in a tempdir and
|
||||
# send that.
|
||||
fd_to_send = tempfile.SpooledTemporaryFile()
|
||||
# TODO: I think ZIP_DEFLATED means compressed.. check it
|
||||
num_files = 0
|
||||
num_bytes = 0
|
||||
tostrip = len(what.split(os.sep))
|
||||
with zipfile.ZipFile(fd_to_send, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||
for path,dirs,files in os.walk(what):
|
||||
# path always starts with args.what, then sometimes might
|
||||
# have "/subdir" appended. We want the zipfile to contain
|
||||
# "" or "subdir"
|
||||
localpath = list(path.split(os.sep)[tostrip:])
|
||||
for fn in files:
|
||||
archivename = os.path.join(*tuple(localpath+[fn]))
|
||||
localfilename = os.path.join(path, fn)
|
||||
zf.write(localfilename, archivename)
|
||||
num_bytes += os.stat(localfilename).st_size
|
||||
num_files += 1
|
||||
fd_to_send.seek(0,2)
|
||||
filesize = fd_to_send.tell()
|
||||
fd_to_send.seek(0,0)
|
||||
phase1["directory"] = {
|
||||
"mode": "zipfile/deflated",
|
||||
"dirname": basename,
|
||||
"zipsize": filesize,
|
||||
"numbytes": num_bytes,
|
||||
"numfiles": num_files,
|
||||
}
|
||||
print(u"Sending directory (%d bytes compressed) named '%s'"
|
||||
% (filesize, basename), file=args.stdout)
|
||||
return phase1, fd_to_send
|
||||
|
||||
raise TypeError("'%s' is neither file nor directory" % args.what)
|
||||
|
||||
|
||||
def send_twisted_sync(args):
|
||||
# try to use twisted.internet.task.react(f) here (but it calls sys.exit
|
|
@ -1,129 +0,0 @@
|
|||
from __future__ import print_function
|
||||
import json, binascii, six
|
||||
from ..errors import TransferError
|
||||
from .progress import ProgressPrinter
|
||||
from ..blocking.transcribe import Wormhole, WrongPasswordError
|
||||
from ..blocking.transit import TransitSender
|
||||
from ..errors import handle_server_error
|
||||
from .send_common import (APPID, handle_zero, build_other_command,
|
||||
build_phase1_data)
|
||||
|
||||
@handle_server_error
|
||||
def send_blocking(args):
|
||||
assert isinstance(args.relay_url, type(u""))
|
||||
handle_zero(args)
|
||||
phase1, fd_to_send = build_phase1_data(args)
|
||||
other_cmd = build_other_command(args)
|
||||
print(u"On the other computer, please run: %s" % other_cmd,
|
||||
file=args.stdout)
|
||||
|
||||
if fd_to_send is not None:
|
||||
transit_sender = TransitSender(args.transit_helper,
|
||||
no_listen=args.no_listen,
|
||||
timing=args.timing)
|
||||
transit_data = {
|
||||
"direct_connection_hints": transit_sender.get_direct_hints(),
|
||||
"relay_connection_hints": transit_sender.get_relay_hints(),
|
||||
}
|
||||
phase1["transit"] = transit_data
|
||||
|
||||
with Wormhole(APPID, args.relay_url, timing=args.timing) as w:
|
||||
if args.code:
|
||||
w.set_code(args.code)
|
||||
code = args.code
|
||||
else:
|
||||
code = w.get_code(args.code_length)
|
||||
if not args.zeromode:
|
||||
print(u"Wormhole code is: %s" % code, file=args.stdout)
|
||||
print(u"", file=args.stdout)
|
||||
|
||||
# get the verifier, because that also lets us derive the transit key,
|
||||
# which we want to set before revealing the connection hints to the
|
||||
# far side, so we'll be ready for them when they connect
|
||||
verifier = binascii.hexlify(w.get_verifier()).decode("ascii")
|
||||
if args.verify:
|
||||
_do_verify(verifier, w)
|
||||
|
||||
if fd_to_send is not None:
|
||||
transit_key = w.derive_key(APPID+"/transit-key")
|
||||
transit_sender.set_transit_key(transit_key)
|
||||
|
||||
my_phase1_bytes = json.dumps(phase1).encode("utf-8")
|
||||
w.send_data(my_phase1_bytes)
|
||||
try:
|
||||
them_phase1_bytes = w.get_data()
|
||||
except WrongPasswordError as e:
|
||||
raise TransferError(e.explain())
|
||||
|
||||
them_phase1 = json.loads(them_phase1_bytes.decode("utf-8"))
|
||||
|
||||
if fd_to_send is None:
|
||||
if them_phase1["message_ack"] == "ok":
|
||||
print(u"text message sent", file=args.stdout)
|
||||
return 0
|
||||
raise TransferError("error sending text: %r" % (them_phase1,))
|
||||
|
||||
return _send_file_blocking(them_phase1, fd_to_send,
|
||||
transit_sender, args.stdout, args.hide_progress,
|
||||
args.timing)
|
||||
|
||||
def _do_verify(verifier, w):
|
||||
while True:
|
||||
ok = six.moves.input("Verifier %s. ok? (yes/no): " % verifier)
|
||||
if ok.lower() == "yes":
|
||||
break
|
||||
if ok.lower() == "no":
|
||||
reject_data = json.dumps({"error": "verification rejected",
|
||||
}).encode("utf-8")
|
||||
w.send_data(reject_data)
|
||||
raise TransferError("verification rejected, abandoning transfer")
|
||||
|
||||
def _send_file_blocking(them_phase1, fd_to_send, transit_sender,
|
||||
stdout, hide_progress, timing):
|
||||
|
||||
# we're sending a file, if they accept it
|
||||
|
||||
if "error" in them_phase1:
|
||||
raise TransferError("remote error, transfer abandoned: %s"
|
||||
% them_phase1["error"])
|
||||
if them_phase1.get("file_ack") != "ok":
|
||||
raise TransferError("ambiguous response from remote, "
|
||||
"transfer abandoned: %s" % (them_phase1,))
|
||||
|
||||
tdata = them_phase1["transit"]
|
||||
transit_sender.add_their_direct_hints(tdata["direct_connection_hints"])
|
||||
transit_sender.add_their_relay_hints(tdata["relay_connection_hints"])
|
||||
record_pipe = transit_sender.connect()
|
||||
|
||||
print(u"Sending (%s).." % record_pipe.describe(), file=stdout)
|
||||
|
||||
_start = timing.add_event("tx file")
|
||||
CHUNKSIZE = 64*1024
|
||||
fd_to_send.seek(0,2)
|
||||
filesize = fd_to_send.tell()
|
||||
fd_to_send.seek(0,0)
|
||||
p = ProgressPrinter(filesize, stdout)
|
||||
with fd_to_send as f:
|
||||
sent = 0
|
||||
if not hide_progress:
|
||||
p.start()
|
||||
while sent < filesize:
|
||||
plaintext = f.read(CHUNKSIZE)
|
||||
record_pipe.send_record(plaintext)
|
||||
sent += len(plaintext)
|
||||
if not hide_progress:
|
||||
p.update(sent)
|
||||
if not hide_progress:
|
||||
p.finish()
|
||||
timing.finish_event(_start)
|
||||
|
||||
_start = timing.add_event("get ack")
|
||||
print(u"File sent.. waiting for confirmation", file=stdout)
|
||||
ack = record_pipe.receive_record()
|
||||
record_pipe.close()
|
||||
if ack == b"ok\n":
|
||||
print(u"Confirmation received. Transfer complete.", file=stdout)
|
||||
timing.finish_event(_start, ack="ok")
|
||||
return 0
|
||||
timing.finish_event(_start, ack="failed")
|
||||
raise TransferError("Transfer failed (remote says: %r)" % ack)
|
|
@ -21,22 +21,14 @@ def dispatch(args):
|
|||
from ..servers import cmd_usage
|
||||
return cmd_usage.tail_usage(args)
|
||||
|
||||
if args.tor:
|
||||
args.twisted = True
|
||||
if args.func == "send/send":
|
||||
if args.twisted:
|
||||
from . import cmd_send_twisted
|
||||
return cmd_send_twisted.send_twisted_sync(args)
|
||||
from . import cmd_send_blocking
|
||||
return cmd_send_blocking.send_blocking(args)
|
||||
from . import cmd_send
|
||||
return cmd_send.send_twisted_sync(args)
|
||||
if args.func == "receive/receive":
|
||||
if args.twisted:
|
||||
_start = args.timing.add_event("import c_r_t")
|
||||
from . import cmd_receive_twisted
|
||||
from . import cmd_receive
|
||||
args.timing.finish_event(_start)
|
||||
return cmd_receive_twisted.receive_twisted_sync(args)
|
||||
from . import cmd_receive_blocking
|
||||
return cmd_receive_blocking.receive_blocking(args)
|
||||
return cmd_receive.receive_twisted_sync(args)
|
||||
|
||||
raise ValueError("unknown args.func %s" % args.func)
|
||||
|
||||
|
|
|
@ -1,90 +0,0 @@
|
|||
from __future__ import print_function
|
||||
import os, sys, six, tempfile, zipfile
|
||||
from ..errors import TransferError
|
||||
|
||||
APPID = u"lothar.com/wormhole/text-or-file-xfer"
|
||||
|
||||
def handle_zero(args):
|
||||
if args.zeromode:
|
||||
assert not args.code
|
||||
args.code = u"0-"
|
||||
|
||||
def build_other_command(args):
|
||||
other_cmd = "wormhole receive"
|
||||
if args.verify:
|
||||
other_cmd = "wormhole --verify receive"
|
||||
if args.zeromode:
|
||||
other_cmd += " -0"
|
||||
return other_cmd
|
||||
|
||||
def build_phase1_data(args):
|
||||
phase1 = {}
|
||||
|
||||
text = args.text
|
||||
if text == "-":
|
||||
print(u"Reading text message from stdin..", file=args.stdout)
|
||||
text = sys.stdin.read()
|
||||
if not text and not args.what:
|
||||
text = six.moves.input("Text to send: ")
|
||||
|
||||
if text is not None:
|
||||
print(u"Sending text message (%d bytes)" % len(text), file=args.stdout)
|
||||
phase1 = { "message": text }
|
||||
fd_to_send = None
|
||||
return phase1, fd_to_send
|
||||
|
||||
what = os.path.join(args.cwd, args.what)
|
||||
what = what.rstrip(os.sep)
|
||||
if not os.path.exists(what):
|
||||
raise TransferError("Cannot send: no file/directory named '%s'" %
|
||||
args.what)
|
||||
basename = os.path.basename(what)
|
||||
|
||||
if os.path.isfile(what):
|
||||
# we're sending a file
|
||||
filesize = os.stat(what).st_size
|
||||
phase1["file"] = {
|
||||
"filename": basename,
|
||||
"filesize": filesize,
|
||||
}
|
||||
print(u"Sending %d byte file named '%s'" % (filesize, basename),
|
||||
file=args.stdout)
|
||||
fd_to_send = open(what, "rb")
|
||||
return phase1, fd_to_send
|
||||
|
||||
if os.path.isdir(what):
|
||||
print(u"Building zipfile..", file=args.stdout)
|
||||
# We're sending a directory. Create a zipfile in a tempdir and
|
||||
# send that.
|
||||
fd_to_send = tempfile.SpooledTemporaryFile()
|
||||
# TODO: I think ZIP_DEFLATED means compressed.. check it
|
||||
num_files = 0
|
||||
num_bytes = 0
|
||||
tostrip = len(what.split(os.sep))
|
||||
with zipfile.ZipFile(fd_to_send, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||
for path,dirs,files in os.walk(what):
|
||||
# path always starts with args.what, then sometimes might
|
||||
# have "/subdir" appended. We want the zipfile to contain
|
||||
# "" or "subdir"
|
||||
localpath = list(path.split(os.sep)[tostrip:])
|
||||
for fn in files:
|
||||
archivename = os.path.join(*tuple(localpath+[fn]))
|
||||
localfilename = os.path.join(path, fn)
|
||||
zf.write(localfilename, archivename)
|
||||
num_bytes += os.stat(localfilename).st_size
|
||||
num_files += 1
|
||||
fd_to_send.seek(0,2)
|
||||
filesize = fd_to_send.tell()
|
||||
fd_to_send.seek(0,0)
|
||||
phase1["directory"] = {
|
||||
"mode": "zipfile/deflated",
|
||||
"dirname": basename,
|
||||
"zipsize": filesize,
|
||||
"numbytes": num_bytes,
|
||||
"numfiles": num_files,
|
||||
}
|
||||
print(u"Sending directory (%d bytes compressed) named '%s'"
|
||||
% (filesize, basename), file=args.stdout)
|
||||
return phase1, fd_to_send
|
||||
|
||||
raise TypeError("'%s' is neither file nor directory" % args.what)
|
|
@ -1,7 +1,7 @@
|
|||
from twisted.application import service
|
||||
from twisted.internet import reactor, defer
|
||||
from twisted.python import log
|
||||
from ..twisted.util import allocate_ports
|
||||
from ..twisted.transit import allocate_tcp_port
|
||||
from ..servers.server import RelayServer
|
||||
from .. import __version__
|
||||
|
||||
|
@ -9,9 +9,8 @@ class ServerBase:
|
|||
def setUp(self):
|
||||
self.sp = service.MultiService()
|
||||
self.sp.startService()
|
||||
d = allocate_ports()
|
||||
def _got_ports(ports):
|
||||
relayport, transitport = ports
|
||||
relayport = allocate_tcp_port()
|
||||
transitport = allocate_tcp_port()
|
||||
s = RelayServer("tcp:%d:interface=127.0.0.1" % relayport,
|
||||
"tcp:%s:interface=127.0.0.1" % transitport,
|
||||
__version__)
|
||||
|
@ -20,8 +19,6 @@ class ServerBase:
|
|||
self._transit_server = s.transit
|
||||
self.relayurl = u"http://127.0.0.1:%d/wormhole-relay/" % relayport
|
||||
self.transit = u"tcp:127.0.0.1:%d" % transitport
|
||||
d.addCallback(_got_ports)
|
||||
return d
|
||||
|
||||
def tearDown(self):
|
||||
# Unit tests that spawn a (blocking) client in a thread might still
|
||||
|
|
|
@ -1,14 +1,11 @@
|
|||
from __future__ import print_function
|
||||
import json
|
||||
from twisted.trial import unittest
|
||||
from twisted.internet.defer import gatherResults, succeed, inlineCallbacks
|
||||
from twisted.internet.defer import gatherResults, succeed
|
||||
from twisted.internet.threads import deferToThread
|
||||
from ..blocking.transcribe import (Wormhole, UsageError, ChannelManager,
|
||||
WrongPasswordError)
|
||||
from ..blocking.eventsource import EventSourceFollower
|
||||
from ..blocking.transit import (TransitSender, TransitReceiver,
|
||||
build_sender_handshake,
|
||||
build_receiver_handshake)
|
||||
from .common import ServerBase
|
||||
|
||||
APPID = u"appid"
|
||||
|
@ -447,101 +444,3 @@ class EventSourceClient(unittest.TestCase):
|
|||
(u"message", u"three"),
|
||||
(u"e2", u"four"),
|
||||
])
|
||||
|
||||
class Transit(_DoBothMixin, ServerBase, unittest.TestCase):
|
||||
def test_hints(self):
|
||||
r = TransitReceiver(self.transit)
|
||||
hints = r.get_direct_hints()
|
||||
self.assertTrue(len(hints), hints)
|
||||
|
||||
@inlineCallbacks
|
||||
def test_direct_to_receiver(self):
|
||||
s = TransitSender(self.transit)
|
||||
r = TransitReceiver(self.transit)
|
||||
key = b"\x00"*32
|
||||
|
||||
# force the connection to be sender->receiver
|
||||
s.set_transit_key(key)
|
||||
# only use 127.0.0.1
|
||||
hint = u"tcp:127.0.0.1:%d" % r.listener.server_address[1]
|
||||
s.add_their_direct_hints([hint])
|
||||
s.add_their_relay_hints([])
|
||||
r.set_transit_key(key)
|
||||
r.add_their_direct_hints([])
|
||||
r.add_their_relay_hints([])
|
||||
|
||||
# it'd be nice to factor this chunk out with 'yield from', but that
|
||||
# didn't appear until python-3.3, and isn't in py2 at all.
|
||||
(sp, rp) = yield self.doBoth([s.connect], [r.connect])
|
||||
yield deferToThread(sp.send_record, b"01234")
|
||||
rec = yield deferToThread(rp.receive_record)
|
||||
self.assertEqual(rec, b"01234")
|
||||
yield deferToThread(sp.close)
|
||||
yield deferToThread(rp.close)
|
||||
|
||||
@inlineCallbacks
|
||||
def test_direct_to_sender(self):
|
||||
s = TransitSender(self.transit)
|
||||
r = TransitReceiver(self.transit)
|
||||
key = b"\x00"*32
|
||||
|
||||
# force the connection to be receiver->sender
|
||||
s.set_transit_key(key)
|
||||
s.add_their_direct_hints([])
|
||||
s.add_their_relay_hints([])
|
||||
r.set_transit_key(key)
|
||||
hint = u"tcp:127.0.0.1:%d" % s.listener.server_address[1]
|
||||
r.add_their_direct_hints([hint])
|
||||
r.add_their_relay_hints([])
|
||||
|
||||
(sp, rp) = yield self.doBoth([s.connect], [r.connect])
|
||||
yield deferToThread(sp.send_record, b"01234")
|
||||
rec = yield deferToThread(rp.receive_record)
|
||||
self.assertEqual(rec, b"01234")
|
||||
yield deferToThread(sp.close)
|
||||
yield deferToThread(rp.close)
|
||||
|
||||
@inlineCallbacks
|
||||
def test_relay(self):
|
||||
s = TransitSender(self.transit)
|
||||
r = TransitReceiver(self.transit)
|
||||
key = b"\x00"*32
|
||||
# force the connection to use the relay by not revealing direct hints
|
||||
s.set_transit_key(key)
|
||||
s.add_their_direct_hints([])
|
||||
s.add_their_relay_hints(r.get_relay_hints())
|
||||
r.set_transit_key(key)
|
||||
r.add_their_direct_hints([])
|
||||
r.add_their_relay_hints(s.get_relay_hints())
|
||||
|
||||
(sp, rp) = yield self.doBoth([s.connect], [r.connect])
|
||||
yield deferToThread(sp.send_record, b"01234")
|
||||
rec = yield deferToThread(rp.receive_record)
|
||||
self.assertEqual(rec, b"01234")
|
||||
yield deferToThread(sp.close)
|
||||
yield deferToThread(rp.close)
|
||||
# TODO: this may be racy if we don't poll the server to make sure
|
||||
# it's witnessed the first connection closing before querying the DB
|
||||
#import time
|
||||
#yield deferToThread(time.sleep, 1)
|
||||
|
||||
# check the transit relay's DB, make sure it counted the bytes
|
||||
db = self._transit_server._db
|
||||
c = db.execute("SELECT * FROM `usage` WHERE `type`=?", (u"transit",))
|
||||
rows = c.fetchall()
|
||||
self.assertEqual(len(rows), 1)
|
||||
row = rows[0]
|
||||
self.assertEqual(row["result"], u"happy")
|
||||
# Sender first writes relay_handshake and waits for OK, but that's
|
||||
# not counted by the transit server. Then sender writes
|
||||
# sender_handshake and waits for receiver_handshake. Then sender
|
||||
# writes GO and the body. Body is length-prefixed SecretBox, so
|
||||
# includes 4-byte length, 24-byte nonce, and 16-byte MAC.
|
||||
sender_count = (len(build_sender_handshake(b""))+
|
||||
len(b"go\n")+
|
||||
4+24+len(b"01234")+16)
|
||||
# Receiver first writes relay_handshake and waits for OK, but that's
|
||||
# not counted. Then receiver writes receiver_handshake and waits for
|
||||
# sender_handshake+GO.
|
||||
receiver_count = len(build_receiver_handshake(b""))
|
||||
self.assertEqual(row["total_bytes"], sender_count+receiver_count)
|
||||
|
|
|
@ -4,12 +4,10 @@ from twisted.trial import unittest
|
|||
from twisted.python import procutils, log
|
||||
from twisted.internet.utils import getProcessOutputAndValue
|
||||
from twisted.internet.defer import inlineCallbacks
|
||||
from twisted.internet.threads import deferToThread
|
||||
from .. import __version__
|
||||
from .common import ServerBase
|
||||
from ..scripts import (runner, cmd_send_blocking, cmd_send_twisted,
|
||||
cmd_receive_blocking, cmd_receive_twisted)
|
||||
from ..scripts.send_common import build_phase1_data
|
||||
from ..scripts import runner, cmd_send, cmd_receive
|
||||
from ..scripts.cmd_send import build_phase1_data
|
||||
from ..errors import TransferError
|
||||
from ..timing import DebugTiming
|
||||
|
||||
|
@ -219,8 +217,7 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
|
|||
|
||||
@inlineCallbacks
|
||||
def _do_test(self, as_subprocess=False,
|
||||
mode="text", addslash=False, override_filename=False,
|
||||
sender_twisted=False, receiver_twisted=False):
|
||||
mode="text", addslash=False, override_filename=False):
|
||||
assert mode in ("text", "file", "directory")
|
||||
common_args = ["--hide-progress",
|
||||
"--relay-url", self.relayurl,
|
||||
|
@ -314,14 +311,8 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
|
|||
rargs.stdout = io.StringIO()
|
||||
rargs.stderr = io.StringIO()
|
||||
rargs.timing = DebugTiming()
|
||||
if sender_twisted:
|
||||
send_d = cmd_send_twisted.send_twisted(sargs)
|
||||
else:
|
||||
send_d = deferToThread(cmd_send_blocking.send_blocking, sargs)
|
||||
if receiver_twisted:
|
||||
receive_d = cmd_receive_twisted.receive_twisted(rargs)
|
||||
else:
|
||||
receive_d = deferToThread(cmd_receive_blocking.receive_blocking, rargs)
|
||||
send_d = cmd_send.send_twisted(sargs)
|
||||
receive_d = cmd_receive.receive_twisted(rargs)
|
||||
|
||||
# The sender might fail, leaving the receiver hanging, or vice
|
||||
# versa. If either side fails, cancel the other, so it won't
|
||||
|
@ -418,24 +409,11 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
|
|||
return self._do_test()
|
||||
def test_text_subprocess(self):
|
||||
return self._do_test(as_subprocess=True)
|
||||
def test_text_twisted_to_blocking(self):
|
||||
return self._do_test(sender_twisted=True)
|
||||
def test_text_blocking_to_twisted(self):
|
||||
return self._do_test(receiver_twisted=True)
|
||||
def test_text_twisted_to_twisted(self):
|
||||
return self._do_test(sender_twisted=True, receiver_twisted=True)
|
||||
|
||||
def test_file(self):
|
||||
return self._do_test(mode="file")
|
||||
def test_file_override(self):
|
||||
return self._do_test(mode="file", override_filename=True)
|
||||
def test_file_twisted_to_blocking(self):
|
||||
return self._do_test(mode="file", sender_twisted=True)
|
||||
def test_file_blocking_to_twisted(self):
|
||||
return self._do_test(mode="file", receiver_twisted=True)
|
||||
def test_file_twisted_to_twisted(self):
|
||||
return self._do_test(mode="file",
|
||||
sender_twisted=True, receiver_twisted=True)
|
||||
|
||||
def test_directory(self):
|
||||
return self._do_test(mode="directory")
|
||||
|
@ -443,16 +421,3 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
|
|||
return self._do_test(mode="directory", addslash=True)
|
||||
def test_directory_override(self):
|
||||
return self._do_test(mode="directory", override_filename=True)
|
||||
def test_directory_twisted_to_blocking(self):
|
||||
return self._do_test(mode="directory", sender_twisted=True)
|
||||
def test_directory_twisted_to_blocking_addslash(self):
|
||||
return self._do_test(mode="directory", addslash=True,
|
||||
sender_twisted=True)
|
||||
def test_directory_blocking_to_twisted(self):
|
||||
return self._do_test(mode="directory", receiver_twisted=True)
|
||||
def test_directory_twisted_to_twisted(self):
|
||||
return self._do_test(mode="directory",
|
||||
sender_twisted=True, receiver_twisted=True)
|
||||
def test_directory_twisted_to_twisted_addslash(self):
|
||||
return self._do_test(mode="directory", addslash=True,
|
||||
sender_twisted=True, receiver_twisted=True)
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
from __future__ import print_function
|
||||
import json
|
||||
import requests
|
||||
from binascii import hexlify
|
||||
from six.moves.urllib_parse import urlencode
|
||||
from twisted.trial import unittest
|
||||
from twisted.internet import reactor, defer
|
||||
from twisted.internet import protocol, reactor, defer
|
||||
from twisted.internet.threads import deferToThread
|
||||
from twisted.internet.endpoints import clientFromString, connectProtocol
|
||||
from twisted.web.client import getPage, Agent, readBody
|
||||
from .. import __version__
|
||||
from .common import ServerBase
|
||||
|
@ -434,7 +436,30 @@ class Summary(unittest.TestCase):
|
|||
self.failUnlessEqual(c._summarize(make_moods(None, "scary"), 41),
|
||||
(1, "scary", 40, 9))
|
||||
|
||||
class Transit(unittest.TestCase):
|
||||
class Accumulator(protocol.Protocol):
|
||||
def __init__(self):
|
||||
self.data = b""
|
||||
self.count = 0
|
||||
self._wait = None
|
||||
def waitForBytes(self, more):
|
||||
assert self._wait is None
|
||||
self.count = more
|
||||
self._wait = defer.Deferred()
|
||||
self._check_done()
|
||||
return self._wait
|
||||
def dataReceived(self, data):
|
||||
self.data = self.data + data
|
||||
self._check_done()
|
||||
def _check_done(self):
|
||||
if self._wait and len(self.data) >= self.count:
|
||||
d = self._wait
|
||||
self._wait = None
|
||||
d.callback(self)
|
||||
def connectionLost(self, why):
|
||||
if self._wait:
|
||||
self._wait.errback(RuntimeError("closed"))
|
||||
|
||||
class Transit(ServerBase, unittest.TestCase):
|
||||
def test_blur_size(self):
|
||||
blur = transit_server.blur_size
|
||||
self.failUnlessEqual(blur(0), 0)
|
||||
|
@ -453,3 +478,62 @@ class Transit(unittest.TestCase):
|
|||
self.failUnlessEqual(blur(1100e6), 1100e6)
|
||||
self.failUnlessEqual(blur(1150e6), 1200e6)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_basic(self):
|
||||
ep = clientFromString(reactor, self.transit)
|
||||
a1 = yield connectProtocol(ep, Accumulator())
|
||||
a2 = yield connectProtocol(ep, Accumulator())
|
||||
|
||||
token1 = b"\x00"*32
|
||||
a1.transport.write(b"please relay " + hexlify(token1) + b"\n")
|
||||
a2.transport.write(b"please relay " + hexlify(token1) + b"\n")
|
||||
|
||||
# a correct handshake yields an ack, after which we can send
|
||||
exp = b"ok\n"
|
||||
yield a1.waitForBytes(len(exp))
|
||||
self.assertEqual(a1.data, exp)
|
||||
s1 = b"data1"
|
||||
a1.transport.write(s1)
|
||||
|
||||
exp = b"ok\n"
|
||||
yield a2.waitForBytes(len(exp))
|
||||
self.assertEqual(a2.data, exp)
|
||||
|
||||
# all data they sent after the handshake should be given to us
|
||||
exp = b"ok\n"+s1
|
||||
yield a2.waitForBytes(len(exp))
|
||||
self.assertEqual(a2.data, exp)
|
||||
|
||||
a1.transport.loseConnection()
|
||||
a2.transport.loseConnection()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_bad_handshake(self):
|
||||
ep = clientFromString(reactor, self.transit)
|
||||
a1 = yield connectProtocol(ep, Accumulator())
|
||||
|
||||
token1 = b"\x00"*32
|
||||
# the server waits for the exact number of bytes in the expected
|
||||
# handshake message. to trigger "bad handshake", we must match.
|
||||
a1.transport.write(b"please DELAY " + hexlify(token1) + b"\n")
|
||||
|
||||
exp = b"bad handshake\n"
|
||||
yield a1.waitForBytes(len(exp))
|
||||
self.assertEqual(a1.data, exp)
|
||||
|
||||
a1.transport.loseConnection()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_impatience(self):
|
||||
ep = clientFromString(reactor, self.transit)
|
||||
a1 = yield connectProtocol(ep, Accumulator())
|
||||
|
||||
token1 = b"\x00"*32
|
||||
# sending too many bytes is impatience.
|
||||
a1.transport.write(b"please RELAY NOWNOW " + hexlify(token1) + b"\n")
|
||||
|
||||
exp = b"impatient\n"
|
||||
yield a1.waitForBytes(len(exp))
|
||||
self.assertEqual(a1.data, exp)
|
||||
|
||||
a1.transport.loseConnection()
|
||||
|
|
|
@ -1,90 +0,0 @@
|
|||
from __future__ import print_function
|
||||
from binascii import hexlify
|
||||
from twisted.trial import unittest
|
||||
from twisted.internet import protocol, defer, reactor
|
||||
from twisted.internet.endpoints import clientFromString, connectProtocol
|
||||
from .common import ServerBase
|
||||
|
||||
class Accumulator(protocol.Protocol):
|
||||
def __init__(self):
|
||||
self.data = b""
|
||||
self.count = 0
|
||||
self._wait = None
|
||||
def waitForBytes(self, more):
|
||||
assert self._wait is None
|
||||
self.count = more
|
||||
self._wait = defer.Deferred()
|
||||
self._check_done()
|
||||
return self._wait
|
||||
def dataReceived(self, data):
|
||||
self.data = self.data + data
|
||||
self._check_done()
|
||||
def _check_done(self):
|
||||
if self._wait and len(self.data) >= self.count:
|
||||
d = self._wait
|
||||
self._wait = None
|
||||
d.callback(self)
|
||||
def connectionLost(self, why):
|
||||
if self._wait:
|
||||
self._wait.errback(RuntimeError("closed"))
|
||||
|
||||
class Transit(ServerBase, unittest.TestCase):
|
||||
@defer.inlineCallbacks
|
||||
def test_basic(self):
|
||||
ep = clientFromString(reactor, self.transit)
|
||||
a1 = yield connectProtocol(ep, Accumulator())
|
||||
a2 = yield connectProtocol(ep, Accumulator())
|
||||
|
||||
token1 = b"\x00"*32
|
||||
a1.transport.write(b"please relay " + hexlify(token1) + b"\n")
|
||||
a2.transport.write(b"please relay " + hexlify(token1) + b"\n")
|
||||
|
||||
# a correct handshake yields an ack, after which we can send
|
||||
exp = b"ok\n"
|
||||
yield a1.waitForBytes(len(exp))
|
||||
self.assertEqual(a1.data, exp)
|
||||
s1 = b"data1"
|
||||
a1.transport.write(s1)
|
||||
|
||||
exp = b"ok\n"
|
||||
yield a2.waitForBytes(len(exp))
|
||||
self.assertEqual(a2.data, exp)
|
||||
|
||||
# all data they sent after the handshake should be given to us
|
||||
exp = b"ok\n"+s1
|
||||
yield a2.waitForBytes(len(exp))
|
||||
self.assertEqual(a2.data, exp)
|
||||
|
||||
a1.transport.loseConnection()
|
||||
a2.transport.loseConnection()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_bad_handshake(self):
|
||||
ep = clientFromString(reactor, self.transit)
|
||||
a1 = yield connectProtocol(ep, Accumulator())
|
||||
|
||||
token1 = b"\x00"*32
|
||||
# the server waits for the exact number of bytes in the expected
|
||||
# handshake message. to trigger "bad handshake", we must match.
|
||||
a1.transport.write(b"please DELAY " + hexlify(token1) + b"\n")
|
||||
|
||||
exp = b"bad handshake\n"
|
||||
yield a1.waitForBytes(len(exp))
|
||||
self.assertEqual(a1.data, exp)
|
||||
|
||||
a1.transport.loseConnection()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_impatience(self):
|
||||
ep = clientFromString(reactor, self.transit)
|
||||
a1 = yield connectProtocol(ep, Accumulator())
|
||||
|
||||
token1 = b"\x00"*32
|
||||
# sending too many bytes is impatience.
|
||||
a1.transport.write(b"please RELAY NOWNOW " + hexlify(token1) + b"\n")
|
||||
|
||||
exp = b"impatient\n"
|
||||
yield a1.waitForBytes(len(exp))
|
||||
self.assertEqual(a1.data, exp)
|
||||
|
||||
a1.transport.loseConnection()
|
|
@ -1,87 +0,0 @@
|
|||
import re
|
||||
from binascii import hexlify
|
||||
from .util.hkdf import HKDF
|
||||
|
||||
class TransitError(Exception):
|
||||
pass
|
||||
|
||||
class BadHandshake(Exception):
|
||||
pass
|
||||
|
||||
class TransitClosed(TransitError):
|
||||
pass
|
||||
|
||||
class BadNonce(TransitError):
|
||||
pass
|
||||
|
||||
# The beginning of each TCP connection consists of the following handshake
|
||||
# messages. The sender transmits the same text regardless of whether it is on
|
||||
# the initiating/connecting end of the TCP connection, or on the
|
||||
# listening/accepting side. Same for the receiver.
|
||||
#
|
||||
# sender -> receiver: transit sender TXID_HEX ready\n\n
|
||||
# receiver -> sender: transit receiver RXID_HEX ready\n\n
|
||||
#
|
||||
# Any deviations from this result in the socket being closed. The handshake
|
||||
# messages are designed to provoke an invalid response from other sorts of
|
||||
# servers (HTTP, SMTP, echo).
|
||||
#
|
||||
# If the sender is satisfied with the handshake, and this is the first socket
|
||||
# to complete negotiation, the sender does:
|
||||
#
|
||||
# sender -> receiver: go\n
|
||||
#
|
||||
# and the next byte on the wire will be from the application.
|
||||
#
|
||||
# If this is not the first socket, the sender does:
|
||||
#
|
||||
# sender -> receiver: nevermind\n
|
||||
#
|
||||
# and closes the socket.
|
||||
|
||||
# So the receiver looks for "transit sender TXID_HEX ready\n\ngo\n" and hangs
|
||||
# up upon the first wrong byte. The sender lookgs for "transit receiver
|
||||
# RXID_HEX ready\n\n" and then makes a first/not-first decision about sending
|
||||
# "go\n" or "nevermind\n"+close().
|
||||
|
||||
def build_receiver_handshake(key):
|
||||
hexid = HKDF(key, 32, CTXinfo=b"transit_receiver")
|
||||
return b"transit receiver "+hexlify(hexid)+b" ready\n\n"
|
||||
|
||||
def build_sender_handshake(key):
|
||||
hexid = HKDF(key, 32, CTXinfo=b"transit_sender")
|
||||
return b"transit sender "+hexlify(hexid)+b" ready\n\n"
|
||||
|
||||
def build_relay_handshake(key):
|
||||
token = HKDF(key, 32, CTXinfo=b"transit_relay_token")
|
||||
return b"please relay "+hexlify(token)+b"\n"
|
||||
|
||||
# The hint format is: TYPE,VALUE= /^([a-zA-Z0-9]+):(.*)$/ . VALUE depends
|
||||
# upon TYPE, and it can have more colons in it. For TYPE=tcp (the only one
|
||||
# currently defined), ADDR,PORT = /^(.*):(\d+)$/ , so ADDR can have colons.
|
||||
# ADDR can be a hostname, ipv4 dotted-quad, or ipv6 colon-hex. If the hint
|
||||
# publisher wants anonymity, their only hint's ADDR will end in .onion .
|
||||
|
||||
def parse_hint_tcp(hint):
|
||||
assert isinstance(hint, type(u""))
|
||||
# return tuple or None for an unparseable hint
|
||||
mo = re.search(r'^([a-zA-Z0-9]+):(.*)$', hint)
|
||||
if not mo:
|
||||
print("unparseable hint '%s'" % (hint,))
|
||||
return None
|
||||
hint_type = mo.group(1)
|
||||
if hint_type != "tcp":
|
||||
print("unknown hint type '%s' in '%s'" % (hint_type, hint))
|
||||
return None
|
||||
hint_value = mo.group(2)
|
||||
mo = re.search(r'^(.*):(\d+)$', hint_value)
|
||||
if not mo:
|
||||
print("unparseable TCP hint '%s'" % (hint,))
|
||||
return None
|
||||
hint_host = mo.group(1)
|
||||
try:
|
||||
hint_port = int(mo.group(2))
|
||||
except ValueError:
|
||||
print("non-numeric port in TCP hint '%s'" % (hint,))
|
||||
return None
|
||||
return hint_host, hint_port
|
|
@ -1,5 +1,5 @@
|
|||
from __future__ import print_function
|
||||
import sys, time, socket, collections
|
||||
import re, sys, time, socket, collections
|
||||
from binascii import hexlify, unhexlify
|
||||
from zope.interface import implementer
|
||||
from twisted.python.runtime import platformType
|
||||
|
@ -12,11 +12,6 @@ from ..util import ipaddrs
|
|||
from ..util.hkdf import HKDF
|
||||
from ..errors import UsageError
|
||||
from ..timing import DebugTiming
|
||||
from ..transit_common import (BadHandshake,
|
||||
BadNonce,
|
||||
build_receiver_handshake,
|
||||
build_sender_handshake,
|
||||
build_relay_handshake)
|
||||
|
||||
def debug(msg):
|
||||
if False:
|
||||
|
@ -24,6 +19,90 @@ def debug(msg):
|
|||
def since(start):
|
||||
return time.time() - start
|
||||
|
||||
class TransitError(Exception):
|
||||
pass
|
||||
|
||||
class BadHandshake(Exception):
|
||||
pass
|
||||
|
||||
class TransitClosed(TransitError):
|
||||
pass
|
||||
|
||||
class BadNonce(TransitError):
|
||||
pass
|
||||
|
||||
# The beginning of each TCP connection consists of the following handshake
|
||||
# messages. The sender transmits the same text regardless of whether it is on
|
||||
# the initiating/connecting end of the TCP connection, or on the
|
||||
# listening/accepting side. Same for the receiver.
|
||||
#
|
||||
# sender -> receiver: transit sender TXID_HEX ready\n\n
|
||||
# receiver -> sender: transit receiver RXID_HEX ready\n\n
|
||||
#
|
||||
# Any deviations from this result in the socket being closed. The handshake
|
||||
# messages are designed to provoke an invalid response from other sorts of
|
||||
# servers (HTTP, SMTP, echo).
|
||||
#
|
||||
# If the sender is satisfied with the handshake, and this is the first socket
|
||||
# to complete negotiation, the sender does:
|
||||
#
|
||||
# sender -> receiver: go\n
|
||||
#
|
||||
# and the next byte on the wire will be from the application.
|
||||
#
|
||||
# If this is not the first socket, the sender does:
|
||||
#
|
||||
# sender -> receiver: nevermind\n
|
||||
#
|
||||
# and closes the socket.
|
||||
|
||||
# So the receiver looks for "transit sender TXID_HEX ready\n\ngo\n" and hangs
|
||||
# up upon the first wrong byte. The sender lookgs for "transit receiver
|
||||
# RXID_HEX ready\n\n" and then makes a first/not-first decision about sending
|
||||
# "go\n" or "nevermind\n"+close().
|
||||
|
||||
def build_receiver_handshake(key):
|
||||
hexid = HKDF(key, 32, CTXinfo=b"transit_receiver")
|
||||
return b"transit receiver "+hexlify(hexid)+b" ready\n\n"
|
||||
|
||||
def build_sender_handshake(key):
|
||||
hexid = HKDF(key, 32, CTXinfo=b"transit_sender")
|
||||
return b"transit sender "+hexlify(hexid)+b" ready\n\n"
|
||||
|
||||
def build_relay_handshake(key):
|
||||
token = HKDF(key, 32, CTXinfo=b"transit_relay_token")
|
||||
return b"please relay "+hexlify(token)+b"\n"
|
||||
|
||||
# The hint format is: TYPE,VALUE= /^([a-zA-Z0-9]+):(.*)$/ . VALUE depends
|
||||
# upon TYPE, and it can have more colons in it. For TYPE=tcp (the only one
|
||||
# currently defined), ADDR,PORT = /^(.*):(\d+)$/ , so ADDR can have colons.
|
||||
# ADDR can be a hostname, ipv4 dotted-quad, or ipv6 colon-hex. If the hint
|
||||
# publisher wants anonymity, their only hint's ADDR will end in .onion .
|
||||
|
||||
def parse_hint_tcp(hint):
|
||||
assert isinstance(hint, type(u""))
|
||||
# return tuple or None for an unparseable hint
|
||||
mo = re.search(r'^([a-zA-Z0-9]+):(.*)$', hint)
|
||||
if not mo:
|
||||
print("unparseable hint '%s'" % (hint,))
|
||||
return None
|
||||
hint_type = mo.group(1)
|
||||
if hint_type != "tcp":
|
||||
print("unknown hint type '%s' in '%s'" % (hint_type, hint))
|
||||
return None
|
||||
hint_value = mo.group(2)
|
||||
mo = re.search(r'^(.*):(\d+)$', hint_value)
|
||||
if not mo:
|
||||
print("unparseable TCP hint '%s'" % (hint,))
|
||||
return None
|
||||
hint_host = mo.group(1)
|
||||
try:
|
||||
hint_port = int(mo.group(2))
|
||||
except ValueError:
|
||||
print("non-numeric port in TCP hint '%s'" % (hint,))
|
||||
return None
|
||||
return hint_host, hint_port
|
||||
|
||||
TIMEOUT=15
|
||||
|
||||
@implementer(interfaces.IProducer, interfaces.IConsumer)
|
||||
|
@ -684,7 +763,7 @@ class Common:
|
|||
return d
|
||||
|
||||
def _endpoint_from_hint(self, hint):
|
||||
# TODO: use transit_common.parse_hint_tcp
|
||||
# TODO: use parse_hint_tcp
|
||||
if ":" not in hint:
|
||||
return None
|
||||
pieces = hint.split(":")
|
||||
|
|
|
@ -1,21 +0,0 @@
|
|||
from twisted.internet import defer, protocol, endpoints, reactor
|
||||
|
||||
def allocate_port():
|
||||
ep = endpoints.serverFromString(reactor, "tcp:0:interface=127.0.0.1")
|
||||
d = ep.listen(protocol.Factory())
|
||||
def _listening(lp):
|
||||
port = lp.getHost().port
|
||||
d2 = lp.stopListening()
|
||||
d2.addCallback(lambda _: port)
|
||||
return d2
|
||||
d.addCallback(_listening)
|
||||
return d
|
||||
|
||||
def allocate_ports():
|
||||
d = defer.DeferredList([allocate_port(), allocate_port()])
|
||||
def _done(results):
|
||||
port1 = results[0][1]
|
||||
port2 = results[1][1]
|
||||
return (port1, port2)
|
||||
d.addCallback(_done)
|
||||
return d
|
|
@ -1,50 +0,0 @@
|
|||
# -*- test-case-name: foolscap.test_observer -*-
|
||||
|
||||
# many thanks to AllMyData for contributing the initial version of this code
|
||||
|
||||
from twisted.internet import defer
|
||||
from foolscap import eventual
|
||||
|
||||
class OneShotObserverList(object):
|
||||
"""A one-shot event distributor.
|
||||
|
||||
Subscribers can get a Deferred that will fire with the results of the
|
||||
event once it finally occurs. The caller does not need to know whether
|
||||
the event has happened yet or not: they get a Deferred in either case.
|
||||
|
||||
The Deferreds returned to subscribers are guaranteed to not fire in the
|
||||
current reactor turn; instead, eventually() is used to fire them in a
|
||||
later turn. Look at Mark Miller's 'Concurrency Among Strangers' paper on
|
||||
erights.org for a description of why this property is useful.
|
||||
|
||||
I can only be fired once."""
|
||||
|
||||
def __init__(self):
|
||||
self._fired = False
|
||||
self._result = None
|
||||
self._watchers = []
|
||||
self.__repr__ = self._unfired_repr
|
||||
|
||||
def _unfired_repr(self):
|
||||
return "<OneShotObserverList [%s]>" % (self._watchers, )
|
||||
|
||||
def _fired_repr(self):
|
||||
return "<OneShotObserverList -> %s>" % (self._result, )
|
||||
|
||||
def whenFired(self):
|
||||
if self._fired:
|
||||
return eventual.fireEventually(self._result)
|
||||
d = defer.Deferred()
|
||||
self._watchers.append(d)
|
||||
return d
|
||||
|
||||
def fire(self, result):
|
||||
assert not self._fired
|
||||
self._fired = True
|
||||
self._result = result
|
||||
|
||||
for w in self._watchers:
|
||||
eventual.eventually(w.callback, result)
|
||||
del self._watchers
|
||||
self.__repr__ = self._fired_repr
|
||||
|
7
tox.ini
7
tox.ini
|
@ -7,12 +7,6 @@
|
|||
envlist = py27,py33,py34,py35
|
||||
skip_missing_interpreters = True
|
||||
|
||||
# There's a race-condition bug in Twisted-15.5.0 (#8014, fixed in trunk) that
|
||||
# manifests as various intermittent magic-wormhole test failures that always
|
||||
# include twisted.internet.endpoints "iterateEndpoint" or "checkDone". So run
|
||||
# all builds with a copy of Twisted from git 'trunk' until Twisted-16.0.0 is
|
||||
# released and we can just depend on that.
|
||||
|
||||
# On windows we need "pypiwin32" installed. It's supposedly possible to make
|
||||
# Twisted do this by depending upon "twisted[windows]" instead of just
|
||||
# "twisted", but when I try this via Appveyor, the extra is ignored.
|
||||
|
@ -24,7 +18,6 @@ skip_missing_interpreters = True
|
|||
|
||||
[testenv]
|
||||
deps =
|
||||
twisted >= 16.1.0
|
||||
pyflakes
|
||||
{env:EXTRA_DEPENDENCY:}
|
||||
commands =
|
||||
|
|
Loading…
Reference in New Issue
Block a user