from __future__ import print_function, unicode_literals import os from collections import deque from attr import attrs, attrib from attr.validators import provides, instance_of, optional from automat import MethodicalMachine from zope.interface import implementer from twisted.internet.defer import Deferred, inlineCallbacks, returnValue from twisted.internet.interfaces import IAddress from twisted.python import log from .._interfaces import IDilator, IDilationManager, ISend, ITerminator from ..util import dict_to_bytes, bytes_to_dict, bytes_to_hexstr from ..observer import OneShotObserver from .._key import derive_key from .encode import to_be4 from .subchannel import (SubChannel, _SubchannelAddress, _WormholeAddress, ControlEndpoint, SubchannelConnectorEndpoint, SubchannelListenerEndpoint) from .connector import Connector from .._hints import parse_hint from .roles import LEADER, FOLLOWER from .connection import KCM, Ping, Pong, Open, Data, Close, Ack from .inbound import Inbound from .outbound import Outbound # exported to Wormhole() for inclusion in versions message DILATION_VERSIONS = ["1"] class OldPeerCannotDilateError(Exception): pass class UnknownDilationMessageType(Exception): pass class ReceivedHintsTooEarly(Exception): pass class UnexpectedKCM(Exception): pass class UnknownMessageType(Exception): pass def make_side(): return bytes_to_hexstr(os.urandom(6)) # new scheme: # * both sides send PLEASE as soon as they have an unverified key and # w.dilate has been called, # * PLEASE includes a dilation-specific "side" (independent of the "side" # used by mailbox messages) # * higher "side" is Leader, lower is Follower # * PLEASE includes can-dilate list of version integers, requires overlap # "1" is current # * we start dilation after both w.dilate() and receiving VERSION, putting us # in WANTING, then we process all previously-queued inbound DILATE-n # messages. When PLEASE arrives, we move to CONNECTING # * HINTS sent after dilation starts # * only Leader sends RECONNECT, only Follower sends RECONNECTING. This # is the only difference between the two sides, and is not enforced # by the protocol (i.e. if the Follower sends RECONNECT to the Leader, # the Leader will obey, although TODO how confusing will this get?) # * upon receiving RECONNECT: drop Connector, start new Connector, send # RECONNECTING, start sending HINTS # * upon sending RECONNECT: go into FLUSHING state and ignore all HINTS until # RECONNECTING received. The new Connector can be spun up earlier, and it # can send HINTS, but it must not be given any HINTS that arrive before # RECONNECTING (since they're probably stale) # * after VERSIONS(KCM) received, we might learn that they other side cannot # dilate. w.dilate errbacks at this point # * maybe signal warning if we stay in a "want" state for too long # * nobody sends HINTS until they're ready to receive # * nobody sends HINTS unless they've called w.dilate() and received PLEASE # * nobody connects to inbound hints unless they've called w.dilate() # * if leader calls w.dilate() but not follower, leader waits forever in # "want" (doesn't send anything) # * if follower calls w.dilate() but not leader, follower waits forever # in "want", leader waits forever in "wanted" @attrs(cmp=False) @implementer(IDilationManager) class Manager(object): _S = attrib(validator=provides(ISend), repr=False) _my_side = attrib(validator=instance_of(type(u""))) _transit_key = attrib(validator=instance_of(bytes), repr=False) _transit_relay_location = attrib(validator=optional(instance_of(str))) _reactor = attrib(repr=False) _eventual_queue = attrib(repr=False) _cooperator = attrib(repr=False) _host_addr = attrib(validator=provides(IAddress)) _no_listen = attrib(default=False) _tor = None # TODO _timing = None # TODO _next_subchannel_id = None # initialized in choose_role m = MethodicalMachine() set_trace = getattr(m, "_setTrace", lambda self, f: None) # pragma: no cover def __attrs_post_init__(self): self._got_versions_d = Deferred() self._my_role = None # determined upon rx_PLEASE self._connection = None self._made_first_connection = False self._first_connected = OneShotObserver(self._eventual_queue) self._stopped = OneShotObserver(self._eventual_queue) self._next_dilation_phase = 0 # I kept getting confused about which methods were for inbound data # (and thus flow-control methods go "out") and which were for # outbound data (with flow-control going "in"), so I split them up # into separate pieces. self._inbound = Inbound(self, self._host_addr) self._outbound = Outbound(self, self._cooperator) # from us to peer def set_listener_endpoint(self, listener_endpoint): self._inbound.set_listener_endpoint(listener_endpoint) def set_subchannel_zero(self, scid0, sc0): self._inbound.set_subchannel_zero(scid0, sc0) def when_first_connected(self): return self._first_connected.when_fired() def when_stopped(self): return self._stopped.when_fired() def send_dilation_phase(self, **fields): dilation_phase = self._next_dilation_phase self._next_dilation_phase += 1 self._S.send("dilate-%d" % dilation_phase, dict_to_bytes(fields)) def send_hints(self, hints): # from Connector self.send_dilation_phase(type="connection-hints", hints=hints) # forward inbound-ish things to _Inbound def subchannel_pauseProducing(self, sc): self._inbound.subchannel_pauseProducing(sc) def subchannel_resumeProducing(self, sc): self._inbound.subchannel_resumeProducing(sc) def subchannel_stopProducing(self, sc): self._inbound.subchannel_stopProducing(sc) # forward outbound-ish things to _Outbound def subchannel_registerProducer(self, sc, producer, streaming): self._outbound.subchannel_registerProducer(sc, producer, streaming) def subchannel_unregisterProducer(self, sc): self._outbound.subchannel_unregisterProducer(sc) def send_open(self, scid): assert isinstance(scid, bytes) self._queue_and_send(Open, scid) def send_data(self, scid, data): assert isinstance(scid, bytes) self._queue_and_send(Data, scid, data) def send_close(self, scid): assert isinstance(scid, bytes) self._queue_and_send(Close, scid) def _queue_and_send(self, record_type, *args): r = self._outbound.build_record(record_type, *args) # Outbound owns the send_record() pipe, so that it can stall new # writes after a new connection is made until after all queued # messages are written (to preserve ordering). self._outbound.queue_and_send_record(r) # may trigger pauseProducing def subchannel_closed(self, scid, sc): # let everyone clean up. This happens just after we delivered # connectionLost to the Protocol, except for the control channel, # which might get connectionLost later after they use ep.connect. # TODO: is this inversion a problem? self._inbound.subchannel_closed(scid, sc) self._outbound.subchannel_closed(scid, sc) # our Connector calls these def connector_connection_made(self, c): self.connection_made() # state machine update self._connection = c self._inbound.use_connection(c) self._outbound.use_connection(c) # does c.registerProducer if not self._made_first_connection: self._made_first_connection = True self._first_connected.fire(None) pass def connector_connection_lost(self): self._stop_using_connection() if self._my_role is LEADER: self.connection_lost_leader() # state machine else: self.connection_lost_follower() def _stop_using_connection(self): # the connection is already lost by this point self._connection = None self._inbound.stop_using_connection() self._outbound.stop_using_connection() # does c.unregisterProducer # from our active Connection def got_record(self, r): # records with sequence numbers: always ack, ignore old ones if isinstance(r, (Open, Data, Close)): self.send_ack(r.seqnum) # always ack, even for old ones if self._inbound.is_record_old(r): return self._inbound.update_ack_watermark(r.seqnum) if isinstance(r, Open): self._inbound.handle_open(r.scid) elif isinstance(r, Data): self._inbound.handle_data(r.scid, r.data) else: # isinstance(r, Close) self._inbound.handle_close(r.scid) return if isinstance(r, KCM): log.err(UnexpectedKCM()) elif isinstance(r, Ping): self.handle_ping(r.ping_id) elif isinstance(r, Pong): self.handle_pong(r.ping_id) elif isinstance(r, Ack): self._outbound.handle_ack(r.resp_seqnum) # retire queued messages else: log.err(UnknownMessageType("{}".format(r))) # pings, pongs, and acks are not queued def send_ping(self, ping_id): self._outbound.send_if_connected(Ping(ping_id)) def send_pong(self, ping_id): self._outbound.send_if_connected(Pong(ping_id)) def send_ack(self, resp_seqnum): self._outbound.send_if_connected(Ack(resp_seqnum)) def handle_ping(self, ping_id): self.send_pong(ping_id) def handle_pong(self, ping_id): # TODO: update is-alive timer pass # subchannel maintenance def allocate_subchannel_id(self): scid_num = self._next_subchannel_id self._next_subchannel_id += 2 return to_be4(scid_num) # state machine # We are born WANTING after the local app calls w.dilate(). We start # CONNECTING when we receive PLEASE from the remote side def start(self): self.send_please() def send_please(self): self.send_dilation_phase(type="please", side=self._my_side) @m.state(initial=True) def WANTING(self): pass # pragma: no cover @m.state() def CONNECTING(self): pass # pragma: no cover @m.state() def CONNECTED(self): pass # pragma: no cover @m.state() def FLUSHING(self): pass # pragma: no cover @m.state() def ABANDONING(self): pass # pragma: no cover @m.state() def LONELY(self): pass # pragma: no cover @m.state() def STOPPING(self): pass # pragma: no cover @m.state(terminal=True) def STOPPED(self): pass # pragma: no cover @m.input() def rx_PLEASE(self, message): pass # pragma: no cover @m.input() # only sent by Follower def rx_HINTS(self, hint_message): pass # pragma: no cover @m.input() # only Leader sends RECONNECT, so only Follower receives it def rx_RECONNECT(self): pass # pragma: no cover @m.input() # only Follower sends RECONNECTING, so only Leader receives it def rx_RECONNECTING(self): pass # pragma: no cover # Connector gives us connection_made() @m.input() def connection_made(self): pass # pragma: no cover # our connection_lost() fires connection_lost_leader or # connection_lost_follower depending upon our role. If either side sees a # problem with the connection (timeouts, bad authentication) then they # just drop it and let connection_lost() handle the cleanup. @m.input() def connection_lost_leader(self): pass # pragma: no cover @m.input() def connection_lost_follower(self): pass @m.input() def stop(self): pass # pragma: no cover @m.output() def choose_role(self, message): their_side = message["side"] if self._my_side > their_side: self._my_role = LEADER # scid 0 is reserved for the control channel. the leader uses odd # numbers starting with 1 self._next_subchannel_id = 1 elif their_side > self._my_side: self._my_role = FOLLOWER # the follower uses even numbers starting with 2 self._next_subchannel_id = 2 else: raise ValueError("their side shouldn't be equal: reflection?") # these Outputs behave differently for the Leader vs the Follower @m.output() def start_connecting_ignore_message(self, message): del message # ignored return self._start_connecting() @m.output() def start_connecting(self): self._start_connecting() def _start_connecting(self): assert self._my_role is not None self._connector = Connector(self._transit_key, self._transit_relay_location, self, self._reactor, self._eventual_queue, self._no_listen, self._tor, self._timing, self._my_side, # needed for relay handshake self._my_role) self._connector.start() @m.output() def send_reconnect(self): self.send_dilation_phase(type="reconnect") # TODO: generation number? @m.output() def send_reconnecting(self): self.send_dilation_phase(type="reconnecting") # TODO: generation? @m.output() def use_hints(self, hint_message): hint_objs = filter(lambda h: h, # ignore None, unrecognizable [parse_hint(hs) for hs in hint_message["hints"]]) hint_objs = list(hint_objs) self._connector.got_hints(hint_objs) @m.output() def stop_connecting(self): self._connector.stop() @m.output() def abandon_connection(self): # we think we're still connected, but the Leader disagrees. Or we've # been told to shut down. self._connection.disconnect() # let connection_lost do cleanup @m.output() def notify_stopped(self): self._stopped.fire(None) # we start CONNECTING when we get rx_PLEASE WANTING.upon(rx_PLEASE, enter=CONNECTING, outputs=[choose_role, start_connecting_ignore_message]) CONNECTING.upon(connection_made, enter=CONNECTED, outputs=[]) # Leader CONNECTED.upon(connection_lost_leader, enter=FLUSHING, outputs=[send_reconnect]) FLUSHING.upon(rx_RECONNECTING, enter=CONNECTING, outputs=[start_connecting]) # Follower # if we notice a lost connection, just wait for the Leader to notice too CONNECTED.upon(connection_lost_follower, enter=LONELY, outputs=[]) LONELY.upon(rx_RECONNECT, enter=CONNECTING, outputs=[send_reconnecting, start_connecting]) # but if they notice it first, abandon our (seemingly functional) # connection, then tell them that we're ready to try again CONNECTED.upon(rx_RECONNECT, enter=ABANDONING, outputs=[abandon_connection]) ABANDONING.upon(connection_lost_follower, enter=CONNECTING, outputs=[send_reconnecting, start_connecting]) # and if they notice a problem while we're still connecting, abandon our # incomplete attempt and try again. in this case we don't have to wait # for a connection to finish shutdown CONNECTING.upon(rx_RECONNECT, enter=CONNECTING, outputs=[stop_connecting, send_reconnecting, start_connecting]) # rx_HINTS never changes state, they're just accepted or ignored WANTING.upon(rx_HINTS, enter=WANTING, outputs=[]) # too early CONNECTING.upon(rx_HINTS, enter=CONNECTING, outputs=[use_hints]) CONNECTED.upon(rx_HINTS, enter=CONNECTED, outputs=[]) # too late, ignore FLUSHING.upon(rx_HINTS, enter=FLUSHING, outputs=[]) # stale, ignore LONELY.upon(rx_HINTS, enter=LONELY, outputs=[]) # stale, ignore ABANDONING.upon(rx_HINTS, enter=ABANDONING, outputs=[]) # shouldn't happen STOPPING.upon(rx_HINTS, enter=STOPPING, outputs=[]) WANTING.upon(stop, enter=STOPPED, outputs=[notify_stopped]) CONNECTING.upon(stop, enter=STOPPED, outputs=[stop_connecting, notify_stopped]) CONNECTED.upon(stop, enter=STOPPING, outputs=[abandon_connection]) ABANDONING.upon(stop, enter=STOPPING, outputs=[]) FLUSHING.upon(stop, enter=STOPPED, outputs=[notify_stopped]) LONELY.upon(stop, enter=STOPPED, outputs=[notify_stopped]) STOPPING.upon(connection_lost_leader, enter=STOPPED, outputs=[notify_stopped]) STOPPING.upon(connection_lost_follower, enter=STOPPED, outputs=[notify_stopped]) @attrs @implementer(IDilator) class Dilator(object): """I launch the dilation process. I am created with every Wormhole (regardless of whether .dilate() was called or not), and I handle the initial phase of dilation, before we know whether we'll be the Leader or the Follower. Once we hear the other side's VERSION message (which tells us that we have a connection, they are capable of dilating, and which side we're on), then we build a Manager and hand control to it. """ _reactor = attrib() _eventual_queue = attrib() _cooperator = attrib() def __attrs_post_init__(self): self._got_versions_d = Deferred() self._started = False self._endpoints = OneShotObserver(self._eventual_queue) self._pending_inbound_dilate_messages = deque() self._manager = None self._host_addr = _WormholeAddress() def wire(self, sender, terminator): self._S = ISend(sender) self._T = ITerminator(terminator) # this is the primary entry point, called when w.dilate() is invoked def dilate(self, transit_relay_location=None, no_listen=False): self._transit_relay_location = transit_relay_location if not self._started: self._started = True self._start(no_listen).addBoth(self._endpoints.fire) return self._endpoints.when_fired() @inlineCallbacks def _start(self, no_listen): # first, we wait until we hear the VERSION message, which tells us 1: # the PAKE key works, so we can talk securely, 2: that they can do # dilation at all (if they can't then w.dilate() errbacks) dilation_version = yield self._got_versions_d # TODO: we could probably return the endpoints earlier, if we flunk # any connection/listen attempts upon OldPeerCannotDilateError, or # if/when we give up on the initial connection if not dilation_version: # "1" or None # TODO: be more specific about the error. dilation_version==None # means we had no version in common with them, which could either # be because they're so old they don't dilate at all, or because # they're so new that they no longer accomodate our old version raise OldPeerCannotDilateError() my_dilation_side = make_side() self._manager = Manager(self._S, my_dilation_side, self._transit_key, self._transit_relay_location, self._reactor, self._eventual_queue, self._cooperator, self._host_addr, no_listen) # We must open subchannel0 early, since messages may arrive very # quickly once the connection is established. This subchannel may or # may not ever get revealed to the caller, since the peer might not # even be capable of dilation. scid0 = to_be4(0) peer_addr0 = _SubchannelAddress(scid0) sc0 = SubChannel(scid0, self._manager, self._host_addr, peer_addr0) self._manager.set_subchannel_zero(scid0, sc0) self._manager.start() while self._pending_inbound_dilate_messages: plaintext = self._pending_inbound_dilate_messages.popleft() self.received_dilate(plaintext) yield self._manager.when_first_connected() # we can open non-zero subchannels as soon as we get our first # connection control_ep = ControlEndpoint(peer_addr0) control_ep._subchannel_zero_opened(sc0) connect_ep = SubchannelConnectorEndpoint(self._manager, self._host_addr) listen_ep = SubchannelListenerEndpoint(self._manager, self._host_addr) self._manager.set_listener_endpoint(listen_ep) endpoints = (control_ep, connect_ep, listen_ep) returnValue(endpoints) # Called by Terminator after everything else (mailbox, nameplate, server # connection) has shut down. Expects to fire T.stoppedD() when Dilator is # stopped too. def stop(self): if not self._started: self._T.stoppedD() return if self._started: self._manager.stop() # TODO: avoid Deferreds for control flow, hard to serialize self._manager.when_stopped().addCallback(lambda _: self._T.stoppedD()) # TODO: tolerate multiple calls # from Boss def got_key(self, key): # TODO: verify this happens before got_wormhole_versions, or add a gate # to tolerate either ordering purpose = b"dilation-v1" LENGTH = 32 # TODO: whatever Noise wants, I guess self._transit_key = derive_key(key, purpose, LENGTH) def got_wormhole_versions(self, their_wormhole_versions): assert self._transit_key is not None # this always happens before received_dilate dilation_version = None their_dilation_versions = set(their_wormhole_versions.get("can-dilate", [])) my_versions = set(DILATION_VERSIONS) shared_versions = my_versions.intersection(their_dilation_versions) if "1" in shared_versions: dilation_version = "1" self._got_versions_d.callback(dilation_version) def received_dilate(self, plaintext): # this receives new in-order DILATE-n payloads, decrypted but not # de-JSONed. # this can appear before our .dilate() method is called, in which case # we queue them for later if not self._manager: self._pending_inbound_dilate_messages.append(plaintext) return message = bytes_to_dict(plaintext) type = message["type"] if type == "please": self._manager.rx_PLEASE(message) elif type == "connection-hints": self._manager.rx_HINTS(message) elif type == "reconnect": self._manager.rx_RECONNECT() elif type == "reconnecting": self._manager.rx_RECONNECTING() else: log.err(UnknownDilationMessageType(message)) return