• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ops/conv2d.h"
18 
19 #include <algorithm>
20 #include <cmath>
21 #include <iterator>
22 #include <map>
23 #include <memory>
24 #include <set>
25 #include <string>
26 #include <vector>
27 
28 #include "abstract/abstract_value.h"
29 #include "abstract/dshape.h"
30 #include "abstract/ops/op_infer.h"
31 #include "abstract/ops/primitive_infer_map.h"
32 #include "abstract/utils.h"
33 #include "base/base.h"
34 #include "ir/anf.h"
35 #include "ir/dtype/number.h"
36 #include "ir/dtype/type.h"
37 #include "ir/primitive.h"
38 #include "ir/scalar.h"
39 #include "ir/value.h"
40 #include "mindapi/base/shape_vector.h"
41 #include "mindapi/base/shared_ptr.h"
42 #include "mindapi/base/type_id.h"
43 #include "mindapi/ir/value.h"
44 #include "mindapi/src/helper.h"
45 #include "mindspore/core/ops/conv_pool_ops.h"
46 #include "ops/op_name.h"
47 #include "ops/primitive_c.h"
48 #include "utils/check_convert_utils.h"
49 #include "utils/convert_utils_base.h"
50 #include "utils/log_adapter.h"
51 #include "utils/shape_utils.h"
52 
53 using mindspore::abstract::Shape;
54 namespace mindspore {
55 namespace ops {
56 namespace {
57 // check functions
58 constexpr size_t kernel_size_num = 2;
59 constexpr size_t stride_num = 2;
60 constexpr size_t dilation_num = 2;
61 constexpr size_t padding_num = 4;
62 constexpr size_t start_index = 2;
63 constexpr size_t top_padding = 0;
64 constexpr size_t bottom_padding = 1;
65 constexpr size_t left_padding = 2;
66 constexpr size_t right_padding = 3;
67 
CheckShapeAnyAndPositive(const std::string & op,const ShapeVector & shape)68 void CheckShapeAnyAndPositive(const std::string &op, const ShapeVector &shape) {
69   for (size_t i = 0; i < shape.size(); ++i) {
70     if ((shape[i] < 0) && (shape[i] != abstract::Shape::kShapeDimAny)) {
71       MS_EXCEPTION(ValueError) << "For '" << op << "',  shape element [" << i
72                                << "] must be positive integer or -1, but got: " << shape[i] << ".";
73     }
74   }
75 }
76 
CheckAttrPositiveInt64(const std::string & op,const ValuePtr & attr,const std::string & attr_name)77 int64_t CheckAttrPositiveInt64(const std::string &op, const ValuePtr &attr, const std::string &attr_name) {
78   MS_EXCEPTION_IF_NULL(attr);
79   auto attr_value = attr->cast<Int64ImmPtr>();
80   MS_EXCEPTION_IF_NULL(attr_value);
81   int64_t attr_val = attr_value->value();
82   if (attr_val <= 0) {
83     MS_LOG(EXCEPTION) << "For '" << op << "', '" << attr_name << "' should be greater than 0, but got: " << attr_val
84                       << ".";
85   }
86   return attr_val;
87 }
88 
CheckAttrIntOrTuple(const ValuePtr & attr,const size_t start_idx,const size_t num_element)89 std::vector<int64_t> CheckAttrIntOrTuple(const ValuePtr &attr, const size_t start_idx, const size_t num_element) {
90   std::vector<int64_t> result;
91   MS_EXCEPTION_IF_NULL(attr);
92   if (attr->isa<ValueTuple>()) {
93     std::vector<ValuePtr> attr_vec = attr->cast<ValueTuplePtr>()->value();
94     if (attr_vec.size() < start_idx + num_element) {
95       MS_LOG(EXCEPTION) << "ValueTuple size verify failed. ValueTuple size is " << attr_vec.size()
96                         << ", start index is " << start_idx << ", element number is " << num_element;
97     }
98     auto it_start = attr_vec.begin() + SizeToLong(start_idx);
99     (void)std::transform(it_start, it_start + SizeToLong(num_element), std::back_inserter(result),
100                          [](const ValuePtr &e) -> int64_t { return GetValue<int64_t>(e); });
101   } else {
102     (void)result.insert(result.begin(), num_element, GetValue<int64_t>(attr));
103   }
104   return result;
105 }
106 
Conv2DPadFunction(std::vector<int64_t> * output_hw,std::vector<int64_t> * pad_list,const int64_t x_h,const int64_t x_w,const std::vector<int64_t> & kernel,const std::vector<int64_t> & stride,const std::vector<int64_t> & dilation,const int64_t & pad_mode,const std::vector<int64_t> & padding,const bool is_min_shape=false)107 void Conv2DPadFunction(std::vector<int64_t> *output_hw, std::vector<int64_t> *pad_list, const int64_t x_h,
108                        const int64_t x_w, const std::vector<int64_t> &kernel, const std::vector<int64_t> &stride,
109                        const std::vector<int64_t> &dilation, const int64_t &pad_mode,
110                        const std::vector<int64_t> &padding, const bool is_min_shape = false) {
111   MS_EXCEPTION_IF_NULL(pad_list);
112   if (pad_mode == PadMode::VALID) {
113     int64_t out_h = -1;
114     int64_t out_w = -1;
115     if (x_h != abstract::Shape::kShapeDimAny) {
116       out_h =
117         static_cast<int64_t>(std::ceil(((x_h * 1.0) - static_cast<float>(dilation[0] * (kernel[0] - 1))) / stride[0]));
118       if (is_min_shape && out_h < 1) {
119         out_h = 1L;
120       }
121     }
122     if (x_w != abstract::Shape::kShapeDimAny) {
123       out_w =
124         static_cast<int64_t>(std::ceil(((x_w * 1.0) - static_cast<float>(dilation[1] * (kernel[1] - 1))) / stride[1]));
125       if (is_min_shape && out_w < 1) {
126         out_w = 1L;
127       }
128     }
129     output_hw->push_back(out_h);
130     output_hw->push_back(out_w);
131     constexpr size_t pad_size = 4;
132     (void)pad_list->insert(pad_list->begin(), pad_size, 0);
133   } else if (pad_mode == PadMode::SAME) {
134     if (x_h == abstract::Shape::kShapeDimAny) {
135       output_hw->push_back(abstract::Shape::kShapeDimAny);
136       pad_list->push_back(abstract::Shape::kShapeDimAny);
137       pad_list->push_back(abstract::Shape::kShapeDimAny);
138     } else {
139       output_hw->push_back(static_cast<int64_t>(std::ceil((x_h * 1.0) / stride[0])));
140       int64_t pad_needed_h = (output_hw->at(0) - 1) * stride[0] + dilation[0] * (kernel[0] - 1) + 1 - x_h;
141       pad_needed_h = std::max(int64_t(0), pad_needed_h);
142       pad_list->push_back(static_cast<int64_t>(std::floor(pad_needed_h / 2)));
143       pad_list->push_back(pad_needed_h - pad_list->at(0));
144     }
145 
146     if (x_w == abstract::Shape::kShapeDimAny) {
147       output_hw->push_back(abstract::Shape::kShapeDimAny);
148       pad_list->push_back(abstract::Shape::kShapeDimAny);
149       pad_list->push_back(abstract::Shape::kShapeDimAny);
150     } else {
151       output_hw->push_back(static_cast<int64_t>(std::ceil((x_w * 1.0) / stride[1])));
152       int64_t pad_needed_w = (output_hw->at(1) - 1) * stride[1] + dilation[1] * (kernel[1] - 1) + 1 - x_w;
153       pad_needed_w = std::max(int64_t(0), pad_needed_w);
154       pad_list->push_back(static_cast<int64_t>(std::floor(pad_needed_w / 2)));
155       pad_list->push_back(pad_needed_w - pad_list->at(kInputIndex2));
156     }
157   } else if (pad_mode == PadMode::PAD) {
158     (void)pad_list->insert(pad_list->begin(), padding.begin(), padding.end());
159     int64_t out_h = -1;
160     int64_t out_w = -1;
161     if (x_h != abstract::Shape::kShapeDimAny) {
162       out_h = static_cast<int64_t>(std::floor(1 + ((x_h * 1.0) + pad_list->at(0) + pad_list->at(1) - kernel[0] -
163                                                    static_cast<float>((kernel[0] - 1) * (dilation[0] - 1))) /
164                                                     stride[0]));
165       if (is_min_shape && out_h < 1) {
166         out_h = 1L;
167       }
168     }
169     if (x_w != abstract::Shape::kShapeDimAny) {
170       out_w =
171         static_cast<int64_t>(std::floor(1 + ((x_w * 1.0) + pad_list->at(kInputIndex2) + pad_list->at(kInputIndex3) -
172                                              kernel[1] - static_cast<float>((kernel[1] - 1) * (dilation[1] - 1))) /
173                                               stride[1]));
174       if (is_min_shape && out_w < 1) {
175         out_w = 1L;
176       }
177     }
178     output_hw->push_back(out_h);
179     output_hw->push_back(out_w);
180   }
181 }
182 
CheckConv2dShape(const std::string & prim_name,const std::vector<AbstractBasePtr> & input_args,const std::vector<int64_t> & x_shape,const std::vector<int64_t> & w_shape,const std::vector<int64_t> & padding,int64_t pad_mode,uint64_t w_axis,uint64_t h_axis)183 bool CheckConv2dShape(const std::string &prim_name, const std::vector<AbstractBasePtr> &input_args,
184                       const std::vector<int64_t> &x_shape, const std::vector<int64_t> &w_shape,
185                       const std::vector<int64_t> &padding, int64_t pad_mode, uint64_t w_axis, uint64_t h_axis) {
186   auto x_shape_ptr = CheckAndConvertUtils::GetTensorInputShape(prim_name, input_args, 0);
187   auto w_shape_ptr = CheckAndConvertUtils::GetTensorInputShape(prim_name, input_args, 1);
188   if (x_shape_ptr->IsDynamic() || w_shape_ptr->IsDynamic()) {
189     return true;
190   }
191   if (w_shape[w_axis] != abstract::Shape::kShapeDimAny && pad_mode != PadMode::SAME) {
192     int64_t input_height = x_shape[h_axis];
193     int64_t input_width = x_shape[w_axis];
194     if (pad_mode == PadMode::PAD) {
195       input_width += padding[left_padding] + padding[right_padding];
196       input_height += padding[top_padding] + padding[bottom_padding];
197     }
198     if (input_height < w_shape[h_axis] || input_width < w_shape[w_axis]) {
199       return false;
200     }
201   }
202   return true;
203 }
204 
Conv2dInferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)205 abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
206   auto prim_name = primitive->name();
207   auto x_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShape());
208   auto w_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->GetShape());
209   auto x_shape = x_shape_map[kShape];
210   auto w_shape = w_shape_map[kShape];
211 
212   ShapeVector output_shape;
213   const auto shape_size = 4;
214   if (IsDynamicRank(x_shape) || IsDynamicRank(w_shape)) {
215     std::vector<int64_t> pad_list(shape_size, abstract::Shape::kShapeDimAny);
216     std::vector<ValuePtr> pad_list_val = {MakeValue(pad_list[kIndex0]), MakeValue(pad_list[kIndex1]),
217                                           MakeValue(pad_list[kIndex2]), MakeValue(pad_list[kIndex3])};
218     primitive->set_attr("pad_list", MakeValue(pad_list_val));
219     output_shape = {abstract::Shape::kShapeRankAny};
220     return std::make_shared<abstract::Shape>(output_shape);
221   }
222 
223   (void)CheckAndConvertUtils::CheckInteger("x shape size", SizeToLong(x_shape.size()), kEqual, shape_size, prim_name);
224   (void)CheckAndConvertUtils::CheckInteger("w shape size", SizeToLong(w_shape.size()), kEqual, shape_size, prim_name);
225   CheckShapeAnyAndPositive(prim_name + " x_shape", x_shape);
226   CheckShapeAnyAndPositive(prim_name + " w_shape", w_shape);
227   const uint64_t n_axis = 0;
228   uint64_t c_axis = 1;
229   uint64_t h_axis = 2;
230   uint64_t w_axis = 3;
231   int64_t data_format = CheckAndConvertUtils::GetAndCheckFormat(primitive->GetAttr("format"));
232   if (data_format == static_cast<int64_t>(Format::NHWC)) {
233     c_axis = 3;
234     h_axis = 1;
235     w_axis = 2;
236   }
237   int64_t group = CheckAttrPositiveInt64(prim_name, primitive->GetAttr("group"), "group");
238   if ((x_shape[c_axis] != abstract::Shape::kShapeDimAny) && (w_shape[c_axis] != abstract::Shape::kShapeDimAny) &&
239       ((x_shape[c_axis] / group) != w_shape[c_axis])) {
240     MS_LOG(EXCEPTION) << "For '" << prim_name
241                       << "', 'C_in' of input 'x' shape divide by parameter 'group' must be "
242                          "equal to 'C_in' of input 'weight' shape: "
243                       << w_shape[c_axis] << ", but got 'C_in' of input 'x' shape: " << x_shape[c_axis]
244                       << ", and 'group': " << group << ".";
245   }
246   int64_t out_channel = CheckAttrPositiveInt64(prim_name, primitive->GetAttr("out_channel"), "out_channel");
247   if (w_shape[n_axis] == abstract::Shape::kShapeDimAny) {
248     out_channel = w_shape[n_axis];
249   } else {
250     if (w_shape[n_axis] != out_channel) {
251       MS_LOG(EXCEPTION) << "For '" << prim_name << "', 'w_shape[" << n_axis
252                         << "]' must be equal to 'out_channel', but got 'w_shape[" << n_axis << "]': " << w_shape[n_axis]
253                         << ", 'out_channel': " << out_channel << ".";
254     }
255   }
256   std::vector<int64_t> kernel_size = CheckAttrIntOrTuple(primitive->GetAttr("kernel_size"), 0, kernel_size_num);
257   if ((w_shape[h_axis] != abstract::Shape::kShapeDimAny) && (w_shape[h_axis] != kernel_size[0])) {
258     MS_LOG(EXCEPTION) << "For '" << prim_name << "', 'w_shape[" << h_axis
259                       << "]' must be equal to 'kernel_size[0]', but got 'w_shape[" << h_axis
260                       << "]': " << w_shape[h_axis] << ", 'kernel_size[0]': " << kernel_size[0] << ".";
261   }
262   if ((w_shape[w_axis] != abstract::Shape::kShapeDimAny) && (w_shape[w_axis] != kernel_size[1])) {
263     MS_LOG(EXCEPTION) << "For '" << prim_name << "', 'w_shape[" << w_axis
264                       << "]' must be equal to 'kernel_size[1]', but got 'w_shape[" << w_axis
265                       << "]': " << w_shape[w_axis] << ", 'kernel_size[1]': " << kernel_size[1] << ".";
266   }
267   std::vector<int64_t> stride = CheckAttrIntOrTuple(primitive->GetAttr("stride"), start_index, stride_num);
268   std::vector<int64_t> dilation = CheckAttrIntOrTuple(primitive->GetAttr("dilation"), start_index, dilation_num);
269   std::vector<int64_t> padding = CheckAttrIntOrTuple(primitive->GetAttr("pad"), 0, padding_num);
270   int64_t pad_mode;
271   CheckAndConvertUtils::GetPadModEnumValue(primitive->GetAttr("pad_mode"), &pad_mode);
272   if (!CheckConv2dShape(prim_name, input_args, x_shape, w_shape, padding, pad_mode, w_axis, h_axis)) {
273     MS_EXCEPTION(ValueError) << "For 'Conv2d', input shape's h and w after padding must be greater than or equal to "
274                                 "kernel_size's h and w respectively.";
275   }
276   std::vector<int64_t> output_hw;
277   std::vector<int64_t> pad_list;
278   Conv2DPadFunction(&output_hw, &pad_list, x_shape[h_axis], x_shape[w_axis], kernel_size, stride, dilation, pad_mode,
279                     padding);
280   std::vector<ValuePtr> pad_list_val = {MakeValue(pad_list[0]), MakeValue(pad_list[1]), MakeValue(pad_list[2]),
281                                         MakeValue(pad_list[3])};
282   primitive->set_attr("pad_list", MakeValue(pad_list_val));
283 
284   output_shape = (data_format == static_cast<int64_t>(Format::NHWC))
285                    ? ShapeVector{x_shape[n_axis], output_hw[0], output_hw[1], out_channel}
286                    : ShapeVector{x_shape[n_axis], out_channel, output_hw[0], output_hw[1]};
287   CheckShapeAnyAndPositive(prim_name + " output_shape", output_shape);
288   return std::make_shared<abstract::Shape>(output_shape);
289 }
290 
Conv2dInferType(const PrimitivePtr & prim,const std::vector<AbstractBasePtr> & input_args)291 TypePtr Conv2dInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
292   const std::set<TypePtr> valid_types = {kInt8, kInt32, kInt64, kFloat16, kFloat32, kBFloat16};
293   auto out_type = CheckAndConvertUtils::CheckTypeValid("x", input_args[0]->GetType(), valid_types, prim->name());
294   if (out_type->type_id() == TypeId::kNumberTypeInt8) {
295     out_type = kInt32;
296   }
297   return out_type;
298 }
299 }  // namespace
300 
301 MIND_API_OPERATOR_IMPL(Conv2D, BaseOperator);
Init(int64_t out_channel,const std::vector<int64_t> & kernel_size,int64_t mode,const PadMode & pad_mode,const std::vector<int64_t> & pad,const std::vector<int64_t> & stride,const std::vector<int64_t> & dilation,int64_t group,const Format & format)302 void Conv2D::Init(int64_t out_channel, const std::vector<int64_t> &kernel_size, int64_t mode, const PadMode &pad_mode,
303                   const std::vector<int64_t> &pad, const std::vector<int64_t> &stride,
304                   const std::vector<int64_t> &dilation, int64_t group, const Format &format) {
305   set_kernel_size(kernel_size);
306   set_stride(stride);
307   set_dilation(dilation);
308   set_pad(pad);
309   set_pad_mode(pad_mode);
310   set_mode(mode);
311   set_out_channel(out_channel);
312   set_group(group);
313   set_format(format);
314 }
315 
set_out_channel(int64_t out_channel)316 void Conv2D::set_out_channel(int64_t out_channel) {
317   (void)AddAttr(kOutChannel,
318                 api::MakeValue(CheckAndConvertUtils::CheckInteger(kOutChannel, out_channel, kGreaterThan, 0, name())));
319 }
320 
set_kernel_size(const std::vector<int64_t> & kernel_size)321 void Conv2D::set_kernel_size(const std::vector<int64_t> &kernel_size) {
322   (void)AddAttr(kKernelSize,
323                 api::MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, name())));
324 }
325 
set_stride(const std::vector<int64_t> & stride)326 void Conv2D::set_stride(const std::vector<int64_t> &stride) {
327   (void)AddAttr(kStride, api::MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStride, stride, name())));
328 }
329 
set_dilation(const std::vector<int64_t> & dilation)330 void Conv2D::set_dilation(const std::vector<int64_t> &dilation) {
331   (void)AddAttr(kDilation, api::MakeValue(CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, name())));
332 }
333 
set_pad_mode(const PadMode & pad_mode)334 void Conv2D::set_pad_mode(const PadMode &pad_mode) {
335   std::vector<int64_t> pad = get_pad();
336   if (pad_mode == PAD) {
337     for (auto item : pad) {
338       CheckAndConvertUtils::Check(kPadItem, item, kGreaterEqual, 0, name());
339     }
340   } else {
341     CheckAndConvertUtils::Check(kPad, pad, kEqual, {0, 0, 0, 0}, name());
342   }
343   int64_t swi = pad_mode;
344   (void)AddAttr(kPadMode, api::MakeValue(swi));
345 }
346 
set_pad(const std::vector<int64_t> & pad)347 void Conv2D::set_pad(const std::vector<int64_t> &pad) {
348   const int64_t pad_size = 4;
349   (void)CheckAndConvertUtils::CheckInteger("pad_size", SizeToLong(pad.size()), kEqual, pad_size, name());
350   (void)AddAttr(kPad, api::MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name())));
351 }
352 
set_mode(int64_t mode)353 void Conv2D::set_mode(int64_t mode) {
354   (void)AddAttr(kMode, api::MakeValue(CheckAndConvertUtils::CheckInteger(kMode, mode, kEqual, 1, name())));
355 }
356 
set_group(int64_t group)357 void Conv2D::set_group(int64_t group) {
358   (void)AddAttr(kGroup, api::MakeValue(CheckAndConvertUtils::CheckInteger(kGroup, group, kGreaterThan, 0, name())));
359 }
360 
set_format(const Format & format)361 void Conv2D::set_format(const Format &format) {
362   int64_t f = format;
363   (void)AddAttr(kFormat, api::MakeValue(f));
364 }
365 
get_out_channel() const366 int64_t Conv2D::get_out_channel() const {
367   auto value_ptr = GetAttr(kOutChannel);
368   return GetValue<int64_t>(value_ptr);
369 }
370 
get_kernel_size() const371 std::vector<int64_t> Conv2D::get_kernel_size() const {
372   auto value_ptr = GetAttr(kKernelSize);
373   return GetValue<std::vector<int64_t>>(value_ptr);
374 }
375 
get_stride() const376 std::vector<int64_t> Conv2D::get_stride() const {
377   auto value_ptr = GetAttr(kStride);
378   return GetValue<std::vector<int64_t>>(value_ptr);
379 }
380 
get_dilation() const381 std::vector<int64_t> Conv2D::get_dilation() const {
382   auto value_ptr = GetAttr(kDilation);
383   return GetValue<std::vector<int64_t>>(value_ptr);
384 }
385 
get_pad_mode() const386 PadMode Conv2D::get_pad_mode() const {
387   auto value_ptr = GetAttr(kPadMode);
388   return PadMode(GetValue<int64_t>(value_ptr));
389 }
390 
get_pad() const391 std::vector<int64_t> Conv2D::get_pad() const {
392   auto value_ptr = GetAttr(kPad);
393   return GetValue<std::vector<int64_t>>(value_ptr);
394 }
395 
get_mode() const396 int64_t Conv2D::get_mode() const {
397   auto value_ptr = GetAttr(kMode);
398   return GetValue<int64_t>(value_ptr);
399 }
400 
get_group() const401 int64_t Conv2D::get_group() const {
402   auto value_ptr = GetAttr(kGroup);
403   return GetValue<int64_t>(value_ptr);
404 }
405 
get_format() const406 Format Conv2D::get_format() const {
407   auto value_ptr = GetAttr(kFormat);
408   return Format(GetValue<int64_t>(value_ptr));
409 }
410 
Conv2dInfer(const abstract::AnalysisEnginePtr &,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)411 AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
412                             const std::vector<AbstractBasePtr> &input_args) {
413   MS_EXCEPTION_IF_NULL(primitive);
414   for (auto item : input_args) {
415     MS_EXCEPTION_IF_NULL(item);
416   }
417 
418   const int64_t input_num = 2;
419   (void)CheckAndConvertUtils::CheckInteger("Conv2d infer", SizeToLong(input_args.size()), kGreaterEqual, input_num,
420                                            primitive->name());
421   const std::set<TypePtr> valid_types = {kInt8, kInt32, kInt64, kFloat16, kFloat32, kBFloat16};
422   std::map<std::string, TypePtr> types;
423   (void)types.emplace("x", input_args[0]->GetType());
424   (void)types.emplace("w", input_args[1]->GetType());
425   (void)CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, primitive->name());
426   return abstract::MakeAbstract(Conv2dInferShape(primitive, input_args), Conv2dInferType(primitive, input_args));
427 }
428 
429 // AG means auto generated
430 class MIND_API AGConv2dInfer : public abstract::OpInferBase {
431  public:
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const432   BaseShapePtr InferShape(const PrimitivePtr &primitive,
433                           const std::vector<AbstractBasePtr> &input_args) const override {
434     return Conv2dInferShape(primitive, input_args);
435   }
436 
InferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const437   TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
438     return Conv2dInferType(primitive, input_args);
439   }
InferShapeAndType(const abstract::AnalysisEnginePtr & engine,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const440   AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
441                                     const std::vector<AbstractBasePtr> &input_args) const override {
442     return Conv2dInfer(engine, primitive, input_args);
443   }
444 };
445 
446 REGISTER_PRIMITIVE_OP_INFER_IMPL(Conv2D, prim::kPrimConv2D, AGConv2dInfer, false);
447 }  // namespace ops
448 }  // namespace mindspore
449