1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ops/avg_pool.h"
18 #include <string>
19 #include <algorithm>
20 #include <memory>
21 #include <set>
22 #include <vector>
23 #include "ops/op_utils.h"
24 #include "utils/check_convert_utils.h"
25 #include "abstract/primitive_infer_map.h"
26
27 namespace mindspore {
28 namespace ops {
set_pad_mode(const PadMode & pad_mode)29 void AvgPool::set_pad_mode(const PadMode &pad_mode) {
30 int64_t swi = pad_mode;
31 (void)this->AddAttr(kPadMode, MakeValue(swi));
32 }
33
get_pad_mode() const34 PadMode AvgPool::get_pad_mode() const { return PadMode(GetValue<int64_t>(GetAttr(kPadMode))); }
set_kernel_size(const std::vector<int64_t> & kernel_size)35 void AvgPool::set_kernel_size(const std::vector<int64_t> &kernel_size) {
36 (void)this->AddAttr(kKernelSize,
37 MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, this->name())));
38 }
39
get_kernel_size() const40 std::vector<int64_t> AvgPool::get_kernel_size() const { return GetValue<std::vector<int64_t>>(GetAttr(kKernelSize)); }
set_strides(const std::vector<int64_t> & strides)41 void AvgPool::set_strides(const std::vector<int64_t> &strides) {
42 (void)this->AddAttr(kStrides, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStrides, strides, this->name())));
43 }
44
get_strides() const45 std::vector<int64_t> AvgPool::get_strides() const { return GetValue<std::vector<int64_t>>(GetAttr(kStrides)); }
46
set_format(const Format & format)47 void AvgPool::set_format(const Format &format) {
48 int64_t f = format;
49 (void)this->AddAttr(kFormat, MakeValue(f));
50 }
51
get_format() const52 Format AvgPool::get_format() const { return Format(GetValue<int64_t>(GetAttr(kFormat))); }
53
set_pad(const std::vector<int64_t> & pad)54 void AvgPool::set_pad(const std::vector<int64_t> &pad) { (void)this->AddAttr(kPad, MakeValue(pad)); }
55
get_pad() const56 std::vector<int64_t> AvgPool::get_pad() const {
57 auto value_ptr = GetAttr(kPad);
58 return GetValue<std::vector<int64_t>>(value_ptr);
59 }
60
set_round_mode(const RoundMode & round_mode)61 void AvgPool::set_round_mode(const RoundMode &round_mode) {
62 int64_t swi = round_mode;
63 (void)this->AddAttr(kRoundMode, MakeValue(swi));
64 }
65
get_round_mode() const66 RoundMode AvgPool::get_round_mode() const {
67 auto value_ptr = GetAttr(kRoundMode);
68 return RoundMode(GetValue<int64_t>(value_ptr));
69 }
70
Init(const std::vector<int64_t> & kernel_size,const std::vector<int64_t> & stride,const PadMode & pad_mode,const Format & format,const std::vector<int64_t> & pad,const RoundMode & round_mode)71 void AvgPool::Init(const std::vector<int64_t> &kernel_size, const std::vector<int64_t> &stride, const PadMode &pad_mode,
72 const Format &format, const std::vector<int64_t> &pad, const RoundMode &round_mode) {
73 this->set_pad_mode(pad_mode);
74 this->set_kernel_size(kernel_size);
75 this->set_strides(stride);
76 this->set_format(format);
77 this->set_pad(pad);
78 this->set_round_mode(round_mode);
79 }
80
81 namespace {
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)82 abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
83 auto op_name = primitive->name();
84 auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
85 auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
86 const int64_t x_size = 4;
87 const int64_t attr_size = 4;
88 (void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kEqual, x_size, op_name);
89 if (format == NHWC) {
90 in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]};
91 }
92 auto kernel_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize));
93 auto pad_mode = PadMode(GetValue<int64_t>(primitive->GetAttr(kPadMode)));
94 auto batch = in_shape[0];
95 auto channel = in_shape[1];
96 auto in_h = in_shape[2];
97 auto in_w = in_shape[3];
98 auto strides = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStrides));
99 (void)CheckAndConvertUtils::CheckInteger("kernel size", SizeToLong(kernel_size.size()), kEqual, attr_size, op_name);
100 (void)CheckAndConvertUtils::CheckInteger("strides size", SizeToLong(strides.size()), kEqual, attr_size, op_name);
101 if (std::any_of(strides.begin(), strides.end(), [](int64_t stride) { return stride <= 0; })) {
102 MS_LOG(EXCEPTION) << "Strides is not valid, strides must be positive.";
103 }
104 if (std::any_of(kernel_size.begin(), kernel_size.end(), [](int64_t size) { return size <= 0; })) {
105 MS_LOG(EXCEPTION) << "Kernel size is not valid, kernel size must be positive.";
106 }
107 auto kernel_h = kernel_size[2];
108 auto kernel_w = kernel_size[3];
109 auto stride_h = strides[2];
110 auto stride_w = strides[3];
111 int64_t out_h = abstract::Shape::SHP_ANY;
112 int64_t out_w = abstract::Shape::SHP_ANY;
113 if (pad_mode == VALID) {
114 out_h = static_cast<int64_t>(std::ceil((in_h - (kernel_h - 1)) / static_cast<float>(stride_h)));
115 out_w = static_cast<int64_t>(std::ceil((in_w - (kernel_w - 1)) / static_cast<float>(stride_w)));
116 } else if (pad_mode == SAME) {
117 out_h = static_cast<int64_t>(std::ceil(in_h / static_cast<float>(stride_h)));
118 out_w = static_cast<int64_t>(std::ceil(in_w / static_cast<float>(stride_w)));
119 }
120 std::vector<int64_t> out_shape = {batch, channel, out_h, out_w};
121 if (format == NHWC) {
122 out_shape = {batch, out_h, out_w, channel};
123 }
124 return std::make_shared<abstract::Shape>(out_shape);
125 }
126
InferType(const std::vector<AbstractBasePtr> & input_args)127 TypePtr InferType(const std::vector<AbstractBasePtr> &input_args) { return input_args[0]->BuildType(); }
128 } // namespace
129
AvgPoolInfer(const abstract::AnalysisEnginePtr &,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)130 AbstractBasePtr AvgPoolInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
131 const std::vector<AbstractBasePtr> &input_args) {
132 MS_EXCEPTION_IF_NULL(primitive);
133 const int64_t input_num = 1;
134 CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
135 return std::make_shared<abstract::AbstractTensor>(InferType(input_args), InferShape(primitive, input_args)->shape());
136 }
137 REGISTER_PRIMITIVE_C(kNameAvgPool, AvgPool);
138 } // namespace ops
139 } // namespace mindspore
140