Merge branch 'error-handling'

This commit is contained in:
Brian Warner 2015-11-11 22:02:51 -08:00
commit dc581d34f2
4 changed files with 221 additions and 174 deletions

View File

@ -56,6 +56,11 @@ suffer longer invitation codes as a result. To encourage `close()`, the
library will log an error if a Wormhole object is destroyed before being library will log an error if a Wormhole object is destroyed before being
closed. closed.
To make it easier to call `close()`, the blocking Wormhole objects can be
used as a context manager. Just put your code in the body of a `with
Wormhole(ARGS) as w:` statement, and `close()` will automatically be called
when the block exits (either successfully or due to an exception).
## Examples ## Examples
The synchronous+blocking flow looks like this: The synchronous+blocking flow looks like this:
@ -64,12 +69,11 @@ The synchronous+blocking flow looks like this:
from wormhole.blocking.transcribe import Wormhole from wormhole.blocking.transcribe import Wormhole
from wormhole.public_relay import RENDEZVOUS_RELAY from wormhole.public_relay import RENDEZVOUS_RELAY
mydata = b"initiator's data" mydata = b"initiator's data"
i = Wormhole(u"appid", RENDEZVOUS_RELAY) with Wormhole(u"appid", RENDEZVOUS_RELAY) as i:
code = i.get_code() code = i.get_code()
print("Invitation Code: %s" % code) print("Invitation Code: %s" % code)
i.send_data(mydata) i.send_data(mydata)
theirdata = i.get_data() theirdata = i.get_data()
i.close()
print("Their data: %s" % theirdata.decode("ascii")) print("Their data: %s" % theirdata.decode("ascii"))
``` ```
@ -79,11 +83,10 @@ from wormhole.blocking.transcribe import Wormhole
from wormhole.public_relay import RENDEZVOUS_RELAY from wormhole.public_relay import RENDEZVOUS_RELAY
mydata = b"receiver's data" mydata = b"receiver's data"
code = sys.argv[1] code = sys.argv[1]
r = Wormhole(u"appid", RENDEZVOUS_RELAY) with Wormhole(u"appid", RENDEZVOUS_RELAY) as r:
r.set_code(code) r.set_code(code)
r.send_data(mydata) r.send_data(mydata)
theirdata = r.get_data() theirdata = r.get_data()
r.close()
print("Their data: %s" % theirdata.decode("ascii")) print("Their data: %s" % theirdata.decode("ascii"))
``` ```

View File

