more flake8 fixes

This commit is contained in:
Brian Warner 2018-06-30 16:23:39 -07:00
parent ea35e570a2
commit bf0c93eddc
13 changed files with 155 additions and 96 deletions

View File

@ -80,28 +80,36 @@ class _Framer(object):
set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover
@m.state() @m.state()
def want_relay(self): pass # pragma: no cover def want_relay(self):
pass # pragma: no cover
@m.state(initial=True) @m.state(initial=True)
def want_prologue(self): pass # pragma: no cover def want_prologue(self):
pass # pragma: no cover
@m.state() @m.state()
def want_frame(self): pass # pragma: no cover def want_frame(self):
pass # pragma: no cover
@m.input() @m.input()
def use_relay(self, relay_handshake): pass def use_relay(self, relay_handshake):
pass
@m.input() @m.input()
def connectionMade(self): pass def connectionMade(self):
pass
@m.input() @m.input()
def parse(self): pass def parse(self):
pass
@m.input() @m.input()
def got_relay_ok(self): pass def got_relay_ok(self):
pass
@m.input() @m.input()
def got_prologue(self): pass def got_prologue(self):
pass
@m.output() @m.output()
def store_relay_handshake(self, relay_handshake): def store_relay_handshake(self, relay_handshake):
@ -312,13 +320,16 @@ class _Record(object):
# states: want_prologue, want_handshake, want_record # states: want_prologue, want_handshake, want_record
@n.state(initial=True) @n.state(initial=True)
def want_prologue(self): pass # pragma: no cover def want_prologue(self):
pass # pragma: no cover
@n.state() @n.state()
def want_handshake(self): pass # pragma: no cover def want_handshake(self):
pass # pragma: no cover
@n.state() @n.state()
def want_message(self): pass # pragma: no cover def want_message(self):
pass # pragma: no cover
@n.input() @n.input()
def got_prologue(self): def got_prologue(self):
@ -417,13 +428,16 @@ class DilatedConnectionProtocol(Protocol, object):
self._can_send_records = False self._can_send_records = False
@m.state(initial=True) @m.state(initial=True)
def unselected(self): pass # pragma: no cover def unselected(self):
pass # pragma: no cover
@m.state() @m.state()
def selecting(self): pass # pragma: no cover def selecting(self):
pass # pragma: no cover
@m.state() @m.state()
def selected(self): pass # pragma: no cover def selected(self):
pass # pragma: no cover
@m.input() @m.input()
def got_kcm(self): def got_kcm(self):

View File

