• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ops/fusion/avg_pool_fusion.h"
18 
19 namespace mindspore {
20 namespace ops {
Init(const std::vector<int64_t> & kernel_size,const std::vector<int64_t> & stride,const PadMode & pad_mode,const Format & format,const std::vector<int64_t> & pad,const RoundMode & round_mode,const bool global,const ActivationType activation_type)21 void AvgPoolFusion::Init(const std::vector<int64_t> &kernel_size, const std::vector<int64_t> &stride,
22                          const PadMode &pad_mode, const Format &format, const std::vector<int64_t> &pad,
23                          const RoundMode &round_mode, const bool global, const ActivationType activation_type) {
24   this->set_pad_mode(pad_mode);
25   this->set_kernel_size(kernel_size);
26   this->set_strides(stride);
27   this->set_format(format);
28   this->set_pad(pad);
29   this->set_round_mode(round_mode);
30   this->set_global(global);
31   this->set_activation_type(activation_type);
32 }
33 
set_global(const bool global)34 void AvgPoolFusion::set_global(const bool global) { (void)AddAttr(kGlobal, MakeValue(global)); }
35 
set_activation_type(ActivationType activation_type)36 void AvgPoolFusion::set_activation_type(ActivationType activation_type) {
37   int64_t swi = activation_type;
38   (void)this->AddAttr(kActivationType, MakeValue(swi));
39 }
40 
get_global() const41 bool AvgPoolFusion::get_global() const {
42   auto value_ptr = GetAttr(kGlobal);
43   MS_EXCEPTION_IF_NULL(value_ptr);
44   return GetValue<bool>(value_ptr);
45 }
46 
get_activation_type() const47 ActivationType AvgPoolFusion::get_activation_type() const {
48   auto value_ptr = GetAttr(kActivationType);
49   MS_EXCEPTION_IF_NULL(value_ptr);
50   return ActivationType(GetValue<int64_t>(value_ptr));
51 }
52 
53 namespace {
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)54 abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
55   MS_EXCEPTION_IF_NULL(primitive);
56   for (auto item : input_args) {
57     MS_EXCEPTION_IF_NULL(item);
58   }
59   auto op_name = primitive->name();
60   auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
61   auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
62   if (format == NHWC) {
63     in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]};
64   }
65   const int64_t x_rank = 4;
66   (void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kEqual, x_rank, op_name);
67   auto kernel_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize));
68   auto pad_mode = PadMode(GetValue<int64_t>(primitive->GetAttr(kPadMode)));
69   auto batch = in_shape[0];
70   auto channel = in_shape[1];
71   auto in_h = in_shape[2];
72   auto in_w = in_shape[3];
73 
74   auto strides = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStrides));
75   (void)CheckAndConvertUtils::CheckPositiveVector(kStride, strides, op_name);
76   auto kernel_h = kernel_size[2];
77   auto kernel_w = kernel_size[3];
78   auto stride_h = strides[2];
79   auto stride_w = strides[3];
80   int64_t out_h = abstract::Shape::SHP_ANY;
81   int64_t out_w = abstract::Shape::SHP_ANY;
82   if (pad_mode == VALID) {
83     out_h = static_cast<int64_t>(ceil((in_h - (kernel_h - 1)) / static_cast<float>(stride_h)));
84     out_w = static_cast<int64_t>(ceil((in_w - (kernel_w - 1)) / static_cast<float>(stride_w)));
85   } else if (pad_mode == SAME) {
86     out_h = static_cast<int64_t>(ceil(in_h / static_cast<float>(stride_h)));
87     out_w = static_cast<int64_t>(ceil(in_w / static_cast<float>(stride_w)));
88   }
89   std::vector<int64_t> out_shape = {batch, channel, out_h, out_w};
90   if (format == NHWC) {
91     out_shape = {batch, out_h, out_w, channel};
92   }
93   if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t a) { return a <= 0; })) {
94     MS_LOG(EXCEPTION) << "Kernel size is not valid.";
95   }
96   return std::make_shared<abstract::Shape>(out_shape);
97 }
98 
InferType(const std::vector<AbstractBasePtr> & input_args)99 TypePtr InferType(const std::vector<AbstractBasePtr> &input_args) {
100   for (auto item : input_args) {
101     MS_EXCEPTION_IF_NULL(item);
102   }
103   return input_args[0]->BuildType();
104 }
105 }  // namespace
106 
AvgPoolFusionInfer(const abstract::AnalysisEnginePtr &,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)107 AbstractBasePtr AvgPoolFusionInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
108                                    const std::vector<AbstractBasePtr> &input_args) {
109   return std::make_shared<abstract::AbstractTensor>(InferType(input_args), InferShape(primitive, input_args));
110 }
111 REGISTER_PRIMITIVE_C(kNameAvgPoolFusion, AvgPoolFusion);
112 }  // namespace ops
113 }  // namespace mindspore
114