fix Noise handshake ordering
I mistakenly believed that Noise handshakes are simultaneous. In fact, the Responder waits until it sees the Initiator's handshake before sending its own. I had to update the Connection state machines to work this way (the Record machine now has set_role_leader and set_role_follower), and update the tests to match. For debugging I added a `_role` property to Record, but it should probably be removed.
This commit is contained in:
parent
d1ff97f988
commit
53ffbe1632
|
@ -4,6 +4,12 @@ except ImportError:
|
||||||
class NoiseInvalidMessage(Exception):
|
class NoiseInvalidMessage(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
from noise.exceptions import NoiseHandshakeError
|
||||||
|
except ImportError:
|
||||||
|
class NoiseHandshakeError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from noise.connection import NoiseConnection
|
from noise.connection import NoiseConnection
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
|
|
@ -11,8 +11,8 @@ from twisted.internet.interfaces import ITransport
|
||||||
from .._interfaces import IDilationConnector
|
from .._interfaces import IDilationConnector
|
||||||
from ..observer import OneShotObserver
|
from ..observer import OneShotObserver
|
||||||
from .encode import to_be4, from_be4
|
from .encode import to_be4, from_be4
|
||||||
from .roles import FOLLOWER
|
from .roles import LEADER, FOLLOWER
|
||||||
from ._noise import NoiseInvalidMessage
|
from ._noise import NoiseInvalidMessage, NoiseHandshakeError
|
||||||
|
|
||||||
# InboundFraming is given data and returns Frames (Noise wire-side
|
# InboundFraming is given data and returns Frames (Noise wire-side
|
||||||
# bytestrings). It handles the relay handshake and the prologue. The Frames it
|
# bytestrings). It handles the relay handshake and the prologue. The Frames it
|
||||||
|
@ -56,6 +56,23 @@ def first(l):
|
||||||
class Disconnect(Exception):
|
class Disconnect(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# all connections look like:
|
||||||
|
# (step 1: only for outbound connections)
|
||||||
|
# 1: if we're connecting to a transit relay:
|
||||||
|
# * send "sided relay handshake": "please relay TOKEN for side SIDE\n"
|
||||||
|
# * the relay will send "ok\n" if/when our peer connects
|
||||||
|
# * a non-relay will probably send junk
|
||||||
|
# * wait for "ok\n", hang up if we get anything different
|
||||||
|
# (all subsequent steps are for both inbound and outbound connections)
|
||||||
|
# 2: send PROLOGUE_LEADER/FOLLOWER: "Magic-Wormhole Dilation Handshale v1 (l/f)\n\n"
|
||||||
|
# 3: wait for the opposite PROLOGUE string, else hang up
|
||||||
|
# (everything past this point is a Frame, with be4 length prefix. Frames are
|
||||||
|
# either noise handshake or an encrypted message)
|
||||||
|
# 4: if LEADER, send noise handshake string. if FOLLOWER, wait for it
|
||||||
|
# 5: if FOLLOWER, send noise response string. if LEADER, wait for it
|
||||||
|
# 6: ...
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
RelayOK = namedtuple("RelayOk", [])
|
RelayOK = namedtuple("RelayOk", [])
|
||||||
Prologue = namedtuple("Prologue", [])
|
Prologue = namedtuple("Prologue", [])
|
||||||
|
@ -193,7 +210,7 @@ class _Framer(object):
|
||||||
def add_and_parse(self, data):
|
def add_and_parse(self, data):
|
||||||
# we can't make this an @m.input because we can't change the state
|
# we can't make this an @m.input because we can't change the state
|
||||||
# from within an input. Instead, let the state choose the parser to
|
# from within an input. Instead, let the state choose the parser to
|
||||||
# use, and use the parsed token drive a state transition.
|
# use, then use the parsed token to drive a state transition.
|
||||||
self._buffer += data
|
self._buffer += data
|
||||||
while True:
|
while True:
|
||||||
# it'd be nice to use an iterator here, but since self.parse()
|
# it'd be nice to use an iterator here, but since self.parse()
|
||||||
|
@ -302,11 +319,16 @@ def encode_record(r):
|
||||||
raise TypeError(r)
|
raise TypeError(r)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_role(_record, _attr, value):
|
||||||
|
if value not in [LEADER, FOLLOWER]:
|
||||||
|
raise ValueError("role must be LEADER or FOLLOWER")
|
||||||
|
|
||||||
@attrs
|
@attrs
|
||||||
@implementer(IRecord)
|
@implementer(IRecord)
|
||||||
class _Record(object):
|
class _Record(object):
|
||||||
_framer = attrib(validator=provides(IFramer))
|
_framer = attrib(validator=provides(IFramer))
|
||||||
_noise = attrib()
|
_noise = attrib()
|
||||||
|
_role = attrib(default="unspecified", validator=_is_role) # for debugging
|
||||||
|
|
||||||
n = MethodicalMachine()
|
n = MethodicalMachine()
|
||||||
# TODO: set_trace
|
# TODO: set_trace
|
||||||
|
@ -321,17 +343,37 @@ class _Record(object):
|
||||||
# states: want_prologue, want_handshake, want_record
|
# states: want_prologue, want_handshake, want_record
|
||||||
|
|
||||||
@n.state(initial=True)
|
@n.state(initial=True)
|
||||||
def want_prologue(self):
|
def no_role_set(self):
|
||||||
pass # pragma: no cover
|
pass # pragma: no cover
|
||||||
|
|
||||||
@n.state()
|
@n.state()
|
||||||
def want_handshake(self):
|
def want_prologue_leader(self):
|
||||||
|
pass # pragma: no cover
|
||||||
|
|
||||||
|
@n.state()
|
||||||
|
def want_prologue_follower(self):
|
||||||
|
pass # pragma: no cover
|
||||||
|
|
||||||
|
@n.state()
|
||||||
|
def want_handshake_leader(self):
|
||||||
|
pass # pragma: no cover
|
||||||
|
|
||||||
|
@n.state()
|
||||||
|
def want_handshake_follower(self):
|
||||||
pass # pragma: no cover
|
pass # pragma: no cover
|
||||||
|
|
||||||
@n.state()
|
@n.state()
|
||||||
def want_message(self):
|
def want_message(self):
|
||||||
pass # pragma: no cover
|
pass # pragma: no cover
|
||||||
|
|
||||||
|
@n.input()
|
||||||
|
def set_role_leader(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@n.input()
|
||||||
|
def set_role_follower(self):
|
||||||
|
pass
|
||||||
|
|
||||||
@n.input()
|
@n.input()
|
||||||
def got_prologue(self):
|
def got_prologue(self):
|
||||||
pass
|
pass
|
||||||
|
@ -340,9 +382,20 @@ class _Record(object):
|
||||||
def got_frame(self, frame):
|
def got_frame(self, frame):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@n.output()
|
||||||
|
def ignore_and_send_handshake(self, frame):
|
||||||
|
self._send_handshake()
|
||||||
|
|
||||||
@n.output()
|
@n.output()
|
||||||
def send_handshake(self):
|
def send_handshake(self):
|
||||||
handshake = self._noise.write_message() # generate the ephemeral key
|
self._send_handshake()
|
||||||
|
|
||||||
|
def _send_handshake(self):
|
||||||
|
try:
|
||||||
|
handshake = self._noise.write_message() # generate the ephemeral key
|
||||||
|
except NoiseHandshakeError as e:
|
||||||
|
log.err(e, "noise error during handshake")
|
||||||
|
raise
|
||||||
self._framer.send_frame(handshake)
|
self._framer.send_frame(handshake)
|
||||||
|
|
||||||
@n.output()
|
@n.output()
|
||||||
|
@ -367,10 +420,19 @@ class _Record(object):
|
||||||
raise Disconnect()
|
raise Disconnect()
|
||||||
return parse_record(message)
|
return parse_record(message)
|
||||||
|
|
||||||
want_prologue.upon(got_prologue, outputs=[send_handshake],
|
no_role_set.upon(set_role_leader, outputs=[], enter=want_prologue_leader)
|
||||||
enter=want_handshake)
|
want_prologue_leader.upon(got_prologue, outputs=[send_handshake],
|
||||||
want_handshake.upon(got_frame, outputs=[process_handshake],
|
enter=want_handshake_leader)
|
||||||
collector=first, enter=want_message)
|
want_handshake_leader.upon(got_frame, outputs=[process_handshake],
|
||||||
|
collector=first, enter=want_message)
|
||||||
|
|
||||||
|
no_role_set.upon(set_role_follower, outputs=[], enter=want_prologue_follower)
|
||||||
|
want_prologue_follower.upon(got_prologue, outputs=[],
|
||||||
|
enter=want_handshake_follower)
|
||||||
|
want_handshake_follower.upon(got_frame, outputs=[process_handshake,
|
||||||
|
ignore_and_send_handshake],
|
||||||
|
collector=first, enter=want_message)
|
||||||
|
|
||||||
want_message.upon(got_frame, outputs=[decrypt_message],
|
want_message.upon(got_frame, outputs=[decrypt_message],
|
||||||
collector=first, enter=want_message)
|
collector=first, enter=want_message)
|
||||||
|
|
||||||
|
@ -493,12 +555,20 @@ class DilatedConnectionProtocol(Protocol, object):
|
||||||
# IProtocol methods
|
# IProtocol methods
|
||||||
|
|
||||||
def connectionMade(self):
|
def connectionMade(self):
|
||||||
framer = _Framer(self.transport,
|
try:
|
||||||
self._outbound_prologue, self._inbound_prologue)
|
framer = _Framer(self.transport,
|
||||||
if self._use_relay:
|
self._outbound_prologue, self._inbound_prologue)
|
||||||
framer.use_relay(self._relay_handshake)
|
if self._use_relay:
|
||||||
self._record = _Record(framer, self._noise)
|
framer.use_relay(self._relay_handshake)
|
||||||
self._record.connectionMade()
|
self._record = _Record(framer, self._noise, self._role)
|
||||||
|
if self._role is LEADER:
|
||||||
|
self._record.set_role_leader()
|
||||||
|
else:
|
||||||
|
self._record.set_role_follower()
|
||||||
|
self._record.connectionMade()
|
||||||
|
except:
|
||||||
|
log.err()
|
||||||
|
raise
|
||||||
|
|
||||||
def dataReceived(self, data):
|
def dataReceived(self, data):
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -69,10 +69,20 @@ class Connection(unittest.TestCase):
|
||||||
clear_mock_calls(n, connector, t, m)
|
clear_mock_calls(n, connector, t, m)
|
||||||
|
|
||||||
c.dataReceived(b"inbound_prologue\n")
|
c.dataReceived(b"inbound_prologue\n")
|
||||||
self.assertEqual(n.mock_calls, [mock.call.write_message()])
|
|
||||||
self.assertEqual(connector.mock_calls, [])
|
|
||||||
exp_handshake = b"\x00\x00\x00\x09handshake"
|
exp_handshake = b"\x00\x00\x00\x09handshake"
|
||||||
self.assertEqual(t.mock_calls, [mock.call.write(exp_handshake)])
|
if role is LEADER:
|
||||||
|
# the LEADER sends the Noise handshake message immediately upon
|
||||||
|
# receipt of the prologue
|
||||||
|
self.assertEqual(n.mock_calls, [mock.call.write_message()])
|
||||||
|
self.assertEqual(t.mock_calls, [mock.call.write(exp_handshake)])
|
||||||
|
else:
|
||||||
|
# however the FOLLOWER waits until receiving the leader's
|
||||||
|
# handshake before sending their own
|
||||||
|
self.assertEqual(n.mock_calls, [])
|
||||||
|
self.assertEqual(t.mock_calls, [])
|
||||||
|
self.assertEqual(connector.mock_calls, [])
|
||||||
|
|
||||||
clear_mock_calls(n, connector, t, m)
|
clear_mock_calls(n, connector, t, m)
|
||||||
|
|
||||||
c.dataReceived(b"\x00\x00\x00\x0Ahandshake2")
|
c.dataReceived(b"\x00\x00\x00\x0Ahandshake2")
|
||||||
|
@ -84,13 +94,16 @@ class Connection(unittest.TestCase):
|
||||||
self.assertEqual(t.mock_calls, [])
|
self.assertEqual(t.mock_calls, [])
|
||||||
self.assertEqual(c._manager, None)
|
self.assertEqual(c._manager, None)
|
||||||
else:
|
else:
|
||||||
# we're the follower, so we encrypt and send the KCM immediately
|
# we're the follower, so we send our Noise handshake, then
|
||||||
|
# encrypt and send the KCM immediately
|
||||||
self.assertEqual(n.mock_calls, [
|
self.assertEqual(n.mock_calls, [
|
||||||
mock.call.read_message(b"handshake2"),
|
mock.call.read_message(b"handshake2"),
|
||||||
|
mock.call.write_message(),
|
||||||
mock.call.encrypt(encode_record(t_kcm)),
|
mock.call.encrypt(encode_record(t_kcm)),
|
||||||
])
|
])
|
||||||
self.assertEqual(connector.mock_calls, [])
|
self.assertEqual(connector.mock_calls, [])
|
||||||
self.assertEqual(t.mock_calls, [
|
self.assertEqual(t.mock_calls, [
|
||||||
|
mock.call.write(exp_handshake),
|
||||||
mock.call.write(exp_kcm)])
|
mock.call.write(exp_kcm)])
|
||||||
self.assertEqual(c._manager, None)
|
self.assertEqual(c._manager, None)
|
||||||
clear_mock_calls(n, connector, t, m)
|
clear_mock_calls(n, connector, t, m)
|
||||||
|
|
|
@ -6,13 +6,15 @@ from ..._dilation._noise import NoiseInvalidMessage
|
||||||
from ..._dilation.connection import (IFramer, Frame, Prologue,
|
from ..._dilation.connection import (IFramer, Frame, Prologue,
|
||||||
_Record, Handshake,
|
_Record, Handshake,
|
||||||
Disconnect, Ping)
|
Disconnect, Ping)
|
||||||
|
from ..._dilation.roles import LEADER
|
||||||
|
|
||||||
|
|
||||||
def make_record():
|
def make_record():
|
||||||
f = mock.Mock()
|
f = mock.Mock()
|
||||||
alsoProvides(f, IFramer)
|
alsoProvides(f, IFramer)
|
||||||
n = mock.Mock() # pretends to be a Noise object
|
n = mock.Mock() # pretends to be a Noise object
|
||||||
r = _Record(f, n)
|
r = _Record(f, n, LEADER)
|
||||||
|
r.set_role_leader()
|
||||||
return r, f, n
|
return r, f, n
|
||||||
|
|
||||||
|
|
||||||
|
@ -30,7 +32,8 @@ class Record(unittest.TestCase):
|
||||||
n.write_message = mock.Mock(return_value=b"tx-handshake")
|
n.write_message = mock.Mock(return_value=b"tx-handshake")
|
||||||
p1, p2 = object(), object()
|
p1, p2 = object(), object()
|
||||||
n.decrypt = mock.Mock(side_effect=[p1, p2])
|
n.decrypt = mock.Mock(side_effect=[p1, p2])
|
||||||
r = _Record(f, n)
|
r = _Record(f, n, LEADER)
|
||||||
|
r.set_role_leader()
|
||||||
self.assertEqual(f.mock_calls, [])
|
self.assertEqual(f.mock_calls, [])
|
||||||
r.connectionMade()
|
r.connectionMade()
|
||||||
self.assertEqual(f.mock_calls, [mock.call.connectionMade()])
|
self.assertEqual(f.mock_calls, [mock.call.connectionMade()])
|
||||||
|
@ -79,7 +82,8 @@ class Record(unittest.TestCase):
|
||||||
n.write_message = mock.Mock(return_value=b"tx-handshake")
|
n.write_message = mock.Mock(return_value=b"tx-handshake")
|
||||||
nvm = NoiseInvalidMessage()
|
nvm = NoiseInvalidMessage()
|
||||||
n.read_message = mock.Mock(side_effect=nvm)
|
n.read_message = mock.Mock(side_effect=nvm)
|
||||||
r = _Record(f, n)
|
r = _Record(f, n, LEADER)
|
||||||
|
r.set_role_leader()
|
||||||
self.assertEqual(f.mock_calls, [])
|
self.assertEqual(f.mock_calls, [])
|
||||||
r.connectionMade()
|
r.connectionMade()
|
||||||
self.assertEqual(f.mock_calls, [mock.call.connectionMade()])
|
self.assertEqual(f.mock_calls, [mock.call.connectionMade()])
|
||||||
|
@ -103,7 +107,8 @@ class Record(unittest.TestCase):
|
||||||
n.write_message = mock.Mock(return_value=b"tx-handshake")
|
n.write_message = mock.Mock(return_value=b"tx-handshake")
|
||||||
nvm = NoiseInvalidMessage()
|
nvm = NoiseInvalidMessage()
|
||||||
n.decrypt = mock.Mock(side_effect=nvm)
|
n.decrypt = mock.Mock(side_effect=nvm)
|
||||||
r = _Record(f, n)
|
r = _Record(f, n, LEADER)
|
||||||
|
r.set_role_leader()
|
||||||
self.assertEqual(f.mock_calls, [])
|
self.assertEqual(f.mock_calls, [])
|
||||||
r.connectionMade()
|
r.connectionMade()
|
||||||
self.assertEqual(f.mock_calls, [mock.call.connectionMade()])
|
self.assertEqual(f.mock_calls, [mock.call.connectionMade()])
|
||||||
|
@ -124,7 +129,8 @@ class Record(unittest.TestCase):
|
||||||
f1 = object()
|
f1 = object()
|
||||||
n.encrypt = mock.Mock(return_value=f1)
|
n.encrypt = mock.Mock(return_value=f1)
|
||||||
r1 = Ping(b"pingid")
|
r1 = Ping(b"pingid")
|
||||||
r = _Record(f, n)
|
r = _Record(f, n, LEADER)
|
||||||
|
r.set_role_leader()
|
||||||
self.assertEqual(f.mock_calls, [])
|
self.assertEqual(f.mock_calls, [])
|
||||||
m1 = object()
|
m1 = object()
|
||||||
with mock.patch("wormhole._dilation.connection.encode_record",
|
with mock.patch("wormhole._dilation.connection.encode_record",
|
||||||
|
|
Loading…
Reference in New Issue
Block a user