Merge branch 'twisted'
This commit is contained in:
		
						commit
						eb18b1359e
					
				
							
								
								
									
										72
									
								
								docs/api.md
									
									
									
									
									
								
							
							
						
						
									
										72
									
								
								docs/api.md
									
									
									
									
									
								
							| 
						 | 
				
			
			@ -50,39 +50,45 @@ theirdata = r.get_data(mydata)
 | 
			
		|||
print("Their data: %s" % theirdata.decode("ascii"))
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
## Twisted (TODO)
 | 
			
		||||
## Twisted
 | 
			
		||||
 | 
			
		||||
The Twisted-friendly flow, which is not yet implemented, may look like this:
 | 
			
		||||
The Twisted-friendly flow looks like this:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
from twisted.internet import reactor
 | 
			
		||||
from wormhole.transcribe import TwistedInitiator
 | 
			
		||||
data = b"initiator's data"
 | 
			
		||||
ti = TwistedInitiator("appid", data, reactor)
 | 
			
		||||
ti.startService()
 | 
			
		||||
d1 = ti.when_get_code()
 | 
			
		||||
d1.addCallback(lambda code: print("Invitation Code: %s" % code))
 | 
			
		||||
d2 = ti.when_get_data()
 | 
			
		||||
d2.addCallback(lambda theirdata:
 | 
			
		||||
               print("Their data: %s" % theirdata.decode("ascii")))
 | 
			
		||||
d2.addCallback(labmda _: reactor.stop())
 | 
			
		||||
from wormhole.public_relay import RENDEZVOUS_RELAY
 | 
			
		||||
from wormhole.twisted.transcribe import SymmetricWormhole
 | 
			
		||||
outbound_message = b"outbound data"
 | 
			
		||||
w1 = SymmetricWormhole("appid", RENDEZVOUS_RELAY)
 | 
			
		||||
d = w1.get_code()
 | 
			
		||||
def _got_code(code):
 | 
			
		||||
    print "Invitation Code:", code
 | 
			
		||||
    return w1.get_data(outbound_message)
 | 
			
		||||
d.addCallback(_got_code)
 | 
			
		||||
def _got_data(inbound_message):
 | 
			
		||||
    print "Inbound message:", inbound_message
 | 
			
		||||
d.addCallback(_got_data)
 | 
			
		||||
d.addBoth(lambda _: reactor.stop())
 | 
			
		||||
reactor.run()
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
On the other side, you call `set_code()` instead of waiting for `get_code()`:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
from twisted.internet import reactor
 | 
			
		||||
from wormhole.transcribe import TwistedReceiver
 | 
			
		||||
data = b"receiver's data"
 | 
			
		||||
code = sys.argv[1]
 | 
			
		||||
tr = TwistedReceiver("appid", code, data, reactor)
 | 
			
		||||
tr.startService()
 | 
			
		||||
d = tr.when_get_data()
 | 
			
		||||
d.addCallback(lambda theirdata:
 | 
			
		||||
              print("Their data: %s" % theirdata.decode("ascii")))
 | 
			
		||||
d.addCallback(lambda _: reactor.stop())
 | 
			
		||||
reactor.run()
 | 
			
		||||
w2 = SymmetricWormhole("appid", RENDEZVOUS_RELAY)
 | 
			
		||||
w2.set_code(code)
 | 
			
		||||
d = w2.get_data(my_message)
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
You can call `d=w.get_verifier()` before `get_data()`: this will perform the
 | 
			
		||||
first half of the PAKE negotiation, then fire the Deferred with a verifier
 | 
			
		||||
object (bytes) which can be converted into a printable representation and
 | 
			
		||||
manually compared. When the users are convinced that `get_verifier()` from
 | 
			
		||||
both sides are the same, call `d=get_data()` to continue the transfer. If you
 | 
			
		||||
call `get_data()` first, it will perform the complete transfer without
 | 
			
		||||
pausing.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
## Application Identifier
 | 
			
		||||
 | 
			
		||||
Applications using this library must provide an "application identifier", a
 | 
			
		||||
| 
						 | 
				
			
			@ -135,17 +141,27 @@ sync will be abandoned, and all callbacks will errback with a TimeoutError.
 | 
			
		|||
 | 
			
		||||
Both have defaults suitable for face-to-face realtime setup environments.
 | 
			
		||||
 | 
			
		||||
## Serialization (TODO)
 | 
			
		||||
## Serialization
 | 
			
		||||
 | 
			
		||||
TODO: only the Twisted form supports serialization so far
 | 
			
		||||
 | 
			
		||||
You may not be able to hold the Initiator/Receiver object in memory for the
 | 
			
		||||
whole sync process: maybe you allow it to wait for several days, but the
 | 
			
		||||
program will be restarted during that time. To support this, you can persist
 | 
			
		||||
the state of the object by calling `data = i.serialize()`, which will return
 | 
			
		||||
the state of the object by calling `data = w.serialize()`, which will return
 | 
			
		||||
a printable bytestring (the JSON-encoding of a small dictionary). To restore,
 | 
			
		||||
call `Initiator.from_serialized(data)`.
 | 
			
		||||
use the `from_serialized(data)` classmethod (e.g. `w =
 | 
			
		||||
SymmetricWormhole.from_serialized(data)`).
 | 
			
		||||
 | 
			
		||||
Note that callbacks are not serialized: they must be restored after
 | 
			
		||||
deserialization.
 | 
			
		||||
There is exactly one point at which you can serialize the wormhole: *after*
 | 
			
		||||
establishing the invitation code, but before waiting for `get_verifier()` or
 | 
			
		||||
