1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <map>
18 #include <vector>
19 #include <string>
20 #include <memory>
21 #include <set>
22
23 #include "ops/grad/conv2d_backprop_input.h"
24
25 namespace mindspore {
26 namespace ops {
27 namespace {
28 constexpr size_t kDoutIndex = 0;
29 constexpr size_t kInputIndex = 1;
30 constexpr int64_t kSizeIndex = 2;
31
SetPadList(const PrimitivePtr & primitive,const std::vector<int64_t> & dout_shape_norm,const std::vector<int64_t> & x_size_v)32 void SetPadList(const PrimitivePtr &primitive, const std::vector<int64_t> &dout_shape_norm,
33 const std::vector<int64_t> &x_size_v) {
34 MS_EXCEPTION_IF_NULL(primitive);
35 auto prim_name = primitive->name();
36 // check
37 auto kernel_size =
38 CheckAndConvertUtils::CheckAttrIntOrTupleInt("kernel_size", primitive->GetAttr(kKernelSize), prim_name);
39 auto stride = CheckAndConvertUtils::CheckAttrIntOrTupleInt("stride", primitive->GetAttr(kStride), prim_name);
40 auto dilation = CheckAndConvertUtils::CheckAttrIntOrTupleInt("dilation", primitive->GetAttr(kDilation), prim_name);
41 // default pad mode is valid
42 auto attr_pad_list_prt = primitive->GetAttr(kPadList);
43 int64_t pad_mode;
44 CheckAndConvertUtils::GetPadModEnumValue(primitive->GetAttr(kPadMode), &pad_mode, true);
45 ShapeVector pad_list = {0, 0, 0, 0};
46 if (!attr_pad_list_prt->isa<None>()) {
47 pad_list = GetValue<ShapeVector>(attr_pad_list_prt);
48 } else if (pad_mode == SAME) {
49 auto stride_h = stride[2];
50 auto stride_w = stride[3];
51 auto kernel_h = kernel_size[0];
52 auto kernel_w = kernel_size[1];
53 auto dilation_h = dilation[2];
54 auto dilation_w = dilation[3];
55 int64_t pad_top = abstract::Shape::SHP_ANY;
56 int64_t pad_bottom = abstract::Shape::SHP_ANY;
57 int64_t pad_left = abstract::Shape::SHP_ANY;
58 int64_t pad_right = abstract::Shape::SHP_ANY;
59 if (dout_shape_norm[kInputIndex2] != abstract::Shape::SHP_ANY &&
60 x_size_v[kInputIndex2] != abstract::Shape::SHP_ANY) {
61 auto pad_needed_h =
62 (dout_shape_norm[kInputIndex2] - 1) * stride_h + dilation_h * (kernel_h - 1) + 1 - x_size_v[kInputIndex2];
63 pad_needed_h = 0 > pad_needed_h ? 0 : pad_needed_h;
64 pad_top = pad_needed_h / 2;
65 pad_bottom = pad_needed_h - pad_top;
66 }
67 if (dout_shape_norm[kInputIndex3] != abstract::Shape::SHP_ANY &&
68 x_size_v[kInputIndex3] != abstract::Shape::SHP_ANY) {
69 auto pad_needed_w =
70 (dout_shape_norm[kInputIndex3] - 1) * stride_w + dilation_w * (kernel_w - 1) + 1 - x_size_v[kInputIndex3];
71 pad_needed_w = pad_needed_w > 0L ? pad_needed_w : 0L;
72 pad_left = pad_needed_w / 2;
73 pad_right = pad_needed_w - pad_left;
74 }
75 pad_list = {pad_top, pad_bottom, pad_left, pad_right};
76 } else if (pad_mode == PAD) {
77 pad_list = GetValue<std::vector<int64_t>>(primitive->GetAttr(kPad));
78 }
79 (void)primitive->AddAttr(kPadList, MakeValue(pad_list));
80 }
81
Conv2DBackpropInputInferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)82 abstract::ShapePtr Conv2DBackpropInputInferShape(const PrimitivePtr &primitive,
83 const std::vector<AbstractBasePtr> &input_args) {
84 MS_EXCEPTION_IF_NULL(primitive);
85 auto prim_name = primitive->name();
86 std::vector<int64_t> out_shape;
87 abstract::ShapePtr ret_shape;
88 auto input_size = input_args[kSizeIndex];
89 auto input_size_v = input_size->BuildValue();
90 MS_EXCEPTION_IF_NULL(input_size_v);
91
92 if (input_size->isa<abstract::AbstractTensor>()) {
93 if (input_size_v->isa<tensor::Tensor>()) {
94 out_shape = CheckAndConvertUtils::CheckTensorIntValue("input x size", input_size_v, prim_name);
95 ret_shape = std::make_shared<abstract::Shape>(out_shape);
96 } else {
97 auto shape_ptr = CheckAndConvertUtils::GetTensorInputShape(prim_name, input_args, kSizeIndex);
98 MS_EXCEPTION_IF_NULL(shape_ptr);
99 auto shape_shape = shape_ptr->shape();
100 if (shape_shape.size() != 1) {
101 MS_LOG(EXCEPTION) << "The " << prim_name << "'s x size must be 1-D.";
102 }
103
104 auto abstract_tensor = input_size->cast<abstract::AbstractTensorPtr>();
105 MS_EXCEPTION_IF_NULL(abstract_tensor);
106 auto shape_max_value = abstract_tensor->get_max_value();
107 auto shape_min_value = abstract_tensor->get_min_value();
108 if (shape_max_value == nullptr || shape_min_value == nullptr) {
109 MS_LOG(EXCEPTION) << "Max_value or min value of x size can not be empty when its value is dynamic.";
110 }
111
112 auto shape_max = GetValue<std::vector<int64_t>>(shape_max_value);
113 auto shape_min = GetValue<std::vector<int64_t>>(shape_min_value);
114
115 auto x_size_len = LongToSize(shape_shape[0]);
116 if (shape_max.size() != x_size_len || shape_min.size() != x_size_len) {
117 MS_LOG(EXCEPTION) << "For " << prim_name << ", x size's min or max value is valid.";
118 }
119
120 for (size_t i = 0; i < x_size_len; i++) {
121 if (shape_min[i] == shape_max[i]) {
122 out_shape.push_back(shape_min[i]);
123 } else {
124 out_shape.push_back(abstract::Shape::SHP_ANY);
125 }
126 }
127 ret_shape = std::make_shared<abstract::Shape>(out_shape, shape_min, shape_max);
128 }
129 } else if (input_size->isa<abstract::AbstractTuple>()) {
130 // check tensor, tuple or int to raise error.
131 out_shape = CheckAndConvertUtils::CheckAttrIntOrTupleInt("input x size", input_size_v, prim_name);
132 ret_shape = std::make_shared<abstract::Shape>(out_shape);
133 } else {
134 MS_EXCEPTION(TypeError) << "Conv2DBackpropInput x_size must be a tuple or tensor, but " << input_size->ToString();
135 }
136 auto dout_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kDoutIndex]->BuildShape())[kShape];
137
138 auto format = CheckAndConvertUtils::GetAndCheckFormat(primitive->GetAttr(kFormat));
139 ShapeVector tmp_shape = {dout_shape[0], dout_shape[2], dout_shape[3], dout_shape[1]};
140 auto dout_shape_norm = format == Format::NCHW ? dout_shape : tmp_shape;
141 SetPadList(primitive, dout_shape_norm, out_shape);
142 return ret_shape;
143 }
144
Conv2DBackpropInputInferType(const PrimitivePtr & prim,const std::vector<AbstractBasePtr> & input_args)145 TypePtr Conv2DBackpropInputInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
146 MS_EXCEPTION_IF_NULL(prim);
147 auto prim_name = prim->name();
148 // check
149 std::map<std::string, TypePtr> types;
150 // todo: check input_sizes
151 (void)types.emplace("x", input_args[kInputIndex]->BuildType());
152 (void)types.emplace("doutput", input_args[kDoutIndex]->BuildType());
153 std::set<TypePtr> valid_x_type = {kInt8, kInt32, kFloat16, kFloat32};
154 return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_x_type, prim_name);
155 }
156 } // namespace
Conv2DBackpropInputInfer(const abstract::AnalysisEnginePtr &,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)157 AbstractBasePtr Conv2DBackpropInputInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
158 const std::vector<AbstractBasePtr> &input_args) {
159 MS_EXCEPTION_IF_NULL(primitive);
160 auto prim_name = primitive->name();
161 // check
162 const int64_t input_num = 3;
163 (void)CheckAndConvertUtils::CheckInteger("input size", SizeToLong(input_args.size()), kGreaterEqual, input_num,
164 prim_name);
165 for (const auto &item : input_args) {
166 MS_EXCEPTION_IF_NULL(item);
167 }
168 auto abs = std::make_shared<abstract::AbstractTensor>(Conv2DBackpropInputInferType(primitive, input_args),
169 Conv2DBackpropInputInferShape(primitive, input_args));
170 return abs;
171 }
172
Init(int64_t out_channel,const std::vector<int64_t> & kernel_size,int64_t mode,const PadMode & pad_mode,const std::vector<int64_t> & pad,const std::vector<int64_t> & stride,const std::vector<int64_t> & dilation,int64_t group,const Format & format,const std::vector<int64_t> & pad_list)173 void Conv2DBackpropInput::Init(int64_t out_channel, const std::vector<int64_t> &kernel_size, int64_t mode,
174 const PadMode &pad_mode, const std::vector<int64_t> &pad,
175 const std::vector<int64_t> &stride, const std::vector<int64_t> &dilation, int64_t group,
176 const Format &format, const std::vector<int64_t> &pad_list) {
177 set_out_channel(out_channel);
178 set_kernel_size(kernel_size);
179 set_mode(mode);
180 set_pad_mode(pad_mode);
181 set_pad(pad);
182 set_stride(stride);
183 set_dilation(dilation);
184 set_group(group);
185 set_format(format);
186 set_pad_list(pad_list);
187 }
188
set_out_channel(int64_t out_channel)189 void Conv2DBackpropInput::set_out_channel(int64_t out_channel) {
190 (void)AddAttr(kOutChannel,
191 MakeValue(CheckAndConvertUtils::CheckInteger(kOutChannel, out_channel, kGreaterThan, 0, name())));
192 }
193
set_kernel_size(const std::vector<int64_t> & kernel_size)194 void Conv2DBackpropInput::set_kernel_size(const std::vector<int64_t> &kernel_size) {
195 (void)AddAttr(kKernelSize, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, name())));
196 }
197
set_stride(const std::vector<int64_t> & stride)198 void Conv2DBackpropInput::set_stride(const std::vector<int64_t> &stride) {
199 (void)AddAttr(kStride, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStride, stride, name())));
200 }
201
set_dilation(const std::vector<int64_t> & dilation)202 void Conv2DBackpropInput::set_dilation(const std::vector<int64_t> &dilation) {
203 (void)AddAttr(kDilation, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, name())));
204 }
205
set_pad_mode(const PadMode & pad_mode)206 void Conv2DBackpropInput::set_pad_mode(const PadMode &pad_mode) {
207 std::vector<int64_t> pad = get_pad();
208 if (pad_mode == PAD) {
209 for (auto item : pad) {
210 CheckAndConvertUtils::Check(kPadItem, item, kGreaterEqual, "zeros_list", 0, name());
211 }
212 } else {
213 CheckAndConvertUtils::Check(kPad, pad, kEqual, "zeros_list", {0, 0, 0, 0}, name());
214 }
215 int64_t swi = pad_mode;
216 (void)AddAttr(kPadMode, MakeValue(swi));
217 }
218
set_pad(const std::vector<int64_t> & pad)219 void Conv2DBackpropInput::set_pad(const std::vector<int64_t> &pad) {
220 const int64_t pad_size = 4;
221 (void)CheckAndConvertUtils::CheckInteger("pad_size", SizeToLong(pad.size()), kEqual, pad_size, name());
222 (void)AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name())));
223 }
224
set_mode(int64_t mode)225 void Conv2DBackpropInput::set_mode(int64_t mode) {
226 (void)AddAttr(kMode, MakeValue(CheckAndConvertUtils::CheckInteger(kMode, mode, kEqual, 1, name())));
227 }
228
set_group(int64_t group)229 void Conv2DBackpropInput::set_group(int64_t group) {
230 (void)AddAttr(kGroup, MakeValue(CheckAndConvertUtils::CheckInteger(kGroup, group, kGreaterThan, 0, name())));
231 }
232
set_format(const Format & format)233 void Conv2DBackpropInput::set_format(const Format &format) {
234 int64_t f = format;
235 (void)AddAttr(kFormat, MakeValue(f));
236 }
237
set_pad_list(const std::vector<int64_t> & pad_list)238 void Conv2DBackpropInput::set_pad_list(const std::vector<int64_t> &pad_list) {
239 (void)this->AddAttr(kPadList, MakeValue(pad_list));
240 }
241
get_out_channel() const242 int64_t Conv2DBackpropInput::get_out_channel() const {
243 auto value_ptr = GetAttr(kOutChannel);
244 MS_EXCEPTION_IF_NULL(value_ptr);
245 return GetValue<int64_t>(value_ptr);
246 }
247
get_kernel_size() const248 std::vector<int64_t> Conv2DBackpropInput::get_kernel_size() const {
249 auto value_ptr = GetAttr(kKernelSize);
250 MS_EXCEPTION_IF_NULL(value_ptr);
251 return GetValue<std::vector<int64_t>>(value_ptr);
252 }
253
get_stride() const254 std::vector<int64_t> Conv2DBackpropInput::get_stride() const {
255 auto value_ptr = GetAttr(kStride);
256 MS_EXCEPTION_IF_NULL(value_ptr);
257 return GetValue<std::vector<int64_t>>(value_ptr);
258 }
259
get_dilation() const260 std::vector<int64_t> Conv2DBackpropInput::get_dilation() const {
261 auto value_ptr = GetAttr(kDilation);
262 MS_EXCEPTION_IF_NULL(value_ptr);
263 return GetValue<std::vector<int64_t>>(value_ptr);
264 }
265
get_pad_mode() const266 PadMode Conv2DBackpropInput::get_pad_mode() const {
267 auto value_ptr = GetAttr(kPadMode);
268 MS_EXCEPTION_IF_NULL(value_ptr);
269 return PadMode(GetValue<int64_t>(value_ptr));
270 }
271
get_pad() const272 std::vector<int64_t> Conv2DBackpropInput::get_pad() const {
273 auto value_ptr = GetAttr(kPad);
274 MS_EXCEPTION_IF_NULL(value_ptr);
275 return GetValue<std::vector<int64_t>>(value_ptr);
276 }
277
get_mode() const278 int64_t Conv2DBackpropInput::get_mode() const {
279 auto value_ptr = GetAttr(kMode);
280 MS_EXCEPTION_IF_NULL(value_ptr);
281 return GetValue<int64_t>(value_ptr);
282 }
283
get_group() const284 int64_t Conv2DBackpropInput::get_group() const {
285 auto value_ptr = GetAttr(kGroup);
286 MS_EXCEPTION_IF_NULL(value_ptr);
287 return GetValue<int64_t>(value_ptr);
288 }
289
get_format() const290 Format Conv2DBackpropInput::get_format() const {
291 auto value_ptr = GetAttr(kFormat);
292 MS_EXCEPTION_IF_NULL(value_ptr);
293 return Format(GetValue<int64_t>(value_ptr));
294 }
295
get_pad_list() const296 std::vector<int64_t> Conv2DBackpropInput::get_pad_list() const {
297 auto value_ptr = GetAttr(kPadList);
298 MS_EXCEPTION_IF_NULL(value_ptr);
299 return GetValue<std::vector<int64_t>>(value_ptr);
300 }
301 REGISTER_PRIMITIVE_EVAL_IMPL(Conv2DBackpropInput, prim::kPrimConv2DBackpropInput, Conv2DBackpropInputInfer, nullptr,
302 true);
303 } // namespace ops
304 } // namespace mindspore
305