Use twisted implementation all the time.
Merge commit '1a455c0'
This commit is contained in:
commit
7e1405576e
4
setup.py
4
setup.py
|
@ -21,10 +21,8 @@ setup(name="magic-wormhole",
|
||||||
entry_points={"console_scripts":
|
entry_points={"console_scripts":
|
||||||
["wormhole = wormhole.scripts.runner:entry"]},
|
["wormhole = wormhole.scripts.runner:entry"]},
|
||||||
install_requires=["spake2==0.3", "pynacl", "requests", "argparse",
|
install_requires=["spake2==0.3", "pynacl", "requests", "argparse",
|
||||||
"six"],
|
"six", "twisted >= 16.1.0"],
|
||||||
extras_require={"tor": ["txtorcon", "ipaddr"]},
|
extras_require={"tor": ["txtorcon", "ipaddr"]},
|
||||||
# for Twisted support, we want Twisted>=15.5.0. Older Twisteds don't
|
|
||||||
# provide sufficient python3 compatibility.
|
|
||||||
test_suite="wormhole.test",
|
test_suite="wormhole.test",
|
||||||
cmdclass=commands,
|
cmdclass=commands,
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,400 +0,0 @@
|
||||||
from __future__ import print_function
|
|
||||||
import time, threading, socket
|
|
||||||
from six.moves import socketserver
|
|
||||||
from binascii import hexlify, unhexlify
|
|
||||||
from nacl.secret import SecretBox
|
|
||||||
from ..util import ipaddrs
|
|
||||||
from ..util.hkdf import HKDF
|
|
||||||
from ..errors import UsageError
|
|
||||||
from ..timing import DebugTiming
|
|
||||||
from ..transit_common import (TransitError, BadHandshake, TransitClosed,
|
|
||||||
BadNonce,
|
|
||||||
build_receiver_handshake,
|
|
||||||
build_sender_handshake,
|
|
||||||
build_relay_handshake,
|
|
||||||
parse_hint_tcp)
|
|
||||||
|
|
||||||
TIMEOUT=15
|
|
||||||
|
|
||||||
# 1: sender only transmits, receiver only accepts, both wait forever
|
|
||||||
# 2: sender also accepts, receiver also transmits
|
|
||||||
# 3: timeouts / stop when no more progress can be made
|
|
||||||
# 4: add relay
|
|
||||||
# 5: accelerate shutdown of losing sockets
|
|
||||||
|
|
||||||
def send_to(skt, data):
|
|
||||||
sent = 0
|
|
||||||
while sent < len(data):
|
|
||||||
sent += skt.send(data[sent:])
|
|
||||||
|
|
||||||
def wait_for_line(skt, max_length, description):
|
|
||||||
got = b""
|
|
||||||
while len(got) < max_length:
|
|
||||||
got += skt.recv(1)
|
|
||||||
if got.endswith(b"\n"):
|
|
||||||
return got[:-1]
|
|
||||||
raise BadHandshake("exceeded max_length, got %r on %s" %
|
|
||||||
(got, description))
|
|
||||||
|
|
||||||
def wait_for(skt, expected, description):
|
|
||||||
assert isinstance(expected, type(b""))
|
|
||||||
got = b""
|
|
||||||
while len(got) < len(expected):
|
|
||||||
got += skt.recv(1)
|
|
||||||
if expected[:len(got)] != got:
|
|
||||||
raise BadHandshake("got %r want %r on %s" %
|
|
||||||
(got, expected, description))
|
|
||||||
|
|
||||||
def debug(msg):
|
|
||||||
if False:
|
|
||||||
print(msg)
|
|
||||||
def since(start):
|
|
||||||
return time.time() - start
|
|
||||||
|
|
||||||
def connector(owner, hint, description,
|
|
||||||
send_handshake, expected_handshake, relay_handshake=None):
|
|
||||||
start = time.time()
|
|
||||||
parsed_hint = parse_hint_tcp(hint)
|
|
||||||
if not parsed_hint:
|
|
||||||
return # unparseable
|
|
||||||
addr,port = parsed_hint
|
|
||||||
skt = None
|
|
||||||
debug("+ connector(%s)" % hint)
|
|
||||||
try:
|
|
||||||
skt = socket.create_connection((addr,port),
|
|
||||||
TIMEOUT) # timeout or ECONNREFUSED
|
|
||||||
skt.settimeout(TIMEOUT)
|
|
||||||
debug(" - socket(%s) connected CT+%.1f" % (description, since(start)))
|
|
||||||
if relay_handshake:
|
|
||||||
debug(" - sending relay_handshake")
|
|
||||||
send_to(skt, relay_handshake)
|
|
||||||
relay_msg = wait_for_line(skt, 10000, description)
|
|
||||||
if relay_msg != b"ok":
|
|
||||||
raise BadHandshake(relay_msg)
|
|
||||||
debug(" - relay ready CT+%.1f" % (since(start),))
|
|
||||||
send_to(skt, send_handshake)
|
|
||||||
wait_for(skt, expected_handshake, description)
|
|
||||||
debug(" + connector(%s) ready CT+%.1f" % (hint, since(start)))
|
|
||||||
except Exception as e:
|
|
||||||
debug(" - error(%s)(%r) CT+%.1f" % (hint, e, since(start)))
|
|
||||||
try:
|
|
||||||
if skt:
|
|
||||||
skt.shutdown(socket.SHUT_WR)
|
|
||||||
except socket.error:
|
|
||||||
pass
|
|
||||||
if skt:
|
|
||||||
skt.close()
|
|
||||||
# ignore socket errors, warn about coding errors
|
|
||||||
if not isinstance(e, (socket.error, socket.timeout, BadHandshake)):
|
|
||||||
raise
|
|
||||||
debug(" - notifying owner._connector_failed(%s) CT+%.1f" % (hint, since(start)))
|
|
||||||
owner._connector_failed(hint)
|
|
||||||
return
|
|
||||||
# owner is now responsible for the socket
|
|
||||||
owner._negotiation_finished(skt, description) # note thread
|
|
||||||
|
|
||||||
def handle(skt, client_address, owner, description,
|
|
||||||
send_handshake, expected_handshake):
|
|
||||||
try:
|
|
||||||
debug("handle %r" % (skt,))
|
|
||||||
skt.settimeout(TIMEOUT)
|
|
||||||
send_to(skt, send_handshake)
|
|
||||||
got = b""
|
|
||||||
# for the receiver, this includes the "go\n"
|
|
||||||
while len(got) < len(expected_handshake):
|
|
||||||
more = skt.recv(1)
|
|
||||||
if not more:
|
|
||||||
raise BadHandshake("disconnect after merely '%r'" % got)
|
|
||||||
got += more
|
|
||||||
if expected_handshake[:len(got)] != got:
|
|
||||||
raise BadHandshake("got '%r' want '%r'" %
|
|
||||||
(got, expected_handshake))
|
|
||||||
debug("handler negotiation finished %r" % (client_address,))
|
|
||||||
except Exception as e:
|
|
||||||
debug("handler failed %r" % (client_address,))
|
|
||||||
try:
|
|
||||||
# this raises socket.err(EBADF) if the socket was already closed
|
|
||||||
skt.shutdown(socket.SHUT_WR)
|
|
||||||
except socket.error:
|
|
||||||
pass
|
|
||||||
skt.close() # this appears to be idempotent
|
|
||||||
# ignore socket errors, warn about coding errors
|
|
||||||
if not isinstance(e, (socket.error, socket.timeout, BadHandshake)):
|
|
||||||
raise
|
|
||||||
return
|
|
||||||
# owner is now responsible for the socket
|
|
||||||
owner._negotiation_finished(skt, description) # note thread
|
|
||||||
|
|
||||||
class MyTCPServer(socketserver.TCPServer):
|
|
||||||
allow_reuse_address = True
|
|
||||||
|
|
||||||
def process_request(self, request, client_address):
|
|
||||||
description = "<-tcp:%s:%d" % (client_address[0], client_address[1])
|
|
||||||
ready_lock = self.owner._ready_for_connections_lock
|
|
||||||
ready_lock.acquire()
|
|
||||||
while not (self.owner._ready_for_connections
|
|
||||||
and self.owner._transit_key):
|
|
||||||
ready_lock.wait()
|
|
||||||
# owner._transit_key is either None or set to a value. We don't
|
|
||||||
# modify it from here, so we can release the condition lock before
|
|
||||||
# grabbing the key.
|
|
||||||
ready_lock.release()
|
|
||||||
|
|
||||||
# Once it is set, we can get handler_(send|receive)_handshake, which
|
|
||||||
# is what we actually care about.
|
|
||||||
t = threading.Thread(target=handle,
|
|
||||||
args=(request, client_address,
|
|
||||||
self.owner, description,
|
|
||||||
self.owner.handler_send_handshake,
|
|
||||||
self.owner.handler_expected_handshake))
|
|
||||||
t.daemon = True
|
|
||||||
t.start()
|
|
||||||
|
|
||||||
|
|
||||||
class ReceiveBuffer:
|
|
||||||
def __init__(self, skt):
|
|
||||||
self.skt = skt
|
|
||||||
self.buf = b""
|
|
||||||
|
|
||||||
def read(self, count):
|
|
||||||
while len(self.buf) < count:
|
|
||||||
more = self.skt.recv(4096)
|
|
||||||
if not more:
|
|
||||||
raise TransitClosed
|
|
||||||
self.buf += more
|
|
||||||
rc = self.buf[:count]
|
|
||||||
self.buf = self.buf[count:]
|
|
||||||
return rc
|
|
||||||
|
|
||||||
class RecordPipe:
|
|
||||||
def __init__(self, skt, send_key, receive_key, description):
|
|
||||||
self.skt = skt
|
|
||||||
self.send_box = SecretBox(send_key)
|
|
||||||
self.send_nonce = 0
|
|
||||||
self.receive_buf = ReceiveBuffer(self.skt)
|
|
||||||
self.receive_box = SecretBox(receive_key)
|
|
||||||
self.next_receive_nonce = 0
|
|
||||||
self._description = description
|
|
||||||
|
|
||||||
def describe(self):
|
|
||||||
return self._description
|
|
||||||
|
|
||||||
def send_record(self, record):
|
|
||||||
if not isinstance(record, type(b"")): raise UsageError
|
|
||||||
assert SecretBox.NONCE_SIZE == 24
|
|
||||||
assert self.send_nonce < 2**(8*24)
|
|
||||||
assert len(record) < 2**(8*4)
|
|
||||||
nonce = unhexlify("%048x" % self.send_nonce) # big-endian
|
|
||||||
self.send_nonce += 1
|
|
||||||
encrypted = self.send_box.encrypt(record, nonce)
|
|
||||||
length = unhexlify("%08x" % len(encrypted)) # always 4 bytes long
|
|
||||||
send_to(self.skt, length)
|
|
||||||
send_to(self.skt, encrypted)
|
|
||||||
|
|
||||||
def receive_record(self):
|
|
||||||
length_buf = self.receive_buf.read(4)
|
|
||||||
length = int(hexlify(length_buf), 16)
|
|
||||||
encrypted = self.receive_buf.read(length)
|
|
||||||
nonce_buf = encrypted[:SecretBox.NONCE_SIZE] # assume it's prepended
|
|
||||||
nonce = int(hexlify(nonce_buf), 16)
|
|
||||||
if nonce != self.next_receive_nonce:
|
|
||||||
raise BadNonce("received out-of-order record")
|
|
||||||
self.next_receive_nonce += 1
|
|
||||||
record = self.receive_box.decrypt(encrypted)
|
|
||||||
return record
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
self.skt.close()
|
|
||||||
|
|
||||||
class Common:
|
|
||||||
def __init__(self, transit_relay, no_listen=False, timing=None):
|
|
||||||
if transit_relay:
|
|
||||||
if not isinstance(transit_relay, type(u"")):
|
|
||||||
raise UsageError
|
|
||||||
self._transit_relays = [transit_relay]
|
|
||||||
else:
|
|
||||||
self._transit_relays = []
|
|
||||||
self._no_listen = no_listen
|
|
||||||
self._timing = timing or DebugTiming()
|
|
||||||
self._timing_started = self._timing.add_event("transit")
|
|
||||||
self.winning = threading.Event()
|
|
||||||
self._negotiation_check_lock = threading.Lock()
|
|
||||||
self._ready_for_connections_lock = threading.Condition()
|
|
||||||
self._ready_for_connections = False
|
|
||||||
self._transit_key = None
|
|
||||||
self._start_server()
|
|
||||||
|
|
||||||
def _start_server(self):
|
|
||||||
if self._no_listen:
|
|
||||||
self.my_direct_hints = []
|
|
||||||
self.listener = None
|
|
||||||
return
|
|
||||||
server = MyTCPServer(("", 0), None)
|
|
||||||
_, port = server.server_address
|
|
||||||
self.my_direct_hints = [u"tcp:%s:%d" % (addr, port)
|
|
||||||
for addr in ipaddrs.find_addresses()]
|
|
||||||
server.owner = self
|
|
||||||
server_thread = threading.Thread(target=server.serve_forever)
|
|
||||||
server_thread.daemon = True
|
|
||||||
server_thread.start()
|
|
||||||
self.listener = server
|
|
||||||
|
|
||||||
def get_direct_hints(self):
|
|
||||||
return self.my_direct_hints
|
|
||||||
def get_relay_hints(self):
|
|
||||||
return self._transit_relays
|
|
||||||
|
|
||||||
def add_their_direct_hints(self, hints):
|
|
||||||
for h in hints:
|
|
||||||
if not isinstance(h, type(u"")):
|
|
||||||
raise TypeError("hint '%r' should be unicode, not %s"
|
|
||||||
% (h, type(h)))
|
|
||||||
self._their_direct_hints = list(hints)
|
|
||||||
def add_their_relay_hints(self, hints):
|
|
||||||
for h in hints:
|
|
||||||
if not isinstance(h, type(u"")):
|
|
||||||
raise TypeError("hint '%r' should be unicode, not %s"
|
|
||||||
% (h, type(h)))
|
|
||||||
self._their_relay_hints = list(hints)
|
|
||||||
|
|
||||||
def _send_this(self):
|
|
||||||
if self.is_sender:
|
|
||||||
return build_sender_handshake(self._transit_key)
|
|
||||||
else:
|
|
||||||
return build_receiver_handshake(self._transit_key)
|
|
||||||
|
|
||||||
def _expect_this(self):
|
|
||||||
if self.is_sender:
|
|
||||||
return build_receiver_handshake(self._transit_key)
|
|
||||||
else:
|
|
||||||
return build_sender_handshake(self._transit_key) + b"go\n"
|
|
||||||
|
|
||||||
def _sender_record_key(self):
|
|
||||||
if self.is_sender:
|
|
||||||
return HKDF(self._transit_key, SecretBox.KEY_SIZE,
|
|
||||||
CTXinfo=b"transit_record_sender_key")
|
|
||||||
else:
|
|
||||||
return HKDF(self._transit_key, SecretBox.KEY_SIZE,
|
|
||||||
CTXinfo=b"transit_record_receiver_key")
|
|
||||||
|
|
||||||
def _receiver_record_key(self):
|
|
||||||
if self.is_sender:
|
|
||||||
return HKDF(self._transit_key, SecretBox.KEY_SIZE,
|
|
||||||
CTXinfo=b"transit_record_receiver_key")
|
|
||||||
else:
|
|
||||||
return HKDF(self._transit_key, SecretBox.KEY_SIZE,
|
|
||||||
CTXinfo=b"transit_record_sender_key")
|
|
||||||
|
|
||||||
def set_transit_key(self, key):
|
|
||||||
# This _ready_for_connections condition/lock protects us against the
|
|
||||||
# race where the sender knows the hints and the key, and connects to
|
|
||||||
# the receiver's transit socket before the receiver gets relay
|
|
||||||
# message (and thus the key).
|
|
||||||
self._ready_for_connections_lock.acquire()
|
|
||||||
self._transit_key = key
|
|
||||||
self.handler_send_handshake = self._send_this() # no "go"
|
|
||||||
self.handler_expected_handshake = self._expect_this()
|
|
||||||
self._ready_for_connections_lock.notify_all()
|
|
||||||
self._ready_for_connections_lock.release()
|
|
||||||
|
|
||||||
def _start_outbound(self):
|
|
||||||
self._active_connectors = set(self._their_direct_hints)
|
|
||||||
self._attempted_connectors = set()
|
|
||||||
for hint in self._their_direct_hints:
|
|
||||||
self._start_connector(hint)
|
|
||||||
if not self._their_direct_hints:
|
|
||||||
self._start_relay_connectors()
|
|
||||||
|
|
||||||
def _start_connector(self, hint, is_relay=False):
|
|
||||||
# Don't try any hint more than once. If all hints fail, we'll
|
|
||||||
# eventually timeout. We make no attempt to fail any faster.
|
|
||||||
if hint in self._attempted_connectors:
|
|
||||||
return
|
|
||||||
self._attempted_connectors.add(hint)
|
|
||||||
description = "->%s" % (hint,)
|
|
||||||
if is_relay:
|
|
||||||
description = "->relay:%s" % (hint,)
|
|
||||||
args = (self, hint, description,
|
|
||||||
self._send_this(), self._expect_this())
|
|
||||||
if is_relay:
|
|
||||||
args = args + (build_relay_handshake(self._transit_key),)
|
|
||||||
t = threading.Thread(target=connector, args=args)
|
|
||||||
t.daemon = True
|
|
||||||
t.start()
|
|
||||||
|
|
||||||
def _start_relay_connectors(self):
|
|
||||||
self._active_connectors.update(self._their_direct_hints)
|
|
||||||
for hint in self._their_relay_hints:
|
|
||||||
self._start_connector(hint, is_relay=True)
|
|
||||||
|
|
||||||
def establish_socket(self):
|
|
||||||
start = time.time()
|
|
||||||
self.winning_skt = None
|
|
||||||
self.winning_skt_description = None
|
|
||||||
self._ready_for_connections_lock.acquire()
|
|
||||||
self._ready_for_connections = True
|
|
||||||
self._ready_for_connections_lock.notify_all()
|
|
||||||
self._ready_for_connections_lock.release()
|
|
||||||
self._start_outbound()
|
|
||||||
|
|
||||||
# we sit here until one of our inbound or outbound sockets succeeds
|
|
||||||
flag = self.winning.wait(2*TIMEOUT)
|
|
||||||
debug("wait returned at %.1f" % (since(start),))
|
|
||||||
|
|
||||||
if not flag:
|
|
||||||
# timeout: self.winning_skt will not be set. ish. race.
|
|
||||||
pass
|
|
||||||
if self.listener:
|
|
||||||
self.listener.shutdown() # TODO: waits up to 0.5s. push to thread
|
|
||||||
if self.winning_skt:
|
|
||||||
return self.winning_skt
|
|
||||||
raise TransitError("timeout")
|
|
||||||
|
|
||||||
def _connector_failed(self, hint):
|
|
||||||
debug("- failed connector %s" % hint)
|
|
||||||
# XXX this was .remove, and occasionally got KeyError
|
|
||||||
self._active_connectors.discard(hint)
|
|
||||||
if not self._active_connectors:
|
|
||||||
self._start_relay_connectors()
|
|
||||||
|
|
||||||
def _negotiation_finished(self, skt, description):
|
|
||||||
# inbound/outbound sockets call this when they finish negotiation.
|
|
||||||
# The first one wins and gets a "go". Any subsequent ones lose and
|
|
||||||
# get a "nevermind" before being closed.
|
|
||||||
|
|
||||||
with self._negotiation_check_lock:
|
|
||||||
if self.winning_skt:
|
|
||||||
is_winner = False
|
|
||||||
else:
|
|
||||||
is_winner = True
|
|
||||||
self.winning_skt = skt
|
|
||||||
self.winning_skt_description = description
|
|
||||||
|
|
||||||
if is_winner:
|
|
||||||
if self.is_sender:
|
|
||||||
send_to(skt, b"go\n")
|
|
||||||
self.winning.set()
|
|
||||||
else:
|
|
||||||
if self.is_sender:
|
|
||||||
try:
|
|
||||||
send_to(skt, b"nevermind\n")
|
|
||||||
except socket.error:
|
|
||||||
# They realized this connection is not going to win, and
|
|
||||||
# closed it so fast we didn't get a chance to tell them
|
|
||||||
# it lost. This happens in unit tests.
|
|
||||||
pass
|
|
||||||
skt.close()
|
|
||||||
|
|
||||||
def connect(self):
|
|
||||||
_start = self._timing.add_event("transit connect")
|
|
||||||
skt = self.establish_socket()
|
|
||||||
self._timing.finish_event(_start)
|
|
||||||
return RecordPipe(skt, self._sender_record_key(),
|
|
||||||
self._receiver_record_key(),
|
|
||||||
self.winning_skt_description)
|
|
||||||
|
|
||||||
class TransitSender(Common):
|
|
||||||
is_sender = True
|
|
||||||
|
|
||||||
class TransitReceiver(Common):
|
|
||||||
is_sender = False
|
|
|
@ -27,8 +27,6 @@ g.add_argument("--hide-progress", action="store_true",
|
||||||
help="supress progress-bar display")
|
help="supress progress-bar display")
|
||||||
g.add_argument("--dump-timing", type=type(u""), # TODO: hide from --help output
|
g.add_argument("--dump-timing", type=type(u""), # TODO: hide from --help output
|
||||||
metavar="FILE", help="(debug) write timing data to file")
|
metavar="FILE", help="(debug) write timing data to file")
|
||||||
g.add_argument("--twisted", action="store_true",
|
|
||||||
help="use Twisted-based implementations, for testing")
|
|
||||||
g.add_argument("--no-listen", action="store_true",
|
g.add_argument("--no-listen", action="store_true",
|
||||||
help="(debug) don't open a listening socket for Transit")
|
help="(debug) don't open a listening socket for Transit")
|
||||||
g.add_argument("--tor", action="store_true",
|
g.add_argument("--tor", action="store_true",
|
||||||
|
|
|
@ -1,13 +1,19 @@
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
import io, json
|
import io, os, sys, json, binascii, six, tempfile, zipfile
|
||||||
from twisted.internet import reactor, defer
|
from twisted.internet import reactor, defer
|
||||||
from twisted.internet.defer import inlineCallbacks, returnValue
|
from twisted.internet.defer import inlineCallbacks, returnValue
|
||||||
from ..twisted.transcribe import Wormhole, WrongPasswordError
|
from ..twisted.transcribe import Wormhole, WrongPasswordError
|
||||||
from ..twisted.transit import TransitReceiver
|
from ..twisted.transit import TransitReceiver
|
||||||
from .cmd_receive_blocking import BlockingReceiver, RespondError, APPID
|
|
||||||
from ..errors import TransferError
|
from ..errors import TransferError
|
||||||
from .progress import ProgressPrinter
|
from .progress import ProgressPrinter
|
||||||
|
|
||||||
|
|
||||||
|
APPID = u"lothar.com/wormhole/text-or-file-xfer"
|
||||||
|
|
||||||
|
class RespondError(Exception):
|
||||||
|
def __init__(self, response):
|
||||||
|
self.response = response
|
||||||
|
|
||||||
def receive_twisted_sync(args):
|
def receive_twisted_sync(args):
|
||||||
# try to use twisted.internet.task.react(f) here (but it calls sys.exit
|
# try to use twisted.internet.task.react(f) here (but it calls sys.exit
|
||||||
# directly)
|
# directly)
|
||||||
|
@ -34,7 +40,14 @@ def receive_twisted_sync(args):
|
||||||
def receive_twisted(args):
|
def receive_twisted(args):
|
||||||
return TwistedReceiver(args).go()
|
return TwistedReceiver(args).go()
|
||||||
|
|
||||||
class TwistedReceiver(BlockingReceiver):
|
|
||||||
|
class TwistedReceiver:
|
||||||
|
def __init__(self, args):
|
||||||
|
assert isinstance(args.relay_url, type(u""))
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
def msg(self, *args, **kwargs):
|
||||||
|
print(*args, file=self.args.stdout, **kwargs)
|
||||||
|
|
||||||
# TODO: @handle_server_error
|
# TODO: @handle_server_error
|
||||||
@inlineCallbacks
|
@inlineCallbacks
|
||||||
|
@ -101,6 +114,11 @@ class TwistedReceiver(BlockingReceiver):
|
||||||
self.args.code_length)
|
self.args.code_length)
|
||||||
yield w.set_code(code)
|
yield w.set_code(code)
|
||||||
|
|
||||||
|
def show_verifier(self, verifier):
|
||||||
|
verifier_hex = binascii.hexlify(verifier).decode("ascii")
|
||||||
|
if self.args.verify:
|
||||||
|
self.msg(u"Verifier %s." % verifier_hex)
|
||||||
|
|
||||||
@inlineCallbacks
|
@inlineCallbacks
|
||||||
def get_data(self, w):
|
def get_data(self, w):
|
||||||
try:
|
try:
|
||||||
|
@ -119,6 +137,61 @@ class TwistedReceiver(BlockingReceiver):
|
||||||
data = json.dumps({"message_ack": "ok"}).encode("utf-8")
|
data = json.dumps({"message_ack": "ok"}).encode("utf-8")
|
||||||
yield w.send_data(data)
|
yield w.send_data(data)
|
||||||
|
|
||||||
|
def handle_file(self, them_d):
|
||||||
|
file_data = them_d["file"]
|
||||||
|
self.abs_destname = self.decide_destname("file",
|
||||||
|
file_data["filename"])
|
||||||
|
self.xfersize = file_data["filesize"]
|
||||||
|
|
||||||
|
self.msg(u"Receiving file (%d bytes) into: %s" %
|
||||||
|
(self.xfersize, os.path.basename(self.abs_destname)))
|
||||||
|
self.ask_permission()
|
||||||
|
tmp_destname = self.abs_destname + ".tmp"
|
||||||
|
return open(tmp_destname, "wb")
|
||||||
|
|
||||||
|
def handle_directory(self, them_d):
|
||||||
|
file_data = them_d["directory"]
|
||||||
|
zipmode = file_data["mode"]
|
||||||
|
if zipmode != "zipfile/deflated":
|
||||||
|
self.msg(u"Error: unknown directory-transfer mode '%s'" % (zipmode,))
|
||||||
|
raise RespondError({"error": "unknown mode"})
|
||||||
|
self.abs_destname = self.decide_destname("directory",
|
||||||
|
file_data["dirname"])
|
||||||
|
self.xfersize = file_data["zipsize"]
|
||||||
|
|
||||||
|
self.msg(u"Receiving directory (%d bytes) into: %s/" %
|
||||||
|
(self.xfersize, os.path.basename(self.abs_destname)))
|
||||||
|
self.msg(u"%d files, %d bytes (uncompressed)" %
|
||||||
|
(file_data["numfiles"], file_data["numbytes"]))
|
||||||
|
self.ask_permission()
|
||||||
|
return tempfile.SpooledTemporaryFile()
|
||||||
|
|
||||||
|
def decide_destname(self, mode, destname):
|
||||||
|
# the basename() is intended to protect us against
|
||||||
|
# "~/.ssh/authorized_keys" and other attacks
|
||||||
|
destname = os.path.basename(destname)
|
||||||
|
if self.args.output_file:
|
||||||
|
destname = self.args.output_file # override
|
||||||
|
abs_destname = os.path.join(self.args.cwd, destname)
|
||||||
|
|
||||||
|
# get confirmation from the user before writing to the local directory
|
||||||
|
if os.path.exists(abs_destname):
|
||||||
|
self.msg(u"Error: refusing to overwrite existing %s %s" %
|
||||||
|
(mode, destname))
|
||||||
|
raise RespondError({"error": "%s already exists" % mode})
|
||||||
|
return abs_destname
|
||||||
|
|
||||||
|
def ask_permission(self):
|
||||||
|
_start = self.args.timing.add_event("permission", waiting="user")
|
||||||
|
while True and not self.args.accept_file:
|
||||||
|
ok = six.moves.input("ok? (y/n): ")
|
||||||
|
if ok.lower().startswith("y"):
|
||||||
|
break
|
||||||
|
print(u"transfer rejected", file=sys.stderr)
|
||||||
|
self.args.timing.finish_event(_start, answer="no")
|
||||||
|
raise RespondError({"error": "transfer rejected"})
|
||||||
|
self.args.timing.finish_event(_start, answer="yes")
|
||||||
|
|
||||||
@inlineCallbacks
|
@inlineCallbacks
|
||||||
def establish_transit(self, w, them_d, tor_manager):
|
def establish_transit(self, w, them_d, tor_manager):
|
||||||
transit_key = w.derive_key(APPID+u"/transit-key")
|
transit_key = w.derive_key(APPID+u"/transit-key")
|
||||||
|
@ -169,6 +242,27 @@ class TwistedReceiver(BlockingReceiver):
|
||||||
returnValue(1) # TODO: exit properly
|
returnValue(1) # TODO: exit properly
|
||||||
assert received == self.xfersize
|
assert received == self.xfersize
|
||||||
|
|
||||||
|
def write_file(self, f):
|
||||||
|
tmp_name = f.name
|
||||||
|
f.close()
|
||||||
|
os.rename(tmp_name, self.abs_destname)
|
||||||
|
self.msg(u"Received file written to %s" %
|
||||||
|
os.path.basename(self.abs_destname))
|
||||||
|
|
||||||
|
def write_directory(self, f):
|
||||||
|
self.msg(u"Unpacking zipfile..")
|
||||||
|
_start = self.args.timing.add_event("unpack zip")
|
||||||
|
with zipfile.ZipFile(f, "r", zipfile.ZIP_DEFLATED) as zf:
|
||||||
|
zf.extractall(path=self.abs_destname)
|
||||||
|
# extractall() appears to offer some protection against
|
||||||
|
# malicious pathnames. For example, "/tmp/oops" and
|
||||||
|
# "../tmp/oops" both do the same thing as the (safe)
|
||||||
|
# "tmp/oops".
|
||||||
|
self.msg(u"Received files written to %s/" %
|
||||||
|
os.path.basename(self.abs_destname))
|
||||||
|
f.close()
|
||||||
|
self.args.timing.finish_event(_start)
|
||||||
|
|
||||||
@inlineCallbacks
|
@inlineCallbacks
|
||||||
def close_transit(self, record_pipe):
|
def close_transit(self, record_pipe):
|
||||||
_start = self.args.timing.add_event("ack")
|
_start = self.args.timing.add_event("ack")
|
|
@ -1,216 +0,0 @@
|
||||||
from __future__ import print_function
|
|
||||||
import io, os, sys, json, binascii, six, tempfile, zipfile
|
|
||||||
from ..blocking.transcribe import Wormhole, WrongPasswordError
|
|
||||||
from ..blocking.transit import TransitReceiver, TransitError
|
|
||||||
from ..errors import handle_server_error, TransferError
|
|
||||||
from .progress import ProgressPrinter
|
|
||||||
|
|
||||||
APPID = u"lothar.com/wormhole/text-or-file-xfer"
|
|
||||||
|
|
||||||
def receive_blocking(args):
|
|
||||||
return BlockingReceiver(args).go()
|
|
||||||
|
|
||||||
class RespondError(Exception):
|
|
||||||
def __init__(self, response):
|
|
||||||
self.response = response
|
|
||||||
|
|
||||||
class BlockingReceiver:
|
|
||||||
def __init__(self, args):
|
|
||||||
assert isinstance(args.relay_url, type(u""))
|
|
||||||
self.args = args
|
|
||||||
|
|
||||||
def msg(self, *args, **kwargs):
|
|
||||||
print(*args, file=self.args.stdout, **kwargs)
|
|
||||||
|
|
||||||
@handle_server_error
|
|
||||||
def go(self):
|
|
||||||
with Wormhole(APPID, self.args.relay_url, timing=self.args.timing) as w:
|
|
||||||
self.handle_code(w)
|
|
||||||
verifier = w.get_verifier()
|
|
||||||
self.show_verifier(verifier)
|
|
||||||
them_d = self.get_data(w)
|
|
||||||
try:
|
|
||||||
if "message" in them_d:
|
|
||||||
self.handle_text(them_d, w)
|
|
||||||
return 0
|
|
||||||
if "file" in them_d:
|
|
||||||
f = self.handle_file(them_d)
|
|
||||||
rp = self.establish_transit(w, them_d)
|
|
||||||
self.transfer_data(rp, f)
|
|
||||||
self.write_file(f)
|
|
||||||
self.close_transit(rp)
|
|
||||||
elif "directory" in them_d:
|
|
||||||
f = self.handle_directory(them_d)
|
|
||||||
rp = self.establish_transit(w, them_d)
|
|
||||||
self.transfer_data(rp, f)
|
|
||||||
self.write_directory(f)
|
|
||||||
self.close_transit(rp)
|
|
||||||
else:
|
|
||||||
self.msg(u"I don't know what they're offering\n")
|
|
||||||
self.msg(u"Offer details:", them_d)
|
|
||||||
raise RespondError({"error": "unknown offer type"})
|
|
||||||
except RespondError as r:
|
|
||||||
data = json.dumps(r.response).encode("utf-8")
|
|
||||||
w.send_data(data)
|
|
||||||
return 1
|
|
||||||
return 0
|
|
||||||
|
|
||||||
def handle_code(self, w):
|
|
||||||
code = self.args.code
|
|
||||||
if self.args.zeromode:
|
|
||||||
assert not code
|
|
||||||
code = u"0-"
|
|
||||||
if not code:
|
|
||||||
code = w.input_code("Enter receive wormhole code: ",
|
|
||||||
self.args.code_length)
|
|
||||||
w.set_code(code)
|
|
||||||
|
|
||||||
def show_verifier(self, verifier):
|
|
||||||
verifier_hex = binascii.hexlify(verifier).decode("ascii")
|
|
||||||
if self.args.verify:
|
|
||||||
self.msg(u"Verifier %s." % verifier_hex)
|
|
||||||
|
|
||||||
def get_data(self, w):
|
|
||||||
try:
|
|
||||||
them_bytes = w.get_data()
|
|
||||||
except WrongPasswordError as e:
|
|
||||||
raise TransferError(u"ERROR: " + e.explain())
|
|
||||||
them_d = json.loads(them_bytes.decode("utf-8"))
|
|
||||||
if "error" in them_d:
|
|
||||||
raise TransferError(u"ERROR: " + them_d["error"])
|
|
||||||
return them_d
|
|
||||||
|
|
||||||
def handle_text(self, them_d, w):
|
|
||||||
# we're receiving a text message
|
|
||||||
self.msg(them_d["message"])
|
|
||||||
data = json.dumps({"message_ack": "ok"}).encode("utf-8")
|
|
||||||
w.send_data(data)
|
|
||||||
|
|
||||||
def handle_file(self, them_d):
|
|
||||||
file_data = them_d["file"]
|
|
||||||
self.abs_destname = self.decide_destname("file",
|
|
||||||
file_data["filename"])
|
|
||||||
self.xfersize = file_data["filesize"]
|
|
||||||
|
|
||||||
self.msg(u"Receiving file (%d bytes) into: %s" %
|
|
||||||
(self.xfersize, os.path.basename(self.abs_destname)))
|
|
||||||
self.ask_permission()
|
|
||||||
tmp_destname = self.abs_destname + ".tmp"
|
|
||||||
return open(tmp_destname, "wb")
|
|
||||||
|
|
||||||
def handle_directory(self, them_d):
|
|
||||||
file_data = them_d["directory"]
|
|
||||||
zipmode = file_data["mode"]
|
|
||||||
if zipmode != "zipfile/deflated":
|
|
||||||
self.msg(u"Error: unknown directory-transfer mode '%s'" % (zipmode,))
|
|
||||||
raise RespondError({"error": "unknown mode"})
|
|
||||||
self.abs_destname = self.decide_destname("directory",
|
|
||||||
file_data["dirname"])
|
|
||||||
self.xfersize = file_data["zipsize"]
|
|
||||||
|
|
||||||
self.msg(u"Receiving directory (%d bytes) into: %s/" %
|
|
||||||
(self.xfersize, os.path.basename(self.abs_destname)))
|
|
||||||
self.msg(u"%d files, %d bytes (uncompressed)" %
|
|
||||||
(file_data["numfiles"], file_data["numbytes"]))
|
|
||||||
self.ask_permission()
|
|
||||||
return tempfile.SpooledTemporaryFile()
|
|
||||||
|
|
||||||
def decide_destname(self, mode, destname):
|
|
||||||
# the basename() is intended to protect us against
|
|
||||||
# "~/.ssh/authorized_keys" and other attacks
|
|
||||||
destname = os.path.basename(destname)
|
|
||||||
if self.args.output_file:
|
|
||||||
destname = self.args.output_file # override
|
|
||||||
abs_destname = os.path.join(self.args.cwd, destname)
|
|
||||||
|
|
||||||
# get confirmation from the user before writing to the local directory
|
|
||||||
if os.path.exists(abs_destname):
|
|
||||||
self.msg(u"Error: refusing to overwrite existing %s %s" %
|
|
||||||
(mode, destname))
|
|
||||||
raise RespondError({"error": "%s already exists" % mode})
|
|
||||||
return abs_destname
|
|
||||||
|
|
||||||
def ask_permission(self):
|
|
||||||
_start = self.args.timing.add_event("permission", waiting="user")
|
|
||||||
while True and not self.args.accept_file:
|
|
||||||
ok = six.moves.input("ok? (y/n): ")
|
|
||||||
if ok.lower().startswith("y"):
|
|
||||||
break
|
|
||||||
print(u"transfer rejected", file=sys.stderr)
|
|
||||||
self.args.timing.finish_event(_start, answer="no")
|
|
||||||
raise RespondError({"error": "transfer rejected"})
|
|
||||||
self.args.timing.finish_event(_start, answer="yes")
|
|
||||||
|
|
||||||
def establish_transit(self, w, them_d):
|
|
||||||
transit_key = w.derive_key(APPID+u"/transit-key")
|
|
||||||
transit_receiver = TransitReceiver(self.args.transit_helper,
|
|
||||||
no_listen=self.args.no_listen,
|
|
||||||
timing=self.args.timing)
|
|
||||||
transit_receiver.set_transit_key(transit_key)
|
|
||||||
data = json.dumps({
|
|
||||||
"file_ack": "ok",
|
|
||||||
"transit": {
|
|
||||||
"direct_connection_hints": transit_receiver.get_direct_hints(),
|
|
||||||
"relay_connection_hints": transit_receiver.get_relay_hints(),
|
|
||||||
},
|
|
||||||
}).encode("utf-8")
|
|
||||||
w.send_data(data)
|
|
||||||
|
|
||||||
# now receive the rest of the owl
|
|
||||||
tdata = them_d["transit"]
|
|
||||||
transit_receiver.add_their_direct_hints(tdata["direct_connection_hints"])
|
|
||||||
transit_receiver.add_their_relay_hints(tdata["relay_connection_hints"])
|
|
||||||
record_pipe = transit_receiver.connect()
|
|
||||||
return record_pipe
|
|
||||||
|
|
||||||
def transfer_data(self, record_pipe, f):
|
|
||||||
self.msg(u"Receiving (%s).." % record_pipe.describe())
|
|
||||||
|
|
||||||
_start = self.args.timing.add_event("rx file")
|
|
||||||
progress_stdout = self.args.stdout
|
|
||||||
if self.args.hide_progress:
|
|
||||||
progress_stdout = io.StringIO()
|
|
||||||
received = 0
|
|
||||||
p = ProgressPrinter(self.xfersize, progress_stdout)
|
|
||||||
p.start()
|
|
||||||
while received < self.xfersize:
|
|
||||||
try:
|
|
||||||
plaintext = record_pipe.receive_record()
|
|
||||||
except TransitError:
|
|
||||||
self.msg()
|
|
||||||
self.msg(u"Connection dropped before full file received")
|
|
||||||
self.msg(u"got %d bytes, wanted %d" % (received, self.xfersize))
|
|
||||||
return 1
|
|
||||||
f.write(plaintext)
|
|
||||||
received += len(plaintext)
|
|
||||||
p.update(received)
|
|
||||||
p.finish()
|
|
||||||
self.args.timing.finish_event(_start)
|
|
||||||
assert received == self.xfersize
|
|
||||||
|
|
||||||
def write_file(self, f):
|
|
||||||
tmp_name = f.name
|
|
||||||
f.close()
|
|
||||||
os.rename(tmp_name, self.abs_destname)
|
|
||||||
self.msg(u"Received file written to %s" %
|
|
||||||
os.path.basename(self.abs_destname))
|
|
||||||
|
|
||||||
def write_directory(self, f):
|
|
||||||
self.msg(u"Unpacking zipfile..")
|
|
||||||
_start = self.args.timing.add_event("unpack zip")
|
|
||||||
with zipfile.ZipFile(f, "r", zipfile.ZIP_DEFLATED) as zf:
|
|
||||||
zf.extractall(path=self.abs_destname)
|
|
||||||
# extractall() appears to offer some protection against
|
|
||||||
# malicious pathnames. For example, "/tmp/oops" and
|
|
||||||
# "../tmp/oops" both do the same thing as the (safe)
|
|
||||||
# "tmp/oops".
|
|
||||||
self.msg(u"Received files written to %s/" %
|
|
||||||
os.path.basename(self.abs_destname))
|
|
||||||
f.close()
|
|
||||||
self.args.timing.finish_event(_start)
|
|
||||||
|
|
||||||
def close_transit(self, record_pipe):
|
|
||||||
_start = self.args.timing.add_event("ack")
|
|
||||||
record_pipe.send_record(b"ok\n")
|
|
||||||
record_pipe.close()
|
|
||||||
self.args.timing.finish_event(_start)
|
|
|
@ -1,5 +1,5 @@
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
import io, json, binascii, six
|
import os, sys, io, json, binascii, six, tempfile, zipfile
|
||||||
from twisted.protocols import basic
|
from twisted.protocols import basic
|
||||||
from twisted.internet import reactor, defer
|
from twisted.internet import reactor, defer
|
||||||
from twisted.internet.defer import inlineCallbacks, returnValue
|
from twisted.internet.defer import inlineCallbacks, returnValue
|
||||||
|
@ -7,8 +7,94 @@ from ..errors import TransferError
|
||||||
from .progress import ProgressPrinter
|
from .progress import ProgressPrinter
|
||||||
from ..twisted.transcribe import Wormhole, WrongPasswordError
|
from ..twisted.transcribe import Wormhole, WrongPasswordError
|
||||||
from ..twisted.transit import TransitSender
|
from ..twisted.transit import TransitSender
|
||||||
from .send_common import (APPID, handle_zero, build_other_command,
|
|
||||||
build_phase1_data)
|
APPID = u"lothar.com/wormhole/text-or-file-xfer"
|
||||||
|
|
||||||
|
def handle_zero(args):
|
||||||
|
if args.zeromode:
|
||||||
|
assert not args.code
|
||||||
|
args.code = u"0-"
|
||||||
|
|
||||||
|
def build_other_command(args):
|
||||||
|
other_cmd = "wormhole receive"
|
||||||
|
if args.verify:
|
||||||
|
other_cmd = "wormhole --verify receive"
|
||||||
|
if args.zeromode:
|
||||||
|
other_cmd += " -0"
|
||||||
|
return other_cmd
|
||||||
|
|
||||||
|
def build_phase1_data(args):
|
||||||
|
phase1 = {}
|
||||||
|
|
||||||
|
text = args.text
|
||||||
|
if text == "-":
|
||||||
|
print(u"Reading text message from stdin..", file=args.stdout)
|
||||||
|
text = sys.stdin.read()
|
||||||
|
if not text and not args.what:
|
||||||
|
text = six.moves.input("Text to send: ")
|
||||||
|
|
||||||
|
if text is not None:
|
||||||
|
print(u"Sending text message (%d bytes)" % len(text), file=args.stdout)
|
||||||
|
phase1 = { "message": text }
|
||||||
|
fd_to_send = None
|
||||||
|
return phase1, fd_to_send
|
||||||
|
|
||||||
|
what = os.path.join(args.cwd, args.what)
|
||||||
|
what = what.rstrip(os.sep)
|
||||||
|
if not os.path.exists(what):
|
||||||
|
raise TransferError("Cannot send: no file/directory named '%s'" %
|
||||||
|
args.what)
|
||||||
|
basename = os.path.basename(what)
|
||||||
|
|
||||||
|
if os.path.isfile(what):
|
||||||
|
# we're sending a file
|
||||||
|
filesize = os.stat(what).st_size
|
||||||
|
phase1["file"] = {
|
||||||
|
"filename": basename,
|
||||||
|
"filesize": filesize,
|
||||||
|
}
|
||||||
|
print(u"Sending %d byte file named '%s'" % (filesize, basename),
|
||||||
|
file=args.stdout)
|
||||||
|
fd_to_send = open(what, "rb")
|
||||||
|
return phase1, fd_to_send
|
||||||
|
|
||||||
|
if os.path.isdir(what):
|
||||||
|
print(u"Building zipfile..", file=args.stdout)
|
||||||
|
# We're sending a directory. Create a zipfile in a tempdir and
|
||||||
|
# send that.
|
||||||
|
fd_to_send = tempfile.SpooledTemporaryFile()
|
||||||
|
# TODO: I think ZIP_DEFLATED means compressed.. check it
|
||||||
|
num_files = 0
|
||||||
|
num_bytes = 0
|
||||||
|
tostrip = len(what.split(os.sep))
|
||||||
|
with zipfile.ZipFile(fd_to_send, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||||
|
for path,dirs,files in os.walk(what):
|
||||||
|
# path always starts with args.what, then sometimes might
|
||||||
|
# have "/subdir" appended. We want the zipfile to contain
|
||||||
|
# "" or "subdir"
|
||||||
|
localpath = list(path.split(os.sep)[tostrip:])
|
||||||
|
for fn in files:
|
||||||
|
archivename = os.path.join(*tuple(localpath+[fn]))
|
||||||
|
localfilename = os.path.join(path, fn)
|
||||||
|
zf.write(localfilename, archivename)
|
||||||
|
num_bytes += os.stat(localfilename).st_size
|
||||||
|
num_files += 1
|
||||||
|
fd_to_send.seek(0,2)
|
||||||
|
filesize = fd_to_send.tell()
|
||||||
|
fd_to_send.seek(0,0)
|
||||||
|
phase1["directory"] = {
|
||||||
|
"mode": "zipfile/deflated",
|
||||||
|
"dirname": basename,
|
||||||
|
"zipsize": filesize,
|
||||||
|
"numbytes": num_bytes,
|
||||||
|
"numfiles": num_files,
|
||||||
|
}
|
||||||
|
print(u"Sending directory (%d bytes compressed) named '%s'"
|
||||||
|
% (filesize, basename), file=args.stdout)
|
||||||
|
return phase1, fd_to_send
|
||||||
|
|
||||||
|
raise TypeError("'%s' is neither file nor directory" % args.what)
|
||||||
|
|
||||||
|
|
||||||
def send_twisted_sync(args):
|
def send_twisted_sync(args):
|
||||||
# try to use twisted.internet.task.react(f) here (but it calls sys.exit
|
# try to use twisted.internet.task.react(f) here (but it calls sys.exit
|
|
@ -1,129 +0,0 @@
|
||||||
from __future__ import print_function
|
|
||||||
import json, binascii, six
|
|
||||||
from ..errors import TransferError
|
|
||||||
from .progress import ProgressPrinter
|
|
||||||
from ..blocking.transcribe import Wormhole, WrongPasswordError
|
|
||||||
from ..blocking.transit import TransitSender
|
|
||||||
from ..errors import handle_server_error
|
|
||||||
from .send_common import (APPID, handle_zero, build_other_command,
|
|
||||||
build_phase1_data)
|
|
||||||
|
|
||||||
@handle_server_error
|
|
||||||
def send_blocking(args):
|
|
||||||
assert isinstance(args.relay_url, type(u""))
|
|
||||||
handle_zero(args)
|
|
||||||
phase1, fd_to_send = build_phase1_data(args)
|
|
||||||
other_cmd = build_other_command(args)
|
|
||||||
print(u"On the other computer, please run: %s" % other_cmd,
|
|
||||||
file=args.stdout)
|
|
||||||
|
|
||||||
if fd_to_send is not None:
|
|
||||||
transit_sender = TransitSender(args.transit_helper,
|
|
||||||
no_listen=args.no_listen,
|
|
||||||
timing=args.timing)
|
|
||||||
transit_data = {
|
|
||||||
"direct_connection_hints": transit_sender.get_direct_hints(),
|
|
||||||
"relay_connection_hints": transit_sender.get_relay_hints(),
|
|
||||||
}
|
|
||||||
phase1["transit"] = transit_data
|
|
||||||
|
|
||||||
with Wormhole(APPID, args.relay_url, timing=args.timing) as w:
|
|
||||||
if args.code:
|
|
||||||
w.set_code(args.code)
|
|
||||||
code = args.code
|
|
||||||
else:
|
|
||||||
code = w.get_code(args.code_length)
|
|
||||||
if not args.zeromode:
|
|
||||||
print(u"Wormhole code is: %s" % code, file=args.stdout)
|
|
||||||
print(u"", file=args.stdout)
|
|
||||||
|
|
||||||
# get the verifier, because that also lets us derive the transit key,
|
|
||||||
# which we want to set before revealing the connection hints to the
|
|
||||||
# far side, so we'll be ready for them when they connect
|
|
||||||
verifier = binascii.hexlify(w.get_verifier()).decode("ascii")
|
|
||||||
if args.verify:
|
|
||||||
_do_verify(verifier, w)
|
|
||||||
|
|
||||||
if fd_to_send is not None:
|
|
||||||
transit_key = w.derive_key(APPID+"/transit-key")
|
|
||||||
transit_sender.set_transit_key(transit_key)
|
|
||||||
|
|
||||||
my_phase1_bytes = json.dumps(phase1).encode("utf-8")
|
|
||||||
w.send_data(my_phase1_bytes)
|
|
||||||
try:
|
|
||||||
them_phase1_bytes = w.get_data()
|
|
||||||
except WrongPasswordError as e:
|
|
||||||
raise TransferError(e.explain())
|
|
||||||
|
|
||||||
them_phase1 = json.loads(them_phase1_bytes.decode("utf-8"))
|
|
||||||
|
|
||||||
if fd_to_send is None:
|
|
||||||
if them_phase1["message_ack"] == "ok":
|
|
||||||
print(u"text message sent", file=args.stdout)
|
|
||||||
return 0
|
|
||||||
raise TransferError("error sending text: %r" % (them_phase1,))
|
|
||||||
|
|
||||||
return _send_file_blocking(them_phase1, fd_to_send,
|
|
||||||
transit_sender, args.stdout, args.hide_progress,
|
|
||||||
args.timing)
|
|
||||||
|
|
||||||
def _do_verify(verifier, w):
|
|
||||||
while True:
|
|
||||||
ok = six.moves.input("Verifier %s. ok? (yes/no): " % verifier)
|
|
||||||
if ok.lower() == "yes":
|
|
||||||
break
|
|
||||||
if ok.lower() == "no":
|
|
||||||
reject_data = json.dumps({"error": "verification rejected",
|
|
||||||
}).encode("utf-8")
|
|
||||||
w.send_data(reject_data)
|
|
||||||
raise TransferError("verification rejected, abandoning transfer")
|
|
||||||
|
|
||||||
def _send_file_blocking(them_phase1, fd_to_send, transit_sender,
|
|
||||||
stdout, hide_progress, timing):
|
|
||||||
|
|
||||||
# we're sending a file, if they accept it
|
|
||||||
|
|
||||||
if "error" in them_phase1:
|
|
||||||
raise TransferError("remote error, transfer abandoned: %s"
|
|
||||||
% them_phase1["error"])
|
|
||||||
if them_phase1.get("file_ack") != "ok":
|
|
||||||
raise TransferError("ambiguous response from remote, "
|
|
||||||
"transfer abandoned: %s" % (them_phase1,))
|
|
||||||
|
|
||||||
tdata = them_phase1["transit"]
|
|
||||||
transit_sender.add_their_direct_hints(tdata["direct_connection_hints"])
|
|
||||||
transit_sender.add_their_relay_hints(tdata["relay_connection_hints"])
|
|
||||||
record_pipe = transit_sender.connect()
|
|
||||||
|
|
||||||
print(u"Sending (%s).." % record_pipe.describe(), file=stdout)
|
|
||||||
|
|
||||||
_start = timing.add_event("tx file")
|
|
||||||
CHUNKSIZE = 64*1024
|
|
||||||
fd_to_send.seek(0,2)
|
|
||||||
filesize = fd_to_send.tell()
|
|
||||||
fd_to_send.seek(0,0)
|
|
||||||
p = ProgressPrinter(filesize, stdout)
|
|
||||||
with fd_to_send as f:
|
|
||||||
sent = 0
|
|
||||||
if not hide_progress:
|
|
||||||
p.start()
|
|
||||||
while sent < filesize:
|
|
||||||
plaintext = f.read(CHUNKSIZE)
|
|
||||||
record_pipe.send_record(plaintext)
|
|
||||||
sent += len(plaintext)
|
|
||||||
if not hide_progress:
|
|
||||||
p.update(sent)
|
|
||||||
if not hide_progress:
|
|
||||||
p.finish()
|
|
||||||
timing.finish_event(_start)
|
|
||||||
|
|
||||||
_start = timing.add_event("get ack")
|
|
||||||
print(u"File sent.. waiting for confirmation", file=stdout)
|
|
||||||
ack = record_pipe.receive_record()
|
|
||||||
record_pipe.close()
|
|
||||||
if ack == b"ok\n":
|
|
||||||
print(u"Confirmation received. Transfer complete.", file=stdout)
|
|
||||||
timing.finish_event(_start, ack="ok")
|
|
||||||
return 0
|
|
||||||
timing.finish_event(_start, ack="failed")
|
|
||||||
raise TransferError("Transfer failed (remote says: %r)" % ack)
|
|
|
@ -21,22 +21,14 @@ def dispatch(args):
|
||||||
from ..servers import cmd_usage
|
from ..servers import cmd_usage
|
||||||
return cmd_usage.tail_usage(args)
|
return cmd_usage.tail_usage(args)
|
||||||
|
|
||||||
if args.tor:
|
|
||||||
args.twisted = True
|
|
||||||
if args.func == "send/send":
|
if args.func == "send/send":
|
||||||
if args.twisted:
|
from . import cmd_send
|
||||||
from . import cmd_send_twisted
|
return cmd_send.send_twisted_sync(args)
|
||||||
return cmd_send_twisted.send_twisted_sync(args)
|
|
||||||
from . import cmd_send_blocking
|
|
||||||
return cmd_send_blocking.send_blocking(args)
|
|
||||||
if args.func == "receive/receive":
|
if args.func == "receive/receive":
|
||||||
if args.twisted:
|
|
||||||
_start = args.timing.add_event("import c_r_t")
|
_start = args.timing.add_event("import c_r_t")
|
||||||
from . import cmd_receive_twisted
|
from . import cmd_receive
|
||||||
args.timing.finish_event(_start)
|
args.timing.finish_event(_start)
|
||||||
return cmd_receive_twisted.receive_twisted_sync(args)
|
return cmd_receive.receive_twisted_sync(args)
|
||||||
from . import cmd_receive_blocking
|
|
||||||
return cmd_receive_blocking.receive_blocking(args)
|
|
||||||
|
|
||||||
raise ValueError("unknown args.func %s" % args.func)
|
raise ValueError("unknown args.func %s" % args.func)
|
||||||
|
|
||||||
|
|
|
@ -1,90 +0,0 @@
|
||||||
from __future__ import print_function
|
|
||||||
import os, sys, six, tempfile, zipfile
|
|
||||||
from ..errors import TransferError
|
|
||||||
|
|
||||||
APPID = u"lothar.com/wormhole/text-or-file-xfer"
|
|
||||||
|
|
||||||
def handle_zero(args):
|
|
||||||
if args.zeromode:
|
|
||||||
assert not args.code
|
|
||||||
args.code = u"0-"
|
|
||||||
|
|
||||||
def build_other_command(args):
|
|
||||||
other_cmd = "wormhole receive"
|
|
||||||
if args.verify:
|
|
||||||
other_cmd = "wormhole --verify receive"
|
|
||||||
if args.zeromode:
|
|
||||||
other_cmd += " -0"
|
|
||||||
return other_cmd
|
|
||||||
|
|
||||||
def build_phase1_data(args):
|
|
||||||
phase1 = {}
|
|
||||||
|
|
||||||
text = args.text
|
|
||||||
if text == "-":
|
|
||||||
print(u"Reading text message from stdin..", file=args.stdout)
|
|
||||||
text = sys.stdin.read()
|
|
||||||
if not text and not args.what:
|
|
||||||
text = six.moves.input("Text to send: ")
|
|
||||||
|
|
||||||
if text is not None:
|
|
||||||
print(u"Sending text message (%d bytes)" % len(text), file=args.stdout)
|
|
||||||
phase1 = { "message": text }
|
|
||||||
fd_to_send = None
|
|
||||||
return phase1, fd_to_send
|
|
||||||
|
|
||||||
what = os.path.join(args.cwd, args.what)
|
|
||||||
what = what.rstrip(os.sep)
|
|
||||||
if not os.path.exists(what):
|
|
||||||
raise TransferError("Cannot send: no file/directory named '%s'" %
|
|
||||||
args.what)
|
|
||||||
basename = os.path.basename(what)
|
|
||||||
|
|
||||||
if os.path.isfile(what):
|
|
||||||
# we're sending a file
|
|
||||||
filesize = os.stat(what).st_size
|
|
||||||
phase1["file"] = {
|
|
||||||
"filename": basename,
|
|
||||||
"filesize": filesize,
|
|
||||||
}
|
|
||||||
print(u"Sending %d byte file named '%s'" % (filesize, basename),
|
|
||||||
file=args.stdout)
|
|
||||||
fd_to_send = open(what, "rb")
|
|
||||||
return phase1, fd_to_send
|
|
||||||
|
|
||||||
if os.path.isdir(what):
|
|
||||||
print(u"Building zipfile..", file=args.stdout)
|
|
||||||
# We're sending a directory. Create a zipfile in a tempdir and
|
|
||||||
# send that.
|
|
||||||
fd_to_send = tempfile.SpooledTemporaryFile()
|
|
||||||
# TODO: I think ZIP_DEFLATED means compressed.. check it
|
|
||||||
num_files = 0
|
|
||||||
num_bytes = 0
|
|
||||||
tostrip = len(what.split(os.sep))
|
|
||||||
with zipfile.ZipFile(fd_to_send, "w", zipfile.ZIP_DEFLATED) as zf:
|
|
||||||
for path,dirs,files in os.walk(what):
|
|
||||||
# path always starts with args.what, then sometimes might
|
|
||||||
# have "/subdir" appended. We want the zipfile to contain
|
|
||||||
# "" or "subdir"
|
|
||||||
localpath = list(path.split(os.sep)[tostrip:])
|
|
||||||
for fn in files:
|
|
||||||
archivename = os.path.join(*tuple(localpath+[fn]))
|
|
||||||
localfilename = os.path.join(path, fn)
|
|
||||||
zf.write(localfilename, archivename)
|
|
||||||
num_bytes += os.stat(localfilename).st_size
|
|
||||||
num_files += 1
|
|
||||||
fd_to_send.seek(0,2)
|
|
||||||
filesize = fd_to_send.tell()
|
|
||||||
fd_to_send.seek(0,0)
|
|
||||||
phase1["directory"] = {
|
|
||||||
"mode": "zipfile/deflated",
|
|
||||||
"dirname": basename,
|
|
||||||
"zipsize": filesize,
|
|
||||||
"numbytes": num_bytes,
|
|
||||||
"numfiles": num_files,
|
|
||||||
}
|
|
||||||
print(u"Sending directory (%d bytes compressed) named '%s'"
|
|
||||||
% (filesize, basename), file=args.stdout)
|
|
||||||
return phase1, fd_to_send
|
|
||||||
|
|
||||||
raise TypeError("'%s' is neither file nor directory" % args.what)
|
|
|
@ -1,7 +1,7 @@
|
||||||
from twisted.application import service
|
from twisted.application import service
|
||||||
from twisted.internet import reactor, defer
|
from twisted.internet import reactor, defer
|
||||||
from twisted.python import log
|
from twisted.python import log
|
||||||
from ..twisted.util import allocate_ports
|
from ..twisted.transit import allocate_tcp_port
|
||||||
from ..servers.server import RelayServer
|
from ..servers.server import RelayServer
|
||||||
from .. import __version__
|
from .. import __version__
|
||||||
|
|
||||||
|
@ -9,9 +9,8 @@ class ServerBase:
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.sp = service.MultiService()
|
self.sp = service.MultiService()
|
||||||
self.sp.startService()
|
self.sp.startService()
|
||||||
d = allocate_ports()
|
relayport = allocate_tcp_port()
|
||||||
def _got_ports(ports):
|
transitport = allocate_tcp_port()
|
||||||
relayport, transitport = ports
|
|
||||||
s = RelayServer("tcp:%d:interface=127.0.0.1" % relayport,
|
s = RelayServer("tcp:%d:interface=127.0.0.1" % relayport,
|
||||||
"tcp:%s:interface=127.0.0.1" % transitport,
|
"tcp:%s:interface=127.0.0.1" % transitport,
|
||||||
__version__)
|
__version__)
|
||||||
|
@ -20,8 +19,6 @@ class ServerBase:
|
||||||
self._transit_server = s.transit
|
self._transit_server = s.transit
|
||||||
self.relayurl = u"http://127.0.0.1:%d/wormhole-relay/" % relayport
|
self.relayurl = u"http://127.0.0.1:%d/wormhole-relay/" % relayport
|
||||||
self.transit = u"tcp:127.0.0.1:%d" % transitport
|
self.transit = u"tcp:127.0.0.1:%d" % transitport
|
||||||
d.addCallback(_got_ports)
|
|
||||||
return d
|
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
# Unit tests that spawn a (blocking) client in a thread might still
|
# Unit tests that spawn a (blocking) client in a thread might still
|
||||||
|
|
|
@ -1,14 +1,11 @@
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
import json
|
import json
|
||||||
from twisted.trial import unittest
|
from twisted.trial import unittest
|
||||||
from twisted.internet.defer import gatherResults, succeed, inlineCallbacks
|
from twisted.internet.defer import gatherResults, succeed
|
||||||
from twisted.internet.threads import deferToThread
|
from twisted.internet.threads import deferToThread
|
||||||
from ..blocking.transcribe import (Wormhole, UsageError, ChannelManager,
|
from ..blocking.transcribe import (Wormhole, UsageError, ChannelManager,
|
||||||
WrongPasswordError)
|
WrongPasswordError)
|
||||||
from ..blocking.eventsource import EventSourceFollower
|
from ..blocking.eventsource import EventSourceFollower
|
||||||
from ..blocking.transit import (TransitSender, TransitReceiver,
|
|
||||||
build_sender_handshake,
|
|
||||||
build_receiver_handshake)
|
|
||||||
from .common import ServerBase
|
from .common import ServerBase
|
||||||
|
|
||||||
APPID = u"appid"
|
APPID = u"appid"
|
||||||
|
@ -447,101 +444,3 @@ class EventSourceClient(unittest.TestCase):
|
||||||
(u"message", u"three"),
|
(u"message", u"three"),
|
||||||
(u"e2", u"four"),
|
(u"e2", u"four"),
|
||||||
])
|
])
|
||||||
|
|
||||||
class Transit(_DoBothMixin, ServerBase, unittest.TestCase):
|
|
||||||
def test_hints(self):
|
|
||||||
r = TransitReceiver(self.transit)
|
|
||||||
hints = r.get_direct_hints()
|
|
||||||
self.assertTrue(len(hints), hints)
|
|
||||||
|
|
||||||
@inlineCallbacks
|
|
||||||
def test_direct_to_receiver(self):
|
|
||||||
s = TransitSender(self.transit)
|
|
||||||
r = TransitReceiver(self.transit)
|
|
||||||
key = b"\x00"*32
|
|
||||||
|
|
||||||
# force the connection to be sender->receiver
|
|
||||||
s.set_transit_key(key)
|
|
||||||
# only use 127.0.0.1
|
|
||||||
hint = u"tcp:127.0.0.1:%d" % r.listener.server_address[1]
|
|
||||||
s.add_their_direct_hints([hint])
|
|
||||||
s.add_their_relay_hints([])
|
|
||||||
r.set_transit_key(key)
|
|
||||||
r.add_their_direct_hints([])
|
|
||||||
r.add_their_relay_hints([])
|
|
||||||
|
|
||||||
# it'd be nice to factor this chunk out with 'yield from', but that
|
|
||||||
# didn't appear until python-3.3, and isn't in py2 at all.
|
|
||||||
(sp, rp) = yield self.doBoth([s.connect], [r.connect])
|
|
||||||
yield deferToThread(sp.send_record, b"01234")
|
|
||||||
rec = yield deferToThread(rp.receive_record)
|
|
||||||
self.assertEqual(rec, b"01234")
|
|
||||||
yield deferToThread(sp.close)
|
|
||||||
yield deferToThread(rp.close)
|
|
||||||
|
|
||||||
@inlineCallbacks
|
|
||||||
def test_direct_to_sender(self):
|
|
||||||
s = TransitSender(self.transit)
|
|
||||||
r = TransitReceiver(self.transit)
|
|
||||||
key = b"\x00"*32
|
|
||||||
|
|
||||||
# force the connection to be receiver->sender
|
|
||||||
s.set_transit_key(key)
|
|
||||||
s.add_their_direct_hints([])
|
|
||||||
s.add_their_relay_hints([])
|
|
||||||
r.set_transit_key(key)
|
|
||||||
hint = u"tcp:127.0.0.1:%d" % s.listener.server_address[1]
|
|
||||||
r.add_their_direct_hints([hint])
|
|
||||||
r.add_their_relay_hints([])
|
|
||||||
|
|
||||||
(sp, rp) = yield self.doBoth([s.connect], [r.connect])
|
|
||||||
yield deferToThread(sp.send_record, b"01234")
|
|
||||||
rec = yield deferToThread(rp.receive_record)
|
|
||||||
self.assertEqual(rec, b"01234")
|
|
||||||
yield deferToThread(sp.close)
|
|
||||||
yield deferToThread(rp.close)
|
|
||||||
|
|
||||||
@inlineCallbacks
|
|
||||||
def test_relay(self):
|
|
||||||
s = TransitSender(self.transit)
|
|
||||||
r = TransitReceiver(self.transit)
|
|
||||||
key = b"\x00"*32
|
|
||||||
# force the connection to use the relay by not revealing direct hints
|
|
||||||
s.set_transit_key(key)
|
|
||||||
s.add_their_direct_hints([])
|
|
||||||
s.add_their_relay_hints(r.get_relay_hints())
|
|
||||||
r.set_transit_key(key)
|
|
||||||
r.add_their_direct_hints([])
|
|
||||||
r.add_their_relay_hints(s.get_relay_hints())
|
|
||||||
|
|
||||||
(sp, rp) = yield self.doBoth([s.connect], [r.connect])
|
|
||||||
yield deferToThread(sp.send_record, b"01234")
|
|
||||||
rec = yield deferToThread(rp.receive_record)
|
|
||||||
self.assertEqual(rec, b"01234")
|
|
||||||
yield deferToThread(sp.close)
|
|
||||||
yield deferToThread(rp.close)
|
|
||||||
# TODO: this may be racy if we don't poll the server to make sure
|
|
||||||
# it's witnessed the first connection closing before querying the DB
|
|
||||||
#import time
|
|
||||||
#yield deferToThread(time.sleep, 1)
|
|
||||||
|
|
||||||
# check the transit relay's DB, make sure it counted the bytes
|
|
||||||
db = self._transit_server._db
|
|
||||||
c = db.execute("SELECT * FROM `usage` WHERE `type`=?", (u"transit",))
|
|
||||||
rows = c.fetchall()
|
|
||||||
self.assertEqual(len(rows), 1)
|
|
||||||
row = rows[0]
|
|
||||||
self.assertEqual(row["result"], u"happy")
|
|
||||||
# Sender first writes relay_handshake and waits for OK, but that's
|
|
||||||
# not counted by the transit server. Then sender writes
|
|
||||||
# sender_handshake and waits for receiver_handshake. Then sender
|
|
||||||
# writes GO and the body. Body is length-prefixed SecretBox, so
|
|
||||||
# includes 4-byte length, 24-byte nonce, and 16-byte MAC.
|
|
||||||
sender_count = (len(build_sender_handshake(b""))+
|
|
||||||
len(b"go\n")+
|
|
||||||
4+24+len(b"01234")+16)
|
|
||||||
# Receiver first writes relay_handshake and waits for OK, but that's
|
|
||||||
# not counted. Then receiver writes receiver_handshake and waits for
|
|
||||||
# sender_handshake+GO.
|
|
||||||
receiver_count = len(build_receiver_handshake(b""))
|
|
||||||
self.assertEqual(row["total_bytes"], sender_count+receiver_count)
|
|
||||||
|
|
|
@ -4,12 +4,10 @@ from twisted.trial import unittest
|
||||||
from twisted.python import procutils, log
|
from twisted.python import procutils, log
|
||||||
from twisted.internet.utils import getProcessOutputAndValue
|
from twisted.internet.utils import getProcessOutputAndValue
|
||||||
from twisted.internet.defer import inlineCallbacks
|
from twisted.internet.defer import inlineCallbacks
|
||||||
from twisted.internet.threads import deferToThread
|
|
||||||
from .. import __version__
|
from .. import __version__
|
||||||
from .common import ServerBase
|
from .common import ServerBase
|
||||||
from ..scripts import (runner, cmd_send_blocking, cmd_send_twisted,
|
from ..scripts import runner, cmd_send, cmd_receive
|
||||||
cmd_receive_blocking, cmd_receive_twisted)
|
from ..scripts.cmd_send import build_phase1_data
|
||||||
from ..scripts.send_common import build_phase1_data
|
|
||||||
from ..errors import TransferError
|
from ..errors import TransferError
|
||||||
from ..timing import DebugTiming
|
from ..timing import DebugTiming
|
||||||
|
|
||||||
|
@ -219,8 +217,7 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
|
||||||
|
|
||||||
@inlineCallbacks
|
@inlineCallbacks
|
||||||
def _do_test(self, as_subprocess=False,
|
def _do_test(self, as_subprocess=False,
|
||||||
mode="text", addslash=False, override_filename=False,
|
mode="text", addslash=False, override_filename=False):
|
||||||
sender_twisted=False, receiver_twisted=False):
|
|
||||||
assert mode in ("text", "file", "directory")
|
assert mode in ("text", "file", "directory")
|
||||||
common_args = ["--hide-progress",
|
common_args = ["--hide-progress",
|
||||||
"--relay-url", self.relayurl,
|
"--relay-url", self.relayurl,
|
||||||
|
@ -314,14 +311,8 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
|
||||||
rargs.stdout = io.StringIO()
|
rargs.stdout = io.StringIO()
|
||||||
rargs.stderr = io.StringIO()
|
rargs.stderr = io.StringIO()
|
||||||
rargs.timing = DebugTiming()
|
rargs.timing = DebugTiming()
|
||||||
if sender_twisted:
|
send_d = cmd_send.send_twisted(sargs)
|
||||||
send_d = cmd_send_twisted.send_twisted(sargs)
|
receive_d = cmd_receive.receive_twisted(rargs)
|
||||||
else:
|
|
||||||
send_d = deferToThread(cmd_send_blocking.send_blocking, sargs)
|
|
||||||
if receiver_twisted:
|
|
||||||
receive_d = cmd_receive_twisted.receive_twisted(rargs)
|
|
||||||
else:
|
|
||||||
receive_d = deferToThread(cmd_receive_blocking.receive_blocking, rargs)
|
|
||||||
|
|
||||||
# The sender might fail, leaving the receiver hanging, or vice
|
# The sender might fail, leaving the receiver hanging, or vice
|
||||||
# versa. If either side fails, cancel the other, so it won't
|
# versa. If either side fails, cancel the other, so it won't
|
||||||
|
@ -418,24 +409,11 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
|
||||||
return self._do_test()
|
return self._do_test()
|
||||||
def test_text_subprocess(self):
|
def test_text_subprocess(self):
|
||||||
return self._do_test(as_subprocess=True)
|
return self._do_test(as_subprocess=True)
|
||||||
def test_text_twisted_to_blocking(self):
|
|
||||||
return self._do_test(sender_twisted=True)
|
|
||||||
def test_text_blocking_to_twisted(self):
|
|
||||||
return self._do_test(receiver_twisted=True)
|
|
||||||
def test_text_twisted_to_twisted(self):
|
|
||||||
return self._do_test(sender_twisted=True, receiver_twisted=True)
|
|
||||||
|
|
||||||
def test_file(self):
|
def test_file(self):
|
||||||
return self._do_test(mode="file")
|
return self._do_test(mode="file")
|
||||||
def test_file_override(self):
|
def test_file_override(self):
|
||||||
return self._do_test(mode="file", override_filename=True)
|
return self._do_test(mode="file", override_filename=True)
|
||||||
def test_file_twisted_to_blocking(self):
|
|
||||||
return self._do_test(mode="file", sender_twisted=True)
|
|
||||||
def test_file_blocking_to_twisted(self):
|
|
||||||
return self._do_test(mode="file", receiver_twisted=True)
|
|
||||||
def test_file_twisted_to_twisted(self):
|
|
||||||
return self._do_test(mode="file",
|
|
||||||
sender_twisted=True, receiver_twisted=True)
|
|
||||||
|
|
||||||
def test_directory(self):
|
def test_directory(self):
|
||||||
return self._do_test(mode="directory")
|
return self._do_test(mode="directory")
|
||||||
|
@ -443,16 +421,3 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
|
||||||
return self._do_test(mode="directory", addslash=True)
|
return self._do_test(mode="directory", addslash=True)
|
||||||
def test_directory_override(self):
|
def test_directory_override(self):
|
||||||
return self._do_test(mode="directory", override_filename=True)
|
return self._do_test(mode="directory", override_filename=True)
|
||||||
def test_directory_twisted_to_blocking(self):
|
|
||||||
return self._do_test(mode="directory", sender_twisted=True)
|
|
||||||
def test_directory_twisted_to_blocking_addslash(self):
|
|
||||||
return self._do_test(mode="directory", addslash=True,
|
|
||||||
sender_twisted=True)
|
|
||||||
def test_directory_blocking_to_twisted(self):
|
|
||||||
return self._do_test(mode="directory", receiver_twisted=True)
|
|
||||||
def test_directory_twisted_to_twisted(self):
|
|
||||||
return self._do_test(mode="directory",
|
|
||||||
sender_twisted=True, receiver_twisted=True)
|
|
||||||
def test_directory_twisted_to_twisted_addslash(self):
|
|
||||||
return self._do_test(mode="directory", addslash=True,
|
|
||||||
sender_twisted=True, receiver_twisted=True)
|
|
||||||
|
|
|
@ -1,10 +1,12 @@
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
import json
|
import json
|
||||||
import requests
|
import requests
|
||||||
|
from binascii import hexlify
|
||||||
from six.moves.urllib_parse import urlencode
|
from six.moves.urllib_parse import urlencode
|
||||||
from twisted.trial import unittest
|
from twisted.trial import unittest
|
||||||
from twisted.internet import reactor, defer
|
from twisted.internet import protocol, reactor, defer
|
||||||
from twisted.internet.threads import deferToThread
|
from twisted.internet.threads import deferToThread
|
||||||
|
from twisted.internet.endpoints import clientFromString, connectProtocol
|
||||||
from twisted.web.client import getPage, Agent, readBody
|
from twisted.web.client import getPage, Agent, readBody
|
||||||
from .. import __version__
|
from .. import __version__
|
||||||
from .common import ServerBase
|
from .common import ServerBase
|
||||||
|
@ -434,7 +436,30 @@ class Summary(unittest.TestCase):
|
||||||
self.failUnlessEqual(c._summarize(make_moods(None, "scary"), 41),
|
self.failUnlessEqual(c._summarize(make_moods(None, "scary"), 41),
|
||||||
(1, "scary", 40, 9))
|
(1, "scary", 40, 9))
|
||||||
|
|
||||||
class Transit(unittest.TestCase):
|
class Accumulator(protocol.Protocol):
|
||||||
|
def __init__(self):
|
||||||
|
self.data = b""
|
||||||
|
self.count = 0
|
||||||
|
self._wait = None
|
||||||
|
def waitForBytes(self, more):
|
||||||
|
assert self._wait is None
|
||||||
|
self.count = more
|
||||||
|
self._wait = defer.Deferred()
|
||||||
|
self._check_done()
|
||||||
|
return self._wait
|
||||||
|
def dataReceived(self, data):
|
||||||
|
self.data = self.data + data
|
||||||
|
self._check_done()
|
||||||
|
def _check_done(self):
|
||||||
|
if self._wait and len(self.data) >= self.count:
|
||||||
|
d = self._wait
|
||||||
|
self._wait = None
|
||||||
|
d.callback(self)
|
||||||
|
def connectionLost(self, why):
|
||||||
|
if self._wait:
|
||||||
|
self._wait.errback(RuntimeError("closed"))
|
||||||
|
|
||||||
|
class Transit(ServerBase, unittest.TestCase):
|
||||||
def test_blur_size(self):
|
def test_blur_size(self):
|
||||||
blur = transit_server.blur_size
|
blur = transit_server.blur_size
|
||||||
self.failUnlessEqual(blur(0), 0)
|
self.failUnlessEqual(blur(0), 0)
|
||||||
|
@ -453,3 +478,62 @@ class Transit(unittest.TestCase):
|
||||||
self.failUnlessEqual(blur(1100e6), 1100e6)
|
self.failUnlessEqual(blur(1100e6), 1100e6)
|
||||||
self.failUnlessEqual(blur(1150e6), 1200e6)
|
self.failUnlessEqual(blur(1150e6), 1200e6)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_basic(self):
|
||||||
|
ep = clientFromString(reactor, self.transit)
|
||||||
|
a1 = yield connectProtocol(ep, Accumulator())
|
||||||
|
a2 = yield connectProtocol(ep, Accumulator())
|
||||||
|
|
||||||
|
token1 = b"\x00"*32
|
||||||
|
a1.transport.write(b"please relay " + hexlify(token1) + b"\n")
|
||||||
|
a2.transport.write(b"please relay " + hexlify(token1) + b"\n")
|
||||||
|
|
||||||
|
# a correct handshake yields an ack, after which we can send
|
||||||
|
exp = b"ok\n"
|
||||||
|
yield a1.waitForBytes(len(exp))
|
||||||
|
self.assertEqual(a1.data, exp)
|
||||||
|
s1 = b"data1"
|
||||||
|
a1.transport.write(s1)
|
||||||
|
|
||||||
|
exp = b"ok\n"
|
||||||
|
yield a2.waitForBytes(len(exp))
|
||||||
|
self.assertEqual(a2.data, exp)
|
||||||
|
|
||||||
|
# all data they sent after the handshake should be given to us
|
||||||
|
exp = b"ok\n"+s1
|
||||||
|
yield a2.waitForBytes(len(exp))
|
||||||
|
self.assertEqual(a2.data, exp)
|
||||||
|
|
||||||
|
a1.transport.loseConnection()
|
||||||
|
a2.transport.loseConnection()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_bad_handshake(self):
|
||||||
|
ep = clientFromString(reactor, self.transit)
|
||||||
|
a1 = yield connectProtocol(ep, Accumulator())
|
||||||
|
|
||||||
|
token1 = b"\x00"*32
|
||||||
|
# the server waits for the exact number of bytes in the expected
|
||||||
|
# handshake message. to trigger "bad handshake", we must match.
|
||||||
|
a1.transport.write(b"please DELAY " + hexlify(token1) + b"\n")
|
||||||
|
|
||||||
|
exp = b"bad handshake\n"
|
||||||
|
yield a1.waitForBytes(len(exp))
|
||||||
|
self.assertEqual(a1.data, exp)
|
||||||
|
|
||||||
|
a1.transport.loseConnection()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_impatience(self):
|
||||||
|
ep = clientFromString(reactor, self.transit)
|
||||||
|
a1 = yield connectProtocol(ep, Accumulator())
|
||||||
|
|
||||||
|
token1 = b"\x00"*32
|
||||||
|
# sending too many bytes is impatience.
|
||||||
|
a1.transport.write(b"please RELAY NOWNOW " + hexlify(token1) + b"\n")
|
||||||
|
|
||||||
|
exp = b"impatient\n"
|
||||||
|
yield a1.waitForBytes(len(exp))
|
||||||
|
self.assertEqual(a1.data, exp)
|
||||||
|
|
||||||
|
a1.transport.loseConnection()
|
||||||
|
|
|
@ -1,90 +0,0 @@
|
||||||
from __future__ import print_function
|
|
||||||
from binascii import hexlify
|
|
||||||
from twisted.trial import unittest
|
|
||||||
from twisted.internet import protocol, defer, reactor
|
|
||||||
from twisted.internet.endpoints import clientFromString, connectProtocol
|
|
||||||
from .common import ServerBase
|
|
||||||
|
|
||||||
class Accumulator(protocol.Protocol):
|
|
||||||
def __init__(self):
|
|
||||||
self.data = b""
|
|
||||||
self.count = 0
|
|
||||||
self._wait = None
|
|
||||||
def waitForBytes(self, more):
|
|
||||||
assert self._wait is None
|
|
||||||
self.count = more
|
|
||||||
self._wait = defer.Deferred()
|
|
||||||
self._check_done()
|
|
||||||
return self._wait
|
|
||||||
def dataReceived(self, data):
|
|
||||||
self.data = self.data + data
|
|
||||||
self._check_done()
|
|
||||||
def _check_done(self):
|
|
||||||
if self._wait and len(self.data) >= self.count:
|
|
||||||
d = self._wait
|
|
||||||
self._wait = None
|
|
||||||
d.callback(self)
|
|
||||||
def connectionLost(self, why):
|
|
||||||
if self._wait:
|
|
||||||
self._wait.errback(RuntimeError("closed"))
|
|
||||||
|
|
||||||
class Transit(ServerBase, unittest.TestCase):
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_basic(self):
|
|
||||||
ep = clientFromString(reactor, self.transit)
|
|
||||||
a1 = yield connectProtocol(ep, Accumulator())
|
|
||||||
a2 = yield connectProtocol(ep, Accumulator())
|
|
||||||
|
|
||||||
token1 = b"\x00"*32
|
|
||||||
a1.transport.write(b"please relay " + hexlify(token1) + b"\n")
|
|
||||||
a2.transport.write(b"please relay " + hexlify(token1) + b"\n")
|
|
||||||
|
|
||||||
# a correct handshake yields an ack, after which we can send
|
|
||||||
exp = b"ok\n"
|
|
||||||
yield a1.waitForBytes(len(exp))
|
|
||||||
self.assertEqual(a1.data, exp)
|
|
||||||
s1 = b"data1"
|
|
||||||
a1.transport.write(s1)
|
|
||||||
|
|
||||||
exp = b"ok\n"
|
|
||||||
yield a2.waitForBytes(len(exp))
|
|
||||||
self.assertEqual(a2.data, exp)
|
|
||||||
|
|
||||||
# all data they sent after the handshake should be given to us
|
|
||||||
exp = b"ok\n"+s1
|
|
||||||
yield a2.waitForBytes(len(exp))
|
|
||||||
self.assertEqual(a2.data, exp)
|
|
||||||
|
|
||||||
a1.transport.loseConnection()
|
|
||||||
a2.transport.loseConnection()
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_bad_handshake(self):
|
|
||||||
ep = clientFromString(reactor, self.transit)
|
|
||||||
a1 = yield connectProtocol(ep, Accumulator())
|
|
||||||
|
|
||||||
token1 = b"\x00"*32
|
|
||||||
# the server waits for the exact number of bytes in the expected
|
|
||||||
# handshake message. to trigger "bad handshake", we must match.
|
|
||||||
a1.transport.write(b"please DELAY " + hexlify(token1) + b"\n")
|
|
||||||
|
|
||||||
exp = b"bad handshake\n"
|
|
||||||
yield a1.waitForBytes(len(exp))
|
|
||||||
self.assertEqual(a1.data, exp)
|
|
||||||
|
|
||||||
a1.transport.loseConnection()
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_impatience(self):
|
|
||||||
ep = clientFromString(reactor, self.transit)
|
|
||||||
a1 = yield connectProtocol(ep, Accumulator())
|
|
||||||
|
|
||||||
token1 = b"\x00"*32
|
|
||||||
# sending too many bytes is impatience.
|
|
||||||
a1.transport.write(b"please RELAY NOWNOW " + hexlify(token1) + b"\n")
|
|
||||||
|
|
||||||
exp = b"impatient\n"
|
|
||||||
yield a1.waitForBytes(len(exp))
|
|
||||||
self.assertEqual(a1.data, exp)
|
|
||||||
|
|
||||||
a1.transport.loseConnection()
|
|
|
@ -1,87 +0,0 @@
|
||||||
import re
|
|
||||||
from binascii import hexlify
|
|
||||||
from .util.hkdf import HKDF
|
|
||||||
|
|
||||||
class TransitError(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
class BadHandshake(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
class TransitClosed(TransitError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
class BadNonce(TransitError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
# The beginning of each TCP connection consists of the following handshake
|
|
||||||
# messages. The sender transmits the same text regardless of whether it is on
|
|
||||||
# the initiating/connecting end of the TCP connection, or on the
|
|
||||||
# listening/accepting side. Same for the receiver.
|
|
||||||
#
|
|
||||||
# sender -> receiver: transit sender TXID_HEX ready\n\n
|
|
||||||
# receiver -> sender: transit receiver RXID_HEX ready\n\n
|
|
||||||
#
|
|
||||||
# Any deviations from this result in the socket being closed. The handshake
|
|
||||||
# messages are designed to provoke an invalid response from other sorts of
|
|
||||||
# servers (HTTP, SMTP, echo).
|
|
||||||
#
|
|
||||||
# If the sender is satisfied with the handshake, and this is the first socket
|
|
||||||
# to complete negotiation, the sender does:
|
|
||||||
#
|
|
||||||
# sender -> receiver: go\n
|
|
||||||
#
|
|
||||||
# and the next byte on the wire will be from the application.
|
|
||||||
#
|
|
||||||
# If this is not the first socket, the sender does:
|
|
||||||
#
|
|
||||||
# sender -> receiver: nevermind\n
|
|
||||||
#
|
|
||||||
# and closes the socket.
|
|
||||||
|
|
||||||
# So the receiver looks for "transit sender TXID_HEX ready\n\ngo\n" and hangs
|
|
||||||
# up upon the first wrong byte. The sender lookgs for "transit receiver
|
|
||||||
# RXID_HEX ready\n\n" and then makes a first/not-first decision about sending
|
|
||||||
# "go\n" or "nevermind\n"+close().
|
|
||||||
|
|
||||||
def build_receiver_handshake(key):
|
|
||||||
hexid = HKDF(key, 32, CTXinfo=b"transit_receiver")
|
|
||||||
return b"transit receiver "+hexlify(hexid)+b" ready\n\n"
|
|
||||||
|
|
||||||
def build_sender_handshake(key):
|
|
||||||
hexid = HKDF(key, 32, CTXinfo=b"transit_sender")
|
|
||||||
return b"transit sender "+hexlify(hexid)+b" ready\n\n"
|
|
||||||
|
|
||||||
def build_relay_handshake(key):
|
|
||||||
token = HKDF(key, 32, CTXinfo=b"transit_relay_token")
|
|
||||||
return b"please relay "+hexlify(token)+b"\n"
|
|
||||||
|
|
||||||
# The hint format is: TYPE,VALUE= /^([a-zA-Z0-9]+):(.*)$/ . VALUE depends
|
|
||||||
# upon TYPE, and it can have more colons in it. For TYPE=tcp (the only one
|
|
||||||
# currently defined), ADDR,PORT = /^(.*):(\d+)$/ , so ADDR can have colons.
|
|
||||||
# ADDR can be a hostname, ipv4 dotted-quad, or ipv6 colon-hex. If the hint
|
|
||||||
# publisher wants anonymity, their only hint's ADDR will end in .onion .
|
|
||||||
|
|
||||||
def parse_hint_tcp(hint):
|
|
||||||
assert isinstance(hint, type(u""))
|
|
||||||
# return tuple or None for an unparseable hint
|
|
||||||
mo = re.search(r'^([a-zA-Z0-9]+):(.*)$', hint)
|
|
||||||
if not mo:
|
|
||||||
print("unparseable hint '%s'" % (hint,))
|
|
||||||
return None
|
|
||||||
hint_type = mo.group(1)
|
|
||||||
if hint_type != "tcp":
|
|
||||||
print("unknown hint type '%s' in '%s'" % (hint_type, hint))
|
|
||||||
return None
|
|
||||||
hint_value = mo.group(2)
|
|
||||||
mo = re.search(r'^(.*):(\d+)$', hint_value)
|
|
||||||
if not mo:
|
|
||||||
print("unparseable TCP hint '%s'" % (hint,))
|
|
||||||
return None
|
|
||||||
hint_host = mo.group(1)
|
|
||||||
try:
|
|
||||||
hint_port = int(mo.group(2))
|
|
||||||
except ValueError:
|
|
||||||
print("non-numeric port in TCP hint '%s'" % (hint,))
|
|
||||||
return None
|
|
||||||
return hint_host, hint_port
|
|
|
@ -1,5 +1,5 @@
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
import sys, time, socket, collections
|
import re, sys, time, socket, collections
|
||||||
from binascii import hexlify, unhexlify
|
from binascii import hexlify, unhexlify
|
||||||
from zope.interface import implementer
|
from zope.interface import implementer
|
||||||
from twisted.python.runtime import platformType
|
from twisted.python.runtime import platformType
|
||||||
|
@ -12,11 +12,6 @@ from ..util import ipaddrs
|
||||||
from ..util.hkdf import HKDF
|
from ..util.hkdf import HKDF
|
||||||
from ..errors import UsageError
|
from ..errors import UsageError
|
||||||
from ..timing import DebugTiming
|
from ..timing import DebugTiming
|
||||||
from ..transit_common import (BadHandshake,
|
|
||||||
BadNonce,
|
|
||||||
build_receiver_handshake,
|
|
||||||
build_sender_handshake,
|
|
||||||
build_relay_handshake)
|
|
||||||
|
|
||||||
def debug(msg):
|
def debug(msg):
|
||||||
if False:
|
if False:
|
||||||
|
@ -24,6 +19,90 @@ def debug(msg):
|
||||||
def since(start):
|
def since(start):
|
||||||
return time.time() - start
|
return time.time() - start
|
||||||
|
|
||||||
|
class TransitError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class BadHandshake(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class TransitClosed(TransitError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class BadNonce(TransitError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# The beginning of each TCP connection consists of the following handshake
|
||||||
|
# messages. The sender transmits the same text regardless of whether it is on
|
||||||
|
# the initiating/connecting end of the TCP connection, or on the
|
||||||
|
# listening/accepting side. Same for the receiver.
|
||||||
|
#
|
||||||
|
# sender -> receiver: transit sender TXID_HEX ready\n\n
|
||||||
|
# receiver -> sender: transit receiver RXID_HEX ready\n\n
|
||||||
|
#
|
||||||
|
# Any deviations from this result in the socket being closed. The handshake
|
||||||
|
# messages are designed to provoke an invalid response from other sorts of
|
||||||
|
# servers (HTTP, SMTP, echo).
|
||||||
|
#
|
||||||
|
# If the sender is satisfied with the handshake, and this is the first socket
|
||||||
|
# to complete negotiation, the sender does:
|
||||||
|
#
|
||||||
|
# sender -> receiver: go\n
|
||||||
|
#
|
||||||
|
# and the next byte on the wire will be from the application.
|
||||||
|
#
|
||||||
|
# If this is not the first socket, the sender does:
|
||||||
|
#
|
||||||
|
# sender -> receiver: nevermind\n
|
||||||
|
#
|
||||||
|
# and closes the socket.
|
||||||
|
|
||||||
|
# So the receiver looks for "transit sender TXID_HEX ready\n\ngo\n" and hangs
|
||||||
|
# up upon the first wrong byte. The sender lookgs for "transit receiver
|
||||||
|
# RXID_HEX ready\n\n" and then makes a first/not-first decision about sending
|
||||||
|
# "go\n" or "nevermind\n"+close().
|
||||||
|
|
||||||
|
def build_receiver_handshake(key):
|
||||||
|
hexid = HKDF(key, 32, CTXinfo=b"transit_receiver")
|
||||||
|
return b"transit receiver "+hexlify(hexid)+b" ready\n\n"
|
||||||
|
|
||||||
|
def build_sender_handshake(key):
|
||||||
|
hexid = HKDF(key, 32, CTXinfo=b"transit_sender")
|
||||||
|
return b"transit sender "+hexlify(hexid)+b" ready\n\n"
|
||||||
|
|
||||||
|
def build_relay_handshake(key):
|
||||||
|
token = HKDF(key, 32, CTXinfo=b"transit_relay_token")
|
||||||
|
return b"please relay "+hexlify(token)+b"\n"
|
||||||
|
|
||||||
|
# The hint format is: TYPE,VALUE= /^([a-zA-Z0-9]+):(.*)$/ . VALUE depends
|
||||||
|
# upon TYPE, and it can have more colons in it. For TYPE=tcp (the only one
|
||||||
|
# currently defined), ADDR,PORT = /^(.*):(\d+)$/ , so ADDR can have colons.
|
||||||
|
# ADDR can be a hostname, ipv4 dotted-quad, or ipv6 colon-hex. If the hint
|
||||||
|
# publisher wants anonymity, their only hint's ADDR will end in .onion .
|
||||||
|
|
||||||
|
def parse_hint_tcp(hint):
|
||||||
|
assert isinstance(hint, type(u""))
|
||||||
|
# return tuple or None for an unparseable hint
|
||||||
|
mo = re.search(r'^([a-zA-Z0-9]+):(.*)$', hint)
|
||||||
|
if not mo:
|
||||||
|
print("unparseable hint '%s'" % (hint,))
|
||||||
|
return None
|
||||||
|
hint_type = mo.group(1)
|
||||||
|
if hint_type != "tcp":
|
||||||
|
print("unknown hint type '%s' in '%s'" % (hint_type, hint))
|
||||||
|
return None
|
||||||
|
hint_value = mo.group(2)
|
||||||
|
mo = re.search(r'^(.*):(\d+)$', hint_value)
|
||||||
|
if not mo:
|
||||||
|
print("unparseable TCP hint '%s'" % (hint,))
|
||||||
|
return None
|
||||||
|
hint_host = mo.group(1)
|
||||||
|
try:
|
||||||
|
hint_port = int(mo.group(2))
|
||||||
|
except ValueError:
|
||||||
|
print("non-numeric port in TCP hint '%s'" % (hint,))
|
||||||
|
return None
|
||||||
|
return hint_host, hint_port
|
||||||
|
|
||||||
TIMEOUT=15
|
TIMEOUT=15
|
||||||
|
|
||||||
@implementer(interfaces.IProducer, interfaces.IConsumer)
|
@implementer(interfaces.IProducer, interfaces.IConsumer)
|
||||||
|
@ -684,7 +763,7 @@ class Common:
|
||||||
return d
|
return d
|
||||||
|
|
||||||
def _endpoint_from_hint(self, hint):
|
def _endpoint_from_hint(self, hint):
|
||||||
# TODO: use transit_common.parse_hint_tcp
|
# TODO: use parse_hint_tcp
|
||||||
if ":" not in hint:
|
if ":" not in hint:
|
||||||
return None
|
return None
|
||||||
pieces = hint.split(":")
|
pieces = hint.split(":")
|
||||||
|
|
|
@ -1,21 +0,0 @@
|
||||||
from twisted.internet import defer, protocol, endpoints, reactor
|
|
||||||
|
|
||||||
def allocate_port():
|
|
||||||
ep = endpoints.serverFromString(reactor, "tcp:0:interface=127.0.0.1")
|
|
||||||
d = ep.listen(protocol.Factory())
|
|
||||||
def _listening(lp):
|
|
||||||
port = lp.getHost().port
|
|
||||||
d2 = lp.stopListening()
|
|
||||||
d2.addCallback(lambda _: port)
|
|
||||||
return d2
|
|
||||||
d.addCallback(_listening)
|
|
||||||
return d
|
|
||||||
|
|
||||||
def allocate_ports():
|
|
||||||
d = defer.DeferredList([allocate_port(), allocate_port()])
|
|
||||||
def _done(results):
|
|
||||||
port1 = results[0][1]
|
|
||||||
port2 = results[1][1]
|
|
||||||
return (port1, port2)
|
|
||||||
d.addCallback(_done)
|
|
||||||
return d
|
|
|
@ -1,50 +0,0 @@
|
||||||
# -*- test-case-name: foolscap.test_observer -*-
|
|
||||||
|
|
||||||
# many thanks to AllMyData for contributing the initial version of this code
|
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
from foolscap import eventual
|
|
||||||
|
|
||||||
class OneShotObserverList(object):
|
|
||||||
"""A one-shot event distributor.
|
|
||||||
|
|
||||||
Subscribers can get a Deferred that will fire with the results of the
|
|
||||||
event once it finally occurs. The caller does not need to know whether
|
|
||||||
the event has happened yet or not: they get a Deferred in either case.
|
|
||||||
|
|
||||||
The Deferreds returned to subscribers are guaranteed to not fire in the
|
|
||||||
current reactor turn; instead, eventually() is used to fire them in a
|
|
||||||
later turn. Look at Mark Miller's 'Concurrency Among Strangers' paper on
|
|
||||||
erights.org for a description of why this property is useful.
|
|
||||||
|
|
||||||
I can only be fired once."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self._fired = False
|
|
||||||
self._result = None
|
|
||||||
self._watchers = []
|
|
||||||
self.__repr__ = self._unfired_repr
|
|
||||||
|
|
||||||
def _unfired_repr(self):
|
|
||||||
return "<OneShotObserverList [%s]>" % (self._watchers, )
|
|
||||||
|
|
||||||
def _fired_repr(self):
|
|
||||||
return "<OneShotObserverList -> %s>" % (self._result, )
|
|
||||||
|
|
||||||
def whenFired(self):
|
|
||||||
if self._fired:
|
|
||||||
return eventual.fireEventually(self._result)
|
|
||||||
d = defer.Deferred()
|
|
||||||
self._watchers.append(d)
|
|
||||||
return d
|
|
||||||
|
|
||||||
def fire(self, result):
|
|
||||||
assert not self._fired
|
|
||||||
self._fired = True
|
|
||||||
self._result = result
|
|
||||||
|
|
||||||
for w in self._watchers:
|
|
||||||
eventual.eventually(w.callback, result)
|
|
||||||
del self._watchers
|
|
||||||
self.__repr__ = self._fired_repr
|
|
||||||
|
|
7
tox.ini
7
tox.ini
|
@ -7,12 +7,6 @@
|
||||||
envlist = py27,py33,py34,py35
|
envlist = py27,py33,py34,py35
|
||||||
skip_missing_interpreters = True
|
skip_missing_interpreters = True
|
||||||
|
|
||||||
# There's a race-condition bug in Twisted-15.5.0 (#8014, fixed in trunk) that
|
|
||||||
# manifests as various intermittent magic-wormhole test failures that always
|
|
||||||
# include twisted.internet.endpoints "iterateEndpoint" or "checkDone". So run
|
|
||||||
# all builds with a copy of Twisted from git 'trunk' until Twisted-16.0.0 is
|
|
||||||
# released and we can just depend on that.
|
|
||||||
|
|
||||||
# On windows we need "pypiwin32" installed. It's supposedly possible to make
|
# On windows we need "pypiwin32" installed. It's supposedly possible to make
|
||||||
# Twisted do this by depending upon "twisted[windows]" instead of just
|
# Twisted do this by depending upon "twisted[windows]" instead of just
|
||||||
# "twisted", but when I try this via Appveyor, the extra is ignored.
|
# "twisted", but when I try this via Appveyor, the extra is ignored.
|
||||||
|
@ -24,7 +18,6 @@ skip_missing_interpreters = True
|
||||||
|
|
||||||
[testenv]
|
[testenv]
|
||||||
deps =
|
deps =
|
||||||
twisted >= 16.1.0
|
|
||||||
pyflakes
|
pyflakes
|
||||||
{env:EXTRA_DEPENDENCY:}
|
{env:EXTRA_DEPENDENCY:}
|
||||||
commands =
|
commands =
|
||||||
|
|
Loading…
Reference in New Issue
Block a user