This makes w.stop() the right way to shut everything down including any Dilator connections (in-progress, active, or in-shutdown).
		
			
				
	
	
		
			316 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			316 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from __future__ import print_function, absolute_import, unicode_literals
 | |
| import os
 | |
| from six.moves.urllib_parse import urlparse
 | |
| from attr import attrs, attrib
 | |
| from attr.validators import provides, instance_of, optional
 | |
| from zope.interface import implementer
 | |
| from twisted.python import log
 | |
| from twisted.internet import defer, endpoints, task
 | |
| from twisted.application import internet
 | |
| from autobahn.twisted import websocket
 | |
| from . import _interfaces, errors
 | |
| from .util import (bytes_to_hexstr, hexstr_to_bytes, bytes_to_dict,
 | |
|                    dict_to_bytes)
 | |
| 
 | |
| 
 | |
| class WSClient(websocket.WebSocketClientProtocol):
 | |
|     def onConnect(self, response):
 | |
|         # this fires during WebSocket negotiation, and isn't very useful
 | |
|         # unless you want to modify the protocol settings
 | |
|         # print("onConnect", response)
 | |
|         pass
 | |
| 
 | |
|     def onOpen(self, *args):
 | |
|         # this fires when the WebSocket is ready to go. No arguments
 | |
|         # print("onOpen", args)
 | |
|         # self.wormhole_open = True
 | |
|         self._RC.ws_open(self)
 | |
| 
 | |
|     def onMessage(self, payload, isBinary):
 | |
|         assert not isBinary
 | |
|         try:
 | |
|             self._RC.ws_message(payload)
 | |
|         except Exception:
 | |
|             from twisted.python.failure import Failure
 | |
|             print("LOGGING", Failure())
 | |
|             log.err()
 | |
|             raise
 | |
| 
 | |
|     def onClose(self, wasClean, code, reason):
 | |
|         # print("onClose")
 | |
|         self._RC.ws_close(wasClean, code, reason)
 | |
|         # if self.wormhole_open:
 | |
|         #     self.wormhole._ws_closed(wasClean, code, reason)
 | |
|         # else:
 | |
|         #     # we closed before establishing a connection (onConnect) or
 | |
|         #     # finishing WebSocket negotiation (onOpen): errback
 | |
|         #     self.factory.d.errback(error.ConnectError(reason))
 | |
| 
 | |
| 
 | |
| class WSFactory(websocket.WebSocketClientFactory):
 | |
|     protocol = WSClient
 | |
| 
 | |
|     def __init__(self, RC, *args, **kwargs):
 | |
|         websocket.WebSocketClientFactory.__init__(self, *args, **kwargs)
 | |
|         self._RC = RC
 | |
| 
 | |
|     def buildProtocol(self, addr):
 | |
|         proto = websocket.WebSocketClientFactory.buildProtocol(self, addr)
 | |
|         proto._RC = self._RC
 | |
|         # proto.wormhole_open = False
 | |
|         return proto
 | |
| 
 | |
| 
 | |
| @attrs
 | |
| @implementer(_interfaces.IRendezvousConnector)
 | |
| class RendezvousConnector(object):
 | |
|     _url = attrib(validator=instance_of(type(u"")))
 | |
|     _appid = attrib(validator=instance_of(type(u"")))
 | |
|     _side = attrib(validator=instance_of(type(u"")))
 | |
|     _reactor = attrib()
 | |
|     _journal = attrib(validator=provides(_interfaces.IJournal))
 | |
|     _tor = attrib(validator=optional(provides(_interfaces.ITorManager)))
 | |
|     _timing = attrib(validator=provides(_interfaces.ITiming))
 | |
|     _client_version = attrib(validator=instance_of(tuple))
 | |
| 
 | |
|     def __attrs_post_init__(self):
 | |
|         self._have_made_a_successful_connection = False
 | |
|         self._stopping = False
 | |
| 
 | |
|         self._trace = None
 | |
|         self._ws = None
 | |
|         f = WSFactory(self, self._url)
 | |
|         f.setProtocolOptions(autoPingInterval=60, autoPingTimeout=600)
 | |
|         p = urlparse(self._url)
 | |
|         ep = self._make_endpoint(p.hostname, p.port or 80)
 | |
|         self._connector = internet.ClientService(ep, f)
 | |
|         faf = None if self._have_made_a_successful_connection else 1
 | |
|         d = self._connector.whenConnected(failAfterFailures=faf)
 | |
|         # if the initial connection fails, signal an error and shut down. do
 | |
|         # this in a different reactor turn to avoid some hazards
 | |
|         d.addBoth(lambda res: task.deferLater(self._reactor, 0.0, lambda: res))
 | |
|         # TODO: use EventualQueue
 | |
|         d.addErrback(self._initial_connection_failed)
 | |
|         self._debug_record_inbound_f = None
 | |
| 
 | |
|     def set_trace(self, f):
 | |
|         self._trace = f
 | |
| 
 | |
|     def _debug(self, what):
 | |
|         if self._trace:
 | |
|             self._trace(old_state="", input=what, new_state="")
 | |
| 
 | |
|     def _make_endpoint(self, hostname, port):
 | |
|         if self._tor:
 | |
|             # TODO: when we enable TLS, maybe add tls=True here
 | |
|             return self._tor.stream_via(hostname, port)
 | |
|         return endpoints.HostnameEndpoint(self._reactor, hostname, port)
 | |
| 
 | |
|     def wire(self, boss, nameplate, mailbox, allocator, lister, terminator):
 | |
|         self._B = _interfaces.IBoss(boss)
 | |
|         self._N = _interfaces.INameplate(nameplate)
 | |
|         self._M = _interfaces.IMailbox(mailbox)
 | |
|         self._A = _interfaces.IAllocator(allocator)
 | |
|         self._L = _interfaces.ILister(lister)
 | |
|         self._T = _interfaces.ITerminator(terminator)
 | |
| 
 | |
|     # from Boss
 | |
|     def start(self):
 | |
|         self._connector.startService()
 | |
| 
 | |
|     # from Mailbox
 | |
|     def tx_claim(self, nameplate):
 | |
|         self._tx("claim", nameplate=nameplate)
 | |
| 
 | |
|     def tx_open(self, mailbox):
 | |
|         self._tx("open", mailbox=mailbox)
 | |
| 
 | |
|     def tx_add(self, phase, body):
 | |
|         assert isinstance(phase, type("")), type(phase)
 | |
|         assert isinstance(body, type(b"")), type(body)
 | |
|         self._tx("add", phase=phase, body=bytes_to_hexstr(body))
 | |
| 
 | |
|     def tx_release(self, nameplate):
 | |
|         self._tx("release", nameplate=nameplate)
 | |
| 
 | |
|     def tx_close(self, mailbox, mood):
 | |
|         self._tx("close", mailbox=mailbox, mood=mood)
 | |
| 
 | |
|     def stop(self):
 | |
|         # ClientService.stopService is defined to "Stop attempting to
 | |
|         # reconnect and close any existing connections"
 | |
|         self._stopping = True  # to catch _initial_connection_failed error
 | |
|         d = defer.maybeDeferred(self._connector.stopService)
 | |
|         # ClientService.stopService always fires with None, even if the
 | |
|         # initial connection failed, so log.err just in case
 | |
|         d.addErrback(log.err)
 | |
|         d.addBoth(self._stopped)
 | |
| 
 | |
|     # from Lister
 | |
|     def tx_list(self):
 | |
|         self._tx("list")
 | |
| 
 | |
|     # from Code
 | |
|     def tx_allocate(self):
 | |
|         self._tx("allocate")
 | |
| 
 | |
|     # from our ClientService
 | |
|     def _initial_connection_failed(self, f):
 | |
|         if not self._stopping:
 | |
|             sce = errors.ServerConnectionError(self._url, f.value)
 | |
|             d = defer.maybeDeferred(self._connector.stopService)
 | |
|             # this should happen right away: the ClientService ought to be in
 | |
|             # the "_waiting" state, and everything in the _waiting.stop
 | |
|             # transition is immediate
 | |
|             d.addErrback(log.err)  # just in case something goes wrong
 | |
|             d.addCallback(lambda _: self._B.error(sce))
 | |
| 
 | |
|     # from our WSClient (the WebSocket protocol)
 | |
|     def ws_open(self, proto):
 | |
|         self._debug("R.connected")
 | |
|         self._have_made_a_successful_connection = True
 | |
|         self._ws = proto
 | |
|         try:
 | |
|             self._tx(
 | |
|                 "bind",
 | |
|                 appid=self._appid,
 | |
|                 side=self._side,
 | |
|                 client_version=self._client_version)
 | |
|             self._N.connected()
 | |
|             self._M.connected()
 | |
|             self._L.connected()
 | |
|             self._A.connected()
 | |
|         except Exception as e:
 | |
|             self._B.error(e)
 | |
|             raise
 | |
|         self._debug("R.connected finished notifications")
 | |
| 
 | |
|     def ws_message(self, payload):
 | |
|         msg = bytes_to_dict(payload)
 | |
|         if msg["type"] != "ack":
 | |
|             self._debug("R.rx(%s %s%s)" % (
 | |
|                 msg["type"],
 | |
|                 msg.get("phase", ""),
 | |
|                 "[mine]" if msg.get("side", "") == self._side else "",
 | |
|             ))
 | |
| 
 | |
|         self._timing.add("ws_receive", _side=self._side, message=msg)
 | |
|         if self._debug_record_inbound_f:
 | |
|             self._debug_record_inbound_f(msg)
 | |
|         mtype = msg["type"]
 | |
|         meth = getattr(self, "_response_handle_" + mtype, None)
 | |
|         if not meth:
 | |
|             # make tests fail, but real application will ignore it
 | |
|             log.err(
 | |
|                 errors._UnknownMessageTypeError(
 | |
|                     "Unknown inbound message type %r" % (msg, )))
 | |
|             return
 | |
|         try:
 | |
|             return meth(msg)
 | |
|         except Exception as e:
 | |
|             log.err(e)
 | |
|             self._B.error(e)
 | |
|             raise
 | |
| 
 | |
|     def ws_close(self, wasClean, code, reason):
 | |
|         self._debug("R.lost")
 | |
|         was_open = bool(self._ws)
 | |
|         self._ws = None
 | |
|         # when Autobahn connects to a non-websocket server, it gets a
 | |
|         # CLOSE_STATUS_CODE_ABNORMAL_CLOSE, and delivers onClose() without
 | |
|         # ever calling onOpen first. This confuses our state machines, so
 | |
|         # avoid telling them we've lost the connection unless we'd previously
 | |
|         # told them we'd connected.
 | |
|         if was_open:
 | |
|             self._N.lost()
 | |
|             self._M.lost()
 | |
|             self._L.lost()
 | |
|             self._A.lost()
 | |
| 
 | |
|         # and if this happens on the very first connection, then we treat it
 | |
|         # as a failed initial connection, even though ClientService didn't
 | |
|         # notice it. There's a Twisted ticket (#8375) about giving
 | |
|         # ClientService an extra setup function to use, so it can tell
 | |
|         # whether post-connection negotiation was successful or not, and
 | |
|         # restart the process if it fails. That would be useful here, so that
 | |
|         # failAfterFailures=1 would do the right thing if the initial TCP
 | |
|         # connection succeeds but the first WebSocket negotiation fails.
 | |
|         if not self._have_made_a_successful_connection:
 | |
|             # shut down the ClientService, which currently thinks it has a
 | |
|             # valid connection
 | |
