parse/transmit/record hint priorities

Use --transit-helper=tcp:HOST:PORT:priority=1.3 to override the default 0.0 .
Larger (positive) priority numbers will be attempted first.
This commit is contained in:
Brian Warner 2016-12-20 20:34:42 -05:00
parent bc17047983
commit 8b864c3eae
2 changed files with 64 additions and 29 deletions

View File

@ -138,29 +138,29 @@ class Hints(unittest.TestCase):
def test_endpoint_from_hint_obj(self): def test_endpoint_from_hint_obj(self):
c = transit.Common("") c = transit.Common("")
efho = c._endpoint_from_hint_obj efho = c._endpoint_from_hint_obj
self.assertIsInstance(efho(transit.DirectTCPV1Hint("host", 1234)), self.assertIsInstance(efho(transit.DirectTCPV1Hint("host", 1234, 0.0)),
endpoints.HostnameEndpoint) endpoints.HostnameEndpoint)
self.assertEqual(efho("unknown:stuff:yowza:pivlor"), None) self.assertEqual(efho("unknown:stuff:yowza:pivlor"), None)
# c._tor_manager is currently None # c._tor_manager is currently None
self.assertEqual(efho(transit.TorTCPV1Hint("host", "port")), None) self.assertEqual(efho(transit.TorTCPV1Hint("host", "port", 0)), None)
c._tor_manager = mock.Mock() c._tor_manager = mock.Mock()
def tor_ep(hostname, port): def tor_ep(hostname, port):
if hostname == "non-public": if hostname == "non-public":
return None return None
return ("tor_ep", hostname, port) return ("tor_ep", hostname, port)
c._tor_manager.get_endpoint_for = mock.Mock(side_effect=tor_ep) c._tor_manager.get_endpoint_for = mock.Mock(side_effect=tor_ep)
self.assertEqual(efho(transit.DirectTCPV1Hint("host", 1234)), self.assertEqual(efho(transit.DirectTCPV1Hint("host", 1234, 0.0)),
("tor_ep", "host", 1234)) ("tor_ep", "host", 1234))
self.assertEqual(efho(transit.TorTCPV1Hint("host2.onion", 1234)), self.assertEqual(efho(transit.TorTCPV1Hint("host2.onion", 1234, 0.0)),
("tor_ep", "host2.onion", 1234)) ("tor_ep", "host2.onion", 1234))
self.assertEqual(efho(transit.DirectTCPV1Hint("non-public", 1234)), self.assertEqual(efho(transit.DirectTCPV1Hint("non-public", 1234, 0.0)),
None) None)
self.assertEqual(efho(UnknownHint("foo")), None) self.assertEqual(efho(UnknownHint("foo")), None)
def test_comparable(self): def test_comparable(self):
h1 = transit.DirectTCPV1Hint("hostname", "port1") h1 = transit.DirectTCPV1Hint("hostname", "port1", 0.0)
h1b = transit.DirectTCPV1Hint("hostname", "port1") h1b = transit.DirectTCPV1Hint("hostname", "port1", 0.0)
h2 = transit.DirectTCPV1Hint("hostname", "port2") h2 = transit.DirectTCPV1Hint("hostname", "port2", 0.0)
r1 = transit.RelayV1Hint(tuple(sorted([h1, h2]))) r1 = transit.RelayV1Hint(tuple(sorted([h1, h2])))
r2 = transit.RelayV1Hint(tuple(sorted([h2, h1]))) r2 = transit.RelayV1Hint(tuple(sorted([h2, h1])))
r3 = transit.RelayV1Hint(tuple(sorted([h1b, h2]))) r3 = transit.RelayV1Hint(tuple(sorted([h1b, h2])))
@ -173,9 +173,15 @@ class Hints(unittest.TestCase):
p = c._parse_tcp_v1_hint p = c._parse_tcp_v1_hint
self.assertEqual(p({"type": "unknown"}), None) self.assertEqual(p({"type": "unknown"}), None)
h = p({"type": "direct-tcp-v1", "hostname": "foo", "port": 1234}) h = p({"type": "direct-tcp-v1", "hostname": "foo", "port": 1234})
self.assertEqual(h, transit.DirectTCPV1Hint("foo", 1234)) self.assertEqual(h, transit.DirectTCPV1Hint("foo", 1234, 0.0))
h = p({"type": "direct-tcp-v1", "hostname": "foo", "port": 1234,
"priority": 2.5})
self.assertEqual(h, transit.DirectTCPV1Hint("foo", 1234, 2.5))
h = p({"type": "tor-tcp-v1", "hostname": "foo", "port": 1234}) h = p({"type": "tor-tcp-v1", "hostname": "foo", "port": 1234})
self.assertEqual(h, transit.TorTCPV1Hint("foo", 1234)) self.assertEqual(h, transit.TorTCPV1Hint("foo", 1234, 0.0))
h = p({"type": "tor-tcp-v1", "hostname": "foo", "port": 1234,
"priority": 2.5})
self.assertEqual(h, transit.TorTCPV1Hint("foo", 1234, 2.5))
self.assertEqual(p({"type": "direct-tcp-v1"}), self.assertEqual(p({"type": "direct-tcp-v1"}),
None) # missing hostname None) # missing hostname
self.assertEqual(p({"type": "direct-tcp-v1", "hostname": 12}), self.assertEqual(p({"type": "direct-tcp-v1", "hostname": 12}),
@ -192,7 +198,11 @@ class Hints(unittest.TestCase):
value = transit.parse_hint_argv(hint, stderr=stderr) value = transit.parse_hint_argv(hint, stderr=stderr)
return value, stderr.getvalue() return value, stderr.getvalue()
h,stderr = p("tcp:host:1234") h,stderr = p("tcp:host:1234")
self.assertEqual(h, transit.DirectTCPV1Hint("host", 1234)) self.assertEqual(h, transit.DirectTCPV1Hint("host", 1234, 0.0))
self.assertEqual(stderr, "")
h,stderr = p("tcp:host:1234:priority=2.6")
self.assertEqual(h, transit.DirectTCPV1Hint("host", 1234, 2.6))
self.assertEqual(stderr, "") self.assertEqual(stderr, "")
h,stderr = p("$!@#^") h,stderr = p("$!@#^")
@ -204,16 +214,26 @@ class Hints(unittest.TestCase):
self.assertEqual(stderr, self.assertEqual(stderr,
"unknown hint type 'unknown' in 'unknown:stuff'\n") "unknown hint type 'unknown' in 'unknown:stuff'\n")
h,stderr = p("tcp:just-a-hostname")
self.assertEqual(h, None)
self.assertEqual(stderr,
"unparseable TCP hint (need more colons) 'tcp:just-a-hostname'\n")
h,stderr = p("tcp:host:number") h,stderr = p("tcp:host:number")
self.assertEqual(h, None) self.assertEqual(h, None)
self.assertEqual(stderr, self.assertEqual(stderr,
"unparseable TCP hint 'tcp:host:number'\n") "non-numeric port in TCP hint 'tcp:host:number'\n")
h,stderr = p("tcp:host:1234:priority=bad")
self.assertEqual(h, None)
self.assertEqual(stderr,
"non-float priority= in TCP hint 'tcp:host:1234:priority=bad'\n")
def test_describe_hint_obj(self): def test_describe_hint_obj(self):
d = transit.describe_hint_obj d = transit.describe_hint_obj
self.assertEqual(d(transit.DirectTCPV1Hint("host", 1234)), self.assertEqual(d(transit.DirectTCPV1Hint("host", 1234, 0.0)),
"tcp:host:1234") "tcp:host:1234")
self.assertEqual(d(transit.TorTCPV1Hint("host", 1234)), self.assertEqual(d(transit.TorTCPV1Hint("host", 1234, 0.0)),
"tor:host:1234") "tor:host:1234")
self.assertEqual(d(UnknownHint("stuff")), str(UnknownHint("stuff"))) self.assertEqual(d(UnknownHint("stuff")), str(UnknownHint("stuff")))
@ -226,7 +246,8 @@ class Basic(unittest.TestCase):
self.assertEqual(hints, [{"type": "relay-v1", self.assertEqual(hints, [{"type": "relay-v1",
"hints": [{"type": "direct-tcp-v1", "hints": [{"type": "direct-tcp-v1",
"hostname": "host", "hostname": "host",
"port": 1234}], "port": 1234,
"priority": 0.0}],
}]) }])
self.assertRaises(InternalError, transit.Common, 123) self.assertRaises(InternalError, transit.Common, 123)

View File

@ -87,8 +87,8 @@ def build_sided_relay_handshake(key, side):
# * expect to see the receiver/sender handshake bytes from the other side # * expect to see the receiver/sender handshake bytes from the other side
# * the sender writes "go\n", the receiver waits for "go\n" # * the sender writes "go\n", the receiver waits for "go\n"
# * the rest of the connection contains transit data # * the rest of the connection contains transit data
DirectTCPV1Hint = namedtuple("DirectTCPV1Hint", ["hostname", "port"]) DirectTCPV1Hint = namedtuple("DirectTCPV1Hint", ["hostname", "port", "priority"])
TorTCPV1Hint = namedtuple("TorTCPV1Hint", ["hostname", "port"]) TorTCPV1Hint = namedtuple("TorTCPV1Hint", ["hostname", "port", "priority"])
# RelayV1Hint contains a tuple of DirectTCPV1Hint and TorTCPV1Hint hints (we # RelayV1Hint contains a tuple of DirectTCPV1Hint and TorTCPV1Hint hints (we
# use a tuple rather than a list so they'll be hashable into a set). For each # use a tuple rather than a list so they'll be hashable into a set). For each
# one, make the TCP connection, send the relay handshake, then complete the # one, make the TCP connection, send the relay handshake, then complete the
@ -106,6 +106,7 @@ def describe_hint_obj(hint):
def parse_hint_argv(hint, stderr=sys.stderr): def parse_hint_argv(hint, stderr=sys.stderr):
assert isinstance(hint, type(u"")) assert isinstance(hint, type(u""))
# return tuple or None for an unparseable hint # return tuple or None for an unparseable hint
priority = 0.0
mo = re.search(r'^([a-zA-Z0-9]+):(.*)$', hint) mo = re.search(r'^([a-zA-Z0-9]+):(.*)$', hint)
if not mo: if not mo:
print("unparseable hint '%s'" % (hint,), file=stderr) print("unparseable hint '%s'" % (hint,), file=stderr)
@ -115,17 +116,27 @@ def parse_hint_argv(hint, stderr=sys.stderr):
print("unknown hint type '%s' in '%s'" % (hint_type, hint), file=stderr) print("unknown hint type '%s' in '%s'" % (hint_type, hint), file=stderr)
return None return None
hint_value = mo.group(2) hint_value = mo.group(2)
mo = re.search(r'^(.*):(\d+)$', hint_value) pieces = hint_value.split(":")
if not mo: if len(pieces) < 2:
print("unparseable TCP hint '%s'" % (hint,), file=stderr) print("unparseable TCP hint (need more colons) '%s'" % (hint,),
file=stderr)
return None return None
hint_host = mo.group(1) mo = re.search(r'^(\d+)$', pieces[1])
try: if not mo:
hint_port = int(mo.group(2))
except ValueError:
print("non-numeric port in TCP hint '%s'" % (hint,), file=stderr) print("non-numeric port in TCP hint '%s'" % (hint,), file=stderr)
return None return None
return DirectTCPV1Hint(hint_host, hint_port) hint_host = pieces[0]
hint_port = int(pieces[1])
for more in pieces[2:]:
if more.startswith("priority="):
more_pieces = more.split("=")
try:
priority = float(more_pieces[1])
except ValueError:
print("non-float priority= in TCP hint '%s'" % (hint,),
file=stderr)
return None
return DirectTCPV1Hint(hint_host, hint_port, priority)
TIMEOUT=15 TIMEOUT=15
@ -604,7 +615,7 @@ class Common:
# some test hosts, including the appveyor VMs, *only* have # some test hosts, including the appveyor VMs, *only* have
# 127.0.0.1, and the tests will hang badly if we remove it. # 127.0.0.1, and the tests will hang badly if we remove it.
addresses = non_loopback_addresses addresses = non_loopback_addresses
direct_hints = [DirectTCPV1Hint(six.u(addr), portnum) direct_hints = [DirectTCPV1Hint(six.u(addr), portnum, 0.0)
for addr in addresses] for addr in addresses]
ep = endpoints.serverFromString(reactor, "tcp:%d" % portnum) ep = endpoints.serverFromString(reactor, "tcp:%d" % portnum)
return direct_hints, ep return direct_hints, ep
@ -620,6 +631,7 @@ class Common:
direct_hints = yield self._get_direct_hints() direct_hints = yield self._get_direct_hints()
for dh in direct_hints: for dh in direct_hints:
hints.append({u"type": u"direct-tcp-v1", hints.append({u"type": u"direct-tcp-v1",
u"priority": dh.priority,
u"hostname": dh.hostname, u"hostname": dh.hostname,
u"port": dh.port, # integer u"port": dh.port, # integer
}) })
@ -627,6 +639,7 @@ class Common:
rhint = {u"type": u"relay-v1", u"hints": []} rhint = {u"type": u"relay-v1", u"hints": []}
for rh in relay.hints: for rh in relay.hints:
rhint[u"hints"].append({u"type": u"direct-tcp-v1", rhint[u"hints"].append({u"type": u"direct-tcp-v1",
u"priority": rh.priority,
u"hostname": rh.hostname, u"hostname": rh.hostname,
u"port": rh.port}) u"port": rh.port})
hints.append(rhint) hints.append(rhint)
@ -686,10 +699,11 @@ class Common:
if not(u"port" in hint and isinstance(hint[u"port"], int)): if not(u"port" in hint and isinstance(hint[u"port"], int)):
log.msg("invalid port in hint: %r" % (hint,)) log.msg("invalid port in hint: %r" % (hint,))
return None return None
priority = hint.get(u"priority", 0.0)
if hint_type == u"direct-tcp-v1": if hint_type == u"direct-tcp-v1":
return DirectTCPV1Hint(hint[u"hostname"], hint[u"port"]) return DirectTCPV1Hint(hint[u"hostname"], hint[u"port"], priority)
else: else:
return TorTCPV1Hint(hint[u"hostname"], hint[u"port"]) return TorTCPV1Hint(hint[u"hostname"], hint[u"port"], priority)
def add_connection_hints(self, hints): def add_connection_hints(self, hints):
for h in hints: # hint structs for h in hints: # hint structs