Skip to content

Commit df9b803

Browse files
committed
init
1 parent e03c25e commit df9b803

File tree

1 file changed

+180
-0
lines changed

1 file changed

+180
-0
lines changed

tensordict/tensorclass_v2.py

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
from typing import Any, Generic, List, Sequence, TypeVar, Union, get_args, get_origin
2+
3+
import torch
4+
from pydantic import BaseModel, model_validator
5+
from tensordict import (
6+
MetaData,
7+
NonTensorData,
8+
NonTensorStack,
9+
TensorDict,
10+
is_tensor_collection,
11+
set_list_to_stack,
12+
)
13+
from torch import Size, Tensor
14+
15+
set_list_to_stack(True).set()
16+
17+
T = TypeVar("T")
18+
19+
20+
class NestedList(Generic[T]):
21+
"""A type that accepts either a single value or nested lists of values of type T"""
22+
23+
def __init__(self, value: Union[T, List["NestedList[T]"]]):
24+
self.value = value
25+
26+
@classmethod
27+
def __get_pydantic_core_schema__(cls, source_type, _handler):
28+
T = get_args(source_type)[0] # Get the type parameter
29+
30+
def validate_nested(v):
31+
if isinstance(v, (list, tuple)):
32+
return [validate_nested(x) for x in v]
33+
return T(v) if not isinstance(v, T) else v
34+
35+
return {"type": "any", "mode": "before", "function": validate_nested}
36+
37+
38+
class TensorClass(BaseModel):
39+
model_config = {
40+
"arbitrary_types_allowed": True,
41+
"extra": "allow", # Allow extra fields like batch_size
42+
}
43+
44+
def __init__(self, **data):
45+
batch_size = data.pop("batch_size", None)
46+
device = data.pop("device", None)
47+
super().__init__(**data)
48+
self.__dict__["_tensordict"] = TensorDict(
49+
{}, batch_size=batch_size, device=device
50+
)
51+
# Initialize tensordict with current values
52+
for field in self.__class__.model_fields:
53+
if hasattr(self, field):
54+
item = getattr(self, field)
55+
if isinstance(item, torch.Tensor) or is_tensor_collection(item):
56+
self._tensordict[field] = item
57+
else:
58+
# Get the field's type annotation
59+
field_type = self.__class__.model_fields[field].annotation
60+
61+
def has_nested_list(type_):
62+
"""Check if a type contains NestedList"""
63+
if get_origin(type_) is NestedList:
64+
return True
65+
if get_origin(type_) in (Union, None):
66+
# For Union types or simple types, check each argument
67+
args = get_args(type_)
68+
return any(has_nested_list(arg) for arg in args)
69+
return False
70+
71+
def get_primary_type(type_):
72+
"""Get the primary type (non-NestedList) from a Union or simple type"""
73+
if get_origin(type_) in (Union, None):
74+
# Look through Union args for non-NestedList type
75+
for arg in get_args(type_):
76+
if not has_nested_list(arg):
77+
return arg
78+
return None
79+
80+
# Check if it's a NestedList type or contains NestedList
81+
if has_nested_list(field_type):
82+
primary_type = get_primary_type(field_type)
83+
# If it matches the primary type (e.g. str for b), treat as scalar
84+
if primary_type and isinstance(item, primary_type):
85+
self._tensordict[field] = NonTensorData(item)
86+
# Otherwise if it's a sequence (but not str), treat as list
87+
elif isinstance(item, Sequence) and not isinstance(
88+
item, (str, bytes)
89+
):
90+
stack = NonTensorStack.from_list(item)
91+
self._tensordict[field] = stack
92+
else:
93+
self._tensordict[field] = NonTensorData(item)
94+
else:
95+
self._tensordict[field] = MetaData(item)
96+
delattr(self, field)
97+
98+
@property
99+
def device(self) -> torch.device:
100+
return self._tensordict.device
101+
102+
@device.setter
103+
def device(self, device: torch.device):
104+
self._tensordict.device = device
105+
106+
@property
107+
def batch_size(self) -> Size:
108+
"""Get the batch size of the underlying TensorDict"""
109+
return self._tensordict.batch_size
110+
111+
@batch_size.setter
112+
def batch_size(self, size):
113+
"""Set the batch size of the underlying TensorDict"""
114+
td = self.__dict__["_tensordict"]
115+
td.batch_size = size
116+
117+
@model_validator(mode="after")
118+
def sync_to_tensordict(self, data):
119+
# Ensure _tensordict exists with proper batch size
120+
if not hasattr(self, "_tensordict"):
121+
self.__dict__["_tensordict"] = TensorDict({}, batch_size=[])
122+
123+
# Sync all fields to tensordict
124+
for field in self.__class__.model_fields:
125+
if hasattr(self, field):
126+
self._tensordict[field] = getattr(self, field)
127+
return self
128+
129+
def __getattr__(self, name: str) -> Any:
130+
if name == "_tensordict":
131+
try:
132+
return self.__dict__[name]
133+
except KeyError:
134+
raise AttributeError(
135+
f"{self.__class__.__name__} has no attribute {name}"
136+
)
137+
try:
138+
return self._tensordict[name]
139+
except (KeyError, AttributeError):
140+
raise AttributeError(f"{self.__class__.__name__} has no attribute {name}")
141+
142+
def __repr__(self):
143+
fields_repr = ", ".join(
144+
f"{field}={getattr(self, field)}" for field in self.__class__.model_fields
145+
)
146+
extra_fields = {
147+
"batch_size": self.batch_size,
148+
"device": self._tensordict.device,
149+
}
150+
extra_repr = ", ".join(f"{k}={v}" for k, v in extra_fields.items())
151+
return f"{self.__class__.__name__}({fields_repr}, {extra_repr})"
152+
153+
154+
if __name__ == "__main__":
155+
class MyClass(TensorClass):
156+
a: int | NestedList[int] # NonTensorStack
157+
b: str | NestedList[str] # Now accepts str or list[str] or list[list[str]] etc
158+
c: Tensor
159+
d: str # MetaData
160+
161+
162+
# Default (empty) batch size
163+
model = MyClass(a=1, b="hello", c=torch.tensor([1.0, 2.0, 3.0]), d="a string")
164+
print(f"{model=}")
165+
print(f"Model attributes: a={model.a}, b={model.b}")
166+
print(f"Model tensor c={model.c}")
167+
print(f"TensorDict contents: {model._tensordict}")
168+
print(f"Default batch_size: {model.batch_size}")
169+
170+
# Integer batch size
171+
model = MyClass(
172+
a=[1, 2],
173+
b=["hello", "world"],
174+
c=torch.tensor([[1.0], [2.0]]), # 2x1 tensor
175+
d="a string",
176+
batch_size=2,
177+
)
178+
print(f"{model=}")
179+
print(f"{model._tensordict=}")
180+
print(f"Integer batch_size: {model.batch_size}")

0 commit comments

Comments
 (0)