From 9bd5afe7df097bc46263bb327c46a5be588d1d16 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Mon, 23 May 2016 23:59:49 -0700 Subject: [PATCH] make close() always wait --- src/wormhole/test/test_wormhole.py | 56 ++++++++++++++++-------------- src/wormhole/wormhole.py | 35 +++++++++++++++---- 2 files changed, 57 insertions(+), 34 deletions(-) diff --git a/src/wormhole/test/test_wormhole.py b/src/wormhole/test/test_wormhole.py index c0f2823..c9e9a0f 100644 --- a/src/wormhole/test/test_wormhole.py +++ b/src/wormhole/test/test_wormhole.py @@ -282,7 +282,8 @@ class Basic(unittest.TestCase): self.assertIn(u"1", w._received_messages) self.assertNotIn(u"1", w._receive_waiters) - w.close() + d = w.close() + self.assertNoResult(d) out = ws.outbound() self.assertEqual(len(out), 1) self.check_out(out[0], type=u"close", mood=u"happy") @@ -293,6 +294,7 @@ class Basic(unittest.TestCase): response(w, type=u"closed") self.assertEqual(w._drop_connection.mock_calls, [mock.call()]) w._ws_closed(True, None, None) + self.assertEqual(self.successResultOf(d), None) def test_close_wait_0(self): # Close before the connection is established. The connection still @@ -301,7 +303,7 @@ class Basic(unittest.TestCase): w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing) w._drop_connection = mock.Mock() - d = w.close(wait=True) + d = w.close() self.assertNoResult(d) ws = MockWebSocket() @@ -322,7 +324,7 @@ class Basic(unittest.TestCase): w._event_connected(ws) w._event_ws_opened(None) - d = w.close(wait=True) + d = w.close() self.check_outbound(ws, [u"bind"]) self.assertNoResult(d) self.assertEqual(w._drop_connection.mock_calls, [mock.call()]) @@ -346,7 +348,7 @@ class Basic(unittest.TestCase): response(w, type=u"claimed", mailbox=u"mb123") - d = w.close(wait=True) + d = w.close() self.check_outbound(ws, [u"open", u"add", u"release", u"close"]) self.assertNoResult(d) self.assertEqual(w._drop_connection.mock_calls, []) @@ -375,7 +377,7 @@ class Basic(unittest.TestCase): w.set_code(CODE) self.check_outbound(ws, [u"bind", u"claim"]) - d = w.close(wait=True) + d = w.close() response(w, type=u"claimed", mailbox=u"mb123") self.check_outbound(ws, [u"release"]) self.assertNoResult(d) @@ -401,7 +403,7 @@ class Basic(unittest.TestCase): response(w, type=u"claimed", mailbox=u"mb456") self.check_outbound(ws, [u"bind", u"claim", u"open", u"add"]) - d = w.close(wait=True) + d = w.close() self.check_outbound(ws, [u"release", u"close"]) self.assertNoResult(d) self.assertEqual(w._drop_connection.mock_calls, []) @@ -439,7 +441,7 @@ class Basic(unittest.TestCase): self.check_outbound(ws, [u"bind", u"claim", u"open", u"add", u"release"]) - d = w.close(wait=True) + d = w.close() self.check_outbound(ws, [u"close"]) self.assertNoResult(d) self.assertEqual(w._drop_connection.mock_calls, []) @@ -625,8 +627,8 @@ class Wormholes(ServerBase, unittest.TestCase): dataY = yield w2.get() self.assertEqual(dataX, b"data2") self.assertEqual(dataY, b"data1") - yield w1.close(wait=True) - yield w2.close(wait=True) + yield w1.close() + yield w2.close() @inlineCallbacks def test_same_message(self): @@ -643,8 +645,8 @@ class Wormholes(ServerBase, unittest.TestCase): dataY = yield w2.get() self.assertEqual(dataX, b"data") self.assertEqual(dataY, b"data") - yield w1.close(wait=True) - yield w2.close(wait=True) + yield w1.close() + yield w2.close() @inlineCallbacks def test_interleaved(self): @@ -659,8 +661,8 @@ class Wormholes(ServerBase, unittest.TestCase): w2.send(b"data2") dataX = yield d self.assertEqual(dataX, b"data2") - yield w1.close(wait=True) - yield w2.close(wait=True) + yield w1.close() + yield w2.close() @inlineCallbacks def test_unidirectional(self): @@ -671,8 +673,8 @@ class Wormholes(ServerBase, unittest.TestCase): w1.send(b"data1") dataY = yield w2.get() self.assertEqual(dataY, b"data1") - yield w1.close(wait=True) - yield w2.close(wait=True) + yield w1.close() + yield w2.close() @inlineCallbacks def test_early(self): @@ -684,8 +686,8 @@ class Wormholes(ServerBase, unittest.TestCase): w2.set_code(u"123-abc-def") dataY = yield d self.assertEqual(dataY, b"data1") - yield w1.close(wait=True) - yield w2.close(wait=True) + yield w1.close() + yield w2.close() @inlineCallbacks def test_fixed_code(self): @@ -698,8 +700,8 @@ class Wormholes(ServerBase, unittest.TestCase): (dataX, dataY) = dl self.assertEqual(dataX, b"data2") self.assertEqual(dataY, b"data1") - yield w1.close(wait=True) - yield w2.close(wait=True) + yield w1.close() + yield w2.close() @inlineCallbacks @@ -718,8 +720,8 @@ class Wormholes(ServerBase, unittest.TestCase): (dataX, dataY) = dl self.assertEqual(dataX, b"data4") self.assertEqual(dataY, b"data3") - yield w1.close(wait=True) - yield w2.close(wait=True) + yield w1.close() + yield w2.close() @inlineCallbacks def test_wrong_password(self): @@ -738,8 +740,8 @@ class Wormholes(ServerBase, unittest.TestCase): yield self.assertFailure(w2.get(), WrongPasswordError) yield self.assertFailure(w1.get(), WrongPasswordError) - yield w1.close(wait=True) - yield w2.close(wait=True) + yield w1.close() + yield w2.close() self.flushLoggedErrors(WrongPasswordError) @inlineCallbacks @@ -758,8 +760,8 @@ class Wormholes(ServerBase, unittest.TestCase): dataY = yield w2.get() self.assertEqual(dataX, b"data2") self.assertEqual(dataY, b"data1") - yield w1.close(wait=True) - yield w2.close(wait=True) + yield w1.close() + yield w2.close() class Errors(ServerBase, unittest.TestCase): @inlineCallbacks @@ -773,7 +775,7 @@ class Errors(ServerBase, unittest.TestCase): self.assertRaises(UsageError, w.set_code, u"123-nope") yield self.assertFailure(w.get_code(), UsageError) yield self.assertFailure(w.input_code(), UsageError) - yield w.close(wait=True) + yield w.close() @inlineCallbacks def test_codes_2(self): @@ -782,5 +784,5 @@ class Errors(ServerBase, unittest.TestCase): self.assertRaises(UsageError, w.set_code, u"123-nope") yield self.assertFailure(w.get_code(), UsageError) yield self.assertFailure(w.input_code(), UsageError) - yield w.close(wait=True) + yield w.close() diff --git a/src/wormhole/wormhole.py b/src/wormhole/wormhole.py index 134aa0b..26c05a7 100644 --- a/src/wormhole/wormhole.py +++ b/src/wormhole/wormhole.py @@ -297,8 +297,29 @@ class _Wormhole: """ return self._API_derive_key(purpose, length) - def close(self, wait=False): - return self._API_close(wait) + def close(self, res=None): + """Collapse the wormhole, freeing up server resources and flushing + all pending messages. Returns a Deferred that fires when everything + is done. It fires with any argument close() was given, to enable use + as a d.addBoth() handler: + + w = wormhole(...) + d = w.get() + .. + d.addBoth(w.close) + return d + + Another reasonable approach is to use inlineCallbacks: + + @inlineCallbacks + def pair(self, code): + w = wormhole(...) + try: + them = yield w.get() + finally: + yield w.close() + """ + return self._API_close(res) # INTERNAL METHODS beyond here @@ -690,14 +711,14 @@ class _Wormhole: if self.DEBUG: print("_signal_error done") @inlineCallbacks - def _API_close(self, wait=False, mood=u"happy"): - if self.DEBUG: print("close", wait) + def _API_close(self, res, mood=u"happy"): + if self.DEBUG: print("close") if self._close_called: raise UsageError self._close_called = True self._maybe_close(WormholeClosedError(), mood) - if wait: - if self.DEBUG: print("waiting for disconnect") - yield self._disconnect_waiter + if self.DEBUG: print("waiting for disconnect") + yield self._disconnect_waiter + returnValue(res) def _maybe_close(self, error, mood): if self._closing: