simplify timing, add msgid
This commit is contained in:
		
							parent
							
								
									5530c33185
								
							
						
					
					
						commit
						5501a6bf1c
					
				| 
						 | 
				
			
			@ -69,12 +69,11 @@ class Channel:
 | 
			
		|||
                   "phase": phase,
 | 
			
		||||
                   "body": hexlify(msg).decode("ascii")}
 | 
			
		||||
        data = json.dumps(payload).encode("utf-8")
 | 
			
		||||
        with self._timing.add("send %s" % phase) as t:
 | 
			
		||||
        with self._timing.add("send %s" % phase):
 | 
			
		||||
            r = requests.post(self._relay_url+"add", data=data,
 | 
			
		||||
                              timeout=self._timeout)
 | 
			
		||||
            r.raise_for_status()
 | 
			
		||||
            resp = r.json()
 | 
			
		||||
            t.server_sent(resp.get("sent"))
 | 
			
		||||
        if "welcome" in resp:
 | 
			
		||||
            self._handle_welcome(resp["welcome"])
 | 
			
		||||
        self._add_inbound_messages(resp["messages"])
 | 
			
		||||
| 
						 | 
				
			
			@ -92,7 +91,7 @@ class Channel:
 | 
			
		|||
        # wasn't one of our own messages. It will either come from
 | 
			
		||||
        # previously-received messages, or from an EventSource that we attach
 | 
			
		||||
        # to the corresponding URL
 | 
			
		||||
        with self._timing.add("get %s" % "/".join(sorted(phases))) as t:
 | 
			
		||||
        with self._timing.add("get %s" % "/".join(sorted(phases))):
 | 
			
		||||
            phase_and_body = self._find_inbound_message(phases)
 | 
			
		||||
            while phase_and_body is None:
 | 
			
		||||
                remaining = self._started + self._timeout - time.time()
 | 
			
		||||
| 
						 | 
				
			
			@ -113,7 +112,6 @@ class Channel:
 | 
			
		|||
                        phase_and_body = self._find_inbound_message(phases)
 | 
			
		||||
                        if phase_and_body:
 | 
			
		||||
                            f.close()
 | 
			
		||||
                            t.server_sent(data.get("sent"))
 | 
			
		||||
                            break
 | 
			
		||||
                if not phase_and_body:
 | 
			
		||||
                    time.sleep(self._wait)
 | 
			
		||||
| 
						 | 
				
			
			@ -133,11 +131,10 @@ class Channel:
 | 
			
		|||
        try:
 | 
			
		||||
            # ignore POST failure, don't call r.raise_for_status(), set a
 | 
			
		||||
            # short timeout and ignore failures
 | 
			
		||||
            with self._timing.add("close") as t:
 | 
			
		||||
            with self._timing.add("close"):
 | 
			
		||||
                r = requests.post(self._relay_url+"deallocate", data=data,
 | 
			
		||||
                                  timeout=5)
 | 
			
		||||
                resp = r.json()
 | 
			
		||||
                t.server_sent(resp.get("sent"))
 | 
			
		||||
                r.json()
 | 
			
		||||
        except requests.exceptions.RequestException:
 | 
			
		||||
            pass
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -154,28 +151,26 @@ class ChannelManager:
 | 
			
		|||
 | 
			
		||||
    def list_channels(self):
 | 
			
		||||
        queryargs = urlencode([("appid", self._appid)])
 | 
			
		||||
        with self._timing.add("list") as t:
 | 
			
		||||
        with self._timing.add("list"):
 | 
			
		||||
            r = requests.get(self._relay_url+"list?%s" % queryargs,
 | 
			
		||||
                             timeout=self._timeout)
 | 
			
		||||
            r.raise_for_status()
 | 
			
		||||
            data = r.json()
 | 
			
		||||
            if "welcome" in data:
 | 
			
		||||
                self._handle_welcome(data["welcome"])
 | 
			
		||||
            t.server_sent(data.get("sent"))
 | 
			
		||||
        channelids = data["channelids"]
 | 
			
		||||
        return channelids
 | 
			
		||||
 | 
			
		||||
    def allocate(self):
 | 
			
		||||
        data = json.dumps({"appid": self._appid,
 | 
			
		||||
                           "side": self._side}).encode("utf-8")
 | 
			
		||||
        with self._timing.add("allocate") as t:
 | 
			
		||||
        with self._timing.add("allocate"):
 | 
			
		||||
            r = requests.post(self._relay_url+"allocate", data=data,
 | 
			
		||||
                              timeout=self._timeout)
 | 
			
		||||
            r.raise_for_status()
 | 
			
		||||
            data = r.json()
 | 
			
		||||
            if "welcome" in data:
 | 
			
		||||
                self._handle_welcome(data["welcome"])
 | 
			
		||||
            t.server_sent(data.get("sent"))
 | 
			
		||||
        channelid = data["channelid"]
 | 
			
		||||
        return channelid
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -4,22 +4,16 @@ import json, time
 | 
			
		|||
