rewrite timing instrumentation: use context managers

This commit is contained in:
Brian Warner 2016-04-29 14:27:29 -07:00
parent b70c2f8868
commit 24e52c0320
8 changed files with 328 additions and 286 deletions

View File

@ -69,12 +69,12 @@ class Channel:
"phase": phase, "phase": phase,
"body": hexlify(msg).decode("ascii")} "body": hexlify(msg).decode("ascii")}
data = json.dumps(payload).encode("utf-8") data = json.dumps(payload).encode("utf-8")
_sent = self._timing.add_event("send %s" % phase) with self._timing.add("send %s" % phase) as t:
r = requests.post(self._relay_url+"add", data=data, r = requests.post(self._relay_url+"add", data=data,
timeout=self._timeout) timeout=self._timeout)
r.raise_for_status() r.raise_for_status()
resp = r.json() resp = r.json()
self._timing.finish_event(_sent, resp.get("sent")) t.server_sent(resp.get("sent"))
if "welcome" in resp: if "welcome" in resp:
self._handle_welcome(resp["welcome"]) self._handle_welcome(resp["welcome"])
self._add_inbound_messages(resp["messages"]) self._add_inbound_messages(resp["messages"])
@ -92,9 +92,7 @@ class Channel:
# wasn't one of our own messages. It will either come from # wasn't one of our own messages. It will either come from
# previously-received messages, or from an EventSource that we attach # previously-received messages, or from an EventSource that we attach
# to the corresponding URL # to the corresponding URL
_sent = self._timing.add_event("get %s" % "/".join(sorted(phases))) with self._timing.add("get %s" % "/".join(sorted(phases))) as t:
_server_sent = None
phase_and_body = self._find_inbound_message(phases) phase_and_body = self._find_inbound_message(phases)
while phase_and_body is None: while phase_and_body is None:
remaining = self._started + self._timeout - time.time() remaining = self._started + self._timeout - time.time()
@ -115,11 +113,10 @@ class Channel:
phase_and_body = self._find_inbound_message(phases) phase_and_body = self._find_inbound_message(phases)
if phase_and_body: if phase_and_body:
f.close() f.close()
_server_sent = data.get("sent") t.server_sent(data.get("sent"))
break break
if not phase_and_body: if not phase_and_body:
time.sleep(self._wait) time.sleep(self._wait)
self._timing.finish_event(_sent, _server_sent)
return phase_and_body return phase_and_body
def get(self, phase): def get(self, phase):
@ -136,11 +133,11 @@ class Channel:
try: try:
# ignore POST failure, don't call r.raise_for_status(), set a # ignore POST failure, don't call r.raise_for_status(), set a
# short timeout and ignore failures # short timeout and ignore failures
_sent = self._timing.add_event("close") with self._timing.add("close") as t:
r = requests.post(self._relay_url+"deallocate", data=data, r = requests.post(self._relay_url+"deallocate", data=data,
timeout=5) timeout=5)
resp = r.json() resp = r.json()
self._timing.finish_event(_sent, resp.get("sent")) t.server_sent(resp.get("sent"))
except requests.exceptions.RequestException: except requests.exceptions.RequestException:
pass pass
@ -157,28 +154,28 @@ class ChannelManager:
def list_channels(self): def list_channels(self):
queryargs = urlencode([("appid", self._appid)]) queryargs = urlencode([("appid", self._appid)])
_sent = self._timing.add_event("list") with self._timing.add("list") as t:
r = requests.get(self._relay_url+"list?%s" % queryargs, r = requests.get(self._relay_url+"list?%s" % queryargs,
timeout=self._timeout) timeout=self._timeout)
r.raise_for_status() r.raise_for_status()
data = r.json() data = r.json()
if "welcome" in data: if "welcome" in data:
self._handle_welcome(data["welcome"]) self._handle_welcome(data["welcome"])
self._timing.finish_event(_sent, data.get("sent")) t.server_sent(data.get("sent"))
channelids = data["channelids"] channelids = data["channelids"]
return channelids return channelids
def allocate(self): def allocate(self):
data = json.dumps({"appid": self._appid, data = json.dumps({"appid": self._appid,
"side": self._side}).encode("utf-8") "side": self._side}).encode("utf-8")
_sent = self._timing.add_event("allocate") with self._timing.add("allocate") as t:
r = requests.post(self._relay_url+"allocate", data=data, r = requests.post(self._relay_url+"allocate", data=data,
timeout=self._timeout) timeout=self._timeout)
r.raise_for_status() r.raise_for_status()
data = r.json() data = r.json()
if "welcome" in data: if "welcome" in data:
self._handle_welcome(data["welcome"]) self._handle_welcome(data["welcome"])
self._timing.finish_event(_sent, data.get("sent")) t.server_sent(data.get("sent"))
channelid = data["channelid"] channelid = data["channelid"]
return channelid return channelid
@ -239,7 +236,7 @@ class Wormhole:
self._got_data = set() self._got_data = set()
self._got_confirmation = False self._got_confirmation = False
self._closed = False self._closed = False
self._timing_started = self._timing.add_event("wormhole") self._timing_started = self._timing.add("wormhole")
def __enter__(self): def __enter__(self):
return self return self
@ -284,11 +281,10 @@ class Wormhole:
# discover the welcome message (and warn the user about an obsolete # discover the welcome message (and warn the user about an obsolete
# client) # client)
initial_channelids = lister() initial_channelids = lister()
_start = self._timing.add_event("input code", waiting="user") with self._timing.add("input code", waiting="user"):
code = codes.input_code_with_completion(prompt, code = codes.input_code_with_completion(prompt,
initial_channelids, lister, initial_channelids, lister,
code_length) code_length)
self._timing.finish_event(_start)
return code return code
def set_code(self, code): # used for human-made pre-generated codes def set_code(self, code): # used for human-made pre-generated codes
@ -299,7 +295,7 @@ class Wormhole:
def _set_code_and_channelid(self, code): def _set_code_and_channelid(self, code):
if self.code is not None: raise UsageError if self.code is not None: raise UsageError
self._timing.add_event("code established") self._timing.add("code established")
mo = re.search(r'^(\d+)-', code) mo = re.search(r'^(\d+)-', code)
if not mo: if not mo:
raise ValueError("code (%s) must start with NN-" % code) raise ValueError("code (%s) must start with NN-" % code)
@ -342,7 +338,7 @@ class Wormhole:
self.key = self.sp.finish(pake_msg) self.key = self.sp.finish(pake_msg)
self.verifier = self.derive_key(u"wormhole:verifier") self.verifier = self.derive_key(u"wormhole:verifier")
self._timing.add_event("key established") self._timing.add("key established")
if not self._send_confirm: if not self._send_confirm:
return return
@ -369,17 +365,16 @@ class Wormhole:
if phase.startswith(u"_"): raise UsageError # reserved for internals if phase.startswith(u"_"): raise UsageError # reserved for internals
if self.code is None: raise UsageError if self.code is None: raise UsageError
if self._channel is None: raise UsageError if self._channel is None: raise UsageError
_sent = self._timing.add_event("API send data", phase=phase) with self._timing.add("API send data", phase=phase):
# Without predefined roles, we can't derive predictably unique keys # Without predefined roles, we can't derive predictably unique
# for each side, so we use the same key for both. We use random # keys for each side, so we use the same key for both. We use
# nonces to keep the messages distinct, and the Channel automatically # random nonces to keep the messages distinct, and the Channel
# ignores reflections. # automatically ignores reflections.
self._sent_data.add(phase) self._sent_data.add(phase)
self._get_key() self._get_key()
data_key = self.derive_key(u"wormhole:phase:%s" % phase) data_key = self.derive_key(u"wormhole:phase:%s" % phase)
outbound_encrypted = self._encrypt_data(data_key, outbound_data) outbound_encrypted = self._encrypt_data(data_key, outbound_data)
self._channel.send(phase, outbound_encrypted) self._channel.send(phase, outbound_encrypted)
self._timing.finish_event(_sent)
@close_on_error @close_on_error
def get_data(self, phase=u"data"): def get_data(self, phase=u"data"):
@ -389,7 +384,7 @@ class Wormhole:
if self._closed: raise UsageError if self._closed: raise UsageError
if self.code is None: raise UsageError if self.code is None: raise UsageError
if self._channel is None: raise UsageError if self._channel is None: raise UsageError
_sent = self._timing.add_event("API get data", phase=phase) with self._timing.add("API get data", phase=phase):
self._got_data.add(phase) self._got_data.add(phase)
self._get_key() self._get_key()
phases = [] phases = []
@ -405,7 +400,6 @@ class Wormhole:
self._got_confirmation = True self._got_confirmation = True
(got_phase, body) = self._channel.get_first_of([phase]) (got_phase, body) = self._channel.get_first_of([phase])
assert got_phase == phase assert got_phase == phase
self._timing.finish_event(_sent)
try: try:
data_key = self.derive_key(u"wormhole:phase:%s" % phase) data_key = self.derive_key(u"wormhole:phase:%s" % phase)
inbound_data = self._decrypt_data(data_key, body) inbound_data = self._decrypt_data(data_key, body)
@ -418,7 +412,7 @@ class Wormhole:
raise TypeError(type(mood)) raise TypeError(type(mood))
self._closed = True self._closed = True
if self._channel: if self._channel:
self._timing.finish_event(self._timing_started, mood=mood) self._timing_started.finish(mood=mood)
c, self._channel = self._channel, None c, self._channel = self._channel, None
monitor.close(c) monitor.close(c)
c.deallocate(mood) c.deallocate(mood)

