diff --git a/src/wormhole/server/transit_server.py b/src/wormhole/server/transit_server.py index 4798e27..d2e5ffa 100644 --- a/src/wormhole/server/transit_server.py +++ b/src/wormhole/server/transit_server.py @@ -25,6 +25,7 @@ def blur_size(size): class TransitConnection(protocol.Protocol): def __init__(self): self._got_token = False + self._got_side = False self._token_buffer = b"" self._sent_ok = False self._buddy = None @@ -32,9 +33,14 @@ class TransitConnection(protocol.Protocol): self._total_sent = 0 def describeToken(self): + d = "-" if self._got_token: - return self._got_token[:16].decode("ascii") - return "-" + d = self._got_token[:16].decode("ascii") + if self._got_side: + d += "-" + self._got_side.decode("ascii") + else: + d += "-" + return d def connectionMade(self): self._started = time.time() @@ -59,26 +65,69 @@ class TransitConnection(protocol.Protocol): # else this should be (part of) the token self._token_buffer += data buf = self._token_buffer - wanted = len("please relay \n")+32*2 - if len(buf) < wanted-1 and b"\n" in buf: - self.transport.write(b"bad handshake\n") - log.msg("transit handshake early failure") - return self.disconnect() - if len(buf) < wanted: - return - if len(buf) > wanted: - self.transport.write(b"impatient\n") - log.msg("transit impatience failure") - return self.disconnect() # impatience yields failure - mo = re.search(br"^please relay (\w{64})\n", buf, re.M) - if not mo: + + # old: "please relay {64}\n" + # new: "please relay {64} for side {16}\n" + (old, handshake_len, token) = self._check_old_handshake(buf) + assert old in ("yes", "waiting", "no") + if old == "yes": + # remember they aren't supposed to send anything past their + # handshake until we've said go + if len(buf) > handshake_len: + self.transport.write(b"impatient\n") + log.msg("transit impatience failure") + return self.disconnect() # impatience yields failure + return self._got_handshake(token, None) + (new, handshake_len, token, side) = self._check_new_handshake(buf) + assert new in ("yes", "waiting", "no") + if new == "yes": + if len(buf) > handshake_len: + self.transport.write(b"impatient\n") + log.msg("transit impatience failure") + return self.disconnect() # impatience yields failure + return self._got_handshake(token, side) + if (old == "no" and new == "no"): self.transport.write(b"bad handshake\n") log.msg("transit handshake failure") return self.disconnect() # incorrectness yields failure - token = mo.group(1) + # else we'll keep waiting + def _check_old_handshake(self, buf): + # old: "please relay {64}\n" + # return ("yes", handshake, token) if buf contains an old-style handshake + # return ("waiting", None, None) if it might eventually contain one + # return ("no", None, None) if it could never contain one + wanted = len("please relay \n")+32*2 + if len(buf) < wanted-1 and b"\n" in buf: + return ("no", None, None) + if len(buf) < wanted: + return ("waiting", None, None) + + mo = re.search(br"^please relay (\w{64})\n", buf, re.M) + if mo: + token = mo.group(1) + return ("yes", wanted, token) + return ("no", None, None) + + def _check_new_handshake(self, buf): + # new: "please relay {64} for side {16}\n" + wanted = len("please relay for side \n")+32*2+8*2 + if len(buf) < wanted-1 and b"\n" in buf: + return ("no", None, None, None) + if len(buf) < wanted: + return ("waiting", None, None, None) + + mo = re.search(br"^please relay (\w{64}) for side (\w{16})\n", buf, re.M) + if mo: + token = mo.group(1) + side = mo.group(2) + return ("yes", wanted, token, side) + return ("no", None, None, None) + + def _got_handshake(self, token, side): self._got_token = token - self.factory.connection_got_token(token, self) + self._got_side = side + self.factory.connection_got_token(token, side, self) def buddy_connected(self, them): self._buddy = them @@ -100,7 +149,7 @@ class TransitConnection(protocol.Protocol): def connectionLost(self, reason): if self._buddy: self._buddy.buddy_disconnected() - self.factory.transitFinished(self, self._got_token, + self.factory.transitFinished(self, self._got_token, self._got_side, self.describeToken()) # Record usage. There are four cases: @@ -136,12 +185,18 @@ class TransitConnection(protocol.Protocol): class Transit(protocol.ServerFactory, service.MultiService): # I manage pairs of simultaneous connections to a secondary TCP port, # both forwarded to the other. Clients must begin each connection with - # "please relay TOKEN\n". I will send "ok\n" when the matching connection - # is established, or disconnect if no matching connection is made within - # MAX_WAIT_TIME seconds. I will disconnect if you send data before the - # "ok\n". All data you get after the "ok\n" will be from the other side. - # You will not receive "ok\n" until the other side has also connected and - # submitted a matching token. The token is the same for each side. + # "please relay TOKEN for SIDE\n" (or a legacy form without the "for + # SIDE"). Two connections match if they use the same TOKEN and have + # different SIDEs (the redundant connections are dropped when a match is + # made). Legacy connections match any with the same TOKEN, ignoring SIDE + # (so two legacy connections will match each other). + + # I will send "ok\n" when the matching connection is established, or + # disconnect if no matching connection is made within MAX_WAIT_TIME + # seconds. I will disconnect if you send data before the "ok\n". All data + # you get after the "ok\n" will be from the other side. You will not + # receive "ok\n" until the other side has also connected and submitted a + # matching token (and differing SIDE). # In addition, the connections will be dropped after MAXLENGTH bytes have # been sent by either side, or MAXTIME seconds have elapsed after the @@ -164,23 +219,38 @@ class Transit(protocol.ServerFactory, service.MultiService): service.MultiService.__init__(self) self._db = db self._blur_usage = blur_usage - self._pending_requests = {} # token -> TransitConnection + self._pending_requests = {} # token -> set((side, TransitConnection)) self._active_connections = set() # TransitConnection self._counts = collections.defaultdict(int) self._count_bytes = 0 - def connection_got_token(self, token, p): - if token in self._pending_requests: - log.msg("transit relay 2: %s" % p.describeToken()) - buddy = self._pending_requests.pop(token) - self._active_connections.add(p) - self._active_connections.add(buddy) - p.buddy_connected(buddy) - buddy.buddy_connected(p) - else: - self._pending_requests[token] = p - log.msg("transit relay 1: %s" % p.describeToken()) - # TODO: timer + def connection_got_token(self, token, new_side, new_tc): + if token not in self._pending_requests: + self._pending_requests[token] = set() + potentials = self._pending_requests[token] + for old in potentials: + (old_side, old_tc) = old + if ((old_side is None) + or (new_side is None) + or (old_side != new_side)): + # we found a match + log.msg("transit relay 2: %s" % new_tc.describeToken()) + + # drop and stop tracking the rest + potentials.remove(old) + for (_, leftover_tc) in potentials: + leftover_tc.disconnect() # TODO: not "errory"? + self._pending_requests.pop(token) + + # glue the two ends together + self._active_connections.add(new_tc) + self._active_connections.add(old_tc) + new_tc.buddy_connected(old_tc) + old_tc.buddy_connected(new_tc) + return + log.msg("transit relay 1: %s" % new_tc.describeToken()) + potentials.add((new_side, new_tc)) + # TODO: timer def recordUsage(self, started, result, total_bytes, total_time, waiting_time): @@ -198,13 +268,15 @@ class Transit(protocol.ServerFactory, service.MultiService): self._counts[result] += 1 self._count_bytes += total_bytes - def transitFinished(self, p, token, description): - for token,tc in self._pending_requests.items(): - if tc is p: + def transitFinished(self, tc, token, side, description): + if token in self._pending_requests: + side_tc = (side, tc) + if side_tc in self._pending_requests[token]: + self._pending_requests[token].remove(side_tc) + if not self._pending_requests[token]: # set is now empty del self._pending_requests[token] - break log.msg("transitFinished %s" % (description,)) - self._active_connections.discard(p) + self._active_connections.discard(tc) def transitFailed(self, p): log.msg("transitFailed %r" % p) diff --git a/src/wormhole/test/test_server.py b/src/wormhole/test/test_server.py index 9bc7f91..2d67ab2 100644 --- a/src/wormhole/test/test_server.py +++ b/src/wormhole/test/test_server.py @@ -1072,6 +1072,7 @@ class Accumulator(protocol.Protocol): self.data = b"" self.count = 0 self._wait = None + self._disconnect = defer.Deferred() def waitForBytes(self, more): assert self._wait is None self.count = more @@ -1089,6 +1090,7 @@ class Accumulator(protocol.Protocol): def connectionLost(self, why): if self._wait: self._wait.errback(RuntimeError("closed")) + self._disconnect.callback(None) class Transit(ServerBase, unittest.TestCase): def test_blur_size(self): @@ -1115,7 +1117,32 @@ class Transit(ServerBase, unittest.TestCase): self.assertEqual('Wormhole Relay'.encode('ascii'), resp.strip()) @defer.inlineCallbacks - def test_basic(self): + def test_register(self): + ep = clientFromString(reactor, self.transit) + a1 = yield connectProtocol(ep, Accumulator()) + + token1 = b"\x00"*32 + side1 = b"\x01"*8 + a1.transport.write(b"please relay " + hexlify(token1) + + b" for side " + hexlify(side1) + b"\n") + + # let that arrive + while self.count() == 0: + yield self.wait() + self.assertEqual(self.count(), 1) + + a1.transport.loseConnection() + + # let that get removed + while self.count() > 0: + yield self.wait() + self.assertEqual(self.count(), 0) + + # the token should be removed too + self.assertEqual(len(self._transit_server._pending_requests), 0) + + @defer.inlineCallbacks + def test_both_unsided(self): ep = clientFromString(reactor, self.transit) a1 = yield connectProtocol(ep, Accumulator()) a2 = yield connectProtocol(ep, Accumulator()) @@ -1143,6 +1170,133 @@ class Transit(ServerBase, unittest.TestCase): a1.transport.loseConnection() a2.transport.loseConnection() + @defer.inlineCallbacks + def test_sided_unsided(self): + ep = clientFromString(reactor, self.transit) + a1 = yield connectProtocol(ep, Accumulator()) + a2 = yield connectProtocol(ep, Accumulator()) + + token1 = b"\x00"*32 + side1 = b"\x01"*8 + a1.transport.write(b"please relay " + hexlify(token1) + + b" for side " + hexlify(side1) + b"\n") + a2.transport.write(b"please relay " + hexlify(token1) + b"\n") + + # a correct handshake yields an ack, after which we can send + exp = b"ok\n" + yield a1.waitForBytes(len(exp)) + self.assertEqual(a1.data, exp) + s1 = b"data1" + a1.transport.write(s1) + + exp = b"ok\n" + yield a2.waitForBytes(len(exp)) + self.assertEqual(a2.data, exp) + + # all data they sent after the handshake should be given to us + exp = b"ok\n"+s1 + yield a2.waitForBytes(len(exp)) + self.assertEqual(a2.data, exp) + + a1.transport.loseConnection() + a2.transport.loseConnection() + + @defer.inlineCallbacks + def test_unsided_sided(self): + ep = clientFromString(reactor, self.transit) + a1 = yield connectProtocol(ep, Accumulator()) + a2 = yield connectProtocol(ep, Accumulator()) + + token1 = b"\x00"*32 + side1 = b"\x01"*8 + a1.transport.write(b"please relay " + hexlify(token1) + b"\n") + a2.transport.write(b"please relay " + hexlify(token1) + + b" for side " + hexlify(side1) + b"\n") + + # a correct handshake yields an ack, after which we can send + exp = b"ok\n" + yield a1.waitForBytes(len(exp)) + self.assertEqual(a1.data, exp) + s1 = b"data1" + a1.transport.write(s1) + + exp = b"ok\n" + yield a2.waitForBytes(len(exp)) + self.assertEqual(a2.data, exp) + + # all data they sent after the handshake should be given to us + exp = b"ok\n"+s1 + yield a2.waitForBytes(len(exp)) + self.assertEqual(a2.data, exp) + + a1.transport.loseConnection() + a2.transport.loseConnection() + + @defer.inlineCallbacks + def test_both_sided(self): + ep = clientFromString(reactor, self.transit) + a1 = yield connectProtocol(ep, Accumulator()) + a2 = yield connectProtocol(ep, Accumulator()) + + token1 = b"\x00"*32 + side1 = b"\x01"*8 + side2 = b"\x02"*8 + a1.transport.write(b"please relay " + hexlify(token1) + + b" for side " + hexlify(side1) + b"\n") + a2.transport.write(b"please relay " + hexlify(token1) + + b" for side " + hexlify(side2) + b"\n") + + # a correct handshake yields an ack, after which we can send + exp = b"ok\n" + yield a1.waitForBytes(len(exp)) + self.assertEqual(a1.data, exp) + s1 = b"data1" + a1.transport.write(s1) + + exp = b"ok\n" + yield a2.waitForBytes(len(exp)) + self.assertEqual(a2.data, exp) + + # all data they sent after the handshake should be given to us + exp = b"ok\n"+s1 + yield a2.waitForBytes(len(exp)) + self.assertEqual(a2.data, exp) + + a1.transport.loseConnection() + a2.transport.loseConnection() + + def count(self): + return sum([len(potentials) + for potentials + in self._transit_server._pending_requests.values()]) + def wait(self): + d = defer.Deferred() + reactor.callLater(0.001, d.callback, None) + return d + + @defer.inlineCallbacks + def test_ignore_same_side(self): + ep = clientFromString(reactor, self.transit) + a1 = yield connectProtocol(ep, Accumulator()) + a2 = yield connectProtocol(ep, Accumulator()) + + token1 = b"\x00"*32 + side1 = b"\x01"*8 + a1.transport.write(b"please relay " + hexlify(token1) + + b" for side " + hexlify(side1) + b"\n") + # let that arrive + while self.count() == 0: + yield self.wait() + a2.transport.write(b"please relay " + hexlify(token1) + + b" for side " + hexlify(side1) + b"\n") + # let that arrive + while self.count() == 1: + yield self.wait() + self.assertEqual(self.count(), 2) # same-side connections don't match + + a1.transport.loseConnection() + a2.transport.loseConnection() + @defer.inlineCallbacks def test_bad_handshake(self): ep = clientFromString(reactor, self.transit) @@ -1186,7 +1340,7 @@ class Transit(ServerBase, unittest.TestCase): token1 = b"\x00"*32 # sending too many bytes is impatience. - a1.transport.write(b"please RELAY NOWNOW " + hexlify(token1) + b"\n") + a1.transport.write(b"please relay " + hexlify(token1) + b"\nNOWNOWNOW") exp = b"impatient\n" yield a1.waitForBytes(len(exp))