@ -4,34 +4,53 @@ class ManagerFollower(_ManagerBase):
set_trace = getattr(m, "_setTrace", lambda self, f: None) set_trace = getattr(m, "_setTrace", lambda self, f: None)
@m.state(initial=True) @m.state(initial=True)
def IDLE(self): pass # pragma: no cover def IDLE(self):
pass # pragma: no cover
@m.state() @m.state()
def WANTING(self): pass # pragma: no cover def WANTING(self):
pass # pragma: no cover
@m.state() @m.state()
def CONNECTING(self): pass # pragma: no cover def CONNECTING(self):
pass # pragma: no cover
@m.state() @m.state()
def CONNECTED(self): pass # pragma: no cover def CONNECTED(self):
pass # pragma: no cover
@m.state(terminal=True) @m.state(terminal=True)
def STOPPED(self): pass # pragma: no cover def STOPPED(self):
pass # pragma: no cover
@m.input() @m.input()
def start(self): pass # pragma: no cover def start(self):
@m.input() pass # pragma: no cover
def rx_PLEASE(self): pass # pragma: no cover
@m.input()
def rx_DILATE(self): pass # pragma: no cover
@m.input()
def rx_HINTS(self, hint_message): pass # pragma: no cover
@m.input() @m.input()
def connection_made(self): pass # pragma: no cover def rx_PLEASE(self):
pass # pragma: no cover
@m.input() @m.input()
def connection_lost(self): pass # pragma: no cover def rx_DILATE(self):
pass # pragma: no cover
@m.input()
def rx_HINTS(self, hint_message):
pass # pragma: no cover
@m.input()
def connection_made(self):
pass # pragma: no cover
@m.input()
def connection_lost(self):
pass # pragma: no cover
# follower doesn't react to connection_lost, but waits for a new LETS_DILATE # follower doesn't react to connection_lost, but waits for a new LETS_DILATE
@m.input() @m.input()
def stop(self): pass # pragma: no cover def stop(self):
pass # pragma: no cover
# these Outputs behave differently for the Leader vs the Follower # these Outputs behave differently for the Leader vs the Follower
@m.output() @m.output()
@ -48,27 +67,32 @@ class ManagerFollower(_ManagerBase):
@m.output() @m.output()
def use_hints(self, hint_message): def use_hints(self, hint_message):
hint_objs = filter(lambda h: h, # ignore None, unrecognizable hint_objs = filter(lambda h: h, # ignore None, unrecognizable
[parse_hint(hs) for hs in hint_message["hints"]]) [parse_hint(hs) for hs in hint_message["hints"]])
self._connector.got_hints(hint_objs) self._connector.got_hints(hint_objs)
@m.output() @m.output()
def stop_connecting(self): def stop_connecting(self):
self._connector.stop() self._connector.stop()
@m.output() @m.output()
def use_connection(self, c): def use_connection(self, c):
self._use_connection(c) self._use_connection(c)
@m.output() @m.output()
def stop_using_connection(self): def stop_using_connection(self):
self._stop_using_connection() self._stop_using_connection()
@m.output() @m.output()
def signal_error(self): def signal_error(self):
pass # TODO pass # TODO
@m.output() @m.output()
def signal_error_hints(self, hint_message): def signal_error_hints(self, hint_message):
pass # TODO pass # TODO
IDLE.upon(rx_HINTS, enter=STOPPED, outputs=[signal_error_hints]) # too early IDLE.upon(rx_HINTS, enter=STOPPED, outputs=[signal_error_hints]) # too early
IDLE.upon(rx_DILATE, enter=STOPPED, outputs=[signal_error]) # too early IDLE.upon(rx_DILATE, enter=STOPPED, outputs=[signal_error]) # too early
# leader shouldn't send us DILATE before receiving our PLEASE # leader shouldn't send us DILATE before receiving our PLEASE
IDLE.upon(stop, enter=STOPPED, outputs=[]) IDLE.upon(stop, enter=STOPPED, outputs=[])
IDLE.upon(start, enter=WANTING, outputs=[send_please]) IDLE.upon(start, enter=WANTING, outputs=[send_please])
@ -78,7 +102,7 @@ class ManagerFollower(_ManagerBase):
CONNECTING.upon(rx_HINTS, enter=CONNECTING, outputs=[use_hints]) CONNECTING.upon(rx_HINTS, enter=CONNECTING, outputs=[use_hints])
CONNECTING.upon(connection_made, enter=CONNECTED, outputs=[use_connection]) CONNECTING.upon(connection_made, enter=CONNECTED, outputs=[use_connection])
# shouldn't happen: connection_lost # shouldn't happen: connection_lost
#CONNECTING.upon(connection_lost, enter=CONNECTING, outputs=[?]) # CONNECTING.upon(connection_lost, enter=CONNECTING, outputs=[?])
CONNECTING.upon(rx_DILATE, enter=CONNECTING, outputs=[stop_connecting, CONNECTING.upon(rx_DILATE, enter=CONNECTING, outputs=[stop_connecting,
start_connecting]) start_connecting])
# receiving rx_DILATE while we're still working on the last one means the # receiving rx_DILATE while we're still working on the last one means the
@ -89,7 +113,7 @@ class ManagerFollower(_ManagerBase):
CONNECTED.upon(connection_lost, enter=WANTING, outputs=[stop_using_connection]) CONNECTED.upon(connection_lost, enter=WANTING, outputs=[stop_using_connection])
CONNECTED.upon(rx_DILATE, enter=CONNECTING, outputs=[stop_using_connection, CONNECTED.upon(rx_DILATE, enter=CONNECTING, outputs=[stop_using_connection,
start_connecting]) start_connecting])
CONNECTED.upon(rx_HINTS, enter=CONNECTED, outputs=[]) # too late, ignore CONNECTED.upon(rx_HINTS, enter=CONNECTED, outputs=[]) # too late, ignore
CONNECTED.upon(stop, enter=STOPPED, outputs=[stop_using_connection]) CONNECTED.upon(stop, enter=STOPPED, outputs=[stop_using_connection])
# shouldn't happen: connection_made # shouldn't happen: connection_made

View File

@ -3,16 +3,19 @@ import mock
from zope.interface import alsoProvides from zope.interface import alsoProvides
from ..._interfaces import IDilationManager, IWormhole from ..._interfaces import IDilationManager, IWormhole
def mock_manager(): def mock_manager():
m = mock.Mock() m = mock.Mock()
alsoProvides(m, IDilationManager) alsoProvides(m, IDilationManager)
return m return m
def mock_wormhole(): def mock_wormhole():
m = mock.Mock() m = mock.Mock()
alsoProvides(m, IWormhole) alsoProvides(m, IWormhole)
return m return m
def clear_mock_calls(*args): def clear_mock_calls(*args):
for a in args: for a in args:
a.mock_calls[:] = [] a.mock_calls[:] = []

View File

