add Channel.get_first_of()

This allows the Wormhole code to wait for multiple messages, which will
be useful for getting Confirmation messages soon.
This commit is contained in:
Brian Warner 2015-11-16 16:47:52 -08:00
parent b709a45891
commit ae2a6c6a05
4 changed files with 131 additions and 29 deletions

View File

@ -49,10 +49,12 @@ class Channel:
body = unhexlify(msg["body"].encode("ascii")) body = unhexlify(msg["body"].encode("ascii"))
self._messages.add( (phase, body) ) self._messages.add( (phase, body) )
def _find_inbound_message(self, phase): def _find_inbound_message(self, phases):
for (their_phase,body) in self._messages - self._sent_messages: their_messages = self._messages - self._sent_messages
if their_phase == phase: for phase in phases:
return body for (their_phase,body) in their_messages:
if their_phase == phase:
return (phase, body)
return None return None
def send(self, phase, msg): def send(self, phase, msg):
@ -73,17 +75,21 @@ class Channel:
resp = r.json() resp = r.json()
self._add_inbound_messages(resp["messages"]) self._add_inbound_messages(resp["messages"])
def get(self, phase): def get_first_of(self, phases):
if not isinstance(phase, type(u"")): raise TypeError(type(phase)) if not isinstance(phases, (list, set)): raise TypeError(type(phases))
for phase in phases:
if not isinstance(phase, type(u"")): raise TypeError(type(phase))
# For now, server errors cause the client to fail. TODO: don't. This # For now, server errors cause the client to fail. TODO: don't. This
# will require changing the client to re-post messages when the # will require changing the client to re-post messages when the
# server comes back up. # server comes back up.
# fire with a bytestring of the first message for 'phase' that wasn't # fire with a bytestring of the first message for any 'phase' that
# one of ours. It will either come from previously-received messages, # wasn't one of our own messages. It will either come from
# or from an EventSource that we attach to the corresponding URL # previously-received messages, or from an EventSource that we attach
body = self._find_inbound_message(phase) # to the corresponding URL
while body is None: phase_and_body = self._find_inbound_message(phases)
while phase_and_body is None:
remaining = self._started + self._timeout - time.time() remaining = self._started + self._timeout - time.time()
if remaining < 0: if remaining < 0:
raise Timeout raise Timeout
@ -98,12 +104,17 @@ class Channel:
self._handle_welcome(json.loads(data)) self._handle_welcome(json.loads(data))
if eventtype == "message": if eventtype == "message":
self._add_inbound_messages([json.loads(data)]) self._add_inbound_messages([json.loads(data)])
body = self._find_inbound_message(phase) phase_and_body = self._find_inbound_message(phases)
if body: if phase_and_body:
f.close() f.close()
break break
if not body: if not phase_and_body:
time.sleep(self._wait) time.sleep(self._wait)
return phase_and_body
def get(self, phase):
(got_phase, body) = self.get_first_of([phase])
assert got_phase == phase
return body return body
def deallocate(self, mood=None): def deallocate(self, mood=None):

View File

@ -63,6 +63,46 @@ class Channel(ServerBase, unittest.TestCase):
return d return d
def test_get_multiple_phases(self):
cm1 = ChannelManager(self.relayurl, APPID, u"side1", self.ignore)
cm2 = ChannelManager(self.relayurl, APPID, u"side2", self.ignore)
c1 = cm1.connect(1)
c2 = cm2.connect(1)
self.failUnlessRaises(TypeError, c2.get_first_of, u"phase1")
self.failUnlessRaises(TypeError, c2.get_first_of, [u"phase1", 7])
d = succeed(None)
d.addCallback(lambda _: deferToThread(c1.send, u"phase1", b"msg1"))
d.addCallback(lambda _: deferToThread(c2.get_first_of, [u"phase1",
u"phase2"]))
d.addCallback(lambda phase_and_body:
self.failUnlessEqual(phase_and_body,
(u"phase1", b"msg1")))
d.addCallback(lambda _: deferToThread(c2.get_first_of, [u"phase2",
u"phase1"]))
d.addCallback(lambda phase_and_body:
self.failUnlessEqual(phase_and_body,
(u"phase1", b"msg1")))
d.addCallback(lambda _: deferToThread(c1.send, u"phase2", b"msg2"))
d.addCallback(lambda _: deferToThread(c2.get, u"phase2"))
# if both are present, it should prefer the first one we asked for
d.addCallback(lambda _: deferToThread(c2.get_first_of, [u"phase1",
u"phase2"]))
d.addCallback(lambda phase_and_body:
self.failUnlessEqual(phase_and_body,
(u"phase1", b"msg1")))
d.addCallback(lambda _: deferToThread(c2.get_first_of, [u"phase2",
u"phase1"]))
d.addCallback(lambda phase_and_body:
self.failUnlessEqual(phase_and_body,
(u"phase2", b"msg2")))
return d
def test_appid_independence(self): def test_appid_independence(self):
APPID_A = u"appid_A" APPID_A = u"appid_A"
APPID_B = u"appid_B" APPID_B = u"appid_B"

View File

@ -61,6 +61,42 @@ class Channel(ServerBase, unittest.TestCase):
return d return d
def test_get_multiple_phases(self):
cm1 = ChannelManager(self.relayurl, APPID, u"side1", self.ignore)
cm2 = ChannelManager(self.relayurl, APPID, u"side2", self.ignore)
c1 = cm1.connect(1)
c2 = cm2.connect(1)
self.failUnlessRaises(TypeError, c2.get_first_of, u"phase1")
self.failUnlessRaises(TypeError, c2.get_first_of, [u"phase1", 7])
d = succeed(None)
d.addCallback(lambda _: c1.send(u"phase1", b"msg1"))
d.addCallback(lambda _: c2.get_first_of([u"phase1", u"phase2"]))
d.addCallback(lambda phase_and_body:
self.failUnlessEqual(phase_and_body,
(u"phase1", b"msg1")))
d.addCallback(lambda _: c2.get_first_of([u"phase2", u"phase1"]))
d.addCallback(lambda phase_and_body:
self.failUnlessEqual(phase_and_body,
(u"phase1", b"msg1")))
d.addCallback(lambda _: c1.send(u"phase2", b"msg2"))
d.addCallback(lambda _: c2.get(u"phase2"))
# if both are present, it should prefer the first one we asked for
d.addCallback(lambda _: c2.get_first_of([u"phase1", u"phase2"]))
d.addCallback(lambda phase_and_body:
self.failUnlessEqual(phase_and_body,
(u"phase1", b"msg1")))
d.addCallback(lambda _: c2.get_first_of([u"phase2", u"phase1"]))
d.addCallback(lambda phase_and_body:
self.failUnlessEqual(phase_and_body,
(u"phase2", b"msg2")))
return d
def test_appid_independence(self): def test_appid_independence(self):
APPID_A = u"appid_A" APPID_A = u"appid_A"
APPID_B = u"appid_B" APPID_B = u"appid_B"

View File

@ -81,10 +81,12 @@ class Channel:
body = unhexlify(msg["body"].encode("ascii")) body = unhexlify(msg["body"].encode("ascii"))
self._messages.add( (phase, body) ) self._messages.add( (phase, body) )
def _find_inbound_message(self, phase): def _find_inbound_message(self, phases):
for (their_phase,body) in self._messages - self._sent_messages: their_messages = self._messages - self._sent_messages
if their_phase == phase: for phase in phases:
return body for (their_phase,body) in their_messages:
if their_phase == phase:
return (phase, body)
return None return None
def send(self, phase, msg): def send(self, phase, msg):
@ -102,14 +104,18 @@ class Channel:
d.addCallback(lambda resp: self._add_inbound_messages(resp["messages"])) d.addCallback(lambda resp: self._add_inbound_messages(resp["messages"]))
return d return d
def get(self, phase): def get_first_of(self, phases):
if not isinstance(phase, type(u"")): raise TypeError(type(phase)) if not isinstance(phases, (list, set)): raise TypeError(type(phases))
# fire with a bytestring of the first message for 'phase' that wasn't for phase in phases:
# one of ours. It will either come from previously-received messages, if not isinstance(phase, type(u"")): raise TypeError(type(phase))
# or from an EventSource that we attach to the corresponding URL
body = self._find_inbound_message(phase) # fire with a bytestring of the first message for any 'phase' that
if body is not None: # wasn't one of our own messages. It will either come from
return defer.succeed(body) # previously-received messages, or from an EventSource that we attach
# to the corresponding URL
phase_and_body = self._find_inbound_message(phases)
if phase_and_body is not None:
return defer.succeed(phase_and_body)
d = defer.Deferred() d = defer.Deferred()
msgs = [] msgs = []
def _handle(name, data): def _handle(name, data):
@ -117,9 +123,9 @@ class Channel:
self._handle_welcome(json.loads(data)) self._handle_welcome(json.loads(data))
if name == "message": if name == "message":
self._add_inbound_messages([json.loads(data)]) self._add_inbound_messages([json.loads(data)])
body = self._find_inbound_message(phase) phase_and_body = self._find_inbound_message(phases)
if body is not None and not msgs: if phase_and_body is not None and not msgs:
msgs.append(body) msgs.append(phase_and_body)
d.callback(None) d.callback(None)
# TODO: use agent=self._agent # TODO: use agent=self._agent
queryargs = urlencode([("appid", self._appid), queryargs = urlencode([("appid", self._appid),
@ -133,6 +139,15 @@ class Channel:
d.addCallback(lambda _: msgs[0]) d.addCallback(lambda _: msgs[0])
return d return d
def get(self, phase):
d = self.get_first_of([phase])
def _got(res):
(got_phase, body) = res
assert got_phase == phase
return body
d.addCallback(_got)
return d
def deallocate(self, mood=u"unknown"): def deallocate(self, mood=u"unknown"):
# only try once, no retries # only try once, no retries
d = post_json(self._agent, self._relay_url+"deallocate", d = post_json(self._agent, self._relay_url+"deallocate",