relay: expire any rendezvous channel after one hour

This commit is contained in:
Brian Warner 2015-03-02 21:22:56 -08:00
parent 20fd7c40ae
commit 9a11f355ea

View File

@ -1,14 +1,18 @@
from __future__ import print_function from __future__ import print_function
import re, json import re, json, time
from collections import defaultdict from collections import defaultdict
from twisted.python import log from twisted.python import log
from twisted.internet import protocol from twisted.internet import protocol
from twisted.application import strports, service from twisted.application import strports, service, internet
from twisted.web import server, static, resource, http from twisted.web import server, static, resource, http
SECONDS = 1.0 SECONDS = 1.0
MINUTE = 60*SECONDS
HOUR = 60*MINUTE
MB = 1000*1000 MB = 1000*1000
CHANNEL_EXPIRATION_TIME = 1*HOUR
class Channel(resource.Resource): class Channel(resource.Resource):
isLeaf = True isLeaf = True
@ -22,6 +26,7 @@ class Channel(resource.Resource):
resource.Resource.__init__(self) resource.Resource.__init__(self)
self.channel_id = channel_id self.channel_id = channel_id
self.relay = relay self.relay = relay
self.expire_at = time.time() + CHANNEL_EXPIRATION_TIME
self.sides = set() self.sides = set()
self.messages = {"pake": defaultdict(list), # side -> [strings] self.messages = {"pake": defaultdict(list), # side -> [strings]
"data": defaultdict(list), # side -> [strings] "data": defaultdict(list), # side -> [strings]
@ -83,6 +88,14 @@ class Relay(resource.Resource):
self.channels = {} self.channels = {}
self.next_channel = 1 self.next_channel = 1
def prune_old_channels(self):
now = time.time()
for channel_id in list(self.channels):
c = self.channels[channel_id]
if c.expire_at < now:
log.msg("expiring %d" % channel_id)
self.free_child(channel_id)
def getChild(self, path, request): def getChild(self, path, request):
if path == "allocate": if path == "allocate":
# be more clever later. Rotate through 1-99 unless they're all # be more clever later. Rotate through 1-99 unless they're all
@ -243,8 +256,10 @@ class RelayServer(service.MultiService):
site = server.Site(self.root) site = server.Site(self.root)
self.relayport_service = strports.service(relayport, site) self.relayport_service = strports.service(relayport, site)
self.relayport_service.setServiceParent(self) self.relayport_service.setServiceParent(self)
self.relay = Relay() # for tests self.relay = Relay() # accessible from tests
self.root.putChild("relay", self.relay) self.root.putChild("relay", self.relay)
t = internet.TimerService(5*MINUTE, self.relay.prune_old_channels)
t.setServiceParent(self)
self.transit = Transit() self.transit = Transit()
self.transit.setServiceParent(self) # for the timer self.transit.setServiceParent(self) # for the timer
self.transport_service = strports.service(transitport, self.transit) self.transport_service = strports.service(transitport, self.transit)