1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "nnacl/infer/conv2d_infer.h"
17 #include "nnacl/infer/infer_register.h"
18
ConvInferShape(int input_h,int input_w,int * output_h,int * output_w,ConvParameter * param)19 int ConvInferShape(int input_h, int input_w, int *output_h, int *output_w, ConvParameter *param) {
20 int kernel_w = param->kernel_w_;
21 int kernel_h = param->kernel_h_;
22 int stride_w = param->stride_w_;
23 int stride_h = param->stride_h_;
24 int dilate_w = param->dilation_w_;
25 int dilate_h = param->dilation_h_;
26
27 if (stride_w == 0 || stride_h == 0) {
28 return NNACL_PARAM_INVALID;
29 }
30 if (INT_MUL_OVERFLOW(kernel_h, dilate_h) || INT_MUL_OVERFLOW(kernel_w, dilate_w)) {
31 return NNACL_ERRCODE_MUL_OVERFLOW;
32 }
33 if (param->pad_mode_ == Pad_same) { // maybe error
34 *output_w = ceil((float)(input_w) / (float)(stride_w));
35 *output_h = ceil((float)(input_h) / (float)(stride_h));
36 int pad_h_all = ((*output_h - 1) * stride_h + (kernel_h - 1) * dilate_h + 1 - input_h);
37 int pad_w_all = ((*output_w - 1) * stride_w + (kernel_w - 1) * dilate_w + 1 - input_w);
38 if (pad_h_all < 0) {
39 param->pad_u_ = param->pad_d_ = 0;
40 } else {
41 param->pad_u_ = pad_h_all / 2;
42 param->pad_d_ = pad_h_all - param->pad_u_;
43 }
44 if (pad_w_all < 0) {
45 param->pad_l_ = param->pad_r_ = 0;
46 } else {
47 param->pad_l_ = pad_w_all / 2;
48 param->pad_r_ = pad_w_all - param->pad_l_;
49 }
50 } else if (param->pad_mode_ == Pad_valid) {
51 *output_w = ceil(((float)(input_w) + param->pad_l_ + param->pad_r_ - ((float)(kernel_w)-1) * (float)(dilate_w)) /
52 (float)(stride_w));
53 *output_h = ceil(((float)(input_h) + param->pad_u_ + param->pad_d_ - ((float)(kernel_h)-1) * (float)(dilate_h)) /
54 (float)(stride_h));
55 } else {
56 int kernel_width = (kernel_w - 1) * dilate_w + 1;
57 int kernel_height = (kernel_h - 1) * dilate_h + 1;
58 *output_w = ((input_w) + param->pad_l_ + param->pad_r_ - kernel_width) / stride_w + 1;
59 *output_h = ((input_h) + param->pad_u_ + param->pad_d_ - kernel_height) / stride_h + 1;
60 }
61
62 if (param->kernel_h_ > input_h + param->pad_u_ + param->pad_d_ ||
63 param->kernel_w_ > input_w + param->pad_l_ + param->pad_r_) {
64 return NNACL_PARAM_INVALID;
65 }
66 return NNACL_OK;
67 }
68
Conv2dInferShape(const TensorC * const * inputs,size_t inputs_size,TensorC ** outputs,size_t outputs_size,OpParameter * parameter)69 int Conv2dInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
70 OpParameter *parameter) {
71 int check_ret = CheckAugmentNullSizeInputTwo(inputs, inputs_size, outputs, outputs_size, parameter, 2, 3, 1);
72 if (check_ret != NNACL_OK) {
73 return check_ret;
74 }
75
76 const TensorC *input_tensor = inputs[0];
77 if (input_tensor->format_ != Format_NHWC && input_tensor->format_ != Format_KHWC) {
78 return NNACL_FORMAT_ERROR;
79 }
80 const TensorC *weight_tensor = inputs[1];
81 TensorC *out_tensor = outputs[0];
82
83 out_tensor->format_ = input_tensor->format_;
84 out_tensor->data_type_ = input_tensor->data_type_;
85 ConvParameter *param = (ConvParameter *)parameter;
86 if (param->group_ == 0) {
87 param->group_ = weight_tensor->shape_[0];
88 }
89 param->output_channel_ = weight_tensor->shape_[0];
90 if (!InferFlag(inputs, inputs_size)) {
91 return NNACL_INFER_INVALID;
92 }
93 const int *in_shape = input_tensor->shape_;
94 if (input_tensor->shape_size_ == 0) {
95 return NNACL_INFER_INVALID;
96 }
97 int input_h = in_shape[1];
98 int input_w = in_shape[2];
99 int input_c = in_shape[3];
100 int output_w = 0, output_h = 0;
101
102 // common conv: input_c == weight_tensor->shape_[3]
103 // conv depthwise: input_c == 1
104 // group conv: input_c / group == weight_tensor->shape_[3]
105 MS_CHECK_FALSE(param->group_ == 0, NNACL_PARAM_INVALID);
106 if (input_c != weight_tensor->shape_[3] && input_c != 1 && (input_c / param->group_) != weight_tensor->shape_[3]) {
107 return NNACL_PARAM_INVALID;
108 }
109 if (param->stride_h_ == 0 || param->stride_w_ == 0) {
110 return NNACL_PARAM_INVALID;
111 }
112
113 param->kernel_h_ = param->kernel_h_ != -1 ? param->kernel_h_ : weight_tensor->shape_[1];
114 param->kernel_w_ = param->kernel_w_ != -1 ? param->kernel_w_ : weight_tensor->shape_[2];
115 int ret = ConvInferShape(input_h, input_w, &output_h, &output_w, param);
116 if (ret != NNACL_OK) {
117 return ret;
118 }
119
120 int out_shape[MAX_SHAPE_SIZE];
121 size_t out_shape_size = 0;
122 ShapeSet(out_shape, &out_shape_size, input_tensor->shape_, input_tensor->shape_size_);
123 out_shape[1] = output_h >= 0 ? output_h : 1;
124 out_shape[2] = output_w >= 0 ? output_w : 1;
125 out_shape[3] = GetBatch(weight_tensor);
126 SetShapeArray(out_tensor, out_shape, out_shape_size);
127
128 param->input_batch_ = in_shape[0];
129 param->input_h_ = in_shape[1];
130 param->input_w_ = in_shape[2];
131 param->input_channel_ = in_shape[3];
132 param->output_batch_ = out_shape[0];
133 param->output_h_ = out_shape[1];
134 param->output_w_ = out_shape[2];
135 param->output_channel_ = out_shape[3];
136
137 return NNACL_OK;
138 }
139
140 REG_INFER(Adder, PrimType_AdderFusion, Conv2dInferShape)
141 REG_INFER(Conv2D, PrimType_Conv2DFusion, Conv2dInferShape)
142