1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8
9 #pragma once
10
11 #include <stddef.h>
12 #include <stdint.h>
13
14 #include <pthreadpool.h>
15
16 #include <xnnpack/allocator.h>
17 #include <xnnpack/cache.h>
18 #include <xnnpack/compute.h>
19 #include <xnnpack/operator-type.h>
20 #include <xnnpack/params.h>
21 #include <xnnpack/ukernel-type.h>
22
23
24 struct xnn_ukernel_conv2d {
25 union {
26 xnn_conv_hwc2chw_ukernel_function hwc2chw_function;
27 xnn_conv_hwc_ukernel_function hwc_function;
28 };
29 uint8_t output_height_tile;
30 uint8_t output_channel_tile;
31 };
32
33 struct xnn_ukernel_dwconv {
34 union {
35 xnn_dwconv_unipass_ukernel_function unipass_function;
36 xnn_dwconv_multipass_ukernel_function multipass_function;
37 };
38 uint8_t primary_tile;
39 uint8_t incremental_tile;
40 };
41
42 // Direct 2D Depthwise Convolution
43 struct xnn_ukernel_dwconv2d {
44 union {
45 xnn_dwconv2d_chw_ukernel_function chw_function;
46 };
47 uint8_t output_width_tile;
48 };
49
50 struct xnn_ukernel_gemm {
51 struct xnn_hmp_gemm_ukernel gemm_cases[XNN_MAX_MR];
52 uint8_t mr;
53 uint8_t nr;
54 uint8_t kr;
55 uint8_t sr;
56 };
57
58 struct xnn_ukernel_igemm {
59 struct xnn_hmp_igemm_ukernel igemm_cases[XNN_MAX_MR];
60 struct xnn_hmp_gemm_ukernel gemm_cases[XNN_MAX_MR];
61 uint8_t mr;
62 uint8_t nr;
63 uint8_t kr;
64 uint8_t sr;
65 };
66
67 struct xnn_ukernel_spmm {
68 xnn_spmm_ukernel_function function;
69 uint8_t mr;
70 };
71
72 struct xnn_ukernel_vmulcaddc {
73 xnn_vmulcaddc_ukernel_function function;
74 uint8_t mr;
75 };
76
77 struct xnn_ukernel_vbinary {
78 xnn_vbinary_ukernel_function op_function;
79 xnn_vbinary_ukernel_function opc_function;
80 xnn_vbinary_ukernel_function ropc_function;
81 };
82
83 struct xnn_ukernel_vunary {
84 xnn_vunary_ukernel_function function;
85 };
86
87 struct xnn_ukernel {
88 enum xnn_ukernel_type type;
89 union {
90 struct xnn_ukernel_conv2d conv2d;
91 struct xnn_ukernel_dwconv dwconv;
92 struct xnn_ukernel_dwconv2d dwconv2d;
93 struct xnn_ukernel_gemm gemm;
94 struct xnn_ukernel_igemm igemm;
95 struct xnn_ukernel_spmm spmm;
96 struct xnn_ukernel_vmulcaddc vmulcaddc;
97 struct xnn_ukernel_vbinary vbinary;
98 struct xnn_ukernel_vunary vunary;
99 };
100 };
101
102 enum xnn_run_state {
103 xnn_run_state_invalid = 0,
104 xnn_run_state_ready,
105 xnn_run_state_skip,
106 };
107
108 struct subconvolution_params {
109 void* weights;
110 size_t w_stride;
111 const void** indirection_buffer;
112 void* output;
113 size_t slice_width;
114 size_t slice_height;
115 size_t indirection_y_stride;
116 size_t indirection_x_stride;
117 // scaled_kernel_size := kernel_size * mr * sizeof(void*).
118 size_t scaled_kernel_size;
119 };
120
121 struct xnn_operator {
122 size_t batch_size;
123 uint32_t padding_top;
124 uint32_t padding_right;
125 uint32_t padding_bottom;
126 uint32_t padding_left;
127 uint32_t kernel_height;
128 uint32_t kernel_width;
129 uint32_t stride_height;
130 uint32_t stride_width;
131 uint32_t dilation_height;
132 uint32_t dilation_width;
133 uint32_t groups;
134 size_t group_channels;
135 size_t group_input_channels;
136 size_t group_output_channels;
137 size_t channels;
138
139 uint32_t pad_value;
140
141 size_t input_height;
142 size_t input_width;
143 size_t input_pixel_stride;
144 const void* input;
145 const void* input2;
146 const void** indirection_buffer;
147
148 size_t output_height;
149 size_t output_width;
150 size_t output_pixel_stride;
151 void* output;
152
153 union {
154 // Pointer to allocated packed weights. Use this if weights_cache is NULL.
155 void* pointer;
156 // Offset into the weights cache where the packed weights are. Only valid if weights_cache is not NULL.
157 size_t offset;
158 } packed_weights;
159 // Total number of non-zero kernel elements when weights use sparse representation.
160 size_t num_nonzero_values;
161 // Total number of non-zero kernel blocks when weights use sparse representation.
162 size_t num_nonzero_blocks;
163 // Total number of output channel blocks when weights use sparse representation.
164 size_t num_output_channel_blocks;
165 // Input channel corresponding to the first non-zero kernel element.
166 size_t first_input_channel;
167
168 float input_scale;
169 float output_scale;
170 int32_t input_zero_point;
171
172 size_t valid_batch_size;
173 size_t last_input_height;
174 size_t last_input_width;
175 const void* last_input;
176 size_t last_output_height;
177 size_t last_output_width;
178 void* last_output;
179
180 uint32_t block_size;
181
182 void* zero_buffer;
183 void* lookup_table;
184 void* pixelwise_buffer;
185 struct subconvolution_params* subconvolution_buffer;
186 uint32_t flags;
187
188 union {
189 union xnn_f16_abs_params f16_abs;
190 union xnn_f16_f32_cvt_params f16_f32_cvt;
191 union xnn_f16_hswish_params f16_hswish;
192 union xnn_f16_elu_params f16_elu;
193 union xnn_f16_lrelu_params f16_lrelu;
194 union xnn_f16_neg_params f16_neg;
195 union xnn_f16_sigmoid_params f16_sigmoid;
196 union xnn_f32_abs_params f32_abs;
197 union xnn_f32_default_params f32_default;
198 union xnn_f32_elu_params f32_elu;
199 union xnn_f32_lrelu_params f32_lrelu;
200 union xnn_f32_neg_params f32_neg;
201 union xnn_f32_rnd_params f32_rnd;
202 union xnn_f32_sigmoid_params f32_sigmoid;
203 union xnn_f32_sqrt_params f32_sqrt;
204 // Parameters for Global Average Pooling in CHW layout
205 union xnn_f32_gavgpool_params f32_gavgpool;
206 union xnn_f32_hswish_params f32_hswish;
207 // Pixelwise Average Pooling normally use f16_minmax_params, but also initialize
208 // f16_scaleminmax_params in case it needs to switch to Global Average Pooling operation.
209 struct {
210 union xnn_f16_minmax_params f16_minmax;
211 union xnn_f16_scaleminmax_params f16_scaleminmax;
212 };
213 // Pixelwise Average Pooling normally use f32_minmax_params, but also initialize
214 // f32_scaleminmax_params in case it needs to switch to Global Average Pooling operation.
215 struct {
216 union xnn_f32_minmax_params f32_minmax;
217 union xnn_f32_scaleminmax_params f32_scaleminmax;
218 };
219 union xnn_f32_chw_params f32_chw;
220 union xnn_f32_f16_cvt_params f32_f16_cvt;
221 union xnn_f32_qs8_cvt_params f32_qs8_cvt;
222 union xnn_f32_qu8_cvt_params f32_qu8_cvt;
223 union xnn_qs8_cvt_params qs8_cvt;
224 union xnn_qs8_f32_cvt_params qs8_f32_cvt;
225 union xnn_qu8_cvt_params qu8_cvt;
226 union xnn_qu8_f32_cvt_params qu8_f32_cvt;
227 union xnn_qs8_conv_minmax_params qs8_conv_minmax;
228 // Average Pooling normally use qs8_avgpool_params, but also initialize qs8_gavgpool_params in case it needs to switch
229 // to Global Average Pooling operation.
230 struct {
231 union xnn_qs8_avgpool_minmax_params qs8_avgpool;
232 union xnn_qs8_avgpool_minmax_params qs8_gavgpool;
233 };
234 // Quantized Add parameters are sensitive to order of inputs, so we initialize an extra copy with the reversed order.
235 struct {
236 union xnn_qs8_add_minmax_params qs8_add;
237 union xnn_qs8_add_minmax_params qs8_radd;
238 };
239 struct {
240 union xnn_qs8_mul_minmax_params qs8_mul;
241 union xnn_qs8_mul_minmax_params qs8_rmul;
242 };
243 struct {
244 union xnn_qu8_add_minmax_params qu8_add;
245 union xnn_qu8_add_minmax_params qu8_radd;
246 };
247 struct {
248 union xnn_qu8_mul_minmax_params qu8_mul;
249 union xnn_qu8_mul_minmax_params qu8_rmul;
250 };
251 union xnn_qu8_conv_minmax_params qu8_conv_minmax;
252 // Average Pooling normally use qu8_avgpool_params, but also initialize qu8_gavgpool_params in case it needs to switch
253 // to Global Average Pooling operation.
254 struct {
255 union xnn_qu8_avgpool_minmax_params qu8_avgpool;
256 union xnn_qu8_avgpool_minmax_params qu8_gavgpool;
257 };
258 union xnn_qs8_lrelu_params qs8_lrelu;
259 union xnn_qu8_lrelu_params qu8_lrelu;
260 union xnn_s8_minmax_params s8_minmax;
261 union xnn_u8_minmax_params u8_minmax;
262 } params;
263 size_t num_post_operation_params;
264 void* post_operation_params;
265 enum xnn_operator_type type;
266 struct xnn_ukernel ukernel;
267
268 struct compute_parameters compute;
269 struct compute_parameters compute2;
270 union {
271 struct argmax_pooling_context argmax_pooling;
272 struct average_pooling_context average_pooling;
273 struct channel_shuffle_context channel_shuffle;
274 struct conv2d_context conv2d;
275 struct dwconv2d_context dwconv2d;
276 struct dwconv_context dwconv;
277 struct elementwise_binary_context elementwise_binary;
278 struct gemm_context gemm;
279 struct global_average_pooling_nwc_context global_average_pooling_nwc;
280 struct global_average_pooling_ncw_context global_average_pooling_ncw;
281 struct igemm_context igemm;
282 struct lut_contiguous_context lut_contiguous;
283 struct lut_strided_context lut_strided;
284 struct max_pooling_context max_pooling;
285 struct pad_context pad;
286 struct pixelwise_average_pooling_context pixelwise_average_pooling;
287 struct prelu_context prelu;
288 struct resize_bilinear_context resize_bilinear;
289 struct resize_bilinear_chw_context resize_bilinear_chw;
290 struct spmm_context spmm;
291 struct subconv_context subconv;
292 struct subgemm_context subgemm;
293 struct transpose_context transpose;
294 struct floating_point_softmax_context floating_point_softmax;
295 struct u8_softmax_context u8_softmax;
296 struct univector_contiguous_context univector_contiguous;
297 struct univector_strided_context univector_strided;
298 struct unpooling_context unpooling;
299 struct vmulcaddc_context vmulcaddc;
300 } context;
301
302 struct xnn_code_cache* code_cache;
303 struct xnn_weights_cache* weights_cache;
304 enum xnn_run_state state;
305 };
306
packed_weights(struct xnn_operator * op)307 static inline void* packed_weights(struct xnn_operator* op) {
308 if (op->weights_cache == NULL) {
309 return op->packed_weights.pointer;
310 } else {
311 return (void*) ((uintptr_t) op->weights_cache->cache.weights.start + op->packed_weights.offset);
312 }
313 }
314
use_weights_cache(struct xnn_operator * op)315 static inline bool use_weights_cache(struct xnn_operator* op) {
316 return op->weights_cache != NULL;
317 }
318
319 // Get a pointer to a region to pack weights into. If weights cache is available, use it, returning to a pointer to the
320 // cache's buffer, otherwise, allocate and return a pointer to a new region. Returns NULL on error.
321 XNN_INTERNAL void* xnn_get_pointer_to_write_weights(
322 xnn_operator_t op,
323 size_t aligned_weights_size,
324 int padding_byte);
325
326 #ifdef __cplusplus
327 extern "C" {
328 #endif
329 XNN_INTERNAL size_t xnn_compute_convolution_output_dimension(
330 size_t padded_input_dimension,
331 size_t kernel_dimension,
332 size_t dilation_dimension,
333 size_t subsampling_dimension);
334
335 XNN_INTERNAL size_t xnn_compute_deconvolution_output_dimension(
336 size_t input_dimension,
337 size_t output_padding_dimension,
338 size_t adjustment_dimension,
339 size_t kernel_dimension,
340 size_t dilation_dimension,
341 size_t stride_dimension);
342
343 XNN_INTERNAL size_t xnn_compute_unpooling_output_dimension(
344 size_t input_dimension,
345 size_t input_padding_dimension,
346 size_t kernel_dimension);
347
348 XNN_INTERNAL uint32_t xnn_get_heuristic_mr_gemm(
349 size_t batch_size,
350 uint32_t max_mr,
351 uint32_t nr,
352 struct xnn_hmp_gemm_ukernel *gemm_cases);
353
354 XNN_INTERNAL uint32_t xnn_get_heuristic_mr_igemm(
355 size_t batch_size,
356 uint32_t max_mr,
357 uint32_t nr,
358 struct xnn_hmp_igemm_ukernel *igemm_cases);
359 #ifdef __cplusplus
360 }
361 #endif
362