From 80beb206315d489dd336492d9eaa75cf23aa304c Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Wed, 11 Nov 2015 21:59:16 -0800 Subject: [PATCH] make blocking.Wormhole into a context manager --- docs/api.md | 29 ++++++++++++++++------------- src/wormhole/blocking/transcribe.py | 6 ++++++ src/wormhole/scripts/cmd_receive.py | 11 +---------- src/wormhole/scripts/cmd_send.py | 10 +--------- 4 files changed, 24 insertions(+), 32 deletions(-) diff --git a/docs/api.md b/docs/api.md index e443c41..2166d6c 100644 --- a/docs/api.md +++ b/docs/api.md @@ -56,6 +56,11 @@ suffer longer invitation codes as a result. To encourage `close()`, the library will log an error if a Wormhole object is destroyed before being closed. +To make it easier to call `close()`, the blocking Wormhole objects can be +used as a context manager. Just put your code in the body of a `with +Wormhole(ARGS) as w:` statement, and `close()` will automatically be called +when the block exits (either successfully or due to an exception). + ## Examples The synchronous+blocking flow looks like this: @@ -64,13 +69,12 @@ The synchronous+blocking flow looks like this: from wormhole.blocking.transcribe import Wormhole from wormhole.public_relay import RENDEZVOUS_RELAY mydata = b"initiator's data" -i = Wormhole(u"appid", RENDEZVOUS_RELAY) -code = i.get_code() -print("Invitation Code: %s" % code) -i.send_data(mydata) -theirdata = i.get_data() -i.close() -print("Their data: %s" % theirdata.decode("ascii")) +with Wormhole(u"appid", RENDEZVOUS_RELAY) as i: + code = i.get_code() + print("Invitation Code: %s" % code) + i.send_data(mydata) + theirdata = i.get_data() + print("Their data: %s" % theirdata.decode("ascii")) ``` ```python @@ -79,12 +83,11 @@ from wormhole.blocking.transcribe import Wormhole from wormhole.public_relay import RENDEZVOUS_RELAY mydata = b"receiver's data" code = sys.argv[1] -r = Wormhole(u"appid", RENDEZVOUS_RELAY) -r.set_code(code) -r.send_data(mydata) -theirdata = r.get_data() -r.close() -print("Their data: %s" % theirdata.decode("ascii")) +with Wormhole(u"appid", RENDEZVOUS_RELAY) as r: + r.set_code(code) + r.send_data(mydata) + theirdata = r.get_data() + print("Their data: %s" % theirdata.decode("ascii")) ``` ## Twisted diff --git a/src/wormhole/blocking/transcribe.py b/src/wormhole/blocking/transcribe.py index 8b9004f..c2d3407 100644 --- a/src/wormhole/blocking/transcribe.py +++ b/src/wormhole/blocking/transcribe.py @@ -203,6 +203,12 @@ class Wormhole: self._got_data = set() self._closed = False + def __enter__(self): + return self + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + return False + def handle_welcome(self, welcome): if ("motd" in welcome and not self.motd_displayed): diff --git a/src/wormhole/scripts/cmd_receive.py b/src/wormhole/scripts/cmd_receive.py index cfc98ad..2936638 100644 --- a/src/wormhole/scripts/cmd_receive.py +++ b/src/wormhole/scripts/cmd_receive.py @@ -12,8 +12,7 @@ def receive(args): from .progress import start_progress, update_progress, finish_progress assert isinstance(args.relay_url, type(u"")) - if True: - w = Wormhole(APPID, args.relay_url) + with Wormhole(APPID, args.relay_url) as w: if args.zeromode: assert not args.code args.code = u"0-" @@ -30,12 +29,10 @@ def receive(args): them_bytes = w.get_data() except WrongPasswordError as e: print("ERROR: " + e.explain(), file=sys.stderr) - w.close() return 1 them_d = json.loads(them_bytes.decode("utf-8")) if "error" in them_d: print("ERROR: " + them_d["error"], file=sys.stderr) - w.close() return 1 if "message" in them_d: @@ -43,18 +40,15 @@ def receive(args): print(them_d["message"]) data = json.dumps({"message_ack": "ok"}).encode("utf-8") w.send_data(data) - w.close() return 0 if not "file" in them_d: print("I don't know what they're offering\n") print(them_d) - w.close() return 1 if "error" in them_d: print("ERROR: " + data["error"], file=sys.stderr) - w.close() return 1 file_data = them_d["file"] @@ -68,7 +62,6 @@ def receive(args): print("Error: refusing to overwrite existing file %s" % (filename,)) data = json.dumps({"error": "file already exists"}).encode("utf-8") w.send_data(data) - w.close() return 1 print("Receiving file (%d bytes) into: %s" % (filesize, filename)) @@ -79,7 +72,6 @@ def receive(args): print("transfer rejected", file=sys.stderr) data = json.dumps({"error": "transfer rejected"}).encode("utf-8") w.send_data(data) - w.close() return 1 transit_receiver = TransitReceiver(args.transit_helper) @@ -91,7 +83,6 @@ def receive(args): }, }).encode("utf-8") w.send_data(data) - w.close() # now receive the rest of the owl tdata = them_d["transit"] diff --git a/src/wormhole/scripts/cmd_send.py b/src/wormhole/scripts/cmd_send.py index cb3d2f4..9daf6cb 100644 --- a/src/wormhole/scripts/cmd_send.py +++ b/src/wormhole/scripts/cmd_send.py @@ -46,8 +46,7 @@ def send(args): }, } - if True: - w = Wormhole(APPID, args.relay_url) + with Wormhole(APPID, args.relay_url) as w: if args.zeromode: assert not args.code args.code = u"0-" @@ -78,7 +77,6 @@ def send(args): reject_data = json.dumps({"error": "verification rejected", }).encode("utf-8") w.send_data(reject_data) - w.close() return 1 my_phase1_bytes = json.dumps(phase1).encode("utf-8") @@ -87,30 +85,24 @@ def send(args): them_phase1_bytes = w.get_data() except WrongPasswordError as e: print("ERROR: " + e.explain(), file=sys.stderr) - w.close() return 1 them_phase1 = json.loads(them_phase1_bytes.decode("utf-8")) if sending_message: if them_phase1["message_ack"] == "ok": print("text message sent") - w.close() return 0 print("error sending text: %r" % (them_phase1,)) - w.close() return 1 if "error" in them_phase1: print("remote error: %s" % them_phase1["error"]) print("transfer abandoned") - w.close() return 1 if them_phase1.get("file_ack") != "ok": print("ambiguous response from remote: %s" % (them_phase1,)) print("transfer abandoned") - w.close() return 1 - w.close() tdata = them_phase1["transit"] transit_key = w.derive_key(APPID+"/transit-key")