Clean up error handling, close channel upon error
This commit is contained in:
commit
1511f96c66
|
@ -1,9 +1,9 @@
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
import os, sys, json, binascii, six, tempfile, zipfile
|
import os, sys, json, binascii, six, tempfile, zipfile
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from twisted.internet import reactor, defer
|
from twisted.internet import reactor
|
||||||
from twisted.internet.defer import inlineCallbacks, returnValue
|
from twisted.internet.defer import inlineCallbacks, returnValue
|
||||||
from ..twisted.transcribe import Wormhole, WrongPasswordError
|
from ..twisted.transcribe import Wormhole
|
||||||
from ..twisted.transit import TransitReceiver
|
from ..twisted.transit import TransitReceiver
|
||||||
from ..errors import TransferError
|
from ..errors import TransferError
|
||||||
|
|
||||||
|
@ -14,6 +14,13 @@ class RespondError(Exception):
|
||||||
self.response = response
|
self.response = response
|
||||||
|
|
||||||
def receive(args, reactor=reactor):
|
def receive(args, reactor=reactor):
|
||||||
|
"""I implement 'wormhole receive'. I return a Deferred that fires with
|
||||||
|
None (for success), or signals one of the following errors:
|
||||||
|
* WrongPasswordError: the two sides didn't use matching passwords
|
||||||
|
* Timeout: something didn't happen fast enough for our tastes
|
||||||
|
* TransferError: the sender rejected the transfer: verifier mismatch
|
||||||
|
* any other error: something unexpected happened
|
||||||
|
"""
|
||||||
return TwistedReceiver(args, reactor).go()
|
return TwistedReceiver(args, reactor).go()
|
||||||
|
|
||||||
|
|
||||||
|
@ -26,9 +33,8 @@ class TwistedReceiver:
|
||||||
def msg(self, *args, **kwargs):
|
def msg(self, *args, **kwargs):
|
||||||
print(*args, file=self.args.stdout, **kwargs)
|
print(*args, file=self.args.stdout, **kwargs)
|
||||||
|
|
||||||
# TODO: @handle_server_error
|
@inlineCallbacks
|
||||||
def go(self):
|
def go(self):
|
||||||
d = defer.succeed(None)
|
|
||||||
tor_manager = None
|
tor_manager = None
|
||||||
if self.args.tor:
|
if self.args.tor:
|
||||||
_start = self.args.timing.add_event("import TorManager")
|
_start = self.args.timing.add_event("import TorManager")
|
||||||
|
@ -39,18 +45,10 @@ class TwistedReceiver:
|
||||||
# tor in parallel with everything else, make sure the TorManager
|
# tor in parallel with everything else, make sure the TorManager
|
||||||
# can lazy-provide an endpoint, and overlap the startup process
|
# can lazy-provide an endpoint, and overlap the startup process
|
||||||
# with the user handing off the wormhole code
|
# with the user handing off the wormhole code
|
||||||
d.addCallback(lambda _: tor_manager.start())
|
yield tor_manager.start()
|
||||||
def _make_wormhole(_):
|
w = Wormhole(APPID, self.args.relay_url, tor_manager,
|
||||||
self._w = Wormhole(APPID, self.args.relay_url, tor_manager,
|
|
||||||
timing=self.args.timing,
|
timing=self.args.timing,
|
||||||
reactor=self._reactor)
|
reactor=self._reactor)
|
||||||
d.addCallback(_make_wormhole)
|
|
||||||
d.addCallback(lambda _: self._go(self._w, tor_manager))
|
|
||||||
def _always_close(res):
|
|
||||||
d2 = self._w.close()
|
|
||||||
d2.addBoth(lambda _: res)
|
|
||||||
return d2
|
|
||||||
d.addBoth(_always_close)
|
|
||||||
# I wanted to do this instead:
|
# I wanted to do this instead:
|
||||||
#
|
#
|
||||||
# try:
|
# try:
|
||||||
|
@ -61,7 +59,9 @@ class TwistedReceiver:
|
||||||
# but when _go had a UsageError, the stacktrace was always displayed
|
# but when _go had a UsageError, the stacktrace was always displayed
|
||||||
# as coming from the "yield self._go" line, which wasn't very useful
|
# as coming from the "yield self._go" line, which wasn't very useful
|
||||||
# for tracking it down.
|
# for tracking it down.
|
||||||
return d
|
d = self._go(w, tor_manager)
|
||||||
|
d.addBoth(w.close)
|
||||||
|
yield d
|
||||||
|
|
||||||
@inlineCallbacks
|
@inlineCallbacks
|
||||||
def _go(self, w, tor_manager):
|
def _go(self, w, tor_manager):
|
||||||
|
@ -72,7 +72,7 @@ class TwistedReceiver:
|
||||||
try:
|
try:
|
||||||
if "message" in them_d:
|
if "message" in them_d:
|
||||||
yield self.handle_text(them_d, w)
|
yield self.handle_text(them_d, w)
|
||||||
returnValue(0)
|
returnValue(None)
|
||||||
if "file" in them_d:
|
if "file" in them_d:
|
||||||
f = self.handle_file(them_d)
|
f = self.handle_file(them_d)
|
||||||
rp = yield self.establish_transit(w, them_d, tor_manager)
|
rp = yield self.establish_transit(w, them_d, tor_manager)
|
||||||
|
@ -92,8 +92,8 @@ class TwistedReceiver:
|
||||||
except RespondError as r:
|
except RespondError as r:
|
||||||
data = json.dumps(r.response).encode("utf-8")
|
data = json.dumps(r.response).encode("utf-8")
|
||||||
yield w.send_data(data)
|
yield w.send_data(data)
|
||||||
raise SystemExit(1)
|
raise TransferError(r["error"])
|
||||||
returnValue(0)
|
returnValue(None)
|
||||||
|
|
||||||
@inlineCallbacks
|
@inlineCallbacks
|
||||||
def handle_code(self, w):
|
def handle_code(self, w):
|
||||||
|
@ -113,10 +113,8 @@ class TwistedReceiver:
|
||||||
|
|
||||||
@inlineCallbacks
|
@inlineCallbacks
|
||||||
def get_data(self, w):
|
def get_data(self, w):
|
||||||
try:
|
# this may raise WrongPasswordError
|
||||||
them_bytes = yield w.get_data()
|
them_bytes = yield w.get_data()
|
||||||
except WrongPasswordError as e:
|
|
||||||
raise TransferError(u"ERROR: " + e.explain())
|
|
||||||
them_d = json.loads(them_bytes.decode("utf-8"))
|
them_d = json.loads(them_bytes.decode("utf-8"))
|
||||||
if "error" in them_d:
|
if "error" in them_d:
|
||||||
raise TransferError(u"ERROR: " + them_d["error"])
|
raise TransferError(u"ERROR: " + them_d["error"])
|
||||||
|
@ -229,7 +227,7 @@ class TwistedReceiver:
|
||||||
self.msg()
|
self.msg()
|
||||||
self.msg(u"Connection dropped before full file received")
|
self.msg(u"Connection dropped before full file received")
|
||||||
self.msg(u"got %d bytes, wanted %d" % (received, self.xfersize))
|
self.msg(u"got %d bytes, wanted %d" % (received, self.xfersize))
|
||||||
returnValue(1) # TODO: exit properly
|
raise TransferError("Connection dropped before full file received")
|
||||||
assert received == self.xfersize
|
assert received == self.xfersize
|
||||||
|
|
||||||
def write_file(self, f):
|
def write_file(self, f):
|
||||||
|
|
|
@ -5,13 +5,21 @@ from twisted.protocols import basic
|
||||||
from twisted.internet import reactor
|
from twisted.internet import reactor
|
||||||
from twisted.internet.defer import inlineCallbacks, returnValue
|
from twisted.internet.defer import inlineCallbacks, returnValue
|
||||||
from ..errors import TransferError
|
from ..errors import TransferError
|
||||||
from ..twisted.transcribe import Wormhole, WrongPasswordError
|
from ..twisted.transcribe import Wormhole
|
||||||
from ..twisted.transit import TransitSender
|
from ..twisted.transit import TransitSender
|
||||||
|
|
||||||
APPID = u"lothar.com/wormhole/text-or-file-xfer"
|
APPID = u"lothar.com/wormhole/text-or-file-xfer"
|
||||||
|
|
||||||
@inlineCallbacks
|
@inlineCallbacks
|
||||||
def send(args, reactor=reactor):
|
def send(args, reactor=reactor):
|
||||||
|
"""I implement 'wormhole send'. I return a Deferred that fires with None
|
||||||
|
(for success), or signals one of the following errors:
|
||||||
|
* WrongPasswordError: the two sides didn't use matching passwords
|
||||||
|
* Timeout: something didn't happen fast enough for our tastes
|
||||||
|
* TransferError: the receiver rejected the transfer: verifier mismatch,
|
||||||
|
permission not granted, ack not successful.
|
||||||
|
* any other error: something unexpected happened
|
||||||
|
"""
|
||||||
assert isinstance(args.relay_url, type(u""))
|
assert isinstance(args.relay_url, type(u""))
|
||||||
if args.zeromode:
|
if args.zeromode:
|
||||||
assert not args.code
|
assert not args.code
|
||||||
|
@ -43,6 +51,13 @@ def send(args, reactor=reactor):
|
||||||
w = Wormhole(APPID, args.relay_url, tor_manager, timing=args.timing,
|
w = Wormhole(APPID, args.relay_url, tor_manager, timing=args.timing,
|
||||||
reactor=reactor)
|
reactor=reactor)
|
||||||
|
|
||||||
|
d = _send(reactor, w, args, phase1, fd_to_send, tor_manager)
|
||||||
|
d.addBoth(w.close)
|
||||||
|
yield d
|
||||||
|
|
||||||
|
@inlineCallbacks
|
||||||
|
def _send(reactor, w, args, phase1, fd_to_send, tor_manager):
|
||||||
|
transit_sender = None
|
||||||
if fd_to_send:
|
if fd_to_send:
|
||||||
transit_sender = TransitSender(args.transit_helper,
|
transit_sender = TransitSender(args.transit_helper,
|
||||||
no_listen=args.no_listen,
|
no_listen=args.no_listen,
|
||||||
|
@ -87,18 +102,15 @@ def send(args, reactor=reactor):
|
||||||
my_phase1_bytes = json.dumps(phase1).encode("utf-8")
|
my_phase1_bytes = json.dumps(phase1).encode("utf-8")
|
||||||
yield w.send_data(my_phase1_bytes)
|
yield w.send_data(my_phase1_bytes)
|
||||||
|
|
||||||
try:
|
# this may raise WrongPasswordError
|
||||||
them_phase1_bytes = yield w.get_data()
|
them_phase1_bytes = yield w.get_data()
|
||||||
except WrongPasswordError as e:
|
|
||||||
raise TransferError(e.explain())
|
|
||||||
|
|
||||||
them_phase1 = json.loads(them_phase1_bytes.decode("utf-8"))
|
them_phase1 = json.loads(them_phase1_bytes.decode("utf-8"))
|
||||||
|
|
||||||
if fd_to_send is None:
|
if fd_to_send is None:
|
||||||
if them_phase1["message_ack"] == "ok":
|
if them_phase1["message_ack"] == "ok":
|
||||||
print(u"text message sent", file=args.stdout)
|
print(u"text message sent", file=args.stdout)
|
||||||
yield w.close()
|
returnValue(None) # terminates this function
|
||||||
returnValue(0) # terminates this function
|
|
||||||
raise TransferError("error sending text: %r" % (them_phase1,))
|
raise TransferError("error sending text: %r" % (them_phase1,))
|
||||||
|
|
||||||
if "error" in them_phase1:
|
if "error" in them_phase1:
|
||||||
|
@ -108,10 +120,12 @@ def send(args, reactor=reactor):
|
||||||
raise TransferError("ambiguous response from remote, "
|
raise TransferError("ambiguous response from remote, "
|
||||||
"transfer abandoned: %s" % (them_phase1,))
|
"transfer abandoned: %s" % (them_phase1,))
|
||||||
tdata = them_phase1["transit"]
|
tdata = them_phase1["transit"]
|
||||||
yield w.close()
|
# XXX the downside of closing above, rather than here, is that it leaves
|
||||||
|
# the channel claimed for a longer time
|
||||||
|
#yield w.close()
|
||||||
yield _send_file_twisted(tdata, transit_sender, fd_to_send,
|
yield _send_file_twisted(tdata, transit_sender, fd_to_send,
|
||||||
args.stdout, args.hide_progress, args.timing)
|
args.stdout, args.hide_progress, args.timing)
|
||||||
returnValue(0)
|
returnValue(None)
|
||||||
|
|
||||||
def build_phase1_data(args):
|
def build_phase1_data(args):
|
||||||
phase1 = {}
|
phase1 = {}
|
||||||
|
|
|
@ -2,7 +2,7 @@ from __future__ import print_function
|
||||||
import os, sys
|
import os, sys
|
||||||
from twisted.internet.defer import maybeDeferred
|
from twisted.internet.defer import maybeDeferred
|
||||||
from twisted.internet.task import react
|
from twisted.internet.task import react
|
||||||
from ..errors import TransferError
|
from ..errors import TransferError, WrongPasswordError, Timeout
|
||||||
from ..timing import DebugTiming
|
from ..timing import DebugTiming
|
||||||
from .cli_args import parser
|
from .cli_args import parser
|
||||||
|
|
||||||
|
@ -35,6 +35,7 @@ def run(reactor, argv, cwd, stdout, stderr, executable=None):
|
||||||
args.timing = timing = DebugTiming()
|
args.timing = timing = DebugTiming()
|
||||||
|
|
||||||
timing.add_event("command dispatch")
|
timing.add_event("command dispatch")
|
||||||
|
# fires with None, or raises an error
|
||||||
d = maybeDeferred(dispatch, args)
|
d = maybeDeferred(dispatch, args)
|
||||||
def _maybe_dump_timing(res):
|
def _maybe_dump_timing(res):
|
||||||
timing.add_event("exit")
|
timing.add_event("exit")
|
||||||
|
@ -43,13 +44,12 @@ def run(reactor, argv, cwd, stdout, stderr, executable=None):
|
||||||
return res
|
return res
|
||||||
d.addBoth(_maybe_dump_timing)
|
d.addBoth(_maybe_dump_timing)
|
||||||
def _explain_error(f):
|
def _explain_error(f):
|
||||||
f.trap(TransferError)
|
# these three errors don't print a traceback, just an explanation
|
||||||
|
f.trap(TransferError, WrongPasswordError, Timeout)
|
||||||
print("ERROR:", f.value, file=stderr)
|
print("ERROR:", f.value, file=stderr)
|
||||||
raise SystemExit(1)
|
raise SystemExit(1)
|
||||||
d.addErrback(_explain_error)
|
d.addErrback(_explain_error)
|
||||||
def _rc(rc):
|
d.addCallback(lambda _: 0)
|
||||||
raise SystemExit(rc)
|
|
||||||
d.addCallback(_rc)
|
|
||||||
return d
|
return d
|
||||||
|
|
||||||
def entry():
|
def entry():
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
import functools, textwrap
|
import functools
|
||||||
|
|
||||||
class ServerError(Exception):
|
class ServerError(Exception):
|
||||||
def __init__(self, message, relay):
|
def __init__(self, message, relay):
|
||||||
|
@ -28,8 +28,6 @@ class WrongPasswordError(Exception):
|
||||||
chance.
|
chance.
|
||||||
"""
|
"""
|
||||||
# or the data blob was corrupted, and that's why decrypt failed
|
# or the data blob was corrupted, and that's why decrypt failed
|
||||||
def explain(self):
|
|
||||||
return textwrap.dedent(self.__doc__)
|
|
||||||
|
|
||||||
class ReflectionAttack(Exception):
|
class ReflectionAttack(Exception):
|
||||||
"""An attacker (or bug) reflected our outgoing message back to us."""
|
"""An attacker (or bug) reflected our outgoing message back to us."""
|
||||||
|
|
|
@ -3,12 +3,12 @@ import os, sys, re, io, zipfile, six
|
||||||
from twisted.trial import unittest
|
from twisted.trial import unittest
|
||||||
from twisted.python import procutils, log
|
from twisted.python import procutils, log
|
||||||
from twisted.internet.utils import getProcessOutputAndValue
|
from twisted.internet.utils import getProcessOutputAndValue
|
||||||
from twisted.internet.defer import inlineCallbacks
|
from twisted.internet.defer import gatherResults, inlineCallbacks
|
||||||
from .. import __version__
|
from .. import __version__
|
||||||
from .common import ServerBase
|
from .common import ServerBase
|
||||||
from ..cli import runner, cmd_send, cmd_receive
|
from ..cli import runner, cmd_send, cmd_receive
|
||||||
from ..cli.cmd_send import build_phase1_data
|
from ..cli.cmd_send import build_phase1_data
|
||||||
from ..errors import TransferError
|
from ..errors import TransferError, WrongPasswordError
|
||||||
from ..timing import DebugTiming
|
from ..timing import DebugTiming
|
||||||
|
|
||||||
class Phase1Data(unittest.TestCase):
|
class Phase1Data(unittest.TestCase):
|
||||||
|
@ -306,15 +306,17 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
|
||||||
path=send_dir)
|
path=send_dir)
|
||||||
receive_d = getProcessOutputAndValue(wormhole_bin, receive_args,
|
receive_d = getProcessOutputAndValue(wormhole_bin, receive_args,
|
||||||
path=receive_dir)
|
path=receive_dir)
|
||||||
send_res = yield send_d
|
(send_res, receive_res) = yield gatherResults([send_d, receive_d],
|
||||||
|
True)
|
||||||
send_stdout = send_res[0].decode("utf-8")
|
send_stdout = send_res[0].decode("utf-8")
|
||||||
send_stderr = send_res[1].decode("utf-8")
|
send_stderr = send_res[1].decode("utf-8")
|
||||||
send_rc = send_res[2]
|
send_rc = send_res[2]
|
||||||
receive_res = yield receive_d
|
|
||||||
receive_stdout = receive_res[0].decode("utf-8")
|
receive_stdout = receive_res[0].decode("utf-8")
|
||||||
receive_stderr = receive_res[1].decode("utf-8")
|
receive_stderr = receive_res[1].decode("utf-8")
|
||||||
receive_rc = receive_res[2]
|
receive_rc = receive_res[2]
|
||||||
NL = os.linesep
|
NL = os.linesep
|
||||||
|
self.assertEqual((send_rc, receive_rc), (0, 0),
|
||||||
|
(send_res, receive_res))
|
||||||
else:
|
else:
|
||||||
sargs = runner.parser.parse_args(send_args)
|
sargs = runner.parser.parse_args(send_args)
|
||||||
sargs.cwd = send_dir
|
sargs.cwd = send_dir
|
||||||
|
@ -330,22 +332,11 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
|
||||||
receive_d = cmd_receive.receive(rargs)
|
receive_d = cmd_receive.receive(rargs)
|
||||||
|
|
||||||
# The sender might fail, leaving the receiver hanging, or vice
|
# The sender might fail, leaving the receiver hanging, or vice
|
||||||
# versa. If either side fails, cancel the other, so it won't
|
# versa. Make sure we don't wait on one side exclusively
|
||||||
# matter which Deferred we wait upon first.
|
|
||||||
|
|
||||||
def _oops(f, which):
|
yield gatherResults([send_d, receive_d], True)
|
||||||
log.msg("test_scripts: %s failed, cancelling both" % which)
|
|
||||||
send_d.cancel()
|
|
||||||
receive_d.cancel()
|
|
||||||
return f
|
|
||||||
send_d.addErrback(_oops, "send_d")
|
|
||||||
receive_d.addErrback(_oops, "receive_d")
|
|
||||||
|
|
||||||
send_rc = yield send_d
|
|
||||||
send_stdout = sargs.stdout.getvalue()
|
send_stdout = sargs.stdout.getvalue()
|
||||||
send_stderr = sargs.stderr.getvalue()
|
send_stderr = sargs.stderr.getvalue()
|
||||||
|
|
||||||
receive_rc = yield receive_d
|
|
||||||
receive_stdout = rargs.stdout.getvalue()
|
receive_stdout = rargs.stdout.getvalue()
|
||||||
receive_stderr = rargs.stderr.getvalue()
|
receive_stderr = rargs.stderr.getvalue()
|
||||||
|
|
||||||
|
@ -355,8 +346,10 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
|
||||||
|
|
||||||
self.maxDiff = None # show full output for assertion failures
|
self.maxDiff = None # show full output for assertion failures
|
||||||
|
|
||||||
self.failUnlessEqual(send_stderr, "")
|
self.failUnlessEqual(send_stderr, "",
|
||||||
self.failUnlessEqual(receive_stderr, "")
|
(send_stdout, send_stderr))
|
||||||
|
self.failUnlessEqual(receive_stderr, "",
|
||||||
|
(receive_stdout, receive_stderr))
|
||||||
|
|
||||||
# check sender
|
# check sender
|
||||||
if mode == "text":
|
if mode == "text":
|
||||||
|
@ -417,9 +410,6 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
|
||||||
with open(fn, "r") as f:
|
with open(fn, "r") as f:
|
||||||
self.failUnlessEqual(f.read(), message(i))
|
self.failUnlessEqual(f.read(), message(i))
|
||||||
|
|
||||||
self.failUnlessEqual(send_rc, 0)
|
|
||||||
self.failUnlessEqual(receive_rc, 0)
|
|
||||||
|
|
||||||
def test_text(self):
|
def test_text(self):
|
||||||
return self._do_test()
|
return self._do_test()
|
||||||
def test_text_subprocess(self):
|
def test_text_subprocess(self):
|
||||||
|
@ -436,3 +426,62 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
|
||||||
return self._do_test(mode="directory", addslash=True)
|
return self._do_test(mode="directory", addslash=True)
|
||||||
def test_directory_override(self):
|
def test_directory_override(self):
|
||||||
return self._do_test(mode="directory", override_filename=True)
|
return self._do_test(mode="directory", override_filename=True)
|
||||||
|
|
||||||
|
class Cleanup(ServerBase, unittest.TestCase):
|
||||||
|
@inlineCallbacks
|
||||||
|
def test_text(self):
|
||||||
|
# the rendezvous channel should be deleted after success
|
||||||
|
code = u"1-abc"
|
||||||
|
common_args = ["--hide-progress",
|
||||||
|
"--relay-url", self.relayurl,
|
||||||
|
"--transit-helper", ""]
|
||||||
|
sargs = runner.parser.parse_args(common_args +
|
||||||
|
["send",
|
||||||
|
"--text", "secret message",
|
||||||
|
"--code", code])
|
||||||
|
sargs.stdout = io.StringIO()
|
||||||
|
sargs.stderr = io.StringIO()
|
||||||
|
sargs.timing = DebugTiming()
|
||||||
|
rargs = runner.parser.parse_args(common_args +
|
||||||
|
["receive", code])
|
||||||
|
rargs.stdout = io.StringIO()
|
||||||
|
rargs.stderr = io.StringIO()
|
||||||
|
rargs.timing = DebugTiming()
|
||||||
|
send_d = cmd_send.send(sargs)
|
||||||
|
receive_d = cmd_receive.receive(rargs)
|
||||||
|
|
||||||
|
yield send_d
|
||||||
|
yield receive_d
|
||||||
|
|
||||||
|
cids = self._rendezvous.get_app(cmd_send.APPID).get_allocated()
|
||||||
|
self.assertEqual(len(cids), 0)
|
||||||
|
|
||||||
|
@inlineCallbacks
|
||||||
|
def test_text_wrong_password(self):
|
||||||
|
# if the password was wrong, the rendezvous channel should still be
|
||||||
|
# deleted
|
||||||
|
common_args = ["--hide-progress",
|
||||||
|
"--relay-url", self.relayurl,
|
||||||
|
"--transit-helper", ""]
|
||||||
|
sargs = runner.parser.parse_args(common_args +
|
||||||
|
["send",
|
||||||
|
"--text", "secret message",
|
||||||
|
"--code", u"1-abc"])
|
||||||
|
sargs.stdout = io.StringIO()
|
||||||
|
sargs.stderr = io.StringIO()
|
||||||
|
sargs.timing = DebugTiming()
|
||||||
|
rargs = runner.parser.parse_args(common_args +
|
||||||
|
["receive", u"1-WRONG"])
|
||||||
|
rargs.stdout = io.StringIO()
|
||||||
|
rargs.stderr = io.StringIO()
|
||||||
|
rargs.timing = DebugTiming()
|
||||||
|
send_d = cmd_send.send(sargs)
|
||||||
|
receive_d = cmd_receive.receive(rargs)
|
||||||
|
|
||||||
|
# both sides should be capable of detecting the mismatch
|
||||||
|
yield self.assertFailure(send_d, WrongPasswordError)
|
||||||
|
yield self.assertFailure(receive_d, WrongPasswordError)
|
||||||
|
|
||||||
|
cids = self._rendezvous.get_app(cmd_send.APPID).get_allocated()
|
||||||
|
self.assertEqual(len(cids), 0)
|
||||||
|
|
||||||
|
|
|
@ -27,30 +27,6 @@ def make_confmsg(confkey, nonce):
|
||||||
def to_bytes(u):
|
def to_bytes(u):
|
||||||
return unicodedata.normalize("NFC", u).encode("utf-8")
|
return unicodedata.normalize("NFC", u).encode("utf-8")
|
||||||
|
|
||||||
def close_on_error(meth): # method decorator
|
|
||||||
# Clients report certain errors as "moods", so the server can make a
|
|
||||||
# rough count failed connections (due to mismatched passwords, attacks,
|
|
||||||
# or timeouts). We don't report precondition failures, as those are the
|
|
||||||
# responsibility/fault of the local application code. We count
|
|
||||||
# non-precondition errors in case they represent server-side problems.
|
|
||||||
def _wrapper(self, *args, **kwargs):
|
|
||||||
d = defer.maybeDeferred(meth, self, *args, **kwargs)
|
|
||||||
def _onerror(f):
|
|
||||||
if f.check(Timeout):
|
|
||||||
d2 = self.close(u"lonely")
|
|
||||||
elif f.check(WrongPasswordError):
|
|
||||||
d2 = self.close(u"scary")
|
|
||||||
elif f.check(TypeError, UsageError):
|
|
||||||
# preconditions don't warrant _close_with_error()
|
|
||||||
d2 = defer.succeed(None)
|
|
||||||
else:
|
|
||||||
d2 = self.close(u"errory")
|
|
||||||
d2.addBoth(lambda _: f)
|
|
||||||
return d2
|
|
||||||
d.addErrback(_onerror)
|
|
||||||
return d
|
|
||||||
return _wrapper
|
|
||||||
|
|
||||||
class WSClient(websocket.WebSocketClientProtocol):
|
class WSClient(websocket.WebSocketClientProtocol):
|
||||||
def onOpen(self):
|
def onOpen(self):
|
||||||
self.wormhole_open = True
|
self.wormhole_open = True
|
||||||
|
@ -181,13 +157,16 @@ class Wormhole:
|
||||||
return self._signal_error(welcome["error"])
|
return self._signal_error(welcome["error"])
|
||||||
|
|
||||||
@inlineCallbacks
|
@inlineCallbacks
|
||||||
def _sleep(self):
|
def _sleep(self, wake_on_error=True):
|
||||||
if self._error: # don't sleep if the bed's already on fire
|
if wake_on_error and self._error:
|
||||||
|
# don't sleep if the bed's already on fire, unless we're waiting
|
||||||
|
# for the fire department to respond, in which case sure, keep on
|
||||||
|
# sleeping
|
||||||
raise self._error
|
raise self._error
|
||||||
d = defer.Deferred()
|
d = defer.Deferred()
|
||||||
self._sleepers.append(d)
|
self._sleepers.append(d)
|
||||||
yield d
|
yield d
|
||||||
if self._error:
|
if wake_on_error and self._error:
|
||||||
raise self._error
|
raise self._error
|
||||||
|
|
||||||
def _wakeup(self):
|
def _wakeup(self):
|
||||||
|
@ -366,7 +345,6 @@ class Wormhole:
|
||||||
self._msg1 = unhexlify(d["msg1"].encode("ascii"))
|
self._msg1 = unhexlify(d["msg1"].encode("ascii"))
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@close_on_error
|
|
||||||
@inlineCallbacks
|
@inlineCallbacks
|
||||||
def get_verifier(self):
|
def get_verifier(self):
|
||||||
if self._closed: raise UsageError
|
if self._closed: raise UsageError
|
||||||
|
@ -458,7 +436,6 @@ class Wormhole:
|
||||||
data = box.decrypt(encrypted)
|
data = box.decrypt(encrypted)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@close_on_error
|
|
||||||
@inlineCallbacks
|
@inlineCallbacks
|
||||||
def send_data(self, outbound_data, phase=u"data", wait=False):
|
def send_data(self, outbound_data, phase=u"data", wait=False):
|
||||||
if not isinstance(outbound_data, type(b"")):
|
if not isinstance(outbound_data, type(b"")):
|
||||||
|
@ -481,7 +458,6 @@ class Wormhole:
|
||||||
yield self._msg_send(phase, outbound_encrypted, wait)
|
yield self._msg_send(phase, outbound_encrypted, wait)
|
||||||
self._timing.finish_event(_sent)
|
self._timing.finish_event(_sent)
|
||||||
|
|
||||||
@close_on_error
|
|
||||||
@inlineCallbacks
|
@inlineCallbacks
|
||||||
def get_data(self, phase=u"data"):
|
def get_data(self, phase=u"data"):
|
||||||
if not isinstance(phase, type(u"")): raise TypeError(type(phase))
|
if not isinstance(phase, type(u"")): raise TypeError(type(phase))
|
||||||
|
@ -506,26 +482,42 @@ class Wormhole:
|
||||||
# TODO: schedule reconnect, unless we're done
|
# TODO: schedule reconnect, unless we're done
|
||||||
|
|
||||||
@inlineCallbacks
|
@inlineCallbacks
|
||||||
def close(self, res=None, mood=u"happy"):
|
def close(self, f=None, mood=None):
|
||||||
if not isinstance(mood, (type(None), type(u""))):
|
"""Do d.addBoth(w.close) at the end of your chain."""
|
||||||
raise TypeError(type(mood))
|
|
||||||
if self._closed:
|
if self._closed:
|
||||||
returnValue(None)
|
returnValue(None)
|
||||||
self._closed = True
|
self._closed = True
|
||||||
if not self._ws:
|
if not self._ws:
|
||||||
returnValue(None)
|
returnValue(None)
|
||||||
|
|
||||||
|
if mood is None:
|
||||||
|
mood = u"happy"
|
||||||
|
if f:
|
||||||
|
if f.check(Timeout):
|
||||||
|
mood = u"lonely"
|
||||||
|
elif f.check(WrongPasswordError):
|
||||||
|
mood = u"scary"
|
||||||
|
elif f.check(TypeError, UsageError):
|
||||||
|
# preconditions don't warrant reporting mood
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
mood = u"errory" # other errors do
|
||||||
|
if not isinstance(mood, (type(None), type(u""))):
|
||||||
|
raise TypeError(type(mood))
|
||||||
|
|
||||||
self._timing.finish_event(self._timing_started, mood=mood)
|
self._timing.finish_event(self._timing_started, mood=mood)
|
||||||
yield self._deallocate(mood)
|
yield self._deallocate(mood)
|
||||||
# TODO: mark WebSocket as don't-reconnect
|
# TODO: mark WebSocket as don't-reconnect
|
||||||
self._ws.transport.loseConnection() # probably flushes
|
self._ws.transport.loseConnection() # probably flushes
|
||||||
del self._ws
|
del self._ws
|
||||||
|
returnValue(f)
|
||||||
|
|
||||||
@inlineCallbacks
|
@inlineCallbacks
|
||||||
def _deallocate(self, mood=None):
|
def _deallocate(self, mood):
|
||||||
_sent = self._timing.add_event("close")
|
_sent = self._timing.add_event("close")
|
||||||
yield self._ws_send(u"deallocate", mood=mood)
|
yield self._ws_send(u"deallocate", mood=mood)
|
||||||
while self._deallocated_status is None:
|
while self._deallocated_status is None:
|
||||||
yield self._sleep()
|
yield self._sleep(wake_on_error=False)
|
||||||
self._timing.finish_event(_sent)
|
self._timing.finish_event(_sent)
|
||||||
# TODO: set a timeout, don't wait forever for an ack
|
# TODO: set a timeout, don't wait forever for an ack
|
||||||
# TODO: if the connection is lost, let it go
|
# TODO: if the connection is lost, let it go
|
||||||
|
|
Loading…
Reference in New Issue
Block a user