Merge PR378: allow custom reactor selection better

This commit is contained in:
Brian Warner 2020-02-06 22:49:54 -08:00
commit f566fae439
2 changed files with 7 additions and 5 deletions

View File

@ -1312,8 +1312,8 @@ class Transit(unittest.TestCase):
@inlineCallbacks @inlineCallbacks
def test_success_direct(self): def test_success_direct(self):
clock = task.Clock() reactor = mock.Mock()
s = transit.TransitSender("", reactor=clock) s = transit.TransitSender("", reactor=reactor)
s.set_transit_key(b"key") s.set_transit_key(b"key")
hints = yield s.get_connection_hints() # start the listener hints = yield s.get_connection_hints() # start the listener
del hints del hints

View File

@ -11,7 +11,7 @@ from collections import deque
import six import six
from nacl.secret import SecretBox from nacl.secret import SecretBox
from twisted.internet import (address, defer, endpoints, error, interfaces, from twisted.internet import (address, defer, endpoints, error, interfaces,
protocol, reactor, task) protocol, task)
from twisted.internet.defer import inlineCallbacks, returnValue from twisted.internet.defer import inlineCallbacks, returnValue
from twisted.protocols import policies from twisted.protocols import policies
from twisted.python import log from twisted.python import log
@ -559,7 +559,7 @@ class Common:
transit_relay, transit_relay,
no_listen=False, no_listen=False,
tor=None, tor=None,
reactor=reactor, reactor=None,
timing=None): timing=None):
self._side = bytes_to_hexstr(os.urandom(8)) # unicode self._side = bytes_to_hexstr(os.urandom(8)) # unicode
if transit_relay: if transit_relay:
@ -579,6 +579,8 @@ class Common:
self._waiting_for_transit_key = [] self._waiting_for_transit_key = []
self._listener = None self._listener = None
self._winner = None self._winner = None
if reactor is None:
from twisted.internet import reactor
self._reactor = reactor self._reactor = reactor
self._timing = timing or DebugTiming() self._timing = timing or DebugTiming()
self._timing.add("transit") self._timing.add("transit")
@ -596,7 +598,7 @@ class Common:
direct_hints = [ direct_hints = [
DirectTCPV1Hint(six.u(addr), portnum, 0.0) for addr in addresses DirectTCPV1Hint(six.u(addr), portnum, 0.0) for addr in addresses
] ]
ep = endpoints.serverFromString(reactor, "tcp:%d" % portnum) ep = endpoints.serverFromString(self._reactor, "tcp:%d" % portnum)
return direct_hints, ep return direct_hints, ep
def get_connection_abilities(self): def get_connection_abilities(self):