• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8 
9 #pragma once
10 
11 #include <stdint.h>
12 #include <stddef.h>
13 
14 #include <xnnpack/common.h>
15 #include <xnnpack/operator.h>
16 
17 
18 #ifdef __cplusplus
19 extern "C" {
20 #endif
21 
22 
23 struct xnn_qu8_packing_params {
24   uint8_t input_zero_point;
25   uint8_t kernel_zero_point;
26 };
27 
28 struct xnn_qs8_packing_params {
29   int8_t input_zero_point;
30 };
31 
32 
33 typedef void (*xnn_pack_gemm_goi_w_function)(
34   size_t g,
35   size_t nc,
36   size_t kc,
37   size_t nr,
38   size_t kr,
39   size_t sr,
40   const void* k,
41   const void* b,
42   void* packed_w,
43   size_t extra_bytes,
44   const void* params);
45 
46 XNN_INTERNAL void xnn_pack_f32_gemm_goi_w(
47   size_t g,
48   size_t nc,
49   size_t kc,
50   size_t nr,
51   size_t kr,
52   size_t sr,
53   const float* k,
54   const float* b,
55   float* packed_w,
56   size_t extra_bytes,
57   const void* params);
58 
59 XNN_INTERNAL void xnn_pack_f16_gemm_goi_w(
60   size_t g,
61   size_t nc,
62   size_t kc,
63   size_t nr,
64   size_t kr,
65   size_t sr,
66   const uint16_t* k,
67   const uint16_t* b,
68   uint16_t* packed_w,
69   size_t extra_bytes,
70   const void* params);
71 
72 XNN_INTERNAL void xnn_pack_f32_to_f16_gemm_goi_w(
73   size_t g,
74   size_t nc,
75   size_t kc,
76   size_t nr,
77   size_t kr,
78   size_t sr,
79   const float* k,
80   const float* b,
81   uint16_t* packed_w,
82   size_t extra_bytes,
83   const void* params);
84 
85 XNN_INTERNAL void xnn_pack_qu8_gemm_goi_w(
86   size_t g,
87   size_t nc,
88   size_t kc,
89   size_t nr,
90   size_t kr,
91   size_t sr,
92   const uint8_t* k,
93   const int32_t* b,
94   void* packed_w,
95   size_t extra_bytes,
96   const struct xnn_qu8_packing_params* params);
97 
98 XNN_INTERNAL void xnn_pack_qs8_gemm_goi_w(
99   size_t g,
100   size_t nc,
101   size_t kc,
102   size_t nr,
103   size_t kr,
104   size_t sr,
105   const int8_t* k,
106   const int32_t* b,
107   void* packed_w,
108   size_t extra_bytes,
109   const struct xnn_qs8_packing_params* params);
110 
111 XNN_INTERNAL void xnn_pack_qs8_gemm_xw_goi_w(
112   size_t g,
113   size_t nc,
114   size_t kc,
115   size_t nr,
116   size_t kr,
117   size_t sr,
118   const int8_t* k,
119   const int32_t* b,
120   void* packed_w,
121   size_t extra_bytes,
122   const struct xnn_qs8_packing_params* params);
123 
124 
125 typedef void (*xnn_pack_gemm_io_w_function)(
126   size_t nc,
127   size_t kc,
128   size_t nr,
129   size_t kr,
130   size_t sr,
131   const void* k,
132   const void* b,
133   void* packed_w,
134   const void* params);
135 
136 XNN_INTERNAL void xnn_pack_f32_gemm_io_w(
137   size_t nc,
138   size_t kc,
139   size_t nr,
140   size_t kr,
141   size_t sr,
142   const float* k,
143   const float* b,
144   float* packed_w,
145   const void* params);
146 
147 XNN_INTERNAL void xnn_pack_f16_gemm_io_w(
148   size_t nc,
149   size_t kc,
150   size_t nr,
151   size_t kr,
152   size_t sr,
153   const uint16_t* k,
154   const uint16_t* b,
155   uint16_t* packed_w,
156   const void* params);
157 
158 XNN_INTERNAL void xnn_pack_f32_to_f16_gemm_io_w(
159   size_t nc,
160   size_t kc,
161   size_t nr,
162   size_t kr,
163   size_t sr,
164   const float* k,
165   const float* b,
166   uint16_t* packed_w,
167   const void* params);
168 
169 XNN_INTERNAL void xnn_pack_qu8_gemm_io_w(
170   size_t nc,
171   size_t kc,
172   size_t nr,
173   size_t kr,
174   size_t sr,
175   const uint8_t* k,
176   const int32_t* b,
177   void* packed_w,
178   const struct xnn_qu8_packing_params* params);
179 
180 XNN_INTERNAL void xnn_pack_qs8_gemm_io_w(
181   size_t nc,
182   size_t kc,
183   size_t nr,
184   size_t kr,
185   size_t sr,
186   const int8_t* k,
187   const int32_t* b,
188   void* packed_w,
189   const struct xnn_qs8_packing_params* params);
190 
191 
192 typedef void (*xnn_pack_conv_goki_w_function)(
193   size_t g,
194   size_t nc,
195   size_t ks,
196   size_t kc,
197   size_t nr,
198   size_t kr,
199   size_t sr,
200   const void* k,
201   const void* b,
202   void* packed_w,
203   size_t extra_bytes,
204   const void* params);
205 
206 XNN_INTERNAL void xnn_pack_f32_conv_goki_w(
207   size_t g,
208   size_t nc,
209   size_t ks,
210   size_t kc,
211   size_t nr,
212   size_t kr,
213   size_t sr,
214   const float* k,
215   const float* b,
216   float* packed_w,
217   size_t extra_bytes,
218   const void* params);
219 
220 XNN_INTERNAL void xnn_pack_f16_conv_goki_w(
221   size_t g,
222   size_t nc,
223   size_t ks,
224   size_t kc,
225   size_t nr,
226   size_t kr,
227   size_t sr,
228   const uint16_t* k,
229   const uint16_t* b,
230   uint16_t* packed_w,
231   size_t extra_bytes,
232   const void* params);
233 
234 XNN_INTERNAL void xnn_pack_f32_to_f16_conv_goki_w(
235   size_t g,
236   size_t nc,
237   size_t ks,
238   size_t kc,
239   size_t nr,
240   size_t kr,
241   size_t sr,
242   const float* k,
243   const float* b,
244   uint16_t* packed_w,
245   size_t extra_bytes,
246   const void* params);
247 
248 XNN_INTERNAL void xnn_pack_qu8_conv_goki_w(
249   size_t g,
250   size_t nc,
251   size_t ks,
252   size_t kc,
253   size_t nr,
254   size_t kr,
255   size_t sr,
256   const uint8_t* k,
257   const int32_t* b,
258   void* packed_w,
259   size_t extra_bytes,
260   const struct xnn_qu8_packing_params* params);
261 
262 XNN_INTERNAL void xnn_pack_qs8_conv_goki_w(
263   size_t g,
264   size_t nc,
265   size_t ks,
266   size_t kc,
267   size_t nr,
268   size_t kr,
269   size_t sr,
270   const int8_t* k,
271   const int32_t* b,
272   void* packed_w,
273   size_t extra_bytes,
274   const struct xnn_qs8_packing_params* params);
275 
276 
277 typedef void (*xnn_pack_conv_kgo_w_function)(
278   size_t g,
279   size_t nc,
280   size_t ks,
281   size_t nr,
282   size_t kr,
283   size_t sr,
284   const void* k,
285   const void* b,
286   void* packed_w,
287   size_t extra_bytes,
288   const void* params);
289 
290 XNN_INTERNAL void xnn_pack_f32_conv_kgo_w(
291   size_t g,
292   size_t nc,
293   size_t ks,
294   size_t nr,
295   size_t kr,
296   size_t sr,
297   const float* k,
298   const float* b,
299   float* packed_w,
300   size_t extra_bytes,
301   const void* params);
302 
303 XNN_INTERNAL void xnn_pack_f16_conv_kgo_w(
304   size_t g,
305   size_t nc,
306   size_t ks,
307   size_t nr,
308   size_t kr,
309   size_t sr,
310   const uint16_t* k,
311   const uint16_t* b,
312   uint16_t* packed_w,
313   size_t extra_bytes,
314   const void* params);
315 
316 XNN_INTERNAL void xnn_pack_f32_to_f16_conv_kgo_w(
317   size_t g,
318   size_t nc,
319   size_t ks,
320   size_t nr,
321   size_t kr,
322   size_t sr,
323   const float* k,
324   const float* b,
325   uint16_t* packed_w,
326   size_t extra_bytes,
327   const void* params);
328 
329 XNN_INTERNAL void xnn_pack_qu8_conv_kgo_w(
330   size_t g,
331   size_t nc,
332   size_t ks,
333   size_t nr,
334   size_t kr,
335   size_t sr,
336   const uint8_t* k,
337   const int32_t* b,
338   void* packed_w,
339   size_t extra_bytes,
340   const struct xnn_qu8_packing_params* params);
341 
342 XNN_INTERNAL void xnn_pack_qs8_conv_kgo_w(
343   size_t g,
344   size_t nc,
345   size_t ks,
346   size_t nr,
347   size_t kr,
348   size_t sr,
349   const int8_t* k,
350   const int32_t* b,
351   void* packed_w,
352   size_t extra_bytes,
353   const struct xnn_qs8_packing_params* params);
354 
355 
356 typedef void (*xnn_pack_deconv_goki_w_function)(
357   size_t g,
358   size_t nc,
359   size_t kh,
360   size_t kw,
361   size_t kc,
362   size_t sh,
363   size_t sw,
364   size_t nr,
365   size_t kr,
366   size_t sr,
367   const void* k,
368   const void* b,
369   void* packed_w,
370   struct subconvolution_params* subconv_params,
371   const void* params);
372 
373 XNN_INTERNAL void xnn_pack_f32_deconv_goki_w(
374   size_t g,
375   size_t nc,
376   size_t kh,
377   size_t kw,
378   size_t kc,
379   size_t sh,
380   size_t sw,
381   size_t nr,
382   size_t kr,
383   size_t sr,
384   const float* k,
385   const float* b,
386   float* packed_w,
387   struct subconvolution_params* subconv_params,
388   const void* params);
389 
390 XNN_INTERNAL void xnn_pack_f16_deconv_goki_w(
391   size_t g,
392   size_t nc,
393   size_t kh,
394   size_t kw,
395   size_t kc,
396   size_t sh,
397   size_t sw,
398   size_t nr,
399   size_t kr,
400   size_t sr,
401   const uint16_t* k,
402   const uint16_t* b,
403   uint16_t* packed_w,
404   struct subconvolution_params* subconv_params,
405   const void* params);
406 
407 XNN_INTERNAL void xnn_pack_qs8_deconv_goki_w(
408   size_t g,
409   size_t nc,
410   size_t kh,
411   size_t kw,
412   size_t kc,
413   size_t sh,
414   size_t sw,
415   size_t nr,
416   size_t kr,
417   size_t sr,
418   const int8_t* k,
419   const int32_t* b,
420   void* packed_w,
421   struct subconvolution_params* subconv_params,
422   const struct xnn_qs8_packing_params* params);
423 
424 XNN_INTERNAL void xnn_pack_qu8_deconv_goki_w(
425   size_t g,
426   size_t nc,
427   size_t kh,
428   size_t kw,
429   size_t kc,
430   size_t sh,
431   size_t sw,
432   size_t nr,
433   size_t kr,
434   size_t sr,
435   const uint8_t* k,
436   const int32_t* b,
437   void* packed_w,
438   struct subconvolution_params* subconv_params,
439   const struct xnn_qu8_packing_params* params);
440 
441 
442 typedef void (*xnn_pack_dwconv_ghw_w_function)(
443   size_t h,
444   size_t w,
445   size_t c,
446   size_t cr,
447   const void* k,
448   const void* b,
449   void* packed_w,
450   size_t extra_bytes,
451   const void* params);
452 
453 XNN_INTERNAL void xnn_pack_f32_dwconv_ghw_w(
454   size_t h,
455   size_t w,
456   size_t c,
457   size_t cr,
458   const float* k,
459   const float* b,
460   float* packed_w,
461   size_t extra_bytes,
462   const void* params);
463 
464 XNN_INTERNAL void xnn_pack_f16_dwconv_ghw_w(
465   size_t h,
466   size_t w,
467   size_t c,
468   size_t cr,
469   const uint16_t* k,
470   const uint16_t* b,
471   uint16_t* packed_w,
472   size_t extra_bytes,
473   const void* params);
474 
475 XNN_INTERNAL void xnn_pack_f32_to_f16_dwconv_ghw_w(
476   size_t h,
477   size_t w,
478   size_t c,
479   size_t cr,
480   const float* k,
481   const float* b,
482   uint16_t* packed_w,
483   size_t extra_bytes,
484   const void* params);
485 
486 XNN_INTERNAL void xnn_pack_qu8_dwconv_ghw_w(
487   size_t h,
488   size_t w,
489   size_t c,
490   size_t cr,
491   const uint8_t* k,
492   const int32_t* b,
493   void* packed_w,
494   size_t extra_bytes,
495   const struct xnn_qu8_packing_params* params);
496 
497 XNN_INTERNAL void xnn_pack_qs8_dwconv_ghw_w(
498   size_t h,
499   size_t w,
500   size_t c,
501   size_t cr,
502   const int8_t* k,
503   const int32_t* b,
504   void* packed_w,
505   size_t extra_bytes,
506   const struct xnn_qs8_packing_params* params);
507 
508 
509 typedef void (*xnn_pack_dwconv_hwg_w_function)(
510   size_t h,
511   size_t w,
512   size_t c,
513   size_t cr,
514   const void* k,
515   const void* b,
516   void* packed_w,
517   size_t extra_bytes,
518   const void* params);
519 
520 XNN_INTERNAL void xnn_pack_f32_dwconv_hwg_w(
521   size_t h,
522   size_t w,
523   size_t c,
524   size_t cr,
525   const float* k,
526   const float* b,
527   float* packed_w,
528   size_t extra_bytes,
529   const void* params);
530 
531 XNN_INTERNAL void xnn_pack_f16_dwconv_hwg_w(
532   size_t h,
533   size_t w,
534   size_t c,
535   size_t cr,
536   const uint16_t* k,
537   const uint16_t* b,
538   uint16_t* packed_w,
539   size_t extra_bytes,
540   const void* params);
541 
542 XNN_INTERNAL void xnn_pack_f32_to_f16_dwconv_hwg_w(
543   size_t h,
544   size_t w,
545   size_t c,
546   size_t cr,
547   const float* k,
548   const float* b,
549   uint16_t* packed_w,
550   size_t extra_bytes,
551   const void* params);
552 
553 XNN_INTERNAL void xnn_pack_qu8_dwconv_hwg_w(
554   size_t h,
555   size_t w,
556   size_t c,
557   size_t cr,
558   const uint8_t* k,
559   const int32_t* b,
560   void* packed_w,
561   size_t extra_bytes,
562   const struct xnn_qu8_packing_params* params);
563 
564 XNN_INTERNAL void xnn_pack_qs8_dwconv_hwg_w(
565   size_t h,
566   size_t w,
567   size_t c,
568   size_t cr,
569   const int8_t* k,
570   const int32_t* b,
571   void* packed_w,
572   size_t extra_bytes,
573   const struct xnn_qs8_packing_params* params);
574 
575 
576 XNN_INTERNAL void xnn_pack_f32_gemminc_goi_w(
577   size_t g,
578   size_t nc,
579   size_t kc,
580   size_t nr,
581   size_t kr,
582   size_t sr,
583   const float* k,
584   float* packed_w,
585   const void* params);
586 
587 XNN_INTERNAL void xnn_pack_f16_gemminc_goi_w(
588   size_t g,
589   size_t nc,
590   size_t kc,
591   size_t nr,
592   size_t kr,
593   size_t sr,
594   const uint16_t* k,
595   uint16_t* packed_w,
596   const void* params);
597 
598 
599 XNN_INTERNAL void xnn_pack_f32_dconv_oki_w(
600   size_t nc,
601   size_t kc,
602   size_t nr,
603   size_t kh,
604   size_t kw,
605   const float* k,
606   const float* b,
607   float* packed_w,
608   const void* params);
609 
610 XNN_INTERNAL void xnn_pack_f16_dconv_oki_w(
611   size_t nc,
612   size_t kc,
613   size_t nr,
614   size_t kh,
615   size_t kw,
616   const uint16_t* k,
617   const uint16_t* b,
618   uint16_t* packed_w,
619   const void* params);
620 
621 
622 XNN_INTERNAL void xnn_pack_f32_chw_dwconv_ghw_w(
623   size_t kernel_size,
624   size_t groups,
625   const float* kernel,
626   const float* bias,
627   float* packed_weights,
628   const void* params);
629 
630 XNN_INTERNAL void xnn_pack_f16_chw_dwconv_ghw_w(
631   size_t kernel_size,
632   size_t groups,
633   const uint16_t* kernel,
634   const uint16_t* bias,
635   uint16_t* packed_weights,
636   const void* params);
637 
638 
639 XNN_INTERNAL void xnn_pack_f32_chw_dwconv_hwg_w(
640   size_t kernel_size,
641   size_t groups,
642   const float* kernel,
643   const float* bias,
644   float* packed_weights,
645   const void* params);
646 
647 
648 typedef void (*xnn_pack_vmulcaddc_w_function)(
649   size_t c,
650   size_t cr,
651   const void* s,
652   const void* b,
653   void* packed_w,
654   const void* params);
655 
656 XNN_INTERNAL void xnn_pack_f32_vmulcaddc_w(
657   size_t c,
658   size_t cr,
659   const float* s,
660   const float* b,
661   float* packed_w,
662   const void* params);
663 
664 XNN_INTERNAL void xnn_pack_f16_vmulcaddc_w(
665   size_t c,
666   size_t cr,
667   const uint16_t* s,
668   const uint16_t* b,
669   uint16_t* packed_w,
670   const void* params);
671 
672 XNN_INTERNAL void xnn_pack_f32_to_f16_vmulcaddc_w(
673   size_t c,
674   size_t cr,
675   const float* s,
676   const float* b,
677   uint16_t* packed_w,
678   const void* params);
679 
680 
681 typedef void (*xnn_pack_prelu_w_function)(
682   size_t c,
683   const void* s,
684   void* packed_w);
685 
686 XNN_INTERNAL void xnn_pack_f32_prelu_w(
687   size_t c,
688   const float* s,
689   float* packed_w);
690 
691 XNN_INTERNAL void xnn_pack_f16_prelu_w(
692   size_t c,
693   const uint16_t* s,
694   uint16_t* packed_w);
695 
696 XNN_INTERNAL void xnn_pack_f32_to_f16_prelu_w(
697   size_t c,
698   const float* s,
699   uint16_t* packed_w);
700 
701 
702 #ifdef __cplusplus
703 }  // extern "C"
704 #endif
705