• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8 
9 #pragma once
10 
11 #include <stddef.h>
12 #include <stdint.h>
13 
14 #include <pthreadpool.h>
15 
16 #include <xnnpack/params.h>
17 #include <xnnpack/compute.h>
18 
19 
20 enum xnn_ukernel_type {
21   xnn_ukernel_type_none = 0,
22   xnn_ukernel_type_add,
23   xnn_ukernel_type_argmax_pooling,
24   xnn_ukernel_type_average_pooling,
25   xnn_ukernel_type_binary_elementwise,
26   xnn_ukernel_type_channel_shuffle,
27   xnn_ukernel_type_clamp,
28   xnn_ukernel_type_dconv2d_hwc2spchw,
29   xnn_ukernel_type_dwconv,
30   xnn_ukernel_type_gemm,
31   xnn_ukernel_type_global_average_pooling,
32   xnn_ukernel_type_hswish,
33   xnn_ukernel_type_igemm,
34   xnn_ukernel_type_lut,
35   xnn_ukernel_type_max_pooling,
36   xnn_ukernel_type_pad,
37   xnn_ukernel_type_pixelwise_average_pooling,
38   xnn_ukernel_type_prelu,
39   xnn_ukernel_type_sigmoid,
40   xnn_ukernel_type_softmax,
41   xnn_ukernel_type_spmm,
42   xnn_ukernel_type_subconv2d,
43   xnn_ukernel_type_unpooling,
44   xnn_ukernel_type_vmulcaddc,
45 };
46 
47 enum xnn_operator_type {
48   xnn_operator_type_none = 0,
49   xnn_operator_type_add_nc_f32,
50   xnn_operator_type_add_nd_f32,
51   xnn_operator_type_add_nc_q8,
52   xnn_operator_type_argmax_pooling_nhwc_f32,
53   xnn_operator_type_average_pooling_nhwc_f32,
54   xnn_operator_type_average_pooling_nhwc_q8,
55   xnn_operator_type_channel_pad_nc_x32,
56   xnn_operator_type_channel_shuffle_nc_x32,
57   xnn_operator_type_channel_shuffle_nc_x8,
58   xnn_operator_type_clamp_nc_f32,
59   xnn_operator_type_clamp_nc_u8,
60   xnn_operator_type_convolution_nhwc_f32,
61   xnn_operator_type_convolution_nhwc_q8,
62   xnn_operator_type_convolution_nchw_f32,
63   xnn_operator_type_deconvolution_nhwc_f32,
64   xnn_operator_type_deconvolution_nhwc_q8,
65   xnn_operator_type_divide_nd_f32,
66   xnn_operator_type_fully_connected_nc_f32,
67   xnn_operator_type_fully_connected_nc_q8,
68   xnn_operator_type_global_average_pooling_nwc_f32,
69   xnn_operator_type_global_average_pooling_nwc_q8,
70   xnn_operator_type_global_average_pooling_ncw_f32,
71   xnn_operator_type_hardswish_nc_f32,
72   xnn_operator_type_leaky_relu_nc_q8,
73   xnn_operator_type_max_pooling_nhwc_f32,
74   xnn_operator_type_max_pooling_nhwc_u8,
75   xnn_operator_type_maximum_nd_f32,
76   xnn_operator_type_minimum_nd_f32,
77   xnn_operator_type_multiply_nd_f32,
78   xnn_operator_type_prelu_nc_f32,
79   xnn_operator_type_resize_bilinear_nhwc_f32,
80   xnn_operator_type_sigmoid_nc_f32,
81   xnn_operator_type_sigmoid_nc_q8,
82   xnn_operator_type_softmax_nc_f32,
83   xnn_operator_type_softmax_nc_q8,
84   xnn_operator_type_subtract_nd_f32,
85   xnn_operator_type_unpooling_nhwc_x32,
86 };
87 
88 struct xnn_ukernel_dconv2d {
89   union {
90     xnn_conv_hwc2spchw_ukernel_function hwc2spchw_function;
91     xnn_conv_hwc_ukernel_function hwc_function;
92   };
93   uint8_t output_height_tile;
94   uint8_t output_channel_tile;
95 };
96 
97 struct xnn_ukernel_dwconv {
98   union {
99     xnn_dwconv_up_ukernel_function unipass_function;
100     xnn_dwconv_mp_ukernel_function multipass_function;
101   };
102   uint8_t mr;
103   uint8_t qr;
104 };
105 
106 // Direct 2D Depthwise Convolution
107 struct xnn_ukernel_dwconv2d {
108   union {
109     xnn_dwconv_spchw_ukernel_function spchw_function;
110   };
111   uint8_t input_width_tile;
112   uint8_t output_width_tile;
113 };
114 
115 struct xnn_ukernel_gemm {
116   xnn_gemm_ukernel_function default_function;
117   xnn_gemm_ukernel_function mr1_function;
118   uint8_t mr;
119   uint8_t nr;
120   uint8_t kr;
121 };
122 
123 struct xnn_ukernel_igemm {
124   xnn_igemm_ukernel_function default_function;
125   xnn_igemm_ukernel_function mr1_function;
126   uint8_t mr;
127   uint8_t nr;
128   uint8_t kr;
129 };
130 
131 struct xnn_ukernel_spmm {
132   xnn_spmm_ukernel_function function;
133   uint8_t mr;
134 };
135 
136 struct xnn_ukernel_vmulcaddc {
137   xnn_vmulcaddc_ukernel_function function;
138   uint8_t mr;
139 };
140 
141 struct xnn_ukernel {
142   enum xnn_ukernel_type type;
143   union {
144     struct xnn_ukernel_dconv2d dconv2d;
145     struct xnn_ukernel_dwconv dwconv;
146     struct xnn_ukernel_dwconv2d dwconv2d;
147     struct xnn_ukernel_gemm gemm;
148     struct xnn_ukernel_igemm igemm;
149     struct xnn_ukernel_spmm spmm;
150     struct xnn_ukernel_vmulcaddc vmulcaddc;
151   };
152 };
153 
154 enum xnn_run_state {
155   xnn_run_state_invalid = 0,
156   xnn_run_state_ready,
157   xnn_run_state_skip,
158 };
159 
160 struct subconvolution_params {
161   void* weights;
162   size_t w_stride;
163   const void** indirection_buffer;
164   void* output;
165   size_t slice_width;
166   size_t slice_height;
167   size_t indirection_y_stride;
168   size_t indirection_x_stride;
169   // scaled_kernel_size := kernel_size * mr * sizeof(void*).
170   size_t scaled_kernel_size;
171 };
172 
173 struct xnn_operator {
174   size_t batch_size;
175   uint32_t padding_top;
176   uint32_t padding_right;
177   uint32_t padding_bottom;
178   uint32_t padding_left;
179   uint32_t kernel_height;
180   uint32_t kernel_width;
181   uint32_t stride_height;
182   uint32_t stride_width;
183   uint32_t dilation_height;
184   uint32_t dilation_width;
185   uint32_t groups;
186   size_t group_channels;
187   size_t group_input_channels;
188   size_t group_output_channels;
189   size_t channels;
190 
191   size_t pad_before_channels;
192   size_t pad_after_channels;
193   uint32_t pad_value;
194 
195   size_t input_height;
196   size_t input_width;
197   size_t input_pixel_stride;
198   const void* input;
199   const void** indirection_buffer;
200 
201   size_t input2_pixel_stride;
202   const void* input2;
203 
204   size_t output_height;
205   size_t output_width;
206   size_t output_pixel_stride;
207   void* output;
208 
209   void* packed_weights;
210   // Total number of non-zero kernel elements when weights use sparse representation.
211   size_t num_nonzero_values;
212   // Total number of non-zero kernel blocks when weights use sparse representation.
213   size_t num_nonzero_blocks;
214   // Total number of output channel blocks when weights use sparse representation.
215   size_t num_output_channel_blocks;
216   // Input channel corresponding to the first non-zero kernel element.
217   size_t first_input_channel;
218 
219   float input_scale;
220   float output_scale;
221   uint8_t input_zero_point;
222   uint8_t kernel_zero_point;
223   uint8_t output_zero_point;
224   uint8_t output_min;
225   uint8_t output_max;
226 
227   size_t valid_batch_size;
228   size_t last_input_height;
229   size_t last_input_width;
230   const void* last_input;
231   size_t last_output_height;
232   size_t last_output_width;
233   void* last_output;
234 
235   void* zero_buffer;
236   void* lookup_table;
237   void* pixelwise_buffer;
238   struct subconvolution_params* subconvolution_buffer;
239   uint32_t flags;
240 
241   union {
242     union xnn_f32_avgpool_params f32_avgpool_params;
243     union xnn_f32_gavgpool_params f32_gavgpool_params;
244     union xnn_f32_hswish_params f32_hswish_params;
245     union xnn_f32_output_params f32_output_params;
246     union xnn_f32_spchw_params f32_spchw_params;
247     union xnn_q8_add_params q8_add_params;
248     union xnn_q8_avgpool_params q8_avgpool_params;
249     union xnn_q8_gemm_params q8_gemm_params;
250     union xnn_u8_output_params u8_output_params;
251   };
252   enum xnn_operator_type type;
253   struct xnn_ukernel ukernel;
254 
255   struct compute_parameters compute;
256   struct compute_parameters compute2;
257   union {
258     struct add_contiguous_context add_contiguous;
259     struct add_strided_context add_strided;
260     struct argmax_pooling_context argmax_pooling;
261     struct average_pooling_context average_pooling;
262     struct channel_pad_context channel_pad;
263     struct channel_shuffle_context channel_shuffle;
264     struct dconv2d_context dconv2d;
265     struct dwconv2d_context dwconv2d;
266     struct dwconv_context dwconv;
267     struct elementwise_binary_context elementwise_binary;
268     struct gemm_context gemm;
269     struct global_average_pooling_nwc_context global_average_pooling_nwc;
270     struct global_average_pooling_ncw_context global_average_pooling_ncw;
271     struct igemm_context igemm;
272     struct lut_contiguous_context lut_contiguous;
273     struct lut_strided_context lut_strided;
274     struct max_pooling_context max_pooling;
275     struct pixelwise_average_pooling_context pixelwise_average_pooling;
276     struct prelu_context prelu;
277     struct resize_bilinear_context resize_bilinear;
278     struct spmm_context spmm;
279     struct subconv_context subconv;
280     struct f32_three_pass_softmax_context f32_three_pass_softmax;
281     struct u8_softmax_context u8_softmax;
282     struct univector_contiguous_context univector_contiguous;
283     struct univector_strided_context univector_strided;
284     struct unpooling_context unpooling;
285     struct vmulcaddc_context vmulcaddc;
286   } context;
287 
288   enum xnn_run_state state;
289 };
290