accept 'wss' in relay_url, use TLS for those connections

Do the same under Tor.

If the hostname is missing, use 443 when using TLS, or 80 when not.

refs #144
This commit is contained in:
Brian Warner 2020-01-16 18:57:53 -08:00
parent fe7b036027
commit 5a60b247f5
3 changed files with 82 additions and 11 deletions

View File

@ -81,8 +81,7 @@ class RendezvousConnector(object):
self._ws = None self._ws = None
f = WSFactory(self, self._url) f = WSFactory(self, self._url)
f.setProtocolOptions(autoPingInterval=60, autoPingTimeout=600) f.setProtocolOptions(autoPingInterval=60, autoPingTimeout=600)
p = urlparse(self._url) ep = self._make_endpoint(self._url)
ep = self._make_endpoint(p.hostname, p.port or 80)
self._connector = internet.ClientService(ep, f) self._connector = internet.ClientService(ep, f)
faf = None if self._have_made_a_successful_connection else 1 faf = None if self._have_made_a_successful_connection else 1
d = self._connector.whenConnected(failAfterFailures=faf) d = self._connector.whenConnected(failAfterFailures=faf)
@ -100,11 +99,16 @@ class RendezvousConnector(object):
if self._trace: if self._trace:
self._trace(old_state="", input=what, new_state="") self._trace(old_state="", input=what, new_state="")
def _make_endpoint(self, hostname, port): def _make_endpoint(self, url):
p = urlparse(url)
tls = (p.scheme == "wss")
port = p.port or (443 if tls else 80)
if self._tor: if self._tor:
# TODO: when we enable TLS, maybe add tls=True here return self._tor.stream_via(p.hostname, port, tls=tls)
return self._tor.stream_via(hostname, port) if tls:
return endpoints.HostnameEndpoint(self._reactor, hostname, port) return endpoints.clientFromString(self._reactor,
"tls:%s:%s" % (p.hostname, port))
return endpoints.HostnameEndpoint(self._reactor, p.hostname, port)
def wire(self, boss, nameplate, mailbox, allocator, lister, terminator): def wire(self, boss, nameplate, mailbox, allocator, lister, terminator):
self._B = _interfaces.IBoss(boss) self._B = _interfaces.IBoss(boss)

View File

@ -384,8 +384,8 @@ class FakeTor:
def __init__(self): def __init__(self):
self.endpoints = [] self.endpoints = []
def stream_via(self, host, port): def stream_via(self, host, port, tls=False):
self.endpoints.append((host, port)) self.endpoints.append((host, port, tls))
return endpoints.HostnameEndpoint(reactor, host, port) return endpoints.HostnameEndpoint(reactor, host, port)
@ -608,9 +608,9 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
yield gatherResults([send_d, receive_d], True) yield gatherResults([send_d, receive_d], True)
if fake_tor: if fake_tor:
expected_endpoints = [("127.0.0.1", self.rdv_ws_port)] expected_endpoints = [("127.0.0.1", self.rdv_ws_port, False)]
if mode in ("file", "directory"): if mode in ("file", "directory"):
expected_endpoints.append(("127.0.0.1", self.transitport)) expected_endpoints.append(("127.0.0.1", self.transitport, False))
tx_timing = mtx_tm.call_args[1]["timing"] tx_timing = mtx_tm.call_args[1]["timing"]
self.assertEqual(tx_tm.endpoints, expected_endpoints) self.assertEqual(tx_tm.endpoints, expected_endpoints)
self.assertEqual( self.assertEqual(

View File

@ -14,7 +14,8 @@ from .. import (__version__, _allocator, _boss, _code, _input, _key, _lister,
_terminator, errors, timing) _terminator, errors, timing)
from .._interfaces import (IAllocator, IBoss, ICode, IDilator, IInput, IKey, from .._interfaces import (IAllocator, IBoss, ICode, IDilator, IInput, IKey,
ILister, IMailbox, INameplate, IOrder, IReceive, ILister, IMailbox, INameplate, IOrder, IReceive,
IRendezvousConnector, ISend, ITerminator, IWordlist) IRendezvousConnector, ISend, ITerminator, IWordlist,
ITorManager)
from .._key import derive_key, derive_phase_key, encrypt_data from .._key import derive_key, derive_phase_key, encrypt_data
from ..journal import ImmediateJournal from ..journal import ImmediateJournal
from ..util import (bytes_to_dict, bytes_to_hexstr, dict_to_bytes, from ..util import (bytes_to_dict, bytes_to_hexstr, dict_to_bytes,
@ -1621,6 +1622,72 @@ class Rendezvous(unittest.TestCase):
("a.lost", ), ("a.lost", ),
]) ])
def test_endpoints(self):
# parse different URLs and check the tls status of each
reactor = object()
journal = ImmediateJournal()
tor_manager = None
client_version = ("python", __version__)
rc = _rendezvous.RendezvousConnector(
"ws://host:4000/v1", "appid", "side", reactor, journal,
tor_manager, timing.DebugTiming(), client_version)
new_ep = object()
with mock.patch("twisted.internet.endpoints.HostnameEndpoint",
return_value=new_ep) as he:
ep = rc._make_endpoint("ws://host:4000/v1")
self.assertEqual(he.mock_calls, [mock.call(reactor, "host", 4000)])
self.assertIs(ep, new_ep)
new_ep = object()
with mock.patch("twisted.internet.endpoints.HostnameEndpoint",
return_value=new_ep) as he:
ep = rc._make_endpoint("ws://host/v1")
self.assertEqual(he.mock_calls, [mock.call(reactor, "host", 80)])
self.assertIs(ep, new_ep)
new_ep = object()
with mock.patch("twisted.internet.endpoints.clientFromString",
return_value=new_ep) as cfs:
ep = rc._make_endpoint("wss://host:4000/v1")
self.assertEqual(cfs.mock_calls, [mock.call(reactor, "tls:host:4000")])
self.assertIs(ep, new_ep)
new_ep = object()
with mock.patch("twisted.internet.endpoints.clientFromString",
return_value=new_ep) as cfs:
ep = rc._make_endpoint("wss://host/v1")
self.assertEqual(cfs.mock_calls, [mock.call(reactor, "tls:host:443")])
self.assertIs(ep, new_ep)
tor_manager = mock.Mock()
directlyProvides(tor_manager, ITorManager)
rc = _rendezvous.RendezvousConnector(
"ws://host:4000/v1", "appid", "side", reactor, journal,
tor_manager, timing.DebugTiming(), client_version)
tor_manager.mock_calls[:] = []
ep = rc._make_endpoint("ws://host:4000/v1")
self.assertEqual(tor_manager.mock_calls,
[mock.call.stream_via("host", 4000, tls=False)])
tor_manager.mock_calls[:] = []
ep = rc._make_endpoint("ws://host/v1")
self.assertEqual(tor_manager.mock_calls,
[mock.call.stream_via("host", 80, tls=False)])
tor_manager.mock_calls[:] = []
ep = rc._make_endpoint("wss://host:4000/v1")
self.assertEqual(tor_manager.mock_calls,
[mock.call.stream_via("host", 4000, tls=True)])
tor_manager.mock_calls[:] = []
ep = rc._make_endpoint("wss://host/v1")
self.assertEqual(tor_manager.mock_calls,
[mock.call.stream_via("host", 443, tls=True)])
# TODO # TODO
# #Send # #Send