diff --git a/src/wormhole/transcribe.py b/src/wormhole/transcribe.py index e588198..8aef09b 100644 --- a/src/wormhole/transcribe.py +++ b/src/wormhole/transcribe.py @@ -17,7 +17,21 @@ class Timeout(Exception): # POST /CHANNEL-ID/SIDE/data/poll -> {messages: [STR..]} # POST /CHANNEL-ID/SIDE/deallocate -> waiting | deleted -class Initiator: +class Common: + def url(self, suffix): + return "%s%d/%s/%s" % (self.relay, self.channel_id, self.side, suffix) + + def poll(self, msgs, url_suffix): + while not msgs: + if time.time() > (self.started + self.timeout): + raise Timeout + time.sleep(self.wait) + r = requests.post(self.url(url_suffix)) + r.raise_for_status() + msgs = r.json()["messages"] + return msgs + +class Initiator(Common): def __init__(self, appid, data, relay=RELAY): self.appid = appid self.data = data @@ -28,9 +42,6 @@ class Initiator: self.timeout = 3*MINUTE self.side = "initiator" - def url(self, suffix): - return "%s%d/%s/%s" % (self.relay, self.channel_id, self.side, suffix) - def get_code(self): # allocate channel r = requests.post(self.relay + "allocate") @@ -48,15 +59,7 @@ class Initiator: def get_data(self): # poll for PAKE response - while True: - r = requests.post(self.url("pake/poll")) - r.raise_for_status() - msgs = r.json()["messages"] - if msgs: - break - if time.time() > (self.started + self.timeout): - raise Timeout - time.sleep(self.wait) + msgs = self.poll([], "pake/poll") pake_msg = unhexlify(msgs[0].encode("ascii")) self.key = self.sp.finish(pake_msg) @@ -64,17 +67,10 @@ class Initiator: post_data = json.dumps({"message": hexlify(self.data).decode("ascii")}) r = requests.post(self.url("data/post"), data=post_data) r.raise_for_status() + other_msgs = r.json()["messages"] # poll for data message - while True: - r = requests.post(self.url("data/poll")) - r.raise_for_status() - msgs = r.json()["messages"] - if msgs: - break - if time.time() > (self.started + self.timeout): - raise Timeout - time.sleep(self.wait) + msgs = self.poll(other_msgs, "data/poll") data = unhexlify(msgs[0].encode("ascii")) # deallocate channel @@ -83,7 +79,7 @@ class Initiator: return data -class Receiver: +class Receiver(Common): def __init__(self, appid, data, code, relay=RELAY): self.appid = appid self.data = data @@ -99,26 +95,16 @@ class Receiver: idA=self.appid+":Initiator", idB=self.appid+":Receiver") - def url(self, suffix): - return "%s%d/%s/%s" % (self.relay, self.channel_id, self.side, suffix) - def get_data(self): # post PAKE message msg = self.sp.start() post_data = {"message": hexlify(msg).decode("ascii")} r = requests.post(self.url("pake/post"), data=json.dumps(post_data)) r.raise_for_status() + other_msgs = r.json()["messages"] # poll for PAKE response - while True: - r = requests.post(self.url("pake/poll")) - r.raise_for_status() - msgs = r.json()["messages"] - if msgs: - break - if time.time() > (self.started + self.timeout): - raise Timeout - time.sleep(self.wait) + msgs = self.poll(other_msgs, "pake/poll") pake_msg = unhexlify(msgs[0].encode("ascii")) self.key = self.sp.finish(pake_msg) @@ -126,17 +112,10 @@ class Receiver: post_data = json.dumps({"message": hexlify(self.data).decode("ascii")}) r = requests.post(self.url("data/post"), data=post_data) r.raise_for_status() + other_msgs = r.json()["messages"] # poll for data message - while True: - r = requests.post(self.url("data/poll")) - r.raise_for_status() - msgs = r.json()["messages"] - if msgs: - break - if time.time() > (self.started + self.timeout): - raise Timeout - time.sleep(self.wait) + msgs = self.poll(other_msgs, "data/poll") data = unhexlify(msgs[0].encode("ascii")) # deallocate channel