diff --git a/src/wormhole/test/test_wormhole.py b/src/wormhole/test/test_wormhole.py index 426bc0f..1a6d865 100644 --- a/src/wormhole/test/test_wormhole.py +++ b/src/wormhole/test/test_wormhole.py @@ -290,8 +290,11 @@ class Basic(unittest.TestCase): self.assertEqual(w._drop_connection.mock_calls, []) response(w, type=u"released") - self.successResultOf(d) self.assertEqual(w._drop_connection.mock_calls, [mock.call()]) + self.assertNoResult(d) + + w._ws_closed(True, None, None) + self.successResultOf(d) def test_close_wait_2(self): # close after both claiming the nameplate and opening the mailbox @@ -314,10 +317,14 @@ class Basic(unittest.TestCase): response(w, type=u"released") self.assertNoResult(d) self.assertEqual(w._drop_connection.mock_calls, []) + response(w, type=u"closed") - self.successResultOf(d) + self.assertNoResult(d) self.assertEqual(w._drop_connection.mock_calls, [mock.call()]) + w._ws_closed(True, None, None) + self.successResultOf(d) + def test_close_wait_3(self): # close after claiming the nameplate, opening the mailbox, then # releasing the nameplate @@ -348,10 +355,14 @@ class Basic(unittest.TestCase): response(w, type=u"released") self.assertNoResult(d) self.assertEqual(w._drop_connection.mock_calls, []) + response(w, type=u"closed") - self.successResultOf(d) + self.assertNoResult(d) self.assertEqual(w._drop_connection.mock_calls, [mock.call()]) + w._ws_closed(True, None, None) + self.successResultOf(d) + def test_get_code_mock(self): timing = DebugTiming() w = wormhole._Wormhole(APPID, u"relay_url", reactor, None, timing) diff --git a/src/wormhole/wormhole.py b/src/wormhole/wormhole.py index 31ffe4c..ccb7f9c 100644 --- a/src/wormhole/wormhole.py +++ b/src/wormhole/wormhole.py @@ -235,6 +235,7 @@ class _Wormhole: self._flag_need_to_send_PAKE = True self._key = None self._closed = False + self._disconnect_waiter = defer.Deferred() self._mood = u"happy" self._get_verifier_called = False @@ -611,6 +612,8 @@ class _Wormhole: if self._mailbox_opened: yield self._close_waiter self._drop_connection() + if wait: + yield self._disconnect_waiter def _maybe_release_nameplate(self): if self.DEBUG: print("_maybe_release_nameplate", self._nameplate_claimed, self._nameplate_released) @@ -635,7 +638,7 @@ class _Wormhole: # calls _ws_closed() when done def _ws_closed(self, wasClean, code, reason): - pass + self._disconnect_waiter.callback(None) def wormhole(appid, relay_url, reactor, tor_manager=None, timing=None): timing = timing or DebugTiming()