@ -30,7 +30,8 @@ def to_bytes(u):
# all JSON responses include a "welcome:{..}" key # all JSON responses include a "welcome:{..}" key
class Channel: class Channel:
def __init__(self, relay_url, appid, channelid, side, handle_welcome): def __init__(self, relay_url, appid, channelid, side, handle_welcome,
wait, timeout):
self._relay_url = relay_url self._relay_url = relay_url
self._appid = appid self._appid = appid
self._channelid = channelid self._channelid = channelid
@ -39,8 +40,8 @@ class Channel:
self._messages = set() # (phase,body) , body is bytes self._messages = set() # (phase,body) , body is bytes
self._sent_messages = set() # (phase,body) self._sent_messages = set() # (phase,body)
self._started = time.time() self._started = time.time()
self._wait = 0.5*SECOND self._wait = wait
self._timeout = 3*MINUTE self._timeout = timeout
def _add_inbound_messages(self, messages): def _add_inbound_messages(self, messages):
for msg in messages: for msg in messages:
@ -66,7 +67,8 @@ class Channel:
"phase": phase, "phase": phase,
"body": hexlify(msg).decode("ascii")} "body": hexlify(msg).decode("ascii")}
data = json.dumps(payload).encode("utf-8") data = json.dumps(payload).encode("utf-8")
r = requests.post(self._relay_url+"add", data=data) r = requests.post(self._relay_url+"add", data=data,
timeout=self._timeout)
r.raise_for_status() r.raise_for_status()
resp = r.json() resp = r.json()
self._add_inbound_messages(resp["messages"]) self._add_inbound_messages(resp["messages"])
@ -84,7 +86,7 @@ class Channel:
while body is None: while body is None:
remaining = self._started + self._timeout - time.time() remaining = self._started + self._timeout - time.time()
if remaining < 0: if remaining < 0:
return Timeout raise Timeout
queryargs = urlencode([("appid", self._appid), queryargs = urlencode([("appid", self._appid),
("channelid", self._channelid)]) ("channelid", self._channelid)])
f = EventSourceFollower(self._relay_url+"get?%s" % queryargs, f = EventSourceFollower(self._relay_url+"get?%s" % queryargs,
@ -104,25 +106,35 @@ class Channel:
time.sleep(self._wait) time.sleep(self._wait)
return body return body
def deallocate(self, mood=u"unknown"): def deallocate(self, mood=None):
# only try once, no retries # only try once, no retries
data = json.dumps({"appid": self._appid, data = json.dumps({"appid": self._appid,
"channelid": self._channelid, "channelid": self._channelid,
"side": self._side, "side": self._side,
"mood": mood}).encode("utf-8") "mood": mood}).encode("utf-8")
requests.post(self._relay_url+"deallocate", data=data) try:
# ignore POST failure, don't call r.raise_for_status() # ignore POST failure, don't call r.raise_for_status(), set a
# short timeout and ignore failures
requests.post(self._relay_url+"deallocate", data=data,
timeout=5)
except (requests.exceptions.ConnectionError,
requests.exceptions.Timeout):
pass
class ChannelManager: class ChannelManager:
def __init__(self, relay_url, appid, side, handle_welcome): def __init__(self, relay_url, appid, side, handle_welcome,
wait=0.5*SECOND, timeout=3*MINUTE):
self._relay_url = relay_url self._relay_url = relay_url
self._appid = appid self._appid = appid
self._side = side self._side = side
self._handle_welcome = handle_welcome self._handle_welcome = handle_welcome
self._wait = wait
self._timeout = timeout
def list_channels(self): def list_channels(self):
queryargs = urlencode([("appid", self._appid)]) queryargs = urlencode([("appid", self._appid)])
r = requests.get(self._relay_url+"list?%s" % queryargs) r = requests.get(self._relay_url+"list?%s" % queryargs,
timeout=self._timeout)
r.raise_for_status() r.raise_for_status()
channelids = r.json()["channelids"] channelids = r.json()["channelids"]
return channelids return channelids
@ -130,7 +142,8 @@ class ChannelManager:
def allocate(self): def allocate(self):
data = json.dumps({"appid": self._appid, data = json.dumps({"appid": self._appid,
"side": self._side}).encode("utf-8") "side": self._side}).encode("utf-8")
r = requests.post(self._relay_url+"allocate", data=data) r = requests.post(self._relay_url+"allocate", data=data,
timeout=self._timeout)
r.raise_for_status() r.raise_for_status()
data = r.json() data = r.json()
if "welcome" in data: if "welcome" in data:
@ -140,27 +153,61 @@ class ChannelManager:
def connect(self, channelid): def connect(self, channelid):
return Channel(self._relay_url, self._appid, channelid, self._side, return Channel(self._relay_url, self._appid, channelid, self._side,
self._handle_welcome) self._handle_welcome, self._wait, self._timeout)
def close_on_error(f): # method decorator
# Clients report certain errors as "moods", so the server can make a
# rough count failed connections (due to mismatched passwords, attacks,
# or timeouts). We don't report precondition failures, as those are the
# responsibility/fault of the local application code. We count
# non-precondition errors in case they represent server-side problems.
def _f(self, *args, **kwargs):
try:
return f(self, *args, **kwargs)
except Timeout:
self.close(u"lonely")
raise
except WrongPasswordError:
self.close(u"scared")
raise
except (TypeError, UsageError):
# preconditions don't warrant _close_with_error()
raise
except:
self.close(u"other-error")
raise
return _f
class Wormhole: class Wormhole:
motd_displayed = False motd_displayed = False
version_warning_displayed = False version_warning_displayed = False
def __init__(self, appid, relay_url): def __init__(self, appid, relay_url, wait=0.5*SECOND, timeout=3*MINUTE):
if not isinstance(appid, type(u"")): raise TypeError(type(appid)) if not isinstance(appid, type(u"")): raise TypeError(type(appid))
if not isinstance(relay_url, type(u"")): if not isinstance(relay_url, type(u"")):
raise TypeError(type(relay_url)) raise TypeError(type(relay_url))
if not relay_url.endswith(u"/"): raise UsageError if not relay_url.endswith(u"/"): raise UsageError
self._appid = appid self._appid = appid
self._relay_url = relay_url self._relay_url = relay_url
self._wait = wait
self._timeout = timeout
side = hexlify(os.urandom(5)).decode("ascii") side = hexlify(os.urandom(5)).decode("ascii")
self._channel_manager = ChannelManager(relay_url, appid, side, self._channel_manager = ChannelManager(relay_url, appid, side,
self.handle_welcome) self.handle_welcome,
self._wait, self._timeout)
self._channel = None
self.code = None self.code = None
self.key = None self.key = None
self.verifier = None self.verifier = None
self._sent_data = set() # phases self._sent_data = set() # phases
self._got_data = set() self._got_data = set()
self._closed = False
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
return False
def handle_welcome(self, welcome): def handle_welcome(self, welcome):
if ("motd" in welcome and if ("motd" in welcome and
@ -212,8 +259,8 @@ class Wormhole:
raise ValueError("code (%s) must start with NN-" % code) raise ValueError("code (%s) must start with NN-" % code)
self.code = code self.code = code
channelid = int(mo.group(1)) channelid = int(mo.group(1))
self.channel = self._channel_manager.connect(channelid) self._channel = self._channel_manager.connect(channelid)
monitor.add(self.channel) monitor.add(self._channel)
def _start(self): def _start(self):
# allocate the rest now too, so it can be serialized # allocate the rest now too, so it can be serialized
@ -221,6 +268,7 @@ class Wormhole:
idSymmetric=to_bytes(self._appid)) idSymmetric=to_bytes(self._appid))
self.msg1 = self.sp.start() self.msg1 = self.sp.start()
@close_on_error
def derive_key(self, purpose, length=SecretBox.KEY_SIZE): def derive_key(self, purpose, length=SecretBox.KEY_SIZE):
if not isinstance(purpose, type(u"")): raise TypeError(type(purpose)) if not isinstance(purpose, type(u"")): raise TypeError(type(purpose))
return HKDF(self.key, length, CTXinfo=to_bytes(purpose)) return HKDF(self.key, length, CTXinfo=to_bytes(purpose))
@ -244,24 +292,28 @@ class Wormhole:
def _get_key(self): def _get_key(self):
if not self.key: if not self.key:
self.channel.send(u"pake", self.msg1) self._channel.send(u"pake", self.msg1)
pake_msg = self.channel.get(u"pake") pake_msg = self._channel.get(u"pake")
self.key = self.sp.finish(pake_msg) self.key = self.sp.finish(pake_msg)
self.verifier = self.derive_key(u"wormhole:verifier") self.verifier = self.derive_key(u"wormhole:verifier")
@close_on_error
def get_verifier(self): def get_verifier(self):
if self._closed: raise UsageError
if self.code is None: raise UsageError if self.code is None: raise UsageError
if self.channel is None: raise UsageError if self._channel is None: raise UsageError
self._get_key() self._get_key()
return self.verifier return self.verifier
@close_on_error
def send_data(self, outbound_data, phase=u"data"): def send_data(self, outbound_data, phase=u"data"):
if not isinstance(outbound_data, type(b"")): if not isinstance(outbound_data, type(b"")):
raise TypeError(type(outbound_data)) raise TypeError(type(outbound_data))
if not isinstance(phase, type(u"")): raise TypeError(type(phase)) if not isinstance(phase, type(u"")): raise TypeError(type(phase))
if self._closed: raise UsageError
if phase in self._sent_data: raise UsageError # only call this once if phase in self._sent_data: raise UsageError # only call this once
if self.code is None: raise UsageError if self.code is None: raise UsageError
if self.channel is None: raise UsageError if self._channel is None: raise UsageError
# Without predefined roles, we can't derive predictably unique keys # Without predefined roles, we can't derive predictably unique keys
# for each side, so we use the same key for both. We use random # for each side, so we use the same key for both. We use random
# nonces to keep the messages distinct, and the Channel automatically # nonces to keep the messages distinct, and the Channel automatically
@ -270,23 +322,30 @@ class Wormhole:
self._get_key() self._get_key()
data_key = self.derive_key(u"wormhole:phase:%s" % phase) data_key = self.derive_key(u"wormhole:phase:%s" % phase)
outbound_encrypted = self._encrypt_data(data_key, outbound_data) outbound_encrypted = self._encrypt_data(data_key, outbound_data)
self.channel.send(phase, outbound_encrypted) self._channel.send(phase, outbound_encrypted)
@close_on_error
def get_data(self, phase=u"data"): def get_data(self, phase=u"data"):
if not isinstance(phase, type(u"")): raise TypeError(type(phase)) if not isinstance(phase, type(u"")): raise TypeError(type(phase))
if phase in self._got_data: raise UsageError # only call this once if phase in self._got_data: raise UsageError # only call this once
if self._closed: raise UsageError
if self.code is None: raise UsageError if self.code is None: raise UsageError
if self.channel is None: raise UsageError if self._channel is None: raise UsageError
self._got_data.add(phase) self._got_data.add(phase)
self._get_key() self._get_key()
data_key = self.derive_key(u"wormhole:phase:%s" % phase) data_key = self.derive_key(u"wormhole:phase:%s" % phase)
inbound_encrypted = self.channel.get(phase) inbound_encrypted = self._channel.get(phase)
try: try:
inbound_data = self._decrypt_data(data_key, inbound_encrypted) inbound_data = self._decrypt_data(data_key, inbound_encrypted)
return inbound_data return inbound_data
except CryptoError: except CryptoError:
raise WrongPasswordError raise WrongPasswordError
def close(self): def close(self, mood=None):
monitor.close(self.channel) if not isinstance(mood, (type(None), type(u""))):
self.channel.deallocate() raise TypeError(type(mood))
self._closed = True
if self._channel:
c, self._channel = self._channel, None
monitor.close(c)
c.deallocate(mood)

View File

@ -12,7 +12,7 @@ def receive(args):
from .progress import start_progress, update_progress, finish_progress from .progress import start_progress, update_progress, finish_progress
assert isinstance(args.relay_url, type(u"")) assert isinstance(args.relay_url, type(u""))
w = Wormhole(APPID, args.relay_url) with Wormhole(APPID, args.relay_url) as w:
if args.zeromode: if args.zeromode:
assert not args.code assert not args.code
args.code = u"0-" args.code = u"0-"
@ -29,12 +29,10 @@ def receive(args):
them_bytes = w.get_data() them_bytes = w.get_data()
except WrongPasswordError as e: except WrongPasswordError as e:
print("ERROR: " + e.explain(), file=sys.stderr) print("ERROR: " + e.explain(), file=sys.stderr)
w.close()
return 1 return 1
them_d = json.loads(them_bytes.decode("utf-8")) them_d = json.loads(them_bytes.decode("utf-8"))
if "error" in them_d: if "error" in them_d:
print("ERROR: " + them_d["error"], file=sys.stderr) print("ERROR: " + them_d["error"], file=sys.stderr)
w.close()
return 1 return 1
if "message" in them_d: if "message" in them_d:
@ -42,18 +40,15 @@ def receive(args):
print(them_d["message"]) print(them_d["message"])
data = json.dumps({"message_ack": "ok"}).encode("utf-8") data = json.dumps({"message_ack": "ok"}).encode("utf-8")
w.send_data(data) w.send_data(data)
w.close()
return 0 return 0
if not "file" in them_d: if not "file" in them_d:
print("I don't know what they're offering\n") print("I don't know what they're offering\n")
print(them_d) print(them_d)
w.close()
return 1 return 1
if "error" in them_d: if "error" in them_d:
print("ERROR: " + data["error"], file=sys.stderr) print("ERROR: " + data["error"], file=sys.stderr)
w.close()
return 1 return 1
file_data = them_d["file"] file_data = them_d["file"]
@ -67,7 +62,6 @@ def receive(args):
print("Error: refusing to overwrite existing file %s" % (filename,)) print("Error: refusing to overwrite existing file %s" % (filename,))
data = json.dumps({"error": "file already exists"}).encode("utf-8") data = json.dumps({"error": "file already exists"}).encode("utf-8")
w.send_data(data) w.send_data(data)
w.close()
return 1 return 1
print("Receiving file (%d bytes) into: %s" % (filesize, filename)) print("Receiving file (%d bytes) into: %s" % (filesize, filename))
@ -78,7 +72,6 @@ def receive(args):
print("transfer rejected", file=sys.stderr) print("transfer rejected", file=sys.stderr)
data = json.dumps({"error": "transfer rejected"}).encode("utf-8") data = json.dumps({"error": "transfer rejected"}).encode("utf-8")
w.send_data(data) w.send_data(data)
w.close()
return 1 return 1
transit_receiver = TransitReceiver(args.transit_helper) transit_receiver = TransitReceiver(args.transit_helper)
@ -90,7 +83,6 @@ def receive(args):
}, },
}).encode("utf-8") }).encode("utf-8")
w.send_data(data) w.send_data(data)
w.close()
# now receive the rest of the owl # now receive the rest of the owl
tdata = them_d["transit"] tdata = them_d["transit"]

