server: remove listeners on disconnect

This wasn't happening before, so channels were staying alive until
reboot.
This commit is contained in:
Brian Warner 2016-06-24 18:47:16 -07:00
parent ffb1a9b9c9
commit 6a2cbf9014
3 changed files with 49 additions and 6 deletions

View File

@ -68,12 +68,11 @@ class Mailbox:
return messages return messages
def add_listener(self, handle, send_f, stop_f): def add_listener(self, handle, send_f, stop_f):
# TODO: update 'updated'
self._listeners[handle] = (send_f, stop_f) self._listeners[handle] = (send_f, stop_f)
return self.get_messages() return self.get_messages()
def remove_listener(self, handle): def remove_listener(self, handle):
self._listeners.pop(handle) self._listeners.pop(handle, None)
def has_listeners(self): def has_listeners(self):
return bool(self._listeners) return bool(self._listeners)
@ -135,12 +134,14 @@ class Mailbox:
# around. # around.
for (send_f, stop_f) in self._listeners.values(): for (send_f, stop_f) in self._listeners.values():
stop_f() stop_f()
self._listeners = {}
self._app.free_mailbox(self._mailbox_id) self._app.free_mailbox(self._mailbox_id)
def _shutdown(self): def _shutdown(self):
# used at test shutdown to accelerate client disconnects # used at test shutdown to accelerate client disconnects
for (send_f, stop_f) in self._listeners.values(): for (send_f, stop_f) in self._listeners.values():
stop_f() stop_f()
self._listeners = {}
class AppNamespace: class AppNamespace:
def __init__(self, db, blur_usage, log_requests, app_id): def __init__(self, db, blur_usage, log_requests, app_id):
@ -410,6 +411,7 @@ class AppNamespace:
for mailbox in self._mailboxes.values(): for mailbox in self._mailboxes.values():
if mailbox.has_listeners(): if mailbox.has_listeners():
log.msg("touch %s because listeners" % mailbox._mailbox_id)
mailbox._touch(now) mailbox._touch(now)
db.commit() # make sure the updates are visible below db.commit() # make sure the updates are visible below
@ -418,11 +420,13 @@ class AppNamespace:
for row in db.execute("SELECT * FROM `mailboxes` WHERE `app_id`=?", for row in db.execute("SELECT * FROM `mailboxes` WHERE `app_id`=?",
(self._app_id,)).fetchall(): (self._app_id,)).fetchall():
mailbox_id = row["id"] mailbox_id = row["id"]
log.msg(" 1: age=%s, old=%s, %s" %
(now - row["updated"], now - old, mailbox_id))
if row["updated"] > old: if row["updated"] > old:
new_mailboxes.add(mailbox_id) new_mailboxes.add(mailbox_id)
else: else:
old_mailboxes.add(mailbox_id) old_mailboxes.add(mailbox_id)
#log.msg(" 2: mailboxes:", new_mailboxes, old_mailboxes) log.msg(" 2: mailboxes:", new_mailboxes, old_mailboxes)
old_nameplates = set() old_nameplates = set()
for row in db.execute("SELECT * FROM `nameplates` WHERE `app_id`=?", for row in db.execute("SELECT * FROM `nameplates` WHERE `app_id`=?",
@ -431,7 +435,7 @@ class AppNamespace:
mailbox_id = row["mailbox_id"] mailbox_id = row["mailbox_id"]
if mailbox_id in old_mailboxes: if mailbox_id in old_mailboxes:
old_nameplates.add(npid) old_nameplates.add(npid)
#log.msg(" 3: old_nameplates", old_nameplates) log.msg(" 3: old_nameplates", old_nameplates)
for npid in old_nameplates: for npid in old_nameplates:
log.msg(" deleting nameplate", npid) log.msg(" deleting nameplate", npid)

View File

@ -88,6 +88,7 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol):
self._app = None self._app = None
self._side = None self._side = None
self._did_allocate = False # only one allocate() per websocket self._did_allocate = False # only one allocate() per websocket
self._listening = False
self._nameplate_id = None self._nameplate_id = None
self._mailbox = None self._mailbox = None
@ -203,6 +204,7 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol):
body=sm.body, server_rx=sm.server_rx, id=sm.msg_id) body=sm.body, server_rx=sm.server_rx, id=sm.msg_id)
def _stop(): def _stop():
pass pass
self._listening = True
for old_sm in self._mailbox.add_listener(self, _send, _stop): for old_sm in self._mailbox.add_listener(self, _send, _stop):
_send(old_sm) _send(old_sm)
@ -233,7 +235,8 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol):
self.sendMessage(payload, False) self.sendMessage(payload, False)
def onClose(self, wasClean, code, reason): def onClose(self, wasClean, code, reason):
pass if self._mailbox and self._listening:
self._mailbox.remove_listener(self)
class WebSocketRendezvousFactory(websocket.WebSocketServerFactory): class WebSocketRendezvousFactory(websocket.WebSocketServerFactory):

View File

@ -1,5 +1,5 @@
from __future__ import print_function, unicode_literals from __future__ import print_function, unicode_literals
import json, itertools import json, itertools, time
from binascii import hexlify from binascii import hexlify
import mock import mock
from twisted.trial import unittest from twisted.trial import unittest
@ -484,6 +484,14 @@ class WSClient(websocket.WebSocketClientProtocol):
return return
self.events.append(event) self.events.append(event)
def close(self):
self.d = defer.Deferred()
self.transport.loseConnection()
return self.d
def onClose(self, wasClean, code, reason):
if self.d:
self.d.callback((wasClean, code, reason))
def next_event(self): def next_event(self):
assert not self.d assert not self.d
if self.events: if self.events:
@ -840,6 +848,7 @@ class WebSocketAPI(ServerBase, unittest.TestCase):
m = yield c1.next_non_ack() m = yield c1.next_non_ack()
self.assertEqual(m["type"], "message") self.assertEqual(m["type"], "message")
self.assertEqual(m["body"], "body") self.assertEqual(m["body"], "body")
self.assertTrue(mb1.has_listeners())
mb1.add_message(SidedMessage(side="side2", phase="phase2", mb1.add_message(SidedMessage(side="side2", phase="phase2",
body="body2", server_rx=0, body="body2", server_rx=0,
@ -893,6 +902,7 @@ class WebSocketAPI(ServerBase, unittest.TestCase):
c1 = yield self.make_client() c1 = yield self.make_client()
yield c1.next_non_ack() yield c1.next_non_ack()
c1.send("bind", appid="appid", side="side") c1.send("bind", appid="appid", side="side")
app = self._rendezvous.get_app("appid")
c1.send("close", mood="mood") # must open first c1.send("close", mood="mood") # must open first
err = yield c1.next_non_ack() err = yield c1.next_non_ack()
@ -900,15 +910,41 @@ class WebSocketAPI(ServerBase, unittest.TestCase):
self.assertEqual(err["error"], "must open mailbox before closing") self.assertEqual(err["error"], "must open mailbox before closing")
c1.send("open", mailbox="mb1") c1.send("open", mailbox="mb1")
yield c1.sync()
mb1 = app._mailboxes["mb1"]
self.assertTrue(mb1.has_listeners())
c1.send("close", mood="mood") c1.send("close", mood="mood")
m = yield c1.next_non_ack() m = yield c1.next_non_ack()
self.assertEqual(m["type"], "closed") self.assertEqual(m["type"], "closed")
self.assertFalse(mb1.has_listeners())
c1.send("close", mood="mood") # already closed c1.send("close", mood="mood") # already closed
err = yield c1.next_non_ack() err = yield c1.next_non_ack()
self.assertEqual(err["type"], "error") self.assertEqual(err["type"], "error")
self.assertEqual(err["error"], "must open mailbox before closing") self.assertEqual(err["error"], "must open mailbox before closing")
@inlineCallbacks
def test_disconnect(self):
c1 = yield self.make_client()
yield c1.next_non_ack()
c1.send("bind", appid="appid", side="side")
app = self._rendezvous.get_app("appid")
c1.send("open", mailbox="mb1")
yield c1.sync()
mb1 = app._mailboxes["mb1"]
self.assertTrue(mb1.has_listeners())
yield c1.close()
# wait for the server to notice the socket has closed
started = time.time()
while mb1.has_listeners() and (time.time()-started < 5.0):
d = defer.Deferred()
reactor.callLater(0.01, d.callback, None)
yield d
self.assertFalse(mb1.has_listeners())
class Summary(unittest.TestCase): class Summary(unittest.TestCase):
def test_mailbox(self): def test_mailbox(self):