remove phase= from the Wormhole API

Phase are now implicit and numbered.
This commit is contained in:
Brian Warner 2016-05-12 16:16:05 -07:00
parent 501af4b4ec
commit d0ef53fc4d
2 changed files with 46 additions and 53 deletions

View File

@ -70,23 +70,43 @@ class Basic(ServerBase, unittest.TestCase):
@inlineCallbacks @inlineCallbacks
def test_phases(self): def test_multiple_messages(self):
w1 = Wormhole(APPID, self.relayurl) w1 = Wormhole(APPID, self.relayurl)
w2 = Wormhole(APPID, self.relayurl) w2 = Wormhole(APPID, self.relayurl)
w1.set_code(u"123-purple-elephant") w1.set_code(u"123-purple-elephant")
w2.set_code(u"123-purple-elephant") w2.set_code(u"123-purple-elephant")
yield self.doBoth(w1.send(b"data1", u"p1"), w2.send(b"data2", u"p1")) yield self.doBoth(w1.send(b"data1"), w2.send(b"data2"))
yield self.doBoth(w1.send(b"data3", u"p2"), w2.send(b"data4", u"p2")) yield self.doBoth(w1.send(b"data3"), w2.send(b"data4"))
dl = yield self.doBoth(w1.get(u"p2"), w2.get(u"p1")) dl = yield self.doBoth(w1.get(), w2.get())
(dataX, dataY) = dl
self.assertEqual(dataX, b"data4")
self.assertEqual(dataY, b"data1")
dl = yield self.doBoth(w1.get(u"p1"), w2.get(u"p2"))
(dataX, dataY) = dl (dataX, dataY) = dl
self.assertEqual(dataX, b"data2") self.assertEqual(dataX, b"data2")
self.assertEqual(dataY, b"data1")
dl = yield self.doBoth(w1.get(), w2.get())
(dataX, dataY) = dl
self.assertEqual(dataX, b"data4")
self.assertEqual(dataY, b"data3") self.assertEqual(dataY, b"data3")
yield self.doBoth(w1.close(), w2.close()) yield self.doBoth(w1.close(), w2.close())
@inlineCallbacks
def test_multiple_messages_2(self):
w1 = Wormhole(APPID, self.relayurl)
w2 = Wormhole(APPID, self.relayurl)
w1.set_code(u"123-purple-elephant")
w2.set_code(u"123-purple-elephant")
# TODO: set_code should be sufficient to kick things off, but for now
# we must also let both sides do at least one send() or get()
yield self.doBoth(w1.send(b"data1"), w2.send(b"ignored"))
yield w1.get()
yield w1.send(b"data2")
yield w1.send(b"data3")
data = yield w2.get()
self.assertEqual(data, b"data1")
data = yield w2.get()
self.assertEqual(data, b"data2")
data = yield w2.get()
self.assertEqual(data, b"data3")
yield self.doBoth(w1.close(), w2.close())
@inlineCallbacks @inlineCallbacks
def test_wrong_password(self): def test_wrong_password(self):
w1 = Wormhole(APPID, self.relayurl) w1 = Wormhole(APPID, self.relayurl)
@ -186,31 +206,6 @@ class Basic(ServerBase, unittest.TestCase):
yield self.assertFailure(w2.get_code(), UsageError) yield self.assertFailure(w2.get_code(), UsageError)
yield self.doBoth(w1.close(), w2.close()) yield self.doBoth(w1.close(), w2.close())
@inlineCallbacks
def test_repeat_phases(self):
w1 = Wormhole(APPID, self.relayurl)
w1.set_code(u"123-purple-elephant")
w2 = Wormhole(APPID, self.relayurl)
w2.set_code(u"123-purple-elephant")
# we must let them establish a key before we can send data
yield self.doBoth(w1.get_verifier(), w2.get_verifier())
yield w1.send(b"data1", phase=u"1")
# underscore-prefixed phases are reserved
yield self.assertFailure(w1.send(b"data1", phase=u"_1"), UsageError)
yield self.assertFailure(w1.get(phase=u"_1"), UsageError)
# you can't send twice to the same phase
yield self.assertFailure(w1.send(b"data1", phase=u"1"), UsageError)
# but you can send to a different one
yield w1.send(b"data2", phase=u"2")
res = yield w2.get(phase=u"1")
self.failUnlessEqual(res, b"data1")
# and you can't read twice from the same phase
yield self.assertFailure(w2.get(phase=u"1"), UsageError)
# but you can read from a different one
res = yield w2.get(phase=u"2")
self.failUnlessEqual(res, b"data2")
yield self.doBoth(w1.close(), w2.close())
@inlineCallbacks @inlineCallbacks
def test_serialize(self): def test_serialize(self):
w1 = Wormhole(APPID, self.relayurl) w1 = Wormhole(APPID, self.relayurl)

View File

@ -75,10 +75,11 @@ class Wormhole:
self._channelid = None self._channelid = None
self._key = None self._key = None
self._started_get_code = False self._started_get_code = False
self._sent_messages = set() # (phase, body_bytes) self._next_outbound_phase = 0
self._delivered_messages = set() # (phase, body_bytes) self._sent_messages = {} # phase -> body_bytes
self._delivered_messages = set() # phase
self._next_inbound_phase = 0
self._received_messages = {} # phase -> body_bytes self._received_messages = {} # phase -> body_bytes
self._sent_phases = set() # phases, to prohibit double-send
self._got_phases = set() # phases, to prohibit double-read self._got_phases = set() # phases, to prohibit double-read
self._sleepers = [] self._sleepers = []
self._confirmation_failed = False self._confirmation_failed = False
@ -335,7 +336,7 @@ class Wormhole:
# get_verifier/get # get_verifier/get
if self._code is None: raise UsageError if self._code is None: raise UsageError
if self._key is not None: raise UsageError if self._key is not None: raise UsageError
if self._sent_phases: raise UsageError if self._sent_messages: raise UsageError
if self._got_phases: raise UsageError if self._got_phases: raise UsageError
data = { data = {
"appid": self._appid, "appid": self._appid,
@ -400,14 +401,15 @@ class Wormhole:
@inlineCallbacks @inlineCallbacks
def _msg_send(self, phase, body, wait=False): def _msg_send(self, phase, body, wait=False):
self._sent_messages.add( (phase, body) ) if phase in self._sent_messages: raise UsageError
self._sent_messages[phase] = body
# TODO: retry on failure, with exponential backoff. We're guarding # TODO: retry on failure, with exponential backoff. We're guarding
# against the rendezvous server being temporarily offline. # against the rendezvous server being temporarily offline.
t = self._timing.add("add", phase=phase, wait=wait) t = self._timing.add("add", phase=phase, wait=wait)
yield self._ws_send(u"add", phase=phase, yield self._ws_send(u"add", phase=phase,
body=hexlify(body).decode("ascii")) body=hexlify(body).decode("ascii"))
if wait: if wait:
while (phase, body) not in self._delivered_messages: while phase not in self._delivered_messages:
yield self._sleep() yield self._sleep()
t.finish() t.finish()
@ -415,8 +417,8 @@ class Wormhole:
m = msg["message"] m = msg["message"]
phase = m["phase"] phase = m["phase"]
body = unhexlify(m["body"].encode("ascii")) body = unhexlify(m["body"].encode("ascii"))
if (phase, body) in self._sent_messages: if phase in self._sent_messages and self._sent_messages[phase] == body:
self._delivered_messages.add( (phase, body) ) # ack by server self._delivered_messages.add(phase) # ack by server
self._wakeup() self._wakeup()
return # ignore echoes of our outbound messages return # ignore echoes of our outbound messages
if phase in self._received_messages: if phase in self._received_messages:
@ -469,39 +471,35 @@ class Wormhole:
return data return data
@inlineCallbacks @inlineCallbacks
def send(self, outbound_data, phase=u"data", wait=False): def send(self, outbound_data, wait=False):
if not isinstance(outbound_data, type(b"")): if not isinstance(outbound_data, type(b"")):
raise TypeError(type(outbound_data)) raise TypeError(type(outbound_data))
if not isinstance(phase, type(u"")): raise TypeError(type(phase))
if self._closed: raise UsageError if self._closed: raise UsageError
if self._code is None: if self._code is None:
raise UsageError("You must set_code() before send()") raise UsageError("You must set_code() before send()")
if phase.startswith(u"_"): raise UsageError # reserved for internals phase = self._next_outbound_phase
if phase in self._sent_phases: raise UsageError # only call this once self._next_outbound_phase += 1
self._sent_phases.add(phase)
with self._timing.add("API send", phase=phase, wait=wait): with self._timing.add("API send", phase=phase, wait=wait):
# Without predefined roles, we can't derive predictably unique # Without predefined roles, we can't derive predictably unique
# keys for each side, so we use the same key for both. We use # keys for each side, so we use the same key for both. We use
# random nonces to keep the messages distinct, and we # random nonces to keep the messages distinct, and we
# automatically ignore reflections. # automatically ignore reflections.
yield self._get_master_key() yield self._get_master_key()
data_key = self.derive_key(u"wormhole:phase:%s" % phase) data_key = self.derive_key(u"wormhole:phase:%d" % phase)
outbound_encrypted = self._encrypt_data(data_key, outbound_data) outbound_encrypted = self._encrypt_data(data_key, outbound_data)
yield self._msg_send(phase, outbound_encrypted, wait) yield self._msg_send(phase, outbound_encrypted, wait)
@inlineCallbacks @inlineCallbacks
def get(self, phase=u"data"): def get(self):
if not isinstance(phase, type(u"")): raise TypeError(type(phase))
if self._closed: raise UsageError if self._closed: raise UsageError
if self._code is None: raise UsageError if self._code is None: raise UsageError
if phase.startswith(u"_"): raise UsageError # reserved for internals phase = self._next_inbound_phase
if phase in self._got_phases: raise UsageError # only call this once self._next_inbound_phase += 1
self._got_phases.add(phase)
with self._timing.add("API get", phase=phase): with self._timing.add("API get", phase=phase):
yield self._get_master_key() yield self._get_master_key()
body = yield self._msg_get(phase) # we can wait a long time here body = yield self._msg_get(phase) # we can wait a long time here
try: try:
data_key = self.derive_key(u"wormhole:phase:%s" % phase) data_key = self.derive_key(u"wormhole:phase:%d" % phase)
inbound_data = self._decrypt_data(data_key, body) inbound_data = self._decrypt_data(data_key, body)
returnValue(inbound_data) returnValue(inbound_data)
except CryptoError: except CryptoError: