Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion daft/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1472,6 +1472,8 @@ def write_lance(
mode: Literal["create", "append", "overwrite"] = "create",
io_config: Optional[IOConfig] = None,
schema: Optional[Schema] = None,
batch_size: int = 1,
max_batch_rows: int = 100000,
**kwargs: Any,
) -> "DataFrame":
"""Writes the DataFrame to a Lance table.
Expand All @@ -1480,12 +1482,19 @@ def write_lance(
uri: The URI of the Lance table to write to
mode: The write mode. One of "create", "append", or "overwrite"
io_config (IOConfig, optional): configurations to use when interacting with remote storage.
batch_size (int, optional): Number of micropartitions to batch together before writing. Default is 1 (no batching).
max_batch_rows (int, optional): Maximum number of rows to accumulate before flushing a batch. Default is 100,000.
**kwargs: Additional keyword arguments to pass to the Lance writer.

Note:
`write_lance` requires python 3.9 or higher
This call is **blocking** and will execute the DataFrame when called

Batching Parameters:
- batch_size=1 (default): No batching, maintains backward compatibility
- batch_size>1: Enables batching to combine multiple micropartitions
- max_batch_rows: Row-based flush control for predictable batching behavior

Returns:
DataFrame: A DataFrame containing metadata about the written Lance table, such as number of fragments, number of deleted rows, number of small files, and version.

Expand Down Expand Up @@ -1530,12 +1539,14 @@ def write_lance(
╰───────────────┴──────────────────┴─────────────────┴─────────╯
<BLANKLINE>
(Showing first 1 of 1 rows)
>>> # Enable batching for improved performance with large datasets
>>> df.write_lance("/tmp/lance/my_table.lance", batch_size=10, max_batch_rows=50000) # doctest: +SKIP
"""
from daft.io.lance.lance_data_sink import LanceDataSink

if schema is None:
schema = self.schema()
sink = LanceDataSink(uri, schema, mode, io_config, **kwargs)
sink = LanceDataSink(uri, schema, mode, io_config, batch_size, max_batch_rows, **kwargs)
return self.write_sink(sink)

@DataframePublicAPI
Expand Down
201 changes: 177 additions & 24 deletions daft/io/lance/lance_data_sink.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import logging
import pathlib
from itertools import chain
from typing import TYPE_CHECKING, Any, Literal
Expand All @@ -21,6 +22,9 @@
from daft.daft import IOConfig


logger = logging.getLogger(__name__)


