diff --git a/src/wormhole/test/test_transit.py b/src/wormhole/test/test_transit.py index 4d1e53f..dee0e63 100644 --- a/src/wormhole/test/test_transit.py +++ b/src/wormhole/test/test_transit.py @@ -138,29 +138,29 @@ class Hints(unittest.TestCase): def test_endpoint_from_hint_obj(self): c = transit.Common("") efho = c._endpoint_from_hint_obj - self.assertIsInstance(efho(transit.DirectTCPV1Hint("host", 1234)), + self.assertIsInstance(efho(transit.DirectTCPV1Hint("host", 1234, 0.0)), endpoints.HostnameEndpoint) self.assertEqual(efho("unknown:stuff:yowza:pivlor"), None) # c._tor_manager is currently None - self.assertEqual(efho(transit.TorTCPV1Hint("host", "port")), None) + self.assertEqual(efho(transit.TorTCPV1Hint("host", "port", 0)), None) c._tor_manager = mock.Mock() def tor_ep(hostname, port): if hostname == "non-public": return None return ("tor_ep", hostname, port) c._tor_manager.get_endpoint_for = mock.Mock(side_effect=tor_ep) - self.assertEqual(efho(transit.DirectTCPV1Hint("host", 1234)), + self.assertEqual(efho(transit.DirectTCPV1Hint("host", 1234, 0.0)), ("tor_ep", "host", 1234)) - self.assertEqual(efho(transit.TorTCPV1Hint("host2.onion", 1234)), + self.assertEqual(efho(transit.TorTCPV1Hint("host2.onion", 1234, 0.0)), ("tor_ep", "host2.onion", 1234)) - self.assertEqual(efho(transit.DirectTCPV1Hint("non-public", 1234)), + self.assertEqual(efho(transit.DirectTCPV1Hint("non-public", 1234, 0.0)), None) self.assertEqual(efho(UnknownHint("foo")), None) def test_comparable(self): - h1 = transit.DirectTCPV1Hint("hostname", "port1") - h1b = transit.DirectTCPV1Hint("hostname", "port1") - h2 = transit.DirectTCPV1Hint("hostname", "port2") + h1 = transit.DirectTCPV1Hint("hostname", "port1", 0.0) + h1b = transit.DirectTCPV1Hint("hostname", "port1", 0.0) + h2 = transit.DirectTCPV1Hint("hostname", "port2", 0.0) r1 = transit.RelayV1Hint(tuple(sorted([h1, h2]))) r2 = transit.RelayV1Hint(tuple(sorted([h2, h1]))) r3 = transit.RelayV1Hint(tuple(sorted([h1b, h2]))) @@ -173,9 +173,15 @@ class Hints(unittest.TestCase): p = c._parse_tcp_v1_hint self.assertEqual(p({"type": "unknown"}), None) h = p({"type": "direct-tcp-v1", "hostname": "foo", "port": 1234}) - self.assertEqual(h, transit.DirectTCPV1Hint("foo", 1234)) + self.assertEqual(h, transit.DirectTCPV1Hint("foo", 1234, 0.0)) + h = p({"type": "direct-tcp-v1", "hostname": "foo", "port": 1234, + "priority": 2.5}) + self.assertEqual(h, transit.DirectTCPV1Hint("foo", 1234, 2.5)) h = p({"type": "tor-tcp-v1", "hostname": "foo", "port": 1234}) - self.assertEqual(h, transit.TorTCPV1Hint("foo", 1234)) + self.assertEqual(h, transit.TorTCPV1Hint("foo", 1234, 0.0)) + h = p({"type": "tor-tcp-v1", "hostname": "foo", "port": 1234, + "priority": 2.5}) + self.assertEqual(h, transit.TorTCPV1Hint("foo", 1234, 2.5)) self.assertEqual(p({"type": "direct-tcp-v1"}), None) # missing hostname self.assertEqual(p({"type": "direct-tcp-v1", "hostname": 12}), @@ -192,7 +198,11 @@ class Hints(unittest.TestCase): value = transit.parse_hint_argv(hint, stderr=stderr) return value, stderr.getvalue() h,stderr = p("tcp:host:1234") - self.assertEqual(h, transit.DirectTCPV1Hint("host", 1234)) + self.assertEqual(h, transit.DirectTCPV1Hint("host", 1234, 0.0)) + self.assertEqual(stderr, "") + + h,stderr = p("tcp:host:1234:priority=2.6") + self.assertEqual(h, transit.DirectTCPV1Hint("host", 1234, 2.6)) self.assertEqual(stderr, "") h,stderr = p("$!@#^") @@ -204,16 +214,26 @@ class Hints(unittest.TestCase): self.assertEqual(stderr, "unknown hint type 'unknown' in 'unknown:stuff'\n") + h,stderr = p("tcp:just-a-hostname") + self.assertEqual(h, None) + self.assertEqual(stderr, + "unparseable TCP hint (need more colons) 'tcp:just-a-hostname'\n") + h,stderr = p("tcp:host:number") self.assertEqual(h, None) self.assertEqual(stderr, - "unparseable TCP hint 'tcp:host:number'\n") + "non-numeric port in TCP hint 'tcp:host:number'\n") + + h,stderr = p("tcp:host:1234:priority=bad") + self.assertEqual(h, None) + self.assertEqual(stderr, + "non-float priority= in TCP hint 'tcp:host:1234:priority=bad'\n") def test_describe_hint_obj(self): d = transit.describe_hint_obj - self.assertEqual(d(transit.DirectTCPV1Hint("host", 1234)), + self.assertEqual(d(transit.DirectTCPV1Hint("host", 1234, 0.0)), "tcp:host:1234") - self.assertEqual(d(transit.TorTCPV1Hint("host", 1234)), + self.assertEqual(d(transit.TorTCPV1Hint("host", 1234, 0.0)), "tor:host:1234") self.assertEqual(d(UnknownHint("stuff")), str(UnknownHint("stuff"))) @@ -225,8 +245,9 @@ class Basic(unittest.TestCase): hints = yield c.get_connection_hints() self.assertEqual(hints, [{"type": "relay-v1", "hints": [{"type": "direct-tcp-v1", - "hostname": "host", - "port": 1234}], + "hostname": "host", + "port": 1234, + "priority": 0.0}], }]) self.assertRaises(InternalError, transit.Common, 123) diff --git a/src/wormhole/transit.py b/src/wormhole/transit.py index 284fb1b..20b8d88 100644 --- a/src/wormhole/transit.py +++ b/src/wormhole/transit.py @@ -87,8 +87,8 @@ def build_sided_relay_handshake(key, side): # * expect to see the receiver/sender handshake bytes from the other side # * the sender writes "go\n", the receiver waits for "go\n" # * the rest of the connection contains transit data -DirectTCPV1Hint = namedtuple("DirectTCPV1Hint", ["hostname", "port"]) -TorTCPV1Hint = namedtuple("TorTCPV1Hint", ["hostname", "port"]) +DirectTCPV1Hint = namedtuple("DirectTCPV1Hint", ["hostname", "port", "priority"]) +TorTCPV1Hint = namedtuple("TorTCPV1Hint", ["hostname", "port", "priority"]) # RelayV1Hint contains a tuple of DirectTCPV1Hint and TorTCPV1Hint hints (we # use a tuple rather than a list so they'll be hashable into a set). For each # one, make the TCP connection, send the relay handshake, then complete the @@ -106,6 +106,7 @@ def describe_hint_obj(hint): def parse_hint_argv(hint, stderr=sys.stderr): assert isinstance(hint, type(u"")) # return tuple or None for an unparseable hint + priority = 0.0 mo = re.search(r'^([a-zA-Z0-9]+):(.*)$', hint) if not mo: print("unparseable hint '%s'" % (hint,), file=stderr) @@ -115,17 +116,27 @@ def parse_hint_argv(hint, stderr=sys.stderr): print("unknown hint type '%s' in '%s'" % (hint_type, hint), file=stderr) return None hint_value = mo.group(2) - mo = re.search(r'^(.*):(\d+)$', hint_value) - if not mo: - print("unparseable TCP hint '%s'" % (hint,), file=stderr) + pieces = hint_value.split(":") + if len(pieces) < 2: + print("unparseable TCP hint (need more colons) '%s'" % (hint,), + file=stderr) return None - hint_host = mo.group(1) - try: - hint_port = int(mo.group(2)) - except ValueError: + mo = re.search(r'^(\d+)$', pieces[1]) + if not mo: print("non-numeric port in TCP hint '%s'" % (hint,), file=stderr) return None - return DirectTCPV1Hint(hint_host, hint_port) + hint_host = pieces[0] + hint_port = int(pieces[1]) + for more in pieces[2:]: + if more.startswith("priority="): + more_pieces = more.split("=") + try: + priority = float(more_pieces[1]) + except ValueError: + print("non-float priority= in TCP hint '%s'" % (hint,), + file=stderr) + return None + return DirectTCPV1Hint(hint_host, hint_port, priority) TIMEOUT=15 @@ -604,7 +615,7 @@ class Common: # some test hosts, including the appveyor VMs, *only* have # 127.0.0.1, and the tests will hang badly if we remove it. addresses = non_loopback_addresses - direct_hints = [DirectTCPV1Hint(six.u(addr), portnum) + direct_hints = [DirectTCPV1Hint(six.u(addr), portnum, 0.0) for addr in addresses] ep = endpoints.serverFromString(reactor, "tcp:%d" % portnum) return direct_hints, ep @@ -620,6 +631,7 @@ class Common: direct_hints = yield self._get_direct_hints() for dh in direct_hints: hints.append({u"type": u"direct-tcp-v1", + u"priority": dh.priority, u"hostname": dh.hostname, u"port": dh.port, # integer }) @@ -627,6 +639,7 @@ class Common: rhint = {u"type": u"relay-v1", u"hints": []} for rh in relay.hints: rhint[u"hints"].append({u"type": u"direct-tcp-v1", + u"priority": rh.priority, u"hostname": rh.hostname, u"port": rh.port}) hints.append(rhint) @@ -686,10 +699,11 @@ class Common: if not(u"port" in hint and isinstance(hint[u"port"], int)): log.msg("invalid port in hint: %r" % (hint,)) return None + priority = hint.get(u"priority", 0.0) if hint_type == u"direct-tcp-v1": - return DirectTCPV1Hint(hint[u"hostname"], hint[u"port"]) + return DirectTCPV1Hint(hint[u"hostname"], hint[u"port"], priority) else: - return TorTCPV1Hint(hint[u"hostname"], hint[u"port"]) + return TorTCPV1Hint(hint[u"hostname"], hint[u"port"], priority) def add_connection_hints(self, hints): for h in hints: # hint structs