• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8 
9 #pragma once
10 
11 #include <stddef.h>
12 #include <stdint.h>
13 
14 #include <pthreadpool.h>
15 
16 #include <xnnpack/allocator.h>
17 #include <xnnpack/params.h>
18 #include <xnnpack/compute.h>
19 
20 
21 enum xnn_ukernel_type {
22   xnn_ukernel_type_default = 0,
23   xnn_ukernel_type_average_pooling,
24   xnn_ukernel_type_conv2d_hwc2chw,
25   xnn_ukernel_type_dwconv,
26   xnn_ukernel_type_gemm,
27   xnn_ukernel_type_igemm,
28   xnn_ukernel_type_pixelwise_average_pooling,
29   xnn_ukernel_type_spmm,
30   xnn_ukernel_type_subconv2d,
31   xnn_ukernel_type_vmulcaddc,
32 };
33 
34 enum xnn_operator_type {
35   xnn_operator_type_invalid = 0,
36   xnn_operator_type_abs_nc_f32,
37   xnn_operator_type_add_nd_f16,
38   xnn_operator_type_add_nd_f32,
39   xnn_operator_type_add_nd_qs8,
40   xnn_operator_type_add_nd_qu8,
41   xnn_operator_type_argmax_pooling_nhwc_f32,
42   xnn_operator_type_average_pooling_nhwc_f32,
43   xnn_operator_type_average_pooling_nhwc_qu8,
44   xnn_operator_type_bankers_rounding_nc_f32,
45   xnn_operator_type_channel_shuffle_nc_x8,
46   xnn_operator_type_channel_shuffle_nc_x32,
47   xnn_operator_type_clamp_nc_f32,
48   xnn_operator_type_clamp_nc_s8,
49   xnn_operator_type_clamp_nc_u8,
50   xnn_operator_type_ceiling_nc_f32,
51   xnn_operator_type_constant_pad_nd_x8,
52   xnn_operator_type_constant_pad_nd_x16,
53   xnn_operator_type_constant_pad_nd_x32,
54   xnn_operator_type_convert_nc_f16_f32,
55   xnn_operator_type_convert_nc_f32_f16,
56   xnn_operator_type_convert_nc_f32_qs8,
57   xnn_operator_type_convert_nc_f32_qu8,
58   xnn_operator_type_convert_nc_qs8_f32,
59   xnn_operator_type_convert_nc_qu8_f32,
60   xnn_operator_type_convolution_nchw_f32,
61   xnn_operator_type_convolution_nhwc_f16,
62   xnn_operator_type_convolution_nhwc_f32,
63   xnn_operator_type_convolution_nhwc_qc8,
64   xnn_operator_type_convolution_nhwc_qs8,
65   xnn_operator_type_convolution_nhwc_qu8,
66   xnn_operator_type_copy_nc_x8,
67   xnn_operator_type_copy_nc_x16,
68   xnn_operator_type_copy_nc_x32,
69   xnn_operator_type_deconvolution_nhwc_f32,
70   xnn_operator_type_deconvolution_nhwc_qs8,
71   xnn_operator_type_deconvolution_nhwc_qu8,
72   xnn_operator_type_depth_to_space_nchw2nhwc_x32,
73   xnn_operator_type_depth_to_space_nhwc_x32,
74   xnn_operator_type_divide_nd_f32,
75   xnn_operator_type_elu_nc_f32,
76   xnn_operator_type_elu_nc_qs8,
77   xnn_operator_type_fully_connected_nc_f16,
78   xnn_operator_type_fully_connected_nc_f32,
79   xnn_operator_type_fully_connected_nc_qs8,
80   xnn_operator_type_fully_connected_nc_qu8,
81   xnn_operator_type_floor_nc_f32,
82   xnn_operator_type_global_average_pooling_nwc_f16,
83   xnn_operator_type_global_average_pooling_nwc_f32,
84   xnn_operator_type_global_average_pooling_nwc_qs8,
85   xnn_operator_type_global_average_pooling_nwc_qu8,
86   xnn_operator_type_global_average_pooling_ncw_f32,
87   xnn_operator_type_hardswish_nc_f16,
88   xnn_operator_type_hardswish_nc_f32,
89   xnn_operator_type_leaky_relu_nc_f32,
90   xnn_operator_type_leaky_relu_nc_qu8,
91   xnn_operator_type_max_pooling_nhwc_f16,
92   xnn_operator_type_max_pooling_nhwc_f32,
93   xnn_operator_type_max_pooling_nhwc_s8,
94   xnn_operator_type_max_pooling_nhwc_u8,
95   xnn_operator_type_maximum_nd_f32,
96   xnn_operator_type_minimum_nd_f32,
97   xnn_operator_type_multiply_nd_f16,
98   xnn_operator_type_multiply_nd_f32,
99   xnn_operator_type_multiply_nd_qs8,
100   xnn_operator_type_multiply_nd_qu8,
101   xnn_operator_type_negate_nc_f32,
102   xnn_operator_type_prelu_nc_f16,
103   xnn_operator_type_prelu_nc_f32,
104   xnn_operator_type_resize_bilinear_nchw_f32,
105   xnn_operator_type_resize_bilinear_nhwc_f32,
106   xnn_operator_type_resize_bilinear_nhwc_s8,
107   xnn_operator_type_resize_bilinear_nhwc_u8,
108   xnn_operator_type_sigmoid_nc_f32,
109   xnn_operator_type_sigmoid_nc_qs8,
110   xnn_operator_type_sigmoid_nc_qu8,
111   xnn_operator_type_softmax_nc_f32,
112   xnn_operator_type_softmax_nc_qu8,
113   xnn_operator_type_square_nc_f32,
114   xnn_operator_type_square_root_nc_f32,
115   xnn_operator_type_squared_difference_nd_f32,
116   xnn_operator_type_subtract_nd_f32,
117   xnn_operator_type_subtract_nd_qs8,
118   xnn_operator_type_subtract_nd_qu8,
119   xnn_operator_type_tanh_nc_qs8,
120   xnn_operator_type_tanh_nc_qu8,
121   xnn_operator_type_truncation_nc_f32,
122   xnn_operator_type_unpooling_nhwc_x32,
123 };
124 
125 struct xnn_ukernel_conv2d {
126   union {
127     xnn_conv_hwc2chw_ukernel_function hwc2chw_function;
128     xnn_conv_hwc_ukernel_function hwc_function;
129   };
130   uint8_t output_height_tile;
131   uint8_t output_channel_tile;
132 };
133 
134 struct xnn_ukernel_dwconv {
135   union {
136     xnn_dwconv_unipass_ukernel_function unipass_function;
137     xnn_dwconv_multipass_ukernel_function multipass_function;
138   };
139   uint8_t primary_tile;
140   uint8_t incremental_tile;
141 };
142 
143 // Direct 2D Depthwise Convolution
144 struct xnn_ukernel_dwconv2d {
145   union {
146     xnn_dwconv2d_chw_ukernel_function chw_function;
147   };
148   uint8_t output_width_tile;
149 };
150 
151 struct xnn_ukernel_gemm {
152   struct xnn_hmp_gemm_ukernel general_case;
153   struct xnn_hmp_gemm_ukernel mr1_case;
154 #if XNN_PLATFORM_JIT
155   struct xnn_code_buffer general_code_buffer;
156   struct xnn_code_buffer mr1_code_buffer;
157 #endif  // XNN_PLATFORM_JIT
158   uint8_t mr;
159   uint8_t nr;
160   uint8_t kr;
161   uint8_t sr;
162 };
163 
164 struct xnn_ukernel_igemm {
165   struct xnn_hmp_igemm_ukernel general_case;
166   struct xnn_hmp_igemm_ukernel mr1_case;
167   struct xnn_hmp_gemm_ukernel gemm_case;
168 #if XNN_PLATFORM_JIT
169   struct xnn_code_buffer general_code_buffer;
170   struct xnn_code_buffer mr1_code_buffer;
171 #endif  // XNN_PLATFORM_JIT
172   uint8_t mr;
173   uint8_t nr;
174   uint8_t kr;
175   uint8_t sr;
176 };
177 
178 struct xnn_ukernel_spmm {
179   xnn_spmm_ukernel_function function;
180   uint8_t mr;
181 };
182 
183 struct xnn_ukernel_vmulcaddc {
184   xnn_vmulcaddc_ukernel_function function;
185   uint8_t mr;
186 };
187 
188 struct xnn_ukernel_vbinary {
189   xnn_vbinary_ukernel_function op_function;
190   xnn_vbinary_ukernel_function opc_function;
191   xnn_vbinary_ukernel_function ropc_function;
192 };
193 
194 struct xnn_ukernel_vunary {
195   xnn_vunary_ukernel_function function;
196 };
197 
198 struct xnn_ukernel {
199   enum xnn_ukernel_type type;
200   union {
201     struct xnn_ukernel_conv2d conv2d;
202     struct xnn_ukernel_dwconv dwconv;
203     struct xnn_ukernel_dwconv2d dwconv2d;
204     struct xnn_ukernel_gemm gemm;
205     struct xnn_ukernel_igemm igemm;
206     struct xnn_ukernel_spmm spmm;
207     struct xnn_ukernel_vmulcaddc vmulcaddc;
208     struct xnn_ukernel_vbinary vbinary;
209     struct xnn_ukernel_vunary vunary;
210   };
211 };
212 
213 enum xnn_run_state {
214   xnn_run_state_invalid = 0,
215   xnn_run_state_ready,
216   xnn_run_state_skip,
217 };
218 
219 struct subconvolution_params {
220   void* weights;
221   size_t w_stride;
222   const void** indirection_buffer;
223   void* output;
224   size_t slice_width;
225   size_t slice_height;
226   size_t indirection_y_stride;
227   size_t indirection_x_stride;
228   // scaled_kernel_size := kernel_size * mr * sizeof(void*).
229   size_t scaled_kernel_size;
230 };
231 
232 struct xnn_operator {
233   size_t batch_size;
234   uint32_t padding_top;
235   uint32_t padding_right;
236   uint32_t padding_bottom;
237   uint32_t padding_left;
238   uint32_t kernel_height;
239   uint32_t kernel_width;
240   uint32_t stride_height;
241   uint32_t stride_width;
242   uint32_t dilation_height;
243   uint32_t dilation_width;
244   uint32_t groups;
245   size_t group_channels;
246   size_t group_input_channels;
247   size_t group_output_channels;
248   size_t channels;
249 
250   size_t pad_before_channels;
251   size_t pad_after_channels;
252   uint32_t pad_value;
253 
254   size_t input_height;
255   size_t input_width;
256   size_t input_pixel_stride;
257   const void* input;
258   const void* input2;
259   const void** indirection_buffer;
260 
261   size_t output_height;
262   size_t output_width;
263   size_t output_pixel_stride;
264   void* output;
265 
266   void* packed_weights;
267   // Total number of non-zero kernel elements when weights use sparse representation.
268   size_t num_nonzero_values;
269   // Total number of non-zero kernel blocks when weights use sparse representation.
270   size_t num_nonzero_blocks;
271   // Total number of output channel blocks when weights use sparse representation.
272   size_t num_output_channel_blocks;
273   // Input channel corresponding to the first non-zero kernel element.
274   size_t first_input_channel;
275 
276   float input_scale;
277   float output_scale;
278   int32_t input_zero_point;
279   uint8_t output_zero_point;
280   uint8_t output_min;
281   uint8_t output_max;
282 
283   size_t valid_batch_size;
284   size_t last_input_height;
285   size_t last_input_width;
286   const void* last_input;
287   size_t last_output_height;
288   size_t last_output_width;
289   void* last_output;
290 
291   uint32_t block_size;
292 
293   void* zero_buffer;
294   void* lookup_table;
295   void* pixelwise_buffer;
296   struct subconvolution_params* subconvolution_buffer;
297   uint32_t flags;
298 
299   union {
300     union xnn_f16_f32_cvt_params f16_f32_cvt;
301     union xnn_f16_hswish_params f16_hswish;
302     union xnn_f32_abs_params f32_abs;
303     union xnn_f32_default_params f32_default;
304     union xnn_f32_elu_params f32_elu;
305     union xnn_f32_lrelu_params f32_lrelu;
306     union xnn_f32_neg_params f32_neg;
307     union xnn_f32_rnd_params f32_rnd;
308     union xnn_f32_sigmoid_params f32_sigmoid;
309     union xnn_f32_sqrt_params f32_sqrt;
310     // Parameters for Global Average Pooling in CHW layout
311     union xnn_f32_gavgpool_params f32_gavgpool;
312     union xnn_f32_hswish_params f32_hswish;
313     union xnn_f16_minmax_params f16_minmax;
314     union xnn_f16_scaleminmax_params f16_scaleminmax;
315     // Pixelwise Average Pooling normally use f32_minmax_params, but also initialize
316     // f32_scaleminmax_params in case it needs to switch to Global Average Pooling operation.
317     struct {
318       union xnn_f32_minmax_params f32_minmax;
319       union xnn_f32_scaleminmax_params f32_scaleminmax;
320     };
321     union xnn_f32_chw_params f32_chw;
322     union xnn_f32_f16_cvt_params f32_f16_cvt;
323     union xnn_f32_qs8_cvt_params f32_qs8_cvt;
324     union xnn_f32_qu8_cvt_params f32_qu8_cvt;
325     union xnn_qs8_f32_cvt_params qs8_f32_cvt;
326     union xnn_qu8_f32_cvt_params qu8_f32_cvt;
327     union xnn_qs8_conv_minmax_params qs8_conv_minmax;
328     // Average Pooling normally use qs8_avgpool_params, but also initialize qs8_gavgpool_params in case it needs to switch
329     // to Global Average Pooling operation.
330     struct {
331       union xnn_qs8_avgpool_minmax_params qs8_avgpool;
332       union xnn_qs8_avgpool_minmax_params qs8_gavgpool;
333     };
334     // Quantized Add parameters are sensitive to order of inputs, so we initialize an extra copy with the reversed order.
335     struct {
336       union xnn_qs8_addsub_minmax_params qs8_addsub;
337       union xnn_qs8_addsub_minmax_params qs8_raddsub;
338     };
339     struct {
340       union xnn_qs8_mul_minmax_params qs8_mul;
341       union xnn_qs8_mul_minmax_params qs8_rmul;
342     };
343     struct {
344       union xnn_qu8_addsub_minmax_params qu8_addsub;
345       union xnn_qu8_addsub_minmax_params qu8_raddsub;
346     };
347     struct {
348       union xnn_qu8_mul_minmax_params qu8_mul;
349       union xnn_qu8_mul_minmax_params qu8_rmul;
350     };
351     union xnn_qu8_conv_minmax_params qu8_conv_minmax;
352     // Average Pooling normally use qu8_avgpool_params, but also initialize qu8_gavgpool_params in case it needs to switch
353     // to Global Average Pooling operation.
354     struct {
355       union xnn_qu8_avgpool_minmax_params qu8_avgpool;
356       union xnn_qu8_avgpool_minmax_params qu8_gavgpool;
357     };
358     union xnn_s8_minmax_params s8_minmax;
359     union xnn_u8_minmax_params u8_minmax;
360   } params;
361   enum xnn_operator_type type;
362   struct xnn_ukernel ukernel;
363 
364   struct compute_parameters compute;
365   struct compute_parameters compute2;
366   union {
367     struct argmax_pooling_context argmax_pooling;
368     struct average_pooling_context average_pooling;
369     struct channel_shuffle_context channel_shuffle;
370     struct conv2d_context conv2d;
371     struct dwconv2d_context dwconv2d;
372     struct dwconv_context dwconv;
373     struct depthtospace2d_chw2hwc_context depthtospace2d_chw;
374     struct depthtospace2d_hwc_context depthtospace2d_hwc;
375     struct elementwise_binary_context elementwise_binary;
376     struct gemm_context gemm;
377     struct global_average_pooling_nwc_context global_average_pooling_nwc;
378     struct global_average_pooling_ncw_context global_average_pooling_ncw;
379     struct igemm_context igemm;
380     struct lut_contiguous_context lut_contiguous;
381     struct lut_strided_context lut_strided;
382     struct max_pooling_context max_pooling;
383     struct pad_context pad;
384     struct pixelwise_average_pooling_context pixelwise_average_pooling;
385     struct prelu_context prelu;
386     struct resize_bilinear_context resize_bilinear;
387     struct resize_bilinear_chw_context resize_bilinear_chw;
388     struct spmm_context spmm;
389     struct subconv_context subconv;
390     struct subgemm_context subgemm;
391     struct f32_three_pass_softmax_context f32_three_pass_softmax;
392     struct u8_softmax_context u8_softmax;
393     struct univector_contiguous_context univector_contiguous;
394     struct univector_strided_context univector_strided;
395     struct unpooling_context unpooling;
396     struct vmulcaddc_context vmulcaddc;
397   } context;
398 
399   enum xnn_run_state state;
400 };
401