From 07686f3de751aa38d386f818303b99ef536224f9 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Wed, 11 Nov 2015 17:52:13 -0800 Subject: [PATCH 1/8] make self.channel internal --- src/wormhole/blocking/transcribe.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/src/wormhole/blocking/transcribe.py b/src/wormhole/blocking/transcribe.py index 4386c59..965738f 100644 --- a/src/wormhole/blocking/transcribe.py +++ b/src/wormhole/blocking/transcribe.py @@ -156,6 +156,7 @@ class Wormhole: side = hexlify(os.urandom(5)).decode("ascii") self._channel_manager = ChannelManager(relay_url, appid, side, self.handle_welcome) + self._channel = None self.code = None self.key = None self.verifier = None @@ -212,8 +213,8 @@ class Wormhole: raise ValueError("code (%s) must start with NN-" % code) self.code = code channelid = int(mo.group(1)) - self.channel = self._channel_manager.connect(channelid) - monitor.add(self.channel) + self._channel = self._channel_manager.connect(channelid) + monitor.add(self._channel) def _start(self): # allocate the rest now too, so it can be serialized @@ -244,14 +245,14 @@ class Wormhole: def _get_key(self): if not self.key: - self.channel.send(u"pake", self.msg1) - pake_msg = self.channel.get(u"pake") + self._channel.send(u"pake", self.msg1) + pake_msg = self._channel.get(u"pake") self.key = self.sp.finish(pake_msg) self.verifier = self.derive_key(u"wormhole:verifier") def get_verifier(self): if self.code is None: raise UsageError - if self.channel is None: raise UsageError + if self._channel is None: raise UsageError self._get_key() return self.verifier @@ -261,7 +262,7 @@ class Wormhole: if not isinstance(phase, type(u"")): raise TypeError(type(phase)) if phase in self._sent_data: raise UsageError # only call this once if self.code is None: raise UsageError - if self.channel is None: raise UsageError + if self._channel is None: raise UsageError # Without predefined roles, we can't derive predictably unique keys # for each side, so we use the same key for both. We use random # nonces to keep the messages distinct, and the Channel automatically @@ -270,17 +271,17 @@ class Wormhole: self._get_key() data_key = self.derive_key(u"wormhole:phase:%s" % phase) outbound_encrypted = self._encrypt_data(data_key, outbound_data) - self.channel.send(phase, outbound_encrypted) + self._channel.send(phase, outbound_encrypted) def get_data(self, phase=u"data"): if not isinstance(phase, type(u"")): raise TypeError(type(phase)) if phase in self._got_data: raise UsageError # only call this once if self.code is None: raise UsageError - if self.channel is None: raise UsageError + if self._channel is None: raise UsageError self._got_data.add(phase) self._get_key() data_key = self.derive_key(u"wormhole:phase:%s" % phase) - inbound_encrypted = self.channel.get(phase) + inbound_encrypted = self._channel.get(phase) try: inbound_data = self._decrypt_data(data_key, inbound_encrypted) return inbound_data @@ -288,5 +289,5 @@ class Wormhole: raise WrongPasswordError def close(self): - monitor.close(self.channel) - self.channel.deallocate() + monitor.close(self._channel) + self._channel.deallocate() From fa3be3523deedeef02f98c8084fcd0e188947ebd Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Wed, 11 Nov 2015 17:56:08 -0800 Subject: [PATCH 2/8] pass timeouts down --- src/wormhole/blocking/transcribe.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/src/wormhole/blocking/transcribe.py b/src/wormhole/blocking/transcribe.py index 965738f..94de91b 100644 --- a/src/wormhole/blocking/transcribe.py +++ b/src/wormhole/blocking/transcribe.py @@ -30,7 +30,8 @@ def to_bytes(u): # all JSON responses include a "welcome:{..}" key class Channel: - def __init__(self, relay_url, appid, channelid, side, handle_welcome): + def __init__(self, relay_url, appid, channelid, side, handle_welcome, + wait, timeout): self._relay_url = relay_url self._appid = appid self._channelid = channelid @@ -39,8 +40,8 @@ class Channel: self._messages = set() # (phase,body) , body is bytes self._sent_messages = set() # (phase,body) self._started = time.time() - self._wait = 0.5*SECOND - self._timeout = 3*MINUTE + self._wait = wait + self._timeout = timeout def _add_inbound_messages(self, messages): for msg in messages: @@ -114,11 +115,14 @@ class Channel: # ignore POST failure, don't call r.raise_for_status() class ChannelManager: - def __init__(self, relay_url, appid, side, handle_welcome): + def __init__(self, relay_url, appid, side, handle_welcome, + wait=0.5*SECOND, timeout=3*MINUTE): self._relay_url = relay_url self._appid = appid self._side = side self._handle_welcome = handle_welcome + self._wait = wait + self._timeout = timeout def list_channels(self): queryargs = urlencode([("appid", self._appid)]) @@ -140,22 +144,25 @@ class ChannelManager: def connect(self, channelid): return Channel(self._relay_url, self._appid, channelid, self._side, - self._handle_welcome) + self._handle_welcome, self._wait, self._timeout) class Wormhole: motd_displayed = False version_warning_displayed = False - def __init__(self, appid, relay_url): + def __init__(self, appid, relay_url, wait=0.5*SECOND, timeout=3*MINUTE): if not isinstance(appid, type(u"")): raise TypeError(type(appid)) if not isinstance(relay_url, type(u"")): raise TypeError(type(relay_url)) if not relay_url.endswith(u"/"): raise UsageError self._appid = appid self._relay_url = relay_url + self._wait = wait + self._timeout = timeout side = hexlify(os.urandom(5)).decode("ascii") self._channel_manager = ChannelManager(relay_url, appid, side, - self.handle_welcome) + self.handle_welcome, + self._wait, self._timeout) self._channel = None self.code = None self.key = None From 6de677c1dff09fed03913c60c8fd9cd4caee040e Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Wed, 11 Nov 2015 18:00:06 -0800 Subject: [PATCH 3/8] use timeouts for allocate and list_channels too --- src/wormhole/blocking/transcribe.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/src/wormhole/blocking/transcribe.py b/src/wormhole/blocking/transcribe.py index 94de91b..52ebc49 100644 --- a/src/wormhole/blocking/transcribe.py +++ b/src/wormhole/blocking/transcribe.py @@ -67,7 +67,8 @@ class Channel: "phase": phase, "body": hexlify(msg).decode("ascii")} data = json.dumps(payload).encode("utf-8") - r = requests.post(self._relay_url+"add", data=data) + r = requests.post(self._relay_url+"add", data=data, + timeout=self._timeout) r.raise_for_status() resp = r.json() self._add_inbound_messages(resp["messages"]) @@ -111,8 +112,13 @@ class Channel: "channelid": self._channelid, "side": self._side, "mood": mood}).encode("utf-8") - requests.post(self._relay_url+"deallocate", data=data) - # ignore POST failure, don't call r.raise_for_status() + try: + # ignore POST failure, don't call r.raise_for_status(), set a + # short timeout and ignore failures + requests.post(self._relay_url+"deallocate", data=data, + timeout=5) + except requests.exceptions.Timeout: + pass class ChannelManager: def __init__(self, relay_url, appid, side, handle_welcome, @@ -126,7 +132,8 @@ class ChannelManager: def list_channels(self): queryargs = urlencode([("appid", self._appid)]) - r = requests.get(self._relay_url+"list?%s" % queryargs) + r = requests.get(self._relay_url+"list?%s" % queryargs, + timeout=self._timeout) r.raise_for_status() channelids = r.json()["channelids"] return channelids @@ -134,7 +141,8 @@ class ChannelManager: def allocate(self): data = json.dumps({"appid": self._appid, "side": self._side}).encode("utf-8") - r = requests.post(self._relay_url+"allocate", data=data) + r = requests.post(self._relay_url+"allocate", data=data, + timeout=self._timeout) r.raise_for_status() data = r.json() if "welcome" in data: From cb5ad8ced12b36ed4711e1f995558c34d707b8e4 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Wed, 11 Nov 2015 18:01:22 -0800 Subject: [PATCH 4/8] Use exception for Timeout, not return value --- src/wormhole/blocking/transcribe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/wormhole/blocking/transcribe.py b/src/wormhole/blocking/transcribe.py index 52ebc49..f028ad7 100644 --- a/src/wormhole/blocking/transcribe.py +++ b/src/wormhole/blocking/transcribe.py @@ -86,7 +86,7 @@ class Channel: while body is None: remaining = self._started + self._timeout - time.time() if remaining < 0: - return Timeout + raise Timeout queryargs = urlencode([("appid", self._appid), ("channelid", self._channelid)]) f = EventSourceFollower(self._relay_url+"get?%s" % queryargs, From 3daef13ac0ede2c82eb7b4bc17d81f9576ec6380 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Wed, 11 Nov 2015 18:10:18 -0800 Subject: [PATCH 5/8] indent commands: no functional changes --- src/wormhole/scripts/cmd_receive.py | 149 ++++++++++++++-------------- src/wormhole/scripts/cmd_send.py | 119 +++++++++++----------- 2 files changed, 135 insertions(+), 133 deletions(-) diff --git a/src/wormhole/scripts/cmd_receive.py b/src/wormhole/scripts/cmd_receive.py index d369b09..cfc98ad 100644 --- a/src/wormhole/scripts/cmd_receive.py +++ b/src/wormhole/scripts/cmd_receive.py @@ -12,85 +12,86 @@ def receive(args): from .progress import start_progress, update_progress, finish_progress assert isinstance(args.relay_url, type(u"")) - w = Wormhole(APPID, args.relay_url) - if args.zeromode: - assert not args.code - args.code = u"0-" - code = args.code - if not code: - code = w.input_code("Enter receive wormhole code: ", args.code_length) - w.set_code(code) + if True: + w = Wormhole(APPID, args.relay_url) + if args.zeromode: + assert not args.code + args.code = u"0-" + code = args.code + if not code: + code = w.input_code("Enter receive wormhole code: ", args.code_length) + w.set_code(code) - if args.verify: - verifier = binascii.hexlify(w.get_verifier()).decode("ascii") - print("Verifier %s." % verifier) + if args.verify: + verifier = binascii.hexlify(w.get_verifier()).decode("ascii") + print("Verifier %s." % verifier) - try: - them_bytes = w.get_data() - except WrongPasswordError as e: - print("ERROR: " + e.explain(), file=sys.stderr) - w.close() - return 1 - them_d = json.loads(them_bytes.decode("utf-8")) - if "error" in them_d: - print("ERROR: " + them_d["error"], file=sys.stderr) - w.close() - return 1 + try: + them_bytes = w.get_data() + except WrongPasswordError as e: + print("ERROR: " + e.explain(), file=sys.stderr) + w.close() + return 1 + them_d = json.loads(them_bytes.decode("utf-8")) + if "error" in them_d: + print("ERROR: " + them_d["error"], file=sys.stderr) + w.close() + return 1 - if "message" in them_d: - # we're receiving a text message - print(them_d["message"]) - data = json.dumps({"message_ack": "ok"}).encode("utf-8") + if "message" in them_d: + # we're receiving a text message + print(them_d["message"]) + data = json.dumps({"message_ack": "ok"}).encode("utf-8") + w.send_data(data) + w.close() + return 0 + + if not "file" in them_d: + print("I don't know what they're offering\n") + print(them_d) + w.close() + return 1 + + if "error" in them_d: + print("ERROR: " + data["error"], file=sys.stderr) + w.close() + return 1 + + file_data = them_d["file"] + # the basename() is intended to protect us against + # "~/.ssh/authorized_keys" and other attacks + filename = os.path.basename(file_data["filename"]) # unicode + filesize = file_data["filesize"] + + # get confirmation from the user before writing to the local directory + if os.path.exists(filename): + print("Error: refusing to overwrite existing file %s" % (filename,)) + data = json.dumps({"error": "file already exists"}).encode("utf-8") + w.send_data(data) + w.close() + return 1 + + print("Receiving file (%d bytes) into: %s" % (filesize, filename)) + while True and not args.accept_file: + ok = six.moves.input("ok? (y/n): ") + if ok.lower().startswith("y"): + break + print("transfer rejected", file=sys.stderr) + data = json.dumps({"error": "transfer rejected"}).encode("utf-8") + w.send_data(data) + w.close() + return 1 + + transit_receiver = TransitReceiver(args.transit_helper) + data = json.dumps({ + "file_ack": "ok", + "transit": { + "direct_connection_hints": transit_receiver.get_direct_hints(), + "relay_connection_hints": transit_receiver.get_relay_hints(), + }, + }).encode("utf-8") w.send_data(data) w.close() - return 0 - - if not "file" in them_d: - print("I don't know what they're offering\n") - print(them_d) - w.close() - return 1 - - if "error" in them_d: - print("ERROR: " + data["error"], file=sys.stderr) - w.close() - return 1 - - file_data = them_d["file"] - # the basename() is intended to protect us against - # "~/.ssh/authorized_keys" and other attacks - filename = os.path.basename(file_data["filename"]) # unicode - filesize = file_data["filesize"] - - # get confirmation from the user before writing to the local directory - if os.path.exists(filename): - print("Error: refusing to overwrite existing file %s" % (filename,)) - data = json.dumps({"error": "file already exists"}).encode("utf-8") - w.send_data(data) - w.close() - return 1 - - print("Receiving file (%d bytes) into: %s" % (filesize, filename)) - while True and not args.accept_file: - ok = six.moves.input("ok? (y/n): ") - if ok.lower().startswith("y"): - break - print("transfer rejected", file=sys.stderr) - data = json.dumps({"error": "transfer rejected"}).encode("utf-8") - w.send_data(data) - w.close() - return 1 - - transit_receiver = TransitReceiver(args.transit_helper) - data = json.dumps({ - "file_ack": "ok", - "transit": { - "direct_connection_hints": transit_receiver.get_direct_hints(), - "relay_connection_hints": transit_receiver.get_relay_hints(), - }, - }).encode("utf-8") - w.send_data(data) - w.close() # now receive the rest of the owl tdata = them_d["transit"] diff --git a/src/wormhole/scripts/cmd_send.py b/src/wormhole/scripts/cmd_send.py index 6f10c0f..cb3d2f4 100644 --- a/src/wormhole/scripts/cmd_send.py +++ b/src/wormhole/scripts/cmd_send.py @@ -46,70 +46,71 @@ def send(args): }, } - w = Wormhole(APPID, args.relay_url) - if args.zeromode: - assert not args.code - args.code = u"0-" - if args.code: - w.set_code(args.code) - code = args.code - else: - code = w.get_code(args.code_length) - other_cmd = "wormhole receive" - if args.verify: - other_cmd = "wormhole --verify receive" - if args.zeromode: - other_cmd += " -0" - print("On the other computer, please run: %s" % other_cmd) - if not args.zeromode: - print("Wormhole code is: %s" % code) - print("") + if True: + w = Wormhole(APPID, args.relay_url) + if args.zeromode: + assert not args.code + args.code = u"0-" + if args.code: + w.set_code(args.code) + code = args.code + else: + code = w.get_code(args.code_length) + other_cmd = "wormhole receive" + if args.verify: + other_cmd = "wormhole --verify receive" + if args.zeromode: + other_cmd += " -0" + print("On the other computer, please run: %s" % other_cmd) + if not args.zeromode: + print("Wormhole code is: %s" % code) + print("") - if args.verify: - verifier = binascii.hexlify(w.get_verifier()).decode("ascii") - while True: - ok = six.moves.input("Verifier %s. ok? (yes/no): " % verifier) - if ok.lower() == "yes": - break - if ok.lower() == "no": - print("verification rejected, abandoning transfer", - file=sys.stderr) - reject_data = json.dumps({"error": "verification rejected", - }).encode("utf-8") - w.send_data(reject_data) - w.close() - return 1 + if args.verify: + verifier = binascii.hexlify(w.get_verifier()).decode("ascii") + while True: + ok = six.moves.input("Verifier %s. ok? (yes/no): " % verifier) + if ok.lower() == "yes": + break + if ok.lower() == "no": + print("verification rejected, abandoning transfer", + file=sys.stderr) + reject_data = json.dumps({"error": "verification rejected", + }).encode("utf-8") + w.send_data(reject_data) + w.close() + return 1 - my_phase1_bytes = json.dumps(phase1).encode("utf-8") - w.send_data(my_phase1_bytes) - try: - them_phase1_bytes = w.get_data() - except WrongPasswordError as e: - print("ERROR: " + e.explain(), file=sys.stderr) - w.close() - return 1 - them_phase1 = json.loads(them_phase1_bytes.decode("utf-8")) - - if sending_message: - if them_phase1["message_ack"] == "ok": - print("text message sent") + my_phase1_bytes = json.dumps(phase1).encode("utf-8") + w.send_data(my_phase1_bytes) + try: + them_phase1_bytes = w.get_data() + except WrongPasswordError as e: + print("ERROR: " + e.explain(), file=sys.stderr) w.close() - return 0 - print("error sending text: %r" % (them_phase1,)) - w.close() - return 1 + return 1 + them_phase1 = json.loads(them_phase1_bytes.decode("utf-8")) - if "error" in them_phase1: - print("remote error: %s" % them_phase1["error"]) - print("transfer abandoned") + if sending_message: + if them_phase1["message_ack"] == "ok": + print("text message sent") + w.close() + return 0 + print("error sending text: %r" % (them_phase1,)) + w.close() + return 1 + + if "error" in them_phase1: + print("remote error: %s" % them_phase1["error"]) + print("transfer abandoned") + w.close() + return 1 + if them_phase1.get("file_ack") != "ok": + print("ambiguous response from remote: %s" % (them_phase1,)) + print("transfer abandoned") + w.close() + return 1 w.close() - return 1 - if them_phase1.get("file_ack") != "ok": - print("ambiguous response from remote: %s" % (them_phase1,)) - print("transfer abandoned") - w.close() - return 1 - w.close() tdata = them_phase1["transit"] transit_key = w.derive_key(APPID+"/transit-key") From 0748647049c8c9c7f4463c5b2b338b600e7d8b5d Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Wed, 11 Nov 2015 18:17:52 -0800 Subject: [PATCH 6/8] allow multiple close() calls, throw error when using a closed Wormhole --- src/wormhole/blocking/transcribe.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/src/wormhole/blocking/transcribe.py b/src/wormhole/blocking/transcribe.py index f028ad7..927869c 100644 --- a/src/wormhole/blocking/transcribe.py +++ b/src/wormhole/blocking/transcribe.py @@ -106,7 +106,7 @@ class Channel: time.sleep(self._wait) return body - def deallocate(self, mood=u"unknown"): + def deallocate(self, mood=None): # only try once, no retries data = json.dumps({"appid": self._appid, "channelid": self._channelid, @@ -177,6 +177,7 @@ class Wormhole: self.verifier = None self._sent_data = set() # phases self._got_data = set() + self._closed = False def handle_welcome(self, welcome): if ("motd" in welcome and @@ -266,6 +267,7 @@ class Wormhole: self.verifier = self.derive_key(u"wormhole:verifier") def get_verifier(self): + if self._closed: raise UsageError if self.code is None: raise UsageError if self._channel is None: raise UsageError self._get_key() @@ -275,6 +277,7 @@ class Wormhole: if not isinstance(outbound_data, type(b"")): raise TypeError(type(outbound_data)) if not isinstance(phase, type(u"")): raise TypeError(type(phase)) + if self._closed: raise UsageError if phase in self._sent_data: raise UsageError # only call this once if self.code is None: raise UsageError if self._channel is None: raise UsageError @@ -291,6 +294,7 @@ class Wormhole: def get_data(self, phase=u"data"): if not isinstance(phase, type(u"")): raise TypeError(type(phase)) if phase in self._got_data: raise UsageError # only call this once + if self._closed: raise UsageError if self.code is None: raise UsageError if self._channel is None: raise UsageError self._got_data.add(phase) @@ -303,6 +307,9 @@ class Wormhole: except CryptoError: raise WrongPasswordError - def close(self): - monitor.close(self._channel) - self._channel.deallocate() + def close(self, mood=None): + self._closed = True + if self._channel: + c, self._channel = self._channel, None + monitor.close(c) + c.deallocate(mood) From a881d6055f1da43ff21d6440ffd37f36c81ec21b Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Wed, 11 Nov 2015 21:54:45 -0800 Subject: [PATCH 7/8] auto-close Channel (with a "mood") upon server or crypto error --- src/wormhole/blocking/transcribe.py | 32 ++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/src/wormhole/blocking/transcribe.py b/src/wormhole/blocking/transcribe.py index 927869c..8b9004f 100644 --- a/src/wormhole/blocking/transcribe.py +++ b/src/wormhole/blocking/transcribe.py @@ -117,7 +117,8 @@ class Channel: # short timeout and ignore failures requests.post(self._relay_url+"deallocate", data=data, timeout=5) - except requests.exceptions.Timeout: + except (requests.exceptions.ConnectionError, + requests.exceptions.Timeout): pass class ChannelManager: @@ -154,6 +155,29 @@ class ChannelManager: return Channel(self._relay_url, self._appid, channelid, self._side, self._handle_welcome, self._wait, self._timeout) +def close_on_error(f): # method decorator + # Clients report certain errors as "moods", so the server can make a + # rough count failed connections (due to mismatched passwords, attacks, + # or timeouts). We don't report precondition failures, as those are the + # responsibility/fault of the local application code. We count + # non-precondition errors in case they represent server-side problems. + def _f(self, *args, **kwargs): + try: + return f(self, *args, **kwargs) + except Timeout: + self.close(u"lonely") + raise + except WrongPasswordError: + self.close(u"scared") + raise + except (TypeError, UsageError): + # preconditions don't warrant _close_with_error() + raise + except: + self.close(u"other-error") + raise + return _f + class Wormhole: motd_displayed = False version_warning_displayed = False @@ -238,6 +262,7 @@ class Wormhole: idSymmetric=to_bytes(self._appid)) self.msg1 = self.sp.start() + @close_on_error def derive_key(self, purpose, length=SecretBox.KEY_SIZE): if not isinstance(purpose, type(u"")): raise TypeError(type(purpose)) return HKDF(self.key, length, CTXinfo=to_bytes(purpose)) @@ -266,6 +291,7 @@ class Wormhole: self.key = self.sp.finish(pake_msg) self.verifier = self.derive_key(u"wormhole:verifier") + @close_on_error def get_verifier(self): if self._closed: raise UsageError if self.code is None: raise UsageError @@ -273,6 +299,7 @@ class Wormhole: self._get_key() return self.verifier + @close_on_error def send_data(self, outbound_data, phase=u"data"): if not isinstance(outbound_data, type(b"")): raise TypeError(type(outbound_data)) @@ -291,6 +318,7 @@ class Wormhole: outbound_encrypted = self._encrypt_data(data_key, outbound_data) self._channel.send(phase, outbound_encrypted) + @close_on_error def get_data(self, phase=u"data"): if not isinstance(phase, type(u"")): raise TypeError(type(phase)) if phase in self._got_data: raise UsageError # only call this once @@ -308,6 +336,8 @@ class Wormhole: raise WrongPasswordError def close(self, mood=None): + if not isinstance(mood, (type(None), type(u""))): + raise TypeError(type(mood)) self._closed = True if self._channel: c, self._channel = self._channel, None From 80beb206315d489dd336492d9eaa75cf23aa304c Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Wed, 11 Nov 2015 21:59:16 -0800 Subject: [PATCH 8/8] make blocking.Wormhole into a context manager --- docs/api.md | 29 ++++++++++++++++------------- src/wormhole/blocking/transcribe.py | 6 ++++++ src/wormhole/scripts/cmd_receive.py | 11 +---------- src/wormhole/scripts/cmd_send.py | 10 +--------- 4 files changed, 24 insertions(+), 32 deletions(-) diff --git a/docs/api.md b/docs/api.md index e443c41..2166d6c 100644 --- a/docs/api.md +++ b/docs/api.md @@ -56,6 +56,11 @@ suffer longer invitation codes as a result. To encourage `close()`, the library will log an error if a Wormhole object is destroyed before being closed. +To make it easier to call `close()`, the blocking Wormhole objects can be +used as a context manager. Just put your code in the body of a `with +Wormhole(ARGS) as w:` statement, and `close()` will automatically be called +when the block exits (either successfully or due to an exception). + ## Examples The synchronous+blocking flow looks like this: @@ -64,13 +69,12 @@ The synchronous+blocking flow looks like this: from wormhole.blocking.transcribe import Wormhole from wormhole.public_relay import RENDEZVOUS_RELAY mydata = b"initiator's data" -i = Wormhole(u"appid", RENDEZVOUS_RELAY) -code = i.get_code() -print("Invitation Code: %s" % code) -i.send_data(mydata) -theirdata = i.get_data() -i.close() -print("Their data: %s" % theirdata.decode("ascii")) +with Wormhole(u"appid", RENDEZVOUS_RELAY) as i: + code = i.get_code() + print("Invitation Code: %s" % code) + i.send_data(mydata) + theirdata = i.get_data() + print("Their data: %s" % theirdata.decode("ascii")) ``` ```python @@ -79,12 +83,11 @@ from wormhole.blocking.transcribe import Wormhole from wormhole.public_relay import RENDEZVOUS_RELAY mydata = b"receiver's data" code = sys.argv[1] -r = Wormhole(u"appid", RENDEZVOUS_RELAY) -r.set_code(code) -r.send_data(mydata) -theirdata = r.get_data() -r.close() -print("Their data: %s" % theirdata.decode("ascii")) +with Wormhole(u"appid", RENDEZVOUS_RELAY) as r: + r.set_code(code) + r.send_data(mydata) + theirdata = r.get_data() + print("Their data: %s" % theirdata.decode("ascii")) ``` ## Twisted diff --git a/src/wormhole/blocking/transcribe.py b/src/wormhole/blocking/transcribe.py index 8b9004f..c2d3407 100644 --- a/src/wormhole/blocking/transcribe.py +++ b/src/wormhole/blocking/transcribe.py @@ -203,6 +203,12 @@ class Wormhole: self._got_data = set() self._closed = False + def __enter__(self): + return self + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + return False + def handle_welcome(self, welcome): if ("motd" in welcome and not self.motd_displayed): diff --git a/src/wormhole/scripts/cmd_receive.py b/src/wormhole/scripts/cmd_receive.py index cfc98ad..2936638 100644 --- a/src/wormhole/scripts/cmd_receive.py +++ b/src/wormhole/scripts/cmd_receive.py @@ -12,8 +12,7 @@ def receive(args): from .progress import start_progress, update_progress, finish_progress assert isinstance(args.relay_url, type(u"")) - if True: - w = Wormhole(APPID, args.relay_url) + with Wormhole(APPID, args.relay_url) as w: if args.zeromode: assert not args.code args.code = u"0-" @@ -30,12 +29,10 @@ def receive(args): them_bytes = w.get_data() except WrongPasswordError as e: print("ERROR: " + e.explain(), file=sys.stderr) - w.close() return 1 them_d = json.loads(them_bytes.decode("utf-8")) if "error" in them_d: print("ERROR: " + them_d["error"], file=sys.stderr) - w.close() return 1 if "message" in them_d: @@ -43,18 +40,15 @@ def receive(args): print(them_d["message"]) data = json.dumps({"message_ack": "ok"}).encode("utf-8") w.send_data(data) - w.close() return 0 if not "file" in them_d: print("I don't know what they're offering\n") print(them_d) - w.close() return 1 if "error" in them_d: print("ERROR: " + data["error"], file=sys.stderr) - w.close() return 1 file_data = them_d["file"] @@ -68,7 +62,6 @@ def receive(args): print("Error: refusing to overwrite existing file %s" % (filename,)) data = json.dumps({"error": "file already exists"}).encode("utf-8") w.send_data(data) - w.close() return 1 print("Receiving file (%d bytes) into: %s" % (filesize, filename)) @@ -79,7 +72,6 @@ def receive(args): print("transfer rejected", file=sys.stderr) data = json.dumps({"error": "transfer rejected"}).encode("utf-8") w.send_data(data) - w.close() return 1 transit_receiver = TransitReceiver(args.transit_helper) @@ -91,7 +83,6 @@ def receive(args): }, }).encode("utf-8") w.send_data(data) - w.close() # now receive the rest of the owl tdata = them_d["transit"] diff --git a/src/wormhole/scripts/cmd_send.py b/src/wormhole/scripts/cmd_send.py index cb3d2f4..9daf6cb 100644 --- a/src/wormhole/scripts/cmd_send.py +++ b/src/wormhole/scripts/cmd_send.py @@ -46,8 +46,7 @@ def send(args): }, } - if True: - w = Wormhole(APPID, args.relay_url) + with Wormhole(APPID, args.relay_url) as w: if args.zeromode: assert not args.code args.code = u"0-" @@ -78,7 +77,6 @@ def send(args): reject_data = json.dumps({"error": "verification rejected", }).encode("utf-8") w.send_data(reject_data) - w.close() return 1 my_phase1_bytes = json.dumps(phase1).encode("utf-8") @@ -87,30 +85,24 @@ def send(args): them_phase1_bytes = w.get_data() except WrongPasswordError as e: print("ERROR: " + e.explain(), file=sys.stderr) - w.close() return 1 them_phase1 = json.loads(them_phase1_bytes.decode("utf-8")) if sending_message: if them_phase1["message_ack"] == "ok": print("text message sent") - w.close() return 0 print("error sending text: %r" % (them_phase1,)) - w.close() return 1 if "error" in them_phase1: print("remote error: %s" % them_phase1["error"]) print("transfer abandoned") - w.close() return 1 if them_phase1.get("file_ack") != "ok": print("ambiguous response from remote: %s" % (them_phase1,)) print("transfer abandoned") - w.close() return 1 - w.close() tdata = them_phase1["transit"] transit_key = w.derive_key(APPID+"/transit-key")