fix Noise handshake ordering
I mistakenly believed that Noise handshakes are simultaneous. In fact, the Responder waits until it sees the Initiator's handshake before sending its own. I had to update the Connection state machines to work this way (the Record machine now has set_role_leader and set_role_follower), and update the tests to match. For debugging I added a `_role` property to Record, but it should probably be removed.
This commit is contained in:
parent
d1ff97f988
commit
53ffbe1632
|
@ -4,6 +4,12 @@ except ImportError:
|
|||
class NoiseInvalidMessage(Exception):
|
||||
pass
|
||||
|
||||
try:
|
||||
from noise.exceptions import NoiseHandshakeError
|
||||
except ImportError:
|
||||
class NoiseHandshakeError(Exception):
|
||||
pass
|
||||
|
||||
try:
|
||||
from noise.connection import NoiseConnection
|
||||
except ImportError:
|
||||
|
|
|
@ -11,8 +11,8 @@ from twisted.internet.interfaces import ITransport
|
|||
from .._interfaces import IDilationConnector
|
||||
from ..observer import OneShotObserver
|
||||
from .encode import to_be4, from_be4
|
||||
from .roles import FOLLOWER
|
||||
from ._noise import NoiseInvalidMessage
|
||||
from .roles import LEADER, FOLLOWER
|
||||
from ._noise import NoiseInvalidMessage, NoiseHandshakeError
|
||||
|
||||
# InboundFraming is given data and returns Frames (Noise wire-side
|
||||
# bytestrings). It handles the relay handshake and the prologue. The Frames it
|
||||
|
@ -56,6 +56,23 @@ def first(l):
|
|||
class Disconnect(Exception):
|
||||
pass
|
||||
|
||||
# all connections look like:
|
||||
# (step 1: only for outbound connections)
|
||||
# 1: if we're connecting to a transit relay:
|
||||
# * send "sided relay handshake": "please relay TOKEN for side SIDE\n"
|
||||
# * the relay will send "ok\n" if/when our peer connects
|
||||
# * a non-relay will probably send junk
|
||||
# * wait for "ok\n", hang up if we get anything different
|
||||
# (all subsequent steps are for both inbound and outbound connections)
|
||||
# 2: send PROLOGUE_LEADER/FOLLOWER: "Magic-Wormhole Dilation Handshale v1 (l/f)\n\n"
|
||||
# 3: wait for the opposite PROLOGUE string, else hang up
|
||||
# (everything past this point is a Frame, with be4 length prefix. Frames are
|
||||
# either noise handshake or an encrypted message)
|
||||
# 4: if LEADER, send noise handshake string. if FOLLOWER, wait for it
|
||||
# 5: if FOLLOWER, send noise response string. if LEADER, wait for it
|
||||
# 6: ...
|
||||
|
||||
|
||||
|
||||
RelayOK = namedtuple("RelayOk", [])
|
||||
Prologue = namedtuple("Prologue", [])
|
||||
|
@ -193,7 +210,7 @@ class _Framer(object):
|
|||
def add_and_parse(self, data):
|
||||
# we can't make this an @m.input because we can't change the state
|
||||
# from within an input. Instead, let the state choose the parser to
|
||||
# use, and use the parsed token drive a state transition.
|
||||
# use, then use the parsed token to drive a state transition.
|
||||
self._buffer += data
|
||||
while True:
|
||||
# it'd be nice to use an iterator here, but since self.parse()
|
||||
|
@ -302,11 +319,16 @@ def encode_record(r):
|
|||
raise TypeError(r)
|
||||
|
||||
|
||||
def _is_role(_record, _attr, value):
|
||||
if value not in [LEADER, FOLLOWER]:
|
||||
raise ValueError("role must be LEADER or FOLLOWER")
|
||||
|
||||
@attrs
|
||||
@implementer(IRecord)
|
||||
class _Record(object):
|
||||
_framer = attrib(validator=provides(IFramer))
|
||||
_noise = attrib()
|
||||
_role = attrib(default="unspecified", validator=_is_role) # for debugging
|
||||
|
||||
n = MethodicalMachine()
|
||||
# TODO: set_trace
|
||||
|
@ -321,17 +343,37 @@ class _Record(object):
|
|||
# states: want_prologue, want_handshake, want_record
|
||||
|
||||
@n.state(initial=True)
|
||||
def want_prologue(self):
|
||||
def no_role_set(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@n.state()
|
||||
def want_handshake(self):
|
||||
def want_prologue_leader(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@n.state()
|
||||
def want_prologue_follower(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@n.state()
|
||||
def want_handshake_leader(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@n.state()
|
||||
def want_handshake_follower(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@n.state()
|
||||
def want_message(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
@n.input()
|
||||
def set_role_leader(self):
|
||||
pass
|
||||
|
||||
@n.input()
|
||||
def set_role_follower(self):
|
||||
pass
|
||||
|
||||
@n.input()
|
||||
def got_prologue(self):
|
||||
pass
|
||||
|
@ -340,9 +382,20 @@ class _Record(object):
|
|||
def got_frame(self, frame):
|
||||
pass
|
||||
|
||||
@n.output()
|
||||
def ignore_and_send_handshake(self, frame):
|
||||
self._send_handshake()
|
||||
|
||||
@n.output()
|
||||
def send_handshake(self):
|
||||
handshake = self._noise.write_message() # generate the ephemeral key
|
||||
self._send_handshake()
|
||||
|
||||
def _send_handshake(self):
|
||||
try:
|
||||
handshake = self._noise.write_message() # generate the ephemeral key
|
||||
except NoiseHandshakeError as e:
|
||||
log.err(e, "noise error during handshake")
|
||||
raise
|
||||
self._framer.send_frame(handshake)
|
||||
|
||||
@n.output()
|
||||
|
@ -367,10 +420,19 @@ class _Record(object):
|
|||
raise Disconnect()
|
||||
return parse_record(message)
|
||||
|
||||
want_prologue.upon(got_prologue, outputs=[send_handshake],
|
||||
enter=want_handshake)
|
||||
want_handshake.upon(got_frame, outputs=[process_handshake],
|
||||
collector=first, enter=want_message)
|
||||
no_role_set.upon(set_role_leader, outputs=[], enter=want_prologue_leader)
|
||||
want_prologue_leader.upon(got_prologue, outputs=[send_handshake],
|
||||
enter=want_handshake_leader)
|
||||
want_handshake_leader.upon(got_frame, outputs=[process_handshake],
|
||||
collector=first, enter=want_message)
|
||||
|
||||
no_role_set.upon(set_role_follower, outputs=[], enter=want_prologue_follower)
|
||||
want_prologue_follower.upon(got_prologue, outputs=[],
|
||||
enter=want_handshake_follower)
|
||||
want_handshake_follower.upon(got_frame, outputs=[process_handshake,
|
||||
ignore_and_send_handshake],
|
||||
collector=first, enter=want_message)
|
||||
|
||||
want_message.upon(got_frame, outputs=[decrypt_message],
|
||||
collector=first, enter=want_message)
|
||||
|
||||
|
@ -493,12 +555,20 @@ class DilatedConnectionProtocol(Protocol, object):
|
|||
# IProtocol methods
|
||||
|
||||
def connectionMade(self):
|
||||
framer = _Framer(self.transport,
|
||||
self._outbound_prologue, self._inbound_prologue)
|
||||
if self._use_relay:
|
||||
framer.use_relay(self._relay_handshake)
|
||||
self._record = _Record(framer, self._noise)
|
||||
self._record.connectionMade()
|
||||
try:
|
||||
framer = _Framer(self.transport,
|
||||
self._outbound_prologue, self._inbound_prologue)
|
||||
if self._use_relay:
|
||||
framer.use_relay(self._relay_handshake)
|
||||
self._record = _Record(framer, self._noise, self._role)
|
||||
if self._role is LEADER:
|
||||
self._record.set_role_leader()
|
||||
else:
|
||||
self._record.set_role_follower()
|
||||
self._record.connectionMade()
|
||||
except:
|
||||
log.err()
|
||||
raise
|
||||
|
||||
def dataReceived(self, data):
|
||||
try:
|
||||
|
|
|
@ -69,10 +69,20 @@ class Connection(unittest.TestCase):
|
|||
clear_mock_calls(n, connector, t, m)
|
||||
|
||||
c.dataReceived(b"inbound_prologue\n")
|
||||
self.assertEqual(n.mock_calls, [mock.call.write_message()])
|
||||
self.assertEqual(connector.mock_calls, [])
|
||||
|
||||
exp_handshake = b"\x00\x00\x00\x09handshake"
|
||||
self.assertEqual(t.mock_calls, [mock.call.write(exp_handshake)])
|
||||
if role is LEADER:
|
||||
# the LEADER sends the Noise handshake message immediately upon
|
||||
# receipt of the prologue
|
||||
self.assertEqual(n.mock_calls, [mock.call.write_message()])
|
||||
self.assertEqual(t.mock_calls, [mock.call.write(exp_handshake)])
|
||||
else:
|
||||
# however the FOLLOWER waits until receiving the leader's
|
||||
# handshake before sending their own
|
||||
self.assertEqual(n.mock_calls, [])
|
||||
self.assertEqual(t.mock_calls, [])
|
||||
self.assertEqual(connector.mock_calls, [])
|
||||
|
||||
clear_mock_calls(n, connector, t, m)
|
||||
|
||||
c.dataReceived(b"\x00\x00\x00\x0Ahandshake2")
|
||||
|
@ -84,13 +94,16 @@ class Connection(unittest.TestCase):
|
|||
self.assertEqual(t.mock_calls, [])
|
||||
self.assertEqual(c._manager, None)
|
||||
else:
|
||||
# we're the follower, so we encrypt and send the KCM immediately
|
||||
# we're the follower, so we send our Noise handshake, then
|
||||
# encrypt and send the KCM immediately
|
||||
self.assertEqual(n.mock_calls, [
|
||||
mock.call.read_message(b"handshake2"),
|
||||
mock.call.write_message(),
|
||||
mock.call.encrypt(encode_record(t_kcm)),
|
||||
])
|
||||
self.assertEqual(connector.mock_calls, [])
|
||||
self.assertEqual(t.mock_calls, [
|
||||
mock.call.write(exp_handshake),
|
||||
mock.call.write(exp_kcm)])
|
||||
self.assertEqual(c._manager, None)
|
||||
clear_mock_calls(n, connector, t, m)
|
||||
|
|
|
@ -6,13 +6,15 @@ from ..._dilation._noise import NoiseInvalidMessage
|
|||
from ..._dilation.connection import (IFramer, Frame, Prologue,
|
||||
_Record, Handshake,
|
||||
Disconnect, Ping)
|
||||
from ..._dilation.roles import LEADER
|
||||
|
||||
|
||||
def make_record():
|
||||
f = mock.Mock()
|
||||
alsoProvides(f, IFramer)
|
||||
n = mock.Mock() # pretends to be a Noise object
|
||||
r = _Record(f, n)
|
||||
r = _Record(f, n, LEADER)
|
||||
r.set_role_leader()
|
||||
return r, f, n
|
||||
|
||||
|
||||
|
@ -30,7 +32,8 @@ class Record(unittest.TestCase):
|
|||
n.write_message = mock.Mock(return_value=b"tx-handshake")
|
||||
p1, p2 = object(), object()
|
||||
n.decrypt = mock.Mock(side_effect=[p1, p2])
|
||||
r = _Record(f, n)
|
||||
r = _Record(f, n, LEADER)
|
||||
r.set_role_leader()
|
||||
self.assertEqual(f.mock_calls, [])
|
||||
r.connectionMade()
|
||||
self.assertEqual(f.mock_calls, [mock.call.connectionMade()])
|
||||
|
@ -79,7 +82,8 @@ class Record(unittest.TestCase):
|
|||
n.write_message = mock.Mock(return_value=b"tx-handshake")
|
||||
nvm = NoiseInvalidMessage()
|
||||
n.read_message = mock.Mock(side_effect=nvm)
|
||||
r = _Record(f, n)
|
||||
r = _Record(f, n, LEADER)
|
||||
r.set_role_leader()
|
||||
self.assertEqual(f.mock_calls, [])
|
||||
r.connectionMade()
|
||||
self.assertEqual(f.mock_calls, [mock.call.connectionMade()])
|
||||
|
@ -103,7 +107,8 @@ class Record(unittest.TestCase):
|
|||
n.write_message = mock.Mock(return_value=b"tx-handshake")
|
||||
nvm = NoiseInvalidMessage()
|
||||
n.decrypt = mock.Mock(side_effect=nvm)
|
||||
r = _Record(f, n)
|
||||
r = _Record(f, n, LEADER)
|
||||
r.set_role_leader()
|
||||
self.assertEqual(f.mock_calls, [])
|
||||
r.connectionMade()
|
||||
self.assertEqual(f.mock_calls, [mock.call.connectionMade()])
|
||||
|
@ -124,7 +129,8 @@ class Record(unittest.TestCase):
|
|||
f1 = object()
|
||||
n.encrypt = mock.Mock(return_value=f1)
|
||||
r1 = Ping(b"pingid")
|
||||
r = _Record(f, n)
|
||||
r = _Record(f, n, LEADER)
|
||||
r.set_role_leader()
|
||||
self.assertEqual(f.mock_calls, [])
|
||||
m1 = object()
|
||||
with mock.patch("wormhole._dilation.connection.encode_record",
|
||||
|
|
Loading…
Reference in New Issue
Block a user