View File

@ -46,7 +46,7 @@ def send(args):
}, },
} }
w = Wormhole(APPID, args.relay_url) with Wormhole(APPID, args.relay_url) as w:
if args.zeromode: if args.zeromode:
assert not args.code assert not args.code
args.code = u"0-" args.code = u"0-"
@ -77,7 +77,6 @@ def send(args):
reject_data = json.dumps({"error": "verification rejected", reject_data = json.dumps({"error": "verification rejected",
}).encode("utf-8") }).encode("utf-8")
w.send_data(reject_data) w.send_data(reject_data)
w.close()
return 1 return 1
my_phase1_bytes = json.dumps(phase1).encode("utf-8") my_phase1_bytes = json.dumps(phase1).encode("utf-8")
@ -86,30 +85,24 @@ def send(args):
them_phase1_bytes = w.get_data() them_phase1_bytes = w.get_data()
except WrongPasswordError as e: except WrongPasswordError as e:
print("ERROR: " + e.explain(), file=sys.stderr) print("ERROR: " + e.explain(), file=sys.stderr)
w.close()
return 1 return 1
them_phase1 = json.loads(them_phase1_bytes.decode("utf-8")) them_phase1 = json.loads(them_phase1_bytes.decode("utf-8"))
if sending_message: if sending_message:
if them_phase1["message_ack"] == "ok": if them_phase1["message_ack"] == "ok":
print("text message sent") print("text message sent")
w.close()
return 0 return 0
print("error sending text: %r" % (them_phase1,)) print("error sending text: %r" % (them_phase1,))
w.close()
return 1 return 1
if "error" in them_phase1: if "error" in them_phase1:
print("remote error: %s" % them_phase1["error"]) print("remote error: %s" % them_phase1["error"])
print("transfer abandoned") print("transfer abandoned")
w.close()
return 1 return 1
if them_phase1.get("file_ack") != "ok": if them_phase1.get("file_ack") != "ok":
print("ambiguous response from remote: %s" % (them_phase1,)) print("ambiguous response from remote: %s" % (them_phase1,))
print("transfer abandoned") print("transfer abandoned")
w.close()
return 1 return 1
w.close()
tdata = them_phase1["transit"] tdata = them_phase1["transit"]
transit_key = w.derive_key(APPID+"/transit-key") transit_key = w.derive_key(APPID+"/transit-key")