remove old files, lots of type work

This commit is contained in:
Brian Warner 2017-02-22 12:51:53 -08:00
parent 3101ca51db
commit a2ed35ceb8
12 changed files with 302 additions and 1184 deletions

View File

@ -18,11 +18,13 @@ digraph {
Connection -> websocket [color="blue"]
#Connection -> Order [color="blue"]
App -> Wormhole [style="dashed" label="set_code\nsend\nclose\n(once)"]
App -> Wormhole [style="dashed" label="start\nset_code\nsend\nclose\n(once)"]
#App -> Wormhole [color="blue"]
Wormhole -> App [style="dashed" label="got_verifier\nreceived\nclosed\n(once)"]
#Wormhole -> Connection [color="blue"]
Wormhole -> Connection [style="dashed" label="start"]
Connection -> Wormhole [style="dashed" label="rx_welcome"]
Wormhole -> Send [style="dashed" label="send"]
@ -75,8 +77,6 @@ digraph {
label="set_code"]
App -> Code [style="dashed"
label="allocate\ninput\nset"]
}

View File

@ -1,180 +0,0 @@
from zope.interface import Interface
from six.moves.urllib_parse import urlparse
from attr import attrs, attrib
from twisted.internet import defer, endpoints #, error
from twisted.application import internet, service
from autobahn.twisted import websocket
from automat import MethodicalMachine
class WSClient(websocket.WebSocketClientProtocol):
def onConnect(self, response):
# this fires during WebSocket negotiation, and isn't very useful
# unless you want to modify the protocol settings
print("onConnect", response)
#self.connection_machine.onConnect(self)
def onOpen(self, *args):
# this fires when the WebSocket is ready to go. No arguments
print("onOpen", args)
#self.wormhole_open = True
# send BIND, since the MailboxMachine does not
self.connection_machine.protocol_onOpen(self)
#self.factory.d.callback(self)
def onMessage(self, payload, isBinary):
print("onMessage")
return
assert not isBinary
self.wormhole._ws_dispatch_response(payload)
def onClose(self, wasClean, code, reason):
print("onClose")
self.connection_machine.protocol_onClose(wasClean, code, reason)
#if self.wormhole_open:
# self.wormhole._ws_closed(wasClean, code, reason)
#else:
# # we closed before establishing a connection (onConnect) or
# # finishing WebSocket negotiation (onOpen): errback
# self.factory.d.errback(error.ConnectError(reason))
class WSFactory(websocket.WebSocketClientFactory):
protocol = WSClient
def buildProtocol(self, addr):
proto = websocket.WebSocketClientFactory.buildProtocol(self, addr)
proto.connection_machine = self.connection_machine
#proto.wormhole_open = False
return proto
# pip install (path to automat checkout)[visualize]
# automat-visualize wormhole._connection
class IRendezvousClient(Interface):
# must be an IService too
def set_dispatch(dispatcher):
"""Assign a dispatcher object to this client. The following methods
will be called on this object when things happen:
* rx_welcome(welcome -> dict)
* rx_nameplates(nameplates -> list) # [{id: str,..}, ..]
* rx_allocated(nameplate -> str)
* rx_claimed(mailbox -> str)
* rx_released()
* rx_message(side -> str, phase -> str, body -> str, msg_id -> str)
* rx_closed()
* rx_pong(pong -> int)
"""
pass
def tx_list(): pass
def tx_allocate(): pass
def tx_claim(nameplate): pass
def tx_release(): pass
def tx_open(mailbox): pass
def tx_add(phase, body): pass
def tx_close(mood): pass
def tx_ping(ping): pass
# We have one WSRelayClient for each wsurl we know about, and it lasts
# as long as its parent Wormhole does.
@attrs
class WSRelayClient(service.MultiService, object):
_journal = attrib()
_wormhole = attrib()
_mailbox = attrib()
_ws_url = attrib()
_reactor = attrib()
def __init__(self):
f = WSFactory(self._ws_url)
f.setProtocolOptions(autoPingInterval=60, autoPingTimeout=600)
f.connection_machine = self # calls onOpen and onClose
p = urlparse(self._ws_url)
ep = self._make_endpoint(p.hostname, p.port or 80)
# default policy: 1s initial, random exponential backoff, max 60s
self._client_service = internet.ClientService(ep, f)
self._connector = None
self._done_d = defer.Deferred()
self._current_delay = self.INITIAL_DELAY
def _make_endpoint(self, hostname, port):
return endpoints.HostnameEndpoint(self._reactor, hostname, port)
# inputs from elsewhere
def d_callback(self, p):
self._p = p
self._m.d_callback()
def d_errback(self, f):
self._f = f
self._m.d_errback()
def protocol_onOpen(self, p):
self._m.onOpen()
def protocol_onClose(self, wasClean, code, reason):
self._m.onClose()
def C_stop(self):
self._m.stop()
def timer_expired(self):
self._m.expire()
# outputs driven by the state machine
def ep_connect(self):
print("ep_connect()")
self._d = self._ep.connect(self._f)
self._d.addCallbacks(self.d_callback, self.d_errback)
def connection_established(self):
self._connection = WSConnection(ws, self._wormhole.appid,
self._wormhole.side, self)
self._mailbox.connected(ws)
self._wormhole.add_connection(self._connection)
self._ws_send_command("bind", appid=self._appid, side=self._side)
def M_lost(self):
self._wormhole.M_lost(self._connection)
self._connection = None
def start_timer(self):
print("start_timer")
self._t = self._reactor.callLater(3.0, self.expire)
def cancel_timer(self):
print("cancel_timer")
self._t.cancel()
self._t = None
def dropConnection(self):
print("dropConnection")
self._ws.dropConnection()
def notify_fail(self):
print("notify_fail", self._f.value if self._f else None)
self._done_d.errback(self._f)
def MC_stopped(self):
pass
def tryit(reactor):
cm = WSRelayClient(None, "ws://127.0.0.1:4000/v1", reactor)
print("_ConnectionMachine created")
print("start:", cm.start())
print("waiting on _done_d to finish")
return cm._done_d
# http://autobahn-python.readthedocs.io/en/latest/websocket/programming.html
# observed sequence of events:
# success: d_callback, onConnect(response), onOpen(), onMessage()
# negotifail (non-websocket): d_callback, onClose()
# noconnect: d_errback
def tryws(reactor):
ws_url = "ws://127.0.0.1:40001/v1"
f = WSFactory(ws_url)
p = urlparse(ws_url)
ep = endpoints.HostnameEndpoint(reactor, p.hostname, p.port or 80)
d = ep.connect(f)
def _good(p): print("_good", p)
def _bad(f): print("_bad", f)
d.addCallbacks(_good, _bad)
return defer.Deferred()
if __name__ == "__main__":
import sys
from twisted.python import log
log.startLogging(sys.stdout)
from twisted.internet.task import react
react(tryit)
# ??? a new WSConnection is created each time the WSRelayClient gets through
# negotiation

View File

@ -18,3 +18,9 @@ class INameplateLister(Interface):
pass
class ICode(Interface):
pass
class ITiming(Interface):
pass
class IJournal(Interface): # TODO: this needs to be public
pass

View File

@ -4,8 +4,10 @@ from spake2 import SPAKE2_Symmetric
from hkdf import Hkdf
from nacl.secret import SecretBox
from nacl.exceptions import CryptoError
from nacl import utils
from automat import MethodicalMachine
from .util import (to_bytes, bytes_to_hexstr, hexstr_to_bytes)
from .util import (to_bytes, bytes_to_hexstr, hexstr_to_bytes,
bytes_to_dict, dict_to_bytes)
from . import _interfaces
CryptoError
__all__ = ["derive_key", "derive_phase_key", "CryptoError",
@ -38,6 +40,14 @@ def decrypt_data(key, encrypted):
data = box.decrypt(encrypted)
return data
def encrypt_data(key, plaintext):
assert isinstance(key, type(b"")), type(key)
assert isinstance(plaintext, type(b"")), type(plaintext)
assert len(key) == SecretBox.KEY_SIZE, len(key)
box = SecretBox(key)
nonce = utils.random(SecretBox.NONCE_SIZE)
return box.encrypt(plaintext, nonce)
@implementer(_interfaces.IKey)
class Key(object):
m = MethodicalMachine()
@ -57,7 +67,9 @@ class Key(object):
@m.state(terminal=True)
def S3_scared(self): pass
def got_pake(self, payload):
def got_pake(self, body):
assert isinstance(body, type(b"")), type(body)
payload = bytes_to_dict(body)
if "pake_v1" in payload:
self.got_pake_good(hexstr_to_bytes(payload["pake_v1"]))
else:
@ -76,7 +88,8 @@ class Key(object):
self._sp = SPAKE2_Symmetric(to_bytes(code),
idSymmetric=to_bytes(self._appid))
msg1 = self._sp.start()
self._M.add_message("pake", {"pake_v1": bytes_to_hexstr(msg1)})
body = dict_to_bytes({"pake_v1": bytes_to_hexstr(msg1)})
self._M.add_message("pake", body)
@m.output()
def scared(self):

View File

@ -10,6 +10,9 @@ class Mailbox(object):
self._side = side
self._mood = None
self._nameplate = None
self._mailbox = None
self._pending_outbound = {}
self._processed = set()
def wire(self, wormhole, rendezvous_connector, ordering):
self._W = _interfaces.IWormhole(wormhole)
@ -101,15 +104,18 @@ class Mailbox(object):
@m.input()
def rx_claimed(self, mailbox): pass
def rx_message(self, side, phase, msg):
def rx_message(self, side, phase, body):
assert isinstance(side, type("")), type(side)
assert isinstance(phase, type("")), type(phase)
assert isinstance(body, type(b"")), type(body)
if side == self._side:
self.rx_message_ours(phase, msg)
self.rx_message_ours(phase, body)
else:
self.rx_message_theirs(phase, msg)
self.rx_message_theirs(phase, body)
@m.input()
def rx_message_ours(self, phase, msg): pass
def rx_message_ours(self, phase, body): pass
@m.input()
def rx_message_theirs(self, phase, msg): pass
def rx_message_theirs(self, phase, body): pass
@m.input()
def rx_released(self): pass
@m.input()
@ -119,7 +125,7 @@ class Mailbox(object):
# from Send or Key
@m.input()
def add_message(self, phase, msg): pass
def add_message(self, phase, body): pass
@m.output()
@ -138,8 +144,10 @@ class Mailbox(object):
assert self._mailbox
self._RC.tx_open(self._mailbox)
@m.output()
def queue(self, phase, msg):
self._pending_outbound[phase] = msg
def queue(self, phase, body):
assert isinstance(phase, type("")), type(phase)
assert isinstance(body, type(b"")), type(body)
self._pending_outbound[phase] = body
@m.output()
def store_mailbox_and_RC_tx_open_and_drain(self, mailbox):
self._mailbox = mailbox
@ -149,18 +157,20 @@ class Mailbox(object):
def drain(self):
self._drain()
def _drain(self):
for phase, msg in self._pending_outbound.items():
self._RC.tx_add(phase, msg)
for phase, body in self._pending_outbound.items():
self._RC.tx_add(phase, body)
@m.output()
def RC_tx_add(self, phase, msg):
self._RC.tx_add(phase, msg)
def RC_tx_add(self, phase, body):
assert isinstance(phase, type("")), type(phase)
assert isinstance(body, type(b"")), type(body)
self._RC.tx_add(phase, body)
@m.output()
def RC_tx_release(self):
self._RC.tx_release()
@m.output()
def RC_tx_release_and_accept(self, phase, msg):
def RC_tx_release_and_accept(self, phase, body):
self._RC.tx_release()
self._accept(phase, msg)
self._accept(phase, body)
@m.output()
def record_mood_and_RC_tx_release(self, mood):
self._mood = mood
@ -179,14 +189,14 @@ class Mailbox(object):
self._mood = mood
self._RC.tx_close(self._mood)
@m.output()
def accept(self, phase, msg):
self._accept(phase, msg)
def _accept(self, phase, msg):
def accept(self, phase, body):
self._accept(phase, body)
def _accept(self, phase, body):
if phase not in self._processed:
self._O.got_message(phase, msg)
self._O.got_message(phase, body)
self._processed.add(phase)
@m.output()
def dequeue(self, phase, msg):
def dequeue(self, phase, body):
self._pending_outbound.pop(phase)
@m.output()
def record_mood(self, mood):

View File

@ -19,36 +19,40 @@ class Order(object):
@m.state(terminal=True)
def S1_yes_pake(self): pass
def got_message(self, phase, payload):
def got_message(self, phase, body):
assert isinstance(phase, type("")), type(phase)
assert isinstance(body, type(b"")), type(body)
if phase == "pake":
self.got_pake(phase, payload)
self.got_pake(phase, body)
else:
self.got_non_pake(phase, payload)
self.got_non_pake(phase, body)
@m.input()
def got_pake(self, phase, payload): pass
def got_pake(self, phase, body): pass
@m.input()
def got_non_pake(self, phase, payload): pass
def got_non_pake(self, phase, body): pass
@m.output()
def queue(self, phase, payload):
self._queue.append((phase, payload))
def queue(self, phase, body):
assert isinstance(phase, type("")), type(phase)
assert isinstance(body, type(b"")), type(body)
self._queue.append((phase, body))
@m.output()
def notify_key(self, phase, payload):
self._K.got_pake(payload)
def notify_key(self, phase, body):
self._K.got_pake(body)
@m.output()
def drain(self, phase, payload):
def drain(self, phase, body):
del phase
del payload
for (phase, payload) in self._queue:
self._deliver(phase, payload)
del body
for (phase, body) in self._queue:
self._deliver(phase, body)
self._queue[:] = []
@m.output()
def deliver(self, phase, payload):
self._deliver(phase, payload)
def deliver(self, phase, body):
self._deliver(phase, body)
def _deliver(self, phase, payload):
self._R.got_message(phase, payload)
def _deliver(self, phase, body):
self._R.got_message(phase, body)
S0_no_pake.upon(got_non_pake, enter=S0_no_pake, outputs=[queue])
S0_no_pake.upon(got_pake, enter=S1_yes_pake, outputs=[notify_key, drain])

View File

@ -24,7 +24,9 @@ class Receive(object):
@m.state(terminal=True)
def S3_scared(self): pass
def got_message(self, phase, payload):
def got_message(self, phase, body):
assert isinstance(phase, type("")), type(phase)
assert isinstance(body, type(b"")), type(body)
assert self._key
data_key = derive_phase_key(self._side, phase)
try:
@ -53,6 +55,8 @@ class Receive(object):
self._W.happy()
@m.output()
def W_got_message(self, phase, plaintext):
assert isinstance(phase, type("")), type(phase)
assert isinstance(plaintext, type(b"")), type(plaintext)
self._W.got_message(phase, plaintext)
@m.output()
def W_scared(self):

View File

@ -1,39 +1,200 @@
import os
from six.moves.urllib_parse import urlparse
from attr import attrs, attrib
from attr.validators import provides, instance_of
from zope.interface import implementer
from twisted.application import service
from twisted.python import log
from twisted.internet import defer, endpoints
from twisted.application import internet
from autobahn.twisted import websocket
from . import _interfaces
from .util import (bytes_to_hexstr, hexstr_to_bytes,
bytes_to_dict, dict_to_bytes)
class WSClient(websocket.WebSocketClientProtocol):
def onConnect(self, response):
# this fires during WebSocket negotiation, and isn't very useful
# unless you want to modify the protocol settings
#print("onConnect", response)
pass
def onOpen(self, *args):
# this fires when the WebSocket is ready to go. No arguments
#print("onOpen", args)
#self.wormhole_open = True
self._RC.ws_open(self)
def onMessage(self, payload, isBinary):
#print("onMessage")
assert not isBinary
self._RC.ws_message(payload)
def onClose(self, wasClean, code, reason):
#print("onClose")
self._RC.ws_close(wasClean, code, reason)
#if self.wormhole_open:
# self.wormhole._ws_closed(wasClean, code, reason)
#else:
# # we closed before establishing a connection (onConnect) or
# # finishing WebSocket negotiation (onOpen): errback
# self.factory.d.errback(error.ConnectError(reason))
class WSFactory(websocket.WebSocketClientFactory):
protocol = WSClient
def __init__(self, RC, *args, **kwargs):
websocket.WebSocketClientFactory.__init__(self, *args, **kwargs)
self._RC = RC
def buildProtocol(self, addr):
proto = websocket.WebSocketClientFactory.buildProtocol(self, addr)
proto._RC = self._RC
#proto.wormhole_open = False
return proto
@attrs
@implementer(_interfaces.IRendezvousConnector)
class RendezvousConnector(service.MultiService, object):
def __init__(self, journal, timing):
self._journal = journal
self._timing = timing
class RendezvousConnector(object):
_url = attrib(instance_of(type(u"")))
_appid = attrib(instance_of(type(u"")))
_side = attrib(instance_of(type(u"")))
_reactor = attrib()
_journal = attrib(provides(_interfaces.IJournal))
_timing = attrib(provides(_interfaces.ITiming))
def wire(self, mailbox, code, nameplate_lister):
def __init__(self):
self._ws = None
f = WSFactory(self, self._url)
f.setProtocolOptions(autoPingInterval=60, autoPingTimeout=600)
p = urlparse(self._url)
ep = self._make_endpoint(p.hostname, p.port or 80)
self._connector = internet.ClientService(ep, f)
def _make_endpoint(self, hostname, port):
# TODO: Tor goes here
return endpoints.HostnameEndpoint(self._reactor, hostname, port)
def wire(self, wormhole, mailbox, code, nameplate_lister):
self._W = _interfaces.IWormhole(wormhole)
self._M = _interfaces.IMailbox(mailbox)
self._C = _interfaces.ICode(code)
self._NL = _interfaces.INameplateLister(nameplate_lister)
# from Wormhole
def start(self):
self._connector.startService()
# from Mailbox
def tx_claim(self):
pass
def tx_open(self):
pass
def tx_add(self, x):
pass
def tx_claim(self, nameplate):
self._tx("claim", nameplate=nameplate)
def tx_open(self, mailbox):
self._tx("open", mailbox=mailbox)
def tx_add(self, phase, body):
assert isinstance(phase, type("")), type(phase)
assert isinstance(body, type(b"")), type(body)
self._tx("add", phase=phase, body=bytes_to_hexstr(body))
def tx_release(self):
pass
self._tx("release")
def tx_close(self, mood):
pass
self._tx("close", mood=mood)
def stop(self):
pass
d = defer.maybeDeferred(self._connector.stopService)
d.addErrback(log.err) # TODO: deliver error upstairs?
d.addBoth(self._stopped)
# from NameplateLister
def tx_list(self):
pass
self._tx("list")
# from Code
def tx_allocate(self):
self._tx("allocate")
# from our WSClient (the WebSocket protocol)
def ws_open(self, proto):
self._ws = proto
self._tx("bind", appid=self._appid, side=self._side)
self._M.connected()
self._NL.connected()
def ws_message(self, payload):
msg = bytes_to_dict(payload)
if self.DEBUG and msg["type"]!="ack": print("DIS", msg["type"], msg)
self._timing.add("ws_receive", _side=self._side, message=msg)
mtype = msg["type"]
meth = getattr(self, "_response_handle_"+mtype, None)
if not meth:
# make tests fail, but real application will ignore it
log.err(ValueError("Unknown inbound message type %r" % (msg,)))
return
return meth(msg)
def ws_close(self, wasClean, code, reason):
self._ws = None
self._M.lost()
self._NL.lost()
# internal
def _stopped(self, res):
self._M.stopped()
def _tx(self, mtype, **kwargs):
assert self._ws
# msgid is used by misc/dump-timing.py to correlate our sends with
# their receives, and vice versa. They are also correlated with the
# ACKs we get back from the server (which we otherwise ignore). There
# are so few messages, 16 bits is enough to be mostly-unique.
if self.DEBUG: print("SEND", mtype)
kwargs["id"] = bytes_to_hexstr(os.urandom(2))
kwargs["type"] = mtype
payload = dict_to_bytes(kwargs)
self._timing.add("ws_send", _side=self._side, **kwargs)
self._ws.sendMessage(payload, False)
def _response_handle_allocated(self, msg):
nameplate = msg["nameplate"]
assert isinstance(nameplate, type("")), type(nameplate)
self._C.rx_allocated(nameplate)
def _response_handle_nameplates(self, msg):
nameplates = msg["nameplates"]
assert isinstance(nameplates, list), type(nameplates)
nids = []
for n in nameplates:
assert isinstance(n, dict), type(n)
nameplate_id = n["id"]
assert isinstance(nameplate_id, type("")), type(nameplate_id)
nids.append(nameplate_id)
self._NL.rx_nameplates(nids)
def _response_handle_ack(self, msg):
pass
# record, message, payload, packet, bundle, ciphertext, plaintext
def _response_handle_welcome(self, msg):
self._W.rx_welcome(msg["welcome"])
def _response_handle_claimed(self, msg):
mailbox = msg["mailbox"]
assert isinstance(mailbox, type("")), type(mailbox)
self._M.rx_claimed(mailbox)
def _response_handle_message(self, msg):
side = msg["side"]
phase = msg["phase"]
assert isinstance(phase, type("")), type(phase)
body = hexstr_to_bytes(msg["body"]) # bytes
self._M.rx_message(side, phase, body)
def _response_handle_released(self, msg):
self._M.rx_released()
def _response_handle_closed(self, msg):
self._M.rx_closed()
# record, message, payload, packet, bundle, ciphertext, plaintext

View File

@ -1,7 +1,7 @@
from zope.interface import implementer
from automat import MethodicalMachine
from . import _interfaces
from .util import hexstr_to_bytes
from ._key import derive_phase_key, encrypt_data
@implementer(_interfaces.ISend)
class Send(object):
@ -17,36 +17,34 @@ class Send(object):
@m.state(terminal=True)
def S1_verified_key(self): pass
def got_pake(self, payload):
if "pake_v1" in payload:
self.got_pake_good(hexstr_to_bytes(payload["pake_v1"]))
else:
self.got_pake_bad()
@m.input()
def got_verified_key(self, key): pass
@m.input()
def send(self, phase, payload): pass
def send(self, phase, plaintext): pass
@m.output()
def queue(self, phase, payload):
self._queue.append((phase, payload))
def queue(self, phase, plaintext):
assert isinstance(phase, type("")), type(phase)
assert isinstance(plaintext, type(b"")), type(plaintext)
self._queue.append((phase, plaintext))
@m.output()
def record_key(self, key):
self._key = key
@m.output()
def drain(self, key):
del key
for (phase, payload) in self._queue:
self._encrypt_and_send(phase, payload)
for (phase, plaintext) in self._queue:
self._encrypt_and_send(phase, plaintext)
self._queue[:] = []
@m.output()
def deliver(self, phase, payload):
self._encrypt_and_send(phase, payload)
def deliver(self, phase, plaintext):
assert isinstance(phase, type("")), type(phase)
assert isinstance(plaintext, type(b"")), type(plaintext)
self._encrypt_and_send(phase, plaintext)
def _encrypt_and_send(self, phase, payload):
data_key = self._derive_phase_key(self._side, phase)
encrypted = self._encrypt_data(data_key, plaintext)
def _encrypt_and_send(self, phase, plaintext):
data_key = derive_phase_key(self._side, phase)
encrypted = encrypt_data(data_key, plaintext)
self._M.add_message(phase, encrypted)
S0_no_key.upon(send, enter=S0_no_key, outputs=[queue])

View File

@ -9,6 +9,7 @@ from ._receive import Receive
from ._rendezvous import RendezvousConnector
from ._nameplate import NameplateListing
from ._code import Code
from .util import bytes_to_dict
@implementer(_interfaces.IWormhole)
class Wormhole:
@ -31,13 +32,13 @@ class Wormhole:
self._O.wire(self._K, self._R)
self._K.wire(self, self._M, self._R)
self._R.wire(self, self._K, self._S)
self._RC.wire(self._M, self._C, self._NL)
self._RC.wire(self, self._M, self._C, self._NL)
self._NL.wire(self._RC, self._C)
self._C.wire(self, self._RC, self._NL)
# these methods are called from outside
def start(self):
self._relay_client.start()
self._RC.start()
# and these are the state-machine transition functions, which don't take
# args
@ -54,7 +55,7 @@ class Wormhole:
# from the Application, or some sort of top-level shim
@m.input()
def send(self, phase, message): pass
def send(self, phase, plaintext): pass
@m.input()
def close(self): pass
@ -69,6 +70,8 @@ class Wormhole:
@m.input()
def scared(self): pass
def got_message(self, phase, plaintext):
assert isinstance(phase, type("")), type(phase)
assert isinstance(plaintext, type(b"")), type(plaintext)
if phase == "version":
self.got_version(plaintext)
else:
@ -91,12 +94,13 @@ class Wormhole:
self._M.set_nameplate(nameplate)
self._K.set_code(code)
@m.output()
def process_version(self, version): # response["message"][phase=version]
pass
def process_version(self, plaintext):
self._their_versions = bytes_to_dict(plaintext)
# ignored for now
@m.output()
def S_send(self, phase, message):
self._S.send(phase, message)
def S_send(self, phase, plaintext):
self._S.send(phase, plaintext)
@m.output()
def close_scared(self):

View File

@ -1,6 +1,9 @@
from zope.interface import implementer
import contextlib
from _interfaces import IJournal
class JournalManager(object):
@implementer(IJournal)
class Journal(object):
def __init__(self, save_checkpoint):
self._save_checkpoint = save_checkpoint
self._outbound_queue = []
@ -8,7 +11,7 @@ class JournalManager(object):
def queue_outbound(self, fn, *args, **kwargs):
assert self._processing
self._outbound_queue.append((fn, args, kwargs))
self._outbound_queue.append((fn, args, kwargs)
@contextlib.contextmanager
def process(self):
@ -21,3 +24,12 @@ class JournalManager(object):
fn(*args, **kwargs)
self._outbound_queue[:] = []
self._processing = False
@implementer(IJournal)
class ImmediateJournal(object):
def queue_outbound(self, fn, *args, **kwargs):
fn(*args, **kwargs)
@contextlib.contextmanager
def process(self):
yield

View File

@ -1,912 +1,5 @@
from __future__ import print_function, absolute_import, unicode_literals
import os, sys, re
from six.moves.urllib_parse import urlparse
from twisted.internet import defer, endpoints #, error
from twisted.internet.threads import deferToThread, blockingCallFromThread
from twisted.internet.defer import inlineCallbacks, returnValue
from twisted.python import log, failure
from nacl.secret import SecretBox
from nacl.exceptions import CryptoError
from nacl import utils
from spake2 import SPAKE2_Symmetric
from hashlib import sha256
from . import __version__
from . import codes
#from .errors import ServerError, Timeout
from .errors import (WrongPasswordError, InternalError, WelcomeError,
WormholeClosedError, KeyFormatError)
from .timing import DebugTiming
from .util import (to_bytes, bytes_to_hexstr, hexstr_to_bytes,
dict_to_bytes, bytes_to_dict)
from hkdf import Hkdf
def HKDF(skm, outlen, salt=None, CTXinfo=b""):
return Hkdf(salt, skm).expand(CTXinfo, outlen)
CONFMSG_NONCE_LENGTH = 128//8
CONFMSG_MAC_LENGTH = 256//8
def make_confmsg(confkey, nonce):
return nonce+HKDF(confkey, CONFMSG_MAC_LENGTH, nonce)
# We send the following messages through the relay server to the far side (by
# sending "add" commands to the server, and getting "message" responses):
#
# phase=setup:
# * unauthenticated version strings (but why?)
# * early warmup for connection hints ("I can do tor, spin up HS")
# * wordlist l10n identifier
# phase=pake: just the SPAKE2 'start' message (binary)
# phase=version: version data, key verification (HKDF(key, nonce)+nonce)
# phase=1,2,3,..: application messages
class _GetCode:
def __init__(self, code_length, send_command, timing):
self._code_length = code_length
self._send_command = send_command
self._timing = timing
self._allocated_d = defer.Deferred()
@inlineCallbacks
def go(self):
with self._timing.add("allocate"):
self._send_command("allocate")
nameplate_id = yield self._allocated_d
code = codes.make_code(nameplate_id, self._code_length)
assert isinstance(code, type("")), type(code)
returnValue(code)
def _response_handle_allocated(self, msg):
nid = msg["nameplate"]
assert isinstance(nid, type("")), type(nid)
self._allocated_d.callback(nid)
class _InputCode:
def __init__(self, reactor, prompt, code_length, send_command, timing,
stderr):
self._reactor = reactor
self._prompt = prompt
self._code_length = code_length
self._send_command = send_command
self._timing = timing
self._stderr = stderr
@inlineCallbacks
def _list(self):
self._lister_d = defer.Deferred()
self._send_command("list")
nameplates = yield self._lister_d
self._lister_d = None
returnValue(nameplates)
def _list_blocking(self):
return blockingCallFromThread(self._reactor, self._list)
@inlineCallbacks
def go(self):
# fetch the list of nameplates ahead of time, to give us a chance to
# discover the welcome message (and warn the user about an obsolete
# client)
#
# TODO: send the request early, show the prompt right away, hide the
# latency in the user's indecision and slow typing. If we're lucky
# the answer will come back before they hit TAB.
initial_nameplate_ids = yield self._list()
with self._timing.add("input code", waiting="user"):
t = self._reactor.addSystemEventTrigger("before", "shutdown",
self._warn_readline)
res = yield deferToThread(codes.input_code_with_completion,
self._prompt,
initial_nameplate_ids,
self._list_blocking,
self._code_length)
(code, used_completion) = res
self._reactor.removeSystemEventTrigger(t)
if not used_completion:
self._remind_about_tab()
returnValue(code)
def _response_handle_nameplates(self, msg):
nameplates = msg["nameplates"]
assert isinstance(nameplates, list), type(nameplates)
nids = []
for n in nameplates:
assert isinstance(n, dict), type(n)
nameplate_id = n["id"]
assert isinstance(nameplate_id, type("")), type(nameplate_id)
nids.append(nameplate_id)
self._lister_d.callback(nids)
def _warn_readline(self):
# When our process receives a SIGINT, Twisted's SIGINT handler will
# stop the reactor and wait for all threads to terminate before the
# process exits. However, if we were waiting for
# input_code_with_completion() when SIGINT happened, the readline
# thread will be blocked waiting for something on stdin. Trick the
# user into satisfying the blocking read so we can exit.
print("\nCommand interrupted: please press Return to quit",
file=sys.stderr)
# Other potential approaches to this problem:
# * hard-terminate our process with os._exit(1), but make sure the
# tty gets reset to a normal mode ("cooked"?) first, so that the
# next shell command the user types is echoed correctly
# * track down the thread (t.p.threadable.getThreadID from inside the
# thread), get a cffi binding to pthread_kill, deliver SIGINT to it
# * allocate a pty pair (pty.openpty), replace sys.stdin with the
# slave, build a pty bridge that copies bytes (and other PTY
# things) from the real stdin to the master, then close the slave
# at shutdown, so readline sees EOF
# * write tab-completion and basic editing (TTY raw mode,
# backspace-is-erase) without readline, probably with curses or
# twisted.conch.insults
# * write a separate program to get codes (maybe just "wormhole
# --internal-get-code"), run it as a subprocess, let it inherit
# stdin/stdout, send it SIGINT when we receive SIGINT ourselves. It
# needs an RPC mechanism (over some extra file descriptors) to ask
# us to fetch the current nameplate_id list.
#
# Note that hard-terminating our process with os.kill(os.getpid(),
# signal.SIGKILL), or SIGTERM, doesn't seem to work: the thread
# doesn't see the signal, and we must still wait for stdin to make
# readline finish.
def _remind_about_tab(self):
print(" (note: you can use <Tab> to complete words)", file=self._stderr)
class _WelcomeHandler:
def __init__(self, url, current_version, signal_error):
self._ws_url = url
self._version_warning_displayed = False
self._current_version = current_version
self._signal_error = signal_error
def handle_welcome(self, welcome):
if "motd" in welcome:
motd_lines = welcome["motd"].splitlines()
motd_formatted = "\n ".join(motd_lines)
print("Server (at %s) says:\n %s" %
(self._ws_url, motd_formatted), file=sys.stderr)
# Only warn if we're running a release version (e.g. 0.0.6, not
# 0.0.6-DISTANCE-gHASH). Only warn once.
if ("current_cli_version" in welcome
and "-" not in self._current_version
and not self._version_warning_displayed
and welcome["current_cli_version"] != self._current_version):
print("Warning: errors may occur unless both sides are running the same version", file=sys.stderr)
print("Server claims %s is current, but ours is %s"
% (welcome["current_cli_version"], self._current_version),
file=sys.stderr)
self._version_warning_displayed = True
if "error" in welcome:
return self._signal_error(WelcomeError(welcome["error"]),
"unwelcome")
# states for nameplates, mailboxes, and the websocket connection
(CLOSED, OPENING, OPEN, CLOSING) = ("closed", "opening", "open", "closing")
class _Wormhole:
DEBUG = False
def __init__(self, appid, relay_url, reactor, tor_manager, timing, stderr):
self._appid = appid
self._ws_url = relay_url
self._reactor = reactor
self._tor_manager = tor_manager
self._timing = timing
self._stderr = stderr
self._welcomer = _WelcomeHandler(self._ws_url, __version__,
self._signal_error)
self._side = bytes_to_hexstr(os.urandom(5))
self._connection_state = CLOSED
self._connection_waiters = []
self._ws_t = None
self._started_get_code = False
self._get_code = None
self._started_input_code = False
self._input_code_waiter = None
self._code = None
self._nameplate_id = None
self._nameplate_state = CLOSED
self._mailbox_id = None
self._mailbox_state = CLOSED
self._flag_need_nameplate = True
self._flag_need_to_see_mailbox_used = True
self._flag_need_to_build_msg1 = True
self._flag_need_to_send_PAKE = True
self._establish_key_called = False
self._key_waiter = None
self._key = None
self._version_message = None
self._version_checked = False
self._get_verifier_called = False
self._verifier = None # bytes
self._verify_result = None # bytes or a Failure
self._verifier_waiter = None
self._my_versions = {} # sent
self._their_versions = {} # received
self._close_called = False # the close() API has been called
self._closing = False # we've started shutdown
self._disconnect_waiter = defer.Deferred()
self._error = None
self._next_send_phase = 0
# send() queues plaintext here, waiting for a connection and the key
self._plaintext_to_send = [] # (phase, plaintext)
self._sent_phases = set() # to detect double-send
self._next_receive_phase = 0
self._receive_waiters = {} # phase -> Deferred
self._received_messages = {} # phase -> plaintext
# API METHODS for applications to call
# You must use at least one of these entry points, to establish the
# wormhole code. Other APIs will stall or be queued until we have one.
# entry point 1: generate a new code. returns a Deferred
def get_code(self, code_length=2): # XX rename to allocate_code()? create_?
return self._API_get_code(code_length)
# entry point 2: interactively type in a code, with completion. returns
# Deferred
def input_code(self, prompt="Enter wormhole code: ", code_length=2):
return self._API_input_code(prompt, code_length)
# entry point 3: paste in a fully-formed code. No return value.
def set_code(self, code):
self._API_set_code(code)
# todo: restore-saved-state entry points
def establish_key(self):
"""
returns a Deferred that fires when we've established the shared key.
When successful, the Deferred fires with a simple `True`, otherwise
it fails.
"""
return self._API_establish_key()
def verify(self):
"""Returns a Deferred that fires when we've heard back from the other
side, and have confirmed that they used the right wormhole code. When
successful, the Deferred fires with a "verifier" (a bytestring) which
can be compared out-of-band before making additional API calls. If
they used the wrong wormhole code, the Deferred errbacks with
WrongPasswordError.
"""
return self._API_verify()
def send(self, outbound_data):
return self._API_send(outbound_data)
def get(self):
return self._API_get()
def derive_key(self, purpose, length):
"""Derive a new key from the established wormhole channel for some
other purpose. This is a deterministic randomized function of the
session key and the 'purpose' string (unicode/py3-string). This
cannot be called until verify() or get() has fired.
"""
return self._API_derive_key(purpose, length)
def close(self, res=None):
"""Collapse the wormhole, freeing up server resources and flushing
all pending messages. Returns a Deferred that fires when everything
is done. It fires with any argument close() was given, to enable use
as a d.addBoth() handler:
w = wormhole(...)
d = w.get()
..
d.addBoth(w.close)
return d
Another reasonable approach is to use inlineCallbacks:
@inlineCallbacks
def pair(self, code):
w = wormhole(...)
try:
them = yield w.get()
finally:
yield w.close()
"""
return self._API_close(res)
# INTERNAL METHODS beyond here
def _start(self):
d = self._connect() # causes stuff to happen
d.addErrback(log.err)
return d # fires when connection is established, if you care
def _make_endpoint(self, hostname, port):
if self._tor_manager:
return self._tor_manager.get_endpoint_for(hostname, port)
# note: HostnameEndpoints have a default 30s timeout
return endpoints.HostnameEndpoint(self._reactor, hostname, port)
def _connect(self):
# TODO: if we lose the connection, make a new one, re-establish the
# state
assert self._side
self._connection_state = OPENING
self._ws_t = self._timing.add("open websocket")
p = urlparse(self._ws_url)
f = WSFactory(self._ws_url)
f.setProtocolOptions(autoPingInterval=60, autoPingTimeout=600)
f.wormhole = self
f.d = defer.Deferred()
# TODO: if hostname="localhost", I get three factories starting
# and stopping (maybe 127.0.0.1, ::1, and something else?), and
# an error in the factory is masked.
ep = self._make_endpoint(p.hostname, p.port or 80)
# .connect errbacks if the TCP connection fails
d = ep.connect(f)
d.addCallback(self._event_connected)
# f.d is errbacked if WebSocket negotiation fails, and the WebSocket
# drops any data sent before onOpen() fires, so we must wait for it
d.addCallback(lambda _: f.d)
d.addCallback(self._event_ws_opened)
return d
def _event_connected(self, ws):
self._ws = ws
if self._ws_t:
self._ws_t.finish()
def _event_ws_opened(self, _):
self._connection_state = OPEN
if self._closing:
return self._maybe_finished_closing()
self._ws_send_command("bind", appid=self._appid, side=self._side)
self._maybe_claim_nameplate()
self._maybe_send_pake()
waiters, self._connection_waiters = self._connection_waiters, []
for d in waiters:
d.callback(None)
def _when_connected(self):
if self._connection_state == OPEN:
return defer.succeed(None)
d = defer.Deferred()
self._connection_waiters.append(d)
return d
def _ws_send_command(self, mtype, **kwargs):
# msgid is used by misc/dump-timing.py to correlate our sends with
# their receives, and vice versa. They are also correlated with the
# ACKs we get back from the server (which we otherwise ignore). There
# are so few messages, 16 bits is enough to be mostly-unique.
if self.DEBUG: print("SEND", mtype)
kwargs["id"] = bytes_to_hexstr(os.urandom(2))
kwargs["type"] = mtype
payload = dict_to_bytes(kwargs)
self._timing.add("ws_send", _side=self._side, **kwargs)
self._ws.sendMessage(payload, False)
def _ws_dispatch_response(self, payload):
msg = bytes_to_dict(payload)
if self.DEBUG and msg["type"]!="ack": print("DIS", msg["type"], msg)
self._timing.add("ws_receive", _side=self._side, message=msg)
mtype = msg["type"]
meth = getattr(self, "_response_handle_"+mtype, None)
if not meth:
# make tests fail, but real application will ignore it
log.err(ValueError("Unknown inbound message type %r" % (msg,)))
return
return meth(msg)
def _response_handle_ack(self, msg):
pass
def _response_handle_welcome(self, msg):
self._welcomer.handle_welcome(msg["welcome"])
# entry point 1: generate a new code
@inlineCallbacks
def _API_get_code(self, code_length):
if self._code is not None: raise InternalError
if self._started_get_code: raise InternalError
self._started_get_code = True
with self._timing.add("API get_code"):
yield self._when_connected()
gc = _GetCode(code_length, self._ws_send_command, self._timing)
self._get_code = gc
self._response_handle_allocated = gc._response_handle_allocated
# TODO: signal_error
code = yield gc.go()
self._get_code = None
self._nameplate_state = OPEN
self._event_learned_code(code)
returnValue(code)
# entry point 2: interactively type in a code, with completion
@inlineCallbacks
def _API_input_code(self, prompt, code_length):
if self._code is not None: raise InternalError
if self._started_input_code: raise InternalError
self._started_input_code = True
with self._timing.add("API input_code"):
yield self._when_connected()
ic = _InputCode(self._reactor, prompt, code_length,
self._ws_send_command, self._timing, self._stderr)
self._response_handle_nameplates = ic._response_handle_nameplates
# we reveal the Deferred we're waiting on, so _signal_error can
# wake us up if something goes wrong (like a welcome error)
self._input_code_waiter = ic.go()
code = yield self._input_code_waiter
self._input_code_waiter = None
self._event_learned_code(code)
returnValue(None)
# entry point 3: paste in a fully-formed code
def _API_set_code(self, code):
self._timing.add("API set_code")
if not isinstance(code, type(u"")):
raise TypeError("Unexpected code type '{}'".format(type(code)))
if self._code is not None:
raise InternalError
self._event_learned_code(code)
# TODO: entry point 4: restore pre-contact saved state (we haven't heard
# from the peer yet, so we still need the nameplate)
# TODO: entry point 5: restore post-contact saved state (so we don't need
# or use the nameplate, only the mailbox)
def _restore_post_contact_state(self, state):
# ...
self._flag_need_nameplate = False
#self._mailbox_id = X(state)
self._event_learned_mailbox()
def _event_learned_code(self, code):
self._timing.add("code established")
# bail out early if the password contains spaces...
# this should raise a useful error
if ' ' in code:
raise KeyFormatError("code (%s) contains spaces." % code)
self._code = code
mo = re.search(r'^(\d+)-', code)
if not mo:
raise ValueError("code (%s) must start with NN-" % code)
nid = mo.group(1)
assert isinstance(nid, type("")), type(nid)
self._nameplate_id = nid
# fire more events
self._maybe_build_msg1()
self._event_learned_nameplate()
def _maybe_build_msg1(self):
if not (self._code and self._flag_need_to_build_msg1):
return
with self._timing.add("pake1", waiting="crypto"):
self._sp = SPAKE2_Symmetric(to_bytes(self._code),
idSymmetric=to_bytes(self._appid))
self._msg1 = self._sp.start()
self._flag_need_to_build_msg1 = False
self._event_built_msg1()
def _event_built_msg1(self):
self._maybe_send_pake()
# every _maybe_X starts with a set of conditions
# for each such condition Y, every _event_Y must call _maybe_X
def _event_learned_nameplate(self):
self._maybe_claim_nameplate()
def _maybe_claim_nameplate(self):
if not (self._nameplate_id and self._connection_state == OPEN):
return
self._ws_send_command("claim", nameplate=self._nameplate_id)
self._nameplate_state = OPEN
def _response_handle_claimed(self, msg):
mailbox_id = msg["mailbox"]
assert isinstance(mailbox_id, type("")), type(mailbox_id)
self._mailbox_id = mailbox_id
self._event_learned_mailbox()
def _event_learned_mailbox(self):
if not self._mailbox_id: raise InternalError
assert self._mailbox_state == CLOSED, self._mailbox_state
if self._closing:
return
self._ws_send_command("open", mailbox=self._mailbox_id)
self._mailbox_state = OPEN
# causes old messages to be sent now, and subscribes to new messages
self._maybe_send_pake()
self._maybe_send_phase_messages()
def _maybe_send_pake(self):
# TODO: deal with reentrant call
if not (self._connection_state == OPEN
and self._mailbox_state == OPEN
and self._flag_need_to_send_PAKE):
return
body = {"pake_v1": bytes_to_hexstr(self._msg1)}
payload = dict_to_bytes(body)
self._msg_send("pake", payload)
self._flag_need_to_send_PAKE = False
def _event_received_pake(self, pake_msg):
payload = bytes_to_dict(pake_msg)
msg2 = hexstr_to_bytes(payload["pake_v1"])
with self._timing.add("pake2", waiting="crypto"):
self._key = self._sp.finish(msg2)
self._event_established_key()
def _event_established_key(self):
self._timing.add("key established")
self._maybe_notify_key()
# both sides send different (random) version messages
self._send_version_message()
verifier = self._derive_key(b"wormhole:verifier")
self._event_computed_verifier(verifier)
self._maybe_check_version()
self._maybe_send_phase_messages()
def _API_establish_key(self):
if self._error: return defer.fail(self._error)
if self._establish_key_called: raise InternalError
self._establish_key_called = True
if self._key is not None:
return defer.succeed(True)
self._key_waiter = defer.Deferred()
return self._key_waiter
def _maybe_notify_key(self):
if self._key is None:
return
if self._error:
result = failure.Failure(self._error)
else:
result = True
if self._key_waiter and not self._key_waiter.called:
self._key_waiter.callback(result)
def _send_version_message(self):
# this is encrypted like a normal phase message, and includes a
# dictionary of version flags to let the other Wormhole know what
# we're capable of (for future expansion)
plaintext = dict_to_bytes(self._my_versions)
phase = "version"
data_key = self._derive_phase_key(self._side, phase)
encrypted = self._encrypt_data(data_key, plaintext)
self._msg_send(phase, encrypted)
def _API_verify(self):
if self._error: return defer.fail(self._error)
if self._get_verifier_called: raise InternalError
self._get_verifier_called = True
if self._verify_result:
return defer.succeed(self._verify_result) # bytes or Failure
self._verifier_waiter = defer.Deferred()
return self._verifier_waiter
def _event_computed_verifier(self, verifier):
self._verifier = verifier
self._maybe_notify_verify()
def _maybe_notify_verify(self):
if not (self._verifier and self._version_checked):
return
if self._error:
self._verify_result = failure.Failure(self._error)
else:
self._verify_result = self._verifier
if self._verifier_waiter and not self._verifier_waiter.called:
self._verifier_waiter.callback(self._verify_result)
def _event_received_version(self, side, body):
# We ought to have the master key by now, because sensible peers
# should always send "pake" before sending "version". It might be
# nice to relax this requirement, which means storing the received
# version message, and having _event_established_key call
# _check_version()
self._version_message = (side, body)
self._maybe_check_version()
def _maybe_check_version(self):
if not (self._key and self._version_message):
return
if self._version_checked:
return
self._version_checked = True
side, body = self._version_message
data_key = self._derive_phase_key(side, "version")
try:
plaintext = self._decrypt_data(data_key, body)
except CryptoError:
# this makes all API calls fail
if self.DEBUG: print("CONFIRM FAILED")
self._signal_error(WrongPasswordError(), "scary")
return
msg = bytes_to_dict(plaintext)
self._version_received(msg)
self._maybe_notify_verify()
def _version_received(self, msg):
self._their_versions = msg
def _API_send(self, outbound_data):
if self._error: raise self._error
if not isinstance(outbound_data, type(b"")):
raise TypeError(type(outbound_data))
phase = self._next_send_phase
self._next_send_phase += 1
self._plaintext_to_send.append( (phase, outbound_data) )
with self._timing.add("API send", phase=phase):
self._maybe_send_phase_messages()
def _derive_phase_key(self, side, phase):
assert isinstance(side, type("")), type(side)
assert isinstance(phase, type("")), type(phase)
side_bytes = side.encode("ascii")
phase_bytes = phase.encode("ascii")
purpose = (b"wormhole:phase:"
+ sha256(side_bytes).digest()
+ sha256(phase_bytes).digest())
return self._derive_key(purpose)
def _maybe_send_phase_messages(self):
# TODO: deal with reentrant call
if not (self._connection_state == OPEN
and self._mailbox_state == OPEN
and self._key):
return
plaintexts = self._plaintext_to_send
self._plaintext_to_send = []
for pm in plaintexts:
(phase_int, plaintext) = pm
assert isinstance(phase_int, int), type(phase_int)
phase = "%d" % phase_int
data_key = self._derive_phase_key(self._side, phase)
encrypted = self._encrypt_data(data_key, plaintext)
self._msg_send(phase, encrypted)
def _encrypt_data(self, key, data):
# Without predefined roles, we can't derive predictably unique keys
# for each side, so we use the same key for both. We use random
# nonces to keep the messages distinct, and we automatically ignore
# reflections.
# TODO: HKDF(side, nonce, key) ?? include 'side' to prevent
# reflections, since we no longer compare messages
assert isinstance(key, type(b"")), type(key)
assert isinstance(data, type(b"")), type(data)
assert len(key) == SecretBox.KEY_SIZE, len(key)
box = SecretBox(key)
nonce = utils.random(SecretBox.NONCE_SIZE)
return box.encrypt(data, nonce)
def _msg_send(self, phase, body):
if phase in self._sent_phases: raise InternalError
assert self._mailbox_state == OPEN, self._mailbox_state
self._sent_phases.add(phase)
# TODO: retry on failure, with exponential backoff. We're guarding
# against the rendezvous server being temporarily offline.
self._timing.add("add", phase=phase)
self._ws_send_command("add", phase=phase, body=bytes_to_hexstr(body))
def _event_mailbox_used(self):
if self.DEBUG: print("_event_mailbox_used")
if self._flag_need_to_see_mailbox_used:
self._maybe_release_nameplate()
self._flag_need_to_see_mailbox_used = False
def _API_derive_key(self, purpose, length):
if self._error: raise self._error
if self._key is None:
raise InternalError # call derive_key after get_verifier() or get()
if not isinstance(purpose, type("")): raise TypeError(type(purpose))
return self._derive_key(to_bytes(purpose), length)
def _derive_key(self, purpose, length=SecretBox.KEY_SIZE):
if not isinstance(purpose, type(b"")): raise TypeError(type(purpose))
if self._key is None:
raise InternalError # call derive_key after get_verifier() or get()
return HKDF(self._key, length, CTXinfo=purpose)
def _response_handle_message(self, msg):
side = msg["side"]
phase = msg["phase"]
assert isinstance(phase, type("")), type(phase)
body = hexstr_to_bytes(msg["body"])
if side == self._side:
return
self._event_received_peer_message(side, phase, body)
def _event_received_peer_message(self, side, phase, body):
# any message in the mailbox means we no longer need the nameplate
self._event_mailbox_used()
if self._closing:
log.msg("received peer message while closing '%s'" % phase)
if phase in self._received_messages:
log.msg("ignoring duplicate peer message '%s'" % phase)
return
if phase == "pake":
self._received_messages["pake"] = body
return self._event_received_pake(body)
if phase == "version":
self._received_messages["version"] = body
return self._event_received_version(side, body)
if re.search(r'^\d+$', phase):
return self._event_received_phase_message(side, phase, body)
# ignore unrecognized phases, for forwards-compatibility
log.msg("received unknown phase '%s'" % phase)
def _event_received_phase_message(self, side, phase, body):
# It's a numbered phase message, aimed at the application above us.
# Decrypt and deliver upstairs, notifying anyone waiting on it
try:
data_key = self._derive_phase_key(side, phase)
plaintext = self._decrypt_data(data_key, body)
except CryptoError:
e = WrongPasswordError()
self._signal_error(e, "scary") # flunk all other API calls
# make tests fail, if they aren't explicitly catching it
if self.DEBUG: print("CryptoError in msg received")
log.err(e)
if self.DEBUG: print(" did log.err", e)
return # ignore this message
self._received_messages[phase] = plaintext
if phase in self._receive_waiters:
d = self._receive_waiters.pop(phase)
d.callback(plaintext)
def _decrypt_data(self, key, encrypted):
assert isinstance(key, type(b"")), type(key)
assert isinstance(encrypted, type(b"")), type(encrypted)
assert len(key) == SecretBox.KEY_SIZE, len(key)
box = SecretBox(key)
data = box.decrypt(encrypted)
return data
def _API_get(self):
if self._error: return defer.fail(self._error)
phase = "%d" % self._next_receive_phase
self._next_receive_phase += 1
with self._timing.add("API get", phase=phase):
if phase in self._received_messages:
return defer.succeed(self._received_messages[phase])
d = self._receive_waiters[phase] = defer.Deferred()
return d
def _signal_error(self, error, mood):
if self.DEBUG: print("_signal_error", error, mood)
if self._error:
return
self._maybe_close(error, mood)
if self.DEBUG: print("_signal_error done")
@inlineCallbacks
def _API_close(self, res, mood="happy"):
if self.DEBUG: print("close")
if self._close_called: raise InternalError
self._close_called = True
self._maybe_close(WormholeClosedError(), mood)
if self.DEBUG: print("waiting for disconnect")
yield self._disconnect_waiter
returnValue(res)
def _maybe_close(self, error, mood):
if self._closing:
return
# ordering constraints:
# * must wait for nameplate/mailbox acks before closing the websocket
# * must mark APIs for failure before errbacking Deferreds
# * since we give up control
# * must mark self._closing before errbacking Deferreds
# * since caller may call close() when we give up control
# * and close() will reenter _maybe_close
self._error = error # causes new API calls to fail
# since we're about to give up control by errbacking any API
# Deferreds, set self._closing, to make sure that a new call to
# close() isn't going to confuse anything
self._closing = True
# now errback all API deferreds except close(): get_code,
# input_code, verify, get
if self._input_code_waiter and not self._input_code_waiter.called:
self._input_code_waiter.errback(error)
for d in self._connection_waiters: # input_code, get_code (early)
if self.DEBUG: print("EB cw")
d.errback(error)
if self._get_code: # get_code (late)
if self.DEBUG: print("EB gc")
self._get_code._allocated_d.errback(error)
if self._verifier_waiter and not self._verifier_waiter.called:
if self.DEBUG: print("EB VW")
self._verifier_waiter.errback(error)
if self._key_waiter and not self._key_waiter.called:
if self.DEBUG: print("EB KW")
self._key_waiter.errback(error)
for d in self._receive_waiters.values():
if self.DEBUG: print("EB RW")
d.errback(error)
# Release nameplate and close mailbox, if either was claimed/open.
# Since _closing is True when both ACKs come back, the handlers will
# close the websocket. When *that* finishes, _disconnect_waiter()
# will fire.
self._maybe_release_nameplate()
self._maybe_close_mailbox(mood)
# In the off chance we got closed before we even claimed the
# nameplate, give _maybe_finished_closing a chance to run now.
self._maybe_finished_closing()
def _maybe_release_nameplate(self):
if self.DEBUG: print("_maybe_release_nameplate", self._nameplate_state)
if self._nameplate_state == OPEN:
if self.DEBUG: print(" sending release")
self._ws_send_command("release")
self._nameplate_state = CLOSING
def _response_handle_released(self, msg):
self._nameplate_state = CLOSED
self._maybe_finished_closing()
def _maybe_close_mailbox(self, mood):
if self.DEBUG: print("_maybe_close_mailbox", self._mailbox_state)
if self._mailbox_state == OPEN:
if self.DEBUG: print(" sending close")
self._ws_send_command("close", mood=mood)
self._mailbox_state = CLOSING
def _response_handle_closed(self, msg):
self._mailbox_state = CLOSED
self._maybe_finished_closing()
def _maybe_finished_closing(self):
if self.DEBUG: print("_maybe_finished_closing", self._closing, self._nameplate_state, self._mailbox_state, self._connection_state)
if not self._closing:
return
if (self._nameplate_state == CLOSED
and self._mailbox_state == CLOSED
and self._connection_state == OPEN):
self._connection_state = CLOSING
self._drop_connection()
def _drop_connection(self):
# separate method so it can be overridden by tests
self._ws.transport.loseConnection() # probably flushes output
# calls _ws_closed() when done
def _ws_closed(self, wasClean, code, reason):
# For now (until we add reconnection), losing the websocket means
# losing everything. Make all API callers fail. Help someone waiting
# in close() to finish
self._connection_state = CLOSED
self._disconnect_waiter.callback(None)
self._maybe_finished_closing()
# what needs to happen when _ws_closed() happens unexpectedly
# * errback all API deferreds
# * maybe: cause new API calls to fail
# * obviously can't release nameplate or close mailbox
# * can't re-close websocket
# * close(wait=True) callers should fire right away
from .journal import ImmediateJournal
def wormhole(appid, relay_url, reactor, tor_manager=None, timing=None,
stderr=sys.stderr):
@ -935,17 +28,10 @@ class _JournaledWormhole(service.MultiService):
event_dispatcher_args=()):
pass
class ImmediateJM(object):
def queue_outbound(self, fn, *args, **kwargs):
fn(*args, **kwargs)
@contextlib.contextmanager
def process(self):
yield
class _Wormhole(_JournaledWormhole):
# send events to self, deliver them via Deferreds
def __init__(self, reactor):
_JournaledWormhole.__init__(self, reactor, ImmediateJM(), self)
_JournaledWormhole.__init__(self, reactor, ImmediateJournal(), self)
def wormhole(reactor):
w = _Wormhole(reactor)
@ -956,5 +42,5 @@ def journaled_from_data(state, reactor, journal,
event_handler, event_handler_args=()):
pass
def journaled(reactor, journal, event_handler, event_handler_args()):
def journaled(reactor, journal, event_handler, event_handler_args=()):
pass