From a34fb2a98ba601a9332be3ec094d3de9d70cbc37 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Thu, 12 May 2016 16:56:19 -0700 Subject: [PATCH] remove plain-HTTP (non-WebSocket) rendezvous server --- src/wormhole/server/rendezvous_web.py | 223 ----------------- src/wormhole/server/server.py | 4 +- src/wormhole/test/test_server.py | 337 -------------------------- src/wormhole/twisted/eventsource.py | 238 ------------------ 4 files changed, 1 insertion(+), 801 deletions(-) delete mode 100644 src/wormhole/server/rendezvous_web.py delete mode 100644 src/wormhole/twisted/eventsource.py diff --git a/src/wormhole/server/rendezvous_web.py b/src/wormhole/server/rendezvous_web.py deleted file mode 100644 index c89654e..0000000 --- a/src/wormhole/server/rendezvous_web.py +++ /dev/null @@ -1,223 +0,0 @@ -import json, time -from twisted.web import server, resource -from twisted.python import log - -def json_response(request, data): - request.setHeader(b"content-type", b"application/json; charset=utf-8") - return (json.dumps(data)+"\n").encode("utf-8") - -class EventsProtocol: - def __init__(self, request): - self.request = request - - def sendComment(self, comment): - # this is ignored by clients, but can keep the connection open in the - # face of firewall/NAT timeouts. It also helps unit tests, since - # apparently twisted.web.client.Agent doesn't consider the connection - # to be established until it sees the first byte of the reponse body. - self.request.write(b": " + comment + b"\n\n") - - def sendEvent(self, data, name=None, id=None, retry=None): - if name: - self.request.write(b"event: " + name.encode("utf-8") + b"\n") - # e.g. if name=foo, then the client web page should do: - # (new EventSource(url)).addEventListener("foo", handlerfunc) - # Note that this basically defaults to "message". - if id: - self.request.write(b"id: " + id.encode("utf-8") + b"\n") - if retry: - self.request.write(b"retry: " + retry + b"\n") # milliseconds - for line in data.splitlines(): - self.request.write(b"data: " + line.encode("utf-8") + b"\n") - self.request.write(b"\n") - - def stop(self): - self.request.finish() - - def send_rendezvous_event(self, data): - data = data.copy() - data["sent"] = time.time() - self.sendEvent(json.dumps(data)) - def stop_rendezvous_watcher(self): - self.stop() - -# note: no versions of IE (including the current IE11) support EventSource - -# relay URLs are as follows: (MESSAGES=[{phase:,body:}..]) -# ("-" indicates a deprecated URL) -# GET /list?appid= -> {channelids: [INT..]} -# POST /allocate {appid:,side:} -> {channelid: INT} -# these return all messages (base64) for appid=/channelid= : -# POST /add {appid:,channelid:,side:,phase:,body:} -> {messages: MESSAGES} -# GET /get?appid=&channelid= (no-eventsource) -> {messages: MESSAGES} -#- GET /get?appid=&channelid= (eventsource) -> {phase:, body:}.. -# GET /watch?appid=&channelid= (eventsource) -> {phase:, body:}.. -# POST /deallocate {appid:,channelid:,side:} -> {status: waiting | deleted} -# all JSON responses include a "welcome:{..}" key - -class RelayResource(resource.Resource): - def __init__(self, rendezvous): - resource.Resource.__init__(self) - self._rendezvous = rendezvous - self._welcome = rendezvous.get_welcome() - -class ChannelLister(RelayResource): - def render_GET(self, request): - if b"appid" not in request.args: - e = NeedToUpgradeErrorResource(self._welcome) - return e.get_message() - appid = request.args[b"appid"][0].decode("utf-8") - #print("LIST", appid) - app = self._rendezvous.get_app(appid) - allocated = app.get_allocated() - data = {"welcome": self._welcome, "channelids": sorted(allocated), - "sent": time.time()} - return json_response(request, data) - -class Allocator(RelayResource): - def render_POST(self, request): - content = request.content.read() - data = json.loads(content.decode("utf-8")) - appid = data["appid"] - side = data["side"] - if not isinstance(side, type(u"")): - raise TypeError("side must be string, not '%s'" % type(side)) - #print("ALLOCATE", appid, side) - app = self._rendezvous.get_app(appid) - channelid = app.find_available_channelid() - app.allocate_channel(channelid, side) - if self._rendezvous.get_log_requests(): - log.msg("allocated #%d, now have %d DB channels" % - (channelid, len(app.get_allocated()))) - response = {"welcome": self._welcome, "channelid": channelid, - "sent": time.time()} - return json_response(request, response) - - def getChild(self, path, req): - # wormhole-0.4.0 "send" started with "POST /allocate/SIDE". - # wormhole-0.5.0 changed that to "POST /allocate". We catch the old - # URL here to deliver a nicer error message (with upgrade - # instructions) than an ugly 404. - return NeedToUpgradeErrorResource(self._welcome) - -class NeedToUpgradeErrorResource(resource.Resource): - def __init__(self, welcome): - resource.Resource.__init__(self) - w = welcome.copy() - w["error"] = "Sorry, you must upgrade your client to use this server." - message = {"welcome": w} - self._message = (json.dumps(message)+"\n").encode("utf-8") - def get_message(self): - return self._message - def render_POST(self, request): - return self._message - def render_GET(self, request): - return self._message - def getChild(self, path, req): - return self - -class Adder(RelayResource): - def render_POST(self, request): - #content = json.load(request.content, encoding="utf-8") - content = request.content.read() - data = json.loads(content.decode("utf-8")) - appid = data["appid"] - channelid = int(data["channelid"]) - side = data["side"] - phase = data["phase"] - if not isinstance(phase, type(u"")): - raise TypeError("phase must be string, not %s" % type(phase)) - body = data["body"] - #print("ADD", appid, channelid, side, phase, body) - - app = self._rendezvous.get_app(appid) - channel = app.get_channel(channelid) - messages = channel.add_message(side, phase, body, time.time(), None) - response = {"welcome": self._welcome, "messages": messages, - "sent": time.time()} - return json_response(request, response) - -class GetterOrWatcher(RelayResource): - def render_GET(self, request): - appid = request.args[b"appid"][0].decode("utf-8") - channelid = int(request.args[b"channelid"][0]) - #print("GET", appid, channelid) - app = self._rendezvous.get_app(appid) - channel = app.get_channel(channelid) - - if b"text/event-stream" not in (request.getHeader(b"accept") or b""): - messages = channel.get_messages() - response = {"welcome": self._welcome, "messages": messages, - "sent": time.time()} - return json_response(request, response) - - request.setHeader(b"content-type", b"text/event-stream; charset=utf-8") - ep = EventsProtocol(request) - ep.sendEvent(json.dumps(self._welcome), name="welcome") - old_events = channel.add_listener(ep) - request.notifyFinish().addErrback(lambda f: - channel.remove_listener(ep)) - for old_event in old_events: - ep.send_rendezvous_event(old_event) - return server.NOT_DONE_YET - -class Watcher(RelayResource): - def render_GET(self, request): - appid = request.args[b"appid"][0].decode("utf-8") - channelid = int(request.args[b"channelid"][0]) - app = self._rendezvous.get_app(appid) - channel = app.get_channel(channelid) - if b"text/event-stream" not in (request.getHeader(b"accept") or b""): - raise TypeError("/watch is for EventSource only") - - request.setHeader(b"content-type", b"text/event-stream; charset=utf-8") - ep = EventsProtocol(request) - ep.sendEvent(json.dumps(self._welcome), name="welcome") - old_events = channel.add_listener(ep) - request.notifyFinish().addErrback(lambda f: - channel.remove_listener(ep)) - for old_event in old_events: - ep.send_rendezvous_event(old_event) - return server.NOT_DONE_YET - -class Deallocator(RelayResource): - def render_POST(self, request): - content = request.content.read() - data = json.loads(content.decode("utf-8")) - appid = data["appid"] - channelid = int(data["channelid"]) - side = data["side"] - if not isinstance(side, type(u"")): - raise TypeError("side must be string, not '%s'" % type(side)) - mood = data.get("mood") - #print("DEALLOCATE", appid, channelid, side) - - app = self._rendezvous.get_app(appid) - channel = app.get_channel(channelid) - deleted = channel.deallocate(side, mood) - response = {"status": "waiting", "sent": time.time()} - if deleted: - response = {"status": "deleted", "sent": time.time()} - return json_response(request, response) - - -class WebRendezvous(resource.Resource): - def __init__(self, rendezvous): - resource.Resource.__init__(self) - self._rendezvous = rendezvous - self.putChild(b"list", ChannelLister(rendezvous)) - self.putChild(b"allocate", Allocator(rendezvous)) - self.putChild(b"add", Adder(rendezvous)) - self.putChild(b"get", GetterOrWatcher(rendezvous)) - self.putChild(b"watch", Watcher(rendezvous)) - self.putChild(b"deallocate", Deallocator(rendezvous)) - - def getChild(self, path, req): - # 0.4.0 used "POST /CID/SIDE/post/MSGNUM" - # 0.5.0 replaced it with "POST /add (json body)" - # give a nicer error message to old clients - if (len(req.postpath) >= 2 - and req.postpath[1] in (b"post", b"poll", b"deallocate")): - welcome = self._rendezvous.get_welcome() - return NeedToUpgradeErrorResource(welcome) - return resource.NoResource("No such child resource.") diff --git a/src/wormhole/server/server.py b/src/wormhole/server/server.py index 166a6ba..694ed19 100644 --- a/src/wormhole/server/server.py +++ b/src/wormhole/server/server.py @@ -8,7 +8,6 @@ from .endpoint_service import ServerEndpointService from .. import __version__ from .database import get_db from .rendezvous import Rendezvous -from .rendezvous_web import WebRendezvous from .rendezvous_websocket import WebSocketRendezvousFactory from .transit_server import Transit @@ -47,7 +46,7 @@ class RelayServer(service.MultiService): rendezvous.setServiceParent(self) # for the pruning timer root = Root() - wr = WebRendezvous(rendezvous) + wr = resource.Resource() root.putChild(b"wormhole-relay", wr) wsrf = WebSocketRendezvousFactory(None, rendezvous) @@ -72,7 +71,6 @@ class RelayServer(service.MultiService): self._db = db self._rendezvous = rendezvous self._root = root - self._rendezvous_web = wr self._rendezvous_web_service = rendezvous_web_service self._rendezvous_websocket = wsrf if transit_port: diff --git a/src/wormhole/test/test_server.py b/src/wormhole/test/test_server.py index d9010f1..c52bd50 100644 --- a/src/wormhole/test/test_server.py +++ b/src/wormhole/test/test_server.py @@ -1,19 +1,15 @@ from __future__ import print_function import json, itertools from binascii import hexlify -import requests -from six.moves.urllib_parse import urlencode from twisted.trial import unittest from twisted.internet import protocol, reactor, defer from twisted.internet.defer import inlineCallbacks, returnValue -from twisted.internet.threads import deferToThread from twisted.internet.endpoints import clientFromString, connectProtocol from twisted.web.client import getPage, Agent, readBody from autobahn.twisted import websocket from .. import __version__ from .common import ServerBase from ..server import rendezvous, transit_server -from ..twisted.eventsource import EventSource class Reachable(ServerBase, unittest.TestCase): @@ -39,22 +35,6 @@ class Reachable(ServerBase, unittest.TestCase): d.addCallback(_got) return d - def test_requests(self): - # requests requires bytes URL, returns unicode - url = self.relayurl.replace("wormhole-relay/", "") - def _get(url): - r = requests.get(url) - r.raise_for_status() - return r.text - d = deferToThread(_get, url) - def _got(res): - self.failUnlessEqual(res, "Wormhole Relay\n") - d.addCallback(_got) - return d - -def unjson(data): - return json.loads(data.decode("utf-8")) - def strip_message(msg): m2 = msg.copy() m2.pop("id", None) @@ -64,323 +44,6 @@ def strip_message(msg): def strip_messages(messages): return [strip_message(m) for m in messages] -class WebAPI(ServerBase, unittest.TestCase): - def build_url(self, path, appid, channelid): - url = self.relayurl+path - queryargs = [] - if appid: - queryargs.append(("appid", appid)) - if channelid: - queryargs.append(("channelid", channelid)) - if queryargs: - url += "?" + urlencode(queryargs) - return url - - def get(self, path, appid=None, channelid=None): - url = self.build_url(path, appid, channelid) - d = getPage(url.encode("ascii")) - d.addCallback(unjson) - return d - - def post(self, path, data): - url = self.relayurl+path - d = getPage(url.encode("ascii"), method=b"POST", - postdata=json.dumps(data).encode("utf-8")) - d.addCallback(unjson) - return d - - def check_welcome(self, data): - self.failUnlessIn("welcome", data) - self.failUnlessEqual(data["welcome"], {"current_version": __version__}) - - def test_allocate_1(self): - d = self.get("list", "app1") - def _check_list_1(data): - self.check_welcome(data) - self.failUnlessEqual(data["channelids"], []) - d.addCallback(_check_list_1) - - d.addCallback(lambda _: self.post("allocate", {"appid": "app1", - "side": "abc"})) - def _allocated(data): - data.pop("sent", None) - self.failUnlessEqual(set(data.keys()), - set(["welcome", "channelid"])) - self.failUnlessIsInstance(data["channelid"], int) - self.cid = data["channelid"] - d.addCallback(_allocated) - - d.addCallback(lambda _: self.get("list", "app1")) - def _check_list_2(data): - self.failUnlessEqual(data["channelids"], [self.cid]) - d.addCallback(_check_list_2) - - d.addCallback(lambda _: self.post("deallocate", - {"appid": "app1", - "channelid": str(self.cid), - "side": "abc"})) - def _check_deallocate(res): - self.failUnlessEqual(res["status"], "deleted") - d.addCallback(_check_deallocate) - - d.addCallback(lambda _: self.get("list", "app1")) - def _check_list_3(data): - self.failUnlessEqual(data["channelids"], []) - d.addCallback(_check_list_3) - - return d - - def test_allocate_2(self): - d = self.post("allocate", {"appid": "app1", "side": "abc"}) - def _allocated(data): - self.cid = data["channelid"] - d.addCallback(_allocated) - - # second caller increases the number of known sides to 2 - d.addCallback(lambda _: self.post("add", - {"appid": "app1", - "channelid": str(self.cid), - "side": "def", - "phase": "1", - "body": ""})) - - d.addCallback(lambda _: self.get("list", "app1")) - d.addCallback(lambda data: - self.failUnlessEqual(data["channelids"], [self.cid])) - - d.addCallback(lambda _: self.post("deallocate", - {"appid": "app1", - "channelid": str(self.cid), - "side": "abc"})) - d.addCallback(lambda res: - self.failUnlessEqual(res["status"], "waiting")) - - d.addCallback(lambda _: self.post("deallocate", - {"appid": "app1", - "channelid": str(self.cid), - "side": "NOT"})) - d.addCallback(lambda res: - self.failUnlessEqual(res["status"], "waiting")) - - d.addCallback(lambda _: self.post("deallocate", - {"appid": "app1", - "channelid": str(self.cid), - "side": "def"})) - d.addCallback(lambda res: - self.failUnlessEqual(res["status"], "deleted")) - - d.addCallback(lambda _: self.get("list", "app1")) - d.addCallback(lambda data: - self.failUnlessEqual(data["channelids"], [])) - - return d - - UPGRADE_ERROR = "Sorry, you must upgrade your client to use this server." - def test_old_allocate(self): - # 0.4.0 used "POST /allocate/SIDE". - # 0.5.0 replaced it with "POST /allocate". - # test that an old client gets a useful error message, not a 404. - d = self.post("allocate/abc", {}) - def _check(data): - self.failUnlessEqual(data["welcome"]["error"], self.UPGRADE_ERROR) - d.addCallback(_check) - return d - - def test_old_list(self): - # 0.4.0 used "GET /list". - # 0.5.0 replaced it with "GET /list?appid=" - d = self.get("list", {}) # no appid - def _check(data): - self.failUnlessEqual(data["welcome"]["error"], self.UPGRADE_ERROR) - d.addCallback(_check) - return d - - def test_old_post(self): - # 0.4.0 used "POST /CID/SIDE/post/MSGNUM" - # 0.5.0 replaced it with "POST /add (json body)" - d = self.post("1/abc/post/pake", {}) - def _check(data): - self.failUnlessEqual(data["welcome"]["error"], self.UPGRADE_ERROR) - d.addCallback(_check) - return d - - def add_message(self, message, side="abc", phase="1"): - return self.post("add", - {"appid": "app1", - "channelid": str(self.cid), - "side": side, - "phase": phase, - "body": message}) - - def parse_messages(self, messages): - out = set() - for m in messages: - self.failUnlessEqual(sorted(m.keys()), sorted(["phase", "body"])) - self.failUnlessIsInstance(m["phase"], type(u"")) - self.failUnlessIsInstance(m["body"], type(u"")) - out.add( (m["phase"], m["body"]) ) - return out - - def check_messages(self, one, two): - # Comparing lists-of-dicts is non-trivial in python3 because we can - # neither sort them (dicts are uncomparable), nor turn them into sets - # (dicts are unhashable). This is close enough. - self.failUnlessEqual(len(one), len(two), (one,two)) - for d in one: - self.failUnlessIn(d, two) - - def test_message(self): - # exercise POST /add - d = self.post("allocate", {"appid": "app1", "side": "abc"}) - def _allocated(data): - self.cid = data["channelid"] - d.addCallback(_allocated) - - d.addCallback(lambda _: self.add_message("msg1A")) - def _check1(data): - self.check_welcome(data) - self.failUnlessEqual(strip_messages(data["messages"]), - [{"phase": "1", "body": "msg1A"}]) - d.addCallback(_check1) - d.addCallback(lambda _: self.get("get", "app1", str(self.cid))) - d.addCallback(_check1) - d.addCallback(lambda _: self.add_message("msg1B", side="def")) - def _check2(data): - self.check_welcome(data) - self.failUnlessEqual(self.parse_messages(strip_messages(data["messages"])), - set([("1", "msg1A"), - ("1", "msg1B")])) - d.addCallback(_check2) - d.addCallback(lambda _: self.get("get", "app1", str(self.cid))) - d.addCallback(_check2) - - # adding a duplicate message is not an error, is ignored by clients - d.addCallback(lambda _: self.add_message("msg1B", side="def")) - def _check3(data): - self.check_welcome(data) - self.failUnlessEqual(self.parse_messages(strip_messages(data["messages"])), - set([("1", "msg1A"), - ("1", "msg1B")])) - d.addCallback(_check3) - d.addCallback(lambda _: self.get("get", "app1", str(self.cid))) - d.addCallback(_check3) - - d.addCallback(lambda _: self.add_message("msg2A", side="abc", - phase="2")) - def _check4(data): - self.check_welcome(data) - self.failUnlessEqual(self.parse_messages(strip_messages(data["messages"])), - set([("1", "msg1A"), - ("1", "msg1B"), - ("2", "msg2A"), - ])) - d.addCallback(_check4) - d.addCallback(lambda _: self.get("get", "app1", str(self.cid))) - d.addCallback(_check4) - - return d - - def test_watch_message(self): - # exercise GET /get (the EventSource version) - # this API is scheduled to be removed after 0.6.0 - return self._do_watch("get") - - def test_watch(self): - # exercise GET /watch (the EventSource version) - return self._do_watch("watch") - - def _do_watch(self, endpoint_name): - d = self.post("allocate", {"appid": "app1", "side": "abc"}) - def _allocated(data): - self.cid = data["channelid"] - url = self.build_url(endpoint_name, "app1", self.cid) - self.o = OneEventAtATime(url, parser=json.loads) - return self.o.wait_for_connection() - d.addCallback(_allocated) - d.addCallback(lambda _: self.o.wait_for_next_event()) - def _check_welcome(ev): - eventtype, data = ev - self.failUnlessEqual(eventtype, "welcome") - self.failUnlessEqual(data, {"current_version": __version__}) - d.addCallback(_check_welcome) - d.addCallback(lambda _: self.add_message("msg1A")) - d.addCallback(lambda _: self.o.wait_for_next_event()) - def _check_msg1(ev): - eventtype, data = ev - self.failUnlessEqual(eventtype, "message") - data.pop("sent", None) - self.failUnlessEqual(strip_message(data), - {"phase": "1", "body": "msg1A"}) - d.addCallback(_check_msg1) - - d.addCallback(lambda _: self.add_message("msg1B")) - d.addCallback(lambda _: self.add_message("msg2A", phase="2")) - d.addCallback(lambda _: self.o.wait_for_next_event()) - def _check_msg2(ev): - eventtype, data = ev - self.failUnlessEqual(eventtype, "message") - data.pop("sent", None) - self.failUnlessEqual(strip_message(data), - {"phase": "1", "body": "msg1B"}) - d.addCallback(_check_msg2) - d.addCallback(lambda _: self.o.wait_for_next_event()) - def _check_msg3(ev): - eventtype, data = ev - self.failUnlessEqual(eventtype, "message") - data.pop("sent", None) - self.failUnlessEqual(strip_message(data), - {"phase": "2", "body": "msg2A"}) - d.addCallback(_check_msg3) - - d.addCallback(lambda _: self.o.close()) - d.addCallback(lambda _: self.o.wait_for_disconnection()) - return d - -class OneEventAtATime: - def __init__(self, url, parser=lambda e: e): - self.parser = parser - self.d = None - self._connected = False - self.connected_d = defer.Deferred() - self.disconnected_d = defer.Deferred() - self.events = [] - self.es = EventSource(url, self.handler, when_connected=self.connected) - d = self.es.start() - d.addBoth(self.disconnected) - - def close(self): - self.es.cancel() - - def wait_for_next_event(self): - assert not self.d - if self.events: - event = self.events.pop(0) - return defer.succeed(event) - self.d = defer.Deferred() - return self.d - - def handler(self, eventtype, data): - event = (eventtype, self.parser(data)) - if self.d: - assert not self.events - d,self.d = self.d,None - d.callback(event) - return - self.events.append(event) - - def wait_for_connection(self): - return self.connected_d - def connected(self): - self._connected = True - self.connected_d.callback(None) - - def wait_for_disconnection(self): - return self.disconnected_d - def disconnected(self, why): - if not self._connected: - self.connected_d.errback(why) - self.disconnected_d.callback((why,)) - class WSClient(websocket.WebSocketClientProtocol): def __init__(self): websocket.WebSocketClientProtocol.__init__(self) diff --git a/src/wormhole/twisted/eventsource.py b/src/wormhole/twisted/eventsource.py deleted file mode 100644 index 19272ca..0000000 --- a/src/wormhole/twisted/eventsource.py +++ /dev/null @@ -1,238 +0,0 @@ -#import sys -from twisted.python import log, failure -from twisted.internet import reactor, defer, protocol -from twisted.application import service -from twisted.protocols import basic -from twisted.web.client import Agent, ResponseDone -from twisted.web.http_headers import Headers -from cgi import parse_header -from .eventual import eventually - -#if sys.version_info[0] == 2: -# to_unicode = unicode -#else: -# to_unicode = str - -class EventSourceParser(basic.LineOnlyReceiver): - # http://www.w3.org/TR/eventsource/ - delimiter = b"\n" - - def __init__(self, handler): - self.current_field = None - self.current_lines = [] - self.handler = handler - self.done_deferred = defer.Deferred() - self.eventtype = u"message" - self.encoding = "utf-8" - - def set_encoding(self, encoding): - self.encoding = encoding - - def connectionLost(self, why): - if why.check(ResponseDone): - why = None - self.done_deferred.callback(why) - - def dataReceived(self, data): - # exceptions here aren't being logged properly, and tests will hang - # rather than halt. I suspect twisted.web._newclient's - # HTTP11ClientProtocol.dataReceived(), which catches everything and - # responds with self._giveUp() but doesn't log.err. - try: - basic.LineOnlyReceiver.dataReceived(self, data) - except: - log.err() - raise - - def lineReceived(self, line): - #line = to_unicode(line, self.encoding) - line = line.decode(self.encoding) - if not line: - # blank line ends the field: deliver event, reset for next - self.eventReceived(self.eventtype, "\n".join(self.current_lines)) - self.eventtype = u"message" - self.current_lines[:] = [] - return - if u":" in line: - fieldname, data = line.split(u":", 1) - if data.startswith(u" "): - data = data[1:] - else: - fieldname = line - data = u"" - if fieldname == u"event": - self.eventtype = data - elif fieldname == u"data": - self.current_lines.append(data) - elif fieldname in (u"id", u"retry"): - # documented but unhandled - pass - else: - log.msg("weird fieldname", fieldname, data) - - def eventReceived(self, eventtype, data): - self.handler(eventtype, data) - -class EventSourceError(Exception): - pass - -# es = EventSource(url, handler) -# d = es.start() -# es.cancel() - -class EventSource: # TODO: service.Service - def __init__(self, url, handler, when_connected=None, agent=None): - assert isinstance(url, type(u"")) - self.url = url - self.handler = handler - self.when_connected = when_connected - self.started = False - self.cancelled = False - self.proto = EventSourceParser(self.handler) - if not agent: - agent = Agent(reactor) - self.agent = agent - - def start(self): - assert not self.started, "single-use" - self.started = True - assert self.url - d = self.agent.request(b"GET", self.url.encode("utf-8"), - Headers({b"accept": [b"text/event-stream"]})) - d.addCallback(self._connected) - return d - - def _connected(self, resp): - if resp.code != 200: - raise EventSourceError("%d: %s" % (resp.code, resp.phrase)) - if self.when_connected: - self.when_connected() - default_ct = "text/event-stream; charset=utf-8" - ct_headers = resp.headers.getRawHeaders("content-type", [default_ct]) - ct, ct_params = parse_header(ct_headers[0]) - assert ct == "text/event-stream", ct - self.proto.set_encoding(ct_params.get("charset", "utf-8")) - resp.deliverBody(self.proto) - if self.cancelled: - self.kill_connection() - return self.proto.done_deferred - - def cancel(self): - self.cancelled = True - if not self.proto.transport: - # _connected hasn't been called yet, but that self.cancelled - # should take care of it when the connection is established - def kill(data): - # this should kill it as soon as any data is delivered - raise ValueError("dead") - self.proto.dataReceived = kill # just in case - return - self.kill_connection() - - def kill_connection(self): - if (hasattr(self.proto.transport, "_producer") - and self.proto.transport._producer): - # This is gross and fragile. We need a clean way to stop the - # client connection. p.transport is a - # twisted.web._newclient.TransportProxyProducer , and its - # ._producer is the tcp.Port. - self.proto.transport._producer.loseConnection() - else: - log.err("get_events: unable to stop connection") - # oh well - #err = EventSourceError("unable to cancel") - try: - self.proto.done_deferred.callback(None) - except defer.AlreadyCalledError: - pass - - -class Connector: - # behave enough like an IConnector to appease ReconnectingClientFactory - def __init__(self, res): - self.res = res - def connect(self): - self.res._maybeStart() - def stopConnecting(self): - self.res._stop_eventsource() - -class ReconnectingEventSource(service.MultiService, - protocol.ReconnectingClientFactory): - def __init__(self, url, handler, agent=None): - service.MultiService.__init__(self) - # we don't use any of the basic Factory/ClientFactory methods of - # this, just the ReconnectingClientFactory.retry, stopTrying, and - # resetDelay methods. - - self.url = url - self.handler = handler - self.agent = agent - # IService provides self.running, toggled by {start,stop}Service. - # self.active is toggled by {,de}activate. If both .running and - # .active are True, then we want to have an outstanding EventSource - # and will start one if necessary. If either is False, then we don't - # want one to be outstanding, and will initiate shutdown. - self.active = False - self.connector = Connector(self) - self.es = None # set we have an outstanding EventSource - self.when_stopped = [] # list of Deferreds - - def isStopped(self): - return not self.es - - def startService(self): - service.MultiService.startService(self) # sets self.running - self._maybeStart() - - def stopService(self): - # clears self.running - d = defer.maybeDeferred(service.MultiService.stopService, self) - d.addCallback(self._maybeStop) - return d - - def activate(self): - assert not self.active - self.active = True - self._maybeStart() - - def deactivate(self): - assert self.active # XXX - self.active = False - return self._maybeStop() - - def _maybeStart(self): - if not (self.active and self.running): - return - self.continueTrying = True - self.es = EventSource(self.url, self.handler, self.resetDelay, - agent=self.agent) - d = self.es.start() - d.addBoth(self._stopped) - - def _stopped(self, res): - self.es = None - # we might have stopped because of a connection error, or because of - # an intentional shutdown. - if self.active and self.running: - # we still want to be connected, so schedule a reconnection - if isinstance(res, failure.Failure): - log.err(res) - self.retry() # will eventually call _maybeStart - return - # intentional shutdown - self.stopTrying() - for d in self.when_stopped: - eventually(d.callback, None) - self.when_stopped = [] - - def _stop_eventsource(self): - if self.es: - eventually(self.es.cancel) - - def _maybeStop(self, _=None): - self.stopTrying() # cancels timer, calls _stop_eventsource() - if not self.es: - return defer.succeed(None) - d = defer.Deferred() - self.when_stopped.append(d) - return d