|
18 | 18 | #include "kernels/TestUtils.h" |
19 | 19 | #include "luci_interpreter/TestMemoryManager.h" |
20 | 20 |
|
| 21 | +#include <numeric> |
| 22 | + |
21 | 23 | namespace luci_interpreter |
22 | 24 | { |
23 | 25 | namespace kernels |
@@ -107,6 +109,48 @@ TEST(StridedSliceTest, Uint8) |
107 | 109 | EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray(output_shape)); |
108 | 110 | } |
109 | 111 |
|
| 112 | +TEST(StridedSliceTest, 5DCase) |
| 113 | +{ |
| 114 | + std::unique_ptr<IMemoryManager> memory_manager = std::make_unique<TestMemoryManager>(); |
| 115 | + |
| 116 | + Shape input_shape{2, 3, 2, 2, 3}; |
| 117 | + std::vector<float> input_data(input_shape.num_elements()); |
| 118 | + std::iota(std::begin(input_data), std::end(input_data), 0); |
| 119 | + Shape begin_shape{5}; |
| 120 | + std::vector<int32_t> begin_data{0, 0, 0, 0, 0}; |
| 121 | + Shape end_shape{5}; |
| 122 | + std::vector<int32_t> end_data{2, 3, 2, 2, 1}; |
| 123 | + Shape strides_shape{5}; |
| 124 | + std::vector<int32_t> strides_data{1, 1, 1, 1, 1}; |
| 125 | + Tensor input_tensor = |
| 126 | + makeInputTensor<DataType::U8>(input_shape, 1.0f, 0, input_data, memory_manager.get()); |
| 127 | + Tensor begin_tensor = |
| 128 | + makeInputTensor<DataType::S32>(begin_shape, begin_data, memory_manager.get()); |
| 129 | + Tensor end_tensor = makeInputTensor<DataType::S32>(end_shape, end_data, memory_manager.get()); |
| 130 | + Tensor strides_tensor = |
| 131 | + makeInputTensor<DataType::S32>(strides_shape, strides_data, memory_manager.get()); |
| 132 | + Tensor output_tensor = makeOutputTensor(DataType::U8, 1.0f, 0); |
| 133 | + |
| 134 | + StridedSliceParams params{}; |
| 135 | + params.begin_mask = 0; |
| 136 | + params.end_mask = 0; |
| 137 | + params.ellipsis_mask = 0; |
| 138 | + params.new_axis_mask = 0; |
| 139 | + params.shrink_axis_mask = 0; |
| 140 | + |
| 141 | + StridedSlice kernel(&input_tensor, &begin_tensor, &end_tensor, &strides_tensor, &output_tensor, |
| 142 | + params); |
| 143 | + kernel.configure(); |
| 144 | + memory_manager->allocate_memory(output_tensor); |
| 145 | + kernel.execute(); |
| 146 | + |
| 147 | + std::vector<int32_t> output_shape{2, 3, 2, 2, 1}; |
| 148 | + std::vector<float> output_data{0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, |
| 149 | + 36, 39, 42, 45, 48, 51, 54, 57, 60, 63, 66, 69}; |
| 150 | + EXPECT_THAT(dequantizeTensorData(output_tensor), FloatArrayNear(output_data)); |
| 151 | + EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray(output_shape)); |
| 152 | +} |
| 153 | + |
110 | 154 | } // namespace |
111 | 155 | } // namespace kernels |
112 | 156 | } // namespace luci_interpreter |
0 commit comments