make close() always wait

This commit is contained in:
Brian Warner 2016-05-23 23:59:49 -07:00
parent e11a6f8243
commit 9bd5afe7df
2 changed files with 57 additions and 34 deletions

View File

@ -282,7 +282,8 @@ class Basic(unittest.TestCase):
self.assertIn(u"1", w._received_messages) self.assertIn(u"1", w._received_messages)
self.assertNotIn(u"1", w._receive_waiters) self.assertNotIn(u"1", w._receive_waiters)
w.close() d = w.close()
self.assertNoResult(d)
out = ws.outbound() out = ws.outbound()
self.assertEqual(len(out), 1) self.assertEqual(len(out), 1)
self.check_out(out[0], type=u"close", mood=u"happy") self.check_out(out[0], type=u"close", mood=u"happy")
@ -293,6 +294,7 @@ class Basic(unittest.TestCase):
response(w, type=u"closed") response(w, type=u"closed")
self.assertEqual(w._drop_connection.mock_calls, [mock.call()]) self.assertEqual(w._drop_connection.mock_calls, [mock.call()])
w._ws_closed(True, None, None) w._ws_closed(True, None, None)
self.assertEqual(self.successResultOf(d), None)
def test_close_wait_0(self): def test_close_wait_0(self):
# Close before the connection is established. The connection still # Close before the connection is established. The connection still
@ -301,7 +303,7 @@ class Basic(unittest.TestCase):
w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing) w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing)
w._drop_connection = mock.Mock() w._drop_connection = mock.Mock()
d = w.close(wait=True) d = w.close()
self.assertNoResult(d) self.assertNoResult(d)
ws = MockWebSocket() ws = MockWebSocket()
@ -322,7 +324,7 @@ class Basic(unittest.TestCase):
w._event_connected(ws) w._event_connected(ws)
w._event_ws_opened(None) w._event_ws_opened(None)
d = w.close(wait=True) d = w.close()
self.check_outbound(ws, [u"bind"]) self.check_outbound(ws, [u"bind"])
self.assertNoResult(d) self.assertNoResult(d)
self.assertEqual(w._drop_connection.mock_calls, [mock.call()]) self.assertEqual(w._drop_connection.mock_calls, [mock.call()])
@ -346,7 +348,7 @@ class Basic(unittest.TestCase):
response(w, type=u"claimed", mailbox=u"mb123") response(w, type=u"claimed", mailbox=u"mb123")
d = w.close(wait=True) d = w.close()
self.check_outbound(ws, [u"open", u"add", u"release", u"close"]) self.check_outbound(ws, [u"open", u"add", u"release", u"close"])
self.assertNoResult(d) self.assertNoResult(d)
self.assertEqual(w._drop_connection.mock_calls, []) self.assertEqual(w._drop_connection.mock_calls, [])
@ -375,7 +377,7 @@ class Basic(unittest.TestCase):
w.set_code(CODE) w.set_code(CODE)
self.check_outbound(ws, [u"bind", u"claim"]) self.check_outbound(ws, [u"bind", u"claim"])
d = w.close(wait=True) d = w.close()
response(w, type=u"claimed", mailbox=u"mb123") response(w, type=u"claimed", mailbox=u"mb123")
self.check_outbound(ws, [u"release"]) self.check_outbound(ws, [u"release"])
self.assertNoResult(d) self.assertNoResult(d)
@ -401,7 +403,7 @@ class Basic(unittest.TestCase):
response(w, type=u"claimed", mailbox=u"mb456") response(w, type=u"claimed", mailbox=u"mb456")
self.check_outbound(ws, [u"bind", u"claim", u"open", u"add"]) self.check_outbound(ws, [u"bind", u"claim", u"open", u"add"])
d = w.close(wait=True) d = w.close()
self.check_outbound(ws, [u"release", u"close"]) self.check_outbound(ws, [u"release", u"close"])
self.assertNoResult(d) self.assertNoResult(d)
self.assertEqual(w._drop_connection.mock_calls, []) self.assertEqual(w._drop_connection.mock_calls, [])
@ -439,7 +441,7 @@ class Basic(unittest.TestCase):
self.check_outbound(ws, [u"bind", u"claim", u"open", u"add", self.check_outbound(ws, [u"bind", u"claim", u"open", u"add",
u"release"]) u"release"])
d = w.close(wait=True) d = w.close()
self.check_outbound(ws, [u"close"]) self.check_outbound(ws, [u"close"])
self.assertNoResult(d) self.assertNoResult(d)
self.assertEqual(w._drop_connection.mock_calls, []) self.assertEqual(w._drop_connection.mock_calls, [])
@ -625,8 +627,8 @@ class Wormholes(ServerBase, unittest.TestCase):
dataY = yield w2.get() dataY = yield w2.get()
self.assertEqual(dataX, b"data2") self.assertEqual(dataX, b"data2")
self.assertEqual(dataY, b"data1") self.assertEqual(dataY, b"data1")
yield w1.close(wait=True) yield w1.close()
yield w2.close(wait=True) yield w2.close()
@inlineCallbacks @inlineCallbacks
def test_same_message(self): def test_same_message(self):
@ -643,8 +645,8 @@ class Wormholes(ServerBase, unittest.TestCase):
dataY = yield w2.get() dataY = yield w2.get()
self.assertEqual(dataX, b"data") self.assertEqual(dataX, b"data")
self.assertEqual(dataY, b"data") self.assertEqual(dataY, b"data")
yield w1.close(wait=True) yield w1.close()
yield w2.close(wait=True) yield w2.close()
@inlineCallbacks @inlineCallbacks
def test_interleaved(self): def test_interleaved(self):
@ -659,8 +661,8 @@ class Wormholes(ServerBase, unittest.TestCase):
w2.send(b"data2") w2.send(b"data2")
dataX = yield d dataX = yield d
self.assertEqual(dataX, b"data2") self.assertEqual(dataX, b"data2")
yield w1.close(wait=True) yield w1.close()
yield w2.close(wait=True) yield w2.close()
@inlineCallbacks @inlineCallbacks
def test_unidirectional(self): def test_unidirectional(self):
@ -671,8 +673,8 @@ class Wormholes(ServerBase, unittest.TestCase):
w1.send(b"data1") w1.send(b"data1")
dataY = yield w2.get() dataY = yield w2.get()
self.assertEqual(dataY, b"data1") self.assertEqual(dataY, b"data1")
yield w1.close(wait=True) yield w1.close()
yield w2.close(wait=True) yield w2.close()
@inlineCallbacks @inlineCallbacks
def test_early(self): def test_early(self):
@ -684,8 +686,8 @@ class Wormholes(ServerBase, unittest.TestCase):
w2.set_code(u"123-abc-def") w2.set_code(u"123-abc-def")
dataY = yield d dataY = yield d
self.assertEqual(dataY, b"data1") self.assertEqual(dataY, b"data1")
yield w1.close(wait=True) yield w1.close()
yield w2.close(wait=True) yield w2.close()
@inlineCallbacks @inlineCallbacks
def test_fixed_code(self): def test_fixed_code(self):
@ -698,8 +700,8 @@ class Wormholes(ServerBase, unittest.TestCase):
(dataX, dataY) = dl (dataX, dataY) = dl
self.assertEqual(dataX, b"data2") self.assertEqual(dataX, b"data2")
self.assertEqual(dataY, b"data1") self.assertEqual(dataY, b"data1")
yield w1.close(wait=True) yield w1.close()
yield w2.close(wait=True) yield w2.close()
@inlineCallbacks @inlineCallbacks
@ -718,8 +720,8 @@ class Wormholes(ServerBase, unittest.TestCase):
(dataX, dataY) = dl (dataX, dataY) = dl
self.assertEqual(dataX, b"data4") self.assertEqual(dataX, b"data4")
self.assertEqual(dataY, b"data3") self.assertEqual(dataY, b"data3")
yield w1.close(wait=True) yield w1.close()
yield w2.close(wait=True) yield w2.close()
@inlineCallbacks @inlineCallbacks
def test_wrong_password(self): def test_wrong_password(self):
@ -738,8 +740,8 @@ class Wormholes(ServerBase, unittest.TestCase):
yield self.assertFailure(w2.get(), WrongPasswordError) yield self.assertFailure(w2.get(), WrongPasswordError)
yield self.assertFailure(w1.get(), WrongPasswordError) yield self.assertFailure(w1.get(), WrongPasswordError)
yield w1.close(wait=True) yield w1.close()
yield w2.close(wait=True) yield w2.close()
self.flushLoggedErrors(WrongPasswordError) self.flushLoggedErrors(WrongPasswordError)
@inlineCallbacks @inlineCallbacks
@ -758,8 +760,8 @@ class Wormholes(ServerBase, unittest.TestCase):
dataY = yield w2.get() dataY = yield w2.get()
self.assertEqual(dataX, b"data2") self.assertEqual(dataX, b"data2")
self.assertEqual(dataY, b"data1") self.assertEqual(dataY, b"data1")
yield w1.close(wait=True) yield w1.close()
yield w2.close(wait=True) yield w2.close()
class Errors(ServerBase, unittest.TestCase): class Errors(ServerBase, unittest.TestCase):
@inlineCallbacks @inlineCallbacks
@ -773,7 +775,7 @@ class Errors(ServerBase, unittest.TestCase):
self.assertRaises(UsageError, w.set_code, u"123-nope") self.assertRaises(UsageError, w.set_code, u"123-nope")
yield self.assertFailure(w.get_code(), UsageError) yield self.assertFailure(w.get_code(), UsageError)
yield self.assertFailure(w.input_code(), UsageError) yield self.assertFailure(w.input_code(), UsageError)
yield w.close(wait=True) yield w.close()
@inlineCallbacks @inlineCallbacks
def test_codes_2(self): def test_codes_2(self):
@ -782,5 +784,5 @@ class Errors(ServerBase, unittest.TestCase):
self.assertRaises(UsageError, w.set_code, u"123-nope") self.assertRaises(UsageError, w.set_code, u"123-nope")
yield self.assertFailure(w.get_code(), UsageError) yield self.assertFailure(w.get_code(), UsageError)
yield self.assertFailure(w.input_code(), UsageError) yield self.assertFailure(w.input_code(), UsageError)
yield w.close(wait=True) yield w.close()

