diff --git a/src/wormhole/cli/cmd_send.py b/src/wormhole/cli/cmd_send.py index 72e3a90..ffb6d02 100644 --- a/src/wormhole/cli/cmd_send.py +++ b/src/wormhole/cli/cmd_send.py @@ -130,6 +130,9 @@ class Sender: them_d = json.loads(them_d_bytes.decode("utf-8")) #print("GOT", them_d) recognized = False + if u"error" in them_d: + raise TransferError("remote error, transfer abandoned: %s" + % them_d["error"]) if u"transit" in them_d: recognized = True yield self._handle_transit(them_d[u"transit"]) @@ -229,9 +232,6 @@ class Sender: returnValue(None) # terminates this function raise TransferError("error sending text: %r" % (them_answer,)) - if "error" in them_answer: - raise TransferError("remote error, transfer abandoned: %s" - % them_answer["error"]) if them_answer.get("file_ack") != "ok": raise TransferError("ambiguous response from remote, " "transfer abandoned: %s" % (them_answer,)) diff --git a/src/wormhole/test/test_scripts.py b/src/wormhole/test/test_scripts.py index d3976cc..a2d042b 100644 --- a/src/wormhole/test/test_scripts.py +++ b/src/wormhole/test/test_scripts.py @@ -430,6 +430,93 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): def test_directory_override(self): return self._do_test(mode="directory", override_filename=True) + @inlineCallbacks + def test_file_noclobber(self): + common_args = ["--hide-progress", "--no-listen", + "--relay-url", self.relayurl, + "--transit-helper", ""] + code = u"1-abc" + message = "test message" + + send_args = common_args + [ "send", "--code", code ] + receive_args = common_args + [ "receive", "--accept-file", code ] + + send_dir = self.mktemp() + os.mkdir(send_dir) + receive_dir = self.mktemp() + os.mkdir(receive_dir) + + send_filename = "testfile" + with open(os.path.join(send_dir, send_filename), "w") as f: + f.write(message) + send_args.append(send_filename) + receive_filename = send_filename + + PRESERVE = "don't clobber me\n" + clobberable = os.path.join(receive_dir, receive_filename) + with open(clobberable, "w") as f: + f.write(PRESERVE) + + sargs = runner.parser.parse_args(send_args) + sargs.cwd = send_dir + sargs.stdout = io.StringIO() + sargs.stderr = io.StringIO() + sargs.timing = DebugTiming() + rargs = runner.parser.parse_args(receive_args) + rargs.cwd = receive_dir + rargs.stdout = io.StringIO() + rargs.stderr = io.StringIO() + rargs.timing = DebugTiming() + send_d = cmd_send.send(sargs) + receive_d = cmd_receive.receive(rargs) + + # both sides will fail because of the pre-existing file + + f = yield self.assertFailure(send_d, TransferError) + self.assertEqual(str(f), "remote error, transfer abandoned: file already exists") + + f = yield self.assertFailure(receive_d, TransferError) + self.assertEqual(str(f), "file already exists") + + send_stdout = sargs.stdout.getvalue() + send_stderr = sargs.stderr.getvalue() + receive_stdout = rargs.stdout.getvalue() + receive_stderr = rargs.stderr.getvalue() + + # all output here comes from a StringIO, which uses \n for + # newlines, even if we're on windows + NL = "\n" + + self.maxDiff = None # show full output for assertion failures + + self.failUnlessEqual(send_stderr, "", + (send_stdout, send_stderr)) + self.failUnlessEqual(receive_stderr, "", + (receive_stdout, receive_stderr)) + + # check sender + self.failUnlessIn("Sending {bytes:d} byte file named '{name}'{NL}" + .format(bytes=len(message), name=send_filename, + NL=NL), send_stdout) + self.failUnlessIn("On the other computer, please run: " + "wormhole receive{NL}" + "Wormhole code is: {code}{NL}{NL}" + .format(code=code, NL=NL), + send_stdout) + self.failIfIn("File sent.. waiting for confirmation{NL}" + "Confirmation received. Transfer complete.{NL}" + .format(NL=NL), send_stdout) + + # check receiver + self.failUnlessIn("Error: " + "refusing to overwrite existing file testfile{NL}" + .format(NL=NL), receive_stdout) + self.failIfIn("Received file written to ", receive_stdout) + fn = os.path.join(receive_dir, receive_filename) + self.failUnless(os.path.exists(fn)) + with open(fn, "r") as f: + self.failUnlessEqual(f.read(), PRESERVE) + class Cleanup(ServerBase, unittest.TestCase): @inlineCallbacks def test_text(self):