• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ops/grad/pooling_grad.h"
18 
19 #include "mindapi/base/shared_ptr.h"
20 #include "mindapi/ir/value.h"
21 #include "mindapi/src/helper.h"
22 #include "ops/op_name.h"
23 #include "ops/primitive_c.h"
24 #include "utils/log_adapter.h"
25 #include "abstract/abstract_value.h"
26 #include "abstract/dshape.h"
27 #include "abstract/ops/primitive_infer_map.h"
28 #include "abstract/utils.h"
29 #include "ir/anf.h"
30 #include "ir/dtype/number.h"
31 #include "ir/primitive.h"
32 #include "mindapi/base/shape_vector.h"
33 #include "utils/check_convert_utils.h"
34 #include "ops/conv_pool_ops.h"
35 
36 namespace mindspore {
37 namespace ops {
38 MIND_API_OPERATOR_IMPL(PoolingGrad, BaseOperator);
Init(const PoolMode & pool_mode,const std::vector<int64_t> & window,const std::vector<int64_t> & stride,const PadMode & pad_mode,const std::vector<int64_t> & pad_list,const RoundMode & round_mode,const Format & format,const bool global)39 void PoolingGrad::Init(const PoolMode &pool_mode, const std::vector<int64_t> &window,
40                        const std::vector<int64_t> &stride, const PadMode &pad_mode,
41                        const std::vector<int64_t> &pad_list, const RoundMode &round_mode, const Format &format,
42                        const bool global) {
43   set_pool_mode(pool_mode);
44   set_window(window);
45   set_stride(stride);
46   set_pad_mode(pad_mode);
47   set_pad_list(pad_list);
48   set_round_mode(round_mode);
49   set_format(format);
50   set_global(global);
51 }
52 
set_pool_mode(const PoolMode & pool_mode)53 void PoolingGrad::set_pool_mode(const PoolMode &pool_mode) {
54   int64_t swi = pool_mode;
55   (void)this->AddAttr(kPoolMode, api::MakeValue(swi));
56 }
57 
get_pool_mode() const58 PoolMode PoolingGrad::get_pool_mode() const {
59   auto value_ptr = GetAttr(kPoolMode);
60   return PoolMode(GetValue<int64_t>(value_ptr));
61 }
62 
set_window(const std::vector<int64_t> & window)63 void PoolingGrad::set_window(const std::vector<int64_t> &window) {
64   (void)this->AddAttr(kWindow, api::MakeValue(window));
65 }
66 
get_window() const67 std::vector<int64_t> PoolingGrad::get_window() const {
68   auto value_ptr = GetAttr(kWindow);
69   MS_EXCEPTION_IF_NULL(value_ptr);
70   return GetValue<std::vector<int64_t>>(value_ptr);
71 }
72 
set_stride(const std::vector<int64_t> & stride)73 void PoolingGrad::set_stride(const std::vector<int64_t> &stride) {
74   (void)this->AddAttr(kStride, api::MakeValue(stride));
75 }
76 
get_stride() const77 std::vector<int64_t> PoolingGrad::get_stride() const {
78   auto value_ptr = GetAttr(kStride);
79   MS_EXCEPTION_IF_NULL(value_ptr);
80   return GetValue<std::vector<int64_t>>(value_ptr);
81 }
82 
set_pad_mode(const PadMode & pad_mode)83 void PoolingGrad::set_pad_mode(const PadMode &pad_mode) {
84   int64_t swi = pad_mode;
85   (void)this->AddAttr(kPadMode, api::MakeValue(swi));
86 }
87 
get_pad_mode() const88 PadMode PoolingGrad::get_pad_mode() const {
89   auto value_ptr = GetAttr(kPadMode);
90   MS_EXCEPTION_IF_NULL(value_ptr);
91   return PadMode(GetValue<int64_t>(value_ptr));
92 }
93 
set_pad_list(const std::vector<int64_t> & pad_list)94 void PoolingGrad::set_pad_list(const std::vector<int64_t> &pad_list) {
95   (void)this->AddAttr(kPadList, api::MakeValue(pad_list));
96 }
97 
get_pad_list() const98 std::vector<int64_t> PoolingGrad::get_pad_list() const {
99   auto value_ptr = GetAttr(kPadList);
100   MS_EXCEPTION_IF_NULL(value_ptr);
101   return GetValue<std::vector<int64_t>>(value_ptr);
102 }
103 
set_round_mode(const RoundMode & round_mode)104 void PoolingGrad::set_round_mode(const RoundMode &round_mode) {
105   int64_t swi = round_mode;
106   (void)this->AddAttr(kRoundMode, api::MakeValue(swi));
107 }
108 
get_round_mode() const109 RoundMode PoolingGrad::get_round_mode() const {
110   auto value_ptr = GetAttr(kRoundMode);
111   MS_EXCEPTION_IF_NULL(value_ptr);
112   return RoundMode(GetValue<int64_t>(value_ptr));
113 }
114 
set_format(const Format & format)115 void PoolingGrad::set_format(const Format &format) {
116   int64_t swi = format;
117   (void)this->AddAttr(kFormat, api::MakeValue(swi));
118 }
119 
get_format() const120 Format PoolingGrad::get_format() const {
121   auto value_ptr = GetAttr(kFormat);
122   MS_EXCEPTION_IF_NULL(value_ptr);
123   return Format(GetValue<int64_t>(value_ptr));
124 }
125 
set_global(const bool global)126 void PoolingGrad::set_global(const bool global) { (void)this->AddAttr(kGlobal, api::MakeValue(global)); }
127 
get_global() const128 bool PoolingGrad::get_global() const {
129   auto value_ptr = GetAttr(kGlobal);
130   MS_EXCEPTION_IF_NULL(value_ptr);
131   return GetValue<bool>(value_ptr);
132 }
133 
134 class MIND_API PoolingGradInfer : public abstract::OpInferBase {
135  public:
136   // This is used for backend infer by kernel tensor.
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const137   BaseShapePtr InferShape(const PrimitivePtr &primitive,
138                           const std::vector<AbstractBasePtr> &input_args) const override {
139     // Inputs: three tensors(y, dy, x).
140     constexpr auto kPoolingGradInputNum = 3;
141     const std::string op_name = primitive->name();
142     CheckArgsSize(op_name, input_args, kPoolingGradInputNum);
143     return input_args[kIndex2]->GetShape()->Clone();
144   }
145 
146   // This is used for backend infer by kernel tensor.
InferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const147   TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
148     // Inputs: three tensors(y, dy, x).
149     constexpr auto kPoolingGradInputNum = 3;
150     const std::string op_name = primitive->name();
151     CheckArgsSize(op_name, input_args, kPoolingGradInputNum);
152     return input_args[kIndex1]->GetType()->Clone();
153   }
154 
InferShapeAndType(const abstract::AnalysisEnginePtr &,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const155   AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
156                                     const std::vector<AbstractBasePtr> &input_args) const override {
157     // Inputs: three tensors(y, dy, x).
158     constexpr auto kPoolingGradInputNum = 3;
159     const std::string op_name = primitive->name();
160     CheckArgsSize(op_name, input_args, kPoolingGradInputNum);
161     auto out_y = abstract::CheckArg<abstract::AbstractTensor>(op_name, input_args, kIndex0);
162     auto d_out = abstract::CheckArg<abstract::AbstractTensor>(op_name, input_args, kIndex1);
163     auto input_x = abstract::CheckArg<abstract::AbstractTensor>(op_name, input_args, kIndex2);
164     (void)abstract::CheckTensorsDTypeSame({out_y, d_out, input_x}, {kInt, kUInt, kFloat},
165                                           op_name + "evaluator three inputs should be %s");
166 
167     AbstractBasePtr ret = d_out->Broaden();
168     auto x_shape = dyn_cast<abstract::TensorShape>(input_args[2]->GetShapeTrack());
169     MS_EXCEPTION_IF_NULL(x_shape);
170     ret->set_shape(x_shape);
171     return ret;
172   }
173 };
174 
175 REGISTER_PRIMITIVE_OP_INFER_IMPL(PoolingGrad, prim::kPrimPoolingGrad, PoolingGradInfer, false);
176 }  // namespace ops
177 }  // namespace mindspore
178