diff --git a/src/wormhole/server/rendezvous.py b/src/wormhole/server/rendezvous.py index bc97f68..54c1b1a 100644 --- a/src/wormhole/server/rendezvous.py +++ b/src/wormhole/server/rendezvous.py @@ -68,12 +68,11 @@ class Mailbox: return messages def add_listener(self, handle, send_f, stop_f): - # TODO: update 'updated' self._listeners[handle] = (send_f, stop_f) return self.get_messages() def remove_listener(self, handle): - self._listeners.pop(handle) + self._listeners.pop(handle, None) def has_listeners(self): return bool(self._listeners) @@ -135,12 +134,14 @@ class Mailbox: # around. for (send_f, stop_f) in self._listeners.values(): stop_f() + self._listeners = {} self._app.free_mailbox(self._mailbox_id) def _shutdown(self): # used at test shutdown to accelerate client disconnects for (send_f, stop_f) in self._listeners.values(): stop_f() + self._listeners = {} class AppNamespace: def __init__(self, db, blur_usage, log_requests, app_id): @@ -410,6 +411,7 @@ class AppNamespace: for mailbox in self._mailboxes.values(): if mailbox.has_listeners(): + log.msg("touch %s because listeners" % mailbox._mailbox_id) mailbox._touch(now) db.commit() # make sure the updates are visible below @@ -418,11 +420,13 @@ class AppNamespace: for row in db.execute("SELECT * FROM `mailboxes` WHERE `app_id`=?", (self._app_id,)).fetchall(): mailbox_id = row["id"] + log.msg(" 1: age=%s, old=%s, %s" % + (now - row["updated"], now - old, mailbox_id)) if row["updated"] > old: new_mailboxes.add(mailbox_id) else: old_mailboxes.add(mailbox_id) - #log.msg(" 2: mailboxes:", new_mailboxes, old_mailboxes) + log.msg(" 2: mailboxes:", new_mailboxes, old_mailboxes) old_nameplates = set() for row in db.execute("SELECT * FROM `nameplates` WHERE `app_id`=?", @@ -431,7 +435,7 @@ class AppNamespace: mailbox_id = row["mailbox_id"] if mailbox_id in old_mailboxes: old_nameplates.add(npid) - #log.msg(" 3: old_nameplates", old_nameplates) + log.msg(" 3: old_nameplates", old_nameplates) for npid in old_nameplates: log.msg(" deleting nameplate", npid) diff --git a/src/wormhole/server/rendezvous_websocket.py b/src/wormhole/server/rendezvous_websocket.py index 2ab2643..69e7567 100644 --- a/src/wormhole/server/rendezvous_websocket.py +++ b/src/wormhole/server/rendezvous_websocket.py @@ -88,6 +88,7 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): self._app = None self._side = None self._did_allocate = False # only one allocate() per websocket + self._listening = False self._nameplate_id = None self._mailbox = None @@ -203,6 +204,7 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): body=sm.body, server_rx=sm.server_rx, id=sm.msg_id) def _stop(): pass + self._listening = True for old_sm in self._mailbox.add_listener(self, _send, _stop): _send(old_sm) @@ -233,7 +235,8 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): self.sendMessage(payload, False) def onClose(self, wasClean, code, reason): - pass + if self._mailbox and self._listening: + self._mailbox.remove_listener(self) class WebSocketRendezvousFactory(websocket.WebSocketServerFactory): diff --git a/src/wormhole/test/test_server.py b/src/wormhole/test/test_server.py index 611f2c3..c79ebf0 100644 --- a/src/wormhole/test/test_server.py +++ b/src/wormhole/test/test_server.py @@ -1,5 +1,5 @@ from __future__ import print_function, unicode_literals -import json, itertools +import json, itertools, time from binascii import hexlify import mock from twisted.trial import unittest @@ -484,6 +484,14 @@ class WSClient(websocket.WebSocketClientProtocol): return self.events.append(event) + def close(self): + self.d = defer.Deferred() + self.transport.loseConnection() + return self.d + def onClose(self, wasClean, code, reason): + if self.d: + self.d.callback((wasClean, code, reason)) + def next_event(self): assert not self.d if self.events: @@ -840,6 +848,7 @@ class WebSocketAPI(ServerBase, unittest.TestCase): m = yield c1.next_non_ack() self.assertEqual(m["type"], "message") self.assertEqual(m["body"], "body") + self.assertTrue(mb1.has_listeners()) mb1.add_message(SidedMessage(side="side2", phase="phase2", body="body2", server_rx=0, @@ -893,6 +902,7 @@ class WebSocketAPI(ServerBase, unittest.TestCase): c1 = yield self.make_client() yield c1.next_non_ack() c1.send("bind", appid="appid", side="side") + app = self._rendezvous.get_app("appid") c1.send("close", mood="mood") # must open first err = yield c1.next_non_ack() @@ -900,15 +910,41 @@ class WebSocketAPI(ServerBase, unittest.TestCase): self.assertEqual(err["error"], "must open mailbox before closing") c1.send("open", mailbox="mb1") + yield c1.sync() + mb1 = app._mailboxes["mb1"] + self.assertTrue(mb1.has_listeners()) + c1.send("close", mood="mood") m = yield c1.next_non_ack() self.assertEqual(m["type"], "closed") + self.assertFalse(mb1.has_listeners()) c1.send("close", mood="mood") # already closed err = yield c1.next_non_ack() self.assertEqual(err["type"], "error") self.assertEqual(err["error"], "must open mailbox before closing") + @inlineCallbacks + def test_disconnect(self): + c1 = yield self.make_client() + yield c1.next_non_ack() + c1.send("bind", appid="appid", side="side") + app = self._rendezvous.get_app("appid") + + c1.send("open", mailbox="mb1") + yield c1.sync() + mb1 = app._mailboxes["mb1"] + self.assertTrue(mb1.has_listeners()) + + yield c1.close() + # wait for the server to notice the socket has closed + started = time.time() + while mb1.has_listeners() and (time.time()-started < 5.0): + d = defer.Deferred() + reactor.callLater(0.01, d.callback, None) + yield d + self.assertFalse(mb1.has_listeners()) + class Summary(unittest.TestCase): def test_mailbox(self):