Merge branch '103-transit-priority'
This commit is contained in:
commit
822fb212c4
|
@ -138,29 +138,29 @@ class Hints(unittest.TestCase):
|
||||||
def test_endpoint_from_hint_obj(self):
|
def test_endpoint_from_hint_obj(self):
|
||||||
c = transit.Common("")
|
c = transit.Common("")
|
||||||
efho = c._endpoint_from_hint_obj
|
efho = c._endpoint_from_hint_obj
|
||||||
self.assertIsInstance(efho(transit.DirectTCPV1Hint("host", 1234)),
|
self.assertIsInstance(efho(transit.DirectTCPV1Hint("host", 1234, 0.0)),
|
||||||
endpoints.HostnameEndpoint)
|
endpoints.HostnameEndpoint)
|
||||||
self.assertEqual(efho("unknown:stuff:yowza:pivlor"), None)
|
self.assertEqual(efho("unknown:stuff:yowza:pivlor"), None)
|
||||||
# c._tor_manager is currently None
|
# c._tor_manager is currently None
|
||||||
self.assertEqual(efho(transit.TorTCPV1Hint("host", "port")), None)
|
self.assertEqual(efho(transit.TorTCPV1Hint("host", "port", 0)), None)
|
||||||
c._tor_manager = mock.Mock()
|
c._tor_manager = mock.Mock()
|
||||||
def tor_ep(hostname, port):
|
def tor_ep(hostname, port):
|
||||||
if hostname == "non-public":
|
if hostname == "non-public":
|
||||||
return None
|
return None
|
||||||
return ("tor_ep", hostname, port)
|
return ("tor_ep", hostname, port)
|
||||||
c._tor_manager.get_endpoint_for = mock.Mock(side_effect=tor_ep)
|
c._tor_manager.get_endpoint_for = mock.Mock(side_effect=tor_ep)
|
||||||
self.assertEqual(efho(transit.DirectTCPV1Hint("host", 1234)),
|
self.assertEqual(efho(transit.DirectTCPV1Hint("host", 1234, 0.0)),
|
||||||
("tor_ep", "host", 1234))
|
("tor_ep", "host", 1234))
|
||||||
self.assertEqual(efho(transit.TorTCPV1Hint("host2.onion", 1234)),
|
self.assertEqual(efho(transit.TorTCPV1Hint("host2.onion", 1234, 0.0)),
|
||||||
("tor_ep", "host2.onion", 1234))
|
("tor_ep", "host2.onion", 1234))
|
||||||
self.assertEqual(efho(transit.DirectTCPV1Hint("non-public", 1234)),
|
self.assertEqual(efho(transit.DirectTCPV1Hint("non-public", 1234, 0.0)),
|
||||||
None)
|
None)
|
||||||
self.assertEqual(efho(UnknownHint("foo")), None)
|
self.assertEqual(efho(UnknownHint("foo")), None)
|
||||||
|
|
||||||
def test_comparable(self):
|
def test_comparable(self):
|
||||||
h1 = transit.DirectTCPV1Hint("hostname", "port1")
|
h1 = transit.DirectTCPV1Hint("hostname", "port1", 0.0)
|
||||||
h1b = transit.DirectTCPV1Hint("hostname", "port1")
|
h1b = transit.DirectTCPV1Hint("hostname", "port1", 0.0)
|
||||||
h2 = transit.DirectTCPV1Hint("hostname", "port2")
|
h2 = transit.DirectTCPV1Hint("hostname", "port2", 0.0)
|
||||||
r1 = transit.RelayV1Hint(tuple(sorted([h1, h2])))
|
r1 = transit.RelayV1Hint(tuple(sorted([h1, h2])))
|
||||||
r2 = transit.RelayV1Hint(tuple(sorted([h2, h1])))
|
r2 = transit.RelayV1Hint(tuple(sorted([h2, h1])))
|
||||||
r3 = transit.RelayV1Hint(tuple(sorted([h1b, h2])))
|
r3 = transit.RelayV1Hint(tuple(sorted([h1b, h2])))
|
||||||
|
@ -173,9 +173,15 @@ class Hints(unittest.TestCase):
|
||||||
p = c._parse_tcp_v1_hint
|
p = c._parse_tcp_v1_hint
|
||||||
self.assertEqual(p({"type": "unknown"}), None)
|
self.assertEqual(p({"type": "unknown"}), None)
|
||||||
h = p({"type": "direct-tcp-v1", "hostname": "foo", "port": 1234})
|
h = p({"type": "direct-tcp-v1", "hostname": "foo", "port": 1234})
|
||||||
self.assertEqual(h, transit.DirectTCPV1Hint("foo", 1234))
|
self.assertEqual(h, transit.DirectTCPV1Hint("foo", 1234, 0.0))
|
||||||
|
h = p({"type": "direct-tcp-v1", "hostname": "foo", "port": 1234,
|
||||||
|
"priority": 2.5})
|
||||||
|
self.assertEqual(h, transit.DirectTCPV1Hint("foo", 1234, 2.5))
|
||||||
h = p({"type": "tor-tcp-v1", "hostname": "foo", "port": 1234})
|
h = p({"type": "tor-tcp-v1", "hostname": "foo", "port": 1234})
|
||||||
self.assertEqual(h, transit.TorTCPV1Hint("foo", 1234))
|
self.assertEqual(h, transit.TorTCPV1Hint("foo", 1234, 0.0))
|
||||||
|
h = p({"type": "tor-tcp-v1", "hostname": "foo", "port": 1234,
|
||||||
|
"priority": 2.5})
|
||||||
|
self.assertEqual(h, transit.TorTCPV1Hint("foo", 1234, 2.5))
|
||||||
self.assertEqual(p({"type": "direct-tcp-v1"}),
|
self.assertEqual(p({"type": "direct-tcp-v1"}),
|
||||||
None) # missing hostname
|
None) # missing hostname
|
||||||
self.assertEqual(p({"type": "direct-tcp-v1", "hostname": 12}),
|
self.assertEqual(p({"type": "direct-tcp-v1", "hostname": 12}),
|
||||||
|
@ -192,7 +198,11 @@ class Hints(unittest.TestCase):
|
||||||
value = transit.parse_hint_argv(hint, stderr=stderr)
|
value = transit.parse_hint_argv(hint, stderr=stderr)
|
||||||
return value, stderr.getvalue()
|
return value, stderr.getvalue()
|
||||||
h,stderr = p("tcp:host:1234")
|
h,stderr = p("tcp:host:1234")
|
||||||
self.assertEqual(h, transit.DirectTCPV1Hint("host", 1234))
|
self.assertEqual(h, transit.DirectTCPV1Hint("host", 1234, 0.0))
|
||||||
|
self.assertEqual(stderr, "")
|
||||||
|
|
||||||
|
h,stderr = p("tcp:host:1234:priority=2.6")
|
||||||
|
self.assertEqual(h, transit.DirectTCPV1Hint("host", 1234, 2.6))
|
||||||
self.assertEqual(stderr, "")
|
self.assertEqual(stderr, "")
|
||||||
|
|
||||||
h,stderr = p("$!@#^")
|
h,stderr = p("$!@#^")
|
||||||
|
@ -204,16 +214,26 @@ class Hints(unittest.TestCase):
|
||||||
self.assertEqual(stderr,
|
self.assertEqual(stderr,
|
||||||
"unknown hint type 'unknown' in 'unknown:stuff'\n")
|
"unknown hint type 'unknown' in 'unknown:stuff'\n")
|
||||||
|
|
||||||
|
h,stderr = p("tcp:just-a-hostname")
|
||||||
|
self.assertEqual(h, None)
|
||||||
|
self.assertEqual(stderr,
|
||||||
|
"unparseable TCP hint (need more colons) 'tcp:just-a-hostname'\n")
|
||||||
|
|
||||||
h,stderr = p("tcp:host:number")
|
h,stderr = p("tcp:host:number")
|
||||||
self.assertEqual(h, None)
|
self.assertEqual(h, None)
|
||||||
self.assertEqual(stderr,
|
self.assertEqual(stderr,
|
||||||
"unparseable TCP hint 'tcp:host:number'\n")
|
"non-numeric port in TCP hint 'tcp:host:number'\n")
|
||||||
|
|
||||||
|
h,stderr = p("tcp:host:1234:priority=bad")
|
||||||
|
self.assertEqual(h, None)
|
||||||
|
self.assertEqual(stderr,
|
||||||
|
"non-float priority= in TCP hint 'tcp:host:1234:priority=bad'\n")
|
||||||
|
|
||||||
def test_describe_hint_obj(self):
|
def test_describe_hint_obj(self):
|
||||||
d = transit.describe_hint_obj
|
d = transit.describe_hint_obj
|
||||||
self.assertEqual(d(transit.DirectTCPV1Hint("host", 1234)),
|
self.assertEqual(d(transit.DirectTCPV1Hint("host", 1234, 0.0)),
|
||||||
"tcp:host:1234")
|
"tcp:host:1234")
|
||||||
self.assertEqual(d(transit.TorTCPV1Hint("host", 1234)),
|
self.assertEqual(d(transit.TorTCPV1Hint("host", 1234, 0.0)),
|
||||||
"tor:host:1234")
|
"tor:host:1234")
|
||||||
self.assertEqual(d(UnknownHint("stuff")), str(UnknownHint("stuff")))
|
self.assertEqual(d(UnknownHint("stuff")), str(UnknownHint("stuff")))
|
||||||
|
|
||||||
|
@ -226,7 +246,8 @@ class Basic(unittest.TestCase):
|
||||||
self.assertEqual(hints, [{"type": "relay-v1",
|
self.assertEqual(hints, [{"type": "relay-v1",
|
||||||
"hints": [{"type": "direct-tcp-v1",
|
"hints": [{"type": "direct-tcp-v1",
|
||||||
"hostname": "host",
|
"hostname": "host",
|
||||||
"port": 1234}],
|
"port": 1234,
|
||||||
|
"priority": 0.0}],
|
||||||
}])
|
}])
|
||||||
self.assertRaises(InternalError, transit.Common, 123)
|
self.assertRaises(InternalError, transit.Common, 123)
|
||||||
|
|
||||||
|
@ -1353,6 +1374,56 @@ class Transit(unittest.TestCase):
|
||||||
self._waiters[0].callback("winner")
|
self._waiters[0].callback("winner")
|
||||||
self.assertEqual(results, ["winner"])
|
self.assertEqual(results, ["winner"])
|
||||||
|
|
||||||
|
@inlineCallbacks
|
||||||
|
def test_priorities(self):
|
||||||
|
clock = task.Clock()
|
||||||
|
s = transit.TransitSender("", reactor=clock, no_listen=True)
|
||||||
|
s.set_transit_key(b"key")
|
||||||
|
hints = yield s.get_connection_hints()
|
||||||
|
del hints
|
||||||
|
s.add_connection_hints([
|
||||||
|
{"type": "relay-v1",
|
||||||
|
"hints": [{"type": "direct-tcp-v1",
|
||||||
|
"hostname": "relay", "port": 1234}]},
|
||||||
|
{"type": "direct-tcp-v1",
|
||||||
|
"hostname": "direct", "port": 1234},
|
||||||
|
{"type": "relay-v1",
|
||||||
|
"hints": [{"type": "direct-tcp-v1", "priority": 2.0,
|
||||||
|
"hostname": "relay2", "port": 1234},
|
||||||
|
{"type": "direct-tcp-v1", "priority": 3.0,
|
||||||
|
"hostname": "relay3", "port": 1234}]},
|
||||||
|
{"type": "relay-v1",
|
||||||
|
"hints": [{"type": "direct-tcp-v1", "priority": 2.0,
|
||||||
|
"hostname": "relay4", "port": 1234}]},
|
||||||
|
])
|
||||||
|
s._endpoint_from_hint_obj = self._endpoint_from_hint_obj
|
||||||
|
s._start_connector = self._start_connector
|
||||||
|
|
||||||
|
d = s.connect()
|
||||||
|
results = []
|
||||||
|
d.addBoth(results.append)
|
||||||
|
self.assertEqual(results, [])
|
||||||
|
# direct connector should be used first, then the priority=3.0 relay,
|
||||||
|
# then the two 2.0 relays, then the (default) 0.0 relay
|
||||||
|
|
||||||
|
self.assertEqual(self._connectors, ["direct"])
|
||||||
|
|
||||||
|
clock.advance(s.RELAY_DELAY + 1.0)
|
||||||
|
self.assertEqual(self._connectors, ["direct", "relay3"])
|
||||||
|
|
||||||
|
clock.advance(s.RELAY_DELAY)
|
||||||
|
self.assertIn(self._connectors,
|
||||||
|
(["direct", "relay3", "relay2", "relay4"],
|
||||||
|
["direct", "relay3", "relay4", "relay2"]))
|
||||||
|
|
||||||
|
clock.advance(s.RELAY_DELAY)
|
||||||
|
self.assertIn(self._connectors,
|
||||||
|
(["direct", "relay3", "relay2", "relay4", "relay"],
|
||||||
|
["direct", "relay3", "relay4", "relay2", "relay"]))
|
||||||
|
|
||||||
|
self._waiters[0].callback("winner")
|
||||||
|
self.assertEqual(results, ["winner"])
|
||||||
|
|
||||||
@inlineCallbacks
|
@inlineCallbacks
|
||||||
def test_no_direct_hints(self):
|
def test_no_direct_hints(self):
|
||||||
clock = task.Clock()
|
clock = task.Clock()
|
||||||
|
|
|
@ -87,8 +87,8 @@ def build_sided_relay_handshake(key, side):
|
||||||
# * expect to see the receiver/sender handshake bytes from the other side
|
# * expect to see the receiver/sender handshake bytes from the other side
|
||||||
# * the sender writes "go\n", the receiver waits for "go\n"
|
# * the sender writes "go\n", the receiver waits for "go\n"
|
||||||
# * the rest of the connection contains transit data
|
# * the rest of the connection contains transit data
|
||||||
DirectTCPV1Hint = namedtuple("DirectTCPV1Hint", ["hostname", "port"])
|
DirectTCPV1Hint = namedtuple("DirectTCPV1Hint", ["hostname", "port", "priority"])
|
||||||
TorTCPV1Hint = namedtuple("TorTCPV1Hint", ["hostname", "port"])
|
TorTCPV1Hint = namedtuple("TorTCPV1Hint", ["hostname", "port", "priority"])
|
||||||
# RelayV1Hint contains a tuple of DirectTCPV1Hint and TorTCPV1Hint hints (we
|
# RelayV1Hint contains a tuple of DirectTCPV1Hint and TorTCPV1Hint hints (we
|
||||||
# use a tuple rather than a list so they'll be hashable into a set). For each
|
# use a tuple rather than a list so they'll be hashable into a set). For each
|
||||||
# one, make the TCP connection, send the relay handshake, then complete the
|
# one, make the TCP connection, send the relay handshake, then complete the
|
||||||
|
@ -106,6 +106,7 @@ def describe_hint_obj(hint):
|
||||||
def parse_hint_argv(hint, stderr=sys.stderr):
|
def parse_hint_argv(hint, stderr=sys.stderr):
|
||||||
assert isinstance(hint, type(u""))
|
assert isinstance(hint, type(u""))
|
||||||
# return tuple or None for an unparseable hint
|
# return tuple or None for an unparseable hint
|
||||||
|
priority = 0.0
|
||||||
mo = re.search(r'^([a-zA-Z0-9]+):(.*)$', hint)
|
mo = re.search(r'^([a-zA-Z0-9]+):(.*)$', hint)
|
||||||
if not mo:
|
if not mo:
|
||||||
print("unparseable hint '%s'" % (hint,), file=stderr)
|
print("unparseable hint '%s'" % (hint,), file=stderr)
|
||||||
|
@ -115,17 +116,27 @@ def parse_hint_argv(hint, stderr=sys.stderr):
|
||||||
print("unknown hint type '%s' in '%s'" % (hint_type, hint), file=stderr)
|
print("unknown hint type '%s' in '%s'" % (hint_type, hint), file=stderr)
|
||||||
return None
|
return None
|
||||||
hint_value = mo.group(2)
|
hint_value = mo.group(2)
|
||||||
mo = re.search(r'^(.*):(\d+)$', hint_value)
|
pieces = hint_value.split(":")
|
||||||
if not mo:
|
if len(pieces) < 2:
|
||||||
print("unparseable TCP hint '%s'" % (hint,), file=stderr)
|
print("unparseable TCP hint (need more colons) '%s'" % (hint,),
|
||||||
|
file=stderr)
|
||||||
return None
|
return None
|
||||||
hint_host = mo.group(1)
|
mo = re.search(r'^(\d+)$', pieces[1])
|
||||||
try:
|
if not mo:
|
||||||
hint_port = int(mo.group(2))
|
|
||||||
except ValueError:
|
|
||||||
print("non-numeric port in TCP hint '%s'" % (hint,), file=stderr)
|
print("non-numeric port in TCP hint '%s'" % (hint,), file=stderr)
|
||||||
return None
|
return None
|
||||||
return DirectTCPV1Hint(hint_host, hint_port)
|
hint_host = pieces[0]
|
||||||
|
hint_port = int(pieces[1])
|
||||||
|
for more in pieces[2:]:
|
||||||
|
if more.startswith("priority="):
|
||||||
|
more_pieces = more.split("=")
|
||||||
|
try:
|
||||||
|
priority = float(more_pieces[1])
|
||||||
|
except ValueError:
|
||||||
|
print("non-float priority= in TCP hint '%s'" % (hint,),
|
||||||
|
file=stderr)
|
||||||
|
return None
|
||||||
|
return DirectTCPV1Hint(hint_host, hint_port, priority)
|
||||||
|
|
||||||
TIMEOUT=15
|
TIMEOUT=15
|
||||||
|
|
||||||
|
@ -604,7 +615,7 @@ class Common:
|
||||||
# some test hosts, including the appveyor VMs, *only* have
|
# some test hosts, including the appveyor VMs, *only* have
|
||||||
# 127.0.0.1, and the tests will hang badly if we remove it.
|
# 127.0.0.1, and the tests will hang badly if we remove it.
|
||||||
addresses = non_loopback_addresses
|
addresses = non_loopback_addresses
|
||||||
direct_hints = [DirectTCPV1Hint(six.u(addr), portnum)
|
direct_hints = [DirectTCPV1Hint(six.u(addr), portnum, 0.0)
|
||||||
for addr in addresses]
|
for addr in addresses]
|
||||||
ep = endpoints.serverFromString(reactor, "tcp:%d" % portnum)
|
ep = endpoints.serverFromString(reactor, "tcp:%d" % portnum)
|
||||||
return direct_hints, ep
|
return direct_hints, ep
|
||||||
|
@ -620,6 +631,7 @@ class Common:
|
||||||
direct_hints = yield self._get_direct_hints()
|
direct_hints = yield self._get_direct_hints()
|
||||||
for dh in direct_hints:
|
for dh in direct_hints:
|
||||||
hints.append({u"type": u"direct-tcp-v1",
|
hints.append({u"type": u"direct-tcp-v1",
|
||||||
|
u"priority": dh.priority,
|
||||||
u"hostname": dh.hostname,
|
u"hostname": dh.hostname,
|
||||||
u"port": dh.port, # integer
|
u"port": dh.port, # integer
|
||||||
})
|
})
|
||||||
|
@ -627,6 +639,7 @@ class Common:
|
||||||
rhint = {u"type": u"relay-v1", u"hints": []}
|
rhint = {u"type": u"relay-v1", u"hints": []}
|
||||||
for rh in relay.hints:
|
for rh in relay.hints:
|
||||||
rhint[u"hints"].append({u"type": u"direct-tcp-v1",
|
rhint[u"hints"].append({u"type": u"direct-tcp-v1",
|
||||||
|
u"priority": rh.priority,
|
||||||
u"hostname": rh.hostname,
|
u"hostname": rh.hostname,
|
||||||
u"port": rh.port})
|
u"port": rh.port})
|
||||||
hints.append(rhint)
|
hints.append(rhint)
|
||||||
|
@ -686,10 +699,11 @@ class Common:
|
||||||
if not(u"port" in hint and isinstance(hint[u"port"], int)):
|
if not(u"port" in hint and isinstance(hint[u"port"], int)):
|
||||||
log.msg("invalid port in hint: %r" % (hint,))
|
log.msg("invalid port in hint: %r" % (hint,))
|
||||||
return None
|
return None
|
||||||
|
priority = hint.get(u"priority", 0.0)
|
||||||
if hint_type == u"direct-tcp-v1":
|
if hint_type == u"direct-tcp-v1":
|
||||||
return DirectTCPV1Hint(hint[u"hostname"], hint[u"port"])
|
return DirectTCPV1Hint(hint[u"hostname"], hint[u"port"], priority)
|
||||||
else:
|
else:
|
||||||
return TorTCPV1Hint(hint[u"hostname"], hint[u"port"])
|
return TorTCPV1Hint(hint[u"hostname"], hint[u"port"], priority)
|
||||||
|
|
||||||
def add_connection_hints(self, hints):
|
def add_connection_hints(self, hints):
|
||||||
for h in hints: # hint structs
|
for h in hints: # hint structs
|
||||||
|
@ -807,8 +821,17 @@ class Common:
|
||||||
# resolve quickly. Many direct hints will be to unused local-network
|
# resolve quickly. Many direct hints will be to unused local-network
|
||||||
# IP addresses, which won't answer, and would take the full TCP
|
# IP addresses, which won't answer, and would take the full TCP
|
||||||
# timeout (30s or more) to fail.
|
# timeout (30s or more) to fail.
|
||||||
|
|
||||||
|
prioritized_relays = {}
|
||||||
for rh in self._our_relay_hints:
|
for rh in self._our_relay_hints:
|
||||||
for hint_obj in rh.hints:
|
for hint_obj in rh.hints:
|
||||||
|
priority = hint_obj.priority
|
||||||
|
if priority not in prioritized_relays:
|
||||||
|
prioritized_relays[priority] = set()
|
||||||
|
prioritized_relays[priority].add(hint_obj)
|
||||||
|
|
||||||
|
for priority in sorted(prioritized_relays, reverse=True):
|
||||||
|
for hint_obj in prioritized_relays[priority]:
|
||||||
ep = self._endpoint_from_hint_obj(hint_obj)
|
ep = self._endpoint_from_hint_obj(hint_obj)
|
||||||
if not ep:
|
if not ep:
|
||||||
continue
|
continue
|
||||||
|
@ -817,6 +840,7 @@ class Common:
|
||||||
self._start_connector, ep, description,
|
self._start_connector, ep, description,
|
||||||
is_relay=True)
|
is_relay=True)
|
||||||
contenders.append(d)
|
contenders.append(d)
|
||||||
|
relay_delay += self.RELAY_DELAY
|
||||||
|
|
||||||
if not contenders:
|
if not contenders:
|
||||||
raise TransitError("No contenders for connection")
|
raise TransitError("No contenders for connection")
|
||||||
|
|
Loading…
Reference in New Issue
Block a user