diff --git a/docs/state-machines/key.dot b/docs/state-machines/key.dot index b6d49f8..01f5f2d 100644 --- a/docs/state-machines/key.dot +++ b/docs/state-machines/key.dot @@ -10,6 +10,26 @@ digraph { start [label="Key\nMachine" style="dotted"] + /* two connected state machines: the first just puts the messages in + the right order, the second handles PAKE */ + + {rank=same; SO_00 PO_got_code SO_10} + {rank=same; SO_01 PO_got_both SO_11} + SO_00 [label="S00"] + SO_01 [label="S01: pake"] + SO_10 [label="S10: code"] + SO_11 [label="S11: both"] + SO_00 -> SO_01 [label="got_pake\n(early)"] + SO_00 -> PO_got_code [label="got_code"] + PO_got_code [shape="box" label="K1.got_code"] + PO_got_code -> SO_10 + SO_01 -> PO_got_both [label="got_code"] + PO_got_both [shape="box" label="K1.got_code\nK1.got_pake"] + PO_got_both -> SO_11 + SO_10 -> PO_got_pake [label="got_pake"] + PO_got_pake [shape="box" label="K1.got_pake"] + PO_got_pake -> SO_11 + S0 [label="S0: know\nnothing"] S0 -> P0_build [label="got_code"] @@ -30,14 +50,14 @@ digraph { S1 -> P_mood_scary [label="got_pake\npake bad"] P_mood_scary [shape="box" color="red" label="W.scared"] - P_mood_scary -> S3 [color="red"] - S3 [label="S3:\nscared" color="red"] + P_mood_scary -> S5 [color="red"] + S5 [label="S5:\nscared" color="red"] S1 -> P1_compute [label="got_pake\npake good"] #S1 -> P_mood_lonely [label="close"] P1_compute [label="compute_key\nM.add_message(version)\nB.got_key\nB.got_verifier\nR.got_key" shape="box"] - P1_compute -> S2 + P1_compute -> S4 - S2 [label="S2: know_key" color="green"] + S4 [label="S4: know_key" color="green"] } diff --git a/src/wormhole/_boss.py b/src/wormhole/_boss.py index a101cb3..a48ad11 100644 --- a/src/wormhole/_boss.py +++ b/src/wormhole/_boss.py @@ -88,7 +88,7 @@ class Boss(object): def _set_trace(self, client_name, which, logger): names = {"B": self, "N": self._N, "M": self._M, "S": self._S, - "O": self._O, "K": self._K, "R": self._R, + "O": self._O, "K": self._K, "SK": self._K._SK, "R": self._R, "RC": self._RC, "L": self._L, "C": self._C, "T": self._T} for machine in which.split(): diff --git a/src/wormhole/_key.py b/src/wormhole/_key.py index 2436de4..330833f 100644 --- a/src/wormhole/_key.py +++ b/src/wormhole/_key.py @@ -52,9 +52,63 @@ def encrypt_data(key, plaintext): nonce = utils.random(SecretBox.NONCE_SIZE) return box.encrypt(plaintext, nonce) +# the Key we expose to callers (Boss, Ordering) is responsible for sorting +# the two messages (got_code and got_pake), then delivering them to +# _SortedKey in the right order. + @attrs @implementer(_interfaces.IKey) class Key(object): + _appid = attrib(validator=instance_of(type(u""))) + _versions = attrib(validator=instance_of(dict)) + _side = attrib(validator=instance_of(type(u""))) + _timing = attrib(validator=provides(_interfaces.ITiming)) + m = MethodicalMachine() + @m.setTrace() + def _set_trace(): pass # pragma: no cover + + def __attrs_post_init__(self): + self._SK = _SortedKey(self._appid, self._versions, self._side, + self._timing) + + def wire(self, boss, mailbox, receive): + self._SK.wire(boss, mailbox, receive) + + @m.state(initial=True) + def S00(self): pass # pragma: no cover + @m.state() + def S01(self): pass # pragma: no cover + @m.state() + def S10(self): pass # pragma: no cover + @m.state() + def S11(self): pass # pragma: no cover + + @m.input() + def got_code(self, code): pass + @m.input() + def got_pake(self, body): pass + + @m.output() + def stash_pake(self, body): + self._pake = body + @m.output() + def deliver_code(self, code): + self._SK.got_code(code) + @m.output() + def deliver_pake(self, body): + self._SK.got_pake(body) + @m.output() + def deliver_code_and_stashed_pake(self, code): + self._SK.got_code(code) + self._SK.got_pake(self._pake) + + S00.upon(got_code, enter=S10, outputs=[deliver_code]) + S10.upon(got_pake, enter=S11, outputs=[deliver_pake]) + S00.upon(got_pake, enter=S01, outputs=[stash_pake]) + S01.upon(got_code, enter=S11, outputs=[deliver_code_and_stashed_pake]) + +@attrs +class _SortedKey(object): _appid = attrib(validator=instance_of(type(u""))) _versions = attrib(validator=instance_of(dict)) _side = attrib(validator=instance_of(type(u"")))