fix Noise handshake ordering

I mistakenly believed that Noise handshakes are simultaneous. In fact, the
Responder waits until it sees the Initiator's handshake before sending its
own. I had to update the Connection state machines to work this way (the
Record machine now has set_role_leader and set_role_follower), and update the
tests to match.

For debugging I added a `_role` property to Record, but it should probably be
removed.
This commit is contained in:
Brian Warner 2019-02-03 17:06:20 -08:00
parent d1ff97f988
commit 53ffbe1632
4 changed files with 120 additions and 25 deletions

View File

@ -4,6 +4,12 @@ except ImportError:
class NoiseInvalidMessage(Exception):
pass
try:
from noise.exceptions import NoiseHandshakeError
except ImportError:
class NoiseHandshakeError(Exception):
pass
try:
from noise.connection import NoiseConnection
except ImportError:

View File

@ -11,8 +11,8 @@ from twisted.internet.interfaces import ITransport
from .._interfaces import IDilationConnector
from ..observer import OneShotObserver
from .encode import to_be4, from_be4
from .roles import FOLLOWER
from ._noise import NoiseInvalidMessage
from .roles import LEADER, FOLLOWER
from ._noise import NoiseInvalidMessage, NoiseHandshakeError
# InboundFraming is given data and returns Frames (Noise wire-side
# bytestrings). It handles the relay handshake and the prologue. The Frames it
@ -56,6 +56,23 @@ def first(l):
class Disconnect(Exception):
pass
# all connections look like:
# (step 1: only for outbound connections)
# 1: if we're connecting to a transit relay:
# * send "sided relay handshake": "please relay TOKEN for side SIDE\n"
# * the relay will send "ok\n" if/when our peer connects
# * a non-relay will probably send junk
# * wait for "ok\n", hang up if we get anything different
# (all subsequent steps are for both inbound and outbound connections)
# 2: send PROLOGUE_LEADER/FOLLOWER: "Magic-Wormhole Dilation Handshale v1 (l/f)\n\n"
# 3: wait for the opposite PROLOGUE string, else hang up
# (everything past this point is a Frame, with be4 length prefix. Frames are
# either noise handshake or an encrypted message)
# 4: if LEADER, send noise handshake string. if FOLLOWER, wait for it
# 5: if FOLLOWER, send noise response string. if LEADER, wait for it
# 6: ...
RelayOK = namedtuple("RelayOk", [])
Prologue = namedtuple("Prologue", [])
@ -193,7 +210,7 @@ class _Framer(object):
def add_and_parse(self, data):
# we can't make this an @m.input because we can't change the state
# from within an input. Instead, let the state choose the parser to
# use, and use the parsed token drive a state transition.
# use, then use the parsed token to drive a state transition.
self._buffer += data
while True:
# it'd be nice to use an iterator here, but since self.parse()
@ -302,11 +319,16 @@ def encode_record(r):
raise TypeError(r)
def _is_role(_record, _attr, value):
if value not in [LEADER, FOLLOWER]:
raise ValueError("role must be LEADER or FOLLOWER")
@attrs
@implementer(IRecord)
class _Record(object):
_framer = attrib(validator=provides(IFramer))
_noise = attrib()
_role = attrib(default="unspecified", validator=_is_role) # for debugging
n = MethodicalMachine()
# TODO: set_trace
@ -321,17 +343,37 @@ class _Record(object):
# states: want_prologue, want_handshake, want_record
@n.state(initial=True)
def want_prologue(self):
def no_role_set(self):
pass # pragma: no cover
@n.state()
def want_handshake(self):
def want_prologue_leader(self):
pass # pragma: no cover
@n.state()
def want_prologue_follower(self):
pass # pragma: no cover
@n.state()
def want_handshake_leader(self):
pass # pragma: no cover
@n.state()
def want_handshake_follower(self):
pass # pragma: no cover
@n.state()
def want_message(self):
pass # pragma: no cover
@n.input()
def set_role_leader(self):
pass
@n.input()
def set_role_follower(self):
pass
@n.input()
def got_prologue(self):
pass
@ -340,9 +382,20 @@ class _Record(object):
def got_frame(self, frame):
pass
@n.output()
def ignore_and_send_handshake(self, frame):
self._send_handshake()
@n.output()
def send_handshake(self):
handshake = self._noise.write_message() # generate the ephemeral key
self._send_handshake()
def _send_handshake(self):
try:
handshake = self._noise.write_message() # generate the ephemeral key
except NoiseHandshakeError as e:
log.err(e, "noise error during handshake")
raise
self._framer.send_frame(handshake)
@n.output()
@ -367,10 +420,19 @@ class _Record(object):
raise Disconnect()
return parse_record(message)
want_prologue.upon(got_prologue, outputs=[send_handshake],
enter=want_handshake)
want_handshake.upon(got_frame, outputs=[process_handshake],
collector=first, enter=want_message)
no_role_set.upon(set_role_leader, outputs=[], enter=want_prologue_leader)
want_prologue_leader.upon(got_prologue, outputs=[send_handshake],
enter=want_handshake_leader)
want_handshake_leader.upon(got_frame, outputs=[process_handshake],
collector=first, enter=want_message)
no_role_set.upon(set_role_follower, outputs=[], enter=want_prologue_follower)
want_prologue_follower.upon(got_prologue, outputs=[],
enter=want_handshake_follower)
want_handshake_follower.upon(got_frame, outputs=[process_handshake,
ignore_and_send_handshake],
collector=first, enter=want_message)
want_message.upon(got_frame, outputs=[decrypt_message],
collector=first, enter=want_message)
@ -493,12 +555,20 @@ class DilatedConnectionProtocol(Protocol, object):
# IProtocol methods
def connectionMade(self):
framer = _Framer(self.transport,
self._outbound_prologue, self._inbound_prologue)
if self._use_relay:
framer.use_relay(self._relay_handshake)
self._record = _Record(framer, self._noise)
self._record.connectionMade()
try:
framer = _Framer(self.transport,
self._outbound_prologue, self._inbound_prologue)
if self._use_relay:
framer.use_relay(self._relay_handshake)
self._record = _Record(framer, self._noise, self._role)
if self._role is LEADER:
self._record.set_role_leader()
else:
self._record.set_role_follower()
self._record.connectionMade()
except:
log.err()
raise
def dataReceived(self, data):
try:

