remove old files, lots of type work
This commit is contained in:
parent
3101ca51db
commit
a2ed35ceb8
|
@ -18,11 +18,13 @@ digraph {
|
|||
Connection -> websocket [color="blue"]
|
||||
#Connection -> Order [color="blue"]
|
||||
|
||||
App -> Wormhole [style="dashed" label="set_code\nsend\nclose\n(once)"]
|
||||
App -> Wormhole [style="dashed" label="start\nset_code\nsend\nclose\n(once)"]
|
||||
#App -> Wormhole [color="blue"]
|
||||
Wormhole -> App [style="dashed" label="got_verifier\nreceived\nclosed\n(once)"]
|
||||
|
||||
#Wormhole -> Connection [color="blue"]
|
||||
Wormhole -> Connection [style="dashed" label="start"]
|
||||
Connection -> Wormhole [style="dashed" label="rx_welcome"]
|
||||
|
||||
Wormhole -> Send [style="dashed" label="send"]
|
||||
|
||||
|
@ -75,8 +77,6 @@ digraph {
|
|||
label="set_code"]
|
||||
App -> Code [style="dashed"
|
||||
label="allocate\ninput\nset"]
|
||||
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
|
|
@ -1,180 +0,0 @@
|
|||
from zope.interface import Interface
|
||||
from six.moves.urllib_parse import urlparse
|
||||
from attr import attrs, attrib
|
||||
from twisted.internet import defer, endpoints #, error
|
||||
from twisted.application import internet, service
|
||||
from autobahn.twisted import websocket
|
||||
from automat import MethodicalMachine
|
||||
|
||||
class WSClient(websocket.WebSocketClientProtocol):
|
||||
def onConnect(self, response):
|
||||
# this fires during WebSocket negotiation, and isn't very useful
|
||||
# unless you want to modify the protocol settings
|
||||
print("onConnect", response)
|
||||
#self.connection_machine.onConnect(self)
|
||||
|
||||
def onOpen(self, *args):
|
||||
# this fires when the WebSocket is ready to go. No arguments
|
||||
print("onOpen", args)
|
||||
#self.wormhole_open = True
|
||||
# send BIND, since the MailboxMachine does not
|
||||
self.connection_machine.protocol_onOpen(self)
|
||||
#self.factory.d.callback(self)
|
||||
|
||||
def onMessage(self, payload, isBinary):
|
||||
print("onMessage")
|
||||
return
|
||||
assert not isBinary
|
||||
self.wormhole._ws_dispatch_response(payload)
|
||||
|
||||
def onClose(self, wasClean, code, reason):
|
||||
print("onClose")
|
||||
self.connection_machine.protocol_onClose(wasClean, code, reason)
|
||||
#if self.wormhole_open:
|
||||
# self.wormhole._ws_closed(wasClean, code, reason)
|
||||
#else:
|
||||
# # we closed before establishing a connection (onConnect) or
|
||||
# # finishing WebSocket negotiation (onOpen): errback
|
||||
# self.factory.d.errback(error.ConnectError(reason))
|
||||
|
||||
class WSFactory(websocket.WebSocketClientFactory):
|
||||
protocol = WSClient
|
||||
def buildProtocol(self, addr):
|
||||
proto = websocket.WebSocketClientFactory.buildProtocol(self, addr)
|
||||
proto.connection_machine = self.connection_machine
|
||||
#proto.wormhole_open = False
|
||||
return proto
|
||||
|
||||
# pip install (path to automat checkout)[visualize]
|
||||
# automat-visualize wormhole._connection
|
||||
|
||||
class IRendezvousClient(Interface):
|
||||
# must be an IService too
|
||||
def set_dispatch(dispatcher):
|
||||
"""Assign a dispatcher object to this client. The following methods
|
||||
will be called on this object when things happen:
|
||||
* rx_welcome(welcome -> dict)
|
||||
* rx_nameplates(nameplates -> list) # [{id: str,..}, ..]
|
||||
* rx_allocated(nameplate -> str)
|
||||
* rx_claimed(mailbox -> str)
|
||||
* rx_released()
|
||||
* rx_message(side -> str, phase -> str, body -> str, msg_id -> str)
|
||||
* rx_closed()
|
||||
* rx_pong(pong -> int)
|
||||
"""
|
||||
pass
|
||||
def tx_list(): pass
|
||||
def tx_allocate(): pass
|
||||
def tx_claim(nameplate): pass
|
||||
def tx_release(): pass
|
||||
def tx_open(mailbox): pass
|
||||
def tx_add(phase, body): pass
|
||||
def tx_close(mood): pass
|
||||
def tx_ping(ping): pass
|
||||
|
||||
# We have one WSRelayClient for each wsurl we know about, and it lasts
|
||||
# as long as its parent Wormhole does.
|
||||
|
||||
@attrs
|
||||
class WSRelayClient(service.MultiService, object):
|
||||
_journal = attrib()
|
||||
_wormhole = attrib()
|
||||
_mailbox = attrib()
|
||||
_ws_url = attrib()
|
||||
_reactor = attrib()
|
||||
|
||||
def __init__(self):
|
||||
f = WSFactory(self._ws_url)
|
||||
f.setProtocolOptions(autoPingInterval=60, autoPingTimeout=600)
|
||||
f.connection_machine = self # calls onOpen and onClose
|
||||
p = urlparse(self._ws_url)
|
||||
ep = self._make_endpoint(p.hostname, p.port or 80)
|
||||
# default policy: 1s initial, random exponential backoff, max 60s
|
||||
self._client_service = internet.ClientService(ep, f)
|
||||
self._connector = None
|
||||
self._done_d = defer.Deferred()
|
||||
self._current_delay = self.INITIAL_DELAY
|
||||
|
||||
def _make_endpoint(self, hostname, port):
|
||||
return endpoints.HostnameEndpoint(self._reactor, hostname, port)
|
||||
|
||||
# inputs from elsewhere
|
||||
def d_callback(self, p):
|
||||
self._p = p
|
||||
self._m.d_callback()
|
||||
def d_errback(self, f):
|
||||
self._f = f
|
||||
self._m.d_errback()
|
||||
def protocol_onOpen(self, p):
|
||||
self._m.onOpen()
|
||||
def protocol_onClose(self, wasClean, code, reason):
|
||||
self._m.onClose()
|
||||
def C_stop(self):
|
||||
self._m.stop()
|
||||
def timer_expired(self):
|
||||
self._m.expire()
|
||||
|
||||
# outputs driven by the state machine
|
||||
def ep_connect(self):
|
||||
print("ep_connect()")
|
||||
self._d = self._ep.connect(self._f)
|
||||
self._d.addCallbacks(self.d_callback, self.d_errback)
|
||||
def connection_established(self):
|
||||
self._connection = WSConnection(ws, self._wormhole.appid,
|
||||
self._wormhole.side, self)
|
||||
self._mailbox.connected(ws)
|
||||
self._wormhole.add_connection(self._connection)
|
||||
self._ws_send_command("bind", appid=self._appid, side=self._side)
|
||||
def M_lost(self):
|
||||
self._wormhole.M_lost(self._connection)
|
||||
self._connection = None
|
||||
def start_timer(self):
|
||||
print("start_timer")
|
||||
self._t = self._reactor.callLater(3.0, self.expire)
|
||||
def cancel_timer(self):
|
||||
print("cancel_timer")
|
||||
self._t.cancel()
|
||||
self._t = None
|
||||
def dropConnection(self):
|
||||
print("dropConnection")
|
||||
self._ws.dropConnection()
|
||||
def notify_fail(self):
|
||||
print("notify_fail", self._f.value if self._f else None)
|
||||
self._done_d.errback(self._f)
|
||||
def MC_stopped(self):
|
||||
pass
|
||||
|
||||
|
||||
def tryit(reactor):
|
||||
cm = WSRelayClient(None, "ws://127.0.0.1:4000/v1", reactor)
|
||||
print("_ConnectionMachine created")
|
||||
print("start:", cm.start())
|
||||
print("waiting on _done_d to finish")
|
||||
return cm._done_d
|
||||
|
||||
# http://autobahn-python.readthedocs.io/en/latest/websocket/programming.html
|
||||
# observed sequence of events:
|
||||
# success: d_callback, onConnect(response), onOpen(), onMessage()
|
||||
# negotifail (non-websocket): d_callback, onClose()
|
||||
# noconnect: d_errback
|
||||
|
||||
def tryws(reactor):
|
||||
ws_url = "ws://127.0.0.1:40001/v1"
|
||||
f = WSFactory(ws_url)
|
||||
p = urlparse(ws_url)
|
||||
ep = endpoints.HostnameEndpoint(reactor, p.hostname, p.port or 80)
|
||||
d = ep.connect(f)
|
||||
def _good(p): print("_good", p)
|
||||
def _bad(f): print("_bad", f)
|
||||
d.addCallbacks(_good, _bad)
|
||||
return defer.Deferred()
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
from twisted.python import log
|
||||
log.startLogging(sys.stdout)
|
||||
from twisted.internet.task import react
|
||||
react(tryit)
|
||||
|
||||
# ??? a new WSConnection is created each time the WSRelayClient gets through
|
||||
# negotiation
|
|
@ -18,3 +18,9 @@ class INameplateLister(Interface):
|
|||
pass
|
||||
class ICode(Interface):
|
||||
pass
|
||||
|
||||
class ITiming(Interface):
|
||||
pass
|
||||
|
||||
class IJournal(Interface): # TODO: this needs to be public
|
||||
pass
|
||||
|
|
|
@ -4,8 +4,10 @@ from spake2 import SPAKE2_Symmetric
|
|||
from hkdf import Hkdf
|
||||
from nacl.secret import SecretBox
|
||||
from nacl.exceptions import CryptoError
|
||||
from nacl import utils
|
||||
from automat import MethodicalMachine
|
||||
from .util import (to_bytes, bytes_to_hexstr, hexstr_to_bytes)
|
||||
from .util import (to_bytes, bytes_to_hexstr, hexstr_to_bytes,
|
||||
bytes_to_dict, dict_to_bytes)
|
||||
from . import _interfaces
|
||||
CryptoError
|
||||
__all__ = ["derive_key", "derive_phase_key", "CryptoError",
|
||||
|
@ -38,6 +40,14 @@ def decrypt_data(key, encrypted):
|
|||
data = box.decrypt(encrypted)
|
||||
return data
|
||||
|
||||
def encrypt_data(key, plaintext):
|
||||
assert isinstance(key, type(b"")), type(key)
|
||||
assert isinstance(plaintext, type(b"")), type(plaintext)
|
||||
assert len(key) == SecretBox.KEY_SIZE, len(key)
|
||||
box = SecretBox(key)
|
||||
nonce = utils.random(SecretBox.NONCE_SIZE)
|
||||
return box.encrypt(plaintext, nonce)
|
||||
|
||||
@implementer(_interfaces.IKey)
|
||||
class Key(object):
|
||||
m = MethodicalMachine()
|
||||
|
@ -57,7 +67,9 @@ class Key(object):
|
|||
@m.state(terminal=True)
|
||||
def S3_scared(self): pass
|
||||
|
||||
def got_pake(self, payload):
|
||||
def got_pake(self, body):
|
||||
assert isinstance(body, type(b"")), type(body)
|
||||
payload = bytes_to_dict(body)
|
||||
if "pake_v1" in payload:
|
||||
self.got_pake_good(hexstr_to_bytes(payload["pake_v1"]))
|
||||
else:
|
||||
|
@ -76,7 +88,8 @@ class Key(object):
|
|||
self._sp = SPAKE2_Symmetric(to_bytes(code),
|
||||
idSymmetric=to_bytes(self._appid))
|
||||
msg1 = self._sp.start()
|
||||
self._M.add_message("pake", {"pake_v1": bytes_to_hexstr(msg1)})
|
||||
body = dict_to_bytes({"pake_v1": bytes_to_hexstr(msg1)})
|
||||
self._M.add_message("pake", body)
|
||||
|
||||
@m.output()
|
||||
def scared(self):
|
||||
|
|
|
@ -10,6 +10,9 @@ class Mailbox(object):
|
|||
self._side = side
|
||||
self._mood = None
|
||||
self._nameplate = None
|
||||
self._mailbox = None
|
||||
self._pending_outbound = {}
|
||||
self._processed = set()
|
||||
|
||||
def wire(self, wormhole, rendezvous_connector, ordering):
|
||||
self._W = _interfaces.IWormhole(wormhole)
|
||||
|
@ -101,15 +104,18 @@ class Mailbox(object):
|
|||
@m.input()
|
||||
def rx_claimed(self, mailbox): pass
|
||||
|
||||
def rx_message(self, side, phase, msg):
|
||||
def rx_message(self, side, phase, body):
|
||||
assert isinstance(side, type("")), type(side)
|
||||
assert isinstance(phase, type("")), type(phase)
|
||||
assert isinstance(body, type(b"")), type(body)
|
||||
if side == self._side:
|
||||
self.rx_message_ours(phase, msg)
|
||||
self.rx_message_ours(phase, body)
|
||||
else:
|
||||
self.rx_message_theirs(phase, msg)
|
||||
self.rx_message_theirs(phase, body)
|
||||
@m.input()
|
||||
def rx_message_ours(self, phase, msg): pass
|
||||
def rx_message_ours(self, phase, body): pass
|
||||
@m.input()
|
||||
def rx_message_theirs(self, phase, msg): pass
|
||||
def rx_message_theirs(self, phase, body): pass
|
||||
@m.input()
|
||||
def rx_released(self): pass
|
||||
@m.input()
|
||||
|
@ -119,7 +125,7 @@ class Mailbox(object):
|
|||
|
||||
# from Send or Key
|
||||
@m.input()
|
||||
def add_message(self, phase, msg): pass
|
||||
def add_message(self, phase, body): pass
|
||||
|
||||
|
||||
@m.output()
|
||||
|
@ -138,8 +144,10 @@ class Mailbox(object):
|
|||
assert self._mailbox
|
||||
self._RC.tx_open(self._mailbox)
|
||||
@m.output()
|
||||
def queue(self, phase, msg):
|
||||
self._pending_outbound[phase] = msg
|
||||
def queue(self, phase, body):
|
||||
assert isinstance(phase, type("")), type(phase)
|
||||
assert isinstance(body, type(b"")), type(body)
|
||||
self._pending_outbound[phase] = body
|
||||
@m.output()
|
||||
def store_mailbox_and_RC_tx_open_and_drain(self, mailbox):
|
||||
self._mailbox = mailbox
|
||||
|
@ -149,18 +157,20 @@ class Mailbox(object):
|
|||
def drain(self):
|
||||
self._drain()
|
||||
def _drain(self):
|
||||
for phase, msg in self._pending_outbound.items():
|
||||
self._RC.tx_add(phase, msg)
|
||||
for phase, body in self._pending_outbound.items():
|
||||
self._RC.tx_add(phase, body)
|
||||
@m.output()
|
||||
def RC_tx_add(self, phase, msg):
|
||||
self._RC.tx_add(phase, msg)
|
||||
def RC_tx_add(self, phase, body):
|
||||
assert isinstance(phase, type("")), type(phase)
|
||||
assert isinstance(body, type(b"")), type(body)
|
||||
self._RC.tx_add(phase, body)
|
||||
@m.output()
|
||||
def RC_tx_release(self):
|
||||
self._RC.tx_release()
|
||||
@m.output()
|
||||
def RC_tx_release_and_accept(self, phase, msg):
|
||||
def RC_tx_release_and_accept(self, phase, body):
|
||||
self._RC.tx_release()
|
||||
self._accept(phase, msg)
|
||||
self._accept(phase, body)
|
||||
@m.output()
|
||||
def record_mood_and_RC_tx_release(self, mood):
|
||||
self._mood = mood
|
||||
|
@ -179,14 +189,14 @@ class Mailbox(object):
|
|||
self._mood = mood
|
||||
self._RC.tx_close(self._mood)
|
||||
@m.output()
|
||||
def accept(self, phase, msg):
|
||||
self._accept(phase, msg)
|
||||
def _accept(self, phase, msg):
|
||||
def accept(self, phase, body):
|
||||
self._accept(phase, body)
|
||||
def _accept(self, phase, body):
|
||||
if phase not in self._processed:
|
||||
self._O.got_message(phase, msg)
|
||||
self._O.got_message(phase, body)
|
||||
self._processed.add(phase)
|
||||
@m.output()
|
||||
def dequeue(self, phase, msg):
|
||||
def dequeue(self, phase, body):
|
||||
self._pending_outbound.pop(phase)
|
||||
@m.output()
|
||||
def record_mood(self, mood):
|
||||
|
|
|
@ -19,36 +19,40 @@ class Order(object):
|
|||
@m.state(terminal=True)
|
||||
def S1_yes_pake(self): pass
|
||||
|
||||
def got_message(self, phase, payload):
|
||||
def got_message(self, phase, body):
|
||||
assert isinstance(phase, type("")), type(phase)
|
||||
assert isinstance(body, type(b"")), type(body)
|
||||
if phase == "pake":
|
||||
self.got_pake(phase, payload)
|
||||
self.got_pake(phase, body)
|
||||
else:
|
||||
self.got_non_pake(phase, payload)
|
||||
self.got_non_pake(phase, body)
|
||||
|
||||
@m.input()
|
||||
def got_pake(self, phase, payload): pass
|
||||
def got_pake(self, phase, body): pass
|
||||
@m.input()
|
||||
def got_non_pake(self, phase, payload): pass
|
||||
def got_non_pake(self, phase, body): pass
|
||||
|
||||
@m.output()
|
||||
def queue(self, phase, payload):
|
||||
self._queue.append((phase, payload))
|
||||
def queue(self, phase, body):
|
||||
assert isinstance(phase, type("")), type(phase)
|
||||
assert isinstance(body, type(b"")), type(body)
|
||||
self._queue.append((phase, body))
|
||||
@m.output()
|
||||
def notify_key(self, phase, payload):
|
||||
self._K.got_pake(payload)
|
||||
def notify_key(self, phase, body):
|
||||
self._K.got_pake(body)
|
||||
@m.output()
|
||||
def drain(self, phase, payload):
|
||||
def drain(self, phase, body):
|
||||
del phase
|
||||
del payload
|
||||
for (phase, payload) in self._queue:
|
||||
self._deliver(phase, payload)
|
||||
del body
|
||||
for (phase, body) in self._queue:
|
||||
self._deliver(phase, body)
|
||||
self._queue[:] = []
|
||||
@m.output()
|
||||
def deliver(self, phase, payload):
|
||||
self._deliver(phase, payload)
|
||||
def deliver(self, phase, body):
|
||||
self._deliver(phase, body)
|
||||
|
||||
def _deliver(self, phase, payload):
|
||||
self._R.got_message(phase, payload)
|
||||
def _deliver(self, phase, body):
|
||||
self._R.got_message(phase, body)
|
||||
|
||||
S0_no_pake.upon(got_non_pake, enter=S0_no_pake, outputs=[queue])
|
||||
S0_no_pake.upon(got_pake, enter=S1_yes_pake, outputs=[notify_key, drain])
|
||||
|
|
|
@ -24,7 +24,9 @@ class Receive(object):
|
|||
@m.state(terminal=True)
|
||||
def S3_scared(self): pass
|
||||
|
||||
def got_message(self, phase, payload):
|
||||
def got_message(self, phase, body):
|
||||
assert isinstance(phase, type("")), type(phase)
|
||||
assert isinstance(body, type(b"")), type(body)
|
||||
assert self._key
|
||||
data_key = derive_phase_key(self._side, phase)
|
||||
try:
|
||||
|
@ -53,6 +55,8 @@ class Receive(object):
|
|||
self._W.happy()
|
||||
@m.output()
|
||||
def W_got_message(self, phase, plaintext):
|
||||
assert isinstance(phase, type("")), type(phase)
|
||||
assert isinstance(plaintext, type(b"")), type(plaintext)
|
||||
self._W.got_message(phase, plaintext)
|
||||
@m.output()
|
||||
def W_scared(self):
|
||||
|
|
|
@ -1,39 +1,200 @@
|
|||
import os
|
||||
from six.moves.urllib_parse import urlparse
|
||||
from attr import attrs, attrib
|
||||
from attr.validators import provides, instance_of
|
||||
from zope.interface import implementer
|
||||
from twisted.application import service
|
||||
from twisted.python import log
|
||||
from twisted.internet import defer, endpoints
|
||||
from twisted.application import internet
|
||||
from autobahn.twisted import websocket
|
||||
from . import _interfaces
|
||||
from .util import (bytes_to_hexstr, hexstr_to_bytes,
|
||||
bytes_to_dict, dict_to_bytes)
|
||||
|
||||
class WSClient(websocket.WebSocketClientProtocol):
|
||||
def onConnect(self, response):
|
||||
# this fires during WebSocket negotiation, and isn't very useful
|
||||
# unless you want to modify the protocol settings
|
||||
#print("onConnect", response)
|
||||
pass
|
||||
|
||||
def onOpen(self, *args):
|
||||
# this fires when the WebSocket is ready to go. No arguments
|
||||
#print("onOpen", args)
|
||||
#self.wormhole_open = True
|
||||
self._RC.ws_open(self)
|
||||
|
||||
def onMessage(self, payload, isBinary):
|
||||
#print("onMessage")
|
||||
assert not isBinary
|
||||
self._RC.ws_message(payload)
|
||||
|
||||
def onClose(self, wasClean, code, reason):
|
||||
#print("onClose")
|
||||
self._RC.ws_close(wasClean, code, reason)
|
||||
#if self.wormhole_open:
|
||||
# self.wormhole._ws_closed(wasClean, code, reason)
|
||||
#else:
|
||||
# # we closed before establishing a connection (onConnect) or
|
||||
# # finishing WebSocket negotiation (onOpen): errback
|
||||
# self.factory.d.errback(error.ConnectError(reason))
|
||||
|
||||
class WSFactory(websocket.WebSocketClientFactory):
|
||||
protocol = WSClient
|
||||
def __init__(self, RC, *args, **kwargs):
|
||||
websocket.WebSocketClientFactory.__init__(self, *args, **kwargs)
|
||||
self._RC = RC
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
proto = websocket.WebSocketClientFactory.buildProtocol(self, addr)
|
||||
proto._RC = self._RC
|
||||
#proto.wormhole_open = False
|
||||
return proto
|
||||
|
||||
@attrs
|
||||
@implementer(_interfaces.IRendezvousConnector)
|
||||
class RendezvousConnector(service.MultiService, object):
|
||||
def __init__(self, journal, timing):
|
||||
self._journal = journal
|
||||
self._timing = timing
|
||||
class RendezvousConnector(object):
|
||||
_url = attrib(instance_of(type(u"")))
|
||||
_appid = attrib(instance_of(type(u"")))
|
||||
_side = attrib(instance_of(type(u"")))
|
||||
_reactor = attrib()
|
||||
_journal = attrib(provides(_interfaces.IJournal))
|
||||
_timing = attrib(provides(_interfaces.ITiming))
|
||||
|
||||
def wire(self, mailbox, code, nameplate_lister):
|
||||
def __init__(self):
|
||||
self._ws = None
|
||||
f = WSFactory(self, self._url)
|
||||
f.setProtocolOptions(autoPingInterval=60, autoPingTimeout=600)
|
||||
p = urlparse(self._url)
|
||||
ep = self._make_endpoint(p.hostname, p.port or 80)
|
||||
self._connector = internet.ClientService(ep, f)
|
||||
|
||||
def _make_endpoint(self, hostname, port):
|
||||
# TODO: Tor goes here
|
||||
return endpoints.HostnameEndpoint(self._reactor, hostname, port)
|
||||
|
||||
def wire(self, wormhole, mailbox, code, nameplate_lister):
|
||||
self._W = _interfaces.IWormhole(wormhole)
|
||||
self._M = _interfaces.IMailbox(mailbox)
|
||||
self._C = _interfaces.ICode(code)
|
||||
self._NL = _interfaces.INameplateLister(nameplate_lister)
|
||||
|
||||
# from Wormhole
|
||||
def start(self):
|
||||
self._connector.startService()
|
||||
|
||||
# from Mailbox
|
||||
def tx_claim(self):
|
||||
pass
|
||||
def tx_open(self):
|
||||
pass
|
||||
def tx_add(self, x):
|
||||
pass
|
||||
def tx_claim(self, nameplate):
|
||||
self._tx("claim", nameplate=nameplate)
|
||||
|
||||
def tx_open(self, mailbox):
|
||||
self._tx("open", mailbox=mailbox)
|
||||
|
||||
def tx_add(self, phase, body):
|
||||
assert isinstance(phase, type("")), type(phase)
|
||||
assert isinstance(body, type(b"")), type(body)
|
||||
self._tx("add", phase=phase, body=bytes_to_hexstr(body))
|
||||
|
||||
def tx_release(self):
|
||||
pass
|
||||
self._tx("release")
|
||||
|
||||
def tx_close(self, mood):
|
||||
pass
|
||||
self._tx("close", mood=mood)
|
||||
|
||||
def stop(self):
|
||||
pass
|
||||
d = defer.maybeDeferred(self._connector.stopService)
|
||||
d.addErrback(log.err) # TODO: deliver error upstairs?
|
||||
d.addBoth(self._stopped)
|
||||
|
||||
|
||||
# from NameplateLister
|
||||
def tx_list(self):
|
||||
pass
|
||||
self._tx("list")
|
||||
|
||||
# from Code
|
||||
def tx_allocate(self):
|
||||
self._tx("allocate")
|
||||
|
||||
# from our WSClient (the WebSocket protocol)
|
||||
def ws_open(self, proto):
|
||||
self._ws = proto
|
||||
self._tx("bind", appid=self._appid, side=self._side)
|
||||
self._M.connected()
|
||||
self._NL.connected()
|
||||
|
||||
def ws_message(self, payload):
|
||||
msg = bytes_to_dict(payload)
|
||||
if self.DEBUG and msg["type"]!="ack": print("DIS", msg["type"], msg)
|
||||
self._timing.add("ws_receive", _side=self._side, message=msg)
|
||||
mtype = msg["type"]
|
||||
meth = getattr(self, "_response_handle_"+mtype, None)
|
||||
if not meth:
|
||||
# make tests fail, but real application will ignore it
|
||||
log.err(ValueError("Unknown inbound message type %r" % (msg,)))
|
||||
return
|
||||
return meth(msg)
|
||||
|
||||
def ws_close(self, wasClean, code, reason):
|
||||
self._ws = None
|
||||
self._M.lost()
|
||||
self._NL.lost()
|
||||
|
||||
# internal
|
||||
def _stopped(self, res):
|
||||
self._M.stopped()
|
||||
|
||||
def _tx(self, mtype, **kwargs):
|
||||
assert self._ws
|
||||
# msgid is used by misc/dump-timing.py to correlate our sends with
|
||||
# their receives, and vice versa. They are also correlated with the
|
||||
# ACKs we get back from the server (which we otherwise ignore). There
|
||||
# are so few messages, 16 bits is enough to be mostly-unique.
|
||||
if self.DEBUG: print("SEND", mtype)
|
||||
kwargs["id"] = bytes_to_hexstr(os.urandom(2))
|
||||
kwargs["type"] = mtype
|
||||
payload = dict_to_bytes(kwargs)
|
||||
self._timing.add("ws_send", _side=self._side, **kwargs)
|
||||
self._ws.sendMessage(payload, False)
|
||||
|
||||
def _response_handle_allocated(self, msg):
|
||||
nameplate = msg["nameplate"]
|
||||
assert isinstance(nameplate, type("")), type(nameplate)
|
||||
self._C.rx_allocated(nameplate)
|
||||
|
||||
def _response_handle_nameplates(self, msg):
|
||||
nameplates = msg["nameplates"]
|
||||
assert isinstance(nameplates, list), type(nameplates)
|
||||
nids = []
|
||||
for n in nameplates:
|
||||
assert isinstance(n, dict), type(n)
|
||||
nameplate_id = n["id"]
|
||||
assert isinstance(nameplate_id, type("")), type(nameplate_id)
|
||||
nids.append(nameplate_id)
|
||||
self._NL.rx_nameplates(nids)
|
||||
|
||||
def _response_handle_ack(self, msg):
|
||||
pass
|
||||
|
||||
# record, message, payload, packet, bundle, ciphertext, plaintext
|
||||
|
||||
def _response_handle_welcome(self, msg):
|
||||
self._W.rx_welcome(msg["welcome"])
|
||||
|
||||
def _response_handle_claimed(self, msg):
|
||||
mailbox = msg["mailbox"]
|
||||
assert isinstance(mailbox, type("")), type(mailbox)
|
||||
self._M.rx_claimed(mailbox)
|
||||
|
||||
def _response_handle_message(self, msg):
|
||||
side = msg["side"]
|
||||
phase = msg["phase"]
|
||||
assert isinstance(phase, type("")), type(phase)
|
||||
body = hexstr_to_bytes(msg["body"]) # bytes
|
||||
self._M.rx_message(side, phase, body)
|
||||
|
||||
def _response_handle_released(self, msg):
|
||||
self._M.rx_released()
|
||||
|
||||
def _response_handle_closed(self, msg):
|
||||
self._M.rx_closed()
|
||||
|
||||
|
||||
# record, message, payload, packet, bundle, ciphertext, plaintext
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from zope.interface import implementer
|
||||
from automat import MethodicalMachine
|
||||
from . import _interfaces
|
||||
from .util import hexstr_to_bytes
|
||||
from ._key import derive_phase_key, encrypt_data
|
||||
|
||||
@implementer(_interfaces.ISend)
|
||||
class Send(object):
|
||||
|
@ -17,36 +17,34 @@ class Send(object):
|
|||
@m.state(terminal=True)
|
||||
def S1_verified_key(self): pass
|
||||
|
||||
def got_pake(self, payload):
|
||||
if "pake_v1" in payload:
|
||||
self.got_pake_good(hexstr_to_bytes(payload["pake_v1"]))
|
||||
else:
|
||||
self.got_pake_bad()
|
||||
|
||||
@m.input()
|
||||
def got_verified_key(self, key): pass
|
||||
@m.input()
|
||||
def send(self, phase, payload): pass
|
||||
def send(self, phase, plaintext): pass
|
||||
|
||||
@m.output()
|
||||
def queue(self, phase, payload):
|
||||
self._queue.append((phase, payload))
|
||||
def queue(self, phase, plaintext):
|
||||
assert isinstance(phase, type("")), type(phase)
|
||||
assert isinstance(plaintext, type(b"")), type(plaintext)
|
||||
self._queue.append((phase, plaintext))
|
||||
@m.output()
|
||||
def record_key(self, key):
|
||||
self._key = key
|
||||
@m.output()
|
||||
def drain(self, key):
|
||||
del key
|
||||
for (phase, payload) in self._queue:
|
||||
self._encrypt_and_send(phase, payload)
|
||||
for (phase, plaintext) in self._queue:
|
||||
self._encrypt_and_send(phase, plaintext)
|
||||
self._queue[:] = []
|
||||
@m.output()
|
||||
def deliver(self, phase, payload):
|
||||
self._encrypt_and_send(phase, payload)
|
||||
def deliver(self, phase, plaintext):
|
||||
assert isinstance(phase, type("")), type(phase)
|
||||
assert isinstance(plaintext, type(b"")), type(plaintext)
|
||||
self._encrypt_and_send(phase, plaintext)
|
||||
|
||||
def _encrypt_and_send(self, phase, payload):
|
||||
data_key = self._derive_phase_key(self._side, phase)
|
||||
encrypted = self._encrypt_data(data_key, plaintext)
|
||||
def _encrypt_and_send(self, phase, plaintext):
|
||||
data_key = derive_phase_key(self._side, phase)
|
||||
encrypted = encrypt_data(data_key, plaintext)
|
||||
self._M.add_message(phase, encrypted)
|
||||
|
||||
S0_no_key.upon(send, enter=S0_no_key, outputs=[queue])
|
||||
|
|
|
@ -9,6 +9,7 @@ from ._receive import Receive
|
|||
from ._rendezvous import RendezvousConnector
|
||||
from ._nameplate import NameplateListing
|
||||
from ._code import Code
|
||||
from .util import bytes_to_dict
|
||||
|
||||
@implementer(_interfaces.IWormhole)
|
||||
class Wormhole:
|
||||
|
@ -31,13 +32,13 @@ class Wormhole:
|
|||
self._O.wire(self._K, self._R)
|
||||
self._K.wire(self, self._M, self._R)
|
||||
self._R.wire(self, self._K, self._S)
|
||||
self._RC.wire(self._M, self._C, self._NL)
|
||||
self._RC.wire(self, self._M, self._C, self._NL)
|
||||
self._NL.wire(self._RC, self._C)
|
||||
self._C.wire(self, self._RC, self._NL)
|
||||
|
||||
# these methods are called from outside
|
||||
def start(self):
|
||||
self._relay_client.start()
|
||||
self._RC.start()
|
||||
|
||||
# and these are the state-machine transition functions, which don't take
|
||||
# args
|
||||
|
@ -54,7 +55,7 @@ class Wormhole:
|
|||
|
||||
# from the Application, or some sort of top-level shim
|
||||
@m.input()
|
||||
def send(self, phase, message): pass
|
||||
def send(self, phase, plaintext): pass
|
||||
@m.input()
|
||||
def close(self): pass
|
||||
|
||||
|
@ -69,6 +70,8 @@ class Wormhole:
|
|||
@m.input()
|
||||
def scared(self): pass
|
||||
def got_message(self, phase, plaintext):
|
||||
assert isinstance(phase, type("")), type(phase)
|
||||
assert isinstance(plaintext, type(b"")), type(plaintext)
|
||||
if phase == "version":
|
||||
self.got_version(plaintext)
|
||||
else:
|
||||
|
@ -91,12 +94,13 @@ class Wormhole:
|
|||
self._M.set_nameplate(nameplate)
|
||||
self._K.set_code(code)
|
||||
@m.output()
|
||||
def process_version(self, version): # response["message"][phase=version]
|
||||
pass
|
||||
def process_version(self, plaintext):
|
||||
self._their_versions = bytes_to_dict(plaintext)
|
||||
# ignored for now
|
||||
|
||||
@m.output()
|
||||
def S_send(self, phase, message):
|
||||
self._S.send(phase, message)
|
||||
def S_send(self, phase, plaintext):
|
||||
self._S.send(phase, plaintext)
|
||||
|
||||
@m.output()
|
||||
def close_scared(self):
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
from zope.interface import implementer
|
||||
import contextlib
|
||||
from _interfaces import IJournal
|
||||
|
||||
class JournalManager(object):
|
||||
@implementer(IJournal)
|
||||
class Journal(object):
|
||||
def __init__(self, save_checkpoint):
|
||||
self._save_checkpoint = save_checkpoint
|
||||
self._outbound_queue = []
|
||||
|
@ -8,7 +11,7 @@ class JournalManager(object):
|
|||
|
||||
def queue_outbound(self, fn, *args, **kwargs):
|
||||
assert self._processing
|
||||
self._outbound_queue.append((fn, args, kwargs))
|
||||
self._outbound_queue.append((fn, args, kwargs)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def process(self):
|
||||
|
@ -21,3 +24,12 @@ class JournalManager(object):
|
|||
fn(*args, **kwargs)
|
||||
self._outbound_queue[:] = []
|
||||
self._processing = False
|
||||
|
||||
|
||||
@implementer(IJournal)
|
||||
class ImmediateJournal(object):
|
||||
def queue_outbound(self, fn, *args, **kwargs):
|
||||
fn(*args, **kwargs)
|
||||
@contextlib.contextmanager
|
||||
def process(self):
|
||||
yield
|
||||
|
|
|
@ -1,912 +1,5 @@
|
|||
from __future__ import print_function, absolute_import, unicode_literals
|
||||
import os, sys, re
|
||||
from six.moves.urllib_parse import urlparse
|
||||
from twisted.internet import defer, endpoints #, error
|
||||
from twisted.internet.threads import deferToThread, blockingCallFromThread
|
||||
from twisted.internet.defer import inlineCallbacks, returnValue
|
||||
from twisted.python import log, failure
|
||||
from nacl.secret import SecretBox
|
||||
from nacl.exceptions import CryptoError
|
||||
from nacl import utils
|
||||
from spake2 import SPAKE2_Symmetric
|
||||
from hashlib import sha256
|
||||
from . import __version__
|
||||
from . import codes
|
||||
#from .errors import ServerError, Timeout
|
||||
from .errors import (WrongPasswordError, InternalError, WelcomeError,
|
||||
WormholeClosedError, KeyFormatError)
|
||||
from .timing import DebugTiming
|
||||
from .util import (to_bytes, bytes_to_hexstr, hexstr_to_bytes,
|
||||
dict_to_bytes, bytes_to_dict)
|
||||
from hkdf import Hkdf
|
||||
|
||||
def HKDF(skm, outlen, salt=None, CTXinfo=b""):
|
||||
return Hkdf(salt, skm).expand(CTXinfo, outlen)
|
||||
|
||||
CONFMSG_NONCE_LENGTH = 128//8
|
||||
CONFMSG_MAC_LENGTH = 256//8
|
||||
def make_confmsg(confkey, nonce):
|
||||
return nonce+HKDF(confkey, CONFMSG_MAC_LENGTH, nonce)
|
||||
|
||||
|
||||
# We send the following messages through the relay server to the far side (by
|
||||
# sending "add" commands to the server, and getting "message" responses):
|
||||
#
|
||||
# phase=setup:
|
||||
# * unauthenticated version strings (but why?)
|
||||
# * early warmup for connection hints ("I can do tor, spin up HS")
|
||||
# * wordlist l10n identifier
|
||||
# phase=pake: just the SPAKE2 'start' message (binary)
|
||||
# phase=version: version data, key verification (HKDF(key, nonce)+nonce)
|
||||
# phase=1,2,3,..: application messages
|
||||
|
||||
class _GetCode:
|
||||
def __init__(self, code_length, send_command, timing):
|
||||
self._code_length = code_length
|
||||
self._send_command = send_command
|
||||
self._timing = timing
|
||||
self._allocated_d = defer.Deferred()
|
||||
|
||||
@inlineCallbacks
|
||||
def go(self):
|
||||
with self._timing.add("allocate"):
|
||||
self._send_command("allocate")
|
||||
nameplate_id = yield self._allocated_d
|
||||
code = codes.make_code(nameplate_id, self._code_length)
|
||||
assert isinstance(code, type("")), type(code)
|
||||
returnValue(code)
|
||||
|
||||
def _response_handle_allocated(self, msg):
|
||||
nid = msg["nameplate"]
|
||||
assert isinstance(nid, type("")), type(nid)
|
||||
self._allocated_d.callback(nid)
|
||||
|
||||
class _InputCode:
|
||||
def __init__(self, reactor, prompt, code_length, send_command, timing,
|
||||
stderr):
|
||||
self._reactor = reactor
|
||||
self._prompt = prompt
|
||||
self._code_length = code_length
|
||||
self._send_command = send_command
|
||||
self._timing = timing
|
||||
self._stderr = stderr
|
||||
|
||||
@inlineCallbacks
|
||||
def _list(self):
|
||||
self._lister_d = defer.Deferred()
|
||||
self._send_command("list")
|
||||
nameplates = yield self._lister_d
|
||||
self._lister_d = None
|
||||
returnValue(nameplates)
|
||||
|
||||
def _list_blocking(self):
|
||||
return blockingCallFromThread(self._reactor, self._list)
|
||||
|
||||
@inlineCallbacks
|
||||
def go(self):
|
||||
# fetch the list of nameplates ahead of time, to give us a chance to
|
||||
# discover the welcome message (and warn the user about an obsolete
|
||||
# client)
|
||||
#
|
||||
# TODO: send the request early, show the prompt right away, hide the
|
||||
# latency in the user's indecision and slow typing. If we're lucky
|
||||
# the answer will come back before they hit TAB.
|
||||
|
||||
initial_nameplate_ids = yield self._list()
|
||||
with self._timing.add("input code", waiting="user"):
|
||||
t = self._reactor.addSystemEventTrigger("before", "shutdown",
|
||||
self._warn_readline)
|
||||
res = yield deferToThread(codes.input_code_with_completion,
|
||||
self._prompt,
|
||||
initial_nameplate_ids,
|
||||
self._list_blocking,
|
||||
self._code_length)
|
||||
(code, used_completion) = res
|
||||
self._reactor.removeSystemEventTrigger(t)
|
||||
if not used_completion:
|
||||
self._remind_about_tab()
|
||||
returnValue(code)
|
||||
|
||||
def _response_handle_nameplates(self, msg):
|
||||
nameplates = msg["nameplates"]
|
||||
assert isinstance(nameplates, list), type(nameplates)
|
||||
nids = []
|
||||
for n in nameplates:
|
||||
assert isinstance(n, dict), type(n)
|
||||
nameplate_id = n["id"]
|
||||
assert isinstance(nameplate_id, type("")), type(nameplate_id)
|
||||
nids.append(nameplate_id)
|
||||
self._lister_d.callback(nids)
|
||||
|
||||
def _warn_readline(self):
|
||||
# When our process receives a SIGINT, Twisted's SIGINT handler will
|
||||
# stop the reactor and wait for all threads to terminate before the
|
||||
# process exits. However, if we were waiting for
|
||||
# input_code_with_completion() when SIGINT happened, the readline
|
||||
# thread will be blocked waiting for something on stdin. Trick the
|
||||
# user into satisfying the blocking read so we can exit.
|
||||
print("\nCommand interrupted: please press Return to quit",
|
||||
file=sys.stderr)
|
||||
|
||||
# Other potential approaches to this problem:
|
||||
# * hard-terminate our process with os._exit(1), but make sure the
|
||||
# tty gets reset to a normal mode ("cooked"?) first, so that the
|
||||
# next shell command the user types is echoed correctly
|
||||
# * track down the thread (t.p.threadable.getThreadID from inside the
|
||||
# thread), get a cffi binding to pthread_kill, deliver SIGINT to it
|
||||
# * allocate a pty pair (pty.openpty), replace sys.stdin with the
|
||||
# slave, build a pty bridge that copies bytes (and other PTY
|
||||
# things) from the real stdin to the master, then close the slave
|
||||
# at shutdown, so readline sees EOF
|
||||
# * write tab-completion and basic editing (TTY raw mode,
|
||||
# backspace-is-erase) without readline, probably with curses or
|
||||
# twisted.conch.insults
|
||||
# * write a separate program to get codes (maybe just "wormhole
|
||||
# --internal-get-code"), run it as a subprocess, let it inherit
|
||||
# stdin/stdout, send it SIGINT when we receive SIGINT ourselves. It
|
||||
# needs an RPC mechanism (over some extra file descriptors) to ask
|
||||
# us to fetch the current nameplate_id list.
|
||||
#
|
||||
# Note that hard-terminating our process with os.kill(os.getpid(),
|
||||
# signal.SIGKILL), or SIGTERM, doesn't seem to work: the thread
|
||||
# doesn't see the signal, and we must still wait for stdin to make
|
||||
# readline finish.
|
||||
|
||||
def _remind_about_tab(self):
|
||||
print(" (note: you can use <Tab> to complete words)", file=self._stderr)
|
||||
|
||||
class _WelcomeHandler:
|
||||
def __init__(self, url, current_version, signal_error):
|
||||
self._ws_url = url
|
||||
self._version_warning_displayed = False
|
||||
self._current_version = current_version
|
||||
self._signal_error = signal_error
|
||||
|
||||
def handle_welcome(self, welcome):
|
||||
if "motd" in welcome:
|
||||
motd_lines = welcome["motd"].splitlines()
|
||||
motd_formatted = "\n ".join(motd_lines)
|
||||
print("Server (at %s) says:\n %s" %
|
||||
(self._ws_url, motd_formatted), file=sys.stderr)
|
||||
|
||||
# Only warn if we're running a release version (e.g. 0.0.6, not
|
||||
# 0.0.6-DISTANCE-gHASH). Only warn once.
|
||||
if ("current_cli_version" in welcome
|
||||
and "-" not in self._current_version
|
||||
and not self._version_warning_displayed
|
||||
and welcome["current_cli_version"] != self._current_version):
|
||||
print("Warning: errors may occur unless both sides are running the same version", file=sys.stderr)
|
||||
print("Server claims %s is current, but ours is %s"
|
||||
% (welcome["current_cli_version"], self._current_version),
|
||||
file=sys.stderr)
|
||||
self._version_warning_displayed = True
|
||||
|
||||
if "error" in welcome:
|
||||
return self._signal_error(WelcomeError(welcome["error"]),
|
||||
"unwelcome")
|
||||
|
||||
# states for nameplates, mailboxes, and the websocket connection
|
||||
(CLOSED, OPENING, OPEN, CLOSING) = ("closed", "opening", "open", "closing")
|
||||
|
||||
class _Wormhole:
|
||||
DEBUG = False
|
||||
|
||||
def __init__(self, appid, relay_url, reactor, tor_manager, timing, stderr):
|
||||
self._appid = appid
|
||||
self._ws_url = relay_url
|
||||
self._reactor = reactor
|
||||
self._tor_manager = tor_manager
|
||||
self._timing = timing
|
||||
self._stderr = stderr
|
||||
|
||||
self._welcomer = _WelcomeHandler(self._ws_url, __version__,
|
||||
self._signal_error)
|
||||
self._side = bytes_to_hexstr(os.urandom(5))
|
||||
self._connection_state = CLOSED
|
||||
self._connection_waiters = []
|
||||
self._ws_t = None
|
||||
self._started_get_code = False
|
||||
self._get_code = None
|
||||
self._started_input_code = False
|
||||
self._input_code_waiter = None
|
||||
self._code = None
|
||||
self._nameplate_id = None
|
||||
self._nameplate_state = CLOSED
|
||||
self._mailbox_id = None
|
||||
self._mailbox_state = CLOSED
|
||||
self._flag_need_nameplate = True
|
||||
self._flag_need_to_see_mailbox_used = True
|
||||
self._flag_need_to_build_msg1 = True
|
||||
self._flag_need_to_send_PAKE = True
|
||||
self._establish_key_called = False
|
||||
self._key_waiter = None
|
||||
self._key = None
|
||||
|
||||
self._version_message = None
|
||||
self._version_checked = False
|
||||
self._get_verifier_called = False
|
||||
self._verifier = None # bytes
|
||||
self._verify_result = None # bytes or a Failure
|
||||
self._verifier_waiter = None
|
||||
|
||||
self._my_versions = {} # sent
|
||||
self._their_versions = {} # received
|
||||
|
||||
self._close_called = False # the close() API has been called
|
||||
self._closing = False # we've started shutdown
|
||||
self._disconnect_waiter = defer.Deferred()
|
||||
self._error = None
|
||||
|
||||
self._next_send_phase = 0
|
||||
# send() queues plaintext here, waiting for a connection and the key
|
||||
self._plaintext_to_send = [] # (phase, plaintext)
|
||||
self._sent_phases = set() # to detect double-send
|
||||
|
||||
self._next_receive_phase = 0
|
||||
self._receive_waiters = {} # phase -> Deferred
|
||||
self._received_messages = {} # phase -> plaintext
|
||||
|
||||
# API METHODS for applications to call
|
||||
|
||||
# You must use at least one of these entry points, to establish the
|
||||
# wormhole code. Other APIs will stall or be queued until we have one.
|
||||
|
||||
# entry point 1: generate a new code. returns a Deferred
|
||||
def get_code(self, code_length=2): # XX rename to allocate_code()? create_?
|
||||
return self._API_get_code(code_length)
|
||||
|
||||
# entry point 2: interactively type in a code, with completion. returns
|
||||
# Deferred
|
||||
def input_code(self, prompt="Enter wormhole code: ", code_length=2):
|
||||
return self._API_input_code(prompt, code_length)
|
||||
|
||||
# entry point 3: paste in a fully-formed code. No return value.
|
||||
def set_code(self, code):
|
||||
self._API_set_code(code)
|
||||
|
||||
# todo: restore-saved-state entry points
|
||||
|
||||
def establish_key(self):
|
||||
"""
|
||||
returns a Deferred that fires when we've established the shared key.
|
||||
When successful, the Deferred fires with a simple `True`, otherwise
|
||||
it fails.
|
||||
"""
|
||||
return self._API_establish_key()
|
||||
|
||||
def verify(self):
|
||||
"""Returns a Deferred that fires when we've heard back from the other
|
||||
side, and have confirmed that they used the right wormhole code. When
|
||||
successful, the Deferred fires with a "verifier" (a bytestring) which
|
||||
can be compared out-of-band before making additional API calls. If
|
||||
they used the wrong wormhole code, the Deferred errbacks with
|
||||
WrongPasswordError.
|
||||
"""
|
||||
return self._API_verify()
|
||||
|
||||
def send(self, outbound_data):
|
||||
return self._API_send(outbound_data)
|
||||
|
||||
def get(self):
|
||||
return self._API_get()
|
||||
|
||||
def derive_key(self, purpose, length):
|
||||
"""Derive a new key from the established wormhole channel for some
|
||||
other purpose. This is a deterministic randomized function of the
|
||||
session key and the 'purpose' string (unicode/py3-string). This
|
||||
cannot be called until verify() or get() has fired.
|
||||
"""
|
||||
return self._API_derive_key(purpose, length)
|
||||
|
||||
def close(self, res=None):
|
||||
"""Collapse the wormhole, freeing up server resources and flushing
|
||||
all pending messages. Returns a Deferred that fires when everything
|
||||
is done. It fires with any argument close() was given, to enable use
|
||||
as a d.addBoth() handler:
|
||||
|
||||
w = wormhole(...)
|
||||
d = w.get()
|
||||
..
|
||||
d.addBoth(w.close)
|
||||
return d
|
||||
|
||||
Another reasonable approach is to use inlineCallbacks:
|
||||
|
||||
@inlineCallbacks
|
||||
def pair(self, code):
|
||||
w = wormhole(...)
|
||||
try:
|
||||
them = yield w.get()
|
||||
finally:
|
||||
yield w.close()
|
||||
"""
|
||||
return self._API_close(res)
|
||||
|
||||
# INTERNAL METHODS beyond here
|
||||
|
||||
def _start(self):
|
||||
d = self._connect() # causes stuff to happen
|
||||
d.addErrback(log.err)
|
||||
return d # fires when connection is established, if you care
|
||||
|
||||
|
||||
|
||||
def _make_endpoint(self, hostname, port):
|
||||
if self._tor_manager:
|
||||
return self._tor_manager.get_endpoint_for(hostname, port)
|
||||
# note: HostnameEndpoints have a default 30s timeout
|
||||
return endpoints.HostnameEndpoint(self._reactor, hostname, port)
|
||||
|
||||
def _connect(self):
|
||||
# TODO: if we lose the connection, make a new one, re-establish the
|
||||
# state
|
||||
assert self._side
|
||||
self._connection_state = OPENING
|
||||
self._ws_t = self._timing.add("open websocket")
|
||||
p = urlparse(self._ws_url)
|
||||
f = WSFactory(self._ws_url)
|
||||
f.setProtocolOptions(autoPingInterval=60, autoPingTimeout=600)
|
||||
f.wormhole = self
|
||||
f.d = defer.Deferred()
|
||||
# TODO: if hostname="localhost", I get three factories starting
|
||||
# and stopping (maybe 127.0.0.1, ::1, and something else?), and
|
||||
# an error in the factory is masked.
|
||||
ep = self._make_endpoint(p.hostname, p.port or 80)
|
||||
# .connect errbacks if the TCP connection fails
|
||||
d = ep.connect(f)
|
||||
d.addCallback(self._event_connected)
|
||||
# f.d is errbacked if WebSocket negotiation fails, and the WebSocket
|
||||
# drops any data sent before onOpen() fires, so we must wait for it
|
||||
d.addCallback(lambda _: f.d)
|
||||
d.addCallback(self._event_ws_opened)
|
||||
return d
|
||||
|
||||
def _event_connected(self, ws):
|
||||
self._ws = ws
|
||||
if self._ws_t:
|
||||
self._ws_t.finish()
|
||||
|
||||
def _event_ws_opened(self, _):
|
||||
self._connection_state = OPEN
|
||||
if self._closing:
|
||||
return self._maybe_finished_closing()
|
||||
self._ws_send_command("bind", appid=self._appid, side=self._side)
|
||||
self._maybe_claim_nameplate()
|
||||
self._maybe_send_pake()
|
||||
waiters, self._connection_waiters = self._connection_waiters, []
|
||||
for d in waiters:
|
||||
d.callback(None)
|
||||
|
||||
def _when_connected(self):
|
||||
if self._connection_state == OPEN:
|
||||
return defer.succeed(None)
|
||||
d = defer.Deferred()
|
||||
self._connection_waiters.append(d)
|
||||
return d
|
||||
|
||||
def _ws_send_command(self, mtype, **kwargs):
|
||||
# msgid is used by misc/dump-timing.py to correlate our sends with
|
||||
# their receives, and vice versa. They are also correlated with the
|
||||
# ACKs we get back from the server (which we otherwise ignore). There
|
||||
# are so few messages, 16 bits is enough to be mostly-unique.
|
||||
if self.DEBUG: print("SEND", mtype)
|
||||
kwargs["id"] = bytes_to_hexstr(os.urandom(2))
|
||||
kwargs["type"] = mtype
|
||||
payload = dict_to_bytes(kwargs)
|
||||
self._timing.add("ws_send", _side=self._side, **kwargs)
|
||||
self._ws.sendMessage(payload, False)
|
||||
|
||||
def _ws_dispatch_response(self, payload):
|
||||
msg = bytes_to_dict(payload)
|
||||
if self.DEBUG and msg["type"]!="ack": print("DIS", msg["type"], msg)
|
||||
self._timing.add("ws_receive", _side=self._side, message=msg)
|
||||
mtype = msg["type"]
|
||||
meth = getattr(self, "_response_handle_"+mtype, None)
|
||||
if not meth:
|
||||
# make tests fail, but real application will ignore it
|
||||
log.err(ValueError("Unknown inbound message type %r" % (msg,)))
|
||||
return
|
||||
return meth(msg)
|
||||
|
||||
def _response_handle_ack(self, msg):
|
||||
pass
|
||||
|
||||
def _response_handle_welcome(self, msg):
|
||||
self._welcomer.handle_welcome(msg["welcome"])
|
||||
|
||||
# entry point 1: generate a new code
|
||||
@inlineCallbacks
|
||||
def _API_get_code(self, code_length):
|
||||
if self._code is not None: raise InternalError
|
||||
if self._started_get_code: raise InternalError
|
||||
self._started_get_code = True
|
||||
with self._timing.add("API get_code"):
|
||||
yield self._when_connected()
|
||||
gc = _GetCode(code_length, self._ws_send_command, self._timing)
|
||||
self._get_code = gc
|
||||
self._response_handle_allocated = gc._response_handle_allocated
|
||||
# TODO: signal_error
|
||||
code = yield gc.go()
|
||||
self._get_code = None
|
||||
self._nameplate_state = OPEN
|
||||
self._event_learned_code(code)
|
||||
returnValue(code)
|
||||
|
||||
# entry point 2: interactively type in a code, with completion
|
||||
@inlineCallbacks
|
||||
def _API_input_code(self, prompt, code_length):
|
||||
if self._code is not None: raise InternalError
|
||||
if self._started_input_code: raise InternalError
|
||||
self._started_input_code = True
|
||||
with self._timing.add("API input_code"):
|
||||
yield self._when_connected()
|
||||
ic = _InputCode(self._reactor, prompt, code_length,
|
||||
self._ws_send_command, self._timing, self._stderr)
|
||||
self._response_handle_nameplates = ic._response_handle_nameplates
|
||||
# we reveal the Deferred we're waiting on, so _signal_error can
|
||||
# wake us up if something goes wrong (like a welcome error)
|
||||
self._input_code_waiter = ic.go()
|
||||
code = yield self._input_code_waiter
|
||||
self._input_code_waiter = None
|
||||
self._event_learned_code(code)
|
||||
returnValue(None)
|
||||
|
||||
# entry point 3: paste in a fully-formed code
|
||||
def _API_set_code(self, code):
|
||||
self._timing.add("API set_code")
|
||||
if not isinstance(code, type(u"")):
|
||||
raise TypeError("Unexpected code type '{}'".format(type(code)))
|
||||
if self._code is not None:
|
||||
raise InternalError
|
||||
self._event_learned_code(code)
|
||||
|
||||
# TODO: entry point 4: restore pre-contact saved state (we haven't heard
|
||||
# from the peer yet, so we still need the nameplate)
|
||||
|
||||
# TODO: entry point 5: restore post-contact saved state (so we don't need
|
||||
# or use the nameplate, only the mailbox)
|
||||
def _restore_post_contact_state(self, state):
|
||||
# ...
|
||||
self._flag_need_nameplate = False
|
||||
#self._mailbox_id = X(state)
|
||||
self._event_learned_mailbox()
|
||||
|
||||
def _event_learned_code(self, code):
|
||||
self._timing.add("code established")
|
||||
# bail out early if the password contains spaces...
|
||||
# this should raise a useful error
|
||||
if ' ' in code:
|
||||
raise KeyFormatError("code (%s) contains spaces." % code)
|
||||
self._code = code
|
||||
mo = re.search(r'^(\d+)-', code)
|
||||
if not mo:
|
||||
raise ValueError("code (%s) must start with NN-" % code)
|
||||
nid = mo.group(1)
|
||||
assert isinstance(nid, type("")), type(nid)
|
||||
self._nameplate_id = nid
|
||||
# fire more events
|
||||
self._maybe_build_msg1()
|
||||
self._event_learned_nameplate()
|
||||
|
||||
def _maybe_build_msg1(self):
|
||||
if not (self._code and self._flag_need_to_build_msg1):
|
||||
return
|
||||
with self._timing.add("pake1", waiting="crypto"):
|
||||
self._sp = SPAKE2_Symmetric(to_bytes(self._code),
|
||||
idSymmetric=to_bytes(self._appid))
|
||||
self._msg1 = self._sp.start()
|
||||
self._flag_need_to_build_msg1 = False
|
||||
self._event_built_msg1()
|
||||
|
||||
def _event_built_msg1(self):
|
||||
self._maybe_send_pake()
|
||||
|
||||
# every _maybe_X starts with a set of conditions
|
||||
# for each such condition Y, every _event_Y must call _maybe_X
|
||||
|
||||
def _event_learned_nameplate(self):
|
||||
self._maybe_claim_nameplate()
|
||||
|
||||
def _maybe_claim_nameplate(self):
|
||||
if not (self._nameplate_id and self._connection_state == OPEN):
|
||||
return
|
||||
self._ws_send_command("claim", nameplate=self._nameplate_id)
|
||||
self._nameplate_state = OPEN
|
||||
|
||||
def _response_handle_claimed(self, msg):
|
||||
mailbox_id = msg["mailbox"]
|
||||
assert isinstance(mailbox_id, type("")), type(mailbox_id)
|
||||
self._mailbox_id = mailbox_id
|
||||
self._event_learned_mailbox()
|
||||
|
||||
def _event_learned_mailbox(self):
|
||||
if not self._mailbox_id: raise InternalError
|
||||
assert self._mailbox_state == CLOSED, self._mailbox_state
|
||||
if self._closing:
|
||||
return
|
||||
self._ws_send_command("open", mailbox=self._mailbox_id)
|
||||
self._mailbox_state = OPEN
|
||||
# causes old messages to be sent now, and subscribes to new messages
|
||||
self._maybe_send_pake()
|
||||
self._maybe_send_phase_messages()
|
||||
|
||||
def _maybe_send_pake(self):
|
||||
# TODO: deal with reentrant call
|
||||
if not (self._connection_state == OPEN
|
||||
and self._mailbox_state == OPEN
|
||||
and self._flag_need_to_send_PAKE):
|
||||
return
|
||||
body = {"pake_v1": bytes_to_hexstr(self._msg1)}
|
||||
payload = dict_to_bytes(body)
|
||||
self._msg_send("pake", payload)
|
||||
self._flag_need_to_send_PAKE = False
|
||||
|
||||
def _event_received_pake(self, pake_msg):
|
||||
payload = bytes_to_dict(pake_msg)
|
||||
msg2 = hexstr_to_bytes(payload["pake_v1"])
|
||||
with self._timing.add("pake2", waiting="crypto"):
|
||||
self._key = self._sp.finish(msg2)
|
||||
self._event_established_key()
|
||||
|
||||
def _event_established_key(self):
|
||||
self._timing.add("key established")
|
||||
self._maybe_notify_key()
|
||||
|
||||
# both sides send different (random) version messages
|
||||
self._send_version_message()
|
||||
|
||||
verifier = self._derive_key(b"wormhole:verifier")
|
||||
self._event_computed_verifier(verifier)
|
||||
|
||||
self._maybe_check_version()
|
||||
self._maybe_send_phase_messages()
|
||||
|
||||
def _API_establish_key(self):
|
||||
if self._error: return defer.fail(self._error)
|
||||
if self._establish_key_called: raise InternalError
|
||||
self._establish_key_called = True
|
||||
if self._key is not None:
|
||||
return defer.succeed(True)
|
||||
self._key_waiter = defer.Deferred()
|
||||
return self._key_waiter
|
||||
|
||||
def _maybe_notify_key(self):
|
||||
if self._key is None:
|
||||
return
|
||||
if self._error:
|
||||
result = failure.Failure(self._error)
|
||||
else:
|
||||
result = True
|
||||
if self._key_waiter and not self._key_waiter.called:
|
||||
self._key_waiter.callback(result)
|
||||
|
||||
def _send_version_message(self):
|
||||
# this is encrypted like a normal phase message, and includes a
|
||||
# dictionary of version flags to let the other Wormhole know what
|
||||
# we're capable of (for future expansion)
|
||||
plaintext = dict_to_bytes(self._my_versions)
|
||||
phase = "version"
|
||||
data_key = self._derive_phase_key(self._side, phase)
|
||||
encrypted = self._encrypt_data(data_key, plaintext)
|
||||
self._msg_send(phase, encrypted)
|
||||
|
||||
def _API_verify(self):
|
||||
if self._error: return defer.fail(self._error)
|
||||
if self._get_verifier_called: raise InternalError
|
||||
self._get_verifier_called = True
|
||||
if self._verify_result:
|
||||
return defer.succeed(self._verify_result) # bytes or Failure
|
||||
self._verifier_waiter = defer.Deferred()
|
||||
return self._verifier_waiter
|
||||
|
||||
def _event_computed_verifier(self, verifier):
|
||||
self._verifier = verifier
|
||||
self._maybe_notify_verify()
|
||||
|
||||
def _maybe_notify_verify(self):
|
||||
if not (self._verifier and self._version_checked):
|
||||
return
|
||||
if self._error:
|
||||
self._verify_result = failure.Failure(self._error)
|
||||
else:
|
||||
self._verify_result = self._verifier
|
||||
if self._verifier_waiter and not self._verifier_waiter.called:
|
||||
self._verifier_waiter.callback(self._verify_result)
|
||||
|
||||
def _event_received_version(self, side, body):
|
||||
# We ought to have the master key by now, because sensible peers
|
||||
# should always send "pake" before sending "version". It might be
|
||||
# nice to relax this requirement, which means storing the received
|
||||
# version message, and having _event_established_key call
|
||||
# _check_version()
|
||||
self._version_message = (side, body)
|
||||
self._maybe_check_version()
|
||||
|
||||
def _maybe_check_version(self):
|
||||
if not (self._key and self._version_message):
|
||||
return
|
||||
if self._version_checked:
|
||||
return
|
||||
self._version_checked = True
|
||||
|
||||
side, body = self._version_message
|
||||
data_key = self._derive_phase_key(side, "version")
|
||||
try:
|
||||
plaintext = self._decrypt_data(data_key, body)
|
||||
except CryptoError:
|
||||
# this makes all API calls fail
|
||||
if self.DEBUG: print("CONFIRM FAILED")
|
||||
self._signal_error(WrongPasswordError(), "scary")
|
||||
return
|
||||
msg = bytes_to_dict(plaintext)
|
||||
self._version_received(msg)
|
||||
|
||||
self._maybe_notify_verify()
|
||||
|
||||
def _version_received(self, msg):
|
||||
self._their_versions = msg
|
||||
|
||||
def _API_send(self, outbound_data):
|
||||
if self._error: raise self._error
|
||||
if not isinstance(outbound_data, type(b"")):
|
||||
raise TypeError(type(outbound_data))
|
||||
phase = self._next_send_phase
|
||||
self._next_send_phase += 1
|
||||
self._plaintext_to_send.append( (phase, outbound_data) )
|
||||
with self._timing.add("API send", phase=phase):
|
||||
self._maybe_send_phase_messages()
|
||||
|
||||
def _derive_phase_key(self, side, phase):
|
||||
assert isinstance(side, type("")), type(side)
|
||||
assert isinstance(phase, type("")), type(phase)
|
||||
side_bytes = side.encode("ascii")
|
||||
phase_bytes = phase.encode("ascii")
|
||||
purpose = (b"wormhole:phase:"
|
||||
+ sha256(side_bytes).digest()
|
||||
+ sha256(phase_bytes).digest())
|
||||
return self._derive_key(purpose)
|
||||
|
||||
def _maybe_send_phase_messages(self):
|
||||
# TODO: deal with reentrant call
|
||||
if not (self._connection_state == OPEN
|
||||
and self._mailbox_state == OPEN
|
||||
and self._key):
|
||||
return
|
||||
plaintexts = self._plaintext_to_send
|
||||
self._plaintext_to_send = []
|
||||
for pm in plaintexts:
|
||||
(phase_int, plaintext) = pm
|
||||
assert isinstance(phase_int, int), type(phase_int)
|
||||
phase = "%d" % phase_int
|
||||
data_key = self._derive_phase_key(self._side, phase)
|
||||
encrypted = self._encrypt_data(data_key, plaintext)
|
||||
self._msg_send(phase, encrypted)
|
||||
|
||||
def _encrypt_data(self, key, data):
|
||||
# Without predefined roles, we can't derive predictably unique keys
|
||||
# for each side, so we use the same key for both. We use random
|
||||
# nonces to keep the messages distinct, and we automatically ignore
|
||||
# reflections.
|
||||
# TODO: HKDF(side, nonce, key) ?? include 'side' to prevent
|
||||
# reflections, since we no longer compare messages
|
||||
assert isinstance(key, type(b"")), type(key)
|
||||
assert isinstance(data, type(b"")), type(data)
|
||||
assert len(key) == SecretBox.KEY_SIZE, len(key)
|
||||
box = SecretBox(key)
|
||||
nonce = utils.random(SecretBox.NONCE_SIZE)
|
||||
return box.encrypt(data, nonce)
|
||||
|
||||
def _msg_send(self, phase, body):
|
||||
if phase in self._sent_phases: raise InternalError
|
||||
assert self._mailbox_state == OPEN, self._mailbox_state
|
||||
self._sent_phases.add(phase)
|
||||
# TODO: retry on failure, with exponential backoff. We're guarding
|
||||
# against the rendezvous server being temporarily offline.
|
||||
self._timing.add("add", phase=phase)
|
||||
self._ws_send_command("add", phase=phase, body=bytes_to_hexstr(body))
|
||||
|
||||
def _event_mailbox_used(self):
|
||||
if self.DEBUG: print("_event_mailbox_used")
|
||||
if self._flag_need_to_see_mailbox_used:
|
||||
self._maybe_release_nameplate()
|
||||
self._flag_need_to_see_mailbox_used = False
|
||||
|
||||
def _API_derive_key(self, purpose, length):
|
||||
if self._error: raise self._error
|
||||
if self._key is None:
|
||||
raise InternalError # call derive_key after get_verifier() or get()
|
||||
if not isinstance(purpose, type("")): raise TypeError(type(purpose))
|
||||
return self._derive_key(to_bytes(purpose), length)
|
||||
|
||||
def _derive_key(self, purpose, length=SecretBox.KEY_SIZE):
|
||||
if not isinstance(purpose, type(b"")): raise TypeError(type(purpose))
|
||||
if self._key is None:
|
||||
raise InternalError # call derive_key after get_verifier() or get()
|
||||
return HKDF(self._key, length, CTXinfo=purpose)
|
||||
|
||||
def _response_handle_message(self, msg):
|
||||
side = msg["side"]
|
||||
phase = msg["phase"]
|
||||
assert isinstance(phase, type("")), type(phase)
|
||||
body = hexstr_to_bytes(msg["body"])
|
||||
if side == self._side:
|
||||
return
|
||||
self._event_received_peer_message(side, phase, body)
|
||||
|
||||
def _event_received_peer_message(self, side, phase, body):
|
||||
# any message in the mailbox means we no longer need the nameplate
|
||||
self._event_mailbox_used()
|
||||
|
||||
if self._closing:
|
||||
log.msg("received peer message while closing '%s'" % phase)
|
||||
if phase in self._received_messages:
|
||||
log.msg("ignoring duplicate peer message '%s'" % phase)
|
||||
return
|
||||
|
||||
if phase == "pake":
|
||||
self._received_messages["pake"] = body
|
||||
return self._event_received_pake(body)
|
||||
if phase == "version":
|
||||
self._received_messages["version"] = body
|
||||
return self._event_received_version(side, body)
|
||||
if re.search(r'^\d+$', phase):
|
||||
return self._event_received_phase_message(side, phase, body)
|
||||
# ignore unrecognized phases, for forwards-compatibility
|
||||
log.msg("received unknown phase '%s'" % phase)
|
||||
|
||||
def _event_received_phase_message(self, side, phase, body):
|
||||
# It's a numbered phase message, aimed at the application above us.
|
||||
# Decrypt and deliver upstairs, notifying anyone waiting on it
|
||||
try:
|
||||
data_key = self._derive_phase_key(side, phase)
|
||||
plaintext = self._decrypt_data(data_key, body)
|
||||
except CryptoError:
|
||||
e = WrongPasswordError()
|
||||
self._signal_error(e, "scary") # flunk all other API calls
|
||||
# make tests fail, if they aren't explicitly catching it
|
||||
if self.DEBUG: print("CryptoError in msg received")
|
||||
log.err(e)
|
||||
if self.DEBUG: print(" did log.err", e)
|
||||
return # ignore this message
|
||||
self._received_messages[phase] = plaintext
|
||||
if phase in self._receive_waiters:
|
||||
d = self._receive_waiters.pop(phase)
|
||||
d.callback(plaintext)
|
||||
|
||||
def _decrypt_data(self, key, encrypted):
|
||||
assert isinstance(key, type(b"")), type(key)
|
||||
assert isinstance(encrypted, type(b"")), type(encrypted)
|
||||
assert len(key) == SecretBox.KEY_SIZE, len(key)
|
||||
box = SecretBox(key)
|
||||
data = box.decrypt(encrypted)
|
||||
return data
|
||||
|
||||
def _API_get(self):
|
||||
if self._error: return defer.fail(self._error)
|
||||
phase = "%d" % self._next_receive_phase
|
||||
self._next_receive_phase += 1
|
||||
with self._timing.add("API get", phase=phase):
|
||||
if phase in self._received_messages:
|
||||
return defer.succeed(self._received_messages[phase])
|
||||
d = self._receive_waiters[phase] = defer.Deferred()
|
||||
return d
|
||||
|
||||
def _signal_error(self, error, mood):
|
||||
if self.DEBUG: print("_signal_error", error, mood)
|
||||
if self._error:
|
||||
return
|
||||
self._maybe_close(error, mood)
|
||||
if self.DEBUG: print("_signal_error done")
|
||||
|
||||
@inlineCallbacks
|
||||
def _API_close(self, res, mood="happy"):
|
||||
if self.DEBUG: print("close")
|
||||
if self._close_called: raise InternalError
|
||||
self._close_called = True
|
||||
self._maybe_close(WormholeClosedError(), mood)
|
||||
if self.DEBUG: print("waiting for disconnect")
|
||||
yield self._disconnect_waiter
|
||||
returnValue(res)
|
||||
|
||||
def _maybe_close(self, error, mood):
|
||||
if self._closing:
|
||||
return
|
||||
|
||||
# ordering constraints:
|
||||
# * must wait for nameplate/mailbox acks before closing the websocket
|
||||
# * must mark APIs for failure before errbacking Deferreds
|
||||
# * since we give up control
|
||||
# * must mark self._closing before errbacking Deferreds
|
||||
# * since caller may call close() when we give up control
|
||||
# * and close() will reenter _maybe_close
|
||||
|
||||
self._error = error # causes new API calls to fail
|
||||
|
||||
# since we're about to give up control by errbacking any API
|
||||
# Deferreds, set self._closing, to make sure that a new call to
|
||||
# close() isn't going to confuse anything
|
||||
self._closing = True
|
||||
|
||||
# now errback all API deferreds except close(): get_code,
|
||||
# input_code, verify, get
|
||||
if self._input_code_waiter and not self._input_code_waiter.called:
|
||||
self._input_code_waiter.errback(error)
|
||||
for d in self._connection_waiters: # input_code, get_code (early)
|
||||
if self.DEBUG: print("EB cw")
|
||||
d.errback(error)
|
||||
if self._get_code: # get_code (late)
|
||||
if self.DEBUG: print("EB gc")
|
||||
self._get_code._allocated_d.errback(error)
|
||||
if self._verifier_waiter and not self._verifier_waiter.called:
|
||||
if self.DEBUG: print("EB VW")
|
||||
self._verifier_waiter.errback(error)
|
||||
if self._key_waiter and not self._key_waiter.called:
|
||||
if self.DEBUG: print("EB KW")
|
||||
self._key_waiter.errback(error)
|
||||
for d in self._receive_waiters.values():
|
||||
if self.DEBUG: print("EB RW")
|
||||
d.errback(error)
|
||||
# Release nameplate and close mailbox, if either was claimed/open.
|
||||
# Since _closing is True when both ACKs come back, the handlers will
|
||||
# close the websocket. When *that* finishes, _disconnect_waiter()
|
||||
# will fire.
|
||||
self._maybe_release_nameplate()
|
||||
self._maybe_close_mailbox(mood)
|
||||
# In the off chance we got closed before we even claimed the
|
||||
# nameplate, give _maybe_finished_closing a chance to run now.
|
||||
self._maybe_finished_closing()
|
||||
|
||||
def _maybe_release_nameplate(self):
|
||||
if self.DEBUG: print("_maybe_release_nameplate", self._nameplate_state)
|
||||
if self._nameplate_state == OPEN:
|
||||
if self.DEBUG: print(" sending release")
|
||||
self._ws_send_command("release")
|
||||
self._nameplate_state = CLOSING
|
||||
|
||||
def _response_handle_released(self, msg):
|
||||
self._nameplate_state = CLOSED
|
||||
self._maybe_finished_closing()
|
||||
|
||||
def _maybe_close_mailbox(self, mood):
|
||||
if self.DEBUG: print("_maybe_close_mailbox", self._mailbox_state)
|
||||
if self._mailbox_state == OPEN:
|
||||
if self.DEBUG: print(" sending close")
|
||||
self._ws_send_command("close", mood=mood)
|
||||
self._mailbox_state = CLOSING
|
||||
|
||||
def _response_handle_closed(self, msg):
|
||||
self._mailbox_state = CLOSED
|
||||
self._maybe_finished_closing()
|
||||
|
||||
def _maybe_finished_closing(self):
|
||||
if self.DEBUG: print("_maybe_finished_closing", self._closing, self._nameplate_state, self._mailbox_state, self._connection_state)
|
||||
if not self._closing:
|
||||
return
|
||||
if (self._nameplate_state == CLOSED
|
||||
and self._mailbox_state == CLOSED
|
||||
and self._connection_state == OPEN):
|
||||
self._connection_state = CLOSING
|
||||
self._drop_connection()
|
||||
|
||||
def _drop_connection(self):
|
||||
# separate method so it can be overridden by tests
|
||||
self._ws.transport.loseConnection() # probably flushes output
|
||||
# calls _ws_closed() when done
|
||||
|
||||
def _ws_closed(self, wasClean, code, reason):
|
||||
# For now (until we add reconnection), losing the websocket means
|
||||
# losing everything. Make all API callers fail. Help someone waiting
|
||||
# in close() to finish
|
||||
self._connection_state = CLOSED
|
||||
self._disconnect_waiter.callback(None)
|
||||
self._maybe_finished_closing()
|
||||
|
||||
# what needs to happen when _ws_closed() happens unexpectedly
|
||||
# * errback all API deferreds
|
||||
# * maybe: cause new API calls to fail
|
||||
# * obviously can't release nameplate or close mailbox
|
||||
# * can't re-close websocket
|
||||
# * close(wait=True) callers should fire right away
|
||||
from .journal import ImmediateJournal
|
||||
|
||||
def wormhole(appid, relay_url, reactor, tor_manager=None, timing=None,
|
||||
stderr=sys.stderr):
|
||||
|
@ -935,17 +28,10 @@ class _JournaledWormhole(service.MultiService):
|
|||
event_dispatcher_args=()):
|
||||
pass
|
||||
|
||||
class ImmediateJM(object):
|
||||
def queue_outbound(self, fn, *args, **kwargs):
|
||||
fn(*args, **kwargs)
|
||||
@contextlib.contextmanager
|
||||
def process(self):
|
||||
yield
|
||||
|
||||
class _Wormhole(_JournaledWormhole):
|
||||
# send events to self, deliver them via Deferreds
|
||||
def __init__(self, reactor):
|
||||
_JournaledWormhole.__init__(self, reactor, ImmediateJM(), self)
|
||||
_JournaledWormhole.__init__(self, reactor, ImmediateJournal(), self)
|
||||
|
||||
def wormhole(reactor):
|
||||
w = _Wormhole(reactor)
|
||||
|
@ -956,5 +42,5 @@ def journaled_from_data(state, reactor, journal,
|
|||
event_handler, event_handler_args=()):
|
||||
pass
|
||||
|
||||
def journaled(reactor, journal, event_handler, event_handler_args()):
|
||||
def journaled(reactor, journal, event_handler, event_handler_args=()):
|
||||
pass
|
||||
|
|
Loading…
Reference in New Issue
Block a user