@ -11,12 +11,13 @@ from ..._dilation.connection import (DilatedConnectionProtocol, encode_record,
KCM, Open, Ack) KCM, Open, Ack)
from .common import clear_mock_calls from .common import clear_mock_calls
def make_con(role, use_relay=False): def make_con(role, use_relay=False):
clock = Clock() clock = Clock()
eq = EventualQueue(clock) eq = EventualQueue(clock)
connector = mock.Mock() connector = mock.Mock()
alsoProvides(connector, IDilationConnector) alsoProvides(connector, IDilationConnector)
n = mock.Mock() # pretends to be a Noise object n = mock.Mock() # pretends to be a Noise object
n.write_message = mock.Mock(side_effect=[b"handshake"]) n.write_message = mock.Mock(side_effect=[b"handshake"])
c = DilatedConnectionProtocol(eq, role, connector, n, c = DilatedConnectionProtocol(eq, role, connector, n,
b"outbound_prologue\n", b"inbound_prologue\n") b"outbound_prologue\n", b"inbound_prologue\n")
@ -26,6 +27,7 @@ def make_con(role, use_relay=False):
alsoProvides(t, ITransport) alsoProvides(t, ITransport)
return c, n, connector, t, eq return c, n, connector, t, eq
class Connection(unittest.TestCase): class Connection(unittest.TestCase):
def test_bad_prologue(self): def test_bad_prologue(self):
c, n, connector, t, eq = make_con(LEADER) c, n, connector, t, eq = make_con(LEADER)
@ -55,10 +57,10 @@ class Connection(unittest.TestCase):
n.decrypt = mock.Mock(side_effect=[ n.decrypt = mock.Mock(side_effect=[
encode_record(t_kcm), encode_record(t_kcm),
encode_record(t_open), encode_record(t_open),
]) ])
exp_kcm = b"\x00\x00\x00\x03kcm" exp_kcm = b"\x00\x00\x00\x03kcm"
n.encrypt = mock.Mock(side_effect=[b"kcm", b"ack1"]) n.encrypt = mock.Mock(side_effect=[b"kcm", b"ack1"])
m = mock.Mock() # Manager m = mock.Mock() # Manager
c.makeConnection(t) c.makeConnection(t)
self.assertEqual(n.mock_calls, [mock.call.start_handshake()]) self.assertEqual(n.mock_calls, [mock.call.start_handshake()])
@ -86,7 +88,7 @@ class Connection(unittest.TestCase):
self.assertEqual(n.mock_calls, [ self.assertEqual(n.mock_calls, [
mock.call.read_message(b"handshake2"), mock.call.read_message(b"handshake2"),
mock.call.encrypt(encode_record(t_kcm)), mock.call.encrypt(encode_record(t_kcm)),
]) ])
self.assertEqual(connector.mock_calls, []) self.assertEqual(connector.mock_calls, [])
self.assertEqual(t.mock_calls, [ self.assertEqual(t.mock_calls, [
mock.call.write(exp_kcm)]) mock.call.write(exp_kcm)])
@ -123,7 +125,7 @@ class Connection(unittest.TestCase):
c.send_record(KCM()) c.send_record(KCM())
self.assertEqual(n.mock_calls, [ self.assertEqual(n.mock_calls, [
mock.call.encrypt(encode_record(t_kcm)), mock.call.encrypt(encode_record(t_kcm)),
]) ])
self.assertEqual(connector.mock_calls, []) self.assertEqual(connector.mock_calls, [])
self.assertEqual(t.mock_calls, [mock.call.write(exp_kcm)]) self.assertEqual(t.mock_calls, [mock.call.write(exp_kcm)])
self.assertEqual(m.mock_calls, []) self.assertEqual(m.mock_calls, [])
@ -163,7 +165,6 @@ class Connection(unittest.TestCase):
def test_no_relay_follower(self): def test_no_relay_follower(self):
return self._test_no_relay(FOLLOWER) return self._test_no_relay(FOLLOWER)
def test_relay(self): def test_relay(self):
c, n, connector, t, eq = make_con(LEADER, use_relay=True) c, n, connector, t, eq = make_con(LEADER, use_relay=True)

View File

@ -2,11 +2,12 @@ from __future__ import print_function, unicode_literals
from twisted.trial import unittest from twisted.trial import unittest
from ..._dilation.encode import to_be4, from_be4 from ..._dilation.encode import to_be4, from_be4
class Encoding(unittest.TestCase): class Encoding(unittest.TestCase):
def test_be4(self): def test_be4(self):
self.assertEqual(to_be4(0), b"\x00\x00\x00\x00") self.assertEqual(to_be4(0), b"\x00\x00\x00\x00")
self.assertEqual(to_be4(1), b"\x00\x00\x00\x01") self.assertEqual(to_be4(1), b"\x00\x00\x00\x01")
self.assertEqual(to_be4(256), b"\x00\x00\x01\x00") self.assertEqual(to_be4(256), b"\x00\x00\x01\x00")
self.assertEqual(to_be4(257), b"\x00\x00\x01\x01") self.assertEqual(to_be4(257), b"\x00\x00\x01\x01")
with self.assertRaises(ValueError): with self.assertRaises(ValueError):

View File

@ -11,6 +11,7 @@ from ..._dilation.subchannel import (ControlEndpoint,
SingleUseEndpointError) SingleUseEndpointError)
from .common import mock_manager from .common import mock_manager
class Endpoints(unittest.TestCase): class Endpoints(unittest.TestCase):
def test_control(self): def test_control(self):
scid0 = b"scid0" scid0 = b"scid0"
@ -94,4 +95,4 @@ class Endpoints(unittest.TestCase):
self.assertEqual(t2.mock_calls, [mock.call._set_protocol(p2)]) self.assertEqual(t2.mock_calls, [mock.call._set_protocol(p2)])
self.assertEqual(p2.mock_calls, [mock.call.makeConnection(t2)]) self.assertEqual(p2.mock_calls, [mock.call.makeConnection(t2)])
lp.stopListening() # TODO: should this do more? lp.stopListening() # TODO: should this do more?

View File

