1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "src/runtime/kernel/arm/int8/detection_post_process_int8.h"
17 #include <vector>
18 #include "schema/model_generated.h"
19 #include "src/kernel_registry.h"
20 #include "include/errorcode.h"
21 #include "nnacl/int8/quant_dtype_cast_int8.h"
22
23 using mindspore::lite::KernelRegistrar;
24 using mindspore::lite::RET_ERROR;
25 using mindspore::lite::RET_OK;
26 using mindspore::schema::PrimitiveType_DetectionPostProcess;
27
28 namespace mindspore::kernel {
DequantizeInt8ToFp32(const int task_id)29 int DetectionPostProcessInt8CPUKernel::DequantizeInt8ToFp32(const int task_id) {
30 int num_unit_thread = MSMIN(thread_n_stride_, quant_size_ - task_id * thread_n_stride_);
31 int thread_offset = task_id * thread_n_stride_;
32 int ret = DoDequantizeInt8ToFp32(data_int8_ + thread_offset, data_fp32_ + thread_offset, quant_param_.scale,
33 quant_param_.zeroPoint, num_unit_thread);
34 if (ret != RET_OK) {
35 MS_LOG(ERROR) << "QuantDTypeCast error task_id[" << task_id << "] error_code[" << ret << "]";
36 return RET_ERROR;
37 }
38 return RET_OK;
39 }
40
DequantizeInt8ToFp32Run(void * cdata,int task_id,float lhs_scale,float rhs_scale)41 int DequantizeInt8ToFp32Run(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
42 auto KernelData = reinterpret_cast<DetectionPostProcessInt8CPUKernel *>(cdata);
43 auto ret = KernelData->DequantizeInt8ToFp32(task_id);
44 if (ret != RET_OK) {
45 MS_LOG(ERROR) << "QuantDTypeCastRun error task_id[" << task_id << "] error_code[" << ret << "]";
46 return RET_ERROR;
47 }
48 return RET_OK;
49 }
50
Dequantize(lite::Tensor * tensor,float ** data)51 int DetectionPostProcessInt8CPUKernel::Dequantize(lite::Tensor *tensor, float **data) {
52 data_int8_ = reinterpret_cast<int8_t *>(tensor->data());
53 CHECK_NULL_RETURN(data_int8_);
54 *data = reinterpret_cast<float *>(ms_context_->allocator->Malloc(tensor->ElementsNum() * sizeof(float)));
55 if (*data == nullptr) {
56 MS_LOG(ERROR) << "Malloc data failed.";
57 return RET_ERROR;
58 }
59 if (tensor->quant_params().empty()) {
60 MS_LOG(ERROR) << "null quant param";
61 return RET_ERROR;
62 }
63 quant_param_ = tensor->quant_params().front();
64 data_fp32_ = *data;
65 quant_size_ = tensor->ElementsNum();
66 thread_n_stride_ = UP_DIV(quant_size_, op_parameter_->thread_num_);
67
68 auto ret = ParallelLaunch(this->ms_context_, DequantizeInt8ToFp32Run, this, op_parameter_->thread_num_);
69 if (ret != RET_OK) {
70 MS_LOG(ERROR) << "QuantDTypeCastRun error error_code[" << ret << "]";
71 ms_context_->allocator->Free(*data);
72 return RET_ERROR;
73 }
74 return RET_OK;
75 }
GetInputData()76 int DetectionPostProcessInt8CPUKernel::GetInputData() {
77 if (in_tensors_.at(0)->data_type() != kNumberTypeInt8 || in_tensors_.at(1)->data_type() != kNumberTypeInt8) {
78 MS_LOG(ERROR) << "Input data type error";
79 return RET_ERROR;
80 }
81 int status = Dequantize(in_tensors_.at(0), &input_boxes_);
82 if (status != RET_OK) {
83 return status;
84 }
85 status = Dequantize(in_tensors_.at(1), &input_scores_);
86 if (status != RET_OK) {
87 return status;
88 }
89 return RET_OK;
90 }
91
FreeAllocatedBuffer()92 void DetectionPostProcessInt8CPUKernel::FreeAllocatedBuffer() {
93 if (params_->decoded_boxes_ != nullptr) {
94 ms_context_->allocator->Free(params_->decoded_boxes_);
95 params_->decoded_boxes_ = nullptr;
96 }
97 if (params_->nms_candidate_ != nullptr) {
98 ms_context_->allocator->Free(params_->nms_candidate_);
99 params_->nms_candidate_ = nullptr;
100 }
101 if (params_->indexes_ != nullptr) {
102 ms_context_->allocator->Free(params_->indexes_);
103 params_->indexes_ = nullptr;
104 }
105 if (params_->scores_ != nullptr) {
106 ms_context_->allocator->Free(params_->scores_);
107 params_->scores_ = nullptr;
108 }
109 if (params_->all_class_indexes_ != nullptr) {
110 ms_context_->allocator->Free(params_->all_class_indexes_);
111 params_->all_class_indexes_ = nullptr;
112 }
113 if (params_->all_class_scores_ != nullptr) {
114 ms_context_->allocator->Free(params_->all_class_scores_);
115 params_->all_class_scores_ = nullptr;
116 }
117 if (params_->single_class_indexes_ != nullptr) {
118 ms_context_->allocator->Free(params_->single_class_indexes_);
119 params_->single_class_indexes_ = nullptr;
120 }
121 if (params_->selected_ != nullptr) {
122 ms_context_->allocator->Free(params_->selected_);
123 params_->selected_ = nullptr;
124 }
125 if (input_boxes_ != nullptr) {
126 ms_context_->allocator->Free(input_boxes_);
127 input_boxes_ = nullptr;
128 }
129 if (input_scores_ != nullptr) {
130 ms_context_->allocator->Free(input_scores_);
131 input_scores_ = nullptr;
132 }
133 }
134
135 REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_DetectionPostProcess,
136 LiteKernelCreator<DetectionPostProcessInt8CPUKernel>)
137 } // namespace mindspore::kernel
138