466 lines
18 KiB
Python
466 lines
18 KiB
Python
|
from __future__ import print_function
|
||
|
import os, sys, json, re, unicodedata
|
||
|
from six.moves.urllib_parse import urlparse
|
||
|
from binascii import hexlify, unhexlify
|
||
|
from twisted.internet import defer, endpoints, error
|
||
|
from twisted.internet.threads import deferToThread, blockingCallFromThread
|
||
|
from twisted.internet.defer import inlineCallbacks, returnValue
|
||
|
from twisted.python import log
|
||
|
from autobahn.twisted import websocket
|
||
|
from nacl.secret import SecretBox
|
||
|
from nacl.exceptions import CryptoError
|
||
|
from nacl import utils
|
||
|
from spake2 import SPAKE2_Symmetric
|
||
|
from .. import __version__
|
||
|
from .. import codes
|
||
|
from ..errors import ServerError, Timeout, WrongPasswordError, UsageError
|
||
|
from ..timing import DebugTiming
|
||
|
from hkdf import Hkdf
|
||
|
|
||
|
def HKDF(skm, outlen, salt=None, CTXinfo=b""):
|
||
|
return Hkdf(salt, skm).expand(CTXinfo, outlen)
|
||
|
|
||
|
CONFMSG_NONCE_LENGTH = 128//8
|
||
|
CONFMSG_MAC_LENGTH = 256//8
|
||
|
def make_confmsg(confkey, nonce):
|
||
|
return nonce+HKDF(confkey, CONFMSG_MAC_LENGTH, nonce)
|
||
|
|
||
|
def to_bytes(u):
|
||
|
return unicodedata.normalize("NFC", u).encode("utf-8")
|
||
|
|
||
|
# We send the following messages through the relay server to the far side (by
|
||
|
# sending "add" commands to the server, and getting "message" responses):
|
||
|
#
|
||
|
# phase=setup:
|
||
|
# * unauthenticated version strings (but why?)
|
||
|
# * early warmup for connection hints ("I can do tor, spin up HS")
|
||
|
# * wordlist l10n identifier
|
||
|
# phase=pake: just the SPAKE2 'start' message (binary)
|
||
|
# phase=confirm: key verification (HKDF(key, nonce)+nonce)
|
||
|
# phase=1,2,3,..: application messages
|
||
|
|
||
|
class WSClient(websocket.WebSocketClientProtocol):
|
||
|
def onOpen(self):
|
||
|
self.wormhole_open = True
|
||
|
self.factory.d.callback(self)
|
||
|
|
||
|
def onMessage(self, payload, isBinary):
|
||
|
assert not isBinary
|
||
|
self.wormhole._ws_dispatch_response(payload)
|
||
|
|
||
|
def onClose(self, wasClean, code, reason):
|
||
|
if self.wormhole_open:
|
||
|
self.wormhole._ws_closed(wasClean, code, reason)
|
||
|
else:
|
||
|
# we closed before establishing a connection (onConnect) or
|
||
|
# finishing WebSocket negotiation (onOpen): errback
|
||
|
self.factory.d.errback(error.ConnectError(reason))
|
||
|
|
||
|
class WSFactory(websocket.WebSocketClientFactory):
|
||
|
protocol = WSClient
|
||
|
def buildProtocol(self, addr):
|
||
|
proto = websocket.WebSocketClientFactory.buildProtocol(self, addr)
|
||
|
proto.wormhole = self.wormhole
|
||
|
proto.wormhole_open = False
|
||
|
return proto
|
||
|
|
||
|
|
||
|
class _GetCode:
|
||
|
def __init__(self, code_length, send_command):
|
||
|
self._code_length = code_length
|
||
|
self._send_command = send_command
|
||
|
self._allocated_d = defer.Deferred()
|
||
|
|
||
|
@inlineCallbacks
|
||
|
def go(self):
|
||
|
with self._timing.add("allocate"):
|
||
|
self._send_command(u"allocate")
|
||
|
nameplate_id = yield self._allocated_d
|
||
|
code = codes.make_code(nameplate_id, self._code_length)
|
||
|
assert isinstance(code, type(u"")), type(code)
|
||
|
returnValue(code)
|
||
|
|
||
|
def _ws_handle_allocated(self, msg):
|
||
|
nid = msg["nameplate"]
|
||
|
assert isinstance(nid, type(u"")), type(nid)
|
||
|
self._allocated_d.callback(nid)
|
||
|
|
||
|
class _InputCode:
|
||
|
def __init__(self, reactor, prompt, code_length, send_command):
|
||
|
self._reactor = reactor
|
||
|
self._prompt = prompt
|
||
|
self._code_length = code_length
|
||
|
self._send_command = send_command
|
||
|
|
||
|
@inlineCallbacks
|
||
|
def _list(self):
|
||
|
self._lister_d = defer.Deferred()
|
||
|
self._send_command(u"list")
|
||
|
nameplates = yield self._lister_d
|
||
|
self._lister_d = None
|
||
|
returnValue(nameplates)
|
||
|
|
||
|
def _list_blocking(self):
|
||
|
return blockingCallFromThread(self._reactor, self._list)
|
||
|
|
||
|
@inlineCallbacks
|
||
|
def go(self):
|
||
|
# fetch the list of nameplates ahead of time, to give us a chance to
|
||
|
# discover the welcome message (and warn the user about an obsolete
|
||
|
# client)
|
||
|
#
|
||
|
# TODO: send the request early, show the prompt right away, hide the
|
||
|
# latency in the user's indecision and slow typing. If we're lucky
|
||
|
# the answer will come back before they hit TAB.
|
||
|
|
||
|
initial_nameplate_ids = yield self._list()
|
||
|
with self._timing.add("input code", waiting="user"):
|
||
|
t = self._reactor.addSystemEventTrigger("before", "shutdown",
|
||
|
self._warn_readline)
|
||
|
code = yield deferToThread(codes.input_code_with_completion,
|
||
|
self._prompt,
|
||
|
initial_nameplate_ids,
|
||
|
self._list_blocking,
|
||
|
self._code_length)
|
||
|
self._reactor.removeSystemEventTrigger(t)
|
||
|
returnValue(code)
|
||
|
|
||
|
def _ws_handle_nameplates(self, msg):
|
||
|
nameplates = msg["nameplates"]
|
||
|
assert isinstance(nameplates, list), type(nameplates)
|
||
|
for nameplate_id in nameplates:
|
||
|
assert isinstance(nameplate_id, type(u"")), type(nameplate_id)
|
||
|
self._lister_d.callback(nameplates)
|
||
|
|
||
|
def _warn_readline(self):
|
||
|
# When our process receives a SIGINT, Twisted's SIGINT handler will
|
||
|
# stop the reactor and wait for all threads to terminate before the
|
||
|
# process exits. However, if we were waiting for
|
||
|
# input_code_with_completion() when SIGINT happened, the readline
|
||
|
# thread will be blocked waiting for something on stdin. Trick the
|
||
|
# user into satisfying the blocking read so we can exit.
|
||
|
print("\nCommand interrupted: please press Return to quit",
|
||
|
file=sys.stderr)
|
||
|
|
||
|
# Other potential approaches to this problem:
|
||
|
# * hard-terminate our process with os._exit(1), but make sure the
|
||
|
# tty gets reset to a normal mode ("cooked"?) first, so that the
|
||
|
# next shell command the user types is echoed correctly
|
||
|
# * track down the thread (t.p.threadable.getThreadID from inside the
|
||
|
# thread), get a cffi binding to pthread_kill, deliver SIGINT to it
|
||
|
# * allocate a pty pair (pty.openpty), replace sys.stdin with the
|
||
|
# slave, build a pty bridge that copies bytes (and other PTY
|
||
|
# things) from the real stdin to the master, then close the slave
|
||
|
# at shutdown, so readline sees EOF
|
||
|
# * write tab-completion and basic editing (TTY raw mode,
|
||
|
# backspace-is-erase) without readline, probably with curses or
|
||
|
# twisted.conch.insults
|
||
|
# * write a separate program to get codes (maybe just "wormhole
|
||
|
# --internal-get-code"), run it as a subprocess, let it inherit
|
||
|
# stdin/stdout, send it SIGINT when we receive SIGINT ourselves. It
|
||
|
# needs an RPC mechanism (over some extra file descriptors) to ask
|
||
|
# us to fetch the current nameplate_id list.
|
||
|
#
|
||
|
# Note that hard-terminating our process with os.kill(os.getpid(),
|
||
|
# signal.SIGKILL), or SIGTERM, doesn't seem to work: the thread
|
||
|
# doesn't see the signal, and we must still wait for stdin to make
|
||
|
# readline finish.
|
||
|
|
||
|
|
||
|
|
||
|
class _Wormhole:
|
||
|
def __init__(self):
|
||
|
self._connected = None
|
||
|
self._flag_need_mailbox = True
|
||
|
self._flag_need_to_see_mailbox_used = True
|
||
|
self._flag_need_to_build_msg1 = True
|
||
|
self._flag_need_to_send_PAKE = True
|
||
|
self._flag_need_PAKE = True
|
||
|
self._flag_need_key = True # rename to not self._key
|
||
|
|
||
|
self._next_send_phase = 0
|
||
|
self._phase_messages_to_send = [] # not yet acked by server
|
||
|
|
||
|
self._next_receive_phase = 0
|
||
|
self._phase_messages_received = {} # phase -> message
|
||
|
|
||
|
|
||
|
def _start(self):
|
||
|
d = self._connect() # causes stuff to happen
|
||
|
d.addErrback(log.err)
|
||
|
return d # fires when connection is established, if you care
|
||
|
|
||
|
def _make_endpoint(self, hostname, port):
|
||
|
if self._tor_manager:
|
||
|
return self._tor_manager.get_endpoint_for(hostname, port)
|
||
|
# note: HostnameEndpoints have a default 30s timeout
|
||
|
return endpoints.HostnameEndpoint(self._reactor, hostname, port)
|
||
|
|
||
|
def _connect(self):
|
||
|
# TODO: if we lose the connection, make a new one, re-establish the
|
||
|
# state
|
||
|
assert self._side
|
||
|
p = urlparse(self._ws_url)
|
||
|
f = WSFactory(self._ws_url)
|
||
|
f.wormhole = self
|
||
|
f.d = defer.Deferred()
|
||
|
# TODO: if hostname="localhost", I get three factories starting
|
||
|
# and stopping (maybe 127.0.0.1, ::1, and something else?), and
|
||
|
# an error in the factory is masked.
|
||
|
ep = self._make_endpoint(p.hostname, p.port or 80)
|
||
|
# .connect errbacks if the TCP connection fails
|
||
|
d = ep.connect(f)
|
||
|
d.addCallback(self._event_connected)
|
||
|
# f.d is errbacked if WebSocket negotiation fails, and the WebSocket
|
||
|
# drops any data sent before onOpen() fires, so we must wait for it
|
||
|
d.addCallback(self._event_ws_opened)
|
||
|
return d
|
||
|
|
||
|
def _event_connected(self, ws, f):
|
||
|
self._ws = ws
|
||
|
self._ws_t = self._timing.add("websocket")
|
||
|
|
||
|
def _event_ws_opened(self, _):
|
||
|
self._connected = True
|
||
|
self._ws_send_command(u"bind", appid=self._appid, side=self._side)
|
||
|
self._maybe_get_mailbox()
|
||
|
|
||
|
def _ws_handle_welcome(self, msg):
|
||
|
welcome = msg["welcome"]
|
||
|
if ("motd" in welcome and
|
||
|
not self.motd_displayed):
|
||
|
motd_lines = welcome["motd"].splitlines()
|
||
|
motd_formatted = "\n ".join(motd_lines)
|
||
|
print("Server (at %s) says:\n %s" %
|
||
|
(self._ws_url, motd_formatted), file=sys.stderr)
|
||
|
self.motd_displayed = True
|
||
|
|
||
|
# Only warn if we're running a release version (e.g. 0.0.6, not
|
||
|
# 0.0.6-DISTANCE-gHASH). Only warn once.
|
||
|
if ("-" not in __version__ and
|
||
|
not self.version_warning_displayed and
|
||
|
welcome["current_version"] != __version__):
|
||
|
print("Warning: errors may occur unless both sides are running the same version", file=sys.stderr)
|
||
|
print("Server claims %s is current, but ours is %s"
|
||
|
% (welcome["current_version"], __version__), file=sys.stderr)
|
||
|
self.version_warning_displayed = True
|
||
|
|
||
|
if "error" in welcome:
|
||
|
return self._signal_error(welcome["error"])
|
||
|
|
||
|
|
||
|
# entry point 1: generate a new code
|
||
|
@inlineCallbacks
|
||
|
def get_code(self, code_length=2): # XX rename to allocate_code()? create_?
|
||
|
if self._code is not None: raise UsageError
|
||
|
if self._started_get_code: raise UsageError
|
||
|
self._started_get_code = True
|
||
|
with self._timing.add("API get_code"):
|
||
|
gc = _GetCode(code_length, self._ws_send_command)
|
||
|
self._ws_handle_allocated = gc._ws_handle_allocated
|
||
|
code = yield gc.go()
|
||
|
self._event_learned_code(code)
|
||
|
returnValue(code)
|
||
|
|
||
|
# entry point 2: interactively type in a code, with completion
|
||
|
@inlineCallbacks
|
||
|
def input_code(self, prompt="Enter wormhole code: ", code_length=2):
|
||
|
if self._code is not None: raise UsageError
|
||
|
if self._started_input_code: raise UsageError
|
||
|
self._started_input_code = True
|
||
|
with self._timing.add("API input_code"):
|
||
|
gc = _InputCode(prompt, code_length, self._ws_send_command)
|
||
|
self._ws_handle_nameplates = gc._ws_handle_nameplates
|
||
|
code = yield gc.go()
|
||
|
self._event_learned_code(code)
|
||
|
returnValue(None)
|
||
|
|
||
|
# entry point 3: paste in a fully-formed code
|
||
|
def set_code(self, code):
|
||
|
self._timing.add("API set_code")
|
||
|
if not isinstance(code, type(u"")): raise TypeError(type(code))
|
||
|
if self._code is not None: raise UsageError
|
||
|
self._event_learned_code(code)
|
||
|
|
||
|
def _event_learned_code(self, code):
|
||
|
self._timing.add("code established")
|
||
|
self._code = code
|
||
|
mo = re.search(r'^(\d+)-', code)
|
||
|
if not mo:
|
||
|
raise ValueError("code (%s) must start with NN-" % code)
|
||
|
nid = mo.group(1)
|
||
|
assert isinstance(nid, type(u"")), type(nid)
|
||
|
self._nameplate_id = nid
|
||
|
# fire more events
|
||
|
self._maybe_build_msg1()
|
||
|
self._event_learned_nameplate()
|
||
|
|
||
|
def _maybe_build_msg1(self):
|
||
|
if not (self._code and self._flag_need_to_build_msg1):
|
||
|
return
|
||
|
with self._timing.add("pake1", waiting="crypto"):
|
||
|
self._sp = SPAKE2_Symmetric(to_bytes(self._code),
|
||
|
idSymmetric=to_bytes(self._appid))
|
||
|
self._msg1 = self._sp.start()
|
||
|
self._flag_need_to_build_msg1 = False
|
||
|
self._event_built_msg1()
|
||
|
|
||
|
def _event_built_msg1(self):
|
||
|
self._maybe_send_pake()
|
||
|
|
||
|
# every _maybe_X starts with a set of conditions
|
||
|
# for each such condition Y, every _event_Y must call _maybe_X
|
||
|
|
||
|
def _event_learned_nameplate(self):
|
||
|
self._maybe_get_mailbox()
|
||
|
|
||
|
def _maybe_get_mailbox(self):
|
||
|
if not (self._flag_need_mailbox and self._nameplate_id
|
||
|
and self._connected):
|
||
|
return
|
||
|
self._ws_send_command(u"claim", nameplate=self._nameplate_id)
|
||
|
|
||
|
def _ws_handle_claimed(self, msg):
|
||
|
mailbox_id = msg["mailbox"]
|
||
|
assert isinstance(mailbox_id, type(u"")), type(mailbox_id)
|
||
|
self._mailbox_id = mailbox_id
|
||
|
self._event_learned_mailbox()
|
||
|
|
||
|
def _event_welcome(self):
|
||
|
pass
|
||
|
|
||
|
def _event_learned_mailbox(self):
|
||
|
self._flag_need_mailbox = False
|
||
|
if not self._mailbox_id: raise UsageError
|
||
|
if self._mailbox_opened: raise UsageError
|
||
|
self._ws_send_command(u"open", mailbox=self._mailbox_id)
|
||
|
# causes old messages to be sent now, and subscribes to new messages
|
||
|
self._maybe_send_pake()
|
||
|
self._maybe_send_phase_messages()
|
||
|
|
||
|
def _maybe_send_pake(self):
|
||
|
# TODO: deal with reentrant call
|
||
|
if not (self._connected and self._mailbox
|
||
|
and self._flag_need_to_send_PAKE):
|
||
|
return
|
||
|
d = self._msg_send(u"pake", self._msg1)
|
||
|
def _pake_sent(res):
|
||
|
self._flag_need_to_send_PAKE = False
|
||
|
d.addCallback(_pake_sent)
|
||
|
d.addErrback(log.err)
|
||
|
|
||
|
def _maybe_send_phase_messages(self):
|
||
|
# TODO: deal with reentrant call
|
||
|
if not (self._connected and self._mailbox and self._key):
|
||
|
return
|
||
|
for pm in self._phase_messages_to_send:
|
||
|
(phase, message) = pm
|
||
|
d = self._msg_send(phase, message)
|
||
|
def _phase_message_sent(res, pm=pm):
|
||
|
try:
|
||
|
self._phase_messages_to_send.remove(pm)
|
||
|
except ValueError:
|
||
|
pass
|
||
|
d.addCallback(_phase_message_sent)
|
||
|
d.addErrback(log.err)
|
||
|
|
||
|
|
||
|
|
||
|
def _event_received_message(self, msg):
|
||
|
pass
|
||
|
def _event_mailbox_used(self):
|
||
|
if self._flag_need_to_see_mailbox_used:
|
||
|
self._ws_send_command(u"release")
|
||
|
self._flag_need_to_see_mailbox_used = False
|
||
|
|
||
|
def _event_learned_PAKE(self, pake_msg):
|
||
|
with self._timing.add("pake2", waiting="crypto"):
|
||
|
self._key = self._sp.finish(pake_msg)
|
||
|
self._event_established_key()
|
||
|
|
||
|
def derive_key(self, purpose, length=SecretBox.KEY_SIZE):
|
||
|
if not isinstance(purpose, type(u"")): raise TypeError(type(purpose))
|
||
|
if self._key is None:
|
||
|
# call after get_verifier() or get()
|
||
|
raise UsageError
|
||
|
return HKDF(self._key, length, CTXinfo=to_bytes(purpose))
|
||
|
|
||
|
def _event_established_key(self):
|
||
|
self._timing.add("key established")
|
||
|
if self._send_confirm:
|
||
|
# both sides send different (random) confirmation messages
|
||
|
confkey = self.derive_key(u"wormhole:confirmation")
|
||
|
nonce = os.urandom(CONFMSG_NONCE_LENGTH)
|
||
|
confmsg = make_confmsg(confkey, nonce)
|
||
|
self._msg_send(u"confirm", confmsg, wait=True)
|
||
|
verifier = self.derive_key(u"wormhole:verifier")
|
||
|
self._event_computed_verifier(verifier)
|
||
|
pass
|
||
|
def _event_computed_verifier(self, verifier):
|
||
|
self._verifier = verifier
|
||
|
d, self._verifier_waiter = self._verifier_waiter, None
|
||
|
if d:
|
||
|
d.callback(verifier)
|
||
|
|
||
|
def _event_received_confirm(self, body):
|
||
|
# TODO: we might not have a master key yet, if the caller wasn't
|
||
|
# waiting in _get_master_key() when a back-to-back pake+_confirm
|
||
|
# message pair arrived.
|
||
|
confkey = self.derive_key(u"wormhole:confirmation")
|
||
|
nonce = body[:CONFMSG_NONCE_LENGTH]
|
||
|
if body != make_confmsg(confkey, nonce):
|
||
|
# this makes all API calls fail
|
||
|
return self._signal_error(WrongPasswordError())
|
||
|
|
||
|
def _event_received_phase_message(self, phase, message):
|
||
|
self._phase_messages_received[phase] = message
|
||
|
if phase in self._phase_message_waiters:
|
||
|
d = self._phase_message_waiters.pop(phase)
|
||
|
d.callback(message)
|
||
|
|
||
|
def _ws_handle_message(self, msg):
|
||
|
side = msg["side"]
|
||
|
phase = msg["phase"]
|
||
|
body = unhexlify(msg["body"].encode("ascii"))
|
||
|
if side == self._side:
|
||
|
return
|
||
|
self._event_received_peer_message(phase, body)
|
||
|
|
||
|
def XXXackstuff():
|
||
|
if phase in self._sent_messages and self._sent_messages[phase] == body:
|
||
|
self._delivered_messages.add(phase) # ack by server
|
||
|
self._wakeup()
|
||
|
return # ignore echoes of our outbound messages
|
||
|
|
||
|
def _event_received_peer_message(self, phase, body):
|
||
|
# any message in the mailbox means we no longer need the nameplate
|
||
|
self._event_mailbox_used()
|
||
|
#if phase in self._received_messages:
|
||
|
# # a nameplate collision would cause this
|
||
|
# err = ServerError("got duplicate phase %s" % phase, self._ws_url)
|
||
|
# return self._signal_error(err)
|
||
|
#self._received_messages[phase] = body
|
||
|
if phase == u"confirm":
|
||
|
self._event_received_confirm(body)
|
||
|
# now notify anyone waiting on it
|
||
|
self._wakeup()
|
||
|
|
||
|
def _event_asked_to_send_phase_message(self, phase, message):
|
||
|
pm = (phase, message)
|
||
|
self._phase_messages_to_send.append(pm)
|
||
|
self._maybe_send_phase_messages()
|
||
|
|
||
|
def _event_asked_to_close(self):
|
||
|
pass
|
||
|
|
||
|
|
||
|
|
||
|
def wormhole(appid, relay_url, reactor, tor_manager=None, timing=None):
|
||
|
w = _Wormhole(appid, relay_url, reactor, tor_manager, timing)
|
||
|
w._start()
|
||
|
return w
|
||
|
|
||
|
def wormhole_from_serialized(data, reactor):
|
||
|
w = _Wormhole.from_serialized(data, reactor)
|
||
|
return w
|