• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ops/max_pool.h"
18 #include <algorithm>
19 #include <map>
20 #include <memory>
21 #include <string>
22 #include <vector>
23 
24 #include "abstract/abstract_value.h"
25 #include "abstract/dshape.h"
26 #include "abstract/ops/op_infer.h"
27 #include "abstract/ops/primitive_infer_map.h"
28 #include "abstract/utils.h"
29 #include "base/base.h"
30 #include "ir/anf.h"
31 #include "ir/primitive.h"
32 #include "ir/value.h"
33 #include "mindapi/base/shared_ptr.h"
34 #include "mindapi/ir/value.h"
35 #include "mindapi/src/helper.h"
36 #include "mindspore/core/ops/conv_pool_ops.h"
37 #include "ops/op_name.h"
38 #include "ops/primitive_c.h"
39 #include "utils/check_convert_utils.h"
40 #include "utils/log_adapter.h"
41 #include "utils/ms_context.h"
42 #include "utils/shape_utils.h"
43 
44 namespace mindspore {
45 namespace ops {
46 namespace {
47 constexpr size_t kSizeFour = 4;
48 constexpr size_t kIndex0 = 0;
49 constexpr size_t kIndex1 = 1;
50 constexpr size_t kIndex2 = 2;
51 constexpr size_t kIndex3 = 3;
52 constexpr auto kAttrPrimitiveTarget = "primitive_target";
53 
ConvertShapeNHWCToNCHW(std::vector<int64_t> * const nhwc_shape)54 void ConvertShapeNHWCToNCHW(std::vector<int64_t> *const nhwc_shape) {
55   if (nhwc_shape->empty()) {
56     return;
57   }
58   if (nhwc_shape->size() != kSizeFour) {
59     MS_EXCEPTION(ValueError) << "The size of shape should be 4, but got " << nhwc_shape->size();
60   }
61   int64_t tmp = (*nhwc_shape)[kIndex3];
62   (*nhwc_shape)[kIndex3] = (*nhwc_shape)[kIndex2];
63   (*nhwc_shape)[kIndex2] = (*nhwc_shape)[kIndex1];
64   (*nhwc_shape)[kIndex1] = tmp;
65 }
66 
CeilDiv(int64_t a,int64_t b)67 int64_t CeilDiv(int64_t a, int64_t b) {
68   if (b == 0) {
69     MS_EXCEPTION(ValueError) << "The number can not be divided by zero.";
70   }
71   int64_t result = a / b;
72   if (a % b != 0) {
73     result += 1;
74   }
75   return result;
76 }
77 
CheckOutshapeValid(const PrimitivePtr & primitive,const std::vector<int64_t> & out_shape,const std::vector<int64_t> & in_shape,const std::vector<int64_t> & kernel_size,const std::vector<int64_t> & strides)78 void CheckOutshapeValid(const PrimitivePtr &primitive, const std::vector<int64_t> &out_shape,
79                         const std::vector<int64_t> &in_shape, const std::vector<int64_t> &kernel_size,
80                         const std::vector<int64_t> &strides) {
81   bool is_dynamic_shape = std::any_of(in_shape.begin(), in_shape.end(), [](int64_t in) { return in == -1; });
82   for (auto out : out_shape) {
83     if (out <= 0 && !is_dynamic_shape) {
84       MS_EXCEPTION(ValueError)
85         << "For '" << primitive->name()
86         << "', the each element of the output shape must be larger than 0, but got output shape: " << out_shape
87         << ". The input shape: " << in_shape << ", kernel size: " << kernel_size << ", strides: " << strides
88         << ". Please check the official api documents for more information about the output.";
89     }
90   }
91 }
92 
GetDeviceTarget(const PrimitivePtr & primitive)93 string GetDeviceTarget(const PrimitivePtr &primitive) {
94   string primitive_target;
95   if (primitive->HasAttr(kAttrPrimitiveTarget)) {
96     primitive_target = GetValue<std::string>(primitive->GetAttr(kAttrPrimitiveTarget));
97   } else {
98     auto ms_context = MsContext::GetInstance();
99     MS_EXCEPTION_IF_NULL(ms_context);
100     primitive_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
101   }
102   return primitive_target;
103 }
104 
CheckDataFormat(const PrimitivePtr & primitive,Format data_format,const string & primitive_target)105 void CheckDataFormat(const PrimitivePtr &primitive, Format data_format, const string &primitive_target) {
106   if (data_format == NHWC) {
107     if (primitive_target != kGPUDevice) {
108       MS_EXCEPTION(ValueError) << "For '" << primitive->name()
109                                << "', the 'NHWC' format is only supported in GPU target.";
110     }
111   } else if (data_format != NCHW) {
112     MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', the input format should be NCHW or NHWC, but got "
113                              << data_format << ".";
114   }
115 }
116 
MaxPoolInferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)117 abstract::ShapePtr MaxPoolInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
118   MS_EXCEPTION_IF_NULL(primitive);
119   auto op_name = primitive->name();
120   std::vector<int64_t> kernel_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize));
121   std::vector<int64_t> strides = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStrides));
122   Format data_format = static_cast<Format>(CheckAndConvertUtils::GetAndCheckFormat(primitive->GetAttr(kFormat)));
123   (void)CheckAndConvertUtils::CheckPositiveVector("kernel_size", kernel_size, op_name);
124   (void)CheckAndConvertUtils::CheckPositiveVector("strides", strides, op_name);
125   int64_t pad_mode_int = 0;
126   CheckAndConvertUtils::GetPadModEnumValue(primitive->GetAttr(kPadMode), &pad_mode_int, true);
127   PadMode pad_mode = static_cast<PadMode>(pad_mode_int);
128 
129   (void)CheckAndConvertUtils::CheckValue<size_t>("length of kernel_size", kernel_size.size(), kEqual, kSizeFour,
130                                                  op_name);
131   (void)CheckAndConvertUtils::CheckValue<size_t>("length of strides", strides.size(), kEqual, kSizeFour, op_name);
132 
133   auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShape());
134   if (shape_map.empty()) {
135     MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', the input should exist, but missed.";
136   }
137   string device_target = GetDeviceTarget(primitive);
138   CheckDataFormat(primitive, data_format, device_target);
139   auto in_shape = shape_map[kShape];
140   if (IsDynamicRank(in_shape)) {
141     if (device_target == kAscendDevice) {
142       MS_EXCEPTION(ValueError) << "For '" << primitive->name()
143                                << "', the Ascend platform hasn't support dynamic rank yet.";
144     }
145     std::vector<int64_t> out_shape = {-2};
146     return std::make_shared<abstract::Shape>(out_shape);
147   }
148   (void)CheckAndConvertUtils::CheckValue<size_t>("length of input", in_shape.size(), kEqual, kSizeFour, op_name);
149 
150   if (data_format == NHWC) {
151     ConvertShapeNHWCToNCHW(&in_shape);
152   }
153 
154   int64_t out_h = 0;
155   int64_t out_w = 0;
156   if (pad_mode == PadMode::SAME) {
157     out_h = in_shape[kIndex2] == -1 ? -1 : CeilDiv(in_shape[kIndex2], strides[kIndex2]);
158     out_w = in_shape[kIndex3] == -1 ? -1 : CeilDiv(in_shape[kIndex3], strides[kIndex3]);
159   } else if (pad_mode == PadMode::VALID) {
160     out_h = in_shape[kIndex2] == -1 ? -1 : CeilDiv((in_shape[kIndex2] - (kernel_size[kIndex2] - 1)), strides[kIndex2]);
161     out_w = in_shape[kIndex3] == -1 ? -1 : CeilDiv((in_shape[kIndex3] - (kernel_size[kIndex3] - 1)), strides[kIndex3]);
162   } else {
163     MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', the pad_mode should be same or valid, but got "
164                              << pad_mode << ".";
165   }
166 
167   abstract::ShapePtr shape;
168   std::vector<int64_t> out_shape;
169   if (data_format == NHWC) {
170     out_shape = {in_shape[kIndex0], out_h, out_w, in_shape[kIndex1]};
171     shape = std::make_shared<abstract::Shape>(out_shape);
172   } else {
173     out_shape = {in_shape[kIndex0], in_shape[kIndex1], out_h, out_w};
174     shape = std::make_shared<abstract::Shape>(out_shape);
175   }
176 
177   CheckOutshapeValid(primitive, out_shape, in_shape, kernel_size, strides);
178   return shape;
179 }
180 
MaxPoolInferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)181 TypePtr MaxPoolInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
182   if (input_args.size() == 0) {
183     MS_EXCEPTION(TypeError) << "For '" << primitive->name()
184                             << "', the input args used for infer shape and type is necessary, but missing it.";
185   }
186   if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr &a) { return a == nullptr; })) {
187     MS_EXCEPTION(TypeError) << "For '" << primitive->name()
188                             << "', the input args used for infer shape and type is necessary, but missing it.";
189   }
190   auto type = CheckAndConvertUtils::GetTensorInputType(primitive->name(), input_args, 0);
191   return type;
192 }
193 }  // namespace
194 
set_pad_mode(const PadMode & pad_mode)195 void MaxPool::set_pad_mode(const PadMode &pad_mode) {
196   int64_t swi = pad_mode;
197   (void)this->AddAttr(kPadMode, api::MakeValue(swi));
198 }
199 
get_pad_mode() const200 PadMode MaxPool::get_pad_mode() const { return PadMode(GetValue<int64_t>(GetAttr(kPadMode))); }
set_kernel_size(const std::vector<int64_t> & kernel_size)201 void MaxPool::set_kernel_size(const std::vector<int64_t> &kernel_size) {
202   (void)this->AddAttr(
203     kKernelSize, api::MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, this->name())));
204 }
205 
get_kernel_size() const206 std::vector<int64_t> MaxPool::get_kernel_size() const { return GetValue<std::vector<int64_t>>(GetAttr(kKernelSize)); }
set_strides(const std::vector<int64_t> & strides)207 void MaxPool::set_strides(const std::vector<int64_t> &strides) {
208   (void)this->AddAttr(kStrides,
209                       api::MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStrides, strides, this->name())));
210 }
211 
get_strides() const212 std::vector<int64_t> MaxPool::get_strides() const { return GetValue<std::vector<int64_t>>(GetAttr(kStrides)); }
213 
set_format(const Format & format)214 void MaxPool::set_format(const Format &format) {
215   int64_t f = format;
216   (void)this->AddAttr(kFormat, api::MakeValue(f));
217 }
218 
get_format() const219 Format MaxPool::get_format() const { return Format(GetValue<int64_t>(GetAttr(kFormat))); }
220 
set_pad(const std::vector<int64_t> & pad)221 void MaxPool::set_pad(const std::vector<int64_t> &pad) { (void)this->AddAttr(kPad, api::MakeValue(pad)); }
222 
get_pad() const223 std::vector<int64_t> MaxPool::get_pad() const {
224   auto value_ptr = GetAttr(kPad);
225   return GetValue<std::vector<int64_t>>(value_ptr);
226 }
227 
set_round_mode(const RoundMode & round_mode)228 void MaxPool::set_round_mode(const RoundMode &round_mode) {
229   int64_t swi = round_mode;
230   (void)this->AddAttr(kRoundMode, api::MakeValue(swi));
231 }
232 
get_round_mode() const233 RoundMode MaxPool::get_round_mode() const {
234   auto value_ptr = GetAttr(kRoundMode);
235   return RoundMode(GetValue<int64_t>(value_ptr));
236 }
237 
Init(const std::vector<int64_t> & kernel_size,const std::vector<int64_t> & stride,const PadMode & pad_mode,const Format & format,const std::vector<int64_t> & pad,const RoundMode & round_mode)238 void MaxPool::Init(const std::vector<int64_t> &kernel_size, const std::vector<int64_t> &stride, const PadMode &pad_mode,
239                    const Format &format, const std::vector<int64_t> &pad, const RoundMode &round_mode) {
240   this->set_pad_mode(pad_mode);
241   this->set_kernel_size(kernel_size);
242   this->set_strides(stride);
243   this->set_format(format);
244   this->set_pad(pad);
245   this->set_round_mode(round_mode);
246 }
247 
248 MIND_API_OPERATOR_IMPL(MaxPool, BaseOperator);
MaxPoolInfer(const abstract::AnalysisEnginePtr &,const PrimitivePtr & primitive,const std::vector<abstract::AbstractBasePtr> & input_args)249 abstract::AbstractBasePtr MaxPoolInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
250                                        const std::vector<abstract::AbstractBasePtr> &input_args) {
251   TypePtr type = MaxPoolInferType(primitive, input_args);
252   abstract::ShapePtr shape = MaxPoolInferShape(primitive, input_args);
253   return abstract::MakeAbstract(shape, type);
254 }
255 
256 // AG means auto generated
257 class MIND_API AGMaxPoolInfer : public abstract::OpInferBase {
258  public:
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const259   BaseShapePtr InferShape(const PrimitivePtr &primitive,
260                           const std::vector<AbstractBasePtr> &input_args) const override {
261     return MaxPoolInferShape(primitive, input_args);
262   }
263 
InferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const264   TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
265     return MaxPoolInferType(primitive, input_args);
266   }
InferShapeAndType(const abstract::AnalysisEnginePtr & engine,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const267   AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
268                                     const std::vector<AbstractBasePtr> &input_args) const override {
269     return MaxPoolInfer(engine, primitive, input_args);
270   }
271 };
272 
273 REGISTER_PRIMITIVE_OP_INFER_IMPL(MaxPool, prim::kPrimMaxPool, AGMaxPoolInfer, false);
274 }  // namespace ops
275 }  // namespace mindspore
276