1 /**
2 * Copyright 2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <memory>
18
19 #include "ops/op_utils.h"
20 #include "utils/check_convert_utils.h"
21 #include "ops/ops_func_impl/im2col_ext.h"
22 #include "ops/ops_func_impl/col2im_ext.h"
23 #include "ops/ops_func_impl/col2im_grad.h"
24 #include "ops/ops_func_impl/simple_infer.h"
25
26 namespace mindspore {
27 namespace ops {
28 namespace {
Im2ColComputeOutputHeightAndWeight(const std::pair<int64_t,int64_t> & input_hw,const std::vector<int64_t> & kernel_size,const std::vector<int64_t> & dilation,const std::vector<int64_t> & padding,const std::vector<int64_t> & stride)29 std::pair<int64_t, int64_t> Im2ColComputeOutputHeightAndWeight(const std::pair<int64_t, int64_t> &input_hw,
30 const std::vector<int64_t> &kernel_size,
31 const std::vector<int64_t> &dilation,
32 const std::vector<int64_t> &padding,
33 const std::vector<int64_t> &stride) {
34 auto &[input_height, input_width] = input_hw;
35 auto kernel_height = kernel_size[0];
36 auto kernel_width = kernel_size[1];
37 auto dilation_height = dilation[0];
38 auto dilation_width = dilation[1];
39 auto pad_height = padding[0];
40 auto pad_width = padding[1];
41 auto stride_height = stride[0];
42 auto stride_width = stride[1];
43
44 int64_t output_height =
45 (input_height + 2 * pad_height - (dilation_height * (kernel_height - 1) + 1)) / stride_height + 1;
46 int64_t output_width = (input_width + 2 * pad_width - (dilation_width * (kernel_width - 1) + 1)) / stride_width + 1;
47
48 return std::make_pair(output_height, output_width);
49 }
50
Im2ColAndCol2ImCommonCheckArray(const PrimitivePtr & primitive,const ArrayValue<int64_t> & array,const std::string & arg_name,size_t ele_num,int64_t min_value)51 std::vector<int64_t> Im2ColAndCol2ImCommonCheckArray(const PrimitivePtr &primitive, const ArrayValue<int64_t> &array,
52 const std::string &arg_name, size_t ele_num, int64_t min_value) {
53 std::vector<int64_t> values(ele_num, abstract::Shape::kShapeDimAny);
54 MS_CHECK_VALUE(array.size() == ele_num,
55 CheckAndConvertUtils::FormatCheckIntegerMsg("number of " + arg_name, SizeToLong(array.size()), kEqual,
56 SizeToLong(ele_num), primitive));
57 for (size_t i = 0; i < array.size(); ++i) {
58 if (MS_UNLIKELY(array.IsValueUnknown(i))) {
59 continue;
60 }
61 MS_CHECK_VALUE(array[i] > min_value,
62 CheckAndConvertUtils::FormatCheckIntegerMsg(arg_name, array[i], kGreaterThan, min_value, primitive));
63 values[i] = array[i];
64 }
65 return values;
66 }
67
Im2ColAndCol2ImCommonCheckValidation(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args,const std::vector<std::string> & arg_names,size_t ele_num,const std::vector<int64_t> & min_values,size_t start_idx)68 std::pair<int32_t, std::vector<ArrayValue<int64_t>>> Im2ColAndCol2ImCommonCheckValidation(
69 const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args,
70 const std::vector<std::string> &arg_names, size_t ele_num, const std::vector<int64_t> &min_values, size_t start_idx) {
71 assert((input_args.size() - start_idx) == arg_names.size());
72
73 std::vector<ArrayValue<int64_t>> arrays;
74 for (size_t i = 0; i < arg_names.size(); ++i) {
75 auto array_opt = GetArrayValue<int64_t>(input_args[start_idx + i]);
76 if (MS_UNLIKELY(!array_opt.has_value())) {
77 return std::make_pair(OP_CHECK_RETRY, std::move(arrays));
78 }
79 auto array = std::move(array_opt.value());
80 (void)Im2ColAndCol2ImCommonCheckArray(primitive, array, arg_names[i], ele_num, min_values[i]);
81 arrays.emplace_back(std::move(array));
82 }
83
84 return std::make_pair(OP_CHECK_SUCCESS, std::move(arrays));
85 }
86
Im2ColAndCol2ImCommonCheckShape(const PrimitivePtr & primitive,const std::vector<int64_t> & input_shape,size_t no_batch_rank,size_t batch_rank)87 void Im2ColAndCol2ImCommonCheckShape(const PrimitivePtr &primitive, const std::vector<int64_t> &input_shape,
88 size_t no_batch_rank, size_t batch_rank) {
89 auto input_rank = input_shape.size();
90 MS_CHECK_VALUE(input_rank >= no_batch_rank && input_rank <= batch_rank,
91 CheckAndConvertUtils::FormatCheckInRangeMsg("input rank", input_rank, kIncludeBoth,
92 {no_batch_rank, batch_rank}, primitive));
93
94 auto ShapeElementCheckFunc = [](int64_t dim_value) {
95 if (dim_value != abstract::TensorShape::kShapeDimAny && dim_value <= 0) {
96 return false;
97 }
98 return true;
99 };
100 auto first_dim_after_batch = input_shape.size() == batch_rank ? kIndex1 : kIndex0;
101 auto check_result =
102 std::all_of(input_shape.begin() + first_dim_after_batch, input_shape.end(), ShapeElementCheckFunc);
103 if (MS_UNLIKELY(!check_result)) {
104 MS_EXCEPTION(ValueError)
105 << "For " << primitive->name() << ", expected " << no_batch_rank << "D or " << batch_rank
106 << "D (batch mode) tensor with possibly 0 batch size and other non-zero dimensions for input, but got "
107 << input_shape;
108 }
109 }
110
Im2ColAndCol2ImCommonCheckShape(const PrimitivePtr & primitive,const AbstractBasePtr & input_arg,size_t no_batch_rank,size_t batch_rank)111 int32_t Im2ColAndCol2ImCommonCheckShape(const PrimitivePtr &primitive, const AbstractBasePtr &input_arg,
112 size_t no_batch_rank, size_t batch_rank) {
113 MS_EXCEPTION_IF_NULL(input_arg);
114 const auto &input_shape = input_arg->GetShape()->GetShapeVector();
115 if (MS_UNLIKELY(IsDynamicRank(input_shape))) {
116 return OP_CHECK_RETRY;
117 }
118 Im2ColAndCol2ImCommonCheckShape(primitive, input_shape, no_batch_rank, batch_rank);
119 return OP_CHECK_SUCCESS;
120 }
121
ArrayOptHasUnknownValue(const std::optional<ArrayValue<int64_t>> & array_opt)122 bool ArrayOptHasUnknownValue(const std::optional<ArrayValue<int64_t>> &array_opt) {
123 if (MS_UNLIKELY(!array_opt.has_value())) {
124 return true;
125 }
126 const auto &array = array_opt.value();
127 return array.HasUnknownValue();
128 }
129
Im2ColOutputLengthError(const PrimitivePtr & primitive,const std::vector<int64_t> & kernel_size,const std::vector<int64_t> & dilation,const std::vector<int64_t> & padding,const std::vector<int64_t> & stride,int64_t input_height,int64_t output_height,int64_t input_width,int64_t output_width)130 void Im2ColOutputLengthError(const PrimitivePtr &primitive, const std::vector<int64_t> &kernel_size,
131 const std::vector<int64_t> &dilation, const std::vector<int64_t> &padding,
132 const std::vector<int64_t> &stride, int64_t input_height, int64_t output_height,
133 int64_t input_width, int64_t output_width) {
134 MS_EXCEPTION(ValueError) << "For " << primitive->name() << ", given input with spatial size (" << input_height << ", "
135 << input_width << "), kernel_size=(" << kernel_size[0] << ", " << kernel_size[1]
136 << "), dilation=(" << dilation[0] << ", " << dilation[1] << "), padding=(" << padding[0]
137 << ", " << padding[1] << "), calculated shape of the array of sliding blocks as ("
138 << output_height << ", " << output_width << "), which is too small (non-positive).";
139 }
140
Col2ImCheckNInputPlane(const PrimitivePtr & primitive,int64_t n_input_plane,const std::vector<int64_t> & kernel_size)141 void Col2ImCheckNInputPlane(const PrimitivePtr &primitive, int64_t n_input_plane,
142 const std::vector<int64_t> &kernel_size) {
143 if (MS_UNLIKELY(n_input_plane % (kernel_size[0] * kernel_size[1]) != 0)) {
144 MS_EXCEPTION(ValueError) << "For " << primitive->name()
145 << ", expected size of input's dimension 1 to be divisible by the product of "
146 "kernel_size, but got input.size(1)="
147 << n_input_plane << " and kernel_size=(" << kernel_size[0] << ", " << kernel_size[1]
148 << ").";
149 }
150 }
151
Col2ImCheckInputLength(const PrimitivePtr & primitive,const std::vector<int64_t> & output_size,const std::vector<int64_t> & kernel_size,const std::vector<int64_t> & dilation,const std::vector<int64_t> & padding,const std::vector<int64_t> & stride,int64_t input_length)152 void Col2ImCheckInputLength(const PrimitivePtr &primitive, const std::vector<int64_t> &output_size,
153 const std::vector<int64_t> &kernel_size, const std::vector<int64_t> &dilation,
154 const std::vector<int64_t> &padding, const std::vector<int64_t> &stride,
155 int64_t input_length) {
156 auto [n_blocks_height, n_blocks_width] = Im2ColComputeOutputHeightAndWeight(
157 std::make_pair(output_size[0], output_size[1]), kernel_size, dilation, padding, stride);
158
159 if (MS_UNLIKELY(n_blocks_height < 1 || n_blocks_width < 1)) {
160 MS_EXCEPTION(ValueError) << "For " << primitive->name() << ", given output_size=(" << output_size[0] << ", "
161 << output_size[1] << "), kernel_size=(" << kernel_size[0] << ", " << kernel_size[1]
162 << "), dilation=(" << dilation[0] << ", " << dilation[1] << "), padding=(" << padding[0]
163 << ", " << padding[1] << "), stride=(" << stride[0] << ", " << stride[1]
164 << "), calculated shape of the array of sliding blocks as (" << n_blocks_height << ", "
165 << n_blocks_width << "), which is too small (non-positive)";
166 }
167
168 if (MS_UNLIKELY(input_length != (n_blocks_height * n_blocks_width))) {
169 MS_EXCEPTION(ValueError)
170 << "For " << primitive->name() << ", given output_size=(" << output_size[0] << ", " << output_size[1]
171 << "), kernel_size=(" << kernel_size[0] << ", " << kernel_size[1] << "), dilation=(" << dilation[0] << ", "
172 << dilation[1] << "), padding=(" << padding[0] << ", " << padding[1] << "), stride=(" << stride[0] << ", "
173 << stride[1] << "), expected size of the input's dimension 2 to match the calculated number of sliding blocks "
174 << n_blocks_height << " * " << n_blocks_width << " = " << (n_blocks_height * n_blocks_width)
175 << ", but got input.size(2)=" << input_length << ".";
176 }
177 }
178 } // namespace
179
InferShape(const PrimitivePtr & primitive,const ValuePtrList & input_values) const180 ShapeArray Im2ColExtFuncImpl::InferShape(const PrimitivePtr &primitive, const ValuePtrList &input_values) const {
181 const auto &input_tensor = input_values[kIndex0]->cast<tensor::BaseTensorPtr>();
182 MS_EXCEPTION_IF_NULL(input_tensor);
183 const auto &input_shape = input_tensor->shape();
184 Im2ColAndCol2ImCommonCheckShape(primitive, input_shape, no_batch_rank_, batch_rank_);
185
186 const size_t ele_num = 2;
187 auto kernel_size_array = GetArrayValue<int64_t>(input_values[kIndex1]).value();
188 auto kernel_size = Im2ColAndCol2ImCommonCheckArray(primitive, kernel_size_array, "kernel_size", ele_num, 0);
189 auto dilation_array = GetArrayValue<int64_t>(input_values[kIndex2]).value();
190 auto dilation = Im2ColAndCol2ImCommonCheckArray(primitive, dilation_array, "dilation", ele_num, 0);
191 auto padding_array = GetArrayValue<int64_t>(input_values[kIndex3]).value();
192 auto padding = Im2ColAndCol2ImCommonCheckArray(primitive, padding_array, "padding", ele_num, -1);
193 auto stride_array = GetArrayValue<int64_t>(input_values[kIndex4]).value();
194 auto stride = Im2ColAndCol2ImCommonCheckArray(primitive, stride_array, "stride", ele_num, 0);
195
196 std::vector<int64_t> out_shape;
197 auto input_rank = input_shape.size();
198 if (input_rank == batch_rank_) {
199 out_shape.push_back(input_shape[kIndex0]);
200 }
201
202 auto n_output_plane = input_shape[input_rank - kIndex3] * kernel_size[0] * kernel_size[1];
203 out_shape.push_back(n_output_plane);
204
205 auto input_height = input_shape[input_rank - kIndex2];
206 auto input_width = input_shape[input_rank - kIndex1];
207 auto [output_height, output_width] = Im2ColComputeOutputHeightAndWeight(std::make_pair(input_height, input_width),
208 kernel_size, dilation, padding, stride);
209 if (MS_UNLIKELY(output_height < 1 || output_width < 1)) {
210 Im2ColOutputLengthError(primitive, kernel_size, dilation, padding, stride, input_height, output_height, input_width,
211 output_width);
212 }
213 auto output_length = output_height * output_width;
214 out_shape.push_back(output_length);
215
216 return {std::move(out_shape)};
217 }
218
InferType(const PrimitivePtr & primitive,const ValuePtrList & input_values) const219 TypePtrList Im2ColExtFuncImpl::InferType(const PrimitivePtr &primitive, const ValuePtrList &input_values) const {
220 const auto &input_tensor = input_values[kIndex0]->cast<tensor::BaseTensorPtr>();
221 MS_EXCEPTION_IF_NULL(input_tensor);
222 return {input_tensor->Dtype()};
223 }
224
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const225 BaseShapePtr Im2ColExtFuncImpl::InferShape(const PrimitivePtr &primitive,
226 const std::vector<AbstractBasePtr> &input_args) const {
227 const auto &input_shape = input_args[0]->GetShape()->GetShapeVector();
228 if (MS_UNLIKELY(IsDynamicRank(input_shape))) {
229 return std::make_shared<abstract::TensorShape>(std::vector<int64_t>{abstract::TensorShape::kShapeRankAny});
230 }
231
232 auto is_dynamic_dim = [](int64_t dim_value) { return dim_value == abstract::TensorShape::kShapeDimAny; };
233 auto input_rank = input_shape.size();
234 std::vector<int64_t> out_shape;
235 if (input_rank == batch_rank_) {
236 out_shape.push_back(input_shape[kIndex0]);
237 }
238
239 auto channle_dim = input_shape[input_rank - kIndex3];
240 auto kernel_size_opt = GetArrayValue<int64_t>(input_args[kIndex1]);
241 auto is_kernel_unknown = ArrayOptHasUnknownValue(kernel_size_opt);
242 if (MS_UNLIKELY(is_dynamic_dim(channle_dim) || is_kernel_unknown)) {
243 out_shape.push_back(abstract::TensorShape::kShapeDimAny);
244 } else {
245 const auto &kernel_size = kernel_size_opt.value();
246 auto n_output_plane = channle_dim * kernel_size[0] * kernel_size[1];
247 out_shape.push_back(n_output_plane);
248 }
249
250 auto input_height = input_shape[input_rank - kIndex2];
251 auto input_width = input_shape[input_rank - kIndex1];
252
253 auto dilation_opt = GetArrayValue<int64_t>(input_args[kIndex2]);
254 auto is_dilation_unknown = ArrayOptHasUnknownValue(dilation_opt);
255 auto padding_opt = GetArrayValue<int64_t>(input_args[kIndex3]);
256 auto is_padding_unknown = ArrayOptHasUnknownValue(padding_opt);
257 auto stride_opt = GetArrayValue<int64_t>(input_args[kIndex4]);
258 auto is_stride_unknown = ArrayOptHasUnknownValue(stride_opt);
259 if (MS_UNLIKELY(is_dynamic_dim(input_height) || is_dynamic_dim(input_width) || is_kernel_unknown ||
260 is_padding_unknown || is_dilation_unknown || is_stride_unknown)) {
261 out_shape.push_back(abstract::TensorShape::kShapeDimAny);
262 } else {
263 const auto &kernel_size = kernel_size_opt.value().ToVector();
264 const auto &dilation = dilation_opt.value().ToVector();
265 const auto &padding = padding_opt.value().ToVector();
266 const auto &stride = stride_opt.value().ToVector();
267 auto [output_height, output_width] = Im2ColComputeOutputHeightAndWeight(std::make_pair(input_height, input_width),
268 kernel_size, dilation, padding, stride);
269 if (MS_UNLIKELY(output_height < 1 || output_width < 1)) {
270 Im2ColOutputLengthError(primitive, kernel_size, dilation, padding, stride, input_height, output_height,
271 input_width, output_width);
272 }
273 auto output_length = output_height * output_width;
274 out_shape.push_back(output_length);
275 }
276
277 return std::make_shared<abstract::TensorShape>(std::move(out_shape));
278 }
279
InferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const280 TypePtr Im2ColExtFuncImpl::InferType(const PrimitivePtr &primitive,
281 const std::vector<AbstractBasePtr> &input_args) const {
282 MS_EXCEPTION_IF_NULL(input_args.at(0));
283 auto input_type = input_args[0]->GetType();
284 return input_type;
285 }
286
CheckValidation(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const287 int32_t Im2ColExtFuncImpl::CheckValidation(const PrimitivePtr &primitive,
288 const std::vector<AbstractBasePtr> &input_args) const {
289 auto check_shape = Im2ColAndCol2ImCommonCheckShape(primitive, input_args[kIndex0], no_batch_rank_, batch_rank_);
290
291 const size_t ele_num = 2;
292 static std::vector<std::string> arg_names{"kernel_size", "dilation", "padding", "stride"};
293 static std::vector<int64_t> min_values{0, 0, -1, 0};
294 auto check_pair =
295 Im2ColAndCol2ImCommonCheckValidation(primitive, input_args, arg_names, ele_num, min_values, kIndex1);
296 auto &check_args = check_pair.first;
297
298 if (MS_UNLIKELY(check_shape == OP_CHECK_RETRY || check_args == OP_CHECK_RETRY)) {
299 return OP_CHECK_RETRY;
300 }
301 return OP_CHECK_SUCCESS;
302 }
303
REGISTER_SIMPLE_INFER(kNameIm2ColExt,Im2ColExtFuncImpl)304 REGISTER_SIMPLE_INFER(kNameIm2ColExt, Im2ColExtFuncImpl)
305 REGISTER_SIMPLE_INFER(kNameCol2ImGrad, Col2ImGradFuncImpl)
306
307 BaseShapePtr Col2ImExtFuncImpl::InferShape(const PrimitivePtr &primitive,
308 const std::vector<AbstractBasePtr> &input_args) const {
309 const auto &input_shape = input_args[0]->GetShape()->GetShapeVector();
310 if (MS_UNLIKELY(IsDynamicRank(input_shape))) {
311 return std::make_shared<abstract::TensorShape>(std::vector<int64_t>{abstract::TensorShape::kShapeRankAny});
312 }
313
314 std::vector<int64_t> out_shape;
315 if (input_shape.size() == batch_rank_) {
316 out_shape.push_back(input_shape[kIndex0]);
317 }
318
319 auto n_input_plane = input_shape[input_shape.size() - kIndex2];
320 auto kernel_size_opt = GetArrayValue<int64_t>(input_args[kIndex2]);
321 if (MS_UNLIKELY(n_input_plane == abstract::Shape::kShapeDimAny || ArrayOptHasUnknownValue(kernel_size_opt))) {
322 out_shape.push_back(abstract::Shape::kShapeDimAny);
323 } else {
324 const auto &kernel_size = kernel_size_opt.value();
325 out_shape.emplace_back(n_input_plane / (kernel_size[0] * kernel_size[1]));
326 }
327
328 auto output_size_opt = GetArrayValue<int64_t>(input_args[kIndex1]);
329 if (MS_UNLIKELY(!output_size_opt.has_value())) {
330 out_shape.insert(out_shape.end(), kIndex2, abstract::Shape::kShapeDimAny);
331 } else {
332 auto output_size = Im2ColAndCol2ImCommonCheckArray(primitive, output_size_opt.value(), "output_size", kIndex2, 0);
333 out_shape.insert(out_shape.end(), output_size.begin(), output_size.end());
334 }
335
336 return std::make_shared<abstract::TensorShape>(std::move(out_shape));
337 }
338
InferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const339 TypePtr Col2ImExtFuncImpl::InferType(const PrimitivePtr &primitive,
340 const std::vector<AbstractBasePtr> &input_args) const {
341 MS_EXCEPTION_IF_NULL(input_args.at(0));
342 auto input_type = input_args[0]->GetType();
343 return input_type;
344 }
345
CheckValidation(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const346 int32_t Col2ImExtFuncImpl::CheckValidation(const PrimitivePtr &primitive,
347 const std::vector<AbstractBasePtr> &input_args) const {
348 auto check_shape = Im2ColAndCol2ImCommonCheckShape(primitive, input_args[kIndex0], no_batch_rank_, batch_rank_);
349 if (MS_UNLIKELY(check_shape == OP_CHECK_RETRY)) {
350 return OP_CHECK_RETRY;
351 }
352
353 const auto &input_shape = input_args[0]->GetShape()->GetShapeVector();
354 auto n_input_plane = input_shape[input_shape.size() - kIndex2];
355 if (MS_UNLIKELY(n_input_plane == abstract::TensorShape::kShapeDimAny)) {
356 return OP_CHECK_RETRY;
357 }
358
359 auto kernel_size_opt = GetArrayValue<int64_t>(input_args[kIndex2]);
360 if (MS_UNLIKELY(!kernel_size_opt.has_value())) {
361 return OP_CHECK_RETRY;
362 }
363
364 const size_t ele_num = 2;
365 const auto &kernel_size_array = kernel_size_opt.value();
366 auto kernel_size = Im2ColAndCol2ImCommonCheckArray(primitive, kernel_size_array, "kernel_size", ele_num, 0);
367 if (MS_UNLIKELY(kernel_size_array.HasUnknownValue())) {
368 return OP_CHECK_RETRY;
369 }
370
371 Col2ImCheckNInputPlane(primitive, n_input_plane, kernel_size);
372
373 auto input_length = input_shape[input_shape.size() - kIndex1];
374 auto output_size_opt = GetArrayValue<int64_t>(input_args[kIndex1]);
375 if (MS_UNLIKELY(input_length == abstract::TensorShape::kShapeDimAny || ArrayOptHasUnknownValue(output_size_opt))) {
376 return OP_CHECK_RETRY;
377 }
378
379 const auto &output_size_array = output_size_opt.value();
380 auto output_size = Im2ColAndCol2ImCommonCheckArray(primitive, output_size_array, "output_size", ele_num, 0);
381
382 static std::vector<std::string> arg_names{"dilation", "padding", "stride"};
383 static std::vector<int64_t> min_values{0, -1, 0};
384 auto check_pair =
385 Im2ColAndCol2ImCommonCheckValidation(primitive, input_args, arg_names, ele_num, min_values, kIndex3);
386 auto &check_other_args = check_pair.first;
387 if (MS_UNLIKELY(check_other_args != OP_CHECK_SUCCESS)) {
388 return OP_CHECK_RETRY;
389 }
390
391 const auto &arrays = check_pair.second;
392 const auto &dilation = arrays[kIndex0].ToVector();
393 const auto &padding = arrays[kIndex1].ToVector();
394 const auto &stride = arrays[kIndex2].ToVector();
395 Col2ImCheckInputLength(primitive, output_size, kernel_size, dilation, padding, stride, input_length);
396
397 return OP_CHECK_SUCCESS;
398 }
399
InferShape(const PrimitivePtr & primitive,const ValuePtrList & input_values) const400 ShapeArray Col2ImExtFuncImpl::InferShape(const PrimitivePtr &primitive, const ValuePtrList &input_values) const {
401 const auto &input_tensor = input_values[kIndex0]->cast<tensor::BaseTensorPtr>();
402 MS_EXCEPTION_IF_NULL(input_tensor);
403 const auto &input_shape = input_tensor->shape();
404 Im2ColAndCol2ImCommonCheckShape(primitive, input_shape, no_batch_rank_, batch_rank_);
405
406 auto input_rank = input_shape.size();
407 std::vector<int64_t> out_shape;
408 if (input_rank == batch_rank_) {
409 out_shape.push_back(input_shape[kIndex0]);
410 }
411
412 const size_t ele_num = 2;
413 auto output_size_array = GetArrayValue<int64_t>(input_values[kIndex1]).value();
414 auto output_size = Im2ColAndCol2ImCommonCheckArray(primitive, output_size_array, "output_size", ele_num, 0);
415 auto kernel_size_array = GetArrayValue<int64_t>(input_values[kIndex2]).value();
416 auto kernel_size = Im2ColAndCol2ImCommonCheckArray(primitive, kernel_size_array, "kernel_size", ele_num, 0);
417 auto dilation_array = GetArrayValue<int64_t>(input_values[kIndex3]).value();
418 auto dilation = Im2ColAndCol2ImCommonCheckArray(primitive, dilation_array, "dilation", ele_num, 0);
419 auto padding_array = GetArrayValue<int64_t>(input_values[kIndex4]).value();
420 auto padding = Im2ColAndCol2ImCommonCheckArray(primitive, padding_array, "padding", ele_num, -1);
421 auto stride_array = GetArrayValue<int64_t>(input_values[kIndex5]).value();
422 auto stride = Im2ColAndCol2ImCommonCheckArray(primitive, stride_array, "stride", ele_num, 0);
423
424 auto n_input_plane = input_shape[input_rank - kIndex2];
425 Col2ImCheckNInputPlane(primitive, n_input_plane, kernel_size);
426 out_shape.emplace_back(n_input_plane / (kernel_size[0] * kernel_size[1]));
427
428 auto input_length = input_shape[input_rank - kIndex1];
429 Col2ImCheckInputLength(primitive, output_size, kernel_size, dilation, padding, stride, input_length);
430 out_shape.insert(out_shape.end(), output_size.begin(), output_size.end());
431
432 return {std::move(out_shape)};
433 }
434
InferType(const PrimitivePtr & primitive,const ValuePtrList & input_values) const435 TypePtrList Col2ImExtFuncImpl::InferType(const PrimitivePtr &primitive, const ValuePtrList &input_values) const {
436 const auto &input_tensor = input_values[kIndex0]->cast<tensor::BaseTensorPtr>();
437 MS_EXCEPTION_IF_NULL(input_tensor);
438 return {input_tensor->Dtype()};
439 }
440
441 REGISTER_SIMPLE_INFER(kNameCol2ImExt, Col2ImExtFuncImpl)
442 } // namespace ops
443 } // namespace mindspore
444