commit
66e0d86db8
|
@ -48,9 +48,13 @@ def server(ctx): # this is the setuptools entrypoint for bin/wormhole-server
|
||||||
"--signal-error", is_flag=True,
|
"--signal-error", is_flag=True,
|
||||||
help="force all clients to fail with a message",
|
help="force all clients to fail with a message",
|
||||||
)
|
)
|
||||||
|
@click.option(
|
||||||
|
"--disallow-list", is_flag=True,
|
||||||
|
help="never send list of allocated nameplates",
|
||||||
|
)
|
||||||
@click.pass_obj
|
@click.pass_obj
|
||||||
def start(cfg, signal_error, no_daemon, blur_usage, advertise_version,
|
def start(cfg, signal_error, no_daemon, blur_usage, advertise_version,
|
||||||
transit, rendezvous):
|
transit, rendezvous, disallow_list):
|
||||||
"""
|
"""
|
||||||
Start a relay server
|
Start a relay server
|
||||||
"""
|
"""
|
||||||
|
@ -61,6 +65,7 @@ def start(cfg, signal_error, no_daemon, blur_usage, advertise_version,
|
||||||
cfg.transit = str(transit)
|
cfg.transit = str(transit)
|
||||||
cfg.rendezvous = str(rendezvous)
|
cfg.rendezvous = str(rendezvous)
|
||||||
cfg.signal_error = signal_error
|
cfg.signal_error = signal_error
|
||||||
|
cfg.allow_list = not disallow_list
|
||||||
|
|
||||||
start_server(cfg)
|
start_server(cfg)
|
||||||
|
|
||||||
|
|
|
@ -3,20 +3,26 @@ import os, time
|
||||||
from twisted.python import usage
|
from twisted.python import usage
|
||||||
from twisted.scripts import twistd
|
from twisted.scripts import twistd
|
||||||
|
|
||||||
class MyPlugin:
|
class MyPlugin(object):
|
||||||
tapname = "xyznode"
|
tapname = "xyznode"
|
||||||
|
|
||||||
def __init__(self, args):
|
def __init__(self, args):
|
||||||
self.args = args
|
self.args = args
|
||||||
|
|
||||||
def makeService(self, so):
|
def makeService(self, so):
|
||||||
# delay this import as late as possible, to allow twistd's code to
|
# delay this import as late as possible, to allow twistd's code to
|
||||||
# accept --reactor= selection
|
# accept --reactor= selection
|
||||||
from .server import RelayServer
|
from .server import RelayServer
|
||||||
return RelayServer(self.args.rendezvous, self.args.transit,
|
return RelayServer(
|
||||||
self.args.advertise_version,
|
self.args.rendezvous,
|
||||||
"relay.sqlite", self.args.blur_usage,
|
self.args.transit,
|
||||||
signal_error=self.args.signal_error,
|
self.args.advertise_version,
|
||||||
stats_file="stats.json",
|
"relay.sqlite",
|
||||||
)
|
self.args.blur_usage,
|
||||||
|
signal_error=self.args.signal_error,
|
||||||
|
stats_file="stats.json",
|
||||||
|
allow_list=self.args.allow_list,
|
||||||
|
)
|
||||||
|
|
||||||
class MyTwistdConfig(twistd.ServerOptions):
|
class MyTwistdConfig(twistd.ServerOptions):
|
||||||
subCommands = [("XYZ", None, usage.Options, "node")]
|
subCommands = [("XYZ", None, usage.Options, "node")]
|
||||||
|
|
|
@ -159,8 +159,10 @@ class Mailbox:
|
||||||
stop_f()
|
stop_f()
|
||||||
self._listeners = {}
|
self._listeners = {}
|
||||||
|
|
||||||
class AppNamespace:
|
|
||||||
def __init__(self, db, blur_usage, log_requests, app_id):
|
class AppNamespace(object):
|
||||||
|
|
||||||
|
def __init__(self, db, blur_usage, log_requests, app_id, allow_list):
|
||||||
self._db = db
|
self._db = db
|
||||||
self._blur_usage = blur_usage
|
self._blur_usage = blur_usage
|
||||||
self._log_requests = log_requests
|
self._log_requests = log_requests
|
||||||
|
@ -168,8 +170,14 @@ class AppNamespace:
|
||||||
self._mailboxes = {}
|
self._mailboxes = {}
|
||||||
self._nameplate_counts = collections.defaultdict(int)
|
self._nameplate_counts = collections.defaultdict(int)
|
||||||
self._mailbox_counts = collections.defaultdict(int)
|
self._mailbox_counts = collections.defaultdict(int)
|
||||||
|
self._allow_list = allow_list
|
||||||
|
|
||||||
def get_nameplate_ids(self):
|
def get_nameplate_ids(self):
|
||||||
|
if not self._allow_list:
|
||||||
|
return []
|
||||||
|
return self._get_nameplate_ids()
|
||||||
|
|
||||||
|
def _get_nameplate_ids(self):
|
||||||
db = self._db
|
db = self._db
|
||||||
# TODO: filter this to numeric ids?
|
# TODO: filter this to numeric ids?
|
||||||
c = db.execute("SELECT DISTINCT `name` FROM `nameplates`"
|
c = db.execute("SELECT DISTINCT `name` FROM `nameplates`"
|
||||||
|
@ -177,7 +185,7 @@ class AppNamespace:
|
||||||
return set([row["name"] for row in c.fetchall()])
|
return set([row["name"] for row in c.fetchall()])
|
||||||
|
|
||||||
def _find_available_nameplate_id(self):
|
def _find_available_nameplate_id(self):
|
||||||
claimed = self.get_nameplate_ids()
|
claimed = self._get_nameplate_ids()
|
||||||
for size in range(1,4): # stick to 1-999 for now
|
for size in range(1,4): # stick to 1-999 for now
|
||||||
available = set()
|
available = set()
|
||||||
for id_int in range(10**(size-1), 10**size):
|
for id_int in range(10**(size-1), 10**size):
|
||||||
|
@ -505,14 +513,17 @@ class AppNamespace:
|
||||||
for channel in self._mailboxes.values():
|
for channel in self._mailboxes.values():
|
||||||
channel._shutdown()
|
channel._shutdown()
|
||||||
|
|
||||||
|
|
||||||
class Rendezvous(service.MultiService):
|
class Rendezvous(service.MultiService):
|
||||||
def __init__(self, db, welcome, blur_usage):
|
|
||||||
|
def __init__(self, db, welcome, blur_usage, allow_list):
|
||||||
service.MultiService.__init__(self)
|
service.MultiService.__init__(self)
|
||||||
self._db = db
|
self._db = db
|
||||||
self._welcome = welcome
|
self._welcome = welcome
|
||||||
self._blur_usage = blur_usage
|
self._blur_usage = blur_usage
|
||||||
log_requests = blur_usage is None
|
log_requests = blur_usage is None
|
||||||
self._log_requests = log_requests
|
self._log_requests = log_requests
|
||||||
|
self._allow_list = allow_list
|
||||||
self._apps = {}
|
self._apps = {}
|
||||||
|
|
||||||
def get_welcome(self):
|
def get_welcome(self):
|
||||||
|
@ -525,9 +536,13 @@ class Rendezvous(service.MultiService):
|
||||||
if not app_id in self._apps:
|
if not app_id in self._apps:
|
||||||
if self._log_requests:
|
if self._log_requests:
|
||||||
log.msg("spawning app_id %s" % (app_id,))
|
log.msg("spawning app_id %s" % (app_id,))
|
||||||
self._apps[app_id] = AppNamespace(self._db,
|
self._apps[app_id] = AppNamespace(
|
||||||
self._blur_usage,
|
self._db,
|
||||||
self._log_requests, app_id)
|
self._blur_usage,
|
||||||
|
self._log_requests,
|
||||||
|
app_id,
|
||||||
|
self._allow_list,
|
||||||
|
)
|
||||||
return self._apps[app_id]
|
return self._apps[app_id]
|
||||||
|
|
||||||
def get_all_apps(self):
|
def get_all_apps(self):
|
||||||
|
|
|
@ -292,6 +292,7 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol):
|
||||||
|
|
||||||
class WebSocketRendezvousFactory(websocket.WebSocketServerFactory):
|
class WebSocketRendezvousFactory(websocket.WebSocketServerFactory):
|
||||||
protocol = WebSocketRendezvous
|
protocol = WebSocketRendezvous
|
||||||
|
|
||||||
def __init__(self, url, rendezvous):
|
def __init__(self, url, rendezvous):
|
||||||
websocket.WebSocketServerFactory.__init__(self, url)
|
websocket.WebSocketServerFactory.__init__(self, url)
|
||||||
self.setProtocolOptions(autoPingInterval=60, autoPingTimeout=600)
|
self.setProtocolOptions(autoPingInterval=60, autoPingTimeout=600)
|
||||||
|
|
|
@ -32,11 +32,13 @@ class PrivacyEnhancedSite(server.Site):
|
||||||
return server.Site.log(self, request)
|
return server.Site.log(self, request)
|
||||||
|
|
||||||
class RelayServer(service.MultiService):
|
class RelayServer(service.MultiService):
|
||||||
|
|
||||||
def __init__(self, rendezvous_web_port, transit_port,
|
def __init__(self, rendezvous_web_port, transit_port,
|
||||||
advertise_version, db_url=":memory:", blur_usage=None,
|
advertise_version, db_url=":memory:", blur_usage=None,
|
||||||
signal_error=None, stats_file=None):
|
signal_error=None, stats_file=None, allow_list=True):
|
||||||
service.MultiService.__init__(self)
|
service.MultiService.__init__(self)
|
||||||
self._blur_usage = blur_usage
|
self._blur_usage = blur_usage
|
||||||
|
self._allow_list = allow_list
|
||||||
|
|
||||||
db = get_db(db_url)
|
db = get_db(db_url)
|
||||||
welcome = {
|
welcome = {
|
||||||
|
@ -58,7 +60,7 @@ class RelayServer(service.MultiService):
|
||||||
if signal_error:
|
if signal_error:
|
||||||
welcome["error"] = signal_error
|
welcome["error"] = signal_error
|
||||||
|
|
||||||
self._rendezvous = Rendezvous(db, welcome, blur_usage)
|
self._rendezvous = Rendezvous(db, welcome, blur_usage, self._allow_list)
|
||||||
self._rendezvous.setServiceParent(self) # for the pruning timer
|
self._rendezvous.setServiceParent(self) # for the pruning timer
|
||||||
|
|
||||||
root = Root()
|
root = Root()
|
||||||
|
@ -108,6 +110,8 @@ class RelayServer(service.MultiService):
|
||||||
log.msg("not logging HTTP requests or Transit connections")
|
log.msg("not logging HTTP requests or Transit connections")
|
||||||
else:
|
else:
|
||||||
log.msg("not blurring access times")
|
log.msg("not blurring access times")
|
||||||
|
if not self._allow_list:
|
||||||
|
log.msg("listing of allocated nameplates disallowed")
|
||||||
|
|
||||||
def timer(self):
|
def timer(self):
|
||||||
now = time.time()
|
now = time.time()
|
||||||
|
|
|
@ -3,6 +3,7 @@ import os, sys, re, io, zipfile, six, stat
|
||||||
from textwrap import fill, dedent
|
from textwrap import fill, dedent
|
||||||
from humanize import naturalsize
|
from humanize import naturalsize
|
||||||
import mock
|
import mock
|
||||||
|
import click.testing
|
||||||
from twisted.trial import unittest
|
from twisted.trial import unittest
|
||||||
from twisted.python import procutils, log
|
from twisted.python import procutils, log
|
||||||
from twisted.internet import defer, endpoints, reactor
|
from twisted.internet import defer, endpoints, reactor
|
||||||
|
@ -12,6 +13,8 @@ from .. import __version__
|
||||||
from .common import ServerBase, config
|
from .common import ServerBase, config
|
||||||
from ..cli import cmd_send, cmd_receive, welcome, cli
|
from ..cli import cmd_send, cmd_receive, welcome, cli
|
||||||
from ..errors import TransferError, WrongPasswordError, WelcomeError
|
from ..errors import TransferError, WrongPasswordError, WelcomeError
|
||||||
|
from wormhole.server.cmd_server import MyPlugin
|
||||||
|
from wormhole.server.cli import server
|
||||||
|
|
||||||
|
|
||||||
def build_offer(args):
|
def build_offer(args):
|
||||||
|
@ -1065,3 +1068,28 @@ class Dispatch(unittest.TestCase):
|
||||||
expected = "<TRACEBACK>\nERROR: abcd\n"
|
expected = "<TRACEBACK>\nERROR: abcd\n"
|
||||||
self.assertEqual(cfg.stderr.getvalue(), expected)
|
self.assertEqual(cfg.stderr.getvalue(), expected)
|
||||||
|
|
||||||
|
|
||||||
|
class Server(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.runner = click.testing.CliRunner()
|
||||||
|
|
||||||
|
@mock.patch('wormhole.server.cmd_server.twistd')
|
||||||
|
def test_server_disallow_list(self, fake_twistd):
|
||||||
|
result = self.runner.invoke(server, ['start', '--no-daemon', '--disallow-list'])
|
||||||
|
self.assertEqual(0, result.exit_code)
|
||||||
|
|
||||||
|
def test_server_plugin(self):
|
||||||
|
class FakeConfig(object):
|
||||||
|
no_daemon = True
|
||||||
|
blur_usage = True
|
||||||
|
advertise_version = u"fake.version.1"
|
||||||
|
transit = str('tcp:4321')
|
||||||
|
rendezvous = str('tcp:1234')
|
||||||
|
signal_error = True
|
||||||
|
allow_list = False
|
||||||
|
|
||||||
|
cfg = FakeConfig()
|
||||||
|
plugin = MyPlugin(cfg)
|
||||||
|
relay = plugin.makeService(None)
|
||||||
|
self.assertEqual(False, relay._allow_list)
|
||||||
|
|
|
@ -301,7 +301,7 @@ class Prune(unittest.TestCase):
|
||||||
|
|
||||||
def test_update(self):
|
def test_update(self):
|
||||||
db = get_db(":memory:")
|
db = get_db(":memory:")
|
||||||
rv = rendezvous.Rendezvous(db, None, None)
|
rv = rendezvous.Rendezvous(db, None, None, True)
|
||||||
app = rv.get_app("appid")
|
app = rv.get_app("appid")
|
||||||
mbox_id = "mbox1"
|
mbox_id = "mbox1"
|
||||||
app.open_mailbox(mbox_id, "side1", 1)
|
app.open_mailbox(mbox_id, "side1", 1)
|
||||||
|
@ -315,7 +315,7 @@ class Prune(unittest.TestCase):
|
||||||
self.assertEqual(self._get_mailbox_updated(app, mbox_id), 3)
|
self.assertEqual(self._get_mailbox_updated(app, mbox_id), 3)
|
||||||
|
|
||||||
def test_apps(self):
|
def test_apps(self):
|
||||||
rv = rendezvous.Rendezvous(get_db(":memory:"), None, None)
|
rv = rendezvous.Rendezvous(get_db(":memory:"), None, None, True)
|
||||||
app = rv.get_app("appid")
|
app = rv.get_app("appid")
|
||||||
app.allocate_nameplate("side", 121)
|
app.allocate_nameplate("side", 121)
|
||||||
app.prune = mock.Mock()
|
app.prune = mock.Mock()
|
||||||
|
@ -324,7 +324,7 @@ class Prune(unittest.TestCase):
|
||||||
|
|
||||||
def test_nameplates(self):
|
def test_nameplates(self):
|
||||||
db = get_db(":memory:")
|
db = get_db(":memory:")
|
||||||
rv = rendezvous.Rendezvous(db, None, 3600)
|
rv = rendezvous.Rendezvous(db, None, 3600, True)
|
||||||
|
|
||||||
# timestamps <=50 are "old", >=51 are "new"
|
# timestamps <=50 are "old", >=51 are "new"
|
||||||
#OLD = "old"; NEW = "new"
|
#OLD = "old"; NEW = "new"
|
||||||
|
@ -358,7 +358,7 @@ class Prune(unittest.TestCase):
|
||||||
|
|
||||||
def test_mailboxes(self):
|
def test_mailboxes(self):
|
||||||
db = get_db(":memory:")
|
db = get_db(":memory:")
|
||||||
rv = rendezvous.Rendezvous(db, None, 3600)
|
rv = rendezvous.Rendezvous(db, None, 3600, True)
|
||||||
|
|
||||||
# timestamps <=50 are "old", >=51 are "new"
|
# timestamps <=50 are "old", >=51 are "new"
|
||||||
#OLD = "old"; NEW = "new"
|
#OLD = "old"; NEW = "new"
|
||||||
|
@ -404,7 +404,7 @@ class Prune(unittest.TestCase):
|
||||||
log.msg(desc)
|
log.msg(desc)
|
||||||
|
|
||||||
db = get_db(":memory:")
|
db = get_db(":memory:")
|
||||||
rv = rendezvous.Rendezvous(db, None, 3600)
|
rv = rendezvous.Rendezvous(db, None, 3600, True)
|
||||||
APPID = "appid"
|
APPID = "appid"
|
||||||
app = rv.get_app(APPID)
|
app = rv.get_app(APPID)
|
||||||
|
|
||||||
|
@ -1174,7 +1174,7 @@ class WebSocketAPI(_Util, ServerBase, unittest.TestCase):
|
||||||
|
|
||||||
class Summary(unittest.TestCase):
|
class Summary(unittest.TestCase):
|
||||||
def test_mailbox(self):
|
def test_mailbox(self):
|
||||||
app = rendezvous.AppNamespace(None, None, False, None)
|
app = rendezvous.AppNamespace(None, None, False, None, True)
|
||||||
# starts at time 1, maybe gets second open at time 3, closes at 5
|
# starts at time 1, maybe gets second open at time 3, closes at 5
|
||||||
def s(rows, pruned=False):
|
def s(rows, pruned=False):
|
||||||
return app._summarize_mailbox(rows, 5, pruned)
|
return app._summarize_mailbox(rows, 5, pruned)
|
||||||
|
@ -1217,7 +1217,7 @@ class Summary(unittest.TestCase):
|
||||||
self.assertEqual(s(rows, pruned=True), Usage(1, 2, 4, "crowded"))
|
self.assertEqual(s(rows, pruned=True), Usage(1, 2, 4, "crowded"))
|
||||||
|
|
||||||
def test_nameplate(self):
|
def test_nameplate(self):
|
||||||
a = rendezvous.AppNamespace(None, None, False, None)
|
a = rendezvous.AppNamespace(None, None, False, None, True)
|
||||||
# starts at time 1, maybe gets second open at time 3, closes at 5
|
# starts at time 1, maybe gets second open at time 3, closes at 5
|
||||||
def s(rows, pruned=False):
|
def s(rows, pruned=False):
|
||||||
return a._summarize_nameplate_usage(rows, 5, pruned)
|
return a._summarize_nameplate_usage(rows, 5, pruned)
|
||||||
|
@ -1233,10 +1233,21 @@ class Summary(unittest.TestCase):
|
||||||
rows = [dict(added=1), dict(added=3), dict(added=4)]
|
rows = [dict(added=1), dict(added=3), dict(added=4)]
|
||||||
self.assertEqual(s(rows), Usage(1, 2, 4, "crowded"))
|
self.assertEqual(s(rows), Usage(1, 2, 4, "crowded"))
|
||||||
|
|
||||||
|
def test_nameplate_disallowed(self):
|
||||||
|
db = get_db(":memory:")
|
||||||
|
a = rendezvous.AppNamespace(db, None, False, "some_app_id", False)
|
||||||
|
a.allocate_nameplate("side1", "123")
|
||||||
|
self.assertEqual([], a.get_nameplate_ids())
|
||||||
|
|
||||||
|
def test_nameplate_allowed(self):
|
||||||
|
db = get_db(":memory:")
|
||||||
|
a = rendezvous.AppNamespace(db, None, False, "some_app_id", True)
|
||||||
|
np = a.allocate_nameplate("side1", "321")
|
||||||
|
self.assertEqual(set([np]), a.get_nameplate_ids())
|
||||||
|
|
||||||
def test_blur(self):
|
def test_blur(self):
|
||||||
db = get_db(":memory:")
|
db = get_db(":memory:")
|
||||||
rv = rendezvous.Rendezvous(db, None, 3600)
|
rv = rendezvous.Rendezvous(db, None, 3600, True)
|
||||||
APPID = "appid"
|
APPID = "appid"
|
||||||
app = rv.get_app(APPID)
|
app = rv.get_app(APPID)
|
||||||
app.claim_nameplate("npid", "side1", 10) # start time is 10
|
app.claim_nameplate("npid", "side1", 10) # start time is 10
|
||||||
|
@ -1253,7 +1264,7 @@ class Summary(unittest.TestCase):
|
||||||
|
|
||||||
def test_no_blur(self):
|
def test_no_blur(self):
|
||||||
db = get_db(":memory:")
|
db = get_db(":memory:")
|
||||||
rv = rendezvous.Rendezvous(db, None, None)
|
rv = rendezvous.Rendezvous(db, None, None, True)
|
||||||
APPID = "appid"
|
APPID = "appid"
|
||||||
app = rv.get_app(APPID)
|
app = rv.get_app(APPID)
|
||||||
app.claim_nameplate("npid", "side1", 10) # start time is 10
|
app.claim_nameplate("npid", "side1", 10) # start time is 10
|
||||||
|
@ -1292,3 +1303,22 @@ class DumpStats(unittest.TestCase):
|
||||||
self.assertEqual(data["rendezvous"]["all_time"]["mailboxes_total"], 0)
|
self.assertEqual(data["rendezvous"]["all_time"]["mailboxes_total"], 0)
|
||||||
self.assertEqual(data["transit"]["all_time"]["total"], 0)
|
self.assertEqual(data["transit"]["all_time"]["total"], 0)
|
||||||
|
|
||||||
|
|
||||||
|
class Startup(unittest.TestCase):
|
||||||
|
|
||||||
|
@mock.patch('wormhole.server.server.log')
|
||||||
|
def test_empty(self, fake_log):
|
||||||
|
rs = server.RelayServer(
|
||||||
|
str("tcp:0"),
|
||||||
|
str("tcp:0"),
|
||||||
|
None,
|
||||||
|
allow_list=False,
|
||||||
|
)
|
||||||
|
rs.startService()
|
||||||
|
try:
|
||||||
|
logs = '\n'.join([call[1][0] for call in fake_log.mock_calls])
|
||||||
|
self.assertTrue(
|
||||||
|
'listing of allocated nameplates disallowed' in logs
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
rs.stopService()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user