View File

@ -37,9 +37,8 @@ class TwistedReceiver:
def go(self): def go(self):
tor_manager = None tor_manager = None
if self.args.tor: if self.args.tor:
_start = self.args.timing.add_event("import TorManager") with self.args.timing.add("import", which="tor_manager"):
from ..twisted.tor_manager import TorManager from ..twisted.tor_manager import TorManager
self.args.timing.finish_event(_start)
tor_manager = TorManager(self._reactor, timing=self.args.timing) tor_manager = TorManager(self._reactor, timing=self.args.timing)
# For now, block everything until Tor has started. Soon: launch # For now, block everything until Tor has started. Soon: launch
# tor in parallel with everything else, make sure the TorManager # tor in parallel with everything else, make sure the TorManager
@ -172,15 +171,15 @@ class TwistedReceiver:
return abs_destname return abs_destname
def ask_permission(self): def ask_permission(self):
_start = self.args.timing.add_event("permission", waiting="user") with self.args.timing.add("permission", waiting="user") as t:
while True and not self.args.accept_file: while True and not self.args.accept_file:
ok = six.moves.input("ok? (y/n): ") ok = six.moves.input("ok? (y/n): ")
if ok.lower().startswith("y"): if ok.lower().startswith("y"):
break break
print(u"transfer rejected", file=sys.stderr) print(u"transfer rejected", file=sys.stderr)
self.args.timing.finish_event(_start, answer="no") t.detail(answer="no")
raise RespondError("transfer rejected") raise RespondError("transfer rejected")
self.args.timing.finish_event(_start, answer="yes") t.detail(answer="yes")
@inlineCallbacks @inlineCallbacks
def establish_transit(self, w, them_d, tor_manager): def establish_transit(self, w, them_d, tor_manager):
@ -207,20 +206,20 @@ class TwistedReceiver:
transit_receiver.add_their_direct_hints(tdata["direct_connection_hints"]) transit_receiver.add_their_direct_hints(tdata["direct_connection_hints"])
transit_receiver.add_their_relay_hints(tdata["relay_connection_hints"]) transit_receiver.add_their_relay_hints(tdata["relay_connection_hints"])
record_pipe = yield transit_receiver.connect() record_pipe = yield transit_receiver.connect()
self.args.timing.add("transit connected")
returnValue(record_pipe) returnValue(record_pipe)
@inlineCallbacks @inlineCallbacks
def transfer_data(self, record_pipe, f): def transfer_data(self, record_pipe, f):
self.msg(u"Receiving (%s).." % record_pipe.describe()) self.msg(u"Receiving (%s).." % record_pipe.describe())
_start = self.args.timing.add_event("rx file") with self.args.timing.add("rx file"):
progress = tqdm(file=self.args.stdout, progress = tqdm(file=self.args.stdout,
disable=self.args.hide_progress, disable=self.args.hide_progress,
unit="B", unit_scale=True, total=self.xfersize) unit="B", unit_scale=True, total=self.xfersize)
with progress: with progress:
received = yield record_pipe.writeToFile(f, self.xfersize, received = yield record_pipe.writeToFile(f, self.xfersize,
progress.update) progress.update)
self.args.timing.finish_event(_start)
# except TransitError # except TransitError
if received < self.xfersize: if received < self.xfersize:
@ -239,7 +238,7 @@ class TwistedReceiver:
def write_directory(self, f): def write_directory(self, f):
self.msg(u"Unpacking zipfile..") self.msg(u"Unpacking zipfile..")
_start = self.args.timing.add_event("unpack zip") with self.args.timing.add("unpack zip"):
with zipfile.ZipFile(f, "r", zipfile.ZIP_DEFLATED) as zf: with zipfile.ZipFile(f, "r", zipfile.ZIP_DEFLATED) as zf:
zf.extractall(path=self.abs_destname) zf.extractall(path=self.abs_destname)
# extractall() appears to offer some protection against # extractall() appears to offer some protection against
@ -249,11 +248,9 @@ class TwistedReceiver:
self.msg(u"Received files written to %s/" % self.msg(u"Received files written to %s/" %
os.path.basename(self.abs_destname)) os.path.basename(self.abs_destname))
f.close() f.close()
self.args.timing.finish_event(_start)
@inlineCallbacks @inlineCallbacks
def close_transit(self, record_pipe): def close_transit(self, record_pipe):
_start = self.args.timing.add_event("ack") with self.args.timing.add("send ack"):
yield record_pipe.send_record(b"ok\n") yield record_pipe.send_record(b"ok\n")
yield record_pipe.close() yield record_pipe.close()
self.args.timing.finish_event(_start)