`get_data()`. If you are creating a new code, the correct time is during the
 | 
			
		||||
callback fired by `get_code()`. If you are accepting a pre-generated code,
 | 
			
		||||
the time is just after calling `set_code()`.
 | 
			
		||||
 | 
			
		||||
To properly checkpoint the process, you should store the first message
 | 
			
		||||
(returned by `start()`) next to the serialized wormhole instance, so you can
 | 
			
		||||
re-send it if necessary.
 | 
			
		||||
 | 
			
		||||
## Detailed Example
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										132
									
								
								src/wormhole/test/test_twisted.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										132
									
								
								src/wormhole/test/test_twisted.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,132 @@
 | 
			
		|||
import json
 | 
			
		||||
from twisted.trial import unittest
 | 
			
		||||
from twisted.internet import defer, protocol, endpoints, reactor
 | 
			
		||||
from twisted.application import service
 | 
			
		||||
from ..servers.relay import RelayServer
 | 
			
		||||
from ..twisted.transcribe import SymmetricWormhole, UsageError
 | 
			
		||||
from .. import __version__
 | 
			
		||||
#from twisted.python import log
 | 
			
		||||
#import sys
 | 
			
		||||
#log.startLogging(sys.stdout)
 | 
			
		||||
 | 
			
		||||
def allocate_port():
 | 
			
		||||
    ep = endpoints.serverFromString(reactor, "tcp:0:interface=127.0.0.1")
 | 
			
		||||
    d = ep.listen(protocol.Factory())
 | 
			
		||||
    def _listening(lp):
 | 
			
		||||
        port = lp.getHost().port
 | 
			
		||||
        d2 = lp.stopListening()
 | 
			
		||||
        d2.addCallback(lambda _: port)
 | 
			
		||||
        return d2
 | 
			
		||||
    d.addCallback(_listening)
 | 
			
		||||
    return d
 | 
			
		||||
 | 
			
		||||
def allocate_ports():
 | 
			
		||||
    d = defer.DeferredList([allocate_port(), allocate_port()])
 | 
			
		||||
    def _done(results):
 | 
			
		||||
        port1 = results[0][1]
 | 
			
		||||
        port2 = results[1][1]
 | 
			
		||||
        return (port1, port2)
 | 
			
		||||
    d.addCallback(_done)
 | 
			
		||||
    return d
 | 
			
		||||
 | 
			
		||||
class Basic(unittest.TestCase):
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        self.sp = service.MultiService()
 | 
			
		||||
        self.sp.startService()
 | 
			
		||||
        d = allocate_ports()
 | 
			
		||||
        def _got_ports(ports):
 | 
			
		||||
            relayport, transitport = ports
 | 
			
		||||
            s = RelayServer("tcp:%d:interface=127.0.0.1" % relayport,
 | 
			
		||||
                            "tcp:%s:interface=127.0.0.1" % transitport,
 | 
			
		||||
                            __version__)
 | 
			
		||||
            s.setServiceParent(self.sp)
 | 
			
		||||
            self.relayurl = "http://127.0.0.1:%d/wormhole-relay/" % relayport
 | 
			
		||||
            self.transit = "tcp:127.0.0.1:%d" % transitport
 | 
			
		||||
        d.addCallback(_got_ports)
 | 
			
		||||
        return d
 | 
			
		||||
 | 
			
		||||
    def tearDown(self):
 | 
			
		||||
        return self.sp.stopService()
 | 
			
		||||
 | 
			
		||||
    def test_basic(self):
 | 
			
		||||
        appid = "appid"
 | 
			
		||||
        w1 = SymmetricWormhole(appid, self.relayurl)
 | 
			
		||||
        w2 = SymmetricWormhole(appid, self.relayurl)
 | 
			
		||||
        d = w1.get_code()
 | 
			
		||||
        def _got_code(code):
 | 
			
		||||
            w2.set_code(code)
 | 
			
		||||
            d1 = w1.get_data("data1")
 | 
			
		||||
            d2 = w2.get_data("data2")
 | 
			
		||||
            return defer.DeferredList([d1,d2], fireOnOneErrback=False)
 | 
			
		||||
        d.addCallback(_got_code)
 | 
			
		||||
        def _done(dl):
 | 
			
		||||
            ((success1, dataX), (success2, dataY)) = dl
 | 
			
		||||
            r1,r2 = dl
 | 
			
		||||
            self.assertTrue(success1)
 | 
			
		||||
            self.assertTrue(success2)
 | 
			
		||||
            self.assertEqual(dataX, "data2")
 | 
			
		||||
            self.assertEqual(dataY, "data1")
 | 
			
		||||
        d.addCallback(_done)
 | 
			
		||||
        return d
 | 
			
		||||
 | 
			
		||||
    def test_fixed_code(self):
 | 
			
		||||
        appid = "appid"
 | 
			
		||||
        w1 = SymmetricWormhole(appid, self.relayurl)
 | 
			
		||||
        w2 = SymmetricWormhole(appid, self.relayurl)
 | 
			
		||||
        w1.set_code("123-purple-elephant")
 | 
			
		||||
        w2.set_code("123-purple-elephant")
 | 
			
		||||
        d1 = w1.get_data("data1")
 | 
			
		||||
        d2 = w2.get_data("data2")
 | 
			
		||||
        d = defer.DeferredList([d1,d2], fireOnOneErrback=False)
 | 
			
		||||
        def _done(dl):
 | 
			
		||||
            ((success1, dataX), (success2, dataY)) = dl
 | 
			
		||||
            r1,r2 = dl
 | 
			
		||||
            self.assertTrue(success1)
 | 
			
		||||
            self.assertTrue(success2)
 | 
			
		||||
            self.assertEqual(dataX, "data2")
 | 
			
		||||
            self.assertEqual(dataY, "data1")
 | 
			
		||||
        d.addCallback(_done)
 | 
			
		||||
        return d
 | 
			
		||||
 | 
			
		||||
    def test_errors(self):
 | 
			
		||||
        appid = "appid"
 | 
			
		||||
        w1 = SymmetricWormhole(appid, self.relayurl)
 | 
			
		||||
        self.assertRaises(UsageError, w1.get_verifier)
 | 
			
		||||
        self.assertRaises(UsageError, w1.get_data, "data")
 | 
			
		||||
        w1.set_code("123-purple-elephant")
 | 
			
		||||
        self.assertRaises(UsageError, w1.set_code, "123-nope")
 | 
			
		||||
        self.assertRaises(UsageError, w1.get_code)
 | 
			
		||||
        w2 = SymmetricWormhole(appid, self.relayurl)
 | 
			
		||||
        d = w2.get_code()
 | 
			
		||||
        self.assertRaises(UsageError, w2.get_code)
 | 
			
		||||
        return d
 | 
			
		||||
 | 
			
		||||
    def test_serialize(self):
 | 
			
		||||
        appid = "appid"
 | 
			
		||||
        w1 = SymmetricWormhole(appid, self.relayurl)
 | 
			
		||||
        self.assertRaises(UsageError, w1.serialize) # too early
 | 
			
		||||
        w2 = SymmetricWormhole(appid, self.relayurl)
 | 
			
		||||
        d = w1.get_code()
 | 
			
		||||
        def _got_code(code):
 | 
			
		||||
            self.assertRaises(UsageError, w2.serialize) # too early
 | 
			
		||||
            w2.set_code(code)
 | 
			
		||||
            w2.serialize() # ok
 | 
			
		||||
            s = w1.serialize()
 | 
			
		||||
            self.assertEqual(type(s), type(""))
 | 
			
		||||
            unpacked = json.loads(s) # this is supposed to be JSON
 | 
			
		||||
            self.assertEqual(type(unpacked), dict)
 | 
			
		||||
            new_w1 = SymmetricWormhole.from_serialized(s)
 | 
			
		||||
            d1 = new_w1.get_data("data1")
 | 
			
		||||
            d2 = w2.get_data("data2")
 | 
			
		||||
            return defer.DeferredList([d1,d2], fireOnOneErrback=False)
 | 
			
		||||
        d.addCallback(_got_code)
 | 
			
		||||
        def _done(dl):
 | 
			
		||||
            ((success1, dataX), (success2, dataY)) = dl
 | 
			
		||||
            r1,r2 = dl
 | 
			
		||||
            self.assertTrue(success1)
 | 
			
		||||
            self.assertTrue(success2)
 | 
			
		||||
            self.assertEqual(dataX, "data2")
 | 
			
		||||
            self.assertEqual(dataY, "data1")
 | 
			
		||||
            self.assertRaises(UsageError, w2.serialize) # too late
 | 
			
		||||
        d.addCallback(_done)
 | 
			
		||||
        return d
 | 
			
		||||
							
								
								
									
										0
									
								
								src/wormhole/twisted/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								src/wormhole/twisted/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										30
									
								
								src/wormhole/twisted/demo.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										30
									
								
								src/wormhole/twisted/demo.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,30 @@
 | 
			
		|||
import sys
 | 
			
		||||
from twisted.internet import reactor
 | 
			
		||||
from .transcribe import SymmetricWormhole
 | 
			
		||||
from .. import public_relay
 | 
			
		||||
 | 
			
		||||
APPID = "lothar.com/wormhole/text-xfer"
 | 
			
		||||
 | 
			
		||||
w = SymmetricWormhole(APPID, public_relay.RENDEZVOUS_RELAY)
 | 
			
		||||
 | 
			
		||||
if sys.argv[1] == "send-text":
 | 
			
		||||
    message = sys.argv[2]
 | 
			
		||||
    d = w.get_code()
 | 
			
		||||
    def _got_code(code):
 | 
			
		||||
        print "code is:", code
 | 
			
		||||
        return w.get_data(message)
 | 
			
		||||
    d.addCallback(_got_code)
 | 
			
		||||
    def _got_data(their_data):
 | 
			
		||||
        print "ack:", their_data
 | 
			
		||||
    d.addCallback(_got_data)
 | 
			
		||||
elif sys.argv[1] == "receive-text":
 | 
			
		||||
    code = sys.argv[2]
 | 
			
		||||
    w.set_code(code)
 | 
			
		||||
    d = w.get_data("ok")
 | 
			
		||||
    def _got_data(their_data):
 | 
			
		||||
        print their_data
 | 
			
		||||
    d.addCallback(_got_data)
 | 
			
		||||
else:
 | 
			
		||||
    raise ValueError("bad command")
 | 
			
		||||
d.addCallback(lambda _: reactor.stop())
 | 
			
		||||
reactor.run()
 | 
			
		||||
							
								
								
									
										219
									
								
								src/wormhole/twisted/eventsource.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										219
									
								
								src/wormhole/twisted/eventsource.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,219 @@
 | 
			
		|||
from twisted.python import log, failure
 | 
			
		||||
from twisted.internet import reactor, defer, protocol
 | 
			
		||||
from twisted.application import service
 | 
			
		||||
from twisted.protocols import basic
 | 
			
		||||
from twisted.web.client import Agent, ResponseDone
 | 
			
		||||
from twisted.web.http_headers import Headers
 | 
			
		||||
from ..util.eventual import eventually
 | 
			
		||||
 | 
			
		||||
class EventSourceParser(basic.LineOnlyReceiver):
 | 
			
		||||
    delimiter = "\n"
 | 
			
		||||
 | 
			
		||||
    def __init__(self, handler):
 | 
			
		||||
        self.current_field = None
 | 
			
		||||
        self.current_lines = []
 | 
			
		||||
        self.handler = handler
 | 
			
		||||
        self.done_deferred = defer.Deferred()
 | 
			
		||||
        self.eventtype = "message"
 | 
			
		||||
 | 
			
		||||
    def connectionLost(self, why):
 | 
			
		||||
        if why.check(ResponseDone):
 | 
			
		||||
            why = None
 | 
			
		||||
        self.done_deferred.callback(why)
 | 
			
		||||
 | 
			
		||||
    def dataReceived(self, data):
 | 
			
		||||
        # exceptions here aren't being logged properly, and tests will hang
 | 
			
		||||
        # rather than halt. I suspect twisted.web._newclient's
 | 
			
		||||
        # HTTP11ClientProtocol.dataReceived(), which catches everything and
 | 
			
		||||
        # responds with self._giveUp() but doesn't log.err.
 | 
			
		||||
        try:
 | 
			
		||||
            basic.LineOnlyReceiver.dataReceived(self, data)
 | 
			
		||||
        except:
 | 
			
		||||
            log.err()
 | 
			
		||||
            raise
 | 
			
		||||
 | 
			
		||||
    def lineReceived(self, line):
 | 
			
		||||
        if not line:
 | 
			
		||||
            # blank line ends the field
 | 
			
		||||
            self.fieldReceived(self.current_field,
 | 
			
		||||
                               "\n".join(self.current_lines))
 | 
			
		||||
            self.current_field = None
 | 
			
		||||
            self.current_lines[:] = []
 | 
			
		||||
            return
 | 
			
		||||
        if self.current_field is None:
 | 
			
		||||
            self.current_field, data = line.split(": ", 1)
 | 
			
		||||
            self.current_lines.append(data)
 | 
			
		||||
        else:
 | 
			
		||||
            self.current_lines.append(line)
 | 
			
		||||
 | 
			
		||||
    def fieldReceived(self, fieldname, data):
 | 
			
		||||
        if fieldname == "event":
 | 
			
		||||
            self.eventtype = data
 | 
			
		||||
        elif fieldname == "data":
 | 
			
		||||
            self.eventReceived(self.eventtype, data)
 | 
			
		||||
            self.eventtype = "message"
 | 
			
		||||
        else:
 | 
			
		||||
            log.msg("weird fieldname", fieldname, data)
 | 
			
		||||
 | 
			
		||||
    def eventReceived(self, eventtype, data):
 | 
			
		||||
        self.handler(eventtype, data)
 | 
			
		||||
 | 
			
		||||
class EventSourceError(Exception):
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
# es = EventSource(url, handler)
 | 
			
		||||
# d = es.start()
 | 
			
		||||
# es.cancel()
 | 
			
		||||
 | 
			
		||||
class EventSource: # TODO: service.Service
 | 
			
		||||
    def __init__(self, url, handler, when_connected=None, agent=None):
 | 
			
		||||
        self.url = url
 | 
			
		||||
        self.handler = handler
 | 
			
		||||
        self.when_connected = when_connected
 | 
			
		||||
        self.started = False
 | 
			
		||||
        self.cancelled = False
 | 
			
		||||
        self.proto = EventSourceParser(self.handler)
 | 
			
		||||
        if not agent:
 | 
			
		||||
            agent = Agent(reactor)
 | 
			
		||||
        self.agent = agent
 | 
			
		||||
 | 
			
		||||
    def start(self):
 | 
			
		||||
        assert not self.started, "single-use"
 | 
			
		||||
        self.started = True
 | 
			
		||||
        d = self.agent.request("GET", self.url,
 | 
			
		||||
                               Headers({"accept": ["text/event-stream"]}))
 | 
			
		||||
        d.addCallback(self._connected)
 | 
			
		||||
        return d
 | 
			
		||||
 | 
			
		||||
    def _connected(self, resp):
 | 
			
		||||
        if resp.code != 200:
 | 
			
		||||
            raise EventSourceError("%d: %s" % (resp.code, resp.phrase))
 | 
			
		||||
        if self.when_connected:
 | 
			
		||||
            self.when_connected()
 | 
			
		||||
        #if resp.headers.getRawHeaders("content-type") == ["text/event-stream"]:
 | 
			
		||||
        resp.deliverBody(self.proto)
 | 
			
		||||
        if self.cancelled:
 | 
			
		||||
            self.kill_connection()
 | 
			
		||||
        return self.proto.done_deferred
 | 
			
		||||
 | 
			
		||||
    def cancel(self):
 | 
			
		||||
        self.cancelled = True
 | 
			
		||||
        if not self.proto.transport:
 | 
			
		||||
            # _connected hasn't been called yet, but that self.cancelled
 | 
			
		||||
            # should take care of it when the connection is established
 | 
			
		||||
            def kill(data):
 | 
			
		||||
                # this should kill it as soon as any data is delivered
 | 
			
		||||
                raise ValueError("dead")
 | 
			
		||||
            self.proto.dataReceived = kill # just in case
 | 
			
		||||
            return
 | 
			
		||||
        self.kill_connection()
 | 
			
		||||
 | 
			
		||||
    def kill_connection(self):
 | 
			
		||||
        if (hasattr(self.proto.transport, "_producer")
 | 
			
		||||
            and self.proto.transport._producer):
 | 
			
		||||
            # This is gross and fragile. We need a clean way to stop the
 | 
			
		||||
            # client connection. p.transport is a
 | 
			
		||||
            # twisted.web._newclient.TransportProxyProducer , and its
 | 
			
		||||
            # ._producer is the tcp.Port.
 | 
			
		||||
            self.proto.transport._producer.loseConnection()
 | 
			
		||||
        else:
 | 
			
		||||
            log.err("get_events: unable to stop connection")
 | 
			
		||||
            # oh well
 | 
			
		||||
            #err = EventSourceError("unable to cancel")
 | 
			
		||||
            try:
 | 
			
		||||
                self.proto.done_deferred.callback(None)
 | 
			
		||||
            except defer.AlreadyCalledError:
 | 
			
		||||
                pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Connector:
 | 
			
		||||
    # behave enough like an IConnector to appease ReconnectingClientFactory
 | 
			
		||||
    def __init__(self, res):
 | 
			
		||||
        self.res = res
 | 
			
		||||
    def connect(self):
 | 
			
		||||
        self.res._maybeStart()
 | 
			
		||||
    def stopConnecting(self):
 | 
			
		||||
        self.res._stop_eventsource()
 | 
			
		||||
 | 
			
		||||
