enforce bytes-vs-str in the API

The main wormhole code is str (unicode in py3, bytes in py2). Most
everything else must be passed as bytes in both py2/py3.

Keep the internal "side" string as a str, to make it easier to merge
with other URL pieces.
This commit is contained in:
Brian Warner 2015-09-27 23:09:51 -07:00
parent 6614783c43
commit a7213d9c9a
5 changed files with 61 additions and 28 deletions

View File

@ -132,6 +132,8 @@ include randomly-selected words or characters. Dice, coin flips, shuffled
cards, or repeated sampling of a high-resolution stopwatch are all useful cards, or repeated sampling of a high-resolution stopwatch are all useful
techniques. techniques.
Note that the code is a human-readable string (the python "str" type: so
unicode in python3, plain bytes in python2).
## Application Identifier ## Application Identifier
@ -210,6 +212,23 @@ To properly checkpoint the process, you should store the first message
(returned by `start()`) next to the serialized wormhole instance, so you can (returned by `start()`) next to the serialized wormhole instance, so you can
re-send it if necessary. re-send it if necessary.
## Bytes, Strings, Unicode, and Python 3
All cryptographically-sensitive parameters are passed as bytes ("str" in
python2, "bytes" in python3):
* application identifier
* verifier string
* data in
* data out
* derived-key "purpose" string
Some human-readable parameters are passed as strings: "str" in python2, "str"
(i.e. unicode) in python3:
* wormhole code
* relay/transit URLs
## Detailed Example ## Detailed Example
```python ```python

View File

@ -90,9 +90,10 @@ class Wormhole:
def get_code(self, code_length=2): def get_code(self, code_length=2):
if self.code is not None: raise UsageError if self.code is not None: raise UsageError
self.side = hexlify(os.urandom(5)) self.side = hexlify(os.urandom(5)).decode("ascii")
channel_id = self._allocate_channel() # allocate channel channel_id = self._allocate_channel() # allocate channel
code = codes.make_code(channel_id, code_length) code = codes.make_code(channel_id, code_length)
assert isinstance(code, str), type(code)
self._set_code_and_channel_id(code) self._set_code_and_channel_id(code)
self._start() self._start()
return code return code
@ -109,10 +110,11 @@ class Wormhole:
return code return code
def set_code(self, code): # used for human-made pre-generated codes def set_code(self, code): # used for human-made pre-generated codes
if not isinstance(code, str): raise UsageError
if self.code is not None: raise UsageError if self.code is not None: raise UsageError
if self.side is not None: raise UsageError if self.side is not None: raise UsageError
self._set_code_and_channel_id(code) self._set_code_and_channel_id(code)
self.side = hexlify(os.urandom(5)) self.side = hexlify(os.urandom(5)).decode("ascii")
self._start() self._start()
def _set_code_and_channel_id(self, code): def _set_code_and_channel_id(self, code):
@ -165,12 +167,16 @@ class Wormhole:
return HKDF(self.key, length, CTXinfo=purpose) return HKDF(self.key, length, CTXinfo=purpose)
def _encrypt_data(self, key, data): def _encrypt_data(self, key, data):
assert isinstance(key, type(b"")), type(key)
assert isinstance(data, type(b"")), type(data)
if len(key) != SecretBox.KEY_SIZE: raise UsageError if len(key) != SecretBox.KEY_SIZE: raise UsageError
box = SecretBox(key) box = SecretBox(key)
nonce = utils.random(SecretBox.NONCE_SIZE) nonce = utils.random(SecretBox.NONCE_SIZE)
return box.encrypt(data, nonce) return box.encrypt(data, nonce)
def _decrypt_data(self, key, encrypted): def _decrypt_data(self, key, encrypted):
assert isinstance(key, type(b"")), type(key)
assert isinstance(encrypted, type(b"")), type(encrypted)
if len(key) != SecretBox.KEY_SIZE: raise UsageError if len(key) != SecretBox.KEY_SIZE: raise UsageError
box = SecretBox(key) box = SecretBox(key)
data = box.decrypt(encrypted) data = box.decrypt(encrypted)
@ -192,6 +198,7 @@ class Wormhole:
def get_data(self, outbound_data): def get_data(self, outbound_data):
# only call this once # only call this once
if not isinstance(outbound_data, type(b"")): raise UsageError
if self.code is None: raise UsageError if self.code is None: raise UsageError
if self.channel_id is None: raise UsageError if self.channel_id is None: raise UsageError
try: try:

