Transit: send new (sided) handshakes

This commit is contained in:
Brian Warner 2016-12-22 18:17:05 -05:00
parent b64f27fdad
commit e1546bf03f
2 changed files with 45 additions and 4 deletions

View File

@ -1,6 +1,7 @@
from __future__ import print_function, unicode_literals from __future__ import print_function, unicode_literals
import io import io
import gc import gc
import mock
from binascii import hexlify, unhexlify from binascii import hexlify, unhexlify
from twisted.trial import unittest from twisted.trial import unittest
from twisted.internet import defer, task, endpoints, protocol, address, error from twisted.internet import defer, task, endpoints, protocol, address, error
@ -9,6 +10,7 @@ from twisted.python import log, failure
from twisted.test import proto_helpers from twisted.test import proto_helpers
from ..errors import InternalError from ..errors import InternalError
from .. import transit from .. import transit
from ..server import transit_server
from .common import ServerBase from .common import ServerBase
from nacl.secret import SecretBox from nacl.secret import SecretBox
from nacl.exceptions import CryptoError from nacl.exceptions import CryptoError
@ -1320,6 +1322,38 @@ class Transit(unittest.TestCase):
relay_connectors[0].callback("winner") relay_connectors[0].callback("winner")
self.assertEqual(results, ["winner"]) self.assertEqual(results, ["winner"])
class RelayHandshake(unittest.TestCase):
def old_build_relay_handshake(self, key):
token = transit.HKDF(key, 32, CTXinfo=b"transit_relay_token")
return (token, b"please relay "+hexlify(token)+b"\n")
def test_old(self):
key = b"\x00"
token, old_handshake = self.old_build_relay_handshake(key)
tc = transit_server.TransitConnection()
tc.factory = mock.Mock()
tc.factory.connection_got_token = mock.Mock()
tc.dataReceived(old_handshake[:-1])
self.assertEqual(tc.factory.connection_got_token.mock_calls, [])
tc.dataReceived(old_handshake[-1:])
self.assertEqual(tc.factory.connection_got_token.mock_calls,
[mock.call(hexlify(token), None, tc)])
def test_new(self):
c = transit.Common(None)
c.set_transit_key(b"\x00")
new_handshake = c._build_relay_handshake()
token, old_handshake = self.old_build_relay_handshake(b"\x00")
tc = transit_server.TransitConnection()
tc.factory = mock.Mock()
tc.factory.connection_got_token = mock.Mock()
tc.dataReceived(new_handshake[:-1])
self.assertEqual(tc.factory.connection_got_token.mock_calls, [])
tc.dataReceived(new_handshake[-1:])
self.assertEqual(tc.factory.connection_got_token.mock_calls,
[mock.call(hexlify(token), c._side.encode("ascii"), tc)])
class Full(ServerBase, unittest.TestCase): class Full(ServerBase, unittest.TestCase):
def doBoth(self, d1, d2): def doBoth(self, d1, d2):

View File

@ -1,6 +1,6 @@
# no unicode_literals, revisit after twisted patch # no unicode_literals, revisit after twisted patch
from __future__ import print_function, absolute_import from __future__ import print_function, absolute_import
import re, sys, time, socket import os, re, sys, time, socket
from collections import namedtuple, deque from collections import namedtuple, deque
from binascii import hexlify, unhexlify from binascii import hexlify, unhexlify
import six import six
@ -15,6 +15,7 @@ from nacl.secret import SecretBox
from hkdf import Hkdf from hkdf import Hkdf
from .errors import InternalError from .errors import InternalError
from .timing import DebugTiming from .timing import DebugTiming
from .util import bytes_to_hexstr
from . import ipaddrs from . import ipaddrs
def HKDF(skm, outlen, salt=None, CTXinfo=b""): def HKDF(skm, outlen, salt=None, CTXinfo=b""):
@ -76,9 +77,11 @@ def build_sender_handshake(key):
hexid = HKDF(key, 32, CTXinfo=b"transit_sender") hexid = HKDF(key, 32, CTXinfo=b"transit_sender")
return b"transit sender "+hexlify(hexid)+b" ready\n\n" return b"transit sender "+hexlify(hexid)+b" ready\n\n"
def build_relay_handshake(key): def build_sided_relay_handshake(key, side):
assert isinstance(side, type(u""))
assert len(side) == 8*2
token = HKDF(key, 32, CTXinfo=b"transit_relay_token") token = HKDF(key, 32, CTXinfo=b"transit_relay_token")
return b"please relay "+hexlify(token)+b"\n" return b"please relay "+hexlify(token)+b" for side "+side.encode("ascii")+b"\n"
# These namedtuples are "hint objects". The JSON-serializable dictionaries # These namedtuples are "hint objects". The JSON-serializable dictionaries
@ -575,6 +578,7 @@ class Common:
def __init__(self, transit_relay, no_listen=False, tor_manager=None, def __init__(self, transit_relay, no_listen=False, tor_manager=None,
reactor=reactor, timing=None): reactor=reactor, timing=None):
self._side = bytes_to_hexstr(os.urandom(8)) # unicode
if transit_relay: if transit_relay:
if not isinstance(transit_relay, type(u"")): if not isinstance(transit_relay, type(u"")):
raise InternalError raise InternalError
@ -830,11 +834,14 @@ class Common:
d.addBoth(_done) d.addBoth(_done)
return d return d
def _build_relay_handshake(self):
return build_sided_relay_handshake(self._transit_key, self._side)
def _start_connector(self, ep, description, is_relay=False): def _start_connector(self, ep, description, is_relay=False):
relay_handshake = None relay_handshake = None
if is_relay: if is_relay:
assert self._transit_key assert self._transit_key
relay_handshake = build_relay_handshake(self._transit_key) relay_handshake = self._build_relay_handshake()
f = OutboundConnectionFactory(self, relay_handshake, description) f = OutboundConnectionFactory(self, relay_handshake, description)
d = ep.connect(f) d = ep.connect(f)
# fires with protocol, or ConnectError # fires with protocol, or ConnectError