diff --git a/src/wormhole/test/test_twisted.py b/src/wormhole/test/test_twisted.py index fe016f4..60f4ee2 100644 --- a/src/wormhole/test/test_twisted.py +++ b/src/wormhole/test/test_twisted.py @@ -43,8 +43,8 @@ class Basic(unittest.TestCase): def _done(dl): ((success1, dataX), (success2, dataY)) = dl r1,r2 = dl - self.assertTrue(success1) - self.assertTrue(success2) + self.assertTrue(success1, dataX) + self.assertTrue(success2, dataY) self.assertEqual(dataX, "data2") self.assertEqual(dataY, "data1") d.addCallback(_done) @@ -62,8 +62,8 @@ class Basic(unittest.TestCase): def _done(dl): ((success1, dataX), (success2, dataY)) = dl r1,r2 = dl - self.assertTrue(success1) - self.assertTrue(success2) + self.assertTrue(success1, dataX) + self.assertTrue(success2, dataY) self.assertEqual(dataX, "data2") self.assertEqual(dataY, "data1") d.addCallback(_done) @@ -104,8 +104,8 @@ class Basic(unittest.TestCase): def _done(dl): ((success1, dataX), (success2, dataY)) = dl r1,r2 = dl - self.assertTrue(success1) - self.assertTrue(success2) + self.assertTrue(success1, dataX) + self.assertTrue(success2, dataY) self.assertEqual(dataX, "data2") self.assertEqual(dataY, "data1") self.assertRaises(UsageError, w2.serialize) # too late diff --git a/src/wormhole/twisted/transcribe.py b/src/wormhole/twisted/transcribe.py index 712bd3a..d1143c7 100644 --- a/src/wormhole/twisted/transcribe.py +++ b/src/wormhole/twisted/transcribe.py @@ -44,6 +44,9 @@ class DataProducer: class SymmetricWormhole: + motd_displayed = False + version_warning_displayed = False + def __init__(self, appid, relay): self.appid = appid self.relay = relay @@ -53,6 +56,34 @@ class SymmetricWormhole: self.key = None self._started_get_code = False + def _url(self, verb, msgnum=None): + url = "%s%d/%s/%s" % (self.relay, self.channel_id, self.side, verb) + if msgnum is not None: + url += "/" + msgnum + return url + + def handle_welcome(self, welcome): + if ("motd" in welcome and + not self.motd_displayed): + motd_lines = welcome["motd"].splitlines() + motd_formatted = "\n ".join(motd_lines) + print("Server (at %s) says:\n %s" % (self.relay, motd_formatted), + file=sys.stderr) + self.motd_displayed = True + + # Only warn if we're running a release version (e.g. 0.0.6, not + # 0.0.6-DISTANCE-gHASH). Only warn once. + if ("-" not in __version__ and + not self.version_warning_displayed and + welcome["current_version"] != __version__): + print("Warning: errors may occur unless both sides are running the same version", file=sys.stderr) + print("Server claims %s is current, but ours is %s" + % (welcome["current_version"], __version__), file=sys.stderr) + self.version_warning_displayed = True + + if "error" in welcome: + raise ServerError(welcome["error"], self.relay) + def get_code(self, code_length=2): if self.code is not None: raise UsageError if self._started_get_code: raise UsageError @@ -141,37 +172,6 @@ class SymmetricWormhole: self.msg1 = d["msg1"].decode("hex") return self - motd_displayed = False - version_warning_displayed = False - - def handle_welcome(self, welcome): - if ("motd" in welcome and - not self.motd_displayed): - motd_lines = welcome["motd"].splitlines() - motd_formatted = "\n ".join(motd_lines) - print("Server (at %s) says:\n %s" % (self.relay, motd_formatted), - file=sys.stderr) - self.motd_displayed = True - - # Only warn if we're running a release version (e.g. 0.0.6, not - # 0.0.6-DISTANCE-gHASH). Only warn once. - if ("-" not in __version__ and - not self.version_warning_displayed and - welcome["current_version"] != __version__): - print("Warning: errors may occur unless both sides are running the same version", file=sys.stderr) - print("Server claims %s is current, but ours is %s" - % (welcome["current_version"], __version__), file=sys.stderr) - self.version_warning_displayed = True - - if "error" in welcome: - raise ServerError(welcome["error"], self.relay) - - def _url(self, verb, msgnum=None): - url = "%s%d/%s/%s" % (self.relay, self.channel_id, self.side, verb) - if msgnum is not None: - url += "/" + msgnum - return url - def _post_message(self, url, msg): # TODO: retry on failure, with exponential backoff. We're guarding # against the rendezvous server being temporarily offline. @@ -257,8 +257,10 @@ class SymmetricWormhole: # for each side, so we use the same key for both. We use random # nonces to keep the messages distinct, and check for reflection. data_key = self.derive_key(b"data-key") + outbound_encrypted = self._encrypt_data(data_key, outbound_data) d = self._post_message(self._url("post", "data"), outbound_encrypted) + d.addCallback(lambda msgs: self._get_message(msgs, "poll", "data")) def _got_data(inbound_encrypted): if inbound_encrypted == outbound_encrypted: