-
Notifications
You must be signed in to change notification settings - Fork 322
feat: add File.to_tempfile method and optimize range requests #5226
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
64149f8
Add File.to_tempfile method and optimize range requests
universalmind303 66c572d
Update daft/file.py
universalmind303 33be053
Merge branch 'main' of https://github.com/Eventual-Inc/Daft into cory…
universalmind303 af6f15a
more tests
universalmind303 7004048
Merge branch 'cory/file-tempfile' of https://github.com/Eventual-Inc/…
universalmind303 968a2c0
Merge branch 'main' of https://github.com/Eventual-Inc/Daft into cory…
universalmind303 9463fbd
sort imports
universalmind303 affd2a2
Merge branch 'main' of https://github.com/Eventual-Inc/Daft into cory…
universalmind303 e9cb0eb
simplify supports_range
universalmind303 ada415c
use reqwest acceptranges instead of own
universalmind303 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -186,6 +186,57 @@ impl PyDaftFile { | |
fn closed(&self) -> PyResult<bool> { | ||
Ok(self.cursor.is_none()) | ||
} | ||
|
||
fn supports_range_requests(&mut self) -> PyResult<bool> { | ||
let cursor = self | ||
.cursor | ||
.as_mut() | ||
.ok_or_else(|| PyIOError::new_err("File not open"))?; | ||
|
||
// Try to read a single byte from the beginning | ||
let supports_range = match cursor { | ||
FileCursor::ObjectReader(reader) => { | ||
let rt = common_runtime::get_io_runtime(true); | ||
let inner_reader = reader.get_ref(); | ||
let uri = inner_reader.uri.clone(); | ||
let source = inner_reader.source.clone(); | ||
|
||
rt.block_within_async_context(async move { | ||
source.supports_range(&uri).await.map_err(DaftError::from) | ||
})?? | ||
} | ||
FileCursor::Memory(_) => true, | ||
}; | ||
|
||
Ok(supports_range) | ||
} | ||
|
||
fn size(&mut self) -> PyResult<usize> { | ||
let cursor = self | ||
.cursor | ||
.as_mut() | ||
.ok_or_else(|| PyIOError::new_err("File not open"))?; | ||
|
||
match cursor { | ||
FileCursor::ObjectReader(reader) => { | ||
let reader = reader.get_ref(); | ||
let source = reader.source.clone(); | ||
let uri = reader.uri.clone(); | ||
let io_stats = reader.io_stats.clone(); | ||
|
||
let rt = common_runtime::get_io_runtime(true); | ||
|
||
let size = rt.block_within_async_context(async move { | ||
source | ||
.get_size(&uri, io_stats) | ||
.await | ||
.map_err(|e| PyIOError::new_err(e.to_string())) | ||
})??; | ||
Ok(size) | ||
} | ||
FileCursor::Memory(mem_cursor) => Ok(mem_cursor.get_ref().len()), | ||
} | ||
} | ||
} | ||
|
||
#[cfg(feature = "python")] | ||
|
@@ -201,6 +252,10 @@ struct ObjectSourceReader { | |
uri: String, | ||
position: usize, | ||
io_stats: Option<IOStatsRef>, | ||
// Cache for full file content when range requests aren't supported | ||
cached_content: Option<Vec<u8>>, | ||
// Flag to track if range requests are supported | ||
supports_range: Option<bool>, | ||
} | ||
|
||
impl ObjectSourceReader { | ||
|
@@ -210,89 +265,195 @@ impl ObjectSourceReader { | |
uri, | ||
position: 0, | ||
io_stats, | ||
cached_content: None, | ||
supports_range: None, | ||
} | ||
} | ||
// Helper to read the entire file content | ||
fn read_full_content(&self) -> io::Result<Vec<u8>> { | ||
let rt = common_runtime::get_io_runtime(true); | ||
|
||
let source = self.source.clone(); | ||
let uri = self.uri.clone(); | ||
let io_stats = self.io_stats.clone(); | ||
|
||
rt.block_within_async_context(async move { | ||
let result = source | ||
.get(&uri, None, io_stats) | ||
.await | ||
.map_err(map_get_error)?; | ||
|
||
result | ||
.bytes() | ||
.await | ||
.map(|b| b.to_vec()) | ||
.map_err(map_bytes_error) | ||
}) | ||
.map_err(map_async_error) | ||
.flatten() | ||
} | ||
} | ||
|
||
// Implement Read for synchronous reading | ||
|
||
impl Read for ObjectSourceReader { | ||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { | ||
if buf.is_empty() { | ||
return Ok(0); | ||
} | ||
|
||
// If we have cached content, serve from cache | ||
if let Some(ref content) = self.cached_content { | ||
if self.position >= content.len() { | ||
return Ok(0); // EOF | ||
} | ||
|
||
let available = content.len() - self.position; | ||
let bytes_to_read = std::cmp::min(buf.len(), available); | ||
|
||
buf[..bytes_to_read] | ||
.copy_from_slice(&content[self.position..self.position + bytes_to_read]); | ||
self.position += bytes_to_read; | ||
|
||
return Ok(bytes_to_read); | ||
} | ||
|
||
// First time reading, or range support is known | ||
let rt = common_runtime::get_io_runtime(true); | ||
let start = self.position; | ||
let end = start + buf.len(); | ||
let range = Some(GetRange::Bounded(start..end)); | ||
|
||
// If we already know range requests aren't supported, read full content | ||
if self.supports_range == Some(false) { | ||
// Read entire file and cache it | ||
let content = self.read_full_content()?; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. admittedly, I'm a bit on the fence about this. I think it could cause some unexpected memory usage. But it does make the api easier to use. I've been thinking about if this should be configurable, or opt-in/out. something like daft.set_execution_config(file_on_unsupported_range_request="download" | "error")
daft.set_execution_config(file_on_unsupported_range_request_max_download_size=1024 * 50) # 50MB max |
||
|
||
// Determine how many bytes to return from the full content | ||
let bytes_to_read = if start < content.len() { | ||
let end = std::cmp::min(end, content.len()); | ||
let bytes_to_read = end - start; | ||
|
||
// Copy the requested portion to the output buffer | ||
buf[..bytes_to_read].copy_from_slice(&content[start..end]); | ||
|
||
bytes_to_read | ||
} else { | ||
0 // Position is beyond EOF | ||
}; | ||
|
||
// Update position and cache the content | ||
self.position += bytes_to_read; | ||
self.cached_content = Some(content); | ||
|
||
return Ok(bytes_to_read); | ||
} | ||
|
||
// Try range request if support is unknown or known to be supported | ||
let range = Some(GetRange::Bounded(start..end)); | ||
let source = self.source.clone(); | ||
let uri = self.uri.clone(); | ||
let io_stats = self.io_stats.clone(); | ||
|
||
let bytes = rt | ||
let range_result = rt | ||
.block_within_async_context(async move { | ||
let result = source | ||
.get(&uri, range, io_stats) | ||
.await | ||
.map_err(map_get_error)?; | ||
result.bytes().await.map_err(map_bytes_error) | ||
match source.get(&uri, range, io_stats.clone()).await { | ||
Ok(result) => { | ||
let bytes = result.bytes().await.map_err(map_bytes_error)?; | ||
Ok((bytes.to_vec(), true)) // Range request succeeded | ||
} | ||
Err(e) => { | ||
// EOF | ||
if let daft_io::Error::InvalidRangeRequest { | ||
source: daft_io::range::InvalidGetRange::StartTooLarge { .. }, | ||
} = e | ||
{ | ||
Ok((Vec::new(), true)) | ||
} else { | ||
let error_str = e.to_string(); | ||
// Check if error suggests range requests aren't supported | ||
if error_str.contains("Requested Range Not Satisfiable") | ||
|| error_str.contains("416") | ||
{ | ||
// Fall back to reading the entire file | ||
let result = source | ||
.get(&uri, None, io_stats) | ||
.await | ||
.map_err(map_get_error)?; | ||
|
||
let bytes = result.bytes().await.map_err(map_bytes_error)?; | ||
Ok((bytes.to_vec(), false)) // Range request not supported | ||
} else { | ||
Err(map_get_error(e)) | ||
} | ||
} | ||
} | ||
} | ||
}) | ||
.map_err(map_async_error)??; | ||
|
||
if bytes.is_empty() { | ||
return Ok(0); | ||
} | ||
let (bytes, supports_range) = range_result; | ||
self.supports_range = Some(supports_range); | ||
|
||
let bytes_to_copy = std::cmp::min(buf.len(), bytes.len()); | ||
buf[..bytes_to_copy].copy_from_slice(&bytes[..bytes_to_copy]); | ||
if !supports_range { | ||
// Range requests not supported - cache the full content | ||
let bytes_to_read = if start < bytes.len() { | ||
let end = std::cmp::min(end, bytes.len()); | ||
let bytes_to_read = end - start; | ||
|
||
self.position += bytes_to_copy; | ||
// Copy the requested portion to the output buffer | ||
buf[..bytes_to_read].copy_from_slice(&bytes[start..end]); | ||
|
||
Ok(bytes_to_copy) | ||
} | ||
bytes_to_read | ||
} else { | ||
0 // Position is beyond EOF | ||
}; | ||
|
||
fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> { | ||
let rt = common_runtime::get_io_runtime(true); | ||
self.position += bytes_to_read; | ||
self.cached_content = Some(bytes); | ||
|
||
let source = self.source.clone(); | ||
let uri = self.uri.clone(); | ||
let io_stats = self.io_stats.clone(); | ||
Ok(bytes_to_read) | ||
} else { | ||
// Range requests supported - use the returned bytes directly | ||
if bytes.is_empty() { | ||
return Ok(0); | ||
} | ||
|
||
let size = rt | ||
.block_within_async_context(async move { | ||
source | ||
.get_size(&uri, io_stats.clone()) | ||
.await | ||
.map_err(map_get_error) | ||
}) | ||
.map_err(map_async_error)??; | ||
let bytes_to_copy = std::cmp::min(buf.len(), bytes.len()); | ||
buf[..bytes_to_copy].copy_from_slice(&bytes[..bytes_to_copy]); | ||
|
||
let source = self.source.clone(); | ||
let uri = self.uri.clone(); | ||
let io_stats = self.io_stats.clone(); | ||
let position = self.position; | ||
self.position += bytes_to_copy; | ||
Ok(bytes_to_copy) | ||
} | ||
} | ||
|
||
let bytes = rt | ||
.block_within_async_context(async move { | ||
let range = Some(GetRange::Bounded(position..size)); | ||
fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> { | ||
// If we have cached content, serve from cache | ||
if let Some(ref content) = self.cached_content { | ||
if self.position >= content.len() { | ||
return Ok(0); // EOF | ||
} | ||
|
||
let result = source | ||
.get(&uri, range, io_stats) | ||
.await | ||
.map_err(map_get_error)?; | ||
let bytes_to_read = content.len() - self.position; | ||
buf.extend_from_slice(&content[self.position..]); | ||
|
||
result.bytes().await.map_err(map_bytes_error) | ||
}) | ||
.map_err(map_async_error)??; | ||
self.position = content.len(); | ||
|
||
buf.reserve(bytes.len()); | ||
return Ok(bytes_to_read); | ||
} | ||
|
||
let content = self.read_full_content()?; | ||
|
||
if self.position >= content.len() { | ||
return Ok(0); | ||
} | ||
|
||
buf.extend_from_slice(&bytes); | ||
let bytes_to_read = content.len() - self.position; | ||
buf.extend_from_slice(&content[self.position..]); | ||
|
||
self.position += bytes.len(); | ||
self.cached_content = Some(content); | ||
self.position = self.cached_content.as_ref().unwrap().len(); | ||
|
||
Ok(bytes.len()) | ||
Ok(bytes_to_read) | ||
} | ||
} | ||
|
||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems a little hacky to hardcode it for these two source types. Is there a way to move this logic to the individual
ObjectSource
implementations themselves?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I was initially thinking that as well, but that also felt kinda hacky as then we need to do
downcast_ref()
and also add asource_type() -> SourceType
method. I actually had that solution coded out in an earlier revision and this one felt slightly less hacky to me.FWIW, I think technically some "s3like" apis could also return false here, depending on the implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also think this could serve as an alternative to #5188. My pr focuses solely on usage within daft.file, but we could expand the usage of this for gracefully handling elsewhere when those pesky 416's pop up.