From 5994eb11d4398a981e2fcd4ab442d7f26b3eea0b Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Wed, 18 May 2016 00:16:46 -0700 Subject: [PATCH] WIP new proto --- src/wormhole/server/db-schemas/v2.sql | 2 +- src/wormhole/server/rendezvous.py | 195 +++++++++++++++----- src/wormhole/server/rendezvous_websocket.py | 70 ++++--- 3 files changed, 187 insertions(+), 80 deletions(-) diff --git a/src/wormhole/server/db-schemas/v2.sql b/src/wormhole/server/db-schemas/v2.sql index 795d49d..6659fc6 100644 --- a/src/wormhole/server/db-schemas/v2.sql +++ b/src/wormhole/server/db-schemas/v2.sql @@ -14,7 +14,7 @@ CREATE TABLE `version` CREATE TABLE `nameplates` ( `app_id` VARCHAR, - `id` VARCHAR PRIMARY KEY, + `id` VARCHAR, `mailbox_id` VARCHAR, -- really a foreign key `side1` VARCHAR, -- side name, or NULL `side2` VARCHAR -- side name, or NULL diff --git a/src/wormhole/server/rendezvous.py b/src/wormhole/server/rendezvous.py index 2072812..1007579 100644 --- a/src/wormhole/server/rendezvous.py +++ b/src/wormhole/server/rendezvous.py @@ -15,13 +15,75 @@ EXPIRATION_CHECK_PERIOD = 2*HOUR CLAIM = u"_claim" RELEASE = u"_release" -class Channel: - def __init__(self, app, db, blur_usage, log_requests, appid, channelid): +def get_sides(row): + return set([s for s in [row["side1"], row["side2"]] if s]) +def make_sides(side1, side2): + return list(sides) + [None] * (2 - len(sides)) +def generate_mailbox_id(): + return base64.b32encode(os.urandom(8)).lower().strip("=") + +# Unlike Channels, these instances are ephemeral, and are created and +# destroyed casually. +class Nameplate: + def __init__(self, app_id, db, id, mailbox_id): + self._app_id = app_id + self._db = db + self._id = id + self._mailbox_id = mailbox_id + + def get_id(self): + return self._id + + def get_mailbox_id(self): + return self._mailbox_id + + def claim(self, side, when): + db = self._db + sides = get_sides(db.execute("SELECT `side1`, `side2` FROM `nameplates`" + " WHERE `app_id`=? AND `id`=?", + (self._app_id, self._id)).fetchone()) + old_sides = len(sides) + sides.add(side) + if len(sides) > 2: + # XXX: crowded: bail + pass + sides12 = make_sides(sides) + db.execute("UPDATE `nameplates` SET `side1`=?, `side2`=?" + " WHERE `app_id`=? AND `id`=?", + (sides12[0], sides12[1], self._app_id, self._id)) + if old_sides == 0: + db.execute("UPDATE `mailboxes` SET `nameplate_started`=?" + " WHERE `app_id`=? AND `id`=?", + (when, self._app_id, self._mailbox_id)) + else: + db.execute("UPDATE `mailboxes` SET `nameplate_second`=?" + " WHERE `app_id`=? AND `id`=?", + (when, self._app_id, self._mailbox_id)) + db.commit() + + def release(self, side, when): + db = self._db + sides = get_sides(db.execute("SELECT `side1`, `side2` FROM `nameplates`" + " WHERE `app_id`=? AND `id`=?", + (self._app_id, self._id)).fetchone()) + sides.discard(side) + sides12 = make_sides(sides) + db.execute("UPDATE `nameplates` SET `side1`=?, `side2`=?" + " WHERE `app_id`=? AND `id`=?", + (sides12[0], sides12[1], self._app_id, self._id)) + if len(sides) == 0: + db.execute("UPDATE `mailboxes` SET `nameplate_closed`=?" + " WHERE `app_id`=? AND `id`=?", + (when, self._app_id, self._mailbox_id)) + db.commit() + +class Mailbox: + def __init__(self, app, db, blur_usage, log_requests, app_id, channelid): self._app = app self._db = db self._blur_usage = blur_usage self._log_requests = log_requests - self._appid = appid + self._app_id = app_id self._channelid = channelid self._listeners = {} # handle -> (send_f, stop_f) # "handle" is a hashable object, for deregistration @@ -34,9 +96,9 @@ class Channel: messages = [] db = self._db for row in db.execute("SELECT * FROM `messages`" - " WHERE `appid`=? AND `channelid`=?" + " WHERE `app_id`=? AND `channelid`=?" " ORDER BY `server_rx` ASC", - (self._appid, self._channelid)).fetchall(): + (self._app_id, self._channelid)).fetchall(): if row["phase"] in (CLAIM, RELEASE): continue messages.append({"phase": row["phase"], "body": row["body"], @@ -58,10 +120,10 @@ class Channel: def _add_message(self, side, phase, body, server_rx, msgid): db = self._db db.execute("INSERT INTO `messages`" - " (`appid`, `channelid`, `side`, `phase`, `body`," + " (`app_id`, `channelid`, `side`, `phase`, `body`," " `server_rx`, `msgid`)" " VALUES (?,?,?,?,?, ?,?)", - (self._appid, self._channelid, side, phase, body, + (self._app_id, self._channelid, side, phase, body, server_rx, msgid)) db.commit() @@ -78,13 +140,13 @@ class Channel: db = self._db seen = set([row["side"] for row in db.execute("SELECT `side` FROM `messages`" - " WHERE `appid`=? AND `channelid`=?", - (self._appid, self._channelid))]) + " WHERE `app_id`=? AND `channelid`=?", + (self._app_id, self._channelid))]) freed = set([row["side"] for row in db.execute("SELECT `side` FROM `messages`" - " WHERE `appid`=? AND `channelid`=?" + " WHERE `app_id`=? AND `channelid`=?" " AND `phase`=?", - (self._appid, self._channelid, RELEASE))]) + (self._app_id, self._channelid, RELEASE))]) if seen - freed: return False self.delete_and_summarize() @@ -94,9 +156,9 @@ class Channel: if self._listeners: return False c = self._db.execute("SELECT `server_rx` FROM `messages`" - " WHERE `appid`=? AND `channelid`=?" + " WHERE `app_id`=? AND `channelid`=?" " ORDER BY `server_rx` DESC LIMIT 1", - (self._appid, self._channelid)) + (self._app_id, self._channelid)) rows = c.fetchall() if not rows: return True @@ -169,15 +231,15 @@ class Channel: def delete_and_summarize(self): db = self._db c = self._db.execute("SELECT * FROM `messages`" - " WHERE `appid`=? AND `channelid`=?" + " WHERE `app_id`=? AND `channelid`=?" " ORDER BY `server_rx`", - (self._appid, self._channelid)) + (self._app_id, self._channelid)) messages = c.fetchall() summary = self._summarize(messages, time.time()) self._store_summary(summary) db.execute("DELETE FROM `messages`" - " WHERE `appid`=? AND `channelid`=?", - (self._appid, self._channelid)) + " WHERE `app_id`=? AND `channelid`=?", + (self._app_id, self._channelid)) db.commit() # Shut down any listeners, just in case they're still lingering @@ -193,37 +255,70 @@ class Channel: stop_f() class AppNamespace: - def __init__(self, db, welcome, blur_usage, log_requests, appid): + def __init__(self, db, welcome, blur_usage, log_requests, app_id): self._db = db self._welcome = welcome self._blur_usage = blur_usage self._log_requests = log_requests - self._appid = appid + self._app_id = app_id self._channels = {} - def get_claimed(self): + def get_nameplate_ids(self): db = self._db - c = db.execute("SELECT DISTINCT `channelid` FROM `messages`" - " WHERE `appid`=?", (self._appid,)) - return set([row["channelid"] for row in c.fetchall()]) + # TODO: filter this to numeric ids? + c = db.execute("SELECT DISTINCT `id` FROM `nameplates`" + " WHERE `app_id`=?", (self._app_id,)) + return set([row["id"] for row in c.fetchall()]) - def find_available_channelid(self): - claimed = self.get_claimed() + def find_available_nameplate_id(self): + claimed = self.get_nameplate_ids() for size in range(1,4): # stick to 1-999 for now available = set() - for cid_int in range(10**(size-1), 10**size): - cid = u"%d" % cid_int - if cid not in claimed: - available.add(cid) + for id_int in range(10**(size-1), 10**size): + id = u"%d" % id_int + if id not in claimed: + available.add(id) if available: return random.choice(list(available)) # ouch, 999 currently claimed. Try random ones for a while. for tries in range(1000): - cid_int = random.randrange(1000, 1000*1000) - cid = u"%d" % cid_int - if cid not in claimed: - return cid - raise ValueError("unable to find a free channel-id") + id_int = random.randrange(1000, 1000*1000) + id = u"%d" % id_int + if id not in claimed: + return id + raise ValueError("unable to find a free nameplate-id") + + def _get_mailbox_id(self, nameplate_id): + row = self._db.execute("SELECT `mailbox_id` FROM `nameplates`" + " WHERE `app_id`=? AND `id`=?", + (self._app_id, nameplate_id)).fetchone() + return row["mailbox_id"] + + def claim_nameplate(self, nameplate_id, side, when): + assert isinstance(nameplate_id, type(u"")), type(nameplate_id) + db = self._db + rows = db.execute("SELECT * FROM `nameplates`" + " WHERE `app_id`=? AND `id`=?", + (self._app_id, nameplate_id)) + if rows: + mailbox_id = rows[0]["mailbox_id"] + else: + if self._log_requests: + log.msg("creating nameplate#%s for app_id %s" % + (nameplate_id, self._app_id)) + mailbox_id = UUID() + db.execute("INSERT INTO `mailboxes`" + " (`app_id`, `id`)" + " VALUES(?,?)", + (self._app_id, mailbox_id)) + db.execute("INSERT INTO `nameplates`" + " (`app_id`, `id`, `mailbox_id`, `side1`, `side2`)" + " VALUES(?,?,?,?,?)", + (self._app_id, nameplate_id, mailbox_id, None, None)) + + nameplate = Nameplate(self._app_id, self._db, nameplate_id, mailbox_id) + nameplate.claim(side, when) + return nameplate def claim_channel(self, channelid, side): assert isinstance(channelid, type(u"")), type(channelid) @@ -235,11 +330,11 @@ class AppNamespace: assert isinstance(channelid, type(u"")) if not channelid in self._channels: if self._log_requests: - log.msg("spawning #%s for appid %s" % (channelid, self._appid)) + log.msg("spawning #%s for app_id %s" % (channelid, self._app_id)) self._channels[channelid] = Channel(self, self._db, self._blur_usage, self._log_requests, - self._appid, channelid) + self._app_id, channelid) return self._channels[channelid] def free_channel(self, channelid): @@ -293,28 +388,28 @@ class Rendezvous(service.MultiService): def get_log_requests(self): return self._log_requests - def get_app(self, appid): - assert isinstance(appid, type(u"")) - if not appid in self._apps: + def get_app(self, app_id): + assert isinstance(app_id, type(u"")) + if not app_id in self._apps: if self._log_requests: - log.msg("spawning appid %s" % (appid,)) - self._apps[appid] = AppNamespace(self._db, self._welcome, + log.msg("spawning app_id %s" % (app_id,)) + self._apps[app_id] = AppNamespace(self._db, self._welcome, self._blur_usage, - self._log_requests, appid) - return self._apps[appid] + self._log_requests, app_id) + return self._apps[app_id] def prune(self): # As with AppNamespace.prune_old_channels, we log for now. log.msg("beginning app prune") - c = self._db.execute("SELECT DISTINCT `appid` FROM `messages`") - apps = set([row["appid"] for row in c.fetchall()]) # these have messages + c = self._db.execute("SELECT DISTINCT `app_id` FROM `messages`") + apps = set([row["app_id"] for row in c.fetchall()]) # these have messages apps.update(self._apps) # these might have listeners - for appid in apps: - log.msg(" app prune checking %r" % (appid,)) - still_active = self.get_app(appid).prune_old_channels() + for app_id in apps: + log.msg(" app prune checking %r" % (app_id,)) + still_active = self.get_app(app_id).prune_old_channels() if not still_active: - log.msg("prune pops app %r" % (appid,)) - self._apps.pop(appid) + log.msg("prune pops app %r" % (app_id,)) + self._apps.pop(app_id) log.msg("app prune ends, %d remaining apps" % len(self._apps)) def stopService(self): diff --git a/src/wormhole/server/rendezvous_websocket.py b/src/wormhole/server/rendezvous_websocket.py index e7f1fde..b2bf3f2 100644 --- a/src/wormhole/server/rendezvous_websocket.py +++ b/src/wormhole/server/rendezvous_websocket.py @@ -102,10 +102,7 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): try: if "type" not in msg: raise Error("missing 'type'") - if "id" in msg: - # Only ack clients modern enough to include [id]. Older ones - # won't recognize the message, then they'll abort. - self.send("ack", id=msg["id"]) + self.send("ack", id=msg.get("id")) mtype = msg["type"] if mtype == "ping": @@ -118,15 +115,18 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): if mtype == "list": return self.handle_list() if mtype == "allocate": - return self.handle_allocate() + return self.handle_allocate(server_rx) if mtype == "claim": - return self.handle_claim(msg) - if mtype == "watch": - return self.handle_watch(msg) + return self.handle_claim(msg, server_rx) + if mtype == "release": + return self.handle_release(msg, server_rx) + + if mtype == "open": + return self.handle_open(msg) if mtype == "add": return self.handle_add(msg, server_rx) - if mtype == "release": - return self.handle_release(msg) + if mtype == "close": + return self.handle_close(msg) raise Error("Unknown type") except Error as e: @@ -147,30 +147,42 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): self._app = self.factory.rendezvous.get_app(msg["appid"]) self._side = msg["side"] - def handle_list(self): - channelids = sorted(self._app.get_claimed()) - self.send("channelids", channelids=channelids) - def handle_allocate(self): + def handle_list(self): + nameplate_ids = sorted(self._app.get_nameplate_ids()) + self.send("nameplates", nameplates=nameplate_ids) + + def handle_allocate(self, server_rx): if self._did_allocate: raise Error("You already allocated one channel, don't be greedy") - channelid = self._app.find_available_channelid() - assert isinstance(channelid, type(u"")) + nameplate_id = self._app.find_available_nameplate_id() + assert isinstance(nameplate_id, type(u"")) self._did_allocate = True - channel = self._app.claim_channel(channelid, self._side) - self._channels[channelid] = channel - self.send("allocated", channelid=channelid) + self._nameplate = self._app.claim_nameplate(nameplate_id, self._side, + server_rx) + self.send("nameplate", nameplate=nameplate_id) - def handle_claim(self, msg): - if "channelid" not in msg: - raise Error("claim requires 'channelid'") - channelid = msg["channelid"] - assert isinstance(channelid, type(u"")), type(channelid) - if channelid not in self._channels: - channel = self._app.claim_channel(channelid, self._side) - self._channels[channelid] = channel + def handle_claim(self, msg, server_rx): + if "nameplate" not in msg: + raise Error("claim requires 'nameplate'") + nameplate_id = msg["nameplate"] + assert isinstance(nameplate_id, type(u"")), type(nameplate) + if self._nameplate and self._nameplate.get_id() != nameplate_id: + raise Error("claimed nameplate doesn't match allocated nameplate") + self._nameplate = self._app.claim_nameplate(nameplate_id, self._side, + server_rx) + mailbox_id = self._nameplate.get_mailbox_id() + self.send("mailbox", mailbox=mailbox_id) - def handle_watch(self, msg): + def handle_release(self, server_rx): + if not self._nameplate: + raise Error("must claim a nameplate before releasing it") + + deleted = self._nameplate.release(self._side, server_rx) + self._nameplate = None + + + def handle_open(self, msg): channelid = msg["channelid"] if channelid not in self._channels: raise Error("must claim channel before watching") @@ -197,7 +209,7 @@ class WebSocketRendezvous(websocket.WebSocketServerProtocol): channel.add_message(self._side, msg["phase"], msg["body"], server_rx, msgid) - def handle_release(self, msg): + def handle_close(self, msg): channelid = msg["channelid"] if channelid not in self._channels: raise Error("must claim channel before releasing")