diff --git a/src/wormhole/test/test_transit.py b/src/wormhole/test/test_transit.py index c170b32..5c79a63 100644 --- a/src/wormhole/test/test_transit.py +++ b/src/wormhole/test/test_transit.py @@ -1435,11 +1435,13 @@ class Transit(unittest.TestCase): def setUp(self): self._connectors = [] self._waiters = [] + self._descriptions = [] def _start_connector(self, ep, description, is_relay=False): d = defer.Deferred() self._connectors.append(ep) self._waiters.append(d) + self._descriptions.append(description) return d @inlineCallbacks @@ -1463,6 +1465,52 @@ class Transit(unittest.TestCase): self._waiters[0].callback("winner") self.assertEqual(results, ["winner"]) + self.assertEqual(self._descriptions, ["->tcp:direct:1234"]) + + @inlineCallbacks + def test_success_direct_tor(self): + clock = task.Clock() + s = transit.TransitSender("", tor_manager=mock.Mock(), reactor=clock) + s.set_transit_key(b"key") + hints = yield s.get_connection_hints() # start the listener + del hints + s.add_connection_hints([DIRECT_HINT_JSON]) + + s._start_connector = self._start_connector + d = s.connect() + results = [] + d.addBoth(results.append) + self.assertEqual(results, []) + self.assertEqual(len(self._waiters), 1) + self.assertIsInstance(self._waiters[0], defer.Deferred) + + self._waiters[0].callback("winner") + self.assertEqual(results, ["winner"]) + self.assertEqual(self._descriptions, ["tor->tcp:direct:1234"]) + + @inlineCallbacks + def test_success_direct_tor_relay(self): + clock = task.Clock() + s = transit.TransitSender("", tor_manager=mock.Mock(), reactor=clock) + s.set_transit_key(b"key") + hints = yield s.get_connection_hints() # start the listener + del hints + s.add_connection_hints([RELAY_HINT_JSON]) + + s._start_connector = self._start_connector + d = s.connect() + results = [] + d.addBoth(results.append) + # move the clock forward any amount, since relay connections are + # triggered starting at T+0.0 + clock.advance(1.0) + self.assertEqual(results, []) + self.assertEqual(len(self._waiters), 1) + self.assertIsInstance(self._waiters[0], defer.Deferred) + + self._waiters[0].callback("winner") + self.assertEqual(results, ["winner"]) + self.assertEqual(self._descriptions, ["tor->relay:tcp:relay:1234"]) def _endpoint_from_hint_obj(self, hint): if isinstance(hint, transit.DirectTCPV1Hint):