diff --git a/docs/api.md b/docs/api.md index 8d160d8..3548c7c 100644 --- a/docs/api.md +++ b/docs/api.md @@ -45,6 +45,16 @@ Transit class currently distinguishes "Sender" from "Receiver", so the programs on each side must have some way to decide (ahead of time) which is which. +Each side gets to do one `send_data()` call and one `get_data()` call. +`get_data` will wait until the other side has done `send_data`, so the +application developer must be careful to avoid deadlocks (don't get before +you send on both sides in the same protocol). When both sides are done, they +must call `close()`, to let the library know that the connection is complete +and it can deallocate the channel. If you forget to call `close()`, the +server will not free the channel, and other users will suffer longer +invitation codes as a result. To encourage `close()`, the library will log an +error if a Wormhole object is destroyed before being closed. + ## Examples The synchronous+blocking flow looks like this: @@ -56,7 +66,9 @@ mydata = b"initiator's data" i = Wormhole(b"appid", RENDEZVOUS_RELAY) code = i.get_code() print("Invitation Code: %s" % code) -theirdata = i.get_data(mydata) +i.send_data(mydata) +theirdata = i.get_data() +i.close() print("Their data: %s" % theirdata.decode("ascii")) ``` @@ -68,7 +80,9 @@ mydata = b"receiver's data" code = sys.argv[1] r = Wormhole(b"appid", RENDEZVOUS_RELAY) r.set_code(code) -theirdata = r.get_data(mydata) +r.send_data(mydata) +theirdata = r.get_data() +r.close() print("Their data: %s" % theirdata.decode("ascii")) ``` @@ -85,11 +99,13 @@ w1 = Wormhole(b"appid", RENDEZVOUS_RELAY) d = w1.get_code() def _got_code(code): print "Invitation Code:", code - return w1.get_data(outbound_message) + return w1.send_data(outbound_message) d.addCallback(_got_code) +d.addCallback(lambda _: w1.get_data()) def _got_data(inbound_message): print "Inbound message:", inbound_message d.addCallback(_got_data) +d.addCallback(w1.close) d.addBoth(lambda _: reactor.stop()) reactor.run() ``` @@ -99,17 +115,26 @@ On the other side, you call `set_code()` instead of waiting for `get_code()`: ```python w2 = Wormhole(b"appid", RENDEZVOUS_RELAY) w2.set_code(code) -d = w2.get_data(my_message) +d = w2.send_data(my_message) ... ``` -You can call `d=w.get_verifier()` before `get_data()`: this will perform the -first half of the PAKE negotiation, then fire the Deferred with a verifier -object (bytes) which can be converted into a printable representation and -manually compared. When the users are convinced that `get_verifier()` from -both sides are the same, call `d=get_data()` to continue the transfer. If you -call `get_data()` first, it will perform the complete transfer without -pausing. +Note that the Twisted-form `close()` accepts (and returns) an optional +argument, so you can use `d.addCallback(w.close)` instead of +`d.addCallback(lambda _: w.close())`. + +## Verifier + +You can call `w.get_verifier()` before `send_data()/get_data()`: this will +perform the first half of the PAKE negotiation, then return a verifier object +(bytes) which can be converted into a printable representation and manually +compared. When the users are convinced that `get_verifier()` from both sides +are the same, call `send_data()/get_data()` to continue the transfer. If you +call `send_data()/get_data()` before `get_verifier()`, it will perform the +complete transfer without pausing. + +The Twisted form of `get_verifier()` returns a Deferred that fires with the +verifier bytes. ## Generating the Invitation Code @@ -204,9 +229,10 @@ Wormhole.from_serialized(data)`). There is exactly one point at which you can serialize the wormhole: *after* establishing the invitation code, but before waiting for `get_verifier()` or -`get_data()`. If you are creating a new code, the correct time is during the -callback fired by `get_code()`. If you are accepting a pre-generated code, -the time is just after calling `set_code()`. +`get_data()`, or calling `send_data()`. If you are creating a new invitation +code, the correct time is during the callback fired by `get_code()`. If you +are accepting a pre-generated code, the time is just after calling +`set_code()`. To properly checkpoint the process, you should store the first message (returned by `start()`) next to the serialized wormhole instance, so you can diff --git a/src/wormhole/blocking/transcribe.py b/src/wormhole/blocking/transcribe.py index e580253..e9d62d2 100644 --- a/src/wormhole/blocking/transcribe.py +++ b/src/wormhole/blocking/transcribe.py @@ -139,6 +139,9 @@ class Wormhole: self.code = None self.key = None self.verifier = None + self._sent_data = False + self._got_data = False + self._closed = False def handle_welcome(self, welcome): if ("motd" in welcome and @@ -232,32 +235,37 @@ class Wormhole: self._get_key() return self.verifier - def get_data(self, outbound_data): - # only call this once + def send_data(self, outbound_data): + if self._sent_data: raise UsageError # only call this once if not isinstance(outbound_data, type(b"")): raise UsageError if self.code is None: raise UsageError if self.channel is None: raise UsageError - try: - self._get_key() - return self._get_data2(outbound_data) - finally: - self.channel.deallocate() - - def _get_data2(self, outbound_data): # Without predefined roles, we can't derive predictably unique keys # for each side, so we use the same key for both. We use random - # nonces to keep the messages distinct, and check for reflection. + # nonces to keep the messages distinct, and the Channel automatically + # ignores reflections. + self._get_key() data_key = self.derive_key(b"data-key") - outbound_encrypted = self._encrypt_data(data_key, outbound_data) self.channel.send(u"data", outbound_encrypted) + def get_data(self): + if self._got_data: raise UsageError # only call this once + if self.code is None: raise UsageError + if self.channel is None: raise UsageError + self._get_key() + data_key = self.derive_key(b"data-key") inbound_encrypted = self.channel.get(u"data") - # _find_inbound_message() ignores any inbound message that matches - # something we previously sent out, so we don't need to explicitly - # check for reflection. A reflection attack will just not progress. try: inbound_data = self._decrypt_data(data_key, inbound_encrypted) return inbound_data except CryptoError: raise WrongPasswordError + + def close(self): + self.channel.deallocate() + self._closed = True + + def __del__(self): + if not self._closed: + print("Error: a Wormhole instance was not closed", file=sys.stderr) diff --git a/src/wormhole/scripts/cmd_receive_file.py b/src/wormhole/scripts/cmd_receive_file.py index 5914558..796b9ff 100644 --- a/src/wormhole/scripts/cmd_receive_file.py +++ b/src/wormhole/scripts/cmd_receive_file.py @@ -33,12 +33,14 @@ def receive_file(args): "relay_connection_hints": transit_receiver.get_relay_hints(), }, }).encode("utf-8") + w.send_data(mydata) try: - data = json.loads(w.get_data(mydata).decode("utf-8")) + data = json.loads(w.get_data().decode("utf-8")) except WrongPasswordError as e: print("ERROR: " + e.explain(), file=sys.stderr) return 1 #print("their data: %r" % (data,)) + w.close() if "error" in data: print("ERROR: " + data["error"], file=sys.stderr) diff --git a/src/wormhole/scripts/cmd_receive_text.py b/src/wormhole/scripts/cmd_receive_text.py index 68bee99..e474bdb 100644 --- a/src/wormhole/scripts/cmd_receive_text.py +++ b/src/wormhole/scripts/cmd_receive_text.py @@ -24,11 +24,13 @@ def receive_text(args): print("Verifier %s." % verifier) data = json.dumps({"message": "ok"}).encode("utf-8") + w.send_data(data) try: - them_bytes = w.get_data(data) + them_bytes = w.get_data() except WrongPasswordError as e: print("ERROR: " + e.explain(), file=sys.stderr) return 1 + w.close() them_d = json.loads(them_bytes.decode("utf-8")) if "error" in them_d: print("ERROR: " + them_d["error"], file=sys.stderr) diff --git a/src/wormhole/scripts/cmd_send_file.py b/src/wormhole/scripts/cmd_send_file.py index db419ec..cd56ccd 100644 --- a/src/wormhole/scripts/cmd_send_file.py +++ b/src/wormhole/scripts/cmd_send_file.py @@ -45,7 +45,8 @@ def send_file(args): file=sys.stderr) reject_data = json.dumps({"error": "verification rejected", }).encode("utf-8") - w.get_data(reject_data) + w.send_data(reject_data) + w.close() return 1 filesize = os.stat(filename).st_size @@ -59,12 +60,13 @@ def send_file(args): "relay_connection_hints": transit_sender.get_relay_hints(), }, }).encode("utf-8") - + w.send_data(data) try: - them_bytes = w.get_data(data) + them_bytes = w.get_data() except WrongPasswordError as e: print("ERROR: " + e.explain(), file=sys.stderr) return 1 + w.close() them_d = json.loads(them_bytes.decode("utf-8")) #print("them: %r" % (them_d,)) diff --git a/src/wormhole/scripts/cmd_send_text.py b/src/wormhole/scripts/cmd_send_text.py index 6ea25b0..1fff009 100644 --- a/src/wormhole/scripts/cmd_send_text.py +++ b/src/wormhole/scripts/cmd_send_text.py @@ -39,17 +39,20 @@ def send_text(args): file=sys.stderr) reject_data = json.dumps({"error": "verification rejected", }).encode("utf-8") - w.get_data(reject_data) + w.send_data(reject_data) + w.close() return 1 message = args.text data = json.dumps({"message": message, }).encode("utf-8") + w.send_data(data) try: - them_bytes = w.get_data(data) + them_bytes = w.get_data() except WrongPasswordError as e: print("ERROR: " + e.explain(), file=sys.stderr) return 1 + w.close() them_d = json.loads(them_bytes.decode("utf-8")) if them_d["message"] == "ok": print("text sent") diff --git a/src/wormhole/test/test_blocking.py b/src/wormhole/test/test_blocking.py index 6e0554f..2f2c929 100644 --- a/src/wormhole/test/test_blocking.py +++ b/src/wormhole/test/test_blocking.py @@ -9,6 +9,14 @@ class Blocking(ServerBase, unittest.TestCase): # we need Twisted to run the server, but we run the sender and receiver # with deferToThread() + def doBoth(self, call1, call2): + f1 = call1[0] + f1args = call1[1:] + f2 = call2[0] + f2args = call2[1:] + return gatherResults([deferToThread(f1, *f1args), + deferToThread(f2, *f2args)], True) + def test_basic(self): appid = b"appid" w1 = BlockingWormhole(appid, self.relayurl) @@ -16,13 +24,39 @@ class Blocking(ServerBase, unittest.TestCase): d = deferToThread(w1.get_code) def _got_code(code): w2.set_code(code) - return gatherResults([deferToThread(w1.get_data, b"data1"), - deferToThread(w2.get_data, b"data2")], True) + return self.doBoth([w1.send_data, b"data1"], + [w2.send_data, b"data2"]) d.addCallback(_got_code) + def _sent(res): + return self.doBoth([w1.get_data], [w2.get_data]) + d.addCallback(_sent) def _done(dl): (dataX, dataY) = dl self.assertEqual(dataX, b"data2") self.assertEqual(dataY, b"data1") + return self.doBoth([w1.close], [w2.close]) + d.addCallback(_done) + return d + + def test_interleaved(self): + appid = b"appid" + w1 = BlockingWormhole(appid, self.relayurl) + w2 = BlockingWormhole(appid, self.relayurl) + d = deferToThread(w1.get_code) + def _got_code(code): + w2.set_code(code) + return self.doBoth([w1.send_data, b"data1"], + [w2.get_data]) + d.addCallback(_got_code) + def _sent(res): + (_, dataY) = res + self.assertEqual(dataY, b"data1") + return self.doBoth([w1.get_data], [w2.send_data, b"data2"]) + d.addCallback(_sent) + def _done(dl): + (dataX, _) = dl + self.assertEqual(dataX, b"data2") + return self.doBoth([w1.close], [w2.close]) d.addCallback(_done) return d @@ -32,12 +66,15 @@ class Blocking(ServerBase, unittest.TestCase): w2 = BlockingWormhole(appid, self.relayurl) w1.set_code("123-purple-elephant") w2.set_code("123-purple-elephant") - d = gatherResults([deferToThread(w1.get_data, b"data1"), - deferToThread(w2.get_data, b"data2")], True) + d = self.doBoth([w1.send_data, b"data1"], [w2.send_data, b"data2"]) + def _sent(res): + return self.doBoth([w1.get_data], [w2.get_data]) + d.addCallback(_sent) def _done(dl): (dataX, dataY) = dl self.assertEqual(dataX, b"data2") self.assertEqual(dataY, b"data1") + return self.doBoth([w1.close], [w2.close]) d.addCallback(_done) return d @@ -48,20 +85,23 @@ class Blocking(ServerBase, unittest.TestCase): d = deferToThread(w1.get_code) def _got_code(code): w2.set_code(code) - return gatherResults([deferToThread(w1.get_verifier), - deferToThread(w2.get_verifier)], True) + return self.doBoth([w1.get_verifier], [w2.get_verifier]) d.addCallback(_got_code) def _check_verifier(res): v1, v2 = res self.failUnlessEqual(type(v1), type(b"")) self.failUnlessEqual(v1, v2) - return gatherResults([deferToThread(w1.get_data, b"data1"), - deferToThread(w2.get_data, b"data2")], True) + return self.doBoth([w1.send_data, b"data1"], + [w2.send_data, b"data2"]) d.addCallback(_check_verifier) + def _sent(res): + return self.doBoth([w1.get_data], [w2.get_data]) + d.addCallback(_sent) def _done(dl): (dataX, dataY) = dl self.assertEqual(dataX, b"data2") self.assertEqual(dataY, b"data1") + return self.doBoth([w1.close], [w2.close]) d.addCallback(_done) return d @@ -72,13 +112,13 @@ class Blocking(ServerBase, unittest.TestCase): d = deferToThread(w1.get_code) def _got_code(code): w2.set_code(code+"not") - return gatherResults([deferToThread(w1.get_verifier), - deferToThread(w2.get_verifier)], True) + return self.doBoth([w1.get_verifier], [w2.get_verifier]) d.addCallback(_got_code) def _check_verifier(res): v1, v2 = res self.failUnlessEqual(type(v1), type(b"")) self.failIfEqual(v1, v2) + return self.doBoth([w1.close], [w2.close]) d.addCallback(_check_verifier) return d @@ -86,7 +126,8 @@ class Blocking(ServerBase, unittest.TestCase): appid = b"appid" w1 = BlockingWormhole(appid, self.relayurl) self.assertRaises(UsageError, w1.get_verifier) - self.assertRaises(UsageError, w1.get_data, b"data") + self.assertRaises(UsageError, w1.get_data) + self.assertRaises(UsageError, w1.send_data, b"data") w1.set_code("123-purple-elephant") self.assertRaises(UsageError, w1.set_code, "123-nope") self.assertRaises(UsageError, w1.get_code) @@ -94,6 +135,7 @@ class Blocking(ServerBase, unittest.TestCase): d = deferToThread(w2.get_code) def _done(code): self.assertRaises(UsageError, w2.get_code) + return self.doBoth([w1.close], [w2.close]) d.addCallback(_done) return d @@ -111,15 +153,19 @@ class Blocking(ServerBase, unittest.TestCase): self.assertEqual(type(s), type("")) unpacked = json.loads(s) # this is supposed to be JSON self.assertEqual(type(unpacked), dict) - new_w1 = BlockingWormhole.from_serialized(s) - return gatherResults([deferToThread(new_w1.get_data, b"data1"), - deferToThread(w2.get_data, b"data2")], True) + self.new_w1 = BlockingWormhole.from_serialized(s) + return self.doBoth([self.new_w1.send_data, b"data1"], + [w2.send_data, b"data2"]) d.addCallback(_got_code) + def _sent(res): + return self.doBoth(self.new_w1.get_data(), w2.get_data()) + d.addCallback(_sent) def _done(dl): (dataX, dataY) = dl self.assertEqual(dataX, b"data2") self.assertEqual(dataY, b"data1") self.assertRaises(UsageError, w2.serialize) # too late + return self.doBoth([w1.close], [w2.close]) d.addCallback(_done) return d test_serialize.skip = "not yet implemented for the blocking flavor" diff --git a/src/wormhole/test/test_twisted.py b/src/wormhole/test/test_twisted.py index 676a2d0..63367d2 100644 --- a/src/wormhole/test/test_twisted.py +++ b/src/wormhole/test/test_twisted.py @@ -5,6 +5,10 @@ from ..twisted.transcribe import Wormhole, UsageError from .common import ServerBase class Basic(ServerBase, unittest.TestCase): + + def doBoth(self, d1, d2): + return gatherResults([d1, d2], True) + def test_basic(self): appid = b"appid" w1 = Wormhole(appid, self.relayurl) @@ -12,13 +16,37 @@ class Basic(ServerBase, unittest.TestCase): d = w1.get_code() def _got_code(code): w2.set_code(code) - return gatherResults([w1.get_data(b"data1"), - w2.get_data(b"data2")], True) + return self.doBoth(w1.send_data(b"data1"), w2.send_data(b"data2")) d.addCallback(_got_code) + def _sent(res): + return self.doBoth(w1.get_data(), w2.get_data()) + d.addCallback(_sent) def _done(dl): (dataX, dataY) = dl self.assertEqual(dataX, b"data2") self.assertEqual(dataY, b"data1") + return self.doBoth(w1.close(), w2.close()) + d.addCallback(_done) + return d + + def test_interleaved(self): + appid = b"appid" + w1 = Wormhole(appid, self.relayurl) + w2 = Wormhole(appid, self.relayurl) + d = w1.get_code() + def _got_code(code): + w2.set_code(code) + return self.doBoth(w1.send_data(b"data1"), w2.get_data()) + d.addCallback(_got_code) + def _sent(res): + (_, dataY) = res + self.assertEqual(dataY, b"data1") + return self.doBoth(w1.get_data(), w2.send_data(b"data2")) + d.addCallback(_sent) + def _done(dl): + (dataX, _) = dl + self.assertEqual(dataX, b"data2") + return self.doBoth(w1.close(), w2.close()) d.addCallback(_done) return d @@ -28,12 +56,15 @@ class Basic(ServerBase, unittest.TestCase): w2 = Wormhole(appid, self.relayurl) w1.set_code("123-purple-elephant") w2.set_code("123-purple-elephant") - d = gatherResults([w1.get_data(b"data1"), - w2.get_data(b"data2")], True) + d = self.doBoth(w1.send_data(b"data1"), w2.send_data(b"data2")) + def _sent(res): + return self.doBoth(w1.get_data(), w2.get_data()) + d.addCallback(_sent) def _done(dl): (dataX, dataY) = dl self.assertEqual(dataX, b"data2") self.assertEqual(dataY, b"data1") + return self.doBoth(w1.close(), w2.close()) d.addCallback(_done) return d @@ -44,19 +75,22 @@ class Basic(ServerBase, unittest.TestCase): d = w1.get_code() def _got_code(code): w2.set_code(code) - return gatherResults([w1.get_verifier(), w2.get_verifier()], True) + return self.doBoth(w1.get_verifier(), w2.get_verifier()) d.addCallback(_got_code) def _check_verifier(res): v1, v2 = res self.failUnlessEqual(type(v1), type(b"")) self.failUnlessEqual(v1, v2) - return gatherResults([w1.get_data(b"data1"), - w2.get_data(b"data2")], True) + return self.doBoth(w1.send_data(b"data1"), w2.send_data(b"data2")) d.addCallback(_check_verifier) + def _sent(res): + return self.doBoth(w1.get_data(), w2.get_data()) + d.addCallback(_sent) def _done(dl): (dataX, dataY) = dl self.assertEqual(dataX, b"data2") self.assertEqual(dataY, b"data1") + return self.doBoth(w1.close(), w2.close()) d.addCallback(_done) return d @@ -67,12 +101,13 @@ class Basic(ServerBase, unittest.TestCase): d = w1.get_code() def _got_code(code): w2.set_code(code+"not") - return gatherResults([w1.get_verifier(), w2.get_verifier()], True) + return self.doBoth(w1.get_verifier(), w2.get_verifier()) d.addCallback(_got_code) def _check_verifier(res): v1, v2 = res self.failUnlessEqual(type(v1), type(b"")) self.failIfEqual(v1, v2) + return self.doBoth(w1.close(), w2.close()) d.addCallback(_check_verifier) return d @@ -80,13 +115,17 @@ class Basic(ServerBase, unittest.TestCase): appid = b"appid" w1 = Wormhole(appid, self.relayurl) self.assertRaises(UsageError, w1.get_verifier) - self.assertRaises(UsageError, w1.get_data, b"data") + self.assertRaises(UsageError, w1.send_data, b"data") + self.assertRaises(UsageError, w1.get_data) w1.set_code("123-purple-elephant") self.assertRaises(UsageError, w1.set_code, "123-nope") self.assertRaises(UsageError, w1.get_code) w2 = Wormhole(appid, self.relayurl) d = w2.get_code() self.assertRaises(UsageError, w2.get_code) + def _got_code(code): + return self.doBoth(w1.close(), w2.close()) + d.addCallback(_got_code) return d def test_serialize(self): @@ -103,15 +142,19 @@ class Basic(ServerBase, unittest.TestCase): self.assertEqual(type(s), type("")) unpacked = json.loads(s) # this is supposed to be JSON self.assertEqual(type(unpacked), dict) - new_w1 = Wormhole.from_serialized(s) - return gatherResults([new_w1.get_data(b"data1"), - w2.get_data(b"data2")], True) + self.new_w1 = Wormhole.from_serialized(s) + return self.doBoth(self.new_w1.send_data(b"data1"), + w2.send_data(b"data2")) d.addCallback(_got_code) + def _sent(res): + return self.doBoth(self.new_w1.get_data(), w2.get_data()) + d.addCallback(_sent) def _done(dl): (dataX, dataY) = dl - self.assertEqual(dataX, b"data2") - self.assertEqual(dataY, b"data1") + self.assertEqual((dataX, dataY), (b"data2", b"data1")) self.assertRaises(UsageError, w2.serialize) # too late + return gatherResults([w1.close(), w2.close(), self.new_w1.close()], + True) d.addCallback(_done) return d diff --git a/src/wormhole/twisted/demo.py b/src/wormhole/twisted/demo.py index 3f8c1b3..790440d 100644 --- a/src/wormhole/twisted/demo.py +++ b/src/wormhole/twisted/demo.py @@ -1,5 +1,6 @@ from __future__ import print_function import sys, json +from twisted.python import log from twisted.internet import reactor from .transcribe import Wormhole from .. import public_relay @@ -14,8 +15,11 @@ if sys.argv[1] == "send-text": d = w.get_code() def _got_code(code): print("code is:", code) - return w.get_data(data) + return w.send_data(data) d.addCallback(_got_code) + def _sent(_): + return w.get_data() + d.addCallback(_sent) def _got_data(them_bytes): them_d = json.loads(them_bytes.decode("utf-8")) if them_d["message"] == "ok": @@ -26,16 +30,19 @@ if sys.argv[1] == "send-text": elif sys.argv[1] == "receive-text": code = sys.argv[2] w.set_code(code) - data = json.dumps({"message": "ok"}).encode("utf-8") - d = w.get_data(data) + d = w.get_data() def _got_data(them_bytes): them_d = json.loads(them_bytes.decode("utf-8")) if "error" in them_d: print("ERROR: " + them_d["error"], file=sys.stderr) return 1 print(them_d["message"]) + data = json.dumps({"message": "ok"}).encode("utf-8") + return w.send_data(data) d.addCallback(_got_data) else: raise ValueError("bad command") +d.addCallback(w.close) d.addCallback(lambda _: reactor.stop()) +d.addErrback(log.err) reactor.run() diff --git a/src/wormhole/twisted/transcribe.py b/src/wormhole/twisted/transcribe.py index eaa479f..78f2c95 100644 --- a/src/wormhole/twisted/transcribe.py +++ b/src/wormhole/twisted/transcribe.py @@ -107,11 +107,11 @@ class Channel: d.addCallback(lambda _: msgs[0]) return d - def deallocate(self, res): + def deallocate(self): # only try once, no retries d = post_json(self._agent, self._channel_url+"/deallocate", {"side": self._side}) - d.addBoth(lambda _: res) # ignore POST failure, pass-through result + d.addBoth(lambda _: None) # ignore POST failure return d class ChannelManager: @@ -150,6 +150,9 @@ class Wormhole: self.code = None self.key = None self._started_get_code = False + self._sent_data = False + self._got_data = False + self._closed = False def _set_side(self, side): self._side = side @@ -218,6 +221,8 @@ class Wormhole: # get_verifier/get_data if self.code is None: raise UsageError if self.key is not None: raise UsageError + if self._sent_data: raise UsageError + if self._got_data: raise UsageError data = { "appid": self.appid, "relay": self.relay, @@ -282,32 +287,51 @@ class Wormhole: d.addCallback(lambda _: self.verifier) return d - def get_data(self, outbound_data): - # only call this once + def send_data(self, outbound_data): + if self._sent_data: raise UsageError # only call this once if not isinstance(outbound_data, type(b"")): raise UsageError if self.code is None: raise UsageError - d = self._get_key() - d.addCallback(self._get_data2, outbound_data) - d.addBoth(self.channel.deallocate) - return d - - def _get_data2(self, key, outbound_data): + if self.channel is None: raise UsageError # Without predefined roles, we can't derive predictably unique keys # for each side, so we use the same key for both. We use random - # nonces to keep the messages distinct, and check for reflection. - data_key = self.derive_key(b"data-key") - - outbound_encrypted = self._encrypt_data(data_key, outbound_data) - d = self.channel.send(u"data", outbound_encrypted) - - d.addCallback(lambda _: self.channel.get(u"data")) - def _got_data(inbound_encrypted): - #if inbound_encrypted == outbound_encrypted: - # raise ReflectionAttack - try: - inbound_data = self._decrypt_data(data_key, inbound_encrypted) - return inbound_data - except CryptoError: - raise WrongPasswordError - d.addCallback(_got_data) + # nonces to keep the messages distinct, and the Channel automatically + # ignores reflections. + d = self._get_key() + def _send(key): + data_key = self.derive_key(b"data-key") + outbound_encrypted = self._encrypt_data(data_key, outbound_data) + return self.channel.send(u"data", outbound_encrypted) + d.addCallback(_send) return d + + def get_data(self): + if self._got_data: raise UsageError # only call this once + if self.code is None: raise UsageError + if self.channel is None: raise UsageError + d = self._get_key() + def _get(key): + data_key = self.derive_key(b"data-key") + d1 = self.channel.get(u"data") + def _decrypt(inbound_encrypted): + try: + inbound_data = self._decrypt_data(data_key, + inbound_encrypted) + return inbound_data + except CryptoError: + raise WrongPasswordError + d1.addCallback(_decrypt) + return d1 + d.addCallback(_get) + return d + + def close(self, res=None): + d = self.channel.deallocate() + def _closed(_): + self._closed = True + return res + d.addCallback(_closed) + return d + + def __del__(self): + if not self._closed: + print("Error: a Wormhole instance was not closed", file=sys.stderr)