Merge: enhance transit server

This adds a new kind of handshake message, which lets the Transit Relay
server tell when two connections (for the same channel) are really from the
same client, and therefore should not be connected to each other. The 0.8.2
client speaks the old handshake, but a future client will speak the new
handshake.

refs #115
This commit is contained in:
Brian Warner 2016-12-23 22:18:26 -05:00
commit b64f27fdad
3 changed files with 301 additions and 46 deletions

View File

@ -25,6 +25,7 @@ def blur_size(size):
class TransitConnection(protocol.Protocol):
def __init__(self):
self._got_token = False
self._got_side = False
self._token_buffer = b""
self._sent_ok = False
self._buddy = None
@ -32,9 +33,14 @@ class TransitConnection(protocol.Protocol):
self._total_sent = 0
def describeToken(self):
d = "-"
if self._got_token:
return self._got_token[:16].decode("ascii")
return "-"
d = self._got_token[:16].decode("ascii")
if self._got_side:
d += "-" + self._got_side.decode("ascii")
else:
d += "-<unsided>"
return d
def connectionMade(self):
self._started = time.time()
@ -59,26 +65,69 @@ class TransitConnection(protocol.Protocol):
# else this should be (part of) the token
self._token_buffer += data
buf = self._token_buffer
wanted = len("please relay \n")+32*2
if len(buf) < wanted-1 and b"\n" in buf:
self.transport.write(b"bad handshake\n")
log.msg("transit handshake early failure")
return self.disconnect()
if len(buf) < wanted:
return
if len(buf) > wanted:
# old: "please relay {64}\n"
# new: "please relay {64} for side {16}\n"
(old, handshake_len, token) = self._check_old_handshake(buf)
assert old in ("yes", "waiting", "no")
if old == "yes":
# remember they aren't supposed to send anything past their
# handshake until we've said go
if len(buf) > handshake_len:
self.transport.write(b"impatient\n")
log.msg("transit impatience failure")
return self.disconnect() # impatience yields failure
mo = re.search(br"^please relay (\w{64})\n", buf, re.M)
if not mo:
return self._got_handshake(token, None)
(new, handshake_len, token, side) = self._check_new_handshake(buf)
assert new in ("yes", "waiting", "no")
if new == "yes":
if len(buf) > handshake_len:
self.transport.write(b"impatient\n")
log.msg("transit impatience failure")
return self.disconnect() # impatience yields failure
return self._got_handshake(token, side)
if (old == "no" and new == "no"):
self.transport.write(b"bad handshake\n")
log.msg("transit handshake failure")
return self.disconnect() # incorrectness yields failure
token = mo.group(1)
# else we'll keep waiting
def _check_old_handshake(self, buf):
# old: "please relay {64}\n"
# return ("yes", handshake, token) if buf contains an old-style handshake
# return ("waiting", None, None) if it might eventually contain one
# return ("no", None, None) if it could never contain one
wanted = len("please relay \n")+32*2
if len(buf) < wanted-1 and b"\n" in buf:
return ("no", None, None)
if len(buf) < wanted:
return ("waiting", None, None)
mo = re.search(br"^please relay (\w{64})\n", buf, re.M)
if mo:
token = mo.group(1)
return ("yes", wanted, token)
return ("no", None, None)
def _check_new_handshake(self, buf):
# new: "please relay {64} for side {16}\n"
wanted = len("please relay for side \n")+32*2+8*2
if len(buf) < wanted-1 and b"\n" in buf:
return ("no", None, None, None)
if len(buf) < wanted:
return ("waiting", None, None, None)
mo = re.search(br"^please relay (\w{64}) for side (\w{16})\n", buf, re.M)
if mo:
token = mo.group(1)
side = mo.group(2)
return ("yes", wanted, token, side)
return ("no", None, None, None)
def _got_handshake(self, token, side):
self._got_token = token
self.factory.connection_got_token(token, self)
self._got_side = side
self.factory.connection_got_token(token, side, self)
def buddy_connected(self, them):
self._buddy = them
@ -100,7 +149,7 @@ class TransitConnection(protocol.Protocol):
def connectionLost(self, reason):
if self._buddy:
self._buddy.buddy_disconnected()
self.factory.transitFinished(self, self._got_token,
self.factory.transitFinished(self, self._got_token, self._got_side,
self.describeToken())
# Record usage. There are four cases:
@ -136,12 +185,18 @@ class TransitConnection(protocol.Protocol):
class Transit(protocol.ServerFactory, service.MultiService):
# I manage pairs of simultaneous connections to a secondary TCP port,
# both forwarded to the other. Clients must begin each connection with
# "please relay TOKEN\n". I will send "ok\n" when the matching connection
# is established, or disconnect if no matching connection is made within
# MAX_WAIT_TIME seconds. I will disconnect if you send data before the
# "ok\n". All data you get after the "ok\n" will be from the other side.
# You will not receive "ok\n" until the other side has also connected and
# submitted a matching token. The token is the same for each side.
# "please relay TOKEN for SIDE\n" (or a legacy form without the "for
# SIDE"). Two connections match if they use the same TOKEN and have
# different SIDEs (the redundant connections are dropped when a match is
# made). Legacy connections match any with the same TOKEN, ignoring SIDE
# (so two legacy connections will match each other).
# I will send "ok\n" when the matching connection is established, or
# disconnect if no matching connection is made within MAX_WAIT_TIME
# seconds. I will disconnect if you send data before the "ok\n". All data
# you get after the "ok\n" will be from the other side. You will not
# receive "ok\n" until the other side has also connected and submitted a
# matching token (and differing SIDE).
# In addition, the connections will be dropped after MAXLENGTH bytes have
# been sent by either side, or MAXTIME seconds have elapsed after the
@ -164,22 +219,37 @@ class Transit(protocol.ServerFactory, service.MultiService):
service.MultiService.__init__(self)
self._db = db
self._blur_usage = blur_usage
self._pending_requests = {} # token -> TransitConnection
self._pending_requests = {} # token -> set((side, TransitConnection))
self._active_connections = set() # TransitConnection
self._counts = collections.defaultdict(int)
self._count_bytes = 0
def connection_got_token(self, token, p):
if token in self._pending_requests:
log.msg("transit relay 2: %s" % p.describeToken())
buddy = self._pending_requests.pop(token)
self._active_connections.add(p)
self._active_connections.add(buddy)
p.buddy_connected(buddy)
buddy.buddy_connected(p)
else:
self._pending_requests[token] = p
log.msg("transit relay 1: %s" % p.describeToken())
def connection_got_token(self, token, new_side, new_tc):
if token not in self._pending_requests:
self._pending_requests[token] = set()
potentials = self._pending_requests[token]
for old in potentials:
(old_side, old_tc) = old
if ((old_side is None)
or (new_side is None)
or (old_side != new_side)):
# we found a match
log.msg("transit relay 2: %s" % new_tc.describeToken())
# drop and stop tracking the rest
potentials.remove(old)
for (_, leftover_tc) in potentials:
leftover_tc.disconnect() # TODO: not "errory"?
self._pending_requests.pop(token)
# glue the two ends together
self._active_connections.add(new_tc)
self._active_connections.add(old_tc)
new_tc.buddy_connected(old_tc)
old_tc.buddy_connected(new_tc)
return
log.msg("transit relay 1: %s" % new_tc.describeToken())
potentials.add((new_side, new_tc))
# TODO: timer
def recordUsage(self, started, result, total_bytes,
@ -198,13 +268,15 @@ class Transit(protocol.ServerFactory, service.MultiService):
self._counts[result] += 1
self._count_bytes += total_bytes
def transitFinished(self, p, token, description):
for token,tc in self._pending_requests.items():
if tc is p:
def transitFinished(self, tc, token, side, description):
if token in self._pending_requests:
side_tc = (side, tc)
if side_tc in self._pending_requests[token]:
self._pending_requests[token].remove(side_tc)
if not self._pending_requests[token]: # set is now empty
del self._pending_requests[token]
break
log.msg("transitFinished %s" % (description,))
self._active_connections.discard(p)
self._active_connections.discard(tc)
def transitFailed(self, p):
log.msg("transitFailed %r" % p)

View File

@ -1072,6 +1072,7 @@ class Accumulator(protocol.Protocol):
self.data = b""
self.count = 0
self._wait = None
self._disconnect = defer.Deferred()
def waitForBytes(self, more):
assert self._wait is None
self.count = more
@ -1089,6 +1090,7 @@ class Accumulator(protocol.Protocol):
def connectionLost(self, why):
if self._wait:
self._wait.errback(RuntimeError("closed"))
self._disconnect.callback(None)
class Transit(ServerBase, unittest.TestCase):
def test_blur_size(self):
@ -1115,7 +1117,32 @@ class Transit(ServerBase, unittest.TestCase):
self.assertEqual('Wormhole Relay'.encode('ascii'), resp.strip())
@defer.inlineCallbacks
def test_basic(self):
def test_register(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
token1 = b"\x00"*32
side1 = b"\x01"*8
a1.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\n")
# let that arrive
while self.count() == 0:
yield self.wait()
self.assertEqual(self.count(), 1)
a1.transport.loseConnection()
# let that get removed
while self.count() > 0:
yield self.wait()
self.assertEqual(self.count(), 0)
# the token should be removed too
self.assertEqual(len(self._transit_server._pending_requests), 0)
@defer.inlineCallbacks
def test_both_unsided(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
a2 = yield connectProtocol(ep, Accumulator())
@ -1143,6 +1170,133 @@ class Transit(ServerBase, unittest.TestCase):
a1.transport.loseConnection()
a2.transport.loseConnection()
@defer.inlineCallbacks
def test_sided_unsided(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
a2 = yield connectProtocol(ep, Accumulator())
token1 = b"\x00"*32
side1 = b"\x01"*8
a1.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\n")
a2.transport.write(b"please relay " + hexlify(token1) + b"\n")
# a correct handshake yields an ack, after which we can send
exp = b"ok\n"
yield a1.waitForBytes(len(exp))
self.assertEqual(a1.data, exp)
s1 = b"data1"
a1.transport.write(s1)
exp = b"ok\n"
yield a2.waitForBytes(len(exp))
self.assertEqual(a2.data, exp)
# all data they sent after the handshake should be given to us
exp = b"ok\n"+s1
yield a2.waitForBytes(len(exp))
self.assertEqual(a2.data, exp)
a1.transport.loseConnection()
a2.transport.loseConnection()
@defer.inlineCallbacks
def test_unsided_sided(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
a2 = yield connectProtocol(ep, Accumulator())
token1 = b"\x00"*32
side1 = b"\x01"*8
a1.transport.write(b"please relay " + hexlify(token1) + b"\n")
a2.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\n")
# a correct handshake yields an ack, after which we can send
exp = b"ok\n"
yield a1.waitForBytes(len(exp))
self.assertEqual(a1.data, exp)
s1 = b"data1"
a1.transport.write(s1)
exp = b"ok\n"
yield a2.waitForBytes(len(exp))
self.assertEqual(a2.data, exp)
# all data they sent after the handshake should be given to us
exp = b"ok\n"+s1
yield a2.waitForBytes(len(exp))
self.assertEqual(a2.data, exp)
a1.transport.loseConnection()
a2.transport.loseConnection()
@defer.inlineCallbacks
def test_both_sided(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
a2 = yield connectProtocol(ep, Accumulator())
token1 = b"\x00"*32
side1 = b"\x01"*8
side2 = b"\x02"*8
a1.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\n")
a2.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side2) + b"\n")
# a correct handshake yields an ack, after which we can send
exp = b"ok\n"
yield a1.waitForBytes(len(exp))
self.assertEqual(a1.data, exp)
s1 = b"data1"
a1.transport.write(s1)
exp = b"ok\n"
yield a2.waitForBytes(len(exp))
self.assertEqual(a2.data, exp)
# all data they sent after the handshake should be given to us
exp = b"ok\n"+s1
yield a2.waitForBytes(len(exp))
self.assertEqual(a2.data, exp)
a1.transport.loseConnection()
a2.transport.loseConnection()
def count(self):
return sum([len(potentials)
for potentials
in self._transit_server._pending_requests.values()])
def wait(self):
d = defer.Deferred()
reactor.callLater(0.001, d.callback, None)
return d
@defer.inlineCallbacks
def test_ignore_same_side(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
a2 = yield connectProtocol(ep, Accumulator())
token1 = b"\x00"*32
side1 = b"\x01"*8
a1.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\n")
# let that arrive
while self.count() == 0:
yield self.wait()
a2.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\n")
# let that arrive
while self.count() == 1:
yield self.wait()
self.assertEqual(self.count(), 2) # same-side connections don't match
a1.transport.loseConnection()
a2.transport.loseConnection()
@defer.inlineCallbacks
def test_bad_handshake(self):
ep = clientFromString(reactor, self.transit)
@ -1186,7 +1340,7 @@ class Transit(ServerBase, unittest.TestCase):
token1 = b"\x00"*32
# sending too many bytes is impatience.
a1.transport.write(b"please RELAY NOWNOW " + hexlify(token1) + b"\n")
a1.transport.write(b"please relay " + hexlify(token1) + b"\nNOWNOWNOW")
exp = b"impatient\n"
yield a1.waitForBytes(len(exp))

View File

@ -9,6 +9,7 @@ from twisted.python import log, failure
from twisted.test import proto_helpers
from ..errors import InternalError
from .. import transit
from .common import ServerBase
from nacl.secret import SecretBox
from nacl.exceptions import CryptoError
@ -1320,12 +1321,12 @@ class Transit(unittest.TestCase):
self.assertEqual(results, ["winner"])
class Full(unittest.TestCase):
class Full(ServerBase, unittest.TestCase):
def doBoth(self, d1, d2):
return gatherResults([d1, d2], True)
@inlineCallbacks
def test_full(self):
def test_direct(self):
KEY = b"k"*32
s = transit.TransitSender(None)
r = transit.TransitReceiver(None)
@ -1351,3 +1352,31 @@ class Full(unittest.TestCase):
yield x.close()
yield y.close()
@inlineCallbacks
def test_relay(self):
KEY = b"k"*32
s = transit.TransitSender(self.transit, no_listen=True)
r = transit.TransitReceiver(self.transit, no_listen=True)
s.set_transit_key(KEY)
r.set_transit_key(KEY)
shints = yield s.get_connection_hints()
rhints = yield r.get_connection_hints()
s.add_connection_hints(rhints)
r.add_connection_hints(shints)
(x,y) = yield self.doBoth(s.connect(), r.connect())
self.assertIsInstance(x, transit.Connection)
self.assertIsInstance(y, transit.Connection)
d = y.receive_record()
x.send_record(b"record1")
r = yield d
self.assertEqual(r, b"record1")
yield x.close()
yield y.close()