Merge PR151: add --disallow-list to server

refs #150
closes #53
This commit is contained in:
Brian Warner 2017-05-16 16:44:14 -07:00
commit 66e0d86db8
7 changed files with 115 additions and 26 deletions

View File

@ -48,9 +48,13 @@ def server(ctx): # this is the setuptools entrypoint for bin/wormhole-server
"--signal-error", is_flag=True, "--signal-error", is_flag=True,
help="force all clients to fail with a message", help="force all clients to fail with a message",
) )
@click.option(
"--disallow-list", is_flag=True,
help="never send list of allocated nameplates",
)
@click.pass_obj @click.pass_obj
def start(cfg, signal_error, no_daemon, blur_usage, advertise_version, def start(cfg, signal_error, no_daemon, blur_usage, advertise_version,
transit, rendezvous): transit, rendezvous, disallow_list):
""" """
Start a relay server Start a relay server
""" """
@ -61,6 +65,7 @@ def start(cfg, signal_error, no_daemon, blur_usage, advertise_version,
cfg.transit = str(transit) cfg.transit = str(transit)
cfg.rendezvous = str(rendezvous) cfg.rendezvous = str(rendezvous)
cfg.signal_error = signal_error cfg.signal_error = signal_error
cfg.allow_list = not disallow_list
start_server(cfg) start_server(cfg)

View File

@ -3,20 +3,26 @@ import os, time
from twisted.python import usage from twisted.python import usage
from twisted.scripts import twistd from twisted.scripts import twistd
class MyPlugin: class MyPlugin(object):
tapname = "xyznode" tapname = "xyznode"
def __init__(self, args): def __init__(self, args):
self.args = args self.args = args
def makeService(self, so): def makeService(self, so):
# delay this import as late as possible, to allow twistd's code to # delay this import as late as possible, to allow twistd's code to
# accept --reactor= selection # accept --reactor= selection
from .server import RelayServer from .server import RelayServer
return RelayServer(self.args.rendezvous, self.args.transit, return RelayServer(
self.args.advertise_version, self.args.rendezvous,
"relay.sqlite", self.args.blur_usage, self.args.transit,
signal_error=self.args.signal_error, self.args.advertise_version,
stats_file="stats.json", "relay.sqlite",
) self.args.blur_usage,
signal_error=self.args.signal_error,
stats_file="stats.json",
allow_list=self.args.allow_list,
)
class MyTwistdConfig(twistd.ServerOptions): class MyTwistdConfig(twistd.ServerOptions):
subCommands = [("XYZ", None, usage.Options, "node")] subCommands = [("XYZ", None, usage.Options, "node")]

View File

@ -159,8 +159,10 @@ class Mailbox:
stop_f() stop_f()
self._listeners = {} self._listeners = {}
class AppNamespace:
def __init__(self, db, blur_usage, log_requests, app_id): class AppNamespace(object):
def __init__(self, db, blur_usage, log_requests, app_id, allow_list):
self._db = db self._db = db
self._blur_usage = blur_usage self._blur_usage = blur_usage
self._log_requests = log_requests self._log_requests = log_requests
@ -168,8 +170,14 @@ class AppNamespace:
self._mailboxes = {} self._mailboxes = {}
self._nameplate_counts = collections.defaultdict(int) self._nameplate_counts = collections.defaultdict(int)
self._mailbox_counts = collections.defaultdict(int) self._mailbox_counts = collections.defaultdict(int)
self._allow_list = allow_list
def get_nameplate_ids(self): def get_nameplate_ids(self):
if not self._allow_list:
return []
return self._get_nameplate_ids()
def _get_nameplate_ids(self):
db = self._db db = self._db
# TODO: filter this to numeric ids? # TODO: filter this to numeric ids?
c = db.execute("SELECT DISTINCT `name` FROM `nameplates`" c = db.execute("SELECT DISTINCT `name` FROM `nameplates`"
@ -177,7 +185,7 @@ class AppNamespace:
return set([row["name"] for row in c.fetchall()]) return set([row["name"] for row in c.fetchall()])
def _find_available_nameplate_id(self): def _find_available_nameplate_id(self):
claimed = self.get_nameplate_ids() claimed = self._get_nameplate_ids()
for size in range(1,4): # stick to 1-999 for now for size in range(1,4): # stick to 1-999 for now
available = set() available = set()
for id_int in range(10**(size-1), 10**size): for id_int in range(10**(size-1), 10**size):
@ -505,14 +513,17 @@ class AppNamespace:
for channel in self._mailboxes.values(): for channel in self._mailboxes.values():
channel._shutdown() channel._shutdown()
class Rendezvous(service.MultiService): class Rendezvous(service.MultiService):
def __init__(self, db, welcome, blur_usage):
def __init__(self, db, welcome, blur_usage, allow_list):
service.MultiService.__init__(self) service.MultiService.__init__(self)
self._db = db self._db = db
self._welcome = welcome self._welcome = welcome
self._blur_usage = blur_usage self._blur_usage = blur_usage
log_requests = blur_usage is None log_requests = blur_usage is None
self._log_requests = log_requests self._log_requests = log_requests
self._allow_list = allow_list
self._apps = {} self._apps = {}
def get_welcome(self): def get_welcome(self):
@ -525,9 +536,13 @@ class Rendezvous(service.MultiService):
if not app_id in self._apps: if not app_id in self._apps:
if self._log_requests: if self._log_requests:
log.msg("spawning app_id %s" % (app_id,)) log.msg("spawning app_id %s" % (app_id,))
self._apps[app_id] = AppNamespace(self._db, self._apps[app_id] = AppNamespace(
self._blur_usage, self._db,
self._log_requests, app_id) self._blur_usage,
self._log_requests,
app_id,
self._allow_list,
)
return self._apps[app_id] return self._apps[app_id]
def get_all_apps(self): def get_all_apps(self):

View File

@ -292,6 +292,7 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol):
class WebSocketRendezvousFactory(websocket.WebSocketServerFactory): class WebSocketRendezvousFactory(websocket.WebSocketServerFactory):
protocol = WebSocketRendezvous protocol = WebSocketRendezvous
def __init__(self, url, rendezvous): def __init__(self, url, rendezvous):
websocket.WebSocketServerFactory.__init__(self, url) websocket.WebSocketServerFactory.__init__(self, url)
self.setProtocolOptions(autoPingInterval=60, autoPingTimeout=600) self.setProtocolOptions(autoPingInterval=60, autoPingTimeout=600)

View File

@ -32,11 +32,13 @@ class PrivacyEnhancedSite(server.Site):
return server.Site.log(self, request) return server.Site.log(self, request)
class RelayServer(service.MultiService): class RelayServer(service.MultiService):
def __init__(self, rendezvous_web_port, transit_port, def __init__(self, rendezvous_web_port, transit_port,
advertise_version, db_url=":memory:", blur_usage=None, advertise_version, db_url=":memory:", blur_usage=None,
signal_error=None, stats_file=None): signal_error=None, stats_file=None, allow_list=True):
service.MultiService.__init__(self) service.MultiService.__init__(self)
self._blur_usage = blur_usage self._blur_usage = blur_usage
self._allow_list = allow_list
db = get_db(db_url) db = get_db(db_url)
welcome = { welcome = {
@ -58,7 +60,7 @@ class RelayServer(service.MultiService):
if signal_error: if signal_error:
welcome["error"] = signal_error welcome["error"] = signal_error
self._rendezvous = Rendezvous(db, welcome, blur_usage) self._rendezvous = Rendezvous(db, welcome, blur_usage, self._allow_list)
self._rendezvous.setServiceParent(self) # for the pruning timer self._rendezvous.setServiceParent(self) # for the pruning timer
root = Root() root = Root()
@ -108,6 +110,8 @@ class RelayServer(service.MultiService):
log.msg("not logging HTTP requests or Transit connections") log.msg("not logging HTTP requests or Transit connections")
else: else:
log.msg("not blurring access times") log.msg("not blurring access times")
if not self._allow_list:
log.msg("listing of allocated nameplates disallowed")
def timer(self): def timer(self):
now = time.time() now = time.time()

View File

@ -3,6 +3,7 @@ import os, sys, re, io, zipfile, six, stat
from textwrap import fill, dedent from textwrap import fill, dedent
from humanize import naturalsize from humanize import naturalsize
import mock import mock
import click.testing
from twisted.trial import unittest from twisted.trial import unittest
from twisted.python import procutils, log from twisted.python import procutils, log
from twisted.internet import defer, endpoints, reactor from twisted.internet import defer, endpoints, reactor
@ -12,6 +13,8 @@ from .. import __version__
from .common import ServerBase, config from .common import ServerBase, config
from ..cli import cmd_send, cmd_receive, welcome, cli from ..cli import cmd_send, cmd_receive, welcome, cli
from ..errors import TransferError, WrongPasswordError, WelcomeError from ..errors import TransferError, WrongPasswordError, WelcomeError
from wormhole.server.cmd_server import MyPlugin
from wormhole.server.cli import server
def build_offer(args): def build_offer(args):
@ -1065,3 +1068,28 @@ class Dispatch(unittest.TestCase):
expected = "<TRACEBACK>\nERROR: abcd\n" expected = "<TRACEBACK>\nERROR: abcd\n"
self.assertEqual(cfg.stderr.getvalue(), expected) self.assertEqual(cfg.stderr.getvalue(), expected)
class Server(unittest.TestCase):
def setUp(self):
self.runner = click.testing.CliRunner()
@mock.patch('wormhole.server.cmd_server.twistd')
def test_server_disallow_list(self, fake_twistd):
result = self.runner.invoke(server, ['start', '--no-daemon', '--disallow-list'])
self.assertEqual(0, result.exit_code)
def test_server_plugin(self):
class FakeConfig(object):
no_daemon = True
blur_usage = True
advertise_version = u"fake.version.1"
transit = str('tcp:4321')
rendezvous = str('tcp:1234')
signal_error = True
allow_list = False
cfg = FakeConfig()
plugin = MyPlugin(cfg)
relay = plugin.makeService(None)
self.assertEqual(False, relay._allow_list)

View File

@ -301,7 +301,7 @@ class Prune(unittest.TestCase):
def test_update(self): def test_update(self):
db = get_db(":memory:") db = get_db(":memory:")
rv = rendezvous.Rendezvous(db, None, None) rv = rendezvous.Rendezvous(db, None, None, True)
app = rv.get_app("appid") app = rv.get_app("appid")
mbox_id = "mbox1" mbox_id = "mbox1"
app.open_mailbox(mbox_id, "side1", 1) app.open_mailbox(mbox_id, "side1", 1)
@ -315,7 +315,7 @@ class Prune(unittest.TestCase):
self.assertEqual(self._get_mailbox_updated(app, mbox_id), 3) self.assertEqual(self._get_mailbox_updated(app, mbox_id), 3)
def test_apps(self): def test_apps(self):
rv = rendezvous.Rendezvous(get_db(":memory:"), None, None) rv = rendezvous.Rendezvous(get_db(":memory:"), None, None, True)
app = rv.get_app("appid") app = rv.get_app("appid")
app.allocate_nameplate("side", 121) app.allocate_nameplate("side", 121)
app.prune = mock.Mock() app.prune = mock.Mock()
@ -324,7 +324,7 @@ class Prune(unittest.TestCase):
def test_nameplates(self): def test_nameplates(self):
db = get_db(":memory:") db = get_db(":memory:")
rv = rendezvous.Rendezvous(db, None, 3600) rv = rendezvous.Rendezvous(db, None, 3600, True)
# timestamps <=50 are "old", >=51 are "new" # timestamps <=50 are "old", >=51 are "new"
#OLD = "old"; NEW = "new" #OLD = "old"; NEW = "new"
@ -358,7 +358,7 @@ class Prune(unittest.TestCase):
def test_mailboxes(self): def test_mailboxes(self):
db = get_db(":memory:") db = get_db(":memory:")
rv = rendezvous.Rendezvous(db, None, 3600) rv = rendezvous.Rendezvous(db, None, 3600, True)
# timestamps <=50 are "old", >=51 are "new" # timestamps <=50 are "old", >=51 are "new"
#OLD = "old"; NEW = "new" #OLD = "old"; NEW = "new"
@ -404,7 +404,7 @@ class Prune(unittest.TestCase):
log.msg(desc) log.msg(desc)
db = get_db(":memory:") db = get_db(":memory:")
rv = rendezvous.Rendezvous(db, None, 3600) rv = rendezvous.Rendezvous(db, None, 3600, True)
APPID = "appid" APPID = "appid"
app = rv.get_app(APPID) app = rv.get_app(APPID)
@ -1174,7 +1174,7 @@ class WebSocketAPI(_Util, ServerBase, unittest.TestCase):
class Summary(unittest.TestCase): class Summary(unittest.TestCase):
def test_mailbox(self): def test_mailbox(self):
app = rendezvous.AppNamespace(None, None, False, None) app = rendezvous.AppNamespace(None, None, False, None, True)
# starts at time 1, maybe gets second open at time 3, closes at 5 # starts at time 1, maybe gets second open at time 3, closes at 5
def s(rows, pruned=False): def s(rows, pruned=False):
return app._summarize_mailbox(rows, 5, pruned) return app._summarize_mailbox(rows, 5, pruned)
@ -1217,7 +1217,7 @@ class Summary(unittest.TestCase):
self.assertEqual(s(rows, pruned=True), Usage(1, 2, 4, "crowded")) self.assertEqual(s(rows, pruned=True), Usage(1, 2, 4, "crowded"))
def test_nameplate(self): def test_nameplate(self):
a = rendezvous.AppNamespace(None, None, False, None) a = rendezvous.AppNamespace(None, None, False, None, True)
# starts at time 1, maybe gets second open at time 3, closes at 5 # starts at time 1, maybe gets second open at time 3, closes at 5
def s(rows, pruned=False): def s(rows, pruned=False):
return a._summarize_nameplate_usage(rows, 5, pruned) return a._summarize_nameplate_usage(rows, 5, pruned)
@ -1233,10 +1233,21 @@ class Summary(unittest.TestCase):
rows = [dict(added=1), dict(added=3), dict(added=4)] rows = [dict(added=1), dict(added=3), dict(added=4)]
self.assertEqual(s(rows), Usage(1, 2, 4, "crowded")) self.assertEqual(s(rows), Usage(1, 2, 4, "crowded"))
def test_nameplate_disallowed(self):
db = get_db(":memory:")
a = rendezvous.AppNamespace(db, None, False, "some_app_id", False)
a.allocate_nameplate("side1", "123")
self.assertEqual([], a.get_nameplate_ids())
def test_nameplate_allowed(self):
db = get_db(":memory:")
a = rendezvous.AppNamespace(db, None, False, "some_app_id", True)
np = a.allocate_nameplate("side1", "321")
self.assertEqual(set([np]), a.get_nameplate_ids())
def test_blur(self): def test_blur(self):
db = get_db(":memory:") db = get_db(":memory:")
rv = rendezvous.Rendezvous(db, None, 3600) rv = rendezvous.Rendezvous(db, None, 3600, True)
APPID = "appid" APPID = "appid"
app = rv.get_app(APPID) app = rv.get_app(APPID)
app.claim_nameplate("npid", "side1", 10) # start time is 10 app.claim_nameplate("npid", "side1", 10) # start time is 10
@ -1253,7 +1264,7 @@ class Summary(unittest.TestCase):
def test_no_blur(self): def test_no_blur(self):
db = get_db(":memory:") db = get_db(":memory:")
rv = rendezvous.Rendezvous(db, None, None) rv = rendezvous.Rendezvous(db, None, None, True)
APPID = "appid" APPID = "appid"
app = rv.get_app(APPID) app = rv.get_app(APPID)
app.claim_nameplate("npid", "side1", 10) # start time is 10 app.claim_nameplate("npid", "side1", 10) # start time is 10
@ -1292,3 +1303,22 @@ class DumpStats(unittest.TestCase):
self.assertEqual(data["rendezvous"]["all_time"]["mailboxes_total"], 0) self.assertEqual(data["rendezvous"]["all_time"]["mailboxes_total"], 0)
self.assertEqual(data["transit"]["all_time"]["total"], 0) self.assertEqual(data["transit"]["all_time"]["total"], 0)
class Startup(unittest.TestCase):
@mock.patch('wormhole.server.server.log')
def test_empty(self, fake_log):
rs = server.RelayServer(
str("tcp:0"),
str("tcp:0"),
None,
allow_list=False,
)
rs.startService()
try:
logs = '\n'.join([call[1][0] for call in fake_log.mock_calls])
self.assertTrue(
'listing of allocated nameplates disallowed' in logs
)
finally:
rs.stopService()