diff --git a/fs/sshfs/sshfs.py b/fs/sshfs/sshfs.py index aaa4cda..89f59da 100644 --- a/fs/sshfs/sshfs.py +++ b/fs/sshfs/sshfs.py @@ -27,6 +27,13 @@ from .error_tools import convert_sshfs_errors +def _get_version_number(): + """ + Returns version number of paramiko as a tuple of ints. + """ + return tuple(int(x if x.isdigit() else "0") for x in getattr(paramiko, "__version__", "0.0.0").split(".")) + + class SSHFS(FS): """A SSH filesystem using SFTP. @@ -408,6 +415,13 @@ def download(self, path, file, chunk_size=None, callback=None, **options): so far and the total bytes to be transferred. Passed transparently to `~paramiko.SFTP.getfo`. + Keyword Arguments: + prefetch (bool): Controls whether prefetching is performed + Defaults to ``True``. Applicable only for paramiko >= 2.8.0 + max_concurrent_prefetch_requests (int): The maximum number of concurrent read requests to prefetch. See + `.SFTPClient.get` (its ``max_concurrent_prefetch_requests`` param) + for details. Applicable only for paramiko >= 3.3.0 + Note that the file object ``file`` will *not* be closed by this method. Take care to close it after this method completes (ideally with a context manager). @@ -423,8 +437,24 @@ def download(self, path, file, chunk_size=None, callback=None, **options): raise errors.ResourceNotFound(path) elif self.isdir(_path): raise errors.FileExpected(path) + + _options = options.copy() + _version = _get_version_number() + if _version[0] >= 2: + if _version[0] == 2 and _version[1] < 8: + _options.pop('prefetch', None) + if bool(_version[0] == 3 and _version[1] < 3) or _version[0] < 3: + _options.pop('max_concurrent_prefetch_requests', None) + else: + _options = {} + with convert_sshfs_errors('download', path): - self._sftp.getfo(_path, file, callback=callback) + self._sftp.getfo( + _path, + file, + callback=callback, + **_options + ) def upload(self, path, file, chunk_size=None, callback=None, file_size=None, confirm=True, **options): """Set a file to the contents of a binary file object. diff --git a/tests/test_sshfs.py b/tests/test_sshfs.py index 77d48c2..cd65616 100644 --- a/tests/test_sshfs.py +++ b/tests/test_sshfs.py @@ -2,6 +2,7 @@ from __future__ import absolute_import from __future__ import unicode_literals +import io import stat import sys import time @@ -202,3 +203,20 @@ def test_setinfo(self): now = int(time.time()) with utils.mock.patch("time.time", lambda: now): super(TestSSHFS, self).test_setinfo() + + def test_download_prefetch(self): + # SSHFS does not support prefetching + test_bytes = b"Hello, World" + self.fs.writebytes("hello.bin", test_bytes) + + write_file = io.BytesIO() + self.fs.download("hello.bin", write_file, prefetch=False) + self.assertEqual(write_file.getvalue(), test_bytes) + + write_file = io.BytesIO() + self.fs.download("hello.bin", write_file, prefetch=True) + self.assertEqual(write_file.getvalue(), test_bytes) + + write_file = io.BytesIO() + self.fs.download("hello.bin", write_file, prefetch=True, max_concurrent_prefetch_requests=2) + self.assertEqual(write_file.getvalue(), test_bytes)