Merge PR197: add --websocket-protocol-option= to server

closes #197
closes #196
This commit is contained in:
Brian Warner 2017-07-15 17:40:22 -07:00
commit 62382b7ac4
4 changed files with 103 additions and 10 deletions

View File

@ -1,4 +1,5 @@
from __future__ import print_function
import json
import click
from ..cli.cli import Config, _compose
@ -19,6 +20,21 @@ def server(ctx): # this is the setuptools entrypoint for bin/wormhole-server
# server commands don't use
ctx.obj = Config()
def _validate_websocket_protocol_options(ctx, param, value):
return list(_validate_websocket_protocol_option(option) for option in value)
def _validate_websocket_protocol_option(option):
try:
key, value = option.split("=", 1)
except ValueError:
raise click.BadParameter("format options as OPTION=VALUE")
try:
value = json.loads(value)
except:
raise click.BadParameter("could not parse JSON value for {}".format(key))
return (key, value)
LaunchArgs = _compose(
click.option(
@ -58,6 +74,11 @@ LaunchArgs = _compose(
"--stats-json-path", default="stats.json", metavar="PATH",
help="location to write the relay stats file",
),
click.option(
"--websocket-protocol-option", multiple=True, metavar="OPTION=VALUE",
callback=_validate_websocket_protocol_options,
help="a websocket server protocol option to configure",
),
)

View File

@ -34,7 +34,8 @@ class RelayServer(service.MultiService):
def __init__(self, rendezvous_web_port, transit_port,
advertise_version, db_url=":memory:", blur_usage=None,
signal_error=None, stats_file=None, allow_list=True):
signal_error=None, stats_file=None, allow_list=True,
websocket_protocol_options=()):
service.MultiService.__init__(self)
self._blur_usage = blur_usage
self._allow_list = allow_list
@ -64,6 +65,7 @@ class RelayServer(service.MultiService):
root = Root()
wsrf = WebSocketRendezvousFactory(None, self._rendezvous)
_set_options(websocket_protocol_options, wsrf)
root.putChild(b"v1", WebSocketResource(wsrf))
site = PrivacyEnhancedSite(root)
@ -137,3 +139,7 @@ class RelayServer(service.MultiService):
f.write(json.dumps(data, indent=1).encode("utf-8"))
f.write(b"\n")
os.rename(tmpfn, self._stats_file)
def _set_options(options, factory):
factory.setProtocolOptions(**dict(options))

View File

@ -1238,3 +1238,48 @@ class Server(unittest.TestCase):
relay = plugin.makeService(None)
self.assertEqual('relay.sqlite', relay._db_url)
self.assertEqual('stats.json', relay._stats_file)
@mock.patch("wormhole.server.cmd_server.start_server")
def test_websocket_protocol_options(self, fake_start_server):
result = self.runner.invoke(
server, [
'start',
'--websocket-protocol-option=a=3',
'--websocket-protocol-option=b=true',
'--websocket-protocol-option=c=3.5',
'--websocket-protocol-option=d=["foo","bar"]',
'--websocket-protocol-option', 'e=["foof","barf"]',
])
self.assertEqual(0, result.exit_code)
cfg = fake_start_server.mock_calls[0][1][0]
self.assertEqual(
cfg.websocket_protocol_option,
[("a", 3), ("b", True), ("c", 3.5), ("d", ['foo', 'bar']),
("e", ['foof', 'barf']),
],
)
def test_broken_websocket_protocol_options(self):
result = self.runner.invoke(
server, [
'start',
'--websocket-protocol-option=a',
])
self.assertNotEqual(0, result.exit_code)
self.assertIn(
'Error: Invalid value for "--websocket-protocol-option": '
'format options as OPTION=VALUE',
result.output,
)
result = self.runner.invoke(
server, [
'start',
'--websocket-protocol-option=a=foo',
])
self.assertNotEqual(0, result.exit_code)
self.assertIn(
'Error: Invalid value for "--websocket-protocol-option": '
'could not parse JSON value for a',
result.output,
)

View File

@ -11,6 +11,19 @@ from ..server import server, rendezvous
from ..server.rendezvous import Usage, SidedMessage
from ..server.database import get_db
def easy_relay(
rendezvous_web_port=str("tcp:0"),
transit_port=str("tcp:0"),
advertise_version=None,
**kwargs
):
return server.RelayServer(
rendezvous_web_port,
transit_port,
advertise_version,
**kwargs
)
class _Util:
def _nameplate(self, app, name):
np_row = app._db.execute("SELECT * FROM `nameplates`"
@ -1313,7 +1326,7 @@ class Summary(unittest.TestCase):
class DumpStats(unittest.TestCase):
def test_nostats(self):
rs = server.RelayServer(str("tcp:0"), str("tcp:0"), None)
rs = easy_relay()
# with no ._stats_file, this should do nothing
rs.dump_stats(1, 1)
@ -1321,8 +1334,7 @@ class DumpStats(unittest.TestCase):
basedir = self.mktemp()
os.mkdir(basedir)
fn = os.path.join(basedir, "stats.json")
rs = server.RelayServer(str("tcp:0"), str("tcp:0"), None,
stats_file=fn)
rs = easy_relay(stats_file=fn)
now = 1234
validity = 500
rs.dump_stats(now, validity)
@ -1339,12 +1351,7 @@ class Startup(unittest.TestCase):
@mock.patch('wormhole.server.server.log')
def test_empty(self, fake_log):
rs = server.RelayServer(
str("tcp:0"),
str("tcp:0"),
None,
allow_list=False,
)
rs = easy_relay(allow_list=False)
rs.startService()
try:
logs = '\n'.join([call[1][0] for call in fake_log.mock_calls])
@ -1353,3 +1360,17 @@ class Startup(unittest.TestCase):
)
finally:
rs.stopService()
class WebSocketProtocolOptions(unittest.TestCase):
@mock.patch('wormhole.server.server.WebSocketRendezvousFactory')
def test_set(self, fake_factory):
easy_relay(
websocket_protocol_options=[
("foo", "bar"),
]
)
self.assertEqual(
mock.call().setProtocolOptions(foo="bar"),
fake_factory.mock_calls[1],
)