1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ops/max_pool.h"
18 #include <algorithm>
19 #include <map>
20 #include <memory>
21 #include <string>
22 #include <vector>
23
24 #include "abstract/abstract_value.h"
25 #include "abstract/dshape.h"
26 #include "abstract/ops/op_infer.h"
27 #include "abstract/ops/primitive_infer_map.h"
28 #include "abstract/utils.h"
29 #include "base/base.h"
30 #include "ir/anf.h"
31 #include "ir/primitive.h"
32 #include "ir/value.h"
33 #include "mindapi/base/shared_ptr.h"
34 #include "mindapi/ir/value.h"
35 #include "mindapi/src/helper.h"
36 #include "mindspore/core/ops/conv_pool_ops.h"
37 #include "ops/op_name.h"
38 #include "ops/primitive_c.h"
39 #include "utils/check_convert_utils.h"
40 #include "utils/log_adapter.h"
41 #include "utils/ms_context.h"
42 #include "utils/shape_utils.h"
43
44 namespace mindspore {
45 namespace ops {
46 namespace {
47 constexpr size_t kSizeFour = 4;
48 constexpr size_t kIndex0 = 0;
49 constexpr size_t kIndex1 = 1;
50 constexpr size_t kIndex2 = 2;
51 constexpr size_t kIndex3 = 3;
52 constexpr auto kAttrPrimitiveTarget = "primitive_target";
53
ConvertShapeNHWCToNCHW(std::vector<int64_t> * const nhwc_shape)54 void ConvertShapeNHWCToNCHW(std::vector<int64_t> *const nhwc_shape) {
55 if (nhwc_shape->empty()) {
56 return;
57 }
58 if (nhwc_shape->size() != kSizeFour) {
59 MS_EXCEPTION(ValueError) << "The size of shape should be 4, but got " << nhwc_shape->size();
60 }
61 int64_t tmp = (*nhwc_shape)[kIndex3];
62 (*nhwc_shape)[kIndex3] = (*nhwc_shape)[kIndex2];
63 (*nhwc_shape)[kIndex2] = (*nhwc_shape)[kIndex1];
64 (*nhwc_shape)[kIndex1] = tmp;
65 }
66
CeilDiv(int64_t a,int64_t b)67 int64_t CeilDiv(int64_t a, int64_t b) {
68 if (b == 0) {
69 MS_EXCEPTION(ValueError) << "The number can not be divided by zero.";
70 }
71 int64_t result = a / b;
72 if (a % b != 0) {
73 result += 1;
74 }
75 return result;
76 }
77
CheckOutshapeValid(const PrimitivePtr & primitive,const std::vector<int64_t> & out_shape,const std::vector<int64_t> & in_shape,const std::vector<int64_t> & kernel_size,const std::vector<int64_t> & strides)78 void CheckOutshapeValid(const PrimitivePtr &primitive, const std::vector<int64_t> &out_shape,
79 const std::vector<int64_t> &in_shape, const std::vector<int64_t> &kernel_size,
80 const std::vector<int64_t> &strides) {
81 bool is_dynamic_shape = std::any_of(in_shape.begin(), in_shape.end(), [](int64_t in) { return in == -1; });
82 for (auto out : out_shape) {
83 if (out <= 0 && !is_dynamic_shape) {
84 MS_EXCEPTION(ValueError)
85 << "For '" << primitive->name()
86 << "', the each element of the output shape must be larger than 0, but got output shape: " << out_shape
87 << ". The input shape: " << in_shape << ", kernel size: " << kernel_size << ", strides: " << strides
88 << ". Please check the official api documents for more information about the output.";
89 }
90 }
91 }
92
GetDeviceTarget(const PrimitivePtr & primitive)93 string GetDeviceTarget(const PrimitivePtr &primitive) {
94 string primitive_target;
95 if (primitive->HasAttr(kAttrPrimitiveTarget)) {
96 primitive_target = GetValue<std::string>(primitive->GetAttr(kAttrPrimitiveTarget));
97 } else {
98 auto ms_context = MsContext::GetInstance();
99 MS_EXCEPTION_IF_NULL(ms_context);
100 primitive_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
101 }
102 return primitive_target;
103 }
104
CheckDataFormat(const PrimitivePtr & primitive,Format data_format,const string & primitive_target)105 void CheckDataFormat(const PrimitivePtr &primitive, Format data_format, const string &primitive_target) {
106 if (data_format == NHWC) {
107 if (primitive_target != kGPUDevice) {
108 MS_EXCEPTION(ValueError) << "For '" << primitive->name()
109 << "', the 'NHWC' format is only supported in GPU target.";
110 }
111 } else if (data_format != NCHW) {
112 MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', the input format should be NCHW or NHWC, but got "
113 << data_format << ".";
114 }
115 }
116
MaxPoolInferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)117 abstract::ShapePtr MaxPoolInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
118 MS_EXCEPTION_IF_NULL(primitive);
119 auto op_name = primitive->name();
120 std::vector<int64_t> kernel_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize));
121 std::vector<int64_t> strides = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStrides));
122 Format data_format = static_cast<Format>(CheckAndConvertUtils::GetAndCheckFormat(primitive->GetAttr(kFormat)));
123 (void)CheckAndConvertUtils::CheckPositiveVector("kernel_size", kernel_size, op_name);
124 (void)CheckAndConvertUtils::CheckPositiveVector("strides", strides, op_name);
125 int64_t pad_mode_int = 0;
126 CheckAndConvertUtils::GetPadModEnumValue(primitive->GetAttr(kPadMode), &pad_mode_int, true);
127 PadMode pad_mode = static_cast<PadMode>(pad_mode_int);
128
129 (void)CheckAndConvertUtils::CheckValue<size_t>("length of kernel_size", kernel_size.size(), kEqual, kSizeFour,
130 op_name);
131 (void)CheckAndConvertUtils::CheckValue<size_t>("length of strides", strides.size(), kEqual, kSizeFour, op_name);
132
133 auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShape());
134 if (shape_map.empty()) {
135 MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', the input should exist, but missed.";
136 }
137 string device_target = GetDeviceTarget(primitive);
138 CheckDataFormat(primitive, data_format, device_target);
139 auto in_shape = shape_map[kShape];
140 if (IsDynamicRank(in_shape)) {
141 if (device_target == kAscendDevice) {
142 MS_EXCEPTION(ValueError) << "For '" << primitive->name()
143 << "', the Ascend platform hasn't support dynamic rank yet.";
144 }
145 std::vector<int64_t> out_shape = {-2};
146 return std::make_shared<abstract::Shape>(out_shape);
147 }
148 (void)CheckAndConvertUtils::CheckValue<size_t>("length of input", in_shape.size(), kEqual, kSizeFour, op_name);
149
150 if (data_format == NHWC) {
151 ConvertShapeNHWCToNCHW(&in_shape);
152 }
153
154 int64_t out_h = 0;
155 int64_t out_w = 0;
156 if (pad_mode == PadMode::SAME) {
157 out_h = in_shape[kIndex2] == -1 ? -1 : CeilDiv(in_shape[kIndex2], strides[kIndex2]);
158 out_w = in_shape[kIndex3] == -1 ? -1 : CeilDiv(in_shape[kIndex3], strides[kIndex3]);
159 } else if (pad_mode == PadMode::VALID) {
160 out_h = in_shape[kIndex2] == -1 ? -1 : CeilDiv((in_shape[kIndex2] - (kernel_size[kIndex2] - 1)), strides[kIndex2]);
161 out_w = in_shape[kIndex3] == -1 ? -1 : CeilDiv((in_shape[kIndex3] - (kernel_size[kIndex3] - 1)), strides[kIndex3]);
162 } else {
163 MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', the pad_mode should be same or valid, but got "
164 << pad_mode << ".";
165 }
166
167 abstract::ShapePtr shape;
168 std::vector<int64_t> out_shape;
169 if (data_format == NHWC) {
170 out_shape = {in_shape[kIndex0], out_h, out_w, in_shape[kIndex1]};
171 shape = std::make_shared<abstract::Shape>(out_shape);
172 } else {
173 out_shape = {in_shape[kIndex0], in_shape[kIndex1], out_h, out_w};
174 shape = std::make_shared<abstract::Shape>(out_shape);
175 }
176
177 CheckOutshapeValid(primitive, out_shape, in_shape, kernel_size, strides);
178 return shape;
179 }
180
MaxPoolInferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)181 TypePtr MaxPoolInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
182 if (input_args.size() == 0) {
183 MS_EXCEPTION(TypeError) << "For '" << primitive->name()
184 << "', the input args used for infer shape and type is necessary, but missing it.";
185 }
186 if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr &a) { return a == nullptr; })) {
187 MS_EXCEPTION(TypeError) << "For '" << primitive->name()
188 << "', the input args used for infer shape and type is necessary, but missing it.";
189 }
190 auto type = CheckAndConvertUtils::GetTensorInputType(primitive->name(), input_args, 0);
191 return type;
192 }
193 } // namespace
194
set_pad_mode(const PadMode & pad_mode)195 void MaxPool::set_pad_mode(const PadMode &pad_mode) {
196 int64_t swi = pad_mode;
197 (void)this->AddAttr(kPadMode, api::MakeValue(swi));
198 }
199
get_pad_mode() const200 PadMode MaxPool::get_pad_mode() const { return PadMode(GetValue<int64_t>(GetAttr(kPadMode))); }
set_kernel_size(const std::vector<int64_t> & kernel_size)201 void MaxPool::set_kernel_size(const std::vector<int64_t> &kernel_size) {
202 (void)this->AddAttr(
203 kKernelSize, api::MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, this->name())));
204 }
205
get_kernel_size() const206 std::vector<int64_t> MaxPool::get_kernel_size() const { return GetValue<std::vector<int64_t>>(GetAttr(kKernelSize)); }
set_strides(const std::vector<int64_t> & strides)207 void MaxPool::set_strides(const std::vector<int64_t> &strides) {
208 (void)this->AddAttr(kStrides,
209 api::MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStrides, strides, this->name())));
210 }
211
get_strides() const212 std::vector<int64_t> MaxPool::get_strides() const { return GetValue<std::vector<int64_t>>(GetAttr(kStrides)); }
213
set_format(const Format & format)214 void MaxPool::set_format(const Format &format) {
215 int64_t f = format;
216 (void)this->AddAttr(kFormat, api::MakeValue(f));
217 }
218
get_format() const219 Format MaxPool::get_format() const { return Format(GetValue<int64_t>(GetAttr(kFormat))); }
220
set_pad(const std::vector<int64_t> & pad)221 void MaxPool::set_pad(const std::vector<int64_t> &pad) { (void)this->AddAttr(kPad, api::MakeValue(pad)); }
222
get_pad() const223 std::vector<int64_t> MaxPool::get_pad() const {
224 auto value_ptr = GetAttr(kPad);
225 return GetValue<std::vector<int64_t>>(value_ptr);
226 }
227
set_round_mode(const RoundMode & round_mode)228 void MaxPool::set_round_mode(const RoundMode &round_mode) {
229 int64_t swi = round_mode;
230 (void)this->AddAttr(kRoundMode, api::MakeValue(swi));
231 }
232
get_round_mode() const233 RoundMode MaxPool::get_round_mode() const {
234 auto value_ptr = GetAttr(kRoundMode);
235 return RoundMode(GetValue<int64_t>(value_ptr));
236 }
237
Init(const std::vector<int64_t> & kernel_size,const std::vector<int64_t> & stride,const PadMode & pad_mode,const Format & format,const std::vector<int64_t> & pad,const RoundMode & round_mode)238 void MaxPool::Init(const std::vector<int64_t> &kernel_size, const std::vector<int64_t> &stride, const PadMode &pad_mode,
239 const Format &format, const std::vector<int64_t> &pad, const RoundMode &round_mode) {
240 this->set_pad_mode(pad_mode);
241 this->set_kernel_size(kernel_size);
242 this->set_strides(stride);
243 this->set_format(format);
244 this->set_pad(pad);
245 this->set_round_mode(round_mode);
246 }
247
248 MIND_API_OPERATOR_IMPL(MaxPool, BaseOperator);
MaxPoolInfer(const abstract::AnalysisEnginePtr &,const PrimitivePtr & primitive,const std::vector<abstract::AbstractBasePtr> & input_args)249 abstract::AbstractBasePtr MaxPoolInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
250 const std::vector<abstract::AbstractBasePtr> &input_args) {
251 TypePtr type = MaxPoolInferType(primitive, input_args);
252 abstract::ShapePtr shape = MaxPoolInferShape(primitive, input_args);
253 return abstract::MakeAbstract(shape, type);
254 }
255
256 // AG means auto generated
257 class MIND_API AGMaxPoolInfer : public abstract::OpInferBase {
258 public:
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const259 BaseShapePtr InferShape(const PrimitivePtr &primitive,
260 const std::vector<AbstractBasePtr> &input_args) const override {
261 return MaxPoolInferShape(primitive, input_args);
262 }
263
InferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const264 TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
265 return MaxPoolInferType(primitive, input_args);
266 }
InferShapeAndType(const abstract::AnalysisEnginePtr & engine,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const267 AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
268 const std::vector<AbstractBasePtr> &input_args) const override {
269 return MaxPoolInfer(engine, primitive, input_args);
270 }
271 };
272
273 REGISTER_PRIMITIVE_OP_INFER_IMPL(MaxPool, prim::kPrimMaxPool, AGMaxPoolInfer, false);
274 } // namespace ops
275 } // namespace mindspore
276