diff --git a/src/wormhole/test/test_transit.py b/src/wormhole/test/test_transit.py index 9f76e63..65bdf18 100644 --- a/src/wormhole/test/test_transit.py +++ b/src/wormhole/test/test_transit.py @@ -1,6 +1,7 @@ from __future__ import print_function, unicode_literals import io import gc +import mock from binascii import hexlify, unhexlify from twisted.trial import unittest from twisted.internet import defer, task, endpoints, protocol, address, error @@ -9,6 +10,7 @@ from twisted.python import log, failure from twisted.test import proto_helpers from ..errors import InternalError from .. import transit +from ..server import transit_server from .common import ServerBase from nacl.secret import SecretBox from nacl.exceptions import CryptoError @@ -1320,6 +1322,38 @@ class Transit(unittest.TestCase): relay_connectors[0].callback("winner") self.assertEqual(results, ["winner"]) +class RelayHandshake(unittest.TestCase): + def old_build_relay_handshake(self, key): + token = transit.HKDF(key, 32, CTXinfo=b"transit_relay_token") + return (token, b"please relay "+hexlify(token)+b"\n") + + def test_old(self): + key = b"\x00" + token, old_handshake = self.old_build_relay_handshake(key) + tc = transit_server.TransitConnection() + tc.factory = mock.Mock() + tc.factory.connection_got_token = mock.Mock() + tc.dataReceived(old_handshake[:-1]) + self.assertEqual(tc.factory.connection_got_token.mock_calls, []) + tc.dataReceived(old_handshake[-1:]) + self.assertEqual(tc.factory.connection_got_token.mock_calls, + [mock.call(hexlify(token), None, tc)]) + + def test_new(self): + c = transit.Common(None) + c.set_transit_key(b"\x00") + new_handshake = c._build_relay_handshake() + token, old_handshake = self.old_build_relay_handshake(b"\x00") + + tc = transit_server.TransitConnection() + tc.factory = mock.Mock() + tc.factory.connection_got_token = mock.Mock() + tc.dataReceived(new_handshake[:-1]) + self.assertEqual(tc.factory.connection_got_token.mock_calls, []) + tc.dataReceived(new_handshake[-1:]) + self.assertEqual(tc.factory.connection_got_token.mock_calls, + [mock.call(hexlify(token), c._side.encode("ascii"), tc)]) + class Full(ServerBase, unittest.TestCase): def doBoth(self, d1, d2): diff --git a/src/wormhole/transit.py b/src/wormhole/transit.py index 308e92f..d851070 100644 --- a/src/wormhole/transit.py +++ b/src/wormhole/transit.py @@ -1,6 +1,6 @@ # no unicode_literals, revisit after twisted patch from __future__ import print_function, absolute_import -import re, sys, time, socket +import os, re, sys, time, socket from collections import namedtuple, deque from binascii import hexlify, unhexlify import six @@ -15,6 +15,7 @@ from nacl.secret import SecretBox from hkdf import Hkdf from .errors import InternalError from .timing import DebugTiming +from .util import bytes_to_hexstr from . import ipaddrs def HKDF(skm, outlen, salt=None, CTXinfo=b""): @@ -76,9 +77,11 @@ def build_sender_handshake(key): hexid = HKDF(key, 32, CTXinfo=b"transit_sender") return b"transit sender "+hexlify(hexid)+b" ready\n\n" -def build_relay_handshake(key): +def build_sided_relay_handshake(key, side): + assert isinstance(side, type(u"")) + assert len(side) == 8*2 token = HKDF(key, 32, CTXinfo=b"transit_relay_token") - return b"please relay "+hexlify(token)+b"\n" + return b"please relay "+hexlify(token)+b" for side "+side.encode("ascii")+b"\n" # These namedtuples are "hint objects". The JSON-serializable dictionaries @@ -575,6 +578,7 @@ class Common: def __init__(self, transit_relay, no_listen=False, tor_manager=None, reactor=reactor, timing=None): + self._side = bytes_to_hexstr(os.urandom(8)) # unicode if transit_relay: if not isinstance(transit_relay, type(u"")): raise InternalError @@ -830,11 +834,14 @@ class Common: d.addBoth(_done) return d + def _build_relay_handshake(self): + return build_sided_relay_handshake(self._transit_key, self._side) + def _start_connector(self, ep, description, is_relay=False): relay_handshake = None if is_relay: assert self._transit_key - relay_handshake = build_relay_handshake(self._transit_key) + relay_handshake = self._build_relay_handshake() f = OutboundConnectionFactory(self, relay_handshake, description) d = ep.connect(f) # fires with protocol, or ConnectError