• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_NNACL_FP32_CONV_SW_H_
18 #define MINDSPORE_NNACL_FP32_CONV_SW_H_
19 
20 #define GenerateConvSWFunc(backend, oc_unit_num, row_num_list, kernel_list, compute_core, outer_compute)            \
21   void SWBorder##backend(float *dst, const float *src, const float *weight, const float *bias, int top, int bottom, \
22                          int left, int right, const ConvParameter *conv_param, const SlidingWindowParam *sw_param,  \
23                          const SWConvKernel kernel, int act_type, int ow_bock, int oc_block, size_t write_mode) {   \
24     for (int oh = top; oh < bottom; oh++) {                                                                         \
25       int ih = oh * conv_param->stride_h_ - conv_param->pad_u_;                                                     \
26       int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_));                                                \
27       int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih, conv_param->dilation_h_));        \
28       const float *src_h = src + ih * sw_param->in_h_step_;                                                         \
29       float *dst_kernel = dst + left * sw_param->out_w_step_;                                                       \
30       for (int ow = left; ow < right; ow += ow_bock) {                                                              \
31         int iw = ow * conv_param->stride_w_ - conv_param->pad_l_;                                                   \
32         int start_kw = MSMAX(0, UP_DIV(-iw, conv_param->dilation_w_));                                              \
33         int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->input_w_ - iw, conv_param->dilation_w_));      \
34         const float *src_w = src_h + iw * sw_param->ic_align_;                                                      \
35         const float *src_kernel = src_w + start_kh * sw_param->in_kh_step_ + start_kw * sw_param->in_kw_step_;      \
36         const float *weight_kernel =                                                                                \
37           weight + (start_kh * conv_param->kernel_w_ + start_kw) * sw_param->ic_align_ * C8NUM * oc_block;          \
38         outer_compute dst_kernel += ow_bock * sw_param->out_w_step_;                                                \
39       }                                                                                                             \
40       dst += sw_param->out_h_step_;                                                                                 \
41     }                                                                                                               \
42   }                                                                                                                 \
43                                                                                                                     \
44   void ConvSW##backend##Fp32(const float *input_data, const float *packed_weight, const float *bias_data,           \
45                              float *output_data, int task_id, ConvParameter *conv_param,                            \
46                              SlidingWindowParam *sw_param) {                                                        \
47     int out_h = conv_param->output_h_;                                                                              \
48     int oh_step = UP_DIV(out_h, conv_param->thread_num_);                                                           \
49     int oh_start = oh_step * task_id;                                                                               \
50     int oh_end = MSMIN(oh_start + oh_step, out_h);                                                                  \
51     if (oh_start >= oh_end) {                                                                                       \
52       return;                                                                                                       \
53     }                                                                                                               \
54     int oc_tile_ = C8NUM; /* oc in algin to C8NUM in arm64 */                                                       \
55     int act_type = 0;                                                                                               \
56     if (conv_param->act_type_ == ActType_Relu6) {                                                                   \
57       act_type += 1;                                                                                                \
58     }                                                                                                               \
59     if (conv_param->act_type_ == ActType_Relu || conv_param->act_type_ == ActType_Relu6) {                          \
60       act_type += 2;                                                                                                \
61     }                                                                                                               \
62     int kernel_h = conv_param->kernel_h_;                                                                           \
63     int kernel_w = conv_param->kernel_w_;                                                                           \
64     int ic_algin = sw_param->ic_align_;                                                                             \
65     int in_sw_step = sw_param->in_sw_step_;                                                                         \
66     int in_kw_step = sw_param->in_kw_step_;                                                                         \
67     int in_kh_step = sw_param->in_kh_step_;                                                                         \
68     int in_sh_step = sw_param->in_sh_step_;                                                                         \
69     int out_h_step = sw_param->out_h_step_;                                                                         \
70     int out_c_step = sw_param->out_c_step_;                                                                         \
71     int out_w_step = sw_param->out_w_step_;                                                                         \
72     int out_block_step = sw_param->out_block_step_;                                                                 \
73     int kernel_step = sw_param->kernel_step_;                                                                       \
74     int in_step = sw_param->in_step_;                                                                               \
75     int out_step = sw_param->out_step_;                                                                             \
76     int c_block = sw_param->c_block_;                                                                               \
77     int top = sw_param->top_;                                                                                       \
78     int left = sw_param->left_;                                                                                     \
79     int right = sw_param->right_;                                                                                   \
80     int bottom = sw_param->bottom_;                                                                                 \
81     int stride_h = conv_param->stride_h_;                                                                           \
82     int stride_w = conv_param->stride_w_;                                                                           \
83     int out_w = conv_param->output_w_;                                                                              \
84     int pad_u = conv_param->pad_u_;                                                                                 \
85     int pad_l = conv_param->pad_l_;                                                                                 \
86     int in_h_step = sw_param->in_h_step_;                                                                           \
87     int out_batch = conv_param->output_batch_;                                                                      \
88     int in_h_start = top * stride_h - pad_u;                                                                        \
89     int in_w_start = left * stride_w - pad_l;                                                                       \
90     int center_step = in_h_start * in_h_step + in_w_start * ic_algin;                                               \
91     int write_mode = conv_param->out_format_;                                                                       \
92     row_num_list kernel_list for (int b = 0; b < out_batch; b++) {                                                  \
93       for (int oh = oh_start; oh < oh_end; oh += 1) {                                                               \
94         float *dst_oh = output_data + oh * out_h_step;                                                              \
95         const float *src_h = input_data + center_step;                                                              \
96                                                                                                                     \
97         int oc_block = 0;                                                                                           \
98         const float *bias = bias_data;                                                                              \
99         for (int oc = 0; oc < c_block; oc += oc_block) {                                                            \
100           oc_block = MSMIN(oc_unit_num, c_block - oc);                                                              \
101           const float *weight = packed_weight + oc * kernel_step;                                                   \
102           if (bias != NULL) {                                                                                       \
103             bias = bias_data + oc * oc_tile_;                                                                       \
104           }                                                                                                         \
105           /* nhwc dst_w = dst_oh + oc * oc_tile_;  nc8hw8 dst_w = dst_oh * oc * ow * oh * oc_tile_; */              \
106           float *dst_oc = dst_oh + oc * out_c_step;                                                                 \
107           const SWConvKernel kernel_border = kernel[oc_block - 1][0];                                               \
108           if (oh < top || oh >= bottom) { /* oh in up or down border */                                             \
109             SWBorder##backend(dst_oc, input_data, weight, bias, oh, oh + 1, 0, out_w, conv_param, sw_param,         \
110                               kernel_border, act_type, 1, oc_block, write_mode);                                    \
111           } else { /* oh in center */                                                                               \
112             /* ow in right */                                                                                       \
113             SWBorder##backend(dst_oc, input_data, weight, bias, oh, oh + 1, 0, left, conv_param, sw_param,          \
114                               kernel_border, act_type, 1, oc_block, write_mode);                                    \
115             /* ow in center */                                                                                      \
116             const float *src_w = src_h + (oh - top) * in_sh_step;                                                   \
117             int ow_block = ow_block_num[oc_block - 1];                                                              \
118             for (int ow = left; ow < right; ow += ow_block) { /* left ~ right */                                    \
119               ow_block = MSMIN(ow_block, right - ow);                                                               \
120               compute_core src_w += ow_block * in_sw_step;                                                          \
121             }                                                                                                       \
122             /* ow in left */                                                                                        \
123             SWBorder##backend(dst_oc, input_data, weight, bias, oh, oh + 1, right, out_w, conv_param, sw_param,     \
124                               kernel_border, act_type, 1, oc_block, write_mode);                                    \
125           }                                                                                                         \
126         }                                                                                                           \
127       } /* output h loop */                                                                                         \
128       input_data += in_step;                                                                                        \
129       output_data += out_step;                                                                                      \
130     } /* batch loop */                                                                                              \
131   }
132 #endif  // MINDSPORE_NNACL_FP32_CONV_SW_H_
133