• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/runtime/kernel/arm/fp32/pooling_fp32.h"
18 #include <cfloat>
19 #include "nnacl/fp32/pooling_fp32.h"
20 #include "src/kernel_registry.h"
21 #include "include/errorcode.h"
22 #include "nnacl/op_base.h"
23 
24 using mindspore::kernel::KERNEL_ARCH;
25 using mindspore::lite::KernelRegistrar;
26 using mindspore::lite::RET_ERROR;
27 using mindspore::lite::RET_OK;
28 using mindspore::schema::PrimitiveType_AvgPoolFusion;
29 using mindspore::schema::PrimitiveType_MaxPoolFusion;
30 
31 namespace mindspore::kernel {
Init()32 int PoolingCPUKernel::Init() {
33   auto ret = PoolingBaseCPUKernel::Init();
34   if (ret != RET_OK) {
35     MS_LOG(ERROR) << "PoolingBase Init failed.";
36     return RET_ERROR;
37   }
38   if (!InferShapeDone()) {
39     return RET_OK;
40   }
41   return ReSize();
42 }
43 
ReSize()44 int PoolingCPUKernel::ReSize() {
45   auto ret = PoolingBaseCPUKernel::ReSize();
46   if (ret != RET_OK) {
47     MS_LOG(ERROR) << "PoolingBase ReSize fai1!ret: " << ret;
48     return ret;
49   }
50   return RET_OK;
51 }
52 
RunImpl(int task_id)53 int PoolingCPUKernel::RunImpl(int task_id) {
54   auto input_ptr = reinterpret_cast<float *>(in_tensors_.at(kInputIndex)->MutableData());
55   CHECK_NULL_RETURN(input_ptr);
56   auto output_ptr = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->MutableData());
57   CHECK_NULL_RETURN(output_ptr);
58   float minf = -FLT_MAX;
59   float maxf = FLT_MAX;
60   if (pooling_param_->act_type_ == ActType_Relu) {
61     minf = 0.f;
62   } else if (pooling_param_->act_type_ == ActType_Relu6) {
63     minf = 0.f;
64     maxf = 6.f;
65   }
66   int ret = 0;
67   if (pooling_param_->pool_mode_ == PoolMode_MaxPool) {
68     ret = MaxPooling(input_ptr, output_ptr, pooling_param_, task_id, minf, maxf);
69   } else {
70     ret = AvgPooling(input_ptr, output_ptr, pooling_param_, task_id, minf, maxf);
71   }
72   if (ret != RET_OK) {
73     MS_LOG(ERROR) << "AcgPooling run failed.";
74     return ret;
75   }
76   return RET_OK;
77 }
78 
PoolingImpl(void * cdata,int task_id,float lhs_scale,float rhs_scale)79 int PoolingImpl(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
80   auto pooling = reinterpret_cast<PoolingCPUKernel *>(cdata);
81   auto error_code = pooling->RunImpl(task_id);
82   if (error_code != RET_OK) {
83     MS_LOG(ERROR) << "Pooling Run error task_id[" << task_id << "] error_code[" << error_code << "]";
84     return RET_ERROR;
85   }
86   return RET_OK;
87 }
88 
Run()89 int PoolingCPUKernel::Run() {
90   int error_code = ParallelLaunch(this->ms_context_, PoolingImpl, this, thread_count_);
91   if (error_code != RET_OK) {
92     MS_LOG(ERROR) << "pooling error error_code[" << error_code << "]";
93     return RET_ERROR;
94   }
95   return RET_OK;
96 }
97 
98 REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_AvgPoolFusion, LiteKernelCreator<PoolingCPUKernel>)
99 REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_MaxPoolFusion, LiteKernelCreator<PoolingCPUKernel>)
100 }  // namespace mindspore::kernel
101