From 464a6a0499d215d5c624041d6ce255b860a54a35 Mon Sep 17 00:00:00 2001 From: j00575040 Date: Tue, 9 Apr 2024 21:34:17 +0800 Subject: [PATCH] fix argminmax int bug && support swish int8 && fix VAD asan bug --- mindspore/lite/src/litert/kernel/cpu/BUILD.gn | 1 + .../src/litert/kernel/cpu/base/custom_base.cc | 14 ++-- .../litert/kernel/cpu/int8/activation_int8.cc | 4 ++ .../litert/kernel/cpu/int8/argminmax_int8.cc | 35 +++++----- .../src/litert/kernel/cpu/int8/sigmoid_int8.h | 2 +- .../src/litert/kernel/cpu/int8/swish_int8.cc | 67 +++++++++++++++++++ .../src/litert/kernel/cpu/int8/swish_int8.h | 38 +++++++++++ mindspore/lite/src/litert/scheduler.cc | 9 ++- 8 files changed, 139 insertions(+), 31 deletions(-) create mode 100644 mindspore/lite/src/litert/kernel/cpu/int8/swish_int8.cc create mode 100644 mindspore/lite/src/litert/kernel/cpu/int8/swish_int8.h diff --git a/mindspore/lite/src/litert/kernel/cpu/BUILD.gn b/mindspore/lite/src/litert/kernel/cpu/BUILD.gn index 7b813314..297fc6f6 100644 --- a/mindspore/lite/src/litert/kernel/cpu/BUILD.gn +++ b/mindspore/lite/src/litert/kernel/cpu/BUILD.gn @@ -245,6 +245,7 @@ int8_kernel_sources = [ "int8/split_int8.cc", "int8/squeeze_int8.cc", "int8/sub_int8.cc", + "int8/swish_int8.cc", "int8/tanh_int8.cc", "int8/topk_int8.cc", "int8/transpose_int8.cc", diff --git a/mindspore/lite/src/litert/kernel/cpu/base/custom_base.cc b/mindspore/lite/src/litert/kernel/cpu/base/custom_base.cc index 9921e063..0459c417 100644 --- a/mindspore/lite/src/litert/kernel/cpu/base/custom_base.cc +++ b/mindspore/lite/src/litert/kernel/cpu/base/custom_base.cc @@ -28,19 +28,15 @@ using mindspore::lite::RET_OK; using mindspore::schema::PrimitiveType_Custom; namespace mindspore::kernel { -int CustomBaseCPUKernel::Prepare() { - return RET_OK; -} +int CustomBaseCPUKernel::Prepare() { return RET_OK; } -int CustomBaseCPUKernel::ReSize() { - return RET_OK; -} +int CustomBaseCPUKernel::ReSize() { return RET_OK; } -int CustomBaseCPUKernel::Run() { - return RET_OK; -} +int CustomBaseCPUKernel::Run() { return RET_OK; } REG_KERNEL(kCPU, kNumberTypeInt32, PrimType_Inner_ThirdPartyModel, LiteKernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimType_Inner_ThirdPartyModel, LiteKernelCreator) +REG_KERNEL(kCPU, kNumberTypeInt8, PrimType_Inner_ThirdPartyModel, LiteKernelCreator) +REG_KERNEL(kCPU, kNumberTypeUInt8, PrimType_Inner_ThirdPartyModel, LiteKernelCreator) REG_KERNEL(kCPU, kNumberTypeBool, PrimType_Inner_ThirdPartyModel, LiteKernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/src/litert/kernel/cpu/int8/activation_int8.cc b/mindspore/lite/src/litert/kernel/cpu/int8/activation_int8.cc index 9bc410e7..10b6cd5a 100644 --- a/mindspore/lite/src/litert/kernel/cpu/int8/activation_int8.cc +++ b/mindspore/lite/src/litert/kernel/cpu/int8/activation_int8.cc @@ -16,6 +16,7 @@ #include "src/litert/kernel/cpu/int8/relux_int8.h" #include "src/litert/kernel/cpu/int8/hswish_int8.h" +#include "src/litert/kernel/cpu/int8/swish_int8.h" #include "src/litert/kernel/cpu/int8/sigmoid_int8.h" #include "src/litert/kernel/cpu/int8/tanh_int8.h" #include "src/litert/kernel/cpu/int8/leaky_relu_int8.h" @@ -50,6 +51,9 @@ kernel::LiteKernel *CpuActivationInt8KernelCreator(const std::vectordata_type() != mindspore::kNumberTypeInt8 || - out_tensors_[0]->data_type() != mindspore::kNumberTypeInt8) { - MS_LOG(ERROR) << "Datatype error, input0 data_type is " << in_tensors_[0]->data_type() << ", output data_type is " - << out_tensors_[0]->data_type(); - return RET_ERROR; - } in_quant_arg_ = reinterpret_cast(malloc(sizeof(QuantArg))); if (in_quant_arg_ == nullptr) { MS_LOG(ERROR) << "Malloc QuantArg for argmin or argmax int8 op failed!"; @@ -64,18 +58,7 @@ int ArgMinMaxInt8CPUKernel::Prepare() { in_quant_arg_->scale_ = in_quant_args.front().scale; in_quant_arg_->zp_ = in_quant_args.front().zeroPoint; - auto *out_tensor = out_tensors_.at(kOutputIndex); - auto out_quant_args = out_tensor->quant_params(); - CHECK_LESS_RETURN(out_quant_args.size(), 1); - out_quant_arg_ = reinterpret_cast(malloc(sizeof(QuantArg))); - out_quant_arg_->scale_ = out_quant_args.front().scale; - out_quant_arg_->zp_ = out_quant_args.front().zeroPoint; - if (out_quant_arg_ == nullptr) { - MS_LOG(ERROR) << "Malloc QuantArg for argmin or argmax int8 op failed!"; - return RET_ERROR; - } - - compute_param_ = reinterpret_cast(sizeof(ArgMinMaxComputeParam)); + compute_param_ = reinterpret_cast(malloc(sizeof(ArgMinMaxComputeParam))); if (compute_param_ == nullptr) { MS_LOG(ERROR) << "Malloc ArgMinMaxComputeParam for argmin or argmax int8 op failed!"; return RET_ERROR; @@ -87,6 +70,22 @@ int ArgMinMaxInt8CPUKernel::Prepare() { compute_param_->out_value_ = param->out_value_; compute_param_->keep_dims_ = param->keep_dims_; + out_quant_arg_ = reinterpret_cast(malloc(sizeof(QuantArg))); + if (out_quant_arg_ == nullptr) { + MS_LOG(ERROR) << "Malloc QuantArg for argmin or argmax int8 op failed!"; + return RET_ERROR; + } + if (out_tensors_.size() == Num2 || compute_param_->out_value_) { + auto *out_tensor = out_tensors_.at(kOutputIndex); + auto out_quant_args = out_tensor->quant_params(); + CHECK_LESS_RETURN(out_quant_args.size(), 1); + out_quant_arg_->scale_ = out_quant_args.front().scale; + out_quant_arg_->zp_ = out_quant_args.front().zeroPoint; + } else { // set default quant value + out_quant_arg_->scale_ = 1.0f; + out_quant_arg_->zp_ = 0; + } + if (!InferShapeDone()) { return RET_OK; } diff --git a/mindspore/lite/src/litert/kernel/cpu/int8/sigmoid_int8.h b/mindspore/lite/src/litert/kernel/cpu/int8/sigmoid_int8.h index 1f383ae6..9080852f 100644 --- a/mindspore/lite/src/litert/kernel/cpu/int8/sigmoid_int8.h +++ b/mindspore/lite/src/litert/kernel/cpu/int8/sigmoid_int8.h @@ -34,7 +34,7 @@ class SigmoidInt8CPUKernel : public LiteKernel { int Run() override; int DoActivation(int task_id); - private: + protected: int8_t table_list_[256]{0}; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/litert/kernel/cpu/int8/swish_int8.cc b/mindspore/lite/src/litert/kernel/cpu/int8/swish_int8.cc new file mode 100644 index 00000000..501793af --- /dev/null +++ b/mindspore/lite/src/litert/kernel/cpu/int8/swish_int8.cc @@ -0,0 +1,67 @@ +/** + * Copyright 2024 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/litert/kernel/cpu/int8/swish_int8.h" +#include +#include +#include "nnacl/int8/quantize.h" +#include "src/litert/kernel_registry.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::ActivationType_SIGMOID; + +namespace mindspore::kernel { +// Calculate the quantization results of 0-255 in advance +void CalculateSwishTableList(int8_t *table, const float input_scale, const int32_t input_zp, const float output_scale, + const int32_t output_zp) { + int32_t min_value = std::numeric_limits::min(); + int32_t max_value = std::numeric_limits::max(); + for (int i = min_value; i < max_value; ++i) { + const float real_input_value = input_scale * (i - input_zp); + const float sigmoid_value = 1.0f / (1.0f + std::exp(-real_input_value)); + const int32_t quantized = (std::round(real_input_value * sigmoid_value / output_scale) + output_zp); + int8_t out_value = static_cast(std::max(std::min(quantized, max_value), min_value)); + uint8_t index = static_cast(i); + table[index] = out_value; + } +} + +int SwishInt8CPUKernel::Prepare() { + CHECK_LESS_RETURN(in_tensors_.size(), C1NUM); + CHECK_LESS_RETURN(out_tensors_.size(), C1NUM); + CHECK_NULL_RETURN(in_tensors_[0]); + CHECK_NULL_RETURN(out_tensors_[0]); + if (in_tensors_[0]->data_type() != mindspore::kNumberTypeInt8 || + out_tensors_[0]->data_type() != mindspore::kNumberTypeInt8) { + MS_LOG(ERROR) << "Datatype error, input0 data_type is " << in_tensors_[0]->data_type() << ", output data_type is " + << out_tensors_[0]->data_type(); + return RET_ERROR; + } + lite::Tensor *input = in_tensors_.at(0); + lite::Tensor *output = out_tensors_.at(0); + MS_CHECK_TRUE_RET(!input->quant_params().empty() && !output->quant_params().empty(), RET_ERROR); + const float input_scale = input->quant_params().front().scale; + const int32_t input_zp = input->quant_params().front().zeroPoint; + const float output_scale = output->quant_params().front().scale; + const int32_t output_zp = output->quant_params().front().zeroPoint; + CalculateSwishTableList(table_list_, input_scale, input_zp, output_scale, output_zp); + return RET_OK; +} +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/litert/kernel/cpu/int8/swish_int8.h b/mindspore/lite/src/litert/kernel/cpu/int8/swish_int8.h new file mode 100644 index 00000000..7b8ef9ca --- /dev/null +++ b/mindspore/lite/src/litert/kernel/cpu/int8/swish_int8.h @@ -0,0 +1,38 @@ +/** + * Copyright 2024 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_INT8_SWISH_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_INT8_SWISH_INT8_H_ + +#include +#include "src/litert/lite_kernel.h" +#include "src/litert/kernel/cpu/int8/sigmoid_int8.h" +#include "nnacl/int8/softmax_int8.h" +#include "nnacl/int8/quantize.h" + +namespace mindspore::kernel { +class SwishInt8CPUKernel : public SigmoidInt8CPUKernel { + public: + SwishInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::InnerContext *ctx) + : SigmoidInt8CPUKernel(parameter, inputs, outputs, ctx) {} + ~SwishInt8CPUKernel() override = default; + + int Prepare() override; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_INT8_SWISH_INT8_H_ diff --git a/mindspore/lite/src/litert/scheduler.cc b/mindspore/lite/src/litert/scheduler.cc index 199b4361..96efd972 100644 --- a/mindspore/lite/src/litert/scheduler.cc +++ b/mindspore/lite/src/litert/scheduler.cc @@ -511,8 +511,8 @@ int Scheduler::ReplaceDelegateKernels(std::vector *dst_ker if (context_->IsDeviceTypeEnabled(DT_NNRT)) { auto delegate = static_cast(delegate_.get()); delegate->ShallowCopyLiteGraph(this->src_model_->graph_); - void *meta_graph = reinterpret_cast(const_cast( - mindspore::schema::GetMetaGraph(this->src_model_->buf))); + void *meta_graph = reinterpret_cast( + const_cast(mindspore::schema::GetMetaGraph(this->src_model_->buf))); delegate->SetMetaGraph(meta_graph); } #endif @@ -865,7 +865,9 @@ int Scheduler::InferSubGraphShape(size_t subgraph_index) { infer_subgraph_index_.push_back(subgraph_index); auto subgraph = src_model_->graph_.sub_graphs_.at(subgraph_index); int subgraph_infershape_ret = RET_OK; - for (auto node_index : subgraph->node_indices_) { + auto node_indexes = subgraph->node_indices_; + for (size_t i = 0; i < node_indexes.size(); ++i) { + auto node_index = node_indexes[i]; auto node = src_model_->graph_.all_nodes_[node_index]; MS_ASSERT(node != nullptr); auto *primitive = node->primitive_; @@ -877,6 +879,7 @@ int Scheduler::InferSubGraphShape(size_t subgraph_index) { // convert shape to built-in shape MS_CHECK_TRUE_RET(node->input_indices_.size() == 1, RET_ERROR); shape_fusion_pass_->Run(node, subgraph_index); + node_indexes = subgraph->node_indices_; } auto ret = InferNodeShape(node); if (ret == RET_INFER_INVALID) { -- 2.31.1