1 /**
2 * Copyright 2021-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <algorithm>
18 #include <map>
19 #include <memory>
20 #include <set>
21 #include <string>
22 #include <vector>
23
24 #include "mindapi/src/helper.h"
25 #include "mindspore/core/ops/conv_pool_ops.h"
26 #include "ops/grad/conv2d_backprop_input.h"
27 #include "ops/op_utils.h"
28 #include "utils/check_convert_utils.h"
29
30 namespace mindspore {
31 namespace ops {
32 namespace {
33 using abstract::Shape;
34 constexpr size_t kConv2DBackpropInputDoutIndex = 0;
35 constexpr size_t kConv2DBackpropInputInputIndex = 1;
36 constexpr size_t kConv2DBackpropInputSizeIndex = 2;
37 constexpr auto kPadSize = 4;
38
CalPadListForSameMode(const ShapeVector & dout_shape_norm,const ShapeVector & x_size_v,const ShapeVector & kernel_size,const ShapeVector & stride,const ShapeVector & dilation)39 ShapeVector CalPadListForSameMode(const ShapeVector &dout_shape_norm, const ShapeVector &x_size_v,
40 const ShapeVector &kernel_size, const ShapeVector &stride,
41 const ShapeVector &dilation) {
42 ShapeVector pad_list(kPadSize, Shape::kShapeDimAny);
43 if (IsDynamicRank(dout_shape_norm) || IsDynamicRank(x_size_v)) {
44 return pad_list;
45 }
46 const auto stride_h = stride[kIndex2];
47 const auto stride_w = stride[kIndex3];
48 const auto kernel_h = kernel_size[kIndex0];
49 const auto kernel_w = kernel_size[kIndex1];
50 const auto dilation_h = dilation[kIndex2];
51 const auto dilation_w = dilation[kIndex3];
52 constexpr auto pad_divisor = 2;
53 if (dout_shape_norm[kInputIndex2] != Shape::kShapeDimAny && x_size_v[kInputIndex2] != Shape::kShapeDimAny) {
54 auto pad_needed_h =
55 (dout_shape_norm[kInputIndex2] - 1) * stride_h + dilation_h * (kernel_h - 1) + 1 - x_size_v[kInputIndex2];
56 pad_needed_h = 0 > pad_needed_h ? 0 : pad_needed_h;
57 pad_list[kIndex0] = pad_needed_h / pad_divisor;
58 pad_list[kIndex1] = pad_needed_h - pad_list[kIndex0];
59 }
60 if (dout_shape_norm[kInputIndex3] != Shape::kShapeDimAny && x_size_v[kInputIndex3] != Shape::kShapeDimAny) {
61 auto pad_needed_w =
62 (dout_shape_norm[kInputIndex3] - 1) * stride_w + dilation_w * (kernel_w - 1) + 1 - x_size_v[kInputIndex3];
63 pad_needed_w = pad_needed_w > 0L ? pad_needed_w : 0L;
64 pad_list[kIndex2] = pad_needed_w / pad_divisor;
65 pad_list[kIndex3] = pad_needed_w - pad_list[kIndex2];
66 }
67 return pad_list;
68 }
69
SetPadList(const PrimitivePtr & primitive,const std::vector<int64_t> & dout_shape_norm,const std::vector<int64_t> & x_size_v)70 void SetPadList(const PrimitivePtr &primitive, const std::vector<int64_t> &dout_shape_norm,
71 const std::vector<int64_t> &x_size_v) {
72 MS_EXCEPTION_IF_NULL(primitive);
73 auto prim_name = primitive->name();
74 // check
75 auto kernel_size =
76 CheckAndConvertUtils::CheckIntOrTupleInt("attribute[kernel_size]", primitive->GetAttr(kKernelSize), prim_name);
77 auto stride = CheckAndConvertUtils::CheckIntOrTupleInt("attribute[stride]", primitive->GetAttr(kStride), prim_name);
78 auto dilation =
79 CheckAndConvertUtils::CheckIntOrTupleInt("attribute[dilation]", primitive->GetAttr(kDilation), prim_name);
80
81 auto attr_pad_list_prt = primitive->GetAttr(kPadList);
82 MS_EXCEPTION_IF_NULL(attr_pad_list_prt);
83 int64_t pad_mode;
84 CheckAndConvertUtils::GetPadModEnumValue(primitive->GetAttr(kPadMode), &pad_mode, true);
85
86 ShapeVector pad_list(kPadSize, Shape::kShapeDimAny);
87 auto is_valid_pad_attr = [&attr_pad_list_prt]() -> bool {
88 if (attr_pad_list_prt->isa<None>()) {
89 return false;
90 }
91 auto attr_pad_list = GetValue<ShapeVector>(attr_pad_list_prt);
92 return std::all_of(attr_pad_list.begin(), attr_pad_list.end(), [](int64_t val) { return val >= 0; });
93 };
94 if (is_valid_pad_attr()) {
95 pad_list = GetValue<ShapeVector>(attr_pad_list_prt);
96 } else if (pad_mode == VALID) {
97 std::fill(pad_list.begin(), pad_list.end(), 0);
98 } else if (pad_mode == SAME) {
99 pad_list = CalPadListForSameMode(dout_shape_norm, x_size_v, kernel_size, stride, dilation);
100 } else if (pad_mode == PAD) {
101 pad_list = GetValue<std::vector<int64_t>>(primitive->GetAttr(kPad));
102 }
103 (void)primitive->AddAttr(kPadList, MakeValue(pad_list));
104 }
105 } // namespace
106
Conv2DBackpropInputInferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)107 BaseShapePtr Conv2DBackpropInputInferShape(const PrimitivePtr &primitive,
108 const std::vector<AbstractBasePtr> &input_args) {
109 MS_EXCEPTION_IF_NULL(primitive);
110 auto input_size = input_args[kConv2DBackpropInputSizeIndex];
111 auto out_shape = GetShapeValue(primitive, input_size);
112 auto ret_shape = std::make_shared<abstract::Shape>(out_shape);
113 auto dout_shape =
114 CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kConv2DBackpropInputDoutIndex]->GetShape())[kShape];
115
116 constexpr size_t kRank = 4;
117 if (!IsDynamicRank(dout_shape) && dout_shape.size() < kRank) {
118 MS_LOG(EXCEPTION) << "For " << primitive->name() << ", the rank of input[0] can't be less than " << kRank
119 << ", but got a invalid shape: " << ShapeVectorToStr(dout_shape);
120 }
121 auto format = CheckAndConvertUtils::GetAndCheckFormat(primitive->GetAttr(kFormat));
122 ShapeVector tmp_shape = {dout_shape[0], dout_shape[2], dout_shape[3], dout_shape[1]};
123 auto dout_shape_norm = format == Format::NCHW ? dout_shape : tmp_shape;
124 SetPadList(primitive, dout_shape_norm, out_shape);
125 return ret_shape;
126 }
127
Conv2DBackpropInputInferType(const PrimitivePtr & prim,const std::vector<AbstractBasePtr> & input_args)128 TypePtr Conv2DBackpropInputInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
129 MS_EXCEPTION_IF_NULL(prim);
130 auto prim_name = prim->name();
131 // check
132 std::map<std::string, TypePtr> types;
133 // todo: check input_sizes
134 (void)types.emplace("x", input_args[kConv2DBackpropInputInputIndex]->GetType());
135 (void)types.emplace("doutput", input_args[kConv2DBackpropInputDoutIndex]->GetType());
136 std::set<TypePtr> valid_x_type = {kInt8, kInt32, kFloat16, kFloat32, kBFloat16};
137 return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_x_type, prim_name);
138 }
139
140 MIND_API_OPERATOR_IMPL(Conv2DBackpropInput, BaseOperator);
Conv2DBackpropInputInfer(const abstract::AnalysisEnginePtr &,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)141 AbstractBasePtr Conv2DBackpropInputInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
142 const std::vector<AbstractBasePtr> &input_args) {
143 MS_EXCEPTION_IF_NULL(primitive);
144 auto prim_name = primitive->name();
145 // check
146 const int64_t input_num = 3;
147 (void)CheckAndConvertUtils::CheckInteger("input size", SizeToLong(input_args.size()), kGreaterEqual, input_num,
148 prim_name);
149 for (const auto &item : input_args) {
150 MS_EXCEPTION_IF_NULL(item);
151 }
152 auto abs = std::make_shared<abstract::AbstractTensor>(Conv2DBackpropInputInferType(primitive, input_args),
153 Conv2DBackpropInputInferShape(primitive, input_args));
154 return abs;
155 }
156
Init(int64_t out_channel,const std::vector<int64_t> & kernel_size,int64_t mode,const PadMode & pad_mode,const std::vector<int64_t> & pad,const std::vector<int64_t> & stride,const std::vector<int64_t> & dilation,int64_t group,const Format & format,const std::vector<int64_t> & pad_list)157 void Conv2DBackpropInput::Init(int64_t out_channel, const std::vector<int64_t> &kernel_size, int64_t mode,
158 const PadMode &pad_mode, const std::vector<int64_t> &pad,
159 const std::vector<int64_t> &stride, const std::vector<int64_t> &dilation, int64_t group,
160 const Format &format, const std::vector<int64_t> &pad_list) {
161 set_out_channel(out_channel);
162 set_kernel_size(kernel_size);
163 set_mode(mode);
164 set_pad_mode(pad_mode);
165 set_pad(pad);
166 set_stride(stride);
167 set_dilation(dilation);
168 set_group(group);
169 set_format(format);
170 set_pad_list(pad_list);
171 }
172
set_out_channel(int64_t out_channel)173 void Conv2DBackpropInput::set_out_channel(int64_t out_channel) {
174 (void)AddAttr(kOutChannel,
175 api::MakeValue(CheckAndConvertUtils::CheckInteger(kOutChannel, out_channel, kGreaterThan, 0, name())));
176 }
177
set_kernel_size(const std::vector<int64_t> & kernel_size)178 void Conv2DBackpropInput::set_kernel_size(const std::vector<int64_t> &kernel_size) {
179 (void)AddAttr(kKernelSize,
180 api::MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, name())));
181 }
182
set_stride(const std::vector<int64_t> & stride)183 void Conv2DBackpropInput::set_stride(const std::vector<int64_t> &stride) {
184 (void)AddAttr(kStride, api::MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStride, stride, name())));
185 }
186
set_dilation(const std::vector<int64_t> & dilation)187 void Conv2DBackpropInput::set_dilation(const std::vector<int64_t> &dilation) {
188 (void)AddAttr(kDilation, api::MakeValue(CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, name())));
189 }
190
set_pad_mode(const PadMode & pad_mode)191 void Conv2DBackpropInput::set_pad_mode(const PadMode &pad_mode) {
192 std::vector<int64_t> pad = get_pad();
193 if (pad_mode == PAD) {
194 for (auto item : pad) {
195 CheckAndConvertUtils::Check(kPadItem, item, kGreaterEqual, 0, name());
196 }
197 } else {
198 CheckAndConvertUtils::Check(kPad, pad, kEqual, {0, 0, 0, 0}, name());
199 }
200 int64_t swi = pad_mode;
201 (void)AddAttr(kPadMode, api::MakeValue(swi));
202 }
203
set_pad(const std::vector<int64_t> & pad)204 void Conv2DBackpropInput::set_pad(const std::vector<int64_t> &pad) {
205 const int64_t pad_size = 4;
206 (void)CheckAndConvertUtils::CheckInteger("pad_size", SizeToLong(pad.size()), kEqual, pad_size, name());
207 (void)AddAttr(kPad, api::MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name())));
208 }
209
set_mode(int64_t mode)210 void Conv2DBackpropInput::set_mode(int64_t mode) {
211 (void)AddAttr(kMode, api::MakeValue(CheckAndConvertUtils::CheckInteger(kMode, mode, kEqual, 1, name())));
212 }
213
set_group(int64_t group)214 void Conv2DBackpropInput::set_group(int64_t group) {
215 (void)AddAttr(kGroup, api::MakeValue(CheckAndConvertUtils::CheckInteger(kGroup, group, kGreaterThan, 0, name())));
216 }
217
set_format(const Format & format)218 void Conv2DBackpropInput::set_format(const Format &format) {
219 int64_t f = format;
220 (void)AddAttr(kFormat, api::MakeValue(f));
221 }
222
set_pad_list(const std::vector<int64_t> & pad_list)223 void Conv2DBackpropInput::set_pad_list(const std::vector<int64_t> &pad_list) {
224 (void)this->AddAttr(kPadList, api::MakeValue(pad_list));
225 }
226
get_out_channel() const227 int64_t Conv2DBackpropInput::get_out_channel() const {
228 auto value_ptr = GetAttr(kOutChannel);
229 MS_EXCEPTION_IF_NULL(value_ptr);
230 return GetValue<int64_t>(value_ptr);
231 }
232
get_kernel_size() const233 std::vector<int64_t> Conv2DBackpropInput::get_kernel_size() const {
234 auto value_ptr = GetAttr(kKernelSize);
235 MS_EXCEPTION_IF_NULL(value_ptr);
236 return GetValue<std::vector<int64_t>>(value_ptr);
237 }
238
get_stride() const239 std::vector<int64_t> Conv2DBackpropInput::get_stride() const {
240 auto value_ptr = GetAttr(kStride);
241 MS_EXCEPTION_IF_NULL(value_ptr);
242 return GetValue<std::vector<int64_t>>(value_ptr);
243 }
244
get_dilation() const245 std::vector<int64_t> Conv2DBackpropInput::get_dilation() const {
246 auto value_ptr = GetAttr(kDilation);
247 MS_EXCEPTION_IF_NULL(value_ptr);
248 return GetValue<std::vector<int64_t>>(value_ptr);
249 }
250
get_pad_mode() const251 PadMode Conv2DBackpropInput::get_pad_mode() const {
252 auto value_ptr = GetAttr(kPadMode);
253 MS_EXCEPTION_IF_NULL(value_ptr);
254 return PadMode(GetValue<int64_t>(value_ptr));
255 }
256
get_pad() const257 std::vector<int64_t> Conv2DBackpropInput::get_pad() const {
258 auto value_ptr = GetAttr(kPad);
259 MS_EXCEPTION_IF_NULL(value_ptr);
260 return GetValue<std::vector<int64_t>>(value_ptr);
261 }
262
get_mode() const263 int64_t Conv2DBackpropInput::get_mode() const {
264 auto value_ptr = GetAttr(kMode);
265 MS_EXCEPTION_IF_NULL(value_ptr);
266 return GetValue<int64_t>(value_ptr);
267 }
268
get_group() const269 int64_t Conv2DBackpropInput::get_group() const {
270 auto value_ptr = GetAttr(kGroup);
271 MS_EXCEPTION_IF_NULL(value_ptr);
272 return GetValue<int64_t>(value_ptr);
273 }
274
get_format() const275 Format Conv2DBackpropInput::get_format() const {
276 auto value_ptr = GetAttr(kFormat);
277 MS_EXCEPTION_IF_NULL(value_ptr);
278 return Format(GetValue<int64_t>(value_ptr));
279 }
280
get_pad_list() const281 std::vector<int64_t> Conv2DBackpropInput::get_pad_list() const {
282 auto value_ptr = GetAttr(kPadList);
283 MS_EXCEPTION_IF_NULL(value_ptr);
284 return GetValue<std::vector<int64_t>>(value_ptr);
285 }
286
287 // AG means auto generated
288 class MIND_API AGConv2DBackpropInputInfer : public abstract::OpInferBase {
289 public:
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const290 BaseShapePtr InferShape(const PrimitivePtr &primitive,
291 const std::vector<AbstractBasePtr> &input_args) const override {
292 return Conv2DBackpropInputInferShape(primitive, input_args);
293 }
294
InferShapeAndType(const abstract::AnalysisEnginePtr & engine,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const295 AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
296 const std::vector<AbstractBasePtr> &input_args) const override {
297 return Conv2DBackpropInputInfer(engine, primitive, input_args);
298 }
299
InferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const300 TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
301 return Conv2DBackpropInputInferType(primitive, input_args);
302 }
303
GetValueDependArgIndices() const304 std::set<int64_t> GetValueDependArgIndices() const override { return {2}; }
305 };
306
307 REGISTER_PRIMITIVE_OP_INFER_IMPL(Conv2DBackpropInput, prim::kPrimConv2DBackpropInput, AGConv2DBackpropInputInfer,
308 false);
309 } // namespace ops
310 } // namespace mindspore
311