minor reformatting, improve test error messages

This commit is contained in:
Brian Warner 2015-07-24 16:57:19 -07:00
parent cdeaac0ad0
commit cebfa71563
2 changed files with 39 additions and 37 deletions

View File

@ -43,8 +43,8 @@ class Basic(unittest.TestCase):
def _done(dl):
((success1, dataX), (success2, dataY)) = dl
r1,r2 = dl
self.assertTrue(success1)
self.assertTrue(success2)
self.assertTrue(success1, dataX)
self.assertTrue(success2, dataY)
self.assertEqual(dataX, "data2")
self.assertEqual(dataY, "data1")
d.addCallback(_done)
@ -62,8 +62,8 @@ class Basic(unittest.TestCase):
def _done(dl):
((success1, dataX), (success2, dataY)) = dl
r1,r2 = dl
self.assertTrue(success1)
self.assertTrue(success2)
self.assertTrue(success1, dataX)
self.assertTrue(success2, dataY)
self.assertEqual(dataX, "data2")
self.assertEqual(dataY, "data1")
d.addCallback(_done)
@ -104,8 +104,8 @@ class Basic(unittest.TestCase):
def _done(dl):
((success1, dataX), (success2, dataY)) = dl
r1,r2 = dl
self.assertTrue(success1)
self.assertTrue(success2)
self.assertTrue(success1, dataX)
self.assertTrue(success2, dataY)
self.assertEqual(dataX, "data2")
self.assertEqual(dataY, "data1")
self.assertRaises(UsageError, w2.serialize) # too late

View File

@ -44,6 +44,9 @@ class DataProducer:
class SymmetricWormhole:
motd_displayed = False
version_warning_displayed = False
def __init__(self, appid, relay):
self.appid = appid
self.relay = relay
@ -53,6 +56,34 @@ class SymmetricWormhole:
self.key = None
self._started_get_code = False
def _url(self, verb, msgnum=None):
url = "%s%d/%s/%s" % (self.relay, self.channel_id, self.side, verb)
if msgnum is not None:
url += "/" + msgnum
return url
def handle_welcome(self, welcome):
if ("motd" in welcome and
not self.motd_displayed):
motd_lines = welcome["motd"].splitlines()
motd_formatted = "\n ".join(motd_lines)
print("Server (at %s) says:\n %s" % (self.relay, motd_formatted),
file=sys.stderr)
self.motd_displayed = True
# Only warn if we're running a release version (e.g. 0.0.6, not
# 0.0.6-DISTANCE-gHASH). Only warn once.
if ("-" not in __version__ and
not self.version_warning_displayed and
welcome["current_version"] != __version__):
print("Warning: errors may occur unless both sides are running the same version", file=sys.stderr)
print("Server claims %s is current, but ours is %s"
% (welcome["current_version"], __version__), file=sys.stderr)
self.version_warning_displayed = True
if "error" in welcome:
raise ServerError(welcome["error"], self.relay)
def get_code(self, code_length=2):
if self.code is not None: raise UsageError
if self._started_get_code: raise UsageError
@ -141,37 +172,6 @@ class SymmetricWormhole:
self.msg1 = d["msg1"].decode("hex")
return self
motd_displayed = False
version_warning_displayed = False
def handle_welcome(self, welcome):
if ("motd" in welcome and
not self.motd_displayed):
motd_lines = welcome["motd"].splitlines()
motd_formatted = "\n ".join(motd_lines)
print("Server (at %s) says:\n %s" % (self.relay, motd_formatted),
file=sys.stderr)
self.motd_displayed = True
# Only warn if we're running a release version (e.g. 0.0.6, not
# 0.0.6-DISTANCE-gHASH). Only warn once.
if ("-" not in __version__ and
not self.version_warning_displayed and
welcome["current_version"] != __version__):
print("Warning: errors may occur unless both sides are running the same version", file=sys.stderr)
print("Server claims %s is current, but ours is %s"
% (welcome["current_version"], __version__), file=sys.stderr)
self.version_warning_displayed = True
if "error" in welcome:
raise ServerError(welcome["error"], self.relay)
def _url(self, verb, msgnum=None):
url = "%s%d/%s/%s" % (self.relay, self.channel_id, self.side, verb)
if msgnum is not None:
url += "/" + msgnum
return url
def _post_message(self, url, msg):
# TODO: retry on failure, with exponential backoff. We're guarding
# against the rendezvous server being temporarily offline.
@ -257,8 +257,10 @@ class SymmetricWormhole:
# for each side, so we use the same key for both. We use random
# nonces to keep the messages distinct, and check for reflection.
data_key = self.derive_key(b"data-key")
outbound_encrypted = self._encrypt_data(data_key, outbound_data)
d = self._post_message(self._url("post", "data"), outbound_encrypted)
d.addCallback(lambda msgs: self._get_message(msgs, "poll", "data"))
def _got_data(inbound_encrypted):
if inbound_encrypted == outbound_encrypted: