diff --git a/src/wormhole/_dilation/connection.py b/src/wormhole/_dilation/connection.py index 4b9ced7..d142eed 100644 --- a/src/wormhole/_dilation/connection.py +++ b/src/wormhole/_dilation/connection.py @@ -80,28 +80,36 @@ class _Framer(object): set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover @m.state() - def want_relay(self): pass # pragma: no cover + def want_relay(self): + pass # pragma: no cover @m.state(initial=True) - def want_prologue(self): pass # pragma: no cover + def want_prologue(self): + pass # pragma: no cover @m.state() - def want_frame(self): pass # pragma: no cover + def want_frame(self): + pass # pragma: no cover @m.input() - def use_relay(self, relay_handshake): pass + def use_relay(self, relay_handshake): + pass @m.input() - def connectionMade(self): pass + def connectionMade(self): + pass @m.input() - def parse(self): pass + def parse(self): + pass @m.input() - def got_relay_ok(self): pass + def got_relay_ok(self): + pass @m.input() - def got_prologue(self): pass + def got_prologue(self): + pass @m.output() def store_relay_handshake(self, relay_handshake): @@ -312,13 +320,16 @@ class _Record(object): # states: want_prologue, want_handshake, want_record @n.state(initial=True) - def want_prologue(self): pass # pragma: no cover + def want_prologue(self): + pass # pragma: no cover @n.state() - def want_handshake(self): pass # pragma: no cover + def want_handshake(self): + pass # pragma: no cover @n.state() - def want_message(self): pass # pragma: no cover + def want_message(self): + pass # pragma: no cover @n.input() def got_prologue(self): @@ -417,13 +428,16 @@ class DilatedConnectionProtocol(Protocol, object): self._can_send_records = False @m.state(initial=True) - def unselected(self): pass # pragma: no cover + def unselected(self): + pass # pragma: no cover @m.state() - def selecting(self): pass # pragma: no cover + def selecting(self): + pass # pragma: no cover @m.state() - def selected(self): pass # pragma: no cover + def selected(self): + pass # pragma: no cover @m.input() def got_kcm(self): diff --git a/src/wormhole/_dilation/old-follower.py b/src/wormhole/_dilation/old-follower.py index 68e3a38..7be4307 100644 --- a/src/wormhole/_dilation/old-follower.py +++ b/src/wormhole/_dilation/old-follower.py @@ -4,34 +4,53 @@ class ManagerFollower(_ManagerBase): set_trace = getattr(m, "_setTrace", lambda self, f: None) @m.state(initial=True) - def IDLE(self): pass # pragma: no cover + def IDLE(self): + pass # pragma: no cover @m.state() - def WANTING(self): pass # pragma: no cover + def WANTING(self): + pass # pragma: no cover + @m.state() - def CONNECTING(self): pass # pragma: no cover + def CONNECTING(self): + pass # pragma: no cover + @m.state() - def CONNECTED(self): pass # pragma: no cover + def CONNECTED(self): + pass # pragma: no cover + @m.state(terminal=True) - def STOPPED(self): pass # pragma: no cover + def STOPPED(self): + pass # pragma: no cover @m.input() - def start(self): pass # pragma: no cover - @m.input() - def rx_PLEASE(self): pass # pragma: no cover - @m.input() - def rx_DILATE(self): pass # pragma: no cover - @m.input() - def rx_HINTS(self, hint_message): pass # pragma: no cover + def start(self): + pass # pragma: no cover @m.input() - def connection_made(self): pass # pragma: no cover + def rx_PLEASE(self): + pass # pragma: no cover + @m.input() - def connection_lost(self): pass # pragma: no cover + def rx_DILATE(self): + pass # pragma: no cover + + @m.input() + def rx_HINTS(self, hint_message): + pass # pragma: no cover + + @m.input() + def connection_made(self): + pass # pragma: no cover + + @m.input() + def connection_lost(self): + pass # pragma: no cover # follower doesn't react to connection_lost, but waits for a new LETS_DILATE @m.input() - def stop(self): pass # pragma: no cover + def stop(self): + pass # pragma: no cover # these Outputs behave differently for the Leader vs the Follower @m.output() @@ -48,27 +67,32 @@ class ManagerFollower(_ManagerBase): @m.output() def use_hints(self, hint_message): - hint_objs = filter(lambda h: h, # ignore None, unrecognizable + hint_objs = filter(lambda h: h, # ignore None, unrecognizable [parse_hint(hs) for hs in hint_message["hints"]]) self._connector.got_hints(hint_objs) + @m.output() def stop_connecting(self): self._connector.stop() + @m.output() def use_connection(self, c): self._use_connection(c) + @m.output() def stop_using_connection(self): self._stop_using_connection() + @m.output() def signal_error(self): - pass # TODO + pass # TODO + @m.output() def signal_error_hints(self, hint_message): - pass # TODO + pass # TODO - IDLE.upon(rx_HINTS, enter=STOPPED, outputs=[signal_error_hints]) # too early - IDLE.upon(rx_DILATE, enter=STOPPED, outputs=[signal_error]) # too early + IDLE.upon(rx_HINTS, enter=STOPPED, outputs=[signal_error_hints]) # too early + IDLE.upon(rx_DILATE, enter=STOPPED, outputs=[signal_error]) # too early # leader shouldn't send us DILATE before receiving our PLEASE IDLE.upon(stop, enter=STOPPED, outputs=[]) IDLE.upon(start, enter=WANTING, outputs=[send_please]) @@ -78,7 +102,7 @@ class ManagerFollower(_ManagerBase): CONNECTING.upon(rx_HINTS, enter=CONNECTING, outputs=[use_hints]) CONNECTING.upon(connection_made, enter=CONNECTED, outputs=[use_connection]) # shouldn't happen: connection_lost - #CONNECTING.upon(connection_lost, enter=CONNECTING, outputs=[?]) + # CONNECTING.upon(connection_lost, enter=CONNECTING, outputs=[?]) CONNECTING.upon(rx_DILATE, enter=CONNECTING, outputs=[stop_connecting, start_connecting]) # receiving rx_DILATE while we're still working on the last one means the @@ -89,7 +113,7 @@ class ManagerFollower(_ManagerBase): CONNECTED.upon(connection_lost, enter=WANTING, outputs=[stop_using_connection]) CONNECTED.upon(rx_DILATE, enter=CONNECTING, outputs=[stop_using_connection, start_connecting]) - CONNECTED.upon(rx_HINTS, enter=CONNECTED, outputs=[]) # too late, ignore + CONNECTED.upon(rx_HINTS, enter=CONNECTED, outputs=[]) # too late, ignore CONNECTED.upon(stop, enter=STOPPED, outputs=[stop_using_connection]) # shouldn't happen: connection_made diff --git a/src/wormhole/test/dilate/common.py b/src/wormhole/test/dilate/common.py index 2ddacfb..4f398d7 100644 --- a/src/wormhole/test/dilate/common.py +++ b/src/wormhole/test/dilate/common.py @@ -3,16 +3,19 @@ import mock from zope.interface import alsoProvides from ..._interfaces import IDilationManager, IWormhole + def mock_manager(): m = mock.Mock() alsoProvides(m, IDilationManager) return m + def mock_wormhole(): m = mock.Mock() alsoProvides(m, IWormhole) return m + def clear_mock_calls(*args): for a in args: a.mock_calls[:] = [] diff --git a/src/wormhole/test/dilate/test_connection.py b/src/wormhole/test/dilate/test_connection.py index 406e42a..ee761fd 100644 --- a/src/wormhole/test/dilate/test_connection.py +++ b/src/wormhole/test/dilate/test_connection.py @@ -11,12 +11,13 @@ from ..._dilation.connection import (DilatedConnectionProtocol, encode_record, KCM, Open, Ack) from .common import clear_mock_calls + def make_con(role, use_relay=False): clock = Clock() eq = EventualQueue(clock) connector = mock.Mock() alsoProvides(connector, IDilationConnector) - n = mock.Mock() # pretends to be a Noise object + n = mock.Mock() # pretends to be a Noise object n.write_message = mock.Mock(side_effect=[b"handshake"]) c = DilatedConnectionProtocol(eq, role, connector, n, b"outbound_prologue\n", b"inbound_prologue\n") @@ -26,6 +27,7 @@ def make_con(role, use_relay=False): alsoProvides(t, ITransport) return c, n, connector, t, eq + class Connection(unittest.TestCase): def test_bad_prologue(self): c, n, connector, t, eq = make_con(LEADER) @@ -55,10 +57,10 @@ class Connection(unittest.TestCase): n.decrypt = mock.Mock(side_effect=[ encode_record(t_kcm), encode_record(t_open), - ]) + ]) exp_kcm = b"\x00\x00\x00\x03kcm" n.encrypt = mock.Mock(side_effect=[b"kcm", b"ack1"]) - m = mock.Mock() # Manager + m = mock.Mock() # Manager c.makeConnection(t) self.assertEqual(n.mock_calls, [mock.call.start_handshake()]) @@ -86,7 +88,7 @@ class Connection(unittest.TestCase): self.assertEqual(n.mock_calls, [ mock.call.read_message(b"handshake2"), mock.call.encrypt(encode_record(t_kcm)), - ]) + ]) self.assertEqual(connector.mock_calls, []) self.assertEqual(t.mock_calls, [ mock.call.write(exp_kcm)]) @@ -123,7 +125,7 @@ class Connection(unittest.TestCase): c.send_record(KCM()) self.assertEqual(n.mock_calls, [ mock.call.encrypt(encode_record(t_kcm)), - ]) + ]) self.assertEqual(connector.mock_calls, []) self.assertEqual(t.mock_calls, [mock.call.write(exp_kcm)]) self.assertEqual(m.mock_calls, []) @@ -163,7 +165,6 @@ class Connection(unittest.TestCase): def test_no_relay_follower(self): return self._test_no_relay(FOLLOWER) - def test_relay(self): c, n, connector, t, eq = make_con(LEADER, use_relay=True) diff --git a/src/wormhole/test/dilate/test_encoding.py b/src/wormhole/test/dilate/test_encoding.py index e2c854e..6bf2c7a 100644 --- a/src/wormhole/test/dilate/test_encoding.py +++ b/src/wormhole/test/dilate/test_encoding.py @@ -2,11 +2,12 @@ from __future__ import print_function, unicode_literals from twisted.trial import unittest from ..._dilation.encode import to_be4, from_be4 + class Encoding(unittest.TestCase): def test_be4(self): - self.assertEqual(to_be4(0), b"\x00\x00\x00\x00") - self.assertEqual(to_be4(1), b"\x00\x00\x00\x01") + self.assertEqual(to_be4(0), b"\x00\x00\x00\x00") + self.assertEqual(to_be4(1), b"\x00\x00\x00\x01") self.assertEqual(to_be4(256), b"\x00\x00\x01\x00") self.assertEqual(to_be4(257), b"\x00\x00\x01\x01") with self.assertRaises(ValueError): diff --git a/src/wormhole/test/dilate/test_endpoints.py b/src/wormhole/test/dilate/test_endpoints.py index bd8f995..ba07fe0 100644 --- a/src/wormhole/test/dilate/test_endpoints.py +++ b/src/wormhole/test/dilate/test_endpoints.py @@ -11,6 +11,7 @@ from ..._dilation.subchannel import (ControlEndpoint, SingleUseEndpointError) from .common import mock_manager + class Endpoints(unittest.TestCase): def test_control(self): scid0 = b"scid0" @@ -94,4 +95,4 @@ class Endpoints(unittest.TestCase): self.assertEqual(t2.mock_calls, [mock.call._set_protocol(p2)]) self.assertEqual(p2.mock_calls, [mock.call.makeConnection(t2)]) - lp.stopListening() # TODO: should this do more? + lp.stopListening() # TODO: should this do more? diff --git a/src/wormhole/test/dilate/test_framer.py b/src/wormhole/test/dilate/test_framer.py index 81d4cf9..51ac039 100644 --- a/src/wormhole/test/dilate/test_framer.py +++ b/src/wormhole/test/dilate/test_framer.py @@ -5,12 +5,14 @@ from twisted.trial import unittest from twisted.internet.interfaces import ITransport from ..._dilation.connection import _Framer, Frame, Prologue, Disconnect + def make_framer(): t = mock.Mock() alsoProvides(t, ITransport) f = _Framer(t, b"outbound_prologue\n", b"inbound_prologue\n") return f, t + class Framer(unittest.TestCase): def test_bad_prologue_length(self): f, t = make_framer() @@ -19,7 +21,7 @@ class Framer(unittest.TestCase): f.connectionMade() self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")]) t.mock_calls[:] = [] - self.assertEqual([], list(f.add_and_parse(b"inbound_"))) # wait for it + self.assertEqual([], list(f.add_and_parse(b"inbound_"))) # wait for it self.assertEqual(t.mock_calls, []) with mock.patch("wormhole._dilation.connection.log.msg") as m: @@ -37,7 +39,7 @@ class Framer(unittest.TestCase): f.connectionMade() self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")]) t.mock_calls[:] = [] - self.assertEqual([], list(f.add_and_parse(b"inbound_"))) # wait for it + self.assertEqual([], list(f.add_and_parse(b"inbound_"))) # wait for it self.assertEqual([], list(f.add_and_parse(b"not"))) with mock.patch("wormhole._dilation.connection.log.msg") as m: diff --git a/src/wormhole/test/dilate/test_inbound.py b/src/wormhole/test/dilate/test_inbound.py index d147283..392a661 100644 --- a/src/wormhole/test/dilate/test_inbound.py +++ b/src/wormhole/test/dilate/test_inbound.py @@ -8,6 +8,7 @@ from ..._dilation.inbound import (Inbound, DuplicateOpenError, DataForMissingSubchannelError, CloseForMissingSubchannelError) + def make_inbound(): m = mock.Mock() alsoProvides(m, IDilationManager) @@ -15,6 +16,7 @@ def make_inbound(): i = Inbound(m, host_addr) return i, m, host_addr + class InboundTest(unittest.TestCase): def test_seqnum(self): i, m, host_addr = make_inbound() @@ -158,7 +160,7 @@ class InboundTest(unittest.TestCase): self.assertEqual(c.mock_calls, [mock.call.pauseProducing()]) c.mock_calls[:] = [] i.subchannel_pauseProducing(sc2) - self.assertEqual(c.mock_calls, []) # was already paused + self.assertEqual(c.mock_calls, []) # was already paused # tolerate duplicate pauseProducing i.subchannel_pauseProducing(sc2) diff --git a/src/wormhole/test/dilate/test_manager.py b/src/wormhole/test/dilate/test_manager.py index 625039d..6acf25b 100644 --- a/src/wormhole/test/dilate/test_manager.py +++ b/src/wormhole/test/dilate/test_manager.py @@ -13,12 +13,14 @@ from ..._dilation.manager import (Dilator, from ..._dilation.subchannel import _WormholeAddress from .common import clear_mock_calls + def make_dilator(): reactor = object() clock = Clock() eq = EventualQueue(clock) - term = mock.Mock(side_effect=lambda: True) # one write per Eventual tick - term_factory = lambda: term + term = mock.Mock(side_effect=lambda: True) # one write per Eventual tick + + def term_factory(): return term coop = Cooperator(terminationPredicateFactory=term_factory, scheduler=eq.eventually) send = mock.Mock() @@ -27,6 +29,7 @@ def make_dilator(): dil.wire(send) return dil, send, reactor, eq, clock, coop + class TestDilator(unittest.TestCase): def test_leader(self): dil, send, reactor, eq, clock, coop = make_dilator() @@ -148,12 +151,11 @@ class TestDilator(unittest.TestCase): d1 = dil.dilate() self.assertNoResult(d1) - dil.got_wormhole_versions("me", "you", {}) # missing "can-dilate" + dil.got_wormhole_versions("me", "you", {}) # missing "can-dilate" eq.flush_sync() f = self.failureResultOf(d1) f.check(OldPeerCannotDilateError) - def test_disjoint_versions(self): dil, send, reactor, eq, clock, coop = make_dilator() d1 = dil.dilate() @@ -164,7 +166,6 @@ class TestDilator(unittest.TestCase): f = self.failureResultOf(d1) f.check(OldPeerCannotDilateError) - def test_early_dilate_messages(self): dil, send, reactor, eq, clock, coop = make_dilator() dil._transit_key = b"key" @@ -188,8 +189,6 @@ class TestDilator(unittest.TestCase): mock.call.rx_HINTS(hintmsg), mock.call.when_first_connected()]) - - def test_transit_relay(self): dil, send, reactor, eq, clock, coop = make_dilator() dil._transit_key = b"key" diff --git a/src/wormhole/test/dilate/test_outbound.py b/src/wormhole/test/dilate/test_outbound.py index fab596a..db38502 100644 --- a/src/wormhole/test/dilate/test_outbound.py +++ b/src/wormhole/test/dilate/test_outbound.py @@ -16,17 +16,20 @@ Pauser = namedtuple("Pauser", ["seqnum"]) NonPauser = namedtuple("NonPauser", ["seqnum"]) Stopper = namedtuple("Stopper", ["sc"]) + def make_outbound(): m = mock.Mock() alsoProvides(m, IDilationManager) clock = Clock() eq = EventualQueue(clock) - term = mock.Mock(side_effect=lambda: True) # one write per Eventual tick - term_factory = lambda: term + term = mock.Mock(side_effect=lambda: True) # one write per Eventual tick + + def term_factory(): return term coop = Cooperator(terminationPredicateFactory=term_factory, scheduler=eq.eventually) o = Outbound(m, coop) - c = mock.Mock() # Connection + c = mock.Mock() # Connection + def maybe_pause(r): if isinstance(r, Pauser): o.pauseProducing() @@ -37,6 +40,7 @@ def make_outbound(): o._test_term = term return o, m, c + class OutboundTest(unittest.TestCase): def test_build_record(self): o, m, c = make_outbound() @@ -69,10 +73,10 @@ class OutboundTest(unittest.TestCase): o.handle_ack(r3.seqnum) self.assertEqual(list(o._outbound_queue), []) - o.handle_ack(r3.seqnum) # ignored + o.handle_ack(r3.seqnum) # ignored self.assertEqual(list(o._outbound_queue), []) - o.handle_ack(r1.seqnum) # ignored + o.handle_ack(r1.seqnum) # ignored self.assertEqual(list(o._outbound_queue), []) def test_duplicate_registerProducer(self): @@ -192,7 +196,8 @@ class OutboundTest(unittest.TestCase): clear_mock_calls(c) sc1, sc2, sc3 = object(), object(), object() - p1, p2, p3 = mock.Mock(name="p1"), mock.Mock(name="p2"), mock.Mock(name="p3") + p1, p2, p3 = mock.Mock(name="p1"), mock.Mock( + name="p2"), mock.Mock(name="p3") # we aren't paused yet, since we haven't sent any data o.subchannel_registerProducer(sc1, p1, True) @@ -310,7 +315,8 @@ class OutboundTest(unittest.TestCase): # and another disconnects itself when called p2.resumeProducing.side_effect = lambda: None - p3.resumeProducing.side_effect = lambda: o.subchannel_unregisterProducer(sc3) + p3.resumeProducing.side_effect = lambda: o.subchannel_unregisterProducer( + sc3) o.pauseProducing() o.resumeProducing() self.assertEqual(p2.mock_calls, [mock.call.pauseProducing(), @@ -360,7 +366,7 @@ class OutboundTest(unittest.TestCase): r2 = NonPauser(seqnum=2) # we aren't paused yet, since we haven't sent any data - o.subchannel_registerProducer(sc1, p1, True) # push + o.subchannel_registerProducer(sc1, p1, True) # push o.queue_and_send_record(r1) # now we're paused self.assertTrue(o._paused) @@ -371,7 +377,7 @@ class OutboundTest(unittest.TestCase): p1.resumeProducing.side_effect = lambda: c.send_record(r1) p2.resumeProducing.side_effect = lambda: c.send_record(r2) - o.subchannel_registerProducer(sc2, p2, False) # pull: always ready + o.subchannel_registerProducer(sc2, p2, False) # pull: always ready # p1 is still first, since p2 was just added (at the end) self.assertTrue(o._paused) @@ -390,7 +396,7 @@ class OutboundTest(unittest.TestCase): mock.call.pauseProducing(), ]) self.assertEqual(p2.mock_calls, []) - self.assertEqual(list(o._all_producers), [p2, p1]) # now p2 is next + self.assertEqual(list(o._all_producers), [p2, p1]) # now p2 is next clear_mock_calls(p1, p2, c) # next should fire p2, then p1 @@ -404,7 +410,7 @@ class OutboundTest(unittest.TestCase): ]) self.assertEqual(p2.mock_calls, [mock.call.resumeProducing(), ]) - self.assertEqual(list(o._all_producers), [p2, p1]) # p2 still at bat + self.assertEqual(list(o._all_producers), [p2, p1]) # p2 still at bat clear_mock_calls(p1, p2, c) def test_pull_producer(self): @@ -430,13 +436,13 @@ class OutboundTest(unittest.TestCase): it = iter(records) p1.resumeProducing.side_effect = lambda: c.send_record(next(it)) o.subchannel_registerProducer(sc1, p1, False) - eq.flush_sync() # fast forward into the glorious (paused) future + eq.flush_sync() # fast forward into the glorious (paused) future self.assertTrue(o._paused) self.assertEqual(c.mock_calls, [mock.call.send_record(r) for r in records[:-1]]) self.assertEqual(p1.mock_calls, - [mock.call.resumeProducing()]*(len(records)-1)) + [mock.call.resumeProducing()] * (len(records) - 1)) clear_mock_calls(c, p1) # next resumeProducing should cause it to disconnect @@ -460,7 +466,7 @@ class OutboundTest(unittest.TestCase): NonPauser(3), NonPauser(13), NonPauser(4), NonPauser(14), Pauser(5)] - expected2 = [ NonPauser(15), + expected2 = [NonPauser(15), NonPauser(6), NonPauser(16), NonPauser(7), NonPauser(17), NonPauser(8), NonPauser(18), @@ -487,14 +493,14 @@ class OutboundTest(unittest.TestCase): p2.resumeProducing.side_effect = lambda: c.send_record(next(it2)) o.subchannel_registerProducer(sc2, p2, False) - eq.flush_sync() # fast forward into the glorious (paused) future + eq.flush_sync() # fast forward into the glorious (paused) future sends = [mock.call.resumeProducing()] self.assertTrue(o._paused) self.assertEqual(c.mock_calls, [mock.call.send_record(r) for r in expected1]) - self.assertEqual(p1.mock_calls, 6*sends) - self.assertEqual(p2.mock_calls, 5*sends) + self.assertEqual(p1.mock_calls, 6 * sends) + self.assertEqual(p2.mock_calls, 5 * sends) clear_mock_calls(c, p1, p2) o.resumeProducing() @@ -502,13 +508,13 @@ class OutboundTest(unittest.TestCase): self.assertTrue(o._paused) self.assertEqual(c.mock_calls, [mock.call.send_record(r) for r in expected2]) - self.assertEqual(p1.mock_calls, 4*sends) - self.assertEqual(p2.mock_calls, 5*sends) + self.assertEqual(p1.mock_calls, 4 * sends) + self.assertEqual(p2.mock_calls, 5 * sends) clear_mock_calls(c, p1, p2) def test_send_if_connected(self): o, m, c = make_outbound() - o.send_if_connected(Ack(1)) # not connected yet + o.send_if_connected(Ack(1)) # not connected yet o.use_connection(c) o.send_if_connected(KCM()) @@ -517,7 +523,7 @@ class OutboundTest(unittest.TestCase): def test_tolerate_duplicate_pause_resume(self): o, m, c = make_outbound() - self.assertTrue(o._paused) # no connection + self.assertTrue(o._paused) # no connection o.use_connection(c) self.assertFalse(o._paused) o.pauseProducing() @@ -533,7 +539,7 @@ class OutboundTest(unittest.TestCase): o, m, c = make_outbound() o.use_connection(c) self.assertFalse(o._paused) - o.stopProducing() # connection does this before loss + o.stopProducing() # connection does this before loss self.assertTrue(o._paused) o.stop_using_connection() self.assertTrue(o._paused) @@ -559,13 +565,15 @@ def make_pushpull(pauses): clock = Clock() eq = EventualQueue(clock) - term = mock.Mock(side_effect=lambda: True) # one write per Eventual tick - term_factory = lambda: term + term = mock.Mock(side_effect=lambda: True) # one write per Eventual tick + + def term_factory(): return term coop = Cooperator(terminationPredicateFactory=term_factory, scheduler=eq.eventually) pp = PullToPush(p, unregister, coop) it = cycle(pauses) + def action(i): if isinstance(i, Exception): raise i @@ -574,41 +582,45 @@ def make_pushpull(pauses): p.resumeProducing.side_effect = lambda: action(next(it)) return p, unregister, pp, eq + class PretendResumptionError(Exception): pass + + class PretendUnregisterError(Exception): pass + class PushPull(unittest.TestCase): # test our wrapper utility, which I copied from # twisted.internet._producer_helpers since it isn't publically exposed def test_start_unpaused(self): - p, unr, pp, eq = make_pushpull([True]) # pause on each resumeProducing + p, unr, pp, eq = make_pushpull([True]) # pause on each resumeProducing # if it starts unpaused, it gets one write before being halted pp.startStreaming(False) eq.flush_sync() - self.assertEqual(p.mock_calls, [mock.call.resumeProducing()]*1) + self.assertEqual(p.mock_calls, [mock.call.resumeProducing()] * 1) clear_mock_calls(p) # now each time we call resumeProducing, we should see one delivered to # the underlying IPullProducer pp.resumeProducing() eq.flush_sync() - self.assertEqual(p.mock_calls, [mock.call.resumeProducing()]*1) + self.assertEqual(p.mock_calls, [mock.call.resumeProducing()] * 1) pp.stopStreaming() - pp.stopStreaming() # should tolerate this + pp.stopStreaming() # should tolerate this def test_start_unpaused_two_writes(self): - p, unr, pp, eq = make_pushpull([False, True]) # pause every other time + p, unr, pp, eq = make_pushpull([False, True]) # pause every other time # it should get two writes, since the first didn't pause pp.startStreaming(False) eq.flush_sync() - self.assertEqual(p.mock_calls, [mock.call.resumeProducing()]*2) + self.assertEqual(p.mock_calls, [mock.call.resumeProducing()] * 2) def test_start_paused(self): - p, unr, pp, eq = make_pushpull([True]) # pause on each resumeProducing + p, unr, pp, eq = make_pushpull([True]) # pause on each resumeProducing pp.startStreaming(True) eq.flush_sync() self.assertEqual(p.mock_calls, []) @@ -637,9 +649,5 @@ class PushPull(unittest.TestCase): self.assertEqual(unr.mock_calls, [mock.call()]) self.flushLoggedErrors(PretendResumptionError, PretendUnregisterError) - - - - # TODO: consider making p1/p2/p3 all elements of a shared Mock, maybe I # could capture the inter-call ordering that way diff --git a/src/wormhole/test/dilate/test_parse.py b/src/wormhole/test/dilate/test_parse.py index 8365e62..f7276a6 100644 --- a/src/wormhole/test/dilate/test_parse.py +++ b/src/wormhole/test/dilate/test_parse.py @@ -4,6 +4,7 @@ from twisted.trial import unittest from ..._dilation.connection import (parse_record, encode_record, KCM, Ping, Pong, Open, Data, Close, Ack) + class Parse(unittest.TestCase): def test_parse(self): self.assertEqual(parse_record(b"\x00"), KCM()) diff --git a/src/wormhole/test/dilate/test_record.py b/src/wormhole/test/dilate/test_record.py index 810396c..41b36e3 100644 --- a/src/wormhole/test/dilate/test_record.py +++ b/src/wormhole/test/dilate/test_record.py @@ -7,13 +7,15 @@ from ..._dilation.connection import (IFramer, Frame, Prologue, _Record, Handshake, Disconnect, Ping) + def make_record(): f = mock.Mock() alsoProvides(f, IFramer) - n = mock.Mock() # pretends to be a Noise object + n = mock.Mock() # pretends to be a Noise object r = _Record(f, n) return r, f, n + class Record(unittest.TestCase): def test_good2(self): f = mock.Mock() @@ -23,7 +25,7 @@ class Record(unittest.TestCase): [Prologue()], [Frame(frame=b"rx-handshake")], [Frame(frame=b"frame1"), Frame(frame=b"frame2")], - ]) + ]) n = mock.Mock() n.write_message = mock.Mock(return_value=b"tx-handshake") p1, p2 = object(), object() @@ -60,9 +62,9 @@ class Record(unittest.TestCase): n.mock_calls[:] = [] # next is a pair of Records - r1, r2 = object() , object() + r1, r2 = object(), object() with mock.patch("wormhole._dilation.connection.parse_record", - side_effect=[r1,r2]) as pr: + side_effect=[r1, r2]) as pr: self.assertEqual(list(r.add_and_unframe(b"blah2")), [r1, r2]) self.assertEqual(n.mock_calls, [mock.call.decrypt(b"frame1"), mock.call.decrypt(b"frame2")]) @@ -186,7 +188,7 @@ class Record(unittest.TestCase): n.write_message = mock.Mock(return_value=outbound_handshake) n.decrypt = mock.Mock(side_effect=[kcm, msg1]) n.encrypt = mock.Mock(side_effect=[f_kcm, f_msg1]) - f.add_and_parse = mock.Mock(side_effect=[[], # no tokens yet + f.add_and_parse = mock.Mock(side_effect=[[], # no tokens yet [Prologue()], [Frame("f_handshake")], [Frame("f_kcm"), @@ -238,7 +240,6 @@ class Record(unittest.TestCase): f.mock_calls[:] = [] n.mock_calls[:] = [] - # 5: at this point we ought to be able to send a messge, the KCM with mock.patch("wormhole._dilation.connection.encode_record", side_effect=[b"r-kcm"]) as er: diff --git a/src/wormhole/test/dilate/test_subchannel.py b/src/wormhole/test/dilate/test_subchannel.py index 69fa001..d56cdf1 100644 --- a/src/wormhole/test/dilate/test_subchannel.py +++ b/src/wormhole/test/dilate/test_subchannel.py @@ -8,6 +8,7 @@ from ..._dilation.subchannel import (Once, SubChannel, AlreadyClosedError) from .common import mock_manager + def make_sc(set_protocol=True): scid = b"scid" hostaddr = _WormholeAddress() @@ -19,6 +20,7 @@ def make_sc(set_protocol=True): sc._set_protocol(p) return sc, m, scid, hostaddr, peeraddr, p + class SubChannelAPI(unittest.TestCase): def test_once(self): o = Once(ValueError)