1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_NNACL_WINOGRAD_UTILS_H_ 18 #define MINDSPORE_NNACL_WINOGRAD_UTILS_H_ 19 20 #ifdef ENABLE_ARM 21 #include <arm_neon.h> 22 #endif 23 #include "nnacl/conv_parameter.h" 24 #include "nnacl/op_base.h" 25 26 #ifdef __cplusplus 27 extern "C" { 28 #endif 29 typedef void (*InputTransFunc)(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c); 30 31 typedef void (*OutputTransFunc)(const float *src_data, float *dst_data, const float *bias_data, int src_step, 32 int dst_step, int out_c, int r_w, int r_h, int r_c); 33 34 #define Load16Data \ 35 src[0] = MS_LDQ_F32(src_data + 0 * src_step); \ 36 src[1] = MS_LDQ_F32(src_data + 1 * src_step); \ 37 src[2] = MS_LDQ_F32(src_data + 2 * src_step); \ 38 src[3] = MS_LDQ_F32(src_data + 3 * src_step); \ 39 src[4] = MS_LDQ_F32(src_data + 4 * src_step); \ 40 src[5] = MS_LDQ_F32(src_data + 5 * src_step); \ 41 src[6] = MS_LDQ_F32(src_data + 6 * src_step); \ 42 src[7] = MS_LDQ_F32(src_data + 7 * src_step); \ 43 src[8] = MS_LDQ_F32(src_data + 8 * src_step); \ 44 src[9] = MS_LDQ_F32(src_data + 9 * src_step); \ 45 src[10] = MS_LDQ_F32(src_data + 10 * src_step); \ 46 src[11] = MS_LDQ_F32(src_data + 11 * src_step); \ 47 src[12] = MS_LDQ_F32(src_data + 12 * src_step); \ 48 src[13] = MS_LDQ_F32(src_data + 13 * src_step); \ 49 src[14] = MS_LDQ_F32(src_data + 14 * src_step); \ 50 src[15] = MS_LDQ_F32(src_data + 15 * src_step); 51 52 #define Load36Data \ 53 src[0] = MS_LDQ_F32(src_data + 0 * src_step); \ 54 src[1] = MS_LDQ_F32(src_data + 1 * src_step); \ 55 src[2] = MS_LDQ_F32(src_data + 2 * src_step); \ 56 src[3] = MS_LDQ_F32(src_data + 3 * src_step); \ 57 src[4] = MS_LDQ_F32(src_data + 4 * src_step); \ 58 src[5] = MS_LDQ_F32(src_data + 5 * src_step); \ 59 src[6] = MS_LDQ_F32(src_data + 6 * src_step); \ 60 src[7] = MS_LDQ_F32(src_data + 7 * src_step); \ 61 src[8] = MS_LDQ_F32(src_data + 8 * src_step); \ 62 src[9] = MS_LDQ_F32(src_data + 9 * src_step); \ 63 src[10] = MS_LDQ_F32(src_data + 10 * src_step); \ 64 src[11] = MS_LDQ_F32(src_data + 11 * src_step); \ 65 src[12] = MS_LDQ_F32(src_data + 12 * src_step); \ 66 src[13] = MS_LDQ_F32(src_data + 13 * src_step); \ 67 src[14] = MS_LDQ_F32(src_data + 14 * src_step); \ 68 src[15] = MS_LDQ_F32(src_data + 15 * src_step); \ 69 src[16] = MS_LDQ_F32(src_data + 16 * src_step); \ 70 src[17] = MS_LDQ_F32(src_data + 17 * src_step); \ 71 src[18] = MS_LDQ_F32(src_data + 18 * src_step); \ 72 src[19] = MS_LDQ_F32(src_data + 19 * src_step); \ 73 src[20] = MS_LDQ_F32(src_data + 20 * src_step); \ 74 src[21] = MS_LDQ_F32(src_data + 21 * src_step); \ 75 src[22] = MS_LDQ_F32(src_data + 22 * src_step); \ 76 src[23] = MS_LDQ_F32(src_data + 23 * src_step); \ 77 src[24] = MS_LDQ_F32(src_data + 24 * src_step); \ 78 src[25] = MS_LDQ_F32(src_data + 25 * src_step); \ 79 src[26] = MS_LDQ_F32(src_data + 26 * src_step); \ 80 src[27] = MS_LDQ_F32(src_data + 27 * src_step); \ 81 src[28] = MS_LDQ_F32(src_data + 28 * src_step); \ 82 src[29] = MS_LDQ_F32(src_data + 29 * src_step); \ 83 src[30] = MS_LDQ_F32(src_data + 30 * src_step); \ 84 src[31] = MS_LDQ_F32(src_data + 31 * src_step); \ 85 src[32] = MS_LDQ_F32(src_data + 32 * src_step); \ 86 src[33] = MS_LDQ_F32(src_data + 33 * src_step); \ 87 src[34] = MS_LDQ_F32(src_data + 34 * src_step); \ 88 src[35] = MS_LDQ_F32(src_data + 35 * src_step); 89 90 #define Load64Data \ 91 src[0] = MS_LDQ_F32(src_data + 0 * src_step); \ 92 src[1] = MS_LDQ_F32(src_data + 1 * src_step); \ 93 src[2] = MS_LDQ_F32(src_data + 2 * src_step); \ 94 src[3] = MS_LDQ_F32(src_data + 3 * src_step); \ 95 src[4] = MS_LDQ_F32(src_data + 4 * src_step); \ 96 src[5] = MS_LDQ_F32(src_data + 5 * src_step); \ 97 src[6] = MS_LDQ_F32(src_data + 6 * src_step); \ 98 src[7] = MS_LDQ_F32(src_data + 7 * src_step); \ 99 src[8] = MS_LDQ_F32(src_data + 8 * src_step); \ 100 src[9] = MS_LDQ_F32(src_data + 9 * src_step); \ 101 src[10] = MS_LDQ_F32(src_data + 10 * src_step); \ 102 src[11] = MS_LDQ_F32(src_data + 11 * src_step); \ 103 src[12] = MS_LDQ_F32(src_data + 12 * src_step); \ 104 src[13] = MS_LDQ_F32(src_data + 13 * src_step); \ 105 src[14] = MS_LDQ_F32(src_data + 14 * src_step); \ 106 src[15] = MS_LDQ_F32(src_data + 15 * src_step); \ 107 src[16] = MS_LDQ_F32(src_data + 16 * src_step); \ 108 src[17] = MS_LDQ_F32(src_data + 17 * src_step); \ 109 src[18] = MS_LDQ_F32(src_data + 18 * src_step); \ 110 src[19] = MS_LDQ_F32(src_data + 19 * src_step); \ 111 src[20] = MS_LDQ_F32(src_data + 20 * src_step); \ 112 src[21] = MS_LDQ_F32(src_data + 21 * src_step); \ 113 src[22] = MS_LDQ_F32(src_data + 22 * src_step); \ 114 src[23] = MS_LDQ_F32(src_data + 23 * src_step); \ 115 src[24] = MS_LDQ_F32(src_data + 24 * src_step); \ 116 src[25] = MS_LDQ_F32(src_data + 25 * src_step); \ 117 src[26] = MS_LDQ_F32(src_data + 26 * src_step); \ 118 src[27] = MS_LDQ_F32(src_data + 27 * src_step); \ 119 src[28] = MS_LDQ_F32(src_data + 28 * src_step); \ 120 src[29] = MS_LDQ_F32(src_data + 29 * src_step); \ 121 src[30] = MS_LDQ_F32(src_data + 30 * src_step); \ 122 src[31] = MS_LDQ_F32(src_data + 31 * src_step); \ 123 src[32] = MS_LDQ_F32(src_data + 32 * src_step); \ 124 src[33] = MS_LDQ_F32(src_data + 33 * src_step); \ 125 src[34] = MS_LDQ_F32(src_data + 34 * src_step); \ 126 src[35] = MS_LDQ_F32(src_data + 35 * src_step); \ 127 src[36] = MS_LDQ_F32(src_data + 36 * src_step); \ 128 src[37] = MS_LDQ_F32(src_data + 37 * src_step); \ 129 src[38] = MS_LDQ_F32(src_data + 38 * src_step); \ 130 src[39] = MS_LDQ_F32(src_data + 39 * src_step); \ 131 src[40] = MS_LDQ_F32(src_data + 40 * src_step); \ 132 src[41] = MS_LDQ_F32(src_data + 41 * src_step); \ 133 src[42] = MS_LDQ_F32(src_data + 42 * src_step); \ 134 src[43] = MS_LDQ_F32(src_data + 43 * src_step); \ 135 src[44] = MS_LDQ_F32(src_data + 44 * src_step); \ 136 src[45] = MS_LDQ_F32(src_data + 45 * src_step); \ 137 src[46] = MS_LDQ_F32(src_data + 46 * src_step); \ 138 src[47] = MS_LDQ_F32(src_data + 47 * src_step); \ 139 src[48] = MS_LDQ_F32(src_data + 48 * src_step); \ 140 src[49] = MS_LDQ_F32(src_data + 49 * src_step); \ 141 src[50] = MS_LDQ_F32(src_data + 50 * src_step); \ 142 src[51] = MS_LDQ_F32(src_data + 51 * src_step); \ 143 src[52] = MS_LDQ_F32(src_data + 52 * src_step); \ 144 src[53] = MS_LDQ_F32(src_data + 53 * src_step); \ 145 src[54] = MS_LDQ_F32(src_data + 54 * src_step); \ 146 src[55] = MS_LDQ_F32(src_data + 55 * src_step); \ 147 src[56] = MS_LDQ_F32(src_data + 56 * src_step); \ 148 src[57] = MS_LDQ_F32(src_data + 57 * src_step); \ 149 src[58] = MS_LDQ_F32(src_data + 58 * src_step); \ 150 src[59] = MS_LDQ_F32(src_data + 59 * src_step); \ 151 src[60] = MS_LDQ_F32(src_data + 60 * src_step); \ 152 src[61] = MS_LDQ_F32(src_data + 61 * src_step); \ 153 src[62] = MS_LDQ_F32(src_data + 62 * src_step); \ 154 src[63] = MS_LDQ_F32(src_data + 63 * src_step); 155 156 InputTransFunc GetInputTransFunc(int input_unit); 157 158 void InputTransform4x4Unit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c); 159 160 void InputTransform6x6Unit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c); 161 162 void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c); 163 164 OutputTransFunc GetOutputTransFunc(int input_unit, int output_unit, ActType act_type); 165 166 #define Store4Data \ 167 MS_STQ_F32(dst_data, m[0]); \ 168 MS_STQ_F32(dst_data + out_c, m[1]); \ 169 MS_STQ_F32(dst_data + dst_step * out_c, m[2]); \ 170 MS_STQ_F32(dst_data + dst_step * out_c + out_c, m[3]); 171 172 #define Store9Data \ 173 MS_STQ_F32(dst_data, m[0]); \ 174 MS_STQ_F32(dst_data + out_c, m[1]); \ 175 MS_STQ_F32(dst_data + 2 * out_c, m[2]); \ 176 MS_STQ_F32(dst_data + dst_step * out_c, m[3]); \ 177 MS_STQ_F32(dst_data + dst_step * out_c + out_c, m[4]); \ 178 MS_STQ_F32(dst_data + dst_step * out_c + 2 * out_c, m[5]); \ 179 MS_STQ_F32(dst_data + 2 * dst_step * out_c, m[6]); \ 180 MS_STQ_F32(dst_data + 2 * dst_step * out_c + out_c, m[7]); \ 181 MS_STQ_F32(dst_data + 2 * dst_step * out_c + 2 * out_c, m[8]); 182 183 #define Store16Data \ 184 MS_STQ_F32(dst_data, m[0]); \ 185 MS_STQ_F32(dst_data + out_c, m[1]); \ 186 MS_STQ_F32(dst_data + 2 * out_c, m[2]); \ 187 MS_STQ_F32(dst_data + 3 * out_c, m[3]); \ 188 MS_STQ_F32(dst_data + dst_step * out_c, m[4]); \ 189 MS_STQ_F32(dst_data + dst_step * out_c + out_c, m[5]); \ 190 MS_STQ_F32(dst_data + dst_step * out_c + 2 * out_c, m[6]); \ 191 MS_STQ_F32(dst_data + dst_step * out_c + 3 * out_c, m[7]); \ 192 MS_STQ_F32(dst_data + 2 * dst_step * out_c, m[8]); \ 193 MS_STQ_F32(dst_data + 2 * dst_step * out_c + out_c, m[9]); \ 194 MS_STQ_F32(dst_data + 2 * dst_step * out_c + 2 * out_c, m[10]); \ 195 MS_STQ_F32(dst_data + 2 * dst_step * out_c + 3 * out_c, m[11]); \ 196 MS_STQ_F32(dst_data + 3 * dst_step * out_c, m[12]); \ 197 MS_STQ_F32(dst_data + 3 * dst_step * out_c + out_c, m[13]); \ 198 MS_STQ_F32(dst_data + 3 * dst_step * out_c + 2 * out_c, m[14]); \ 199 MS_STQ_F32(dst_data + 3 * dst_step * out_c + 3 * out_c, m[15]); 200 201 #define Store25Data \ 202 MS_STQ_F32(dst_data, m[0]); \ 203 MS_STQ_F32(dst_data + out_c, m[1]); \ 204 MS_STQ_F32(dst_data + 2 * out_c, m[2]); \ 205 MS_STQ_F32(dst_data + 3 * out_c, m[3]); \ 206 MS_STQ_F32(dst_data + 4 * out_c, m[4]); \ 207 MS_STQ_F32(dst_data + dst_step * out_c, m[5]); \ 208 MS_STQ_F32(dst_data + dst_step * out_c + out_c, m[6]); \ 209 MS_STQ_F32(dst_data + dst_step * out_c + 2 * out_c, m[7]); \ 210 MS_STQ_F32(dst_data + dst_step * out_c + 3 * out_c, m[8]); \ 211 MS_STQ_F32(dst_data + dst_step * out_c + 4 * out_c, m[9]); \ 212 MS_STQ_F32(dst_data + 2 * dst_step * out_c, m[10]); \ 213 MS_STQ_F32(dst_data + 2 * dst_step * out_c + out_c, m[11]); \ 214 MS_STQ_F32(dst_data + 2 * dst_step * out_c + 2 * out_c, m[12]); \ 215 MS_STQ_F32(dst_data + 2 * dst_step * out_c + 3 * out_c, m[13]); \ 216 MS_STQ_F32(dst_data + 2 * dst_step * out_c + 4 * out_c, m[14]); \ 217 MS_STQ_F32(dst_data + 3 * dst_step * out_c, m[15]); \ 218 MS_STQ_F32(dst_data + 3 * dst_step * out_c + out_c, m[16]); \ 219 MS_STQ_F32(dst_data + 3 * dst_step * out_c + 2 * out_c, m[17]); \ 220 MS_STQ_F32(dst_data + 3 * dst_step * out_c + 3 * out_c, m[18]); \ 221 MS_STQ_F32(dst_data + 3 * dst_step * out_c + 4 * out_c, m[19]); \ 222 MS_STQ_F32(dst_data + 4 * dst_step * out_c, m[20]); \ 223 MS_STQ_F32(dst_data + 4 * dst_step * out_c + out_c, m[21]); \ 224 MS_STQ_F32(dst_data + 4 * dst_step * out_c + 2 * out_c, m[22]); \ 225 MS_STQ_F32(dst_data + 4 * dst_step * out_c + 3 * out_c, m[23]); \ 226 MS_STQ_F32(dst_data + 4 * dst_step * out_c + 4 * out_c, m[24]); 227 228 void OutputTransform4x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, 229 int out_c, int r_w, int r_h, int r_c); 230 void OutputTransform4x2ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, 231 int dst_step, int out_c, int r_w, int r_h, int r_c); 232 void OutputTransform4x2Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, 233 int dst_step, int out_c, int r_w, int r_h, int r_c); 234 void OutputTransform4x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, 235 int out_c, int r_w, int r_h, int r_c); 236 void OutputTransform4x3ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, 237 int dst_step, int out_c, int r_w, int r_h, int r_c); 238 void OutputTransform4x3Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, 239 int dst_step, int out_c, int r_w, int r_h, int r_c); 240 241 void OutputTransform6x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, 242 int out_c, int r_w, int r_h, int r_c); 243 void OutputTransform6x2ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, 244 int dst_step, int out_c, int r_w, int r_h, int r_c); 245 void OutputTransform6x2Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, 246 int dst_step, int out_c, int r_w, int r_h, int r_c); 247 void OutputTransform6x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, 248 int out_c, int r_w, int r_h, int r_c); 249 void OutputTransform6x3ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, 250 int dst_step, int out_c, int r_w, int r_h, int r_c); 251 void OutputTransform6x3Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, 252 int dst_step, int out_c, int r_w, int r_h, int r_c); 253 void OutputTransform6x4Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, 254 int out_c, int r_w, int r_h, int r_c); 255 void OutputTransform6x4ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, 256 int dst_step, int out_c, int r_w, int r_h, int r_c); 257 void OutputTransform6x4Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, 258 int dst_step, int out_c, int r_w, int r_h, int r_c); 259 void OutputTransform6x5Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, 260 int out_c, int r_w, int r_h, int r_c); 261 void OutputTransform6x5ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, 262 int dst_step, int out_c, int r_w, int r_h, int r_c); 263 void OutputTransform6x5Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, 264 int dst_step, int out_c, int r_w, int r_h, int r_c); 265 266 void OutputTransform8x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, 267 int out_c, int r_w, int r_h, int r_c); 268 void OutputTransform8x2ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, 269 int dst_step, int out_c, int r_w, int r_h, int r_c); 270 void OutputTransform8x2Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, 271 int dst_step, int out_c, int r_w, int r_h, int r_c); 272 void OutputTransform8x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, 273 int out_c, int r_w, int r_h, int r_c); 274 void OutputTransform8x3ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, 275 int dst_step, int out_c, int r_w, int r_h, int r_c); 276 void OutputTransform8x3Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, 277 int dst_step, int out_c, int r_w, int r_h, int r_c); 278 void OutputTransform8x4Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, 279 int out_c, int r_w, int r_h, int r_c); 280 void OutputTransform8x4ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, 281 int dst_step, int out_c, int r_w, int r_h, int r_c); 282 void OutputTransform8x4Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, 283 int dst_step, int out_c, int r_w, int r_h, int r_c); 284 void OutputTransform8x5Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, 285 int out_c, int r_w, int r_h, int r_c); 286 void OutputTransform8x5ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, 287 int dst_step, int out_c, int r_w, int r_h, int r_c); 288 void OutputTransform8x5Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, 289 int dst_step, int out_c, int r_w, int r_h, int r_c); 290 void OutputTransform8x6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, 291 int out_c, int r_w, int r_h, int r_c); 292 void OutputTransform8x6ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, 293 int dst_step, int out_c, int r_w, int r_h, int r_c); 294 void OutputTransform8x6Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, 295 int dst_step, int out_c, int r_w, int r_h, int r_c); 296 void OutputTransform8x7Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, 297 int out_c, int r_w, int r_h, int r_c); 298 void OutputTransform8x7ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, 299 int dst_step, int out_c, int r_w, int r_h, int r_c); 300 void OutputTransform8x7Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, 301 int dst_step, int out_c, int r_w, int r_h, int r_c); 302 303 int SelectOutputUnit(const ConvParameter *conv_param); 304 305 #ifdef __cplusplus 306 } 307 #endif 308 309 #endif // MINDSPORE_NNACL_WINOGRAD_UTILS_H_ 310