View File

@ -69,10 +69,20 @@ class Connection(unittest.TestCase):
clear_mock_calls(n, connector, t, m)
c.dataReceived(b"inbound_prologue\n")
self.assertEqual(n.mock_calls, [mock.call.write_message()])
self.assertEqual(connector.mock_calls, [])
exp_handshake = b"\x00\x00\x00\x09handshake"
self.assertEqual(t.mock_calls, [mock.call.write(exp_handshake)])
if role is LEADER:
# the LEADER sends the Noise handshake message immediately upon
# receipt of the prologue
self.assertEqual(n.mock_calls, [mock.call.write_message()])
self.assertEqual(t.mock_calls, [mock.call.write(exp_handshake)])
else:
# however the FOLLOWER waits until receiving the leader's
# handshake before sending their own
self.assertEqual(n.mock_calls, [])
self.assertEqual(t.mock_calls, [])
self.assertEqual(connector.mock_calls, [])
clear_mock_calls(n, connector, t, m)
c.dataReceived(b"\x00\x00\x00\x0Ahandshake2")
@ -84,13 +94,16 @@ class Connection(unittest.TestCase):
self.assertEqual(t.mock_calls, [])
self.assertEqual(c._manager, None)
else:
# we're the follower, so we encrypt and send the KCM immediately
# we're the follower, so we send our Noise handshake, then
# encrypt and send the KCM immediately
self.assertEqual(n.mock_calls, [
mock.call.read_message(b"handshake2"),
mock.call.write_message(),
mock.call.encrypt(encode_record(t_kcm)),
])
self.assertEqual(connector.mock_calls, [])
self.assertEqual(t.mock_calls, [
mock.call.write(exp_handshake),
mock.call.write(exp_kcm)])
self.assertEqual(c._manager, None)
clear_mock_calls(n, connector, t, m)

View File

@ -6,13 +6,15 @@ from ..._dilation._noise import NoiseInvalidMessage
from ..._dilation.connection import (IFramer, Frame, Prologue,
_Record, Handshake,
Disconnect, Ping)
from ..._dilation.roles import LEADER
def make_record():
f = mock.Mock()
alsoProvides(f, IFramer)
n = mock.Mock() # pretends to be a Noise object
r = _Record(f, n)
r = _Record(f, n, LEADER)
r.set_role_leader()
return r, f, n
@ -30,7 +32,8 @@ class Record(unittest.TestCase):
n.write_message = mock.Mock(return_value=b"tx-handshake")
p1, p2 = object(), object()
n.decrypt = mock.Mock(side_effect=[p1, p2])
r = _Record(f, n)
r = _Record(f, n, LEADER)
r.set_role_leader()
self.assertEqual(f.mock_calls, [])
r.connectionMade()
self.assertEqual(f.mock_calls, [mock.call.connectionMade()])
@ -79,7 +82,8 @@ class Record(unittest.TestCase):
n.write_message = mock.Mock(return_value=b"tx-handshake")
nvm = NoiseInvalidMessage()
n.read_message = mock.Mock(side_effect=nvm)
r = _Record(f, n)
r = _Record(f, n, LEADER)
r.set_role_leader()
self.assertEqual(f.mock_calls, [])
r.connectionMade()
self.assertEqual(f.mock_calls, [mock.call.connectionMade()])
@ -103,7 +107,8 @@ class Record(unittest.TestCase):
n.write_message = mock.Mock(return_value=b"tx-handshake")
nvm = NoiseInvalidMessage()
n.decrypt = mock.Mock(side_effect=nvm)
r = _Record(f, n)
r = _Record(f, n, LEADER)
r.set_role_leader()
self.assertEqual(f.mock_calls, [])
r.connectionMade()
self.assertEqual(f.mock_calls, [mock.call.connectionMade()])
@ -124,7 +129,8 @@ class Record(unittest.TestCase):
f1 = object()
n.encrypt = mock.Mock(return_value=f1)
r1 = Ping(b"pingid")
r = _Record(f, n)
r = _Record(f, n, LEADER)
r.set_role_leader()
self.assertEqual(f.mock_calls, [])
m1 = object()
with mock.patch("wormhole._dilation.connection.encode_record",