1 /**
2 * Copyright 2020-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ops/conv2d.h"
18
19 #include <algorithm>
20 #include <cmath>
21 #include <iterator>
22 #include <map>
23 #include <memory>
24 #include <set>
25 #include <string>
26 #include <vector>
27
28 #include "abstract/abstract_value.h"
29 #include "abstract/dshape.h"
30 #include "abstract/ops/op_infer.h"
31 #include "abstract/ops/primitive_infer_map.h"
32 #include "abstract/utils.h"
33 #include "base/base.h"
34 #include "ir/anf.h"
35 #include "ir/dtype/number.h"
36 #include "ir/dtype/type.h"
37 #include "ir/primitive.h"
38 #include "ir/scalar.h"
39 #include "ir/value.h"
40 #include "mindapi/base/shape_vector.h"
41 #include "mindapi/base/shared_ptr.h"
42 #include "mindapi/base/type_id.h"
43 #include "mindapi/ir/value.h"
44 #include "mindapi/src/helper.h"
45 #include "mindspore/core/ops/conv_pool_ops.h"
46 #include "ops/op_name.h"
47 #include "ops/primitive_c.h"
48 #include "utils/check_convert_utils.h"
49 #include "utils/convert_utils_base.h"
50 #include "utils/log_adapter.h"
51 #include "utils/shape_utils.h"
52
53 using mindspore::abstract::Shape;
54 namespace mindspore {
55 namespace ops {
56 namespace {
57 // check functions
58 constexpr size_t kernel_size_num = 2;
59 constexpr size_t stride_num = 2;
60 constexpr size_t dilation_num = 2;
61 constexpr size_t padding_num = 4;
62 constexpr size_t start_index = 2;
63 constexpr size_t top_padding = 0;
64 constexpr size_t bottom_padding = 1;
65 constexpr size_t left_padding = 2;
66 constexpr size_t right_padding = 3;
67
CheckShapeAnyAndPositive(const std::string & op,const ShapeVector & shape)68 void CheckShapeAnyAndPositive(const std::string &op, const ShapeVector &shape) {
69 for (size_t i = 0; i < shape.size(); ++i) {
70 if ((shape[i] < 0) && (shape[i] != abstract::Shape::kShapeDimAny)) {
71 MS_EXCEPTION(ValueError) << "For '" << op << "', shape element [" << i
72 << "] must be positive integer or -1, but got: " << shape[i] << ".";
73 }
74 }
75 }
76
CheckAttrPositiveInt64(const std::string & op,const ValuePtr & attr,const std::string & attr_name)77 int64_t CheckAttrPositiveInt64(const std::string &op, const ValuePtr &attr, const std::string &attr_name) {
78 MS_EXCEPTION_IF_NULL(attr);
79 auto attr_value = attr->cast<Int64ImmPtr>();
80 MS_EXCEPTION_IF_NULL(attr_value);
81 int64_t attr_val = attr_value->value();
82 if (attr_val <= 0) {
83 MS_LOG(EXCEPTION) << "For '" << op << "', '" << attr_name << "' should be greater than 0, but got: " << attr_val
84 << ".";
85 }
86 return attr_val;
87 }
88
CheckAttrIntOrTuple(const ValuePtr & attr,const size_t start_idx,const size_t num_element)89 std::vector<int64_t> CheckAttrIntOrTuple(const ValuePtr &attr, const size_t start_idx, const size_t num_element) {
90 std::vector<int64_t> result;
91 MS_EXCEPTION_IF_NULL(attr);
92 if (attr->isa<ValueTuple>()) {
93 std::vector<ValuePtr> attr_vec = attr->cast<ValueTuplePtr>()->value();
94 if (attr_vec.size() < start_idx + num_element) {
95 MS_LOG(EXCEPTION) << "ValueTuple size verify failed. ValueTuple size is " << attr_vec.size()
96 << ", start index is " << start_idx << ", element number is " << num_element;
97 }
98 auto it_start = attr_vec.begin() + SizeToLong(start_idx);
99 (void)std::transform(it_start, it_start + SizeToLong(num_element), std::back_inserter(result),
100 [](const ValuePtr &e) -> int64_t { return GetValue<int64_t>(e); });
101 } else {
102 (void)result.insert(result.begin(), num_element, GetValue<int64_t>(attr));
103 }
104 return result;
105 }
106
Conv2DPadFunction(std::vector<int64_t> * output_hw,std::vector<int64_t> * pad_list,const int64_t x_h,const int64_t x_w,const std::vector<int64_t> & kernel,const std::vector<int64_t> & stride,const std::vector<int64_t> & dilation,const int64_t & pad_mode,const std::vector<int64_t> & padding,const bool is_min_shape=false)107 void Conv2DPadFunction(std::vector<int64_t> *output_hw, std::vector<int64_t> *pad_list, const int64_t x_h,
108 const int64_t x_w, const std::vector<int64_t> &kernel, const std::vector<int64_t> &stride,
109 const std::vector<int64_t> &dilation, const int64_t &pad_mode,
110 const std::vector<int64_t> &padding, const bool is_min_shape = false) {
111 MS_EXCEPTION_IF_NULL(pad_list);
112 if (pad_mode == PadMode::VALID) {
113 int64_t out_h = -1;
114 int64_t out_w = -1;
115 if (x_h != abstract::Shape::kShapeDimAny) {
116 out_h =
117 static_cast<int64_t>(std::ceil(((x_h * 1.0) - static_cast<float>(dilation[0] * (kernel[0] - 1))) / stride[0]));
118 if (is_min_shape && out_h < 1) {
119 out_h = 1L;
120 }
121 }
122 if (x_w != abstract::Shape::kShapeDimAny) {
123 out_w =
124 static_cast<int64_t>(std::ceil(((x_w * 1.0) - static_cast<float>(dilation[1] * (kernel[1] - 1))) / stride[1]));
125 if (is_min_shape && out_w < 1) {
126 out_w = 1L;
127 }
128 }
129 output_hw->push_back(out_h);
130 output_hw->push_back(out_w);
131 constexpr size_t pad_size = 4;
132 (void)pad_list->insert(pad_list->begin(), pad_size, 0);
133 } else if (pad_mode == PadMode::SAME) {
134 if (x_h == abstract::Shape::kShapeDimAny) {
135 output_hw->push_back(abstract::Shape::kShapeDimAny);
136 pad_list->push_back(abstract::Shape::kShapeDimAny);
137 pad_list->push_back(abstract::Shape::kShapeDimAny);
138 } else {
139 output_hw->push_back(static_cast<int64_t>(std::ceil((x_h * 1.0) / stride[0])));
140 int64_t pad_needed_h = (output_hw->at(0) - 1) * stride[0] + dilation[0] * (kernel[0] - 1) + 1 - x_h;
141 pad_needed_h = std::max(int64_t(0), pad_needed_h);
142 pad_list->push_back(static_cast<int64_t>(std::floor(pad_needed_h / 2)));
143 pad_list->push_back(pad_needed_h - pad_list->at(0));
144 }
145
146 if (x_w == abstract::Shape::kShapeDimAny) {
147 output_hw->push_back(abstract::Shape::kShapeDimAny);
148 pad_list->push_back(abstract::Shape::kShapeDimAny);
149 pad_list->push_back(abstract::Shape::kShapeDimAny);
150 } else {
151 output_hw->push_back(static_cast<int64_t>(std::ceil((x_w * 1.0) / stride[1])));
152 int64_t pad_needed_w = (output_hw->at(1) - 1) * stride[1] + dilation[1] * (kernel[1] - 1) + 1 - x_w;
153 pad_needed_w = std::max(int64_t(0), pad_needed_w);
154 pad_list->push_back(static_cast<int64_t>(std::floor(pad_needed_w / 2)));
155 pad_list->push_back(pad_needed_w - pad_list->at(kInputIndex2));
156 }
157 } else if (pad_mode == PadMode::PAD) {
158 (void)pad_list->insert(pad_list->begin(), padding.begin(), padding.end());
159 int64_t out_h = -1;
160 int64_t out_w = -1;
161 if (x_h != abstract::Shape::kShapeDimAny) {
162 out_h = static_cast<int64_t>(std::floor(1 + ((x_h * 1.0) + pad_list->at(0) + pad_list->at(1) - kernel[0] -
163 static_cast<float>((kernel[0] - 1) * (dilation[0] - 1))) /
164 stride[0]));
165 if (is_min_shape && out_h < 1) {
166 out_h = 1L;
167 }
168 }
169 if (x_w != abstract::Shape::kShapeDimAny) {
170 out_w =
171 static_cast<int64_t>(std::floor(1 + ((x_w * 1.0) + pad_list->at(kInputIndex2) + pad_list->at(kInputIndex3) -
172 kernel[1] - static_cast<float>((kernel[1] - 1) * (dilation[1] - 1))) /
173 stride[1]));
174 if (is_min_shape && out_w < 1) {
175 out_w = 1L;
176 }
177 }
178 output_hw->push_back(out_h);
179 output_hw->push_back(out_w);
180 }
181 }
182
CheckConv2dShape(const std::string & prim_name,const std::vector<AbstractBasePtr> & input_args,const std::vector<int64_t> & x_shape,const std::vector<int64_t> & w_shape,const std::vector<int64_t> & padding,int64_t pad_mode,uint64_t w_axis,uint64_t h_axis)183 bool CheckConv2dShape(const std::string &prim_name, const std::vector<AbstractBasePtr> &input_args,
184 const std::vector<int64_t> &x_shape, const std::vector<int64_t> &w_shape,
185 const std::vector<int64_t> &padding, int64_t pad_mode, uint64_t w_axis, uint64_t h_axis) {
186 auto x_shape_ptr = CheckAndConvertUtils::GetTensorInputShape(prim_name, input_args, 0);
187 auto w_shape_ptr = CheckAndConvertUtils::GetTensorInputShape(prim_name, input_args, 1);
188 if (x_shape_ptr->IsDynamic() || w_shape_ptr->IsDynamic()) {
189 return true;
190 }
191 if (w_shape[w_axis] != abstract::Shape::kShapeDimAny && pad_mode != PadMode::SAME) {
192 int64_t input_height = x_shape[h_axis];
193 int64_t input_width = x_shape[w_axis];
194 if (pad_mode == PadMode::PAD) {
195 input_width += padding[left_padding] + padding[right_padding];
196 input_height += padding[top_padding] + padding[bottom_padding];
197 }
198 if (input_height < w_shape[h_axis] || input_width < w_shape[w_axis]) {
199 return false;
200 }
201 }
202 return true;
203 }
204
Conv2dInferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)205 abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
206 auto prim_name = primitive->name();
207 auto x_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShape());
208 auto w_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->GetShape());
209 auto x_shape = x_shape_map[kShape];
210 auto w_shape = w_shape_map[kShape];
211
212 ShapeVector output_shape;
213 const auto shape_size = 4;
214 if (IsDynamicRank(x_shape) || IsDynamicRank(w_shape)) {
215 std::vector<int64_t> pad_list(shape_size, abstract::Shape::kShapeDimAny);
216 std::vector<ValuePtr> pad_list_val = {MakeValue(pad_list[kIndex0]), MakeValue(pad_list[kIndex1]),
217 MakeValue(pad_list[kIndex2]), MakeValue(pad_list[kIndex3])};
218 primitive->set_attr("pad_list", MakeValue(pad_list_val));
219 output_shape = {abstract::Shape::kShapeRankAny};
220 return std::make_shared<abstract::Shape>(output_shape);
221 }
222
223 (void)CheckAndConvertUtils::CheckInteger("x shape size", SizeToLong(x_shape.size()), kEqual, shape_size, prim_name);
224 (void)CheckAndConvertUtils::CheckInteger("w shape size", SizeToLong(w_shape.size()), kEqual, shape_size, prim_name);
225 CheckShapeAnyAndPositive(prim_name + " x_shape", x_shape);
226 CheckShapeAnyAndPositive(prim_name + " w_shape", w_shape);
227 const uint64_t n_axis = 0;
228 uint64_t c_axis = 1;
229 uint64_t h_axis = 2;
230 uint64_t w_axis = 3;
231 int64_t data_format = CheckAndConvertUtils::GetAndCheckFormat(primitive->GetAttr("format"));
232 if (data_format == static_cast<int64_t>(Format::NHWC)) {
233 c_axis = 3;
234 h_axis = 1;
235 w_axis = 2;
236 }
237 int64_t group = CheckAttrPositiveInt64(prim_name, primitive->GetAttr("group"), "group");
238 if ((x_shape[c_axis] != abstract::Shape::kShapeDimAny) && (w_shape[c_axis] != abstract::Shape::kShapeDimAny) &&
239 ((x_shape[c_axis] / group) != w_shape[c_axis])) {
240 MS_LOG(EXCEPTION) << "For '" << prim_name
241 << "', 'C_in' of input 'x' shape divide by parameter 'group' must be "
242 "equal to 'C_in' of input 'weight' shape: "
243 << w_shape[c_axis] << ", but got 'C_in' of input 'x' shape: " << x_shape[c_axis]
244 << ", and 'group': " << group << ".";
245 }
246 int64_t out_channel = CheckAttrPositiveInt64(prim_name, primitive->GetAttr("out_channel"), "out_channel");
247 if (w_shape[n_axis] == abstract::Shape::kShapeDimAny) {
248 out_channel = w_shape[n_axis];
249 } else {
250 if (w_shape[n_axis] != out_channel) {
251 MS_LOG(EXCEPTION) << "For '" << prim_name << "', 'w_shape[" << n_axis
252 << "]' must be equal to 'out_channel', but got 'w_shape[" << n_axis << "]': " << w_shape[n_axis]
253 << ", 'out_channel': " << out_channel << ".";
254 }
255 }
256 std::vector<int64_t> kernel_size = CheckAttrIntOrTuple(primitive->GetAttr("kernel_size"), 0, kernel_size_num);
257 if ((w_shape[h_axis] != abstract::Shape::kShapeDimAny) && (w_shape[h_axis] != kernel_size[0])) {
258 MS_LOG(EXCEPTION) << "For '" << prim_name << "', 'w_shape[" << h_axis
259 << "]' must be equal to 'kernel_size[0]', but got 'w_shape[" << h_axis
260 << "]': " << w_shape[h_axis] << ", 'kernel_size[0]': " << kernel_size[0] << ".";
261 }
262 if ((w_shape[w_axis] != abstract::Shape::kShapeDimAny) && (w_shape[w_axis] != kernel_size[1])) {
263 MS_LOG(EXCEPTION) << "For '" << prim_name << "', 'w_shape[" << w_axis
264 << "]' must be equal to 'kernel_size[1]', but got 'w_shape[" << w_axis
265 << "]': " << w_shape[w_axis] << ", 'kernel_size[1]': " << kernel_size[1] << ".";
266 }
267 std::vector<int64_t> stride = CheckAttrIntOrTuple(primitive->GetAttr("stride"), start_index, stride_num);
268 std::vector<int64_t> dilation = CheckAttrIntOrTuple(primitive->GetAttr("dilation"), start_index, dilation_num);
269 std::vector<int64_t> padding = CheckAttrIntOrTuple(primitive->GetAttr("pad"), 0, padding_num);
270 int64_t pad_mode;
271 CheckAndConvertUtils::GetPadModEnumValue(primitive->GetAttr("pad_mode"), &pad_mode);
272 if (!CheckConv2dShape(prim_name, input_args, x_shape, w_shape, padding, pad_mode, w_axis, h_axis)) {
273 MS_EXCEPTION(ValueError) << "For 'Conv2d', input shape's h and w after padding must be greater than or equal to "
274 "kernel_size's h and w respectively.";
275 }
276 std::vector<int64_t> output_hw;
277 std::vector<int64_t> pad_list;
278 Conv2DPadFunction(&output_hw, &pad_list, x_shape[h_axis], x_shape[w_axis], kernel_size, stride, dilation, pad_mode,
279 padding);
280 std::vector<ValuePtr> pad_list_val = {MakeValue(pad_list[0]), MakeValue(pad_list[1]), MakeValue(pad_list[2]),
281 MakeValue(pad_list[3])};
282 primitive->set_attr("pad_list", MakeValue(pad_list_val));
283
284 output_shape = (data_format == static_cast<int64_t>(Format::NHWC))
285 ? ShapeVector{x_shape[n_axis], output_hw[0], output_hw[1], out_channel}
286 : ShapeVector{x_shape[n_axis], out_channel, output_hw[0], output_hw[1]};
287 CheckShapeAnyAndPositive(prim_name + " output_shape", output_shape);
288 return std::make_shared<abstract::Shape>(output_shape);
289 }
290
Conv2dInferType(const PrimitivePtr & prim,const std::vector<AbstractBasePtr> & input_args)291 TypePtr Conv2dInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
292 const std::set<TypePtr> valid_types = {kInt8, kInt32, kInt64, kFloat16, kFloat32, kBFloat16};
293 auto out_type = CheckAndConvertUtils::CheckTypeValid("x", input_args[0]->GetType(), valid_types, prim->name());
294 if (out_type->type_id() == TypeId::kNumberTypeInt8) {
295 out_type = kInt32;
296 }
297 return out_type;
298 }
299 } // namespace
300
301 MIND_API_OPERATOR_IMPL(Conv2D, BaseOperator);
Init(int64_t out_channel,const std::vector<int64_t> & kernel_size,int64_t mode,const PadMode & pad_mode,const std::vector<int64_t> & pad,const std::vector<int64_t> & stride,const std::vector<int64_t> & dilation,int64_t group,const Format & format)302 void Conv2D::Init(int64_t out_channel, const std::vector<int64_t> &kernel_size, int64_t mode, const PadMode &pad_mode,
303 const std::vector<int64_t> &pad, const std::vector<int64_t> &stride,
304 const std::vector<int64_t> &dilation, int64_t group, const Format &format) {
305 set_kernel_size(kernel_size);
306 set_stride(stride);
307 set_dilation(dilation);
308 set_pad(pad);
309 set_pad_mode(pad_mode);
310 set_mode(mode);
311 set_out_channel(out_channel);
312 set_group(group);
313 set_format(format);
314 }
315
set_out_channel(int64_t out_channel)316 void Conv2D::set_out_channel(int64_t out_channel) {
317 (void)AddAttr(kOutChannel,
318 api::MakeValue(CheckAndConvertUtils::CheckInteger(kOutChannel, out_channel, kGreaterThan, 0, name())));
319 }
320
set_kernel_size(const std::vector<int64_t> & kernel_size)321 void Conv2D::set_kernel_size(const std::vector<int64_t> &kernel_size) {
322 (void)AddAttr(kKernelSize,
323 api::MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, name())));
324 }
325
set_stride(const std::vector<int64_t> & stride)326 void Conv2D::set_stride(const std::vector<int64_t> &stride) {
327 (void)AddAttr(kStride, api::MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStride, stride, name())));
328 }
329
set_dilation(const std::vector<int64_t> & dilation)330 void Conv2D::set_dilation(const std::vector<int64_t> &dilation) {
331 (void)AddAttr(kDilation, api::MakeValue(CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, name())));
332 }
333
set_pad_mode(const PadMode & pad_mode)334 void Conv2D::set_pad_mode(const PadMode &pad_mode) {
335 std::vector<int64_t> pad = get_pad();
336 if (pad_mode == PAD) {
337 for (auto item : pad) {
338 CheckAndConvertUtils::Check(kPadItem, item, kGreaterEqual, 0, name());
339 }
340 } else {
341 CheckAndConvertUtils::Check(kPad, pad, kEqual, {0, 0, 0, 0}, name());
342 }
343 int64_t swi = pad_mode;
344 (void)AddAttr(kPadMode, api::MakeValue(swi));
345 }
346
set_pad(const std::vector<int64_t> & pad)347 void Conv2D::set_pad(const std::vector<int64_t> &pad) {
348 const int64_t pad_size = 4;
349 (void)CheckAndConvertUtils::CheckInteger("pad_size", SizeToLong(pad.size()), kEqual, pad_size, name());
350 (void)AddAttr(kPad, api::MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name())));
351 }
352
set_mode(int64_t mode)353 void Conv2D::set_mode(int64_t mode) {
354 (void)AddAttr(kMode, api::MakeValue(CheckAndConvertUtils::CheckInteger(kMode, mode, kEqual, 1, name())));
355 }
356
set_group(int64_t group)357 void Conv2D::set_group(int64_t group) {
358 (void)AddAttr(kGroup, api::MakeValue(CheckAndConvertUtils::CheckInteger(kGroup, group, kGreaterThan, 0, name())));
359 }
360
set_format(const Format & format)361 void Conv2D::set_format(const Format &format) {
362 int64_t f = format;
363 (void)AddAttr(kFormat, api::MakeValue(f));
364 }
365
get_out_channel() const366 int64_t Conv2D::get_out_channel() const {
367 auto value_ptr = GetAttr(kOutChannel);
368 return GetValue<int64_t>(value_ptr);
369 }
370
get_kernel_size() const371 std::vector<int64_t> Conv2D::get_kernel_size() const {
372 auto value_ptr = GetAttr(kKernelSize);
373 return GetValue<std::vector<int64_t>>(value_ptr);
374 }
375
get_stride() const376 std::vector<int64_t> Conv2D::get_stride() const {
377 auto value_ptr = GetAttr(kStride);
378 return GetValue<std::vector<int64_t>>(value_ptr);
379 }
380
get_dilation() const381 std::vector<int64_t> Conv2D::get_dilation() const {
382 auto value_ptr = GetAttr(kDilation);
383 return GetValue<std::vector<int64_t>>(value_ptr);
384 }
385
get_pad_mode() const386 PadMode Conv2D::get_pad_mode() const {
387 auto value_ptr = GetAttr(kPadMode);
388 return PadMode(GetValue<int64_t>(value_ptr));
389 }
390
get_pad() const391 std::vector<int64_t> Conv2D::get_pad() const {
392 auto value_ptr = GetAttr(kPad);
393 return GetValue<std::vector<int64_t>>(value_ptr);
394 }
395
get_mode() const396 int64_t Conv2D::get_mode() const {
397 auto value_ptr = GetAttr(kMode);
398 return GetValue<int64_t>(value_ptr);
399 }
400
get_group() const401 int64_t Conv2D::get_group() const {
402 auto value_ptr = GetAttr(kGroup);
403 return GetValue<int64_t>(value_ptr);
404 }
405
get_format() const406 Format Conv2D::get_format() const {
407 auto value_ptr = GetAttr(kFormat);
408 return Format(GetValue<int64_t>(value_ptr));
409 }
410
Conv2dInfer(const abstract::AnalysisEnginePtr &,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)411 AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
412 const std::vector<AbstractBasePtr> &input_args) {
413 MS_EXCEPTION_IF_NULL(primitive);
414 for (auto item : input_args) {
415 MS_EXCEPTION_IF_NULL(item);
416 }
417
418 const int64_t input_num = 2;
419 (void)CheckAndConvertUtils::CheckInteger("Conv2d infer", SizeToLong(input_args.size()), kGreaterEqual, input_num,
420 primitive->name());
421 const std::set<TypePtr> valid_types = {kInt8, kInt32, kInt64, kFloat16, kFloat32, kBFloat16};
422 std::map<std::string, TypePtr> types;
423 (void)types.emplace("x", input_args[0]->GetType());
424 (void)types.emplace("w", input_args[1]->GetType());
425 (void)CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, primitive->name());
426 return abstract::MakeAbstract(Conv2dInferShape(primitive, input_args), Conv2dInferType(primitive, input_args));
427 }
428
429 // AG means auto generated
430 class MIND_API AGConv2dInfer : public abstract::OpInferBase {
431 public:
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const432 BaseShapePtr InferShape(const PrimitivePtr &primitive,
433 const std::vector<AbstractBasePtr> &input_args) const override {
434 return Conv2dInferShape(primitive, input_args);
435 }
436
InferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const437 TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
438 return Conv2dInferType(primitive, input_args);
439 }
InferShapeAndType(const abstract::AnalysisEnginePtr & engine,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const440 AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
441 const std::vector<AbstractBasePtr> &input_args) const override {
442 return Conv2dInfer(engine, primitive, input_args);
443 }
444 };
445
446 REGISTER_PRIMITIVE_OP_INFER_IMPL(Conv2D, prim::kPrimConv2D, AGConv2dInfer, false);
447 } // namespace ops
448 } // namespace mindspore
449