• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <memory>
18 
19 #include "ops/op_utils.h"
20 #include "utils/check_convert_utils.h"
21 #include "ops/ops_func_impl/im2col_ext.h"
22 #include "ops/ops_func_impl/col2im_ext.h"
23 #include "ops/ops_func_impl/col2im_grad.h"
24 #include "ops/ops_func_impl/simple_infer.h"
25 
26 namespace mindspore {
27 namespace ops {
28 namespace {
Im2ColComputeOutputHeightAndWeight(const std::pair<int64_t,int64_t> & input_hw,const std::vector<int64_t> & kernel_size,const std::vector<int64_t> & dilation,const std::vector<int64_t> & padding,const std::vector<int64_t> & stride)29 std::pair<int64_t, int64_t> Im2ColComputeOutputHeightAndWeight(const std::pair<int64_t, int64_t> &input_hw,
30                                                                const std::vector<int64_t> &kernel_size,
31                                                                const std::vector<int64_t> &dilation,
32                                                                const std::vector<int64_t> &padding,
33                                                                const std::vector<int64_t> &stride) {
34   auto &[input_height, input_width] = input_hw;
35   auto kernel_height = kernel_size[0];
36   auto kernel_width = kernel_size[1];
37   auto dilation_height = dilation[0];
38   auto dilation_width = dilation[1];
39   auto pad_height = padding[0];
40   auto pad_width = padding[1];
41   auto stride_height = stride[0];
42   auto stride_width = stride[1];
43 
44   int64_t output_height =
45     (input_height + 2 * pad_height - (dilation_height * (kernel_height - 1) + 1)) / stride_height + 1;
46   int64_t output_width = (input_width + 2 * pad_width - (dilation_width * (kernel_width - 1) + 1)) / stride_width + 1;
47 
48   return std::make_pair(output_height, output_width);
49 }
50 
Im2ColAndCol2ImCommonCheckArray(const PrimitivePtr & primitive,const ArrayValue<int64_t> & array,const std::string & arg_name,size_t ele_num,int64_t min_value)51 std::vector<int64_t> Im2ColAndCol2ImCommonCheckArray(const PrimitivePtr &primitive, const ArrayValue<int64_t> &array,
52                                                      const std::string &arg_name, size_t ele_num, int64_t min_value) {
53   std::vector<int64_t> values(ele_num, abstract::Shape::kShapeDimAny);
54   MS_CHECK_VALUE(array.size() == ele_num,
55                  CheckAndConvertUtils::FormatCheckIntegerMsg("number of " + arg_name, SizeToLong(array.size()), kEqual,
56                                                              SizeToLong(ele_num), primitive));
57   for (size_t i = 0; i < array.size(); ++i) {
58     if (MS_UNLIKELY(array.IsValueUnknown(i))) {
59       continue;
60     }
61     MS_CHECK_VALUE(array[i] > min_value,
62                    CheckAndConvertUtils::FormatCheckIntegerMsg(arg_name, array[i], kGreaterThan, min_value, primitive));
63     values[i] = array[i];
64   }
65   return values;
66 }
67 
Im2ColAndCol2ImCommonCheckValidation(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args,const std::vector<std::string> & arg_names,size_t ele_num,const std::vector<int64_t> & min_values,size_t start_idx)68 std::pair<int32_t, std::vector<ArrayValue<int64_t>>> Im2ColAndCol2ImCommonCheckValidation(
69   const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args,
70   const std::vector<std::string> &arg_names, size_t ele_num, const std::vector<int64_t> &min_values, size_t start_idx) {
71   assert((input_args.size() - start_idx) == arg_names.size());
72 
73   std::vector<ArrayValue<int64_t>> arrays;
74   for (size_t i = 0; i < arg_names.size(); ++i) {
75     auto array_opt = GetArrayValue<int64_t>(input_args[start_idx + i]);
76     if (MS_UNLIKELY(!array_opt.has_value())) {
77       return std::make_pair(OP_CHECK_RETRY, std::move(arrays));
78     }
79     auto array = std::move(array_opt.value());
80     (void)Im2ColAndCol2ImCommonCheckArray(primitive, array, arg_names[i], ele_num, min_values[i]);
81     arrays.emplace_back(std::move(array));
82   }
83 
84   return std::make_pair(OP_CHECK_SUCCESS, std::move(arrays));
85 }
86 
Im2ColAndCol2ImCommonCheckShape(const PrimitivePtr & primitive,const std::vector<int64_t> & input_shape,size_t no_batch_rank,size_t batch_rank)87 void Im2ColAndCol2ImCommonCheckShape(const PrimitivePtr &primitive, const std::vector<int64_t> &input_shape,
88                                      size_t no_batch_rank, size_t batch_rank) {
89   auto input_rank = input_shape.size();
90   MS_CHECK_VALUE(input_rank >= no_batch_rank && input_rank <= batch_rank,
91                  CheckAndConvertUtils::FormatCheckInRangeMsg("input rank", input_rank, kIncludeBoth,
92                                                              {no_batch_rank, batch_rank}, primitive));
93 
94   auto ShapeElementCheckFunc = [](int64_t dim_value) {
95     if (dim_value != abstract::TensorShape::kShapeDimAny && dim_value <= 0) {
96       return false;
97     }
98     return true;
99   };
100   auto first_dim_after_batch = input_shape.size() == batch_rank ? kIndex1 : kIndex0;
101   auto check_result =
102     std::all_of(input_shape.begin() + first_dim_after_batch, input_shape.end(), ShapeElementCheckFunc);
103   if (MS_UNLIKELY(!check_result)) {
104     MS_EXCEPTION(ValueError)
105       << "For " << primitive->name() << ", expected " << no_batch_rank << "D or " << batch_rank
106       << "D (batch mode) tensor with possibly 0 batch size and other non-zero dimensions for input, but got "
107       << input_shape;
108   }
109 }
110 
Im2ColAndCol2ImCommonCheckShape(const PrimitivePtr & primitive,const AbstractBasePtr & input_arg,size_t no_batch_rank,size_t batch_rank)111 int32_t Im2ColAndCol2ImCommonCheckShape(const PrimitivePtr &primitive, const AbstractBasePtr &input_arg,
112                                         size_t no_batch_rank, size_t batch_rank) {
113   MS_EXCEPTION_IF_NULL(input_arg);
114   const auto &input_shape = input_arg->GetShape()->GetShapeVector();
115   if (MS_UNLIKELY(IsDynamicRank(input_shape))) {
116     return OP_CHECK_RETRY;
117   }
118   Im2ColAndCol2ImCommonCheckShape(primitive, input_shape, no_batch_rank, batch_rank);
119   return OP_CHECK_SUCCESS;
120 }
121 
ArrayOptHasUnknownValue(const std::optional<ArrayValue<int64_t>> & array_opt)122 bool ArrayOptHasUnknownValue(const std::optional<ArrayValue<int64_t>> &array_opt) {
123   if (MS_UNLIKELY(!array_opt.has_value())) {
124     return true;
125   }
126   const auto &array = array_opt.value();
127   return array.HasUnknownValue();
128 }
129 
Im2ColOutputLengthError(const PrimitivePtr & primitive,const std::vector<int64_t> & kernel_size,const std::vector<int64_t> & dilation,const std::vector<int64_t> & padding,const std::vector<int64_t> & stride,int64_t input_height,int64_t output_height,int64_t input_width,int64_t output_width)130 void Im2ColOutputLengthError(const PrimitivePtr &primitive, const std::vector<int64_t> &kernel_size,
131                              const std::vector<int64_t> &dilation, const std::vector<int64_t> &padding,
132                              const std::vector<int64_t> &stride, int64_t input_height, int64_t output_height,
133                              int64_t input_width, int64_t output_width) {
134   MS_EXCEPTION(ValueError) << "For " << primitive->name() << ", given input with spatial size (" << input_height << ", "
135                            << input_width << "), kernel_size=(" << kernel_size[0] << ", " << kernel_size[1]
136                            << "), dilation=(" << dilation[0] << ", " << dilation[1] << "), padding=(" << padding[0]
137                            << ", " << padding[1] << "), calculated shape of the array of sliding blocks as ("
138                            << output_height << ", " << output_width << "), which is too small (non-positive).";
139 }
140 
Col2ImCheckNInputPlane(const PrimitivePtr & primitive,int64_t n_input_plane,const std::vector<int64_t> & kernel_size)141 void Col2ImCheckNInputPlane(const PrimitivePtr &primitive, int64_t n_input_plane,
142                             const std::vector<int64_t> &kernel_size) {
143   if (MS_UNLIKELY(n_input_plane % (kernel_size[0] * kernel_size[1]) != 0)) {
144     MS_EXCEPTION(ValueError) << "For " << primitive->name()
145                              << ", expected size of input's dimension 1 to be divisible by the product of "
146                                 "kernel_size, but got input.size(1)="
147                              << n_input_plane << " and kernel_size=(" << kernel_size[0] << ", " << kernel_size[1]
148                              << ").";
149   }
150 }
151 
Col2ImCheckInputLength(const PrimitivePtr & primitive,const std::vector<int64_t> & output_size,const std::vector<int64_t> & kernel_size,const std::vector<int64_t> & dilation,const std::vector<int64_t> & padding,const std::vector<int64_t> & stride,int64_t input_length)152 void Col2ImCheckInputLength(const PrimitivePtr &primitive, const std::vector<int64_t> &output_size,
153                             const std::vector<int64_t> &kernel_size, const std::vector<int64_t> &dilation,
154                             const std::vector<int64_t> &padding, const std::vector<int64_t> &stride,
155                             int64_t input_length) {
156   auto [n_blocks_height, n_blocks_width] = Im2ColComputeOutputHeightAndWeight(
157     std::make_pair(output_size[0], output_size[1]), kernel_size, dilation, padding, stride);
158 
159   if (MS_UNLIKELY(n_blocks_height < 1 || n_blocks_width < 1)) {
160     MS_EXCEPTION(ValueError) << "For " << primitive->name() << ", given output_size=(" << output_size[0] << ", "
161                              << output_size[1] << "), kernel_size=(" << kernel_size[0] << ", " << kernel_size[1]
162                              << "), dilation=(" << dilation[0] << ", " << dilation[1] << "), padding=(" << padding[0]
163                              << ", " << padding[1] << "), stride=(" << stride[0] << ", " << stride[1]
164                              << "), calculated shape of the array of sliding blocks as (" << n_blocks_height << ", "
165                              << n_blocks_width << "), which is too small (non-positive)";
166   }
167 
168   if (MS_UNLIKELY(input_length != (n_blocks_height * n_blocks_width))) {
169     MS_EXCEPTION(ValueError)
170       << "For " << primitive->name() << ", given output_size=(" << output_size[0] << ", " << output_size[1]
171       << "), kernel_size=(" << kernel_size[0] << ", " << kernel_size[1] << "), dilation=(" << dilation[0] << ", "
172       << dilation[1] << "), padding=(" << padding[0] << ", " << padding[1] << "), stride=(" << stride[0] << ", "
173       << stride[1] << "), expected size of the input's dimension 2 to match the calculated number of sliding blocks "
174       << n_blocks_height << " * " << n_blocks_width << " = " << (n_blocks_height * n_blocks_width)
175       << ", but got input.size(2)=" << input_length << ".";
176   }
177 }
178 }  // namespace
179 
InferShape(const PrimitivePtr & primitive,const ValuePtrList & input_values) const180 ShapeArray Im2ColExtFuncImpl::InferShape(const PrimitivePtr &primitive, const ValuePtrList &input_values) const {
181   const auto &input_tensor = input_values[kIndex0]->cast<tensor::BaseTensorPtr>();
182   MS_EXCEPTION_IF_NULL(input_tensor);
183   const auto &input_shape = input_tensor->shape();
184   Im2ColAndCol2ImCommonCheckShape(primitive, input_shape, no_batch_rank_, batch_rank_);
185 
186   const size_t ele_num = 2;
187   auto kernel_size_array = GetArrayValue<int64_t>(input_values[kIndex1]).value();
188   auto kernel_size = Im2ColAndCol2ImCommonCheckArray(primitive, kernel_size_array, "kernel_size", ele_num, 0);
189   auto dilation_array = GetArrayValue<int64_t>(input_values[kIndex2]).value();
190   auto dilation = Im2ColAndCol2ImCommonCheckArray(primitive, dilation_array, "dilation", ele_num, 0);
191   auto padding_array = GetArrayValue<int64_t>(input_values[kIndex3]).value();
192   auto padding = Im2ColAndCol2ImCommonCheckArray(primitive, padding_array, "padding", ele_num, -1);
193   auto stride_array = GetArrayValue<int64_t>(input_values[kIndex4]).value();
194   auto stride = Im2ColAndCol2ImCommonCheckArray(primitive, stride_array, "stride", ele_num, 0);
195 
196   std::vector<int64_t> out_shape;
197   auto input_rank = input_shape.size();
198   if (input_rank == batch_rank_) {
199     out_shape.push_back(input_shape[kIndex0]);
200   }
201 
202   auto n_output_plane = input_shape[input_rank - kIndex3] * kernel_size[0] * kernel_size[1];
203   out_shape.push_back(n_output_plane);
204 
205   auto input_height = input_shape[input_rank - kIndex2];
206   auto input_width = input_shape[input_rank - kIndex1];
207   auto [output_height, output_width] = Im2ColComputeOutputHeightAndWeight(std::make_pair(input_height, input_width),
208                                                                           kernel_size, dilation, padding, stride);
209   if (MS_UNLIKELY(output_height < 1 || output_width < 1)) {
210     Im2ColOutputLengthError(primitive, kernel_size, dilation, padding, stride, input_height, output_height, input_width,
211                             output_width);
212   }
213   auto output_length = output_height * output_width;
214   out_shape.push_back(output_length);
215 
216   return {std::move(out_shape)};
217 }
218 
InferType(const PrimitivePtr & primitive,const ValuePtrList & input_values) const219 TypePtrList Im2ColExtFuncImpl::InferType(const PrimitivePtr &primitive, const ValuePtrList &input_values) const {
220   const auto &input_tensor = input_values[kIndex0]->cast<tensor::BaseTensorPtr>();
221   MS_EXCEPTION_IF_NULL(input_tensor);
222   return {input_tensor->Dtype()};
223 }
224 
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const225 BaseShapePtr Im2ColExtFuncImpl::InferShape(const PrimitivePtr &primitive,
226                                            const std::vector<AbstractBasePtr> &input_args) const {
227   const auto &input_shape = input_args[0]->GetShape()->GetShapeVector();
228   if (MS_UNLIKELY(IsDynamicRank(input_shape))) {
229     return std::make_shared<abstract::TensorShape>(std::vector<int64_t>{abstract::TensorShape::kShapeRankAny});
230   }
231 
232   auto is_dynamic_dim = [](int64_t dim_value) { return dim_value == abstract::TensorShape::kShapeDimAny; };
233   auto input_rank = input_shape.size();
234   std::vector<int64_t> out_shape;
235   if (input_rank == batch_rank_) {
236     out_shape.push_back(input_shape[kIndex0]);
237   }
238 
239   auto channle_dim = input_shape[input_rank - kIndex3];
240   auto kernel_size_opt = GetArrayValue<int64_t>(input_args[kIndex1]);
241   auto is_kernel_unknown = ArrayOptHasUnknownValue(kernel_size_opt);
242   if (MS_UNLIKELY(is_dynamic_dim(channle_dim) || is_kernel_unknown)) {
243     out_shape.push_back(abstract::TensorShape::kShapeDimAny);
244   } else {
245     const auto &kernel_size = kernel_size_opt.value();
246     auto n_output_plane = channle_dim * kernel_size[0] * kernel_size[1];
247     out_shape.push_back(n_output_plane);
248   }
249 
250   auto input_height = input_shape[input_rank - kIndex2];
251   auto input_width = input_shape[input_rank - kIndex1];
252 
253   auto dilation_opt = GetArrayValue<int64_t>(input_args[kIndex2]);
254   auto is_dilation_unknown = ArrayOptHasUnknownValue(dilation_opt);
255   auto padding_opt = GetArrayValue<int64_t>(input_args[kIndex3]);
256   auto is_padding_unknown = ArrayOptHasUnknownValue(padding_opt);
257   auto stride_opt = GetArrayValue<int64_t>(input_args[kIndex4]);
258   auto is_stride_unknown = ArrayOptHasUnknownValue(stride_opt);
259   if (MS_UNLIKELY(is_dynamic_dim(input_height) || is_dynamic_dim(input_width) || is_kernel_unknown ||
260                   is_padding_unknown || is_dilation_unknown || is_stride_unknown)) {
261     out_shape.push_back(abstract::TensorShape::kShapeDimAny);
262   } else {
263     const auto &kernel_size = kernel_size_opt.value().ToVector();
264     const auto &dilation = dilation_opt.value().ToVector();
265     const auto &padding = padding_opt.value().ToVector();
266     const auto &stride = stride_opt.value().ToVector();
267     auto [output_height, output_width] = Im2ColComputeOutputHeightAndWeight(std::make_pair(input_height, input_width),
268                                                                             kernel_size, dilation, padding, stride);
269     if (MS_UNLIKELY(output_height < 1 || output_width < 1)) {
270       Im2ColOutputLengthError(primitive, kernel_size, dilation, padding, stride, input_height, output_height,
271                               input_width, output_width);
272     }
273     auto output_length = output_height * output_width;
274     out_shape.push_back(output_length);
275   }
276 
277   return std::make_shared<abstract::TensorShape>(std::move(out_shape));
278 }
279 
InferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const280 TypePtr Im2ColExtFuncImpl::InferType(const PrimitivePtr &primitive,
281                                      const std::vector<AbstractBasePtr> &input_args) const {
282   MS_EXCEPTION_IF_NULL(input_args.at(0));
283   auto input_type = input_args[0]->GetType();
284   return input_type;
285 }
286 
CheckValidation(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const287 int32_t Im2ColExtFuncImpl::CheckValidation(const PrimitivePtr &primitive,
288                                            const std::vector<AbstractBasePtr> &input_args) const {
289   auto check_shape = Im2ColAndCol2ImCommonCheckShape(primitive, input_args[kIndex0], no_batch_rank_, batch_rank_);
290 
291   const size_t ele_num = 2;
292   static std::vector<std::string> arg_names{"kernel_size", "dilation", "padding", "stride"};
293   static std::vector<int64_t> min_values{0, 0, -1, 0};
294   auto check_pair =
295     Im2ColAndCol2ImCommonCheckValidation(primitive, input_args, arg_names, ele_num, min_values, kIndex1);
296   auto &check_args = check_pair.first;
297 
298   if (MS_UNLIKELY(check_shape == OP_CHECK_RETRY || check_args == OP_CHECK_RETRY)) {
299     return OP_CHECK_RETRY;
300   }
301   return OP_CHECK_SUCCESS;
302 }
303 
REGISTER_SIMPLE_INFER(kNameIm2ColExt,Im2ColExtFuncImpl)304 REGISTER_SIMPLE_INFER(kNameIm2ColExt, Im2ColExtFuncImpl)
305 REGISTER_SIMPLE_INFER(kNameCol2ImGrad, Col2ImGradFuncImpl)
306 
307 BaseShapePtr Col2ImExtFuncImpl::InferShape(const PrimitivePtr &primitive,
308                                            const std::vector<AbstractBasePtr> &input_args) const {
309   const auto &input_shape = input_args[0]->GetShape()->GetShapeVector();
310   if (MS_UNLIKELY(IsDynamicRank(input_shape))) {
311     return std::make_shared<abstract::TensorShape>(std::vector<int64_t>{abstract::TensorShape::kShapeRankAny});
312   }
313 
314   std::vector<int64_t> out_shape;
315   if (input_shape.size() == batch_rank_) {
316     out_shape.push_back(input_shape[kIndex0]);
317   }
318 
319   auto n_input_plane = input_shape[input_shape.size() - kIndex2];
320   auto kernel_size_opt = GetArrayValue<int64_t>(input_args[kIndex2]);
321   if (MS_UNLIKELY(n_input_plane == abstract::Shape::kShapeDimAny || ArrayOptHasUnknownValue(kernel_size_opt))) {
322     out_shape.push_back(abstract::Shape::kShapeDimAny);
323   } else {
324     const auto &kernel_size = kernel_size_opt.value();
325     out_shape.emplace_back(n_input_plane / (kernel_size[0] * kernel_size[1]));
326   }
327 
328   auto output_size_opt = GetArrayValue<int64_t>(input_args[kIndex1]);
329   if (MS_UNLIKELY(!output_size_opt.has_value())) {
330     out_shape.insert(out_shape.end(), kIndex2, abstract::Shape::kShapeDimAny);
331   } else {
332     auto output_size = Im2ColAndCol2ImCommonCheckArray(primitive, output_size_opt.value(), "output_size", kIndex2, 0);
333     out_shape.insert(out_shape.end(), output_size.begin(), output_size.end());
334   }
335 
336   return std::make_shared<abstract::TensorShape>(std::move(out_shape));
337 }
338 
InferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const339 TypePtr Col2ImExtFuncImpl::InferType(const PrimitivePtr &primitive,
340                                      const std::vector<AbstractBasePtr> &input_args) const {
341   MS_EXCEPTION_IF_NULL(input_args.at(0));
342   auto input_type = input_args[0]->GetType();
343   return input_type;
344 }
345 
CheckValidation(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const346 int32_t Col2ImExtFuncImpl::CheckValidation(const PrimitivePtr &primitive,
347                                            const std::vector<AbstractBasePtr> &input_args) const {
348   auto check_shape = Im2ColAndCol2ImCommonCheckShape(primitive, input_args[kIndex0], no_batch_rank_, batch_rank_);
349   if (MS_UNLIKELY(check_shape == OP_CHECK_RETRY)) {
350     return OP_CHECK_RETRY;
351   }
352 
353   const auto &input_shape = input_args[0]->GetShape()->GetShapeVector();
354   auto n_input_plane = input_shape[input_shape.size() - kIndex2];
355   if (MS_UNLIKELY(n_input_plane == abstract::TensorShape::kShapeDimAny)) {
356     return OP_CHECK_RETRY;
357   }
358 
359   auto kernel_size_opt = GetArrayValue<int64_t>(input_args[kIndex2]);
360   if (MS_UNLIKELY(!kernel_size_opt.has_value())) {
361     return OP_CHECK_RETRY;
362   }
363 
364   const size_t ele_num = 2;
365   const auto &kernel_size_array = kernel_size_opt.value();
366   auto kernel_size = Im2ColAndCol2ImCommonCheckArray(primitive, kernel_size_array, "kernel_size", ele_num, 0);
367   if (MS_UNLIKELY(kernel_size_array.HasUnknownValue())) {
368     return OP_CHECK_RETRY;
369   }
370 
371   Col2ImCheckNInputPlane(primitive, n_input_plane, kernel_size);
372 
373   auto input_length = input_shape[input_shape.size() - kIndex1];
374   auto output_size_opt = GetArrayValue<int64_t>(input_args[kIndex1]);
375   if (MS_UNLIKELY(input_length == abstract::TensorShape::kShapeDimAny || ArrayOptHasUnknownValue(output_size_opt))) {
376     return OP_CHECK_RETRY;
377   }
378 
379   const auto &output_size_array = output_size_opt.value();
380   auto output_size = Im2ColAndCol2ImCommonCheckArray(primitive, output_size_array, "output_size", ele_num, 0);
381 
382   static std::vector<std::string> arg_names{"dilation", "padding", "stride"};
383   static std::vector<int64_t> min_values{0, -1, 0};
384   auto check_pair =
385     Im2ColAndCol2ImCommonCheckValidation(primitive, input_args, arg_names, ele_num, min_values, kIndex3);
386   auto &check_other_args = check_pair.first;
387   if (MS_UNLIKELY(check_other_args != OP_CHECK_SUCCESS)) {
388     return OP_CHECK_RETRY;
389   }
390 
391   const auto &arrays = check_pair.second;
392   const auto &dilation = arrays[kIndex0].ToVector();
393   const auto &padding = arrays[kIndex1].ToVector();
394   const auto &stride = arrays[kIndex2].ToVector();
395   Col2ImCheckInputLength(primitive, output_size, kernel_size, dilation, padding, stride, input_length);
396 
397   return OP_CHECK_SUCCESS;
398 }
399 
InferShape(const PrimitivePtr & primitive,const ValuePtrList & input_values) const400 ShapeArray Col2ImExtFuncImpl::InferShape(const PrimitivePtr &primitive, const ValuePtrList &input_values) const {
401   const auto &input_tensor = input_values[kIndex0]->cast<tensor::BaseTensorPtr>();
402   MS_EXCEPTION_IF_NULL(input_tensor);
403   const auto &input_shape = input_tensor->shape();
404   Im2ColAndCol2ImCommonCheckShape(primitive, input_shape, no_batch_rank_, batch_rank_);
405 
406   auto input_rank = input_shape.size();
407   std::vector<int64_t> out_shape;
408   if (input_rank == batch_rank_) {
409     out_shape.push_back(input_shape[kIndex0]);
410   }
411 
412   const size_t ele_num = 2;
413   auto output_size_array = GetArrayValue<int64_t>(input_values[kIndex1]).value();
414   auto output_size = Im2ColAndCol2ImCommonCheckArray(primitive, output_size_array, "output_size", ele_num, 0);
415   auto kernel_size_array = GetArrayValue<int64_t>(input_values[kIndex2]).value();
416   auto kernel_size = Im2ColAndCol2ImCommonCheckArray(primitive, kernel_size_array, "kernel_size", ele_num, 0);
417   auto dilation_array = GetArrayValue<int64_t>(input_values[kIndex3]).value();
418   auto dilation = Im2ColAndCol2ImCommonCheckArray(primitive, dilation_array, "dilation", ele_num, 0);
419   auto padding_array = GetArrayValue<int64_t>(input_values[kIndex4]).value();
420   auto padding = Im2ColAndCol2ImCommonCheckArray(primitive, padding_array, "padding", ele_num, -1);
421   auto stride_array = GetArrayValue<int64_t>(input_values[kIndex5]).value();
422   auto stride = Im2ColAndCol2ImCommonCheckArray(primitive, stride_array, "stride", ele_num, 0);
423 
424   auto n_input_plane = input_shape[input_rank - kIndex2];
425   Col2ImCheckNInputPlane(primitive, n_input_plane, kernel_size);
426   out_shape.emplace_back(n_input_plane / (kernel_size[0] * kernel_size[1]));
427 
428   auto input_length = input_shape[input_rank - kIndex1];
429   Col2ImCheckInputLength(primitive, output_size, kernel_size, dilation, padding, stride, input_length);
430   out_shape.insert(out_shape.end(), output_size.begin(), output_size.end());
431 
432   return {std::move(out_shape)};
433 }
434 
InferType(const PrimitivePtr & primitive,const ValuePtrList & input_values) const435 TypePtrList Col2ImExtFuncImpl::InferType(const PrimitivePtr &primitive, const ValuePtrList &input_values) const {
436   const auto &input_tensor = input_values[kIndex0]->cast<tensor::BaseTensorPtr>();
437   MS_EXCEPTION_IF_NULL(input_tensor);
438   return {input_tensor->Dtype()};
439 }
440 
441 REGISTER_SIMPLE_INFER(kNameCol2ImExt, Col2ImExtFuncImpl)
442 }  // namespace ops
443 }  // namespace mindspore
444