From 6b31517b673aefc396d9111035cca8d44188861a Mon Sep 17 00:00:00 2001 From: meejah Date: Sun, 7 May 2017 21:02:30 -0600 Subject: [PATCH] Add an allow_list option to control nameplate-listings --- src/wormhole/server/cli.py | 7 ++- src/wormhole/server/cmd_server.py | 20 ++++++--- src/wormhole/server/rendezvous.py | 29 ++++++++++--- src/wormhole/server/rendezvous_websocket.py | 1 + src/wormhole/server/server.py | 8 +++- src/wormhole/test/test_cli.py | 28 ++++++++++++ src/wormhole/test/test_server.py | 48 +++++++++++++++++---- 7 files changed, 115 insertions(+), 26 deletions(-) diff --git a/src/wormhole/server/cli.py b/src/wormhole/server/cli.py index 8279e10..3377af3 100644 --- a/src/wormhole/server/cli.py +++ b/src/wormhole/server/cli.py @@ -48,9 +48,13 @@ def server(ctx): # this is the setuptools entrypoint for bin/wormhole-server "--signal-error", is_flag=True, help="force all clients to fail with a message", ) +@click.option( + "--disallow-list", is_flag=True, + help="never send list of allocated nameplates", +) @click.pass_obj def start(cfg, signal_error, no_daemon, blur_usage, advertise_version, - transit, rendezvous): + transit, rendezvous, disallow_list): """ Start a relay server """ @@ -61,6 +65,7 @@ def start(cfg, signal_error, no_daemon, blur_usage, advertise_version, cfg.transit = str(transit) cfg.rendezvous = str(rendezvous) cfg.signal_error = signal_error + cfg.allow_list = not disallow_list start_server(cfg) diff --git a/src/wormhole/server/cmd_server.py b/src/wormhole/server/cmd_server.py index 7abbf0f..7aff110 100644 --- a/src/wormhole/server/cmd_server.py +++ b/src/wormhole/server/cmd_server.py @@ -3,20 +3,26 @@ import os, time from twisted.python import usage from twisted.scripts import twistd -class MyPlugin: +class MyPlugin(object): tapname = "xyznode" + def __init__(self, args): self.args = args + def makeService(self, so): # delay this import as late as possible, to allow twistd's code to # accept --reactor= selection from .server import RelayServer - return RelayServer(self.args.rendezvous, self.args.transit, - self.args.advertise_version, - "relay.sqlite", self.args.blur_usage, - signal_error=self.args.signal_error, - stats_file="stats.json", - ) + return RelayServer( + self.args.rendezvous, + self.args.transit, + self.args.advertise_version, + "relay.sqlite", + self.args.blur_usage, + signal_error=self.args.signal_error, + stats_file="stats.json", + allow_list=self.args.allow_list, + ) class MyTwistdConfig(twistd.ServerOptions): subCommands = [("XYZ", None, usage.Options, "node")] diff --git a/src/wormhole/server/rendezvous.py b/src/wormhole/server/rendezvous.py index 8ad206a..50486aa 100644 --- a/src/wormhole/server/rendezvous.py +++ b/src/wormhole/server/rendezvous.py @@ -159,8 +159,10 @@ class Mailbox: stop_f() self._listeners = {} -class AppNamespace: - def __init__(self, db, blur_usage, log_requests, app_id): + +class AppNamespace(object): + + def __init__(self, db, blur_usage, log_requests, app_id, allow_list): self._db = db self._blur_usage = blur_usage self._log_requests = log_requests @@ -168,8 +170,14 @@ class AppNamespace: self._mailboxes = {} self._nameplate_counts = collections.defaultdict(int) self._mailbox_counts = collections.defaultdict(int) + self._allow_list = allow_list def get_nameplate_ids(self): + if not self._allow_list: + return [] + return self._get_nameplate_ids() + + def _get_nameplate_ids(self): db = self._db # TODO: filter this to numeric ids? c = db.execute("SELECT DISTINCT `name` FROM `nameplates`" @@ -177,7 +185,7 @@ class AppNamespace: return set([row["name"] for row in c.fetchall()]) def _find_available_nameplate_id(self): - claimed = self.get_nameplate_ids() + claimed = self._get_nameplate_ids() for size in range(1,4): # stick to 1-999 for now available = set() for id_int in range(10**(size-1), 10**size): @@ -505,14 +513,17 @@ class AppNamespace: for channel in self._mailboxes.values(): channel._shutdown() + class Rendezvous(service.MultiService): - def __init__(self, db, welcome, blur_usage): + + def __init__(self, db, welcome, blur_usage, allow_list): service.MultiService.__init__(self) self._db = db self._welcome = welcome self._blur_usage = blur_usage log_requests = blur_usage is None self._log_requests = log_requests + self._allow_list = allow_list self._apps = {} def get_welcome(self): @@ -525,9 +536,13 @@ class Rendezvous(service.MultiService): if not app_id in self._apps: if self._log_requests: log.msg("spawning app_id %s" % (app_id,)) - self._apps[app_id] = AppNamespace(self._db, - self._blur_usage, - self._log_requests, app_id) + self._apps[app_id] = AppNamespace( + self._db, + self._blur_usage, + self._log_requests, + app_id, + self._allow_list, + ) return self._apps[app_id] def get_all_apps(self): diff --git a/src/wormhole/server/rendezvous_websocket.py b/src/wormhole/server/rendezvous_websocket.py index e489f87..dc4eb51 100644 --- a/src/wormhole/server/rendezvous_websocket.py +++ b/src/wormhole/server/rendezvous_websocket.py @@ -292,6 +292,7 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): class WebSocketRendezvousFactory(websocket.WebSocketServerFactory): protocol = WebSocketRendezvous + def __init__(self, url, rendezvous): websocket.WebSocketServerFactory.__init__(self, url) self.setProtocolOptions(autoPingInterval=60, autoPingTimeout=600) diff --git a/src/wormhole/server/server.py b/src/wormhole/server/server.py index 72ef694..4f94294 100644 --- a/src/wormhole/server/server.py +++ b/src/wormhole/server/server.py @@ -32,11 +32,13 @@ class PrivacyEnhancedSite(server.Site): return server.Site.log(self, request) class RelayServer(service.MultiService): + def __init__(self, rendezvous_web_port, transit_port, advertise_version, db_url=":memory:", blur_usage=None, - signal_error=None, stats_file=None): + signal_error=None, stats_file=None, allow_list=True): service.MultiService.__init__(self) self._blur_usage = blur_usage + self._allow_list = allow_list db = get_db(db_url) welcome = { @@ -58,7 +60,7 @@ class RelayServer(service.MultiService): if signal_error: welcome["error"] = signal_error - self._rendezvous = Rendezvous(db, welcome, blur_usage) + self._rendezvous = Rendezvous(db, welcome, blur_usage, self._allow_list) self._rendezvous.setServiceParent(self) # for the pruning timer root = Root() @@ -108,6 +110,8 @@ class RelayServer(service.MultiService): log.msg("not logging HTTP requests or Transit connections") else: log.msg("not blurring access times") + if not self._allow_list: + log.msg("listing of allocated nameplates disallowed") def timer(self): now = time.time() diff --git a/src/wormhole/test/test_cli.py b/src/wormhole/test/test_cli.py index 71393d1..3dd23c7 100644 --- a/src/wormhole/test/test_cli.py +++ b/src/wormhole/test/test_cli.py @@ -3,6 +3,7 @@ import os, sys, re, io, zipfile, six, stat from textwrap import fill, dedent from humanize import naturalsize import mock +import click.testing from twisted.trial import unittest from twisted.python import procutils, log from twisted.internet import defer, endpoints, reactor @@ -12,6 +13,8 @@ from .. import __version__ from .common import ServerBase, config from ..cli import cmd_send, cmd_receive, welcome, cli from ..errors import TransferError, WrongPasswordError, WelcomeError +from wormhole.server.cmd_server import MyPlugin +from wormhole.server.cli import server def build_offer(args): @@ -1065,3 +1068,28 @@ class Dispatch(unittest.TestCase): expected = "\nERROR: abcd\n" self.assertEqual(cfg.stderr.getvalue(), expected) + +class Server(unittest.TestCase): + + def setUp(self): + self.runner = click.testing.CliRunner() + + @mock.patch('wormhole.server.cmd_server.twistd') + def test_server_disallow_list(self, fake_twistd): + result = self.runner.invoke(server, ['start', '--no-daemon', '--disallow-list']) + self.assertEqual(0, result.exit_code) + + def test_server_plugin(self): + class FakeConfig(object): + no_daemon = True + blur_usage = True + advertise_version = u"fake.version.1" + transit = str('tcp:4321') + rendezvous = str('tcp:1234') + signal_error = True + allow_list = False + + cfg = FakeConfig() + plugin = MyPlugin(cfg) + relay = plugin.makeService(None) + self.assertEqual(False, relay._allow_list) diff --git a/src/wormhole/test/test_server.py b/src/wormhole/test/test_server.py index 06daff3..351729c 100644 --- a/src/wormhole/test/test_server.py +++ b/src/wormhole/test/test_server.py @@ -301,7 +301,7 @@ class Prune(unittest.TestCase): def test_update(self): db = get_db(":memory:") - rv = rendezvous.Rendezvous(db, None, None) + rv = rendezvous.Rendezvous(db, None, None, True) app = rv.get_app("appid") mbox_id = "mbox1" app.open_mailbox(mbox_id, "side1", 1) @@ -315,7 +315,7 @@ class Prune(unittest.TestCase): self.assertEqual(self._get_mailbox_updated(app, mbox_id), 3) def test_apps(self): - rv = rendezvous.Rendezvous(get_db(":memory:"), None, None) + rv = rendezvous.Rendezvous(get_db(":memory:"), None, None, True) app = rv.get_app("appid") app.allocate_nameplate("side", 121) app.prune = mock.Mock() @@ -324,7 +324,7 @@ class Prune(unittest.TestCase): def test_nameplates(self): db = get_db(":memory:") - rv = rendezvous.Rendezvous(db, None, 3600) + rv = rendezvous.Rendezvous(db, None, 3600, True) # timestamps <=50 are "old", >=51 are "new" #OLD = "old"; NEW = "new" @@ -358,7 +358,7 @@ class Prune(unittest.TestCase): def test_mailboxes(self): db = get_db(":memory:") - rv = rendezvous.Rendezvous(db, None, 3600) + rv = rendezvous.Rendezvous(db, None, 3600, True) # timestamps <=50 are "old", >=51 are "new" #OLD = "old"; NEW = "new" @@ -404,7 +404,7 @@ class Prune(unittest.TestCase): log.msg(desc) db = get_db(":memory:") - rv = rendezvous.Rendezvous(db, None, 3600) + rv = rendezvous.Rendezvous(db, None, 3600, True) APPID = "appid" app = rv.get_app(APPID) @@ -1174,7 +1174,7 @@ class WebSocketAPI(_Util, ServerBase, unittest.TestCase): class Summary(unittest.TestCase): def test_mailbox(self): - app = rendezvous.AppNamespace(None, None, False, None) + app = rendezvous.AppNamespace(None, None, False, None, True) # starts at time 1, maybe gets second open at time 3, closes at 5 def s(rows, pruned=False): return app._summarize_mailbox(rows, 5, pruned) @@ -1217,7 +1217,7 @@ class Summary(unittest.TestCase): self.assertEqual(s(rows, pruned=True), Usage(1, 2, 4, "crowded")) def test_nameplate(self): - a = rendezvous.AppNamespace(None, None, False, None) + a = rendezvous.AppNamespace(None, None, False, None, True) # starts at time 1, maybe gets second open at time 3, closes at 5 def s(rows, pruned=False): return a._summarize_nameplate_usage(rows, 5, pruned) @@ -1233,10 +1233,21 @@ class Summary(unittest.TestCase): rows = [dict(added=1), dict(added=3), dict(added=4)] self.assertEqual(s(rows), Usage(1, 2, 4, "crowded")) + def test_nameplate_disallowed(self): + db = get_db(":memory:") + a = rendezvous.AppNamespace(db, None, False, "some_app_id", False) + a.allocate_nameplate("side1", "123") + self.assertEqual([], a.get_nameplate_ids()) + + def test_nameplate_allowed(self): + db = get_db(":memory:") + a = rendezvous.AppNamespace(db, None, False, "some_app_id", True) + np = a.allocate_nameplate("side1", "321") + self.assertEqual(set([np]), a.get_nameplate_ids()) def test_blur(self): db = get_db(":memory:") - rv = rendezvous.Rendezvous(db, None, 3600) + rv = rendezvous.Rendezvous(db, None, 3600, True) APPID = "appid" app = rv.get_app(APPID) app.claim_nameplate("npid", "side1", 10) # start time is 10 @@ -1253,7 +1264,7 @@ class Summary(unittest.TestCase): def test_no_blur(self): db = get_db(":memory:") - rv = rendezvous.Rendezvous(db, None, None) + rv = rendezvous.Rendezvous(db, None, None, True) APPID = "appid" app = rv.get_app(APPID) app.claim_nameplate("npid", "side1", 10) # start time is 10 @@ -1292,3 +1303,22 @@ class DumpStats(unittest.TestCase): self.assertEqual(data["rendezvous"]["all_time"]["mailboxes_total"], 0) self.assertEqual(data["transit"]["all_time"]["total"], 0) + +class Startup(unittest.TestCase): + + @mock.patch('wormhole.server.server.log') + def test_empty(self, fake_log): + rs = server.RelayServer( + str("tcp:0"), + str("tcp:0"), + None, + allow_list=False, + ) + rs.startService() + try: + logs = '\n'.join([call[1][0] for call in fake_log.mock_calls]) + self.assertTrue( + 'listing of allocated nameplates disallowed' in logs + ) + finally: + rs.stopService()