• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8 
9 #pragma once
10 
11 #include <stdint.h>
12 #include <xnnpack/math.h>
13 #include <xnnpack/operator.h>
14 
15 
xnn_pack_q8_gemm_goi_w(size_t g,size_t nc,size_t kc,uint32_t nr,uint32_t kr,uint8_t izp,uint8_t kzp,const uint8_t * k,const int32_t * b,void * packed_w)16 static inline void xnn_pack_q8_gemm_goi_w(
17   size_t g,
18   size_t nc,
19   size_t kc,
20   uint32_t nr,
21   uint32_t kr,
22   uint8_t izp,
23   uint8_t kzp,
24   const uint8_t* k,
25   const int32_t* b,
26   void* packed_w)
27 {
28   const int32_t boff = (int32_t) kc * (int32_t) izp * (int32_t) kzp;
29   do {
30     for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
31       const size_t nr_block_size = min(nc - nr_block_start, nr);
32       int32_t* packed_b = (int32_t*) packed_w;
33       if XNN_LIKELY(b != NULL) {
34         for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
35           *((int32_t*) packed_w) = b[nr_block_start + nr_block_offset] + boff;
36           packed_w = (void*) ((uintptr_t) packed_w + sizeof(int32_t));
37         }
38       } else {
39         size_t n = nr_block_size;
40         do {
41           *((int32_t*) packed_w) = boff;
42           packed_w = (void*) ((uintptr_t) packed_w + sizeof(int32_t));
43         } while (--n != 0);
44       }
45       packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * sizeof(int32_t));
46       for (size_t kr_block_start = 0; kr_block_start < kc; kr_block_start += kr) {
47         const size_t kr_block_size = min(kc - kr_block_start, kr);
48         for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
49           int32_t ksum = 0;
50           for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; kr_block_offset++) {
51             const uint8_t kv = k[(nr_block_start + nr_block_offset) * kc + (kr_block_start + kr_block_offset)];
52             ksum += (int32_t) kv;
53             *((uint8_t*) packed_w) = kv;
54             packed_w = (void*) ((uintptr_t) packed_w + sizeof(uint8_t));
55           }
56           packed_b[nr_block_offset] -= ksum * (int32_t) izp;
57           packed_w = (void*) ((uintptr_t) packed_w + (kr - kr_block_size) * sizeof(uint8_t));
58         }
59         packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * kr * sizeof(uint8_t));
60       }
61     }
62     k += nc * kc;
63     if XNN_UNPREDICTABLE(b != NULL) {
64       b += nc;
65     }
66   } while (--g != 0);
67 }
68 
xnn_pack_q8_gemm_io_w(size_t nc,size_t kc,uint32_t nr,uint32_t kr,uint8_t izp,uint8_t kzp,const uint8_t * k,const int32_t * b,void * packed_w)69 static inline void xnn_pack_q8_gemm_io_w(
70   size_t nc,
71   size_t kc,
72   uint32_t nr,
73   uint32_t kr,
74   uint8_t izp,
75   uint8_t kzp,
76   const uint8_t* k,
77   const int32_t* b,
78   void* packed_w)
79 {
80   const int32_t boff = (int32_t) kc * (int32_t) izp * (int32_t) kzp;
81   for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
82     const size_t nr_block_size = min(nc - nr_block_start, nr);
83     int32_t* packed_b = (int32_t*) packed_w;
84     if XNN_LIKELY(b != NULL) {
85       for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
86         *((int32_t*) packed_w) = b[nr_block_start + nr_block_offset] + boff;
87         packed_w = (void*) ((uintptr_t) packed_w + sizeof(int32_t));
88       }
89     } else {
90       size_t n = nr_block_size;
91       do {
92         *((int32_t*) packed_w) = boff;
93         packed_w = (void*) ((uintptr_t) packed_w + sizeof(int32_t));
94       } while (--n != 0);
95     }
96     packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * sizeof(int32_t));
97     for (size_t kr_block_start = 0; kr_block_start < kc; kr_block_start += kr) {
98       const size_t kr_block_size = min(kc - kr_block_start, kr);
99       for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
100         int32_t ksum = 0;
101         for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; kr_block_offset++) {
102           const uint8_t kv = k[(kr_block_start + kr_block_offset) * nc + (nr_block_start + nr_block_offset)];
103           ksum += (int32_t) kv;
104           *((uint8_t*) packed_w) = kv;
105           packed_w = (void*) ((uintptr_t) packed_w + sizeof(uint8_t));
106         }
107         packed_b[nr_block_offset] -= ksum * (int32_t) izp;
108         packed_w = (void*) ((uintptr_t) packed_w + (kr - kr_block_size) * sizeof(uint8_t));
109       }
110       packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * kr * sizeof(uint8_t));
111     }
112   }
113 }
114 
xnn_pack_q8_conv_goki_w(size_t g,size_t nc,size_t ks,size_t kc,uint32_t nr,uint32_t kr,uint8_t izp,uint8_t kzp,const uint8_t * k,const int32_t * b,void * packed_w)115 static inline void xnn_pack_q8_conv_goki_w(
116   size_t g,
117   size_t nc,
118   size_t ks,
119   size_t kc,
120   uint32_t nr,
121   uint32_t kr,
122   uint8_t izp,
123   uint8_t kzp,
124   const uint8_t* k,
125   const int32_t* b,
126   void* packed_w)
127 {
128   const int32_t boff = (int32_t) ks * (int32_t) kc * (int32_t) izp * (int32_t) kzp;
129   do {
130     for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
131       const size_t nr_block_size = min(nc - nr_block_start, nr);
132       int32_t* packed_b = (int32_t*) packed_w;
133       if XNN_LIKELY(b != NULL) {
134         for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
135           *((int32_t*) packed_w) = b[nr_block_start + nr_block_offset] + boff;
136           packed_w = (void*) ((uintptr_t) packed_w + sizeof(int32_t));
137         }
138       } else {
139         size_t n = nr_block_size;
140         do {
141           *((int32_t*) packed_w) = boff;
142           packed_w = (void*) ((uintptr_t) packed_w + sizeof(int32_t));
143         } while (--n != 0);
144       }
145       packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * sizeof(int32_t));
146       for (size_t ki = 0; ki < ks; ki++) {
147         for (size_t kr_block_start = 0; kr_block_start < kc; kr_block_start += kr) {
148           const size_t kr_block_size = min(kc - kr_block_start, kr);
149           for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
150             int32_t ksum = 0;
151             for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; kr_block_offset++) {
152               const uint8_t kv =
153                 k[((nr_block_start + nr_block_offset) * ks + ki) * kc + (kr_block_start + kr_block_offset)];
154               ksum += (int32_t) kv;
155               *((uint8_t*) packed_w) = kv;
156               packed_w = (void*) ((uintptr_t) packed_w + sizeof(uint8_t));
157             }
158             packed_b[nr_block_offset] -= ksum * (int32_t) izp;
159             packed_w = (void*) ((uintptr_t) packed_w + (kr - kr_block_size) * sizeof(uint8_t));
160           }
161           packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * kr * sizeof(uint8_t));
162         }
163       }
164     }
165     k += ks * kc * nc;
166     if XNN_UNPREDICTABLE(b != NULL) {
167       b += nc;
168     }
169   } while (--g != 0);
170 }
171 
xnn_pack_q8_conv_kgo_w(size_t g,size_t nc,size_t ks,uint32_t nr,uint32_t kr,uint8_t izp,uint8_t kzp,const uint8_t * k,const int32_t * b,void * packed_w)172 static inline void xnn_pack_q8_conv_kgo_w(
173   size_t g,
174   size_t nc,
175   size_t ks,
176   uint32_t nr,
177   uint32_t kr,
178   uint8_t izp,
179   uint8_t kzp,
180   const uint8_t* k,
181   const int32_t* b,
182   void* packed_w)
183 {
184   const int32_t boff = (int32_t) ks * (int32_t) izp * (int32_t) kzp;
185   for (size_t i = 0; i < g; i++) {
186     for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
187       const size_t nr_block_size = min(nc - nr_block_start, nr);
188       int32_t* packed_b = (int32_t*) packed_w;
189       if XNN_LIKELY(b != NULL) {
190         for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
191           *((int32_t*) packed_w) = b[nr_block_start + nr_block_offset] + boff;
192           packed_w = (void*) ((uintptr_t) packed_w + sizeof(int32_t));
193         }
194       } else {
195         size_t n = nr_block_size;
196         do {
197           *((int32_t*) packed_w) = boff;
198           packed_w = (void*) ((uintptr_t) packed_w + sizeof(int32_t));
199         } while (--n != 0);
200       }
201       packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * sizeof(int32_t));
202       for (size_t ki = 0; ki < ks; ki++) {
203         for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
204           const uint8_t kv =
205             k[ki * g * nc + (nr_block_start + nr_block_offset)];
206           *((uint8_t*) packed_w) = kv;
207           packed_b[nr_block_offset] -= (int32_t) kv * (int32_t) izp;
208           packed_w = (void*) ((uintptr_t) packed_w + kr * sizeof(uint8_t));
209         }
210         packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * kr * sizeof(uint8_t));
211       }
212     }
213     k += nc;
214     if XNN_UNPREDICTABLE(b != NULL) {
215       b += nc;
216     }
217   }
218 }
219 
xnn_pack_q8_deconv_goki_w(size_t g,size_t nc,size_t kh,size_t kw,size_t kc,size_t sh,size_t sw,size_t nr,size_t kr,uint8_t izp,uint8_t kzp,const uint8_t * k,const int32_t * b,void * packed_w,struct subconvolution_params * params)220 static inline void xnn_pack_q8_deconv_goki_w(
221   size_t g,
222   size_t nc,
223   size_t kh,
224   size_t kw,
225   size_t kc,
226   size_t sh,
227   size_t sw,
228   size_t nr,
229   size_t kr,
230   uint8_t izp,
231   uint8_t kzp,
232   const uint8_t* k,
233   const int32_t* b,
234   void* packed_w,
235   struct subconvolution_params* params)
236 {
237   for (size_t i = 0; i < g; i++) {
238     for (size_t oy = 0; oy < sh; oy++) {
239       for (size_t ox = 0; ox < sw; ox++) {
240         if (i == 0) {
241           (*params++).weights = packed_w;
242         }
243         const int32_t boff = (int32_t) divide_round_up(kh - oy, sh) * (int32_t) divide_round_up(kw - ox, sw) * (int32_t) kc * (int32_t) izp * (int32_t) kzp;
244         for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
245           const size_t nr_block_size = min(nc - nr_block_start, nr);
246           int32_t* packed_b = (int32_t*) packed_w;
247           if XNN_LIKELY(b != 0) {
248             for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
249               *((int32_t*) packed_w) = b[nr_block_start + nr_block_offset] + boff;
250               packed_w = (void*) ((uintptr_t) packed_w + sizeof(int32_t));
251             }
252           } else {
253             size_t n = nr_block_size;
254             do {
255               *((int32_t*) packed_w) = boff;
256               packed_w = (void*) ((uintptr_t) packed_w + sizeof(int32_t));
257             } while (--n != 0);
258           }
259           packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * sizeof(int32_t));
260           for (size_t ky = oy; ky < kh; ky += sh) {
261             for (size_t kx = ox; kx < kw; kx += sw) {
262               for (size_t kr_block_start = 0; kr_block_start < kc; kr_block_start += kr) {
263                 const size_t kr_block_size = min(kc - kr_block_start, kr);
264                 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
265                   int32_t ksum = 0;
266                   for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; kr_block_offset++) {
267                     const uint8_t kv =
268                       k[(((nr_block_start + nr_block_offset) * kh + ky) * kw + kx) * kc + (kr_block_start + kr_block_offset)];
269                     ksum += (int32_t) kv;
270                     *((uint8_t*) packed_w) = kv;
271                     packed_w = (void*) ((uintptr_t) packed_w + sizeof(uint8_t));
272                   }
273                   packed_b[nr_block_offset] -= ksum * (int32_t) izp;
274                   packed_w = (void*) ((uintptr_t) packed_w + (kr - kr_block_size) * sizeof(uint8_t));
275                 }
276                 packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * kr * sizeof(uint8_t));
277               }
278             }
279           }
280         }
281       }
282     }
283     k += kh * kw * kc * nc;
284     if XNN_UNPREDICTABLE(b != NULL) {
285       b += nc;
286     }
287   }
288 }
289 
xnn_pack_q8_dwconv_ghw_w(size_t h,size_t w,size_t c,size_t cr,uint8_t izp,uint8_t kzp,const uint8_t * k,const int32_t * b,void * packed_w)290 static inline void xnn_pack_q8_dwconv_ghw_w(
291   size_t h,
292   size_t w,
293   size_t c,
294   size_t cr,
295   uint8_t izp,
296   uint8_t kzp,
297   const uint8_t* k,
298   const int32_t* b,
299   void* packed_w)
300 {
301   const int32_t boff = (int32_t) h * (int32_t) w * (int32_t) izp * (int32_t) kzp;
302   for (size_t cr_block_start = 0; cr_block_start < c; cr_block_start += cr) {
303     const size_t cr_block_size = min(c - cr_block_start, cr);
304     int32_t* packed_b = (int32_t*) packed_w;
305     if XNN_LIKELY(b != NULL) {
306       for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size; cr_block_offset++) {
307         *((int32_t*) packed_w) = b[cr_block_start + cr_block_offset] + boff;
308         packed_w = (void*) ((uintptr_t) packed_w + sizeof(int32_t));
309       }
310     } else {
311       size_t n = cr_block_size;
312       do {
313         *((int32_t*) packed_w) = boff;
314         packed_w = (void*) ((uintptr_t) packed_w + sizeof(int32_t));
315       } while (--n != 0);
316     }
317     packed_w = (void*) ((uintptr_t) packed_w + (cr - cr_block_size) * sizeof(int32_t));
318     for (size_t x = 0; x < w; x++) {
319       for (size_t y = 0; y < h; y++) {
320         for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size; cr_block_offset++) {
321           const uint8_t kv = k[((cr_block_start + cr_block_offset) * h + y) * w + x];
322           packed_b[cr_block_offset] -= (int32_t) kv * (int32_t) izp;
323           *((uint8_t*) packed_w) = kv;
324           packed_w = (void*) ((uintptr_t) packed_w + sizeof(uint8_t));
325         }
326         packed_w = (void*) ((uintptr_t) packed_w + (cr - cr_block_size) * sizeof(uint8_t));
327       }
328     }
329   }
330 }
331 
xnn_pack_q8_dwconv_hwg_w(size_t h,size_t w,size_t c,size_t cr,uint8_t izp,uint8_t kzp,const uint8_t * k,const int32_t * b,void * packed_w)332 static inline void xnn_pack_q8_dwconv_hwg_w(
333   size_t h,
334   size_t w,
335   size_t c,
336   size_t cr,
337   uint8_t izp,
338   uint8_t kzp,
339   const uint8_t* k,
340   const int32_t* b,
341   void* packed_w)
342 {
343   const int32_t boff = (int32_t) h * (int32_t) w * (int32_t) izp * (int32_t) kzp;
344   for (size_t cr_block_start = 0; cr_block_start < c; cr_block_start += cr) {
345     const size_t cr_block_size = min(c - cr_block_start, cr);
346     int32_t* packed_b = (int32_t*) packed_w;
347     if XNN_LIKELY(b != NULL) {
348       for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size; cr_block_offset++) {
349         *((int32_t*) packed_w) = b[cr_block_start + cr_block_offset] + boff;
350         packed_w = (void*) ((uintptr_t) packed_w + sizeof(int32_t));
351       }
352     } else {
353       size_t n = cr_block_size;
354       do {
355         *((int32_t*) packed_w) = boff;
356         packed_w = (void*) ((uintptr_t) packed_w + sizeof(int32_t));
357       } while (--n != 0);
358     }
359     packed_w = (void*) ((uintptr_t) packed_w + (cr - cr_block_size) * sizeof(int32_t));
360     for (size_t x = 0; x < w; x++) {
361       for (size_t y = 0; y < h; y++) {
362         for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size; cr_block_offset++) {
363           const uint8_t kv = k[(y * w + x) * c + (cr_block_start + cr_block_offset)];
364           packed_b[cr_block_offset] -= (int32_t) kv * (int32_t) izp;
365           *((uint8_t*) packed_w) = kv;
366           packed_w = (void*) ((uintptr_t) packed_w + sizeof(uint8_t));
367         }
368         packed_w = (void*) ((uintptr_t) packed_w + (cr - cr_block_size) * sizeof(uint8_t));
369       }
370     }
371   }
372 }
373 
xnn_pack_f16_gemm_goi_w(size_t g,size_t nc,size_t kc,size_t nr,size_t kr,const uint16_t * k,const uint16_t * b,uint16_t * packed_w)374 static inline void xnn_pack_f16_gemm_goi_w(
375   size_t g,
376   size_t nc,
377   size_t kc,
378   size_t nr,
379   size_t kr,
380   const uint16_t* k,
381   const uint16_t* b,
382   uint16_t* packed_w)
383 {
384   do {
385     for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
386       const size_t nr_block_size = min(nc - nr_block_start, nr);
387       if XNN_LIKELY(b != NULL) {
388         for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
389           packed_w[nr_block_offset] = b[nr_block_start + nr_block_offset];
390         }
391       }
392       packed_w += nr;
393       for (size_t kr_block_start = 0; kr_block_start < kc; kr_block_start += kr) {
394         const size_t kr_block_size = min(kc - kr_block_start, kr);
395         for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
396           for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; kr_block_offset++) {
397             *packed_w++ =
398               k[(nr_block_start + nr_block_offset) * kc + (kr_block_start + kr_block_offset)];
399           }
400           packed_w += kr - kr_block_size;
401         }
402         packed_w += (nr - nr_block_size) * kr;
403       }
404     }
405     k += nc * kc;
406     if XNN_UNPREDICTABLE(b != NULL) {
407       b += nc;
408     }
409   } while (--g != 0);
410 }
411 
xnn_pack_f16_gemm_io_w(size_t nc,size_t kc,size_t nr,size_t kr,const uint16_t * k,const uint16_t * b,uint16_t * packed_w)412 static inline void xnn_pack_f16_gemm_io_w(
413   size_t nc,
414   size_t kc,
415   size_t nr,
416   size_t kr,
417   const uint16_t* k,
418   const uint16_t* b,
419   uint16_t* packed_w)
420 {
421   for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
422     const size_t nr_block_size = min(nc - nr_block_start, nr);
423     if XNN_LIKELY(b != NULL) {
424       for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
425         packed_w[nr_block_offset] = b[nr_block_start + nr_block_offset];
426       }
427     }
428     packed_w += nr;
429     for (size_t kr_block_start = 0; kr_block_start < kc; kr_block_start += kr) {
430       const size_t kr_block_size = min(kc - kr_block_start, kr);
431       for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
432         for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; kr_block_offset++) {
433           *packed_w++ =
434             k[(kr_block_start + kr_block_offset) * nc + (nr_block_start + nr_block_offset)];
435         }
436         packed_w += kr - kr_block_size;
437       }
438       packed_w += (nr - nr_block_size) * kr;
439     }
440   }
441 }
442 
xnn_pack_f32_gemm_goi_w(size_t g,size_t nc,size_t kc,size_t nr,size_t kr,size_t sr,const float * k,const float * b,float * packed_w)443 static inline void xnn_pack_f32_gemm_goi_w(
444   size_t g,
445   size_t nc,
446   size_t kc,
447   size_t nr,
448   size_t kr,
449   size_t sr,
450   const float* k,
451   const float* b,
452   float* packed_w)
453 {
454   const size_t skr = sr * kr;
455   const size_t skc = round_down_po2(kc, skr);
456   const size_t sr_mask = (sr - 1) * kr;
457   do {
458     for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
459       const size_t nr_block_size = min(nc - nr_block_start, nr);
460       if XNN_LIKELY(b != NULL) {
461         for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
462           packed_w[nr_block_offset] = b[nr_block_start + nr_block_offset];
463         }
464       }
465       packed_w += nr;
466 
467       for (size_t kr_block_start = 0; kr_block_start < skc; kr_block_start += kr) {
468         for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
469           for (size_t kr_block_offset = 0; kr_block_offset < kr; kr_block_offset++) {
470             *packed_w++ =
471               k[(nr_block_start + nr_block_offset) * kc + round_down_po2(kr_block_start, skr) + ((kr_block_start + nr_block_offset * kr) & sr_mask) + kr_block_offset];
472           }
473         }
474         packed_w += (nr - nr_block_size) * kr;
475       }
476 
477       for (size_t kr_block_start = skc; kr_block_start < kc; kr_block_start += kr) {
478         const size_t kr_block_size = min(kc - kr_block_start, kr);
479         for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
480           for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; kr_block_offset++) {
481             *packed_w++ =
482               k[(nr_block_start + nr_block_offset) * kc + (kr_block_start + kr_block_offset)];
483           }
484           packed_w += kr - kr_block_size;
485         }
486         packed_w += (nr - nr_block_size) * kr;
487       }
488     }
489     k += nc * kc;
490     if XNN_UNPREDICTABLE(b != NULL) {
491       b += nc;
492     }
493   } while (--g != 0);
494 }
495 
xnn_pack_f32_gemm_io_w(size_t nc,size_t kc,size_t nr,size_t kr,size_t sr,const float * k,const float * b,float * packed_w)496 static inline void xnn_pack_f32_gemm_io_w(
497   size_t nc,
498   size_t kc,
499   size_t nr,
500   size_t kr,
501   size_t sr,
502   const float* k,
503   const float* b,
504   float* packed_w)
505 {
506   const size_t skr = sr * kr;
507   const size_t skc = round_down_po2(kc, skr);
508   const size_t sr_mask = (sr - 1) * kr;
509   for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
510     const size_t nr_block_size = min(nc - nr_block_start, nr);
511     if XNN_LIKELY(b != NULL) {
512       for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
513         packed_w[nr_block_offset] = b[nr_block_start + nr_block_offset];
514       }
515     }
516     packed_w += nr;
517 
518     for (size_t kr_block_start = 0; kr_block_start < skc; kr_block_start += kr) {
519       for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
520         for (size_t kr_block_offset = 0; kr_block_offset < kr; kr_block_offset++) {
521           *packed_w++ =
522             k[(round_down_po2(kr_block_start, skr) + ((kr_block_start + nr_block_offset * kr) & sr_mask) + kr_block_offset) * nc + (nr_block_start + nr_block_offset)];
523         }
524       }
525       packed_w += (nr - nr_block_size) * kr;
526     }
527 
528     for (size_t kr_block_start = skc; kr_block_start < kc; kr_block_start += kr) {
529       const size_t kr_block_size = min(kc - kr_block_start, kr);
530       for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
531         for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; kr_block_offset++) {
532           *packed_w++ =
533             k[(kr_block_start + kr_block_offset) * nc + (nr_block_start + nr_block_offset)];
534         }
535         packed_w += kr - kr_block_size;
536       }
537       packed_w += (nr - nr_block_size) * kr;
538     }
539   }
540 }
541 
xnn_pack_f32_gemminc_goi_w(size_t g,size_t nc,size_t kc,size_t nr,size_t kr,size_t sr,const float * k,float * packed_w)542 static inline void xnn_pack_f32_gemminc_goi_w(
543   size_t g,
544   size_t nc,
545   size_t kc,
546   size_t nr,
547   size_t kr,
548   size_t sr,
549   const float* k,
550   float* packed_w)
551 {
552   const size_t skr = sr * kr;
553   const size_t skc = round_down_po2(kc, skr);
554   const size_t sr_mask = (sr - 1) * kr;
555   do {
556     for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
557       const size_t nr_block_size = min(nc - nr_block_start, nr);
558 
559       for (size_t kr_block_start = 0; kr_block_start < skc; kr_block_start += kr) {
560         for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
561           for (size_t kr_block_offset = 0; kr_block_offset < kr; kr_block_offset++) {
562             *packed_w++ =
563               k[(nr_block_start + nr_block_offset) * kc + round_down_po2(kr_block_start, skr) + ((kr_block_start + nr_block_offset * kr) & sr_mask) + kr_block_offset];
564           }
565         }
566         packed_w += (nr - nr_block_size) * kr;
567       }
568 
569       for (size_t kr_block_start = skc; kr_block_start < kc; kr_block_start += kr) {
570         const size_t kr_block_size = min(kc - kr_block_start, kr);
571         for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
572           for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; kr_block_offset++) {
573             *packed_w++ =
574               k[(nr_block_start + nr_block_offset) * kc + (kr_block_start + kr_block_offset)];
575           }
576           packed_w += kr - kr_block_size;
577         }
578         packed_w += (nr - nr_block_size) * kr;
579       }
580     }
581     k += nc * kc;
582   } while (--g != 0);
583 }
584 
xnn_pack_f32_conv_goki_w(size_t g,size_t nc,size_t ks,size_t kc,size_t nr,size_t kr,size_t sr,const float * k,const float * b,float * packed_w)585 static inline void xnn_pack_f32_conv_goki_w(
586   size_t g,
587   size_t nc,
588   size_t ks,
589   size_t kc,
590   size_t nr,
591   size_t kr,
592   size_t sr,
593   const float* k,
594   const float* b,
595   float* packed_w)
596 {
597   const size_t skr = sr * kr;
598   const size_t skc = round_down_po2(kc, skr);
599   const size_t sr_mask = (sr - 1) * kr;
600   do {
601     for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
602       const size_t nr_block_size = min(nc - nr_block_start, nr);
603       if XNN_LIKELY(b != NULL) {
604         for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
605           packed_w[nr_block_offset] = b[nr_block_start + nr_block_offset];
606         }
607       }
608       packed_w += nr;
609 
610       for (size_t ki = 0; ki < ks; ki++) {
611         for (size_t kr_block_start = 0; kr_block_start < skc; kr_block_start += kr) {
612           for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
613             for (size_t kr_block_offset = 0; kr_block_offset < kr; kr_block_offset++) {
614               *packed_w++ =
615                 k[((nr_block_start + nr_block_offset) * ks + ki) * kc + round_down_po2(kr_block_start, skr) + ((kr_block_start + nr_block_offset * kr) & sr_mask) + kr_block_offset];
616             }
617           }
618           packed_w += (nr - nr_block_size) * kr;
619         }
620 
621         for (size_t kr_block_start = skc; kr_block_start < kc; kr_block_start += kr) {
622           const size_t kr_block_size = min(kc - kr_block_start, kr);
623           for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
624             for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; kr_block_offset++) {
625               *packed_w++ =
626                 k[((nr_block_start + nr_block_offset) * ks + ki) * kc + (kr_block_start + kr_block_offset)];
627             }
628             packed_w += kr - kr_block_size;
629           }
630           packed_w += (nr - nr_block_size) * kr;
631         }
632       }
633     }
634     k += ks * kc * nc;
635     if XNN_UNPREDICTABLE(b != NULL) {
636       b += nc;
637     }
638   } while (--g != 0);
639 }
640 
xnn_pack_f32_conv_kgo_w(size_t g,size_t nc,size_t ks,size_t nr,size_t kr,const float * k,const float * b,float * packed_w)641 static inline void xnn_pack_f32_conv_kgo_w(
642   size_t g,
643   size_t nc,
644   size_t ks,
645   size_t nr,
646   size_t kr,
647   const float* k,
648   const float* b,
649   float* packed_w)
650 {
651   for (size_t i = 0; i < g; i++) {
652     for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
653       const size_t nr_block_size = min(nc - nr_block_start, nr);
654       if XNN_LIKELY(b != NULL) {
655         for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
656           packed_w[nr_block_offset] = b[nr_block_start + nr_block_offset];
657         }
658       }
659       packed_w += nr;
660       for (size_t ki = 0; ki < ks; ki++) {
661         for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
662           *packed_w =
663             k[ki * g * nc + (nr_block_start + nr_block_offset)];
664           packed_w += kr;
665         }
666         packed_w += (nr - nr_block_size) * kr;
667       }
668     }
669     k += nc;
670     if XNN_UNPREDICTABLE(b != NULL) {
671       b += nc;
672     }
673   }
674 }
675 
xnn_pack_f32_dconv_oki_w(size_t nc,size_t kc,size_t nr,size_t kh,size_t kw,const float * k,const float * b,float * packed_w)676 static inline void xnn_pack_f32_dconv_oki_w(
677   size_t nc,
678   size_t kc,
679   size_t nr,
680   size_t kh,
681   size_t kw,
682   const float* k,
683   const float* b,
684   float* packed_w)
685 {
686   for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
687     const size_t nr_block_size = min(nc - nr_block_start, nr);
688     if XNN_LIKELY(b != NULL) {
689       for (size_t nr_block_offset = 0; nr_block_offset < nr; nr_block_offset++) {
690         *packed_w++ = b[min(nr_block_offset, nr_block_size - 1)];
691       }
692     } else {
693       size_t n = nr;
694       do {
695         *packed_w++ = 0.0f;
696       } while (--n != 0);
697     }
698 
699     for (size_t kx = 0; kx < kw; kx++) {
700       for (size_t c = 0; c < kc; c++) {
701         for (size_t ky = 0; ky < kh; ky++) {
702           for (size_t nr_block_offset = 0; nr_block_offset < nr; nr_block_offset++) {
703             *packed_w++ = k[(((nr_block_start + min(nr_block_offset, nr_block_size - 1)) * kh + ky) * kw + kx) * kc + c];
704           }
705         }
706       }
707     }
708     if XNN_UNPREDICTABLE(b != NULL) {
709       b += nr;
710     }
711   }
712 }
713 
xnn_pack_f32_deconv_goki_w(size_t g,size_t nc,size_t kh,size_t kw,size_t kc,size_t sh,size_t sw,size_t nr,size_t kr,size_t sr,const float * k,const float * b,float * packed_w,struct subconvolution_params * params)714 static inline void xnn_pack_f32_deconv_goki_w(
715   size_t g,
716   size_t nc,
717   size_t kh,
718   size_t kw,
719   size_t kc,
720   size_t sh,
721   size_t sw,
722   size_t nr,
723   size_t kr,
724   size_t sr,
725   const float* k,
726   const float* b,
727   float* packed_w,
728   struct subconvolution_params* params)
729 {
730   const size_t skr = sr * kr;
731   const size_t skc = round_down_po2(kc, skr);
732   const size_t sr_mask = (sr - 1) * kr;
733   for (size_t i = 0; i < g; i++) {
734     for (size_t oy = 0; oy < sh; oy++) {
735       for (size_t ox = 0; ox < sw; ox++) {
736         if (i == 0) {
737           (*params++).weights = packed_w;
738         }
739         for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
740           const size_t nr_block_size = min(nc - nr_block_start, nr);
741           if XNN_LIKELY(b != NULL) {
742             for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
743               packed_w[nr_block_offset] = b[nr_block_start + nr_block_offset];
744             }
745           }
746           packed_w += nr;
747           for (size_t ky = oy; ky < kh; ky += sh) {
748             for (size_t kx = ox; kx < kw; kx += sw) {
749               for (size_t kr_block_start = 0; kr_block_start < skc; kr_block_start += kr) {
750                 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
751                   for (size_t kr_block_offset = 0; kr_block_offset < kr; kr_block_offset++) {
752                     *packed_w++ =
753                       k[(((nr_block_start + nr_block_offset) * kh + ky) * kw + kx) * kc + round_down_po2(kr_block_start, skr) + ((kr_block_start + nr_block_offset * kr) & sr_mask) + kr_block_offset];
754                   }
755                 }
756                 packed_w += (nr - nr_block_size) * kr;
757               }
758 
759               for (size_t kr_block_start = skc; kr_block_start < kc; kr_block_start += kr) {
760                 const size_t kr_block_size = min(kc - kr_block_start, kr);
761                 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
762                   for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; kr_block_offset++) {
763                     *packed_w++ =
764                       k[(((nr_block_start + nr_block_offset) * kh + ky) * kw + kx) * kc + (kr_block_start + kr_block_offset)];
765                   }
766                   packed_w += kr - kr_block_size;
767                 }
768                 packed_w += (nr - nr_block_size) * kr;
769               }
770             }
771           }
772         }
773       }
774     }
775     k += kh * kw * kc * nc;
776     if XNN_UNPREDICTABLE(b != NULL) {
777       b += nc;
778     }
779   }
780 }
781 
xnn_pack_f32_dwconv_ghw_w(size_t h,size_t w,size_t c,size_t cr,const float * k,const float * b,float * packed_w)782 static inline void xnn_pack_f32_dwconv_ghw_w(
783   size_t h,
784   size_t w,
785   size_t c,
786   size_t cr,
787   const float* k,
788   const float* b,
789   float* packed_w)
790 {
791   for (size_t cr_block_start = 0; cr_block_start < c; cr_block_start += cr) {
792     const size_t cr_block_size = min(c - cr_block_start, cr);
793     if XNN_LIKELY(b != NULL) {
794       for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size; cr_block_offset++) {
795         *packed_w++ = b[cr_block_start + cr_block_offset];
796       }
797     } else {
798       size_t n = cr_block_size;
799       do {
800         *packed_w++ = 0.0f;
801       } while (--n != 0);
802     }
803     packed_w += cr - cr_block_size;
804     for (size_t x = 0; x < w; x++) {
805       for (size_t y = 0; y < h; y++) {
806         for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size; cr_block_offset++) {
807           const float kv = k[((cr_block_start + cr_block_offset) * h + y) * w + x];
808           *packed_w++ = kv;
809         }
810         packed_w += cr - cr_block_size;
811       }
812     }
813   }
814 }
815 
xnn_pack_f32_dwconv_hwg_w(size_t h,size_t w,size_t c,size_t cr,const float * k,const float * b,float * packed_w)816 static inline void xnn_pack_f32_dwconv_hwg_w(
817   size_t h,
818   size_t w,
819   size_t c,
820   size_t cr,
821   const float* k,
822   const float* b,
823   float* packed_w)
824 {
825   for (size_t cr_block_start = 0; cr_block_start < c; cr_block_start += cr) {
826     const size_t cr_block_size = min(c - cr_block_start, cr);
827     if XNN_LIKELY(b != NULL) {
828       for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size; cr_block_offset++) {
829         *packed_w++ = b[cr_block_start + cr_block_offset];
830       }
831     } else {
832       size_t n = cr_block_size;
833       do {
834         *packed_w++ = 0.0f;
835       } while (--n != 0);
836     }
837     packed_w += cr - cr_block_size;
838     for (size_t x = 0; x < w; x++) {
839       for (size_t y = 0; y < h; y++) {
840         for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size; cr_block_offset++) {
841           const float kv = k[(y * w + x) * c + (cr_block_start + cr_block_offset)];
842           *packed_w++ = kv;
843         }
844         packed_w += cr - cr_block_size;
845       }
846     }
847   }
848 }
849 
xnn_pack_f32_spchw_dwconv_ghw_w(size_t kernel_size,size_t groups,const float * kernel,const float * bias,float * packed_weights)850 static inline void xnn_pack_f32_spchw_dwconv_ghw_w(
851   size_t kernel_size,
852   size_t groups,
853   const float* kernel,
854   const float* bias,
855   float* packed_weights)
856 {
857   for (size_t g = 0; g < groups; g++) {
858     if XNN_LIKELY(bias != NULL) {
859       *packed_weights = *bias++;
860     } else {
861       *packed_weights = 0.0f;
862     }
863     packed_weights += 1;
864     for (size_t i = 0; i < kernel_size; i++) {
865       *packed_weights++ = kernel[g * kernel_size + i];
866     }
867   }
868 }
869 
xnn_pack_f32_vmulcaddc_w(size_t c,size_t cr,const float * s,const float * b,float * packed_w)870 static inline void xnn_pack_f32_vmulcaddc_w(
871   size_t c,
872   size_t cr,
873   const float* s,
874   const float* b,
875   float* packed_w)
876 {
877   for (size_t cr_block_start = 0; cr_block_start < c; cr_block_start += cr) {
878     const size_t cr_block_size = min(c - cr_block_start, cr);
879     for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size; cr_block_offset++) {
880       *packed_w++ = s[cr_block_start + cr_block_offset];
881     }
882     packed_w += cr - cr_block_size;
883     if XNN_LIKELY(b != NULL) {
884       for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size; cr_block_offset++) {
885         *packed_w++ = b[cr_block_start + cr_block_offset];
886       }
887     } else {
888       size_t n = cr_block_size;
889       do {
890         *packed_w++ = 0.0f;
891       } while (--n != 0);
892     }
893     packed_w += cr - cr_block_size;
894   }
895 }
896