add DebugTiming object, --dump-timing= option
This writes timeline data to a .json file, which can be examined later to find likely candidates for optimization.
This commit is contained in:
parent
84def8a54b
commit
8d82726c51
|
@ -10,6 +10,7 @@ from .eventsource import EventSourceFollower
|
|||
from .. import __version__
|
||||
from .. import codes
|
||||
from ..errors import ServerError, Timeout, WrongPasswordError, UsageError
|
||||
from ..timing import DebugTiming
|
||||
from ..util.hkdf import HKDF
|
||||
from ..channel_monitor import monitor
|
||||
|
||||
|
@ -193,7 +194,8 @@ class Wormhole:
|
|||
version_warning_displayed = False
|
||||
_send_confirm = True
|
||||
|
||||
def __init__(self, appid, relay_url, wait=0.5*SECOND, timeout=3*MINUTE):
|
||||
def __init__(self, appid, relay_url, wait=0.5*SECOND, timeout=3*MINUTE,
|
||||
timing=None):
|
||||
if not isinstance(appid, type(u"")): raise TypeError(type(appid))
|
||||
if not isinstance(relay_url, type(u"")):
|
||||
raise TypeError(type(relay_url))
|
||||
|
@ -202,6 +204,7 @@ class Wormhole:
|
|||
self._relay_url = relay_url
|
||||
self._wait = wait
|
||||
self._timeout = timeout
|
||||
self._timing = timing or DebugTiming()
|
||||
side = hexlify(os.urandom(5)).decode("ascii")
|
||||
self._channel_manager = ChannelManager(relay_url, appid, side,
|
||||
self.handle_welcome,
|
||||
|
@ -214,6 +217,7 @@ class Wormhole:
|
|||
self._got_data = set()
|
||||
self._got_confirmation = False
|
||||
self._closed = False
|
||||
self._timing_started = self._timing.add_event("wormhole")
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
@ -245,7 +249,9 @@ class Wormhole:
|
|||
|
||||
def get_code(self, code_length=2):
|
||||
if self.code is not None: raise UsageError
|
||||
_start = self._timing.add_event("alloc channel")
|
||||
channelid = self._channel_manager.allocate()
|
||||
self._timing.finish_event(_start)
|
||||
code = codes.make_code(channelid, code_length)
|
||||
assert isinstance(code, type(u"")), type(code)
|
||||
self._set_code_and_channelid(code)
|
||||
|
@ -258,9 +264,11 @@ class Wormhole:
|
|||
# discover the welcome message (and warn the user about an obsolete
|
||||
# client)
|
||||
initial_channelids = lister()
|
||||
_start = self._timing.add_event("input code", waiting="user")
|
||||
code = codes.input_code_with_completion(prompt,
|
||||
initial_channelids, lister,
|
||||
code_length)
|
||||
self._timing.finish_event(_start)
|
||||
return code
|
||||
|
||||
def set_code(self, code): # used for human-made pre-generated codes
|
||||
|
@ -271,6 +279,7 @@ class Wormhole:
|
|||
|
||||
def _set_code_and_channelid(self, code):
|
||||
if self.code is not None: raise UsageError
|
||||
self._timing.add_event("code established")
|
||||
mo = re.search(r'^(\d+)-', code)
|
||||
if not mo:
|
||||
raise ValueError("code (%s) must start with NN-" % code)
|
||||
|
@ -308,16 +317,26 @@ class Wormhole:
|
|||
|
||||
def _get_key(self):
|
||||
if not self.key:
|
||||
_sent = self._timing.add_event("send pake")
|
||||
self._channel.send(u"pake", self.msg1)
|
||||
self._timing.finish_event(_sent)
|
||||
|
||||
_sent = self._timing.add_event("get pake")
|
||||
pake_msg = self._channel.get(u"pake")
|
||||
self._timing.finish_event(_sent)
|
||||
|
||||
self.key = self.sp.finish(pake_msg)
|
||||
self.verifier = self.derive_key(u"wormhole:verifier")
|
||||
self._timing.add_event("key established")
|
||||
|
||||
if not self._send_confirm:
|
||||
return
|
||||
_sent = self._timing.add_event("send confirmation")
|
||||
confkey = self.derive_key(u"wormhole:confirmation")
|
||||
nonce = os.urandom(CONFMSG_NONCE_LENGTH)
|
||||
confmsg = make_confmsg(confkey, nonce)
|
||||
self._channel.send(u"_confirm", confmsg)
|
||||
self._timing.finish_event(_sent)
|
||||
|
||||
@close_on_error
|
||||
def get_verifier(self):
|
||||
|
@ -337,6 +356,7 @@ class Wormhole:
|
|||
if phase.startswith(u"_"): raise UsageError # reserved for internals
|
||||
if self.code is None: raise UsageError
|
||||
if self._channel is None: raise UsageError
|
||||
_sent = self._timing.add_event("API send data", phase=phase)
|
||||
# Without predefined roles, we can't derive predictably unique keys
|
||||
# for each side, so we use the same key for both. We use random
|
||||
# nonces to keep the messages distinct, and the Channel automatically
|
||||
|
@ -345,7 +365,10 @@ class Wormhole:
|
|||
self._get_key()
|
||||
data_key = self.derive_key(u"wormhole:phase:%s" % phase)
|
||||
outbound_encrypted = self._encrypt_data(data_key, outbound_data)
|
||||
_sent2 = self._timing.add_event("send")
|
||||
self._channel.send(phase, outbound_encrypted)
|
||||
self._timing.finish_event(_sent2)
|
||||
self._timing.finish_event(_sent)
|
||||
|
||||
@close_on_error
|
||||
def get_data(self, phase=u"data"):
|
||||
|
@ -355,21 +378,27 @@ class Wormhole:
|
|||
if self._closed: raise UsageError
|
||||
if self.code is None: raise UsageError
|
||||
if self._channel is None: raise UsageError
|
||||
_sent = self._timing.add_event("API get data", phase=phase)
|
||||
self._got_data.add(phase)
|
||||
self._get_key()
|
||||
phases = []
|
||||
if not self._got_confirmation:
|
||||
phases.append(u"_confirm")
|
||||
phases.append(phase)
|
||||
_sent2 = self._timing.add_event("get", phases=phases)
|
||||
(got_phase, body) = self._channel.get_first_of(phases)
|
||||
self._timing.finish_event(_sent2)
|
||||
if got_phase == u"_confirm":
|
||||
confkey = self.derive_key(u"wormhole:confirmation")
|
||||
nonce = body[:CONFMSG_NONCE_LENGTH]
|
||||
if body != make_confmsg(confkey, nonce):
|
||||
raise WrongPasswordError
|
||||
self._got_confirmation = True
|
||||
_sent2 = self._timing.add_event("get", phases=[phase])
|
||||
(got_phase, body) = self._channel.get_first_of([phase])
|
||||
self._timing.finish_event(_sent2)
|
||||
assert got_phase == phase
|
||||
self._timing.finish_event(_sent)
|
||||
try:
|
||||
data_key = self.derive_key(u"wormhole:phase:%s" % phase)
|
||||
inbound_data = self._decrypt_data(data_key, body)
|
||||
|
@ -382,6 +411,9 @@ class Wormhole:
|
|||
raise TypeError(type(mood))
|
||||
self._closed = True
|
||||
if self._channel:
|
||||
self._timing.finish_event(self._timing_started, mood=mood)
|
||||
c, self._channel = self._channel, None
|
||||
monitor.close(c)
|
||||
_sent = self._timing.add_event("close")
|
||||
c.deallocate(mood)
|
||||
self._timing.finish_event(_sent)
|
||||
|
|
|
@ -6,6 +6,7 @@ from nacl.secret import SecretBox
|
|||
from ..util import ipaddrs
|
||||
from ..util.hkdf import HKDF
|
||||
from ..errors import UsageError
|
||||
from ..timing import DebugTiming
|
||||
from ..transit_common import (TransitError, BadHandshake, TransitClosed,
|
||||
BadNonce,
|
||||
build_receiver_handshake,
|
||||
|
@ -206,13 +207,15 @@ class RecordPipe:
|
|||
self.skt.close()
|
||||
|
||||
class Common:
|
||||
def __init__(self, transit_relay):
|
||||
def __init__(self, transit_relay, timing=None):
|
||||
if transit_relay:
|
||||
if not isinstance(transit_relay, type(u"")):
|
||||
raise UsageError
|
||||
self._transit_relays = [transit_relay]
|
||||
else:
|
||||
self._transit_relays = []
|
||||
self._timing = timing or DebugTiming()
|
||||
self._timing_started = self._timing.add_event("transit")
|
||||
self.winning = threading.Event()
|
||||
self._negotiation_check_lock = threading.Lock()
|
||||
self._ready_for_connections_lock = threading.Condition()
|
||||
|
@ -372,7 +375,9 @@ class Common:
|
|||
skt.close()
|
||||
|
||||
def connect(self):
|
||||
_start = self._timing.add_event("transit connect")
|
||||
skt = self.establish_socket()
|
||||
self._timing.finish_event(_start)
|
||||
return RecordPipe(skt, self._sender_record_key(),
|
||||
self._receiver_record_key(),
|
||||
self.winning_skt_description)
|
||||
|
|
|
@ -25,6 +25,9 @@ g.add_argument("-v", "--verify", action="store_true",
|
|||
help="display (and wait for acceptance of) verification string")
|
||||
g.add_argument("--hide-progress", action="store_true",
|
||||
help="supress progress-bar display")
|
||||
g.add_argument("--dump-timing", type=type(u""), # TODO: hide from --help output
|
||||
metavar="FILE", help="(debug) write timing data to file")
|
||||
parser.set_defaults(timing=None)
|
||||
subparsers = parser.add_subparsers(title="subcommands",
|
||||
dest="subcommand")
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@ class BlockingReceiver:
|
|||
|
||||
@handle_server_error
|
||||
def go(self):
|
||||
with Wormhole(APPID, self.args.relay_url) as w:
|
||||
with Wormhole(APPID, self.args.relay_url, timing=self.args.timing) as w:
|
||||
self.handle_code(w)
|
||||
verifier = w.get_verifier()
|
||||
self.show_verifier(verifier)
|
||||
|
@ -131,16 +131,20 @@ class BlockingReceiver:
|
|||
return abs_destname
|
||||
|
||||
def ask_permission(self):
|
||||
_start = self.args.timing.add_event("permission", waiting="user")
|
||||
while True and not self.args.accept_file:
|
||||
ok = six.moves.input("ok? (y/n): ")
|
||||
if ok.lower().startswith("y"):
|
||||
break
|
||||
print(u"transfer rejected", file=sys.stderr)
|
||||
self.args.timing.finish_event(_start, answer="no")
|
||||
raise RespondError({"error": "transfer rejected"})
|
||||
self.args.timing.finish_event(_start, answer="yes")
|
||||
|
||||
def establish_transit(self, w, them_d):
|
||||
transit_key = w.derive_key(APPID+u"/transit-key")
|
||||
transit_receiver = TransitReceiver(self.args.transit_helper)
|
||||
transit_receiver = TransitReceiver(self.args.transit_helper,
|
||||
timing=self.args.timing)
|
||||
transit_receiver.set_transit_key(transit_key)
|
||||
data = json.dumps({
|
||||
"file_ack": "ok",
|
||||
|
@ -161,6 +165,7 @@ class BlockingReceiver:
|
|||
def transfer_data(self, record_pipe, f):
|
||||
self.msg(u"Receiving (%s).." % record_pipe.describe())
|
||||
|
||||
_start = self.args.timing.add_event("rx file")
|
||||
progress_stdout = self.args.stdout
|
||||
if self.args.hide_progress:
|
||||
progress_stdout = io.StringIO()
|
||||
|
@ -179,6 +184,7 @@ class BlockingReceiver:
|
|||
received += len(plaintext)
|
||||
p.update(received)
|
||||
p.finish()
|
||||
self.args.timing.finish_event(_start)
|
||||
assert received == self.xfersize
|
||||
|
||||
def write_file(self, f):
|
||||
|
@ -190,6 +196,7 @@ class BlockingReceiver:
|
|||
|
||||
def write_directory(self, f):
|
||||
self.msg(u"Unpacking zipfile..")
|
||||
_start = self.args.timing.add_event("unpack zip")
|
||||
with zipfile.ZipFile(f, "r", zipfile.ZIP_DEFLATED) as zf:
|
||||
zf.extractall(path=self.abs_destname)
|
||||
# extractall() appears to offer some protection against
|
||||
|
@ -199,7 +206,10 @@ class BlockingReceiver:
|
|||
self.msg(u"Received files written to %s/" %
|
||||
os.path.basename(self.abs_destname))
|
||||
f.close()
|
||||
self.args.timing.finish_event(_start)
|
||||
|
||||
def close_transit(self, record_pipe):
|
||||
_start = self.args.timing.add_event("ack")
|
||||
record_pipe.send_record(b"ok\n")
|
||||
record_pipe.close()
|
||||
self.args.timing.finish_event(_start)
|
||||
|
|
|
@ -17,7 +17,7 @@ class TwistedReceiver(BlockingReceiver):
|
|||
# TODO: @handle_server_error
|
||||
@inlineCallbacks
|
||||
def go(self):
|
||||
w = Wormhole(APPID, self.args.relay_url)
|
||||
w = Wormhole(APPID, self.args.relay_url, timing=self.args.timing)
|
||||
|
||||
rc = yield self._go(w)
|
||||
yield w.close()
|
||||
|
@ -76,7 +76,8 @@ class TwistedReceiver(BlockingReceiver):
|
|||
@inlineCallbacks
|
||||
def establish_transit(self, w, them_d):
|
||||
transit_key = w.derive_key(APPID+u"/transit-key")
|
||||
transit_receiver = TransitReceiver(self.args.transit_helper)
|
||||
transit_receiver = TransitReceiver(self.args.transit_helper,
|
||||
timing=self.args.timing)
|
||||
transit_receiver.set_transit_key(transit_key)
|
||||
direct_hints = yield transit_receiver.get_direct_hints()
|
||||
relay_hints = yield transit_receiver.get_relay_hints()
|
||||
|
@ -100,6 +101,7 @@ class TwistedReceiver(BlockingReceiver):
|
|||
def transfer_data(self, record_pipe, f):
|
||||
self.msg(u"Receiving (%s).." % record_pipe.describe())
|
||||
|
||||
_start = self.args.timing.add_event("rx file")
|
||||
progress_stdout = self.args.stdout
|
||||
if self.args.hide_progress:
|
||||
progress_stdout = io.StringIO()
|
||||
|
@ -107,6 +109,7 @@ class TwistedReceiver(BlockingReceiver):
|
|||
record_pipe.connectConsumer(pfc)
|
||||
received = yield pfc.when_done
|
||||
record_pipe.disconnectConsumer()
|
||||
self.args.timing.finish_event(_start)
|
||||
# except TransitError
|
||||
if received < self.xfersize:
|
||||
self.msg()
|
||||
|
@ -117,8 +120,10 @@ class TwistedReceiver(BlockingReceiver):
|
|||
|
||||
@inlineCallbacks
|
||||
def close_transit(self, record_pipe):
|
||||
_start = self.args.timing.add_event("ack")
|
||||
yield record_pipe.send_record(b"ok\n")
|
||||
yield record_pipe.close()
|
||||
self.args.timing.finish_event(_start)
|
||||
|
||||
# based on twisted.protocols.ftp.FileConsumer, but:
|
||||
# - finish after 'xfersize' bytes received, instead of connectionLost()
|
||||
|
|
|
@ -18,14 +18,14 @@ def send_blocking(args):
|
|||
file=args.stdout)
|
||||
|
||||
if fd_to_send is not None:
|
||||
transit_sender = TransitSender(args.transit_helper)
|
||||
transit_sender = TransitSender(args.transit_helper, timing=args.timing)
|
||||
transit_data = {
|
||||
"direct_connection_hints": transit_sender.get_direct_hints(),
|
||||
"relay_connection_hints": transit_sender.get_relay_hints(),
|
||||
}
|
||||
phase1["transit"] = transit_data
|
||||
|
||||
with Wormhole(APPID, args.relay_url) as w:
|
||||
with Wormhole(APPID, args.relay_url, timing=args.timing) as w:
|
||||
if args.code:
|
||||
w.set_code(args.code)
|
||||
code = args.code
|
||||
|
@ -62,7 +62,8 @@ def send_blocking(args):
|
|||
raise TransferError("error sending text: %r" % (them_phase1,))
|
||||
|
||||
return _send_file_blocking(them_phase1, fd_to_send,
|
||||
transit_sender, args.stdout, args.hide_progress)
|
||||
transit_sender, args.stdout, args.hide_progress,
|
||||
args.timing)
|
||||
|
||||
def _do_verify(verifier, w):
|
||||
while True:
|
||||
|
@ -76,7 +77,7 @@ def _do_verify(verifier, w):
|
|||
raise TransferError("verification rejected, abandoning transfer")
|
||||
|
||||
def _send_file_blocking(them_phase1, fd_to_send, transit_sender,
|
||||
stdout, hide_progress):
|
||||
stdout, hide_progress, timing):
|
||||
|
||||
# we're sending a file, if they accept it
|
||||
|
||||
|
@ -94,6 +95,7 @@ def _send_file_blocking(them_phase1, fd_to_send, transit_sender,
|
|||
|
||||
print(u"Sending (%s).." % record_pipe.describe(), file=stdout)
|
||||
|
||||
_start = timing.add_event("tx file")
|
||||
CHUNKSIZE = 64*1024
|
||||
fd_to_send.seek(0,2)
|
||||
filesize = fd_to_send.tell()
|
||||
|
@ -111,11 +113,15 @@ def _send_file_blocking(them_phase1, fd_to_send, transit_sender,
|
|||
p.update(sent)
|
||||
if not hide_progress:
|
||||
p.finish()
|
||||
timing.finish_event(_start)
|
||||
|
||||
_start = timing.add_event("get ack")
|
||||
print(u"File sent.. waiting for confirmation", file=stdout)
|
||||
ack = record_pipe.receive_record()
|
||||
record_pipe.close()
|
||||
if ack == b"ok\n":
|
||||
print(u"Confirmation received. Transfer complete.", file=stdout)
|
||||
timing.finish_event(_start, ack="ok")
|
||||
return 0
|
||||
timing.finish_event(_start, ack="failed")
|
||||
raise TransferError("Transfer failed (remote says: %r)" % ack)
|
||||
|
|
|
@ -36,10 +36,10 @@ def send_twisted(args):
|
|||
print(u"On the other computer, please run: %s" % other_cmd,
|
||||
file=args.stdout)
|
||||
|
||||
w = Wormhole(APPID, args.relay_url)
|
||||
w = Wormhole(APPID, args.relay_url, timing=args.timing)
|
||||
|
||||
if fd_to_send:
|
||||
transit_sender = TransitSender(args.transit_helper)
|
||||
transit_sender = TransitSender(args.transit_helper, timing=args.timing)
|
||||
phase1["transit"] = transit_data = {}
|
||||
transit_data["relay_connection_hints"] = transit_sender.get_relay_hints()
|
||||
direct_hints = yield transit_sender.get_direct_hints()
|
||||
|
@ -101,7 +101,7 @@ def send_twisted(args):
|
|||
tdata = them_phase1["transit"]
|
||||
yield w.close()
|
||||
yield _send_file_twisted(tdata, transit_sender, fd_to_send,
|
||||
args.stdout, args.hide_progress)
|
||||
args.stdout, args.hide_progress, args.timing)
|
||||
returnValue(0)
|
||||
|
||||
class ProgressingFileSender(basic.FileSender):
|
||||
|
@ -124,7 +124,7 @@ class ProgressingFileSender(basic.FileSender):
|
|||
|
||||
@inlineCallbacks
|
||||
def _send_file_twisted(tdata, transit_sender, fd_to_send,
|
||||
stdout, hide_progress):
|
||||
stdout, hide_progress, timing):
|
||||
transit_sender.add_their_direct_hints(tdata["direct_connection_hints"])
|
||||
transit_sender.add_their_relay_hints(tdata["relay_connection_hints"])
|
||||
|
||||
|
@ -139,10 +139,16 @@ def _send_file_twisted(tdata, transit_sender, fd_to_send,
|
|||
record_pipe = yield transit_sender.connect()
|
||||
# record_pipe should implement IConsumer, chunks are just records
|
||||
print(u"Sending (%s).." % record_pipe.describe(), file=stdout)
|
||||
_start = timing.add_event("tx file")
|
||||
yield pfs.beginFileTransfer(fd_to_send, record_pipe)
|
||||
timing.finish_event(_start)
|
||||
|
||||
print(u"File sent.. waiting for confirmation", file=stdout)
|
||||
_start = timing.add_event("get ack")
|
||||
ack = yield record_pipe.receive_record()
|
||||
record_pipe.close()
|
||||
if ack != b"ok\n":
|
||||
timing.finish_event(_start, ack="failed")
|
||||
raise TransferError("Transfer failed (remote says: %r)" % ack)
|
||||
print(u"Confirmation received. Transfer complete.", file=stdout)
|
||||
timing.finish_event(_start, ack="ok")
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from __future__ import print_function
|
||||
import os, sys
|
||||
from ..errors import TransferError
|
||||
from ..timing import DebugTiming
|
||||
from .cli_args import parser
|
||||
|
||||
def dispatch(args):
|
||||
|
@ -40,11 +41,19 @@ def run(args, cwd, stdout, stderr, executable=None):
|
|||
args.cwd = cwd
|
||||
args.stdout = stdout
|
||||
args.stderr = stderr
|
||||
args.timing = timing = DebugTiming()
|
||||
|
||||
try:
|
||||
timing.add_event("command dispatch")
|
||||
rc = dispatch(args)
|
||||
timing.add_event("exit")
|
||||
if args.dump_timing:
|
||||
timing.write(args.dump_timing, stderr)
|
||||
return rc
|
||||
except TransferError as e:
|
||||
print(e, file=stderr)
|
||||
if args.dump_timing:
|
||||
timing.write(args.dump_timing, stderr)
|
||||
return 1
|
||||
except ImportError as e:
|
||||
print("--- ImportError ---", file=stderr)
|
||||
|
|
|
@ -11,6 +11,7 @@ from ..scripts import (runner, cmd_send_blocking, cmd_send_twisted,
|
|||
cmd_receive_blocking, cmd_receive_twisted)
|
||||
from ..scripts.send_common import build_phase1_data
|
||||
from ..errors import TransferError
|
||||
from ..timing import DebugTiming
|
||||
|
||||
class Phase1Data(unittest.TestCase):
|
||||
def test_text(self):
|
||||
|
@ -295,10 +296,12 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
|
|||
sargs.cwd = send_dir
|
||||
sargs.stdout = io.StringIO()
|
||||
sargs.stderr = io.StringIO()
|
||||
sargs.timing = DebugTiming()
|
||||
rargs = runner.parser.parse_args(receive_args)
|
||||
rargs.cwd = receive_dir
|
||||
rargs.stdout = io.StringIO()
|
||||
rargs.stderr = io.StringIO()
|
||||
rargs.timing = DebugTiming()
|
||||
if sender_twisted:
|
||||
send_d = cmd_send_twisted.send_twisted(sargs)
|
||||
else:
|
||||
|
|
18
src/wormhole/timing.py
Normal file
18
src/wormhole/timing.py
Normal file
|
@ -0,0 +1,18 @@
|
|||
from __future__ import print_function
|
||||
import json, time
|
||||
|
||||
class DebugTiming:
|
||||
def __init__(self):
|
||||
self.data = []
|
||||
def add_event(self, name, **details):
|
||||
# [ start, [stop], name, start_details{}, stop_details{} ]
|
||||
self.data.append( [time.time(), None, name, details, {}] )
|
||||
return len(self.data)-1
|
||||
def finish_event(self, index, **details):
|
||||
self.data[index][1] = time.time()
|
||||
self.data[index][4] = details
|
||||
def write(self, fn, stderr):
|
||||
with open(fn, "wb") as f:
|
||||
json.dump(self.data, f)
|
||||
f.write("\n")
|
||||
print("Timing data written to %s" % fn, file=stderr)
|
|
@ -16,6 +16,7 @@ from .eventsource_twisted import ReconnectingEventSource
|
|||
from .. import __version__
|
||||
from .. import codes
|
||||
from ..errors import ServerError, Timeout, WrongPasswordError, UsageError
|
||||
from ..timing import DebugTiming
|
||||
from ..util.hkdf import HKDF
|
||||
from ..channel_monitor import monitor
|
||||
|
||||
|
@ -227,13 +228,14 @@ class Wormhole:
|
|||
version_warning_displayed = False
|
||||
_send_confirm = True
|
||||
|
||||
def __init__(self, appid, relay_url):
|
||||
def __init__(self, appid, relay_url, timing=None):
|
||||
if not isinstance(appid, type(u"")): raise TypeError(type(appid))
|
||||
if not isinstance(relay_url, type(u"")):
|
||||
raise TypeError(type(relay_url))
|
||||
if not relay_url.endswith(u"/"): raise UsageError
|
||||
self._appid = appid
|
||||
self._relay_url = relay_url
|
||||
self._timing = timing or DebugTiming()
|
||||
self._set_side(hexlify(os.urandom(5)).decode("ascii"))
|
||||
self.code = None
|
||||
self.key = None
|
||||
|
@ -242,6 +244,7 @@ class Wormhole:
|
|||
self._got_data = set()
|
||||
self._got_confirmation = False
|
||||
self._closed = False
|
||||
self._timing_started = self._timing.add_event("wormhole")
|
||||
|
||||
def _set_side(self, side):
|
||||
self._side = side
|
||||
|
@ -276,7 +279,9 @@ class Wormhole:
|
|||
if self.code is not None: raise UsageError
|
||||
if self._started_get_code: raise UsageError
|
||||
self._started_get_code = True
|
||||
_start = self._timing.add_event("alloc channel")
|
||||
channelid = yield self._channel_manager.allocate()
|
||||
self._timing.finish_event(_start)
|
||||
code = codes.make_code(channelid, code_length)
|
||||
assert isinstance(code, type(u"")), type(code)
|
||||
self._set_code_and_channelid(code)
|
||||
|
@ -291,6 +296,7 @@ class Wormhole:
|
|||
|
||||
def _set_code_and_channelid(self, code):
|
||||
if self.code is not None: raise UsageError
|
||||
self._timing.add_event("code established")
|
||||
mo = re.search(r'^(\d+)-', code)
|
||||
if not mo:
|
||||
raise ValueError("code (%s) must start with NN-" % code)
|
||||
|
@ -361,17 +367,25 @@ class Wormhole:
|
|||
# TODO: prevent multiple invocation
|
||||
if self.key:
|
||||
returnValue(self.key)
|
||||
_sent = self._timing.add_event("send pake")
|
||||
yield self._channel.send(u"pake", self.msg1)
|
||||
self._timing.finish_event(_sent)
|
||||
_sent = self._timing.add_event("get pake")
|
||||
pake_msg = yield self._channel.get(u"pake")
|
||||
self._timing.finish_event(_sent)
|
||||
key = self.sp.finish(pake_msg)
|
||||
self.key = key
|
||||
self.verifier = self.derive_key(u"wormhole:verifier")
|
||||
self._timing.add_event("key established")
|
||||
|
||||
if not self._send_confirm:
|
||||
returnValue(key)
|
||||
confkey = self.derive_key(u"wormhole:confirmation")
|
||||
nonce = os.urandom(CONFMSG_NONCE_LENGTH)
|
||||
confmsg = make_confmsg(confkey, nonce)
|
||||
_sent = self._timing.add_event("send confirmation")
|
||||
yield self._channel.send(u"_confirm", confmsg)
|
||||
self._timing.finish_event(_sent)
|
||||
returnValue(key)
|
||||
|
||||
@close_on_error
|
||||
|
@ -393,6 +407,7 @@ class Wormhole:
|
|||
if phase.startswith(u"_"): raise UsageError # reserved for internals
|
||||
if self.code is None: raise UsageError
|
||||
if self._channel is None: raise UsageError
|
||||
_sent = self._timing.add_event("API send data", phase=phase)
|
||||
# Without predefined roles, we can't derive predictably unique keys
|
||||
# for each side, so we use the same key for both. We use random
|
||||
# nonces to keep the messages distinct, and the Channel automatically
|
||||
|
@ -401,7 +416,10 @@ class Wormhole:
|
|||
yield self._get_key()
|
||||
data_key = self.derive_key(u"wormhole:phase:%s" % phase)
|
||||
outbound_encrypted = self._encrypt_data(data_key, outbound_data)
|
||||
_sent2 = self._timing.add_event("send")
|
||||
yield self._channel.send(phase, outbound_encrypted)
|
||||
self._timing.finish_event(_sent2)
|
||||
self._timing.finish_event(_sent)
|
||||
|
||||
@close_on_error
|
||||
@inlineCallbacks
|
||||
|
@ -412,13 +430,16 @@ class Wormhole:
|
|||
if self._closed: raise UsageError
|
||||
if self.code is None: raise UsageError
|
||||
if self._channel is None: raise UsageError
|
||||
_sent = self._timing.add_event("API get data", phase=phase)
|
||||
self._got_data.add(phase)
|
||||
yield self._get_key()
|
||||
phases = []
|
||||
if not self._got_confirmation:
|
||||
phases.append(u"_confirm")
|
||||
phases.append(phase)
|
||||
_sent2 = self._timing.add_event("get", phases=phases)
|
||||
phase_and_body = yield self._channel.get_first_of(phases)
|
||||
self._timing.finish_event(_sent2)
|
||||
(got_phase, body) = phase_and_body
|
||||
if got_phase == u"_confirm":
|
||||
confkey = self.derive_key(u"wormhole:confirmation")
|
||||
|
@ -426,8 +447,11 @@ class Wormhole:
|
|||
if body != make_confmsg(confkey, nonce):
|
||||
raise WrongPasswordError
|
||||
self._got_confirmation = True
|
||||
_sent3 = self._timing.add_event("get", phases=[phase])
|
||||
phase_and_body = yield self._channel.get_first_of([phase])
|
||||
self._timing.finish_event(_sent3)
|
||||
(got_phase, body) = phase_and_body
|
||||
self._timing.finish_event(_sent)
|
||||
assert got_phase == phase
|
||||
try:
|
||||
data_key = self.derive_key(u"wormhole:phase:%s" % phase)
|
||||
|
@ -443,7 +467,10 @@ class Wormhole:
|
|||
self._closed = True
|
||||
if not self._channel:
|
||||
returnValue(None)
|
||||
self._timing.finish_event(self._timing_started, mood=mood)
|
||||
c, self._channel = self._channel, None
|
||||
monitor.close(c)
|
||||
_sent = self._timing.add_event("close")
|
||||
yield c.deallocate(mood)
|
||||
self._timing.finish_event(_sent)
|
||||
|
||||
|
|
|
@ -11,6 +11,7 @@ from nacl.secret import SecretBox
|
|||
from ..util import ipaddrs
|
||||
from ..util.hkdf import HKDF
|
||||
from ..errors import UsageError
|
||||
from ..timing import DebugTiming
|
||||
from ..transit_common import (BadHandshake,
|
||||
BadNonce,
|
||||
build_receiver_handshake,
|
||||
|
@ -409,7 +410,7 @@ def there_can_be_only_one(contenders):
|
|||
class Common:
|
||||
RELAY_DELAY = 2.0
|
||||
|
||||
def __init__(self, transit_relay, reactor=reactor):
|
||||
def __init__(self, transit_relay, reactor=reactor, timing=None):
|
||||
if transit_relay:
|
||||
if not isinstance(transit_relay, type(u"")):
|
||||
raise UsageError
|
||||
|
@ -421,6 +422,8 @@ class Common:
|
|||
self._listener = None
|
||||
self._winner = None
|
||||
self._reactor = reactor
|
||||
self._timing = timing or DebugTiming()
|
||||
self._timing_started = self._timing.add_event("transit")
|
||||
|
||||
def _build_listener(self):
|
||||
portnum = allocate_tcp_port()
|
||||
|
@ -539,11 +542,13 @@ class Common:
|
|||
|
||||
@inlineCallbacks
|
||||
def connect(self):
|
||||
_start = self._timing.add_event("transit connect")
|
||||
yield self._get_transit_key()
|
||||
# we want to have the transit key before starting any outbound
|
||||
# connections, so those connections will know what to say when they
|
||||
# connect
|
||||
winner = yield self._connect()
|
||||
self._timing.finish_event(_start)
|
||||
returnValue(winner)
|
||||
|
||||
def _connect(self):
|
||||
|
|
Loading…
Reference in New Issue
Block a user