• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "src/runtime/kernel/arm/int8/detection_post_process_int8.h"
17 #include <vector>
18 #include "schema/model_generated.h"
19 #include "src/kernel_registry.h"
20 #include "include/errorcode.h"
21 #include "nnacl/int8/quant_dtype_cast_int8.h"
22 
23 using mindspore::lite::KernelRegistrar;
24 using mindspore::lite::RET_ERROR;
25 using mindspore::lite::RET_OK;
26 using mindspore::schema::PrimitiveType_DetectionPostProcess;
27 
28 namespace mindspore::kernel {
DequantizeInt8ToFp32(const int task_id)29 int DetectionPostProcessInt8CPUKernel::DequantizeInt8ToFp32(const int task_id) {
30   int num_unit_thread = MSMIN(thread_n_stride_, quant_size_ - task_id * thread_n_stride_);
31   int thread_offset = task_id * thread_n_stride_;
32   int ret = DoDequantizeInt8ToFp32(data_int8_ + thread_offset, data_fp32_ + thread_offset, quant_param_.scale,
33                                    quant_param_.zeroPoint, num_unit_thread);
34   if (ret != RET_OK) {
35     MS_LOG(ERROR) << "QuantDTypeCast error task_id[" << task_id << "] error_code[" << ret << "]";
36     return RET_ERROR;
37   }
38   return RET_OK;
39 }
40 
DequantizeInt8ToFp32Run(void * cdata,int task_id,float lhs_scale,float rhs_scale)41 int DequantizeInt8ToFp32Run(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
42   auto KernelData = reinterpret_cast<DetectionPostProcessInt8CPUKernel *>(cdata);
43   auto ret = KernelData->DequantizeInt8ToFp32(task_id);
44   if (ret != RET_OK) {
45     MS_LOG(ERROR) << "QuantDTypeCastRun error task_id[" << task_id << "] error_code[" << ret << "]";
46     return RET_ERROR;
47   }
48   return RET_OK;
49 }
50 
Dequantize(lite::Tensor * tensor,float ** data)51 int DetectionPostProcessInt8CPUKernel::Dequantize(lite::Tensor *tensor, float **data) {
52   data_int8_ = reinterpret_cast<int8_t *>(tensor->data());
53   CHECK_NULL_RETURN(data_int8_);
54   *data = reinterpret_cast<float *>(ms_context_->allocator->Malloc(tensor->ElementsNum() * sizeof(float)));
55   if (*data == nullptr) {
56     MS_LOG(ERROR) << "Malloc data failed.";
57     return RET_ERROR;
58   }
59   if (tensor->quant_params().empty()) {
60     MS_LOG(ERROR) << "null quant param";
61     return RET_ERROR;
62   }
63   quant_param_ = tensor->quant_params().front();
64   data_fp32_ = *data;
65   quant_size_ = tensor->ElementsNum();
66   thread_n_stride_ = UP_DIV(quant_size_, op_parameter_->thread_num_);
67 
68   auto ret = ParallelLaunch(this->ms_context_, DequantizeInt8ToFp32Run, this, op_parameter_->thread_num_);
69   if (ret != RET_OK) {
70     MS_LOG(ERROR) << "QuantDTypeCastRun error error_code[" << ret << "]";
71     ms_context_->allocator->Free(*data);
72     return RET_ERROR;
73   }
74   return RET_OK;
75 }
GetInputData()76 int DetectionPostProcessInt8CPUKernel::GetInputData() {
77   if (in_tensors_.at(0)->data_type() != kNumberTypeInt8 || in_tensors_.at(1)->data_type() != kNumberTypeInt8) {
78     MS_LOG(ERROR) << "Input data type error";
79     return RET_ERROR;
80   }
81   int status = Dequantize(in_tensors_.at(0), &input_boxes_);
82   if (status != RET_OK) {
83     return status;
84   }
85   status = Dequantize(in_tensors_.at(1), &input_scores_);
86   if (status != RET_OK) {
87     return status;
88   }
89   return RET_OK;
90 }
91 
FreeAllocatedBuffer()92 void DetectionPostProcessInt8CPUKernel::FreeAllocatedBuffer() {
93   if (params_->decoded_boxes_ != nullptr) {
94     ms_context_->allocator->Free(params_->decoded_boxes_);
95     params_->decoded_boxes_ = nullptr;
96   }
97   if (params_->nms_candidate_ != nullptr) {
98     ms_context_->allocator->Free(params_->nms_candidate_);
99     params_->nms_candidate_ = nullptr;
100   }
101   if (params_->indexes_ != nullptr) {
102     ms_context_->allocator->Free(params_->indexes_);
103     params_->indexes_ = nullptr;
104   }
105   if (params_->scores_ != nullptr) {
106     ms_context_->allocator->Free(params_->scores_);
107     params_->scores_ = nullptr;
108   }
109   if (params_->all_class_indexes_ != nullptr) {
110     ms_context_->allocator->Free(params_->all_class_indexes_);
111     params_->all_class_indexes_ = nullptr;
112   }
113   if (params_->all_class_scores_ != nullptr) {
114     ms_context_->allocator->Free(params_->all_class_scores_);
115     params_->all_class_scores_ = nullptr;
116   }
117   if (params_->single_class_indexes_ != nullptr) {
118     ms_context_->allocator->Free(params_->single_class_indexes_);
119     params_->single_class_indexes_ = nullptr;
120   }
121   if (params_->selected_ != nullptr) {
122     ms_context_->allocator->Free(params_->selected_);
123     params_->selected_ = nullptr;
124   }
125   if (input_boxes_ != nullptr) {
126     ms_context_->allocator->Free(input_boxes_);
127     input_boxes_ = nullptr;
128   }
129   if (input_scores_ != nullptr) {
130     ms_context_->allocator->Free(input_scores_);
131     input_scores_ = nullptr;
132   }
133 }
134 
135 REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_DetectionPostProcess,
136            LiteKernelCreator<DetectionPostProcessInt8CPUKernel>)
137 }  // namespace mindspore::kernel
138