1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ops/grad/pooling_grad.h"
18
19 #include "mindapi/base/shared_ptr.h"
20 #include "mindapi/ir/value.h"
21 #include "mindapi/src/helper.h"
22 #include "ops/op_name.h"
23 #include "ops/primitive_c.h"
24 #include "utils/log_adapter.h"
25 #include "abstract/abstract_value.h"
26 #include "abstract/dshape.h"
27 #include "abstract/ops/primitive_infer_map.h"
28 #include "abstract/utils.h"
29 #include "ir/anf.h"
30 #include "ir/dtype/number.h"
31 #include "ir/primitive.h"
32 #include "mindapi/base/shape_vector.h"
33 #include "utils/check_convert_utils.h"
34 #include "ops/conv_pool_ops.h"
35
36 namespace mindspore {
37 namespace ops {
38 MIND_API_OPERATOR_IMPL(PoolingGrad, BaseOperator);
Init(const PoolMode & pool_mode,const std::vector<int64_t> & window,const std::vector<int64_t> & stride,const PadMode & pad_mode,const std::vector<int64_t> & pad_list,const RoundMode & round_mode,const Format & format,const bool global)39 void PoolingGrad::Init(const PoolMode &pool_mode, const std::vector<int64_t> &window,
40 const std::vector<int64_t> &stride, const PadMode &pad_mode,
41 const std::vector<int64_t> &pad_list, const RoundMode &round_mode, const Format &format,
42 const bool global) {
43 set_pool_mode(pool_mode);
44 set_window(window);
45 set_stride(stride);
46 set_pad_mode(pad_mode);
47 set_pad_list(pad_list);
48 set_round_mode(round_mode);
49 set_format(format);
50 set_global(global);
51 }
52
set_pool_mode(const PoolMode & pool_mode)53 void PoolingGrad::set_pool_mode(const PoolMode &pool_mode) {
54 int64_t swi = pool_mode;
55 (void)this->AddAttr(kPoolMode, api::MakeValue(swi));
56 }
57
get_pool_mode() const58 PoolMode PoolingGrad::get_pool_mode() const {
59 auto value_ptr = GetAttr(kPoolMode);
60 return PoolMode(GetValue<int64_t>(value_ptr));
61 }
62
set_window(const std::vector<int64_t> & window)63 void PoolingGrad::set_window(const std::vector<int64_t> &window) {
64 (void)this->AddAttr(kWindow, api::MakeValue(window));
65 }
66
get_window() const67 std::vector<int64_t> PoolingGrad::get_window() const {
68 auto value_ptr = GetAttr(kWindow);
69 MS_EXCEPTION_IF_NULL(value_ptr);
70 return GetValue<std::vector<int64_t>>(value_ptr);
71 }
72
set_stride(const std::vector<int64_t> & stride)73 void PoolingGrad::set_stride(const std::vector<int64_t> &stride) {
74 (void)this->AddAttr(kStride, api::MakeValue(stride));
75 }
76
get_stride() const77 std::vector<int64_t> PoolingGrad::get_stride() const {
78 auto value_ptr = GetAttr(kStride);
79 MS_EXCEPTION_IF_NULL(value_ptr);
80 return GetValue<std::vector<int64_t>>(value_ptr);
81 }
82
set_pad_mode(const PadMode & pad_mode)83 void PoolingGrad::set_pad_mode(const PadMode &pad_mode) {
84 int64_t swi = pad_mode;
85 (void)this->AddAttr(kPadMode, api::MakeValue(swi));
86 }
87
get_pad_mode() const88 PadMode PoolingGrad::get_pad_mode() const {
89 auto value_ptr = GetAttr(kPadMode);
90 MS_EXCEPTION_IF_NULL(value_ptr);
91 return PadMode(GetValue<int64_t>(value_ptr));
92 }
93
set_pad_list(const std::vector<int64_t> & pad_list)94 void PoolingGrad::set_pad_list(const std::vector<int64_t> &pad_list) {
95 (void)this->AddAttr(kPadList, api::MakeValue(pad_list));
96 }
97
get_pad_list() const98 std::vector<int64_t> PoolingGrad::get_pad_list() const {
99 auto value_ptr = GetAttr(kPadList);
100 MS_EXCEPTION_IF_NULL(value_ptr);
101 return GetValue<std::vector<int64_t>>(value_ptr);
102 }
103
set_round_mode(const RoundMode & round_mode)104 void PoolingGrad::set_round_mode(const RoundMode &round_mode) {
105 int64_t swi = round_mode;
106 (void)this->AddAttr(kRoundMode, api::MakeValue(swi));
107 }
108
get_round_mode() const109 RoundMode PoolingGrad::get_round_mode() const {
110 auto value_ptr = GetAttr(kRoundMode);
111 MS_EXCEPTION_IF_NULL(value_ptr);
112 return RoundMode(GetValue<int64_t>(value_ptr));
113 }
114
set_format(const Format & format)115 void PoolingGrad::set_format(const Format &format) {
116 int64_t swi = format;
117 (void)this->AddAttr(kFormat, api::MakeValue(swi));
118 }
119
get_format() const120 Format PoolingGrad::get_format() const {
121 auto value_ptr = GetAttr(kFormat);
122 MS_EXCEPTION_IF_NULL(value_ptr);
123 return Format(GetValue<int64_t>(value_ptr));
124 }
125
set_global(const bool global)126 void PoolingGrad::set_global(const bool global) { (void)this->AddAttr(kGlobal, api::MakeValue(global)); }
127
get_global() const128 bool PoolingGrad::get_global() const {
129 auto value_ptr = GetAttr(kGlobal);
130 MS_EXCEPTION_IF_NULL(value_ptr);
131 return GetValue<bool>(value_ptr);
132 }
133
134 class MIND_API PoolingGradInfer : public abstract::OpInferBase {
135 public:
136 // This is used for backend infer by kernel tensor.
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const137 BaseShapePtr InferShape(const PrimitivePtr &primitive,
138 const std::vector<AbstractBasePtr> &input_args) const override {
139 // Inputs: three tensors(y, dy, x).
140 constexpr auto kPoolingGradInputNum = 3;
141 const std::string op_name = primitive->name();
142 CheckArgsSize(op_name, input_args, kPoolingGradInputNum);
143 return input_args[kIndex2]->GetShape()->Clone();
144 }
145
146 // This is used for backend infer by kernel tensor.
InferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const147 TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
148 // Inputs: three tensors(y, dy, x).
149 constexpr auto kPoolingGradInputNum = 3;
150 const std::string op_name = primitive->name();
151 CheckArgsSize(op_name, input_args, kPoolingGradInputNum);
152 return input_args[kIndex1]->GetType()->Clone();
153 }
154
InferShapeAndType(const abstract::AnalysisEnginePtr &,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const155 AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
156 const std::vector<AbstractBasePtr> &input_args) const override {
157 // Inputs: three tensors(y, dy, x).
158 constexpr auto kPoolingGradInputNum = 3;
159 const std::string op_name = primitive->name();
160 CheckArgsSize(op_name, input_args, kPoolingGradInputNum);
161 auto out_y = abstract::CheckArg<abstract::AbstractTensor>(op_name, input_args, kIndex0);
162 auto d_out = abstract::CheckArg<abstract::AbstractTensor>(op_name, input_args, kIndex1);
163 auto input_x = abstract::CheckArg<abstract::AbstractTensor>(op_name, input_args, kIndex2);
164 (void)abstract::CheckTensorsDTypeSame({out_y, d_out, input_x}, {kInt, kUInt, kFloat},
165 op_name + "evaluator three inputs should be %s");
166
167 AbstractBasePtr ret = d_out->Broaden();
168 auto x_shape = dyn_cast<abstract::TensorShape>(input_args[2]->GetShapeTrack());
169 MS_EXCEPTION_IF_NULL(x_shape);
170 ret->set_shape(x_shape);
171 return ret;
172 }
173 };
174
175 REGISTER_PRIMITIVE_OP_INFER_IMPL(PoolingGrad, prim::kPrimPoolingGrad, PoolingGradInfer, false);
176 } // namespace ops
177 } // namespace mindspore
178