Merge branch 'twisted'

This commit is contained in:
Brian Warner 2015-06-20 19:21:16 -07:00
commit eb18b1359e
6 changed files with 687 additions and 46 deletions

View File

@ -50,39 +50,45 @@ theirdata = r.get_data(mydata)
print("Their data: %s" % theirdata.decode("ascii"))
```
## Twisted (TODO)
## Twisted
The Twisted-friendly flow, which is not yet implemented, may look like this:
The Twisted-friendly flow looks like this:
```python
from twisted.internet import reactor
from wormhole.transcribe import TwistedInitiator
data = b"initiator's data"
ti = TwistedInitiator("appid", data, reactor)
ti.startService()
d1 = ti.when_get_code()
d1.addCallback(lambda code: print("Invitation Code: %s" % code))
d2 = ti.when_get_data()
d2.addCallback(lambda theirdata:
print("Their data: %s" % theirdata.decode("ascii")))
d2.addCallback(labmda _: reactor.stop())
from wormhole.public_relay import RENDEZVOUS_RELAY
from wormhole.twisted.transcribe import SymmetricWormhole
outbound_message = b"outbound data"
w1 = SymmetricWormhole("appid", RENDEZVOUS_RELAY)
d = w1.get_code()
def _got_code(code):
print "Invitation Code:", code
return w1.get_data(outbound_message)
d.addCallback(_got_code)
def _got_data(inbound_message):
print "Inbound message:", inbound_message
d.addCallback(_got_data)
d.addBoth(lambda _: reactor.stop())
reactor.run()
```
On the other side, you call `set_code()` instead of waiting for `get_code()`:
```python
from twisted.internet import reactor
from wormhole.transcribe import TwistedReceiver
data = b"receiver's data"
code = sys.argv[1]
tr = TwistedReceiver("appid", code, data, reactor)
tr.startService()
d = tr.when_get_data()
d.addCallback(lambda theirdata:
print("Their data: %s" % theirdata.decode("ascii")))
d.addCallback(lambda _: reactor.stop())
reactor.run()
w2 = SymmetricWormhole("appid", RENDEZVOUS_RELAY)
w2.set_code(code)
d = w2.get_data(my_message)
```
You can call `d=w.get_verifier()` before `get_data()`: this will perform the
first half of the PAKE negotiation, then fire the Deferred with a verifier
object (bytes) which can be converted into a printable representation and
manually compared. When the users are convinced that `get_verifier()` from
both sides are the same, call `d=get_data()` to continue the transfer. If you
call `get_data()` first, it will perform the complete transfer without
pausing.
## Application Identifier
Applications using this library must provide an "application identifier", a
@ -135,17 +141,27 @@ sync will be abandoned, and all callbacks will errback with a TimeoutError.
Both have defaults suitable for face-to-face realtime setup environments.
## Serialization (TODO)
## Serialization
TODO: only the Twisted form supports serialization so far
You may not be able to hold the Initiator/Receiver object in memory for the
whole sync process: maybe you allow it to wait for several days, but the
program will be restarted during that time. To support this, you can persist
the state of the object by calling `data = i.serialize()`, which will return
the state of the object by calling `data = w.serialize()`, which will return
a printable bytestring (the JSON-encoding of a small dictionary). To restore,
call `Initiator.from_serialized(data)`.
use the `from_serialized(data)` classmethod (e.g. `w =
SymmetricWormhole.from_serialized(data)`).
Note that callbacks are not serialized: they must be restored after
deserialization.
There is exactly one point at which you can serialize the wormhole: *after*
establishing the invitation code, but before waiting for `get_verifier()` or
`get_data()`. If you are creating a new code, the correct time is during the
callback fired by `get_code()`. If you are accepting a pre-generated code,
the time is just after calling `set_code()`.
To properly checkpoint the process, you should store the first message
(returned by `start()`) next to the serialized wormhole instance, so you can
re-send it if necessary.
## Detailed Example

View File

@ -0,0 +1,132 @@
import json
from twisted.trial import unittest
from twisted.internet import defer, protocol, endpoints, reactor
from twisted.application import service
from ..servers.relay import RelayServer
from ..twisted.transcribe import SymmetricWormhole, UsageError
from .. import __version__
#from twisted.python import log
#import sys
#log.startLogging(sys.stdout)
def allocate_port():
ep = endpoints.serverFromString(reactor, "tcp:0:interface=127.0.0.1")
d = ep.listen(protocol.Factory())
def _listening(lp):
port = lp.getHost().port
d2 = lp.stopListening()
d2.addCallback(lambda _: port)
return d2
d.addCallback(_listening)
return d
def allocate_ports():
d = defer.DeferredList([allocate_port(), allocate_port()])
def _done(results):
port1 = results[0][1]
port2 = results[1][1]
return (port1, port2)
d.addCallback(_done)
return d
class Basic(unittest.TestCase):
def setUp(self):
self.sp = service.MultiService()
self.sp.startService()
d = allocate_ports()
def _got_ports(ports):
relayport, transitport = ports
s = RelayServer("tcp:%d:interface=127.0.0.1" % relayport,
"tcp:%s:interface=127.0.0.1" % transitport,
__version__)
s.setServiceParent(self.sp)
self.relayurl = "http://127.0.0.1:%d/wormhole-relay/" % relayport
self.transit = "tcp:127.0.0.1:%d" % transitport
d.addCallback(_got_ports)
return d
def tearDown(self):
return self.sp.stopService()
def test_basic(self):
appid = "appid"
w1 = SymmetricWormhole(appid, self.relayurl)
w2 = SymmetricWormhole(appid, self.relayurl)
d = w1.get_code()
def _got_code(code):
w2.set_code(code)
d1 = w1.get_data("data1")
d2 = w2.get_data("data2")
return defer.DeferredList([d1,d2], fireOnOneErrback=False)
d.addCallback(_got_code)
def _done(dl):
((success1, dataX), (success2, dataY)) = dl
r1,r2 = dl
self.assertTrue(success1)
self.assertTrue(success2)
self.assertEqual(dataX, "data2")
self.assertEqual(dataY, "data1")
d.addCallback(_done)
return d
def test_fixed_code(self):
appid = "appid"
w1 = SymmetricWormhole(appid, self.relayurl)
w2 = SymmetricWormhole(appid, self.relayurl)
w1.set_code("123-purple-elephant")
w2.set_code("123-purple-elephant")
d1 = w1.get_data("data1")
d2 = w2.get_data("data2")
d = defer.DeferredList([d1,d2], fireOnOneErrback=False)
def _done(dl):
((success1, dataX), (success2, dataY)) = dl
r1,r2 = dl
self.assertTrue(success1)
self.assertTrue(success2)
self.assertEqual(dataX, "data2")
self.assertEqual(dataY, "data1")
d.addCallback(_done)
return d
def test_errors(self):
appid = "appid"
w1 = SymmetricWormhole(appid, self.relayurl)
self.assertRaises(UsageError, w1.get_verifier)
self.assertRaises(UsageError, w1.get_data, "data")
w1.set_code("123-purple-elephant")
self.assertRaises(UsageError, w1.set_code, "123-nope")
self.assertRaises(UsageError, w1.get_code)
w2 = SymmetricWormhole(appid, self.relayurl)
d = w2.get_code()
self.assertRaises(UsageError, w2.get_code)
return d
def test_serialize(self):
appid = "appid"
w1 = SymmetricWormhole(appid, self.relayurl)
self.assertRaises(UsageError, w1.serialize) # too early
w2 = SymmetricWormhole(appid, self.relayurl)
d = w1.get_code()
def _got_code(code):
self.assertRaises(UsageError, w2.serialize) # too early
w2.set_code(code)
w2.serialize() # ok
s = w1.serialize()
self.assertEqual(type(s), type(""))
unpacked = json.loads(s) # this is supposed to be JSON
self.assertEqual(type(unpacked), dict)
new_w1 = SymmetricWormhole.from_serialized(s)
d1 = new_w1.get_data("data1")
d2 = w2.get_data("data2")
return defer.DeferredList([d1,d2], fireOnOneErrback=False)
d.addCallback(_got_code)
def _done(dl):
((success1, dataX), (success2, dataY)) = dl
r1,r2 = dl
self.assertTrue(success1)
self.assertTrue(success2)
self.assertEqual(dataX, "data2")
self.assertEqual(dataY, "data1")
self.assertRaises(UsageError, w2.serialize) # too late
d.addCallback(_done)
return d

View File

View File

@ -0,0 +1,30 @@
import sys
from twisted.internet import reactor
from .transcribe import SymmetricWormhole
from .. import public_relay
APPID = "lothar.com/wormhole/text-xfer"
w = SymmetricWormhole(APPID, public_relay.RENDEZVOUS_RELAY)
if sys.argv[1] == "send-text":
message = sys.argv[2]
d = w.get_code()
def _got_code(code):
print "code is:", code
return w.get_data(message)
d.addCallback(_got_code)
def _got_data(their_data):
print "ack:", their_data
d.addCallback(_got_data)
elif sys.argv[1] == "receive-text":
code = sys.argv[2]
w.set_code(code)
d = w.get_data("ok")
def _got_data(their_data):
print their_data
d.addCallback(_got_data)
else:
raise ValueError("bad command")
d.addCallback(lambda _: reactor.stop())
reactor.run()

View File

@ -0,0 +1,219 @@
from twisted.python import log, failure
from twisted.internet import reactor, defer, protocol
from twisted.application import service
from twisted.protocols import basic
from twisted.web.client import Agent, ResponseDone
from twisted.web.http_headers import Headers
from ..util.eventual import eventually
class EventSourceParser(basic.LineOnlyReceiver):
delimiter = "\n"
def __init__(self, handler):
self.current_field = None
self.current_lines = []
self.handler = handler
self.done_deferred = defer.Deferred()
self.eventtype = "message"
def connectionLost(self, why):
if why.check(ResponseDone):
why = None
self.done_deferred.callback(why)
def dataReceived(self, data):
# exceptions here aren't being logged properly, and tests will hang
# rather than halt. I suspect twisted.web._newclient's
# HTTP11ClientProtocol.dataReceived(), which catches everything and
# responds with self._giveUp() but doesn't log.err.
try:
basic.LineOnlyReceiver.dataReceived(self, data)
except:
log.err()
raise
def lineReceived(self, line):
if not line:
# blank line ends the field
self.fieldReceived(self.current_field,
"\n".join(self.current_lines))
self.current_field = None
self.current_lines[:] = []
return
if self.current_field is None:
self.current_field, data = line.split(": ", 1)
self.current_lines.append(data)
else:
self.current_lines.append(line)
def fieldReceived(self, fieldname, data):
if fieldname == "event":
self.eventtype = data
elif fieldname == "data":
self.eventReceived(self.eventtype, data)
self.eventtype = "message"
else:
log.msg("weird fieldname", fieldname, data)
def eventReceived(self, eventtype, data):
self.handler(eventtype, data)
class EventSourceError(Exception):
pass
# es = EventSource(url, handler)
# d = es.start()
# es.cancel()
class EventSource: # TODO: service.Service
def __init__(self, url, handler, when_connected=None, agent=None):
self.url = url
self.handler = handler
self.when_connected = when_connected
self.started = False
self.cancelled = False
self.proto = EventSourceParser(self.handler)
if not agent:
agent = Agent(reactor)
self.agent = agent
def start(self):
assert not self.started, "single-use"
self.started = True
d = self.agent.request("GET", self.url,
Headers({"accept": ["text/event-stream"]}))
d.addCallback(self._connected)
return d
def _connected(self, resp):
if resp.code != 200:
raise EventSourceError("%d: %s" % (resp.code, resp.phrase))
if self.when_connected:
self.when_connected()
#if resp.headers.getRawHeaders("content-type") == ["text/event-stream"]:
resp.deliverBody(self.proto)
if self.cancelled:
self.kill_connection()
return self.proto.done_deferred
def cancel(self):
self.cancelled = True
if not self.proto.transport:
# _connected hasn't been called yet, but that self.cancelled
# should take care of it when the connection is established
def kill(data):
# this should kill it as soon as any data is delivered
raise ValueError("dead")
self.proto.dataReceived = kill # just in case
return
self.kill_connection()
def kill_connection(self):
if (hasattr(self.proto.transport, "_producer")
and self.proto.transport._producer):
# This is gross and fragile. We need a clean way to stop the
# client connection. p.transport is a
# twisted.web._newclient.TransportProxyProducer , and its
# ._producer is the tcp.Port.
self.proto.transport._producer.loseConnection()
else:
log.err("get_events: unable to stop connection")
# oh well
#err = EventSourceError("unable to cancel")
try:
self.proto.done_deferred.callback(None)
except defer.AlreadyCalledError:
pass
class Connector:
# behave enough like an IConnector to appease ReconnectingClientFactory
def __init__(self, res):
self.res = res
def connect(self):
self.res._maybeStart()
def stopConnecting(self):
self.res._stop_eventsource()
class ReconnectingEventSource(service.MultiService,
protocol.ReconnectingClientFactory):
def __init__(self, baseurl, connection_starting, handler, agent=None):
service.MultiService.__init__(self)
# we don't use any of the basic Factory/ClientFactory methods of
# this, just the ReconnectingClientFactory.retry, stopTrying, and
# resetDelay methods.
self.baseurl = baseurl
self.connection_starting = connection_starting
self.handler = handler
self.agent = agent
# IService provides self.running, toggled by {start,stop}Service.
# self.active is toggled by {,de}activate. If both .running and
# .active are True, then we want to have an outstanding EventSource
# and will start one if necessary. If either is False, then we don't
# want one to be outstanding, and will initiate shutdown.
self.active = False
self.connector = Connector(self)
self.es = None # set we have an outstanding EventSource
self.when_stopped = [] # list of Deferreds
def isStopped(self):
return not self.es
def startService(self):
service.MultiService.startService(self) # sets self.running
self._maybeStart()
def stopService(self):
# clears self.running
d = defer.maybeDeferred(service.MultiService.stopService, self)
d.addCallback(self._maybeStop)
return d
def activate(self):
assert not self.active
self.active = True
self._maybeStart()
def deactivate(self):
assert self.active # XXX
self.active = False
return self._maybeStop()
def _maybeStart(self):
if not (self.active and self.running):
return
self.continueTrying = True
url = self.connection_starting()
self.es = EventSource(url, self.handler, self.resetDelay,
agent=self.agent)
d = self.es.start()
d.addBoth(self._stopped)
def _stopped(self, res):
self.es = None
# we might have stopped because of a connection error, or because of
# an intentional shutdown.
if self.active and self.running:
# we still want to be connected, so schedule a reconnection
if isinstance(res, failure.Failure):
log.err(res)
self.retry() # will eventually call _maybeStart
return
# intentional shutdown
self.stopTrying()
for d in self.when_stopped:
eventually(d.callback, None)
self.when_stopped = []
def _stop_eventsource(self):
if self.es:
eventually(self.es.cancel)
def _maybeStop(self, _=None):
self.stopTrying() # cancels timer, calls _stop_eventsource()
if not self.es:
return defer.succeed(None)
d = defer.Deferred()
self.when_stopped.append(d)
return d

View File

@ -1,27 +1,271 @@
from twisted.application import service
from ..const import RELAY
from __future__ import print_function
import os, sys, json, re
from binascii import hexlify, unhexlify
from zope.interface import implementer
from twisted.internet import reactor, defer
from twisted.web import client as web_client
from twisted.web import error as web_error
from twisted.web.iweb import IBodyProducer
from nacl.secret import SecretBox
from nacl.exceptions import CryptoError
from nacl import utils
from spake2 import SPAKE2_Symmetric
from .eventsource import ReconnectingEventSource
from .. import __version__
from .. import codes
from ..errors import ServerError
from ..util.hkdf import HKDF
class TwistedInitiator(service.MultiService):
def __init__(self, appid, data, reactor, relay=RELAY):
self.appid = appid
class WrongPasswordError(Exception):
"""
Key confirmation failed.
"""
class ReflectionAttack(Exception):
"""An attacker (or bug) reflected our outgoing message back to us."""
class UsageError(Exception):
"""The programmer did something wrong."""
@implementer(IBodyProducer)
class DataProducer:
def __init__(self, data):
self.data = data
self.reactor = reactor
self.length = len(data)
def startProducing(self, consumer):
consumer.write(self.data)
return defer.succeed(None)
def stopProducing(self):
pass
def pauseProducing(self):
pass
def resumeProducing(self):
pass
class SymmetricWormhole:
def __init__(self, appid, relay):
self.appid = appid
self.relay = relay
self.agent = web_client.Agent(reactor)
self.side = None
self.code = None
self.key = None
self._started_get_code = False
def when_get_code(self):
pass # return Deferred
def get_code(self, code_length=2):
if self.code is not None: raise UsageError
if self._started_get_code: raise UsageError
self._started_get_code = True
self.side = hexlify(os.urandom(5))
d = self._allocate_channel()
def _got_channel_id(channel_id):
code = codes.make_code(channel_id, code_length)
self._set_code_and_channel_id(code)
self._start()
return code
d.addCallback(_got_channel_id)
return d
def when_get_data(self):
pass # return Deferred
def _allocate_channel(self):
url = self.relay + "allocate/%s" % self.side
d = self.post(url)
def _got_channel(data):
if "welcome" in data:
self.handle_welcome(data["welcome"])
return data["channel-id"]
d.addCallback(_got_channel)
return d
class TwistedReceiver(service.MultiService):
def __init__(self, appid, data, code, reactor, relay=RELAY):
self.appid = appid
self.data = data
def set_code(self, code):
if self.code is not None: raise UsageError
if self.side is not None: raise UsageError
self._set_code_and_channel_id(code)
self.side = hexlify(os.urandom(5))
self._start()
def _set_code_and_channel_id(self, code):
if self.code is not None: raise UsageError
mo = re.search(r'^(\d+)-', code)
if not mo:
raise ValueError("code (%s) must start with NN-" % code)
self.channel_id = int(mo.group(1))
self.code = code
self.reactor = reactor
self.relay = relay
def when_get_data(self):
pass # return Deferred
def _start(self):
# allocate the rest now too, so it can be serialized
self.sp = SPAKE2_Symmetric(self.code.encode("ascii"),
idA=self.appid+":SymmetricA",
idB=self.appid+":SymmetricB")
self.msg1 = self.sp.start()
def serialize(self):
# I can only be serialized after get_code/set_code and before
# get_verifier/get_data
if self.code is None: raise UsageError
if self.key is not None: raise UsageError
data = {
"appid": self.appid,
"relay": self.relay,
"code": self.code,
"side": self.side,
"spake2": json.loads(self.sp.serialize()),
"msg1": self.msg1.encode("hex"),
}
return json.dumps(data)
@classmethod
def from_serialized(klass, data):
d = json.loads(data)
self = klass(d["appid"].encode("ascii"), d["relay"].encode("ascii"))
self._set_code_and_channel_id(d["code"].encode("ascii"))
self.side = d["side"].encode("ascii")
self.sp = SPAKE2_Symmetric.from_serialized(json.dumps(d["spake2"]))
self.msg1 = d["msg1"].decode("hex")
return self
motd_displayed = False
version_warning_displayed = False
def handle_welcome(self, welcome):
if ("motd" in welcome and
not self.motd_displayed):
motd_lines = welcome["motd"].splitlines()
motd_formatted = "\n ".join(motd_lines)
print("Server (at %s) says:\n %s" % (self.relay, motd_formatted),
file=sys.stderr)
self.motd_displayed = True
# Only warn if we're running a release version (e.g. 0.0.6, not
# 0.0.6-DISTANCE-gHASH). Only warn once.
if ("-" not in __version__ and
not self.version_warning_displayed and
welcome["current_version"] != __version__):
print("Warning: errors may occur unless both sides are running the same version", file=sys.stderr)
print("Server claims %s is current, but ours is %s"
% (welcome["current_version"], __version__), file=sys.stderr)
self.version_warning_displayed = True
if "error" in welcome:
raise ServerError(welcome["error"], self.relay)
def url(self, verb, msgnum=None):
url = "%s%d/%s/%s" % (self.relay, self.channel_id, self.side, verb)
if msgnum is not None:
url += "/" + msgnum
return url
def post(self, url, post_json=None):
# TODO: retry on failure, with exponential backoff. We're guarding
# against the rendezvous server being temporarily offline.
p = None
if post_json:
data = json.dumps(post_json).encode("utf-8")
p = DataProducer(data)
d = self.agent.request("POST", url, bodyProducer=p)
def _check_error(resp):
if resp.code != 200:
raise web_error.Error(resp.code, resp.phrase)
return resp
d.addCallback(_check_error)
d.addCallback(web_client.readBody)
d.addCallback(lambda data: json.loads(data))
return d
def _get_msgs(self, old_msgs, verb, msgnum):
# fire with a list of messages that match verb/msgnum, which either
# came from old_msgs, or from an EventSource that we attached to the
# corresponding URL
if old_msgs:
return defer.succeed(old_msgs)
d = defer.Deferred()
msgs = []
def _handle(name, data):
if name == "welcome":
self.handle_welcome(json.loads(data))
if name == "message":
msgs.append(json.loads(data)["message"])
d.callback(None)
es = ReconnectingEventSource(None, lambda: self.url(verb, msgnum),
_handle)#, agent=self.agent)
es.startService() # TODO: .setServiceParent(self)
es.activate()
d.addCallback(lambda _: es.deactivate())
d.addCallback(lambda _: es.stopService())
d.addCallback(lambda _: msgs)
return d
def derive_key(self, purpose, length=SecretBox.KEY_SIZE):
assert self.key is not None # call after get_verifier() or get_data()
assert type(purpose) == type(b"")
return HKDF(self.key, length, CTXinfo=purpose)
def _encrypt_data(self, key, data):
assert len(key) == SecretBox.KEY_SIZE
box = SecretBox(key)
nonce = utils.random(SecretBox.NONCE_SIZE)
return box.encrypt(data, nonce)
def _decrypt_data(self, key, encrypted):
assert len(key) == SecretBox.KEY_SIZE
box = SecretBox(key)
data = box.decrypt(encrypted)
return data
def _get_key(self):
# TODO: prevent multiple invocation
if self.key:
return defer.succeed(self.key)
data = {"message": hexlify(self.msg1).decode("ascii")}
d = self.post(self.url("post", "pake"), data)
d.addCallback(lambda j: self._get_msgs(j["messages"], "poll", "pake"))
def _got_pake(msgs):
pake_msg = unhexlify(msgs[0].encode("ascii"))
key = self.sp.finish(pake_msg)
self.key = key
self.verifier = self.derive_key(self.appid+b":Verifier")
return key
d.addCallback(_got_pake)
return d
def get_verifier(self):
if self.code is None: raise UsageError
d = self._get_key()
d.addCallback(lambda _: self.verifier)
return d
def get_data(self, outbound_data):
# only call this once
if self.code is None: raise UsageError
d = self._get_key()
d.addCallback(self._get_data2, outbound_data)
return d
def _get_data2(self, key, outbound_data):
# Without predefined roles, we can't derive predictably unique keys
# for each side, so we use the same key for both. We use random
# nonces to keep the messages distinct, and check for reflection.
data_key = self.derive_key(b"data-key")
outbound_encrypted = self._encrypt_data(data_key, outbound_data)
data = {"message": hexlify(outbound_encrypted).decode("ascii")}
d = self.post(self.url("post", "data"), data)
d.addCallback(lambda j: self._get_msgs(j["messages"], "poll", "data"))
def _got_data(msgs):
inbound_encrypted = unhexlify(msgs[0].encode("ascii"))
if inbound_encrypted == outbound_encrypted:
raise ReflectionAttack
try:
inbound_data = self._decrypt_data(data_key, inbound_encrypted)
return inbound_data
except CryptoError:
raise WrongPasswordError
d.addCallback(_got_data)
d.addBoth(self._deallocate)
return d
def _deallocate(self, res):
# only try once, no retries
d = self.agent.request("POST", self.url("deallocate"))
d.addBoth(lambda _: res) # ignore POST failure, pass-through result
return d