diff --git a/src/wormhole/servers/relay_server.py b/src/wormhole/servers/relay_server.py index b0a2681..1ab1124 100644 --- a/src/wormhole/servers/relay_server.py +++ b/src/wormhole/servers/relay_server.py @@ -160,9 +160,9 @@ class GetterOrWatcher(RelayResource): request.setHeader(b"content-type", b"text/event-stream; charset=utf-8") ep = EventsProtocol(request) ep.sendEvent(json.dumps(self._welcome), name="welcome") - old_events = channel.add_listener(ep.sendEvent) + old_events = channel.add_listener(ep) request.notifyFinish().addErrback(lambda f: - channel.remove_listener(ep.sendEvent)) + channel.remove_listener(ep)) for old_event in old_events: ep.sendEvent(old_event) return server.NOT_DONE_YET @@ -179,9 +179,9 @@ class Watcher(RelayResource): request.setHeader(b"content-type", b"text/event-stream; charset=utf-8") ep = EventsProtocol(request) ep.sendEvent(json.dumps(self._welcome), name="welcome") - old_events = channel.add_listener(ep.sendEvent) + old_events = channel.add_listener(ep) request.notifyFinish().addErrback(lambda f: - channel.remove_listener(ep.sendEvent)) + channel.remove_listener(ep)) for old_event in old_events: ep.sendEvent(old_event) return server.NOT_DONE_YET @@ -218,7 +218,8 @@ class Channel: self._log_requests = log_requests self._appid = appid self._channelid = channelid - self._listeners = set() # callbacks that take JSONable object + self._listeners = set() # EventsProtocol instances, with a .sendEvent + # that takes a JSONable object def get_messages(self): messages = [] @@ -233,8 +234,8 @@ class Channel: data = {"welcome": self._welcome, "messages": messages} return data - def add_listener(self, listener): - self._listeners.add(listener) + def add_listener(self, ep): + self._listeners.add(ep) db = self._db for row in db.execute("SELECT * FROM `messages`" " WHERE `appid`=? AND `channelid`=?" @@ -243,13 +244,13 @@ class Channel: if row["phase"] in (u"_allocate", u"_deallocate"): continue yield json.dumps({"phase": row["phase"], "body": row["body"]}) - def remove_listener(self, listener): - self._listeners.discard(listener) + def remove_listener(self, ep): + self._listeners.discard(ep) def broadcast_message(self, phase, body): data = json.dumps({"phase": phase, "body": body}) - for listener in self._listeners: - listener(data) + for ep in self._listeners: + ep.sendEvent(data) def _add_message(self, side, phase, body): db = self._db @@ -375,18 +376,17 @@ class Channel: (self._appid, self._channelid)) db.commit() - # It'd be nice to shut down any EventSource listeners here. But we - # don't hang on to the EventsProtocol, so we can't really shut it - # down here: any listeners will stick around until they shut down - # from the client side. That will keep the Channel object in memory, - # but it won't be reachable from the AppNamespace, so no further - # messages will be sent to it. Eventually, when they close the TCP - # connection, self.remove_listener() will be called, ep.sendEvent - # will be removed from self._listeners, breaking the circular - # reference, and everything will get freed. + # Shut down any EventSource listeners, just in case they're still + # lingering around. + for ep in self._listeners: + ep.stop() self._app.free_channel(self._channelid) + def _shutdown(self): + # used at test shutdown to accelerate client disconnects + for ep in self._listeners: + ep.stop() class AppNamespace: def __init__(self, db, welcome, blur_usage, log_requests, appid): @@ -465,6 +465,10 @@ class AppNamespace: log.msg(" channel prune done, %r left" % (self._channels.keys(),)) return bool(self._channels) + def _shutdown(self): + for channel in self._channels.values(): + channel._shutdown() + class Relay(resource.Resource, service.MultiService): def __init__(self, db, welcome, blur_usage): resource.Resource.__init__(self) @@ -516,3 +520,14 @@ class Relay(resource.Resource, service.MultiService): log.msg("prune pops app %r" % (appid,)) self._apps.pop(appid) log.msg("app prune ends, %d remaining apps" % len(self._apps)) + + def stopService(self): + # This forcibly boots any clients that are still connected, which + # helps with unit tests that use threads for both clients. One client + # hits an exception, which terminates the test (and .tearDown calls + # stopService on the relay), but the other client (in its thread) is + # still waiting for a message. By killing off all connections, that + # other client gets an error, and exits promptly. + for app in self._apps.values(): + app._shutdown() + return service.MultiService.stopService(self)