remove plain-HTTP (non-WebSocket) rendezvous server

This commit is contained in:
Brian Warner 2016-05-12 16:56:19 -07:00
parent 104ef44d53
commit a34fb2a98b
4 changed files with 1 additions and 801 deletions

View File

@ -1,223 +0,0 @@
import json, time
from twisted.web import server, resource
from twisted.python import log
def json_response(request, data):
request.setHeader(b"content-type", b"application/json; charset=utf-8")
return (json.dumps(data)+"\n").encode("utf-8")
class EventsProtocol:
def __init__(self, request):
self.request = request
def sendComment(self, comment):
# this is ignored by clients, but can keep the connection open in the
# face of firewall/NAT timeouts. It also helps unit tests, since
# apparently twisted.web.client.Agent doesn't consider the connection
# to be established until it sees the first byte of the reponse body.
self.request.write(b": " + comment + b"\n\n")
def sendEvent(self, data, name=None, id=None, retry=None):
if name:
self.request.write(b"event: " + name.encode("utf-8") + b"\n")
# e.g. if name=foo, then the client web page should do:
# (new EventSource(url)).addEventListener("foo", handlerfunc)
# Note that this basically defaults to "message".
if id:
self.request.write(b"id: " + id.encode("utf-8") + b"\n")
if retry:
self.request.write(b"retry: " + retry + b"\n") # milliseconds
for line in data.splitlines():
self.request.write(b"data: " + line.encode("utf-8") + b"\n")
self.request.write(b"\n")
def stop(self):
self.request.finish()
def send_rendezvous_event(self, data):
data = data.copy()
data["sent"] = time.time()
self.sendEvent(json.dumps(data))
def stop_rendezvous_watcher(self):
self.stop()
# note: no versions of IE (including the current IE11) support EventSource
# relay URLs are as follows: (MESSAGES=[{phase:,body:}..])
# ("-" indicates a deprecated URL)
# GET /list?appid= -> {channelids: [INT..]}
# POST /allocate {appid:,side:} -> {channelid: INT}
# these return all messages (base64) for appid=/channelid= :
# POST /add {appid:,channelid:,side:,phase:,body:} -> {messages: MESSAGES}
# GET /get?appid=&channelid= (no-eventsource) -> {messages: MESSAGES}
#- GET /get?appid=&channelid= (eventsource) -> {phase:, body:}..
# GET /watch?appid=&channelid= (eventsource) -> {phase:, body:}..
# POST /deallocate {appid:,channelid:,side:} -> {status: waiting | deleted}
# all JSON responses include a "welcome:{..}" key
class RelayResource(resource.Resource):
def __init__(self, rendezvous):
resource.Resource.__init__(self)
self._rendezvous = rendezvous
self._welcome = rendezvous.get_welcome()
class ChannelLister(RelayResource):
def render_GET(self, request):
if b"appid" not in request.args:
e = NeedToUpgradeErrorResource(self._welcome)
return e.get_message()
appid = request.args[b"appid"][0].decode("utf-8")
#print("LIST", appid)
app = self._rendezvous.get_app(appid)
allocated = app.get_allocated()
data = {"welcome": self._welcome, "channelids": sorted(allocated),
"sent": time.time()}
return json_response(request, data)
class Allocator(RelayResource):
def render_POST(self, request):
content = request.content.read()
data = json.loads(content.decode("utf-8"))
appid = data["appid"]
side = data["side"]
if not isinstance(side, type(u"")):
raise TypeError("side must be string, not '%s'" % type(side))
#print("ALLOCATE", appid, side)
app = self._rendezvous.get_app(appid)
channelid = app.find_available_channelid()
app.allocate_channel(channelid, side)
if self._rendezvous.get_log_requests():
log.msg("allocated #%d, now have %d DB channels" %
(channelid, len(app.get_allocated())))
response = {"welcome": self._welcome, "channelid": channelid,
"sent": time.time()}
return json_response(request, response)
def getChild(self, path, req):
# wormhole-0.4.0 "send" started with "POST /allocate/SIDE".
# wormhole-0.5.0 changed that to "POST /allocate". We catch the old
# URL here to deliver a nicer error message (with upgrade
# instructions) than an ugly 404.
return NeedToUpgradeErrorResource(self._welcome)
class NeedToUpgradeErrorResource(resource.Resource):
def __init__(self, welcome):
resource.Resource.__init__(self)
w = welcome.copy()
w["error"] = "Sorry, you must upgrade your client to use this server."
message = {"welcome": w}
self._message = (json.dumps(message)+"\n").encode("utf-8")
def get_message(self):
return self._message
def render_POST(self, request):
return self._message
def render_GET(self, request):
return self._message
def getChild(self, path, req):
return self
class Adder(RelayResource):
def render_POST(self, request):
#content = json.load(request.content, encoding="utf-8")
content = request.content.read()
data = json.loads(content.decode("utf-8"))
appid = data["appid"]
channelid = int(data["channelid"])
side = data["side"]
phase = data["phase"]
if not isinstance(phase, type(u"")):
raise TypeError("phase must be string, not %s" % type(phase))
body = data["body"]
#print("ADD", appid, channelid, side, phase, body)
app = self._rendezvous.get_app(appid)
channel = app.get_channel(channelid)
messages = channel.add_message(side, phase, body, time.time(), None)
response = {"welcome": self._welcome, "messages": messages,
"sent": time.time()}
return json_response(request, response)
class GetterOrWatcher(RelayResource):
def render_GET(self, request):
appid = request.args[b"appid"][0].decode("utf-8")
channelid = int(request.args[b"channelid"][0])
#print("GET", appid, channelid)
app = self._rendezvous.get_app(appid)
channel = app.get_channel(channelid)
if b"text/event-stream" not in (request.getHeader(b"accept") or b""):
messages = channel.get_messages()
response = {"welcome": self._welcome, "messages": messages,
"sent": time.time()}
return json_response(request, response)
request.setHeader(b"content-type", b"text/event-stream; charset=utf-8")
ep = EventsProtocol(request)
ep.sendEvent(json.dumps(self._welcome), name="welcome")
old_events = channel.add_listener(ep)
request.notifyFinish().addErrback(lambda f:
channel.remove_listener(ep))
for old_event in old_events:
ep.send_rendezvous_event(old_event)
return server.NOT_DONE_YET
class Watcher(RelayResource):
def render_GET(self, request):
appid = request.args[b"appid"][0].decode("utf-8")
channelid = int(request.args[b"channelid"][0])
app = self._rendezvous.get_app(appid)
channel = app.get_channel(channelid)
if b"text/event-stream" not in (request.getHeader(b"accept") or b""):
raise TypeError("/watch is for EventSource only")
request.setHeader(b"content-type", b"text/event-stream; charset=utf-8")
ep = EventsProtocol(request)
ep.sendEvent(json.dumps(self._welcome), name="welcome")
old_events = channel.add_listener(ep)
request.notifyFinish().addErrback(lambda f:
channel.remove_listener(ep))
for old_event in old_events:
ep.send_rendezvous_event(old_event)
return server.NOT_DONE_YET
class Deallocator(RelayResource):
def render_POST(self, request):
content = request.content.read()
data = json.loads(content.decode("utf-8"))
appid = data["appid"]
channelid = int(data["channelid"])
side = data["side"]
if not isinstance(side, type(u"")):
raise TypeError("side must be string, not '%s'" % type(side))
mood = data.get("mood")
#print("DEALLOCATE", appid, channelid, side)
app = self._rendezvous.get_app(appid)
channel = app.get_channel(channelid)
deleted = channel.deallocate(side, mood)
response = {"status": "waiting", "sent": time.time()}
if deleted:
response = {"status": "deleted", "sent": time.time()}
return json_response(request, response)
class WebRendezvous(resource.Resource):
def __init__(self, rendezvous):
resource.Resource.__init__(self)
self._rendezvous = rendezvous
self.putChild(b"list", ChannelLister(rendezvous))
self.putChild(b"allocate", Allocator(rendezvous))
self.putChild(b"add", Adder(rendezvous))
self.putChild(b"get", GetterOrWatcher(rendezvous))
self.putChild(b"watch", Watcher(rendezvous))
self.putChild(b"deallocate", Deallocator(rendezvous))
def getChild(self, path, req):
# 0.4.0 used "POST /CID/SIDE/post/MSGNUM"
# 0.5.0 replaced it with "POST /add (json body)"
# give a nicer error message to old clients
if (len(req.postpath) >= 2
and req.postpath[1] in (b"post", b"poll", b"deallocate")):
welcome = self._rendezvous.get_welcome()
return NeedToUpgradeErrorResource(welcome)
return resource.NoResource("No such child resource.")

View File

@ -8,7 +8,6 @@ from .endpoint_service import ServerEndpointService
from .. import __version__ from .. import __version__
from .database import get_db from .database import get_db
from .rendezvous import Rendezvous from .rendezvous import Rendezvous
from .rendezvous_web import WebRendezvous
from .rendezvous_websocket import WebSocketRendezvousFactory from .rendezvous_websocket import WebSocketRendezvousFactory
from .transit_server import Transit from .transit_server import Transit
@ -47,7 +46,7 @@ class RelayServer(service.MultiService):
rendezvous.setServiceParent(self) # for the pruning timer rendezvous.setServiceParent(self) # for the pruning timer
root = Root() root = Root()
wr = WebRendezvous(rendezvous) wr = resource.Resource()
root.putChild(b"wormhole-relay", wr) root.putChild(b"wormhole-relay", wr)
wsrf = WebSocketRendezvousFactory(None, rendezvous) wsrf = WebSocketRendezvousFactory(None, rendezvous)
@ -72,7 +71,6 @@ class RelayServer(service.MultiService):
self._db = db self._db = db
self._rendezvous = rendezvous self._rendezvous = rendezvous
self._root = root self._root = root
self._rendezvous_web = wr
self._rendezvous_web_service = rendezvous_web_service self._rendezvous_web_service = rendezvous_web_service
self._rendezvous_websocket = wsrf self._rendezvous_websocket = wsrf
if transit_port: if transit_port:

View File

@ -1,19 +1,15 @@
from __future__ import print_function from __future__ import print_function
import json, itertools import json, itertools
from binascii import hexlify from binascii import hexlify
import requests
from six.moves.urllib_parse import urlencode
from twisted.trial import unittest from twisted.trial import unittest
from twisted.internet import protocol, reactor, defer from twisted.internet import protocol, reactor, defer
from twisted.internet.defer import inlineCallbacks, returnValue from twisted.internet.defer import inlineCallbacks, returnValue
from twisted.internet.threads import deferToThread
from twisted.internet.endpoints import clientFromString, connectProtocol from twisted.internet.endpoints import clientFromString, connectProtocol
from twisted.web.client import getPage, Agent, readBody from twisted.web.client import getPage, Agent, readBody
from autobahn.twisted import websocket from autobahn.twisted import websocket
from .. import __version__ from .. import __version__
from .common import ServerBase from .common import ServerBase
from ..server import rendezvous, transit_server from ..server import rendezvous, transit_server
from ..twisted.eventsource import EventSource
class Reachable(ServerBase, unittest.TestCase): class Reachable(ServerBase, unittest.TestCase):
@ -39,22 +35,6 @@ class Reachable(ServerBase, unittest.TestCase):
d.addCallback(_got) d.addCallback(_got)
return d return d
def test_requests(self):
# requests requires bytes URL, returns unicode
url = self.relayurl.replace("wormhole-relay/", "")
def _get(url):
r = requests.get(url)
r.raise_for_status()
return r.text
d = deferToThread(_get, url)
def _got(res):
self.failUnlessEqual(res, "Wormhole Relay\n")
d.addCallback(_got)
return d
def unjson(data):
return json.loads(data.decode("utf-8"))
def strip_message(msg): def strip_message(msg):
m2 = msg.copy() m2 = msg.copy()
m2.pop("id", None) m2.pop("id", None)
@ -64,323 +44,6 @@ def strip_message(msg):
def strip_messages(messages): def strip_messages(messages):
return [strip_message(m) for m in messages] return [strip_message(m) for m in messages]
class WebAPI(ServerBase, unittest.TestCase):
def build_url(self, path, appid, channelid):
url = self.relayurl+path
queryargs = []
if appid:
queryargs.append(("appid", appid))
if channelid:
queryargs.append(("channelid", channelid))
if queryargs:
url += "?" + urlencode(queryargs)
return url
def get(self, path, appid=None, channelid=None):
url = self.build_url(path, appid, channelid)
d = getPage(url.encode("ascii"))
d.addCallback(unjson)
return d
def post(self, path, data):
url = self.relayurl+path
d = getPage(url.encode("ascii"), method=b"POST",
postdata=json.dumps(data).encode("utf-8"))
d.addCallback(unjson)
return d
def check_welcome(self, data):
self.failUnlessIn("welcome", data)
self.failUnlessEqual(data["welcome"], {"current_version": __version__})
def test_allocate_1(self):
d = self.get("list", "app1")
def _check_list_1(data):
self.check_welcome(data)
self.failUnlessEqual(data["channelids"], [])
d.addCallback(_check_list_1)
d.addCallback(lambda _: self.post("allocate", {"appid": "app1",
"side": "abc"}))
def _allocated(data):
data.pop("sent", None)
self.failUnlessEqual(set(data.keys()),
set(["welcome", "channelid"]))
self.failUnlessIsInstance(data["channelid"], int)
self.cid = data["channelid"]
d.addCallback(_allocated)
d.addCallback(lambda _: self.get("list", "app1"))
def _check_list_2(data):
self.failUnlessEqual(data["channelids"], [self.cid])
d.addCallback(_check_list_2)
d.addCallback(lambda _: self.post("deallocate",
{"appid": "app1",
"channelid": str(self.cid),
"side": "abc"}))
def _check_deallocate(res):
self.failUnlessEqual(res["status"], "deleted")
d.addCallback(_check_deallocate)
d.addCallback(lambda _: self.get("list", "app1"))
def _check_list_3(data):
self.failUnlessEqual(data["channelids"], [])
d.addCallback(_check_list_3)
return d
def test_allocate_2(self):
d = self.post("allocate", {"appid": "app1", "side": "abc"})
def _allocated(data):
self.cid = data["channelid"]
d.addCallback(_allocated)
# second caller increases the number of known sides to 2
d.addCallback(lambda _: self.post("add",
{"appid": "app1",
"channelid": str(self.cid),
"side": "def",
"phase": "1",
"body": ""}))
d.addCallback(lambda _: self.get("list", "app1"))
d.addCallback(lambda data:
self.failUnlessEqual(data["channelids"], [self.cid]))
d.addCallback(lambda _: self.post("deallocate",
{"appid": "app1",
"channelid": str(self.cid),
"side": "abc"}))
d.addCallback(lambda res:
self.failUnlessEqual(res["status"], "waiting"))
d.addCallback(lambda _: self.post("deallocate",
{"appid": "app1",
"channelid": str(self.cid),
"side": "NOT"}))
d.addCallback(lambda res:
self.failUnlessEqual(res["status"], "waiting"))
d.addCallback(lambda _: self.post("deallocate",
{"appid": "app1",
"channelid": str(self.cid),
"side": "def"}))
d.addCallback(lambda res:
self.failUnlessEqual(res["status"], "deleted"))
d.addCallback(lambda _: self.get("list", "app1"))
d.addCallback(lambda data:
self.failUnlessEqual(data["channelids"], []))
return d
UPGRADE_ERROR = "Sorry, you must upgrade your client to use this server."
def test_old_allocate(self):
# 0.4.0 used "POST /allocate/SIDE".
# 0.5.0 replaced it with "POST /allocate".
# test that an old client gets a useful error message, not a 404.
d = self.post("allocate/abc", {})
def _check(data):
self.failUnlessEqual(data["welcome"]["error"], self.UPGRADE_ERROR)
d.addCallback(_check)
return d
def test_old_list(self):
# 0.4.0 used "GET /list".
# 0.5.0 replaced it with "GET /list?appid="
d = self.get("list", {}) # no appid
def _check(data):
self.failUnlessEqual(data["welcome"]["error"], self.UPGRADE_ERROR)
d.addCallback(_check)
return d
def test_old_post(self):
# 0.4.0 used "POST /CID/SIDE/post/MSGNUM"
# 0.5.0 replaced it with "POST /add (json body)"
d = self.post("1/abc/post/pake", {})
def _check(data):
self.failUnlessEqual(data["welcome"]["error"], self.UPGRADE_ERROR)
d.addCallback(_check)
return d
def add_message(self, message, side="abc", phase="1"):
return self.post("add",
{"appid": "app1",
"channelid": str(self.cid),
"side": side,
"phase": phase,
"body": message})
def parse_messages(self, messages):
out = set()
for m in messages:
self.failUnlessEqual(sorted(m.keys()), sorted(["phase", "body"]))
self.failUnlessIsInstance(m["phase"], type(u""))
self.failUnlessIsInstance(m["body"], type(u""))
out.add( (m["phase"], m["body"]) )
return out
def check_messages(self, one, two):
# Comparing lists-of-dicts is non-trivial in python3 because we can
# neither sort them (dicts are uncomparable), nor turn them into sets
# (dicts are unhashable). This is close enough.
self.failUnlessEqual(len(one), len(two), (one,two))
for d in one:
self.failUnlessIn(d, two)
def test_message(self):
# exercise POST /add
d = self.post("allocate", {"appid": "app1", "side": "abc"})
def _allocated(data):
self.cid = data["channelid"]
d.addCallback(_allocated)
d.addCallback(lambda _: self.add_message("msg1A"))
def _check1(data):
self.check_welcome(data)
self.failUnlessEqual(strip_messages(data["messages"]),
[{"phase": "1", "body": "msg1A"}])
d.addCallback(_check1)
d.addCallback(lambda _: self.get("get", "app1", str(self.cid)))
d.addCallback(_check1)
d.addCallback(lambda _: self.add_message("msg1B", side="def"))
def _check2(data):
self.check_welcome(data)
self.failUnlessEqual(self.parse_messages(strip_messages(data["messages"])),
set([("1", "msg1A"),
("1", "msg1B")]))
d.addCallback(_check2)
d.addCallback(lambda _: self.get("get", "app1", str(self.cid)))
d.addCallback(_check2)
# adding a duplicate message is not an error, is ignored by clients
d.addCallback(lambda _: self.add_message("msg1B", side="def"))
def _check3(data):
self.check_welcome(data)
self.failUnlessEqual(self.parse_messages(strip_messages(data["messages"])),
set([("1", "msg1A"),
("1", "msg1B")]))
d.addCallback(_check3)
d.addCallback(lambda _: self.get("get", "app1", str(self.cid)))
d.addCallback(_check3)
d.addCallback(lambda _: self.add_message("msg2A", side="abc",
phase="2"))
def _check4(data):
self.check_welcome(data)
self.failUnlessEqual(self.parse_messages(strip_messages(data["messages"])),
set([("1", "msg1A"),
("1", "msg1B"),
("2", "msg2A"),
]))
d.addCallback(_check4)
d.addCallback(lambda _: self.get("get", "app1", str(self.cid)))
d.addCallback(_check4)
return d
def test_watch_message(self):
# exercise GET /get (the EventSource version)
# this API is scheduled to be removed after 0.6.0
return self._do_watch("get")
def test_watch(self):
# exercise GET /watch (the EventSource version)
return self._do_watch("watch")
def _do_watch(self, endpoint_name):
d = self.post("allocate", {"appid": "app1", "side": "abc"})
def _allocated(data):
self.cid = data["channelid"]
url = self.build_url(endpoint_name, "app1", self.cid)
self.o = OneEventAtATime(url, parser=json.loads)
return self.o.wait_for_connection()
d.addCallback(_allocated)
d.addCallback(lambda _: self.o.wait_for_next_event())
def _check_welcome(ev):
eventtype, data = ev
self.failUnlessEqual(eventtype, "welcome")
self.failUnlessEqual(data, {"current_version": __version__})
d.addCallback(_check_welcome)
d.addCallback(lambda _: self.add_message("msg1A"))
d.addCallback(lambda _: self.o.wait_for_next_event())
def _check_msg1(ev):
eventtype, data = ev
self.failUnlessEqual(eventtype, "message")
data.pop("sent", None)
self.failUnlessEqual(strip_message(data),
{"phase": "1", "body": "msg1A"})
d.addCallback(_check_msg1)
d.addCallback(lambda _: self.add_message("msg1B"))
d.addCallback(lambda _: self.add_message("msg2A", phase="2"))
d.addCallback(lambda _: self.o.wait_for_next_event())
def _check_msg2(ev):
eventtype, data = ev
self.failUnlessEqual(eventtype, "message")
data.pop("sent", None)
self.failUnlessEqual(strip_message(data),
{"phase": "1", "body": "msg1B"})
d.addCallback(_check_msg2)
d.addCallback(lambda _: self.o.wait_for_next_event())
def _check_msg3(ev):
eventtype, data = ev
self.failUnlessEqual(eventtype, "message")
data.pop("sent", None)
self.failUnlessEqual(strip_message(data),
{"phase": "2", "body": "msg2A"})
d.addCallback(_check_msg3)
d.addCallback(lambda _: self.o.close())
d.addCallback(lambda _: self.o.wait_for_disconnection())
return d
class OneEventAtATime:
def __init__(self, url, parser=lambda e: e):
self.parser = parser
self.d = None
self._connected = False
self.connected_d = defer.Deferred()
self.disconnected_d = defer.Deferred()
self.events = []
self.es = EventSource(url, self.handler, when_connected=self.connected)
d = self.es.start()
d.addBoth(self.disconnected)
def close(self):
self.es.cancel()
def wait_for_next_event(self):
assert not self.d
if self.events:
event = self.events.pop(0)
return defer.succeed(event)
self.d = defer.Deferred()
return self.d
def handler(self, eventtype, data):
event = (eventtype, self.parser(data))
if self.d:
assert not self.events
d,self.d = self.d,None
d.callback(event)
return
self.events.append(event)
def wait_for_connection(self):
return self.connected_d
def connected(self):
self._connected = True
self.connected_d.callback(None)
def wait_for_disconnection(self):
return self.disconnected_d
def disconnected(self, why):
if not self._connected:
self.connected_d.errback(why)
self.disconnected_d.callback((why,))
class WSClient(websocket.WebSocketClientProtocol): class WSClient(websocket.WebSocketClientProtocol):
def __init__(self): def __init__(self):
websocket.WebSocketClientProtocol.__init__(self) websocket.WebSocketClientProtocol.__init__(self)

View File

@ -1,238 +0,0 @@
#import sys
from twisted.python import log, failure
from twisted.internet import reactor, defer, protocol
from twisted.application import service
from twisted.protocols import basic
from twisted.web.client import Agent, ResponseDone
from twisted.web.http_headers import Headers
from cgi import parse_header
from .eventual import eventually
#if sys.version_info[0] == 2:
# to_unicode = unicode
#else:
# to_unicode = str
class EventSourceParser(basic.LineOnlyReceiver):
# http://www.w3.org/TR/eventsource/
delimiter = b"\n"
def __init__(self, handler):
self.current_field = None
self.current_lines = []
self.handler = handler
self.done_deferred = defer.Deferred()
self.eventtype = u"message"
self.encoding = "utf-8"
def set_encoding(self, encoding):
self.encoding = encoding
def connectionLost(self, why):
if why.check(ResponseDone):
why = None
self.done_deferred.callback(why)
def dataReceived(self, data):
# exceptions here aren't being logged properly, and tests will hang
# rather than halt. I suspect twisted.web._newclient's
# HTTP11ClientProtocol.dataReceived(), which catches everything and
# responds with self._giveUp() but doesn't log.err.
try:
basic.LineOnlyReceiver.dataReceived(self, data)
except:
log.err()
raise
def lineReceived(self, line):
#line = to_unicode(line, self.encoding)
line = line.decode(self.encoding)
if not line:
# blank line ends the field: deliver event, reset for next
self.eventReceived(self.eventtype, "\n".join(self.current_lines))
self.eventtype = u"message"
self.current_lines[:] = []
return
if u":" in line:
fieldname, data = line.split(u":", 1)
if data.startswith(u" "):
data = data[1:]
else:
fieldname = line
data = u""
if fieldname == u"event":
self.eventtype = data
elif fieldname == u"data":
self.current_lines.append(data)
elif fieldname in (u"id", u"retry"):
# documented but unhandled
pass
else:
log.msg("weird fieldname", fieldname, data)
def eventReceived(self, eventtype, data):
self.handler(eventtype, data)
class EventSourceError(Exception):
pass
# es = EventSource(url, handler)
# d = es.start()
# es.cancel()
class EventSource: # TODO: service.Service
def __init__(self, url, handler, when_connected=None, agent=None):
assert isinstance(url, type(u""))
self.url = url
self.handler = handler
self.when_connected = when_connected
self.started = False
self.cancelled = False
self.proto = EventSourceParser(self.handler)
if not agent:
agent = Agent(reactor)
self.agent = agent
def start(self):
assert not self.started, "single-use"
self.started = True
assert self.url
d = self.agent.request(b"GET", self.url.encode("utf-8"),
Headers({b"accept": [b"text/event-stream"]}))
d.addCallback(self._connected)
return d
def _connected(self, resp):
if resp.code != 200:
raise EventSourceError("%d: %s" % (resp.code, resp.phrase))
if self.when_connected:
self.when_connected()
default_ct = "text/event-stream; charset=utf-8"
ct_headers = resp.headers.getRawHeaders("content-type", [default_ct])
ct, ct_params = parse_header(ct_headers[0])
assert ct == "text/event-stream", ct
self.proto.set_encoding(ct_params.get("charset", "utf-8"))
resp.deliverBody(self.proto)
if self.cancelled:
self.kill_connection()
return self.proto.done_deferred
def cancel(self):
self.cancelled = True
if not self.proto.transport:
# _connected hasn't been called yet, but that self.cancelled
# should take care of it when the connection is established
def kill(data):
# this should kill it as soon as any data is delivered
raise ValueError("dead")
self.proto.dataReceived = kill # just in case
return
self.kill_connection()
def kill_connection(self):
if (hasattr(self.proto.transport, "_producer")
and self.proto.transport._producer):
# This is gross and fragile. We need a clean way to stop the
# client connection. p.transport is a
# twisted.web._newclient.TransportProxyProducer , and its
# ._producer is the tcp.Port.
self.proto.transport._producer.loseConnection()
else:
log.err("get_events: unable to stop connection")
# oh well
#err = EventSourceError("unable to cancel")
try:
self.proto.done_deferred.callback(None)
except defer.AlreadyCalledError:
pass
class Connector:
# behave enough like an IConnector to appease ReconnectingClientFactory
def __init__(self, res):
self.res = res
def connect(self):
self.res._maybeStart()
def stopConnecting(self):
self.res._stop_eventsource()
class ReconnectingEventSource(service.MultiService,
protocol.ReconnectingClientFactory):
def __init__(self, url, handler, agent=None):
service.MultiService.__init__(self)
# we don't use any of the basic Factory/ClientFactory methods of
# this, just the ReconnectingClientFactory.retry, stopTrying, and
# resetDelay methods.
self.url = url
self.handler = handler
self.agent = agent
# IService provides self.running, toggled by {start,stop}Service.
# self.active is toggled by {,de}activate. If both .running and
# .active are True, then we want to have an outstanding EventSource
# and will start one if necessary. If either is False, then we don't
# want one to be outstanding, and will initiate shutdown.
self.active = False
self.connector = Connector(self)
self.es = None # set we have an outstanding EventSource
self.when_stopped = [] # list of Deferreds
def isStopped(self):
return not self.es
def startService(self):
service.MultiService.startService(self) # sets self.running
self._maybeStart()
def stopService(self):
# clears self.running
d = defer.maybeDeferred(service.MultiService.stopService, self)
d.addCallback(self._maybeStop)
return d
def activate(self):
assert not self.active
self.active = True
self._maybeStart()
def deactivate(self):
assert self.active # XXX
self.active = False
return self._maybeStop()
def _maybeStart(self):
if not (self.active and self.running):
return
self.continueTrying = True
self.es = EventSource(self.url, self.handler, self.resetDelay,
agent=self.agent)
d = self.es.start()
d.addBoth(self._stopped)
def _stopped(self, res):
self.es = None
# we might have stopped because of a connection error, or because of
# an intentional shutdown.
if self.active and self.running:
# we still want to be connected, so schedule a reconnection
if isinstance(res, failure.Failure):
log.err(res)
self.retry() # will eventually call _maybeStart
return
# intentional shutdown
self.stopTrying()
for d in self.when_stopped:
eventually(d.callback, None)
self.when_stopped = []
def _stop_eventsource(self):
if self.es:
eventually(self.es.cancel)
def _maybeStop(self, _=None):
self.stopTrying() # cancels timer, calls _stop_eventsource()
if not self.es:
return defer.succeed(None)
d = defer.Deferred()
self.when_stopped.append(d)
return d