diff --git a/src/wormhole/test/test_tor_manager.py b/src/wormhole/test/test_tor_manager.py index f14e97c..1f04354 100644 --- a/src/wormhole/test/test_tor_manager.py +++ b/src/wormhole/test/test_tor_manager.py @@ -51,29 +51,37 @@ class Tor(unittest.TestCase): reactor = object() my_tor = X() # object() didn't like providedBy() tcp = "port" + ep = object() connect_d = defer.Deferred() stderr = io.StringIO() with mock.patch("wormhole.tor_manager.txtorcon.connect", side_effect=connect_d) as connect: - d = get_tor(reactor, tor_control_port=tcp, stderr=stderr) - self.assertNoResult(d) - self.assertEqual(connect.mock_calls, [mock.call(reactor, tcp)]) - connect_d.callback(my_tor) - tor = self.successResultOf(d) - self.assertIs(tor, my_tor) - self.assert_(ITorManager.providedBy(tor)) - self.assertEqual(stderr.getvalue(), " using Tor via control port\n") + with mock.patch("wormhole.tor_manager.clientFromString", + side_effect=[ep]) as sfs: + d = get_tor(reactor, tor_control_port=tcp, stderr=stderr) + self.assertEqual(sfs.mock_calls, [mock.call(reactor, tcp)]) + self.assertNoResult(d) + self.assertEqual(connect.mock_calls, [mock.call(reactor, ep)]) + connect_d.callback(my_tor) + tor = self.successResultOf(d) + self.assertIs(tor, my_tor) + self.assert_(ITorManager.providedBy(tor)) + self.assertEqual(stderr.getvalue(), " using Tor via control port\n") def test_connect_fails(self): reactor = object() tcp = "port" + ep = object() connect_d = defer.Deferred() stderr = io.StringIO() with mock.patch("wormhole.tor_manager.txtorcon.connect", side_effect=connect_d) as connect: - d = get_tor(reactor, tor_control_port=tcp, stderr=stderr) - self.assertNoResult(d) - self.assertEqual(connect.mock_calls, [mock.call(reactor, tcp)]) + with mock.patch("wormhole.tor_manager.clientFromString", + side_effect=[ep]) as sfs: + d = get_tor(reactor, tor_control_port=tcp, stderr=stderr) + self.assertEqual(sfs.mock_calls, [mock.call(reactor, tcp)]) + self.assertNoResult(d) + self.assertEqual(connect.mock_calls, [mock.call(reactor, ep)]) connect_d.errback(ConnectError()) tor = self.successResultOf(d) diff --git a/src/wormhole/tor_manager.py b/src/wormhole/tor_manager.py index 3aa2a11..f92a627 100644 --- a/src/wormhole/tor_manager.py +++ b/src/wormhole/tor_manager.py @@ -3,6 +3,7 @@ import sys from attr import attrs, attrib from zope.interface.declarations import directlyProvides from twisted.internet.defer import inlineCallbacks, returnValue +from twisted.internet.endpoints import clientFromString try: import txtorcon except ImportError: @@ -86,6 +87,9 @@ def get_tor(reactor, launch_tor=False, tor_control_port=None, # If tor_control_port is None (the default), txtorcon # will look through a list of usual places. If it is set, # it will look only in the place we tell it to. + if tor_control_port is not None: + tor_control_port = clientFromString(reactor, + tor_control_port) tor = yield txtorcon.connect(reactor, tor_control_port) print(" using Tor via control port", file=stderr) except Exception: