diff --git a/tensorflow/lite/micro/kernels/BUILD b/tensorflow/lite/micro/kernels/BUILD index 9d26e508c61..698855ec015 100644 --- a/tensorflow/lite/micro/kernels/BUILD +++ b/tensorflow/lite/micro/kernels/BUILD @@ -239,6 +239,7 @@ tflm_kernel_cc_library( "decode.cc", "decode_state.cc", "decode_state_lut.cc", + "decode_state_prune.cc", "depth_to_space.cc", "depthwise_conv.cc", "depthwise_conv_common.cc", @@ -332,6 +333,7 @@ tflm_kernel_cc_library( "conv.h", "decode_state.h", "decode_state_lut.h", + "decode_state_prune.h", "depthwise_conv.h", "dequantize.h", "ethosu.h", diff --git a/tensorflow/lite/micro/kernels/decode.cc b/tensorflow/lite/micro/kernels/decode.cc index 778a516224c..1e06a3390ef 100644 --- a/tensorflow/lite/micro/kernels/decode.cc +++ b/tensorflow/lite/micro/kernels/decode.cc @@ -63,6 +63,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { break; } + TF_LITE_ENSURE(context, IsConstantTensor(input)); + TF_LITE_ENSURE(context, IsConstantTensor(ancillary)); + if (DecodeState::Version(*ancillary) != 1) { MicroPrintf("version %u != 1", DecodeState::Version(*ancillary)); status = kTfLiteError; @@ -75,6 +78,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { dsp = DecodeState::CreateDecodeStateLUT( context, micro_context->GetAlternateProfiler()); break; + case DecodeState::kDcmTypePrune: + dsp = DecodeState::CreateDecodeStatePrune( + context, micro_context->GetAlternateProfiler()); + break; case DecodeState::kDcmTypeCustom: MicroPrintf("Custom decode type not yet supported"); break; diff --git a/tensorflow/lite/micro/kernels/decode_state.cc b/tensorflow/lite/micro/kernels/decode_state.cc index a55b4b4148b..adcdf913be8 100644 --- a/tensorflow/lite/micro/kernels/decode_state.cc +++ b/tensorflow/lite/micro/kernels/decode_state.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/lite/micro/kernels/decode_state.h" #include "tensorflow/lite/micro/kernels/decode_state_lut.h" +#include "tensorflow/lite/micro/kernels/decode_state_prune.h" #include "tensorflow/lite/micro/micro_context.h" namespace tflite { @@ -33,4 +34,17 @@ DecodeState* DecodeState::CreateDecodeStateLUT( return dsp; } +DecodeState* DecodeState::CreateDecodeStatePrune( + const TfLiteContext* context, MicroProfilerInterface* profiler) { + MicroContext* const micro_context = GetMicroContext(context); + void* buffer = + micro_context->AllocatePersistentBuffer(sizeof(DecodeStatePrune)); + if (buffer == nullptr) { + return nullptr; + } + DecodeState* dsp = new (buffer) DecodeStatePrune(context, profiler); + + return dsp; +} + } // namespace tflite diff --git a/tensorflow/lite/micro/kernels/decode_state.h b/tensorflow/lite/micro/kernels/decode_state.h index 3818781b9dc..baebfb5ea63 100644 --- a/tensorflow/lite/micro/kernels/decode_state.h +++ b/tensorflow/lite/micro/kernels/decode_state.h @@ -43,6 +43,8 @@ class DecodeState { static DecodeState* CreateDecodeStateLUT(const TfLiteContext* context, MicroProfilerInterface* profiler); + static DecodeState* CreateDecodeStatePrune(const TfLiteContext* context, + MicroProfilerInterface* profiler); static uint8_t Type(const TfLiteTensor& ancillary) { return GetTensorData(&ancillary)[kDcmDecodeTypeOffset]; @@ -66,6 +68,7 @@ class DecodeState { // Decode Common Metadata constants public: static constexpr uint8_t kDcmTypeLUT = 0; + static constexpr uint8_t kDcmTypePrune = 2; static constexpr uint8_t kDcmTypeCustom = 127; static constexpr size_t kDcmSizeInBytes = 16; diff --git a/tensorflow/lite/micro/kernels/decode_state_prune.cc b/tensorflow/lite/micro/kernels/decode_state_prune.cc new file mode 100644 index 00000000000..aadfd8445ee --- /dev/null +++ b/tensorflow/lite/micro/kernels/decode_state_prune.cc @@ -0,0 +1,206 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/micro/kernels/decode_state_prune.h" + +#include +#include + +#include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/micro/micro_context.h" +#include "tensorflow/lite/micro/micro_log.h" +#include "tensorflow/lite/micro/micro_profiler.h" + +namespace tflite { + +TfLiteStatus DecodeStatePrune::Setup(const TfLiteTensor& input, + const TfLiteTensor& ancillary, + const TfLiteTensor& output) { + const uint8_t* const ancillary_data = GetTensorData(&ancillary); + if (ancillary_data[kDcmVersionOffset] != 1) { + MicroPrintf("unsupported version %u", ancillary_data[kDcmVersionOffset]); + return kTfLiteError; + } + + // resolve num_channels_, use_alternate_axis_, and zero points + if (output.quantization.type == kTfLiteAffineQuantization && + output.quantization.params != nullptr) { + const TfLiteAffineQuantization* quantization = + reinterpret_cast(output.quantization.params); + num_channels_ = quantization->scale->size; + if ((quantization->quantized_dimension == output.dims->size - 1) && + num_channels_ > 1) { + use_alternate_axis_ = true; + } else if (quantization->quantized_dimension != 0) { + MicroPrintf("unsupported quantization axis %u", + quantization->quantized_dimension); + return kTfLiteError; + } + + TFLITE_DCHECK(num_channels_ == + static_cast(quantization->zero_point->size)); + bool has_non_zero_zp = + std::any_of(quantization->zero_point->data, + quantization->zero_point->data + num_channels_, + [](int zp) { return zp != 0; }); + + if (output.type != kTfLiteInt8) { + // make sure all zero points are 0 (zero) + TF_LITE_ENSURE_MSG(const_cast(context_), + has_non_zero_zp == false, + "All zero-points must be zero"); + } + + if (num_channels_ > 1 && has_non_zero_zp) { + // copy zero points + MicroContext* micro_context = GetMicroContext(context_); + const size_t bufsize = num_channels_ * sizeof(*zero_points_); + zero_points_ = static_cast( + micro_context->AllocatePersistentBuffer(bufsize)); + if (zero_points_ == nullptr) { + MicroPrintf("unable to allocate zero_points_"); + return kTfLiteError; + } + std::copy_n(quantization->zero_point->data, num_channels_, zero_points_); + } else { + single_zero_point_ = quantization->zero_point->data[0]; + } + } + + compressed_indices_ = GetTensorData(&input); + count_indices_ = NumElements(&output); + elements_per_channel_ = + use_alternate_axis_ ? 1 : count_indices_ / num_channels_; + value_table_ = &ancillary_data[kDcmSizeInBytes]; + + return kTfLiteOk; +} + +TfLiteStatus DecodeStatePrune::Decode(const TfLiteEvalTensor& input, + const TfLiteEvalTensor& ancillary, + const TfLiteEvalTensor& output) { + void* const buffer = const_cast(micro::GetTensorData(&output)); + TFLITE_DCHECK(buffer != nullptr); + + switch (output.type) { + case kTfLiteBool: + DecompressToBuffer(buffer); + break; + case kTfLiteFloat32: + DecompressToBuffer(buffer); + break; + case kTfLiteInt8: + if (num_channels_ > 1 && zero_points_ != nullptr) { + DecompressToBufferPerChannelInt8(buffer); + } else { + DecompressToBuffer(buffer); + } + break; + case kTfLiteInt16: + DecompressToBuffer(buffer); + break; + case kTfLiteInt32: + DecompressToBuffer(buffer); + break; + case kTfLiteInt64: + DecompressToBuffer(buffer); + break; + default: + MicroPrintf("unsupported tensor type %s", TfLiteTypeGetName(output.type)); + return kTfLiteError; + } + + return kTfLiteOk; +} + +template +void DecodeStatePrune::DecompressToBuffer(void* vp) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + T* buffer = static_cast(vp); + const T* value_table = static_cast(value_table_); + const size_t max_count = count_indices_; + const uint8_t* const indices = compressed_indices_; + + for (size_t index = 0; index < max_count; index++) { + size_t shift = ~index & 0b111; + size_t is_not_zp = (indices[index >> 3] >> shift) & 0b1; + + if (is_not_zp) { + *buffer++ = *value_table++; + } else { + *buffer++ = single_zero_point_; + } + } +} + +void DecodeStatePrune::DecompressToBufferPerChannelInt8(void* vp) { + TFLITE_DCHECK(zero_points_ != nullptr); + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + int8_t* buffer = static_cast(vp); + size_t current_offset = 0; + const uint8_t* const indices = compressed_indices_; + const int8_t* value_table = static_cast(value_table_); + + if (use_alternate_axis_) { + const size_t max_channels = num_channels_; + size_t count = count_indices_; + + while (count > 0) { + for (size_t channel = 0; channel < max_channels; channel++) { + const int8_t zp = zero_points_[channel]; + size_t shift = ~current_offset & 0b111; + size_t is_not_zp = (indices[current_offset >> 3] >> shift) & 0b1; + + if (is_not_zp) { + *buffer++ = *value_table++; + } else { + *buffer++ = zp; + } + current_offset++; + } + count -= max_channels; + } + } else { + const size_t max_count = elements_per_channel_; + + for (size_t channel = 0; channel < num_channels_; channel++) { + size_t count = max_count; + const int8_t zp = zero_points_[channel]; + + while (count-- > 0) { + size_t shift = ~current_offset & 0b111; + size_t is_not_zp = (indices[current_offset >> 3] >> shift) & 0b1; + + if (is_not_zp) { + *buffer++ = *value_table++; + } else { + *buffer++ = zp; + } + current_offset++; + } + } + } +} + +template void DecodeStatePrune::DecompressToBuffer(void*); +template void DecodeStatePrune::DecompressToBuffer(void*); +template void DecodeStatePrune::DecompressToBuffer(void*); +template void DecodeStatePrune::DecompressToBuffer(void*); + +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/decode_state_prune.h b/tensorflow/lite/micro/kernels/decode_state_prune.h new file mode 100644 index 00000000000..de5ddd84249 --- /dev/null +++ b/tensorflow/lite/micro/kernels/decode_state_prune.h @@ -0,0 +1,69 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_MICRO_MICRO_KERNELS_DECODE_STATE_PRUNE_H_ +#define TENSORFLOW_LITE_MICRO_MICRO_KERNELS_DECODE_STATE_PRUNE_H_ + +#include + +#include "tensorflow/lite/micro/compatibility.h" +#include "tensorflow/lite/micro/kernels/decode_state.h" + +namespace tflite { + +struct DecodeStatePrune : public DecodeState { + DecodeStatePrune() = delete; + + DecodeStatePrune(const TfLiteContext* context, + MicroProfilerInterface* profiler) + : DecodeState(context, profiler) {} + + virtual TfLiteStatus Setup(const TfLiteTensor& input, + const TfLiteTensor& ancillary, + const TfLiteTensor& output) override; + virtual TfLiteStatus Decode(const TfLiteEvalTensor& input, + const TfLiteEvalTensor& ancillary, + const TfLiteEvalTensor& output) override; + + private: + // Prune Decode Common Metadata constants + static constexpr size_t kDcmVersionOffset = 4; + + protected: + virtual ~DecodeStatePrune() = default; + + template + void DecompressToBuffer(void* buffer); + + void DecompressToBufferPerChannelInt8(void* buffer); + + protected: + const uint8_t* compressed_indices_ = nullptr; + size_t count_indices_ = 0; + size_t num_channels_ = 1; + size_t elements_per_channel_ = 0; // computed from use_alternate_axis_ + const void* value_table_ = nullptr; // original non-pruned values + int8_t* zero_points_ = nullptr; // quantized per-channel zero points + int8_t single_zero_point_ = 0; // single channel zero point + bool use_alternate_axis_ = false; // shape channel axis: + // false = first, true = last + + private: + TF_LITE_REMOVE_VIRTUAL_DELETE +}; + +} // namespace tflite + +#endif // TENSORFLOW_LITE_MICRO_MICRO_KERNELS_DECODE_STATE_PRUNE_H_ diff --git a/tensorflow/lite/micro/kernels/decode_test.cc b/tensorflow/lite/micro/kernels/decode_test.cc index 593b608e2ec..42e440a8945 100644 --- a/tensorflow/lite/micro/kernels/decode_test.cc +++ b/tensorflow/lite/micro/kernels/decode_test.cc @@ -59,6 +59,9 @@ struct AncillaryData { T value_table_[N > 0 ? N : 1]; // assure not zero length }; +// +// LUT test data +// constexpr int kBitWidthLUT = 2; constexpr int8_t kAncillaryDataLUT0[] = {1, 2, 3, 4}; @@ -98,6 +101,169 @@ constexpr int kEncodedShapeLUT[] = {1, sizeof(kEncodedLUT)}; constexpr int8_t kExpectLUT0[] = {1, 2, 3, 4, 4, 3, 2, 1}; constexpr int16_t kExpectLUT1[] = {5, 6, 7, 8, 8, 7, 6, 5}; +// +// Prune test data +// +constexpr int8_t kAncillaryDataPrune0[] = { + 1, 2, 3, 4, 1, // chan 0 + 2, 3, 4, 1, 2, // chan 0 + 3, 4, 1, 2, 3, // chan 0 + 4, 1, 2, 3, 4, // chan 0 + 11, 12, 13, 14, 11, // chan 1 + 12, 13, 14, 11, 12, // chan 1 + 13, 14, 11, 12, 13, // chan 1 + 14, 11, 12, 13, 14 // chan 1 +}; +constexpr int16_t kAncillaryDataPrune1[] = { + 5, 6, 7, 8, 5, // chan 0 + 6, 7, 8, 5, 6, // chan 0 + 7, 8, 5, 6, 7, // chan 0 + 8, 5, 6, 7, 8, // chan 0 + 15, 16, 17, 18, 15, // chan 1 + 16, 17, 18, 15, 16, // chan 1 + 17, 18, 15, 16, 17, // chan 1 + 18, 15, 16, 17, 18 // chan 1 +}; +constexpr float kAncillaryDataPrune2[] = { + 9.0f, 10.0f, 11.0f, 12.0f, // encoded byte 0 + 9.0f, 10.0f, 11.0f, 12.0f, // encoded byte 1 + 9.0f, 10.0f, 11.0f, 12.0f, // encoded byte 2 + 9.0f, 10.0f, 11.0f, 12.0f, // encoded byte 3 + 9.0f, 10.0f, 11.0f, 12.0f, // encoded byte 4 + 19.0f, 20.0f, 21.0f, 22.0f, // encoded byte 5 + 19.0f, 20.0f, 21.0f, 22.0f, // encoded byte 6 + 19.0f, 20.0f, 21.0f, 22.0f, // encoded byte 7 + 19.0f, 20.0f, 21.0f, 22.0f, // encoded byte 8 + 19.0f, 20.0f, 21.0f, 22.0f // encoded byte 9 +}; +constexpr int8_t kAncillaryDataPrune3[] = { + 13, 14, 15, 16, 13, // chan 0 + 14, 15, 16, 13, 14, // chan 0 + 15, 16, 13, 14, 15, // chan 0 + 16, 13, 14, 15, 16, // chan 0 + 113, 114, 115, 116, 113, // chan 1 + 114, 115, 116, 113, 114, // chan 1 + 115, 116, 113, 114, 115, // chan 1 + 116, 113, 114, 115, 116 // chan 1 +}; +constexpr int8_t kAncillaryDataPrune4[] = { + 17, 18, 19, 20, 17, 18, 19, 20, 17, 18, // group 0 + 19, 20, 17, 18, 19, 20, 17, 18, 19, 20, // group 0 + 21, 22, 23, 24, 21, 22, 23, 24, 21, 22, // group 1 + 23, 24, 21, 22, 23, 24, 21, 22, 23, 24, // group 1 +}; +constexpr int8_t kAncillaryDataPrune5[] = { + 13, 14, 15, 16, 13, // chan 0 + 14, 15, 16, 13, 14, // chan 0 + 15, 16, 13, 14, 15, // chan 0 + 16, 13, 14, 15, 16, // chan 0 + 23, 24, 25, 26, 23, // chan 0 + 24, 25, 26, 23, 24, // chan 0 + 25, 26, 23, 24, 25, // chan 0 + 26, 23, 24, 25, 26 // chan 0 +}; + +constexpr uint8_t kDcmPrune[tflite::DecodeState::kDcmSizeInBytes] = { + tflite::DecodeState::kDcmTypePrune, // type: Prune + 1, // DCM version: 1 + 0, // reserved + 0, // reserved + 1, // Prune version: 1 +}; + +// Align the tensor data the same as a Buffer in the TfLite schema. +// Use 0x5A in byte 1 to check byte ordering in the low-level code. +alignas(16) const uint8_t kEncodedPrune[] = {0xA5, 0x5A, 0xA5, 0xA5, 0xA5, + 0xA5, 0xA5, 0xA5, 0xA5, 0xA5}; + +// Tensor shapes as TfLiteIntArray +constexpr int kEncodedShapePrune[] = {1, sizeof(kEncodedPrune)}; +constexpr int kOutputShapePrune[] = {4, 2, 5, 8, 1}; // 2 channels +constexpr int kOutputShapePrune4[] = {4, 1, 2, 1, 40}; // 40 channels, alt-axis +constexpr int kOutputShapePrune5[] = {4, 1, 8, 10, 1}; // 1 channel + +// Quantization datum as TfLiteIntArray. +constexpr int kZeroPointsPrune0[] = {2, -128, 0}; +constexpr int kZeroPointsPrune1[] = {2, 0, 0}; +constexpr int kZeroPointsPrune1_Invalid[] = {2, 0, -1}; +constexpr int kZeroPointsPrune3[] = {2, 0, 0}; +constexpr int kZeroPointsPrune4[] = { + 40, // size + 0, -1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, + -14, -15, -16, -17, -18, -19, 0, -1, -2, -3, -4, -5, -6, -7, + -8, -9, -10, -11, -12, -13, -14, -15, -16, -17, -18, -19, +}; +constexpr int kZeroPointsPrune5[] = {1, -44}; + +constexpr int8_t kExpectPrune0[] = { + 1, -128, 2, -128, -128, 3, -128, 4, // chan 0 + -128, 1, -128, 2, 3, -128, 4, -128, // chan 0 + 1, -128, 2, -128, -128, 3, -128, 4, // chan 0 + 1, -128, 2, -128, -128, 3, -128, 4, // chan 0 + 1, -128, 2, -128, -128, 3, -128, 4, // chan 0 + 11, 0, 12, 0, 0, 13, 0, 14, // chan 1 + 11, 0, 12, 0, 0, 13, 0, 14, // chan 1 + 11, 0, 12, 0, 0, 13, 0, 14, // chan 1 + 11, 0, 12, 0, 0, 13, 0, 14, // chan 1 + 11, 0, 12, 0, 0, 13, 0, 14 // chan 1 +}; +constexpr int16_t kExpectPrune1[] = { + 5, 0, 6, 0, 0, 7, 0, 8, // chan 0 + 0, 5, 0, 6, 7, 0, 8, 0, // chan 0 + 5, 0, 6, 0, 0, 7, 0, 8, // chan 0 + 5, 0, 6, 0, 0, 7, 0, 8, // chan 0 + 5, 0, 6, 0, 0, 7, 0, 8, // chan 0 + 15, 0, 16, 0, 0, 17, 0, 18, // chan 1 + 15, 0, 16, 0, 0, 17, 0, 18, // chan 1 + 15, 0, 16, 0, 0, 17, 0, 18, // chan 1 + 15, 0, 16, 0, 0, 17, 0, 18, // chan 1 + 15, 0, 16, 0, 0, 17, 0, 18 // chan 1 +}; +constexpr float kExpectPrune2[] = { + 9.0f, 0.0f, 10.0f, 0.0f, 0.0f, 11.0f, 0.0f, 12.0f, // encode byte 0 + 0.0f, 9.0f, 0.0f, 10.0f, 11.0f, 0.0f, 12.0f, 0.0f, // encode byte 1 + 9.0f, 0.0f, 10.0f, 0.0f, 0.0f, 11.0f, 0.0f, 12.0f, // encode byte 2 + 9.0f, 0.0f, 10.0f, 0.0f, 0.0f, 11.0f, 0.0f, 12.0f, // encode byte 3 + 9.0f, 0.0f, 10.0f, 0.0f, 0.0f, 11.0f, 0.0f, 12.0f, // encode byte 4 + 19.0f, 0.0f, 20.0f, 0.0f, 0.0f, 21.0f, 0.0f, 22.0f, // encode byte 5 + 19.0f, 0.0f, 20.0f, 0.0f, 0.0f, 21.0f, 0.0f, 22.0f, // encode byte 6 + 19.0f, 0.0f, 20.0f, 0.0f, 0.0f, 21.0f, 0.0f, 22.0f, // encode byte 7 + 19.0f, 0.0f, 20.0f, 0.0f, 0.0f, 21.0f, 0.0f, 22.0f, // encode byte 8 + 19.0f, 0.0f, 20.0f, 0.0f, 0.0f, 21.0f, 0.0f, 22.0f // encode byte 9 +}; +constexpr int8_t kExpectPrune3[] = { + 13, 0, 14, 0, 0, 15, 0, 16, // chan 0 + 0, 13, 0, 14, 15, 0, 16, 0, // chan 0 + 13, 0, 14, 0, 0, 15, 0, 16, // chan 0 + 13, 0, 14, 0, 0, 15, 0, 16, // chan 0 + 13, 0, 14, 0, 0, 15, 0, 16, // chan 0 + 113, 0, 114, 0, 0, 115, 0, 116, // chan 1 + 113, 0, 114, 0, 0, 115, 0, 116, // chan 1 + 113, 0, 114, 0, 0, 115, 0, 116, // chan 1 + 113, 0, 114, 0, 0, 115, 0, 116, // chan 1 + 113, 0, 114, 0, 0, 115, 0, 116 // chan 1 +}; +constexpr int8_t kExpectPrune4[] = { + 17, -1, 18, -3, -4, 19, -6, 20, -8, 17, // group 0 + -10, 18, 19, -13, 20, -15, 17, -17, 18, -19, // group 0 + 0, 19, -2, 20, 17, -5, 18, -7, -8, 19, // group 0 + -10, 20, 17, -13, 18, -15, -16, 19, -18, 20, // group 0 + 21, -1, 22, -3, -4, 23, -6, 24, 21, -9, // group 1 + 22, -11, -12, 23, -14, 24, 21, -17, 22, -19, // group 1 + 0, 23, -2, 24, 21, -5, 22, -7, -8, 23, // group 1 + -10, 24, 21, -13, 22, -15, -16, 23, -18, 24 // group 1 +}; +constexpr int8_t kExpectPrune5[] = { + 13, -44, 14, -44, -44, 15, -44, 16, -44, 13, // chan 0 + -44, 14, 15, -44, 16, -44, 13, -44, 14, -44, // chan 0 + -44, 15, -44, 16, 13, -44, 14, -44, -44, 15, // chan 0 + -44, 16, 13, -44, 14, -44, -44, 15, -44, 16, // chan 0 + 23, -44, 24, -44, -44, 25, -44, 26, 23, -44, // chan 0 + 24, -44, -44, 25, -44, 26, 23, -44, 24, -44, // chan 0 + -44, 25, -44, 26, 23, -44, 24, -44, -44, 25, // chan 0 + -44, 26, 23, -44, 24, -44, -44, 25, -44, 26 // chan 0 +}; + template TfLiteStatus CheckOutput(const TfLiteTensor& output, const void* const expected) { @@ -170,10 +336,14 @@ void TestDecode(const std::initializer_list& encodes, tensors[i] = CreateTensor(tid_encode.data, const_cast(&tid_encode.dims), false, kTfLiteUInt8); + // must be a const tensor + tensors[i].allocation_type = kTfLiteMmapRo; const TensorInDatum& tid_ancillary = *ancillaries.begin()[i / 2]; tensors[i + 1] = CreateTensor( tid_ancillary.data, const_cast(&tid_ancillary.dims), false, kTfLiteUInt8); + // must be a const tensor + tensors[i + 1].allocation_type = kTfLiteMmapRo; } for (size_t i = 0; i < kNumOutputs; i++) { const TensorOutDatum& tod = *outputs.begin()[i]; @@ -204,13 +374,37 @@ TF_LITE_MICRO_TESTS_BEGIN using tflite::testing::AncillaryData; using tflite::testing::kAncillaryDataLUT0; using tflite::testing::kAncillaryDataLUT1; +using tflite::testing::kAncillaryDataPrune0; +using tflite::testing::kAncillaryDataPrune1; +using tflite::testing::kAncillaryDataPrune2; +using tflite::testing::kAncillaryDataPrune3; +using tflite::testing::kAncillaryDataPrune4; +using tflite::testing::kAncillaryDataPrune5; using tflite::testing::kDcmLUT0; using tflite::testing::kDcmLUT1; +using tflite::testing::kDcmPrune; using tflite::testing::kEncodedLUT; +using tflite::testing::kEncodedPrune; using tflite::testing::kEncodedShapeLUT; +using tflite::testing::kEncodedShapePrune; using tflite::testing::kExpectLUT0; using tflite::testing::kExpectLUT1; +using tflite::testing::kExpectPrune0; +using tflite::testing::kExpectPrune1; +using tflite::testing::kExpectPrune2; +using tflite::testing::kExpectPrune3; +using tflite::testing::kExpectPrune4; +using tflite::testing::kExpectPrune5; using tflite::testing::kOutputShapeLUT; +using tflite::testing::kOutputShapePrune; +using tflite::testing::kOutputShapePrune4; +using tflite::testing::kOutputShapePrune5; +using tflite::testing::kZeroPointsPrune0; +using tflite::testing::kZeroPointsPrune1; +using tflite::testing::kZeroPointsPrune1_Invalid; +using tflite::testing::kZeroPointsPrune3; +using tflite::testing::kZeroPointsPrune4; +using tflite::testing::kZeroPointsPrune5; using tflite::testing::TensorInDatum; using tflite::testing::TensorOutDatum; @@ -336,4 +530,387 @@ TF_LITE_MICRO_TEST(DecodeTwoTensorsLUT) { encodes, ancillaries, outputs, expected, tflite::Register_DECODE()); } +TF_LITE_MICRO_TEST(DecodePruneFloat) { + // Align the tensor data the same as a Buffer in the TfLite schema + alignas(16) float output_data[std::size(kExpectPrune2)] = {}; + alignas(16) const AncillaryData + kAncillaryData = {{kDcmPrune}, {kAncillaryDataPrune2}}; + + const TfLiteIntArray* const kEncodedDims = + tflite::testing::IntArrayFromInts(kEncodedShapePrune); + static const TensorInDatum kEncodeTID = { + kEncodedPrune, + *kEncodedDims, + }; + static constexpr std::initializer_list kEncodes = { + &kEncodeTID, + }; + + constexpr int kAncillaryShape[] = {1, sizeof(kAncillaryData)}; + const TfLiteIntArray* const kAncillaryDims = + tflite::testing::IntArrayFromInts(kAncillaryShape); + static const TensorInDatum kAncillaryTID = { + &kAncillaryData, + *kAncillaryDims, + }; + static constexpr std::initializer_list kAncillaries = { + &kAncillaryTID}; + + const TfLiteIntArray* const kOutputDims = + tflite::testing::IntArrayFromInts(kOutputShapePrune); + constexpr int kOutputZeroPointsData[] = {0}; + const TfLiteIntArray* const kOutputZeroPoints = + tflite::testing::IntArrayFromInts(kOutputZeroPointsData); + const TfLiteFloatArray kOutputScales = {kOutputZeroPoints->size}; + static const TensorOutDatum kTOD = { + output_data, + *kOutputDims, + kTfLiteFloat32, + kOutputScales, + *kOutputZeroPoints, + 0, + {}, + }; + static constexpr std::initializer_list kOutputs = { + &kTOD}; + + const std::initializer_list kExpected = {kExpectPrune2}; + + tflite::testing::TestDecode( + kEncodes, kAncillaries, kOutputs, kExpected, tflite::Register_DECODE()); +} + +TF_LITE_MICRO_TEST(DecodePruneInt8) { + // Align the tensor data the same as a Buffer in the TfLite schema + alignas(16) int8_t output_data[std::size(kExpectPrune3)] = {}; + alignas(16) const AncillaryData + kAncillaryData = {{kDcmPrune}, {kAncillaryDataPrune3}}; + + const TfLiteIntArray* const kEncodedDims = + tflite::testing::IntArrayFromInts(kEncodedShapePrune); + static const TensorInDatum kEncodeTID = { + kEncodedPrune, + *kEncodedDims, + }; + static constexpr std::initializer_list kEncodes = { + &kEncodeTID, + }; + + constexpr int kAncillaryShape[] = {1, sizeof(kAncillaryData)}; + const TfLiteIntArray* const kAncillaryDims = + tflite::testing::IntArrayFromInts(kAncillaryShape); + static const TensorInDatum kAncillaryTID = { + &kAncillaryData, + *kAncillaryDims, + }; + static constexpr std::initializer_list kAncillaries = { + &kAncillaryTID}; + + const TfLiteIntArray* const kOutputDims = + tflite::testing::IntArrayFromInts(kOutputShapePrune); + constexpr int kOutputZeroPointsData[] = {0}; + const TfLiteIntArray* const kOutputZeroPoints = + tflite::testing::IntArrayFromInts(kOutputZeroPointsData); + const TfLiteFloatArray kOutputScales = {kOutputZeroPoints->size}; + static const TensorOutDatum kTOD = { + output_data, *kOutputDims, kTfLiteInt8, kOutputScales, *kOutputZeroPoints, + 0, {}, + }; + static constexpr std::initializer_list kOutputs = { + &kTOD}; + + const std::initializer_list kExpected = {kExpectPrune3}; + + tflite::testing::TestDecode( + kEncodes, kAncillaries, kOutputs, kExpected, tflite::Register_DECODE()); +} + +TF_LITE_MICRO_TEST(DecodePruneQuantizedInt8) { + // Align the tensor data the same as a Buffer in the TfLite schema + alignas(16) int8_t output_data[std::size(kExpectPrune3)] = {}; + alignas(16) const AncillaryData + kAncillaryData = {{kDcmPrune}, {kAncillaryDataPrune3}}; + + const TfLiteIntArray* const kEncodedDims = + tflite::testing::IntArrayFromInts(kEncodedShapePrune); + static const TensorInDatum kEncodeTID = { + kEncodedPrune, + *kEncodedDims, + }; + static constexpr std::initializer_list kEncodes = { + &kEncodeTID, + }; + + constexpr int kAncillaryShape[] = {1, sizeof(kAncillaryData)}; + const TfLiteIntArray* const kAncillaryDims = + tflite::testing::IntArrayFromInts(kAncillaryShape); + static const TensorInDatum kAncillaryTID = { + &kAncillaryData, + *kAncillaryDims, + }; + static constexpr std::initializer_list kAncillaries = { + &kAncillaryTID}; + + const TfLiteIntArray* const kOutputDims = + tflite::testing::IntArrayFromInts(kOutputShapePrune); + const TfLiteIntArray* const kOutputZeroPoints = + tflite::testing::IntArrayFromInts(kZeroPointsPrune3); + const TfLiteFloatArray kOutputScales = {kOutputZeroPoints->size}; + static const TensorOutDatum kTOD = { + output_data, *kOutputDims, kTfLiteInt8, kOutputScales, *kOutputZeroPoints, + 0, {}, + }; + static constexpr std::initializer_list kOutputs = { + &kTOD}; + + const std::initializer_list kExpected = {kExpectPrune3}; + + tflite::testing::TestDecode( + kEncodes, kAncillaries, kOutputs, kExpected, tflite::Register_DECODE()); +} + +TF_LITE_MICRO_TEST(DecodePruneQuantizedMixedZeroPointInt8) { + // Align the tensor data the same as a Buffer in the TfLite schema + alignas(16) int8_t output_data[std::size(kExpectPrune0)] = {}; + alignas(16) const AncillaryData + kAncillaryData = {{kDcmPrune}, {kAncillaryDataPrune0}}; + + const TfLiteIntArray* const kEncodedDims = + tflite::testing::IntArrayFromInts(kEncodedShapePrune); + static const TensorInDatum kEncodeTID = { + kEncodedPrune, + *kEncodedDims, + }; + static constexpr std::initializer_list kEncodes = { + &kEncodeTID, + }; + + constexpr int kAncillaryShape[] = {1, sizeof(kAncillaryData)}; + const TfLiteIntArray* const kAncillaryDims = + tflite::testing::IntArrayFromInts(kAncillaryShape); + static const TensorInDatum kAncillaryTID = { + &kAncillaryData, + *kAncillaryDims, + }; + static constexpr std::initializer_list kAncillaries = { + &kAncillaryTID}; + + const TfLiteIntArray* const kOutputDims = + tflite::testing::IntArrayFromInts(kOutputShapePrune); + const TfLiteIntArray* const kOutputZeroPoints = + tflite::testing::IntArrayFromInts(kZeroPointsPrune0); + const TfLiteFloatArray kOutputScales = {kOutputZeroPoints->size}; + static const TensorOutDatum kTOD = { + output_data, *kOutputDims, kTfLiteInt8, kOutputScales, *kOutputZeroPoints, + 0, {}, + }; + static constexpr std::initializer_list kOutputs = { + &kTOD}; + + const std::initializer_list kExpected = {kExpectPrune0}; + + tflite::testing::TestDecode( + kEncodes, kAncillaries, kOutputs, kExpected, tflite::Register_DECODE()); +} + +TF_LITE_MICRO_TEST(DecodePruneQuantizedSingleChannelInt8) { + // Align the tensor data the same as a Buffer in the TfLite schema + alignas(16) int8_t output_data[std::size(kExpectPrune5)] = {}; + alignas(16) const AncillaryData + kAncillaryData = {{kDcmPrune}, {kAncillaryDataPrune5}}; + + const TfLiteIntArray* const kEncodedDims = + tflite::testing::IntArrayFromInts(kEncodedShapePrune); + static const TensorInDatum kEncodeTID = { + kEncodedPrune, + *kEncodedDims, + }; + static constexpr std::initializer_list kEncodes = { + &kEncodeTID, + }; + + constexpr int kAncillaryShape[] = {1, sizeof(kAncillaryData)}; + const TfLiteIntArray* const kAncillaryDims = + tflite::testing::IntArrayFromInts(kAncillaryShape); + static const TensorInDatum kAncillaryTID = { + &kAncillaryData, + *kAncillaryDims, + }; + static constexpr std::initializer_list kAncillaries = { + &kAncillaryTID}; + + const TfLiteIntArray* const kOutputDims = + tflite::testing::IntArrayFromInts(kOutputShapePrune5); + const TfLiteIntArray* const kOutputZeroPoints = + tflite::testing::IntArrayFromInts(kZeroPointsPrune5); + const TfLiteFloatArray kOutputScales = {kOutputZeroPoints->size}; + static const TensorOutDatum kTOD = { + output_data, *kOutputDims, kTfLiteInt8, kOutputScales, *kOutputZeroPoints, + 0, {}, + }; + static constexpr std::initializer_list kOutputs = { + &kTOD}; + + const std::initializer_list kExpected = {kExpectPrune5}; + + tflite::testing::TestDecode( + kEncodes, kAncillaries, kOutputs, kExpected, tflite::Register_DECODE()); +} + +TF_LITE_MICRO_TEST(DecodePruneQuantizedAltAxisInt8) { + // Align the tensor data the same as a Buffer in the TfLite schema + alignas(16) int8_t output_data[std::size(kExpectPrune4)] = {}; + alignas(16) const AncillaryData + kAncillaryData = {{kDcmPrune}, {kAncillaryDataPrune4}}; + + const TfLiteIntArray* const kEncodedDims = + tflite::testing::IntArrayFromInts(kEncodedShapePrune); + static const TensorInDatum kEncodeTID = { + kEncodedPrune, + *kEncodedDims, + }; + static constexpr std::initializer_list kEncodes = { + &kEncodeTID, + }; + + constexpr int kAncillaryShape[] = {1, sizeof(kAncillaryData)}; + const TfLiteIntArray* const kAncillaryDims = + tflite::testing::IntArrayFromInts(kAncillaryShape); + static const TensorInDatum kAncillaryTID = { + &kAncillaryData, + *kAncillaryDims, + }; + static constexpr std::initializer_list kAncillaries = { + &kAncillaryTID}; + + const TfLiteIntArray* const kOutputDims = + tflite::testing::IntArrayFromInts(kOutputShapePrune4); + const TfLiteIntArray* const kOutputZeroPoints = + tflite::testing::IntArrayFromInts(kZeroPointsPrune4); + const TfLiteFloatArray kOutputScales = {kOutputZeroPoints->size}; + static const TensorOutDatum kTOD = { + output_data, + *kOutputDims, + kTfLiteInt8, + kOutputScales, + *kOutputZeroPoints, + (kOutputDims->size - 1), + {}, + }; + static constexpr std::initializer_list kOutputs = { + &kTOD}; + + const std::initializer_list kExpected = {kExpectPrune4}; + + tflite::testing::TestDecode( + kEncodes, kAncillaries, kOutputs, kExpected, tflite::Register_DECODE()); +} + +TF_LITE_MICRO_TEST(DecodePruneQuantizedInt16) { + // Align the tensor data the same as a Buffer in the TfLite schema + alignas(16) int16_t output_data[std::size(kExpectPrune1)] = {}; + alignas(16) const AncillaryData + kAncillaryData = {{kDcmPrune}, {kAncillaryDataPrune1}}; + + const TfLiteIntArray* const kEncodedDims = + tflite::testing::IntArrayFromInts(kEncodedShapePrune); + static const TensorInDatum kEncodeTID = { + kEncodedPrune, + *kEncodedDims, + }; + static constexpr std::initializer_list kEncodes = { + &kEncodeTID, + }; + + constexpr int kAncillaryShape[] = {1, sizeof(kAncillaryData)}; + const TfLiteIntArray* const kAncillaryDims = + tflite::testing::IntArrayFromInts(kAncillaryShape); + static const TensorInDatum kAncillaryTID = { + &kAncillaryData, + *kAncillaryDims, + }; + static constexpr std::initializer_list kAncillaries = { + &kAncillaryTID}; + + const TfLiteIntArray* const kOutputDims = + tflite::testing::IntArrayFromInts(kOutputShapePrune); + const TfLiteIntArray* const kOutputZeroPoints = + tflite::testing::IntArrayFromInts(kZeroPointsPrune1); + const TfLiteFloatArray kOutputScales = {kOutputZeroPoints->size}; + static const TensorOutDatum kTOD = { + output_data, + *kOutputDims, + kTfLiteInt16, + kOutputScales, + *kOutputZeroPoints, + 0, + {}, + }; + static constexpr std::initializer_list kOutputs = { + &kTOD}; + + const std::initializer_list kExpected = {kExpectPrune1}; + + tflite::testing::TestDecode( + kEncodes, kAncillaries, kOutputs, kExpected, tflite::Register_DECODE()); +} + +TF_LITE_MICRO_TEST(DecodePruneQuantizedInvalidZeroPointInt16) { + // Align the tensor data the same as a Buffer in the TfLite schema + alignas(16) int16_t output_data[std::size(kExpectPrune1)] = {}; + alignas(16) const AncillaryData + kAncillaryData = {{kDcmPrune}, {kAncillaryDataPrune1}}; + + const TfLiteIntArray* const kEncodedDims = + tflite::testing::IntArrayFromInts(kEncodedShapePrune); + static const TensorInDatum kEncodeTID = { + kEncodedPrune, + *kEncodedDims, + }; + static constexpr std::initializer_list kEncodes = { + &kEncodeTID, + }; + + constexpr int kAncillaryShape[] = {1, sizeof(kAncillaryData)}; + const TfLiteIntArray* const kAncillaryDims = + tflite::testing::IntArrayFromInts(kAncillaryShape); + static const TensorInDatum kAncillaryTID = { + &kAncillaryData, + *kAncillaryDims, + }; + static constexpr std::initializer_list kAncillaries = { + &kAncillaryTID}; + + const TfLiteIntArray* const kOutputDims = + tflite::testing::IntArrayFromInts(kOutputShapePrune); + const TfLiteIntArray* const kOutputZeroPoints = + tflite::testing::IntArrayFromInts(kZeroPointsPrune1_Invalid); + const TfLiteFloatArray kOutputScales = {kOutputZeroPoints->size}; + static const TensorOutDatum kTOD = { + output_data, + *kOutputDims, + kTfLiteInt16, + kOutputScales, + *kOutputZeroPoints, + 0, + {}, + }; + static constexpr std::initializer_list kOutputs = { + &kTOD}; + + const std::initializer_list kExpected = {kExpectPrune1}; + + tflite::testing::TestDecode( + kEncodes, kAncillaries, kOutputs, kExpected, tflite::Register_DECODE(), + kTfLiteError); +} + TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/micro/kernels/xtensa/decode_state.cc b/tensorflow/lite/micro/kernels/xtensa/decode_state.cc new file mode 100644 index 00000000000..4feec409e15 --- /dev/null +++ b/tensorflow/lite/micro/kernels/xtensa/decode_state.cc @@ -0,0 +1,70 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/micro/kernels/decode_state.h" + +#include "tensorflow/lite/micro/kernels/decode_state_lut.h" +#include "tensorflow/lite/micro/kernels/decode_state_prune.h" +#include "tensorflow/lite/micro/micro_context.h" + +#ifdef HIFI5 +#include "tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_lut.h" +#include "tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_prune.h" +#endif // HIFI5 + +namespace tflite { + +DecodeState* DecodeState::CreateDecodeStateLUT( + const TfLiteContext* context, MicroProfilerInterface* profiler) { + MicroContext* const micro_context = GetMicroContext(context); +#ifdef HIFI5 + constexpr size_t kBufferSize = sizeof(XtensaDecodeStateLUT); +#else + constexpr size_t kBufferSize = sizeof(DecodeStateLUT); +#endif // HIFI5 + void* buffer = micro_context->AllocatePersistentBuffer(kBufferSize); + if (buffer == nullptr) { + return nullptr; + } +#ifdef HIFI5 + DecodeState* dsp = new (buffer) XtensaDecodeStateLUT(context, profiler); +#else + DecodeState* dsp = new (buffer) DecodeStateLUT(context, profiler); +#endif // HIFI5 + + return dsp; +} + +DecodeState* DecodeState::CreateDecodeStatePrune( + const TfLiteContext* context, MicroProfilerInterface* profiler) { + MicroContext* const micro_context = GetMicroContext(context); +#ifdef HIFI5 + constexpr size_t kBufferSize = sizeof(XtensaDecodeStatePrune); +#else + constexpr size_t kBufferSize = sizeof(DecodeStatePrune); +#endif // HIFI5 + void* buffer = micro_context->AllocatePersistentBuffer(kBufferSize); + if (buffer == nullptr) { + return nullptr; + } +#ifdef HIFI5 + DecodeState* dsp = new (buffer) XtensaDecodeStatePrune(context, profiler); +#else + DecodeState* dsp = new (buffer) DecodeStatePrune(context, profiler); +#endif // HIFI5 + return dsp; +} + +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_lut.cc b/tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_lut.cc new file mode 100644 index 00000000000..de5435f4b00 --- /dev/null +++ b/tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_lut.cc @@ -0,0 +1,609 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_lut.h" + +#include +#include + +#include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/micro/kernels/xtensa/xtensa.h" +#include "tensorflow/lite/micro/micro_log.h" +#include "tensorflow/lite/micro/micro_profiler.h" + +namespace tflite { + +void XtensaDecodeStateLUT::DecompressToBufferWidth4_Xtensa(int8_t* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + ae_int8x8 d_shuffle_t = AE_MOVINT8X8_FROMINT64(0xFB73EA62D951C840LL); + ae_int8x8 d_shuffle_value_t = AE_MOVINT8X8_FROMINT64(0x08192A3B4C5D6E7FLL); + int elements_per_channel_t_by_4 = elements_per_channel_ >> 4; + int elements_per_channel_t_rem = elements_per_channel_ & 0xF; + int j; + + ae_int8x8 d_out1, d_out2; + ae_int8x8 d_value_0_t, d_value_1_t; + ae_int8x8 d_value_0, d_value_1; + ae_int8x8 d_index, d_dummy; + + ae_int8x8* __restrict pIn_tmp = (ae_int8x8*)compressed_indices_; + ae_int8* __restrict p_out_tmp = (ae_int8*)buffer; + + const size_t stride = value_table_channel_stride_; + const uint8_t* __restrict value_table = + static_cast(value_table_); + + const uint8_t* __restrict value_table_t = value_table; + + ae_valignx2 align_store = AE_ZALIGN128(); + + for (size_t i = 0; i < num_channels_; i++) { + value_table_t = value_table; + ae_valignx2 align_vtab = AE_LA128_PP(value_table_t); + AE_LA8X8X2_IP(d_value_0_t, d_value_1_t, align_vtab, + (ae_int8x16*)value_table_t); + AE_DSEL8X8(d_value_0, d_value_1, d_value_0_t, d_value_1_t, + d_shuffle_value_t); + + ae_valign align_load = AE_LA64_PP(pIn_tmp); + + for (j = 0; j < elements_per_channel_t_by_4; j++) { + AE_LA8X8_IP(d_index, align_load, pIn_tmp); + AE_DSEL8X8(d_out1, d_out2, d_value_0, d_value_1, d_index); + AE_DSEL8X8(d_out1, d_out2, d_out1, d_out2, d_shuffle_t); + AE_SA8X8X2_IP(d_out1, d_out2, align_store, (ae_int8x16*)p_out_tmp); + } + + value_table += stride; + if (elements_per_channel_t_rem) { + ae_valignx2 align_index = AE_LA128_PP(pIn_tmp); + AE_LAV8X8X2_XP(d_index, d_dummy, align_index, (ae_int8x16*)pIn_tmp, + (elements_per_channel_t_rem >> + 1)); /* Loading 48 bits for decoding 16 weight values */ + AE_DSEL8X8(d_out1, d_out2, d_value_0, d_value_1, d_index); + AE_DSEL8X8(d_out1, d_out2, d_out1, d_out2, d_shuffle_t); + AE_SAV8X8X2_XP(d_out1, d_out2, align_store, (ae_int8x16*)p_out_tmp, + elements_per_channel_t_rem); + } + } + AE_SA128POS_FP(align_store, (ae_int8x16*)p_out_tmp); +} + +void XtensaDecodeStateLUT::DecompressToBufferWidth3_Xtensa(int8_t* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + int i, j; + ae_int8* __restrict p_out_tmp = (ae_int8*)buffer; + ae_int8x8* pIn_tmp = (ae_int8x8*)compressed_indices_; + const uint8_t* __restrict value_table = + static_cast(value_table_); + + const uint8_t* __restrict value_table_t = value_table; + + int num_channels_t = num_channels_; + const size_t stride = value_table_channel_stride_; + + int elements_per_channel_t_by_4 = elements_per_channel_ >> 4; + int elements_per_channel_t_rem = elements_per_channel_ & 0xF; + + ae_int8x8 d_index, d_dummy; + ae_int8x8 d1, d2, d3, d4, d5, d6, d7, d8, d9, d10, d11; + ae_int8x8 d_out1, d_out2; + + ae_valignx2 align_index = AE_LA128_PP(pIn_tmp); + + ae_int8x8 d_shuffle_value_t = AE_MOVINT8X8_FROMINT64(0x08192A3B4C5D6E7FLL); + ae_int8x8 d_shuffle_t1 = AE_MOVINT8X8_FROMINT64(0x0F00050C00020000LL); + ae_int8x8 d_shuffle_t2 = AE_MOVINT8X8_FROMINT64(0x000E00040B000100LL); + ae_int8x8 d_shuffle_t3 = AE_MOVINT8X8_FROMINT64(0x0F060D040C030A01LL); + ae_int8x8 d_shuffle_t = AE_MOVINT8X8_FROMINT64(0xFB73EA62D951C840LL); + + ae_valignx2 align_store = AE_ZALIGN128(); + + for (i = 0; i < num_channels_t; i++) { + ae_int8x8 d_value_0 = AE_MOVINT8X8_FROMINT64(AE_ZERO()); + ae_int8x8 d_value_1 = AE_MOVINT8X8_FROMINT64(AE_ZERO()); + + value_table_t = value_table; + + ae_valign align_vtab = AE_LA64_PP(value_table_t); + AE_LA8X8_IP(d_value_0, align_vtab, (ae_int8x8*)value_table_t); + AE_DSEL8X8(d_value_0, d_value_1, d_value_0, d_value_1, d_shuffle_value_t); + + for (j = 0; j < elements_per_channel_t_by_4; j++) { + AE_LAV8X8X2_XP(d_index, d_dummy, align_index, (ae_int8x16*)pIn_tmp, + 6); /* Loading 48 bits for decoding 16 weight values */ + + d1 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 1)); + d2 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 2)); + d3 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 3)); + d4 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 4)); + + d1 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d1), 0x7007007007000000LL)); + d2 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d2), 0x0700700700700000LL)); + d3 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d3), 0x0070070070070000LL)); + d4 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d4), 0x0007007007007000LL)); + + d5 = d1 | d2; + d6 = d3 | d4; + + d7 = AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d5), 4)); + d8 = AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d6), 4)); + + d9 = AE_SEL8X8(d5, d7, d_shuffle_t1); + d10 = AE_SEL8X8(d6, d8, d_shuffle_t2); + d11 = AE_SEL8X8(d9, d10, d_shuffle_t3); + + AE_DSEL8X8(d_out1, d_out2, d_value_0, d_value_1, d11); + AE_DSEL8X8(d_out1, d_out2, d_out1, d_out2, d_shuffle_t); + + AE_SA8X8X2_IP(d_out1, d_out2, align_store, (ae_int8x16*)p_out_tmp); + } + if (elements_per_channel_t_rem) { + AE_LAV8X8X2_XP(d_index, d_dummy, align_index, (ae_int8x16*)pIn_tmp, + 3); /* Loading 48 bits for decoding 16 weight values */ + + d1 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 1)); + d2 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 2)); + d3 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 3)); + d4 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 4)); + + d1 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d1), 0x7007007007000000LL)); + d2 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d2), 0x0700700700700000LL)); + d3 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d3), 0x0070070070070000LL)); + d4 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d4), 0x0007007007007000LL)); + + d5 = d1 | d2; + d6 = d3 | d4; + + d7 = AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d5), 4)); + d8 = AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d6), 4)); + + d9 = AE_SEL8X8(d5, d7, d_shuffle_t1); + d10 = AE_SEL8X8(d6, d8, d_shuffle_t2); + d11 = AE_SEL8X8(d9, d10, d_shuffle_t3); + + AE_DSEL8X8(d_out1, d_out2, d_value_0, d_value_1, d11); + AE_DSEL8X8(d_out1, d_out2, d_out1, d_out2, d_shuffle_t); + + AE_SAV8X8X2_XP(d_out1, d_out2, align_store, (ae_int8x16*)p_out_tmp, + elements_per_channel_t_rem); + } + + value_table = value_table + stride; + } + AE_SA128POS_FP(align_store, (ae_int8x16*)p_out_tmp); +} + +void XtensaDecodeStateLUT::DecompressToBufferWidth2_Xtensa(int8_t* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + int i, j; + ae_int8* __restrict p_out_tmp = (ae_int8*)buffer; + ae_int8x8* pIn_tmp = (ae_int8x8*)compressed_indices_; + const uint8_t* __restrict value_table = + static_cast(value_table_); + + const uint8_t* __restrict value_table_t = value_table; + + int num_channels_t = num_channels_; + const size_t stride = value_table_channel_stride_; + + int elements_per_channel_t_by_5 = elements_per_channel_ >> 5; + int elements_per_channel_t_rem = elements_per_channel_ & 0x1F; + int elements_per_channel_t_rem_minus_16 = 0; + if (elements_per_channel_t_rem > 16) { + elements_per_channel_t_rem_minus_16 = elements_per_channel_t_rem - 16; + } + + ae_int8x8 d_index, d_dummy; + ae_int8x8 d0, d1, d2, d3, d4, d5; + ae_int8x8 q0, q1, q2, q3; + ae_int8x8 d_out1, d_out2; + + ae_valignx2 align_index = AE_LA128_PP(pIn_tmp); + + ae_int8x8 d_shuffle_value_t = AE_MOVINT8X8_FROMINT64(0x08192A3B4C5D6E7FLL); + ae_int8x8 d_shuffle_t1 = AE_MOVINT8X8_FROMINT64(0xFB73EA62D951C840LL); + ae_int8x8 d_shuffle_t2 = AE_MOVINT8X8_FROMINT64(0xFBEA7362D9C85140LL); + + ae_valignx2 align_store = AE_ZALIGN128(); + + for (i = 0; i < num_channels_t; i++) { + ae_int8x8 d_value_0 = AE_MOVINT8X8_FROMINT64(AE_ZERO()); + ae_int8x8 d_value_1 = AE_MOVINT8X8_FROMINT64(AE_ZERO()); + + value_table_t = value_table; + + ae_valign align_vtab = AE_LA64_PP(value_table_t); + AE_LA8X8_IP(d_value_0, align_vtab, (ae_int8x8*)value_table_t); + AE_DSEL8X8(d_value_0, d_value_1, d_value_0, d_value_1, d_shuffle_value_t); + + for (j = 0; j < elements_per_channel_t_by_5; j++) { + // AE_LA8X8_IP( d_index, align_index, pIn_tmp ); /* Loading 64 bits + // for decoding 32 weight values */ + + AE_LAV8X8X2_XP(d_index, d_dummy, align_index, (ae_int8x16*)pIn_tmp, + 8); /* Loading 64 bits for decoding 32 weight values */ + d0 = d_index; + d1 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 2)); + + d2 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d0), + 0x3333333333333333LL)); // i1,i3,i5, .... + d3 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d1), + 0x3333333333333333LL)); // i0,i2,i4, .... + + AE_DSEL8X8(d4, d5, d3, d2, + d_shuffle_t1); // d4 = i0,i2,i1,i3,i4,i6,... d5 = + // i16,i18, i17,i19, .... + + AE_DSEL8X8(q0, q1, d_value_0, d_value_1, + d4); // q0 = 0,1,4,5,8,9,12,13 q1 = 2,3,6,7,10,11,14,15 + AE_DSEL8X8( + q2, q3, d_value_0, d_value_1, + d5); // q2 = 16,17,20,21,24,25,28,29 q3 = 18,19,22,23,26,27,30,31 + + AE_DSEL8X8(d_out1, d_out2, q0, q1, d_shuffle_t2); + AE_SA8X8X2_IP(d_out1, d_out2, align_store, (ae_int8x16*)p_out_tmp); + + AE_DSEL8X8(d_out1, d_out2, q2, q3, d_shuffle_t2); + AE_SA8X8X2_IP(d_out1, d_out2, align_store, (ae_int8x16*)p_out_tmp); + } + if (elements_per_channel_t_rem) { + AE_LAV8X8X2_XP(d_index, d_dummy, align_index, (ae_int8x16*)pIn_tmp, + (elements_per_channel_t_rem >> + 2)); /* Loading 48 bits for decoding 16 weight values */ + d0 = d_index; + d1 = + AE_MOVINT8X8_FROMINT64(AE_SRLI64(AE_MOVINT64_FROMINT8X8(d_index), 2)); + d2 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d0), + 0x3333333333333333LL)); // i1,i3,i5, .... + d3 = AE_MOVINT8X8_FROMINT64( + AE_AND64(AE_MOVINT64_FROMINT8X8(d1), + 0x3333333333333333LL)); // i0,i2,i4, .... + + AE_DSEL8X8(d4, d5, d3, d2, + d_shuffle_t1); // d4 = i0,i2,i1,i3,i4,i6,... d5 = + // i16,i18, i17,i19, .... + + AE_DSEL8X8(q0, q1, d_value_0, d_value_1, + d4); // q0 = 0,1,4,5,8,9,12,13 q1 = 2,3,6,7,10,11,14,15 + AE_DSEL8X8( + q2, q3, d_value_0, d_value_1, + d5); // q2 = 16,17,20,21,24,25,28,29 q3 = 18,19,22,23,26,27,30,31 + + AE_DSEL8X8(d_out1, d_out2, q0, q1, d_shuffle_t2); + + AE_SAV8X8X2_XP(d_out1, d_out2, align_store, (ae_int8x16*)p_out_tmp, + elements_per_channel_t_rem); + + AE_DSEL8X8(d_out1, d_out2, q2, q3, d_shuffle_t2); + + AE_SAV8X8X2_XP(d_out1, d_out2, align_store, (ae_int8x16*)p_out_tmp, + elements_per_channel_t_rem_minus_16); + } + + value_table = value_table + stride; + } + AE_SA128POS_FP(align_store, (ae_int8x16*)p_out_tmp); +} + +void XtensaDecodeStateLUT::DecompressToBufferWidthAnyInt8_Xtensa( + int8_t* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + const int stride = value_table_channel_stride_; + const uint8_t* __restrict value_table = + static_cast(value_table_); + + int num_channels_t = num_channels_; + short* __restrict p_stream = (short*)compressed_indices_; + uint32_t index; + ae_int8* __restrict p_out_tmp = (ae_int8*)buffer; + const size_t bw = compressed_bit_width_; + + WUR_AE_BITPTR(0); + WUR_AE_BITHEAD(0); + + AE_DBI_IP((const unsigned short*)p_stream, 16); + AE_DBI_IP((const unsigned short*)p_stream, 16); + + if (use_alternate_axis_) { + int count = count_indices_; + const uint8_t* __restrict value_table_t = value_table; + + while (count > 0) { + value_table = value_table_t; + + for (int channel = 0; channel < num_channels_t; channel++) { + AE_LB_DB_IP((unsigned short*)p_stream, index, bw); + ae_int8x8 d_tmp = AE_L8_X((const ae_int8*)value_table, index); + AE_S8_0_IP(d_tmp, p_out_tmp, 1); + value_table += stride; + } + + count -= num_channels_t; + } + } else { + int elements_per_channel_t = elements_per_channel_; + uint32_t index_1, index_2; + uint32_t mask_bits = (1 << compressed_bit_width_) - 1; + + for (int i = 0; i < num_channels_t; i++) { + elements_per_channel_t = elements_per_channel_; + /* if output pointer is not 2 byte aligned */ + if ((unsigned int)p_out_tmp & 0x1) { + AE_LB_DB_IP((unsigned short*)p_stream, index, bw); + ae_int8x8 d_tmp = AE_L8_X((const ae_int8*)value_table, index); + AE_S8_0_IP(d_tmp, p_out_tmp, 1); + elements_per_channel_t = elements_per_channel_t - 1; + } + for (int j = 0; j < (elements_per_channel_t >> 1); j++) { + AE_LB_DB_IP((unsigned short*)p_stream, index, 2 * bw); + index_1 = (index >> compressed_bit_width_) & mask_bits; + index_2 = (index)&mask_bits; + ae_int8x8 d_tmp1 = AE_L8_X((const ae_int8*)value_table, index_1); + ae_int8x8 d_tmp2 = AE_L8_X((const ae_int8*)value_table, index_2); + ae_int16x4 d_tmp = + AE_MOVINT16X4_FROMINT8X8(AE_SEL8X8I(d_tmp2, d_tmp1, 21)); + AE_S16_0_IP(d_tmp, (ae_int16*)p_out_tmp, 2); + } + if (elements_per_channel_t & 0x1) { + AE_LB_DB_IP((unsigned short*)p_stream, index, bw); + ae_int8x8 d_tmp = AE_L8_X((const ae_int8*)value_table, index); + AE_S8_0_IP(d_tmp, p_out_tmp, 1); + } + value_table += stride; + } + } +} + +void XtensaDecodeStateLUT::DecompressToBufferWidthAnyInt16_Xtensa( + int16_t* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + const int stride = value_table_channel_stride_; + const uint16_t* __restrict value_table = + static_cast(value_table_); + + int num_channels_t = num_channels_; + short* __restrict p_stream = (short*)compressed_indices_; + uint32_t index; + ae_int16* __restrict p_out_tmp = (ae_int16*)buffer; + const size_t bw = compressed_bit_width_; + + WUR_AE_BITPTR(0); + WUR_AE_BITHEAD(0); + + AE_DBI_IP((const unsigned short*)p_stream, 16); + AE_DBI_IP((const unsigned short*)p_stream, 16); + + if (use_alternate_axis_) { + int count = count_indices_; + const uint16_t* __restrict value_table_t = value_table; + + while (count > 0) { + value_table = value_table_t; + + for (int channel = 0; channel < num_channels_t; channel++) { + AE_LB_DB_IP((unsigned short*)p_stream, index, bw); + ae_int16x4 d_tmp = AE_L16_X((const ae_int16*)value_table, index << 1); + AE_S16_0_IP(d_tmp, p_out_tmp, 2); + value_table += stride; + } + + count -= num_channels_t; + } + } else { + int elements_per_channel_t = elements_per_channel_; + + for (int i = 0; i < num_channels_t; i++) { + for (int j = 0; j < elements_per_channel_t; j++) { + AE_LB_DB_IP((unsigned short*)p_stream, index, bw); + ae_int16x4 d_tmp = AE_L16_X((const ae_int16*)value_table, index << 1); + AE_S16_0_IP(d_tmp, p_out_tmp, 2); + } + + value_table += stride; + } + } +} + +void XtensaDecodeStateLUT::DecompressToBufferWidthAnyInt32_Xtensa( + int32_t* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + const int stride = value_table_channel_stride_; + const uint32_t* __restrict value_table = + static_cast(value_table_); + + int num_channels_t = num_channels_; + short* __restrict p_stream = (short*)compressed_indices_; + uint32_t index; + ae_int32* __restrict p_out_tmp = (ae_int32*)buffer; + const size_t bw = compressed_bit_width_; + + WUR_AE_BITPTR(0); + WUR_AE_BITHEAD(0); + + AE_DBI_IP((const unsigned short*)p_stream, 16); + AE_DBI_IP((const unsigned short*)p_stream, 16); + + if (use_alternate_axis_) { + int count = count_indices_; + const uint32_t* __restrict value_table_t = value_table; + + while (count > 0) { + value_table = value_table_t; + + for (int channel = 0; channel < num_channels_t; channel++) { + AE_LB_DB_IP((unsigned short*)p_stream, index, bw); + ae_int32x2 d_tmp = AE_L32_X((const ae_int32*)value_table, index << 2); + AE_S32_L_IP(d_tmp, p_out_tmp, 4); + value_table += stride; + } + + count -= num_channels_t; + } + } else { + int elements_per_channel_t = elements_per_channel_; + + for (int i = 0; i < num_channels_t; i++) { + for (int j = 0; j < elements_per_channel_t; j++) { + AE_LB_DB_IP((unsigned short*)p_stream, index, bw); + ae_int32x2 d_tmp = AE_L32_X((const ae_int32*)value_table, index << 2); + AE_S32_L_IP(d_tmp, p_out_tmp, 4); + } + + value_table += stride; + } + } +} + +void XtensaDecodeStateLUT::DecompressToBufferWidthAnyInt64_Xtensa( + int64_t* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + const int stride = value_table_channel_stride_; + const uint64_t* __restrict value_table = + static_cast(value_table_); + + int num_channels_t = num_channels_; + short* __restrict p_stream = (short*)compressed_indices_; + uint32_t index; + ae_int64* __restrict p_out_tmp = (ae_int64*)buffer; + const size_t bw = compressed_bit_width_; + + WUR_AE_BITPTR(0); + WUR_AE_BITHEAD(0); + + AE_DBI_IP((const unsigned short*)p_stream, 16); + AE_DBI_IP((const unsigned short*)p_stream, 16); + + if (use_alternate_axis_) { + int count = count_indices_; + const uint64_t* __restrict value_table_t = value_table; + + while (count > 0) { + value_table = value_table_t; + + for (int channel = 0; channel < num_channels_t; channel++) { + AE_LB_DB_IP((unsigned short*)p_stream, index, bw); + ae_int64 d_tmp = AE_L64_X((const ae_int64*)value_table, index << 3); + AE_S64_IP(d_tmp, p_out_tmp, 8); + value_table += stride; + } + + count -= num_channels_t; + } + } else { + int elements_per_channel_t = elements_per_channel_; + + for (int i = 0; i < num_channels_t; i++) { + for (int j = 0; j < elements_per_channel_t; j++) { + AE_LB_DB_IP((unsigned short*)p_stream, index, bw); + ae_int64 d_tmp = AE_L64_X((const ae_int64*)value_table, index << 3); + AE_S64_IP(d_tmp, p_out_tmp, 8); + } + + value_table += stride; + } + } +} + +void XtensaDecodeStateLUT::DecompressToBuffer(int8_t* buffer) { + if (compressed_bit_width_ == 4 && !use_alternate_axis_) { + if (!(elements_per_channel_ & 0x01)) { + DecompressToBufferWidth4_Xtensa(buffer); + } else { + DecompressToBufferWidthAnyInt8_Xtensa(buffer); + } + } else if (compressed_bit_width_ == 3 && !use_alternate_axis_) { + if (!(elements_per_channel_ & 0x07)) { + DecompressToBufferWidth3_Xtensa(buffer); + } else { + DecompressToBufferWidthAnyInt8_Xtensa(buffer); + } + } else if (compressed_bit_width_ == 2 && !use_alternate_axis_) { + if (!(elements_per_channel_ & 0x03)) { + DecompressToBufferWidth2_Xtensa(buffer); + } else { + DecompressToBufferWidthAnyInt8_Xtensa(buffer); + } + } else { + DecompressToBufferWidthAnyInt8_Xtensa(buffer); + } +} + +TfLiteStatus XtensaDecodeStateLUT::Decode(const TfLiteEvalTensor& input, + const TfLiteEvalTensor& ancillary, + const TfLiteEvalTensor& output) { + TFLITE_DCHECK(compressed_bit_width_ <= kMaxBitWidth); + TFLITE_DCHECK(compressed_bit_width_ > 0); + + void* const buffer = const_cast(micro::GetTensorData(&output)); + TFLITE_DCHECK(buffer != nullptr); + + switch (output.type) { + case kTfLiteBool: + DecompressToBuffer(static_cast(buffer)); + break; + case kTfLiteFloat32: + DecompressToBufferWidthAnyInt32_Xtensa(static_cast(buffer)); + break; + case kTfLiteInt8: + DecompressToBuffer(static_cast(buffer)); + break; + case kTfLiteInt16: + DecompressToBufferWidthAnyInt16_Xtensa(static_cast(buffer)); + break; + case kTfLiteInt32: + DecompressToBufferWidthAnyInt32_Xtensa(static_cast(buffer)); + break; + case kTfLiteInt64: + DecompressToBufferWidthAnyInt64_Xtensa(static_cast(buffer)); + break; + default: + MicroPrintf("unsupported tensor type %s", TfLiteTypeGetName(output.type)); + return kTfLiteError; + } + + return kTfLiteOk; +} + +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_lut.h b/tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_lut.h new file mode 100644 index 00000000000..b614887a4cc --- /dev/null +++ b/tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_lut.h @@ -0,0 +1,57 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_MICRO_MICRO_KERNELS_XTENSA_DECODE_STATE_LUT_H_ +#define TENSORFLOW_LITE_MICRO_MICRO_KERNELS_XTENSA_DECODE_STATE_LUT_H_ + +#include + +#include "tensorflow/lite/micro/compatibility.h" +#include "tensorflow/lite/micro/kernels/decode_state_lut.h" + +namespace tflite { + +struct XtensaDecodeStateLUT : public DecodeStateLUT { + XtensaDecodeStateLUT() = delete; + + XtensaDecodeStateLUT(const TfLiteContext* context, + MicroProfilerInterface* profiler) + : DecodeStateLUT(context, profiler) {} + + virtual TfLiteStatus Decode(const TfLiteEvalTensor& input, + const TfLiteEvalTensor& ancillary, + const TfLiteEvalTensor& output) override; + + protected: + virtual ~XtensaDecodeStateLUT() = default; + + void DecompressToBuffer(int8_t* buffer); + + void DecompressToBufferWidth4_Xtensa(int8_t* buffer); + void DecompressToBufferWidth3_Xtensa(int8_t* buffer); + void DecompressToBufferWidth2_Xtensa(int8_t* buffer); + + void DecompressToBufferWidthAnyInt8_Xtensa(int8_t* buffer); + void DecompressToBufferWidthAnyInt16_Xtensa(int16_t* buffer); + void DecompressToBufferWidthAnyInt32_Xtensa(int32_t* buffer); + void DecompressToBufferWidthAnyInt64_Xtensa(int64_t* buffer); + + private: + TF_LITE_REMOVE_VIRTUAL_DELETE +}; + +} // namespace tflite + +#endif // TENSORFLOW_LITE_MICRO_MICRO_KERNELS_XTENSA_DECODE_STATE_LUT_H_ diff --git a/tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_prune.cc b/tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_prune.cc new file mode 100644 index 00000000000..04df0831a40 --- /dev/null +++ b/tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_prune.cc @@ -0,0 +1,485 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_prune.h" + +#include + +#include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/micro/kernels/xtensa/xtensa.h" +#include "tensorflow/lite/micro/micro_log.h" +#include "tensorflow/lite/micro/micro_profiler.h" + +namespace tflite { + +TfLiteStatus XtensaDecodeStatePrune::Decode(const TfLiteEvalTensor& input, + const TfLiteEvalTensor& ancillary, + const TfLiteEvalTensor& output) { + void* const buffer = const_cast(micro::GetTensorData(&output)); + TFLITE_DCHECK(buffer != nullptr); + + switch (output.type) { + case kTfLiteBool: + DecompressToBufferInt8_Xtensa(buffer); + break; + case kTfLiteFloat32: + DecodeStatePrune::DecompressToBuffer(buffer); + break; + case kTfLiteInt8: + DecompressToBufferInt8_Xtensa(buffer); + break; + case kTfLiteInt16: + DecompressToBufferInt16_Xtensa(buffer); + break; + case kTfLiteInt32: + DecodeStatePrune::DecompressToBuffer(buffer); + break; + case kTfLiteInt64: + DecodeStatePrune::DecompressToBuffer(buffer); + break; + default: + MicroPrintf("unsupported tensor type %s", TfLiteTypeGetName(output.type)); + return kTfLiteError; + } + + return kTfLiteOk; +} + +void XtensaDecodeStatePrune::DecompressToBufferInt8_Xtensa(void* buffer) { + if (num_channels_ > 1 && zero_points_ != nullptr) { + DecompressToBufferPerChannelInt8_Xtensa(buffer); + return; + } + + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + ae_int8x16* p_weights = (ae_int8x16*)value_table_; + int* __restrict p_mask32 = (int*)compressed_indices_; + ae_valign align = AE_LA64_PP(p_weights); + ae_int8x8 data0, data1, data2, data3; + ae_int8x8 shfl0, shfl1, shfl2, shfl3; + const int count = count_indices_; + int8_t* __restrict pCoeff = static_cast(buffer); + ae_int8x8 zero = single_zero_point_; + ae_int8x8 discarded; + + if (single_zero_point_ == 0) { + for (int i = 0; i < count >> 5; i++) { + // unpack elements + int mask = *p_mask32++; + AE_LAVUNSQZ8X8_XP(data0, shfl0, align, p_weights, mask, 0); + AE_LAVUNSQZ8X8_XP(data1, shfl1, align, p_weights, mask, 1); + AE_LAVUNSQZ8X8_XP(data2, shfl2, align, p_weights, mask, 2); + AE_LAVUNSQZ8X8_XP(data3, shfl3, align, p_weights, mask, 3); + data0 = AE_SHFL8X8(data0, shfl0); + data1 = AE_SHFL8X8(data1, shfl1); + data2 = AE_SHFL8X8(data2, shfl2); + data3 = AE_SHFL8X8(data3, shfl3); + + // move elements to output + AE_S8X8X2_IP(data0, data1, (ae_int8x16*)pCoeff, 16); + AE_S8X8X2_IP(data2, data3, (ae_int8x16*)pCoeff, 16); + } + } else { + for (int i = 0; i < count >> 5; i++) { + // unpack elements + int mask = *p_mask32++; + AE_LAVUNSQZ8X8_XP(data0, shfl0, align, p_weights, mask, 0); + AE_LAVUNSQZ8X8_XP(data1, shfl1, align, p_weights, mask, 1); + AE_LAVUNSQZ8X8_XP(data2, shfl2, align, p_weights, mask, 2); + AE_LAVUNSQZ8X8_XP(data3, shfl3, align, p_weights, mask, 3); + data0 = AE_SHFL8X8(data0, shfl0); + data1 = AE_SHFL8X8(data1, shfl1); + data2 = AE_SHFL8X8(data2, shfl2); + data3 = AE_SHFL8X8(data3, shfl3); + + // merge into elements + AE_MOVT8X16_L(discarded, data0, zero, data0, mask); + AE_MOVT8X16_L(discarded, data1, zero, data1, mask >> 8); + AE_MOVT8X16_H(discarded, data2, zero, data2, mask); + AE_MOVT8X16_H(discarded, data3, zero, data3, mask >> 8); + + // move merged elements to output + AE_S8X8X2_IP(data0, data1, (ae_int8x16*)pCoeff, 16); + AE_S8X8X2_IP(data2, data3, (ae_int8x16*)pCoeff, 16); + } + } + + const int count_rem = count & 0x1F; + if (count_rem) { + ae_valignx2 align2 = AE_ZALIGN128(); + int8_t* __restrict p_mask8 = reinterpret_cast(p_mask32); + + // unpack and merge into remaining elements + int mask = *p_mask8++; + AE_LAVUNSQZ8X8_XP(data0, shfl0, align, p_weights, mask, 0); + data0 = AE_SHFL8X8(data0, shfl0); + AE_MOVT8X16_L(discarded, data0, zero, data0, mask); + if (count_rem > 8) { + mask = *p_mask8++; + AE_LAVUNSQZ8X8_XP(data1, shfl1, align, p_weights, mask, 0); + data1 = AE_SHFL8X8(data1, shfl1); + AE_MOVT8X16_L(discarded, data1, zero, data1, mask); + } + if (count_rem > 16) { + mask = *p_mask8++; + AE_LAVUNSQZ8X8_XP(data2, shfl2, align, p_weights, mask, 0); + data2 = AE_SHFL8X8(data2, shfl2); + AE_MOVT8X16_L(discarded, data2, zero, data2, mask); + } + if (count_rem > 24) { + mask = *p_mask8++; + AE_LAVUNSQZ8X8_XP(data3, shfl3, align, p_weights, mask, 0); + data3 = AE_SHFL8X8(data3, shfl3); + AE_MOVT8X16_L(discarded, data3, zero, data3, mask); + } + + // move merged elements to output + if (count_rem <= 16) { + AE_SAV8X8X2_XP(data0, data1, align2, (ae_int8x16*)pCoeff, count_rem); + } else { + AE_SAV8X8X2_XP(data0, data1, align2, (ae_int8x16*)pCoeff, 16); + AE_SAV8X8X2_XP(data2, data3, align2, (ae_int8x16*)pCoeff, + count_rem & 0xF); + } + AE_SA128POS_FP(align2, pCoeff); + } +} + +void XtensaDecodeStatePrune::DecompressToBufferPerChannelInt8_Xtensa( + void* buffer) { + if (use_alternate_axis_) { + DecompressToBufferPerChannelAltAxisInt8_Xtensa(buffer); + return; + } + TFLITE_DCHECK(zero_points_ != nullptr); + + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + ae_int8x16* p_weights = (ae_int8x16*)value_table_; + short* __restrict p_stream = (short*)compressed_indices_; + ae_valign align = AE_LA64_PP(p_weights); + ae_valignx2 align2 = AE_ZALIGN128(); + ae_int8x8 data0, data1, data2, data3; + ae_int8x8 shfl0, shfl1, shfl2, shfl3; + const int count = elements_per_channel_; + int8_t* __restrict pCoeff = static_cast(buffer); + ae_int8x8 discarded; + + WUR_AE_BITPTR(0); + WUR_AE_BITHEAD(0); + + AE_DBI_IP((const unsigned short*)p_stream, 16); + AE_DBI_IP((const unsigned short*)p_stream, 16); + + for (size_t channel = 0; channel < num_channels_; channel++) { + ae_int8x8 zero = zero_points_[channel]; + uint32_t mask_low, mask_high; + + if (zero_points_[channel] == 0) { + for (int i = 0; i < count >> 5; i++) { + // unpack elements + AE_LBI_DBI_IP((unsigned short*)p_stream, mask_high, 16); + AE_LBI_DBI_IP((unsigned short*)p_stream, mask_low, 16); + const int mask = (mask_high << 16) | mask_low; + + AE_LAVUNSQZ8X8_XP(data0, shfl0, align, p_weights, mask, 3); + AE_LAVUNSQZ8X8_XP(data1, shfl1, align, p_weights, mask, 2); + AE_LAVUNSQZ8X8_XP(data2, shfl2, align, p_weights, mask, 1); + AE_LAVUNSQZ8X8_XP(data3, shfl3, align, p_weights, mask, 0); + data0 = AE_SHFL8X8(data0, shfl0); + data1 = AE_SHFL8X8(data1, shfl1); + data2 = AE_SHFL8X8(data2, shfl2); + data3 = AE_SHFL8X8(data3, shfl3); + + // move elements to output + AE_SAV8X8X2_XP(data0, data1, align2, (ae_int8x16*)pCoeff, 16); + AE_SAV8X8X2_XP(data2, data3, align2, (ae_int8x16*)pCoeff, 16); + } + } else { + for (int i = 0; i < count >> 5; i++) { + // unpack elements + AE_LBI_DBI_IP((unsigned short*)p_stream, mask_high, 16); + AE_LBI_DBI_IP((unsigned short*)p_stream, mask_low, 16); + const int mask = (mask_high << 16) | mask_low; + + AE_LAVUNSQZ8X8_XP(data0, shfl0, align, p_weights, mask, 3); + AE_LAVUNSQZ8X8_XP(data1, shfl1, align, p_weights, mask, 2); + AE_LAVUNSQZ8X8_XP(data2, shfl2, align, p_weights, mask, 1); + AE_LAVUNSQZ8X8_XP(data3, shfl3, align, p_weights, mask, 0); + data0 = AE_SHFL8X8(data0, shfl0); + data1 = AE_SHFL8X8(data1, shfl1); + data2 = AE_SHFL8X8(data2, shfl2); + data3 = AE_SHFL8X8(data3, shfl3); + + // merge into elements + AE_MOVT8X16_H(discarded, data0, zero, data0, mask >> 8); + AE_MOVT8X16_H(discarded, data1, zero, data1, mask); + AE_MOVT8X16_L(discarded, data2, zero, data2, mask >> 8); + AE_MOVT8X16_L(discarded, data3, zero, data3, mask); + + // move merged elements to output + AE_SAV8X8X2_XP(data0, data1, align2, (ae_int8x16*)pCoeff, 16); + AE_SAV8X8X2_XP(data2, data3, align2, (ae_int8x16*)pCoeff, 16); + } + } + + const int count_rem = count & 0x1F; + if (count_rem) { + if (count_rem > 16) { + AE_LBI_DBI_IP((unsigned short*)p_stream, mask_high, 16); + AE_LB_DB_IP((unsigned short*)p_stream, mask_low, count_rem - 16); + mask_low <<= 32 - count_rem; + } else { + AE_LB_DB_IP((unsigned short*)p_stream, mask_high, count_rem); + mask_high <<= 16 - count_rem; + mask_low = 0; + } + const int mask = (mask_high << 16) | mask_low; + + // unpack and merge into remaining elements + AE_LAVUNSQZ8X8_XP(data0, shfl0, align, p_weights, mask, 3); + data0 = AE_SHFL8X8(data0, shfl0); + AE_MOVT8X16_H(discarded, data0, zero, data0, mask >> 8); + AE_LAVUNSQZ8X8_XP(data1, shfl1, align, p_weights, mask, 2); + data1 = AE_SHFL8X8(data1, shfl1); + AE_MOVT8X16_H(discarded, data1, zero, data1, mask); + AE_LAVUNSQZ8X8_XP(data2, shfl2, align, p_weights, mask, 1); + data2 = AE_SHFL8X8(data2, shfl2); + AE_MOVT8X16_L(discarded, data2, zero, data2, mask >> 8); + AE_LAVUNSQZ8X8_XP(data3, shfl3, align, p_weights, mask, 0); + data3 = AE_SHFL8X8(data3, shfl3); + AE_MOVT8X16_L(discarded, data3, zero, data3, mask); + + // move merged elements to output + if (count_rem <= 16) { + AE_SAV8X8X2_XP(data0, data1, align2, (ae_int8x16*)pCoeff, count_rem); + } else { + AE_SAV8X8X2_XP(data0, data1, align2, (ae_int8x16*)pCoeff, 16); + AE_SAV8X8X2_XP(data2, data3, align2, (ae_int8x16*)pCoeff, + count_rem & 0xF); + } + } + } + AE_SA128POS_FP(align2, pCoeff); +} + +void XtensaDecodeStatePrune::DecompressToBufferPerChannelAltAxisInt8_Xtensa( + void* buffer) { + TFLITE_DCHECK(zero_points_ != nullptr); + + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + ae_int8x16* p_weights = (ae_int8x16*)value_table_; + short* __restrict p_stream = (short*)compressed_indices_; + ae_valign align = AE_LA64_PP(p_weights); + ae_valignx2 align2 = AE_ZALIGN128(); + ae_int8x8 data0, data1, data2, data3; + ae_int8x8 shfl0, shfl1, shfl2, shfl3; + int count = count_indices_ / num_channels_; + const int max_channels = num_channels_; + int8_t* __restrict pCoeff = static_cast(buffer); + ae_int8x8 discarded; + + WUR_AE_BITPTR(0); + WUR_AE_BITHEAD(0); + + AE_DBI_IP((const unsigned short*)p_stream, 16); + AE_DBI_IP((const unsigned short*)p_stream, 16); + + while (count-- > 0) { + ae_int8x8 zero0, zero1, zero2, zero3; + uint32_t mask_low, mask_high; + // p_zero is always 16 byte aligned due to copy during Setup(). + int8_t* __restrict p_zero = (int8_t*)zero_points_; + + for (int i = 0; i < max_channels >> 5; i++) { + // unpack elements + AE_LBI_DBI_IP((unsigned short*)p_stream, mask_high, 16); + AE_LBI_DBI_IP((unsigned short*)p_stream, mask_low, 16); + const int mask = (mask_high << 16) | mask_low; + AE_LAVUNSQZ8X8_XP(data0, shfl0, align, p_weights, mask, 3); + AE_LAVUNSQZ8X8_XP(data1, shfl1, align, p_weights, mask, 2); + AE_LAVUNSQZ8X8_XP(data2, shfl2, align, p_weights, mask, 1); + AE_LAVUNSQZ8X8_XP(data3, shfl3, align, p_weights, mask, 0); + data0 = AE_SHFL8X8(data0, shfl0); + data1 = AE_SHFL8X8(data1, shfl1); + data2 = AE_SHFL8X8(data2, shfl2); + data3 = AE_SHFL8X8(data3, shfl3); + + // load values + AE_L8X8X2_IP(zero0, zero1, (ae_int8x16*)p_zero, 16); + AE_L8X8X2_IP(zero2, zero3, (ae_int8x16*)p_zero, 16); + + // merge into elements + AE_MOVT8X16_H(discarded, data0, zero0, data0, mask >> 8); + AE_MOVT8X16_H(discarded, data1, zero1, data1, mask); + AE_MOVT8X16_L(discarded, data2, zero2, data2, mask >> 8); + AE_MOVT8X16_L(discarded, data3, zero3, data3, mask); + + // move merged elements to output + AE_SAV8X8X2_XP(data0, data1, align2, (ae_int8x16*)pCoeff, 16); + AE_SAV8X8X2_XP(data2, data3, align2, (ae_int8x16*)pCoeff, 16); + } + + const int count_rem = max_channels & 0x1F; + if (count_rem) { + if (count_rem > 16) { + AE_LBI_DBI_IP((unsigned short*)p_stream, mask_high, 16); + AE_LB_DB_IP((unsigned short*)p_stream, mask_low, count_rem - 16); + mask_low <<= 32 - count_rem; + } else { + AE_LB_DB_IP((unsigned short*)p_stream, mask_high, count_rem); + mask_high <<= 16 - count_rem; + mask_low = 0; + } + const int mask = (mask_high << 16) | mask_low; + + // unpack remaining elements + AE_LAVUNSQZ8X8_XP(data0, shfl0, align, p_weights, mask, 3); + AE_LAVUNSQZ8X8_XP(data1, shfl1, align, p_weights, mask, 2); + AE_LAVUNSQZ8X8_XP(data2, shfl2, align, p_weights, mask, 1); + AE_LAVUNSQZ8X8_XP(data3, shfl3, align, p_weights, mask, 0); + data0 = AE_SHFL8X8(data0, shfl0); + data1 = AE_SHFL8X8(data1, shfl1); + data2 = AE_SHFL8X8(data2, shfl2); + data3 = AE_SHFL8X8(data3, shfl3); + + // load values, merge into elements and + // move merged elements to output + ae_valignx2 align_zero = AE_LA128_PP(p_zero); + if (count_rem <= 16) { + AE_LAV8X8X2_XP(zero0, zero1, align_zero, (ae_int8x16*)p_zero, + count_rem); + AE_MOVT8X16_H(discarded, data0, zero0, data0, mask >> 8); + AE_MOVT8X16_H(discarded, data1, zero1, data1, mask); + AE_SAV8X8X2_XP(data0, data1, align2, (ae_int8x16*)pCoeff, count_rem); + } else { + AE_LAV8X8X2_XP(zero0, zero1, align_zero, (ae_int8x16*)p_zero, 16); + AE_LAV8X8X2_XP(zero2, zero3, align_zero, (ae_int8x16*)p_zero, + count_rem & 0xF); + AE_MOVT8X16_H(discarded, data0, zero0, data0, mask >> 8); + AE_MOVT8X16_H(discarded, data1, zero1, data1, mask); + AE_MOVT8X16_L(discarded, data2, zero2, data2, mask >> 8); + AE_MOVT8X16_L(discarded, data3, zero3, data3, mask); + AE_SAV8X8X2_XP(data0, data1, align2, (ae_int8x16*)pCoeff, 16); + AE_SAV8X8X2_XP(data2, data3, align2, (ae_int8x16*)pCoeff, + count_rem & 0xF); + } + } + } + AE_SA128POS_FP(align2, pCoeff); +} + +void XtensaDecodeStatePrune::DecompressToBufferInt16_Xtensa(void* buffer) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + ae_int16x8* p_weights = (ae_int16x8*)value_table_; + int* __restrict p_mask32 = (int*)compressed_indices_; + ae_valign align = AE_LA64_PP(p_weights); + ae_int16x4 data0, data1, data2, data3; + ae_int16x4 data4, data5, data6, data7; + ae_int16x4 shfl0, shfl1, shfl2, shfl3; + ae_int16x4 shfl4, shfl5, shfl6, shfl7; + const int count = count_indices_; + int16_t* __restrict pCoeff = static_cast(buffer); + + for (int i = 0; i < count >> 5; i++) { + // unpack elements and merge 0 (zero) elements + int mask = *p_mask32++; + AE_LAVUNSQZ16X4_XP(data0, shfl0, align, p_weights, mask, 1); + AE_LAVUNSQZ16X4_XP(data1, shfl1, align, p_weights, mask, 0); + AE_LAVUNSQZ16X4_XP(data2, shfl2, align, p_weights, mask, 3); + AE_LAVUNSQZ16X4_XP(data3, shfl3, align, p_weights, mask, 2); + AE_LAVUNSQZ16X4_XP(data4, shfl4, align, p_weights, mask, 5); + AE_LAVUNSQZ16X4_XP(data5, shfl5, align, p_weights, mask, 4); + AE_LAVUNSQZ16X4_XP(data6, shfl6, align, p_weights, mask, 7); + AE_LAVUNSQZ16X4_XP(data7, shfl7, align, p_weights, mask, 6); + data0 = AE_SHFL16X4(data0, shfl0); + data1 = AE_SHFL16X4(data1, shfl1); + data2 = AE_SHFL16X4(data2, shfl2); + data3 = AE_SHFL16X4(data3, shfl3); + data4 = AE_SHFL16X4(data4, shfl4); + data5 = AE_SHFL16X4(data5, shfl5); + data6 = AE_SHFL16X4(data6, shfl6); + data7 = AE_SHFL16X4(data7, shfl7); + + // move merged elements to output + AE_S16X4X2_IP(data0, data1, (ae_int16x8*)pCoeff, 16); + AE_S16X4X2_IP(data2, data3, (ae_int16x8*)pCoeff, 16); + AE_S16X4X2_IP(data4, data5, (ae_int16x8*)pCoeff, 16); + AE_S16X4X2_IP(data6, data7, (ae_int16x8*)pCoeff, 16); + } + + const int count_rem = count & 0x1F; + if (count_rem) { + ae_valignx2 align2 = AE_ZALIGN128(); + int8_t* __restrict p_mask8 = reinterpret_cast(p_mask32); + + // unpack and merge into remaining elements + int mask = *p_mask8++; + AE_LAVUNSQZ16X4_XP(data0, shfl0, align, p_weights, mask, 1); + AE_LAVUNSQZ16X4_XP(data1, shfl1, align, p_weights, mask, 0); + data0 = AE_SHFL16X4(data0, shfl0); + data1 = AE_SHFL16X4(data1, shfl1); + if (count_rem > 8) { + mask = *p_mask8++; + AE_LAVUNSQZ16X4_XP(data2, shfl2, align, p_weights, mask, 1); + AE_LAVUNSQZ16X4_XP(data3, shfl3, align, p_weights, mask, 0); + data2 = AE_SHFL16X4(data2, shfl2); + data3 = AE_SHFL16X4(data3, shfl3); + } + if (count_rem > 16) { + mask = *p_mask8++; + AE_LAVUNSQZ16X4_XP(data4, shfl4, align, p_weights, mask, 1); + AE_LAVUNSQZ16X4_XP(data5, shfl5, align, p_weights, mask, 0); + data4 = AE_SHFL16X4(data4, shfl4); + data5 = AE_SHFL16X4(data5, shfl5); + } + if (count_rem > 24) { + mask = *p_mask8++; + AE_LAVUNSQZ16X4_XP(data6, shfl6, align, p_weights, mask, 1); + AE_LAVUNSQZ16X4_XP(data7, shfl7, align, p_weights, mask, 0); + data6 = AE_SHFL16X4(data6, shfl6); + data7 = AE_SHFL16X4(data7, shfl7); + } + + // move merged elements to output + if (count_rem <= 8) { + AE_SAV16X4X2_XP(data0, data1, align2, (ae_int16x8*)pCoeff, + count_rem << 1); + } else if (count_rem <= 16) { + AE_SAV16X4X2_XP(data0, data1, align2, (ae_int16x8*)pCoeff, 16); + AE_SAV16X4X2_XP(data2, data3, align2, (ae_int16x8*)pCoeff, + (count_rem - 8) << 1); + } else if (count_rem <= 24) { + AE_SAV16X4X2_XP(data0, data1, align2, (ae_int16x8*)pCoeff, 16); + AE_SAV16X4X2_XP(data2, data3, align2, (ae_int16x8*)pCoeff, 16); + AE_SAV16X4X2_XP(data4, data5, align2, (ae_int16x8*)pCoeff, + (count_rem - 16) << 1); + } else { + AE_SAV16X4X2_XP(data0, data1, align2, (ae_int16x8*)pCoeff, 16); + AE_SAV16X4X2_XP(data2, data3, align2, (ae_int16x8*)pCoeff, 16); + AE_SAV16X4X2_XP(data4, data5, align2, (ae_int16x8*)pCoeff, 16); + AE_SAV16X4X2_XP(data6, data7, align2, (ae_int16x8*)pCoeff, + (count_rem - 24) << 1); + } + AE_SA128POS_FP(align2, pCoeff); + } +} + +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_prune.h b/tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_prune.h new file mode 100644 index 00000000000..fb6935f3383 --- /dev/null +++ b/tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_prune.h @@ -0,0 +1,51 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_MICRO_MICRO_KERNELS_XTENSA_DECODE_STATE_PRUNE_H_ +#define TENSORFLOW_LITE_MICRO_MICRO_KERNELS_XTENSA_DECODE_STATE_PRUNE_H_ + +#include + +#include "tensorflow/lite/micro/compatibility.h" +#include "tensorflow/lite/micro/kernels/decode_state_prune.h" + +namespace tflite { + +struct XtensaDecodeStatePrune : public DecodeStatePrune { + XtensaDecodeStatePrune() = delete; + + XtensaDecodeStatePrune(const TfLiteContext* context, + MicroProfilerInterface* profiler) + : DecodeStatePrune(context, profiler) {} + + virtual TfLiteStatus Decode(const TfLiteEvalTensor& input, + const TfLiteEvalTensor& ancillary, + const TfLiteEvalTensor& output) override; + + protected: + virtual ~XtensaDecodeStatePrune() = default; + + void DecompressToBufferInt8_Xtensa(void* buffer); + void DecompressToBufferPerChannelInt8_Xtensa(void* buffer); + void DecompressToBufferPerChannelAltAxisInt8_Xtensa(void* buffer); + void DecompressToBufferInt16_Xtensa(void* buffer); + + private: + TF_LITE_REMOVE_VIRTUAL_DELETE +}; + +} // namespace tflite + +#endif // TENSORFLOW_LITE_MICRO_MICRO_KERNELS_XTENSA_DECODE_STATE_PRUNE_H_ diff --git a/tensorflow/lite/micro/tools/make/Makefile b/tensorflow/lite/micro/tools/make/Makefile index 46621194601..c3e1bbab3bf 100644 --- a/tensorflow/lite/micro/tools/make/Makefile +++ b/tensorflow/lite/micro/tools/make/Makefile @@ -389,6 +389,7 @@ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/cumsum.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/decode.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/decode_state.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/decode_state_lut.cc \ +$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/decode_state_prune.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/decompress.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/decompress_common.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/depth_to_space.cc \ diff --git a/tensorflow/lite/micro/tools/make/targets/xtensa_makefile.inc b/tensorflow/lite/micro/tools/make/targets/xtensa_makefile.inc index b05a0670248..4a8c1591445 100644 --- a/tensorflow/lite/micro/tools/make/targets/xtensa_makefile.inc +++ b/tensorflow/lite/micro/tools/make/targets/xtensa_makefile.inc @@ -124,6 +124,13 @@ ifeq ($(OPTIMIZED_KERNEL_DIR), xtensa) MICROLITE_CC_KERNEL_SRCS += \ $(TENSORFLOW_ROOT)tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/kernels/kernel_util.cc + + # Additional kernel sources for DECODE operator support + ifeq ($(TARGET_ARCH), $(filter $(TARGET_ARCH), hifi5)) + MICROLITE_CC_KERNEL_SRCS += \ + $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_lut.cc \ + $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_prune.cc + endif endif # override KERNEL_OPTIMIZATION_LEVEL to enable higher performance @@ -131,3 +138,11 @@ endif $(KERNEL_OBJDIR)$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/decompress.o: $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/decompress.cc @mkdir -p $(dir $@) $(CXX) $(CXXFLAGS) -O3 -LNO:simd $(INCLUDES) -c $< -o $@ + +$(KERNEL_OBJDIR)$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_lut.o: $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_lut.cc + @mkdir -p $(dir $@) + $(CXX) $(CXXFLAGS) -O3 -LNO:simd $(INCLUDES) -c $< -o $@ + +$(KERNEL_OBJDIR)$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_prune.o: $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/xtensa/xtensa_decode_state_prune.cc + @mkdir -p $(dir $@) + $(CXX) $(CXXFLAGS) -O3 -LNO:simd $(INCLUDES) -c $< -o $@