• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "nnacl/base/conv_common_base.h"
17 #include "nnacl/errorcode.h"
18 
19 #define MIN_UNIT 2
20 #define MAX_UNIT 8
21 
22 #if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX))
CheckConvDw1DWinograd(const ConvParameter * conv_param,int thread_num)23 bool CheckConvDw1DWinograd(const ConvParameter *conv_param, int thread_num) {
24   return conv_param->kernel_h_ == 3 && conv_param->kernel_w_ == 3 && conv_param->stride_w_ == 1 &&
25          conv_param->stride_h_ == 1 && conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1 &&
26          conv_param->pad_u_ == 1 && conv_param->pad_d_ == 1 && conv_param->pad_l_ == 1 && conv_param->pad_r_ == 1 &&
27          conv_param->input_channel_ == conv_param->output_channel_ && conv_param->output_w_ >= 4 &&
28          conv_param->output_h_ >= thread_num * 4;  // better had more than 4 rows for each thread
29 }
30 #endif
31 
CheckWinogradInputOutputUnit(int input_unit,int output_unit)32 bool CheckWinogradInputOutputUnit(int input_unit, int output_unit) {
33   if (input_unit != 4 && input_unit != 6 && input_unit != 8) {
34     return false;
35   }
36   if ((output_unit >= input_unit) || (output_unit < 2)) {
37     return false;
38   }
39   return true;
40 }
41 
42 // Reference to the paper "Fast Algorithms for Convolutional Neural Networks"
43 // Utilize cost model to compute performance gain.
44 // If the gain is greater than got from Im2col, winograd algorithm will be chosen.
SelectOutputUnit(const ConvParameter * conv_param)45 int SelectOutputUnit(const ConvParameter *conv_param) {
46   int kernel_h = conv_param->kernel_h_;
47   int kernel_w = conv_param->kernel_w_;
48   int in_c = conv_param->input_channel_;
49   int out_w = conv_param->output_w_;
50   int out_h = conv_param->output_h_;
51   int out_c = conv_param->output_channel_;
52   if (conv_param->op_parameter_.thread_num_ == 0) {
53     return NNACL_PARAM_INVALID;
54   }
55   int unit2 = UP_DIV(out_w * out_h, C12NUM * conv_param->op_parameter_.thread_num_);
56   int max_out_unit = (int)(sqrtf((float)unit2));
57   max_out_unit = max_out_unit < MAX_UNIT ? max_out_unit : MAX_UNIT;
58   max_out_unit = max_out_unit > MIN_UNIT ? max_out_unit : MIN_UNIT;
59 
60   int unit = 0;
61   float max_rate = 0.0f;
62   float common_cost = (float)out_h * out_w * in_c * out_c * kernel_h * kernel_w;
63 
64   for (int i = MIN_UNIT; i <= max_out_unit; ++i) {
65     int input_unit = i + kernel_w - 1;
66     if (!CheckWinogradInputOutputUnit(input_unit, i)) {
67       continue;
68     }
69     float penalty = ((float)input_unit * input_unit) / ((float)kernel_h * kernel_w) * 0.12f;
70     float wino_cost = ((2 + out_c) * (float)input_unit * input_unit * in_c + ((float)input_unit + i) * i * out_c) *
71                       UP_DIV(out_w, i) * UP_DIV(out_h, i);
72     float reduce_rate = common_cost / wino_cost - penalty;
73     if (reduce_rate > max_rate) {
74       max_rate = reduce_rate;
75       unit = i;
76     }
77   }
78   if (max_rate < 1.0f) {
79     return 1;
80   }
81   // If output_unit is 1, then it is conventional convolution
82   return unit;
83 }
84 
CheckIfUseWinograd(int * output_unit,const ConvParameter * conv_param)85 bool CheckIfUseWinograd(int *output_unit, const ConvParameter *conv_param) {
86   if (conv_param->kernel_w_ == conv_param->kernel_h_ && conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1 &&
87       conv_param->stride_h_ == 1 && conv_param->stride_w_ == 1 && conv_param->input_channel_ != 1) {
88     *output_unit = SelectOutputUnit(conv_param);
89     if (*output_unit > 1) {
90       return true;
91     }
92   }
93   return false;
94 }
95