Skip to content

Commit ef69e2c

Browse files
committed
refactor(lance): implement Lance DataSink batching mechanism with row-based flush control
This commit introduces a comprehensive batching mechanism for Lance DataSink to improve write performance and merge small files by combining multiple micropartitions before writing, creating larger Lance files and reducing fragmentation. Key Features: - Add batch_size parameter to control number of micropartitions to batch (default: 1) - Add max_batch_rows parameter for row-based flush control (default: 100,000) - Implement intelligent batching logic with dual flush conditions: * Flush when batch_size micropartitions are accumulated * Flush when max_batch_rows total rows are reached - Maintain full backward compatibility with existing code
1 parent 70116a6 commit ef69e2c

File tree

4 files changed

+541
-25
lines changed

4 files changed

+541
-25
lines changed

daft/dataframe/dataframe.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1472,6 +1472,8 @@ def write_lance(
14721472
mode: Literal["create", "append", "overwrite"] = "create",
14731473
io_config: Optional[IOConfig] = None,
14741474
schema: Optional[Schema] = None,
1475+
batch_size: int = 1,
1476+
max_batch_rows: int = 100000,
14751477
**kwargs: Any,
14761478
) -> "DataFrame":
14771479
"""Writes the DataFrame to a Lance table.
@@ -1480,12 +1482,19 @@ def write_lance(
14801482
uri: The URI of the Lance table to write to
14811483
mode: The write mode. One of "create", "append", or "overwrite"
14821484
io_config (IOConfig, optional): configurations to use when interacting with remote storage.
1485+
batch_size (int, optional): Number of micropartitions to batch together before writing. Default is 1 (no batching).
1486+
max_batch_rows (int, optional): Maximum number of rows to accumulate before flushing a batch. Default is 100,000.
14831487
**kwargs: Additional keyword arguments to pass to the Lance writer.
14841488
14851489
Note:
14861490
`write_lance` requires python 3.9 or higher
14871491
This call is **blocking** and will execute the DataFrame when called
14881492
1493+
Batching Parameters:
1494+
- batch_size=1 (default): No batching, maintains backward compatibility
1495+
- batch_size>1: Enables batching to combine multiple micropartitions
1496+
- max_batch_rows: Row-based flush control for predictable batching behavior
1497+
14891498
Returns:
14901499
DataFrame: A DataFrame containing metadata about the written Lance table, such as number of fragments, number of deleted rows, number of small files, and version.
14911500
@@ -1530,12 +1539,14 @@ def write_lance(
15301539
╰───────────────┴──────────────────┴─────────────────┴─────────╯
15311540
<BLANKLINE>
15321541
(Showing first 1 of 1 rows)
1542+
>>> # Enable batching for improved performance with large datasets
1543+
>>> df.write_lance("/tmp/lance/my_table.lance", batch_size=10, max_batch_rows=50000) # doctest: +SKIP
15331544
"""
15341545
from daft.io.lance.lance_data_sink import LanceDataSink
15351546

15361547
if schema is None:
15371548
schema = self.schema()
1538-
sink = LanceDataSink(uri, schema, mode, io_config, **kwargs)
1549+
sink = LanceDataSink(uri, schema, mode, io_config, batch_size, max_batch_rows, **kwargs)
15391550
return self.write_sink(sink)
15401551

15411552
@DataframePublicAPI

daft/io/lance/lance_data_sink.py

Lines changed: 177 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import logging
34
import pathlib
45
from itertools import chain
56
from typing import TYPE_CHECKING, Any, Literal
@@ -21,6 +22,9 @@
2122
from daft.daft import IOConfig
2223

2324

