fix basic test

This commit is contained in:
Brian Warner 2017-02-26 04:13:57 -08:00
parent 1beae97ec4
commit b0c9c9bb4c
3 changed files with 28 additions and 25 deletions

View File

@ -84,11 +84,11 @@ class Mailbox(object):
if side == self._side: if side == self._side:
self.rx_message_ours(phase, body) self.rx_message_ours(phase, body)
else: else:
self.rx_message_theirs(phase, body) self.rx_message_theirs(side, phase, body)
@m.input() @m.input()
def rx_message_ours(self, phase, body): pass def rx_message_ours(self, phase, body): pass
@m.input() @m.input()
def rx_message_theirs(self, phase, body): pass def rx_message_theirs(self, side, phase, body): pass
@m.input() @m.input()
def rx_closed(self): pass def rx_closed(self): pass
@ -129,9 +129,9 @@ class Mailbox(object):
assert isinstance(body, type(b"")), type(body) assert isinstance(body, type(b"")), type(body)
self._RC.tx_add(phase, body) self._RC.tx_add(phase, body)
@m.output() @m.output()
def N_release_and_accept(self, phase, body): def N_release_and_accept(self, side, phase, body):
self._N.release() self._N.release()
self._accept(phase, body) self._accept(side, phase, body)
@m.output() @m.output()
def RC_tx_close(self): def RC_tx_close(self):
assert self._mood assert self._mood
@ -139,12 +139,12 @@ class Mailbox(object):
def _RC_tx_close(self): def _RC_tx_close(self):
self._RC.tx_close(self._mailbox, self._mood) self._RC.tx_close(self._mailbox, self._mood)
@m.output() @m.output()
def accept(self, phase, body): def accept(self, side, phase, body):
self._accept(phase, body) self._accept(side, phase, body)
def _accept(self, phase, body): def _accept(self, side, phase, body):
if phase not in self._processed: if phase not in self._processed:
self._processed.add(phase) self._processed.add(phase)
self._O.got_message(phase, body) self._O.got_message(side, phase, body)
@m.output() @m.output()
def dequeue(self, phase, body): def dequeue(self, phase, body):
self._pending_outbound.pop(phase, None) self._pending_outbound.pop(phase, None)

View File

@ -24,41 +24,43 @@ class Order(object):
@m.state(terminal=True) @m.state(terminal=True)
def S1_yes_pake(self): pass def S1_yes_pake(self): pass
def got_message(self, phase, body): def got_message(self, side, phase, body):
#print("ORDER[%s].got_message(%s)" % (self._side, phase)) #print("ORDER[%s].got_message(%s)" % (self._side, phase))
assert isinstance(side, type("")), type(phase)
assert isinstance(phase, type("")), type(phase) assert isinstance(phase, type("")), type(phase)
assert isinstance(body, type(b"")), type(body) assert isinstance(body, type(b"")), type(body)
if phase == "pake": if phase == "pake":
self.got_pake(phase, body) self.got_pake(side, phase, body)
else: else:
self.got_non_pake(phase, body) self.got_non_pake(side, phase, body)
@m.input() @m.input()
def got_pake(self, phase, body): pass def got_pake(self, side, phase, body): pass
@m.input() @m.input()
def got_non_pake(self, phase, body): pass def got_non_pake(self, side, phase, body): pass
@m.output() @m.output()
def queue(self, phase, body): def queue(self, side, phase, body):
assert isinstance(side, type("")), type(phase)
assert isinstance(phase, type("")), type(phase) assert isinstance(phase, type("")), type(phase)
assert isinstance(body, type(b"")), type(body) assert isinstance(body, type(b"")), type(body)
self._queue.append((phase, body)) self._queue.append((side, phase, body))
@m.output() @m.output()
def notify_key(self, phase, body): def notify_key(self, side, phase, body):
self._K.got_pake(body) self._K.got_pake(body)
@m.output() @m.output()
def drain(self, phase, body): def drain(self, side, phase, body):
del phase del phase
del body del body
for (phase, body) in self._queue: for (side, phase, body) in self._queue:
self._deliver(phase, body) self._deliver(side, phase, body)
self._queue[:] = [] self._queue[:] = []
@m.output() @m.output()
def deliver(self, phase, body): def deliver(self, side, phase, body):
self._deliver(phase, body) self._deliver(side, phase, body)
def _deliver(self, phase, body): def _deliver(self, side, phase, body):
self._R.got_message(phase, body) self._R.got_message(side, phase, body)
S0_no_pake.upon(got_non_pake, enter=S0_no_pake, outputs=[queue]) S0_no_pake.upon(got_non_pake, enter=S0_no_pake, outputs=[queue])
S0_no_pake.upon(got_pake, enter=S1_yes_pake, outputs=[notify_key, drain]) S0_no_pake.upon(got_pake, enter=S1_yes_pake, outputs=[notify_key, drain])

View File

@ -31,11 +31,12 @@ class Receive(object):
def S3_scared(self): pass def S3_scared(self): pass
# from Ordering # from Ordering
def got_message(self, phase, body): def got_message(self, side, phase, body):
assert isinstance(side, type("")), type(phase)
assert isinstance(phase, type("")), type(phase) assert isinstance(phase, type("")), type(phase)
assert isinstance(body, type(b"")), type(body) assert isinstance(body, type(b"")), type(body)
assert self._key assert self._key
data_key = derive_phase_key(self._key, self._side, phase) data_key = derive_phase_key(self._key, side, phase)
try: try:
plaintext = decrypt_data(data_key, body) plaintext = decrypt_data(data_key, body)
except CryptoError: except CryptoError: