• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <map>
18 #include <vector>
19 #include <string>
20 #include <memory>
21 #include <set>
22 
23 #include "ops/grad/conv2d_backprop_input.h"
24 
25 namespace mindspore {
26 namespace ops {
27 namespace {
28 constexpr size_t kDoutIndex = 0;
29 constexpr size_t kInputIndex = 1;
30 constexpr int64_t kSizeIndex = 2;
31 
SetPadList(const PrimitivePtr & primitive,const std::vector<int64_t> & dout_shape_norm,const std::vector<int64_t> & x_size_v)32 void SetPadList(const PrimitivePtr &primitive, const std::vector<int64_t> &dout_shape_norm,
33                 const std::vector<int64_t> &x_size_v) {
34   MS_EXCEPTION_IF_NULL(primitive);
35   auto prim_name = primitive->name();
36   // check
37   auto kernel_size =
38     CheckAndConvertUtils::CheckAttrIntOrTupleInt("kernel_size", primitive->GetAttr(kKernelSize), prim_name);
39   auto stride = CheckAndConvertUtils::CheckAttrIntOrTupleInt("stride", primitive->GetAttr(kStride), prim_name);
40   auto dilation = CheckAndConvertUtils::CheckAttrIntOrTupleInt("dilation", primitive->GetAttr(kDilation), prim_name);
41   // default pad mode is valid
42   auto attr_pad_list_prt = primitive->GetAttr(kPadList);
43   int64_t pad_mode;
44   CheckAndConvertUtils::GetPadModEnumValue(primitive->GetAttr(kPadMode), &pad_mode, true);
45   ShapeVector pad_list = {0, 0, 0, 0};
46   if (!attr_pad_list_prt->isa<None>()) {
47     pad_list = GetValue<ShapeVector>(attr_pad_list_prt);
48   } else if (pad_mode == SAME) {
49     auto stride_h = stride[2];
50     auto stride_w = stride[3];
51     auto kernel_h = kernel_size[0];
52     auto kernel_w = kernel_size[1];
53     auto dilation_h = dilation[2];
54     auto dilation_w = dilation[3];
55     int64_t pad_top = abstract::Shape::SHP_ANY;
56     int64_t pad_bottom = abstract::Shape::SHP_ANY;
57     int64_t pad_left = abstract::Shape::SHP_ANY;
58     int64_t pad_right = abstract::Shape::SHP_ANY;
59     if (dout_shape_norm[kInputIndex2] != abstract::Shape::SHP_ANY &&
60         x_size_v[kInputIndex2] != abstract::Shape::SHP_ANY) {
61       auto pad_needed_h =
62         (dout_shape_norm[kInputIndex2] - 1) * stride_h + dilation_h * (kernel_h - 1) + 1 - x_size_v[kInputIndex2];
63       pad_needed_h = 0 > pad_needed_h ? 0 : pad_needed_h;
64       pad_top = pad_needed_h / 2;
65       pad_bottom = pad_needed_h - pad_top;
66     }
67     if (dout_shape_norm[kInputIndex3] != abstract::Shape::SHP_ANY &&
68         x_size_v[kInputIndex3] != abstract::Shape::SHP_ANY) {
69       auto pad_needed_w =
70         (dout_shape_norm[kInputIndex3] - 1) * stride_w + dilation_w * (kernel_w - 1) + 1 - x_size_v[kInputIndex3];
71       pad_needed_w = pad_needed_w > 0L ? pad_needed_w : 0L;
72       pad_left = pad_needed_w / 2;
73       pad_right = pad_needed_w - pad_left;
74     }
75     pad_list = {pad_top, pad_bottom, pad_left, pad_right};
76   } else if (pad_mode == PAD) {
77     pad_list = GetValue<std::vector<int64_t>>(primitive->GetAttr(kPad));
78   }
79   (void)primitive->AddAttr(kPadList, MakeValue(pad_list));
80 }
81 
Conv2DBackpropInputInferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)82 abstract::ShapePtr Conv2DBackpropInputInferShape(const PrimitivePtr &primitive,
83                                                  const std::vector<AbstractBasePtr> &input_args) {
84   MS_EXCEPTION_IF_NULL(primitive);
85   auto prim_name = primitive->name();
86   std::vector<int64_t> out_shape;
87   abstract::ShapePtr ret_shape;
88   auto input_size = input_args[kSizeIndex];
89   auto input_size_v = input_size->BuildValue();
90   MS_EXCEPTION_IF_NULL(input_size_v);
91 
92   if (input_size->isa<abstract::AbstractTensor>()) {
93     if (input_size_v->isa<tensor::Tensor>()) {
94       out_shape = CheckAndConvertUtils::CheckTensorIntValue("input x size", input_size_v, prim_name);
95       ret_shape = std::make_shared<abstract::Shape>(out_shape);
96     } else {
97       auto shape_ptr = CheckAndConvertUtils::GetTensorInputShape(prim_name, input_args, kSizeIndex);
98       MS_EXCEPTION_IF_NULL(shape_ptr);
99       auto shape_shape = shape_ptr->shape();
100       if (shape_shape.size() != 1) {
101         MS_LOG(EXCEPTION) << "The " << prim_name << "'s x size must be 1-D.";
102       }
103 
104       auto abstract_tensor = input_size->cast<abstract::AbstractTensorPtr>();
105       MS_EXCEPTION_IF_NULL(abstract_tensor);
106       auto shape_max_value = abstract_tensor->get_max_value();
107       auto shape_min_value = abstract_tensor->get_min_value();
108       if (shape_max_value == nullptr || shape_min_value == nullptr) {
109         MS_LOG(EXCEPTION) << "Max_value or min value of x size can not be empty when its value is dynamic.";
110       }
111 
112       auto shape_max = GetValue<std::vector<int64_t>>(shape_max_value);
113       auto shape_min = GetValue<std::vector<int64_t>>(shape_min_value);
114 
115       auto x_size_len = LongToSize(shape_shape[0]);
116       if (shape_max.size() != x_size_len || shape_min.size() != x_size_len) {
117         MS_LOG(EXCEPTION) << "For " << prim_name << ", x size's min or max value is valid.";
118       }
119 
120       for (size_t i = 0; i < x_size_len; i++) {
121         if (shape_min[i] == shape_max[i]) {
122           out_shape.push_back(shape_min[i]);
123         } else {
124           out_shape.push_back(abstract::Shape::SHP_ANY);
125         }
126       }
127       ret_shape = std::make_shared<abstract::Shape>(out_shape, shape_min, shape_max);
128     }
129   } else if (input_size->isa<abstract::AbstractTuple>()) {
130     // check tensor, tuple or int to raise error.
131     out_shape = CheckAndConvertUtils::CheckAttrIntOrTupleInt("input x size", input_size_v, prim_name);
132     ret_shape = std::make_shared<abstract::Shape>(out_shape);
133   } else {
134     MS_EXCEPTION(TypeError) << "Conv2DBackpropInput x_size must be a tuple or tensor, but " << input_size->ToString();
135   }
136   auto dout_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kDoutIndex]->BuildShape())[kShape];
137 
138   auto format = CheckAndConvertUtils::GetAndCheckFormat(primitive->GetAttr(kFormat));
139   ShapeVector tmp_shape = {dout_shape[0], dout_shape[2], dout_shape[3], dout_shape[1]};
140   auto dout_shape_norm = format == Format::NCHW ? dout_shape : tmp_shape;
141   SetPadList(primitive, dout_shape_norm, out_shape);
142   return ret_shape;
143 }
144 
Conv2DBackpropInputInferType(const PrimitivePtr & prim,const std::vector<AbstractBasePtr> & input_args)145 TypePtr Conv2DBackpropInputInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
146   MS_EXCEPTION_IF_NULL(prim);
147   auto prim_name = prim->name();
148   // check
149   std::map<std::string, TypePtr> types;
150   // todo: check input_sizes
151   (void)types.emplace("x", input_args[kInputIndex]->BuildType());
152   (void)types.emplace("doutput", input_args[kDoutIndex]->BuildType());
153   std::set<TypePtr> valid_x_type = {kInt8, kInt32, kFloat16, kFloat32};
154   return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_x_type, prim_name);
155 }
156 }  // namespace
Conv2DBackpropInputInfer(const abstract::AnalysisEnginePtr &,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)157 AbstractBasePtr Conv2DBackpropInputInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
158                                          const std::vector<AbstractBasePtr> &input_args) {
159   MS_EXCEPTION_IF_NULL(primitive);
160   auto prim_name = primitive->name();
161   // check
162   const int64_t input_num = 3;
163   (void)CheckAndConvertUtils::CheckInteger("input size", SizeToLong(input_args.size()), kGreaterEqual, input_num,
164                                            prim_name);
165   for (const auto &item : input_args) {
166     MS_EXCEPTION_IF_NULL(item);
167   }
168   auto abs = std::make_shared<abstract::AbstractTensor>(Conv2DBackpropInputInferType(primitive, input_args),
169                                                         Conv2DBackpropInputInferShape(primitive, input_args));
170   return abs;
171 }
172 
Init(int64_t out_channel,const std::vector<int64_t> & kernel_size,int64_t mode,const PadMode & pad_mode,const std::vector<int64_t> & pad,const std::vector<int64_t> & stride,const std::vector<int64_t> & dilation,int64_t group,const Format & format,const std::vector<int64_t> & pad_list)173 void Conv2DBackpropInput::Init(int64_t out_channel, const std::vector<int64_t> &kernel_size, int64_t mode,
174                                const PadMode &pad_mode, const std::vector<int64_t> &pad,
175                                const std::vector<int64_t> &stride, const std::vector<int64_t> &dilation, int64_t group,
176                                const Format &format, const std::vector<int64_t> &pad_list) {
177   set_out_channel(out_channel);
178   set_kernel_size(kernel_size);
179   set_mode(mode);
180   set_pad_mode(pad_mode);
181   set_pad(pad);
182   set_stride(stride);
183   set_dilation(dilation);
184   set_group(group);
185   set_format(format);
186   set_pad_list(pad_list);
187 }
188 
set_out_channel(int64_t out_channel)189 void Conv2DBackpropInput::set_out_channel(int64_t out_channel) {
190   (void)AddAttr(kOutChannel,
191                 MakeValue(CheckAndConvertUtils::CheckInteger(kOutChannel, out_channel, kGreaterThan, 0, name())));
192 }
193 
set_kernel_size(const std::vector<int64_t> & kernel_size)194 void Conv2DBackpropInput::set_kernel_size(const std::vector<int64_t> &kernel_size) {
195   (void)AddAttr(kKernelSize, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, name())));
196 }
197 
set_stride(const std::vector<int64_t> & stride)198 void Conv2DBackpropInput::set_stride(const std::vector<int64_t> &stride) {
199   (void)AddAttr(kStride, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStride, stride, name())));
200 }
201 
set_dilation(const std::vector<int64_t> & dilation)202 void Conv2DBackpropInput::set_dilation(const std::vector<int64_t> &dilation) {
203   (void)AddAttr(kDilation, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, name())));
204 }
205 
set_pad_mode(const PadMode & pad_mode)206 void Conv2DBackpropInput::set_pad_mode(const PadMode &pad_mode) {
207   std::vector<int64_t> pad = get_pad();
208   if (pad_mode == PAD) {
209     for (auto item : pad) {
210       CheckAndConvertUtils::Check(kPadItem, item, kGreaterEqual, "zeros_list", 0, name());
211     }
212   } else {
213     CheckAndConvertUtils::Check(kPad, pad, kEqual, "zeros_list", {0, 0, 0, 0}, name());
214   }
215   int64_t swi = pad_mode;
216   (void)AddAttr(kPadMode, MakeValue(swi));
217 }
218 
set_pad(const std::vector<int64_t> & pad)219 void Conv2DBackpropInput::set_pad(const std::vector<int64_t> &pad) {
220   const int64_t pad_size = 4;
221   (void)CheckAndConvertUtils::CheckInteger("pad_size", SizeToLong(pad.size()), kEqual, pad_size, name());
222   (void)AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name())));
223 }
224 
set_mode(int64_t mode)225 void Conv2DBackpropInput::set_mode(int64_t mode) {
226   (void)AddAttr(kMode, MakeValue(CheckAndConvertUtils::CheckInteger(kMode, mode, kEqual, 1, name())));
227 }
228 
set_group(int64_t group)229 void Conv2DBackpropInput::set_group(int64_t group) {
230   (void)AddAttr(kGroup, MakeValue(CheckAndConvertUtils::CheckInteger(kGroup, group, kGreaterThan, 0, name())));
231 }
232 
set_format(const Format & format)233 void Conv2DBackpropInput::set_format(const Format &format) {
234   int64_t f = format;
235   (void)AddAttr(kFormat, MakeValue(f));
236 }
237 
set_pad_list(const std::vector<int64_t> & pad_list)238 void Conv2DBackpropInput::set_pad_list(const std::vector<int64_t> &pad_list) {
239   (void)this->AddAttr(kPadList, MakeValue(pad_list));
240 }
241 
get_out_channel() const242 int64_t Conv2DBackpropInput::get_out_channel() const {
243   auto value_ptr = GetAttr(kOutChannel);
244   MS_EXCEPTION_IF_NULL(value_ptr);
245   return GetValue<int64_t>(value_ptr);
246 }
247 
get_kernel_size() const248 std::vector<int64_t> Conv2DBackpropInput::get_kernel_size() const {
249   auto value_ptr = GetAttr(kKernelSize);
250   MS_EXCEPTION_IF_NULL(value_ptr);
251   return GetValue<std::vector<int64_t>>(value_ptr);
252 }
253 
get_stride() const254 std::vector<int64_t> Conv2DBackpropInput::get_stride() const {
255   auto value_ptr = GetAttr(kStride);
256   MS_EXCEPTION_IF_NULL(value_ptr);
257   return GetValue<std::vector<int64_t>>(value_ptr);
258 }
259 
get_dilation() const260 std::vector<int64_t> Conv2DBackpropInput::get_dilation() const {
261   auto value_ptr = GetAttr(kDilation);
262   MS_EXCEPTION_IF_NULL(value_ptr);
263   return GetValue<std::vector<int64_t>>(value_ptr);
264 }
265 
get_pad_mode() const266 PadMode Conv2DBackpropInput::get_pad_mode() const {
267   auto value_ptr = GetAttr(kPadMode);
268   MS_EXCEPTION_IF_NULL(value_ptr);
269   return PadMode(GetValue<int64_t>(value_ptr));
270 }
271 
get_pad() const272 std::vector<int64_t> Conv2DBackpropInput::get_pad() const {
273   auto value_ptr = GetAttr(kPad);
274   MS_EXCEPTION_IF_NULL(value_ptr);
275   return GetValue<std::vector<int64_t>>(value_ptr);
276 }
277 
get_mode() const278 int64_t Conv2DBackpropInput::get_mode() const {
279   auto value_ptr = GetAttr(kMode);
280   MS_EXCEPTION_IF_NULL(value_ptr);
281   return GetValue<int64_t>(value_ptr);
282 }
283 
get_group() const284 int64_t Conv2DBackpropInput::get_group() const {
285   auto value_ptr = GetAttr(kGroup);
286   MS_EXCEPTION_IF_NULL(value_ptr);
287   return GetValue<int64_t>(value_ptr);
288 }
289 
get_format() const290 Format Conv2DBackpropInput::get_format() const {
291   auto value_ptr = GetAttr(kFormat);
292   MS_EXCEPTION_IF_NULL(value_ptr);
293   return Format(GetValue<int64_t>(value_ptr));
294 }
295 
get_pad_list() const296 std::vector<int64_t> Conv2DBackpropInput::get_pad_list() const {
297   auto value_ptr = GetAttr(kPadList);
298   MS_EXCEPTION_IF_NULL(value_ptr);
299   return GetValue<std::vector<int64_t>>(value_ptr);
300 }
301 REGISTER_PRIMITIVE_EVAL_IMPL(Conv2DBackpropInput, prim::kPrimConv2DBackpropInput, Conv2DBackpropInputInfer, nullptr,
302                              true);
303 }  // namespace ops
304 }  // namespace mindspore
305