Add an allow_list option to control nameplate-listings
This commit is contained in:
parent
95651f24f9
commit
6b31517b67
|
@ -48,9 +48,13 @@ def server(ctx): # this is the setuptools entrypoint for bin/wormhole-server
|
|||
"--signal-error", is_flag=True,
|
||||
help="force all clients to fail with a message",
|
||||
)
|
||||
@click.option(
|
||||
"--disallow-list", is_flag=True,
|
||||
help="never send list of allocated nameplates",
|
||||
)
|
||||
@click.pass_obj
|
||||
def start(cfg, signal_error, no_daemon, blur_usage, advertise_version,
|
||||
transit, rendezvous):
|
||||
transit, rendezvous, disallow_list):
|
||||
"""
|
||||
Start a relay server
|
||||
"""
|
||||
|
@ -61,6 +65,7 @@ def start(cfg, signal_error, no_daemon, blur_usage, advertise_version,
|
|||
cfg.transit = str(transit)
|
||||
cfg.rendezvous = str(rendezvous)
|
||||
cfg.signal_error = signal_error
|
||||
cfg.allow_list = not disallow_list
|
||||
|
||||
start_server(cfg)
|
||||
|
||||
|
|
|
@ -3,20 +3,26 @@ import os, time
|
|||
from twisted.python import usage
|
||||
from twisted.scripts import twistd
|
||||
|
||||
class MyPlugin:
|
||||
class MyPlugin(object):
|
||||
tapname = "xyznode"
|
||||
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
|
||||
def makeService(self, so):
|
||||
# delay this import as late as possible, to allow twistd's code to
|
||||
# accept --reactor= selection
|
||||
from .server import RelayServer
|
||||
return RelayServer(self.args.rendezvous, self.args.transit,
|
||||
self.args.advertise_version,
|
||||
"relay.sqlite", self.args.blur_usage,
|
||||
signal_error=self.args.signal_error,
|
||||
stats_file="stats.json",
|
||||
)
|
||||
return RelayServer(
|
||||
self.args.rendezvous,
|
||||
self.args.transit,
|
||||
self.args.advertise_version,
|
||||
"relay.sqlite",
|
||||
self.args.blur_usage,
|
||||
signal_error=self.args.signal_error,
|
||||
stats_file="stats.json",
|
||||
allow_list=self.args.allow_list,
|
||||
)
|
||||
|
||||
class MyTwistdConfig(twistd.ServerOptions):
|
||||
subCommands = [("XYZ", None, usage.Options, "node")]
|
||||
|
|
|
@ -159,8 +159,10 @@ class Mailbox:
|
|||
stop_f()
|
||||
self._listeners = {}
|
||||
|
||||
class AppNamespace:
|
||||
def __init__(self, db, blur_usage, log_requests, app_id):
|
||||
|
||||
class AppNamespace(object):
|
||||
|
||||
def __init__(self, db, blur_usage, log_requests, app_id, allow_list):
|
||||
self._db = db
|
||||
self._blur_usage = blur_usage
|
||||
self._log_requests = log_requests
|
||||
|
@ -168,8 +170,14 @@ class AppNamespace:
|
|||
self._mailboxes = {}
|
||||
self._nameplate_counts = collections.defaultdict(int)
|
||||
self._mailbox_counts = collections.defaultdict(int)
|
||||
self._allow_list = allow_list
|
||||
|
||||
def get_nameplate_ids(self):
|
||||
if not self._allow_list:
|
||||
return []
|
||||
return self._get_nameplate_ids()
|
||||
|
||||
def _get_nameplate_ids(self):
|
||||
db = self._db
|
||||
# TODO: filter this to numeric ids?
|
||||
c = db.execute("SELECT DISTINCT `name` FROM `nameplates`"
|
||||
|
@ -177,7 +185,7 @@ class AppNamespace:
|
|||
return set([row["name"] for row in c.fetchall()])
|
||||
|
||||
def _find_available_nameplate_id(self):
|
||||
claimed = self.get_nameplate_ids()
|
||||
claimed = self._get_nameplate_ids()
|
||||
for size in range(1,4): # stick to 1-999 for now
|
||||
available = set()
|
||||
for id_int in range(10**(size-1), 10**size):
|
||||
|
@ -505,14 +513,17 @@ class AppNamespace:
|
|||
for channel in self._mailboxes.values():
|
||||
channel._shutdown()
|
||||
|
||||
|
||||
class Rendezvous(service.MultiService):
|
||||
def __init__(self, db, welcome, blur_usage):
|
||||
|
||||
def __init__(self, db, welcome, blur_usage, allow_list):
|
||||
service.MultiService.__init__(self)
|
||||
self._db = db
|
||||
self._welcome = welcome
|
||||
self._blur_usage = blur_usage
|
||||
log_requests = blur_usage is None
|
||||
self._log_requests = log_requests
|
||||
self._allow_list = allow_list
|
||||
self._apps = {}
|
||||
|
||||
def get_welcome(self):
|
||||
|
@ -525,9 +536,13 @@ class Rendezvous(service.MultiService):
|
|||
if not app_id in self._apps:
|
||||
if self._log_requests:
|
||||
log.msg("spawning app_id %s" % (app_id,))
|
||||
self._apps[app_id] = AppNamespace(self._db,
|
||||
self._blur_usage,
|
||||
self._log_requests, app_id)
|
||||
self._apps[app_id] = AppNamespace(
|
||||
self._db,
|
||||
self._blur_usage,
|
||||
self._log_requests,
|
||||
app_id,
|
||||
self._allow_list,
|
||||
)
|
||||
return self._apps[app_id]
|
||||
|
||||
def get_all_apps(self):
|
||||
|
|
|
@ -292,6 +292,7 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol):
|
|||
|
||||
class WebSocketRendezvousFactory(websocket.WebSocketServerFactory):
|
||||
protocol = WebSocketRendezvous
|
||||
|
||||
def __init__(self, url, rendezvous):
|
||||
websocket.WebSocketServerFactory.__init__(self, url)
|
||||
self.setProtocolOptions(autoPingInterval=60, autoPingTimeout=600)
|
||||
|
|
|
@ -32,11 +32,13 @@ class PrivacyEnhancedSite(server.Site):
|
|||
return server.Site.log(self, request)
|
||||
|
||||
class RelayServer(service.MultiService):
|
||||
|
||||
def __init__(self, rendezvous_web_port, transit_port,
|
||||
advertise_version, db_url=":memory:", blur_usage=None,
|
||||
signal_error=None, stats_file=None):
|
||||
signal_error=None, stats_file=None, allow_list=True):
|
||||
service.MultiService.__init__(self)
|
||||
self._blur_usage = blur_usage
|
||||
self._allow_list = allow_list
|
||||
|
||||
db = get_db(db_url)
|
||||
welcome = {
|
||||
|
@ -58,7 +60,7 @@ class RelayServer(service.MultiService):
|
|||
if signal_error:
|
||||
welcome["error"] = signal_error
|
||||
|
||||
self._rendezvous = Rendezvous(db, welcome, blur_usage)
|
||||
self._rendezvous = Rendezvous(db, welcome, blur_usage, self._allow_list)
|
||||
self._rendezvous.setServiceParent(self) # for the pruning timer
|
||||
|
||||
root = Root()
|
||||
|
@ -108,6 +110,8 @@ class RelayServer(service.MultiService):
|
|||
log.msg("not logging HTTP requests or Transit connections")
|
||||
else:
|
||||
log.msg("not blurring access times")
|
||||
if not self._allow_list:
|
||||
log.msg("listing of allocated nameplates disallowed")
|
||||
|
||||
def timer(self):
|
||||
now = time.time()
|
||||
|
|
|
@ -3,6 +3,7 @@ import os, sys, re, io, zipfile, six, stat
|
|||
from textwrap import fill, dedent
|
||||
from humanize import naturalsize
|
||||
import mock
|
||||
import click.testing
|
||||
from twisted.trial import unittest
|
||||
from twisted.python import procutils, log
|
||||
from twisted.internet import defer, endpoints, reactor
|
||||
|
@ -12,6 +13,8 @@ from .. import __version__
|
|||
from .common import ServerBase, config
|
||||
from ..cli import cmd_send, cmd_receive, welcome, cli
|
||||
from ..errors import TransferError, WrongPasswordError, WelcomeError
|
||||
from wormhole.server.cmd_server import MyPlugin
|
||||
from wormhole.server.cli import server
|
||||
|
||||
|
||||
def build_offer(args):
|
||||
|
@ -1065,3 +1068,28 @@ class Dispatch(unittest.TestCase):
|
|||
expected = "<TRACEBACK>\nERROR: abcd\n"
|
||||
self.assertEqual(cfg.stderr.getvalue(), expected)
|
||||
|
||||
|
||||
class Server(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.runner = click.testing.CliRunner()
|
||||
|
||||
@mock.patch('wormhole.server.cmd_server.twistd')
|
||||
def test_server_disallow_list(self, fake_twistd):
|
||||
result = self.runner.invoke(server, ['start', '--no-daemon', '--disallow-list'])
|
||||
self.assertEqual(0, result.exit_code)
|
||||
|
||||
def test_server_plugin(self):
|
||||
class FakeConfig(object):
|
||||
no_daemon = True
|
||||
blur_usage = True
|
||||
advertise_version = u"fake.version.1"
|
||||
transit = str('tcp:4321')
|
||||
rendezvous = str('tcp:1234')
|
||||
signal_error = True
|
||||
allow_list = False
|
||||
|
||||
cfg = FakeConfig()
|
||||
plugin = MyPlugin(cfg)
|
||||
relay = plugin.makeService(None)
|
||||
self.assertEqual(False, relay._allow_list)
|
||||
|
|
|
@ -301,7 +301,7 @@ class Prune(unittest.TestCase):
|
|||
|
||||
def test_update(self):
|
||||
db = get_db(":memory:")
|
||||
rv = rendezvous.Rendezvous(db, None, None)
|
||||
rv = rendezvous.Rendezvous(db, None, None, True)
|
||||
app = rv.get_app("appid")
|
||||
mbox_id = "mbox1"
|
||||
app.open_mailbox(mbox_id, "side1", 1)
|
||||
|
@ -315,7 +315,7 @@ class Prune(unittest.TestCase):
|
|||
self.assertEqual(self._get_mailbox_updated(app, mbox_id), 3)
|
||||
|
||||
def test_apps(self):
|
||||
rv = rendezvous.Rendezvous(get_db(":memory:"), None, None)
|
||||
rv = rendezvous.Rendezvous(get_db(":memory:"), None, None, True)
|
||||
app = rv.get_app("appid")
|
||||
app.allocate_nameplate("side", 121)
|
||||
app.prune = mock.Mock()
|
||||
|
@ -324,7 +324,7 @@ class Prune(unittest.TestCase):
|
|||
|
||||
def test_nameplates(self):
|
||||
db = get_db(":memory:")
|
||||
rv = rendezvous.Rendezvous(db, None, 3600)
|
||||
rv = rendezvous.Rendezvous(db, None, 3600, True)
|
||||
|
||||
# timestamps <=50 are "old", >=51 are "new"
|
||||
#OLD = "old"; NEW = "new"
|
||||
|
@ -358,7 +358,7 @@ class Prune(unittest.TestCase):
|
|||
|
||||
def test_mailboxes(self):
|
||||
db = get_db(":memory:")
|
||||
rv = rendezvous.Rendezvous(db, None, 3600)
|
||||
rv = rendezvous.Rendezvous(db, None, 3600, True)
|
||||
|
||||
# timestamps <=50 are "old", >=51 are "new"
|
||||
#OLD = "old"; NEW = "new"
|
||||
|
@ -404,7 +404,7 @@ class Prune(unittest.TestCase):
|
|||
log.msg(desc)
|
||||
|
||||
db = get_db(":memory:")
|
||||
rv = rendezvous.Rendezvous(db, None, 3600)
|
||||
rv = rendezvous.Rendezvous(db, None, 3600, True)
|
||||
APPID = "appid"
|
||||
app = rv.get_app(APPID)
|
||||
|
||||
|
@ -1174,7 +1174,7 @@ class WebSocketAPI(_Util, ServerBase, unittest.TestCase):
|
|||
|
||||
class Summary(unittest.TestCase):
|
||||
def test_mailbox(self):
|
||||
app = rendezvous.AppNamespace(None, None, False, None)
|
||||
app = rendezvous.AppNamespace(None, None, False, None, True)
|
||||
# starts at time 1, maybe gets second open at time 3, closes at 5
|
||||
def s(rows, pruned=False):
|
||||
return app._summarize_mailbox(rows, 5, pruned)
|
||||
|
@ -1217,7 +1217,7 @@ class Summary(unittest.TestCase):
|
|||
self.assertEqual(s(rows, pruned=True), Usage(1, 2, 4, "crowded"))
|
||||
|
||||
def test_nameplate(self):
|
||||
a = rendezvous.AppNamespace(None, None, False, None)
|
||||
a = rendezvous.AppNamespace(None, None, False, None, True)
|
||||
# starts at time 1, maybe gets second open at time 3, closes at 5
|
||||
def s(rows, pruned=False):
|
||||
return a._summarize_nameplate_usage(rows, 5, pruned)
|
||||
|
@ -1233,10 +1233,21 @@ class Summary(unittest.TestCase):
|
|||
rows = [dict(added=1), dict(added=3), dict(added=4)]
|
||||
self.assertEqual(s(rows), Usage(1, 2, 4, "crowded"))
|
||||
|
||||
def test_nameplate_disallowed(self):
|
||||
db = get_db(":memory:")
|
||||
a = rendezvous.AppNamespace(db, None, False, "some_app_id", False)
|
||||
a.allocate_nameplate("side1", "123")
|
||||
self.assertEqual([], a.get_nameplate_ids())
|
||||
|
||||
def test_nameplate_allowed(self):
|
||||
db = get_db(":memory:")
|
||||
a = rendezvous.AppNamespace(db, None, False, "some_app_id", True)
|
||||
np = a.allocate_nameplate("side1", "321")
|
||||
self.assertEqual(set([np]), a.get_nameplate_ids())
|
||||
|
||||
def test_blur(self):
|
||||
db = get_db(":memory:")
|
||||
rv = rendezvous.Rendezvous(db, None, 3600)
|
||||
rv = rendezvous.Rendezvous(db, None, 3600, True)
|
||||
APPID = "appid"
|
||||
app = rv.get_app(APPID)
|
||||
app.claim_nameplate("npid", "side1", 10) # start time is 10
|
||||
|
@ -1253,7 +1264,7 @@ class Summary(unittest.TestCase):
|
|||
|
||||
def test_no_blur(self):
|
||||
db = get_db(":memory:")
|
||||
rv = rendezvous.Rendezvous(db, None, None)
|
||||
rv = rendezvous.Rendezvous(db, None, None, True)
|
||||
APPID = "appid"
|
||||
app = rv.get_app(APPID)
|
||||
app.claim_nameplate("npid", "side1", 10) # start time is 10
|
||||
|
@ -1292,3 +1303,22 @@ class DumpStats(unittest.TestCase):
|
|||
self.assertEqual(data["rendezvous"]["all_time"]["mailboxes_total"], 0)
|
||||
self.assertEqual(data["transit"]["all_time"]["total"], 0)
|
||||
|
||||
|
||||
class Startup(unittest.TestCase):
|
||||
|
||||
@mock.patch('wormhole.server.server.log')
|
||||
def test_empty(self, fake_log):
|
||||
rs = server.RelayServer(
|
||||
str("tcp:0"),
|
||||
str("tcp:0"),
|
||||
None,
|
||||
allow_list=False,
|
||||
)
|
||||
rs.startService()
|
||||
try:
|
||||
logs = '\n'.join([call[1][0] for call in fake_log.mock_calls])
|
||||
self.assertTrue(
|
||||
'listing of allocated nameplates disallowed' in logs
|
||||
)
|
||||
finally:
|
||||
rs.stopService()
|
||||
|
|
Loading…
Reference in New Issue
Block a user