make twisted work, get serialization into shape, add proper tests

This commit is contained in:
Brian Warner 2015-06-20 19:18:21 -07:00
parent 0f58f3906d
commit 25472423c6
2 changed files with 244 additions and 124 deletions

View File

@ -0,0 +1,132 @@
import json
from twisted.trial import unittest
from twisted.internet import defer, protocol, endpoints, reactor
from twisted.application import service
from ..servers.relay import RelayServer
from ..twisted.transcribe import SymmetricWormhole, UsageError
from .. import __version__
#from twisted.python import log
#import sys
#log.startLogging(sys.stdout)
def allocate_port():
ep = endpoints.serverFromString(reactor, "tcp:0:interface=127.0.0.1")
d = ep.listen(protocol.Factory())
def _listening(lp):
port = lp.getHost().port
d2 = lp.stopListening()
d2.addCallback(lambda _: port)
return d2
d.addCallback(_listening)
return d
def allocate_ports():
d = defer.DeferredList([allocate_port(), allocate_port()])
def _done(results):
port1 = results[0][1]
port2 = results[1][1]
return (port1, port2)
d.addCallback(_done)
return d
class Basic(unittest.TestCase):
def setUp(self):
self.sp = service.MultiService()
self.sp.startService()
d = allocate_ports()
def _got_ports(ports):
relayport, transitport = ports
s = RelayServer("tcp:%d:interface=127.0.0.1" % relayport,
"tcp:%s:interface=127.0.0.1" % transitport,
__version__)
s.setServiceParent(self.sp)
self.relayurl = "http://127.0.0.1:%d/wormhole-relay/" % relayport
self.transit = "tcp:127.0.0.1:%d" % transitport
d.addCallback(_got_ports)
return d
def tearDown(self):
return self.sp.stopService()
def test_basic(self):
appid = "appid"
w1 = SymmetricWormhole(appid, self.relayurl)
w2 = SymmetricWormhole(appid, self.relayurl)
d = w1.get_code()
def _got_code(code):
w2.set_code(code)
d1 = w1.get_data("data1")
d2 = w2.get_data("data2")
return defer.DeferredList([d1,d2], fireOnOneErrback=False)
d.addCallback(_got_code)
def _done(dl):
((success1, dataX), (success2, dataY)) = dl
r1,r2 = dl
self.assertTrue(success1)
self.assertTrue(success2)
self.assertEqual(dataX, "data2")
self.assertEqual(dataY, "data1")
d.addCallback(_done)
return d
def test_fixed_code(self):
appid = "appid"
w1 = SymmetricWormhole(appid, self.relayurl)
w2 = SymmetricWormhole(appid, self.relayurl)
w1.set_code("123-purple-elephant")
w2.set_code("123-purple-elephant")
d1 = w1.get_data("data1")
d2 = w2.get_data("data2")
d = defer.DeferredList([d1,d2], fireOnOneErrback=False)
def _done(dl):
((success1, dataX), (success2, dataY)) = dl
r1,r2 = dl
self.assertTrue(success1)
self.assertTrue(success2)
self.assertEqual(dataX, "data2")
self.assertEqual(dataY, "data1")
d.addCallback(_done)
return d
def test_errors(self):
appid = "appid"
w1 = SymmetricWormhole(appid, self.relayurl)
self.assertRaises(UsageError, w1.get_verifier)
self.assertRaises(UsageError, w1.get_data, "data")
w1.set_code("123-purple-elephant")
self.assertRaises(UsageError, w1.set_code, "123-nope")
self.assertRaises(UsageError, w1.get_code)
w2 = SymmetricWormhole(appid, self.relayurl)
d = w2.get_code()
self.assertRaises(UsageError, w2.get_code)
return d
def test_serialize(self):
appid = "appid"
w1 = SymmetricWormhole(appid, self.relayurl)
self.assertRaises(UsageError, w1.serialize) # too early
w2 = SymmetricWormhole(appid, self.relayurl)
d = w1.get_code()
def _got_code(code):
self.assertRaises(UsageError, w2.serialize) # too early
w2.set_code(code)
w2.serialize() # ok
s = w1.serialize()
self.assertEqual(type(s), type(""))
unpacked = json.loads(s) # this is supposed to be JSON
self.assertEqual(type(unpacked), dict)
new_w1 = SymmetricWormhole.from_serialized(s)
d1 = new_w1.get_data("data1")
d2 = w2.get_data("data2")
return defer.DeferredList([d1,d2], fireOnOneErrback=False)
d.addCallback(_got_code)
def _done(dl):
((success1, dataX), (success2, dataY)) = dl
r1,r2 = dl
self.assertTrue(success1)
self.assertTrue(success2)
self.assertEqual(dataX, "data2")
self.assertEqual(dataY, "data1")
self.assertRaises(UsageError, w2.serialize) # too late
d.addCallback(_done)
return d

