Skip to content

Commit eaafc18

Browse files
author
Vincent Moens
committed
[Feature] Fix type assertion in Seq build
ghstack-source-id: 83d3dca Pull Request resolved: #1143
1 parent 7df2062 commit eaafc18

File tree

2 files changed

+45
-5
lines changed

2 files changed

+45
-5
lines changed

tensordict/nn/common.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,18 @@
99
import inspect
1010
import warnings
1111
from textwrap import indent
12-
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
12+
from typing import (
13+
Any,
14+
Callable,
15+
Dict,
16+
Iterable,
17+
List,
18+
MutableSequence,
19+
Optional,
20+
Sequence,
21+
Tuple,
22+
Union,
23+
)
1324

1425
import torch
1526
from cloudpickle import dumps as cloudpickle_dumps, loads as cloudpickle_loads
@@ -981,20 +992,20 @@ def __init__(
981992
else:
982993
if isinstance(in_keys, (str, tuple)):
983994
in_keys = [in_keys]
984-
elif not isinstance(in_keys, list):
995+
elif not isinstance(in_keys, MutableSequence):
985996
raise ValueError(self._IN_KEY_ERR)
986997
self._kwargs = None
987998

988999
if isinstance(out_keys, (str, tuple)):
9891000
out_keys = [out_keys]
990-
elif not isinstance(out_keys, list):
1001+
elif not isinstance(out_keys, MutableSequence):
9911002
raise ValueError(self._OUT_KEY_ERR)
9921003
try:
993-
in_keys = unravel_key_list(in_keys)
1004+
in_keys = unravel_key_list(list(in_keys))
9941005
except Exception:
9951006
raise ValueError(self._IN_KEY_ERR)
9961007
try:
997-
out_keys = unravel_key_list(out_keys)
1008+
out_keys = unravel_key_list(list(out_keys))
9981009
except Exception:
9991010
raise ValueError(self._OUT_KEY_ERR)
10001011

test/test_nn.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import unittest
1212
import weakref
1313
from collections import OrderedDict
14+
from collections.abc import MutableSequence
1415

1516
import pytest
1617
import torch
@@ -118,6 +119,34 @@ def test_from_str_correct_raise(self, unsupported_type_str):
118119

119120

120121
class TestTDModule:
122+
class MyMutableSequence(MutableSequence):
123+
def __init__(self, initial_data=None):
124+
self._data = [] if initial_data is None else list(initial_data)
125+
126+
def __getitem__(self, index):
127+
return self._data[index]
128+
129+
def __setitem__(self, index, value):
130+
self._data[index] = value
131+
132+
def __delitem__(self, index):
133+
del self._data[index]
134+
135+
def __len__(self):
136+
return len(self._data)
137+
138+
def insert(self, index, value):
139+
self._data.insert(index, value)
140+
141+
def test_mutable_sequence(self):
142+
in_keys = self.MyMutableSequence(["a", "b", "c"])
143+
out_keys = self.MyMutableSequence(["d", "e", "f"])
144+
mod = TensorDictModule(lambda *x: x, in_keys=in_keys, out_keys=out_keys)
145+
td = mod(TensorDict(a=0, b=0, c=0))
146+
assert "d" in td
147+
assert "e" in td
148+
assert "f" in td
149+
121150
def test_auto_unravel(self):
122151
tdm = TensorDictModule(
123152
lambda x: x,

0 commit comments

Comments
 (0)