From 53ffbe1632fbaab3715dbffe80c758d14ed2af78 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Sun, 3 Feb 2019 17:06:20 -0800 Subject: [PATCH] fix Noise handshake ordering I mistakenly believed that Noise handshakes are simultaneous. In fact, the Responder waits until it sees the Initiator's handshake before sending its own. I had to update the Connection state machines to work this way (the Record machine now has set_role_leader and set_role_follower), and update the tests to match. For debugging I added a `_role` property to Record, but it should probably be removed. --- src/wormhole/_dilation/_noise.py | 6 ++ src/wormhole/_dilation/connection.py | 102 +++++++++++++++++--- src/wormhole/test/dilate/test_connection.py | 21 +++- src/wormhole/test/dilate/test_record.py | 16 ++- 4 files changed, 120 insertions(+), 25 deletions(-) diff --git a/src/wormhole/_dilation/_noise.py b/src/wormhole/_dilation/_noise.py index bb4cf58..1005264 100644 --- a/src/wormhole/_dilation/_noise.py +++ b/src/wormhole/_dilation/_noise.py @@ -4,6 +4,12 @@ except ImportError: class NoiseInvalidMessage(Exception): pass +try: + from noise.exceptions import NoiseHandshakeError +except ImportError: + class NoiseHandshakeError(Exception): + pass + try: from noise.connection import NoiseConnection except ImportError: diff --git a/src/wormhole/_dilation/connection.py b/src/wormhole/_dilation/connection.py index c53f036..b11b7bc 100644 --- a/src/wormhole/_dilation/connection.py +++ b/src/wormhole/_dilation/connection.py @@ -11,8 +11,8 @@ from twisted.internet.interfaces import ITransport from .._interfaces import IDilationConnector from ..observer import OneShotObserver from .encode import to_be4, from_be4 -from .roles import FOLLOWER -from ._noise import NoiseInvalidMessage +from .roles import LEADER, FOLLOWER +from ._noise import NoiseInvalidMessage, NoiseHandshakeError # InboundFraming is given data and returns Frames (Noise wire-side # bytestrings). It handles the relay handshake and the prologue. The Frames it @@ -56,6 +56,23 @@ def first(l): class Disconnect(Exception): pass +# all connections look like: +# (step 1: only for outbound connections) +# 1: if we're connecting to a transit relay: +# * send "sided relay handshake": "please relay TOKEN for side SIDE\n" +# * the relay will send "ok\n" if/when our peer connects +# * a non-relay will probably send junk +# * wait for "ok\n", hang up if we get anything different +# (all subsequent steps are for both inbound and outbound connections) +# 2: send PROLOGUE_LEADER/FOLLOWER: "Magic-Wormhole Dilation Handshale v1 (l/f)\n\n" +# 3: wait for the opposite PROLOGUE string, else hang up +# (everything past this point is a Frame, with be4 length prefix. Frames are +# either noise handshake or an encrypted message) +# 4: if LEADER, send noise handshake string. if FOLLOWER, wait for it +# 5: if FOLLOWER, send noise response string. if LEADER, wait for it +# 6: ... + + RelayOK = namedtuple("RelayOk", []) Prologue = namedtuple("Prologue", []) @@ -193,7 +210,7 @@ class _Framer(object): def add_and_parse(self, data): # we can't make this an @m.input because we can't change the state # from within an input. Instead, let the state choose the parser to - # use, and use the parsed token drive a state transition. + # use, then use the parsed token to drive a state transition. self._buffer += data while True: # it'd be nice to use an iterator here, but since self.parse() @@ -302,11 +319,16 @@ def encode_record(r): raise TypeError(r) +def _is_role(_record, _attr, value): + if value not in [LEADER, FOLLOWER]: + raise ValueError("role must be LEADER or FOLLOWER") + @attrs @implementer(IRecord) class _Record(object): _framer = attrib(validator=provides(IFramer)) _noise = attrib() + _role = attrib(default="unspecified", validator=_is_role) # for debugging n = MethodicalMachine() # TODO: set_trace @@ -321,17 +343,37 @@ class _Record(object): # states: want_prologue, want_handshake, want_record @n.state(initial=True) - def want_prologue(self): + def no_role_set(self): pass # pragma: no cover @n.state() - def want_handshake(self): + def want_prologue_leader(self): + pass # pragma: no cover + + @n.state() + def want_prologue_follower(self): + pass # pragma: no cover + + @n.state() + def want_handshake_leader(self): + pass # pragma: no cover + + @n.state() + def want_handshake_follower(self): pass # pragma: no cover @n.state() def want_message(self): pass # pragma: no cover + @n.input() + def set_role_leader(self): + pass + + @n.input() + def set_role_follower(self): + pass + @n.input() def got_prologue(self): pass @@ -340,9 +382,20 @@ class _Record(object): def got_frame(self, frame): pass + @n.output() + def ignore_and_send_handshake(self, frame): + self._send_handshake() + @n.output() def send_handshake(self): - handshake = self._noise.write_message() # generate the ephemeral key + self._send_handshake() + + def _send_handshake(self): + try: + handshake = self._noise.write_message() # generate the ephemeral key + except NoiseHandshakeError as e: + log.err(e, "noise error during handshake") + raise self._framer.send_frame(handshake) @n.output() @@ -367,10 +420,19 @@ class _Record(object): raise Disconnect() return parse_record(message) - want_prologue.upon(got_prologue, outputs=[send_handshake], - enter=want_handshake) - want_handshake.upon(got_frame, outputs=[process_handshake], - collector=first, enter=want_message) + no_role_set.upon(set_role_leader, outputs=[], enter=want_prologue_leader) + want_prologue_leader.upon(got_prologue, outputs=[send_handshake], + enter=want_handshake_leader) + want_handshake_leader.upon(got_frame, outputs=[process_handshake], + collector=first, enter=want_message) + + no_role_set.upon(set_role_follower, outputs=[], enter=want_prologue_follower) + want_prologue_follower.upon(got_prologue, outputs=[], + enter=want_handshake_follower) + want_handshake_follower.upon(got_frame, outputs=[process_handshake, + ignore_and_send_handshake], + collector=first, enter=want_message) + want_message.upon(got_frame, outputs=[decrypt_message], collector=first, enter=want_message) @@ -493,12 +555,20 @@ class DilatedConnectionProtocol(Protocol, object): # IProtocol methods def connectionMade(self): - framer = _Framer(self.transport, - self._outbound_prologue, self._inbound_prologue) - if self._use_relay: - framer.use_relay(self._relay_handshake) - self._record = _Record(framer, self._noise) - self._record.connectionMade() + try: + framer = _Framer(self.transport, + self._outbound_prologue, self._inbound_prologue) + if self._use_relay: + framer.use_relay(self._relay_handshake) + self._record = _Record(framer, self._noise, self._role) + if self._role is LEADER: + self._record.set_role_leader() + else: + self._record.set_role_follower() + self._record.connectionMade() + except: + log.err() + raise def dataReceived(self, data): try: diff --git a/src/wormhole/test/dilate/test_connection.py b/src/wormhole/test/dilate/test_connection.py index ee761fd..959c5cd 100644 --- a/src/wormhole/test/dilate/test_connection.py +++ b/src/wormhole/test/dilate/test_connection.py @@ -69,10 +69,20 @@ class Connection(unittest.TestCase): clear_mock_calls(n, connector, t, m) c.dataReceived(b"inbound_prologue\n") - self.assertEqual(n.mock_calls, [mock.call.write_message()]) - self.assertEqual(connector.mock_calls, []) + exp_handshake = b"\x00\x00\x00\x09handshake" - self.assertEqual(t.mock_calls, [mock.call.write(exp_handshake)]) + if role is LEADER: + # the LEADER sends the Noise handshake message immediately upon + # receipt of the prologue + self.assertEqual(n.mock_calls, [mock.call.write_message()]) + self.assertEqual(t.mock_calls, [mock.call.write(exp_handshake)]) + else: + # however the FOLLOWER waits until receiving the leader's + # handshake before sending their own + self.assertEqual(n.mock_calls, []) + self.assertEqual(t.mock_calls, []) + self.assertEqual(connector.mock_calls, []) + clear_mock_calls(n, connector, t, m) c.dataReceived(b"\x00\x00\x00\x0Ahandshake2") @@ -84,13 +94,16 @@ class Connection(unittest.TestCase): self.assertEqual(t.mock_calls, []) self.assertEqual(c._manager, None) else: - # we're the follower, so we encrypt and send the KCM immediately + # we're the follower, so we send our Noise handshake, then + # encrypt and send the KCM immediately self.assertEqual(n.mock_calls, [ mock.call.read_message(b"handshake2"), + mock.call.write_message(), mock.call.encrypt(encode_record(t_kcm)), ]) self.assertEqual(connector.mock_calls, []) self.assertEqual(t.mock_calls, [ + mock.call.write(exp_handshake), mock.call.write(exp_kcm)]) self.assertEqual(c._manager, None) clear_mock_calls(n, connector, t, m) diff --git a/src/wormhole/test/dilate/test_record.py b/src/wormhole/test/dilate/test_record.py index 63a784c..252a8b0 100644 --- a/src/wormhole/test/dilate/test_record.py +++ b/src/wormhole/test/dilate/test_record.py @@ -6,13 +6,15 @@ from ..._dilation._noise import NoiseInvalidMessage from ..._dilation.connection import (IFramer, Frame, Prologue, _Record, Handshake, Disconnect, Ping) +from ..._dilation.roles import LEADER def make_record(): f = mock.Mock() alsoProvides(f, IFramer) n = mock.Mock() # pretends to be a Noise object - r = _Record(f, n) + r = _Record(f, n, LEADER) + r.set_role_leader() return r, f, n @@ -30,7 +32,8 @@ class Record(unittest.TestCase): n.write_message = mock.Mock(return_value=b"tx-handshake") p1, p2 = object(), object() n.decrypt = mock.Mock(side_effect=[p1, p2]) - r = _Record(f, n) + r = _Record(f, n, LEADER) + r.set_role_leader() self.assertEqual(f.mock_calls, []) r.connectionMade() self.assertEqual(f.mock_calls, [mock.call.connectionMade()]) @@ -79,7 +82,8 @@ class Record(unittest.TestCase): n.write_message = mock.Mock(return_value=b"tx-handshake") nvm = NoiseInvalidMessage() n.read_message = mock.Mock(side_effect=nvm) - r = _Record(f, n) + r = _Record(f, n, LEADER) + r.set_role_leader() self.assertEqual(f.mock_calls, []) r.connectionMade() self.assertEqual(f.mock_calls, [mock.call.connectionMade()]) @@ -103,7 +107,8 @@ class Record(unittest.TestCase): n.write_message = mock.Mock(return_value=b"tx-handshake") nvm = NoiseInvalidMessage() n.decrypt = mock.Mock(side_effect=nvm) - r = _Record(f, n) + r = _Record(f, n, LEADER) + r.set_role_leader() self.assertEqual(f.mock_calls, []) r.connectionMade() self.assertEqual(f.mock_calls, [mock.call.connectionMade()]) @@ -124,7 +129,8 @@ class Record(unittest.TestCase): f1 = object() n.encrypt = mock.Mock(return_value=f1) r1 = Ping(b"pingid") - r = _Record(f, n) + r = _Record(f, n, LEADER) + r.set_role_leader() self.assertEqual(f.mock_calls, []) m1 = object() with mock.patch("wormhole._dilation.connection.encode_record",