add Channel.get_first_of()

This allows the Wormhole code to wait for multiple messages, which will
be useful for getting Confirmation messages soon.
This commit is contained in:
Brian Warner 2015-11-16 16:47:52 -08:00
parent b709a45891
commit ae2a6c6a05
4 changed files with 131 additions and 29 deletions

View File

@ -49,10 +49,12 @@ class Channel:
body = unhexlify(msg["body"].encode("ascii"))
self._messages.add( (phase, body) )
def _find_inbound_message(self, phase):
for (their_phase,body) in self._messages - self._sent_messages:
def _find_inbound_message(self, phases):
their_messages = self._messages - self._sent_messages
for phase in phases:
for (their_phase,body) in their_messages:
if their_phase == phase:
return body
return (phase, body)
return None
def send(self, phase, msg):
@ -73,17 +75,21 @@ class Channel:
resp = r.json()
self._add_inbound_messages(resp["messages"])
def get(self, phase):
def get_first_of(self, phases):
if not isinstance(phases, (list, set)): raise TypeError(type(phases))
for phase in phases:
if not isinstance(phase, type(u"")): raise TypeError(type(phase))
# For now, server errors cause the client to fail. TODO: don't. This
# will require changing the client to re-post messages when the
# server comes back up.
# fire with a bytestring of the first message for 'phase' that wasn't
# one of ours. It will either come from previously-received messages,
# or from an EventSource that we attach to the corresponding URL
body = self._find_inbound_message(phase)
while body is None:
# fire with a bytestring of the first message for any 'phase' that
# wasn't one of our own messages. It will either come from
# previously-received messages, or from an EventSource that we attach
# to the corresponding URL
phase_and_body = self._find_inbound_message(phases)
while phase_and_body is None:
remaining = self._started + self._timeout - time.time()
if remaining < 0:
raise Timeout
@ -98,12 +104,17 @@ class Channel:
self._handle_welcome(json.loads(data))
if eventtype == "message":
self._add_inbound_messages([json.loads(data)])
body = self._find_inbound_message(phase)
if body:
phase_and_body = self._find_inbound_message(phases)
if phase_and_body:
f.close()
break
if not body:
if not phase_and_body:
time.sleep(self._wait)
return phase_and_body
def get(self, phase):
(got_phase, body) = self.get_first_of([phase])
assert got_phase == phase
return body
def deallocate(self, mood=None):

View File

@ -63,6 +63,46 @@ class Channel(ServerBase, unittest.TestCase):
return d
def test_get_multiple_phases(self):
cm1 = ChannelManager(self.relayurl, APPID, u"side1", self.ignore)
cm2 = ChannelManager(self.relayurl, APPID, u"side2", self.ignore)
c1 = cm1.connect(1)
c2 = cm2.connect(1)
self.failUnlessRaises(TypeError, c2.get_first_of, u"phase1")
self.failUnlessRaises(TypeError, c2.get_first_of, [u"phase1", 7])
d = succeed(None)
d.addCallback(lambda _: deferToThread(c1.send, u"phase1", b"msg1"))
d.addCallback(lambda _: deferToThread(c2.get_first_of, [u"phase1",
u"phase2"]))
d.addCallback(lambda phase_and_body:
self.failUnlessEqual(phase_and_body,
(u"phase1", b"msg1")))
d.addCallback(lambda _: deferToThread(c2.get_first_of, [u"phase2",
u"phase1"]))
d.addCallback(lambda phase_and_body:
self.failUnlessEqual(phase_and_body,
(u"phase1", b"msg1")))
d.addCallback(lambda _: deferToThread(c1.send, u"phase2", b"msg2"))
d.addCallback(lambda _: deferToThread(c2.get, u"phase2"))
# if both are present, it should prefer the first one we asked for
d.addCallback(lambda _: deferToThread(c2.get_first_of, [u"phase1",
u"phase2"]))
d.addCallback(lambda phase_and_body:
self.failUnlessEqual(phase_and_body,
(u"phase1", b"msg1")))
d.addCallback(lambda _: deferToThread(c2.get_first_of, [u"phase2",
u"phase1"]))
d.addCallback(lambda phase_and_body:
self.failUnlessEqual(phase_and_body,
(u"phase2", b"msg2")))
return d
def test_appid_independence(self):
APPID_A = u"appid_A"
APPID_B = u"appid_B"

