diff --git a/src/wormhole/_dilation/manager.py b/src/wormhole/_dilation/manager.py index f38390b..7f5c6d8 100644 --- a/src/wormhole/_dilation/manager.py +++ b/src/wormhole/_dilation/manager.py @@ -38,6 +38,13 @@ class ReceivedHintsTooEarly(Exception): pass +class UnexpectedKCM(Exception): + pass + + +class UnknownMessageType(Exception): + pass + def make_side(): return bytes_to_hexstr(os.urandom(6)) @@ -94,7 +101,7 @@ class Manager(object): _next_subchannel_id = None # initialized in choose_role m = MethodicalMachine() - set_trace = getattr(m, "_setTrace", lambda self, f: None) + set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover def __attrs_post_init__(self): self._got_versions_d = Deferred() @@ -214,8 +221,9 @@ class Manager(object): self._inbound.handle_data(r.scid, r.data) else: # isinstance(r, Close) self._inbound.handle_close(r.scid) + return if isinstance(r, KCM): - log.err("got unexpected KCM") + log.err(UnexpectedKCM()) elif isinstance(r, Ping): self.handle_ping(r.ping_id) elif isinstance(r, Pong): @@ -223,7 +231,7 @@ class Manager(object): elif isinstance(r, Ack): self._outbound.handle_ack(r.resp_seqnum) # retire queued messages else: - log.err("received unknown message type {}".format(r)) + log.err(UnknownMessageType("{}".format(r))) # pings, pongs, and acks are not queued def send_ping(self, ping_id): diff --git a/src/wormhole/test/dilate/test_manager.py b/src/wormhole/test/dilate/test_manager.py index f0d50a0..45832c4 100644 --- a/src/wormhole/test/dilate/test_manager.py +++ b/src/wormhole/test/dilate/test_manager.py @@ -11,9 +11,11 @@ from ..._dilation import roles from ..._dilation.encode import to_be4 from ..._dilation.manager import (Dilator, Manager, make_side, OldPeerCannotDilateError, - UnknownDilationMessageType) + UnknownDilationMessageType, + UnexpectedKCM, + UnknownMessageType) from ..._dilation.subchannel import _WormholeAddress -from ..._dilation.connection import Open, Data, Close, Ack +from ..._dilation.connection import Open, Data, Close, Ack, KCM, Ping, Pong from .common import clear_mock_calls @@ -570,10 +572,74 @@ class TestManager(unittest.TestCase): self.assertEqual(c4.mock_calls, [mock.call.start()]) clear_mock_calls(c3, connector4, c4) - def test_stop(self): - pass - - def test_mirror(self): # receive a PLEASE with the same side as us: shouldn't happen - pass + m, h = make_manager(leader=True) + + m.start() + clear_mock_calls(h.send) + e = self.assertRaises(ValueError, m.rx_PLEASE, {"side": LEADER}) + self.assertEqual(str(e), "their side shouldn't be equal: reflection?") + + def test_ping_pong(self): + m, h = make_manager(leader=False) + + m.got_record(KCM()) + self.flushLoggedErrors(UnexpectedKCM) + + m.got_record(Ping(1)) + self.assertEqual(h.outbound.mock_calls, + [mock.call.send_if_connected(Pong(1))]) + clear_mock_calls(h.outbound) + + m.got_record(Pong(2)) + # currently ignored, will eventually update a timer + + m.got_record("not recognized") + e = self.flushLoggedErrors(UnknownMessageType) + self.assertEqual(len(e), 1) + self.assertEqual(str(e[0].value), "not recognized") + + m.send_ping(3) + self.assertEqual(h.outbound.mock_calls, + [mock.call.send_if_connected(Pong(3))]) + clear_mock_calls(h.outbound) + + def test_subchannel(self): + m, h = make_manager(leader=True) + sc = object() + + m.subchannel_pauseProducing(sc) + self.assertEqual(h.inbound.mock_calls, [ + mock.call.subchannel_pauseProducing(sc)]) + clear_mock_calls(h.inbound) + + m.subchannel_resumeProducing(sc) + self.assertEqual(h.inbound.mock_calls, [ + mock.call.subchannel_resumeProducing(sc)]) + clear_mock_calls(h.inbound) + + m.subchannel_stopProducing(sc) + self.assertEqual(h.inbound.mock_calls, [ + mock.call.subchannel_stopProducing(sc)]) + clear_mock_calls(h.inbound) + + p = object() + streaming = object() + + m.subchannel_registerProducer(sc, p, streaming) + self.assertEqual(h.outbound.mock_calls, [ + mock.call.subchannel_registerProducer(sc, p, streaming)]) + clear_mock_calls(h.outbound) + + m.subchannel_unregisterProducer(sc) + self.assertEqual(h.outbound.mock_calls, [ + mock.call.subchannel_unregisterProducer(sc)]) + clear_mock_calls(h.outbound) + + m.subchannel_closed("scid", sc) + self.assertEqual(h.inbound.mock_calls, [ + mock.call.subchannel_closed("scid", sc)]) + self.assertEqual(h.outbound.mock_calls, [ + mock.call.subchannel_closed("scid", sc)]) + clear_mock_calls(h.inbound, h.outbound)