1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "src/runtime/kernel/arm/fp32/roi_pooling_fp32.h"
17 #include "nnacl/fp32/roi_pooling_fp32.h"
18 #include <vector>
19 #include "schema/model_generated.h"
20 #include "src/kernel_registry.h"
21 #include "include/errorcode.h"
22
23 using mindspore::kernel::KERNEL_ARCH;
24 using mindspore::lite::KernelRegistrar;
25 using mindspore::lite::RET_ERROR;
26 using mindspore::lite::RET_MEMORY_FAILED;
27 using mindspore::lite::RET_OK;
28 using mindspore::schema::PrimitiveType_ROIPooling;
29
30 namespace mindspore::kernel {
Init()31 int ROIPoolingCPUKernel::Init() {
32 MS_CHECK_TRUE_RET(in_tensors_.size() == kInputSize1, RET_ERROR);
33 MS_CHECK_TRUE_RET(out_tensors_.size() == 1, RET_ERROR);
34 CHECK_NULL_RETURN(in_tensors_[0]);
35 CHECK_NULL_RETURN(in_tensors_[1]);
36 CHECK_NULL_RETURN(out_tensors_[0]);
37 if (!InferShapeDone()) {
38 return RET_OK;
39 }
40 return ReSize();
41 }
42
ReSize()43 int ROIPoolingCPUKernel::ReSize() {
44 if (max_c_ != nullptr) {
45 free(max_c_);
46 max_c_ = nullptr;
47 }
48 auto in_shape = in_tensors_.front()->shape();
49 auto out_shape = out_tensors_.front()->shape();
50 int ndims = static_cast<int>(in_shape.size());
51 if (ndims < C4NUM) {
52 MS_LOG(ERROR) << "ROIPooling in_shape.size() error ,shape dim greater than or equal to 4!";
53 return RET_ERROR;
54 }
55 if (out_shape.size() < C4NUM) {
56 MS_LOG(ERROR) << "ROIPooling out_shape.size() error ,shape dim greater than or equal to 4!";
57 return RET_ERROR;
58 }
59 param_->ndim_ = ndims;
60 param_->input_n_ = in_shape.at(0);
61 param_->input_h_ = in_shape.at(1);
62 param_->input_w_ = in_shape.at(2);
63 param_->input_c_ = in_shape.at(3);
64 param_->output_n_ = out_shape.at(0);
65 param_->output_h_ = out_shape.at(1);
66 param_->output_w_ = out_shape.at(2);
67 param_->output_c_ = out_shape.at(3);
68 param_->in_strides_[ndims - 1] = 1;
69 param_->out_strides_[ndims - 1] = 1;
70 for (int i = ndims - 2; i >= 0; --i) {
71 param_->in_strides_[i] = in_shape.at(i + 1) * param_->in_strides_[i + 1];
72 param_->out_strides_[i] = out_shape.at(i + 1) * param_->out_strides_[i + 1];
73 }
74 param_->thread_num_ = MSMIN(param_->op_parameter_.thread_num_, out_shape.at(0));
75 MS_CHECK_FALSE_MSG(INT_MUL_OVERFLOW(param_->input_c_, static_cast<int>(sizeof(float))), RET_ERROR, "mul overflow");
76 max_c_ = reinterpret_cast<float *>(malloc(param_->input_c_ * static_cast<int>(sizeof(float))));
77 if (max_c_ == nullptr) {
78 MS_LOG(ERROR) << "malloc max_c failed.";
79 return RET_MEMORY_FAILED;
80 }
81 return RET_OK;
82 }
83
DoExecute(int task_id)84 int ROIPoolingCPUKernel::DoExecute(int task_id) {
85 CHECK_NULL_RETURN(in_ptr_);
86 CHECK_NULL_RETURN(out_ptr_);
87 CHECK_NULL_RETURN(roi_ptr_);
88 CHECK_NULL_RETURN(max_c_);
89 CHECK_NULL_RETURN(param_);
90 MS_CHECK_FALSE_MSG(param_->thread_num_ == 0, RET_ERROR, "div zero");
91 auto ret = ROIPooling(in_ptr_, out_ptr_, roi_ptr_, max_c_, task_id, param_);
92 if (ret != RET_OK) {
93 MS_LOG(ERROR) << "ROIPooling Execute error task_id[" << task_id << "] error_code[" << ret << "]";
94 return ret;
95 }
96 return RET_OK;
97 }
98
ROIPoolingRun(void * cdata,int task_id,float lhs_scale,float rhs_scale)99 int ROIPoolingRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
100 auto Data = reinterpret_cast<ROIPoolingCPUKernel *>(cdata);
101 auto ret = Data->DoExecute(task_id);
102 if (ret != RET_OK) {
103 MS_LOG(ERROR) << "ROIPooling Run error task_id[" << task_id << "] error_code[" << ret << "]";
104 return ret;
105 }
106 return RET_OK;
107 }
108
Run()109 int ROIPoolingCPUKernel::Run() {
110 in_ptr_ = reinterpret_cast<float *>(in_tensors_.front()->MutableData());
111 out_ptr_ = reinterpret_cast<float *>(out_tensors_.front()->MutableData());
112 roi_ptr_ = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData());
113 auto ret = ParallelLaunch(this->ms_context_, ROIPoolingRun, this, param_->thread_num_);
114 if (ret != RET_OK) {
115 MS_LOG(ERROR) << "ROIPooling error: error_code[" << ret << "]";
116 return ret;
117 }
118 return ret;
119 }
120
121 REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ROIPooling, LiteKernelCreator<ROIPoolingCPUKernel>)
122 } // namespace mindspore::kernel
123