diff --git a/src/wormhole/blocking/transcribe.py b/src/wormhole/blocking/transcribe.py index 967f035..97580e0 100644 --- a/src/wormhole/blocking/transcribe.py +++ b/src/wormhole/blocking/transcribe.py @@ -49,10 +49,12 @@ class Channel: body = unhexlify(msg["body"].encode("ascii")) self._messages.add( (phase, body) ) - def _find_inbound_message(self, phase): - for (their_phase,body) in self._messages - self._sent_messages: - if their_phase == phase: - return body + def _find_inbound_message(self, phases): + their_messages = self._messages - self._sent_messages + for phase in phases: + for (their_phase,body) in their_messages: + if their_phase == phase: + return (phase, body) return None def send(self, phase, msg): @@ -73,17 +75,21 @@ class Channel: resp = r.json() self._add_inbound_messages(resp["messages"]) - def get(self, phase): - if not isinstance(phase, type(u"")): raise TypeError(type(phase)) + def get_first_of(self, phases): + if not isinstance(phases, (list, set)): raise TypeError(type(phases)) + for phase in phases: + if not isinstance(phase, type(u"")): raise TypeError(type(phase)) + # For now, server errors cause the client to fail. TODO: don't. This # will require changing the client to re-post messages when the # server comes back up. - # fire with a bytestring of the first message for 'phase' that wasn't - # one of ours. It will either come from previously-received messages, - # or from an EventSource that we attach to the corresponding URL - body = self._find_inbound_message(phase) - while body is None: + # fire with a bytestring of the first message for any 'phase' that + # wasn't one of our own messages. It will either come from + # previously-received messages, or from an EventSource that we attach + # to the corresponding URL + phase_and_body = self._find_inbound_message(phases) + while phase_and_body is None: remaining = self._started + self._timeout - time.time() if remaining < 0: raise Timeout @@ -98,12 +104,17 @@ class Channel: self._handle_welcome(json.loads(data)) if eventtype == "message": self._add_inbound_messages([json.loads(data)]) - body = self._find_inbound_message(phase) - if body: + phase_and_body = self._find_inbound_message(phases) + if phase_and_body: f.close() break - if not body: + if not phase_and_body: time.sleep(self._wait) + return phase_and_body + + def get(self, phase): + (got_phase, body) = self.get_first_of([phase]) + assert got_phase == phase return body def deallocate(self, mood=None): diff --git a/src/wormhole/test/test_blocking.py b/src/wormhole/test/test_blocking.py index e6a2ba8..5801e99 100644 --- a/src/wormhole/test/test_blocking.py +++ b/src/wormhole/test/test_blocking.py @@ -63,6 +63,46 @@ class Channel(ServerBase, unittest.TestCase): return d + def test_get_multiple_phases(self): + cm1 = ChannelManager(self.relayurl, APPID, u"side1", self.ignore) + cm2 = ChannelManager(self.relayurl, APPID, u"side2", self.ignore) + c1 = cm1.connect(1) + c2 = cm2.connect(1) + + self.failUnlessRaises(TypeError, c2.get_first_of, u"phase1") + self.failUnlessRaises(TypeError, c2.get_first_of, [u"phase1", 7]) + + d = succeed(None) + d.addCallback(lambda _: deferToThread(c1.send, u"phase1", b"msg1")) + + d.addCallback(lambda _: deferToThread(c2.get_first_of, [u"phase1", + u"phase2"])) + d.addCallback(lambda phase_and_body: + self.failUnlessEqual(phase_and_body, + (u"phase1", b"msg1"))) + d.addCallback(lambda _: deferToThread(c2.get_first_of, [u"phase2", + u"phase1"])) + d.addCallback(lambda phase_and_body: + self.failUnlessEqual(phase_and_body, + (u"phase1", b"msg1"))) + + d.addCallback(lambda _: deferToThread(c1.send, u"phase2", b"msg2")) + d.addCallback(lambda _: deferToThread(c2.get, u"phase2")) + + # if both are present, it should prefer the first one we asked for + d.addCallback(lambda _: deferToThread(c2.get_first_of, [u"phase1", + u"phase2"])) + d.addCallback(lambda phase_and_body: + self.failUnlessEqual(phase_and_body, + (u"phase1", b"msg1"))) + d.addCallback(lambda _: deferToThread(c2.get_first_of, [u"phase2", + u"phase1"])) + d.addCallback(lambda phase_and_body: + self.failUnlessEqual(phase_and_body, + (u"phase2", b"msg2"))) + + return d + def test_appid_independence(self): APPID_A = u"appid_A" APPID_B = u"appid_B" diff --git a/src/wormhole/test/test_twisted.py b/src/wormhole/test/test_twisted.py index 46b352b..fd20498 100644 --- a/src/wormhole/test/test_twisted.py +++ b/src/wormhole/test/test_twisted.py @@ -61,6 +61,42 @@ class Channel(ServerBase, unittest.TestCase): return d + def test_get_multiple_phases(self): + cm1 = ChannelManager(self.relayurl, APPID, u"side1", self.ignore) + cm2 = ChannelManager(self.relayurl, APPID, u"side2", self.ignore) + c1 = cm1.connect(1) + c2 = cm2.connect(1) + + self.failUnlessRaises(TypeError, c2.get_first_of, u"phase1") + self.failUnlessRaises(TypeError, c2.get_first_of, [u"phase1", 7]) + + d = succeed(None) + d.addCallback(lambda _: c1.send(u"phase1", b"msg1")) + + d.addCallback(lambda _: c2.get_first_of([u"phase1", u"phase2"])) + d.addCallback(lambda phase_and_body: + self.failUnlessEqual(phase_and_body, + (u"phase1", b"msg1"))) + d.addCallback(lambda _: c2.get_first_of([u"phase2", u"phase1"])) + d.addCallback(lambda phase_and_body: + self.failUnlessEqual(phase_and_body, + (u"phase1", b"msg1"))) + + d.addCallback(lambda _: c1.send(u"phase2", b"msg2")) + d.addCallback(lambda _: c2.get(u"phase2")) + + # if both are present, it should prefer the first one we asked for + d.addCallback(lambda _: c2.get_first_of([u"phase1", u"phase2"])) + d.addCallback(lambda phase_and_body: + self.failUnlessEqual(phase_and_body, + (u"phase1", b"msg1"))) + d.addCallback(lambda _: c2.get_first_of([u"phase2", u"phase1"])) + d.addCallback(lambda phase_and_body: + self.failUnlessEqual(phase_and_body, + (u"phase2", b"msg2"))) + + return d + def test_appid_independence(self): APPID_A = u"appid_A" APPID_B = u"appid_B" diff --git a/src/wormhole/twisted/transcribe.py b/src/wormhole/twisted/transcribe.py index b9e328d..797c3f4 100644 --- a/src/wormhole/twisted/transcribe.py +++ b/src/wormhole/twisted/transcribe.py @@ -81,10 +81,12 @@ class Channel: body = unhexlify(msg["body"].encode("ascii")) self._messages.add( (phase, body) ) - def _find_inbound_message(self, phase): - for (their_phase,body) in self._messages - self._sent_messages: - if their_phase == phase: - return body + def _find_inbound_message(self, phases): + their_messages = self._messages - self._sent_messages + for phase in phases: + for (their_phase,body) in their_messages: + if their_phase == phase: + return (phase, body) return None def send(self, phase, msg): @@ -102,14 +104,18 @@ class Channel: d.addCallback(lambda resp: self._add_inbound_messages(resp["messages"])) return d - def get(self, phase): - if not isinstance(phase, type(u"")): raise TypeError(type(phase)) - # fire with a bytestring of the first message for 'phase' that wasn't - # one of ours. It will either come from previously-received messages, - # or from an EventSource that we attach to the corresponding URL - body = self._find_inbound_message(phase) - if body is not None: - return defer.succeed(body) + def get_first_of(self, phases): + if not isinstance(phases, (list, set)): raise TypeError(type(phases)) + for phase in phases: + if not isinstance(phase, type(u"")): raise TypeError(type(phase)) + + # fire with a bytestring of the first message for any 'phase' that + # wasn't one of our own messages. It will either come from + # previously-received messages, or from an EventSource that we attach + # to the corresponding URL + phase_and_body = self._find_inbound_message(phases) + if phase_and_body is not None: + return defer.succeed(phase_and_body) d = defer.Deferred() msgs = [] def _handle(name, data): @@ -117,9 +123,9 @@ class Channel: self._handle_welcome(json.loads(data)) if name == "message": self._add_inbound_messages([json.loads(data)]) - body = self._find_inbound_message(phase) - if body is not None and not msgs: - msgs.append(body) + phase_and_body = self._find_inbound_message(phases) + if phase_and_body is not None and not msgs: + msgs.append(phase_and_body) d.callback(None) # TODO: use agent=self._agent queryargs = urlencode([("appid", self._appid), @@ -133,6 +139,15 @@ class Channel: d.addCallback(lambda _: msgs[0]) return d + def get(self, phase): + d = self.get_first_of([phase]) + def _got(res): + (got_phase, body) = res + assert got_phase == phase + return body + d.addCallback(_got) + return d + def deallocate(self, mood=u"unknown"): # only try once, no retries d = post_json(self._agent, self._relay_url+"deallocate",