transit server: accept both new (sided) and old (unsided) handshakes
This commit is contained in:
parent
c7e4d57405
commit
5fcea701bb
|
@ -25,6 +25,7 @@ def blur_size(size):
|
|||
class TransitConnection(protocol.Protocol):
|
||||
def __init__(self):
|
||||
self._got_token = False
|
||||
self._got_side = False
|
||||
self._token_buffer = b""
|
||||
self._sent_ok = False
|
||||
self._buddy = None
|
||||
|
@ -32,9 +33,14 @@ class TransitConnection(protocol.Protocol):
|
|||
self._total_sent = 0
|
||||
|
||||
def describeToken(self):
|
||||
d = "-"
|
||||
if self._got_token:
|
||||
return self._got_token[:16].decode("ascii")
|
||||
return "-"
|
||||
d = self._got_token[:16].decode("ascii")
|
||||
if self._got_side:
|
||||
d += "-" + self._got_side.decode("ascii")
|
||||
else:
|
||||
d += "-<unsided>"
|
||||
return d
|
||||
|
||||
def connectionMade(self):
|
||||
self._started = time.time()
|
||||
|
@ -59,26 +65,69 @@ class TransitConnection(protocol.Protocol):
|
|||
# else this should be (part of) the token
|
||||
self._token_buffer += data
|
||||
buf = self._token_buffer
|
||||
wanted = len("please relay \n")+32*2
|
||||
if len(buf) < wanted-1 and b"\n" in buf:
|
||||
self.transport.write(b"bad handshake\n")
|
||||
log.msg("transit handshake early failure")
|
||||
return self.disconnect()
|
||||
if len(buf) < wanted:
|
||||
return
|
||||
if len(buf) > wanted:
|
||||
self.transport.write(b"impatient\n")
|
||||
log.msg("transit impatience failure")
|
||||
return self.disconnect() # impatience yields failure
|
||||
mo = re.search(br"^please relay (\w{64})\n", buf, re.M)
|
||||
if not mo:
|
||||
|
||||
# old: "please relay {64}\n"
|
||||
# new: "please relay {64} for side {16}\n"
|
||||
(old, handshake_len, token) = self._check_old_handshake(buf)
|
||||
assert old in ("yes", "waiting", "no")
|
||||
if old == "yes":
|
||||
# remember they aren't supposed to send anything past their
|
||||
# handshake until we've said go
|
||||
if len(buf) > handshake_len:
|
||||
self.transport.write(b"impatient\n")
|
||||
log.msg("transit impatience failure")
|
||||
return self.disconnect() # impatience yields failure
|
||||
return self._got_handshake(token, None)
|
||||
(new, handshake_len, token, side) = self._check_new_handshake(buf)
|
||||
assert new in ("yes", "waiting", "no")
|
||||
if new == "yes":
|
||||
if len(buf) > handshake_len:
|
||||
self.transport.write(b"impatient\n")
|
||||
log.msg("transit impatience failure")
|
||||
return self.disconnect() # impatience yields failure
|
||||
return self._got_handshake(token, side)
|
||||
if (old == "no" and new == "no"):
|
||||
self.transport.write(b"bad handshake\n")
|
||||
log.msg("transit handshake failure")
|
||||
return self.disconnect() # incorrectness yields failure
|
||||
token = mo.group(1)
|
||||
# else we'll keep waiting
|
||||
|
||||
def _check_old_handshake(self, buf):
|
||||
# old: "please relay {64}\n"
|
||||
# return ("yes", handshake, token) if buf contains an old-style handshake
|
||||
# return ("waiting", None, None) if it might eventually contain one
|
||||
# return ("no", None, None) if it could never contain one
|
||||
wanted = len("please relay \n")+32*2
|
||||
if len(buf) < wanted-1 and b"\n" in buf:
|
||||
return ("no", None, None)
|
||||
if len(buf) < wanted:
|
||||
return ("waiting", None, None)
|
||||
|
||||
mo = re.search(br"^please relay (\w{64})\n", buf, re.M)
|
||||
if mo:
|
||||
token = mo.group(1)
|
||||
return ("yes", wanted, token)
|
||||
return ("no", None, None)
|
||||
|
||||
def _check_new_handshake(self, buf):
|
||||
# new: "please relay {64} for side {16}\n"
|
||||
wanted = len("please relay for side \n")+32*2+8*2
|
||||
if len(buf) < wanted-1 and b"\n" in buf:
|
||||
return ("no", None, None, None)
|
||||
if len(buf) < wanted:
|
||||
return ("waiting", None, None, None)
|
||||
|
||||
mo = re.search(br"^please relay (\w{64}) for side (\w{16})\n", buf, re.M)
|
||||
if mo:
|
||||
token = mo.group(1)
|
||||
side = mo.group(2)
|
||||
return ("yes", wanted, token, side)
|
||||
return ("no", None, None, None)
|
||||
|
||||
def _got_handshake(self, token, side):
|
||||
self._got_token = token
|
||||
self.factory.connection_got_token(token, self)
|
||||
self._got_side = side
|
||||
self.factory.connection_got_token(token, side, self)
|
||||
|
||||
def buddy_connected(self, them):
|
||||
self._buddy = them
|
||||
|
@ -100,7 +149,7 @@ class TransitConnection(protocol.Protocol):
|
|||
def connectionLost(self, reason):
|
||||
if self._buddy:
|
||||
self._buddy.buddy_disconnected()
|
||||
self.factory.transitFinished(self, self._got_token,
|
||||
self.factory.transitFinished(self, self._got_token, self._got_side,
|
||||
self.describeToken())
|
||||
|
||||
# Record usage. There are four cases:
|
||||
|
@ -136,12 +185,18 @@ class TransitConnection(protocol.Protocol):
|
|||
class Transit(protocol.ServerFactory, service.MultiService):
|
||||
# I manage pairs of simultaneous connections to a secondary TCP port,
|
||||
# both forwarded to the other. Clients must begin each connection with
|
||||
# "please relay TOKEN\n". I will send "ok\n" when the matching connection
|
||||
# is established, or disconnect if no matching connection is made within
|
||||
# MAX_WAIT_TIME seconds. I will disconnect if you send data before the
|
||||
# "ok\n". All data you get after the "ok\n" will be from the other side.
|
||||
# You will not receive "ok\n" until the other side has also connected and
|
||||
# submitted a matching token. The token is the same for each side.
|
||||
# "please relay TOKEN for SIDE\n" (or a legacy form without the "for
|
||||
# SIDE"). Two connections match if they use the same TOKEN and have
|
||||
# different SIDEs (the redundant connections are dropped when a match is
|
||||
# made). Legacy connections match any with the same TOKEN, ignoring SIDE
|
||||
# (so two legacy connections will match each other).
|
||||
|
||||
# I will send "ok\n" when the matching connection is established, or
|
||||
# disconnect if no matching connection is made within MAX_WAIT_TIME
|
||||
# seconds. I will disconnect if you send data before the "ok\n". All data
|
||||
# you get after the "ok\n" will be from the other side. You will not
|
||||
# receive "ok\n" until the other side has also connected and submitted a
|
||||
# matching token (and differing SIDE).
|
||||
|
||||
# In addition, the connections will be dropped after MAXLENGTH bytes have
|
||||
# been sent by either side, or MAXTIME seconds have elapsed after the
|
||||
|
@ -164,23 +219,38 @@ class Transit(protocol.ServerFactory, service.MultiService):
|
|||
service.MultiService.__init__(self)
|
||||
self._db = db
|
||||
self._blur_usage = blur_usage
|
||||
self._pending_requests = {} # token -> TransitConnection
|
||||
self._pending_requests = {} # token -> set((side, TransitConnection))
|
||||
self._active_connections = set() # TransitConnection
|
||||
self._counts = collections.defaultdict(int)
|
||||
self._count_bytes = 0
|
||||
|
||||
def connection_got_token(self, token, p):
|
||||
if token in self._pending_requests:
|
||||
log.msg("transit relay 2: %s" % p.describeToken())
|
||||
buddy = self._pending_requests.pop(token)
|
||||
self._active_connections.add(p)
|
||||
self._active_connections.add(buddy)
|
||||
p.buddy_connected(buddy)
|
||||
buddy.buddy_connected(p)
|
||||
else:
|
||||
self._pending_requests[token] = p
|
||||
log.msg("transit relay 1: %s" % p.describeToken())
|
||||
# TODO: timer
|
||||
def connection_got_token(self, token, new_side, new_tc):
|
||||
if token not in self._pending_requests:
|
||||
self._pending_requests[token] = set()
|
||||
potentials = self._pending_requests[token]
|
||||
for old in potentials:
|
||||
(old_side, old_tc) = old
|
||||
if ((old_side is None)
|
||||
or (new_side is None)
|
||||
or (old_side != new_side)):
|
||||
# we found a match
|
||||
log.msg("transit relay 2: %s" % new_tc.describeToken())
|
||||
|
||||
# drop and stop tracking the rest
|
||||
potentials.remove(old)
|
||||
for (_, leftover_tc) in potentials:
|
||||
leftover_tc.disconnect() # TODO: not "errory"?
|
||||
self._pending_requests.pop(token)
|
||||
|
||||
# glue the two ends together
|
||||
self._active_connections.add(new_tc)
|
||||
self._active_connections.add(old_tc)
|
||||
new_tc.buddy_connected(old_tc)
|
||||
old_tc.buddy_connected(new_tc)
|
||||
return
|
||||
log.msg("transit relay 1: %s" % new_tc.describeToken())
|
||||
potentials.add((new_side, new_tc))
|
||||
# TODO: timer
|
||||
|
||||
def recordUsage(self, started, result, total_bytes,
|
||||
total_time, waiting_time):
|
||||
|
@ -198,13 +268,15 @@ class Transit(protocol.ServerFactory, service.MultiService):
|
|||
self._counts[result] += 1
|
||||
self._count_bytes += total_bytes
|
||||
|
||||
def transitFinished(self, p, token, description):
|
||||
for token,tc in self._pending_requests.items():
|
||||
if tc is p:
|
||||
def transitFinished(self, tc, token, side, description):
|
||||
if token in self._pending_requests:
|
||||
side_tc = (side, tc)
|
||||
if side_tc in self._pending_requests[token]:
|
||||
self._pending_requests[token].remove(side_tc)
|
||||
if not self._pending_requests[token]: # set is now empty
|
||||
del self._pending_requests[token]
|
||||
break
|
||||
log.msg("transitFinished %s" % (description,))
|
||||
self._active_connections.discard(p)
|
||||
self._active_connections.discard(tc)
|
||||
|
||||
def transitFailed(self, p):
|
||||
log.msg("transitFailed %r" % p)
|
||||
|
|
|
@ -1072,6 +1072,7 @@ class Accumulator(protocol.Protocol):
|
|||
self.data = b""
|
||||
self.count = 0
|
||||
self._wait = None
|
||||
self._disconnect = defer.Deferred()
|
||||
def waitForBytes(self, more):
|
||||
assert self._wait is None
|
||||
self.count = more
|
||||
|
@ -1089,6 +1090,7 @@ class Accumulator(protocol.Protocol):
|
|||
def connectionLost(self, why):
|
||||
if self._wait:
|
||||
self._wait.errback(RuntimeError("closed"))
|
||||
self._disconnect.callback(None)
|
||||
|
||||
class Transit(ServerBase, unittest.TestCase):
|
||||
def test_blur_size(self):
|
||||
|
@ -1115,7 +1117,32 @@ class Transit(ServerBase, unittest.TestCase):
|
|||
self.assertEqual('Wormhole Relay'.encode('ascii'), resp.strip())
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_basic(self):
|
||||
def test_register(self):
|
||||
ep = clientFromString(reactor, self.transit)
|
||||
a1 = yield connectProtocol(ep, Accumulator())
|
||||
|
||||
token1 = b"\x00"*32
|
||||
side1 = b"\x01"*8
|
||||
a1.transport.write(b"please relay " + hexlify(token1) +
|
||||
b" for side " + hexlify(side1) + b"\n")
|
||||
|
||||
# let that arrive
|
||||
while self.count() == 0:
|
||||
yield self.wait()
|
||||
self.assertEqual(self.count(), 1)
|
||||
|
||||
a1.transport.loseConnection()
|
||||
|
||||
# let that get removed
|
||||
while self.count() > 0:
|
||||
yield self.wait()
|
||||
self.assertEqual(self.count(), 0)
|
||||
|
||||
# the token should be removed too
|
||||
self.assertEqual(len(self._transit_server._pending_requests), 0)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_both_unsided(self):
|
||||
ep = clientFromString(reactor, self.transit)
|
||||
a1 = yield connectProtocol(ep, Accumulator())
|
||||
a2 = yield connectProtocol(ep, Accumulator())
|
||||
|
@ -1143,6 +1170,133 @@ class Transit(ServerBase, unittest.TestCase):
|
|||
a1.transport.loseConnection()
|
||||
a2.transport.loseConnection()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_sided_unsided(self):
|
||||
ep = clientFromString(reactor, self.transit)
|
||||
a1 = yield connectProtocol(ep, Accumulator())
|
||||
a2 = yield connectProtocol(ep, Accumulator())
|
||||
|
||||
token1 = b"\x00"*32
|
||||
side1 = b"\x01"*8
|
||||
a1.transport.write(b"please relay " + hexlify(token1) +
|
||||
b" for side " + hexlify(side1) + b"\n")
|
||||
a2.transport.write(b"please relay " + hexlify(token1) + b"\n")
|
||||
|
||||
# a correct handshake yields an ack, after which we can send
|
||||
exp = b"ok\n"
|
||||
yield a1.waitForBytes(len(exp))
|
||||
self.assertEqual(a1.data, exp)
|
||||
s1 = b"data1"
|
||||
a1.transport.write(s1)
|
||||
|
||||
exp = b"ok\n"
|
||||
yield a2.waitForBytes(len(exp))
|
||||
self.assertEqual(a2.data, exp)
|
||||
|
||||
# all data they sent after the handshake should be given to us
|
||||
exp = b"ok\n"+s1
|
||||
yield a2.waitForBytes(len(exp))
|
||||
self.assertEqual(a2.data, exp)
|
||||
|
||||
a1.transport.loseConnection()
|
||||
a2.transport.loseConnection()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_unsided_sided(self):
|
||||
ep = clientFromString(reactor, self.transit)
|
||||
a1 = yield connectProtocol(ep, Accumulator())
|
||||
a2 = yield connectProtocol(ep, Accumulator())
|
||||
|
||||
token1 = b"\x00"*32
|
||||
side1 = b"\x01"*8
|
||||
a1.transport.write(b"please relay " + hexlify(token1) + b"\n")
|
||||
a2.transport.write(b"please relay " + hexlify(token1) +
|
||||
b" for side " + hexlify(side1) + b"\n")
|
||||
|
||||
# a correct handshake yields an ack, after which we can send
|
||||
exp = b"ok\n"
|
||||
yield a1.waitForBytes(len(exp))
|
||||
self.assertEqual(a1.data, exp)
|
||||
s1 = b"data1"
|
||||
a1.transport.write(s1)
|
||||
|
||||
exp = b"ok\n"
|
||||
yield a2.waitForBytes(len(exp))
|
||||
self.assertEqual(a2.data, exp)
|
||||
|
||||
# all data they sent after the handshake should be given to us
|
||||
exp = b"ok\n"+s1
|
||||
yield a2.waitForBytes(len(exp))
|
||||
self.assertEqual(a2.data, exp)
|
||||
|
||||
a1.transport.loseConnection()
|
||||
a2.transport.loseConnection()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_both_sided(self):
|
||||
ep = clientFromString(reactor, self.transit)
|
||||
a1 = yield connectProtocol(ep, Accumulator())
|
||||
a2 = yield connectProtocol(ep, Accumulator())
|
||||
|
||||
token1 = b"\x00"*32
|
||||
side1 = b"\x01"*8
|
||||
side2 = b"\x02"*8
|
||||
a1.transport.write(b"please relay " + hexlify(token1) +
|
||||
b" for side " + hexlify(side1) + b"\n")
|
||||
a2.transport.write(b"please relay " + hexlify(token1) +
|
||||
b" for side " + hexlify(side2) + b"\n")
|
||||
|
||||
# a correct handshake yields an ack, after which we can send
|
||||
exp = b"ok\n"
|
||||
yield a1.waitForBytes(len(exp))
|
||||
self.assertEqual(a1.data, exp)
|
||||
s1 = b"data1"
|
||||
a1.transport.write(s1)
|
||||
|
||||
exp = b"ok\n"
|
||||
yield a2.waitForBytes(len(exp))
|
||||
self.assertEqual(a2.data, exp)
|
||||
|
||||
# all data they sent after the handshake should be given to us
|
||||
exp = b"ok\n"+s1
|
||||
yield a2.waitForBytes(len(exp))
|
||||
self.assertEqual(a2.data, exp)
|
||||
|
||||
a1.transport.loseConnection()
|
||||
a2.transport.loseConnection()
|
||||
|
||||
def count(self):
|
||||
return sum([len(potentials)
|
||||
for potentials
|
||||
in self._transit_server._pending_requests.values()])
|
||||
def wait(self):
|
||||
d = defer.Deferred()
|
||||
reactor.callLater(0.001, d.callback, None)
|
||||
return d
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_ignore_same_side(self):
|
||||
ep = clientFromString(reactor, self.transit)
|
||||
a1 = yield connectProtocol(ep, Accumulator())
|
||||
a2 = yield connectProtocol(ep, Accumulator())
|
||||
|
||||
token1 = b"\x00"*32
|
||||
side1 = b"\x01"*8
|
||||
a1.transport.write(b"please relay " + hexlify(token1) +
|
||||
b" for side " + hexlify(side1) + b"\n")
|
||||
# let that arrive
|
||||
while self.count() == 0:
|
||||
yield self.wait()
|
||||
a2.transport.write(b"please relay " + hexlify(token1) +
|
||||
b" for side " + hexlify(side1) + b"\n")
|
||||
# let that arrive
|
||||
while self.count() == 1:
|
||||
yield self.wait()
|
||||
self.assertEqual(self.count(), 2) # same-side connections don't match
|
||||
|
||||
a1.transport.loseConnection()
|
||||
a2.transport.loseConnection()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_bad_handshake(self):
|
||||
ep = clientFromString(reactor, self.transit)
|
||||
|
@ -1186,7 +1340,7 @@ class Transit(ServerBase, unittest.TestCase):
|
|||
|
||||
token1 = b"\x00"*32
|
||||
# sending too many bytes is impatience.
|
||||
a1.transport.write(b"please RELAY NOWNOW " + hexlify(token1) + b"\n")
|
||||
a1.transport.write(b"please relay " + hexlify(token1) + b"\nNOWNOWNOW")
|
||||
|
||||
exp = b"impatient\n"
|
||||
yield a1.waitForBytes(len(exp))
|
||||
|
|
Loading…
Reference in New Issue
Block a user