Skip to content

Commit 9e21c31

Browse files
authored
fix(mutate): simplify and drop mutated columns (#1298)
* fix(mutate): simplify and drop mutated columns * cover more edge cases, add more tests * fix tests, keep sys columns on mutate
1 parent aec36fc commit 9e21c31

File tree

10 files changed

+368
-130
lines changed

10 files changed

+368
-130
lines changed

src/datachain/lib/dc/datachain.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1262,8 +1262,10 @@ def mutate(self, **kwargs) -> "Self":
12621262
# adding new signal
12631263
mutated[name] = value
12641264

1265+
new_schema = schema.mutate(kwargs)
12651266
return self._evolve(
1266-
query=self._query.mutate(**mutated), signal_schema=schema.mutate(kwargs)
1267+
query=self._query.mutate(new_schema=new_schema, **mutated),
1268+
signal_schema=new_schema,
12671269
)
12681270

12691271
@property

src/datachain/lib/signal_schema.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from datachain.lib.file import File
3535
from datachain.lib.model_store import ModelStore
3636
from datachain.lib.utils import DataChainParamsError
37-
from datachain.query.schema import DEFAULT_DELIMITER, Column, ColumnMeta
37+
from datachain.query.schema import DEFAULT_DELIMITER, C, Column, ColumnMeta
3838
from datachain.sql.types import SQLType
3939

4040
if TYPE_CHECKING:
@@ -680,35 +680,46 @@ def mutate(self, args_map: dict) -> "SignalSchema":
680680
primitives = (bool, str, int, float)
681681

682682
for name, value in args_map.items():
683+
current_type = None
684+
685+
if C.is_nested(name):
686+
try:
687+
current_type = self.get_column_type(name)
688+
except SignalResolvingError as err:
689+
msg = f"Creating new nested columns directly is not allowed: {name}"
690+
raise ValueError(msg) from err
691+
683692
if isinstance(value, Column) and value.name in self.values:
684693
# renaming existing signal
694+
# Note: it won't touch nested signals here (e.g. file__path)
695+
# we don't allow removing nested columns to keep objects consistent
685696
del new_values[value.name]
686697
new_values[name] = self.values[value.name]
687-
continue
688-
if isinstance(value, Column):
698+
elif isinstance(value, Column):
689699
# adding new signal from existing signal field
690-
try:
691-
new_values[name] = self.get_column_type(
692-
value.name, with_subtree=True
693-
)
694-
continue
695-
except SignalResolvingError:
696-
pass
697-
if isinstance(value, Func):
700+
new_values[name] = self.get_column_type(value.name, with_subtree=True)
701+
elif isinstance(value, Func):
698702
# adding new signal with function
699703
new_values[name] = value.get_result_type(self)
700-
continue
701-
if isinstance(value, primitives):
704+
elif isinstance(value, primitives):
702705
# For primitives, store the type, not the value
703706
val = literal(value)
704707
val.type = python_to_sql(type(value))()
705708
new_values[name] = sql_to_python(val)
706-
continue
707-
if isinstance(value, ColumnElement):
709+
elif isinstance(value, ColumnElement):
708710
# adding new signal
709711
new_values[name] = sql_to_python(value)
710-
continue
711-
new_values[name] = value
712+
else:
713+
new_values[name] = value
714+
715+
if C.is_nested(name):
716+
if current_type != new_values[name]:
717+
msg = (
718+
f"Altering nested column type is not allowed: {name}, "
719+
f"current type: {current_type}, new type: {new_values[name]}"
720+
)
721+
raise ValueError(msg)
722+
del new_values[name]
712723

713724
return SignalSchema(new_values)
714725

src/datachain/query/dataset.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from collections.abc import Generator, Iterable, Iterator, Sequence
1111
from copy import copy
1212
from functools import wraps
13-
from secrets import token_hex
1413
from types import GeneratorType
1514
from typing import (
1615
TYPE_CHECKING,
@@ -29,7 +28,7 @@
2928
from fsspec.callbacks import DEFAULT_CALLBACK, Callback, TqdmCallback
3029
from sqlalchemy import Column
3130
from sqlalchemy.sql import func as f
32-
from sqlalchemy.sql.elements import ColumnClause, ColumnElement
31+
from sqlalchemy.sql.elements import ColumnClause, ColumnElement, Label
3332
from sqlalchemy.sql.expression import label
3433
from sqlalchemy.sql.schema import TableClause
3534
from sqlalchemy.sql.selectable import Select
@@ -46,6 +45,7 @@
4645
from datachain.error import DatasetNotFoundError, QueryScriptCancelError
4746
from datachain.func.base import Function
4847
from datachain.lib.listing import is_listing_dataset, listing_dataset_expired
48+
from datachain.lib.signal_schema import SignalSchema
4949
from datachain.lib.udf import UDFAdapter, _get_cache
5050
from datachain.progress import CombinedDownloadCallback, TqdmCombinedDownloadCallback
5151
from datachain.project import Project
@@ -795,28 +795,32 @@ def apply_sql_clause(self, query: Select) -> Select:
795795

796796
@frozen
797797
class SQLMutate(SQLClause):
798-
args: tuple[Union[Function, ColumnElement], ...]
798+
args: tuple[Label, ...]
799+
new_schema: SignalSchema
799800

800801
def apply_sql_clause(self, query: Select) -> Select:
801802
original_subquery = query.subquery()
802-
args = [
803-
original_subquery.c[str(c)] if isinstance(c, (str, C)) else c
804-
for c in self.parse_cols(self.args)
805-
]
806-
to_mutate = {c.name for c in args}
803+
to_mutate = {c.name for c in self.args}
807804

808-
prefix = f"mutate{token_hex(8)}_"
809-
cols = [
810-
c.label(prefix + c.name) if c.name in to_mutate else c
805+
# Drop the original versions to avoid name collisions, exclude renamed
806+
# columns. Always keep system columns (sys__*) if they exist in original query
807+
new_schema_columns = set(self.new_schema.db_signals())
808+
base_cols = [
809+
c
811810
for c in original_subquery.c
811+
if c.name not in to_mutate
812+
and (c.name in new_schema_columns or c.name.startswith("sys__"))
812813
]
813-
# this is needed for new column to be used in clauses
814-
# like ORDER BY, otherwise new column is not recognized
815-
subquery = (
816-
sqlalchemy.select(*cols, *args).select_from(original_subquery).subquery()
814+
815+
# Create intermediate subquery to properly handle window functions
816+
intermediate_query = sqlalchemy.select(*base_cols, *self.args).select_from(
817+
original_subquery
817818
)
819+
intermediate_subquery = intermediate_query.subquery()
818820

819-
return sqlalchemy.select(*subquery.c).select_from(subquery)
821+
return sqlalchemy.select(*intermediate_subquery.c).select_from(
822+
intermediate_subquery
823+
)
820824

821825

822826
@frozen
@@ -1470,7 +1474,7 @@ def select_except(self, *args) -> "Self":
14701474
return query
14711475

14721476
@detach
1473-
def mutate(self, *args, **kwargs) -> "Self":
1477+
def mutate(self, *args, new_schema, **kwargs) -> "Self":
14741478
"""
14751479
Add new columns to this query.
14761480
@@ -1482,7 +1486,7 @@ def mutate(self, *args, **kwargs) -> "Self":
14821486
"""
14831487
query_args = [v.label(k) for k, v in dict(args, **kwargs).items()]
14841488
query = self.clone()
1485-
query.steps.append(SQLMutate((*query_args,)))
1489+
query.steps.append(SQLMutate((*query_args,), new_schema))
14861490
return query
14871491

14881492
@detach

src/datachain/query/schema.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ def to_db_name(name: str) -> str:
3636
def __getattr__(cls, name: str):
3737
return cls(ColumnMeta.to_db_name(name))
3838

39+
@staticmethod
40+
def is_nested(name: str) -> bool:
41+
return DEFAULT_DELIMITER in name
42+
3943

4044
class Column(sa.ColumnClause, metaclass=ColumnMeta):
4145
inherit_cache: Optional[bool] = True

tests/func/test_data_storage.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from tests.utils import (
1818
DEFAULT_TREE,
1919
TARRED_TREE,
20-
create_tar_dataset_with_legacy_columns,
20+
create_tar_dataset,
2121
)
2222

2323
COMPLEX_TREE: dict[str, Any] = {
@@ -39,7 +39,7 @@ def test_dir_expansion(cloud_test_catalog, version_aware, cloud_type):
3939
# we don't want to index things in parent directory
4040
src_uri += "/"
4141

42-
chain = create_tar_dataset_with_legacy_columns(session, ctc.src_uri, "dc")
42+
chain = create_tar_dataset(session, ctc.src_uri, "dc")
4343
dataset = catalog.get_dataset(chain.name)
4444
with catalog.warehouse.clone() as warehouse:
4545
dr = warehouse.dataset_rows(dataset, column="file")

tests/func/test_datachain.py

Lines changed: 0 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import pytest
1717
import pytz
1818
from PIL import Image
19-
from sqlalchemy import Column
2019

2120
import datachain as dc
2221
from datachain import DataModel, func
@@ -236,22 +235,6 @@ def test_read_storage_dependencies(cloud_test_catalog, cloud_type):
236235
assert dependencies[0].name == dep_name
237236

238237

239-
def test_persist_after_mutate(test_session):
240-
chain = (
241-
dc.read_values(fib=[1, 1, 2, 3, 5, 8, 13, 21], session=test_session)
242-
.map(mod3=lambda fib: fib % 3, output=int)
243-
.group_by(
244-
cnt=dc.func.count(),
245-
partition_by="mod3",
246-
)
247-
.mutate(x=1)
248-
.persist()
249-
)
250-
251-
assert chain.count() == 3
252-
assert set(chain.to_values("mod3")) == {0, 1, 2}
253-
254-
255238
def test_persist_not_affects_dependencies(tmp_dir, test_session):
256239
for i in range(4):
257240
(tmp_dir / f"file{i}.txt").write_text(f"file{i}")
@@ -776,59 +759,6 @@ def test_read_storage_check_rows(tmp_dir, test_session):
776759
)
777760

778761

779-
def test_mutate_existing_column(test_session):
780-
ds = dc.read_values(ids=[1, 2, 3], session=test_session)
781-
ds = ds.mutate(ids=Column("ids") + 1)
782-
783-
assert ds.order_by("ids").to_list() == [(2,), (3,), (4,)]
784-
785-
786-
def test_mutate_with_primitives_save_load(test_session):
787-
"""Test that mutate with primitive values properly persists schema
788-
through save/load cycle."""
789-
original_data = [1, 2, 3]
790-
791-
# Create dataset with multiple primitive columns added via mutate
792-
ds = dc.read_values(data=original_data, session=test_session).mutate(
793-
str_col="test_string",
794-
int_col=42,
795-
float_col=3.14,
796-
bool_col=True,
797-
)
798-
799-
# Verify schema before saving
800-
schema = ds.signals_schema.values
801-
assert schema.get("str_col") is str
802-
assert schema.get("int_col") is int
803-
assert schema.get("float_col") is float
804-
assert schema.get("bool_col") is bool
805-
806-
ds.save("test_mutate_primitives")
807-
808-
# Load the dataset back
809-
loaded_ds = dc.read_dataset("test_mutate_primitives", session=test_session)
810-
811-
# Verify schema after loading
812-
loaded_schema = loaded_ds.signals_schema.values
813-
assert loaded_schema.get("str_col") is str
814-
assert loaded_schema.get("int_col") is int
815-
assert loaded_schema.get("float_col") is float
816-
assert loaded_schema.get("bool_col") is bool
817-
818-
# Verify data integrity
819-
results = set(loaded_ds.to_list())
820-
assert len(results) == 3
821-
822-
# Expected tuples: (data, str_col, int_col, float_col, bool_col)
823-
expected_results = {
824-
(1, "test_string", 42, 3.14, True),
825-
(2, "test_string", 42, 3.14, True),
826-
(3, "test_string", 42, 3.14, True),
827-
}
828-
829-
assert results == expected_results
830-
831-
832762
@pytest.mark.parametrize("processes", [False, 2, True])
833763
@pytest.mark.xdist_group(name="tmpfile")
834764
def test_parallel(processes, test_session_tmpfile):

tests/func/test_dataset_query.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88

99
from datachain.dataset import DatasetDependencyType, DatasetStatus
1010
from datachain.error import DatasetNotFoundError
11+
from datachain.lib.file import File
1112
from datachain.lib.listing import parse_listing_uri
13+
from datachain.lib.signal_schema import SignalSchema
1214
from datachain.query import C, DatasetQuery, Object, Stream
1315
from datachain.sql.functions import path as pathfunc
1416
from datachain.sql.types import String
@@ -19,6 +21,12 @@ def from_result_row(col_names, row):
1921
return dict(zip(col_names, row))
2022

2123

24+
def create_dataset_query_mutate_schema(**mutations):
25+
schema_values = {"file": File}
26+
schema_values.update(mutations)
27+
return SignalSchema(schema_values)
28+
29+
2230
@pytest.fixture
2331
def dogs_cats_dataset(listed_bucket, cloud_test_catalog, dogs_dataset, cats_dataset):
2432
dataset_name = uuid.uuid4().hex
@@ -306,12 +314,14 @@ def test_distinct_count(cloud_test_catalog, animal_dataset):
306314
def test_mutate(cloud_test_catalog, save, animal_dataset):
307315
catalog = cloud_test_catalog.catalog
308316
ds = DatasetQuery(animal_dataset.name, catalog=catalog)
317+
schema = create_dataset_query_mutate_schema(size10x=int, size1000x=int)
309318
q = (
310-
ds.mutate(size10x=C("file.size") * 10)
311-
.mutate(size1000x=C.size10x * 100)
319+
ds.mutate(new_schema=schema, size10x=C("file.size") * 10)
320+
.mutate(new_schema=schema, size1000x=C.size10x * 100)
312321
.mutate(
313322
("s2", C("file.size") * 2),
314323
("s3", C("file.size") * 3),
324+
new_schema=schema,
315325
s4=C("file.size") * 4,
316326
)
317327
.filter((C.size10x < 40) | (C.size10x > 100) | C("file.path").glob("cat*"))
@@ -349,8 +359,9 @@ def test_mutate(cloud_test_catalog, save, animal_dataset):
349359
def test_order_by_after_mutate(cloud_test_catalog, save, animal_dataset):
350360
catalog = cloud_test_catalog.catalog
351361
ds = DatasetQuery(animal_dataset.name, catalog=catalog)
362+
schema = create_dataset_query_mutate_schema(size10x=int)
352363
q = (
353-
ds.mutate(size10x=C("file.size") * 10)
364+
ds.mutate(new_schema=schema, size10x=C("file.size") * 10)
354365
.filter((C.size10x < 40) | (C.size10x > 100) | C("file.path").glob("cat*"))
355366
.order_by(C.size10x.desc())
356367
)
@@ -446,10 +457,12 @@ def test_offset_limit(cloud_test_catalog, save, animal_dataset):
446457
@pytest.mark.parametrize("save", [True, False])
447458
def test_mutate_offset_limit(cloud_test_catalog, save, animal_dataset):
448459
catalog = cloud_test_catalog.catalog
460+
base_query = DatasetQuery(animal_dataset.name, catalog=catalog).order_by(
461+
C("file.path")
462+
)
463+
schema = create_dataset_query_mutate_schema(size10x=int)
449464
q = (
450-
DatasetQuery(animal_dataset.name, catalog=catalog)
451-
.order_by(C("file.path"))
452-
.mutate(size10x=C("file.size") * 10)
465+
base_query.mutate(new_schema=schema, size10x=C("file.size") * 10)
453466
.offset(3)
454467
.limit(2)
455468
)

0 commit comments

Comments
 (0)