1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8
9 #pragma once
10
11 #include <stdint.h>
12 #include <xnnpack/math.h>
13 #include <xnnpack/operator.h>
14
15
xnn_pack_q8_gemm_goi_w(size_t g,size_t nc,size_t kc,uint32_t nr,uint32_t kr,uint8_t izp,uint8_t kzp,const uint8_t * k,const int32_t * b,void * packed_w)16 static inline void xnn_pack_q8_gemm_goi_w(
17 size_t g,
18 size_t nc,
19 size_t kc,
20 uint32_t nr,
21 uint32_t kr,
22 uint8_t izp,
23 uint8_t kzp,
24 const uint8_t* k,
25 const int32_t* b,
26 void* packed_w)
27 {
28 const int32_t boff = (int32_t) kc * (int32_t) izp * (int32_t) kzp;
29 do {
30 for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
31 const size_t nr_block_size = min(nc - nr_block_start, nr);
32 int32_t* packed_b = (int32_t*) packed_w;
33 if XNN_LIKELY(b != NULL) {
34 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
35 *((int32_t*) packed_w) = b[nr_block_start + nr_block_offset] + boff;
36 packed_w = (void*) ((uintptr_t) packed_w + sizeof(int32_t));
37 }
38 } else {
39 size_t n = nr_block_size;
40 do {
41 *((int32_t*) packed_w) = boff;
42 packed_w = (void*) ((uintptr_t) packed_w + sizeof(int32_t));
43 } while (--n != 0);
44 }
45 packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * sizeof(int32_t));
46 for (size_t kr_block_start = 0; kr_block_start < kc; kr_block_start += kr) {
47 const size_t kr_block_size = min(kc - kr_block_start, kr);
48 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
49 int32_t ksum = 0;
50 for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; kr_block_offset++) {
51 const uint8_t kv = k[(nr_block_start + nr_block_offset) * kc + (kr_block_start + kr_block_offset)];
52 ksum += (int32_t) kv;
53 *((uint8_t*) packed_w) = kv;
54 packed_w = (void*) ((uintptr_t) packed_w + sizeof(uint8_t));
55 }
56 packed_b[nr_block_offset] -= ksum * (int32_t) izp;
57 packed_w = (void*) ((uintptr_t) packed_w + (kr - kr_block_size) * sizeof(uint8_t));
58 }
59 packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * kr * sizeof(uint8_t));
60 }
61 }
62 k += nc * kc;
63 if XNN_UNPREDICTABLE(b != NULL) {
64 b += nc;
65 }
66 } while (--g != 0);
67 }
68
xnn_pack_q8_gemm_io_w(size_t nc,size_t kc,uint32_t nr,uint32_t kr,uint8_t izp,uint8_t kzp,const uint8_t * k,const int32_t * b,void * packed_w)69 static inline void xnn_pack_q8_gemm_io_w(
70 size_t nc,
71 size_t kc,
72 uint32_t nr,
73 uint32_t kr,
74 uint8_t izp,
75 uint8_t kzp,
76 const uint8_t* k,
77 const int32_t* b,
78 void* packed_w)
79 {
80 const int32_t boff = (int32_t) kc * (int32_t) izp * (int32_t) kzp;
81 for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
82 const size_t nr_block_size = min(nc - nr_block_start, nr);
83 int32_t* packed_b = (int32_t*) packed_w;
84 if XNN_LIKELY(b != NULL) {
85 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
86 *((int32_t*) packed_w) = b[nr_block_start + nr_block_offset] + boff;
87 packed_w = (void*) ((uintptr_t) packed_w + sizeof(int32_t));
88 }
89 } else {
90 size_t n = nr_block_size;
91 do {
92 *((int32_t*) packed_w) = boff;
93 packed_w = (void*) ((uintptr_t) packed_w + sizeof(int32_t));
94 } while (--n != 0);
95 }
96 packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * sizeof(int32_t));
97 for (size_t kr_block_start = 0; kr_block_start < kc; kr_block_start += kr) {
98 const size_t kr_block_size = min(kc - kr_block_start, kr);
99 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
100 int32_t ksum = 0;
101 for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; kr_block_offset++) {
102 const uint8_t kv = k[(kr_block_start + kr_block_offset) * nc + (nr_block_start + nr_block_offset)];
103 ksum += (int32_t) kv;
104 *((uint8_t*) packed_w) = kv;
105 packed_w = (void*) ((uintptr_t) packed_w + sizeof(uint8_t));
106 }
107 packed_b[nr_block_offset] -= ksum * (int32_t) izp;
108 packed_w = (void*) ((uintptr_t) packed_w + (kr - kr_block_size) * sizeof(uint8_t));
109 }
110 packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * kr * sizeof(uint8_t));
111 }
112 }
113 }
114
xnn_pack_q8_conv_goki_w(size_t g,size_t nc,size_t ks,size_t kc,uint32_t nr,uint32_t kr,uint8_t izp,uint8_t kzp,const uint8_t * k,const int32_t * b,void * packed_w)115 static inline void xnn_pack_q8_conv_goki_w(
116 size_t g,
117 size_t nc,
118 size_t ks,
119 size_t kc,
120 uint32_t nr,
121 uint32_t kr,
122 uint8_t izp,
123 uint8_t kzp,
124 const uint8_t* k,
125 const int32_t* b,
126 void* packed_w)
127 {
128 const int32_t boff = (int32_t) ks * (int32_t) kc * (int32_t) izp * (int32_t) kzp;
129 do {
130 for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
131 const size_t nr_block_size = min(nc - nr_block_start, nr);
132 int32_t* packed_b = (int32_t*) packed_w;
133 if XNN_LIKELY(b != NULL) {
134 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
135 *((int32_t*) packed_w) = b[nr_block_start + nr_block_offset] + boff;
136 packed_w = (void*) ((uintptr_t) packed_w + sizeof(int32_t));
137 }
138 } else {
139 size_t n = nr_block_size;
140 do {
141 *((int32_t*) packed_w) = boff;
142 packed_w = (void*) ((uintptr_t) packed_w + sizeof(int32_t));
143 } while (--n != 0);
144 }
145 packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * sizeof(int32_t));
146 for (size_t ki = 0; ki < ks; ki++) {
147 for (size_t kr_block_start = 0; kr_block_start < kc; kr_block_start += kr) {
148 const size_t kr_block_size = min(kc - kr_block_start, kr);
149 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
150 int32_t ksum = 0;
151 for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; kr_block_offset++) {
152 const uint8_t kv =
153 k[((nr_block_start + nr_block_offset) * ks + ki) * kc + (kr_block_start + kr_block_offset)];
154 ksum += (int32_t) kv;
155 *((uint8_t*) packed_w) = kv;
156 packed_w = (void*) ((uintptr_t) packed_w + sizeof(uint8_t));
157 }
158 packed_b[nr_block_offset] -= ksum * (int32_t) izp;
159 packed_w = (void*) ((uintptr_t) packed_w + (kr - kr_block_size) * sizeof(uint8_t));
160 }
161 packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * kr * sizeof(uint8_t));
162 }
163 }
164 }
165 k += ks * kc * nc;
166 if XNN_UNPREDICTABLE(b != NULL) {
167 b += nc;
168 }
169 } while (--g != 0);
170 }
171
xnn_pack_q8_conv_kgo_w(size_t g,size_t nc,size_t ks,uint32_t nr,uint32_t kr,uint8_t izp,uint8_t kzp,const uint8_t * k,const int32_t * b,void * packed_w)172 static inline void xnn_pack_q8_conv_kgo_w(
173 size_t g,
174 size_t nc,
175 size_t ks,
176 uint32_t nr,
177 uint32_t kr,
178 uint8_t izp,
179 uint8_t kzp,
180 const uint8_t* k,
181 const int32_t* b,
182 void* packed_w)
183 {
184 const int32_t boff = (int32_t) ks * (int32_t) izp * (int32_t) kzp;
185 for (size_t i = 0; i < g; i++) {
186 for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
187 const size_t nr_block_size = min(nc - nr_block_start, nr);
188 int32_t* packed_b = (int32_t*) packed_w;
189 if XNN_LIKELY(b != NULL) {
190 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
191 *((int32_t*) packed_w) = b[nr_block_start + nr_block_offset] + boff;
192 packed_w = (void*) ((uintptr_t) packed_w + sizeof(int32_t));
193 }
194 } else {
195 size_t n = nr_block_size;
196 do {
197 *((int32_t*) packed_w) = boff;
198 packed_w = (void*) ((uintptr_t) packed_w + sizeof(int32_t));
199 } while (--n != 0);
200 }
201 packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * sizeof(int32_t));
202 for (size_t ki = 0; ki < ks; ki++) {
203 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
204 const uint8_t kv =
205 k[ki * g * nc + (nr_block_start + nr_block_offset)];
206 *((uint8_t*) packed_w) = kv;
207 packed_b[nr_block_offset] -= (int32_t) kv * (int32_t) izp;
208 packed_w = (void*) ((uintptr_t) packed_w + kr * sizeof(uint8_t));
209 }
210 packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * kr * sizeof(uint8_t));
211 }
212 }
213 k += nc;
214 if XNN_UNPREDICTABLE(b != NULL) {
215 b += nc;
216 }
217 }
218 }
219
xnn_pack_q8_deconv_goki_w(size_t g,size_t nc,size_t kh,size_t kw,size_t kc,size_t sh,size_t sw,size_t nr,size_t kr,uint8_t izp,uint8_t kzp,const uint8_t * k,const int32_t * b,void * packed_w,struct subconvolution_params * params)220 static inline void xnn_pack_q8_deconv_goki_w(
221 size_t g,
222 size_t nc,
223 size_t kh,
224 size_t kw,
225 size_t kc,
226 size_t sh,
227 size_t sw,
228 size_t nr,
229 size_t kr,
230 uint8_t izp,
231 uint8_t kzp,
232 const uint8_t* k,
233 const int32_t* b,
234 void* packed_w,
235 struct subconvolution_params* params)
236 {
237 for (size_t i = 0; i < g; i++) {
238 for (size_t oy = 0; oy < sh; oy++) {
239 for (size_t ox = 0; ox < sw; ox++) {
240 if (i == 0) {
241 (*params++).weights = packed_w;
242 }
243 const int32_t boff = (int32_t) divide_round_up(kh - oy, sh) * (int32_t) divide_round_up(kw - ox, sw) * (int32_t) kc * (int32_t) izp * (int32_t) kzp;
244 for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
245 const size_t nr_block_size = min(nc - nr_block_start, nr);
246 int32_t* packed_b = (int32_t*) packed_w;
247 if XNN_LIKELY(b != 0) {
248 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
249 *((int32_t*) packed_w) = b[nr_block_start + nr_block_offset] + boff;
250 packed_w = (void*) ((uintptr_t) packed_w + sizeof(int32_t));
251 }
252 } else {
253 size_t n = nr_block_size;
254 do {
255 *((int32_t*) packed_w) = boff;
256 packed_w = (void*) ((uintptr_t) packed_w + sizeof(int32_t));
257 } while (--n != 0);
258 }
259 packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * sizeof(int32_t));
260 for (size_t ky = oy; ky < kh; ky += sh) {
261 for (size_t kx = ox; kx < kw; kx += sw) {
262 for (size_t kr_block_start = 0; kr_block_start < kc; kr_block_start += kr) {
263 const size_t kr_block_size = min(kc - kr_block_start, kr);
264 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
265 int32_t ksum = 0;
266 for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; kr_block_offset++) {
267 const uint8_t kv =
268 k[(((nr_block_start + nr_block_offset) * kh + ky) * kw + kx) * kc + (kr_block_start + kr_block_offset)];
269 ksum += (int32_t) kv;
270 *((uint8_t*) packed_w) = kv;
271 packed_w = (void*) ((uintptr_t) packed_w + sizeof(uint8_t));
272 }
273 packed_b[nr_block_offset] -= ksum * (int32_t) izp;
274 packed_w = (void*) ((uintptr_t) packed_w + (kr - kr_block_size) * sizeof(uint8_t));
275 }
276 packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * kr * sizeof(uint8_t));
277 }
278 }
279 }
280 }
281 }
282 }
283 k += kh * kw * kc * nc;
284 if XNN_UNPREDICTABLE(b != NULL) {
285 b += nc;
286 }
287 }
288 }
289
xnn_pack_q8_dwconv_ghw_w(size_t h,size_t w,size_t c,size_t cr,uint8_t izp,uint8_t kzp,const uint8_t * k,const int32_t * b,void * packed_w)290 static inline void xnn_pack_q8_dwconv_ghw_w(
291 size_t h,
292 size_t w,
293 size_t c,
294 size_t cr,
295 uint8_t izp,
296 uint8_t kzp,
297 const uint8_t* k,
298 const int32_t* b,
299 void* packed_w)
300 {
301 const int32_t boff = (int32_t) h * (int32_t) w * (int32_t) izp * (int32_t) kzp;
302 for (size_t cr_block_start = 0; cr_block_start < c; cr_block_start += cr) {
303 const size_t cr_block_size = min(c - cr_block_start, cr);
304 int32_t* packed_b = (int32_t*) packed_w;
305 if XNN_LIKELY(b != NULL) {
306 for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size; cr_block_offset++) {
307 *((int32_t*) packed_w) = b[cr_block_start + cr_block_offset] + boff;
308 packed_w = (void*) ((uintptr_t) packed_w + sizeof(int32_t));
309 }
310 } else {
311 size_t n = cr_block_size;
312 do {
313 *((int32_t*) packed_w) = boff;
314 packed_w = (void*) ((uintptr_t) packed_w + sizeof(int32_t));
315 } while (--n != 0);
316 }
317 packed_w = (void*) ((uintptr_t) packed_w + (cr - cr_block_size) * sizeof(int32_t));
318 for (size_t x = 0; x < w; x++) {
319 for (size_t y = 0; y < h; y++) {
320 for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size; cr_block_offset++) {
321 const uint8_t kv = k[((cr_block_start + cr_block_offset) * h + y) * w + x];
322 packed_b[cr_block_offset] -= (int32_t) kv * (int32_t) izp;
323 *((uint8_t*) packed_w) = kv;
324 packed_w = (void*) ((uintptr_t) packed_w + sizeof(uint8_t));
325 }
326 packed_w = (void*) ((uintptr_t) packed_w + (cr - cr_block_size) * sizeof(uint8_t));
327 }
328 }
329 }
330 }
331
xnn_pack_q8_dwconv_hwg_w(size_t h,size_t w,size_t c,size_t cr,uint8_t izp,uint8_t kzp,const uint8_t * k,const int32_t * b,void * packed_w)332 static inline void xnn_pack_q8_dwconv_hwg_w(
333 size_t h,
334 size_t w,
335 size_t c,
336 size_t cr,
337 uint8_t izp,
338 uint8_t kzp,
339 const uint8_t* k,
340 const int32_t* b,
341 void* packed_w)
342 {
343 const int32_t boff = (int32_t) h * (int32_t) w * (int32_t) izp * (int32_t) kzp;
344 for (size_t cr_block_start = 0; cr_block_start < c; cr_block_start += cr) {
345 const size_t cr_block_size = min(c - cr_block_start, cr);
346 int32_t* packed_b = (int32_t*) packed_w;
347 if XNN_LIKELY(b != NULL) {
348 for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size; cr_block_offset++) {
349 *((int32_t*) packed_w) = b[cr_block_start + cr_block_offset] + boff;
350 packed_w = (void*) ((uintptr_t) packed_w + sizeof(int32_t));
351 }
352 } else {
353 size_t n = cr_block_size;
354 do {
355 *((int32_t*) packed_w) = boff;
356 packed_w = (void*) ((uintptr_t) packed_w + sizeof(int32_t));
357 } while (--n != 0);
358 }
359 packed_w = (void*) ((uintptr_t) packed_w + (cr - cr_block_size) * sizeof(int32_t));
360 for (size_t x = 0; x < w; x++) {
361 for (size_t y = 0; y < h; y++) {
362 for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size; cr_block_offset++) {
363 const uint8_t kv = k[(y * w + x) * c + (cr_block_start + cr_block_offset)];
364 packed_b[cr_block_offset] -= (int32_t) kv * (int32_t) izp;
365 *((uint8_t*) packed_w) = kv;
366 packed_w = (void*) ((uintptr_t) packed_w + sizeof(uint8_t));
367 }
368 packed_w = (void*) ((uintptr_t) packed_w + (cr - cr_block_size) * sizeof(uint8_t));
369 }
370 }
371 }
372 }
373
xnn_pack_f16_gemm_goi_w(size_t g,size_t nc,size_t kc,size_t nr,size_t kr,const uint16_t * k,const uint16_t * b,uint16_t * packed_w)374 static inline void xnn_pack_f16_gemm_goi_w(
375 size_t g,
376 size_t nc,
377 size_t kc,
378 size_t nr,
379 size_t kr,
380 const uint16_t* k,
381 const uint16_t* b,
382 uint16_t* packed_w)
383 {
384 do {
385 for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
386 const size_t nr_block_size = min(nc - nr_block_start, nr);
387 if XNN_LIKELY(b != NULL) {
388 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
389 packed_w[nr_block_offset] = b[nr_block_start + nr_block_offset];
390 }
391 }
392 packed_w += nr;
393 for (size_t kr_block_start = 0; kr_block_start < kc; kr_block_start += kr) {
394 const size_t kr_block_size = min(kc - kr_block_start, kr);
395 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
396 for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; kr_block_offset++) {
397 *packed_w++ =
398 k[(nr_block_start + nr_block_offset) * kc + (kr_block_start + kr_block_offset)];
399 }
400 packed_w += kr - kr_block_size;
401 }
402 packed_w += (nr - nr_block_size) * kr;
403 }
404 }
405 k += nc * kc;
406 if XNN_UNPREDICTABLE(b != NULL) {
407 b += nc;
408 }
409 } while (--g != 0);
410 }
411
xnn_pack_f16_gemm_io_w(size_t nc,size_t kc,size_t nr,size_t kr,const uint16_t * k,const uint16_t * b,uint16_t * packed_w)412 static inline void xnn_pack_f16_gemm_io_w(
413 size_t nc,
414 size_t kc,
415 size_t nr,
416 size_t kr,
417 const uint16_t* k,
418 const uint16_t* b,
419 uint16_t* packed_w)
420 {
421 for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
422 const size_t nr_block_size = min(nc - nr_block_start, nr);
423 if XNN_LIKELY(b != NULL) {
424 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
425 packed_w[nr_block_offset] = b[nr_block_start + nr_block_offset];
426 }
427 }
428 packed_w += nr;
429 for (size_t kr_block_start = 0; kr_block_start < kc; kr_block_start += kr) {
430 const size_t kr_block_size = min(kc - kr_block_start, kr);
431 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
432 for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; kr_block_offset++) {
433 *packed_w++ =
434 k[(kr_block_start + kr_block_offset) * nc + (nr_block_start + nr_block_offset)];
435 }
436 packed_w += kr - kr_block_size;
437 }
438 packed_w += (nr - nr_block_size) * kr;
439 }
440 }
441 }
442
xnn_pack_f32_gemm_goi_w(size_t g,size_t nc,size_t kc,size_t nr,size_t kr,size_t sr,const float * k,const float * b,float * packed_w)443 static inline void xnn_pack_f32_gemm_goi_w(
444 size_t g,
445 size_t nc,
446 size_t kc,
447 size_t nr,
448 size_t kr,
449 size_t sr,
450 const float* k,
451 const float* b,
452 float* packed_w)
453 {
454 const size_t skr = sr * kr;
455 const size_t skc = round_down_po2(kc, skr);
456 const size_t sr_mask = (sr - 1) * kr;
457 do {
458 for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
459 const size_t nr_block_size = min(nc - nr_block_start, nr);
460 if XNN_LIKELY(b != NULL) {
461 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
462 packed_w[nr_block_offset] = b[nr_block_start + nr_block_offset];
463 }
464 }
465 packed_w += nr;
466
467 for (size_t kr_block_start = 0; kr_block_start < skc; kr_block_start += kr) {
468 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
469 for (size_t kr_block_offset = 0; kr_block_offset < kr; kr_block_offset++) {
470 *packed_w++ =
471 k[(nr_block_start + nr_block_offset) * kc + round_down_po2(kr_block_start, skr) + ((kr_block_start + nr_block_offset * kr) & sr_mask) + kr_block_offset];
472 }
473 }
474 packed_w += (nr - nr_block_size) * kr;
475 }
476
477 for (size_t kr_block_start = skc; kr_block_start < kc; kr_block_start += kr) {
478 const size_t kr_block_size = min(kc - kr_block_start, kr);
479 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
480 for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; kr_block_offset++) {
481 *packed_w++ =
482 k[(nr_block_start + nr_block_offset) * kc + (kr_block_start + kr_block_offset)];
483 }
484 packed_w += kr - kr_block_size;
485 }
486 packed_w += (nr - nr_block_size) * kr;
487 }
488 }
489 k += nc * kc;
490 if XNN_UNPREDICTABLE(b != NULL) {
491 b += nc;
492 }
493 } while (--g != 0);
494 }
495
xnn_pack_f32_gemm_io_w(size_t nc,size_t kc,size_t nr,size_t kr,size_t sr,const float * k,const float * b,float * packed_w)496 static inline void xnn_pack_f32_gemm_io_w(
497 size_t nc,
498 size_t kc,
499 size_t nr,
500 size_t kr,
501 size_t sr,
502 const float* k,
503 const float* b,
504 float* packed_w)
505 {
506 const size_t skr = sr * kr;
507 const size_t skc = round_down_po2(kc, skr);
508 const size_t sr_mask = (sr - 1) * kr;
509 for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
510 const size_t nr_block_size = min(nc - nr_block_start, nr);
511 if XNN_LIKELY(b != NULL) {
512 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
513 packed_w[nr_block_offset] = b[nr_block_start + nr_block_offset];
514 }
515 }
516 packed_w += nr;
517
518 for (size_t kr_block_start = 0; kr_block_start < skc; kr_block_start += kr) {
519 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
520 for (size_t kr_block_offset = 0; kr_block_offset < kr; kr_block_offset++) {
521 *packed_w++ =
522 k[(round_down_po2(kr_block_start, skr) + ((kr_block_start + nr_block_offset * kr) & sr_mask) + kr_block_offset) * nc + (nr_block_start + nr_block_offset)];
523 }
524 }
525 packed_w += (nr - nr_block_size) * kr;
526 }
527
528 for (size_t kr_block_start = skc; kr_block_start < kc; kr_block_start += kr) {
529 const size_t kr_block_size = min(kc - kr_block_start, kr);
530 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
531 for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; kr_block_offset++) {
532 *packed_w++ =
533 k[(kr_block_start + kr_block_offset) * nc + (nr_block_start + nr_block_offset)];
534 }
535 packed_w += kr - kr_block_size;
536 }
537 packed_w += (nr - nr_block_size) * kr;
538 }
539 }
540 }
541
xnn_pack_f32_gemminc_goi_w(size_t g,size_t nc,size_t kc,size_t nr,size_t kr,size_t sr,const float * k,float * packed_w)542 static inline void xnn_pack_f32_gemminc_goi_w(
543 size_t g,
544 size_t nc,
545 size_t kc,
546 size_t nr,
547 size_t kr,
548 size_t sr,
549 const float* k,
550 float* packed_w)
551 {
552 const size_t skr = sr * kr;
553 const size_t skc = round_down_po2(kc, skr);
554 const size_t sr_mask = (sr - 1) * kr;
555 do {
556 for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
557 const size_t nr_block_size = min(nc - nr_block_start, nr);
558
559 for (size_t kr_block_start = 0; kr_block_start < skc; kr_block_start += kr) {
560 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
561 for (size_t kr_block_offset = 0; kr_block_offset < kr; kr_block_offset++) {
562 *packed_w++ =
563 k[(nr_block_start + nr_block_offset) * kc + round_down_po2(kr_block_start, skr) + ((kr_block_start + nr_block_offset * kr) & sr_mask) + kr_block_offset];
564 }
565 }
566 packed_w += (nr - nr_block_size) * kr;
567 }
568
569 for (size_t kr_block_start = skc; kr_block_start < kc; kr_block_start += kr) {
570 const size_t kr_block_size = min(kc - kr_block_start, kr);
571 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
572 for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; kr_block_offset++) {
573 *packed_w++ =
574 k[(nr_block_start + nr_block_offset) * kc + (kr_block_start + kr_block_offset)];
575 }
576 packed_w += kr - kr_block_size;
577 }
578 packed_w += (nr - nr_block_size) * kr;
579 }
580 }
581 k += nc * kc;
582 } while (--g != 0);
583 }
584
xnn_pack_f32_conv_goki_w(size_t g,size_t nc,size_t ks,size_t kc,size_t nr,size_t kr,size_t sr,const float * k,const float * b,float * packed_w)585 static inline void xnn_pack_f32_conv_goki_w(
586 size_t g,
587 size_t nc,
588 size_t ks,
589 size_t kc,
590 size_t nr,
591 size_t kr,
592 size_t sr,
593 const float* k,
594 const float* b,
595 float* packed_w)
596 {
597 const size_t skr = sr * kr;
598 const size_t skc = round_down_po2(kc, skr);
599 const size_t sr_mask = (sr - 1) * kr;
600 do {
601 for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
602 const size_t nr_block_size = min(nc - nr_block_start, nr);
603 if XNN_LIKELY(b != NULL) {
604 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
605 packed_w[nr_block_offset] = b[nr_block_start + nr_block_offset];
606 }
607 }
608 packed_w += nr;
609
610 for (size_t ki = 0; ki < ks; ki++) {
611 for (size_t kr_block_start = 0; kr_block_start < skc; kr_block_start += kr) {
612 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
613 for (size_t kr_block_offset = 0; kr_block_offset < kr; kr_block_offset++) {
614 *packed_w++ =
615 k[((nr_block_start + nr_block_offset) * ks + ki) * kc + round_down_po2(kr_block_start, skr) + ((kr_block_start + nr_block_offset * kr) & sr_mask) + kr_block_offset];
616 }
617 }
618 packed_w += (nr - nr_block_size) * kr;
619 }
620
621 for (size_t kr_block_start = skc; kr_block_start < kc; kr_block_start += kr) {
622 const size_t kr_block_size = min(kc - kr_block_start, kr);
623 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
624 for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; kr_block_offset++) {
625 *packed_w++ =
626 k[((nr_block_start + nr_block_offset) * ks + ki) * kc + (kr_block_start + kr_block_offset)];
627 }
628 packed_w += kr - kr_block_size;
629 }
630 packed_w += (nr - nr_block_size) * kr;
631 }
632 }
633 }
634 k += ks * kc * nc;
635 if XNN_UNPREDICTABLE(b != NULL) {
636 b += nc;
637 }
638 } while (--g != 0);
639 }
640
xnn_pack_f32_conv_kgo_w(size_t g,size_t nc,size_t ks,size_t nr,size_t kr,const float * k,const float * b,float * packed_w)641 static inline void xnn_pack_f32_conv_kgo_w(
642 size_t g,
643 size_t nc,
644 size_t ks,
645 size_t nr,
646 size_t kr,
647 const float* k,
648 const float* b,
649 float* packed_w)
650 {
651 for (size_t i = 0; i < g; i++) {
652 for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
653 const size_t nr_block_size = min(nc - nr_block_start, nr);
654 if XNN_LIKELY(b != NULL) {
655 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
656 packed_w[nr_block_offset] = b[nr_block_start + nr_block_offset];
657 }
658 }
659 packed_w += nr;
660 for (size_t ki = 0; ki < ks; ki++) {
661 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
662 *packed_w =
663 k[ki * g * nc + (nr_block_start + nr_block_offset)];
664 packed_w += kr;
665 }
666 packed_w += (nr - nr_block_size) * kr;
667 }
668 }
669 k += nc;
670 if XNN_UNPREDICTABLE(b != NULL) {
671 b += nc;
672 }
673 }
674 }
675
xnn_pack_f32_dconv_oki_w(size_t nc,size_t kc,size_t nr,size_t kh,size_t kw,const float * k,const float * b,float * packed_w)676 static inline void xnn_pack_f32_dconv_oki_w(
677 size_t nc,
678 size_t kc,
679 size_t nr,
680 size_t kh,
681 size_t kw,
682 const float* k,
683 const float* b,
684 float* packed_w)
685 {
686 for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
687 const size_t nr_block_size = min(nc - nr_block_start, nr);
688 if XNN_LIKELY(b != NULL) {
689 for (size_t nr_block_offset = 0; nr_block_offset < nr; nr_block_offset++) {
690 *packed_w++ = b[min(nr_block_offset, nr_block_size - 1)];
691 }
692 } else {
693 size_t n = nr;
694 do {
695 *packed_w++ = 0.0f;
696 } while (--n != 0);
697 }
698
699 for (size_t kx = 0; kx < kw; kx++) {
700 for (size_t c = 0; c < kc; c++) {
701 for (size_t ky = 0; ky < kh; ky++) {
702 for (size_t nr_block_offset = 0; nr_block_offset < nr; nr_block_offset++) {
703 *packed_w++ = k[(((nr_block_start + min(nr_block_offset, nr_block_size - 1)) * kh + ky) * kw + kx) * kc + c];
704 }
705 }
706 }
707 }
708 if XNN_UNPREDICTABLE(b != NULL) {
709 b += nr;
710 }
711 }
712 }
713
xnn_pack_f32_deconv_goki_w(size_t g,size_t nc,size_t kh,size_t kw,size_t kc,size_t sh,size_t sw,size_t nr,size_t kr,size_t sr,const float * k,const float * b,float * packed_w,struct subconvolution_params * params)714 static inline void xnn_pack_f32_deconv_goki_w(
715 size_t g,
716 size_t nc,
717 size_t kh,
718 size_t kw,
719 size_t kc,
720 size_t sh,
721 size_t sw,
722 size_t nr,
723 size_t kr,
724 size_t sr,
725 const float* k,
726 const float* b,
727 float* packed_w,
728 struct subconvolution_params* params)
729 {
730 const size_t skr = sr * kr;
731 const size_t skc = round_down_po2(kc, skr);
732 const size_t sr_mask = (sr - 1) * kr;
733 for (size_t i = 0; i < g; i++) {
734 for (size_t oy = 0; oy < sh; oy++) {
735 for (size_t ox = 0; ox < sw; ox++) {
736 if (i == 0) {
737 (*params++).weights = packed_w;
738 }
739 for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
740 const size_t nr_block_size = min(nc - nr_block_start, nr);
741 if XNN_LIKELY(b != NULL) {
742 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
743 packed_w[nr_block_offset] = b[nr_block_start + nr_block_offset];
744 }
745 }
746 packed_w += nr;
747 for (size_t ky = oy; ky < kh; ky += sh) {
748 for (size_t kx = ox; kx < kw; kx += sw) {
749 for (size_t kr_block_start = 0; kr_block_start < skc; kr_block_start += kr) {
750 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
751 for (size_t kr_block_offset = 0; kr_block_offset < kr; kr_block_offset++) {
752 *packed_w++ =
753 k[(((nr_block_start + nr_block_offset) * kh + ky) * kw + kx) * kc + round_down_po2(kr_block_start, skr) + ((kr_block_start + nr_block_offset * kr) & sr_mask) + kr_block_offset];
754 }
755 }
756 packed_w += (nr - nr_block_size) * kr;
757 }
758
759 for (size_t kr_block_start = skc; kr_block_start < kc; kr_block_start += kr) {
760 const size_t kr_block_size = min(kc - kr_block_start, kr);
761 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
762 for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; kr_block_offset++) {
763 *packed_w++ =
764 k[(((nr_block_start + nr_block_offset) * kh + ky) * kw + kx) * kc + (kr_block_start + kr_block_offset)];
765 }
766 packed_w += kr - kr_block_size;
767 }
768 packed_w += (nr - nr_block_size) * kr;
769 }
770 }
771 }
772 }
773 }
774 }
775 k += kh * kw * kc * nc;
776 if XNN_UNPREDICTABLE(b != NULL) {
777 b += nc;
778 }
779 }
780 }
781
xnn_pack_f32_dwconv_ghw_w(size_t h,size_t w,size_t c,size_t cr,const float * k,const float * b,float * packed_w)782 static inline void xnn_pack_f32_dwconv_ghw_w(
783 size_t h,
784 size_t w,
785 size_t c,
786 size_t cr,
787 const float* k,
788 const float* b,
789 float* packed_w)
790 {
791 for (size_t cr_block_start = 0; cr_block_start < c; cr_block_start += cr) {
792 const size_t cr_block_size = min(c - cr_block_start, cr);
793 if XNN_LIKELY(b != NULL) {
794 for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size; cr_block_offset++) {
795 *packed_w++ = b[cr_block_start + cr_block_offset];
796 }
797 } else {
798 size_t n = cr_block_size;
799 do {
800 *packed_w++ = 0.0f;
801 } while (--n != 0);
802 }
803 packed_w += cr - cr_block_size;
804 for (size_t x = 0; x < w; x++) {
805 for (size_t y = 0; y < h; y++) {
806 for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size; cr_block_offset++) {
807 const float kv = k[((cr_block_start + cr_block_offset) * h + y) * w + x];
808 *packed_w++ = kv;
809 }
810 packed_w += cr - cr_block_size;
811 }
812 }
813 }
814 }
815
xnn_pack_f32_dwconv_hwg_w(size_t h,size_t w,size_t c,size_t cr,const float * k,const float * b,float * packed_w)816 static inline void xnn_pack_f32_dwconv_hwg_w(
817 size_t h,
818 size_t w,
819 size_t c,
820 size_t cr,
821 const float* k,
822 const float* b,
823 float* packed_w)
824 {
825 for (size_t cr_block_start = 0; cr_block_start < c; cr_block_start += cr) {
826 const size_t cr_block_size = min(c - cr_block_start, cr);
827 if XNN_LIKELY(b != NULL) {
828 for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size; cr_block_offset++) {
829 *packed_w++ = b[cr_block_start + cr_block_offset];
830 }
831 } else {
832 size_t n = cr_block_size;
833 do {
834 *packed_w++ = 0.0f;
835 } while (--n != 0);
836 }
837 packed_w += cr - cr_block_size;
838 for (size_t x = 0; x < w; x++) {
839 for (size_t y = 0; y < h; y++) {
840 for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size; cr_block_offset++) {
841 const float kv = k[(y * w + x) * c + (cr_block_start + cr_block_offset)];
842 *packed_w++ = kv;
843 }
844 packed_w += cr - cr_block_size;
845 }
846 }
847 }
848 }
849
xnn_pack_f32_spchw_dwconv_ghw_w(size_t kernel_size,size_t groups,const float * kernel,const float * bias,float * packed_weights)850 static inline void xnn_pack_f32_spchw_dwconv_ghw_w(
851 size_t kernel_size,
852 size_t groups,
853 const float* kernel,
854 const float* bias,
855 float* packed_weights)
856 {
857 for (size_t g = 0; g < groups; g++) {
858 if XNN_LIKELY(bias != NULL) {
859 *packed_weights = *bias++;
860 } else {
861 *packed_weights = 0.0f;
862 }
863 packed_weights += 1;
864 for (size_t i = 0; i < kernel_size; i++) {
865 *packed_weights++ = kernel[g * kernel_size + i];
866 }
867 }
868 }
869
xnn_pack_f32_vmulcaddc_w(size_t c,size_t cr,const float * s,const float * b,float * packed_w)870 static inline void xnn_pack_f32_vmulcaddc_w(
871 size_t c,
872 size_t cr,
873 const float* s,
874 const float* b,
875 float* packed_w)
876 {
877 for (size_t cr_block_start = 0; cr_block_start < c; cr_block_start += cr) {
878 const size_t cr_block_size = min(c - cr_block_start, cr);
879 for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size; cr_block_offset++) {
880 *packed_w++ = s[cr_block_start + cr_block_offset];
881 }
882 packed_w += cr - cr_block_size;
883 if XNN_LIKELY(b != NULL) {
884 for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size; cr_block_offset++) {
885 *packed_w++ = b[cr_block_start + cr_block_offset];
886 }
887 } else {
888 size_t n = cr_block_size;
889 do {
890 *packed_w++ = 0.0f;
891 } while (--n != 0);
892 }
893 packed_w += cr - cr_block_size;
894 }
895 }
896