Merge PR197: add --websocket-protocol-option= to server
closes #197 closes #196
This commit is contained in:
commit
62382b7ac4
|
@ -1,4 +1,5 @@
|
|||
from __future__ import print_function
|
||||
import json
|
||||
import click
|
||||
from ..cli.cli import Config, _compose
|
||||
|
||||
|
@ -19,6 +20,21 @@ def server(ctx): # this is the setuptools entrypoint for bin/wormhole-server
|
|||
# server commands don't use
|
||||
ctx.obj = Config()
|
||||
|
||||
def _validate_websocket_protocol_options(ctx, param, value):
|
||||
return list(_validate_websocket_protocol_option(option) for option in value)
|
||||
|
||||
def _validate_websocket_protocol_option(option):
|
||||
try:
|
||||
key, value = option.split("=", 1)
|
||||
except ValueError:
|
||||
raise click.BadParameter("format options as OPTION=VALUE")
|
||||
|
||||
try:
|
||||
value = json.loads(value)
|
||||
except:
|
||||
raise click.BadParameter("could not parse JSON value for {}".format(key))
|
||||
|
||||
return (key, value)
|
||||
|
||||
LaunchArgs = _compose(
|
||||
click.option(
|
||||
|
@ -58,6 +74,11 @@ LaunchArgs = _compose(
|
|||
"--stats-json-path", default="stats.json", metavar="PATH",
|
||||
help="location to write the relay stats file",
|
||||
),
|
||||
click.option(
|
||||
"--websocket-protocol-option", multiple=True, metavar="OPTION=VALUE",
|
||||
callback=_validate_websocket_protocol_options,
|
||||
help="a websocket server protocol option to configure",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -34,7 +34,8 @@ class RelayServer(service.MultiService):
|
|||
|
||||
def __init__(self, rendezvous_web_port, transit_port,
|
||||
advertise_version, db_url=":memory:", blur_usage=None,
|
||||
signal_error=None, stats_file=None, allow_list=True):
|
||||
signal_error=None, stats_file=None, allow_list=True,
|
||||
websocket_protocol_options=()):
|
||||
service.MultiService.__init__(self)
|
||||
self._blur_usage = blur_usage
|
||||
self._allow_list = allow_list
|
||||
|
@ -64,6 +65,7 @@ class RelayServer(service.MultiService):
|
|||
|
||||
root = Root()
|
||||
wsrf = WebSocketRendezvousFactory(None, self._rendezvous)
|
||||
_set_options(websocket_protocol_options, wsrf)
|
||||
root.putChild(b"v1", WebSocketResource(wsrf))
|
||||
|
||||
site = PrivacyEnhancedSite(root)
|
||||
|
@ -137,3 +139,7 @@ class RelayServer(service.MultiService):
|
|||
f.write(json.dumps(data, indent=1).encode("utf-8"))
|
||||
f.write(b"\n")
|
||||
os.rename(tmpfn, self._stats_file)
|
||||
|
||||
|
||||
def _set_options(options, factory):
|
||||
factory.setProtocolOptions(**dict(options))
|
||||
|
|
|
@ -1238,3 +1238,48 @@ class Server(unittest.TestCase):
|
|||
relay = plugin.makeService(None)
|
||||
self.assertEqual('relay.sqlite', relay._db_url)
|
||||
self.assertEqual('stats.json', relay._stats_file)
|
||||
|
||||
@mock.patch("wormhole.server.cmd_server.start_server")
|
||||
def test_websocket_protocol_options(self, fake_start_server):
|
||||
result = self.runner.invoke(
|
||||
server, [
|
||||
'start',
|
||||
'--websocket-protocol-option=a=3',
|
||||
'--websocket-protocol-option=b=true',
|
||||
'--websocket-protocol-option=c=3.5',
|
||||
'--websocket-protocol-option=d=["foo","bar"]',
|
||||
'--websocket-protocol-option', 'e=["foof","barf"]',
|
||||
])
|
||||
self.assertEqual(0, result.exit_code)
|
||||
cfg = fake_start_server.mock_calls[0][1][0]
|
||||
self.assertEqual(
|
||||
cfg.websocket_protocol_option,
|
||||
[("a", 3), ("b", True), ("c", 3.5), ("d", ['foo', 'bar']),
|
||||
("e", ['foof', 'barf']),
|
||||
],
|
||||
)
|
||||
|
||||
def test_broken_websocket_protocol_options(self):
|
||||
result = self.runner.invoke(
|
||||
server, [
|
||||
'start',
|
||||
'--websocket-protocol-option=a',
|
||||
])
|
||||
self.assertNotEqual(0, result.exit_code)
|
||||
self.assertIn(
|
||||
'Error: Invalid value for "--websocket-protocol-option": '
|
||||
'format options as OPTION=VALUE',
|
||||
result.output,
|
||||
)
|
||||
|
||||
result = self.runner.invoke(
|
||||
server, [
|
||||
'start',
|
||||
'--websocket-protocol-option=a=foo',
|
||||
])
|
||||
self.assertNotEqual(0, result.exit_code)
|
||||
self.assertIn(
|
||||
'Error: Invalid value for "--websocket-protocol-option": '
|
||||
'could not parse JSON value for a',
|
||||
result.output,
|
||||
)
|
||||
|
|
|
@ -11,6 +11,19 @@ from ..server import server, rendezvous
|
|||
from ..server.rendezvous import Usage, SidedMessage
|
||||
from ..server.database import get_db
|
||||
|
||||
def easy_relay(
|
||||
rendezvous_web_port=str("tcp:0"),
|
||||
transit_port=str("tcp:0"),
|
||||
advertise_version=None,
|
||||
**kwargs
|
||||
):
|
||||
return server.RelayServer(
|
||||
rendezvous_web_port,
|
||||
transit_port,
|
||||
advertise_version,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
class _Util:
|
||||
def _nameplate(self, app, name):
|
||||
np_row = app._db.execute("SELECT * FROM `nameplates`"
|
||||
|
@ -1313,7 +1326,7 @@ class Summary(unittest.TestCase):
|
|||
|
||||
class DumpStats(unittest.TestCase):
|
||||
def test_nostats(self):
|
||||
rs = server.RelayServer(str("tcp:0"), str("tcp:0"), None)
|
||||
rs = easy_relay()
|
||||
# with no ._stats_file, this should do nothing
|
||||
rs.dump_stats(1, 1)
|
||||
|
||||
|
@ -1321,8 +1334,7 @@ class DumpStats(unittest.TestCase):
|
|||
basedir = self.mktemp()
|
||||
os.mkdir(basedir)
|
||||
fn = os.path.join(basedir, "stats.json")
|
||||
rs = server.RelayServer(str("tcp:0"), str("tcp:0"), None,
|
||||
stats_file=fn)
|
||||
rs = easy_relay(stats_file=fn)
|
||||
now = 1234
|
||||
validity = 500
|
||||
rs.dump_stats(now, validity)
|
||||
|
@ -1339,12 +1351,7 @@ class Startup(unittest.TestCase):
|
|||
|
||||
@mock.patch('wormhole.server.server.log')
|
||||
def test_empty(self, fake_log):
|
||||
rs = server.RelayServer(
|
||||
str("tcp:0"),
|
||||
str("tcp:0"),
|
||||
None,
|
||||
allow_list=False,
|
||||
)
|
||||
rs = easy_relay(allow_list=False)
|
||||
rs.startService()
|
||||
try:
|
||||
logs = '\n'.join([call[1][0] for call in fake_log.mock_calls])
|
||||
|
@ -1353,3 +1360,17 @@ class Startup(unittest.TestCase):
|
|||
)
|
||||
finally:
|
||||
rs.stopService()
|
||||
|
||||
|
||||
class WebSocketProtocolOptions(unittest.TestCase):
|
||||
@mock.patch('wormhole.server.server.WebSocketRendezvousFactory')
|
||||
def test_set(self, fake_factory):
|
||||
easy_relay(
|
||||
websocket_protocol_options=[
|
||||
("foo", "bar"),
|
||||
]
|
||||
)
|
||||
self.assertEqual(
|
||||
mock.call().setProtocolOptions(foo="bar"),
|
||||
fake_factory.mock_calls[1],
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue
Block a user