class ReconnectingEventSource(service.MultiService,
 | 
			
		||||
                              protocol.ReconnectingClientFactory):
 | 
			
		||||
    def __init__(self, baseurl, connection_starting, handler, agent=None):
 | 
			
		||||
        service.MultiService.__init__(self)
 | 
			
		||||
        # we don't use any of the basic Factory/ClientFactory methods of
 | 
			
		||||
        # this, just the ReconnectingClientFactory.retry, stopTrying, and
 | 
			
		||||
        # resetDelay methods.
 | 
			
		||||
 | 
			
		||||
        self.baseurl = baseurl
 | 
			
		||||
        self.connection_starting = connection_starting
 | 
			
		||||
        self.handler = handler
 | 
			
		||||
        self.agent = agent
 | 
			
		||||
        # IService provides self.running, toggled by {start,stop}Service.
 | 
			
		||||
        # self.active is toggled by {,de}activate. If both .running and
 | 
			
		||||
        # .active are True, then we want to have an outstanding EventSource
 | 
			
		||||
        # and will start one if necessary. If either is False, then we don't
 | 
			
		||||
        # want one to be outstanding, and will initiate shutdown.
 | 
			
		||||
        self.active = False
 | 
			
		||||
        self.connector = Connector(self)
 | 
			
		||||
        self.es = None # set we have an outstanding EventSource
 | 
			
		||||
        self.when_stopped = [] # list of Deferreds
 | 
			
		||||
 | 
			
		||||
    def isStopped(self):
 | 
			
		||||
        return not self.es
 | 
			
		||||
 | 
			
		||||
    def startService(self):
 | 
			
		||||
        service.MultiService.startService(self) # sets self.running
 | 
			
		||||
        self._maybeStart()
 | 
			
		||||
 | 
			
		||||
    def stopService(self):
 | 
			
		||||
        # clears self.running
 | 
			
		||||
        d = defer.maybeDeferred(service.MultiService.stopService, self)
 | 
			
		||||
        d.addCallback(self._maybeStop)
 | 
			
		||||
        return d
 | 
			
		||||
 | 
			
		||||
    def activate(self):
 | 
			
		||||
        assert not self.active
 | 
			
		||||
        self.active = True
 | 
			
		||||
        self._maybeStart()
 | 
			
		||||
 | 
			
		||||
    def deactivate(self):
 | 
			
		||||
        assert self.active # XXX
 | 
			
		||||
        self.active = False
 | 
			
		||||
        return self._maybeStop()
 | 
			
		||||
 | 
			
		||||
    def _maybeStart(self):
 | 
			
		||||
        if not (self.active and self.running):
 | 
			
		||||
            return
 | 
			
		||||
        self.continueTrying = True
 | 
			
		||||
        url = self.connection_starting()
 | 
			
		||||
        self.es = EventSource(url, self.handler, self.resetDelay,
 | 
			
		||||
                              agent=self.agent)
 | 
			
		||||
        d = self.es.start()
 | 
			
		||||
        d.addBoth(self._stopped)
 | 
			
		||||
 | 
			
		||||
    def _stopped(self, res):
 | 
			
		||||
        self.es = None
 | 
			
		||||
        # we might have stopped because of a connection error, or because of
 | 
			
		||||
        # an intentional shutdown.
 | 
			
		||||
        if self.active and self.running:
 | 
			
		||||
            # we still want to be connected, so schedule a reconnection
 | 
			
		||||
            if isinstance(res, failure.Failure):
 | 
			
		||||
                log.err(res)
 | 
			
		||||
            self.retry() # will eventually call _maybeStart
 | 
			
		||||
            return
 | 
			
		||||
        # intentional shutdown
 | 
			
		||||
        self.stopTrying()
 | 
			
		||||
        for d in self.when_stopped:
 | 
			
		||||
            eventually(d.callback, None)
 | 
			
		||||
        self.when_stopped = []
 | 
			
		||||
 | 
			
		||||
    def _stop_eventsource(self):
 | 
			
		||||
        if self.es:
 | 
			
		||||
            eventually(self.es.cancel)
 | 
			
		||||
 | 
			
		||||
    def _maybeStop(self, _=None):
 | 
			
		||||
        self.stopTrying() # cancels timer, calls _stop_eventsource()
 | 
			
		||||
        if not self.es:
 | 
			
		||||
            return defer.succeed(None)
 | 
			
		||||
        d = defer.Deferred()
 | 
			
		||||
        self.when_stopped.append(d)
 | 
			
		||||
        return d
 | 
			
		||||
