diff --git a/src/wormhole/_rendezvous.py b/src/wormhole/_rendezvous.py index b27f318..3606aff 100644 --- a/src/wormhole/_rendezvous.py +++ b/src/wormhole/_rendezvous.py @@ -81,8 +81,7 @@ class RendezvousConnector(object): self._ws = None f = WSFactory(self, self._url) f.setProtocolOptions(autoPingInterval=60, autoPingTimeout=600) - p = urlparse(self._url) - ep = self._make_endpoint(p.hostname, p.port or 80) + ep = self._make_endpoint(self._url) self._connector = internet.ClientService(ep, f) faf = None if self._have_made_a_successful_connection else 1 d = self._connector.whenConnected(failAfterFailures=faf) @@ -100,11 +99,16 @@ class RendezvousConnector(object): if self._trace: self._trace(old_state="", input=what, new_state="") - def _make_endpoint(self, hostname, port): + def _make_endpoint(self, url): + p = urlparse(url) + tls = (p.scheme == "wss") + port = p.port or (443 if tls else 80) if self._tor: - # TODO: when we enable TLS, maybe add tls=True here - return self._tor.stream_via(hostname, port) - return endpoints.HostnameEndpoint(self._reactor, hostname, port) + return self._tor.stream_via(p.hostname, port, tls=tls) + if tls: + return endpoints.clientFromString(self._reactor, + "tls:%s:%s" % (p.hostname, port)) + return endpoints.HostnameEndpoint(self._reactor, p.hostname, port) def wire(self, boss, nameplate, mailbox, allocator, lister, terminator): self._B = _interfaces.IBoss(boss) diff --git a/src/wormhole/test/test_cli.py b/src/wormhole/test/test_cli.py index 6dc1751..3b92c34 100644 --- a/src/wormhole/test/test_cli.py +++ b/src/wormhole/test/test_cli.py @@ -384,8 +384,8 @@ class FakeTor: def __init__(self): self.endpoints = [] - def stream_via(self, host, port): - self.endpoints.append((host, port)) + def stream_via(self, host, port, tls=False): + self.endpoints.append((host, port, tls)) return endpoints.HostnameEndpoint(reactor, host, port) @@ -608,9 +608,9 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): yield gatherResults([send_d, receive_d], True) if fake_tor: - expected_endpoints = [("127.0.0.1", self.rdv_ws_port)] + expected_endpoints = [("127.0.0.1", self.rdv_ws_port, False)] if mode in ("file", "directory"): - expected_endpoints.append(("127.0.0.1", self.transitport)) + expected_endpoints.append(("127.0.0.1", self.transitport, False)) tx_timing = mtx_tm.call_args[1]["timing"] self.assertEqual(tx_tm.endpoints, expected_endpoints) self.assertEqual( diff --git a/src/wormhole/test/test_machines.py b/src/wormhole/test/test_machines.py index 9e417d6..6256e0a 100644 --- a/src/wormhole/test/test_machines.py +++ b/src/wormhole/test/test_machines.py @@ -14,7 +14,8 @@ from .. import (__version__, _allocator, _boss, _code, _input, _key, _lister, _terminator, errors, timing) from .._interfaces import (IAllocator, IBoss, ICode, IDilator, IInput, IKey, ILister, IMailbox, INameplate, IOrder, IReceive, - IRendezvousConnector, ISend, ITerminator, IWordlist) + IRendezvousConnector, ISend, ITerminator, IWordlist, + ITorManager) from .._key import derive_key, derive_phase_key, encrypt_data from ..journal import ImmediateJournal from ..util import (bytes_to_dict, bytes_to_hexstr, dict_to_bytes, @@ -1621,6 +1622,72 @@ class Rendezvous(unittest.TestCase): ("a.lost", ), ]) + def test_endpoints(self): + # parse different URLs and check the tls status of each + reactor = object() + journal = ImmediateJournal() + tor_manager = None + client_version = ("python", __version__) + rc = _rendezvous.RendezvousConnector( + "ws://host:4000/v1", "appid", "side", reactor, journal, + tor_manager, timing.DebugTiming(), client_version) + + new_ep = object() + with mock.patch("twisted.internet.endpoints.HostnameEndpoint", + return_value=new_ep) as he: + ep = rc._make_endpoint("ws://host:4000/v1") + self.assertEqual(he.mock_calls, [mock.call(reactor, "host", 4000)]) + self.assertIs(ep, new_ep) + + new_ep = object() + with mock.patch("twisted.internet.endpoints.HostnameEndpoint", + return_value=new_ep) as he: + ep = rc._make_endpoint("ws://host/v1") + self.assertEqual(he.mock_calls, [mock.call(reactor, "host", 80)]) + self.assertIs(ep, new_ep) + + new_ep = object() + with mock.patch("twisted.internet.endpoints.clientFromString", + return_value=new_ep) as cfs: + ep = rc._make_endpoint("wss://host:4000/v1") + self.assertEqual(cfs.mock_calls, [mock.call(reactor, "tls:host:4000")]) + self.assertIs(ep, new_ep) + + new_ep = object() + with mock.patch("twisted.internet.endpoints.clientFromString", + return_value=new_ep) as cfs: + ep = rc._make_endpoint("wss://host/v1") + self.assertEqual(cfs.mock_calls, [mock.call(reactor, "tls:host:443")]) + self.assertIs(ep, new_ep) + + tor_manager = mock.Mock() + directlyProvides(tor_manager, ITorManager) + rc = _rendezvous.RendezvousConnector( + "ws://host:4000/v1", "appid", "side", reactor, journal, + tor_manager, timing.DebugTiming(), client_version) + + tor_manager.mock_calls[:] = [] + ep = rc._make_endpoint("ws://host:4000/v1") + self.assertEqual(tor_manager.mock_calls, + [mock.call.stream_via("host", 4000, tls=False)]) + + tor_manager.mock_calls[:] = [] + ep = rc._make_endpoint("ws://host/v1") + self.assertEqual(tor_manager.mock_calls, + [mock.call.stream_via("host", 80, tls=False)]) + + tor_manager.mock_calls[:] = [] + ep = rc._make_endpoint("wss://host:4000/v1") + self.assertEqual(tor_manager.mock_calls, + [mock.call.stream_via("host", 4000, tls=True)]) + + tor_manager.mock_calls[:] = [] + ep = rc._make_endpoint("wss://host/v1") + self.assertEqual(tor_manager.mock_calls, + [mock.call.stream_via("host", 443, tls=True)]) + + + # TODO # #Send