|             sce = errors.ServerConnectionError(self._url, reason)
 | |
|             d = defer.maybeDeferred(self._connector.stopService)
 | |
|             d.addErrback(log.err)  # just in case something goes wrong
 | |
|             # tell the Boss to quit and inform the user
 | |
|             d.addCallback(lambda _: self._B.error(sce))
 | |
| 
 | |
|     # internal
 | |
|     def _stopped(self, res):
 | |
|         self._T.stoppedRC()
 | |
| 
 | |
|     def _tx(self, mtype, **kwargs):
 | |
|         assert self._ws
 | |
|         # msgid is used by misc/dump-timing.py to correlate our sends with
 | |
|         # their receives, and vice versa. They are also correlated with the
 | |
|         # ACKs we get back from the server (which we otherwise ignore). There
 | |
|         # are so few messages, 16 bits is enough to be mostly-unique.
 | |
|         kwargs["id"] = bytes_to_hexstr(os.urandom(2))
 | |
|         kwargs["type"] = mtype
 | |
|         self._debug("R.tx(%s %s)" % (mtype.upper(), kwargs.get("phase", "")))
 | |
|         payload = dict_to_bytes(kwargs)
 | |
|         self._timing.add("ws_send", _side=self._side, **kwargs)
 | |
|         self._ws.sendMessage(payload, False)
 | |
| 
 | |
|     def _response_handle_allocated(self, msg):
 | |
|         nameplate = msg["nameplate"]
 | |
|         assert isinstance(nameplate, type("")), type(nameplate)
 | |
|         self._A.rx_allocated(nameplate)
 | |
| 
 | |
|     def _response_handle_nameplates(self, msg):
 | |
|         # we get list of {id: ID}, with maybe more attributes in the future
 | |
|         nameplates = msg["nameplates"]
 | |
|         assert isinstance(nameplates, list), type(nameplates)
 | |
|         nids = set()
 | |
|         for n in nameplates:
 | |
|             assert isinstance(n, dict), type(n)
 | |
|             nameplate_id = n["id"]
 | |
|             assert isinstance(nameplate_id, type("")), type(nameplate_id)
 | |
|             nids.add(nameplate_id)
 | |
|         # deliver a set of nameplate ids
 | |
|         self._L.rx_nameplates(nids)
 | |
| 
 | |
|     def _response_handle_ack(self, msg):
 | |
|         pass
 | |
| 
 | |
|     def _response_handle_error(self, msg):
 | |
|         # the server sent us a type=error. Most cases are due to our mistakes
 | |
|         # (malformed protocol messages, sending things in the wrong order),
 | |
|         # but it can also result from CrowdedError (more than two clients
 | |
|         # using the same channel).
 | |
|         err = msg["error"]
 | |
|         orig = msg["orig"]
 | |
|         self._B.rx_error(err, orig)
 | |
| 
 | |
|     def _response_handle_welcome(self, msg):
 | |
|         self._B.rx_welcome(msg["welcome"])
 | |
| 
 | |
|     def _response_handle_claimed(self, msg):
 | |
|         mailbox = msg["mailbox"]
 | |
|         assert isinstance(mailbox, type("")), type(mailbox)
 | |
|         self._N.rx_claimed(mailbox)
 | |
| 
 | |
|     def _response_handle_message(self, msg):
 | |
|         side = msg["side"]
 | |
|         phase = msg["phase"]
 | |
|         assert isinstance(phase, type("")), type(phase)
 | |
|         body = hexstr_to_bytes(msg["body"])  # bytes
 | |
|         self._M.rx_message(side, phase, body)
 | |
| 
 | |
|     def _response_handle_released(self, msg):
 | |
|         self._N.rx_released()
 | |
| 
 | |
|     def _response_handle_closed(self, msg):
 | |
|         self._M.rx_closed()
 | |
| 
 | |
|     # record, message, payload, packet, bundle, ciphertext, plaintext
 |