Code: handle being connected before being told what to do

This commit is contained in:
Brian Warner 2017-03-04 12:40:19 +01:00
parent 8ee342ad82
commit 6ada8252b7
3 changed files with 91 additions and 15 deletions

View File

@ -1,15 +1,20 @@
digraph { digraph {
start [label="Wormhole Code\nMachine" style="dotted"] start [label="Wormhole Code\nMachine" style="dotted"]
{rank=same; start S0} {rank=same; start S0A S0B}
start -> S0 [style="invis"] start -> S0A [style="invis"]
S0 [label="S0:\nunknown\ndisconnected"] S0A [label="S0A:\nunknown\ndisconnected"]
S0 -> P0_got_code [label="set_code"] S0A -> S0B [label="connected"]
S0B -> S0A [label="lost"]
S0B [label="S0B:\nunknown\nconnected"]
S0A -> P0_got_code [label="set_code"]
S0B -> P0_got_code [label="set_code"]
P0_got_code [shape="box" label="B.got_code"] P0_got_code [shape="box" label="B.got_code"]
P0_got_code -> S4 P0_got_code -> S4
S4 [label="S4: known" color="green"] S4 [label="S4: known" color="green"]
S0 -> P0_list_nameplates [label="input_code"] S0A -> P0_list_nameplates [label="input_code"]
S0B -> P0_list_nameplates [label="input_code"]
S2 [label="S2: typing\nnameplate"] S2 [label="S2: typing\nnameplate"]
S2 -> P2_completion [label="<tab>"] S2 -> P2_completion [label="<tab>"]
@ -32,7 +37,7 @@ digraph {
S3 -> P0_got_code [label="<return>" S3 -> P0_got_code [label="<return>"
color="orange" fontcolor="orange"] color="orange" fontcolor="orange"]
S0 -> S1A [label="allocate_code"] S0A -> S1A [label="allocate_code"]
S1A [label="S1A:\nconnecting"] S1A [label="S1A:\nconnecting"]
S1A -> P1_allocate [label="connected"] S1A -> P1_allocate [label="connected"]
P1_allocate [shape="box" label="RC.tx_allocate"] P1_allocate [shape="box" label="RC.tx_allocate"]
@ -42,5 +47,7 @@ digraph {
S1B -> S1A [label="lost"] S1B -> S1A [label="lost"]
P1_generate [shape="box" label="generate\nrandom code"] P1_generate [shape="box" label="generate\nrandom code"]
P1_generate -> P0_got_code P1_generate -> P0_got_code
S0B -> P1_allocate [label="allocate_code"]
} }

View File

@ -34,7 +34,9 @@ class Code(object):
self._L = _interfaces.ILister(lister) self._L = _interfaces.ILister(lister)
@m.state(initial=True) @m.state(initial=True)
def S0_unknown(self): pass # pragma: no cover def S0A_unknown(self): pass # pragma: no cover
@m.state()
def S0B_unknown_connected(self): pass # pragma: no cover
@m.state() @m.state()
def S1A_connecting(self): pass # pragma: no cover def S1A_connecting(self): pass # pragma: no cover
@m.state() @m.state()
@ -82,6 +84,10 @@ class Code(object):
self._stdio = stdio self._stdio = stdio
self._L.refresh_nameplates() self._L.refresh_nameplates()
@m.output() @m.output()
def stash_code_length_and_RC_tx_allocate(self, code_length):
self._code_length = code_length
self._RC.tx_allocate()
@m.output()
def stash_code_length(self, code_length): def stash_code_length(self, code_length):
self._code_length = code_length self._code_length = code_length
@m.output() @m.output()
@ -113,18 +119,26 @@ class Code(object):
def _B_got_code(self): def _B_got_code(self):
self._B.got_code(self._code) self._B.got_code(self._code)
S0_unknown.upon(set_code, enter=S4_known, outputs=[B_got_code]) S0A_unknown.upon(connected, enter=S0B_unknown_connected, outputs=[])
S0B_unknown_connected.upon(lost, enter=S0A_unknown, outputs=[])
S0_unknown.upon(allocate_code, enter=S1A_connecting, S0A_unknown.upon(set_code, enter=S4_known, outputs=[B_got_code])
outputs=[stash_code_length]) S0B_unknown_connected.upon(set_code, enter=S4_known, outputs=[B_got_code])
S0A_unknown.upon(allocate_code, enter=S1A_connecting,
outputs=[stash_code_length])
S0B_unknown_connected.upon(allocate_code, enter=S1B_allocating,
outputs=[stash_code_length_and_RC_tx_allocate])
S1A_connecting.upon(connected, enter=S1B_allocating, S1A_connecting.upon(connected, enter=S1B_allocating,
outputs=[RC_tx_allocate]) outputs=[RC_tx_allocate])
S1B_allocating.upon(lost, enter=S1A_connecting, outputs=[]) S1B_allocating.upon(lost, enter=S1A_connecting, outputs=[])
S1B_allocating.upon(rx_allocated, enter=S4_known, S1B_allocating.upon(rx_allocated, enter=S4_known,
outputs=[generate_and_B_got_code]) outputs=[generate_and_B_got_code])
S0_unknown.upon(input_code, enter=S2_typing_nameplate, S0A_unknown.upon(input_code, enter=S2_typing_nameplate,
outputs=[start_input_and_L_refresh_nameplates]) outputs=[start_input_and_L_refresh_nameplates])
S0B_unknown_connected.upon(input_code, enter=S2_typing_nameplate,
outputs=[start_input_and_L_refresh_nameplates])
S2_typing_nameplate.upon(tab, enter=S2_typing_nameplate, S2_typing_nameplate.upon(tab, enter=S2_typing_nameplate,
outputs=[do_completion_nameplates]) outputs=[do_completion_nameplates])
S2_typing_nameplate.upon(got_nameplates, enter=S2_typing_nameplate, S2_typing_nameplate.upon(got_nameplates, enter=S2_typing_nameplate,

View File

@ -2,8 +2,9 @@ from __future__ import print_function, unicode_literals
import json import json
from zope.interface import directlyProvides from zope.interface import directlyProvides
from twisted.trial import unittest from twisted.trial import unittest
from .. import timing, _order, _receive, _key from .. import timing, _order, _receive, _key, _code
from .._interfaces import IKey, IReceive, IBoss, ISend, IMailbox from .._interfaces import (IKey, IReceive, IBoss, ISend, IMailbox,
IRendezvousConnector, ILister)
from .._key import derive_key, derive_phase_key, encrypt_data from .._key import derive_key, derive_phase_key, encrypt_data
from ..util import dict_to_bytes, hexstr_to_bytes, bytes_to_hexstr, to_bytes from ..util import dict_to_bytes, hexstr_to_bytes, bytes_to_hexstr, to_bytes
from spake2 import SPAKE2_Symmetric from spake2 import SPAKE2_Symmetric
@ -174,7 +175,6 @@ class Key(unittest.TestCase):
self.assertEqual(events[2][:2], ("m.add_message", "version")) self.assertEqual(events[2][:2], ("m.add_message", "version"))
self.assertEqual(events[3], ("r.got_key", key2)) self.assertEqual(events[3], ("r.got_key", key2))
def test_bad(self): def test_bad(self):
k, b, m, r, events = self.build() k, b, m, r, events = self.build()
code = u"1-foo" code = u"1-foo"
@ -188,3 +188,58 @@ class Key(unittest.TestCase):
bad_pake_d = {"not_pake_v1": "stuff"} bad_pake_d = {"not_pake_v1": "stuff"}
k.got_pake(dict_to_bytes(bad_pake_d)) k.got_pake(dict_to_bytes(bad_pake_d))
self.assertEqual(events, [("b.scared",)]) self.assertEqual(events, [("b.scared",)])
class Code(unittest.TestCase):
def build(self):
events = []
c = _code.Code(timing.DebugTiming())
b = Dummy("b", events, IBoss, "got_code")
rc = Dummy("rc", events, IRendezvousConnector, "tx_allocate")
l = Dummy("l", events, ILister, "refresh_nameplates")
c.wire(b, rc, l)
return c, b, rc, l, events
def test_set_disconnected(self):
c, b, rc, l, events = self.build()
c.set_code(u"code")
self.assertEqual(events, [("b.got_code", u"code")])
def test_set_connected(self):
c, b, rc, l, events = self.build()
c.connected()
c.set_code(u"code")
self.assertEqual(events, [("b.got_code", u"code")])
def test_allocate_disconnected(self):
c, b, rc, l, events = self.build()
c.allocate_code(2)
self.assertEqual(events, [])
c.connected()
self.assertEqual(events, [("rc.tx_allocate",)])
events[:] = []
c.lost()
self.assertEqual(events, [])
c.connected()
self.assertEqual(events, [("rc.tx_allocate",)])
events[:] = []
c.rx_allocated("4")
self.assertEqual(len(events), 1, events)
self.assertEqual(events[0][0], "b.got_code")
code = events[0][1]
self.assert_(code.startswith("4-"), code)
def test_allocate_connected(self):
c, b, rc, l, events = self.build()
c.connected()
c.allocate_code(2)
self.assertEqual(events, [("rc.tx_allocate",)])
events[:] = []
c.rx_allocated("4")
self.assertEqual(len(events), 1, events)
self.assertEqual(events[0][0], "b.got_code")
code = events[0][1]
self.assert_(code.startswith("4-"), code)
# TODO: input_code