View File

@ -16,8 +16,8 @@ class Blocking(ServerBase, unittest.TestCase):
d = deferToThread(w1.get_code) d = deferToThread(w1.get_code)
def _got_code(code): def _got_code(code):
w2.set_code(code) w2.set_code(code)
d1 = deferToThread(w1.get_data, "data1") d1 = deferToThread(w1.get_data, b"data1")
d2 = deferToThread(w2.get_data, "data2") d2 = deferToThread(w2.get_data, b"data2")
return defer.DeferredList([d1,d2], fireOnOneErrback=False) return defer.DeferredList([d1,d2], fireOnOneErrback=False)
d.addCallback(_got_code) d.addCallback(_got_code)
def _done(dl): def _done(dl):
@ -25,8 +25,8 @@ class Blocking(ServerBase, unittest.TestCase):
r1,r2 = dl r1,r2 = dl
self.assertTrue(success1, dataX) self.assertTrue(success1, dataX)
self.assertTrue(success2, dataY) self.assertTrue(success2, dataY)
self.assertEqual(dataX, "data2") self.assertEqual(dataX, b"data2")
self.assertEqual(dataY, "data1") self.assertEqual(dataY, b"data1")
d.addCallback(_done) d.addCallback(_done)
return d return d
@ -36,16 +36,16 @@ class Blocking(ServerBase, unittest.TestCase):
w2 = BlockingWormhole(appid, self.relayurl) w2 = BlockingWormhole(appid, self.relayurl)
w1.set_code("123-purple-elephant") w1.set_code("123-purple-elephant")
w2.set_code("123-purple-elephant") w2.set_code("123-purple-elephant")
d1 = deferToThread(w1.get_data, "data1") d1 = deferToThread(w1.get_data, b"data1")
d2 = deferToThread(w2.get_data, "data2") d2 = deferToThread(w2.get_data, b"data2")
d = defer.DeferredList([d1,d2], fireOnOneErrback=False) d = defer.DeferredList([d1,d2], fireOnOneErrback=False)
def _done(dl): def _done(dl):
((success1, dataX), (success2, dataY)) = dl ((success1, dataX), (success2, dataY)) = dl
r1,r2 = dl r1,r2 = dl
self.assertTrue(success1, dataX) self.assertTrue(success1, dataX)
self.assertTrue(success2, dataY) self.assertTrue(success2, dataY)
self.assertEqual(dataX, "data2") self.assertEqual(dataX, b"data2")
self.assertEqual(dataY, "data1") self.assertEqual(dataY, b"data1")
d.addCallback(_done) d.addCallback(_done)
return d return d
@ -53,7 +53,7 @@ class Blocking(ServerBase, unittest.TestCase):
appid = b"appid" appid = b"appid"
w1 = BlockingWormhole(appid, self.relayurl) w1 = BlockingWormhole(appid, self.relayurl)
self.assertRaises(UsageError, w1.get_verifier) self.assertRaises(UsageError, w1.get_verifier)
self.assertRaises(UsageError, w1.get_data, "data") self.assertRaises(UsageError, w1.get_data, b"data")
w1.set_code("123-purple-elephant") w1.set_code("123-purple-elephant")
self.assertRaises(UsageError, w1.set_code, "123-nope") self.assertRaises(UsageError, w1.set_code, "123-nope")
self.assertRaises(UsageError, w1.get_code) self.assertRaises(UsageError, w1.get_code)
@ -79,8 +79,8 @@ class Blocking(ServerBase, unittest.TestCase):
unpacked = json.loads(s) # this is supposed to be JSON unpacked = json.loads(s) # this is supposed to be JSON
self.assertEqual(type(unpacked), dict) self.assertEqual(type(unpacked), dict)
new_w1 = BlockingWormhole.from_serialized(s) new_w1 = BlockingWormhole.from_serialized(s)
d1 = deferToThread(new_w1.get_data, "data1") d1 = deferToThread(new_w1.get_data, b"data1")
d2 = deferToThread(w2.get_data, "data2") d2 = deferToThread(w2.get_data, b"data2")
return defer.DeferredList([d1,d2], fireOnOneErrback=False) return defer.DeferredList([d1,d2], fireOnOneErrback=False)
d.addCallback(_got_code) d.addCallback(_got_code)
def _done(dl): def _done(dl):
@ -88,8 +88,8 @@ class Blocking(ServerBase, unittest.TestCase):
r1,r2 = dl r1,r2 = dl
self.assertTrue(success1, dataX) self.assertTrue(success1, dataX)
self.assertTrue(success2, dataY) self.assertTrue(success2, dataY)
self.assertEqual(dataX, "data2") self.assertEqual(dataX, b"data2")
self.assertEqual(dataY, "data1") self.assertEqual(dataY, b"data1")
self.assertRaises(UsageError, w2.serialize) # too late self.assertRaises(UsageError, w2.serialize) # too late
d.addCallback(_done) d.addCallback(_done)
return d return d

