Merge: enhance transit server
This adds a new kind of handshake message, which lets the Transit Relay server tell when two connections (for the same channel) are really from the same client, and therefore should not be connected to each other. The 0.8.2 client speaks the old handshake, but a future client will speak the new handshake. refs #115
This commit is contained in:
commit
b64f27fdad
|
@ -25,6 +25,7 @@ def blur_size(size):
|
|||
class TransitConnection(protocol.Protocol):
|
||||
def __init__(self):
|
||||
self._got_token = False
|
||||
self._got_side = False
|
||||
self._token_buffer = b""
|
||||
self._sent_ok = False
|
||||
self._buddy = None
|
||||
|
@ -32,9 +33,14 @@ class TransitConnection(protocol.Protocol):
|
|||
self._total_sent = 0
|
||||
|
||||
def describeToken(self):
|
||||
d = "-"
|
||||
if self._got_token:
|
||||
return self._got_token[:16].decode("ascii")
|
||||
return "-"
|
||||
d = self._got_token[:16].decode("ascii")
|
||||
if self._got_side:
|
||||
d += "-" + self._got_side.decode("ascii")
|
||||
else:
|
||||
d += "-<unsided>"
|
||||
return d
|
||||
|
||||
def connectionMade(self):
|
||||
self._started = time.time()
|
||||
|
@ -59,26 +65,69 @@ class TransitConnection(protocol.Protocol):
|
|||
# else this should be (part of) the token
|
||||
self._token_buffer += data
|
||||
buf = self._token_buffer
|
||||
wanted = len("please relay \n")+32*2
|
||||
if len(buf) < wanted-1 and b"\n" in buf:
|
||||
self.transport.write(b"bad handshake\n")
|
||||
log.msg("transit handshake early failure")
|
||||
return self.disconnect()
|
||||
if len(buf) < wanted:
|
||||
return
|
||||
if len(buf) > wanted:
|
||||
|
||||
# old: "please relay {64}\n"
|
||||
# new: "please relay {64} for side {16}\n"
|
||||
(old, handshake_len, token) = self._check_old_handshake(buf)
|
||||
assert old in ("yes", "waiting", "no")
|
||||
if old == "yes":
|
||||
# remember they aren't supposed to send anything past their
|
||||
# handshake until we've said go
|
||||
if len(buf) > handshake_len:
|
||||
self.transport.write(b"impatient\n")
|
||||
log.msg("transit impatience failure")
|
||||
return self.disconnect() # impatience yields failure
|
||||
mo = re.search(br"^please relay (\w{64})\n", buf, re.M)
|
||||
if not mo:
|
||||
return self._got_handshake(token, None)
|
||||
(new, handshake_len, token, side) = self._check_new_handshake(buf)
|
||||
assert new in ("yes", "waiting", "no")
|
||||
if new == "yes":
|
||||
if len(buf) > handshake_len:
|
||||
self.transport.write(b"impatient\n")
|
||||
log.msg("transit impatience failure")
|
||||
return self.disconnect() # impatience yields failure
|
||||
return self._got_handshake(token, side)
|
||||
if (old == "no" and new == "no"):
|
||||
self.transport.write(b"bad handshake\n")
|
||||
log.msg("transit handshake failure")
|
||||
return self.disconnect() # incorrectness yields failure
|
||||
token = mo.group(1)
|
||||
# else we'll keep waiting
|
||||
|
||||
def _check_old_handshake(self, buf):
|
||||
# old: "please relay {64}\n"
|
||||
# return ("yes", handshake, token) if buf contains an old-style handshake
|
||||
# return ("waiting", None, None) if it might eventually contain one
|
||||
# return ("no", None, None) if it could never contain one
|
||||
wanted = len("please relay \n")+32*2
|
||||
if len(buf) < wanted-1 and b"\n" in buf:
|
||||
return ("no", None, None)
|
||||
if len(buf) < wanted:
|
||||
return ("waiting", None, None)
|
||||
|
||||
mo = re.search(br"^please relay (\w{64})\n", buf, re.M)
|
||||
if mo:
|
||||
token = mo.group(1)
|
||||
return ("yes", wanted, token)
|
||||
return ("no", None, None)
|
||||
|
||||
def _check_new_handshake(self, buf):
|
||||
# new: "please relay {64} for side {16}\n"
|
||||
wanted = len("please relay for side \n")+32*2+8*2
|
||||
if len(buf) < wanted-1 and b"\n" in buf:
|
||||
return ("no", None, None, None)
|
||||
if len(buf) < wanted:
|
||||
return ("waiting", None, None, None)
|
||||
|
||||
mo = re.search(br"^please relay (\w{64}) for side (\w{16})\n", buf, re.M)
|
||||
if mo:
|
||||
token = mo.group(1)
|
||||
side = mo.group(2)
|
||||
return ("yes", wanted, token, side)
|
||||
return ("no", None, None, None)
|
||||
|
||||
def _got_handshake(self, token, side):
|
||||
self._got_token = token
|
||||
self.factory.connection_got_token(token, self)
|
||||
self._got_side = side
|
||||
self.factory.connection_got_token(token, side, self)
|
||||
|
||||
def buddy_connected(self, them):
|
||||
self._buddy = them
|
||||
|
@ -100,7 +149,7 @@ class TransitConnection(protocol.Protocol):
|
|||
def connectionLost(self, reason):
|
||||
if self._buddy:
|
||||
self._buddy.buddy_disconnected()
|
||||
self.factory.transitFinished(self, self._got_token,
|
||||
self.factory.transitFinished(self, self._got_token, self._got_side,
|
||||
self.describeToken())
|
||||
|
||||
# Record usage. There are four cases:
|
||||
|
@ -136,12 +185,18 @@ class TransitConnection(protocol.Protocol):
|
|||
class Transit(protocol.ServerFactory, service.MultiService):
|
||||
# I manage pairs of simultaneous connections to a secondary TCP port,
|
||||
# both forwarded to the other. Clients must begin each connection with
|
||||
# "please relay TOKEN\n". I will send "ok\n" when the matching connection
|
||||
# is established, or disconnect if no matching connection is made within
|
||||
# MAX_WAIT_TIME seconds. I will disconnect if you send data before the
|
||||
# "ok\n". All data you get after the "ok\n" will be from the other side.
|
||||
# You will not receive "ok\n" until the other side has also connected and
|
||||
# submitted a matching token. The token is the same for each side.
|
||||
# "please relay TOKEN for SIDE\n" (or a legacy form without the "for
|
||||
# SIDE"). Two connections match if they use the same TOKEN and have
|
||||
# different SIDEs (the redundant connections are dropped when a match is
|
||||
# made). Legacy connections match any with the same TOKEN, ignoring SIDE
|
||||
# (so two legacy connections will match each other).
|
||||
|
||||
# I will send "ok\n" when the matching connection is established, or
|
||||
# disconnect if no matching connection is made within MAX_WAIT_TIME
|
||||
# seconds. I will disconnect if you send data before the "ok\n". All data
|
||||
# you get after the "ok\n" will be from the other side. You will not
|
||||
# receive "ok\n" until the other side has also connected and submitted a
|
||||
# matching token (and differing SIDE).
|
||||
|
||||
# In addition, the connections will be dropped after MAXLENGTH bytes have
|
||||
# been sent by either side, or MAXTIME seconds have elapsed after the
|
||||
|
@ -164,22 +219,37 @@ class Transit(protocol.ServerFactory, service.MultiService):
|
|||
service.MultiService.__init__(self)
|
||||
self._db = db
|
||||
self._blur_usage = blur_usage
|
||||
self._pending_requests = {} # token -> TransitConnection
|
||||
self._pending_requests = {} # token -> set((side, TransitConnection))
|
||||
self._active_connections = set() # TransitConnection
|
||||
self._counts = collections.defaultdict(int)
|
||||
self._count_bytes = 0
|
||||
|
||||
def connection_got_token(self, token, p):
|
||||
if token in self._pending_requests:
|
||||
log.msg("transit relay 2: %s" % p.describeToken())
|
||||
buddy = self._pending_requests.pop(token)
|
||||
self._active_connections.add(p)
|
||||
self._active_connections.add(buddy)
|
||||
p.buddy_connected(buddy)
|
||||
buddy.buddy_connected(p)
|
||||
else:
|
||||
self._pending_requests[token] = p
|
||||
log.msg("transit relay 1: %s" % p.describeToken())
|
||||
def connection_got_token(self, token, new_side, new_tc):
|
||||
if token not in self._pending_requests:
|
||||
self._pending_requests[token] = set()
|
||||
potentials = self._pending_requests[token]
|
||||
for old in potentials:
|
||||
(old_side, old_tc) = old
|
||||
if ((old_side is None)
|
||||
or (new_side is None)
|
||||
or (old_side != new_side)):
|
||||
# we found a match
|
||||
log.msg("transit relay 2: %s" % new_tc.describeToken())
|
||||
|
||||
# drop and stop tracking the rest
|
||||
potentials.remove(old)
|
||||
for (_, leftover_tc) in potentials:
|
||||
leftover_tc.disconnect() # TODO: not "errory"?
|
||||
self._pending_requests.pop(token)
|
||||
|
||||
# glue the two ends together
|
||||
self._active_connections.add(new_tc)
|
||||
self._active_connections.add(old_tc)
|
||||
new_tc.buddy_connected(old_tc)
|
||||
old_tc.buddy_connected(new_tc)
|
||||
return
|
||||
log.msg("transit relay 1: %s" % new_tc.describeToken())
|
||||
potentials.add((new_side, new_tc))
|
||||
# TODO: timer
|
||||
|
||||
def recordUsage(self, started, result, total_bytes,
|
||||
|
@ -198,13 +268,15 @@ class Transit(protocol.ServerFactory, service.MultiService):
|
|||
self._counts[result] += 1
|
||||
self._count_bytes += total_bytes
|
||||
|
||||
def transitFinished(self, p, token, description):
|
||||
for token,tc in self._pending_requests.items():
|
||||
if tc is p:
|
||||
def transitFinished(self, tc, token, side, description):
|
||||
if token in self._pending_requests:
|
||||
side_tc = (side, tc)
|
||||
if side_tc in self._pending_requests[token]:
|
||||
self._pending_requests[token].remove(side_tc)
|
||||
if not self._pending_requests[token]: # set is now empty
|
||||
del self._pending_requests[token]
|
||||
break
|
||||
log.msg("transitFinished %s" % (description,))
|
||||
self._active_connections.discard(p)
|
||||
self._active_connections.discard(tc)
|
||||
|
||||
def transitFailed(self, p):
|
||||
log.msg("transitFailed %r" % p)
|
||||
|
|
|
@ -1072,6 +1072,7 @@ class Accumulator(protocol.Protocol):
|
|||
self.data = b""
|
||||
self.count = 0
|
||||
self._wait = None
|
||||
self._disconnect = defer.Deferred()
|
||||
def waitForBytes(self, more):
|
||||
assert self._wait is None
|
||||
self.count = more
|
||||
|
@ -1089,6 +1090,7 @@ class Accumulator(protocol.Protocol):
|
|||
def connectionLost(self, why):
|
||||
if self._wait:
|
||||
self._wait.errback(RuntimeError("closed"))
|
||||
self._disconnect.callback(None)
|
||||
|
||||
class Transit(ServerBase, unittest.TestCase):
|
||||
def test_blur_size(self):
|
||||
|
@ -1115,7 +1117,32 @@ class Transit(ServerBase, unittest.TestCase):
|
|||
self.assertEqual('Wormhole Relay'.encode('ascii'), resp.strip())
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_basic(self):
|
||||
def test_register(self):
|
||||
ep = clientFromString(reactor, self.transit)
|
||||
a1 = yield connectProtocol(ep, Accumulator())
|
||||
|
||||
token1 = b"\x00"*32
|
||||
side1 = b"\x01"*8
|
||||
a1.transport.write(b"please relay " + hexlify(token1) +
|
||||
b" for side " + hexlify(side1) + b"\n")
|
||||
|
||||
# let that arrive
|
||||
while self.count() == 0:
|
||||
yield self.wait()
|
||||
self.assertEqual(self.count(), 1)
|
||||
|
||||
a1.transport.loseConnection()
|
||||
|
||||
# let that get removed
|
||||
while self.count() > 0:
|
||||
yield self.wait()
|
||||
self.assertEqual(self.count(), 0)
|
||||
|
||||
# the token should be removed too
|
||||
self.assertEqual(len(self._transit_server._pending_requests), 0)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_both_unsided(self):
|
||||
ep = clientFromString(reactor, self.transit)
|
||||
a1 = yield connectProtocol(ep, Accumulator())
|
||||
a2 = yield connectProtocol(ep, Accumulator())
|
||||
|
@ -1143,6 +1170,133 @@ class Transit(ServerBase, unittest.TestCase):
|
|||
a1.transport.loseConnection()
|
||||
a2.transport.loseConnection()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_sided_unsided(self):
|
||||
ep = clientFromString(reactor, self.transit)
|
||||
a1 = yield connectProtocol(ep, Accumulator())
|
||||
a2 = yield connectProtocol(ep, Accumulator())
|
||||
|
||||
token1 = b"\x00"*32
|
||||
side1 = b"\x01"*8
|
||||
a1.transport.write(b"please relay " + hexlify(token1) +
|
||||
b" for side " + hexlify(side1) + b"\n")
|
||||
a2.transport.write(b"please relay " + hexlify(token1) + b"\n")
|
||||
|
||||
# a correct handshake yields an ack, after which we can send
|
||||
exp = b"ok\n"
|
||||
yield a1.waitForBytes(len(exp))
|
||||
self.assertEqual(a1.data, exp)
|
||||
s1 = b"data1"
|
||||
a1.transport.write(s1)
|
||||
|
||||
exp = b"ok\n"
|
||||
yield a2.waitForBytes(len(exp))
|
||||
self.assertEqual(a2.data, exp)
|
||||
|
||||
# all data they sent after the handshake should be given to us
|
||||
exp = b"ok\n"+s1
|
||||
yield a2.waitForBytes(len(exp))
|
||||
self.assertEqual(a2.data, exp)
|
||||
|
||||
a1.transport.loseConnection()
|
||||
a2.transport.loseConnection()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_unsided_sided(self):
|
||||
ep = clientFromString(reactor, self.transit)
|
||||
a1 = yield connectProtocol(ep, Accumulator())
|
||||
a2 = yield connectProtocol(ep, Accumulator())
|
||||
|
||||
token1 = b"\x00"*32
|
||||
side1 = b"\x01"*8
|
||||
a1.transport.write(b"please relay " + hexlify(token1) + b"\n")
|
||||
a2.transport.write(b"please relay " + hexlify(token1) +
|
||||
b" for side " + hexlify(side1) + b"\n")
|
||||
|
||||
# a correct handshake yields an ack, after which we can send
|
||||
exp = b"ok\n"
|
||||
yield a1.waitForBytes(len(exp))
|
||||
self.assertEqual(a1.data, exp)
|
||||
s1 = b"data1"
|
||||
a1.transport.write(s1)
|
||||
|
||||
exp = b"ok\n"
|
||||
yield a2.waitForBytes(len(exp))
|
||||
self.assertEqual(a2.data, exp)
|
||||
|
||||
# all data they sent after the handshake should be given to us
|
||||
exp = b"ok\n"+s1
|
||||
yield a2.waitForBytes(len(exp))
|
||||
self.assertEqual(a2.data, exp)
|
||||
|
||||
a1.transport.loseConnection()
|
||||
a2.transport.loseConnection()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_both_sided(self):
|
||||
ep = clientFromString(reactor, self.transit)
|
||||
a1 = yield connectProtocol(ep, Accumulator())
|
||||
a2 = yield connectProtocol(ep, Accumulator())
|
||||
|
||||
token1 = b"\x00"*32
|
||||
side1 = b"\x01"*8
|
||||
side2 = b"\x02"*8
|
||||
a1.transport.write(b"please relay " + hexlify(token1) +
|
||||
b" for side " + hexlify(side1) + b"\n")
|
||||
a2.transport.write(b"please relay " + hexlify(token1) +
|
||||
b" for side " + hexlify(side2) + b"\n")
|
||||
|
||||
# a correct handshake yields an ack, after which we can send
|
||||
exp = b"ok\n"
|
||||
yield a1.waitForBytes(len(exp))
|
||||
self.assertEqual(a1.data, exp)
|
||||
s1 = b"data1"
|
||||
a1.transport.write(s1)
|
||||
|
||||
exp = b"ok\n"
|
||||
yield a2.waitForBytes(len(exp))
|
||||
self.assertEqual(a2.data, exp)
|
||||
|
||||
# all data they sent after the handshake should be given to us
|
||||
exp = b"ok\n"+s1
|
||||
yield a2.waitForBytes(len(exp))
|
||||
self.assertEqual(a2.data, exp)
|
||||
|
||||
a1.transport.loseConnection()
|
||||
a2.transport.loseConnection()
|
||||
|
||||
def count(self):
|
||||
return sum([len(potentials)
|
||||
for potentials
|
||||
in self._transit_server._pending_requests.values()])
|
||||
def wait(self):
|
||||
d = defer.Deferred()
|
||||
reactor.callLater(0.001, d.callback, None)
|
||||
return d
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_ignore_same_side(self):
|
||||
ep = clientFromString(reactor, self.transit)
|
||||
a1 = yield connectProtocol(ep, Accumulator())
|
||||
a2 = yield connectProtocol(ep, Accumulator())
|
||||
|
||||
token1 = b"\x00"*32
|
||||
side1 = b"\x01"*8
|
||||
a1.transport.write(b"please relay " + hexlify(token1) +
|
||||
b" for side " + hexlify(side1) + b"\n")
|
||||
# let that arrive
|
||||
while self.count() == 0:
|
||||
yield self.wait()
|
||||
a2.transport.write(b"please relay " + hexlify(token1) +
|
||||
b" for side " + hexlify(side1) + b"\n")
|
||||
# let that arrive
|
||||
while self.count() == 1:
|
||||
yield self.wait()
|
||||
self.assertEqual(self.count(), 2) # same-side connections don't match
|
||||
|
||||
a1.transport.loseConnection()
|
||||
a2.transport.loseConnection()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_bad_handshake(self):
|
||||
ep = clientFromString(reactor, self.transit)
|
||||
|
@ -1186,7 +1340,7 @@ class Transit(ServerBase, unittest.TestCase):
|
|||
|
||||
token1 = b"\x00"*32
|
||||
# sending too many bytes is impatience.
|
||||
a1.transport.write(b"please RELAY NOWNOW " + hexlify(token1) + b"\n")
|
||||
a1.transport.write(b"please relay " + hexlify(token1) + b"\nNOWNOWNOW")
|
||||
|
||||
exp = b"impatient\n"
|
||||
yield a1.waitForBytes(len(exp))
|
||||
|
|
|
@ -9,6 +9,7 @@ from twisted.python import log, failure
|
|||
from twisted.test import proto_helpers
|
||||
from ..errors import InternalError
|
||||
from .. import transit
|
||||
from .common import ServerBase
|
||||
from nacl.secret import SecretBox
|
||||
from nacl.exceptions import CryptoError
|
||||
|
||||
|
@ -1320,12 +1321,12 @@ class Transit(unittest.TestCase):
|
|||
self.assertEqual(results, ["winner"])
|
||||
|
||||
|
||||
class Full(unittest.TestCase):
|
||||
class Full(ServerBase, unittest.TestCase):
|
||||
def doBoth(self, d1, d2):
|
||||
return gatherResults([d1, d2], True)
|
||||
|
||||
@inlineCallbacks
|
||||
def test_full(self):
|
||||
def test_direct(self):
|
||||
KEY = b"k"*32
|
||||
s = transit.TransitSender(None)
|
||||
r = transit.TransitReceiver(None)
|
||||
|
@ -1351,3 +1352,31 @@ class Full(unittest.TestCase):
|
|||
|
||||
yield x.close()
|
||||
yield y.close()
|
||||
|
||||
@inlineCallbacks
|
||||
def test_relay(self):
|
||||
KEY = b"k"*32
|
||||
s = transit.TransitSender(self.transit, no_listen=True)
|
||||
r = transit.TransitReceiver(self.transit, no_listen=True)
|
||||
|
||||
s.set_transit_key(KEY)
|
||||
r.set_transit_key(KEY)
|
||||
|
||||
shints = yield s.get_connection_hints()
|
||||
rhints = yield r.get_connection_hints()
|
||||
|
||||
s.add_connection_hints(rhints)
|
||||
r.add_connection_hints(shints)
|
||||
|
||||
(x,y) = yield self.doBoth(s.connect(), r.connect())
|
||||
self.assertIsInstance(x, transit.Connection)
|
||||
self.assertIsInstance(y, transit.Connection)
|
||||
|
||||
d = y.receive_record()
|
||||
|
||||
x.send_record(b"record1")
|
||||
r = yield d
|
||||
self.assertEqual(r, b"record1")
|
||||
|
||||
yield x.close()
|
||||
yield y.close()
|
||||
|
|
Loading…
Reference in New Issue
Block a user