Skip to content

Commit c93d12c

Browse files
authored
[luci-interpreter] Relax StrideSlice rank limitations (#14507)
This commit removes assert that limit StrideSlice rank to 4. TF 2.8 allows to inference this operator with rank 5. ONE-DCO-1.0-Signed-off-by: Mateusz Bencer <[email protected]>
1 parent 518bd72 commit c93d12c

File tree

2 files changed

+45
-1
lines changed

2 files changed

+45
-1
lines changed

compiler/luci-interpreter/src/kernels/StridedSlice.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ void StridedSlice::configure()
4444
assert(begin()->element_type() == DataType::S32);
4545
assert(end()->element_type() == DataType::S32);
4646
assert(strides()->element_type() == DataType::S32);
47-
assert(input()->shape().num_dims() <= 4);
47+
assert(input()->shape().num_dims() <= 5);
4848
if (params().ellipsis_mask != 0)
4949
{
5050
throw std::runtime_error("ellipsis_mask is not implemented yet.");

compiler/luci-interpreter/src/kernels/StridedSlice.test.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
#include "kernels/TestUtils.h"
1919
#include "luci_interpreter/TestMemoryManager.h"
2020

21+
#include <numeric>
22+
2123
namespace luci_interpreter
2224
{
2325
namespace kernels
@@ -107,6 +109,48 @@ TEST(StridedSliceTest, Uint8)
107109
EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray(output_shape));
108110
}
109111

112+
TEST(StridedSliceTest, 5DCase)
113+
{
114+
std::unique_ptr<IMemoryManager> memory_manager = std::make_unique<TestMemoryManager>();
115+
116+
Shape input_shape{2, 3, 2, 2, 3};
117+
std::vector<float> input_data(input_shape.num_elements());
118+
std::iota(std::begin(input_data), std::end(input_data), 0);
119+
Shape begin_shape{5};
120+
std::vector<int32_t> begin_data{0, 0, 0, 0, 0};
121+
Shape end_shape{5};
122+
std::vector<int32_t> end_data{2, 3, 2, 2, 1};
123+
Shape strides_shape{5};
124+
std::vector<int32_t> strides_data{1, 1, 1, 1, 1};
125+
Tensor input_tensor =
126+
makeInputTensor<DataType::U8>(input_shape, 1.0f, 0, input_data, memory_manager.get());
127+
Tensor begin_tensor =
128+
makeInputTensor<DataType::S32>(begin_shape, begin_data, memory_manager.get());
129+
Tensor end_tensor = makeInputTensor<DataType::S32>(end_shape, end_data, memory_manager.get());
130+
Tensor strides_tensor =
131+
makeInputTensor<DataType::S32>(strides_shape, strides_data, memory_manager.get());
132+
Tensor output_tensor = makeOutputTensor(DataType::U8, 1.0f, 0);
133+
134+
StridedSliceParams params{};
135+
params.begin_mask = 0;
136+
params.end_mask = 0;
137+
params.ellipsis_mask = 0;
138+
params.new_axis_mask = 0;
139+
params.shrink_axis_mask = 0;
140+
141+
StridedSlice kernel(&input_tensor, &begin_tensor, &end_tensor, &strides_tensor, &output_tensor,
142+
params);
143+
kernel.configure();
144+
memory_manager->allocate_memory(output_tensor);
145+
kernel.execute();
146+
147+
std::vector<int32_t> output_shape{2, 3, 2, 2, 1};
148+
std::vector<float> output_data{0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33,
149+
36, 39, 42, 45, 48, 51, 54, 57, 60, 63, 66, 69};
150+
EXPECT_THAT(dequantizeTensorData(output_tensor), FloatArrayNear(output_data));
151+
EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray(output_shape));
152+
}
153+
110154
} // namespace
111155
} // namespace kernels
112156
} // namespace luci_interpreter

0 commit comments

Comments
 (0)