1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/runtime/kernel/arm/fp32/pooling_fp32.h"
18 #include <cfloat>
19 #include "nnacl/fp32/pooling_fp32.h"
20 #include "src/kernel_registry.h"
21 #include "include/errorcode.h"
22 #include "nnacl/op_base.h"
23
24 using mindspore::kernel::KERNEL_ARCH;
25 using mindspore::lite::KernelRegistrar;
26 using mindspore::lite::RET_ERROR;
27 using mindspore::lite::RET_OK;
28 using mindspore::schema::PrimitiveType_AvgPoolFusion;
29 using mindspore::schema::PrimitiveType_MaxPoolFusion;
30
31 namespace mindspore::kernel {
Init()32 int PoolingCPUKernel::Init() {
33 auto ret = PoolingBaseCPUKernel::Init();
34 if (ret != RET_OK) {
35 MS_LOG(ERROR) << "PoolingBase Init failed.";
36 return RET_ERROR;
37 }
38 if (!InferShapeDone()) {
39 return RET_OK;
40 }
41 return ReSize();
42 }
43
ReSize()44 int PoolingCPUKernel::ReSize() {
45 auto ret = PoolingBaseCPUKernel::ReSize();
46 if (ret != RET_OK) {
47 MS_LOG(ERROR) << "PoolingBase ReSize fai1!ret: " << ret;
48 return ret;
49 }
50 return RET_OK;
51 }
52
RunImpl(int task_id)53 int PoolingCPUKernel::RunImpl(int task_id) {
54 auto input_ptr = reinterpret_cast<float *>(in_tensors_.at(kInputIndex)->MutableData());
55 CHECK_NULL_RETURN(input_ptr);
56 auto output_ptr = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->MutableData());
57 CHECK_NULL_RETURN(output_ptr);
58 float minf = -FLT_MAX;
59 float maxf = FLT_MAX;
60 if (pooling_param_->act_type_ == ActType_Relu) {
61 minf = 0.f;
62 } else if (pooling_param_->act_type_ == ActType_Relu6) {
63 minf = 0.f;
64 maxf = 6.f;
65 }
66 int ret = 0;
67 if (pooling_param_->pool_mode_ == PoolMode_MaxPool) {
68 ret = MaxPooling(input_ptr, output_ptr, pooling_param_, task_id, minf, maxf);
69 } else {
70 ret = AvgPooling(input_ptr, output_ptr, pooling_param_, task_id, minf, maxf);
71 }
72 if (ret != RET_OK) {
73 MS_LOG(ERROR) << "AcgPooling run failed.";
74 return ret;
75 }
76 return RET_OK;
77 }
78
PoolingImpl(void * cdata,int task_id,float lhs_scale,float rhs_scale)79 int PoolingImpl(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
80 auto pooling = reinterpret_cast<PoolingCPUKernel *>(cdata);
81 auto error_code = pooling->RunImpl(task_id);
82 if (error_code != RET_OK) {
83 MS_LOG(ERROR) << "Pooling Run error task_id[" << task_id << "] error_code[" << error_code << "]";
84 return RET_ERROR;
85 }
86 return RET_OK;
87 }
88
Run()89 int PoolingCPUKernel::Run() {
90 int error_code = ParallelLaunch(this->ms_context_, PoolingImpl, this, thread_count_);
91 if (error_code != RET_OK) {
92 MS_LOG(ERROR) << "pooling error error_code[" << error_code << "]";
93 return RET_ERROR;
94 }
95 return RET_OK;
96 }
97
98 REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_AvgPoolFusion, LiteKernelCreator<PoolingCPUKernel>)
99 REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_MaxPoolFusion, LiteKernelCreator<PoolingCPUKernel>)
100 } // namespace mindspore::kernel
101