25+
logger = logging.getLogger(__name__)
26+
27+
2428
def pyarrow_schema_castable(src: pa.Schema, dst: pa.Schema) -> bool:
2529
if len(src) != len(dst):
2630
return False
@@ -50,6 +54,8 @@ def __init__(
5054
schema: Schema,
5155
mode: Literal["create", "append", "overwrite"],
5256
io_config: IOConfig | None = None,
57+
batch_size: int = 1,
58+
max_batch_rows: int = 100000,
5359
**kwargs: Any,
5460
) -> None:
5561
from daft.io.object_store_options import io_config_to_storage_options
@@ -64,6 +70,18 @@ def __init__(
6470

6571
self._storage_options = io_config_to_storage_options(self._io_config, self._table_uri)
6672

73+
if batch_size < 1:
74+
raise ValueError(f"batch_size must be >= 1, got {batch_size}")
75+
if max_batch_rows < 1:
76+
raise ValueError(f"max_batch_rows must be >= 1, got {max_batch_rows}")
77+
78+
self._batch_size = batch_size
79+
self._max_batch_rows = max_batch_rows
80+
81+
self._batch_tables: list[pa.Table] = []
82+
self._batch_row_count: int = 0
83+
self._batch_count: int = 0
84+
6785
self._pyarrow_schema = schema.to_pyarrow_schema()
6886

6987
try:
@@ -101,37 +119,172 @@ def schema(self) -> Schema:
101119
return self._schema
102120

103121
def write(self, micropartitions: Iterator[MicroPartition]) -> Iterator[WriteResult[list[lance.FragmentMetadata]]]:
104-
"""Writes fragments from the given micropartitions."""
122+
"""Writes fragments from the given micropartitions with class-level batching support."""
123+
self._reset_batch_state()
124+
105125
lance = self._import_lance()
106126

107-
for micropartition in micropartitions:
108-
arrow_table = pa.Table.from_batches(
109-
micropartition.to_arrow().to_batches(),
110-
self._pyarrow_schema,
111-
)
112-
if self._table_schema is not None:
113-
arrow_table = arrow_table.cast(self._table_schema)
114-
115-
bytes_written = arrow_table.nbytes
116-
rows_written = arrow_table.num_rows
117-
118-
fragments = lance.fragment.write_fragments(
119-
arrow_table,
120-
dataset_uri=self._table_uri,
121-
mode=self._mode,
122-
storage_options=self._storage_options,
123-
**self._kwargs,
124-
)
125-
yield WriteResult(
126-
result=fragments,
127-
bytes_written=bytes_written,
128-
rows_written=rows_written,
129-
)
127+
if self._batch_size == 1:
128+
for micropartition in micropartitions:
129+
arrow_table = pa.Table.from_batches(
130+
micropartition.to_arrow().to_batches(),
131+
self._pyarrow_schema,
132+
)
133+
if self._table_schema is not None:
134+
arrow_table = arrow_table.cast(self._table_schema)
135+
136+
bytes_written = arrow_table.nbytes
137+
rows_written = arrow_table.num_rows
138+
139+
fragments = lance.fragment.write_fragments(
140+
arrow_table,
141+
dataset_uri=self._table_uri,
142+
mode=self._mode,
143+
storage_options=self._storage_options,
144+
**self._kwargs,
145+
)
146+
yield WriteResult(
147+
result=fragments,
148+
bytes_written=bytes_written,
149+
rows_written=rows_written,
150+
)
151+
else:
152+
try:
153+
for micropartition in micropartitions:
154+
try:
155+
arrow_table = pa.Table.from_batches(
156+
micropartition.to_arrow().to_batches(),
157+
self._pyarrow_schema,
158+
)
159+
if self._table_schema is not None:
160+
arrow_table = arrow_table.cast(self._table_schema)
161+
162+
self._batch_tables.append(arrow_table)
163+
self._batch_row_count += arrow_table.num_rows
164+
self._batch_count += 1
165+
166+
if self._should_flush_batch(self._batch_count, self._batch_row_count):
167+
yield from self._flush_batch(lance)
168+
169+
except Exception as e:
170+
logger.error("Error processing micropartition: %s", e)
171+
if self._batch_tables:
172+
try:
173+
yield from self._flush_batch(lance)
174+
except Exception as flush_error:
175+
logger.error("Batch flush failed: %s", flush_error)
176+
for table in self._batch_tables:
177+
yield from self._write_single_table(table, lance)
178+
self._reset_batch_state()
179+
180+
try:
181+
arrow_table = pa.Table.from_batches(
182+
micropartition.to_arrow().to_batches(),
183+
self._pyarrow_schema,
184+
)
185+
if self._table_schema is not None:
186+
arrow_table = arrow_table.cast(self._table_schema)
187+
yield from self._write_single_table(arrow_table, lance)
188+
except Exception as individual_error:
189+
logger.error("Failed to process micropartition individually: %s", individual_error)
190+
raise e
191+
finally:
192+
if self._batch_tables:
193+
try:
194+
yield from self._flush_batch(lance)
195+
except Exception as final_flush_error:
196+
logger.error("Final batch flush failed: %s", final_flush_error)
197+
for table in self._batch_tables:
198+
try:
199+
yield from self._write_single_table(table, lance)
200+
except Exception as table_error:
201+
logger.error("Failed to write individual table in final cleanup: %s", table_error)
202+
self._reset_batch_state()
203+
204+
def _should_flush_batch(self, batch_count: int, batch_row_count: int) -> bool:
205+
"""Determines if the current batch should be flushed."""
206+
should_flush = batch_count >= self._batch_size or batch_row_count >= self._max_batch_rows
207+
208+
logger.debug(
209+
"Batch status: count=%d/%d, rows=%d/%d, should_flush=%s",
210+
batch_count,
211+
self._batch_size,
212+
batch_row_count,
213+
self._max_batch_rows,
214+
should_flush,
215+
)
216+
217+
return should_flush
218+
219+
def _flush_batch(self, lance: ModuleType) -> Iterator[WriteResult[list[lance.FragmentMetadata]]]:
220+
"""Flushes the class-level batch of Arrow tables by concatenating and writing them."""
221+
if not self._batch_tables:
222+
return
223+
224+
logger.debug(
225+
"Flushing class-level batch with %d tables, total rows: %d", len(self._batch_tables), self._batch_row_count
226+
)
227+
228+
try:
229+
if len(self._batch_tables) == 1:
230+
yield from self._write_single_table(self._batch_tables[0], lance)
231+
else:
232+
try:
233+
combined_table = pa.concat_tables(self._batch_tables)
234+
yield from self._write_single_table(combined_table, lance)
235+
except Exception as concat_error:
236+
logger.warning("Table concatenation failed: %s, falling back to individual writes", concat_error)
237+
for i, table in enumerate(self._batch_tables):
238+
try:
239+
yield from self._write_single_table(table, lance)
240+
except Exception as table_error:
241+
logger.error("Failed to write individual table %d: %s", i, table_error)
242+
raise
243+
except Exception as e:
244+
logger.error("Batch flush failed with %d tables: %s", len(self._batch_tables), e)
245+
raise
246+
finally:
247+
self._reset_batch_state()
248+
249+
def _reset_batch_state(self) -> None:
250+
"""Resets the class-level batch state."""
251+
self._batch_tables = []
252+
self._batch_row_count = 0
253+
self._batch_count = 0
254+
255+
def flush_batch(self) -> Iterator[WriteResult[list[lance.FragmentMetadata]]]:
256+
"""Public method to manually flush any accumulated batch data."""
257+
if self._batch_tables:
258+
lance = self._import_lance()
259+
yield from self._flush_batch(lance)
260+
261+
def _write_single_table(
262+
self, arrow_table: pa.Table, lance: ModuleType
263+
) -> Iterator[WriteResult[list[lance.FragmentMetadata]]]:
264+
"""Writes a single Arrow table to Lance."""
265+
bytes_written = arrow_table.nbytes
266+
rows_written = arrow_table.num_rows
267+
268+
fragments = lance.fragment.write_fragments(
269+
arrow_table,
270+
dataset_uri=self._table_uri,
271+
mode=self._mode,
272+
storage_options=self._storage_options,
273+
**self._kwargs,
274+
)
275+
yield WriteResult(
276+
result=fragments,
277+
bytes_written=bytes_written,
278+
rows_written=rows_written,
279+
)
130280

131281
def finalize(self, write_results: list[WriteResult[list[lance.FragmentMetadata]]]) -> MicroPartition:
132282
"""Commits the fragments to the Lance dataset. Returns a DataFrame with the stats of the dataset."""
133283
lance = self._import_lance()
134284

285+
remaining_results = list(self.flush_batch())
286+
write_results.extend(remaining_results)
287+
135288
fragments = list(chain.from_iterable(write_result.result for write_result in write_results))
136289

137290
if self._mode == "create" or self._mode == "overwrite":

daft/io/lance/lance_scan.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,10 @@ def _create_regular_scan_tasks(
231231
combined_filter = combined_filter.__and__(filter_expr)
232232
pushed_expr = Expression._from_pyexpr(combined_filter).to_arrow_expr()
233233

234+
if fragment.count_rows(pushed_expr) == 0:
235+
logger.debug("Skipping fragment %s with fragment_id %s with 0 rows", fragment.fragment_id)
236+
continue
237+
234238
yield ScanTask.python_factory_func_scan_task(
235239
module=_lancedb_table_factory_function.__module__,
236240
func_name=_lancedb_table_factory_function.__name__,

0 commit comments

Comments
 (0)