• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ops/max_pool.h"
18 #include <string>
19 #include <algorithm>
20 #include <memory>
21 #include <set>
22 #include <vector>
23 #include "ops/op_utils.h"
24 #include "utils/check_convert_utils.h"
25 #include "abstract/primitive_infer_map.h"
26 
27 namespace mindspore {
28 namespace ops {
set_pad_mode(const PadMode & pad_mode)29 void MaxPool::set_pad_mode(const PadMode &pad_mode) {
30   int64_t swi = pad_mode;
31   (void)this->AddAttr(kPadMode, MakeValue(swi));
32 }
33 
get_pad_mode() const34 PadMode MaxPool::get_pad_mode() const { return PadMode(GetValue<int64_t>(GetAttr(kPadMode))); }
set_kernel_size(const std::vector<int64_t> & kernel_size)35 void MaxPool::set_kernel_size(const std::vector<int64_t> &kernel_size) {
36   (void)this->AddAttr(kKernelSize,
37                       MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, this->name())));
38 }
39 
get_kernel_size() const40 std::vector<int64_t> MaxPool::get_kernel_size() const { return GetValue<std::vector<int64_t>>(GetAttr(kKernelSize)); }
set_strides(const std::vector<int64_t> & strides)41 void MaxPool::set_strides(const std::vector<int64_t> &strides) {
42   (void)this->AddAttr(kStrides, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStrides, strides, this->name())));
43 }
44 
get_strides() const45 std::vector<int64_t> MaxPool::get_strides() const { return GetValue<std::vector<int64_t>>(GetAttr(kStrides)); }
46 
set_format(const Format & format)47 void MaxPool::set_format(const Format &format) {
48   int64_t f = format;
49   (void)this->AddAttr(kFormat, MakeValue(f));
50 }
51 
get_format() const52 Format MaxPool::get_format() const { return Format(GetValue<int64_t>(GetAttr(kFormat))); }
53 
set_pad(const std::vector<int64_t> & pad)54 void MaxPool::set_pad(const std::vector<int64_t> &pad) { (void)this->AddAttr(kPad, MakeValue(pad)); }
55 
get_pad() const56 std::vector<int64_t> MaxPool::get_pad() const {
57   auto value_ptr = GetAttr(kPad);
58   return GetValue<std::vector<int64_t>>(value_ptr);
59 }
60 
set_round_mode(const RoundMode & round_mode)61 void MaxPool::set_round_mode(const RoundMode &round_mode) {
62   int64_t swi = round_mode;
63   (void)this->AddAttr(kRoundMode, MakeValue(swi));
64 }
65 
get_round_mode() const66 RoundMode MaxPool::get_round_mode() const {
67   auto value_ptr = GetAttr(kRoundMode);
68   return RoundMode(GetValue<int64_t>(value_ptr));
69 }
70 
Init(const std::vector<int64_t> & kernel_size,const std::vector<int64_t> & stride,const PadMode & pad_mode,const Format & format,const std::vector<int64_t> & pad,const RoundMode & round_mode)71 void MaxPool::Init(const std::vector<int64_t> &kernel_size, const std::vector<int64_t> &stride, const PadMode &pad_mode,
72                    const Format &format, const std::vector<int64_t> &pad, const RoundMode &round_mode) {
73   this->set_pad_mode(pad_mode);
74   this->set_kernel_size(kernel_size);
75   this->set_strides(stride);
76   this->set_format(format);
77   this->set_pad(pad);
78   this->set_round_mode(round_mode);
79 }
80 
81 namespace {
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)82 abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
83   MS_EXCEPTION_IF_NULL(primitive);
84   auto op_name = primitive->name();
85   MS_EXCEPTION_IF_NULL(input_args[0]);
86   auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
87   auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
88   if (format == NHWC) {
89     in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]};
90   }
91   const int64_t x_rank = 4;
92   (void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kEqual, x_rank, op_name);
93 
94   auto kernel_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize));
95   auto pad_mode_value = (primitive->GetAttr(kPadMode));
96   auto pad_mode = PadMode(GetValue<int64_t>(pad_mode_value));
97   auto batch = in_shape[0];
98   auto channel = in_shape[1];
99   auto in_h = in_shape[2];
100   auto in_w = in_shape[3];
101   auto strides = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStrides));
102   auto kernel_h = kernel_size[2];
103   auto kernel_w = kernel_size[3];
104   auto stride_h = strides[2];
105   auto stride_w = strides[3];
106   int64_t out_h = abstract::Shape::SHP_ANY;
107   int64_t out_w = abstract::Shape::SHP_ANY;
108   if (pad_mode == VALID) {
109     out_h = static_cast<int64_t>(ceil((in_h - (kernel_h - 1)) + static_cast<float>(stride_h) - 1) /
110                                  static_cast<float>(stride_h));
111     out_w = static_cast<int64_t>(ceil((in_w - (kernel_w - 1)) + static_cast<float>(stride_w) - 1) /
112                                  static_cast<float>(stride_w));
113   } else if (pad_mode == SAME) {
114     out_h = static_cast<int64_t>(ceil(in_h / static_cast<float>(stride_h)));
115     out_w = static_cast<int64_t>(ceil(in_w / static_cast<float>(stride_w)));
116   }
117   std::vector<int64_t> out_shape = {batch, channel, out_h, out_w};
118   if (format == NHWC) {
119     out_shape = {batch, out_h, out_w, channel};
120   }
121   if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t a) { return a <= 0; })) {
122     MS_LOG(EXCEPTION) << "Kernel size is not valid.";
123   }
124   return std::make_shared<abstract::Shape>(out_shape);
125 }
126 
InferType(const PrimitivePtr & prim,const std::vector<AbstractBasePtr> & input_args)127 TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
128   if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr arg) { return arg == nullptr; })) {
129     MS_LOG(EXCEPTION) << "nullptr";
130   }
131   auto name = prim->name();
132   MS_LOG(DEBUG) << "Infer data type for : " << name;
133   auto input_type = input_args[0]->BuildType();
134   MS_EXCEPTION_IF_NULL(input_type);
135   auto input_tensor_type = input_type->cast<TensorTypePtr>();
136   if (input_tensor_type == nullptr) {
137     MS_LOG_EXCEPTION << "The maxpool's input must be a tensor but got " << input_type->ToString();
138   }
139   return input_tensor_type->element();
140 }
141 }  // namespace
142 
MaxPoolInfer(const abstract::AnalysisEnginePtr &,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)143 AbstractBasePtr MaxPoolInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
144                              const std::vector<AbstractBasePtr> &input_args) {
145   return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
146                                                     InferShape(primitive, input_args)->shape());
147 }
148 REGISTER_PRIMITIVE_C(kNameMaxPool, MaxPool);
149 }  // namespace ops
150 }  // namespace mindspore
151