Transit: send new (sided) handshakes
This commit is contained in:
parent
b64f27fdad
commit
e1546bf03f
|
@ -1,6 +1,7 @@
|
||||||
from __future__ import print_function, unicode_literals
|
from __future__ import print_function, unicode_literals
|
||||||
import io
|
import io
|
||||||
import gc
|
import gc
|
||||||
|
import mock
|
||||||
from binascii import hexlify, unhexlify
|
from binascii import hexlify, unhexlify
|
||||||
from twisted.trial import unittest
|
from twisted.trial import unittest
|
||||||
from twisted.internet import defer, task, endpoints, protocol, address, error
|
from twisted.internet import defer, task, endpoints, protocol, address, error
|
||||||
|
@ -9,6 +10,7 @@ from twisted.python import log, failure
|
||||||
from twisted.test import proto_helpers
|
from twisted.test import proto_helpers
|
||||||
from ..errors import InternalError
|
from ..errors import InternalError
|
||||||
from .. import transit
|
from .. import transit
|
||||||
|
from ..server import transit_server
|
||||||
from .common import ServerBase
|
from .common import ServerBase
|
||||||
from nacl.secret import SecretBox
|
from nacl.secret import SecretBox
|
||||||
from nacl.exceptions import CryptoError
|
from nacl.exceptions import CryptoError
|
||||||
|
@ -1320,6 +1322,38 @@ class Transit(unittest.TestCase):
|
||||||
relay_connectors[0].callback("winner")
|
relay_connectors[0].callback("winner")
|
||||||
self.assertEqual(results, ["winner"])
|
self.assertEqual(results, ["winner"])
|
||||||
|
|
||||||
|
class RelayHandshake(unittest.TestCase):
|
||||||
|
def old_build_relay_handshake(self, key):
|
||||||
|
token = transit.HKDF(key, 32, CTXinfo=b"transit_relay_token")
|
||||||
|
return (token, b"please relay "+hexlify(token)+b"\n")
|
||||||
|
|
||||||
|
def test_old(self):
|
||||||
|
key = b"\x00"
|
||||||
|
token, old_handshake = self.old_build_relay_handshake(key)
|
||||||
|
tc = transit_server.TransitConnection()
|
||||||
|
tc.factory = mock.Mock()
|
||||||
|
tc.factory.connection_got_token = mock.Mock()
|
||||||
|
tc.dataReceived(old_handshake[:-1])
|
||||||
|
self.assertEqual(tc.factory.connection_got_token.mock_calls, [])
|
||||||
|
tc.dataReceived(old_handshake[-1:])
|
||||||
|
self.assertEqual(tc.factory.connection_got_token.mock_calls,
|
||||||
|
[mock.call(hexlify(token), None, tc)])
|
||||||
|
|
||||||
|
def test_new(self):
|
||||||
|
c = transit.Common(None)
|
||||||
|
c.set_transit_key(b"\x00")
|
||||||
|
new_handshake = c._build_relay_handshake()
|
||||||
|
token, old_handshake = self.old_build_relay_handshake(b"\x00")
|
||||||
|
|
||||||
|
tc = transit_server.TransitConnection()
|
||||||
|
tc.factory = mock.Mock()
|
||||||
|
tc.factory.connection_got_token = mock.Mock()
|
||||||
|
tc.dataReceived(new_handshake[:-1])
|
||||||
|
self.assertEqual(tc.factory.connection_got_token.mock_calls, [])
|
||||||
|
tc.dataReceived(new_handshake[-1:])
|
||||||
|
self.assertEqual(tc.factory.connection_got_token.mock_calls,
|
||||||
|
[mock.call(hexlify(token), c._side.encode("ascii"), tc)])
|
||||||
|
|
||||||
|
|
||||||
class Full(ServerBase, unittest.TestCase):
|
class Full(ServerBase, unittest.TestCase):
|
||||||
def doBoth(self, d1, d2):
|
def doBoth(self, d1, d2):
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# no unicode_literals, revisit after twisted patch
|
# no unicode_literals, revisit after twisted patch
|
||||||
from __future__ import print_function, absolute_import
|
from __future__ import print_function, absolute_import
|
||||||
import re, sys, time, socket
|
import os, re, sys, time, socket
|
||||||
from collections import namedtuple, deque
|
from collections import namedtuple, deque
|
||||||
from binascii import hexlify, unhexlify
|
from binascii import hexlify, unhexlify
|
||||||
import six
|
import six
|
||||||
|
@ -15,6 +15,7 @@ from nacl.secret import SecretBox
|
||||||
from hkdf import Hkdf
|
from hkdf import Hkdf
|
||||||
from .errors import InternalError
|
from .errors import InternalError
|
||||||
from .timing import DebugTiming
|
from .timing import DebugTiming
|
||||||
|
from .util import bytes_to_hexstr
|
||||||
from . import ipaddrs
|
from . import ipaddrs
|
||||||
|
|
||||||
def HKDF(skm, outlen, salt=None, CTXinfo=b""):
|
def HKDF(skm, outlen, salt=None, CTXinfo=b""):
|
||||||
|
@ -76,9 +77,11 @@ def build_sender_handshake(key):
|
||||||
hexid = HKDF(key, 32, CTXinfo=b"transit_sender")
|
hexid = HKDF(key, 32, CTXinfo=b"transit_sender")
|
||||||
return b"transit sender "+hexlify(hexid)+b" ready\n\n"
|
return b"transit sender "+hexlify(hexid)+b" ready\n\n"
|
||||||
|
|
||||||
def build_relay_handshake(key):
|
def build_sided_relay_handshake(key, side):
|
||||||
|
assert isinstance(side, type(u""))
|
||||||
|
assert len(side) == 8*2
|
||||||
token = HKDF(key, 32, CTXinfo=b"transit_relay_token")
|
token = HKDF(key, 32, CTXinfo=b"transit_relay_token")
|
||||||
return b"please relay "+hexlify(token)+b"\n"
|
return b"please relay "+hexlify(token)+b" for side "+side.encode("ascii")+b"\n"
|
||||||
|
|
||||||
|
|
||||||
# These namedtuples are "hint objects". The JSON-serializable dictionaries
|
# These namedtuples are "hint objects". The JSON-serializable dictionaries
|
||||||
|
@ -575,6 +578,7 @@ class Common:
|
||||||
|
|
||||||
def __init__(self, transit_relay, no_listen=False, tor_manager=None,
|
def __init__(self, transit_relay, no_listen=False, tor_manager=None,
|
||||||
reactor=reactor, timing=None):
|
reactor=reactor, timing=None):
|
||||||
|
self._side = bytes_to_hexstr(os.urandom(8)) # unicode
|
||||||
if transit_relay:
|
if transit_relay:
|
||||||
if not isinstance(transit_relay, type(u"")):
|
if not isinstance(transit_relay, type(u"")):
|
||||||
raise InternalError
|
raise InternalError
|
||||||
|
@ -830,11 +834,14 @@ class Common:
|
||||||
d.addBoth(_done)
|
d.addBoth(_done)
|
||||||
return d
|
return d
|
||||||
|
|
||||||
|
def _build_relay_handshake(self):
|
||||||
|
return build_sided_relay_handshake(self._transit_key, self._side)
|
||||||
|
|
||||||
def _start_connector(self, ep, description, is_relay=False):
|
def _start_connector(self, ep, description, is_relay=False):
|
||||||
relay_handshake = None
|
relay_handshake = None
|
||||||
if is_relay:
|
if is_relay:
|
||||||
assert self._transit_key
|
assert self._transit_key
|
||||||
relay_handshake = build_relay_handshake(self._transit_key)
|
relay_handshake = self._build_relay_handshake()
|
||||||
f = OutboundConnectionFactory(self, relay_handshake, description)
|
f = OutboundConnectionFactory(self, relay_handshake, description)
|
||||||
d = ep.connect(f)
|
d = ep.connect(f)
|
||||||
# fires with protocol, or ConnectError
|
# fires with protocol, or ConnectError
|
||||||
|
|
Loading…
Reference in New Issue
Block a user