Merge: enhance transit server

This adds a new kind of handshake message, which lets the Transit Relay
server tell when two connections (for the same channel) are really from the
same client, and therefore should not be connected to each other. The 0.8.2
client speaks the old handshake, but a future client will speak the new
handshake.

refs #115
This commit is contained in:
Brian Warner 2016-12-23 22:18:26 -05:00
commit b64f27fdad
3 changed files with 301 additions and 46 deletions

View File

@ -25,6 +25,7 @@ def blur_size(size):
class TransitConnection(protocol.Protocol): class TransitConnection(protocol.Protocol):
def __init__(self): def __init__(self):
self._got_token = False self._got_token = False
self._got_side = False
self._token_buffer = b"" self._token_buffer = b""
self._sent_ok = False self._sent_ok = False
self._buddy = None self._buddy = None
@ -32,9 +33,14 @@ class TransitConnection(protocol.Protocol):
self._total_sent = 0 self._total_sent = 0
def describeToken(self): def describeToken(self):
d = "-"
if self._got_token: if self._got_token:
return self._got_token[:16].decode("ascii") d = self._got_token[:16].decode("ascii")
return "-" if self._got_side:
d += "-" + self._got_side.decode("ascii")
else:
d += "-<unsided>"
return d
def connectionMade(self): def connectionMade(self):
self._started = time.time() self._started = time.time()
@ -59,26 +65,69 @@ class TransitConnection(protocol.Protocol):
# else this should be (part of) the token # else this should be (part of) the token
self._token_buffer += data self._token_buffer += data
buf = self._token_buffer buf = self._token_buffer
wanted = len("please relay \n")+32*2
if len(buf) < wanted-1 and b"\n" in buf: # old: "please relay {64}\n"
self.transport.write(b"bad handshake\n") # new: "please relay {64} for side {16}\n"
log.msg("transit handshake early failure") (old, handshake_len, token) = self._check_old_handshake(buf)
return self.disconnect() assert old in ("yes", "waiting", "no")
if len(buf) < wanted: if old == "yes":
return # remember they aren't supposed to send anything past their
if len(buf) > wanted: # handshake until we've said go
if len(buf) > handshake_len:
self.transport.write(b"impatient\n") self.transport.write(b"impatient\n")
log.msg("transit impatience failure") log.msg("transit impatience failure")
return self.disconnect() # impatience yields failure return self.disconnect() # impatience yields failure
mo = re.search(br"^please relay (\w{64})\n", buf, re.M) return self._got_handshake(token, None)
if not mo: (new, handshake_len, token, side) = self._check_new_handshake(buf)
assert new in ("yes", "waiting", "no")
if new == "yes":
if len(buf) > handshake_len:
self.transport.write(b"impatient\n")
log.msg("transit impatience failure")
return self.disconnect() # impatience yields failure
return self._got_handshake(token, side)
if (old == "no" and new == "no"):
self.transport.write(b"bad handshake\n") self.transport.write(b"bad handshake\n")
log.msg("transit handshake failure") log.msg("transit handshake failure")
return self.disconnect() # incorrectness yields failure return self.disconnect() # incorrectness yields failure
token = mo.group(1) # else we'll keep waiting
def _check_old_handshake(self, buf):
# old: "please relay {64}\n"
# return ("yes", handshake, token) if buf contains an old-style handshake
# return ("waiting", None, None) if it might eventually contain one
# return ("no", None, None) if it could never contain one
wanted = len("please relay \n")+32*2
if len(buf) < wanted-1 and b"\n" in buf:
return ("no", None, None)
if len(buf) < wanted:
return ("waiting", None, None)
mo = re.search(br"^please relay (\w{64})\n", buf, re.M)
if mo:
token = mo.group(1)
return ("yes", wanted, token)
return ("no", None, None)
def _check_new_handshake(self, buf):
# new: "please relay {64} for side {16}\n"
wanted = len("please relay for side \n")+32*2+8*2
if len(buf) < wanted-1 and b"\n" in buf:
return ("no", None, None, None)
if len(buf) < wanted:
return ("waiting", None, None, None)
mo = re.search(br"^please relay (\w{64}) for side (\w{16})\n", buf, re.M)
if mo:
token = mo.group(1)
side = mo.group(2)
return ("yes", wanted, token, side)
return ("no", None, None, None)
def _got_handshake(self, token, side):
self._got_token = token self._got_token = token
self.factory.connection_got_token(token, self) self._got_side = side
self.factory.connection_got_token(token, side, self)
def buddy_connected(self, them): def buddy_connected(self, them):
self._buddy = them self._buddy = them
@ -100,7 +149,7 @@ class TransitConnection(protocol.Protocol):
def connectionLost(self, reason): def connectionLost(self, reason):
if self._buddy: if self._buddy:
self._buddy.buddy_disconnected() self._buddy.buddy_disconnected()
self.factory.transitFinished(self, self._got_token, self.factory.transitFinished(self, self._got_token, self._got_side,
self.describeToken()) self.describeToken())
# Record usage. There are four cases: # Record usage. There are four cases:
@ -136,12 +185,18 @@ class TransitConnection(protocol.Protocol):
class Transit(protocol.ServerFactory, service.MultiService): class Transit(protocol.ServerFactory, service.MultiService):
# I manage pairs of simultaneous connections to a secondary TCP port, # I manage pairs of simultaneous connections to a secondary TCP port,
# both forwarded to the other. Clients must begin each connection with # both forwarded to the other. Clients must begin each connection with
# "please relay TOKEN\n". I will send "ok\n" when the matching connection # "please relay TOKEN for SIDE\n" (or a legacy form without the "for
# is established, or disconnect if no matching connection is made within # SIDE"). Two connections match if they use the same TOKEN and have
# MAX_WAIT_TIME seconds. I will disconnect if you send data before the # different SIDEs (the redundant connections are dropped when a match is
# "ok\n". All data you get after the "ok\n" will be from the other side. # made). Legacy connections match any with the same TOKEN, ignoring SIDE
# You will not receive "ok\n" until the other side has also connected and # (so two legacy connections will match each other).
# submitted a matching token. The token is the same for each side.
# I will send "ok\n" when the matching connection is established, or
# disconnect if no matching connection is made within MAX_WAIT_TIME
# seconds. I will disconnect if you send data before the "ok\n". All data
# you get after the "ok\n" will be from the other side. You will not
# receive "ok\n" until the other side has also connected and submitted a
# matching token (and differing SIDE).
# In addition, the connections will be dropped after MAXLENGTH bytes have # In addition, the connections will be dropped after MAXLENGTH bytes have
# been sent by either side, or MAXTIME seconds have elapsed after the # been sent by either side, or MAXTIME seconds have elapsed after the
@ -164,22 +219,37 @@ class Transit(protocol.ServerFactory, service.MultiService):
service.MultiService.__init__(self) service.MultiService.__init__(self)
self._db = db self._db = db
self._blur_usage = blur_usage self._blur_usage = blur_usage
self._pending_requests = {} # token -> TransitConnection self._pending_requests = {} # token -> set((side, TransitConnection))
self._active_connections = set() # TransitConnection self._active_connections = set() # TransitConnection
self._counts = collections.defaultdict(int) self._counts = collections.defaultdict(int)
self._count_bytes = 0 self._count_bytes = 0
def connection_got_token(self, token, p): def connection_got_token(self, token, new_side, new_tc):
if token in self._pending_requests: if token not in self._pending_requests:
log.msg("transit relay 2: %s" % p.describeToken()) self._pending_requests[token] = set()
buddy = self._pending_requests.pop(token) potentials = self._pending_requests[token]
self._active_connections.add(p) for old in potentials:
self._active_connections.add(buddy) (old_side, old_tc) = old
p.buddy_connected(buddy) if ((old_side is None)
buddy.buddy_connected(p) or (new_side is None)
else: or (old_side != new_side)):
self._pending_requests[token] = p # we found a match
log.msg("transit relay 1: %s" % p.describeToken()) log.msg("transit relay 2: %s" % new_tc.describeToken())
# drop and stop tracking the rest
potentials.remove(old)
for (_, leftover_tc) in potentials:
leftover_tc.disconnect() # TODO: not "errory"?
self._pending_requests.pop(token)
# glue the two ends together
self._active_connections.add(new_tc)
self._active_connections.add(old_tc)
new_tc.buddy_connected(old_tc)
old_tc.buddy_connected(new_tc)
return
log.msg("transit relay 1: %s" % new_tc.describeToken())
potentials.add((new_side, new_tc))
# TODO: timer # TODO: timer
def recordUsage(self, started, result, total_bytes, def recordUsage(self, started, result, total_bytes,
@ -198,13 +268,15 @@ class Transit(protocol.ServerFactory, service.MultiService):
self._counts[result] += 1 self._counts[result] += 1
self._count_bytes += total_bytes self._count_bytes += total_bytes
def transitFinished(self, p, token, description): def transitFinished(self, tc, token, side, description):
for token,tc in self._pending_requests.items(): if token in self._pending_requests:
if tc is p: side_tc = (side, tc)
if side_tc in self._pending_requests[token]:
self._pending_requests[token].remove(side_tc)
if not self._pending_requests[token]: # set is now empty
del self._pending_requests[token] del self._pending_requests[token]
break
log.msg("transitFinished %s" % (description,)) log.msg("transitFinished %s" % (description,))
self._active_connections.discard(p) self._active_connections.discard(tc)
def transitFailed(self, p): def transitFailed(self, p):
log.msg("transitFailed %r" % p) log.msg("transitFailed %r" % p)

View File

@ -1072,6 +1072,7 @@ class Accumulator(protocol.Protocol):
self.data = b"" self.data = b""
self.count = 0 self.count = 0
self._wait = None self._wait = None
self._disconnect = defer.Deferred()
def waitForBytes(self, more): def waitForBytes(self, more):
assert self._wait is None assert self._wait is None
self.count = more self.count = more
@ -1089,6 +1090,7 @@ class Accumulator(protocol.Protocol):
def connectionLost(self, why): def connectionLost(self, why):
if self._wait: if self._wait:
self._wait.errback(RuntimeError("closed")) self._wait.errback(RuntimeError("closed"))
self._disconnect.callback(None)
class Transit(ServerBase, unittest.TestCase): class Transit(ServerBase, unittest.TestCase):
def test_blur_size(self): def test_blur_size(self):
@ -1115,7 +1117,32 @@ class Transit(ServerBase, unittest.TestCase):
self.assertEqual('Wormhole Relay'.encode('ascii'), resp.strip()) self.assertEqual('Wormhole Relay'.encode('ascii'), resp.strip())
@defer.inlineCallbacks @defer.inlineCallbacks
def test_basic(self): def test_register(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
token1 = b"\x00"*32
side1 = b"\x01"*8
a1.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\n")
# let that arrive
while self.count() == 0:
yield self.wait()
self.assertEqual(self.count(), 1)
a1.transport.loseConnection()
# let that get removed
while self.count() > 0:
yield self.wait()
self.assertEqual(self.count(), 0)
# the token should be removed too
self.assertEqual(len(self._transit_server._pending_requests), 0)
@defer.inlineCallbacks
def test_both_unsided(self):
ep = clientFromString(reactor, self.transit) ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator()) a1 = yield connectProtocol(ep, Accumulator())
a2 = yield connectProtocol(ep, Accumulator()) a2 = yield connectProtocol(ep, Accumulator())
@ -1143,6 +1170,133 @@ class Transit(ServerBase, unittest.TestCase):
a1.transport.loseConnection() a1.transport.loseConnection()
a2.transport.loseConnection() a2.transport.loseConnection()
@defer.inlineCallbacks
def test_sided_unsided(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
a2 = yield connectProtocol(ep, Accumulator())
token1 = b"\x00"*32
side1 = b"\x01"*8
a1.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\n")
a2.transport.write(b"please relay " + hexlify(token1) + b"\n")
# a correct handshake yields an ack, after which we can send
exp = b"ok\n"
yield a1.waitForBytes(len(exp))
self.assertEqual(a1.data, exp)
s1 = b"data1"
a1.transport.write(s1)
exp = b"ok\n"
yield a2.waitForBytes(len(exp))
self.assertEqual(a2.data, exp)
# all data they sent after the handshake should be given to us
exp = b"ok\n"+s1
yield a2.waitForBytes(len(exp))
self.assertEqual(a2.data, exp)
a1.transport.loseConnection()
a2.transport.loseConnection()
@defer.inlineCallbacks
def test_unsided_sided(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
a2 = yield connectProtocol(ep, Accumulator())
token1 = b"\x00"*32
side1 = b"\x01"*8
a1.transport.write(b"please relay " + hexlify(token1) + b"\n")
a2.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\n")
# a correct handshake yields an ack, after which we can send
exp = b"ok\n"
yield a1.waitForBytes(len(exp))
self.assertEqual(a1.data, exp)
s1 = b"data1"
a1.transport.write(s1)
exp = b"ok\n"
yield a2.waitForBytes(len(exp))
self.assertEqual(a2.data, exp)
# all data they sent after the handshake should be given to us
exp = b"ok\n"+s1
yield a2.waitForBytes(len(exp))
self.assertEqual(a2.data, exp)
a1.transport.loseConnection()
a2.transport.loseConnection()
@defer.inlineCallbacks
def test_both_sided(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
a2 = yield connectProtocol(ep, Accumulator())
token1 = b"\x00"*32
side1 = b"\x01"*8
side2 = b"\x02"*8
a1.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\n")
a2.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side2) + b"\n")
# a correct handshake yields an ack, after which we can send
exp = b"ok\n"
yield a1.waitForBytes(len(exp))
self.assertEqual(a1.data, exp)
s1 = b"data1"
a1.transport.write(s1)
exp = b"ok\n"
yield a2.waitForBytes(len(exp))
self.assertEqual(a2.data, exp)
# all data they sent after the handshake should be given to us
exp = b"ok\n"+s1
yield a2.waitForBytes(len(exp))
self.assertEqual(a2.data, exp)
a1.transport.loseConnection()
a2.transport.loseConnection()
def count(self):
return sum([len(potentials)
for potentials
in self._transit_server._pending_requests.values()])
def wait(self):
d = defer.Deferred()
reactor.callLater(0.001, d.callback, None)
return d
@defer.inlineCallbacks
def test_ignore_same_side(self):
ep = clientFromString(reactor, self.transit)
a1 = yield connectProtocol(ep, Accumulator())
a2 = yield connectProtocol(ep, Accumulator())
token1 = b"\x00"*32
side1 = b"\x01"*8
a1.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\n")
# let that arrive
while self.count() == 0:
yield self.wait()
a2.transport.write(b"please relay " + hexlify(token1) +
b" for side " + hexlify(side1) + b"\n")
# let that arrive
while self.count() == 1:
yield self.wait()
self.assertEqual(self.count(), 2) # same-side connections don't match
a1.transport.loseConnection()
a2.transport.loseConnection()
@defer.inlineCallbacks @defer.inlineCallbacks
def test_bad_handshake(self): def test_bad_handshake(self):
ep = clientFromString(reactor, self.transit) ep = clientFromString(reactor, self.transit)
@ -1186,7 +1340,7 @@ class Transit(ServerBase, unittest.TestCase):
token1 = b"\x00"*32 token1 = b"\x00"*32
# sending too many bytes is impatience. # sending too many bytes is impatience.
a1.transport.write(b"please RELAY NOWNOW " + hexlify(token1) + b"\n") a1.transport.write(b"please relay " + hexlify(token1) + b"\nNOWNOWNOW")
exp = b"impatient\n" exp = b"impatient\n"
yield a1.waitForBytes(len(exp)) yield a1.waitForBytes(len(exp))

View File

@ -9,6 +9,7 @@ from twisted.python import log, failure
from twisted.test import proto_helpers from twisted.test import proto_helpers
from ..errors import InternalError from ..errors import InternalError
from .. import transit from .. import transit
from .common import ServerBase
from nacl.secret import SecretBox from nacl.secret import SecretBox
from nacl.exceptions import CryptoError from nacl.exceptions import CryptoError
@ -1320,12 +1321,12 @@ class Transit(unittest.TestCase):
self.assertEqual(results, ["winner"]) self.assertEqual(results, ["winner"])
class Full(unittest.TestCase): class Full(ServerBase, unittest.TestCase):
def doBoth(self, d1, d2): def doBoth(self, d1, d2):
return gatherResults([d1, d2], True) return gatherResults([d1, d2], True)
@inlineCallbacks @inlineCallbacks
def test_full(self): def test_direct(self):
KEY = b"k"*32 KEY = b"k"*32
s = transit.TransitSender(None) s = transit.TransitSender(None)
r = transit.TransitReceiver(None) r = transit.TransitReceiver(None)
@ -1351,3 +1352,31 @@ class Full(unittest.TestCase):
yield x.close() yield x.close()
yield y.close() yield y.close()
@inlineCallbacks
def test_relay(self):
KEY = b"k"*32
s = transit.TransitSender(self.transit, no_listen=True)
r = transit.TransitReceiver(self.transit, no_listen=True)
s.set_transit_key(KEY)
r.set_transit_key(KEY)
shints = yield s.get_connection_hints()
rhints = yield r.get_connection_hints()
s.add_connection_hints(rhints)
r.add_connection_hints(shints)
(x,y) = yield self.doBoth(s.connect(), r.connect())
self.assertIsInstance(x, transit.Connection)
self.assertIsInstance(y, transit.Connection)
d = y.receive_record()
x.send_record(b"record1")
r = yield d
self.assertEqual(r, b"record1")
yield x.close()
yield y.close()