437 lines
14 KiB
Python
437 lines
14 KiB
Python
import six
|
|
from collections import deque
|
|
from attr import attrs, attrib
|
|
from attr.validators import instance_of, provides
|
|
from zope.interface import implementer
|
|
from twisted.internet.defer import inlineCallbacks, returnValue
|
|
from twisted.internet.interfaces import (ITransport, IProducer, IConsumer,
|
|
IAddress, IListeningPort,
|
|
IHalfCloseableProtocol,
|
|
IStreamClientEndpoint,
|
|
IStreamServerEndpoint)
|
|
from twisted.internet.error import ConnectionDone
|
|
from automat import MethodicalMachine
|
|
from .._interfaces import ISubChannel, IDilationManager
|
|
from ..observer import OneShotObserver
|
|
|
|
# each subchannel frame (the data passed into transport.write(data)) gets a
|
|
# 9-byte header prefix (type, subchannel id, and sequence number), then gets
|
|
# encrypted (adding a 16-byte authentication tag). The result is transmitted
|
|
# with a 4-byte length prefix (which only covers the padded message, not the
|
|
# length prefix itself), so the padded message must be less than 2**32 bytes
|
|
# long.
|
|
MAX_FRAME_LENGTH = 2**32 - 1 - 9 - 16;
|
|
|
|
@attrs
|
|
class Once(object):
|
|
_errtype = attrib()
|
|
|
|
def __attrs_post_init__(self):
|
|
self._called = False
|
|
|
|
def __call__(self):
|
|
if self._called:
|
|
raise self._errtype()
|
|
self._called = True
|
|
|
|
|
|
class SingleUseEndpointError(Exception):
|
|
pass
|
|
|
|
# created in the (OPEN) state, by either:
|
|
# * receipt of an OPEN message
|
|
# * or local client_endpoint.connect()
|
|
# then transitions are:
|
|
# (OPEN) rx DATA: deliver .dataReceived(), -> (OPEN)
|
|
# (OPEN) rx CLOSE: deliver .connectionLost(), send CLOSE, -> (CLOSED)
|
|
# (OPEN) local .write(): send DATA, -> (OPEN)
|
|
# (OPEN) local .loseConnection(): send CLOSE, -> (CLOSING)
|
|
# (CLOSING) local .write(): error
|
|
# (CLOSING) local .loseConnection(): error
|
|
# (CLOSING) rx DATA: deliver .dataReceived(), -> (CLOSING)
|
|
# (CLOSING) rx CLOSE: deliver .connectionLost(), -> (CLOSED)
|
|
# object is deleted upon transition to (CLOSED)
|
|
|
|
|
|
class AlreadyClosedError(Exception):
|
|
pass
|
|
|
|
class NormalCloseUsedOnHalfCloseable(Exception):
|
|
pass
|
|
class HalfCloseUsedOnNonHalfCloseable(Exception):
|
|
pass
|
|
|
|
|
|
@implementer(IAddress)
|
|
class _WormholeAddress(object):
|
|
pass
|
|
|
|
|
|
@implementer(IAddress)
|
|
@attrs
|
|
class _SubchannelAddress(object):
|
|
_scid = attrib(validator=instance_of(six.integer_types))
|
|
|
|
|
|
@attrs(cmp=False)
|
|
@implementer(ITransport)
|
|
@implementer(IProducer)
|
|
@implementer(IConsumer)
|
|
@implementer(ISubChannel)
|
|
class SubChannel(object):
|
|
_scid = attrib(validator=instance_of(six.integer_types))
|
|
_manager = attrib(validator=provides(IDilationManager))
|
|
_host_addr = attrib(validator=instance_of(_WormholeAddress))
|
|
_peer_addr = attrib(validator=instance_of(_SubchannelAddress))
|
|
|
|
m = MethodicalMachine()
|
|
set_trace = getattr(m, "_setTrace", lambda self,
|
|
f: None) # pragma: no cover
|
|
|
|
def __attrs_post_init__(self):
|
|
# self._mailbox = None
|
|
# self._pending_outbound = {}
|
|
# self._processed = set()
|
|
self._protocol = None
|
|
self._pending_remote_data = []
|
|
self._pending_remote_close = False
|
|
|
|
@m.state(initial=True)
|
|
def unconnected(self):
|
|
pass # pragma: no cover
|
|
|
|
# once we get the IProtocol, it's either a IHalfCloseableProtocol, or it
|
|
# can only be fully closed
|
|
@m.state()
|
|
def open_half(self):
|
|
pass # pragma: no cover
|
|
|
|
@m.state()
|
|
def read_closed():
|
|
pass # pragma: no cover
|
|
|
|
@m.state()
|
|
def write_closed():
|
|
pass # pragma: no cover
|
|
|
|
@m.state()
|
|
def open_full(self):
|
|
pass # pragma: no cover
|
|
|
|
@m.state()
|
|
def closing():
|
|
pass # pragma: no cover
|
|
|
|
@m.state()
|
|
def closed():
|
|
pass # pragma: no cover
|
|
|
|
@m.input()
|
|
def connect_protocol_half(self):
|
|
pass
|
|
|
|
@m.input()
|
|
def connect_protocol_full(self):
|
|
pass
|
|
|
|
@m.input()
|
|
def remote_data(self, data):
|
|
pass
|
|
|
|
@m.input()
|
|
def remote_close(self):
|
|
pass
|
|
|
|
@m.input()
|
|
def local_data(self, data):
|
|
pass
|
|
|
|
@m.input()
|
|
def local_close(self):
|
|
pass
|
|
|
|
@m.output()
|
|
def queue_remote_data(self, data):
|
|
self._pending_remote_data.append(data)
|
|
|
|
@m.output()
|
|
def queue_remote_close(self):
|
|
self._pending_remote_close = True
|
|
|
|
@m.output()
|
|
def send_data(self, data):
|
|
self._manager.send_data(self._scid, data)
|
|
|
|
@m.output()
|
|
def send_close(self):
|
|
self._manager.send_close(self._scid)
|
|
|
|
@m.output()
|
|
def signal_dataReceived(self, data):
|
|
assert self._protocol
|
|
self._protocol.dataReceived(data)
|
|
|
|
@m.output()
|
|
def signal_readConnectionLost(self):
|
|
IHalfCloseableProtocol(self._protocol).readConnectionLost()
|
|
|
|
@m.output()
|
|
def signal_writeConnectionLost(self):
|
|
IHalfCloseableProtocol(self._protocol).writeConnectionLost()
|
|
|
|
@m.output()
|
|
def signal_connectionLost(self):
|
|
assert self._protocol
|
|
self._protocol.connectionLost(ConnectionDone())
|
|
|
|
@m.output()
|
|
def close_subchannel(self):
|
|
self._manager.subchannel_closed(self._scid, self)
|
|
# we're deleted momentarily
|
|
|
|
@m.output()
|
|
def error_closed_write(self, data):
|
|
raise AlreadyClosedError("write not allowed on closed subchannel")
|
|
|
|
@m.output()
|
|
def error_closed_close(self):
|
|
raise AlreadyClosedError(
|
|
"loseConnection not allowed on closed subchannel")
|
|
|
|
# stuff that arrives before we have a protocol connected
|
|
unconnected.upon(remote_data, enter=unconnected, outputs=[queue_remote_data])
|
|
unconnected.upon(remote_close, enter=unconnected, outputs=[queue_remote_close])
|
|
|
|
# IHalfCloseableProtocol flow
|
|
unconnected.upon(connect_protocol_half, enter=open_half, outputs=[])
|
|
open_half.upon(remote_data, enter=open_half, outputs=[signal_dataReceived])
|
|
open_half.upon(local_data, enter=open_half, outputs=[send_data])
|
|
# remote closes first
|
|
open_half.upon(remote_close, enter=read_closed, outputs=[signal_readConnectionLost])
|
|
read_closed.upon(local_data, enter=read_closed, outputs=[send_data])
|
|
read_closed.upon(local_close, enter=closed, outputs=[send_close,
|
|
close_subchannel,
|
|
# TODO: eventual-signal this?
|
|
signal_writeConnectionLost,
|
|
])
|
|
# local closes first
|
|
open_half.upon(local_close, enter=write_closed, outputs=[signal_writeConnectionLost,
|
|
send_close])
|
|
write_closed.upon(local_data, enter=write_closed, outputs=[error_closed_write])
|
|
write_closed.upon(remote_data, enter=write_closed, outputs=[signal_dataReceived])
|
|
write_closed.upon(remote_close, enter=closed, outputs=[close_subchannel,
|
|
signal_readConnectionLost,
|
|
])
|
|
# error cases
|
|
write_closed.upon(local_close, enter=write_closed, outputs=[error_closed_close])
|
|
|
|
# fully-closeable-only flow
|
|
unconnected.upon(connect_protocol_full, enter=open_full, outputs=[])
|
|
open_full.upon(remote_data, enter=open_full, outputs=[signal_dataReceived])
|
|
open_full.upon(local_data, enter=open_full, outputs=[send_data])
|
|
open_full.upon(remote_close, enter=closed, outputs=[send_close,
|
|
close_subchannel,
|
|
signal_connectionLost])
|
|
open_full.upon(local_close, enter=closing, outputs=[send_close])
|
|
closing.upon(remote_data, enter=closing, outputs=[signal_dataReceived])
|
|
closing.upon(remote_close, enter=closed, outputs=[close_subchannel,
|
|
signal_connectionLost])
|
|
# error cases
|
|
# we won't ever see an OPEN, since L4 will log+ignore those for us
|
|
closing.upon(local_data, enter=closing, outputs=[error_closed_write])
|
|
closing.upon(local_close, enter=closing, outputs=[error_closed_close])
|
|
# the CLOSED state won't ever see messages, since we'll be deleted
|
|
|
|
# our endpoints use these
|
|
|
|
def _set_protocol(self, protocol):
|
|
assert not self._protocol
|
|
self._protocol = protocol
|
|
if IHalfCloseableProtocol.providedBy(protocol):
|
|
self.connect_protocol_half()
|
|
else:
|
|
# move from UNCONNECTED to OPEN
|
|
self.connect_protocol_full();
|
|
|
|
def _deliver_queued_data(self):
|
|
for data in self._pending_remote_data:
|
|
self.remote_data(data)
|
|
del self._pending_remote_data
|
|
if self._pending_remote_close:
|
|
self.remote_close()
|
|
del self._pending_remote_close
|
|
|
|
# ITransport
|
|
def write(self, data):
|
|
assert isinstance(data, type(b""))
|
|
assert len(data) <= MAX_FRAME_LENGTH
|
|
self.local_data(data)
|
|
|
|
def writeSequence(self, iovec):
|
|
self.write(b"".join(iovec))
|
|
|
|
def loseWriteConnection(self):
|
|
if not IHalfCloseableProtocol.providedBy(self._protocol):
|
|
# this is a clear error
|
|
raise HalfCloseUsedOnNonHalfCloseable()
|
|
self.local_close();
|
|
|
|
def loseConnection(self):
|
|
# TODO: what happens if an IHalfCloseableProtocol calls normal
|
|
# loseConnection()? I think we need to close the read side too.
|
|
if IHalfCloseableProtocol.providedBy(self._protocol):
|
|
# I don't know is correct, so avoid this for now
|
|
raise NormalCloseUsedOnHalfCloseable()
|
|
self.local_close()
|
|
|
|
def getHost(self):
|
|
# we define "host addr" as the overall wormhole
|
|
return self._host_addr
|
|
|
|
def getPeer(self):
|
|
# and "peer addr" as the subchannel within that wormhole
|
|
return self._peer_addr
|
|
|
|
# IProducer: throttle inbound data (wormhole "up" to local app's Protocol)
|
|
def stopProducing(self):
|
|
self._manager.subchannel_stopProducing(self)
|
|
|
|
def pauseProducing(self):
|
|
self._manager.subchannel_pauseProducing(self)
|
|
|
|
def resumeProducing(self):
|
|
self._manager.subchannel_resumeProducing(self)
|
|
|
|
# IConsumer: allow the wormhole to throttle outbound data (app->wormhole)
|
|
def registerProducer(self, producer, streaming):
|
|
self._manager.subchannel_registerProducer(self, producer, streaming)
|
|
|
|
def unregisterProducer(self):
|
|
self._manager.subchannel_unregisterProducer(self)
|
|
|
|
|
|
@implementer(IStreamClientEndpoint)
|
|
@attrs
|
|
class ControlEndpoint(object):
|
|
_peer_addr = attrib(validator=provides(IAddress))
|
|
_subchannel_zero = attrib(validator=provides(ISubChannel))
|
|
_eventual_queue = attrib(repr=False)
|
|
_used = False
|
|
|
|
def __attrs_post_init__(self):
|
|
self._once = Once(SingleUseEndpointError)
|
|
self._wait_for_main_channel = OneShotObserver(self._eventual_queue)
|
|
|
|
# from manager
|
|
|
|
def _main_channel_ready(self):
|
|
self._wait_for_main_channel.fire(None)
|
|
def _main_channel_failed(self, f):
|
|
self._wait_for_main_channel.error(f)
|
|
|
|
@inlineCallbacks
|
|
def connect(self, protocolFactory):
|
|
# return Deferred that fires with IProtocol or Failure(ConnectError)
|
|
self._once()
|
|
yield self._wait_for_main_channel.when_fired()
|
|
p = protocolFactory.buildProtocol(self._peer_addr)
|
|
self._subchannel_zero._set_protocol(p)
|
|
# this sets p.transport and calls p.connectionMade()
|
|
p.makeConnection(self._subchannel_zero)
|
|
self._subchannel_zero._deliver_queued_data()
|
|
returnValue(p)
|
|
|
|
|
|
@implementer(IStreamClientEndpoint)
|
|
@attrs
|
|
class SubchannelConnectorEndpoint(object):
|
|
_manager = attrib(validator=provides(IDilationManager))
|
|
_host_addr = attrib(validator=instance_of(_WormholeAddress))
|
|
_eventual_queue = attrib(repr=False)
|
|
|
|
def __attrs_post_init__(self):
|
|
self._connection_deferreds = deque()
|
|
self._wait_for_main_channel = OneShotObserver(self._eventual_queue)
|
|
|
|
def _main_channel_ready(self):
|
|
self._wait_for_main_channel.fire(None)
|
|
def _main_channel_failed(self, f):
|
|
self._wait_for_main_channel.error(f)
|
|
|
|
@inlineCallbacks
|
|
def connect(self, protocolFactory):
|
|
# return Deferred that fires with IProtocol or Failure(ConnectError)
|
|
yield self._wait_for_main_channel.when_fired()
|
|
scid = self._manager.allocate_subchannel_id()
|
|
self._manager.send_open(scid)
|
|
peer_addr = _SubchannelAddress(scid)
|
|
# ? f.doStart()
|
|
# ? f.startedConnecting(CONNECTOR) # ??
|
|
sc = SubChannel(scid, self._manager, self._host_addr, peer_addr)
|
|
self._manager.subchannel_local_open(scid, sc)
|
|
p = protocolFactory.buildProtocol(peer_addr)
|
|
sc._set_protocol(p)
|
|
p.makeConnection(sc) # set p.transport = sc and call connectionMade()
|
|
returnValue(p)
|
|
|
|
|
|
@implementer(IStreamServerEndpoint)
|
|
@attrs
|
|
class SubchannelListenerEndpoint(object):
|
|
_manager = attrib(validator=provides(IDilationManager))
|
|
_host_addr = attrib(validator=provides(IAddress))
|
|
_eventual_queue = attrib(repr=False)
|
|
|
|
def __attrs_post_init__(self):
|
|
self._once = Once(SingleUseEndpointError)
|
|
self._factory = None
|
|
self._pending_opens = deque()
|
|
self._wait_for_main_channel = OneShotObserver(self._eventual_queue)
|
|
|
|
# from manager (actually Inbound)
|
|
def _got_open(self, t, peer_addr):
|
|
if self._factory:
|
|
self._connect(t, peer_addr)
|
|
else:
|
|
self._pending_opens.append((t, peer_addr))
|
|
|
|
def _connect(self, t, peer_addr):
|
|
p = self._factory.buildProtocol(peer_addr)
|
|
t._set_protocol(p)
|
|
p.makeConnection(t)
|
|
t._deliver_queued_data()
|
|
|
|
def _main_channel_ready(self):
|
|
self._wait_for_main_channel.fire(None)
|
|
def _main_channel_failed(self, f):
|
|
self._wait_for_main_channel.error(f)
|
|
|
|
# IStreamServerEndpoint
|
|
|
|
@inlineCallbacks
|
|
def listen(self, protocolFactory):
|
|
self._once()
|
|
yield self._wait_for_main_channel.when_fired()
|
|
self._factory = protocolFactory
|
|
while self._pending_opens:
|
|
(t, peer_addr) = self._pending_opens.popleft()
|
|
self._connect(t, peer_addr)
|
|
lp = SubchannelListeningPort(self._host_addr)
|
|
returnValue(lp)
|
|
|
|
|
|
@implementer(IListeningPort)
|
|
@attrs
|
|
class SubchannelListeningPort(object):
|
|
_host_addr = attrib(validator=provides(IAddress))
|
|
|
|
def startListening(self):
|
|
pass
|
|
|
|
def stopListening(self):
|
|
# TODO
|
|
pass
|
|
|
|
def getHost(self):
|
|
return self._host_addr
|