|
| 1 | +from typing import Any, Generic, List, Sequence, TypeVar, Union, get_args, get_origin |
| 2 | + |
| 3 | +import torch |
| 4 | +from pydantic import BaseModel, model_validator |
| 5 | +from tensordict import ( |
| 6 | + MetaData, |
| 7 | + NonTensorData, |
| 8 | + NonTensorStack, |
| 9 | + TensorDict, |
| 10 | + is_tensor_collection, |
| 11 | + set_list_to_stack, |
| 12 | +) |
| 13 | +from torch import Size, Tensor |
| 14 | + |
| 15 | +set_list_to_stack(True).set() |
| 16 | + |
| 17 | +T = TypeVar("T") |
| 18 | + |
| 19 | + |
| 20 | +class NestedList(Generic[T]): |
| 21 | + """A type that accepts either a single value or nested lists of values of type T""" |
| 22 | + |
| 23 | + def __init__(self, value: Union[T, List["NestedList[T]"]]): |
| 24 | + self.value = value |
| 25 | + |
| 26 | + @classmethod |
| 27 | + def __get_pydantic_core_schema__(cls, source_type, _handler): |
| 28 | + T = get_args(source_type)[0] # Get the type parameter |
| 29 | + |
| 30 | + def validate_nested(v): |
| 31 | + if isinstance(v, (list, tuple)): |
| 32 | + return [validate_nested(x) for x in v] |
| 33 | + return T(v) if not isinstance(v, T) else v |
| 34 | + |
| 35 | + return {"type": "any", "mode": "before", "function": validate_nested} |
| 36 | + |
| 37 | + |
| 38 | +class TensorClass(BaseModel): |
| 39 | + model_config = { |
| 40 | + "arbitrary_types_allowed": True, |
| 41 | + "extra": "allow", # Allow extra fields like batch_size |
| 42 | + } |
| 43 | + |
| 44 | + def __init__(self, **data): |
| 45 | + batch_size = data.pop("batch_size", None) |
| 46 | + device = data.pop("device", None) |
| 47 | + super().__init__(**data) |
| 48 | + self.__dict__["_tensordict"] = TensorDict( |
| 49 | + {}, batch_size=batch_size, device=device |
| 50 | + ) |
| 51 | + # Initialize tensordict with current values |
| 52 | + for field in self.__class__.model_fields: |
| 53 | + if hasattr(self, field): |
| 54 | + item = getattr(self, field) |
| 55 | + if isinstance(item, torch.Tensor) or is_tensor_collection(item): |
| 56 | + self._tensordict[field] = item |
| 57 | + else: |
| 58 | + # Get the field's type annotation |
| 59 | + field_type = self.__class__.model_fields[field].annotation |
| 60 | + |
| 61 | + def has_nested_list(type_): |
| 62 | + """Check if a type contains NestedList""" |
| 63 | + if get_origin(type_) is NestedList: |
| 64 | + return True |
| 65 | + if get_origin(type_) in (Union, None): |
| 66 | + # For Union types or simple types, check each argument |
| 67 | + args = get_args(type_) |
| 68 | + return any(has_nested_list(arg) for arg in args) |
| 69 | + return False |
| 70 | + |
| 71 | + def get_primary_type(type_): |
| 72 | + """Get the primary type (non-NestedList) from a Union or simple type""" |
| 73 | + if get_origin(type_) in (Union, None): |
| 74 | + # Look through Union args for non-NestedList type |
| 75 | + for arg in get_args(type_): |
| 76 | + if not has_nested_list(arg): |
| 77 | + return arg |
| 78 | + return None |
| 79 | + |
| 80 | + # Check if it's a NestedList type or contains NestedList |
| 81 | + if has_nested_list(field_type): |
| 82 | + primary_type = get_primary_type(field_type) |
| 83 | + # If it matches the primary type (e.g. str for b), treat as scalar |
| 84 | + if primary_type and isinstance(item, primary_type): |
| 85 | + self._tensordict[field] = NonTensorData(item) |
| 86 | + # Otherwise if it's a sequence (but not str), treat as list |
| 87 | + elif isinstance(item, Sequence) and not isinstance( |
| 88 | + item, (str, bytes) |
| 89 | + ): |
| 90 | + stack = NonTensorStack.from_list(item) |
| 91 | + self._tensordict[field] = stack |
| 92 | + else: |
| 93 | + self._tensordict[field] = NonTensorData(item) |
| 94 | + else: |
| 95 | + self._tensordict[field] = MetaData(item) |
| 96 | + delattr(self, field) |
| 97 | + |
| 98 | + @property |
| 99 | + def device(self) -> torch.device: |
| 100 | + return self._tensordict.device |
| 101 | + |
| 102 | + @device.setter |
| 103 | + def device(self, device: torch.device): |
| 104 | + self._tensordict.device = device |
| 105 | + |
| 106 | + @property |
| 107 | + def batch_size(self) -> Size: |
| 108 | + """Get the batch size of the underlying TensorDict""" |
| 109 | + return self._tensordict.batch_size |
| 110 | + |
| 111 | + @batch_size.setter |
| 112 | + def batch_size(self, size): |
| 113 | + """Set the batch size of the underlying TensorDict""" |
| 114 | + td = self.__dict__["_tensordict"] |
| 115 | + td.batch_size = size |
| 116 | + |
| 117 | + @model_validator(mode="after") |
| 118 | + def sync_to_tensordict(self, data): |
| 119 | + # Ensure _tensordict exists with proper batch size |
| 120 | + if not hasattr(self, "_tensordict"): |
| 121 | + self.__dict__["_tensordict"] = TensorDict({}, batch_size=[]) |
| 122 | + |
| 123 | + # Sync all fields to tensordict |
| 124 | + for field in self.__class__.model_fields: |
| 125 | + if hasattr(self, field): |
| 126 | + self._tensordict[field] = getattr(self, field) |
| 127 | + return self |
| 128 | + |
| 129 | + def __getattr__(self, name: str) -> Any: |
| 130 | + if name == "_tensordict": |
| 131 | + try: |
| 132 | + return self.__dict__[name] |
| 133 | + except KeyError: |
| 134 | + raise AttributeError( |
| 135 | + f"{self.__class__.__name__} has no attribute {name}" |
| 136 | + ) |
| 137 | + try: |
| 138 | + return self._tensordict[name] |
| 139 | + except (KeyError, AttributeError): |
| 140 | + raise AttributeError(f"{self.__class__.__name__} has no attribute {name}") |
| 141 | + |
| 142 | + def __repr__(self): |
| 143 | + fields_repr = ", ".join( |
| 144 | + f"{field}={getattr(self, field)}" for field in self.__class__.model_fields |
| 145 | + ) |
| 146 | + extra_fields = { |
| 147 | + "batch_size": self.batch_size, |
| 148 | + "device": self._tensordict.device, |
| 149 | + } |
| 150 | + extra_repr = ", ".join(f"{k}={v}" for k, v in extra_fields.items()) |
| 151 | + return f"{self.__class__.__name__}({fields_repr}, {extra_repr})" |
| 152 | + |
| 153 | + |
| 154 | +if __name__ == "__main__": |
| 155 | + class MyClass(TensorClass): |
| 156 | + a: int | NestedList[int] # NonTensorStack |
| 157 | + b: str | NestedList[str] # Now accepts str or list[str] or list[list[str]] etc |
| 158 | + c: Tensor |
| 159 | + d: str # MetaData |
| 160 | + |
| 161 | + |
| 162 | + # Default (empty) batch size |
| 163 | + model = MyClass(a=1, b="hello", c=torch.tensor([1.0, 2.0, 3.0]), d="a string") |
| 164 | + print(f"{model=}") |
| 165 | + print(f"Model attributes: a={model.a}, b={model.b}") |
| 166 | + print(f"Model tensor c={model.c}") |
| 167 | + print(f"TensorDict contents: {model._tensordict}") |
| 168 | + print(f"Default batch_size: {model.batch_size}") |
| 169 | + |
| 170 | + # Integer batch size |
| 171 | + model = MyClass( |
| 172 | + a=[1, 2], |
| 173 | + b=["hello", "world"], |
| 174 | + c=torch.tensor([[1.0], [2.0]]), # 2x1 tensor |
| 175 | + d="a string", |
| 176 | + batch_size=2, |
| 177 | + ) |
| 178 | + print(f"{model=}") |
| 179 | + print(f"{model._tensordict=}") |
| 180 | + print(f"Integer batch_size: {model.batch_size}") |
0 commit comments