| 
						 | 
				
			
			@ -1,27 +1,271 @@
 | 
			
		|||
from twisted.application import service
 | 
			
		||||
from ..const import RELAY
 | 
			
		||||
from __future__ import print_function
 | 
			
		||||
import os, sys, json, re
 | 
			
		||||
from binascii import hexlify, unhexlify
 | 
			
		||||
from zope.interface import implementer
 | 
			
		||||
from twisted.internet import reactor, defer
 | 
			
		||||
from twisted.web import client as web_client
 | 
			
		||||
from twisted.web import error as web_error
 | 
			
		||||
from twisted.web.iweb import IBodyProducer
 | 
			
		||||
from nacl.secret import SecretBox
 | 
			
		||||
from nacl.exceptions import CryptoError
 | 
			
		||||
from nacl import utils
 | 
			
		||||
from spake2 import SPAKE2_Symmetric
 | 
			
		||||
from .eventsource import ReconnectingEventSource
 | 
			
		||||
from .. import __version__
 | 
			
		||||
from .. import codes
 | 
			
		||||
from ..errors import ServerError
 | 
			
		||||
from ..util.hkdf import HKDF
 | 
			
		||||
 | 
			
		||||
class TwistedInitiator(service.MultiService):
 | 
			
		||||
    def __init__(self, appid, data, reactor, relay=RELAY):
 | 
			
		||||
        self.appid = appid
 | 
			
		||||
