remove phase= from the Wormhole API

Phase are now implicit and numbered.
This commit is contained in:
Brian Warner 2016-05-12 16:16:05 -07:00
parent 501af4b4ec
commit d0ef53fc4d
2 changed files with 46 additions and 53 deletions

View File

@ -70,23 +70,43 @@ class Basic(ServerBase, unittest.TestCase):
@inlineCallbacks
def test_phases(self):
def test_multiple_messages(self):
w1 = Wormhole(APPID, self.relayurl)
w2 = Wormhole(APPID, self.relayurl)
w1.set_code(u"123-purple-elephant")
w2.set_code(u"123-purple-elephant")
yield self.doBoth(w1.send(b"data1", u"p1"), w2.send(b"data2", u"p1"))
yield self.doBoth(w1.send(b"data3", u"p2"), w2.send(b"data4", u"p2"))
dl = yield self.doBoth(w1.get(u"p2"), w2.get(u"p1"))
(dataX, dataY) = dl
self.assertEqual(dataX, b"data4")
self.assertEqual(dataY, b"data1")
dl = yield self.doBoth(w1.get(u"p1"), w2.get(u"p2"))
yield self.doBoth(w1.send(b"data1"), w2.send(b"data2"))
yield self.doBoth(w1.send(b"data3"), w2.send(b"data4"))
dl = yield self.doBoth(w1.get(), w2.get())
(dataX, dataY) = dl
self.assertEqual(dataX, b"data2")
self.assertEqual(dataY, b"data1")
dl = yield self.doBoth(w1.get(), w2.get())
(dataX, dataY) = dl
self.assertEqual(dataX, b"data4")
self.assertEqual(dataY, b"data3")
yield self.doBoth(w1.close(), w2.close())
@inlineCallbacks
def test_multiple_messages_2(self):
w1 = Wormhole(APPID, self.relayurl)
w2 = Wormhole(APPID, self.relayurl)
w1.set_code(u"123-purple-elephant")
w2.set_code(u"123-purple-elephant")
# TODO: set_code should be sufficient to kick things off, but for now
# we must also let both sides do at least one send() or get()
yield self.doBoth(w1.send(b"data1"), w2.send(b"ignored"))
yield w1.get()
yield w1.send(b"data2")
yield w1.send(b"data3")
data = yield w2.get()
self.assertEqual(data, b"data1")
data = yield w2.get()
self.assertEqual(data, b"data2")
data = yield w2.get()
self.assertEqual(data, b"data3")
yield self.doBoth(w1.close(), w2.close())
@inlineCallbacks
def test_wrong_password(self):
w1 = Wormhole(APPID, self.relayurl)
@ -186,31 +206,6 @@ class Basic(ServerBase, unittest.TestCase):
yield self.assertFailure(w2.get_code(), UsageError)
yield self.doBoth(w1.close(), w2.close())
@inlineCallbacks
def test_repeat_phases(self):
w1 = Wormhole(APPID, self.relayurl)
w1.set_code(u"123-purple-elephant")
w2 = Wormhole(APPID, self.relayurl)
w2.set_code(u"123-purple-elephant")
# we must let them establish a key before we can send data
yield self.doBoth(w1.get_verifier(), w2.get_verifier())
yield w1.send(b"data1", phase=u"1")
# underscore-prefixed phases are reserved
yield self.assertFailure(w1.send(b"data1", phase=u"_1"), UsageError)
yield self.assertFailure(w1.get(phase=u"_1"), UsageError)
# you can't send twice to the same phase
yield self.assertFailure(w1.send(b"data1", phase=u"1"), UsageError)
# but you can send to a different one
yield w1.send(b"data2", phase=u"2")
res = yield w2.get(phase=u"1")
self.failUnlessEqual(res, b"data1")
# and you can't read twice from the same phase
yield self.assertFailure(w2.get(phase=u"1"), UsageError)
# but you can read from a different one
res = yield w2.get(phase=u"2")
self.failUnlessEqual(res, b"data2")
yield self.doBoth(w1.close(), w2.close())
@inlineCallbacks
def test_serialize(self):
w1 = Wormhole(APPID, self.relayurl)

