from __future__ import print_function, unicode_literals from collections import namedtuple import six from attr import attrs, attrib from attr.validators import instance_of, provides from automat import MethodicalMachine from zope.interface import Interface, implementer from twisted.python import log from twisted.internet.protocol import Protocol from twisted.internet.interfaces import ITransport from .._interfaces import IDilationConnector from ..observer import OneShotObserver from .encode import to_be4, from_be4 from .roles import FOLLOWER # InboundFraming is given data and returns Frames (Noise wire-side # bytestrings). It handles the relay handshake and the prologue. The Frames it # returns are either the ephemeral key (the Noise "handshake") or ciphertext # messages. # The next object up knows whether it's expecting a Handshake or a message. It # feeds the first into Noise as a handshake, it feeds the rest into Noise as a # message (which produces a plaintext stream). It emits tokens that are either # "i've finished with the handshake (so you can send the KCM if you want)", or # "here is a decrypted message (which might be the KCM)". # the transmit direction goes directly to transport.write, and doesn't touch # the state machine. we can do this because the way we encode/encrypt/frame # things doesn't depend upon the receiver state. It would be more safe to e.g. # prohibit sending ciphertext frames unless we're in the received-handshake # state, but then we'll be in the middle of an inbound state transition ("we # just received the handshake, so you can send a KCM now") when we perform an # operation that depends upon the state (send_plaintext(kcm)), which is not a # coherent/safe place to touch the state machine. # we could set a flag and test it from inside send_plaintext, which kind of # violates the state machine owning the state (ideally all "if" statements # would be translated into same-input transitions from different starting # states). For the specific question of sending plaintext frames, Noise will # refuse us unless it's ready anyways, so the question is probably moot. class IFramer(Interface): pass class IRecord(Interface): pass def first(l): return l[0] class Disconnect(Exception): pass RelayOK = namedtuple("RelayOk", []) Prologue = namedtuple("Prologue", []) Frame = namedtuple("Frame", ["frame"]) @attrs @implementer(IFramer) class _Framer(object): _transport = attrib(validator=provides(ITransport)) _outbound_prologue = attrib(validator=instance_of(bytes)) _inbound_prologue = attrib(validator=instance_of(bytes)) _buffer = b"" _can_send_frames = False # in: use_relay # in: connectionMade, dataReceived # out: prologue_received, frame_received # out (shared): transport.loseConnection # out (shared): transport.write (relay handshake, prologue) # states: want_relay, want_prologue, want_frame m = MethodicalMachine() set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover @m.state() def want_relay(self): pass # pragma: no cover @m.state(initial=True) def want_prologue(self): pass # pragma: no cover @m.state() def want_frame(self): pass # pragma: no cover @m.input() def use_relay(self, relay_handshake): pass @m.input() def connectionMade(self): pass @m.input() def parse(self): pass @m.input() def got_relay_ok(self): pass @m.input() def got_prologue(self): pass @m.output() def store_relay_handshake(self, relay_handshake): self._outbound_relay_handshake = relay_handshake self._expected_relay_handshake = b"ok\n" # TODO: make this configurable @m.output() def send_relay_handshake(self): self._transport.write(self._outbound_relay_handshake) @m.output() def send_prologue(self): self._transport.write(self._outbound_prologue) @m.output() def parse_relay_ok(self): if self._get_expected("relay_ok", self._expected_relay_handshake): return RelayOK() @m.output() def parse_prologue(self): if self._get_expected("prologue", self._inbound_prologue): return Prologue() @m.output() def can_send_frames(self): self._can_send_frames = True # for assertion in send_frame() @m.output() def parse_frame(self): if len(self._buffer) < 4: return None frame_length = from_be4(self._buffer[0:4]) if len(self._buffer) < 4 + frame_length: return None frame = self._buffer[4:4 + frame_length] self._buffer = self._buffer[4 + frame_length:] # TODO: avoid copy return Frame(frame=frame) want_prologue.upon(use_relay, outputs=[store_relay_handshake], enter=want_relay) want_relay.upon(connectionMade, outputs=[send_relay_handshake], enter=want_relay) want_relay.upon(parse, outputs=[parse_relay_ok], enter=want_relay, collector=first) want_relay.upon(got_relay_ok, outputs=[send_prologue], enter=want_prologue) want_prologue.upon(connectionMade, outputs=[send_prologue], enter=want_prologue) want_prologue.upon(parse, outputs=[parse_prologue], enter=want_prologue, collector=first) want_prologue.upon(got_prologue, outputs=[can_send_frames], enter=want_frame) want_frame.upon(parse, outputs=[parse_frame], enter=want_frame, collector=first) def _get_expected(self, name, expected): lb = len(self._buffer) le = len(expected) if self._buffer.startswith(expected): # if the buffer starts with the expected string, consume it and # return True self._buffer = self._buffer[le:] return True if not expected.startswith(self._buffer): # we're not on track: the data we've received so far does not # match the expected value, so this can't possibly be right. # Don't complain until we see the expected length, or a newline, # so we can capture the weird input in the log for debugging. if (b"\n" in self._buffer or lb >= le): log.msg("bad {}: {}".format(name, self._buffer[:le])) raise Disconnect() return False # wait a bit longer # good so far, just waiting for the rest return False # external API is: connectionMade, add_and_parse, and send_frame def add_and_parse(self, data): # we can't make this an @m.input because we can't change the state # from within an input. Instead, let the state choose the parser to # use, and use the parsed token drive a state transition. self._buffer += data while True: # it'd be nice to use an iterator here, but since self.parse() # dispatches to a different parser (depending upon the current # state), we'd be using multiple iterators token = self.parse() if isinstance(token, RelayOK): self.got_relay_ok() elif isinstance(token, Prologue): self.got_prologue() yield token # triggers send_handshake elif isinstance(token, Frame): yield token else: break def send_frame(self, frame): assert self._can_send_frames self._transport.write(to_be4(len(frame)) + frame) # RelayOK: Newline-terminated buddy-is-connected response from Relay. # First data received from relay. # Prologue: double-newline-terminated this-is-really-wormhole response # from peer. First data received from peer. # Frame: Either handshake or encrypted message. Length-prefixed on wire. # Handshake: the Noise ephemeral key, first framed message # Message: plaintext: encoded KCM/PING/PONG/OPEN/DATA/CLOSE/ACK # KCM: Key Confirmation Message (encrypted b"\x00"). First frame # from peer. Sent immediately by Follower, after Selection by Leader. # Record: namedtuple of KCM/Open/Data/Close/Ack/Ping/Pong Handshake = namedtuple("Handshake", []) # decrypted frames: produces KCM, Ping, Pong, Open, Data, Close, Ack KCM = namedtuple("KCM", []) Ping = namedtuple("Ping", ["ping_id"]) # ping_id is arbitrary 4-byte value Pong = namedtuple("Pong", ["ping_id"]) Open = namedtuple("Open", ["seqnum", "scid"]) # seqnum is integer Data = namedtuple("Data", ["seqnum", "scid", "data"]) Close = namedtuple("Close", ["seqnum", "scid"]) # scid is integer Ack = namedtuple("Ack", ["resp_seqnum"]) # resp_seqnum is integer Records = (KCM, Ping, Pong, Open, Data, Close, Ack) Handshake_or_Records = (Handshake,) + Records T_KCM = b"\x00" T_PING = b"\x01" T_PONG = b"\x02" T_OPEN = b"\x03" T_DATA = b"\x04" T_CLOSE = b"\x05" T_ACK = b"\x06" def parse_record(plaintext): msgtype = plaintext[0:1] if msgtype == T_KCM: return KCM() if msgtype == T_PING: ping_id = plaintext[1:5] return Ping(ping_id) if msgtype == T_PONG: ping_id = plaintext[1:5] return Pong(ping_id) if msgtype == T_OPEN: scid = from_be4(plaintext[1:5]) seqnum = from_be4(plaintext[5:9]) return Open(seqnum, scid) if msgtype == T_DATA: scid = from_be4(plaintext[1:5]) seqnum = from_be4(plaintext[5:9]) data = plaintext[9:] return Data(seqnum, scid, data) if msgtype == T_CLOSE: scid = from_be4(plaintext[1:5]) seqnum = from_be4(plaintext[5:9]) return Close(seqnum, scid) if msgtype == T_ACK: resp_seqnum = from_be4(plaintext[1:5]) return Ack(resp_seqnum) log.err("received unknown message type: {}".format(plaintext)) raise ValueError() def encode_record(r): if isinstance(r, KCM): return b"\x00" if isinstance(r, Ping): return b"\x01" + r.ping_id if isinstance(r, Pong): return b"\x02" + r.ping_id if isinstance(r, Open): assert isinstance(r.scid, six.integer_types) assert isinstance(r.seqnum, six.integer_types) return b"\x03" + to_be4(r.scid) + to_be4(r.seqnum) if isinstance(r, Data): assert isinstance(r.scid, six.integer_types) assert isinstance(r.seqnum, six.integer_types) return b"\x04" + to_be4(r.scid) + to_be4(r.seqnum) + r.data if isinstance(r, Close): assert isinstance(r.scid, six.integer_types) assert isinstance(r.seqnum, six.integer_types) return b"\x05" + to_be4(r.scid) + to_be4(r.seqnum) if isinstance(r, Ack): assert isinstance(r.resp_seqnum, six.integer_types) return b"\x06" + to_be4(r.resp_seqnum) raise TypeError(r) @attrs @implementer(IRecord) class _Record(object): _framer = attrib(validator=provides(IFramer)) _noise = attrib() n = MethodicalMachine() # TODO: set_trace def __attrs_post_init__(self): self._noise.start_handshake() # in: role= # in: prologue_received, frame_received # out: handshake_received, record_received # out: transport.write (noise handshake, encrypted records) # states: want_prologue, want_handshake, want_record @n.state(initial=True) def want_prologue(self): pass # pragma: no cover @n.state() def want_handshake(self): pass # pragma: no cover @n.state() def want_message(self): pass # pragma: no cover @n.input() def got_prologue(self): pass @n.input() def got_frame(self, frame): pass @n.output() def send_handshake(self): handshake = self._noise.write_message() # generate the ephemeral key self._framer.send_frame(handshake) @n.output() def process_handshake(self, frame): from noise.exceptions import NoiseInvalidMessage try: payload = self._noise.read_message(frame) # Noise can include unencrypted data in the handshake, but we don't # use it del payload except NoiseInvalidMessage as e: log.err(e, "bad inbound noise handshake") raise Disconnect() return Handshake() @n.output() def decrypt_message(self, frame): from noise.exceptions import NoiseInvalidMessage try: message = self._noise.decrypt(frame) except NoiseInvalidMessage as e: # if this happens during tests, flunk the test log.err(e, "bad inbound noise frame") raise Disconnect() return parse_record(message) want_prologue.upon(got_prologue, outputs=[send_handshake], enter=want_handshake) want_handshake.upon(got_frame, outputs=[process_handshake], collector=first, enter=want_message) want_message.upon(got_frame, outputs=[decrypt_message], collector=first, enter=want_message) # external API is: connectionMade, dataReceived, send_record def connectionMade(self): self._framer.connectionMade() def add_and_unframe(self, data): for token in self._framer.add_and_parse(data): if isinstance(token, Prologue): self.got_prologue() # triggers send_handshake else: assert isinstance(token, Frame) yield self.got_frame(token.frame) # Handshake or a Record type def send_record(self, r): message = encode_record(r) frame = self._noise.encrypt(message) self._framer.send_frame(frame) @attrs class DilatedConnectionProtocol(Protocol, object): """I manage an L2 connection. When a new L2 connection is needed (as determined by the Leader), both Leader and Follower will initiate many simultaneous connections (probably TCP, but conceivably others). A subset will actually connect. A subset of those will successfully pass negotiation by exchanging handshakes to demonstrate knowledge of the session key. One of the negotiated connections will be selected by the Leader for active use, and the others will be dropped. At any given time, there is at most one active L2 connection. """ _eventual_queue = attrib() _role = attrib() _connector = attrib(validator=provides(IDilationConnector)) _noise = attrib() _outbound_prologue = attrib(validator=instance_of(bytes)) _inbound_prologue = attrib(validator=instance_of(bytes)) _use_relay = False _relay_handshake = None m = MethodicalMachine() set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover def __attrs_post_init__(self): self._manager = None # set if/when we are selected self._disconnected = OneShotObserver(self._eventual_queue) self._can_send_records = False @m.state(initial=True) def unselected(self): pass # pragma: no cover @m.state() def selecting(self): pass # pragma: no cover @m.state() def selected(self): pass # pragma: no cover @m.input() def got_kcm(self): pass @m.input() def select(self, manager): pass # fires set_manager() @m.input() def got_record(self, record): pass @m.output() def add_candidate(self): self._connector.add_candidate(self) @m.output() def set_manager(self, manager): self._manager = manager @m.output() def can_send_records(self, manager): self._can_send_records = True @m.output() def deliver_record(self, record): self._manager.got_record(record) unselected.upon(got_kcm, outputs=[add_candidate], enter=selecting) selecting.upon(select, outputs=[set_manager, can_send_records], enter=selected) selected.upon(got_record, outputs=[deliver_record], enter=selected) # called by Connector def use_relay(self, relay_handshake): assert isinstance(relay_handshake, bytes) self._use_relay = True self._relay_handshake = relay_handshake def when_disconnected(self): return self._disconnected.when_fired() def disconnect(self): self.transport.loseConnection() # select() called by Connector # called by Manager def send_record(self, record): assert self._can_send_records self._record.send_record(record) # IProtocol methods def connectionMade(self): framer = _Framer(self.transport, self._outbound_prologue, self._inbound_prologue) if self._use_relay: framer.use_relay(self._relay_handshake) self._record = _Record(framer, self._noise) self._record.connectionMade() def dataReceived(self, data): try: for token in self._record.add_and_unframe(data): assert isinstance(token, Handshake_or_Records) if isinstance(token, Handshake): if self._role is FOLLOWER: self._record.send_record(KCM()) elif isinstance(token, KCM): # if we're the leader, add this connection as a candiate. # if we're the follower, accept this connection. self.got_kcm() # connector.add_candidate() else: self.got_record(token) # manager.got_record() except Disconnect: self.transport.loseConnection() def connectionLost(self, why=None): self._disconnected.fire(self)