View File

@ -15,8 +15,8 @@ class Basic(ServerBase, unittest.TestCase):
d = w1.get_code() d = w1.get_code()
def _got_code(code): def _got_code(code):
w2.set_code(code) w2.set_code(code)
d1 = w1.get_data("data1") d1 = w1.get_data(b"data1")
d2 = w2.get_data("data2") d2 = w2.get_data(b"data2")
return defer.DeferredList([d1,d2], fireOnOneErrback=False) return defer.DeferredList([d1,d2], fireOnOneErrback=False)
d.addCallback(_got_code) d.addCallback(_got_code)
def _done(dl): def _done(dl):
@ -24,8 +24,8 @@ class Basic(ServerBase, unittest.TestCase):
r1,r2 = dl r1,r2 = dl
self.assertTrue(success1, dataX) self.assertTrue(success1, dataX)
self.assertTrue(success2, dataY) self.assertTrue(success2, dataY)
self.assertEqual(dataX, "data2") self.assertEqual(dataX, b"data2")
self.assertEqual(dataY, "data1") self.assertEqual(dataY, b"data1")
d.addCallback(_done) d.addCallback(_done)
return d return d
@ -35,16 +35,16 @@ class Basic(ServerBase, unittest.TestCase):
w2 = Wormhole(appid, self.relayurl) w2 = Wormhole(appid, self.relayurl)
w1.set_code("123-purple-elephant") w1.set_code("123-purple-elephant")
w2.set_code("123-purple-elephant") w2.set_code("123-purple-elephant")
d1 = w1.get_data("data1") d1 = w1.get_data(b"data1")
d2 = w2.get_data("data2") d2 = w2.get_data(b"data2")
d = defer.DeferredList([d1,d2], fireOnOneErrback=False) d = defer.DeferredList([d1,d2], fireOnOneErrback=False)
def _done(dl): def _done(dl):
((success1, dataX), (success2, dataY)) = dl ((success1, dataX), (success2, dataY)) = dl
r1,r2 = dl r1,r2 = dl
self.assertTrue(success1, dataX) self.assertTrue(success1, dataX)
self.assertTrue(success2, dataY) self.assertTrue(success2, dataY)
self.assertEqual(dataX, "data2") self.assertEqual(dataX, b"data2")
self.assertEqual(dataY, "data1") self.assertEqual(dataY, b"data1")
d.addCallback(_done) d.addCallback(_done)
return d return d
@ -52,7 +52,7 @@ class Basic(ServerBase, unittest.TestCase):
appid = b"appid" appid = b"appid"
w1 = Wormhole(appid, self.relayurl) w1 = Wormhole(appid, self.relayurl)
self.assertRaises(UsageError, w1.get_verifier) self.assertRaises(UsageError, w1.get_verifier)
self.assertRaises(UsageError, w1.get_data, "data") self.assertRaises(UsageError, w1.get_data, b"data")
w1.set_code("123-purple-elephant") w1.set_code("123-purple-elephant")
self.assertRaises(UsageError, w1.set_code, "123-nope") self.assertRaises(UsageError, w1.set_code, "123-nope")
self.assertRaises(UsageError, w1.get_code) self.assertRaises(UsageError, w1.get_code)
@ -76,8 +76,8 @@ class Basic(ServerBase, unittest.TestCase):
unpacked = json.loads(s) # this is supposed to be JSON unpacked = json.loads(s) # this is supposed to be JSON
self.assertEqual(type(unpacked), dict) self.assertEqual(type(unpacked), dict)
new_w1 = Wormhole.from_serialized(s) new_w1 = Wormhole.from_serialized(s)
d1 = new_w1.get_data("data1") d1 = new_w1.get_data(b"data1")
d2 = w2.get_data("data2") d2 = w2.get_data(b"data2")
return defer.DeferredList([d1,d2], fireOnOneErrback=False) return defer.DeferredList([d1,d2], fireOnOneErrback=False)
d.addCallback(_got_code) d.addCallback(_got_code)
def _done(dl): def _done(dl):
@ -85,8 +85,8 @@ class Basic(ServerBase, unittest.TestCase):
r1,r2 = dl r1,r2 = dl
self.assertTrue(success1, dataX) self.assertTrue(success1, dataX)
self.assertTrue(success2, dataY) self.assertTrue(success2, dataY)
self.assertEqual(dataX, "data2") self.assertEqual(dataX, b"data2")
self.assertEqual(dataY, "data1") self.assertEqual(dataY, b"data1")
self.assertRaises(UsageError, w2.serialize) # too late self.assertRaises(UsageError, w2.serialize) # too late
d.addCallback(_done) d.addCallback(_done)
return d return d

