• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <algorithm>
18 #include <map>
19 #include <memory>
20 #include <set>
21 #include <string>
22 #include <vector>
23 
24 #include "mindapi/src/helper.h"
25 #include "mindspore/core/ops/conv_pool_ops.h"
26 #include "ops/grad/conv2d_backprop_input.h"
27 #include "ops/op_utils.h"
28 #include "utils/check_convert_utils.h"
29 
30 namespace mindspore {
31 namespace ops {
32 namespace {
33 using abstract::Shape;
34 constexpr size_t kConv2DBackpropInputDoutIndex = 0;
35 constexpr size_t kConv2DBackpropInputInputIndex = 1;
36 constexpr size_t kConv2DBackpropInputSizeIndex = 2;
37 constexpr auto kPadSize = 4;
38 
CalPadListForSameMode(const ShapeVector & dout_shape_norm,const ShapeVector & x_size_v,const ShapeVector & kernel_size,const ShapeVector & stride,const ShapeVector & dilation)39 ShapeVector CalPadListForSameMode(const ShapeVector &dout_shape_norm, const ShapeVector &x_size_v,
40                                   const ShapeVector &kernel_size, const ShapeVector &stride,
41                                   const ShapeVector &dilation) {
42   ShapeVector pad_list(kPadSize, Shape::kShapeDimAny);
43   if (IsDynamicRank(dout_shape_norm) || IsDynamicRank(x_size_v)) {
44     return pad_list;
45   }
46   const auto stride_h = stride[kIndex2];
47   const auto stride_w = stride[kIndex3];
48   const auto kernel_h = kernel_size[kIndex0];
49   const auto kernel_w = kernel_size[kIndex1];
50   const auto dilation_h = dilation[kIndex2];
51   const auto dilation_w = dilation[kIndex3];
52   constexpr auto pad_divisor = 2;
53   if (dout_shape_norm[kInputIndex2] != Shape::kShapeDimAny && x_size_v[kInputIndex2] != Shape::kShapeDimAny) {
54     auto pad_needed_h =
55       (dout_shape_norm[kInputIndex2] - 1) * stride_h + dilation_h * (kernel_h - 1) + 1 - x_size_v[kInputIndex2];
56     pad_needed_h = 0 > pad_needed_h ? 0 : pad_needed_h;
57     pad_list[kIndex0] = pad_needed_h / pad_divisor;
58     pad_list[kIndex1] = pad_needed_h - pad_list[kIndex0];
59   }
60   if (dout_shape_norm[kInputIndex3] != Shape::kShapeDimAny && x_size_v[kInputIndex3] != Shape::kShapeDimAny) {
61     auto pad_needed_w =
62       (dout_shape_norm[kInputIndex3] - 1) * stride_w + dilation_w * (kernel_w - 1) + 1 - x_size_v[kInputIndex3];
63     pad_needed_w = pad_needed_w > 0L ? pad_needed_w : 0L;
64     pad_list[kIndex2] = pad_needed_w / pad_divisor;
65     pad_list[kIndex3] = pad_needed_w - pad_list[kIndex2];
66   }
67   return pad_list;
68 }
69 
SetPadList(const PrimitivePtr & primitive,const std::vector<int64_t> & dout_shape_norm,const std::vector<int64_t> & x_size_v)70 void SetPadList(const PrimitivePtr &primitive, const std::vector<int64_t> &dout_shape_norm,
71                 const std::vector<int64_t> &x_size_v) {
72   MS_EXCEPTION_IF_NULL(primitive);
73   auto prim_name = primitive->name();
74   // check
75   auto kernel_size =
76     CheckAndConvertUtils::CheckIntOrTupleInt("attribute[kernel_size]", primitive->GetAttr(kKernelSize), prim_name);
77   auto stride = CheckAndConvertUtils::CheckIntOrTupleInt("attribute[stride]", primitive->GetAttr(kStride), prim_name);
78   auto dilation =
79     CheckAndConvertUtils::CheckIntOrTupleInt("attribute[dilation]", primitive->GetAttr(kDilation), prim_name);
80 
81   auto attr_pad_list_prt = primitive->GetAttr(kPadList);
82   MS_EXCEPTION_IF_NULL(attr_pad_list_prt);
83   int64_t pad_mode;
84   CheckAndConvertUtils::GetPadModEnumValue(primitive->GetAttr(kPadMode), &pad_mode, true);
85 
86   ShapeVector pad_list(kPadSize, Shape::kShapeDimAny);
87   auto is_valid_pad_attr = [&attr_pad_list_prt]() -> bool {
88     if (attr_pad_list_prt->isa<None>()) {
89       return false;
90     }
91     auto attr_pad_list = GetValue<ShapeVector>(attr_pad_list_prt);
92     return std::all_of(attr_pad_list.begin(), attr_pad_list.end(), [](int64_t val) { return val >= 0; });
93   };
94   if (is_valid_pad_attr()) {
95     pad_list = GetValue<ShapeVector>(attr_pad_list_prt);
96   } else if (pad_mode == VALID) {
97     std::fill(pad_list.begin(), pad_list.end(), 0);
98   } else if (pad_mode == SAME) {
99     pad_list = CalPadListForSameMode(dout_shape_norm, x_size_v, kernel_size, stride, dilation);
100   } else if (pad_mode == PAD) {
101     pad_list = GetValue<std::vector<int64_t>>(primitive->GetAttr(kPad));
102   }
103   (void)primitive->AddAttr(kPadList, MakeValue(pad_list));
104 }
105 }  // namespace
106 
Conv2DBackpropInputInferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)107 BaseShapePtr Conv2DBackpropInputInferShape(const PrimitivePtr &primitive,
108                                            const std::vector<AbstractBasePtr> &input_args) {
109   MS_EXCEPTION_IF_NULL(primitive);
110   auto input_size = input_args[kConv2DBackpropInputSizeIndex];
111   auto out_shape = GetShapeValue(primitive, input_size);
112   auto ret_shape = std::make_shared<abstract::Shape>(out_shape);
113   auto dout_shape =
114     CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kConv2DBackpropInputDoutIndex]->GetShape())[kShape];
115 
116   constexpr size_t kRank = 4;
117   if (!IsDynamicRank(dout_shape) && dout_shape.size() < kRank) {
118     MS_LOG(EXCEPTION) << "For " << primitive->name() << ", the rank of input[0] can't be less than " << kRank
119                       << ", but got a invalid shape: " << ShapeVectorToStr(dout_shape);
120   }
121   auto format = CheckAndConvertUtils::GetAndCheckFormat(primitive->GetAttr(kFormat));
122   ShapeVector tmp_shape = {dout_shape[0], dout_shape[2], dout_shape[3], dout_shape[1]};
123   auto dout_shape_norm = format == Format::NCHW ? dout_shape : tmp_shape;
124   SetPadList(primitive, dout_shape_norm, out_shape);
125   return ret_shape;
126 }
127 
Conv2DBackpropInputInferType(const PrimitivePtr & prim,const std::vector<AbstractBasePtr> & input_args)128 TypePtr Conv2DBackpropInputInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
129   MS_EXCEPTION_IF_NULL(prim);
130   auto prim_name = prim->name();
131   // check
132   std::map<std::string, TypePtr> types;
133   // todo: check input_sizes
134   (void)types.emplace("x", input_args[kConv2DBackpropInputInputIndex]->GetType());
135   (void)types.emplace("doutput", input_args[kConv2DBackpropInputDoutIndex]->GetType());
136   std::set<TypePtr> valid_x_type = {kInt8, kInt32, kFloat16, kFloat32, kBFloat16};
137   return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_x_type, prim_name);
138 }
139 
140 MIND_API_OPERATOR_IMPL(Conv2DBackpropInput, BaseOperator);
Conv2DBackpropInputInfer(const abstract::AnalysisEnginePtr &,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)141 AbstractBasePtr Conv2DBackpropInputInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
142                                          const std::vector<AbstractBasePtr> &input_args) {
143   MS_EXCEPTION_IF_NULL(primitive);
144   auto prim_name = primitive->name();
145   // check
146   const int64_t input_num = 3;
147   (void)CheckAndConvertUtils::CheckInteger("input size", SizeToLong(input_args.size()), kGreaterEqual, input_num,
148                                            prim_name);
149   for (const auto &item : input_args) {
150     MS_EXCEPTION_IF_NULL(item);
151   }
152   auto abs = std::make_shared<abstract::AbstractTensor>(Conv2DBackpropInputInferType(primitive, input_args),
153                                                         Conv2DBackpropInputInferShape(primitive, input_args));
154   return abs;
155 }
156 
Init(int64_t out_channel,const std::vector<int64_t> & kernel_size,int64_t mode,const PadMode & pad_mode,const std::vector<int64_t> & pad,const std::vector<int64_t> & stride,const std::vector<int64_t> & dilation,int64_t group,const Format & format,const std::vector<int64_t> & pad_list)157 void Conv2DBackpropInput::Init(int64_t out_channel, const std::vector<int64_t> &kernel_size, int64_t mode,
158                                const PadMode &pad_mode, const std::vector<int64_t> &pad,
159                                const std::vector<int64_t> &stride, const std::vector<int64_t> &dilation, int64_t group,
160                                const Format &format, const std::vector<int64_t> &pad_list) {
161   set_out_channel(out_channel);
162   set_kernel_size(kernel_size);
163   set_mode(mode);
164   set_pad_mode(pad_mode);
165   set_pad(pad);
166   set_stride(stride);
167   set_dilation(dilation);
168   set_group(group);
169   set_format(format);
170   set_pad_list(pad_list);
171 }
172 
set_out_channel(int64_t out_channel)173 void Conv2DBackpropInput::set_out_channel(int64_t out_channel) {
174   (void)AddAttr(kOutChannel,
175                 api::MakeValue(CheckAndConvertUtils::CheckInteger(kOutChannel, out_channel, kGreaterThan, 0, name())));
176 }
177 
set_kernel_size(const std::vector<int64_t> & kernel_size)178 void Conv2DBackpropInput::set_kernel_size(const std::vector<int64_t> &kernel_size) {
179   (void)AddAttr(kKernelSize,
180                 api::MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, name())));
181 }
182 
set_stride(const std::vector<int64_t> & stride)183 void Conv2DBackpropInput::set_stride(const std::vector<int64_t> &stride) {
184   (void)AddAttr(kStride, api::MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStride, stride, name())));
185 }
186 
set_dilation(const std::vector<int64_t> & dilation)187 void Conv2DBackpropInput::set_dilation(const std::vector<int64_t> &dilation) {
188   (void)AddAttr(kDilation, api::MakeValue(CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, name())));
189 }
190 
set_pad_mode(const PadMode & pad_mode)191 void Conv2DBackpropInput::set_pad_mode(const PadMode &pad_mode) {
192   std::vector<int64_t> pad = get_pad();
193   if (pad_mode == PAD) {
194     for (auto item : pad) {
195       CheckAndConvertUtils::Check(kPadItem, item, kGreaterEqual, 0, name());
196     }
197   } else {
198     CheckAndConvertUtils::Check(kPad, pad, kEqual, {0, 0, 0, 0}, name());
199   }
200   int64_t swi = pad_mode;
201   (void)AddAttr(kPadMode, api::MakeValue(swi));
202 }
203 
set_pad(const std::vector<int64_t> & pad)204 void Conv2DBackpropInput::set_pad(const std::vector<int64_t> &pad) {
205   const int64_t pad_size = 4;
206   (void)CheckAndConvertUtils::CheckInteger("pad_size", SizeToLong(pad.size()), kEqual, pad_size, name());
207   (void)AddAttr(kPad, api::MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name())));
208 }
209 
set_mode(int64_t mode)210 void Conv2DBackpropInput::set_mode(int64_t mode) {
211   (void)AddAttr(kMode, api::MakeValue(CheckAndConvertUtils::CheckInteger(kMode, mode, kEqual, 1, name())));
212 }
213 
set_group(int64_t group)214 void Conv2DBackpropInput::set_group(int64_t group) {
215   (void)AddAttr(kGroup, api::MakeValue(CheckAndConvertUtils::CheckInteger(kGroup, group, kGreaterThan, 0, name())));
216 }
217 
set_format(const Format & format)218 void Conv2DBackpropInput::set_format(const Format &format) {
219   int64_t f = format;
220   (void)AddAttr(kFormat, api::MakeValue(f));
221 }
222 
set_pad_list(const std::vector<int64_t> & pad_list)223 void Conv2DBackpropInput::set_pad_list(const std::vector<int64_t> &pad_list) {
224   (void)this->AddAttr(kPadList, api::MakeValue(pad_list));
225 }
226 
get_out_channel() const227 int64_t Conv2DBackpropInput::get_out_channel() const {
228   auto value_ptr = GetAttr(kOutChannel);
229   MS_EXCEPTION_IF_NULL(value_ptr);
230   return GetValue<int64_t>(value_ptr);
231 }
232 
get_kernel_size() const233 std::vector<int64_t> Conv2DBackpropInput::get_kernel_size() const {
234   auto value_ptr = GetAttr(kKernelSize);
235   MS_EXCEPTION_IF_NULL(value_ptr);
236   return GetValue<std::vector<int64_t>>(value_ptr);
237 }
238 
get_stride() const239 std::vector<int64_t> Conv2DBackpropInput::get_stride() const {
240   auto value_ptr = GetAttr(kStride);
241   MS_EXCEPTION_IF_NULL(value_ptr);
242   return GetValue<std::vector<int64_t>>(value_ptr);
243 }
244 
get_dilation() const245 std::vector<int64_t> Conv2DBackpropInput::get_dilation() const {
246   auto value_ptr = GetAttr(kDilation);
247   MS_EXCEPTION_IF_NULL(value_ptr);
248   return GetValue<std::vector<int64_t>>(value_ptr);
249 }
250 
get_pad_mode() const251 PadMode Conv2DBackpropInput::get_pad_mode() const {
252   auto value_ptr = GetAttr(kPadMode);
253   MS_EXCEPTION_IF_NULL(value_ptr);
254   return PadMode(GetValue<int64_t>(value_ptr));
255 }
256 
get_pad() const257 std::vector<int64_t> Conv2DBackpropInput::get_pad() const {
258   auto value_ptr = GetAttr(kPad);
259   MS_EXCEPTION_IF_NULL(value_ptr);
260   return GetValue<std::vector<int64_t>>(value_ptr);
261 }
262 
get_mode() const263 int64_t Conv2DBackpropInput::get_mode() const {
264   auto value_ptr = GetAttr(kMode);
265   MS_EXCEPTION_IF_NULL(value_ptr);
266   return GetValue<int64_t>(value_ptr);
267 }
268 
get_group() const269 int64_t Conv2DBackpropInput::get_group() const {
270   auto value_ptr = GetAttr(kGroup);
271   MS_EXCEPTION_IF_NULL(value_ptr);
272   return GetValue<int64_t>(value_ptr);
273 }
274 
get_format() const275 Format Conv2DBackpropInput::get_format() const {
276   auto value_ptr = GetAttr(kFormat);
277   MS_EXCEPTION_IF_NULL(value_ptr);
278   return Format(GetValue<int64_t>(value_ptr));
279 }
280 
get_pad_list() const281 std::vector<int64_t> Conv2DBackpropInput::get_pad_list() const {
282   auto value_ptr = GetAttr(kPadList);
283   MS_EXCEPTION_IF_NULL(value_ptr);
284   return GetValue<std::vector<int64_t>>(value_ptr);
285 }
286 
287 // AG means auto generated
288 class MIND_API AGConv2DBackpropInputInfer : public abstract::OpInferBase {
289  public:
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const290   BaseShapePtr InferShape(const PrimitivePtr &primitive,
291                           const std::vector<AbstractBasePtr> &input_args) const override {
292     return Conv2DBackpropInputInferShape(primitive, input_args);
293   }
294 
InferShapeAndType(const abstract::AnalysisEnginePtr & engine,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const295   AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
296                                     const std::vector<AbstractBasePtr> &input_args) const override {
297     return Conv2DBackpropInputInfer(engine, primitive, input_args);
298   }
299 
InferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const300   TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
301     return Conv2DBackpropInputInferType(primitive, input_args);
302   }
303 
GetValueDependArgIndices() const304   std::set<int64_t> GetValueDependArgIndices() const override { return {2}; }
305 };
306 
307 REGISTER_PRIMITIVE_OP_INFER_IMPL(Conv2DBackpropInput, prim::kPrimConv2DBackpropInput, AGConv2DBackpropInputInfer,
308                                  false);
309 }  // namespace ops
310 }  // namespace mindspore
311