• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "src/runtime/kernel/arm/fp32/roi_pooling_fp32.h"
17 #include "nnacl/fp32/roi_pooling_fp32.h"
18 #include <vector>
19 #include "schema/model_generated.h"
20 #include "src/kernel_registry.h"
21 #include "include/errorcode.h"
22 
23 using mindspore::kernel::KERNEL_ARCH;
24 using mindspore::lite::KernelRegistrar;
25 using mindspore::lite::RET_ERROR;
26 using mindspore::lite::RET_MEMORY_FAILED;
27 using mindspore::lite::RET_OK;
28 using mindspore::schema::PrimitiveType_ROIPooling;
29 
30 namespace mindspore::kernel {
Init()31 int ROIPoolingCPUKernel::Init() {
32   MS_CHECK_TRUE_RET(in_tensors_.size() == kInputSize1, RET_ERROR);
33   MS_CHECK_TRUE_RET(out_tensors_.size() == 1, RET_ERROR);
34   CHECK_NULL_RETURN(in_tensors_[0]);
35   CHECK_NULL_RETURN(in_tensors_[1]);
36   CHECK_NULL_RETURN(out_tensors_[0]);
37   if (!InferShapeDone()) {
38     return RET_OK;
39   }
40   return ReSize();
41 }
42 
ReSize()43 int ROIPoolingCPUKernel::ReSize() {
44   if (max_c_ != nullptr) {
45     free(max_c_);
46     max_c_ = nullptr;
47   }
48   auto in_shape = in_tensors_.front()->shape();
49   auto out_shape = out_tensors_.front()->shape();
50   int ndims = static_cast<int>(in_shape.size());
51   if (ndims < C4NUM) {
52     MS_LOG(ERROR) << "ROIPooling in_shape.size() error ,shape dim greater than or equal to 4!";
53     return RET_ERROR;
54   }
55   if (out_shape.size() < C4NUM) {
56     MS_LOG(ERROR) << "ROIPooling out_shape.size() error ,shape dim greater than or equal to 4!";
57     return RET_ERROR;
58   }
59   param_->ndim_ = ndims;
60   param_->input_n_ = in_shape.at(0);
61   param_->input_h_ = in_shape.at(1);
62   param_->input_w_ = in_shape.at(2);
63   param_->input_c_ = in_shape.at(3);
64   param_->output_n_ = out_shape.at(0);
65   param_->output_h_ = out_shape.at(1);
66   param_->output_w_ = out_shape.at(2);
67   param_->output_c_ = out_shape.at(3);
68   param_->in_strides_[ndims - 1] = 1;
69   param_->out_strides_[ndims - 1] = 1;
70   for (int i = ndims - 2; i >= 0; --i) {
71     param_->in_strides_[i] = in_shape.at(i + 1) * param_->in_strides_[i + 1];
72     param_->out_strides_[i] = out_shape.at(i + 1) * param_->out_strides_[i + 1];
73   }
74   param_->thread_num_ = MSMIN(param_->op_parameter_.thread_num_, out_shape.at(0));
75   MS_CHECK_FALSE_MSG(INT_MUL_OVERFLOW(param_->input_c_, static_cast<int>(sizeof(float))), RET_ERROR, "mul overflow");
76   max_c_ = reinterpret_cast<float *>(malloc(param_->input_c_ * static_cast<int>(sizeof(float))));
77   if (max_c_ == nullptr) {
78     MS_LOG(ERROR) << "malloc max_c failed.";
79     return RET_MEMORY_FAILED;
80   }
81   return RET_OK;
82 }
83 
DoExecute(int task_id)84 int ROIPoolingCPUKernel::DoExecute(int task_id) {
85   CHECK_NULL_RETURN(in_ptr_);
86   CHECK_NULL_RETURN(out_ptr_);
87   CHECK_NULL_RETURN(roi_ptr_);
88   CHECK_NULL_RETURN(max_c_);
89   CHECK_NULL_RETURN(param_);
90   MS_CHECK_FALSE_MSG(param_->thread_num_ == 0, RET_ERROR, "div zero");
91   auto ret = ROIPooling(in_ptr_, out_ptr_, roi_ptr_, max_c_, task_id, param_);
92   if (ret != RET_OK) {
93     MS_LOG(ERROR) << "ROIPooling Execute error task_id[" << task_id << "] error_code[" << ret << "]";
94     return ret;
95   }
96   return RET_OK;
97 }
98 
ROIPoolingRun(void * cdata,int task_id,float lhs_scale,float rhs_scale)99 int ROIPoolingRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
100   auto Data = reinterpret_cast<ROIPoolingCPUKernel *>(cdata);
101   auto ret = Data->DoExecute(task_id);
102   if (ret != RET_OK) {
103     MS_LOG(ERROR) << "ROIPooling Run error task_id[" << task_id << "] error_code[" << ret << "]";
104     return ret;
105   }
106   return RET_OK;
107 }
108 
Run()109 int ROIPoolingCPUKernel::Run() {
110   in_ptr_ = reinterpret_cast<float *>(in_tensors_.front()->MutableData());
111   out_ptr_ = reinterpret_cast<float *>(out_tensors_.front()->MutableData());
112   roi_ptr_ = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData());
113   auto ret = ParallelLaunch(this->ms_context_, ROIPoolingRun, this, param_->thread_num_);
114   if (ret != RET_OK) {
115     MS_LOG(ERROR) << "ROIPooling error: error_code[" << ret << "]";
116     return ret;
117   }
118   return ret;
119 }
120 
121 REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ROIPooling, LiteKernelCreator<ROIPoolingCPUKernel>)
122 }  // namespace mindspore::kernel
123