def pyarrow_schema_castable(src: pa.Schema, dst: pa.Schema) -> bool:
if len(src) != len(dst):
return False
Expand Down Expand Up @@ -50,6 +54,8 @@ def __init__(
schema: Schema,
mode: Literal["create", "append", "overwrite"],
io_config: IOConfig | None = None,
batch_size: int = 1,
max_batch_rows: int = 100000,
**kwargs: Any,
) -> None:
from daft.io.object_store_options import io_config_to_storage_options
Expand All @@ -64,6 +70,18 @@ def __init__(

self._storage_options = io_config_to_storage_options(self._io_config, self._table_uri)

if batch_size < 1:
raise ValueError(f"batch_size must be >= 1, got {batch_size}")
if max_batch_rows < 1:
raise ValueError(f"max_batch_rows must be >= 1, got {max_batch_rows}")

self._batch_size = batch_size
self._max_batch_rows = max_batch_rows

self._batch_tables: list[pa.Table] = []
self._batch_row_count: int = 0
self._batch_count: int = 0

self._pyarrow_schema = schema.to_pyarrow_schema()

try:
Expand Down Expand Up @@ -101,37 +119,172 @@ def schema(self) -> Schema:
return self._schema

def write(self, micropartitions: Iterator[MicroPartition]) -> Iterator[WriteResult[list[lance.FragmentMetadata]]]:
"""Writes fragments from the given micropartitions."""
"""Writes fragments from the given micropartitions with class-level batching support."""
self._reset_batch_state()

lance = self._import_lance()

for micropartition in micropartitions:
arrow_table = pa.Table.from_batches(
micropartition.to_arrow().to_batches(),
self._pyarrow_schema,
)
if self._table_schema is not None:
arrow_table = arrow_table.cast(self._table_schema)

bytes_written = arrow_table.nbytes
rows_written = arrow_table.num_rows

fragments = lance.fragment.write_fragments(
arrow_table,
dataset_uri=self._table_uri,
mode=self._mode,
storage_options=self._storage_options,
**self._kwargs,
)
yield WriteResult(
result=fragments,
bytes_written=bytes_written,
rows_written=rows_written,
)
if self._batch_size == 1:
for micropartition in micropartitions:
arrow_table = pa.Table.from_batches(
micropartition.to_arrow().to_batches(),
self._pyarrow_schema,
)
if self._table_schema is not None:
arrow_table = arrow_table.cast(self._table_schema)

bytes_written = arrow_table.nbytes
rows_written = arrow_table.num_rows

fragments = lance.fragment.write_fragments(
arrow_table,
dataset_uri=self._table_uri,
mode=self._mode,
storage_options=self._storage_options,
**self._kwargs,
)
yield WriteResult(
result=fragments,
bytes_written=bytes_written,
rows_written=rows_written,
)
else:
try:
for micropartition in micropartitions:
try:
arrow_table = pa.Table.from_batches(
micropartition.to_arrow().to_batches(),
self._pyarrow_schema,
)
if self._table_schema is not None:
arrow_table = arrow_table.cast(self._table_schema)

self._batch_tables.append(arrow_table)
self._batch_row_count += arrow_table.num_rows
self._batch_count += 1

if self._should_flush_batch(self._batch_count, self._batch_row_count):
yield from self._flush_batch(lance)

except Exception as e:
logger.error("Error processing micropartition: %s", e)
if self._batch_tables:
try:
yield from self._flush_batch(lance)
except Exception as flush_error:
logger.error("Batch flush failed: %s", flush_error)
for table in self._batch_tables:
yield from self._write_single_table(table, lance)
self._reset_batch_state()

try:
arrow_table = pa.Table.from_batches(
micropartition.to_arrow().to_batches(),
self._pyarrow_schema,
)
if self._table_schema is not None:
arrow_table = arrow_table.cast(self._table_schema)
yield from self._write_single_table(arrow_table, lance)
except Exception as individual_error:
logger.error("Failed to process micropartition individually: %s", individual_error)
raise e
finally:
if self._batch_tables:
try:
yield from self._flush_batch(lance)
except Exception as final_flush_error:
logger.error("Final batch flush failed: %s", final_flush_error)
for table in self._batch_tables:
try:
yield from self._write_single_table(table, lance)
except Exception as table_error:
logger.error("Failed to write individual table in final cleanup: %s", table_error)
self._reset_batch_state()

def _should_flush_batch(self, batch_count: int, batch_row_count: int) -> bool:
"""Determines if the current batch should be flushed."""
should_flush = batch_count >= self._batch_size or batch_row_count >= self._max_batch_rows

logger.debug(
"Batch status: count=%d/%d, rows=%d/%d, should_flush=%s",
batch_count,
self._batch_size,
batch_row_count,
self._max_batch_rows,
should_flush,
)

return should_flush

def _flush_batch(self, lance: ModuleType) -> Iterator[WriteResult[list[lance.FragmentMetadata]]]:
"""Flushes the class-level batch of Arrow tables by concatenating and writing them."""
if not self._batch_tables:
return

logger.debug(
"Flushing class-level batch with %d tables, total rows: %d", len(self._batch_tables), self._batch_row_count
)

try:
if len(self._batch_tables) == 1:
yield from self._write_single_table(self._batch_tables[0], lance)
else:
try:
combined_table = pa.concat_tables(self._batch_tables)
yield from self._write_single_table(combined_table, lance)
except Exception as concat_error:
logger.warning("Table concatenation failed: %s, falling back to individual writes", concat_error)
for i, table in enumerate(self._batch_tables):
try:
yield from self._write_single_table(table, lance)
except Exception as table_error:
logger.error("Failed to write individual table %d: %s", i, table_error)
raise
except Exception as e:
logger.error("Batch flush failed with %d tables: %s", len(self._batch_tables), e)
raise
finally:
self._reset_batch_state()

def _reset_batch_state(self) -> None:
"""Resets the class-level batch state."""
self._batch_tables = []
self._batch_row_count = 0
self._batch_count = 0

def flush_batch(self) -> Iterator[WriteResult[list[lance.FragmentMetadata]]]:
"""Public method to manually flush any accumulated batch data."""
if self._batch_tables:
lance = self._import_lance()
yield from self._flush_batch(lance)

def _write_single_table(
self, arrow_table: pa.Table, lance: ModuleType
) -> Iterator[WriteResult[list[lance.FragmentMetadata]]]:
"""Writes a single Arrow table to Lance."""
bytes_written = arrow_table.nbytes
rows_written = arrow_table.num_rows

fragments = lance.fragment.write_fragments(
arrow_table,
dataset_uri=self._table_uri,
mode=self._mode,
storage_options=self._storage_options,
**self._kwargs,
)
yield WriteResult(
result=fragments,
bytes_written=bytes_written,
rows_written=rows_written,
)

def finalize(self, write_results: list[WriteResult[list[lance.FragmentMetadata]]]) -> MicroPartition:
"""Commits the fragments to the Lance dataset. Returns a DataFrame with the stats of the dataset."""
lance = self._import_lance()

remaining_results = list(self.flush_batch())
write_results.extend(remaining_results)

fragments = list(chain.from_iterable(write_result.result for write_result in write_results))

if self._mode == "create" or self._mode == "overwrite":
Expand Down
4 changes: 4 additions & 0 deletions daft/io/lance/lance_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,10 @@ def _create_regular_scan_tasks(
combined_filter = combined_filter.__and__(filter_expr)
pushed_expr = Expression._from_pyexpr(combined_filter).to_arrow_expr()

if fragment.count_rows(pushed_expr) == 0:
logger.debug("Skipping fragment %s with fragment_id %s with 0 rows", fragment.fragment_id)
continue
Comment on lines +234 to +236
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: This optimization adds I/O overhead by calling count_rows() with filters for every fragment. Consider measuring the performance impact - for datasets with many small fragments or complex filters, this could be slower than just processing empty fragments downstream.


yield ScanTask.python_factory_func_scan_task(
module=_lancedb_table_factory_function.__module__,
func_name=_lancedb_table_factory_function.__name__,
Expand Down
Loading
Loading