• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "nnacl/infer/conv2d_infer.h"
17 #include "nnacl/infer/infer_register.h"
18 
ConvInferShape(int input_h,int input_w,int * output_h,int * output_w,ConvParameter * param)19 int ConvInferShape(int input_h, int input_w, int *output_h, int *output_w, ConvParameter *param) {
20   int kernel_w = param->kernel_w_;
21   int kernel_h = param->kernel_h_;
22   int stride_w = param->stride_w_;
23   int stride_h = param->stride_h_;
24   int dilate_w = param->dilation_w_;
25   int dilate_h = param->dilation_h_;
26 
27   if (stride_w == 0 || stride_h == 0) {
28     return NNACL_PARAM_INVALID;
29   }
30   if (INT_MUL_OVERFLOW(kernel_h, dilate_h) || INT_MUL_OVERFLOW(kernel_w, dilate_w)) {
31     return NNACL_ERRCODE_MUL_OVERFLOW;
32   }
33   if (param->pad_mode_ == Pad_same) {  // maybe error
34     *output_w = ceil((float)(input_w) / (float)(stride_w));
35     *output_h = ceil((float)(input_h) / (float)(stride_h));
36     int pad_h_all = ((*output_h - 1) * stride_h + (kernel_h - 1) * dilate_h + 1 - input_h);
37     int pad_w_all = ((*output_w - 1) * stride_w + (kernel_w - 1) * dilate_w + 1 - input_w);
38     if (pad_h_all < 0) {
39       param->pad_u_ = param->pad_d_ = 0;
40     } else {
41       param->pad_u_ = pad_h_all / 2;
42       param->pad_d_ = pad_h_all - param->pad_u_;
43     }
44     if (pad_w_all < 0) {
45       param->pad_l_ = param->pad_r_ = 0;
46     } else {
47       param->pad_l_ = pad_w_all / 2;
48       param->pad_r_ = pad_w_all - param->pad_l_;
49     }
50   } else if (param->pad_mode_ == Pad_valid) {
51     *output_w = ceil(((float)(input_w) + param->pad_l_ + param->pad_r_ - ((float)(kernel_w)-1) * (float)(dilate_w)) /
52                      (float)(stride_w));
53     *output_h = ceil(((float)(input_h) + param->pad_u_ + param->pad_d_ - ((float)(kernel_h)-1) * (float)(dilate_h)) /
54                      (float)(stride_h));
55   } else {
56     int kernel_width = (kernel_w - 1) * dilate_w + 1;
57     int kernel_height = (kernel_h - 1) * dilate_h + 1;
58     *output_w = ((input_w) + param->pad_l_ + param->pad_r_ - kernel_width) / stride_w + 1;
59     *output_h = ((input_h) + param->pad_u_ + param->pad_d_ - kernel_height) / stride_h + 1;
60   }
61 
62   if (param->kernel_h_ > input_h + param->pad_u_ + param->pad_d_ ||
63       param->kernel_w_ > input_w + param->pad_l_ + param->pad_r_) {
64     return NNACL_PARAM_INVALID;
65   }
66   return NNACL_OK;
67 }
68 
Conv2dInferShape(const TensorC * const * inputs,size_t inputs_size,TensorC ** outputs,size_t outputs_size,OpParameter * parameter)69 int Conv2dInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
70                      OpParameter *parameter) {
71   int check_ret = CheckAugmentNullSizeInputTwo(inputs, inputs_size, outputs, outputs_size, parameter, 2, 3, 1);
72   if (check_ret != NNACL_OK) {
73     return check_ret;
74   }
75 
76   const TensorC *input_tensor = inputs[0];
77   if (input_tensor->format_ != Format_NHWC && input_tensor->format_ != Format_KHWC) {
78     return NNACL_FORMAT_ERROR;
79   }
80   const TensorC *weight_tensor = inputs[1];
81   TensorC *out_tensor = outputs[0];
82 
83   out_tensor->format_ = input_tensor->format_;
84   out_tensor->data_type_ = input_tensor->data_type_;
85   ConvParameter *param = (ConvParameter *)parameter;
86   if (param->group_ == 0) {
87     param->group_ = weight_tensor->shape_[0];
88   }
89   param->output_channel_ = weight_tensor->shape_[0];
90   if (!InferFlag(inputs, inputs_size)) {
91     return NNACL_INFER_INVALID;
92   }
93   const int *in_shape = input_tensor->shape_;
94   if (input_tensor->shape_size_ == 0) {
95     return NNACL_INFER_INVALID;
96   }
97   int input_h = in_shape[1];
98   int input_w = in_shape[2];
99   int input_c = in_shape[3];
100   int output_w = 0, output_h = 0;
101 
102   // common conv: input_c == weight_tensor->shape_[3]
103   // conv depthwise: input_c == 1
104   // group conv: input_c / group == weight_tensor->shape_[3]
105   MS_CHECK_FALSE(param->group_ == 0, NNACL_PARAM_INVALID);
106   if (input_c != weight_tensor->shape_[3] && input_c != 1 && (input_c / param->group_) != weight_tensor->shape_[3]) {
107     return NNACL_PARAM_INVALID;
108   }
109   if (param->stride_h_ == 0 || param->stride_w_ == 0) {
110     return NNACL_PARAM_INVALID;
111   }
112 
113   param->kernel_h_ = param->kernel_h_ != -1 ? param->kernel_h_ : weight_tensor->shape_[1];
114   param->kernel_w_ = param->kernel_w_ != -1 ? param->kernel_w_ : weight_tensor->shape_[2];
115   int ret = ConvInferShape(input_h, input_w, &output_h, &output_w, param);
116   if (ret != NNACL_OK) {
117     return ret;
118   }
119 
120   int out_shape[MAX_SHAPE_SIZE];
121   size_t out_shape_size = 0;
122   ShapeSet(out_shape, &out_shape_size, input_tensor->shape_, input_tensor->shape_size_);
123   out_shape[1] = output_h >= 0 ? output_h : 1;
124   out_shape[2] = output_w >= 0 ? output_w : 1;
125   out_shape[3] = GetBatch(weight_tensor);
126   SetShapeArray(out_tensor, out_shape, out_shape_size);
127 
128   param->input_batch_ = in_shape[0];
129   param->input_h_ = in_shape[1];
130   param->input_w_ = in_shape[2];
131   param->input_channel_ = in_shape[3];
132   param->output_batch_ = out_shape[0];
133   param->output_h_ = out_shape[1];
134   param->output_w_ = out_shape[2];
135   param->output_channel_ = out_shape[3];
136 
137   return NNACL_OK;
138 }
139 
140 REG_INFER(Adder, PrimType_AdderFusion, Conv2dInferShape)
141 REG_INFER(Conv2D, PrimType_Conv2DFusion, Conv2dInferShape)
142