View File

@ -110,6 +110,7 @@ class Wormhole:
d = self._allocate_channel() d = self._allocate_channel()
def _got_channel_id(channel_id): def _got_channel_id(channel_id):
code = codes.make_code(channel_id, code_length) code = codes.make_code(channel_id, code_length)
assert isinstance(code, str), type(code)
self._set_code_and_channel_id(code) self._set_code_and_channel_id(code)
self._start() self._start()
return code return code
@ -117,6 +118,7 @@ class Wormhole:
return d return d
def set_code(self, code): def set_code(self, code):
if not isinstance(code, str): raise UsageError
if self.code is not None: raise UsageError if self.code is not None: raise UsageError
if self.side is not None: raise UsageError if self.side is not None: raise UsageError
self._set_code_and_channel_id(code) self._set_code_and_channel_id(code)
@ -202,12 +204,16 @@ class Wormhole:
return HKDF(self.key, length, CTXinfo=purpose) return HKDF(self.key, length, CTXinfo=purpose)
def _encrypt_data(self, key, data): def _encrypt_data(self, key, data):
assert isinstance(key, type(b"")), type(key)
assert isinstance(data, type(b"")), type(data)
if len(key) != SecretBox.KEY_SIZE: raise UsageError if len(key) != SecretBox.KEY_SIZE: raise UsageError
box = SecretBox(key) box = SecretBox(key)
nonce = utils.random(SecretBox.NONCE_SIZE) nonce = utils.random(SecretBox.NONCE_SIZE)
return box.encrypt(data, nonce) return box.encrypt(data, nonce)
def _decrypt_data(self, key, encrypted): def _decrypt_data(self, key, encrypted):
assert isinstance(key, type(b"")), type(key)
assert isinstance(encrypted, type(b"")), type(encrypted)
if len(key) != SecretBox.KEY_SIZE: raise UsageError if len(key) != SecretBox.KEY_SIZE: raise UsageError
box = SecretBox(key) box = SecretBox(key)
data = box.decrypt(encrypted) data = box.decrypt(encrypted)
@ -236,6 +242,7 @@ class Wormhole:
def get_data(self, outbound_data): def get_data(self, outbound_data):
# only call this once # only call this once
if not isinstance(outbound_data, type(b"")): raise UsageError
if self.code is None: raise UsageError if self.code is None: raise UsageError
d = self._get_key() d = self._get_key()
d.addCallback(self._get_data2, outbound_data) d.addCallback(self._get_data2, outbound_data)