View File

@ -297,8 +297,29 @@ class _Wormhole:
""" """
return self._API_derive_key(purpose, length) return self._API_derive_key(purpose, length)
def close(self, wait=False): def close(self, res=None):
return self._API_close(wait) """Collapse the wormhole, freeing up server resources and flushing
all pending messages. Returns a Deferred that fires when everything
is done. It fires with any argument close() was given, to enable use
as a d.addBoth() handler:
w = wormhole(...)
d = w.get()
..
d.addBoth(w.close)
return d
Another reasonable approach is to use inlineCallbacks:
@inlineCallbacks
def pair(self, code):
w = wormhole(...)
try:
them = yield w.get()
finally:
yield w.close()
"""
return self._API_close(res)
# INTERNAL METHODS beyond here # INTERNAL METHODS beyond here
@ -690,14 +711,14 @@ class _Wormhole:
if self.DEBUG: print("_signal_error done") if self.DEBUG: print("_signal_error done")
@inlineCallbacks @inlineCallbacks
def _API_close(self, wait=False, mood=u"happy"): def _API_close(self, res, mood=u"happy"):
if self.DEBUG: print("close", wait) if self.DEBUG: print("close")
if self._close_called: raise UsageError if self._close_called: raise UsageError
self._close_called = True self._close_called = True
self._maybe_close(WormholeClosedError(), mood) self._maybe_close(WormholeClosedError(), mood)
if wait: if self.DEBUG: print("waiting for disconnect")
if self.DEBUG: print("waiting for disconnect") yield self._disconnect_waiter
yield self._disconnect_waiter returnValue(res)
def _maybe_close(self, error, mood): def _maybe_close(self, error, mood):
if self._closing: if self._closing: