diff --git a/src/wormhole/servers/relay.py b/src/wormhole/servers/relay.py index cf25183..1f4eb3c 100644 --- a/src/wormhole/servers/relay.py +++ b/src/wormhole/servers/relay.py @@ -1,9 +1,10 @@ from __future__ import print_function import re, json, time, random from twisted.python import log -from twisted.internet import protocol -from twisted.application import strports, service, internet +from twisted.internet import reactor, protocol, endpoints +from twisted.application import service, internet from twisted.web import server, static, resource, http +from ..util.endpoint_service import EndpointServerService from .. import __version__ from ..database import get_db @@ -400,7 +401,8 @@ class RelayServer(service.MultiService): welcome["current_version"] = advertise_version self.root = Root() site = server.Site(self.root) - self.relayport_service = strports.service(relayport, site) + r = endpoints.serverFromString(reactor, relayport) + self.relayport_service = EndpointServerService(r, site) self.relayport_service.setServiceParent(self) self.relay = Relay(self.db, welcome) # accessible from tests self.root.putChild("wormhole-relay", self.relay) @@ -410,5 +412,6 @@ class RelayServer(service.MultiService): if transitport: self.transit = Transit() self.transit.setServiceParent(self) # for the timer - self.transport_service = strports.service(transitport, self.transit) + t = endpoints.serverFromString(reactor, transitport) + self.transport_service = EndpointServerService(t, self.transit) self.transport_service.setServiceParent(self) diff --git a/src/wormhole/util/endpoint_service.py b/src/wormhole/util/endpoint_service.py new file mode 100644 index 0000000..c2b3069 --- /dev/null +++ b/src/wormhole/util/endpoint_service.py @@ -0,0 +1,26 @@ +from twisted.python import log +from twisted.internet import defer +from twisted.application import service + +# this should probably live in Twisted + +class EndpointServerService(service.Service): + def __init__(self, endpoint, factory): + self.endpoint = endpoint + self.factory = factory + self._started = defer.Deferred() + self._listeningport = None + + def startService(self): + d = self.endpoint.listen(self.factory) + def _set_port(listeningport): + self._listeningport = listeningport + self._started.callback(listeningport) + d.addCallback(_set_port) + d.addErrback(log.err) + + def stopService(self): + def _stop(port): + return port.stopListening() + self._started.addCallback(_stop) + return self._started