@ -5,12 +5,14 @@ from twisted.trial import unittest
from twisted.internet.interfaces import ITransport from twisted.internet.interfaces import ITransport
from ..._dilation.connection import _Framer, Frame, Prologue, Disconnect from ..._dilation.connection import _Framer, Frame, Prologue, Disconnect
def make_framer(): def make_framer():
t = mock.Mock() t = mock.Mock()
alsoProvides(t, ITransport) alsoProvides(t, ITransport)
f = _Framer(t, b"outbound_prologue\n", b"inbound_prologue\n") f = _Framer(t, b"outbound_prologue\n", b"inbound_prologue\n")
return f, t return f, t
class Framer(unittest.TestCase): class Framer(unittest.TestCase):
def test_bad_prologue_length(self): def test_bad_prologue_length(self):
f, t = make_framer() f, t = make_framer()
@ -19,7 +21,7 @@ class Framer(unittest.TestCase):
f.connectionMade() f.connectionMade()
self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")]) self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")])
t.mock_calls[:] = [] t.mock_calls[:] = []
self.assertEqual([], list(f.add_and_parse(b"inbound_"))) # wait for it self.assertEqual([], list(f.add_and_parse(b"inbound_"))) # wait for it
self.assertEqual(t.mock_calls, []) self.assertEqual(t.mock_calls, [])
with mock.patch("wormhole._dilation.connection.log.msg") as m: with mock.patch("wormhole._dilation.connection.log.msg") as m:
@ -37,7 +39,7 @@ class Framer(unittest.TestCase):
f.connectionMade() f.connectionMade()
self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")]) self.assertEqual(t.mock_calls, [mock.call.write(b"outbound_prologue\n")])
t.mock_calls[:] = [] t.mock_calls[:] = []
self.assertEqual([], list(f.add_and_parse(b"inbound_"))) # wait for it self.assertEqual([], list(f.add_and_parse(b"inbound_"))) # wait for it
self.assertEqual([], list(f.add_and_parse(b"not"))) self.assertEqual([], list(f.add_and_parse(b"not")))
with mock.patch("wormhole._dilation.connection.log.msg") as m: with mock.patch("wormhole._dilation.connection.log.msg") as m:

View File

@ -8,6 +8,7 @@ from ..._dilation.inbound import (Inbound, DuplicateOpenError,
DataForMissingSubchannelError, DataForMissingSubchannelError,
CloseForMissingSubchannelError) CloseForMissingSubchannelError)
def make_inbound(): def make_inbound():
m = mock.Mock() m = mock.Mock()
alsoProvides(m, IDilationManager) alsoProvides(m, IDilationManager)
@ -15,6 +16,7 @@ def make_inbound():
i = Inbound(m, host_addr) i = Inbound(m, host_addr)
return i, m, host_addr return i, m, host_addr
class InboundTest(unittest.TestCase): class InboundTest(unittest.TestCase):
def test_seqnum(self): def test_seqnum(self):
i, m, host_addr = make_inbound() i, m, host_addr = make_inbound()
@ -158,7 +160,7 @@ class InboundTest(unittest.TestCase):
self.assertEqual(c.mock_calls, [mock.call.pauseProducing()]) self.assertEqual(c.mock_calls, [mock.call.pauseProducing()])
c.mock_calls[:] = [] c.mock_calls[:] = []
i.subchannel_pauseProducing(sc2) i.subchannel_pauseProducing(sc2)
self.assertEqual(c.mock_calls, []) # was already paused self.assertEqual(c.mock_calls, []) # was already paused
# tolerate duplicate pauseProducing # tolerate duplicate pauseProducing
i.subchannel_pauseProducing(sc2) i.subchannel_pauseProducing(sc2)

View File

@ -13,12 +13,14 @@ from ..._dilation.manager import (Dilator,
from ..._dilation.subchannel import _WormholeAddress from ..._dilation.subchannel import _WormholeAddress
from .common import clear_mock_calls from .common import clear_mock_calls
def make_dilator(): def make_dilator():
reactor = object() reactor = object()
clock = Clock() clock = Clock()
eq = EventualQueue(clock) eq = EventualQueue(clock)
term = mock.Mock(side_effect=lambda: True) # one write per Eventual tick term = mock.Mock(side_effect=lambda: True) # one write per Eventual tick
term_factory = lambda: term
def term_factory(): return term
coop = Cooperator(terminationPredicateFactory=term_factory, coop = Cooperator(terminationPredicateFactory=term_factory,
scheduler=eq.eventually) scheduler=eq.eventually)
send = mock.Mock() send = mock.Mock()
@ -27,6 +29,7 @@ def make_dilator():
dil.wire(send) dil.wire(send)
return dil, send, reactor, eq, clock, coop return dil, send, reactor, eq, clock, coop
class TestDilator(unittest.TestCase): class TestDilator(unittest.TestCase):
def test_leader(self): def test_leader(self):
dil, send, reactor, eq, clock, coop = make_dilator() dil, send, reactor, eq, clock, coop = make_dilator()
@ -148,12 +151,11 @@ class TestDilator(unittest.TestCase):
d1 = dil.dilate() d1 = dil.dilate()
self.assertNoResult(d1) self.assertNoResult(d1)
dil.got_wormhole_versions("me", "you", {}) # missing "can-dilate" dil.got_wormhole_versions("me", "you", {}) # missing "can-dilate"
eq.flush_sync() eq.flush_sync()
f = self.failureResultOf(d1) f = self.failureResultOf(d1)
f.check(OldPeerCannotDilateError) f.check(OldPeerCannotDilateError)
def test_disjoint_versions(self): def test_disjoint_versions(self):
dil, send, reactor, eq, clock, coop = make_dilator() dil, send, reactor, eq, clock, coop = make_dilator()
d1 = dil.dilate() d1 = dil.dilate()
@ -164,7 +166,6 @@ class TestDilator(unittest.TestCase):
f = self.failureResultOf(d1) f = self.failureResultOf(d1)
f.check(OldPeerCannotDilateError) f.check(OldPeerCannotDilateError)
def test_early_dilate_messages(self): def test_early_dilate_messages(self):
dil, send, reactor, eq, clock, coop = make_dilator() dil, send, reactor, eq, clock, coop = make_dilator()
dil._transit_key = b"key" dil._transit_key = b"key"
@ -188,8 +189,6 @@ class TestDilator(unittest.TestCase):
mock.call.rx_HINTS(hintmsg), mock.call.rx_HINTS(hintmsg),
mock.call.when_first_connected()]) mock.call.when_first_connected()])
def test_transit_relay(self): def test_transit_relay(self):
dil, send, reactor, eq, clock, coop = make_dilator() dil, send, reactor, eq, clock, coop = make_dilator()
dil._transit_key = b"key" dil._transit_key = b"key"

View File

@ -16,17 +16,20 @@ Pauser = namedtuple("Pauser", ["seqnum"])
NonPauser = namedtuple("NonPauser", ["seqnum"]) NonPauser = namedtuple("NonPauser", ["seqnum"])
Stopper = namedtuple("Stopper", ["sc"]) Stopper = namedtuple("Stopper", ["sc"])
def make_outbound(): def make_outbound():
m = mock.Mock() m = mock.Mock()
alsoProvides(m, IDilationManager) alsoProvides(m, IDilationManager)
clock = Clock() clock = Clock()
eq = EventualQueue(clock) eq = EventualQueue(clock)
term = mock.Mock(side_effect=lambda: True) # one write per Eventual tick term = mock.Mock(side_effect=lambda: True) # one write per Eventual tick
term_factory = lambda: term
def term_factory(): return term
coop = Cooperator(terminationPredicateFactory=term_factory, coop = Cooperator(terminationPredicateFactory=term_factory,
scheduler=eq.eventually) scheduler=eq.eventually)
o = Outbound(m, coop) o = Outbound(m, coop)
c = mock.Mock() # Connection c = mock.Mock() # Connection
def maybe_pause(r): def maybe_pause(r):
if isinstance(r, Pauser): if isinstance(r, Pauser):
o.pauseProducing() o.pauseProducing()
@ -37,6 +40,7 @@ def make_outbound():
o._test_term = term o._test_term = term
return o, m, c return o, m, c
class OutboundTest(unittest.TestCase): class OutboundTest(unittest.TestCase):
def test_build_record(self): def test_build_record(self):
o, m, c = make_outbound() o, m, c = make_outbound()
@ -69,10 +73,10 @@ class OutboundTest(unittest.TestCase):
o.handle_ack(r3.seqnum) o.handle_ack(r3.seqnum)
self.assertEqual(list(o._outbound_queue), []) self.assertEqual(list(o._outbound_queue), [])
o.handle_ack(r3.seqnum) # ignored o.handle_ack(r3.seqnum) # ignored
self.assertEqual(list(o._outbound_queue), []) self.assertEqual(list(o._outbound_queue), [])
o.handle_ack(r1.seqnum) # ignored o.handle_ack(r1.seqnum) # ignored
self.assertEqual(list(o._outbound_queue), []) self.assertEqual(list(o._outbound_queue), [])
def test_duplicate_registerProducer(self): def test_duplicate_registerProducer(self):
@ -192,7 +196,8 @@ class OutboundTest(unittest.TestCase):
clear_mock_calls(c) clear_mock_calls(c)
sc1, sc2, sc3 = object(), object(), object() sc1, sc2, sc3 = object(), object(), object()
p1, p2, p3 = mock.Mock(name="p1"), mock.Mock(name="p2"), mock.Mock(name="p3") p1, p2, p3 = mock.Mock(name="p1"), mock.Mock(
name="p2"), mock.Mock(name="p3")
# we aren't paused yet, since we haven't sent any data # we aren't paused yet, since we haven't sent any data
o.subchannel_registerProducer(sc1, p1, True) o.subchannel_registerProducer(sc1, p1, True)
@ -310,7 +315,8 @@ class OutboundTest(unittest.TestCase):
# and another disconnects itself when called # and another disconnects itself when called
p2.resumeProducing.side_effect = lambda: None p2.resumeProducing.side_effect = lambda: None
p3.resumeProducing.side_effect = lambda: o.subchannel_unregisterProducer(sc3) p3.resumeProducing.side_effect = lambda: o.subchannel_unregisterProducer(
sc3)
o.pauseProducing() o.pauseProducing()
o.resumeProducing() o.resumeProducing()
self.assertEqual(p2.mock_calls, [mock.call.pauseProducing(), self.assertEqual(p2.mock_calls, [mock.call.pauseProducing(),
@ -360,7 +366,7 @@ class OutboundTest(unittest.TestCase):
r2 = NonPauser(seqnum=2) r2 = NonPauser(seqnum=2)
# we aren't paused yet, since we haven't sent any data # we aren't paused yet, since we haven't sent any data
o.subchannel_registerProducer(sc1, p1, True) # push o.subchannel_registerProducer(sc1, p1, True) # push
o.queue_and_send_record(r1) o.queue_and_send_record(r1)
# now we're paused # now we're paused
self.assertTrue(o._paused) self.assertTrue(o._paused)
@ -371,7 +377,7 @@ class OutboundTest(unittest.TestCase):
p1.resumeProducing.side_effect = lambda: c.send_record(r1) p1.resumeProducing.side_effect = lambda: c.send_record(r1)
p2.resumeProducing.side_effect = lambda: c.send_record(r2) p2.resumeProducing.side_effect = lambda: c.send_record(r2)
o.subchannel_registerProducer(sc2, p2, False) # pull: always ready o.subchannel_registerProducer(sc2, p2, False) # pull: always ready
# p1 is still first, since p2 was just added (at the end) # p1 is still first, since p2 was just added (at the end)
self.assertTrue(o._paused) self.assertTrue(o._paused)
@ -390,7 +396,7 @@ class OutboundTest(unittest.TestCase):
mock.call.pauseProducing(), mock.call.pauseProducing(),
]) ])
self.assertEqual(p2.mock_calls, []) self.assertEqual(p2.mock_calls, [])
self.assertEqual(list(o._all_producers), [p2, p1]) # now p2 is next self.assertEqual(list(o._all_producers), [p2, p1]) # now p2 is next
clear_mock_calls(p1, p2, c) clear_mock_calls(p1, p2, c)
# next should fire p2, then p1 # next should fire p2, then p1
@ -404,7 +410,7 @@ class OutboundTest(unittest.TestCase):
]) ])
self.assertEqual(p2.mock_calls, [mock.call.resumeProducing(), self.assertEqual(p2.mock_calls, [mock.call.resumeProducing(),
]) ])
self.assertEqual(list(o._all_producers), [p2, p1]) # p2 still at bat self.assertEqual(list(o._all_producers), [p2, p1]) # p2 still at bat
clear_mock_calls(p1, p2, c) clear_mock_calls(p1, p2, c)
def test_pull_producer(self): def test_pull_producer(self):
@ -430,13 +436,13 @@ class OutboundTest(unittest.TestCase):
it = iter(records) it = iter(records)
p1.resumeProducing.side_effect = lambda: c.send_record(next(it)) p1.resumeProducing.side_effect = lambda: c.send_record(next(it))
o.subchannel_registerProducer(sc1, p1, False) o.subchannel_registerProducer(sc1, p1, False)
eq.flush_sync() # fast forward into the glorious (paused) future eq.flush_sync() # fast forward into the glorious (paused) future
self.assertTrue(o._paused) self.assertTrue(o._paused)
self.assertEqual(c.mock_calls, self.assertEqual(c.mock_calls,
[mock.call.send_record(r) for r in records[:-1]]) [mock.call.send_record(r) for r in records[:-1]])
self.assertEqual(p1.mock_calls, self.assertEqual(p1.mock_calls,
[mock.call.resumeProducing()]*(len(records)-1)) [mock.call.resumeProducing()] * (len(records) - 1))
clear_mock_calls(c, p1) clear_mock_calls(c, p1)
# next resumeProducing should cause it to disconnect # next resumeProducing should cause it to disconnect
@ -460,7 +466,7 @@ class OutboundTest(unittest.TestCase):
NonPauser(3), NonPauser(13), NonPauser(3), NonPauser(13),
NonPauser(4), NonPauser(14), NonPauser(4), NonPauser(14),
Pauser(5)] Pauser(5)]
expected2 = [ NonPauser(15), expected2 = [NonPauser(15),
NonPauser(6), NonPauser(16), NonPauser(6), NonPauser(16),
NonPauser(7), NonPauser(17), NonPauser(7), NonPauser(17),
NonPauser(8), NonPauser(18), NonPauser(8), NonPauser(18),
@ -487,14 +493,14 @@ class OutboundTest(unittest.TestCase):
p2.resumeProducing.side_effect = lambda: c.send_record(next(it2)) p2.resumeProducing.side_effect = lambda: c.send_record(next(it2))
o.subchannel_registerProducer(sc2, p2, False) o.subchannel_registerProducer(sc2, p2, False)
eq.flush_sync() # fast forward into the glorious (paused) future eq.flush_sync() # fast forward into the glorious (paused) future
sends = [mock.call.resumeProducing()] sends = [mock.call.resumeProducing()]
self.assertTrue(o._paused) self.assertTrue(o._paused)
self.assertEqual(c.mock_calls, self.assertEqual(c.mock_calls,
[mock.call.send_record(r) for r in expected1]) [mock.call.send_record(r) for r in expected1])
self.assertEqual(p1.mock_calls, 6*sends) self.assertEqual(p1.mock_calls, 6 * sends)
self.assertEqual(p2.mock_calls, 5*sends) self.assertEqual(p2.mock_calls, 5 * sends)
clear_mock_calls(c, p1, p2) clear_mock_calls(c, p1, p2)
o.resumeProducing() o.resumeProducing()
@ -502,13 +508,13 @@ class OutboundTest(unittest.TestCase):
self.assertTrue(o._paused) self.assertTrue(o._paused)
self.assertEqual(c.mock_calls, self.assertEqual(c.mock_calls,
[mock.call.send_record(r) for r in expected2]) [mock.call.send_record(r) for r in expected2])
self.assertEqual(p1.mock_calls, 4*sends) self.assertEqual(p1.mock_calls, 4 * sends)
self.assertEqual(p2.mock_calls, 5*sends) self.assertEqual(p2.mock_calls, 5 * sends)
clear_mock_calls(c, p1, p2) clear_mock_calls(c, p1, p2)
def test_send_if_connected(self): def test_send_if_connected(self):
o, m, c = make_outbound() o, m, c = make_outbound()
o.send_if_connected(Ack(1)) # not connected yet o.send_if_connected(Ack(1)) # not connected yet
o.use_connection(c) o.use_connection(c)
o.send_if_connected(KCM()) o.send_if_connected(KCM())
@ -517,7 +523,7 @@ class OutboundTest(unittest.TestCase):
def test_tolerate_duplicate_pause_resume(self): def test_tolerate_duplicate_pause_resume(self):
o, m, c = make_outbound() o, m, c = make_outbound()
self.assertTrue(o._paused) # no connection self.assertTrue(o._paused) # no connection
o.use_connection(c) o.use_connection(c)
self.assertFalse(o._paused) self.assertFalse(o._paused)
o.pauseProducing() o.pauseProducing()
@ -533,7 +539,7 @@ class OutboundTest(unittest.TestCase):
o, m, c = make_outbound() o, m, c = make_outbound()
o.use_connection(c) o.use_connection(c)
self.assertFalse(o._paused) self.assertFalse(o._paused)
o.stopProducing() # connection does this before loss o.stopProducing() # connection does this before loss
self.assertTrue(o._paused) self.assertTrue(o._paused)
o.stop_using_connection() o.stop_using_connection()
self.assertTrue(o._paused) self.assertTrue(o._paused)
@ -559,13 +565,15 @@ def make_pushpull(pauses):
clock = Clock() clock = Clock()
eq = EventualQueue(clock) eq = EventualQueue(clock)
term = mock.Mock(side_effect=lambda: True) # one write per Eventual tick term = mock.Mock(side_effect=lambda: True) # one write per Eventual tick
term_factory = lambda: term
def term_factory(): return term
coop = Cooperator(terminationPredicateFactory=term_factory, coop = Cooperator(terminationPredicateFactory=term_factory,
scheduler=eq.eventually) scheduler=eq.eventually)
pp = PullToPush(p, unregister, coop) pp = PullToPush(p, unregister, coop)
it = cycle(pauses) it = cycle(pauses)
def action(i): def action(i):
if isinstance(i, Exception): if isinstance(i, Exception):
raise i raise i
@ -574,41 +582,45 @@ def make_pushpull(pauses):
p.resumeProducing.side_effect = lambda: action(next(it)) p.resumeProducing.side_effect = lambda: action(next(it))
return p, unregister, pp, eq return p, unregister, pp, eq
class PretendResumptionError(Exception): class PretendResumptionError(Exception):
pass pass
class PretendUnregisterError(Exception): class PretendUnregisterError(Exception):
pass pass
class PushPull(unittest.TestCase): class PushPull(unittest.TestCase):
# test our wrapper utility, which I copied from # test our wrapper utility, which I copied from
# twisted.internet._producer_helpers since it isn't publically exposed # twisted.internet._producer_helpers since it isn't publically exposed
def test_start_unpaused(self): def test_start_unpaused(self):
p, unr, pp, eq = make_pushpull([True]) # pause on each resumeProducing p, unr, pp, eq = make_pushpull([True]) # pause on each resumeProducing
# if it starts unpaused, it gets one write before being halted # if it starts unpaused, it gets one write before being halted
pp.startStreaming(False) pp.startStreaming(False)
eq.flush_sync() eq.flush_sync()
self.assertEqual(p.mock_calls, [mock.call.resumeProducing()]*1) self.assertEqual(p.mock_calls, [mock.call.resumeProducing()] * 1)
clear_mock_calls(p) clear_mock_calls(p)
# now each time we call resumeProducing, we should see one delivered to # now each time we call resumeProducing, we should see one delivered to
# the underlying IPullProducer # the underlying IPullProducer
pp.resumeProducing() pp.resumeProducing()
eq.flush_sync() eq.flush_sync()
self.assertEqual(p.mock_calls, [mock.call.resumeProducing()]*1) self.assertEqual(p.mock_calls, [mock.call.resumeProducing()] * 1)
pp.stopStreaming() pp.stopStreaming()
pp.stopStreaming() # should tolerate this pp.stopStreaming() # should tolerate this
def test_start_unpaused_two_writes(self): def test_start_unpaused_two_writes(self):
p, unr, pp, eq = make_pushpull([False, True]) # pause every other time p, unr, pp, eq = make_pushpull([False, True]) # pause every other time
# it should get two writes, since the first didn't pause # it should get two writes, since the first didn't pause
pp.startStreaming(False) pp.startStreaming(False)
eq.flush_sync() eq.flush_sync()
self.assertEqual(p.mock_calls, [mock.call.resumeProducing()]*2) self.assertEqual(p.mock_calls, [mock.call.resumeProducing()] * 2)
def test_start_paused(self): def test_start_paused(self):
p, unr, pp, eq = make_pushpull([True]) # pause on each resumeProducing p, unr, pp, eq = make_pushpull([True]) # pause on each resumeProducing
pp.startStreaming(True) pp.startStreaming(True)
eq.flush_sync() eq.flush_sync()
self.assertEqual(p.mock_calls, []) self.assertEqual(p.mock_calls, [])
@ -637,9 +649,5 @@ class PushPull(unittest.TestCase):
self.assertEqual(unr.mock_calls, [mock.call()]) self.assertEqual(unr.mock_calls, [mock.call()])
self.flushLoggedErrors(PretendResumptionError, PretendUnregisterError) self.flushLoggedErrors(PretendResumptionError, PretendUnregisterError)
# TODO: consider making p1/p2/p3 all elements of a shared Mock, maybe I # TODO: consider making p1/p2/p3 all elements of a shared Mock, maybe I
# could capture the inter-call ordering that way # could capture the inter-call ordering that way

View File

@ -4,6 +4,7 @@ from twisted.trial import unittest
from ..._dilation.connection import (parse_record, encode_record, from ..._dilation.connection import (parse_record, encode_record,
KCM, Ping, Pong, Open, Data, Close, Ack) KCM, Ping, Pong, Open, Data, Close, Ack)
class Parse(unittest.TestCase): class Parse(unittest.TestCase):
def test_parse(self): def test_parse(self):
self.assertEqual(parse_record(b"\x00"), KCM()) self.assertEqual(parse_record(b"\x00"), KCM())

View File

@ -7,13 +7,15 @@ from ..._dilation.connection import (IFramer, Frame, Prologue,
_Record, Handshake, _Record, Handshake,
Disconnect, Ping) Disconnect, Ping)
def make_record(): def make_record():
f = mock.Mock() f = mock.Mock()
alsoProvides(f, IFramer) alsoProvides(f, IFramer)
n = mock.Mock() # pretends to be a Noise object n = mock.Mock() # pretends to be a Noise object
r = _Record(f, n) r = _Record(f, n)
return r, f, n return r, f, n
class Record(unittest.TestCase): class Record(unittest.TestCase):
def test_good2(self): def test_good2(self):
f = mock.Mock() f = mock.Mock()
@ -23,7 +25,7 @@ class Record(unittest.TestCase):
[Prologue()], [Prologue()],
[Frame(frame=b"rx-handshake")], [Frame(frame=b"rx-handshake")],
[Frame(frame=b"frame1"), Frame(frame=b"frame2")], [Frame(frame=b"frame1"), Frame(frame=b"frame2")],
]) ])
n = mock.Mock() n = mock.Mock()
n.write_message = mock.Mock(return_value=b"tx-handshake") n.write_message = mock.Mock(return_value=b"tx-handshake")
p1, p2 = object(), object() p1, p2 = object(), object()
@ -60,9 +62,9 @@ class Record(unittest.TestCase):
n.mock_calls[:] = [] n.mock_calls[:] = []
# next is a pair of Records # next is a pair of Records
r1, r2 = object() , object() r1, r2 = object(), object()
with mock.patch("wormhole._dilation.connection.parse_record", with mock.patch("wormhole._dilation.connection.parse_record",
side_effect=[r1,r2]) as pr: side_effect=[r1, r2]) as pr:
self.assertEqual(list(r.add_and_unframe(b"blah2")), [r1, r2]) self.assertEqual(list(r.add_and_unframe(b"blah2")), [r1, r2])
self.assertEqual(n.mock_calls, [mock.call.decrypt(b"frame1"), self.assertEqual(n.mock_calls, [mock.call.decrypt(b"frame1"),
mock.call.decrypt(b"frame2")]) mock.call.decrypt(b"frame2")])
@ -186,7 +188,7 @@ class Record(unittest.TestCase):
n.write_message = mock.Mock(return_value=outbound_handshake) n.write_message = mock.Mock(return_value=outbound_handshake)
n.decrypt = mock.Mock(side_effect=[kcm, msg1]) n.decrypt = mock.Mock(side_effect=[kcm, msg1])
n.encrypt = mock.Mock(side_effect=[f_kcm, f_msg1]) n.encrypt = mock.Mock(side_effect=[f_kcm, f_msg1])
f.add_and_parse = mock.Mock(side_effect=[[], # no tokens yet f.add_and_parse = mock.Mock(side_effect=[[], # no tokens yet
[Prologue()], [Prologue()],
[Frame("f_handshake")], [Frame("f_handshake")],
[Frame("f_kcm"), [Frame("f_kcm"),
@ -238,7 +240,6 @@ class Record(unittest.TestCase):
f.mock_calls[:] = [] f.mock_calls[:] = []
n.mock_calls[:] = [] n.mock_calls[:] = []
# 5: at this point we ought to be able to send a messge, the KCM # 5: at this point we ought to be able to send a messge, the KCM
with mock.patch("wormhole._dilation.connection.encode_record", with mock.patch("wormhole._dilation.connection.encode_record",
side_effect=[b"r-kcm"]) as er: side_effect=[b"r-kcm"]) as er:

View File

@ -8,6 +8,7 @@ from ..._dilation.subchannel import (Once, SubChannel,
AlreadyClosedError) AlreadyClosedError)
from .common import mock_manager from .common import mock_manager
def make_sc(set_protocol=True): def make_sc(set_protocol=True):
scid = b"scid" scid = b"scid"
hostaddr = _WormholeAddress() hostaddr = _WormholeAddress()
@ -19,6 +20,7 @@ def make_sc(set_protocol=True):
sc._set_protocol(p) sc._set_protocol(p)
return sc, m, scid, hostaddr, peeraddr, p return sc, m, scid, hostaddr, peeraddr, p
class SubChannelAPI(unittest.TestCase): class SubChannelAPI(unittest.TestCase):
def test_once(self): def test_once(self):
o = Once(ValueError) o = Once(ValueError)