View File

@ -75,10 +75,11 @@ class Wormhole:
self._channelid = None
self._key = None
self._started_get_code = False
self._sent_messages = set() # (phase, body_bytes)
self._delivered_messages = set() # (phase, body_bytes)
self._next_outbound_phase = 0
self._sent_messages = {} # phase -> body_bytes
self._delivered_messages = set() # phase
self._next_inbound_phase = 0
self._received_messages = {} # phase -> body_bytes
self._sent_phases = set() # phases, to prohibit double-send
self._got_phases = set() # phases, to prohibit double-read
self._sleepers = []
self._confirmation_failed = False
@ -335,7 +336,7 @@ class Wormhole:
# get_verifier/get
if self._code is None: raise UsageError
if self._key is not None: raise UsageError
if self._sent_phases: raise UsageError
if self._sent_messages: raise UsageError
if self._got_phases: raise UsageError
data = {
"appid": self._appid,
@ -400,14 +401,15 @@ class Wormhole:
@inlineCallbacks
def _msg_send(self, phase, body, wait=False):
self._sent_messages.add( (phase, body) )
if phase in self._sent_messages: raise UsageError
self._sent_messages[phase] = body
# TODO: retry on failure, with exponential backoff. We're guarding
# against the rendezvous server being temporarily offline.
t = self._timing.add("add", phase=phase, wait=wait)
yield self._ws_send(u"add", phase=phase,
body=hexlify(body).decode("ascii"))
if wait:
while (phase, body) not in self._delivered_messages:
while phase not in self._delivered_messages:
yield self._sleep()
t.finish()
@ -415,8 +417,8 @@ class Wormhole:
m = msg["message"]
phase = m["phase"]
body = unhexlify(m["body"].encode("ascii"))
if (phase, body) in self._sent_messages:
self._delivered_messages.add( (phase, body) ) # ack by server
if phase in self._sent_messages and self._sent_messages[phase] == body:
self._delivered_messages.add(phase) # ack by server
self._wakeup()
return # ignore echoes of our outbound messages
if phase in self._received_messages:
@ -469,39 +471,35 @@ class Wormhole:
return data
@inlineCallbacks
def send(self, outbound_data, phase=u"data", wait=False):
def send(self, outbound_data, wait=False):
if not isinstance(outbound_data, type(b"")):
raise TypeError(type(outbound_data))
if not isinstance(phase, type(u"")): raise TypeError(type(phase))
if self._closed: raise UsageError
if self._code is None:
raise UsageError("You must set_code() before send()")
if phase.startswith(u"_"): raise UsageError # reserved for internals
if phase in self._sent_phases: raise UsageError # only call this once
self._sent_phases.add(phase)
phase = self._next_outbound_phase
self._next_outbound_phase += 1
with self._timing.add("API send", phase=phase, wait=wait):
# Without predefined roles, we can't derive predictably unique
# keys for each side, so we use the same key for both. We use
# random nonces to keep the messages distinct, and we
# automatically ignore reflections.
yield self._get_master_key()
data_key = self.derive_key(u"wormhole:phase:%s" % phase)
data_key = self.derive_key(u"wormhole:phase:%d" % phase)
outbound_encrypted = self._encrypt_data(data_key, outbound_data)
yield self._msg_send(phase, outbound_encrypted, wait)
@inlineCallbacks
def get(self, phase=u"data"):
if not isinstance(phase, type(u"")): raise TypeError(type(phase))
def get(self):
if self._closed: raise UsageError
if self._code is None: raise UsageError
if phase.startswith(u"_"): raise UsageError # reserved for internals
if phase in self._got_phases: raise UsageError # only call this once
self._got_phases.add(phase)
phase = self._next_inbound_phase
self._next_inbound_phase += 1
with self._timing.add("API get", phase=phase):
yield self._get_master_key()
body = yield self._msg_get(phase) # we can wait a long time here
try:
data_key = self.derive_key(u"wormhole:phase:%s" % phase)
data_key = self.derive_key(u"wormhole:phase:%d" % phase)
inbound_data = self._decrypt_data(data_key, body)
returnValue(inbound_data)
except CryptoError: