diff --git a/src/wormhole/_boss.py b/src/wormhole/_boss.py index 097efe7..686335f 100644 --- a/src/wormhole/_boss.py +++ b/src/wormhole/_boss.py @@ -252,11 +252,12 @@ class Boss(object): def scared(self): pass - def got_message(self, phase, plaintext): + def got_message(self, side, phase, plaintext): + # this is only called for side != ours assert isinstance(phase, type("")), type(phase) assert isinstance(plaintext, type(b"")), type(plaintext) if phase == "version": - self._got_version(plaintext) + self._got_version(side, plaintext) elif re.search(r'^\d+$', phase): self._got_phase(int(phase), plaintext) else: @@ -265,7 +266,7 @@ class Boss(object): log.err(_UnknownPhaseError("received unknown phase '%s'" % phase)) @m.input() - def _got_version(self, plaintext): + def _got_version(self, side, plaintext): pass @m.input() @@ -290,9 +291,10 @@ class Boss(object): self._W.got_code(code) @m.output() - def process_version(self, plaintext): + def process_version(self, side, plaintext): # most of this is wormhole-to-wormhole, ignored for now # in the future, this is how Dilation is signalled + self._their_side = side self._their_versions = bytes_to_dict(plaintext) # but this part is app-to-app app_versions = self._their_versions.get("app_versions", {}) diff --git a/src/wormhole/_receive.py b/src/wormhole/_receive.py index 8e9de4f..832dc44 100644 --- a/src/wormhole/_receive.py +++ b/src/wormhole/_receive.py @@ -53,10 +53,10 @@ class Receive(object): except CryptoError: self.got_message_bad() return - self.got_message_good(phase, plaintext) + self.got_message_good(side, phase, plaintext) @m.input() - def got_message_good(self, phase, plaintext): + def got_message_good(self, side, phase, plaintext): pass @m.input() @@ -73,23 +73,23 @@ class Receive(object): self._key = key @m.output() - def S_got_verified_key(self, phase, plaintext): + def S_got_verified_key(self, side, phase, plaintext): assert self._key self._S.got_verified_key(self._key) @m.output() - def W_happy(self, phase, plaintext): + def W_happy(self, side, phase, plaintext): self._B.happy() @m.output() - def W_got_verifier(self, phase, plaintext): + def W_got_verifier(self, side, phase, plaintext): self._B.got_verifier(derive_key(self._key, b"wormhole:verifier")) @m.output() - def W_got_message(self, phase, plaintext): + def W_got_message(self, side, phase, plaintext): assert isinstance(phase, type("")), type(phase) assert isinstance(plaintext, type(b"")), type(plaintext) - self._B.got_message(phase, plaintext) + self._B.got_message(side, phase, plaintext) @m.output() def W_scared(self): diff --git a/src/wormhole/test/test_machines.py b/src/wormhole/test/test_machines.py index 860b0b8..8a17b76 100644 --- a/src/wormhole/test/test_machines.py +++ b/src/wormhole/test/test_machines.py @@ -167,7 +167,7 @@ class Receive(unittest.TestCase): ("s.got_verified_key", key), ("b.happy", ), ("b.got_verifier", verifier), - ("b.got_message", u"phase1", data1), + ("b.got_message", u"side", u"phase1", data1), ]) phase2_key = derive_phase_key(key, u"side", u"phase2") @@ -178,8 +178,8 @@ class Receive(unittest.TestCase): ("s.got_verified_key", key), ("b.happy", ), ("b.got_verifier", verifier), - ("b.got_message", u"phase1", data1), - ("b.got_message", u"phase2", data2), + ("b.got_message", u"side", u"phase1", data1), + ("b.got_message", u"side", u"phase2", data2), ]) def test_early_bad(self): @@ -217,7 +217,7 @@ class Receive(unittest.TestCase): ("s.got_verified_key", key), ("b.happy", ), ("b.got_verifier", verifier), - ("b.got_message", u"phase1", data1), + ("b.got_message", u"side", u"phase1", data1), ]) phase2_key = derive_phase_key(key, u"side", u"bad") @@ -228,7 +228,7 @@ class Receive(unittest.TestCase): ("s.got_verified_key", key), ("b.happy", ), ("b.got_verifier", verifier), - ("b.got_message", u"phase1", data1), + ("b.got_message", u"side", u"phase1", data1), ("b.scared", ), ]) r.got_message(u"side", u"phase1", good_body) @@ -237,7 +237,7 @@ class Receive(unittest.TestCase): ("s.got_verified_key", key), ("b.happy", ), ("b.got_verifier", verifier), - ("b.got_message", u"phase1", data1), + ("b.got_message", u"side", u"phase1", data1), ("b.scared", ), ]) @@ -1320,8 +1320,8 @@ class Boss(unittest.TestCase): b.got_key(b"key") b.happy() b.got_verifier(b"verifier") - b.got_message("version", b"{}") - b.got_message("0", b"msg1") + b.got_message("side", "version", b"{}") + b.got_message("side", "0", b"msg1") self.assertEqual(events, [ ("w.got_key", b"key"), ("w.got_verifier", b"verifier"), @@ -1477,7 +1477,7 @@ class Boss(unittest.TestCase): b.happy() # phase=version - b.got_message("unknown-phase", b"spooky") + b.got_message("side", "unknown-phase", b"spooky") self.assertEqual(events, []) self.flushLoggedErrors(errors._UnknownPhaseError)