magic-wormhole/src/wormhole/twisted/transcribe.py
2016-04-20 19:18:41 -07:00

499 lines
20 KiB
Python

from __future__ import print_function
import os, sys, json, re, unicodedata
from six.moves.urllib_parse import urlparse
from binascii import hexlify, unhexlify
from twisted.internet import reactor, defer, endpoints, error
from twisted.internet.threads import deferToThread, blockingCallFromThread
from twisted.internet.defer import inlineCallbacks, returnValue
from autobahn.twisted import websocket
from nacl.secret import SecretBox
from nacl.exceptions import CryptoError
from nacl import utils
from spake2 import SPAKE2_Symmetric
from .. import __version__
from .. import codes
from ..errors import ServerError, Timeout, WrongPasswordError, UsageError
from ..timing import DebugTiming
from hkdf import Hkdf
def HKDF(skm, outlen, salt=None, CTXinfo=b""):
return Hkdf(salt, skm).expand(CTXinfo, outlen)
CONFMSG_NONCE_LENGTH = 128//8
CONFMSG_MAC_LENGTH = 256//8
def make_confmsg(confkey, nonce):
return nonce+HKDF(confkey, CONFMSG_MAC_LENGTH, nonce)
def to_bytes(u):
return unicodedata.normalize("NFC", u).encode("utf-8")
def close_on_error(meth): # method decorator
# Clients report certain errors as "moods", so the server can make a
# rough count failed connections (due to mismatched passwords, attacks,
# or timeouts). We don't report precondition failures, as those are the
# responsibility/fault of the local application code. We count
# non-precondition errors in case they represent server-side problems.
def _wrapper(self, *args, **kwargs):
d = defer.maybeDeferred(meth, self, *args, **kwargs)
def _onerror(f):
if f.check(Timeout):
d2 = self.close(u"lonely")
elif f.check(WrongPasswordError):
d2 = self.close(u"scary")
elif f.check(TypeError, UsageError):
# preconditions don't warrant _close_with_error()
d2 = defer.succeed(None)
else:
d2 = self.close(u"errory")
d2.addBoth(lambda _: f)
return d2
d.addErrback(_onerror)
return d
return _wrapper
class WSClient(websocket.WebSocketClientProtocol):
def onOpen(self):
self.wormhole_open = True
self.factory.d.callback(self)
def onMessage(self, payload, isBinary):
assert not isBinary
self.wormhole._ws_dispatch_msg(payload)
def onClose(self, wasClean, code, reason):
if self.wormhole_open:
self.wormhole._ws_closed(wasClean, code, reason)
else:
# we closed before establishing a connection (onConnect) or
# finishing WebSocket negotiation (onOpen): errback
self.factory.d.errback(error.ConnectError(reason))
class WSFactory(websocket.WebSocketClientFactory):
protocol = WSClient
def buildProtocol(self, addr):
proto = websocket.WebSocketClientFactory.buildProtocol(self, addr)
proto.wormhole = self.wormhole
proto.wormhole_open = False
return proto
class Wormhole:
motd_displayed = False
version_warning_displayed = False
_send_confirm = True
def __init__(self, appid, relay_url, tor_manager=None, timing=None,
reactor=reactor):
if not isinstance(appid, type(u"")): raise TypeError(type(appid))
if not isinstance(relay_url, type(u"")):
raise TypeError(type(relay_url))
if not relay_url.endswith(u"/"): raise UsageError
self._appid = appid
self._relay_url = relay_url
self._ws_url = relay_url.replace("http:", "ws:") + "ws"
self._tor_manager = tor_manager
self._timing = timing or DebugTiming()
self._reactor = reactor
self._side = hexlify(os.urandom(5)).decode("ascii")
self._code = None
self._channelid = None
self._key = None
self._started_get_code = False
self._sent_messages = set() # (phase, body_bytes)
self._delivered_messages = set() # (phase, body_bytes)
self._received_messages = {} # phase -> body_bytes
self._sent_phases = set() # phases, to prohibit double-send
self._got_phases = set() # phases, to prohibit double-read
self._sleepers = []
self._confirmation_failed = False
self._closed = False
self._deallocated_status = None
self._timing_started = self._timing.add_event("wormhole")
self._ws = None
self._ws_channel_claimed = False
self._error = None
def _make_endpoint(self, hostname, port):
if self._tor_manager:
return self._tor_manager.endpointForURI()
return endpoints.HostnameEndpoint(self._reactor, hostname, port) # 30s
@inlineCallbacks
def _get_websocket(self):
if not self._ws:
# TODO: if we lose the connection, make a new one
#from twisted.python import log
#log.startLogging(sys.stderr)
assert self._side
assert not self._ws_channel_claimed
p = urlparse(self._ws_url)
f = WSFactory(self._ws_url)
f.wormhole = self
f.d = defer.Deferred()
# TODO: if hostname="localhost", I get three factories starting
# and stopping (maybe 127.0.0.1, ::1, and something else?), and
# an error in the factory is masked.
ep = self._make_endpoint(p.hostname, p.port or 80)
# .connect errbacks if the TCP connection fails
self._ws = yield ep.connect(f)
# f.d is errbacked if WebSocket negotiation fails
yield f.d # WebSocket drops data sent before onOpen() fires
self._ws_send(u"bind", appid=self._appid, side=self._side)
# the socket is connected, and bound, but no channel has been claimed
returnValue(self._ws)
@inlineCallbacks
def _ws_send(self, mtype, **kwargs):
ws = yield self._get_websocket()
kwargs["type"] = mtype
payload = json.dumps(kwargs).encode("utf-8")
ws.sendMessage(payload, False)
def _ws_dispatch_msg(self, payload):
msg = json.loads(payload.decode("utf-8"))
mtype = msg["type"]
meth = getattr(self, "_ws_handle_"+mtype, None)
if not meth:
raise ValueError("Unknown inbound message type %r" % (msg,))
return meth(msg)
def _ws_handle_welcome(self, msg):
welcome = msg["welcome"]
if ("motd" in welcome and
not self.motd_displayed):
motd_lines = welcome["motd"].splitlines()
motd_formatted = "\n ".join(motd_lines)
print("Server (at %s) says:\n %s" %
(self._ws_url, motd_formatted), file=sys.stderr)
self.motd_displayed = True
# Only warn if we're running a release version (e.g. 0.0.6, not
# 0.0.6-DISTANCE-gHASH). Only warn once.
if ("-" not in __version__ and
not self.version_warning_displayed and
welcome["current_version"] != __version__):
print("Warning: errors may occur unless both sides are running the same version", file=sys.stderr)
print("Server claims %s is current, but ours is %s"
% (welcome["current_version"], __version__), file=sys.stderr)
self.version_warning_displayed = True
if "error" in welcome:
return self._signal_error(welcome["error"])
@inlineCallbacks
def _sleep(self):
if self._error: # don't sleep if the bed's already on fire
raise self._error
d = defer.Deferred()
self._sleepers.append(d)
yield d
if self._error:
raise self._error
def _wakeup(self):
sleepers = self._sleepers
self._sleepers = []
for d in sleepers:
d.callback(None)
# NOTE: callers should avoid reentrancy themselves. An
# eventual-send would be safer here, but it makes synchronizing
# unit tests annoying.
def _signal_error(self, error):
assert isinstance(error, Exception)
self._error = error
self._wakeup()
def _ws_handle_error(self, msg):
err = ServerError("%s: %s" % (msg["error"], msg["orig"]),
self._ws_url)
return self._signal_error(err)
@inlineCallbacks
def _claim_channel_and_watch(self):
assert self._channelid is not None
yield self._get_websocket()
if not self._ws_channel_claimed:
yield self._ws_send(u"claim", channelid=self._channelid)
self._ws_channel_claimed = True
yield self._ws_send(u"watch")
# entry point 1: generate a new code
@inlineCallbacks
def get_code(self, code_length=2): # rename to allocate_code()? create_?
if self._code is not None: raise UsageError
if self._started_get_code: raise UsageError
self._started_get_code = True
_sent = self._timing.add_event("allocate")
yield self._ws_send(u"allocate")
while self._channelid is None:
yield self._sleep()
self._timing.finish_event(_sent)
code = codes.make_code(self._channelid, code_length)
assert isinstance(code, type(u"")), type(code)
self._set_code(code)
self._start()
returnValue(code)
def _ws_handle_allocated(self, msg):
if self._channelid is not None:
return self._signal_error("got duplicate channelid")
self._channelid = msg["channelid"]
self._wakeup()
def _start(self):
# allocate the rest now too, so it can be serialized
self._sp = SPAKE2_Symmetric(to_bytes(self._code),
idSymmetric=to_bytes(self._appid))
self._msg1 = self._sp.start()
# entry point 2a: interactively type in a code, with completion
@inlineCallbacks
def input_code(self, prompt="Enter wormhole code: ", code_length=2):
def _lister():
return blockingCallFromThread(self._reactor, self._list_channels)
# fetch the list of channels ahead of time, to give us a chance to
# discover the welcome message (and warn the user about an obsolete
# client)
#
# TODO: send the request early, show the prompt right away, hide the
# latency in the user's indecision and slow typing. If we're lucky
# the answer will come back before they hit TAB.
initial_channelids = yield self._list_channels()
_start = self._timing.add_event("input code", waiting="user")
code = yield deferToThread(codes.input_code_with_completion,
prompt,
initial_channelids, _lister,
code_length)
self._timing.finish_event(_start)
returnValue(code) # application will give this to set_code()
@inlineCallbacks
def _list_channels(self):
_sent = self._timing.add_event("list")
self._latest_channelids = None
yield self._ws_send(u"list")
while self._latest_channelids is None:
yield self._sleep()
self._timing.finish_event(_sent)
returnValue(self._latest_channelids)
def _ws_handle_channelids(self, msg):
self._latest_channelids = msg["channelids"]
self._wakeup()
# entry point 2b: paste in a fully-formed code
def set_code(self, code):
if not isinstance(code, type(u"")): raise TypeError(type(code))
if self._code is not None: raise UsageError
mo = re.search(r'^(\d+)-', code)
if not mo:
raise ValueError("code (%s) must start with NN-" % code)
self._channelid = int(mo.group(1))
self._set_code(code)
self._start()
def _set_code(self, code):
if self._code is not None: raise UsageError
self._timing.add_event("code established")
self._code = code
def serialize(self):
# I can only be serialized after get_code/set_code and before
# get_verifier/get_data
if self._code is None: raise UsageError
if self._key is not None: raise UsageError
if self._sent_phases: raise UsageError
if self._got_phases: raise UsageError
data = {
"appid": self._appid,
"relay_url": self._relay_url,
"code": self._code,
"channelid": self._channelid,
"side": self._side,
"spake2": json.loads(self._sp.serialize().decode("ascii")),
"msg1": hexlify(self._msg1).decode("ascii"),
}
return json.dumps(data)
# entry point 3: resume a previously-serialized session
@classmethod
def from_serialized(klass, data):
d = json.loads(data)
self = klass(d["appid"], d["relay_url"])
self._side = d["side"]
self._channelid = d["channelid"]
self._set_code(d["code"])
sp_data = json.dumps(d["spake2"]).encode("ascii")
self._sp = SPAKE2_Symmetric.from_serialized(sp_data)
self._msg1 = unhexlify(d["msg1"].encode("ascii"))
return self
@close_on_error
@inlineCallbacks
def get_verifier(self):
if self._closed: raise UsageError
if self._code is None: raise UsageError
yield self._get_master_key()
returnValue(self._verifier)
@inlineCallbacks
def _get_master_key(self):
# TODO: prevent multiple invocation
if not self._key:
yield self._claim_channel_and_watch()
yield self._msg_send(u"pake", self._msg1)
pake_msg = yield self._msg_get(u"pake")
self._key = self._sp.finish(pake_msg)
self._verifier = self.derive_key(u"wormhole:verifier")
self._timing.add_event("key established")
if self._send_confirm:
# both sides send different (random) confirmation messages
confkey = self.derive_key(u"wormhole:confirmation")
nonce = os.urandom(CONFMSG_NONCE_LENGTH)
confmsg = make_confmsg(confkey, nonce)
yield self._msg_send(u"_confirm", confmsg)
@inlineCallbacks
def _msg_send(self, phase, body, wait=False):
self._sent_messages.add( (phase, body) )
# TODO: retry on failure, with exponential backoff. We're guarding
# against the rendezvous server being temporarily offline.
yield self._ws_send(u"add", phase=phase,
body=hexlify(body).decode("ascii"))
if wait:
while (phase, body) not in self._delivered_messages:
yield self._sleep()
def _ws_handle_message(self, msg):
m = msg["message"]
phase = m["phase"]
body = unhexlify(m["body"].encode("ascii"))
if (phase, body) in self._sent_messages:
self._delivered_messages.add( (phase, body) ) # ack by server
self._wakeup()
return # ignore echoes of our outbound messages
if phase in self._received_messages:
# a channel collision would cause this
err = ServerError("got duplicate phase %s" % phase, self._ws_url)
return self._signal_error(err)
self._received_messages[phase] = body
if phase == u"_confirm":
confkey = self.derive_key(u"wormhole:confirmation")
nonce = body[:CONFMSG_NONCE_LENGTH]
if body != make_confmsg(confkey, nonce):
# this makes all API calls fail
return self._signal_error(WrongPasswordError())
# now notify anyone waiting on it
self._wakeup()
@inlineCallbacks
def _msg_get(self, phase):
_start = self._timing.add_event("get(%s)" % phase)
while phase not in self._received_messages:
yield self._sleep() # we can wait a long time here
# that will throw an error if something goes wrong
self._timing.finish_event(_start)
returnValue(self._received_messages[phase])
def derive_key(self, purpose, length=SecretBox.KEY_SIZE):
if not isinstance(purpose, type(u"")): raise TypeError(type(purpose))
if self._key is None:
# call after get_verifier() or get_data()
raise UsageError
return HKDF(self._key, length, CTXinfo=to_bytes(purpose))
def _encrypt_data(self, key, data):
assert isinstance(key, type(b"")), type(key)
assert isinstance(data, type(b"")), type(data)
assert len(key) == SecretBox.KEY_SIZE, len(key)
box = SecretBox(key)
nonce = utils.random(SecretBox.NONCE_SIZE)
return box.encrypt(data, nonce)
def _decrypt_data(self, key, encrypted):
assert isinstance(key, type(b"")), type(key)
assert isinstance(encrypted, type(b"")), type(encrypted)
assert len(key) == SecretBox.KEY_SIZE, len(key)
box = SecretBox(key)
data = box.decrypt(encrypted)
return data
@close_on_error
@inlineCallbacks
def send_data(self, outbound_data, phase=u"data", wait=False):
if not isinstance(outbound_data, type(b"")):
raise TypeError(type(outbound_data))
if not isinstance(phase, type(u"")): raise TypeError(type(phase))
if self._closed: raise UsageError
if self._code is None:
raise UsageError("You must set_code() before send_data()")
if phase.startswith(u"_"): raise UsageError # reserved for internals
if phase in self._sent_phases: raise UsageError # only call this once
self._sent_phases.add(phase)
_sent = self._timing.add_event("API send data", phase=phase, wait=wait)
# Without predefined roles, we can't derive predictably unique keys
# for each side, so we use the same key for both. We use random
# nonces to keep the messages distinct, and we automatically ignore
# reflections.
yield self._get_master_key()
data_key = self.derive_key(u"wormhole:phase:%s" % phase)
outbound_encrypted = self._encrypt_data(data_key, outbound_data)
yield self._msg_send(phase, outbound_encrypted, wait)
self._timing.finish_event(_sent)
@close_on_error
@inlineCallbacks
def get_data(self, phase=u"data"):
if not isinstance(phase, type(u"")): raise TypeError(type(phase))
if self._closed: raise UsageError
if self._code is None: raise UsageError
if phase.startswith(u"_"): raise UsageError # reserved for internals
if phase in self._got_phases: raise UsageError # only call this once
self._got_phases.add(phase)
_sent = self._timing.add_event("API get data", phase=phase)
yield self._get_master_key()
body = yield self._msg_get(phase) # we can wait a long time here
self._timing.finish_event(_sent)
try:
data_key = self.derive_key(u"wormhole:phase:%s" % phase)
inbound_data = self._decrypt_data(data_key, body)
returnValue(inbound_data)
except CryptoError:
raise WrongPasswordError
def _ws_closed(self, wasClean, code, reason):
self._ws = None
# TODO: schedule reconnect, unless we're done
@inlineCallbacks
def close(self, res=None, mood=u"happy"):
if not isinstance(mood, (type(None), type(u""))):
raise TypeError(type(mood))
if self._closed:
returnValue(None)
self._closed = True
if not self._ws:
returnValue(None)
self._timing.finish_event(self._timing_started, mood=mood)
yield self._deallocate(mood)
# TODO: mark WebSocket as don't-reconnect
self._ws.transport.loseConnection() # probably flushes
del self._ws
@inlineCallbacks
def _deallocate(self, mood=None):
_sent = self._timing.add_event("close")
yield self._ws_send(u"deallocate", mood=mood)
while self._deallocated_status is None:
yield self._sleep()
self._timing.finish_event(_sent)
# TODO: set a timeout, don't wait forever for an ack
# TODO: if the connection is lost, let it go
returnValue(self._deallocated_status)
def _ws_handle_deallocated(self, msg):
self._deallocated_status = msg["status"]
self._wakeup()