View File

@ -210,6 +210,7 @@ def _send_file_twisted(tdata, transit_sender, fd_to_send,
fd_to_send.seek(0,0) fd_to_send.seek(0,0)
record_pipe = yield transit_sender.connect() record_pipe = yield transit_sender.connect()
timing.add("transit connected")
# record_pipe should implement IConsumer, chunks are just records # record_pipe should implement IConsumer, chunks are just records
print(u"Sending (%s).." % record_pipe.describe(), file=stdout) print(u"Sending (%s).." % record_pipe.describe(), file=stdout)
@ -221,17 +222,17 @@ def _send_file_twisted(tdata, transit_sender, fd_to_send,
return data return data
fs = basic.FileSender() fs = basic.FileSender()
_start = timing.add_event("tx file") with timing.add("tx file"):
with progress: with progress:
yield fs.beginFileTransfer(fd_to_send, record_pipe, transform=_count) yield fs.beginFileTransfer(fd_to_send, record_pipe,
timing.finish_event(_start) transform=_count)
print(u"File sent.. waiting for confirmation", file=stdout) print(u"File sent.. waiting for confirmation", file=stdout)
_start = timing.add_event("get ack") with timing.add("get ack") as t:
ack = yield record_pipe.receive_record() ack = yield record_pipe.receive_record()
record_pipe.close() record_pipe.close()
if ack != b"ok\n": if ack != b"ok\n":
timing.finish_event(_start, ack="failed") t.detail(ack="failed")
raise TransferError("Transfer failed (remote says: %r)" % ack) raise TransferError("Transfer failed (remote says: %r)" % ack)
print(u"Confirmation received. Transfer complete.", file=stdout) print(u"Confirmation received. Transfer complete.", file=stdout)
timing.finish_event(_start, ack="ok") t.detail(ack="ok")

View File

@ -1,19 +1,22 @@
from __future__ import print_function from __future__ import print_function
import time
start = time.time()
import os, sys import os, sys
from twisted.internet.defer import maybeDeferred from twisted.internet.defer import maybeDeferred
from twisted.internet.task import react from twisted.internet.task import react
from ..errors import TransferError, WrongPasswordError, Timeout from ..errors import TransferError, WrongPasswordError, Timeout
from ..timing import DebugTiming from ..timing import DebugTiming
from .cli_args import parser from .cli_args import parser
top_import_finish = time.time()
def dispatch(args): # returns Deferred def dispatch(args): # returns Deferred
if args.func == "send/send": if args.func == "send/send":
with args.timing.add("import", which="cmd_send"):
from . import cmd_send from . import cmd_send
return cmd_send.send(args) return cmd_send.send(args)
if args.func == "receive/receive": if args.func == "receive/receive":
_start = args.timing.add_event("import c_r_t") with args.timing.add("import", which="cmd_receive"):
from . import cmd_receive from . import cmd_receive
args.timing.finish_event(_start)
return cmd_receive.receive(args) return cmd_receive.receive(args)
raise ValueError("unknown args.func %s" % args.func) raise ValueError("unknown args.func %s" % args.func)
@ -34,11 +37,12 @@ def run(reactor, argv, cwd, stdout, stderr, executable=None):
args.stderr = stderr args.stderr = stderr
args.timing = timing = DebugTiming() args.timing = timing = DebugTiming()
timing.add_event("command dispatch") timing.add("command dispatch")
timing.add("import", when=start, which="top").finish(when=top_import_finish)
# fires with None, or raises an error # fires with None, or raises an error
d = maybeDeferred(dispatch, args) d = maybeDeferred(dispatch, args)
def _maybe_dump_timing(res): def _maybe_dump_timing(res):
timing.add_event("exit") timing.add("exit")
if args.dump_timing: if args.dump_timing:
timing.write(args.dump_timing, stderr) timing.write(args.dump_timing, stderr)
return res return res

View File

@ -1,23 +1,57 @@
from __future__ import print_function from __future__ import print_function, absolute_import
import json, time import json, time
class Event:
def __init__(self, name, when, **details):
# data fields that will be dumped to JSON later
self._start = time.time() if when is None else float(when)
self._server_sent = None
self._stop = None
self._name = name
self._details = details
def server_sent(self, when):
self._server_sent = when
def detail(self, **details):
self._details.update(details)
def finish(self, server_sent=None, **details):
self._stop = time.time()
if server_sent:
self.server_sent(server_sent)
self.detail(**details)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, exc_tb):
if exc_type:
# inlineCallbacks uses a special exception (defer._DefGen_Return)
# to deliver returnValue(), so if returnValue is used inside our
# with: block, we'll mistakenly think it means something broke.
# I've moved all returnValue() calls outside the 'with
# timing.add()' blocks to avoid this, but if a new one
# accidentally pops up, it'll get marked as an error. I used to
# catch-and-release _DefGen_Return to avoid this, but removed it
# because it requires referencing defer.py's private class
self.finish(exception=str(exc_type))
else:
self.finish()
class DebugTiming: class DebugTiming:
def __init__(self): def __init__(self):
self.data = [] self._events = []
def add_event(self, name, when=None, **details):
# [ start, [server_sent], [stop], name, start_details{}, stop_details{} ] def add(self, name, when=None, **details):
if when is None: ev = Event(name, when, **details)
when = time.time() self._events.append(ev)
when = float(when) return ev
self.data.append( [when, None, None, name, details, {}] )
return len(self.data)-1
def finish_event(self, index, server_sent=None, **details):
if server_sent is not None:
self.data[index][1] = float(server_sent)
self.data[index][2] = time.time()
self.data[index][5] = details
def write(self, fn, stderr): def write(self, fn, stderr):
with open(fn, "wb") as f: with open(fn, "wb") as f:
json.dump(self.data, f) data = [ [e._start, e._server_sent, e._stop, e._name, e._details]
for e in self._events ]
json.dump(data, f)
f.write("\n") f.write("\n")
print("Timing data written to %s" % fn, file=stderr) print("Timing data written to %s" % fn, file=stderr)

View File

@ -44,10 +44,10 @@ class TorManager:
self._can_run_service = False self._can_run_service = False
returnValue(True) returnValue(True)
_start_find = self._timing.add_event("find tor") _start_find = self._timing.add("find tor")
# try port 9051, then try /var/run/tor/control . Throws on failure. # try port 9051, then try /var/run/tor/control . Throws on failure.
state = None state = None
_start_tcp = self._timing.add_event("tor localhost") with self._timing.add("tor localhost"):
try: try:
connection = (self._reactor, "127.0.0.1", self._tor_control_port) connection = (self._reactor, "127.0.0.1", self._tor_control_port)
state = yield txtorcon.build_tor_connection(connection) state = yield txtorcon.build_tor_connection(connection)
@ -55,20 +55,18 @@ class TorManager:
except ConnectError: except ConnectError:
print("unable to reach Tor on %d" % self._tor_control_port) print("unable to reach Tor on %d" % self._tor_control_port)
pass pass
self._timing.finish_event(_start_tcp)
if not state: if not state:
_start_unix = self._timing.add_event("tor unix") with self._timing.add("tor unix"):
try: try:
connection = (self._reactor, "/var/run/tor/control") connection = (self._reactor, "/var/run/tor/control")
# add build_state=False to get back a Protocol object instead # add build_state=False to get back a Protocol object
# of a State object # instead of a State object
state = yield txtorcon.build_tor_connection(connection) state = yield txtorcon.build_tor_connection(connection)
self._tor_protocol = state.protocol self._tor_protocol = state.protocol
except (ValueError, ConnectError): except (ValueError, ConnectError):
print("unable to reach Tor on /var/run/tor/control") print("unable to reach Tor on /var/run/tor/control")
pass pass
self._timing.finish_event(_start_unix)
if state: if state:
print("connected to pre-existing Tor process") print("connected to pre-existing Tor process")
@ -78,19 +76,19 @@ class TorManager:
yield self._create_my_own_tor() yield self._create_my_own_tor()
# that sets self._tor_socks_port and self._tor_protocol # that sets self._tor_socks_port and self._tor_protocol
self._timing.finish_event(_start_find) _start_find.finish()
self._can_run_service = True self._can_run_service = True
returnValue(True) returnValue(True)
@inlineCallbacks @inlineCallbacks
def _create_my_own_tor(self): def _create_my_own_tor(self):
_start_launch = self._timing.add_event("launch tor") with self._timing.add("launch tor"):
start = time.time() start = time.time()
config = self.config = txtorcon.TorConfig() config = self.config = txtorcon.TorConfig()
if 0: if 0:
# The default is for launch_tor to create a tempdir itself, and # The default is for launch_tor to create a tempdir itself,
# delete it when done. We only need to set a DataDirectory if we # and delete it when done. We only need to set a
# want it to be persistent. # DataDirectory if we want it to be persistent.
import tempfile import tempfile
datadir = tempfile.mkdtemp() datadir = tempfile.mkdtemp()
config.DataDirectory = datadir config.DataDirectory = datadir
@ -108,7 +106,6 @@ class TorManager:
self._tor_protocol = tpp.tor_protocol self._tor_protocol = tpp.tor_protocol
print("tp:", self._tor_protocol) print("tp:", self._tor_protocol)
print("elapsed:", time.time() - start) print("elapsed:", time.time() - start)
self._timing.finish_event(_start_launch)
returnValue(True) returnValue(True)
def is_non_public_numeric_address(self, host): def is_non_public_numeric_address(self, host):

View File

@ -77,14 +77,16 @@ class Wormhole:
self._sent_messages = set() # (phase, body_bytes) self._sent_messages = set() # (phase, body_bytes)
self._delivered_messages = set() # (phase, body_bytes) self._delivered_messages = set() # (phase, body_bytes)
self._received_messages = {} # phase -> body_bytes self._received_messages = {} # phase -> body_bytes
self._received_messages_sent = {} # phase -> server timestamp
self._sent_phases = set() # phases, to prohibit double-send self._sent_phases = set() # phases, to prohibit double-send
self._got_phases = set() # phases, to prohibit double-read self._got_phases = set() # phases, to prohibit double-read
self._sleepers = [] self._sleepers = []
self._confirmation_failed = False self._confirmation_failed = False
self._closed = False self._closed = False
self._deallocated_status = None self._deallocated_status = None
self._timing_started = self._timing.add_event("wormhole") self._timing_started = self._timing.add("wormhole")
self._ws = None self._ws = None
self._ws_t = None # timing Event
self._ws_channel_claimed = False self._ws_channel_claimed = False
self._error = None self._error = None
@ -112,6 +114,7 @@ class Wormhole:
ep = self._make_endpoint(p.hostname, p.port or 80) ep = self._make_endpoint(p.hostname, p.port or 80)
# .connect errbacks if the TCP connection fails # .connect errbacks if the TCP connection fails
self._ws = yield ep.connect(f) self._ws = yield ep.connect(f)
self._ws_t = self._timing.add("websocket")
# f.d is errbacked if WebSocket negotiation fails # f.d is errbacked if WebSocket negotiation fails
yield f.d # WebSocket drops data sent before onOpen() fires yield f.d # WebSocket drops data sent before onOpen() fires
self._ws_send(u"bind", appid=self._appid, side=self._side) self._ws_send(u"bind", appid=self._appid, side=self._side)
@ -134,6 +137,7 @@ class Wormhole:
return meth(msg) return meth(msg)
def _ws_handle_welcome(self, msg): def _ws_handle_welcome(self, msg):
self._timing.add("welcome").server_sent(msg["sent"])
welcome = msg["welcome"] welcome = msg["welcome"]
if ("motd" in welcome and if ("motd" in welcome and
not self.motd_displayed): not self.motd_displayed):
@ -184,6 +188,7 @@ class Wormhole:
self._wakeup() self._wakeup()
def _ws_handle_error(self, msg): def _ws_handle_error(self, msg):
self._timing.add("error").server_sent(msg["sent"])
err = ServerError("%s: %s" % (msg["error"], msg["orig"]), err = ServerError("%s: %s" % (msg["error"], msg["orig"]),
self._ws_url) self._ws_url)
return self._signal_error(err) return self._signal_error(err)
@ -203,11 +208,12 @@ class Wormhole:
if self._code is not None: raise UsageError if self._code is not None: raise UsageError
if self._started_get_code: raise UsageError if self._started_get_code: raise UsageError
self._started_get_code = True self._started_get_code = True
_sent = self._timing.add_event("allocate") with self._timing.add("API get_code"):
with self._timing.add("allocate") as t:
self._allocate_t = t
yield self._ws_send(u"allocate") yield self._ws_send(u"allocate")
while self._channelid is None: while self._channelid is None:
yield self._sleep() yield self._sleep()
self._timing.finish_event(_sent)
code = codes.make_code(self._channelid, code_length) code = codes.make_code(self._channelid, code_length)
assert isinstance(code, type(u"")), type(code) assert isinstance(code, type(u"")), type(code)
self._set_code(code) self._set_code(code)
@ -215,6 +221,7 @@ class Wormhole:
returnValue(code) returnValue(code)
def _ws_handle_allocated(self, msg): def _ws_handle_allocated(self, msg):
self._allocate_t.server_sent(msg["sent"])
if self._channelid is not None: if self._channelid is not None:
return self._signal_error("got duplicate channelid") return self._signal_error("got duplicate channelid")
self._channelid = msg["channelid"] self._channelid = msg["channelid"]
@ -222,6 +229,7 @@ class Wormhole:
def _start(self): def _start(self):
# allocate the rest now too, so it can be serialized # allocate the rest now too, so it can be serialized
with self._timing.add("pake1", waiting="crypto"):
self._sp = SPAKE2_Symmetric(to_bytes(self._code), self._sp = SPAKE2_Symmetric(to_bytes(self._code),
idSymmetric=to_bytes(self._appid)) idSymmetric=to_bytes(self._appid))
self._msg1 = self._sp.start() self._msg1 = self._sp.start()
@ -238,8 +246,9 @@ class Wormhole:
# TODO: send the request early, show the prompt right away, hide the # TODO: send the request early, show the prompt right away, hide the
# latency in the user's indecision and slow typing. If we're lucky # latency in the user's indecision and slow typing. If we're lucky
# the answer will come back before they hit TAB. # the answer will come back before they hit TAB.
with self._timing.add("API input_code"):
initial_channelids = yield self._list_channels() initial_channelids = yield self._list_channels()
_start = self._timing.add_event("input code", waiting="user") with self._timing.add("input code", waiting="user"):
t = self._reactor.addSystemEventTrigger("before", "shutdown", t = self._reactor.addSystemEventTrigger("before", "shutdown",
self._warn_readline) self._warn_readline)
code = yield deferToThread(codes.input_code_with_completion, code = yield deferToThread(codes.input_code_with_completion,
@ -247,7 +256,6 @@ class Wormhole:
initial_channelids, _lister, initial_channelids, _lister,
code_length) code_length)
self._reactor.removeSystemEventTrigger(t) self._reactor.removeSystemEventTrigger(t)
self._timing.finish_event(_start)
returnValue(code) # application will give this to set_code() returnValue(code) # application will give this to set_code()
def _warn_readline(self): def _warn_readline(self):
@ -286,12 +294,11 @@ class Wormhole:
@inlineCallbacks @inlineCallbacks
def _list_channels(self): def _list_channels(self):
_sent = self._timing.add_event("list") with self._timing.add("list"):
self._latest_channelids = None self._latest_channelids = None
yield self._ws_send(u"list") yield self._ws_send(u"list")
while self._latest_channelids is None: while self._latest_channelids is None:
yield self._sleep() yield self._sleep()
self._timing.finish_event(_sent)
returnValue(self._latest_channelids) returnValue(self._latest_channelids)
def _ws_handle_channelids(self, msg): def _ws_handle_channelids(self, msg):
@ -305,13 +312,14 @@ class Wormhole:
mo = re.search(r'^(\d+)-', code) mo = re.search(r'^(\d+)-', code)
if not mo: if not mo:
raise ValueError("code (%s) must start with NN-" % code) raise ValueError("code (%s) must start with NN-" % code)
with self._timing.add("API set_code"):
self._channelid = int(mo.group(1)) self._channelid = int(mo.group(1))
self._set_code(code) self._set_code(code)
self._start() self._start()
def _set_code(self, code): def _set_code(self, code):
if self._code is not None: raise UsageError if self._code is not None: raise UsageError
self._timing.add_event("code established") self._timing.add("code established")
self._code = code self._code = code
def serialize(self): def serialize(self):
@ -349,15 +357,16 @@ class Wormhole:
def get_verifier(self): def get_verifier(self):
if self._closed: raise UsageError if self._closed: raise UsageError
if self._code is None: raise UsageError if self._code is None: raise UsageError
with self._timing.add("API get_verifier"):
yield self._get_master_key() yield self._get_master_key()
# If the caller cares about the verifier, then they'll probably also # If the caller cares about the verifier, then they'll probably
# willing to wait a moment to see the _confirm message. Each side # also willing to wait a moment to see the _confirm message. Each
# sends this as soon as it sees the other's PAKE message. So the # side sends this as soon as it sees the other's PAKE message. So
# sender should see this hot on the heels of the inbound PAKE message # the sender should see this hot on the heels of the inbound PAKE
# (a moment after _get_master_key() returns). The receiver will see # message (a moment after _get_master_key() returns). The
# this a round-trip after they send their PAKE (because the sender is # receiver will see this a round-trip after they send their PAKE
# using wait=True inside _get_master_key, below: otherwise the sender # (because the sender is using wait=True inside _get_master_key,
# might go do some blocking call). # below: otherwise the sender might go do some blocking call).
yield self._msg_get(u"_confirm") yield self._msg_get(u"_confirm")
returnValue(self._verifier) returnValue(self._verifier)
@ -369,9 +378,10 @@ class Wormhole:
yield self._msg_send(u"pake", self._msg1) yield self._msg_send(u"pake", self._msg1)
pake_msg = yield self._msg_get(u"pake") pake_msg = yield self._msg_get(u"pake")
with self._timing.add("pake2", waiting="crypto"):
self._key = self._sp.finish(pake_msg) self._key = self._sp.finish(pake_msg)
self._verifier = self.derive_key(u"wormhole:verifier") self._verifier = self.derive_key(u"wormhole:verifier")
self._timing.add_event("key established") self._timing.add("key established")
if self._send_confirm: if self._send_confirm:
# both sides send different (random) confirmation messages # both sides send different (random) confirmation messages
@ -385,11 +395,13 @@ class Wormhole:
self._sent_messages.add( (phase, body) ) self._sent_messages.add( (phase, body) )
# TODO: retry on failure, with exponential backoff. We're guarding # TODO: retry on failure, with exponential backoff. We're guarding
# against the rendezvous server being temporarily offline. # against the rendezvous server being temporarily offline.
t = self._timing.add("add", phase=phase, wait=wait)
yield self._ws_send(u"add", phase=phase, yield self._ws_send(u"add", phase=phase,
body=hexlify(body).decode("ascii")) body=hexlify(body).decode("ascii"))
if wait: if wait:
while (phase, body) not in self._delivered_messages: while (phase, body) not in self._delivered_messages:
yield self._sleep() yield self._sleep()
t.finish()
def _ws_handle_message(self, msg): def _ws_handle_message(self, msg):
m = msg["message"] m = msg["message"]
@ -404,6 +416,7 @@ class Wormhole:
err = ServerError("got duplicate phase %s" % phase, self._ws_url) err = ServerError("got duplicate phase %s" % phase, self._ws_url)
return self._signal_error(err) return self._signal_error(err)
self._received_messages[phase] = body self._received_messages[phase] = body
self._received_messages_sent[phase] = msg.get(u"sent")
if phase == u"_confirm": if phase == u"_confirm":
# TODO: we might not have a master key yet, if the caller wasn't # TODO: we might not have a master key yet, if the caller wasn't
# waiting in _get_master_key() when a back-to-back pake+_confirm # waiting in _get_master_key() when a back-to-back pake+_confirm
@ -418,12 +431,15 @@ class Wormhole:
@inlineCallbacks @inlineCallbacks
def _msg_get(self, phase): def _msg_get(self, phase):
_start = self._timing.add_event("get(%s)" % phase) with self._timing.add("get", phase=phase) as t:
while phase not in self._received_messages: while phase not in self._received_messages:
yield self._sleep() # we can wait a long time here yield self._sleep() # we can wait a long time here
# that will throw an error if something goes wrong # that will throw an error if something goes wrong
self._timing.finish_event(_start) msg = self._received_messages[phase]
returnValue(self._received_messages[phase]) sent = self._received_messages_sent[phase]
if sent:
t.server_sent(sent)
returnValue(msg)
def derive_key(self, purpose, length=SecretBox.KEY_SIZE): def derive_key(self, purpose, length=SecretBox.KEY_SIZE):
if not isinstance(purpose, type(u"")): raise TypeError(type(purpose)) if not isinstance(purpose, type(u"")): raise TypeError(type(purpose))
@ -459,16 +475,15 @@ class Wormhole:
if phase.startswith(u"_"): raise UsageError # reserved for internals if phase.startswith(u"_"): raise UsageError # reserved for internals
if phase in self._sent_phases: raise UsageError # only call this once if phase in self._sent_phases: raise UsageError # only call this once
self._sent_phases.add(phase) self._sent_phases.add(phase)
_sent = self._timing.add_event("API send data", phase=phase, wait=wait) with self._timing.add("API send_data", phase=phase, wait=wait):
# Without predefined roles, we can't derive predictably unique keys # Without predefined roles, we can't derive predictably unique
# for each side, so we use the same key for both. We use random # keys for each side, so we use the same key for both. We use
# nonces to keep the messages distinct, and we automatically ignore # random nonces to keep the messages distinct, and we
# reflections. # automatically ignore reflections.
yield self._get_master_key() yield self._get_master_key()
data_key = self.derive_key(u"wormhole:phase:%s" % phase) data_key = self.derive_key(u"wormhole:phase:%s" % phase)
outbound_encrypted = self._encrypt_data(data_key, outbound_data) outbound_encrypted = self._encrypt_data(data_key, outbound_data)
yield self._msg_send(phase, outbound_encrypted, wait) yield self._msg_send(phase, outbound_encrypted, wait)
self._timing.finish_event(_sent)
@inlineCallbacks @inlineCallbacks
def get_data(self, phase=u"data"): def get_data(self, phase=u"data"):
@ -478,10 +493,9 @@ class Wormhole:
if phase.startswith(u"_"): raise UsageError # reserved for internals if phase.startswith(u"_"): raise UsageError # reserved for internals
if phase in self._got_phases: raise UsageError # only call this once if phase in self._got_phases: raise UsageError # only call this once
self._got_phases.add(phase) self._got_phases.add(phase)
_sent = self._timing.add_event("API get data", phase=phase) with self._timing.add("API get_data", phase=phase):
yield self._get_master_key() yield self._get_master_key()
body = yield self._msg_get(phase) # we can wait a long time here body = yield self._msg_get(phase) # we can wait a long time here
self._timing.finish_event(_sent)
try: try:
data_key = self.derive_key(u"wormhole:phase:%s" % phase) data_key = self.derive_key(u"wormhole:phase:%s" % phase)
inbound_data = self._decrypt_data(data_key, body) inbound_data = self._decrypt_data(data_key, body)
@ -491,6 +505,7 @@ class Wormhole:
def _ws_closed(self, wasClean, code, reason): def _ws_closed(self, wasClean, code, reason):
self._ws = None self._ws = None
self._ws_t.finish()
# TODO: schedule reconnect, unless we're done # TODO: schedule reconnect, unless we're done
@inlineCallbacks @inlineCallbacks
@ -517,20 +532,21 @@ class Wormhole:
if not isinstance(mood, (type(None), type(u""))): if not isinstance(mood, (type(None), type(u""))):
raise TypeError(type(mood)) raise TypeError(type(mood))
self._timing.finish_event(self._timing_started, mood=mood) with self._timing.add("API close"):
yield self._deallocate(mood) yield self._deallocate(mood)
# TODO: mark WebSocket as don't-reconnect # TODO: mark WebSocket as don't-reconnect
self._ws.transport.loseConnection() # probably flushes self._ws.transport.loseConnection() # probably flushes
del self._ws del self._ws
self._ws_t.finish()
self._timing_started.finish(mood=mood)
returnValue(f) returnValue(f)
@inlineCallbacks @inlineCallbacks
def _deallocate(self, mood): def _deallocate(self, mood):
_sent = self._timing.add_event("close") with self._timing.add("deallocate"):
yield self._ws_send(u"deallocate", mood=mood) yield self._ws_send(u"deallocate", mood=mood)
while self._deallocated_status is None: while self._deallocated_status is None:
yield self._sleep(wake_on_error=False) yield self._sleep(wake_on_error=False)
self._timing.finish_event(_sent)
# TODO: set a timeout, don't wait forever for an ack # TODO: set a timeout, don't wait forever for an ack
# TODO: if the connection is lost, let it go # TODO: if the connection is lost, let it go
returnValue(self._deallocated_status) returnValue(self._deallocated_status)

View File

@ -565,7 +565,7 @@ class Common:
self._winner = None self._winner = None
self._reactor = reactor self._reactor = reactor
self._timing = timing or DebugTiming() self._timing = timing or DebugTiming()
self._timing_started = self._timing.add_event("transit") self._timing.add("transit")
def _build_listener(self): def _build_listener(self):
if self._no_listen or self._tor_manager: if self._no_listen or self._tor_manager:
@ -690,13 +690,12 @@ class Common:
@inlineCallbacks @inlineCallbacks
def connect(self): def connect(self):
_start = self._timing.add_event("transit connect") with self._timing.add("transit connect"):
yield self._get_transit_key() yield self._get_transit_key()
# we want to have the transit key before starting any outbound # we want to have the transit key before starting any outbound
# connections, so those connections will know what to say when they # connections, so those connections will know what to say when
# connect # they connect
winner = yield self._connect() winner = yield self._connect()
self._timing.finish_event(_start)
returnValue(winner) returnValue(winner)
def _connect(self): def _connect(self):