This requires a DB delete/recreate when upgrading. It changes the server protocol, and app IDs, so clients cannot interoperate with each other across this change, nor with the server. Flag day for everyone! Now apps do not share channel IDs, so a lot of usage of app1 will not cause the wormhole codes for app2 to get longer.
367 lines
14 KiB
Python
367 lines
14 KiB
Python
from __future__ import print_function
|
|
import os, sys, json, re, unicodedata
|
|
from six.moves.urllib_parse import urlencode
|
|
from binascii import hexlify, unhexlify
|
|
from zope.interface import implementer
|
|
from twisted.internet import reactor, defer
|
|
from twisted.web import client as web_client
|
|
from twisted.web import error as web_error
|
|
from twisted.web.iweb import IBodyProducer
|
|
from nacl.secret import SecretBox
|
|
from nacl.exceptions import CryptoError
|
|
from nacl import utils
|
|
from spake2 import SPAKE2_Symmetric
|
|
from .eventsource_twisted import ReconnectingEventSource
|
|
from .. import __version__
|
|
from .. import codes
|
|
from ..errors import ServerError, WrongPasswordError, UsageError
|
|
from ..util.hkdf import HKDF
|
|
from ..channel_monitor import monitor
|
|
|
|
def to_bytes(u):
|
|
return unicodedata.normalize("NFC", u).encode("utf-8")
|
|
|
|
@implementer(IBodyProducer)
|
|
class DataProducer:
|
|
def __init__(self, data):
|
|
self.data = data
|
|
self.length = len(data)
|
|
def startProducing(self, consumer):
|
|
consumer.write(self.data)
|
|
return defer.succeed(None)
|
|
def stopProducing(self):
|
|
pass
|
|
def pauseProducing(self):
|
|
pass
|
|
def resumeProducing(self):
|
|
pass
|
|
|
|
|
|
def post_json(agent, url, request_body):
|
|
# POST a JSON body to a URL, parsing the response as JSON
|
|
data = json.dumps(request_body).encode("utf-8")
|
|
d = agent.request("POST", url.encode("utf-8"),
|
|
bodyProducer=DataProducer(data))
|
|
def _check_error(resp):
|
|
if resp.code != 200:
|
|
raise web_error.Error(resp.code, resp.phrase)
|
|
return resp
|
|
d.addCallback(_check_error)
|
|
d.addCallback(web_client.readBody)
|
|
d.addCallback(lambda data: json.loads(data))
|
|
return d
|
|
|
|
def get_json(agent, url):
|
|
# GET from a URL, parsing the response as JSON
|
|
d = agent.request("GET", url.encode("utf-8"))
|
|
def _check_error(resp):
|
|
if resp.code != 200:
|
|
raise web_error.Error(resp.code, resp.phrase)
|
|
return resp
|
|
d.addCallback(_check_error)
|
|
d.addCallback(web_client.readBody)
|
|
d.addCallback(lambda data: json.loads(data))
|
|
return d
|
|
|
|
class Channel:
|
|
def __init__(self, relay_url, appid, channelid, side, handle_welcome,
|
|
agent):
|
|
self._relay_url = relay_url
|
|
self._appid = appid
|
|
self._channelid = channelid
|
|
self._side = side
|
|
self._handle_welcome = handle_welcome
|
|
self._agent = agent
|
|
self._messages = set() # (phase,body) , body is bytes
|
|
self._sent_messages = set() # (phase,body)
|
|
|
|
def _add_inbound_messages(self, messages):
|
|
for msg in messages:
|
|
phase = msg["phase"]
|
|
body = unhexlify(msg["body"].encode("ascii"))
|
|
self._messages.add( (phase, body) )
|
|
|
|
def _find_inbound_message(self, phase):
|
|
for (their_phase,body) in self._messages - self._sent_messages:
|
|
if their_phase == phase:
|
|
return body
|
|
return None
|
|
|
|
def send(self, phase, msg):
|
|
# TODO: retry on failure, with exponential backoff. We're guarding
|
|
# against the rendezvous server being temporarily offline.
|
|
if not isinstance(phase, type(u"")): raise UsageError(type(phase))
|
|
if not isinstance(msg, type(b"")): raise UsageError(type(msg))
|
|
self._sent_messages.add( (phase,msg) )
|
|
payload = {"appid": self._appid,
|
|
"channelid": self._channelid,
|
|
"side": self._side,
|
|
"phase": phase,
|
|
"body": hexlify(msg).decode("ascii")}
|
|
d = post_json(self._agent, self._relay_url+"add", payload)
|
|
d.addCallback(lambda resp: self._add_inbound_messages(resp["messages"]))
|
|
return d
|
|
|
|
def get(self, phase):
|
|
# fire with a bytestring of the first message for 'phase' that wasn't
|
|
# one of ours. It will either come from previously-received messages,
|
|
# or from an EventSource that we attach to the corresponding URL
|
|
body = self._find_inbound_message(phase)
|
|
if body is not None:
|
|
return defer.succeed(body)
|
|
d = defer.Deferred()
|
|
msgs = []
|
|
def _handle(name, data):
|
|
if name == "welcome":
|
|
self._handle_welcome(json.loads(data))
|
|
if name == "message":
|
|
self._add_inbound_messages([json.loads(data)])
|
|
body = self._find_inbound_message(phase)
|
|
if body is not None and not msgs:
|
|
msgs.append(body)
|
|
d.callback(None)
|
|
# TODO: use agent=self._agent
|
|
queryargs = urlencode([("appid", self._appid),
|
|
("channelid", self._channelid)])
|
|
es = ReconnectingEventSource(self._relay_url+"get?%s" % queryargs,
|
|
_handle)
|
|
es.startService() # TODO: .setServiceParent(self)
|
|
es.activate()
|
|
d.addCallback(lambda _: es.deactivate())
|
|
d.addCallback(lambda _: es.stopService())
|
|
d.addCallback(lambda _: msgs[0])
|
|
return d
|
|
|
|
def deallocate(self):
|
|
# only try once, no retries
|
|
d = post_json(self._agent, self._relay_url+"deallocate",
|
|
{"appid": self._appid,
|
|
"channelid": self._channelid,
|
|
"side": self._side})
|
|
d.addBoth(lambda _: None) # ignore POST failure
|
|
return d
|
|
|
|
class ChannelManager:
|
|
def __init__(self, relay, appid, side, handle_welcome):
|
|
assert isinstance(relay, type(u""))
|
|
self._relay = relay
|
|
self._appid = appid
|
|
self._side = side
|
|
self._handle_welcome = handle_welcome
|
|
self._agent = web_client.Agent(reactor)
|
|
|
|
def allocate(self):
|
|
url = self._relay + "allocate"
|
|
d = post_json(self._agent, url, {"appid": self._appid,
|
|
"side": self._side})
|
|
def _got_channel(data):
|
|
if "welcome" in data:
|
|
self._handle_welcome(data["welcome"])
|
|
return data["channelid"]
|
|
d.addCallback(_got_channel)
|
|
return d
|
|
|
|
def list_channels(self):
|
|
queryargs = urlencode([("appid", self._appid)])
|
|
url = self._relay + u"list?%s" % queryargs
|
|
d = get_json(self._agent, url)
|
|
d.addCallback(lambda r: r["channelids"])
|
|
return d
|
|
|
|
def connect(self, channelid):
|
|
return Channel(self._relay, self._appid, channelid, self._side,
|
|
self._handle_welcome, self._agent)
|
|
|
|
class Wormhole:
|
|
motd_displayed = False
|
|
version_warning_displayed = False
|
|
|
|
def __init__(self, appid, relay_url):
|
|
if not isinstance(appid, type(u"")): raise UsageError
|
|
if not isinstance(relay_url, type(u"")): raise UsageError
|
|
if not relay_url.endswith(u"/"): raise UsageError
|
|
self._appid = appid
|
|
self._relay_url = relay_url
|
|
self._set_side(hexlify(os.urandom(5)).decode("ascii"))
|
|
self.code = None
|
|
self.key = None
|
|
self._started_get_code = False
|
|
self._sent_data = False
|
|
self._got_data = False
|
|
|
|
def _set_side(self, side):
|
|
self._side = side
|
|
self._channel_manager = ChannelManager(self._relay_url, self._appid,
|
|
self._side, self.handle_welcome)
|
|
|
|
def handle_welcome(self, welcome):
|
|
if ("motd" in welcome and
|
|
not self.motd_displayed):
|
|
motd_lines = welcome["motd"].splitlines()
|
|
motd_formatted = "\n ".join(motd_lines)
|
|
print("Server (at %s) says:\n %s" %
|
|
(self._relay_url, motd_formatted), file=sys.stderr)
|
|
self.motd_displayed = True
|
|
|
|
# Only warn if we're running a release version (e.g. 0.0.6, not
|
|
# 0.0.6-DISTANCE-gHASH). Only warn once.
|
|
if ("-" not in __version__ and
|
|
not self.version_warning_displayed and
|
|
welcome["current_version"] != __version__):
|
|
print("Warning: errors may occur unless both sides are running the same version", file=sys.stderr)
|
|
print("Server claims %s is current, but ours is %s"
|
|
% (welcome["current_version"], __version__), file=sys.stderr)
|
|
self.version_warning_displayed = True
|
|
|
|
if "error" in welcome:
|
|
raise ServerError(welcome["error"], self._relay_url)
|
|
|
|
def get_code(self, code_length=2):
|
|
if self.code is not None: raise UsageError
|
|
if self._started_get_code: raise UsageError
|
|
self._started_get_code = True
|
|
d = self._channel_manager.allocate()
|
|
def _got_channelid(channelid):
|
|
code = codes.make_code(channelid, code_length)
|
|
assert isinstance(code, str), type(code)
|
|
self._set_code_and_channelid(code)
|
|
self._start()
|
|
return code
|
|
d.addCallback(_got_channelid)
|
|
return d
|
|
|
|
def set_code(self, code):
|
|
if not isinstance(code, str): raise UsageError
|
|
if self.code is not None: raise UsageError
|
|
self._set_code_and_channelid(code)
|
|
self._start()
|
|
|
|
def _set_code_and_channelid(self, code):
|
|
if self.code is not None: raise UsageError
|
|
mo = re.search(r'^(\d+)-', code)
|
|
if not mo:
|
|
raise ValueError("code (%s) must start with NN-" % code)
|
|
self.code = code
|
|
channelid = int(mo.group(1))
|
|
self.channel = self._channel_manager.connect(channelid)
|
|
monitor.add(self.channel)
|
|
|
|
def _start(self):
|
|
# allocate the rest now too, so it can be serialized
|
|
self.sp = SPAKE2_Symmetric(self.code.encode("ascii"),
|
|
idSymmetric=to_bytes(self._appid))
|
|
self.msg1 = self.sp.start()
|
|
|
|
def serialize(self):
|
|
# I can only be serialized after get_code/set_code and before
|
|
# get_verifier/get_data
|
|
if self.code is None: raise UsageError
|
|
if self.key is not None: raise UsageError
|
|
if self._sent_data: raise UsageError
|
|
if self._got_data: raise UsageError
|
|
data = {
|
|
"appid": self._appid,
|
|
"relay_url": self._relay_url,
|
|
"code": self.code,
|
|
"side": self._side,
|
|
"spake2": json.loads(self.sp.serialize()),
|
|
"msg1": self.msg1.encode("hex"),
|
|
}
|
|
return json.dumps(data)
|
|
|
|
@classmethod
|
|
def from_serialized(klass, data):
|
|
d = json.loads(data)
|
|
self = klass(d["appid"], d["relay_url"])
|
|
self._set_side(d["side"].encode("ascii"))
|
|
self._set_code_and_channelid(d["code"].encode("ascii"))
|
|
self.sp = SPAKE2_Symmetric.from_serialized(json.dumps(d["spake2"]))
|
|
self.msg1 = d["msg1"].decode("hex")
|
|
return self
|
|
|
|
def derive_key(self, purpose, length=SecretBox.KEY_SIZE):
|
|
if not isinstance(purpose, type(u"")): raise UsageError
|
|
if self.key is None:
|
|
# call after get_verifier() or get_data()
|
|
raise UsageError
|
|
return HKDF(self.key, length, CTXinfo=to_bytes(purpose))
|
|
|
|
def _encrypt_data(self, key, data):
|
|
assert isinstance(key, type(b"")), type(key)
|
|
assert isinstance(data, type(b"")), type(data)
|
|
if len(key) != SecretBox.KEY_SIZE: raise UsageError
|
|
box = SecretBox(key)
|
|
nonce = utils.random(SecretBox.NONCE_SIZE)
|
|
return box.encrypt(data, nonce)
|
|
|
|
def _decrypt_data(self, key, encrypted):
|
|
assert isinstance(key, type(b"")), type(key)
|
|
assert isinstance(encrypted, type(b"")), type(encrypted)
|
|
if len(key) != SecretBox.KEY_SIZE: raise UsageError
|
|
box = SecretBox(key)
|
|
data = box.decrypt(encrypted)
|
|
return data
|
|
|
|
|
|
def _get_key(self):
|
|
# TODO: prevent multiple invocation
|
|
if self.key:
|
|
return defer.succeed(self.key)
|
|
d = self.channel.send(u"pake", self.msg1)
|
|
d.addCallback(lambda _: self.channel.get(u"pake"))
|
|
def _got_pake(pake_msg):
|
|
key = self.sp.finish(pake_msg)
|
|
self.key = key
|
|
self.verifier = self.derive_key(self._appid+u":Verifier")
|
|
return key
|
|
d.addCallback(_got_pake)
|
|
return d
|
|
|
|
def get_verifier(self):
|
|
if self.code is None: raise UsageError
|
|
d = self._get_key()
|
|
d.addCallback(lambda _: self.verifier)
|
|
return d
|
|
|
|
def send_data(self, outbound_data):
|
|
if self._sent_data: raise UsageError # only call this once
|
|
if not isinstance(outbound_data, type(b"")): raise UsageError
|
|
if self.code is None: raise UsageError
|
|
if self.channel is None: raise UsageError
|
|
# Without predefined roles, we can't derive predictably unique keys
|
|
# for each side, so we use the same key for both. We use random
|
|
# nonces to keep the messages distinct, and the Channel automatically
|
|
# ignores reflections.
|
|
d = self._get_key()
|
|
def _send(key):
|
|
data_key = self.derive_key(u"data-key")
|
|
outbound_encrypted = self._encrypt_data(data_key, outbound_data)
|
|
return self.channel.send(u"data", outbound_encrypted)
|
|
d.addCallback(_send)
|
|
return d
|
|
|
|
def get_data(self):
|
|
if self._got_data: raise UsageError # only call this once
|
|
if self.code is None: raise UsageError
|
|
if self.channel is None: raise UsageError
|
|
d = self._get_key()
|
|
def _get(key):
|
|
data_key = self.derive_key(u"data-key")
|
|
d1 = self.channel.get(u"data")
|
|
def _decrypt(inbound_encrypted):
|
|
try:
|
|
inbound_data = self._decrypt_data(data_key,
|
|
inbound_encrypted)
|
|
return inbound_data
|
|
except CryptoError:
|
|
raise WrongPasswordError
|
|
d1.addCallback(_decrypt)
|
|
return d1
|
|
d.addCallback(_get)
|
|
return d
|
|
|
|
def close(self, res=None):
|
|
monitor.close(self.channel)
|
|
d = self.channel.deallocate()
|
|
return d
|