don't hang when asked to send a zero-length file

closes #98
This commit is contained in:
Brian Warner 2017-01-16 17:29:40 -05:00
parent 360ad70667
commit 862820679c
3 changed files with 17 additions and 6 deletions

View File

@ -290,8 +290,10 @@ class Sender:
with self._timing.add("tx file"): with self._timing.add("tx file"):
with progress: with progress:
yield fs.beginFileTransfer(self._fd_to_send, record_pipe, if filesize:
transform=_count_and_hash) # don't send zero-length files
yield fs.beginFileTransfer(self._fd_to_send, record_pipe,
transform=_count_and_hash)
expected_hash = hasher.digest() expected_hash = hasher.digest()
expected_hex = bytes_to_hexstr(expected_hash) expected_hex = bytes_to_hexstr(expected_hash)

View File

@ -239,7 +239,7 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
def _do_test(self, as_subprocess=False, def _do_test(self, as_subprocess=False,
mode="text", addslash=False, override_filename=False, mode="text", addslash=False, override_filename=False,
fake_tor=False): fake_tor=False):
assert mode in ("text", "file", "directory", "slow-text") assert mode in ("text", "file", "empty-file", "directory", "slow-text")
if fake_tor: if fake_tor:
assert not as_subprocess assert not as_subprocess
send_cfg = config("send") send_cfg = config("send")
@ -260,10 +260,12 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
receive_dir = self.mktemp() receive_dir = self.mktemp()
os.mkdir(receive_dir) os.mkdir(receive_dir)
if mode == "text" or mode == "slow-text": if mode in ("text", "slow-text"):
send_cfg.text = message send_cfg.text = message
elif mode == "file": elif mode in ("file", "empty-file"):
if mode == "empty-file":
message = ""
send_filename = "testfile" send_filename = "testfile"
with open(os.path.join(send_dir, send_filename), "w") as f: with open(os.path.join(send_dir, send_filename), "w") as f:
f.write(message) f.write(message)
@ -503,6 +505,8 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase):
return self._do_test(mode="file", override_filename=True) return self._do_test(mode="file", override_filename=True)
def test_file_tor(self): def test_file_tor(self):
return self._do_test(mode="file", fake_tor=True) return self._do_test(mode="file", fake_tor=True)
def test_empty_file(self):
return self._do_test(mode="empty-file")
def test_directory(self): def test_directory(self):
return self._do_test(mode="directory") return self._do_test(mode="directory")

View File

@ -374,7 +374,8 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
least that number of bytes have been written. This function will then least that number of bytes have been written. This function will then
return a Deferred (that fires with the number of bytes actually return a Deferred (that fires with the number of bytes actually
received). If the connection is lost while this Deferred is received). If the connection is lost while this Deferred is
outstanding, it will errback. outstanding, it will errback. If 'expected' is 0, the Deferred will
fire right away.
If 'expected' is None, then this function returns None instead of a If 'expected' is None, then this function returns None instead of a
Deferred, and you must call disconnectConsumer() when you are done.""" Deferred, and you must call disconnectConsumer() when you are done."""
@ -402,6 +403,9 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
if expected is not None: if expected is not None:
d = defer.Deferred() d = defer.Deferred()
self._consumer_deferred = d self._consumer_deferred = d
if expected == 0:
# write empty record to kick consumer into shutdown
self._writeToConsumer(b"")
# drain any pending records # drain any pending records
while self._consumer and self._inbound_records: while self._consumer and self._inbound_records:
r = self._inbound_records.popleft() r = self._inbound_records.popleft()
@ -428,6 +432,7 @@ class Connection(protocol.Protocol, policies.TimeoutMixin):
# optional callable which will be called on each write (with the number # optional callable which will be called on each write (with the number
# of bytes written). Returns a Deferred that fires (with the number of # of bytes written). Returns a Deferred that fires (with the number of
# bytes written) when the count is reached or the RecordPipe is closed. # bytes written) when the count is reached or the RecordPipe is closed.
def writeToFile(self, f, expected, progress=None, hasher=None): def writeToFile(self, f, expected, progress=None, hasher=None):
fc = FileConsumer(f, progress, hasher) fc = FileConsumer(f, progress, hasher)
return self.connectConsumer(fc, expected) return self.connectConsumer(fc, expected)