transcribe.py: factor out common polling code

This commit is contained in:
Brian Warner 2015-02-11 01:35:11 -08:00
parent 882644bfc1
commit 48476f0840

View File

@ -17,7 +17,21 @@ class Timeout(Exception):
# POST /CHANNEL-ID/SIDE/data/poll -> {messages: [STR..]} # POST /CHANNEL-ID/SIDE/data/poll -> {messages: [STR..]}
# POST /CHANNEL-ID/SIDE/deallocate -> waiting | deleted # POST /CHANNEL-ID/SIDE/deallocate -> waiting | deleted
class Initiator: class Common:
def url(self, suffix):
return "%s%d/%s/%s" % (self.relay, self.channel_id, self.side, suffix)
def poll(self, msgs, url_suffix):
while not msgs:
if time.time() > (self.started + self.timeout):
raise Timeout
time.sleep(self.wait)
r = requests.post(self.url(url_suffix))
r.raise_for_status()
msgs = r.json()["messages"]
return msgs
class Initiator(Common):
def __init__(self, appid, data, relay=RELAY): def __init__(self, appid, data, relay=RELAY):
self.appid = appid self.appid = appid
self.data = data self.data = data
@ -28,9 +42,6 @@ class Initiator:
self.timeout = 3*MINUTE self.timeout = 3*MINUTE
self.side = "initiator" self.side = "initiator"
def url(self, suffix):
return "%s%d/%s/%s" % (self.relay, self.channel_id, self.side, suffix)
def get_code(self): def get_code(self):
# allocate channel # allocate channel
r = requests.post(self.relay + "allocate") r = requests.post(self.relay + "allocate")
@ -48,15 +59,7 @@ class Initiator:
def get_data(self): def get_data(self):
# poll for PAKE response # poll for PAKE response
while True: msgs = self.poll([], "pake/poll")
r = requests.post(self.url("pake/poll"))
r.raise_for_status()
msgs = r.json()["messages"]
if msgs:
break
if time.time() > (self.started + self.timeout):
raise Timeout
time.sleep(self.wait)
pake_msg = unhexlify(msgs[0].encode("ascii")) pake_msg = unhexlify(msgs[0].encode("ascii"))
self.key = self.sp.finish(pake_msg) self.key = self.sp.finish(pake_msg)
@ -64,17 +67,10 @@ class Initiator:
post_data = json.dumps({"message": hexlify(self.data).decode("ascii")}) post_data = json.dumps({"message": hexlify(self.data).decode("ascii")})
r = requests.post(self.url("data/post"), data=post_data) r = requests.post(self.url("data/post"), data=post_data)
r.raise_for_status() r.raise_for_status()
other_msgs = r.json()["messages"]
# poll for data message # poll for data message
while True: msgs = self.poll(other_msgs, "data/poll")
r = requests.post(self.url("data/poll"))
r.raise_for_status()
msgs = r.json()["messages"]
if msgs:
break
if time.time() > (self.started + self.timeout):
raise Timeout
time.sleep(self.wait)
data = unhexlify(msgs[0].encode("ascii")) data = unhexlify(msgs[0].encode("ascii"))
# deallocate channel # deallocate channel
@ -83,7 +79,7 @@ class Initiator:
return data return data
class Receiver: class Receiver(Common):
def __init__(self, appid, data, code, relay=RELAY): def __init__(self, appid, data, code, relay=RELAY):
self.appid = appid self.appid = appid
self.data = data self.data = data
@ -99,26 +95,16 @@ class Receiver:
idA=self.appid+":Initiator", idA=self.appid+":Initiator",
idB=self.appid+":Receiver") idB=self.appid+":Receiver")
def url(self, suffix):
return "%s%d/%s/%s" % (self.relay, self.channel_id, self.side, suffix)
def get_data(self): def get_data(self):
# post PAKE message # post PAKE message
msg = self.sp.start() msg = self.sp.start()
post_data = {"message": hexlify(msg).decode("ascii")} post_data = {"message": hexlify(msg).decode("ascii")}
r = requests.post(self.url("pake/post"), data=json.dumps(post_data)) r = requests.post(self.url("pake/post"), data=json.dumps(post_data))
r.raise_for_status() r.raise_for_status()
other_msgs = r.json()["messages"]
# poll for PAKE response # poll for PAKE response
while True: msgs = self.poll(other_msgs, "pake/poll")
r = requests.post(self.url("pake/poll"))
r.raise_for_status()
msgs = r.json()["messages"]
if msgs:
break
if time.time() > (self.started + self.timeout):
raise Timeout
time.sleep(self.wait)
pake_msg = unhexlify(msgs[0].encode("ascii")) pake_msg = unhexlify(msgs[0].encode("ascii"))
self.key = self.sp.finish(pake_msg) self.key = self.sp.finish(pake_msg)
@ -126,17 +112,10 @@ class Receiver:
post_data = json.dumps({"message": hexlify(self.data).decode("ascii")}) post_data = json.dumps({"message": hexlify(self.data).decode("ascii")})
r = requests.post(self.url("data/post"), data=post_data) r = requests.post(self.url("data/post"), data=post_data)
r.raise_for_status() r.raise_for_status()
other_msgs = r.json()["messages"]
# poll for data message # poll for data message
while True: msgs = self.poll(other_msgs, "data/poll")
r = requests.post(self.url("data/poll"))
r.raise_for_status()
msgs = r.json()["messages"]
if msgs:
break
if time.time() > (self.started + self.timeout):
raise Timeout
time.sleep(self.wait)
data = unhexlify(msgs[0].encode("ascii")) data = unhexlify(msgs[0].encode("ascii"))
# deallocate channel # deallocate channel