View File

@ -1,14 +1,14 @@
from __future__ import print_function from __future__ import print_function
import sys, json import os, sys, json, re
from binascii import hexlify, unhexlify from binascii import hexlify, unhexlify
from zope.interface import implementer from zope.interface import implementer
#from twisted.application import service
from twisted.internet import reactor, defer from twisted.internet import reactor, defer
from twisted.web import client as web_client from twisted.web import client as web_client
from twisted.web import error as web_error from twisted.web import error as web_error
from twisted.web.iweb import IBodyProducer from twisted.web.iweb import IBodyProducer
from nacl.secret import SecretBox from nacl.secret import SecretBox
from nacl.exceptions import CryptoError from nacl.exceptions import CryptoError
from nacl import utils
from spake2 import SPAKE2_Symmetric from spake2 import SPAKE2_Symmetric
from .eventsource import ReconnectingEventSource from .eventsource import ReconnectingEventSource
from .. import __version__ from .. import __version__
@ -21,10 +21,17 @@ class WrongPasswordError(Exception):
Key confirmation failed. Key confirmation failed.
""" """
class ReflectionAttack(Exception):
"""An attacker (or bug) reflected our outgoing message back to us."""
class UsageError(Exception):
"""The programmer did something wrong."""
@implementer(IBodyProducer) @implementer(IBodyProducer)
class DataProducer: class DataProducer:
def __init__(self, data): def __init__(self, data):
self.data = data self.data = data
self.length = len(data)
def startProducing(self, consumer): def startProducing(self, consumer):
consumer.write(self.data) consumer.write(self.data)
return defer.succeed(None) return defer.succeed(None)
@ -35,112 +42,87 @@ class DataProducer:
def resumeProducing(self): def resumeProducing(self):
pass pass
'''
class TwistedInitiator(service.MultiService):
"""I am a service, and I must be running to function. Either call my
.startService() method, or .setServiceParent() me to some other running
service. You can use i.when_done().addCallback(i.disownServiceParent) to
make me go away when everything is done.
"""
def __init__(self, appid, data, reactor, relay):
self.appid = appid
self.data = data
self.reactor = reactor
self.relay = relay
self.code = None
def set_code(self, code): # used for human-made pre-generated codes
assert self.code is None
mo = re.search(r'^(\d+)-', code)
if not mo:
raise ValueError("code (%s) must start with NN-" % code)
self.channel_id = int(mo.group(1))
self.code = code
self.sp = SPAKE2_A(self.code.encode("ascii"),
idA=self.appid+":Initiator",
idB=self.appid+":Receiver")
def get_code(self, length=2):
assert self.code is None
d = self._allocate_channel()
def _got_channel_id(channel_id):
code = codes.make_code(channel_id, code_length)
self.set_code(code)
return code
d.addCallback(_got_channel_id)
return d
def serialize(self):
if not self.code:
raise ValueEror
def get_data(self, outbound_data):
msg = self.sp.start()
# change SPAKE2 to choose random_scalar earlier, to make getting the
# first message idempotent.
...
def when_get_code(self):
pass # return Deferred
def when_get_data(self):
pass # return Deferred
class TwistedReceiver(service.MultiService):
def __init__(self, appid, data, code, reactor, relay):
self.appid = appid
self.data = data
self.code = code
self.reactor = reactor
self.relay = relay
def when_get_data(self):
pass # return Deferred
'''
class SymmetricWormhole: class SymmetricWormhole:
def __init__(self, appid, relay): def __init__(self, appid, relay):
self.appid = appid self.appid = appid
self.relay = relay self.relay = relay
self.agent = web_client.Agent(reactor) self.agent = web_client.Agent(reactor)
self.side = None
self.code = None
self.key = None self.key = None
self._started_get_code = False
def set_code(self, code): def get_code(self, code_length=2):
assert self.code is None if self.code is not None: raise UsageError
self.code = code if self._started_get_code: raise UsageError
# allocate the rest now too, so it can be serialized self._started_get_code = True
self.sp = SPAKE2_Symmetric(self.code.encode("ascii"), self.side = hexlify(os.urandom(5))
idA=self.appid+":SymmetricA", d = self._allocate_channel()
idB=self.appid+":SymmetricB") def _got_channel_id(channel_id):
self.msg1 = self.sp.start() code = codes.make_code(channel_id, code_length)
self._set_code_and_channel_id(code)
self._start()
return code
d.addCallback(_got_channel_id)
return d
def _allocate_channel(self): def _allocate_channel(self):
url = self.relay + "allocate/%s" % self.side url = self.relay + "allocate/%s" % self.side
d = self.post(url) d = self.post(url)
def _got_channel(data_json): def _got_channel(data):
data = json.loads(data_json)
if "welcome" in data: if "welcome" in data:
self.handle_welcome(data["welcome"]) self.handle_welcome(data["welcome"])
return data["channel-id"] return data["channel-id"]
d.addCallback(_got_channel) d.addCallback(_got_channel)
return d return d
def _deallocate(self, res): def set_code(self, code):
d = self.agent.request("POST", self.url("deallocate")) if self.code is not None: raise UsageError
d.addBoth(lambda _: res) # ignore POST failure, pass-through result if self.side is not None: raise UsageError
return d self._set_code_and_channel_id(code)
self.side = hexlify(os.urandom(5))
self._start()
def get_code(self, code_length=2): def _set_code_and_channel_id(self, code):
if self.code is not None: if self.code is not None: raise UsageError
return defer.succeed(self.code) mo = re.search(r'^(\d+)-', code)
d = self._allocate_channel() if not mo:
def _got_channel_id(channel_id): raise ValueError("code (%s) must start with NN-" % code)
code = codes.make_code(channel_id, code_length) self.channel_id = int(mo.group(1))
self.set_code(code) self.code = code
return code
d.addCallback(_got_channel_id) def _start(self):
return d # allocate the rest now too, so it can be serialized
self.sp = SPAKE2_Symmetric(self.code.encode("ascii"),
idA=self.appid+":SymmetricA",
idB=self.appid+":SymmetricB")
self.msg1 = self.sp.start()
def serialize(self):
# I can only be serialized after get_code/set_code and before
# get_verifier/get_data
if self.code is None: raise UsageError
if self.key is not None: raise UsageError
data = {
"appid": self.appid,
"relay": self.relay,
"code": self.code,
"side": self.side,
"spake2": json.loads(self.sp.serialize()),
"msg1": self.msg1.encode("hex"),
}
return json.dumps(data)
@classmethod
def from_serialized(klass, data):
d = json.loads(data)
self = klass(d["appid"].encode("ascii"), d["relay"].encode("ascii"))
self._set_code_and_channel_id(d["code"].encode("ascii"))
self.side = d["side"].encode("ascii")
self.sp = SPAKE2_Symmetric.from_serialized(json.dumps(d["spake2"]))
self.msg1 = d["msg1"].decode("hex")
return self
motd_displayed = False motd_displayed = False
version_warning_displayed = False version_warning_displayed = False
@ -174,6 +156,8 @@ class SymmetricWormhole:
return url return url
def post(self, url, post_json=None): def post(self, url, post_json=None):
# TODO: retry on failure, with exponential backoff. We're guarding
# against the rendezvous server being temporarily offline.
p = None p = None
if post_json: if post_json:
data = json.dumps(post_json).encode("utf-8") data = json.dumps(post_json).encode("utf-8")
@ -188,7 +172,7 @@ class SymmetricWormhole:
d.addCallback(lambda data: json.loads(data)) d.addCallback(lambda data: json.loads(data))
return d return d
def get_msgs(self, old_msgs, verb, msgnum): def _get_msgs(self, old_msgs, verb, msgnum):
# fire with a list of messages that match verb/msgnum, which either # fire with a list of messages that match verb/msgnum, which either
# came from old_msgs, or from an EventSource that we attached to the # came from old_msgs, or from an EventSource that we attached to the
# corresponding URL # corresponding URL
@ -200,9 +184,9 @@ class SymmetricWormhole:
if name == "welcome": if name == "welcome":
self.handle_welcome(json.loads(data)) self.handle_welcome(json.loads(data))
if name == "message": if name == "message":
msgs.extend(json.loads(data)["message"]) msgs.append(json.loads(data)["message"])
d.callback(None) d.callback(None)
es = ReconnectingEventSource(None, lambda: self.url("post", "pake"), es = ReconnectingEventSource(None, lambda: self.url(verb, msgnum),
_handle)#, agent=self.agent) _handle)#, agent=self.agent)
es.startService() # TODO: .setServiceParent(self) es.startService() # TODO: .setServiceParent(self)
es.activate() es.activate()
@ -212,9 +196,22 @@ class SymmetricWormhole:
return d return d
def derive_key(self, purpose, length=SecretBox.KEY_SIZE): def derive_key(self, purpose, length=SecretBox.KEY_SIZE):
assert self.key is not None # call after get_verifier() or get_data()
assert type(purpose) == type(b"") assert type(purpose) == type(b"")
return HKDF(self.key, length, CTXinfo=purpose) return HKDF(self.key, length, CTXinfo=purpose)
def _encrypt_data(self, key, data):
assert len(key) == SecretBox.KEY_SIZE
box = SecretBox(key)
nonce = utils.random(SecretBox.NONCE_SIZE)
return box.encrypt(data, nonce)
def _decrypt_data(self, key, encrypted):
assert len(key) == SecretBox.KEY_SIZE
box = SecretBox(key)
data = box.decrypt(encrypted)
return data
def _get_key(self): def _get_key(self):
# TODO: prevent multiple invocation # TODO: prevent multiple invocation
@ -222,7 +219,7 @@ class SymmetricWormhole:
return defer.succeed(self.key) return defer.succeed(self.key)
data = {"message": hexlify(self.msg1).decode("ascii")} data = {"message": hexlify(self.msg1).decode("ascii")}
d = self.post(self.url("post", "pake"), data) d = self.post(self.url("post", "pake"), data)
d.addCallback(lambda j: self.get_msgs(j["messages"], "poll", "pake")) d.addCallback(lambda j: self._get_msgs(j["messages"], "poll", "pake"))
def _got_pake(msgs): def _got_pake(msgs):
pake_msg = unhexlify(msgs[0].encode("ascii")) pake_msg = unhexlify(msgs[0].encode("ascii"))
key = self.sp.finish(pake_msg) key = self.sp.finish(pake_msg)
@ -233,25 +230,33 @@ class SymmetricWormhole:
return d return d
def get_verifier(self): def get_verifier(self):
if self.code is None: raise UsageError
d = self._get_key() d = self._get_key()
d.addCallback(lambda _: self.verifier) d.addCallback(lambda _: self.verifier)
return d return d
def get_data(self, outbound_data): def get_data(self, outbound_data):
# only call this once # only call this once
if self.code is None: raise UsageError
d = self._get_key() d = self._get_key()
def _got_key(_): d.addCallback(self._get_data2, outbound_data)
outbound_key = self.derive_key(b"sender") return d
outbound_encrypted = self._encrypt_data(outbound_key, outbound_data)
data = {"message": hexlify(outbound_encrypted).decode("ascii")} def _get_data2(self, key, outbound_data):
return self.post(self.url("post", "data"), data) # Without predefined roles, we can't derive predictably unique keys
d.addCallback(lambda j: self.get_msgs(j["messages"], "poll", "data")) # for each side, so we use the same key for both. We use random
# nonces to keep the messages distinct, and check for reflection.
data_key = self.derive_key(b"data-key")
outbound_encrypted = self._encrypt_data(data_key, outbound_data)
data = {"message": hexlify(outbound_encrypted).decode("ascii")}
d = self.post(self.url("post", "data"), data)
d.addCallback(lambda j: self._get_msgs(j["messages"], "poll", "data"))
def _got_data(msgs): def _got_data(msgs):
inbound_encrypted = unhexlify(msgs[0].encode("ascii")) inbound_encrypted = unhexlify(msgs[0].encode("ascii"))
inbound_key = self.derive_key(b"receiver") if inbound_encrypted == outbound_encrypted:
raise ReflectionAttack
try: try:
inbound_data = self._decrypt_data(inbound_key, inbound_data = self._decrypt_data(data_key, inbound_encrypted)
inbound_encrypted)
return inbound_data return inbound_data
except CryptoError: except CryptoError:
raise WrongPasswordError raise WrongPasswordError
@ -259,25 +264,8 @@ class SymmetricWormhole:
d.addBoth(self._deallocate) d.addBoth(self._deallocate)
return d return d
def _deallocate(self, res):
def serialize(self): # only try once, no retries
assert self.code is not None d = self.agent.request("POST", self.url("deallocate"))
data = { d.addBoth(lambda _: res) # ignore POST failure, pass-through result
"appid": self.appid, return d
"payload_for_them": self.payload_for_them.encode("hex"),
"relay": self.relay,
"code": self.code,
"wormhole": json.loads(self.sp.serialize()),
"msg1": self.msg1.encode("hex"),
}
return json.dumps(data)
@classmethod
def from_serialized(klass, data):
d = json.loads(data)
self = klass(str(d["appid"]), d["payload_for_them"].decode("hex"),
str(d["relay"]))
self.code = str(d["code"])
self.sp = SPAKE2_Symmetric.from_serialized(json.dumps(d["wormhole"]))
self.msg1 = d["msg1"].decode("hex")
return self