diff --git a/src/wormhole/server/rendezvous.py b/src/wormhole/server/rendezvous.py index 122f780..3401008 100644 --- a/src/wormhole/server/rendezvous.py +++ b/src/wormhole/server/rendezvous.py @@ -24,9 +24,9 @@ class Channel: self._log_requests = log_requests self._appid = appid self._channelid = channelid - self._listeners = set() # instances with .send_rendezvous_event (that - # takes a JSONable object) and - # .stop_rendezvous_watcher() + self._listeners = {} # handle -> (send_f, stop_f) + # "handle" is a hashable object, for deregistration + # send_f() takes a JSONable object, stop_f() has no args def get_channelid(self): return self._channelid @@ -44,17 +44,17 @@ class Channel: "server_rx": row["server_rx"], "id": row["msgid"]}) return messages - def add_listener(self, ep): - self._listeners.add(ep) + def add_listener(self, handle, send_f, stop_f): + self._listeners[handle] = (send_f, stop_f) return self.get_messages() - def remove_listener(self, ep): - self._listeners.discard(ep) + def remove_listener(self, handle): + self._listeners.pop(handle) def broadcast_message(self, phase, body, server_rx, msgid): - for ep in self._listeners: - ep.send_rendezvous_event({"phase": phase, "body": body, - "server_rx": server_rx, "id": msgid}) + for (send_f, stop_f) in self._listeners.values(): + send_f({"phase": phase, "body": body, + "server_rx": server_rx, "id": msgid}) def _add_message(self, side, phase, body, server_rx, msgid): db = self._db @@ -183,15 +183,15 @@ class Channel: # Shut down any listeners, just in case they're still lingering # around. - for ep in self._listeners: - ep.stop_rendezvous_watcher() + for (send_f, stop_f) in self._listeners.values(): + stop_f() self._app.free_channel(self._channelid) def _shutdown(self): # used at test shutdown to accelerate client disconnects - for ep in self._listeners: - ep.stop_rendezvous_watcher() + for (send_f, stop_f) in self._listeners.values(): + stop_f() class AppNamespace: def __init__(self, db, welcome, blur_usage, log_requests, appid): diff --git a/src/wormhole/server/rendezvous_websocket.py b/src/wormhole/server/rendezvous_websocket.py index 93abf2c..6bd291b 100644 --- a/src/wormhole/server/rendezvous_websocket.py +++ b/src/wormhole/server/rendezvous_websocket.py @@ -150,8 +150,12 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): if self._watching: raise Error("already watching") self._watching = True - for old_message in channel.add_listener(self): - self.send_rendezvous_event(old_message) + def _send(event): + self.send_rendezvous_event(event) + def _stop(): + self.stop_rendezvous_watcher() + for old_message in channel.add_listener(self, _send, _stop): + _send(old_message) def handle_add(self, channel, msg, server_rx): if "phase" not in msg: