Merge PR151: add --disallow-list to server

refs #150
closes #53
This commit is contained in:
Brian Warner 2017-05-16 16:44:14 -07:00
commit 66e0d86db8
7 changed files with 115 additions and 26 deletions

View File

@ -48,9 +48,13 @@ def server(ctx): # this is the setuptools entrypoint for bin/wormhole-server
"--signal-error", is_flag=True,
help="force all clients to fail with a message",
)
@click.option(
"--disallow-list", is_flag=True,
help="never send list of allocated nameplates",
)
@click.pass_obj
def start(cfg, signal_error, no_daemon, blur_usage, advertise_version,
transit, rendezvous):
transit, rendezvous, disallow_list):
"""
Start a relay server
"""
@ -61,6 +65,7 @@ def start(cfg, signal_error, no_daemon, blur_usage, advertise_version,
cfg.transit = str(transit)
cfg.rendezvous = str(rendezvous)
cfg.signal_error = signal_error
cfg.allow_list = not disallow_list
start_server(cfg)

View File

@ -3,20 +3,26 @@ import os, time
from twisted.python import usage
from twisted.scripts import twistd
class MyPlugin:
class MyPlugin(object):
tapname = "xyznode"
def __init__(self, args):
self.args = args
def makeService(self, so):
# delay this import as late as possible, to allow twistd's code to
# accept --reactor= selection
from .server import RelayServer
return RelayServer(self.args.rendezvous, self.args.transit,
self.args.advertise_version,
"relay.sqlite", self.args.blur_usage,
signal_error=self.args.signal_error,
stats_file="stats.json",
)
return RelayServer(
self.args.rendezvous,
self.args.transit,
self.args.advertise_version,
"relay.sqlite",
self.args.blur_usage,
signal_error=self.args.signal_error,
stats_file="stats.json",
allow_list=self.args.allow_list,
)
class MyTwistdConfig(twistd.ServerOptions):
subCommands = [("XYZ", None, usage.Options, "node")]

View File

@ -159,8 +159,10 @@ class Mailbox:
stop_f()
self._listeners = {}
class AppNamespace:
def __init__(self, db, blur_usage, log_requests, app_id):
class AppNamespace(object):
def __init__(self, db, blur_usage, log_requests, app_id, allow_list):
self._db = db
self._blur_usage = blur_usage
self._log_requests = log_requests
@ -168,8 +170,14 @@ class AppNamespace:
self._mailboxes = {}
self._nameplate_counts = collections.defaultdict(int)
self._mailbox_counts = collections.defaultdict(int)
self._allow_list = allow_list
def get_nameplate_ids(self):
if not self._allow_list:
return []
return self._get_nameplate_ids()
def _get_nameplate_ids(self):
db = self._db
# TODO: filter this to numeric ids?
c = db.execute("SELECT DISTINCT `name` FROM `nameplates`"
@ -177,7 +185,7 @@ class AppNamespace:
return set([row["name"] for row in c.fetchall()])
def _find_available_nameplate_id(self):
claimed = self.get_nameplate_ids()
claimed = self._get_nameplate_ids()
for size in range(1,4): # stick to 1-999 for now
available = set()
for id_int in range(10**(size-1), 10**size):
@ -505,14 +513,17 @@ class AppNamespace:
for channel in self._mailboxes.values():
channel._shutdown()
class Rendezvous(service.MultiService):
def __init__(self, db, welcome, blur_usage):
def __init__(self, db, welcome, blur_usage, allow_list):
service.MultiService.__init__(self)
self._db = db
self._welcome = welcome
self._blur_usage = blur_usage
log_requests = blur_usage is None
self._log_requests = log_requests
self._allow_list = allow_list
self._apps = {}
def get_welcome(self):
@ -525,9 +536,13 @@ class Rendezvous(service.MultiService):
if not app_id in self._apps:
if self._log_requests:
log.msg("spawning app_id %s" % (app_id,))
self._apps[app_id] = AppNamespace(self._db,
self._blur_usage,
self._log_requests, app_id)
self._apps[app_id] = AppNamespace(
self._db,
self._blur_usage,
self._log_requests,
app_id,
self._allow_list,
)
return self._apps[app_id]
def get_all_apps(self):

View File

@ -292,6 +292,7 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol):
class WebSocketRendezvousFactory(websocket.WebSocketServerFactory):
protocol = WebSocketRendezvous
def __init__(self, url, rendezvous):
websocket.WebSocketServerFactory.__init__(self, url)
self.setProtocolOptions(autoPingInterval=60, autoPingTimeout=600)

View File

@ -32,11 +32,13 @@ class PrivacyEnhancedSite(server.Site):
return server.Site.log(self, request)
class RelayServer(service.MultiService):
def __init__(self, rendezvous_web_port, transit_port,
advertise_version, db_url=":memory:", blur_usage=None,
signal_error=None, stats_file=None):
signal_error=None, stats_file=None, allow_list=True):
service.MultiService.__init__(self)
self._blur_usage = blur_usage
self._allow_list = allow_list
db = get_db(db_url)
welcome = {
@ -58,7 +60,7 @@ class RelayServer(service.MultiService):
if signal_error:
welcome["error"] = signal_error
self._rendezvous = Rendezvous(db, welcome, blur_usage)
self._rendezvous = Rendezvous(db, welcome, blur_usage, self._allow_list)
self._rendezvous.setServiceParent(self) # for the pruning timer
root = Root()
@ -108,6 +110,8 @@ class RelayServer(service.MultiService):
log.msg("not logging HTTP requests or Transit connections")
else:
log.msg("not blurring access times")
if not self._allow_list:
log.msg("listing of allocated nameplates disallowed")
def timer(self):
now = time.time()

View File

@ -3,6 +3,7 @@ import os, sys, re, io, zipfile, six, stat
from textwrap import fill, dedent
from humanize import naturalsize
import mock
import click.testing
from twisted.trial import unittest
from twisted.python import procutils, log
from twisted.internet import defer, endpoints, reactor
@ -12,6 +13,8 @@ from .. import __version__
from .common import ServerBase, config
from ..cli import cmd_send, cmd_receive, welcome, cli
from ..errors import TransferError, WrongPasswordError, WelcomeError
from wormhole.server.cmd_server import MyPlugin
from wormhole.server.cli import server
def build_offer(args):
@ -1065,3 +1068,28 @@ class Dispatch(unittest.TestCase):
expected = "<TRACEBACK>\nERROR: abcd\n"
self.assertEqual(cfg.stderr.getvalue(), expected)
class Server(unittest.TestCase):
def setUp(self):
self.runner = click.testing.CliRunner()
@mock.patch('wormhole.server.cmd_server.twistd')
def test_server_disallow_list(self, fake_twistd):
result = self.runner.invoke(server, ['start', '--no-daemon', '--disallow-list'])
self.assertEqual(0, result.exit_code)
def test_server_plugin(self):
class FakeConfig(object):
no_daemon = True
blur_usage = True
advertise_version = u"fake.version.1"
transit = str('tcp:4321')
rendezvous = str('tcp:1234')
signal_error = True
allow_list = False
cfg = FakeConfig()
plugin = MyPlugin(cfg)
relay = plugin.makeService(None)
self.assertEqual(False, relay._allow_list)

View File

@ -301,7 +301,7 @@ class Prune(unittest.TestCase):
def test_update(self):
db = get_db(":memory:")
rv = rendezvous.Rendezvous(db, None, None)
rv = rendezvous.Rendezvous(db, None, None, True)
app = rv.get_app("appid")
mbox_id = "mbox1"
app.open_mailbox(mbox_id, "side1", 1)
@ -315,7 +315,7 @@ class Prune(unittest.TestCase):
self.assertEqual(self._get_mailbox_updated(app, mbox_id), 3)
def test_apps(self):
rv = rendezvous.Rendezvous(get_db(":memory:"), None, None)
rv = rendezvous.Rendezvous(get_db(":memory:"), None, None, True)
app = rv.get_app("appid")
app.allocate_nameplate("side", 121)
app.prune = mock.Mock()
@ -324,7 +324,7 @@ class Prune(unittest.TestCase):
def test_nameplates(self):
db = get_db(":memory:")
rv = rendezvous.Rendezvous(db, None, 3600)
rv = rendezvous.Rendezvous(db, None, 3600, True)
# timestamps <=50 are "old", >=51 are "new"
#OLD = "old"; NEW = "new"
@ -358,7 +358,7 @@ class Prune(unittest.TestCase):
def test_mailboxes(self):
db = get_db(":memory:")
rv = rendezvous.Rendezvous(db, None, 3600)
rv = rendezvous.Rendezvous(db, None, 3600, True)
# timestamps <=50 are "old", >=51 are "new"
#OLD = "old"; NEW = "new"
@ -404,7 +404,7 @@ class Prune(unittest.TestCase):
log.msg(desc)
db = get_db(":memory:")
rv = rendezvous.Rendezvous(db, None, 3600)
rv = rendezvous.Rendezvous(db, None, 3600, True)
APPID = "appid"
app = rv.get_app(APPID)
@ -1174,7 +1174,7 @@ class WebSocketAPI(_Util, ServerBase, unittest.TestCase):
class Summary(unittest.TestCase):
def test_mailbox(self):
app = rendezvous.AppNamespace(None, None, False, None)
app = rendezvous.AppNamespace(None, None, False, None, True)
# starts at time 1, maybe gets second open at time 3, closes at 5
def s(rows, pruned=False):
return app._summarize_mailbox(rows, 5, pruned)
@ -1217,7 +1217,7 @@ class Summary(unittest.TestCase):
self.assertEqual(s(rows, pruned=True), Usage(1, 2, 4, "crowded"))
def test_nameplate(self):
a = rendezvous.AppNamespace(None, None, False, None)
a = rendezvous.AppNamespace(None, None, False, None, True)
# starts at time 1, maybe gets second open at time 3, closes at 5
def s(rows, pruned=False):
return a._summarize_nameplate_usage(rows, 5, pruned)
@ -1233,10 +1233,21 @@ class Summary(unittest.TestCase):
rows = [dict(added=1), dict(added=3), dict(added=4)]
self.assertEqual(s(rows), Usage(1, 2, 4, "crowded"))
def test_nameplate_disallowed(self):
db = get_db(":memory:")
a = rendezvous.AppNamespace(db, None, False, "some_app_id", False)
a.allocate_nameplate("side1", "123")
self.assertEqual([], a.get_nameplate_ids())
def test_nameplate_allowed(self):
db = get_db(":memory:")
a = rendezvous.AppNamespace(db, None, False, "some_app_id", True)
np = a.allocate_nameplate("side1", "321")
self.assertEqual(set([np]), a.get_nameplate_ids())
def test_blur(self):
db = get_db(":memory:")
rv = rendezvous.Rendezvous(db, None, 3600)
rv = rendezvous.Rendezvous(db, None, 3600, True)
APPID = "appid"
app = rv.get_app(APPID)
app.claim_nameplate("npid", "side1", 10) # start time is 10
@ -1253,7 +1264,7 @@ class Summary(unittest.TestCase):
def test_no_blur(self):
db = get_db(":memory:")
rv = rendezvous.Rendezvous(db, None, None)
rv = rendezvous.Rendezvous(db, None, None, True)
APPID = "appid"
app = rv.get_app(APPID)
app.claim_nameplate("npid", "side1", 10) # start time is 10
@ -1292,3 +1303,22 @@ class DumpStats(unittest.TestCase):
self.assertEqual(data["rendezvous"]["all_time"]["mailboxes_total"], 0)
self.assertEqual(data["transit"]["all_time"]["total"], 0)
class Startup(unittest.TestCase):
@mock.patch('wormhole.server.server.log')
def test_empty(self, fake_log):
rs = server.RelayServer(
str("tcp:0"),
str("tcp:0"),
None,
allow_list=False,
)
rs.startService()
try:
logs = '\n'.join([call[1][0] for call in fake_log.mock_calls])
self.assertTrue(
'listing of allocated nameplates disallowed' in logs
)
finally:
rs.stopService()