class Event:
 | 
			
		||||
    def __init__(self, name, when, **details):
 | 
			
		||||
        # data fields that will be dumped to JSON later
 | 
			
		||||
        self._start = time.time() if when is None else float(when)
 | 
			
		||||
        self._server_sent = None
 | 
			
		||||
        self._stop = None
 | 
			
		||||
        self._name = name
 | 
			
		||||
        self._start = time.time() if when is None else float(when)
 | 
			
		||||
        self._stop = None
 | 
			
		||||
        self._details = details
 | 
			
		||||
 | 
			
		||||
    def server_sent(self, when):
 | 
			
		||||
        self._server_sent = when
 | 
			
		||||
 | 
			
		||||
    def detail(self, **details):
 | 
			
		||||
        self._details.update(details)
 | 
			
		||||
 | 
			
		||||
    def finish(self, when=None, server_sent=None, **details):
 | 
			
		||||
    def finish(self, when=None, **details):
 | 
			
		||||
        self._stop = time.time() if when is None else float(when)
 | 
			
		||||
        if server_sent:
 | 
			
		||||
            self.server_sent(server_sent)
 | 
			
		||||
        self.detail(**details)
 | 
			
		||||
 | 
			
		||||
    def __enter__(self):
 | 
			
		||||
| 
						 | 
				
			
			@ -50,8 +44,11 @@ class DebugTiming:
 | 
			
		|||
 | 
			
		||||
    def write(self, fn, stderr):
 | 
			
		||||
        with open(fn, "wb") as f:
 | 
			
		||||
            data = [ [e._start, e._server_sent, e._stop, e._name, e._details]
 | 
			
		||||
            data = [ dict(name=e._name,
 | 
			
		||||
                          start=e._start, stop=e._stop,
 | 
			
		||||
                          details=e._details,
 | 
			
		||||
                          )
 | 
			
		||||
                     for e in self._events ]
 | 
			
		||||
            json.dump(data, f)
 | 
			
		||||
            json.dump(data, f, indent=1)
 | 
			
		||||
            f.write("\n")
 | 
			
		||||
        print("Timing data written to %s" % fn, file=stderr)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -78,7 +78,6 @@ class Wormhole:
 | 
			
		|||
        self._sent_messages = set() # (phase, body_bytes)
 | 
			
		||||
        self._delivered_messages = set() # (phase, body_bytes)
 | 
			
		||||
        self._received_messages = {} # phase -> body_bytes
 | 
			
		||||
        self._received_messages_sent = {} # phase -> server timestamp
 | 
			
		||||
        self._sent_phases = set() # phases, to prohibit double-send
 | 
			
		||||
        self._got_phases = set() # phases, to prohibit double-read
 | 
			
		||||
        self._sleepers = []
 | 
			
		||||
| 
						 | 
				
			
			@ -125,12 +124,19 @@ class Wormhole:
 | 
			
		|||
    @inlineCallbacks
 | 
			
		||||
    def _ws_send(self, mtype, **kwargs):
 | 
			
		||||
        ws = yield self._get_websocket()
 | 
			
		||||
        # msgid is used by misc/dump-timing.py to correlate our sends with
 | 
			
		||||
        # their receives, and vice versa. They are also correlated with the
 | 
			
		||||
        # ACKs we get back from the server (which we otherwise ignore). There
 | 
			
		||||
        # are so few messages, 16 bits is enough to be mostly-unique.
 | 
			
		||||
        kwargs["id"] = hexlify(os.urandom(2)).decode("ascii")
 | 
			
		||||
        kwargs["type"] = mtype
 | 
			
		||||
        payload = json.dumps(kwargs).encode("utf-8")
 | 
			
		||||
        self._timing.add("ws_send", _side=self._side, **kwargs)
 | 
			
		||||
        ws.sendMessage(payload, False)
 | 
			
		||||
 | 
			
		||||
    def _ws_dispatch_msg(self, payload):
 | 
			
		||||
        msg = json.loads(payload.decode("utf-8"))
 | 
			
		||||
        self._timing.add("ws_receive", _side=self._side, message=msg)
 | 
			
		||||
        mtype = msg["type"]
 | 
			
		||||
        meth = getattr(self, "_ws_handle_"+mtype, None)
 | 
			
		||||
        if not meth:
 | 
			
		||||
| 
						 | 
				
			
			@ -143,7 +149,6 @@ class Wormhole:
 | 
			
		|||
        pass
 | 
			
		||||
 | 
			
		||||
    def _ws_handle_welcome(self, msg):
 | 
			
		||||
        self._timing.add("welcome").server_sent(msg["server_tx"])
 | 
			
		||||
        welcome = msg["welcome"]
 | 
			
		||||
        if ("motd" in welcome and
 | 
			
		||||
            not self.motd_displayed):
 | 
			
		||||
| 
						 | 
				
			
			@ -194,7 +199,6 @@ class Wormhole:
 | 
			
		|||
        self._wakeup()
 | 
			
		||||
 | 
			
		||||
    def _ws_handle_error(self, msg):
 | 
			
		||||
        self._timing.add("error").server_sent(msg["server_tx"])
 | 
			
		||||
        err = ServerError("%s: %s" % (msg["error"], msg["orig"]),
 | 
			
		||||
                          self._ws_url)
 | 
			
		||||
        return self._signal_error(err)
 | 
			
		||||
| 
						 | 
				
			
			@ -215,8 +219,7 @@ class Wormhole:
 | 
			
		|||
        if self._started_get_code: raise UsageError
 | 
			
		||||
        self._started_get_code = True
 | 
			
		||||
        with self._timing.add("API get_code"):
 | 
			
		||||
            with self._timing.add("allocate") as t:
 | 
			
		||||
                self._allocate_t = t
 | 
			
		||||
            with self._timing.add("allocate"):
 | 
			
		||||
                yield self._ws_send(u"allocate")
 | 
			
		||||
                while self._channelid is None:
 | 
			
		||||
                    yield self._sleep()
 | 
			
		||||
| 
						 | 
				
			
			@ -227,7 +230,6 @@ class Wormhole:
 | 
			
		|||
        returnValue(code)
 | 
			
		||||
 | 
			
		||||
    def _ws_handle_allocated(self, msg):
 | 
			
		||||
        self._allocate_t.server_sent(msg["server_tx"])
 | 
			
		||||
        if self._channelid is not None:
 | 
			
		||||
            return self._signal_error("got duplicate channelid")
 | 
			
		||||
        self._channelid = msg["channelid"]
 | 
			
		||||
| 
						 | 
				
			
			@ -422,7 +424,6 @@ class Wormhole:
 | 
			
		|||
            err = ServerError("got duplicate phase %s" % phase, self._ws_url)
 | 
			
		||||
            return self._signal_error(err)
 | 
			
		||||
        self._received_messages[phase] = body
 | 
			
		||||
        self._received_messages_sent[phase] = msg.get(u"sent")
 | 
			
		||||
        if phase == u"_confirm":
 | 
			
		||||
            # TODO: we might not have a master key yet, if the caller wasn't
 | 
			
		||||
            # waiting in _get_master_key() when a back-to-back pake+_confirm
 | 
			
		||||
| 
						 | 
				
			
			@ -437,14 +438,11 @@ class Wormhole:
 | 
			
		|||
 | 
			
		||||
    @inlineCallbacks
 | 
			
		||||
    def _msg_get(self, phase):
 | 
			
		||||
        with self._timing.add("get", phase=phase) as t:
 | 
			
		||||
        with self._timing.add("get", phase=phase):
 | 
			
		||||
            while phase not in self._received_messages:
 | 
			
		||||
                yield self._sleep() # we can wait a long time here
 | 
			
		||||
                # that will throw an error if something goes wrong
 | 
			
		||||
            msg = self._received_messages[phase]
 | 
			
		||||
            sent = self._received_messages_sent[phase]
 | 
			
		||||
            if sent:
 | 
			
		||||
                t.server_sent(sent)
 | 
			
		||||
        returnValue(msg)
 | 
			
		||||
 | 
			
		||||
    def derive_key(self, purpose, length=SecretBox.KEY_SIZE):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user