Merge branch 'twisted'
This commit is contained in:
commit
eb18b1359e
72
docs/api.md
72
docs/api.md
|
@ -50,39 +50,45 @@ theirdata = r.get_data(mydata)
|
|||
print("Their data: %s" % theirdata.decode("ascii"))
|
||||
```
|
||||
|
||||
## Twisted (TODO)
|
||||
## Twisted
|
||||
|
||||
The Twisted-friendly flow, which is not yet implemented, may look like this:
|
||||
The Twisted-friendly flow looks like this:
|
||||
|
||||
```python
|
||||
from twisted.internet import reactor
|
||||
from wormhole.transcribe import TwistedInitiator
|
||||
data = b"initiator's data"
|
||||
ti = TwistedInitiator("appid", data, reactor)
|
||||
ti.startService()
|
||||
d1 = ti.when_get_code()
|
||||
d1.addCallback(lambda code: print("Invitation Code: %s" % code))
|
||||
d2 = ti.when_get_data()
|
||||
d2.addCallback(lambda theirdata:
|
||||
print("Their data: %s" % theirdata.decode("ascii")))
|
||||
d2.addCallback(labmda _: reactor.stop())
|
||||
from wormhole.public_relay import RENDEZVOUS_RELAY
|
||||
from wormhole.twisted.transcribe import SymmetricWormhole
|
||||
outbound_message = b"outbound data"
|
||||
w1 = SymmetricWormhole("appid", RENDEZVOUS_RELAY)
|
||||
d = w1.get_code()
|
||||
def _got_code(code):
|
||||
print "Invitation Code:", code
|
||||
return w1.get_data(outbound_message)
|
||||
d.addCallback(_got_code)
|
||||
def _got_data(inbound_message):
|
||||
print "Inbound message:", inbound_message
|
||||
d.addCallback(_got_data)
|
||||
d.addBoth(lambda _: reactor.stop())
|
||||
reactor.run()
|
||||
```
|
||||
|
||||
On the other side, you call `set_code()` instead of waiting for `get_code()`:
|
||||
|
||||
```python
|
||||
from twisted.internet import reactor
|
||||
from wormhole.transcribe import TwistedReceiver
|
||||
data = b"receiver's data"
|
||||
code = sys.argv[1]
|
||||
tr = TwistedReceiver("appid", code, data, reactor)
|
||||
tr.startService()
|
||||
d = tr.when_get_data()
|
||||
d.addCallback(lambda theirdata:
|
||||
print("Their data: %s" % theirdata.decode("ascii")))
|
||||
d.addCallback(lambda _: reactor.stop())
|
||||
reactor.run()
|
||||
w2 = SymmetricWormhole("appid", RENDEZVOUS_RELAY)
|
||||
w2.set_code(code)
|
||||
d = w2.get_data(my_message)
|
||||
```
|
||||
|
||||
You can call `d=w.get_verifier()` before `get_data()`: this will perform the
|
||||
first half of the PAKE negotiation, then fire the Deferred with a verifier
|
||||
object (bytes) which can be converted into a printable representation and
|
||||
manually compared. When the users are convinced that `get_verifier()` from
|
||||
both sides are the same, call `d=get_data()` to continue the transfer. If you
|
||||
call `get_data()` first, it will perform the complete transfer without
|
||||
pausing.
|
||||
|
||||
|
||||
## Application Identifier
|
||||
|
||||
Applications using this library must provide an "application identifier", a
|
||||
|
@ -135,17 +141,27 @@ sync will be abandoned, and all callbacks will errback with a TimeoutError.
|
|||
|
||||
Both have defaults suitable for face-to-face realtime setup environments.
|
||||
|
||||
## Serialization (TODO)
|
||||
## Serialization
|
||||
|
||||
TODO: only the Twisted form supports serialization so far
|
||||
|
||||
You may not be able to hold the Initiator/Receiver object in memory for the
|
||||
whole sync process: maybe you allow it to wait for several days, but the
|
||||
program will be restarted during that time. To support this, you can persist
|
||||
the state of the object by calling `data = i.serialize()`, which will return
|
||||
the state of the object by calling `data = w.serialize()`, which will return
|
||||
a printable bytestring (the JSON-encoding of a small dictionary). To restore,
|
||||
call `Initiator.from_serialized(data)`.
|
||||
use the `from_serialized(data)` classmethod (e.g. `w =
|
||||
SymmetricWormhole.from_serialized(data)`).
|
||||
|
||||
Note that callbacks are not serialized: they must be restored after
|
||||
deserialization.
|
||||
There is exactly one point at which you can serialize the wormhole: *after*
|
||||
establishing the invitation code, but before waiting for `get_verifier()` or
|
||||
`get_data()`. If you are creating a new code, the correct time is during the
|
||||
callback fired by `get_code()`. If you are accepting a pre-generated code,
|
||||
the time is just after calling `set_code()`.
|
||||
|
||||
To properly checkpoint the process, you should store the first message
|
||||
(returned by `start()`) next to the serialized wormhole instance, so you can
|
||||
re-send it if necessary.
|
||||
|
||||
## Detailed Example
|
||||
|
||||
|
|
132
src/wormhole/test/test_twisted.py
Normal file
132
src/wormhole/test/test_twisted.py
Normal file
|
@ -0,0 +1,132 @@
|
|||
import json
|
||||
from twisted.trial import unittest
|
||||
from twisted.internet import defer, protocol, endpoints, reactor
|
||||
from twisted.application import service
|
||||
from ..servers.relay import RelayServer
|
||||
from ..twisted.transcribe import SymmetricWormhole, UsageError
|
||||
from .. import __version__
|
||||
#from twisted.python import log
|
||||
#import sys
|
||||
#log.startLogging(sys.stdout)
|
||||
|
||||
def allocate_port():
|
||||
ep = endpoints.serverFromString(reactor, "tcp:0:interface=127.0.0.1")
|
||||
d = ep.listen(protocol.Factory())
|
||||
def _listening(lp):
|
||||
port = lp.getHost().port
|
||||
d2 = lp.stopListening()
|
||||
d2.addCallback(lambda _: port)
|
||||
return d2
|
||||
d.addCallback(_listening)
|
||||
return d
|
||||
|
||||
def allocate_ports():
|
||||
d = defer.DeferredList([allocate_port(), allocate_port()])
|
||||
def _done(results):
|
||||
port1 = results[0][1]
|
||||
port2 = results[1][1]
|
||||
return (port1, port2)
|
||||
d.addCallback(_done)
|
||||
return d
|
||||
|
||||
class Basic(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.sp = service.MultiService()
|
||||
self.sp.startService()
|
||||
d = allocate_ports()
|
||||
def _got_ports(ports):
|
||||
relayport, transitport = ports
|
||||
s = RelayServer("tcp:%d:interface=127.0.0.1" % relayport,
|
||||
"tcp:%s:interface=127.0.0.1" % transitport,
|
||||
__version__)
|
||||
s.setServiceParent(self.sp)
|
||||
self.relayurl = "http://127.0.0.1:%d/wormhole-relay/" % relayport
|
||||
self.transit = "tcp:127.0.0.1:%d" % transitport
|
||||
d.addCallback(_got_ports)
|
||||
return d
|
||||
|
||||
def tearDown(self):
|
||||
return self.sp.stopService()
|
||||
|
||||
def test_basic(self):
|
||||
appid = "appid"
|
||||
w1 = SymmetricWormhole(appid, self.relayurl)
|
||||
w2 = SymmetricWormhole(appid, self.relayurl)
|
||||
d = w1.get_code()
|
||||
def _got_code(code):
|
||||
w2.set_code(code)
|
||||
d1 = w1.get_data("data1")
|
||||
d2 = w2.get_data("data2")
|
||||
return defer.DeferredList([d1,d2], fireOnOneErrback=False)
|
||||
d.addCallback(_got_code)
|
||||
def _done(dl):
|
||||
((success1, dataX), (success2, dataY)) = dl
|
||||
r1,r2 = dl
|
||||
self.assertTrue(success1)
|
||||
self.assertTrue(success2)
|
||||
self.assertEqual(dataX, "data2")
|
||||
self.assertEqual(dataY, "data1")
|
||||
d.addCallback(_done)
|
||||
return d
|
||||
|
||||
def test_fixed_code(self):
|
||||
appid = "appid"
|
||||
w1 = SymmetricWormhole(appid, self.relayurl)
|
||||
w2 = SymmetricWormhole(appid, self.relayurl)
|
||||
w1.set_code("123-purple-elephant")
|
||||
w2.set_code("123-purple-elephant")
|
||||
d1 = w1.get_data("data1")
|
||||
d2 = w2.get_data("data2")
|
||||
d = defer.DeferredList([d1,d2], fireOnOneErrback=False)
|
||||
def _done(dl):
|
||||
((success1, dataX), (success2, dataY)) = dl
|
||||
r1,r2 = dl
|
||||
self.assertTrue(success1)
|
||||
self.assertTrue(success2)
|
||||
self.assertEqual(dataX, "data2")
|
||||
self.assertEqual(dataY, "data1")
|
||||
d.addCallback(_done)
|
||||
return d
|
||||
|
||||
def test_errors(self):
|
||||
appid = "appid"
|
||||
w1 = SymmetricWormhole(appid, self.relayurl)
|
||||
self.assertRaises(UsageError, w1.get_verifier)
|
||||
self.assertRaises(UsageError, w1.get_data, "data")
|
||||
w1.set_code("123-purple-elephant")
|
||||
self.assertRaises(UsageError, w1.set_code, "123-nope")
|
||||
self.assertRaises(UsageError, w1.get_code)
|
||||
w2 = SymmetricWormhole(appid, self.relayurl)
|
||||
d = w2.get_code()
|
||||
self.assertRaises(UsageError, w2.get_code)
|
||||
return d
|
||||
|
||||
def test_serialize(self):
|
||||
appid = "appid"
|
||||
w1 = SymmetricWormhole(appid, self.relayurl)
|
||||
self.assertRaises(UsageError, w1.serialize) # too early
|
||||
w2 = SymmetricWormhole(appid, self.relayurl)
|
||||
d = w1.get_code()
|
||||
def _got_code(code):
|
||||
self.assertRaises(UsageError, w2.serialize) # too early
|
||||
w2.set_code(code)
|
||||
w2.serialize() # ok
|
||||
s = w1.serialize()
|
||||
self.assertEqual(type(s), type(""))
|
||||
unpacked = json.loads(s) # this is supposed to be JSON
|
||||
self.assertEqual(type(unpacked), dict)
|
||||
new_w1 = SymmetricWormhole.from_serialized(s)
|
||||
d1 = new_w1.get_data("data1")
|
||||
d2 = w2.get_data("data2")
|
||||
return defer.DeferredList([d1,d2], fireOnOneErrback=False)
|
||||
d.addCallback(_got_code)
|
||||
def _done(dl):
|
||||
((success1, dataX), (success2, dataY)) = dl
|
||||
r1,r2 = dl
|
||||
self.assertTrue(success1)
|
||||
self.assertTrue(success2)
|
||||
self.assertEqual(dataX, "data2")
|
||||
self.assertEqual(dataY, "data1")
|
||||
self.assertRaises(UsageError, w2.serialize) # too late
|
||||
d.addCallback(_done)
|
||||
return d
|
0
src/wormhole/twisted/__init__.py
Normal file
0
src/wormhole/twisted/__init__.py
Normal file
30
src/wormhole/twisted/demo.py
Normal file
30
src/wormhole/twisted/demo.py
Normal file
|
@ -0,0 +1,30 @@
|
|||
import sys
|
||||
from twisted.internet import reactor
|
||||
from .transcribe import SymmetricWormhole
|
||||
from .. import public_relay
|
||||
|
||||
APPID = "lothar.com/wormhole/text-xfer"
|
||||
|
||||
w = SymmetricWormhole(APPID, public_relay.RENDEZVOUS_RELAY)
|
||||
|
||||
if sys.argv[1] == "send-text":
|
||||
message = sys.argv[2]
|
||||
d = w.get_code()
|
||||
def _got_code(code):
|
||||
print "code is:", code
|
||||
return w.get_data(message)
|
||||
d.addCallback(_got_code)
|
||||
def _got_data(their_data):
|
||||
print "ack:", their_data
|
||||
d.addCallback(_got_data)
|
||||
elif sys.argv[1] == "receive-text":
|
||||
code = sys.argv[2]
|
||||
w.set_code(code)
|
||||
d = w.get_data("ok")
|
||||
def _got_data(their_data):
|
||||
print their_data
|
||||
d.addCallback(_got_data)
|
||||
else:
|
||||
raise ValueError("bad command")
|
||||
d.addCallback(lambda _: reactor.stop())
|
||||
reactor.run()
|
219
src/wormhole/twisted/eventsource.py
Normal file
219
src/wormhole/twisted/eventsource.py
Normal file
|
@ -0,0 +1,219 @@
|
|||
from twisted.python import log, failure
|
||||
from twisted.internet import reactor, defer, protocol
|
||||
from twisted.application import service
|
||||
from twisted.protocols import basic
|
||||
from twisted.web.client import Agent, ResponseDone
|
||||
from twisted.web.http_headers import Headers
|
||||
from ..util.eventual import eventually
|
||||
|
||||
class EventSourceParser(basic.LineOnlyReceiver):
|
||||
delimiter = "\n"
|
||||
|
||||
def __init__(self, handler):
|
||||
self.current_field = None
|
||||
self.current_lines = []
|
||||
self.handler = handler
|
||||
self.done_deferred = defer.Deferred()
|
||||
self.eventtype = "message"
|
||||
|
||||
def connectionLost(self, why):
|
||||
if why.check(ResponseDone):
|
||||
why = None
|
||||
self.done_deferred.callback(why)
|
||||
|
||||
def dataReceived(self, data):
|
||||
# exceptions here aren't being logged properly, and tests will hang
|
||||
# rather than halt. I suspect twisted.web._newclient's
|
||||
# HTTP11ClientProtocol.dataReceived(), which catches everything and
|
||||
# responds with self._giveUp() but doesn't log.err.
|
||||
try:
|
||||
basic.LineOnlyReceiver.dataReceived(self, data)
|
||||
except:
|
||||
log.err()
|
||||
raise
|
||||
|
||||
def lineReceived(self, line):
|
||||
if not line:
|
||||
# blank line ends the field
|
||||
self.fieldReceived(self.current_field,
|
||||
"\n".join(self.current_lines))
|
||||
self.current_field = None
|
||||
self.current_lines[:] = []
|
||||
return
|
||||
if self.current_field is None:
|
||||
self.current_field, data = line.split(": ", 1)
|
||||
self.current_lines.append(data)
|
||||
else:
|
||||
self.current_lines.append(line)
|
||||
|
||||
def fieldReceived(self, fieldname, data):
|
||||
if fieldname == "event":
|
||||
self.eventtype = data
|
||||
elif fieldname == "data":
|
||||
self.eventReceived(self.eventtype, data)
|
||||
self.eventtype = "message"
|
||||
else:
|
||||
log.msg("weird fieldname", fieldname, data)
|
||||
|
||||
def eventReceived(self, eventtype, data):
|
||||
self.handler(eventtype, data)
|
||||
|
||||
class EventSourceError(Exception):
|
||||
pass
|
||||
|
||||
# es = EventSource(url, handler)
|
||||
# d = es.start()
|
||||
# es.cancel()
|
||||
|
||||
class EventSource: # TODO: service.Service
|
||||
def __init__(self, url, handler, when_connected=None, agent=None):
|
||||
self.url = url
|
||||
self.handler = handler
|
||||
self.when_connected = when_connected
|
||||
self.started = False
|
||||
self.cancelled = False
|
||||
self.proto = EventSourceParser(self.handler)
|
||||
if not agent:
|
||||
agent = Agent(reactor)
|
||||
self.agent = agent
|
||||
|
||||
def start(self):
|
||||
assert not self.started, "single-use"
|
||||
self.started = True
|
||||
d = self.agent.request("GET", self.url,
|
||||
Headers({"accept": ["text/event-stream"]}))
|
||||
d.addCallback(self._connected)
|
||||
return d
|
||||
|
||||
def _connected(self, resp):
|
||||
if resp.code != 200:
|
||||
raise EventSourceError("%d: %s" % (resp.code, resp.phrase))
|
||||
if self.when_connected:
|
||||
self.when_connected()
|
||||
#if resp.headers.getRawHeaders("content-type") == ["text/event-stream"]:
|
||||
resp.deliverBody(self.proto)
|
||||
if self.cancelled:
|
||||
self.kill_connection()
|
||||
return self.proto.done_deferred
|
||||
|
||||
def cancel(self):
|
||||
self.cancelled = True
|
||||
if not self.proto.transport:
|
||||
# _connected hasn't been called yet, but that self.cancelled
|
||||
# should take care of it when the connection is established
|
||||
def kill(data):
|
||||
# this should kill it as soon as any data is delivered
|
||||
raise ValueError("dead")
|
||||
self.proto.dataReceived = kill # just in case
|
||||
return
|
||||
self.kill_connection()
|
||||
|
||||
def kill_connection(self):
|
||||
if (hasattr(self.proto.transport, "_producer")
|
||||
and self.proto.transport._producer):
|
||||
# This is gross and fragile. We need a clean way to stop the
|
||||
# client connection. p.transport is a
|
||||
# twisted.web._newclient.TransportProxyProducer , and its
|
||||
# ._producer is the tcp.Port.
|
||||
self.proto.transport._producer.loseConnection()
|
||||
else:
|
||||
log.err("get_events: unable to stop connection")
|
||||
# oh well
|
||||
#err = EventSourceError("unable to cancel")
|
||||
try:
|
||||
self.proto.done_deferred.callback(None)
|
||||
except defer.AlreadyCalledError:
|
||||
pass
|
||||
|
||||
|
||||
class Connector:
|
||||
# behave enough like an IConnector to appease ReconnectingClientFactory
|
||||
def __init__(self, res):
|
||||
self.res = res
|
||||
def connect(self):
|
||||
self.res._maybeStart()
|
||||
def stopConnecting(self):
|
||||
self.res._stop_eventsource()
|
||||
|
||||
class ReconnectingEventSource(service.MultiService,
|
||||
protocol.ReconnectingClientFactory):
|
||||
def __init__(self, baseurl, connection_starting, handler, agent=None):
|
||||
service.MultiService.__init__(self)
|
||||
# we don't use any of the basic Factory/ClientFactory methods of
|
||||
# this, just the ReconnectingClientFactory.retry, stopTrying, and
|
||||
# resetDelay methods.
|
||||
|
||||
self.baseurl = baseurl
|
||||
self.connection_starting = connection_starting
|
||||
self.handler = handler
|
||||
self.agent = agent
|
||||
# IService provides self.running, toggled by {start,stop}Service.
|
||||
# self.active is toggled by {,de}activate. If both .running and
|
||||
# .active are True, then we want to have an outstanding EventSource
|
||||
# and will start one if necessary. If either is False, then we don't
|
||||
# want one to be outstanding, and will initiate shutdown.
|
||||
self.active = False
|
||||
self.connector = Connector(self)
|
||||
self.es = None # set we have an outstanding EventSource
|
||||
self.when_stopped = [] # list of Deferreds
|
||||
|
||||
def isStopped(self):
|
||||
return not self.es
|
||||
|
||||
def startService(self):
|
||||
service.MultiService.startService(self) # sets self.running
|
||||
self._maybeStart()
|
||||
|
||||
def stopService(self):
|
||||
# clears self.running
|
||||
d = defer.maybeDeferred(service.MultiService.stopService, self)
|
||||
d.addCallback(self._maybeStop)
|
||||
return d
|
||||
|
||||
def activate(self):
|
||||
assert not self.active
|
||||
self.active = True
|
||||
self._maybeStart()
|
||||
|
||||
def deactivate(self):
|
||||
assert self.active # XXX
|
||||
self.active = False
|
||||
return self._maybeStop()
|
||||
|
||||
def _maybeStart(self):
|
||||
if not (self.active and self.running):
|
||||
return
|
||||
self.continueTrying = True
|
||||
url = self.connection_starting()
|
||||
self.es = EventSource(url, self.handler, self.resetDelay,
|
||||
agent=self.agent)
|
||||
d = self.es.start()
|
||||
d.addBoth(self._stopped)
|
||||
|
||||
def _stopped(self, res):
|
||||
self.es = None
|
||||
# we might have stopped because of a connection error, or because of
|
||||
# an intentional shutdown.
|
||||
if self.active and self.running:
|
||||
# we still want to be connected, so schedule a reconnection
|
||||
if isinstance(res, failure.Failure):
|
||||
log.err(res)
|
||||
self.retry() # will eventually call _maybeStart
|
||||
return
|
||||
# intentional shutdown
|
||||
self.stopTrying()
|
||||
for d in self.when_stopped:
|
||||
eventually(d.callback, None)
|
||||
self.when_stopped = []
|
||||
|
||||
def _stop_eventsource(self):
|
||||
if self.es:
|
||||
eventually(self.es.cancel)
|
||||
|
||||
def _maybeStop(self, _=None):
|
||||
self.stopTrying() # cancels timer, calls _stop_eventsource()
|
||||
if not self.es:
|
||||
return defer.succeed(None)
|
||||
d = defer.Deferred()
|
||||
self.when_stopped.append(d)
|
||||
return d
|
|
@ -1,27 +1,271 @@
|
|||
from twisted.application import service
|
||||
from ..const import RELAY
|
||||
from __future__ import print_function
|
||||
import os, sys, json, re
|
||||
from binascii import hexlify, unhexlify
|
||||
from zope.interface import implementer
|
||||
from twisted.internet import reactor, defer
|
||||
from twisted.web import client as web_client
|
||||
from twisted.web import error as web_error
|
||||
from twisted.web.iweb import IBodyProducer
|
||||
from nacl.secret import SecretBox
|
||||
from nacl.exceptions import CryptoError
|
||||
from nacl import utils
|
||||
from spake2 import SPAKE2_Symmetric
|
||||
from .eventsource import ReconnectingEventSource
|
||||
from .. import __version__
|
||||
from .. import codes
|
||||
from ..errors import ServerError
|
||||
from ..util.hkdf import HKDF
|
||||
|
||||
class TwistedInitiator(service.MultiService):
|
||||
def __init__(self, appid, data, reactor, relay=RELAY):
|
||||
self.appid = appid
|
||||
class WrongPasswordError(Exception):
|
||||
"""
|
||||
Key confirmation failed.
|
||||
"""
|
||||
|
||||
class ReflectionAttack(Exception):
|
||||
"""An attacker (or bug) reflected our outgoing message back to us."""
|
||||
|
||||
class UsageError(Exception):
|
||||
"""The programmer did something wrong."""
|
||||
|
||||
@implementer(IBodyProducer)
|
||||
class DataProducer:
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
self.reactor = reactor
|
||||
self.length = len(data)
|
||||
def startProducing(self, consumer):
|
||||
consumer.write(self.data)
|
||||
return defer.succeed(None)
|
||||
def stopProducing(self):
|
||||
pass
|
||||
def pauseProducing(self):
|
||||
pass
|
||||
def resumeProducing(self):
|
||||
pass
|
||||
|
||||
|
||||
class SymmetricWormhole:
|
||||
def __init__(self, appid, relay):
|
||||
self.appid = appid
|
||||
self.relay = relay
|
||||
self.agent = web_client.Agent(reactor)
|
||||
self.side = None
|
||||
self.code = None
|
||||
self.key = None
|
||||
self._started_get_code = False
|
||||
|
||||
def when_get_code(self):
|
||||
pass # return Deferred
|
||||
def get_code(self, code_length=2):
|
||||
if self.code is not None: raise UsageError
|
||||
if self._started_get_code: raise UsageError
|
||||
self._started_get_code = True
|
||||
self.side = hexlify(os.urandom(5))
|
||||
d = self._allocate_channel()
|
||||
def _got_channel_id(channel_id):
|
||||
code = codes.make_code(channel_id, code_length)
|
||||
self._set_code_and_channel_id(code)
|
||||
self._start()
|
||||
return code
|
||||
d.addCallback(_got_channel_id)
|
||||
return d
|
||||
|
||||
def when_get_data(self):
|
||||
pass # return Deferred
|
||||
def _allocate_channel(self):
|
||||
url = self.relay + "allocate/%s" % self.side
|
||||
d = self.post(url)
|
||||
def _got_channel(data):
|
||||
if "welcome" in data:
|
||||
self.handle_welcome(data["welcome"])
|
||||
return data["channel-id"]
|
||||
d.addCallback(_got_channel)
|
||||
return d
|
||||
|
||||
class TwistedReceiver(service.MultiService):
|
||||
def __init__(self, appid, data, code, reactor, relay=RELAY):
|
||||
self.appid = appid
|
||||
self.data = data
|
||||
def set_code(self, code):
|
||||
if self.code is not None: raise UsageError
|
||||
if self.side is not None: raise UsageError
|
||||
self._set_code_and_channel_id(code)
|
||||
self.side = hexlify(os.urandom(5))
|
||||
self._start()
|
||||
|
||||
def _set_code_and_channel_id(self, code):
|
||||
if self.code is not None: raise UsageError
|
||||
mo = re.search(r'^(\d+)-', code)
|
||||
if not mo:
|
||||
raise ValueError("code (%s) must start with NN-" % code)
|
||||
self.channel_id = int(mo.group(1))
|
||||
self.code = code
|
||||
self.reactor = reactor
|
||||
self.relay = relay
|
||||
|
||||
def when_get_data(self):
|
||||
pass # return Deferred
|
||||
def _start(self):
|
||||
# allocate the rest now too, so it can be serialized
|
||||
self.sp = SPAKE2_Symmetric(self.code.encode("ascii"),
|
||||
idA=self.appid+":SymmetricA",
|
||||
idB=self.appid+":SymmetricB")
|
||||
self.msg1 = self.sp.start()
|
||||
|
||||
def serialize(self):
|
||||
# I can only be serialized after get_code/set_code and before
|
||||
# get_verifier/get_data
|
||||
if self.code is None: raise UsageError
|
||||
if self.key is not None: raise UsageError
|
||||
data = {
|
||||
"appid": self.appid,
|
||||
"relay": self.relay,
|
||||
"code": self.code,
|
||||
"side": self.side,
|
||||
"spake2": json.loads(self.sp.serialize()),
|
||||
"msg1": self.msg1.encode("hex"),
|
||||
}
|
||||
return json.dumps(data)
|
||||
|
||||
@classmethod
|
||||
def from_serialized(klass, data):
|
||||
d = json.loads(data)
|
||||
self = klass(d["appid"].encode("ascii"), d["relay"].encode("ascii"))
|
||||
self._set_code_and_channel_id(d["code"].encode("ascii"))
|
||||
self.side = d["side"].encode("ascii")
|
||||
self.sp = SPAKE2_Symmetric.from_serialized(json.dumps(d["spake2"]))
|
||||
self.msg1 = d["msg1"].decode("hex")
|
||||
return self
|
||||
|
||||
motd_displayed = False
|
||||
version_warning_displayed = False
|
||||
|
||||
def handle_welcome(self, welcome):
|
||||
if ("motd" in welcome and
|
||||
not self.motd_displayed):
|
||||
motd_lines = welcome["motd"].splitlines()
|
||||
motd_formatted = "\n ".join(motd_lines)
|
||||
print("Server (at %s) says:\n %s" % (self.relay, motd_formatted),
|
||||
file=sys.stderr)
|
||||
self.motd_displayed = True
|
||||
|
||||
# Only warn if we're running a release version (e.g. 0.0.6, not
|
||||
# 0.0.6-DISTANCE-gHASH). Only warn once.
|
||||
if ("-" not in __version__ and
|
||||
not self.version_warning_displayed and
|
||||
welcome["current_version"] != __version__):
|
||||
print("Warning: errors may occur unless both sides are running the same version", file=sys.stderr)
|
||||
print("Server claims %s is current, but ours is %s"
|
||||
% (welcome["current_version"], __version__), file=sys.stderr)
|
||||
self.version_warning_displayed = True
|
||||
|
||||
if "error" in welcome:
|
||||
raise ServerError(welcome["error"], self.relay)
|
||||
|
||||
def url(self, verb, msgnum=None):
|
||||
url = "%s%d/%s/%s" % (self.relay, self.channel_id, self.side, verb)
|
||||
if msgnum is not None:
|
||||
url += "/" + msgnum
|
||||
return url
|
||||
|
||||
def post(self, url, post_json=None):
|
||||
# TODO: retry on failure, with exponential backoff. We're guarding
|
||||
# against the rendezvous server being temporarily offline.
|
||||
p = None
|
||||
if post_json:
|
||||
data = json.dumps(post_json).encode("utf-8")
|
||||
p = DataProducer(data)
|
||||
d = self.agent.request("POST", url, bodyProducer=p)
|
||||
def _check_error(resp):
|
||||
if resp.code != 200:
|
||||
raise web_error.Error(resp.code, resp.phrase)
|
||||
return resp
|
||||
d.addCallback(_check_error)
|
||||
d.addCallback(web_client.readBody)
|
||||
d.addCallback(lambda data: json.loads(data))
|
||||
return d
|
||||
|
||||
def _get_msgs(self, old_msgs, verb, msgnum):
|
||||
# fire with a list of messages that match verb/msgnum, which either
|
||||
# came from old_msgs, or from an EventSource that we attached to the
|
||||
# corresponding URL
|
||||
if old_msgs:
|
||||
return defer.succeed(old_msgs)
|
||||
d = defer.Deferred()
|
||||
msgs = []
|
||||
def _handle(name, data):
|
||||
if name == "welcome":
|
||||
self.handle_welcome(json.loads(data))
|
||||
if name == "message":
|
||||
msgs.append(json.loads(data)["message"])
|
||||
d.callback(None)
|
||||
es = ReconnectingEventSource(None, lambda: self.url(verb, msgnum),
|
||||
_handle)#, agent=self.agent)
|
||||
es.startService() # TODO: .setServiceParent(self)
|
||||
es.activate()
|
||||
d.addCallback(lambda _: es.deactivate())
|
||||
d.addCallback(lambda _: es.stopService())
|
||||
d.addCallback(lambda _: msgs)
|
||||
return d
|
||||
|
||||
def derive_key(self, purpose, length=SecretBox.KEY_SIZE):
|
||||
assert self.key is not None # call after get_verifier() or get_data()
|
||||
assert type(purpose) == type(b"")
|
||||
return HKDF(self.key, length, CTXinfo=purpose)
|
||||
|
||||
def _encrypt_data(self, key, data):
|
||||
assert len(key) == SecretBox.KEY_SIZE
|
||||
box = SecretBox(key)
|
||||
nonce = utils.random(SecretBox.NONCE_SIZE)
|
||||
return box.encrypt(data, nonce)
|
||||
|
||||
def _decrypt_data(self, key, encrypted):
|
||||
assert len(key) == SecretBox.KEY_SIZE
|
||||
box = SecretBox(key)
|
||||
data = box.decrypt(encrypted)
|
||||
return data
|
||||
|
||||
|
||||
def _get_key(self):
|
||||
# TODO: prevent multiple invocation
|
||||
if self.key:
|
||||
return defer.succeed(self.key)
|
||||
data = {"message": hexlify(self.msg1).decode("ascii")}
|
||||
d = self.post(self.url("post", "pake"), data)
|
||||
d.addCallback(lambda j: self._get_msgs(j["messages"], "poll", "pake"))
|
||||
def _got_pake(msgs):
|
||||
pake_msg = unhexlify(msgs[0].encode("ascii"))
|
||||
key = self.sp.finish(pake_msg)
|
||||
self.key = key
|
||||
self.verifier = self.derive_key(self.appid+b":Verifier")
|
||||
return key
|
||||
d.addCallback(_got_pake)
|
||||
return d
|
||||
|
||||
def get_verifier(self):
|
||||
if self.code is None: raise UsageError
|
||||
d = self._get_key()
|
||||
d.addCallback(lambda _: self.verifier)
|
||||
return d
|
||||
|
||||
def get_data(self, outbound_data):
|
||||
# only call this once
|
||||
if self.code is None: raise UsageError
|
||||
d = self._get_key()
|
||||
d.addCallback(self._get_data2, outbound_data)
|
||||
return d
|
||||
|
||||
def _get_data2(self, key, outbound_data):
|
||||
# Without predefined roles, we can't derive predictably unique keys
|
||||
# for each side, so we use the same key for both. We use random
|
||||
# nonces to keep the messages distinct, and check for reflection.
|
||||
data_key = self.derive_key(b"data-key")
|
||||
outbound_encrypted = self._encrypt_data(data_key, outbound_data)
|
||||
data = {"message": hexlify(outbound_encrypted).decode("ascii")}
|
||||
d = self.post(self.url("post", "data"), data)
|
||||
d.addCallback(lambda j: self._get_msgs(j["messages"], "poll", "data"))
|
||||
def _got_data(msgs):
|
||||
inbound_encrypted = unhexlify(msgs[0].encode("ascii"))
|
||||
if inbound_encrypted == outbound_encrypted:
|
||||
raise ReflectionAttack
|
||||
try:
|
||||
inbound_data = self._decrypt_data(data_key, inbound_encrypted)
|
||||
return inbound_data
|
||||
except CryptoError:
|
||||
raise WrongPasswordError
|
||||
d.addCallback(_got_data)
|
||||
d.addBoth(self._deallocate)
|
||||
return d
|
||||
|
||||
def _deallocate(self, res):
|
||||
# only try once, no retries
|
||||
d = self.agent.request("POST", self.url("deallocate"))
|
||||
d.addBoth(lambda _: res) # ignore POST failure, pass-through result
|
||||
return d
|
||||
|
|
Loading…
Reference in New Issue
Block a user