diff --git a/src/wormhole/cli/cmd_send.py b/src/wormhole/cli/cmd_send.py index 159cee6..d0a41bf 100644 --- a/src/wormhole/cli/cmd_send.py +++ b/src/wormhole/cli/cmd_send.py @@ -290,8 +290,10 @@ class Sender: with self._timing.add("tx file"): with progress: - yield fs.beginFileTransfer(self._fd_to_send, record_pipe, - transform=_count_and_hash) + if filesize: + # don't send zero-length files + yield fs.beginFileTransfer(self._fd_to_send, record_pipe, + transform=_count_and_hash) expected_hash = hasher.digest() expected_hex = bytes_to_hexstr(expected_hash) diff --git a/src/wormhole/test/test_scripts.py b/src/wormhole/test/test_scripts.py index 007af62..c5e2ede 100644 --- a/src/wormhole/test/test_scripts.py +++ b/src/wormhole/test/test_scripts.py @@ -239,7 +239,7 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): def _do_test(self, as_subprocess=False, mode="text", addslash=False, override_filename=False, fake_tor=False): - assert mode in ("text", "file", "directory", "slow-text") + assert mode in ("text", "file", "empty-file", "directory", "slow-text") if fake_tor: assert not as_subprocess send_cfg = config("send") @@ -260,10 +260,12 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): receive_dir = self.mktemp() os.mkdir(receive_dir) - if mode == "text" or mode == "slow-text": + if mode in ("text", "slow-text"): send_cfg.text = message - elif mode == "file": + elif mode in ("file", "empty-file"): + if mode == "empty-file": + message = "" send_filename = "testfile" with open(os.path.join(send_dir, send_filename), "w") as f: f.write(message) @@ -503,6 +505,8 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): return self._do_test(mode="file", override_filename=True) def test_file_tor(self): return self._do_test(mode="file", fake_tor=True) + def test_empty_file(self): + return self._do_test(mode="empty-file") def test_directory(self): return self._do_test(mode="directory") diff --git a/src/wormhole/transit.py b/src/wormhole/transit.py index eca9a85..baae431 100644 --- a/src/wormhole/transit.py +++ b/src/wormhole/transit.py @@ -374,7 +374,8 @@ class Connection(protocol.Protocol, policies.TimeoutMixin): least that number of bytes have been written. This function will then return a Deferred (that fires with the number of bytes actually received). If the connection is lost while this Deferred is - outstanding, it will errback. + outstanding, it will errback. If 'expected' is 0, the Deferred will + fire right away. If 'expected' is None, then this function returns None instead of a Deferred, and you must call disconnectConsumer() when you are done.""" @@ -402,6 +403,9 @@ class Connection(protocol.Protocol, policies.TimeoutMixin): if expected is not None: d = defer.Deferred() self._consumer_deferred = d + if expected == 0: + # write empty record to kick consumer into shutdown + self._writeToConsumer(b"") # drain any pending records while self._consumer and self._inbound_records: r = self._inbound_records.popleft() @@ -428,6 +432,7 @@ class Connection(protocol.Protocol, policies.TimeoutMixin): # optional callable which will be called on each write (with the number # of bytes written). Returns a Deferred that fires (with the number of # bytes written) when the count is reached or the RecordPipe is closed. + def writeToFile(self, f, expected, progress=None, hasher=None): fc = FileConsumer(f, progress, hasher) return self.connectConsumer(fc, expected)