diff --git a/docs/state-machines/code.dot b/docs/state-machines/code.dot index b56fa1d..0cde977 100644 --- a/docs/state-machines/code.dot +++ b/docs/state-machines/code.dot @@ -1,15 +1,20 @@ digraph { start [label="Wormhole Code\nMachine" style="dotted"] - {rank=same; start S0} - start -> S0 [style="invis"] - S0 [label="S0:\nunknown\ndisconnected"] - S0 -> P0_got_code [label="set_code"] + {rank=same; start S0A S0B} + start -> S0A [style="invis"] + S0A [label="S0A:\nunknown\ndisconnected"] + S0A -> S0B [label="connected"] + S0B -> S0A [label="lost"] + S0B [label="S0B:\nunknown\nconnected"] + S0A -> P0_got_code [label="set_code"] + S0B -> P0_got_code [label="set_code"] P0_got_code [shape="box" label="B.got_code"] P0_got_code -> S4 S4 [label="S4: known" color="green"] - S0 -> P0_list_nameplates [label="input_code"] + S0A -> P0_list_nameplates [label="input_code"] + S0B -> P0_list_nameplates [label="input_code"] S2 [label="S2: typing\nnameplate"] S2 -> P2_completion [label=""] @@ -32,7 +37,7 @@ digraph { S3 -> P0_got_code [label="" color="orange" fontcolor="orange"] - S0 -> S1A [label="allocate_code"] + S0A -> S1A [label="allocate_code"] S1A [label="S1A:\nconnecting"] S1A -> P1_allocate [label="connected"] P1_allocate [shape="box" label="RC.tx_allocate"] @@ -42,5 +47,7 @@ digraph { S1B -> S1A [label="lost"] P1_generate [shape="box" label="generate\nrandom code"] P1_generate -> P0_got_code + + S0B -> P1_allocate [label="allocate_code"] } diff --git a/src/wormhole/_code.py b/src/wormhole/_code.py index eeba08f..307dc5e 100644 --- a/src/wormhole/_code.py +++ b/src/wormhole/_code.py @@ -34,7 +34,9 @@ class Code(object): self._L = _interfaces.ILister(lister) @m.state(initial=True) - def S0_unknown(self): pass # pragma: no cover + def S0A_unknown(self): pass # pragma: no cover + @m.state() + def S0B_unknown_connected(self): pass # pragma: no cover @m.state() def S1A_connecting(self): pass # pragma: no cover @m.state() @@ -82,6 +84,10 @@ class Code(object): self._stdio = stdio self._L.refresh_nameplates() @m.output() + def stash_code_length_and_RC_tx_allocate(self, code_length): + self._code_length = code_length + self._RC.tx_allocate() + @m.output() def stash_code_length(self, code_length): self._code_length = code_length @m.output() @@ -113,18 +119,26 @@ class Code(object): def _B_got_code(self): self._B.got_code(self._code) - S0_unknown.upon(set_code, enter=S4_known, outputs=[B_got_code]) + S0A_unknown.upon(connected, enter=S0B_unknown_connected, outputs=[]) + S0B_unknown_connected.upon(lost, enter=S0A_unknown, outputs=[]) - S0_unknown.upon(allocate_code, enter=S1A_connecting, - outputs=[stash_code_length]) + S0A_unknown.upon(set_code, enter=S4_known, outputs=[B_got_code]) + S0B_unknown_connected.upon(set_code, enter=S4_known, outputs=[B_got_code]) + + S0A_unknown.upon(allocate_code, enter=S1A_connecting, + outputs=[stash_code_length]) + S0B_unknown_connected.upon(allocate_code, enter=S1B_allocating, + outputs=[stash_code_length_and_RC_tx_allocate]) S1A_connecting.upon(connected, enter=S1B_allocating, outputs=[RC_tx_allocate]) S1B_allocating.upon(lost, enter=S1A_connecting, outputs=[]) S1B_allocating.upon(rx_allocated, enter=S4_known, outputs=[generate_and_B_got_code]) - S0_unknown.upon(input_code, enter=S2_typing_nameplate, - outputs=[start_input_and_L_refresh_nameplates]) + S0A_unknown.upon(input_code, enter=S2_typing_nameplate, + outputs=[start_input_and_L_refresh_nameplates]) + S0B_unknown_connected.upon(input_code, enter=S2_typing_nameplate, + outputs=[start_input_and_L_refresh_nameplates]) S2_typing_nameplate.upon(tab, enter=S2_typing_nameplate, outputs=[do_completion_nameplates]) S2_typing_nameplate.upon(got_nameplates, enter=S2_typing_nameplate, diff --git a/src/wormhole/test/test_machines.py b/src/wormhole/test/test_machines.py index a5c3730..4830901 100644 --- a/src/wormhole/test/test_machines.py +++ b/src/wormhole/test/test_machines.py @@ -2,8 +2,9 @@ from __future__ import print_function, unicode_literals import json from zope.interface import directlyProvides from twisted.trial import unittest -from .. import timing, _order, _receive, _key -from .._interfaces import IKey, IReceive, IBoss, ISend, IMailbox +from .. import timing, _order, _receive, _key, _code +from .._interfaces import (IKey, IReceive, IBoss, ISend, IMailbox, + IRendezvousConnector, ILister) from .._key import derive_key, derive_phase_key, encrypt_data from ..util import dict_to_bytes, hexstr_to_bytes, bytes_to_hexstr, to_bytes from spake2 import SPAKE2_Symmetric @@ -174,7 +175,6 @@ class Key(unittest.TestCase): self.assertEqual(events[2][:2], ("m.add_message", "version")) self.assertEqual(events[3], ("r.got_key", key2)) - def test_bad(self): k, b, m, r, events = self.build() code = u"1-foo" @@ -188,3 +188,58 @@ class Key(unittest.TestCase): bad_pake_d = {"not_pake_v1": "stuff"} k.got_pake(dict_to_bytes(bad_pake_d)) self.assertEqual(events, [("b.scared",)]) + +class Code(unittest.TestCase): + def build(self): + events = [] + c = _code.Code(timing.DebugTiming()) + b = Dummy("b", events, IBoss, "got_code") + rc = Dummy("rc", events, IRendezvousConnector, "tx_allocate") + l = Dummy("l", events, ILister, "refresh_nameplates") + c.wire(b, rc, l) + return c, b, rc, l, events + + def test_set_disconnected(self): + c, b, rc, l, events = self.build() + c.set_code(u"code") + self.assertEqual(events, [("b.got_code", u"code")]) + + def test_set_connected(self): + c, b, rc, l, events = self.build() + c.connected() + c.set_code(u"code") + self.assertEqual(events, [("b.got_code", u"code")]) + + def test_allocate_disconnected(self): + c, b, rc, l, events = self.build() + c.allocate_code(2) + self.assertEqual(events, []) + c.connected() + self.assertEqual(events, [("rc.tx_allocate",)]) + events[:] = [] + c.lost() + self.assertEqual(events, []) + c.connected() + self.assertEqual(events, [("rc.tx_allocate",)]) + events[:] = [] + c.rx_allocated("4") + self.assertEqual(len(events), 1, events) + self.assertEqual(events[0][0], "b.got_code") + code = events[0][1] + self.assert_(code.startswith("4-"), code) + + def test_allocate_connected(self): + c, b, rc, l, events = self.build() + c.connected() + c.allocate_code(2) + self.assertEqual(events, [("rc.tx_allocate",)]) + events[:] = [] + c.rx_allocated("4") + self.assertEqual(len(events), 1, events) + self.assertEqual(events[0][0], "b.got_code") + code = events[0][1] + self.assert_(code.startswith("4-"), code) + + # TODO: input_code + +