10
10
from collections .abc import Generator , Iterable , Iterator , Sequence
11
11
from copy import copy
12
12
from functools import wraps
13
- from secrets import token_hex
14
13
from types import GeneratorType
15
14
from typing import (
16
15
TYPE_CHECKING ,
29
28
from fsspec .callbacks import DEFAULT_CALLBACK , Callback , TqdmCallback
30
29
from sqlalchemy import Column
31
30
from sqlalchemy .sql import func as f
32
- from sqlalchemy .sql .elements import ColumnClause , ColumnElement
31
+ from sqlalchemy .sql .elements import ColumnClause , ColumnElement , Label
33
32
from sqlalchemy .sql .expression import label
34
33
from sqlalchemy .sql .schema import TableClause
35
34
from sqlalchemy .sql .selectable import Select
46
45
from datachain .error import DatasetNotFoundError , QueryScriptCancelError
47
46
from datachain .func .base import Function
48
47
from datachain .lib .listing import is_listing_dataset , listing_dataset_expired
48
+ from datachain .lib .signal_schema import SignalSchema
49
49
from datachain .lib .udf import UDFAdapter , _get_cache
50
50
from datachain .progress import CombinedDownloadCallback , TqdmCombinedDownloadCallback
51
51
from datachain .project import Project
@@ -795,28 +795,32 @@ def apply_sql_clause(self, query: Select) -> Select:
795
795
796
796
@frozen
797
797
class SQLMutate (SQLClause ):
798
- args : tuple [Union [Function , ColumnElement ], ...]
798
+ args : tuple [Label , ...]
799
+ new_schema : SignalSchema
799
800
800
801
def apply_sql_clause (self , query : Select ) -> Select :
801
802
original_subquery = query .subquery ()
802
- args = [
803
- original_subquery .c [str (c )] if isinstance (c , (str , C )) else c
804
- for c in self .parse_cols (self .args )
805
- ]
806
- to_mutate = {c .name for c in args }
803
+ to_mutate = {c .name for c in self .args }
807
804
808
- prefix = f"mutate{ token_hex (8 )} _"
809
- cols = [
810
- c .label (prefix + c .name ) if c .name in to_mutate else c
805
+ # Drop the original versions to avoid name collisions, exclude renamed
806
+ # columns. Always keep system columns (sys__*) if they exist in original query
807
+ new_schema_columns = set (self .new_schema .db_signals ())
808
+ base_cols = [
809
+ c
811
810
for c in original_subquery .c
811
+ if c .name not in to_mutate
812
+ and (c .name in new_schema_columns or c .name .startswith ("sys__" ))
812
813
]
813
- # this is needed for new column to be used in clauses
814
- # like ORDER BY, otherwise new column is not recognized
815
- subquery = (
816
- sqlalchemy . select ( * cols , * args ). select_from ( original_subquery ). subquery ()
814
+
815
+ # Create intermediate subquery to properly handle window functions
816
+ intermediate_query = sqlalchemy . select ( * base_cols , * self . args ). select_from (
817
+ original_subquery
817
818
)
819
+ intermediate_subquery = intermediate_query .subquery ()
818
820
819
- return sqlalchemy .select (* subquery .c ).select_from (subquery )
821
+ return sqlalchemy .select (* intermediate_subquery .c ).select_from (
822
+ intermediate_subquery
823
+ )
820
824
821
825
822
826
@frozen
@@ -1470,7 +1474,7 @@ def select_except(self, *args) -> "Self":
1470
1474
return query
1471
1475
1472
1476
@detach
1473
- def mutate (self , * args , ** kwargs ) -> "Self" :
1477
+ def mutate (self , * args , new_schema , ** kwargs ) -> "Self" :
1474
1478
"""
1475
1479
Add new columns to this query.
1476
1480
@@ -1482,7 +1486,7 @@ def mutate(self, *args, **kwargs) -> "Self":
1482
1486
"""
1483
1487
query_args = [v .label (k ) for k , v in dict (args , ** kwargs ).items ()]
1484
1488
query = self .clone ()
1485
- query .steps .append (SQLMutate ((* query_args ,)))
1489
+ query .steps .append (SQLMutate ((* query_args ,), new_schema ))
1486
1490
return query
1487
1491
1488
1492
@detach
0 commit comments