add Channel.get_first_of()
This allows the Wormhole code to wait for multiple messages, which will be useful for getting Confirmation messages soon.
This commit is contained in:
parent
b709a45891
commit
ae2a6c6a05
|
@ -49,10 +49,12 @@ class Channel:
|
|||
body = unhexlify(msg["body"].encode("ascii"))
|
||||
self._messages.add( (phase, body) )
|
||||
|
||||
def _find_inbound_message(self, phase):
|
||||
for (their_phase,body) in self._messages - self._sent_messages:
|
||||
def _find_inbound_message(self, phases):
|
||||
their_messages = self._messages - self._sent_messages
|
||||
for phase in phases:
|
||||
for (their_phase,body) in their_messages:
|
||||
if their_phase == phase:
|
||||
return body
|
||||
return (phase, body)
|
||||
return None
|
||||
|
||||
def send(self, phase, msg):
|
||||
|
@ -73,17 +75,21 @@ class Channel:
|
|||
resp = r.json()
|
||||
self._add_inbound_messages(resp["messages"])
|
||||
|
||||
def get(self, phase):
|
||||
def get_first_of(self, phases):
|
||||
if not isinstance(phases, (list, set)): raise TypeError(type(phases))
|
||||
for phase in phases:
|
||||
if not isinstance(phase, type(u"")): raise TypeError(type(phase))
|
||||
|
||||
# For now, server errors cause the client to fail. TODO: don't. This
|
||||
# will require changing the client to re-post messages when the
|
||||
# server comes back up.
|
||||
|
||||
# fire with a bytestring of the first message for 'phase' that wasn't
|
||||
# one of ours. It will either come from previously-received messages,
|
||||
# or from an EventSource that we attach to the corresponding URL
|
||||
body = self._find_inbound_message(phase)
|
||||
while body is None:
|
||||
# fire with a bytestring of the first message for any 'phase' that
|
||||
# wasn't one of our own messages. It will either come from
|
||||
# previously-received messages, or from an EventSource that we attach
|
||||
# to the corresponding URL
|
||||
phase_and_body = self._find_inbound_message(phases)
|
||||
while phase_and_body is None:
|
||||
remaining = self._started + self._timeout - time.time()
|
||||
if remaining < 0:
|
||||
raise Timeout
|
||||
|
@ -98,12 +104,17 @@ class Channel:
|
|||
self._handle_welcome(json.loads(data))
|
||||
if eventtype == "message":
|
||||
self._add_inbound_messages([json.loads(data)])
|
||||
body = self._find_inbound_message(phase)
|
||||
if body:
|
||||
phase_and_body = self._find_inbound_message(phases)
|
||||
if phase_and_body:
|
||||
f.close()
|
||||
break
|
||||
if not body:
|
||||
if not phase_and_body:
|
||||
time.sleep(self._wait)
|
||||
return phase_and_body
|
||||
|
||||
def get(self, phase):
|
||||
(got_phase, body) = self.get_first_of([phase])
|
||||
assert got_phase == phase
|
||||
return body
|
||||
|
||||
def deallocate(self, mood=None):
|
||||
|
|
|
@ -63,6 +63,46 @@ class Channel(ServerBase, unittest.TestCase):
|
|||
|
||||
return d
|
||||
|
||||
def test_get_multiple_phases(self):
|
||||
cm1 = ChannelManager(self.relayurl, APPID, u"side1", self.ignore)
|
||||
cm2 = ChannelManager(self.relayurl, APPID, u"side2", self.ignore)
|
||||
c1 = cm1.connect(1)
|
||||
c2 = cm2.connect(1)
|
||||
|
||||
self.failUnlessRaises(TypeError, c2.get_first_of, u"phase1")
|
||||
self.failUnlessRaises(TypeError, c2.get_first_of, [u"phase1", 7])
|
||||
|
||||
d = succeed(None)
|
||||
d.addCallback(lambda _: deferToThread(c1.send, u"phase1", b"msg1"))
|
||||
|
||||
d.addCallback(lambda _: deferToThread(c2.get_first_of, [u"phase1",
|
||||
u"phase2"]))
|
||||
d.addCallback(lambda phase_and_body:
|
||||
self.failUnlessEqual(phase_and_body,
|
||||
(u"phase1", b"msg1")))
|
||||
d.addCallback(lambda _: deferToThread(c2.get_first_of, [u"phase2",
|
||||
u"phase1"]))
|
||||
d.addCallback(lambda phase_and_body:
|
||||
self.failUnlessEqual(phase_and_body,
|
||||
(u"phase1", b"msg1")))
|
||||
|
||||
d.addCallback(lambda _: deferToThread(c1.send, u"phase2", b"msg2"))
|
||||
d.addCallback(lambda _: deferToThread(c2.get, u"phase2"))
|
||||
|
||||
# if both are present, it should prefer the first one we asked for
|
||||
d.addCallback(lambda _: deferToThread(c2.get_first_of, [u"phase1",
|
||||
u"phase2"]))
|
||||
d.addCallback(lambda phase_and_body:
|
||||
self.failUnlessEqual(phase_and_body,
|
||||
(u"phase1", b"msg1")))
|
||||
d.addCallback(lambda _: deferToThread(c2.get_first_of, [u"phase2",
|
||||
u"phase1"]))
|
||||
d.addCallback(lambda phase_and_body:
|
||||
self.failUnlessEqual(phase_and_body,
|
||||
(u"phase2", b"msg2")))
|
||||
|
||||
return d
|
||||
|
||||
def test_appid_independence(self):
|
||||
APPID_A = u"appid_A"
|
||||
APPID_B = u"appid_B"
|
||||
|
|
|
@ -61,6 +61,42 @@ class Channel(ServerBase, unittest.TestCase):
|
|||
|
||||
return d
|
||||
|
||||
def test_get_multiple_phases(self):
|
||||
cm1 = ChannelManager(self.relayurl, APPID, u"side1", self.ignore)
|
||||
cm2 = ChannelManager(self.relayurl, APPID, u"side2", self.ignore)
|
||||
c1 = cm1.connect(1)
|
||||
c2 = cm2.connect(1)
|
||||
|
||||
self.failUnlessRaises(TypeError, c2.get_first_of, u"phase1")
|
||||
self.failUnlessRaises(TypeError, c2.get_first_of, [u"phase1", 7])
|
||||
|
||||
d = succeed(None)
|
||||
d.addCallback(lambda _: c1.send(u"phase1", b"msg1"))
|
||||
|
||||
d.addCallback(lambda _: c2.get_first_of([u"phase1", u"phase2"]))
|
||||
d.addCallback(lambda phase_and_body:
|
||||
self.failUnlessEqual(phase_and_body,
|
||||
(u"phase1", b"msg1")))
|
||||
d.addCallback(lambda _: c2.get_first_of([u"phase2", u"phase1"]))
|
||||
d.addCallback(lambda phase_and_body:
|
||||
self.failUnlessEqual(phase_and_body,
|
||||
(u"phase1", b"msg1")))
|
||||
|
||||
d.addCallback(lambda _: c1.send(u"phase2", b"msg2"))
|
||||
d.addCallback(lambda _: c2.get(u"phase2"))
|
||||
|
||||
# if both are present, it should prefer the first one we asked for
|
||||
d.addCallback(lambda _: c2.get_first_of([u"phase1", u"phase2"]))
|
||||
d.addCallback(lambda phase_and_body:
|
||||
self.failUnlessEqual(phase_and_body,
|
||||
(u"phase1", b"msg1")))
|
||||
d.addCallback(lambda _: c2.get_first_of([u"phase2", u"phase1"]))
|
||||
d.addCallback(lambda phase_and_body:
|
||||
self.failUnlessEqual(phase_and_body,
|
||||
(u"phase2", b"msg2")))
|
||||
|
||||
return d
|
||||
|
||||
def test_appid_independence(self):
|
||||
APPID_A = u"appid_A"
|
||||
APPID_B = u"appid_B"
|
||||
|
|
|
@ -81,10 +81,12 @@ class Channel:
|
|||
body = unhexlify(msg["body"].encode("ascii"))
|
||||
self._messages.add( (phase, body) )
|
||||
|
||||
def _find_inbound_message(self, phase):
|
||||
for (their_phase,body) in self._messages - self._sent_messages:
|
||||
def _find_inbound_message(self, phases):
|
||||
their_messages = self._messages - self._sent_messages
|
||||
for phase in phases:
|
||||
for (their_phase,body) in their_messages:
|
||||
if their_phase == phase:
|
||||
return body
|
||||
return (phase, body)
|
||||
return None
|
||||
|
||||
def send(self, phase, msg):
|
||||
|
@ -102,14 +104,18 @@ class Channel:
|
|||
d.addCallback(lambda resp: self._add_inbound_messages(resp["messages"]))
|
||||
return d
|
||||
|
||||
def get(self, phase):
|
||||
def get_first_of(self, phases):
|
||||
if not isinstance(phases, (list, set)): raise TypeError(type(phases))
|
||||
for phase in phases:
|
||||
if not isinstance(phase, type(u"")): raise TypeError(type(phase))
|
||||
# fire with a bytestring of the first message for 'phase' that wasn't
|
||||
# one of ours. It will either come from previously-received messages,
|
||||
# or from an EventSource that we attach to the corresponding URL
|
||||
body = self._find_inbound_message(phase)
|
||||
if body is not None:
|
||||
return defer.succeed(body)
|
||||
|
||||
# fire with a bytestring of the first message for any 'phase' that
|
||||
# wasn't one of our own messages. It will either come from
|
||||
# previously-received messages, or from an EventSource that we attach
|
||||
# to the corresponding URL
|
||||
phase_and_body = self._find_inbound_message(phases)
|
||||
if phase_and_body is not None:
|
||||
return defer.succeed(phase_and_body)
|
||||
d = defer.Deferred()
|
||||
msgs = []
|
||||
def _handle(name, data):
|
||||
|
@ -117,9 +123,9 @@ class Channel:
|
|||
self._handle_welcome(json.loads(data))
|
||||
if name == "message":
|
||||
self._add_inbound_messages([json.loads(data)])
|
||||
body = self._find_inbound_message(phase)
|
||||
if body is not None and not msgs:
|
||||
msgs.append(body)
|
||||
phase_and_body = self._find_inbound_message(phases)
|
||||
if phase_and_body is not None and not msgs:
|
||||
msgs.append(phase_and_body)
|
||||
d.callback(None)
|
||||
# TODO: use agent=self._agent
|
||||
queryargs = urlencode([("appid", self._appid),
|
||||
|
@ -133,6 +139,15 @@ class Channel:
|
|||
d.addCallback(lambda _: msgs[0])
|
||||
return d
|
||||
|
||||
def get(self, phase):
|
||||
d = self.get_first_of([phase])
|
||||
def _got(res):
|
||||
(got_phase, body) = res
|
||||
assert got_phase == phase
|
||||
return body
|
||||
d.addCallback(_got)
|
||||
return d
|
||||
|
||||
def deallocate(self, mood=u"unknown"):
|
||||
# only try once, no retries
|
||||
d = post_json(self._agent, self._relay_url+"deallocate",
|
||||
|
|
Loading…
Reference in New Issue
Block a user