class WrongPasswordError(Exception):
 | 
			
		||||
    """
 | 
			
		||||
    Key confirmation failed.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
class ReflectionAttack(Exception):
 | 
			
		||||
    """An attacker (or bug) reflected our outgoing message back to us."""
 | 
			
		||||
 | 
			
		||||
class UsageError(Exception):
 | 
			
		||||
    """The programmer did something wrong."""
 | 
			
		||||
 | 
			
		||||
@implementer(IBodyProducer)
 | 
			
		||||
class DataProducer:
 | 
			
		||||
    def __init__(self, data):
 | 
			
		||||
        self.data = data
 | 
			
		||||
        self.reactor = reactor
 | 
			
		||||
        self.length = len(data)
 | 
			
		||||
    def startProducing(self, consumer):
 | 
			
		||||
        consumer.write(self.data)
 | 
			
		||||
        return defer.succeed(None)
 | 
			
		||||
    def stopProducing(self):
 | 
			
		||||
        pass
 | 
			
		||||
    def pauseProducing(self):
 | 
			
		||||
        pass
 | 
			
		||||
    def resumeProducing(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SymmetricWormhole:
 | 
			
		||||
    def __init__(self, appid, relay):
 | 
			
		||||
        self.appid = appid
 | 
			
		||||
        self.relay = relay
 | 
			
		||||
        self.agent = web_client.Agent(reactor)
 | 
			
		||||
        self.side = None
 | 
			
		||||
        self.code = None
 | 
			
		||||
        self.key = None
 | 
			
		||||
        self._started_get_code = False
 | 
			
		||||
 | 
			
		||||
    def when_get_code(self):
 | 
			
		||||
        pass # return Deferred
 | 
			
		||||
    def get_code(self, code_length=2):
 | 
			
		||||
        if self.code is not None: raise UsageError
 | 
			
		||||
        if self._started_get_code: raise UsageError
 | 
			
		||||
        self._started_get_code = True
 | 
			
		||||
        self.side = hexlify(os.urandom(5))
 | 
			
		||||
        d = self._allocate_channel()
 | 
			
		||||
        def _got_channel_id(channel_id):
 | 
			
		||||
            code = codes.make_code(channel_id, code_length)
 | 
			
		||||
            self._set_code_and_channel_id(code)
 | 
			
		||||
            self._start()
 | 
			
		||||
            return code
 | 
			
		||||
        d.addCallback(_got_channel_id)
 | 
			
		||||
        return d
 | 
			
		||||
 | 
			
		||||
    def when_get_data(self):
 | 
			
		||||
        pass # return Deferred
 | 
			
		||||
    def _allocate_channel(self):
 | 
			
		||||
        url = self.relay + "allocate/%s" % self.side
 | 
			
		||||
        d = self.post(url)
 | 
			
		||||
        def _got_channel(data):
 | 
			
		||||
            if "welcome" in data:
 | 
			
		||||
                self.handle_welcome(data["welcome"])
 | 
			
		||||
            return data["channel-id"]
 | 
			
		||||
        d.addCallback(_got_channel)
 | 
			
		||||
        return d
 | 
			
		||||
 | 
			
		||||
class TwistedReceiver(service.MultiService):
 | 
			
		||||
    def __init__(self, appid, data, code, reactor, relay=RELAY):
 | 
			
		||||
        self.appid = appid
 | 
			
		||||
        self.data = data
 | 
			
		||||
    def set_code(self, code):
 | 
			
		||||
        if self.code is not None: raise UsageError
 | 
			
		||||
        if self.side is not None: raise UsageError
 | 
			
		||||
        self._set_code_and_channel_id(code)
 | 
			
		||||
        self.side = hexlify(os.urandom(5))
 | 
			
		||||
        self._start()
 | 
			
		||||
 | 
			
		||||
    def _set_code_and_channel_id(self, code):
 | 
			
		||||
        if self.code is not None: raise UsageError
 | 
			
		||||
        mo = re.search(r'^(\d+)-', code)
 | 
			
		||||
        if not mo:
 | 
			
		||||
            raise ValueError("code (%s) must start with NN-" % code)
 | 
			
		||||
        self.channel_id = int(mo.group(1))
 | 
			
		||||
        self.code = code
 | 
			
		||||
        self.reactor = reactor
 | 
			
		||||
        self.relay = relay
 | 
			
		||||
 | 
			
		||||
    def when_get_data(self):
 | 
			
		||||
        pass # return Deferred
 | 
			
		||||
    def _start(self):
 | 
			
		||||
        # allocate the rest now too, so it can be serialized
 | 
			
		||||
        self.sp = SPAKE2_Symmetric(self.code.encode("ascii"),
 | 
			
		||||
                                   idA=self.appid+":SymmetricA",
 | 
			
		||||
                                   idB=self.appid+":SymmetricB")
 | 
			
		||||
        self.msg1 = self.sp.start()
 | 
			
		||||
 | 
			
		||||
    def serialize(self):
 | 
			
		||||
        # I can only be serialized after get_code/set_code and before
 | 
			
		||||
        # get_verifier/get_data
 | 
			
		||||
        if self.code is None: raise UsageError
 | 
			
		||||
        if self.key is not None: raise UsageError
 | 
			
		||||
        data = {
 | 
			
		||||
            "appid": self.appid,
 | 
			
		||||
            "relay": self.relay,
 | 
			
		||||
            "code": self.code,
 | 
			
		||||
            "side": self.side,
 | 
			
		||||
            "spake2": json.loads(self.sp.serialize()),
 | 
			
		||||
            "msg1": self.msg1.encode("hex"),
 | 
			
		||||
        }
 | 
			
		||||
        return json.dumps(data)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def from_serialized(klass, data):
 | 
			
		||||
        d = json.loads(data)
 | 
			
		||||
        self = klass(d["appid"].encode("ascii"), d["relay"].encode("ascii"))
 | 
			
		||||
        self._set_code_and_channel_id(d["code"].encode("ascii"))
 | 
			
		||||
        self.side = d["side"].encode("ascii")
 | 
			
		||||
        self.sp = SPAKE2_Symmetric.from_serialized(json.dumps(d["spake2"]))
 | 
			
		||||
        self.msg1 = d["msg1"].decode("hex")
 | 
			
		||||
        return self
 | 
			
		||||
 | 
			
		||||
    motd_displayed = False
 | 
			
		||||
    version_warning_displayed = False
 | 
			
		||||
 | 
			
		||||
    def handle_welcome(self, welcome):
 | 
			
		||||
        if ("motd" in welcome and
 | 
			
		||||
            not self.motd_displayed):
 | 
			
		||||
            motd_lines = welcome["motd"].splitlines()
 | 
			
		||||
            motd_formatted = "\n ".join(motd_lines)
 | 
			
		||||
            print("Server (at %s) says:\n %s" % (self.relay, motd_formatted),
 | 
			
		||||
                  file=sys.stderr)
 | 
			
		||||
            self.motd_displayed = True
 | 
			
		||||
 | 
			
		||||
        # Only warn if we're running a release version (e.g. 0.0.6, not
 | 
			
		||||
        # 0.0.6-DISTANCE-gHASH). Only warn once.
 | 
			
		||||
        if ("-" not in __version__ and
 | 
			
		||||
            not self.version_warning_displayed and
 | 
			
		||||
            welcome["current_version"] != __version__):
 | 
			
		||||
            print("Warning: errors may occur unless both sides are running the same version", file=sys.stderr)
 | 
			
		||||
            print("Server claims %s is current, but ours is %s"
 | 
			
		||||
                  % (welcome["current_version"], __version__), file=sys.stderr)
 | 
			
		||||
            self.version_warning_displayed = True
 | 
			
		||||
 | 
			
		||||
        if "error" in welcome:
 | 
			
		||||
            raise ServerError(welcome["error"], self.relay)
 | 
			
		||||
 | 
			
		||||
    def url(self, verb, msgnum=None):
 | 
			
		||||
        url = "%s%d/%s/%s" % (self.relay, self.channel_id, self.side, verb)
 | 
			
		||||
        if msgnum is not None:
 | 
			
		||||
            url += "/" + msgnum
 | 
			
		||||
        return url
 | 
			
		||||
 | 
			
		||||
    def post(self, url, post_json=None):
 | 
			
		||||
        # TODO: retry on failure, with exponential backoff. We're guarding
 | 
			
		||||
        # against the rendezvous server being temporarily offline.
 | 
			
		||||
        p = None
 | 
			
		||||
        if post_json:
 | 
			
		||||
            data = json.dumps(post_json).encode("utf-8")
 | 
			
		||||
            p = DataProducer(data)
 | 
			
		||||
        d = self.agent.request("POST", url, bodyProducer=p)
 | 
			
		||||
        def _check_error(resp):
 | 
			
		||||
            if resp.code != 200:
 | 
			
		||||
                raise web_error.Error(resp.code, resp.phrase)
 | 
			
		||||
            return resp
 | 
			
		||||
        d.addCallback(_check_error)
 | 
			
		||||
        d.addCallback(web_client.readBody)
 | 
			
		||||
        d.addCallback(lambda data: json.loads(data))
 | 
			
		||||
        return d
 | 
			
		||||
 | 
			
		||||
    def _get_msgs(self, old_msgs, verb, msgnum):
 | 
			
		||||
        # fire with a list of messages that match verb/msgnum, which either
 | 
			
		||||
        # came from old_msgs, or from an EventSource that we attached to the
 | 
			
		||||
        # corresponding URL
 | 
			
		||||
        if old_msgs:
 | 
			
		||||
            return defer.succeed(old_msgs)
 | 
			
		||||
        d = defer.Deferred()
 | 
			
		||||
        msgs = []
 | 
			
		||||
        def _handle(name, data):
 | 
			
		||||
            if name == "welcome":
 | 
			
		||||
                self.handle_welcome(json.loads(data))
 | 
			
		||||
            if name == "message":
 | 
			
		||||
                msgs.append(json.loads(data)["message"])
 | 
			
		||||
                d.callback(None)
 | 
			
		||||
        es = ReconnectingEventSource(None, lambda: self.url(verb, msgnum),
 | 
			
		||||
                                     _handle)#, agent=self.agent)
 | 
			
		||||
        es.startService() # TODO: .setServiceParent(self)
 | 
			
		||||
        es.activate()
 | 
			
		||||
        d.addCallback(lambda _: es.deactivate())
 | 
			
		||||
        d.addCallback(lambda _: es.stopService())
 | 
			
		||||
        d.addCallback(lambda _: msgs)
 | 
			
		||||
        return d
 | 
			
		||||
 | 
			
		||||
    def derive_key(self, purpose, length=SecretBox.KEY_SIZE):
 | 
			
		||||
        assert self.key is not None # call after get_verifier() or get_data()
 | 
			
		||||
        assert type(purpose) == type(b"")
 | 
			
		||||
        return HKDF(self.key, length, CTXinfo=purpose)
 | 
			
		||||
 | 
			
		||||
    def _encrypt_data(self, key, data):
 | 
			
		||||
        assert len(key) == SecretBox.KEY_SIZE
 | 
			
		||||
        box = SecretBox(key)
 | 
			
		||||
        nonce = utils.random(SecretBox.NONCE_SIZE)
 | 
			
		||||
        return box.encrypt(data, nonce)
 | 
			
		||||
 | 
			
		||||
    def _decrypt_data(self, key, encrypted):
 | 
			
		||||
        assert len(key) == SecretBox.KEY_SIZE
 | 
			
		||||
        box = SecretBox(key)
 | 
			
		||||
        data = box.decrypt(encrypted)
 | 
			
		||||
        return data
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    def _get_key(self):
 | 
			
		||||
        # TODO: prevent multiple invocation
 | 
			
		||||
        if self.key:
 | 
			
		||||
            return defer.succeed(self.key)
 | 
			
		||||
        data = {"message": hexlify(self.msg1).decode("ascii")}
 | 
			
		||||
        d = self.post(self.url("post", "pake"), data)
 | 
			
		||||
        d.addCallback(lambda j: self._get_msgs(j["messages"], "poll", "pake"))
 | 
			
		||||
        def _got_pake(msgs):
 | 
			
		||||
            pake_msg = unhexlify(msgs[0].encode("ascii"))
 | 
			
		||||
            key = self.sp.finish(pake_msg)
 | 
			
		||||
            self.key = key
 | 
			
		||||
            self.verifier = self.derive_key(self.appid+b":Verifier")
 | 
			
		||||
            return key
 | 
			
		||||
        d.addCallback(_got_pake)
 | 
			
		||||
        return d
 | 
			
		||||
 | 
			
		||||
    def get_verifier(self):
 | 
			
		||||
        if self.code is None: raise UsageError
 | 
			
		||||
        d = self._get_key()
 | 
			
		||||
        d.addCallback(lambda _: self.verifier)
 | 
			
		||||
        return d
 | 
			
		||||
 | 
			
		||||
    def get_data(self, outbound_data):
 | 
			
		||||
        # only call this once
 | 
			
		||||
        if self.code is None: raise UsageError
 | 
			
		||||
        d = self._get_key()
 | 
			
		||||
        d.addCallback(self._get_data2, outbound_data)
 | 
			
		||||
        return d
 | 
			
		||||
 | 
			
		||||
    def _get_data2(self, key, outbound_data):
 | 
			
		||||
        # Without predefined roles, we can't derive predictably unique keys
 | 
			
		||||
        # for each side, so we use the same key for both. We use random
 | 
			
		||||
        # nonces to keep the messages distinct, and check for reflection.
 | 
			
		||||
        data_key = self.derive_key(b"data-key")
 | 
			
		||||
        outbound_encrypted = self._encrypt_data(data_key, outbound_data)
 | 
			
		||||
        data = {"message": hexlify(outbound_encrypted).decode("ascii")}
 | 
			
		||||
        d = self.post(self.url("post", "data"), data)
 | 
			
		||||
        d.addCallback(lambda j: self._get_msgs(j["messages"], "poll", "data"))
 | 
			
		||||
        def _got_data(msgs):
 | 
			
		||||
            inbound_encrypted = unhexlify(msgs[0].encode("ascii"))
 | 
			
		||||
            if inbound_encrypted == outbound_encrypted:
 | 
			
		||||
                raise ReflectionAttack
 | 
			
		||||
            try:
 | 
			
		||||
                inbound_data = self._decrypt_data(data_key, inbound_encrypted)
 | 
			
		||||
                return inbound_data
 | 
			
		||||
            except CryptoError:
 | 
			
		||||
                raise WrongPasswordError
 | 
			
		||||
        d.addCallback(_got_data)
 | 
			
		||||
        d.addBoth(self._deallocate)
 | 
			
		||||
        return d
 | 
			
		||||
 | 
			
		||||
    def _deallocate(self, res):
 | 
			
		||||
        # only try once, no retries
 | 
			
		||||
        d = self.agent.request("POST", self.url("deallocate"))
 | 
			
		||||
        d.addBoth(lambda _: res) # ignore POST failure, pass-through result
 | 
			
		||||
        return d
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user