View File

@ -61,6 +61,42 @@ class Channel(ServerBase, unittest.TestCase):
return d
def test_get_multiple_phases(self):
cm1 = ChannelManager(self.relayurl, APPID, u"side1", self.ignore)
cm2 = ChannelManager(self.relayurl, APPID, u"side2", self.ignore)
c1 = cm1.connect(1)
c2 = cm2.connect(1)
self.failUnlessRaises(TypeError, c2.get_first_of, u"phase1")
self.failUnlessRaises(TypeError, c2.get_first_of, [u"phase1", 7])
d = succeed(None)
d.addCallback(lambda _: c1.send(u"phase1", b"msg1"))
d.addCallback(lambda _: c2.get_first_of([u"phase1", u"phase2"]))
d.addCallback(lambda phase_and_body:
self.failUnlessEqual(phase_and_body,
(u"phase1", b"msg1")))
d.addCallback(lambda _: c2.get_first_of([u"phase2", u"phase1"]))
d.addCallback(lambda phase_and_body:
self.failUnlessEqual(phase_and_body,
(u"phase1", b"msg1")))
d.addCallback(lambda _: c1.send(u"phase2", b"msg2"))
d.addCallback(lambda _: c2.get(u"phase2"))
# if both are present, it should prefer the first one we asked for
d.addCallback(lambda _: c2.get_first_of([u"phase1", u"phase2"]))
d.addCallback(lambda phase_and_body:
self.failUnlessEqual(phase_and_body,
(u"phase1", b"msg1")))
d.addCallback(lambda _: c2.get_first_of([u"phase2", u"phase1"]))
d.addCallback(lambda phase_and_body:
self.failUnlessEqual(phase_and_body,
(u"phase2", b"msg2")))
return d
def test_appid_independence(self):
APPID_A = u"appid_A"
APPID_B = u"appid_B"

View File

@ -81,10 +81,12 @@ class Channel:
body = unhexlify(msg["body"].encode("ascii"))
self._messages.add( (phase, body) )
def _find_inbound_message(self, phase):
for (their_phase,body) in self._messages - self._sent_messages:
def _find_inbound_message(self, phases):
their_messages = self._messages - self._sent_messages
for phase in phases:
for (their_phase,body) in their_messages:
if their_phase == phase:
return body
return (phase, body)
return None
def send(self, phase, msg):
@ -102,14 +104,18 @@ class Channel:
d.addCallback(lambda resp: self._add_inbound_messages(resp["messages"]))
return d
def get(self, phase):
def get_first_of(self, phases):
if not isinstance(phases, (list, set)): raise TypeError(type(phases))
for phase in phases:
if not isinstance(phase, type(u"")): raise TypeError(type(phase))
# fire with a bytestring of the first message for 'phase' that wasn't
# one of ours. It will either come from previously-received messages,
# or from an EventSource that we attach to the corresponding URL
body = self._find_inbound_message(phase)
if body is not None:
return defer.succeed(body)
# fire with a bytestring of the first message for any 'phase' that
# wasn't one of our own messages. It will either come from
# previously-received messages, or from an EventSource that we attach
# to the corresponding URL
phase_and_body = self._find_inbound_message(phases)
if phase_and_body is not None:
return defer.succeed(phase_and_body)
d = defer.Deferred()
msgs = []
def _handle(name, data):
@ -117,9 +123,9 @@ class Channel:
self._handle_welcome(json.loads(data))
if name == "message":
self._add_inbound_messages([json.loads(data)])
body = self._find_inbound_message(phase)
if body is not None and not msgs:
msgs.append(body)
phase_and_body = self._find_inbound_message(phases)
if phase_and_body is not None and not msgs:
msgs.append(phase_and_body)
d.callback(None)
# TODO: use agent=self._agent
queryargs = urlencode([("appid", self._appid),
@ -133,6 +139,15 @@ class Channel:
d.addCallback(lambda _: msgs[0])
return d
def get(self, phase):
d = self.get_first_of([phase])
def _got(res):
(got_phase, body) = res
assert got_phase == phase
return body
d.addCallback(_got)
return d
def deallocate(self, mood=u"unknown"):
# only try once, no retries
d = post_json(self._agent, self._relay_url+"deallocate",