• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8 
9 #pragma once
10 
11 #include <stddef.h>
12 #include <stdint.h>
13 
14 #include <pthreadpool.h>
15 
16 #include <xnnpack/allocator.h>
17 #include <xnnpack/cache.h>
18 #include <xnnpack/compute.h>
19 #include <xnnpack/operator-type.h>
20 #include <xnnpack/params.h>
21 #include <xnnpack/ukernel-type.h>
22 
23 
24 struct xnn_ukernel_conv2d {
25   union {
26     xnn_conv_hwc2chw_ukernel_function hwc2chw_function;
27     xnn_conv_hwc_ukernel_function hwc_function;
28   };
29   uint8_t output_height_tile;
30   uint8_t output_channel_tile;
31 };
32 
33 struct xnn_ukernel_dwconv {
34   union {
35     xnn_dwconv_unipass_ukernel_function unipass_function;
36     xnn_dwconv_multipass_ukernel_function multipass_function;
37   };
38   uint8_t primary_tile;
39   uint8_t incremental_tile;
40 };
41 
42 // Direct 2D Depthwise Convolution
43 struct xnn_ukernel_dwconv2d {
44   union {
45     xnn_dwconv2d_chw_ukernel_function chw_function;
46   };
47   uint8_t output_width_tile;
48 };
49 
50 struct xnn_ukernel_gemm {
51   struct xnn_hmp_gemm_ukernel gemm_cases[XNN_MAX_MR];
52   uint8_t mr;
53   uint8_t nr;
54   uint8_t kr;
55   uint8_t sr;
56 };
57 
58 struct xnn_ukernel_igemm {
59   struct xnn_hmp_igemm_ukernel igemm_cases[XNN_MAX_MR];
60   struct xnn_hmp_gemm_ukernel gemm_cases[XNN_MAX_MR];
61   uint8_t mr;
62   uint8_t nr;
63   uint8_t kr;
64   uint8_t sr;
65 };
66 
67 struct xnn_ukernel_spmm {
68   xnn_spmm_ukernel_function function;
69   uint8_t mr;
70 };
71 
72 struct xnn_ukernel_vmulcaddc {
73   xnn_vmulcaddc_ukernel_function function;
74   uint8_t mr;
75 };
76 
77 struct xnn_ukernel_vbinary {
78   xnn_vbinary_ukernel_function op_function;
79   xnn_vbinary_ukernel_function opc_function;
80   xnn_vbinary_ukernel_function ropc_function;
81 };
82 
83 struct xnn_ukernel_vunary {
84   xnn_vunary_ukernel_function function;
85 };
86 
87 struct xnn_ukernel {
88   enum xnn_ukernel_type type;
89   union {
90     struct xnn_ukernel_conv2d conv2d;
91     struct xnn_ukernel_dwconv dwconv;
92     struct xnn_ukernel_dwconv2d dwconv2d;
93     struct xnn_ukernel_gemm gemm;
94     struct xnn_ukernel_igemm igemm;
95     struct xnn_ukernel_spmm spmm;
96     struct xnn_ukernel_vmulcaddc vmulcaddc;
97     struct xnn_ukernel_vbinary vbinary;
98     struct xnn_ukernel_vunary vunary;
99   };
100 };
101 
102 enum xnn_run_state {
103   xnn_run_state_invalid = 0,
104   xnn_run_state_ready,
105   xnn_run_state_skip,
106 };
107 
108 struct subconvolution_params {
109   void* weights;
110   size_t w_stride;
111   const void** indirection_buffer;
112   void* output;
113   size_t slice_width;
114   size_t slice_height;
115   size_t indirection_y_stride;
116   size_t indirection_x_stride;
117   // scaled_kernel_size := kernel_size * mr * sizeof(void*).
118   size_t scaled_kernel_size;
119 };
120 
121 struct xnn_operator {
122   size_t batch_size;
123   uint32_t padding_top;
124   uint32_t padding_right;
125   uint32_t padding_bottom;
126   uint32_t padding_left;
127   uint32_t kernel_height;
128   uint32_t kernel_width;
129   uint32_t stride_height;
130   uint32_t stride_width;
131   uint32_t dilation_height;
132   uint32_t dilation_width;
133   uint32_t groups;
134   size_t group_channels;
135   size_t group_input_channels;
136   size_t group_output_channels;
137   size_t channels;
138 
139   uint32_t pad_value;
140 
141   size_t input_height;
142   size_t input_width;
143   size_t input_pixel_stride;
144   const void* input;
145   const void* input2;
146   const void** indirection_buffer;
147 
148   size_t output_height;
149   size_t output_width;
150   size_t output_pixel_stride;
151   void* output;
152 
153   union {
154     // Pointer to allocated packed weights. Use this if weights_cache is NULL.
155     void* pointer;
156     // Offset into the weights cache where the packed weights are. Only valid if weights_cache is not NULL.
157     size_t offset;
158   } packed_weights;
159   // Total number of non-zero kernel elements when weights use sparse representation.
160   size_t num_nonzero_values;
161   // Total number of non-zero kernel blocks when weights use sparse representation.
162   size_t num_nonzero_blocks;
163   // Total number of output channel blocks when weights use sparse representation.
164   size_t num_output_channel_blocks;
165   // Input channel corresponding to the first non-zero kernel element.
166   size_t first_input_channel;
167 
168   float input_scale;
169   float output_scale;
170   int32_t input_zero_point;
171 
172   size_t valid_batch_size;
173   size_t last_input_height;
174   size_t last_input_width;
175   const void* last_input;
176   size_t last_output_height;
177   size_t last_output_width;
178   void* last_output;
179 
180   uint32_t block_size;
181 
182   void* zero_buffer;
183   void* lookup_table;
184   void* pixelwise_buffer;
185   struct subconvolution_params* subconvolution_buffer;
186   uint32_t flags;
187 
188   union {
189     union xnn_f16_abs_params f16_abs;
190     union xnn_f16_f32_cvt_params f16_f32_cvt;
191     union xnn_f16_hswish_params f16_hswish;
192     union xnn_f16_elu_params f16_elu;
193     union xnn_f16_lrelu_params f16_lrelu;
194     union xnn_f16_neg_params f16_neg;
195     union xnn_f16_sigmoid_params f16_sigmoid;
196     union xnn_f32_abs_params f32_abs;
197     union xnn_f32_default_params f32_default;
198     union xnn_f32_elu_params f32_elu;
199     union xnn_f32_lrelu_params f32_lrelu;
200     union xnn_f32_neg_params f32_neg;
201     union xnn_f32_rnd_params f32_rnd;
202     union xnn_f32_sigmoid_params f32_sigmoid;
203     union xnn_f32_sqrt_params f32_sqrt;
204     // Parameters for Global Average Pooling in CHW layout
205     union xnn_f32_gavgpool_params f32_gavgpool;
206     union xnn_f32_hswish_params f32_hswish;
207     // Pixelwise Average Pooling normally use f16_minmax_params, but also initialize
208     // f16_scaleminmax_params in case it needs to switch to Global Average Pooling operation.
209     struct {
210       union xnn_f16_minmax_params f16_minmax;
211       union xnn_f16_scaleminmax_params f16_scaleminmax;
212     };
213     // Pixelwise Average Pooling normally use f32_minmax_params, but also initialize
214     // f32_scaleminmax_params in case it needs to switch to Global Average Pooling operation.
215     struct {
216       union xnn_f32_minmax_params f32_minmax;
217       union xnn_f32_scaleminmax_params f32_scaleminmax;
218     };
219     union xnn_f32_chw_params f32_chw;
220     union xnn_f32_f16_cvt_params f32_f16_cvt;
221     union xnn_f32_qs8_cvt_params f32_qs8_cvt;
222     union xnn_f32_qu8_cvt_params f32_qu8_cvt;
223     union xnn_qs8_cvt_params qs8_cvt;
224     union xnn_qs8_f32_cvt_params qs8_f32_cvt;
225     union xnn_qu8_cvt_params qu8_cvt;
226     union xnn_qu8_f32_cvt_params qu8_f32_cvt;
227     union xnn_qs8_conv_minmax_params qs8_conv_minmax;
228     // Average Pooling normally use qs8_avgpool_params, but also initialize qs8_gavgpool_params in case it needs to switch
229     // to Global Average Pooling operation.
230     struct {
231       union xnn_qs8_avgpool_minmax_params qs8_avgpool;
232       union xnn_qs8_avgpool_minmax_params qs8_gavgpool;
233     };
234     // Quantized Add parameters are sensitive to order of inputs, so we initialize an extra copy with the reversed order.
235     struct {
236       union xnn_qs8_add_minmax_params qs8_add;
237       union xnn_qs8_add_minmax_params qs8_radd;
238     };
239     struct {
240       union xnn_qs8_mul_minmax_params qs8_mul;
241       union xnn_qs8_mul_minmax_params qs8_rmul;
242     };
243     struct {
244       union xnn_qu8_add_minmax_params qu8_add;
245       union xnn_qu8_add_minmax_params qu8_radd;
246     };
247     struct {
248       union xnn_qu8_mul_minmax_params qu8_mul;
249       union xnn_qu8_mul_minmax_params qu8_rmul;
250     };
251     union xnn_qu8_conv_minmax_params qu8_conv_minmax;
252     // Average Pooling normally use qu8_avgpool_params, but also initialize qu8_gavgpool_params in case it needs to switch
253     // to Global Average Pooling operation.
254     struct {
255       union xnn_qu8_avgpool_minmax_params qu8_avgpool;
256       union xnn_qu8_avgpool_minmax_params qu8_gavgpool;
257     };
258     union xnn_qs8_lrelu_params qs8_lrelu;
259     union xnn_qu8_lrelu_params qu8_lrelu;
260     union xnn_s8_minmax_params s8_minmax;
261     union xnn_u8_minmax_params u8_minmax;
262   } params;
263   size_t num_post_operation_params;
264   void* post_operation_params;
265   enum xnn_operator_type type;
266   struct xnn_ukernel ukernel;
267 
268   struct compute_parameters compute;
269   struct compute_parameters compute2;
270   union {
271     struct argmax_pooling_context argmax_pooling;
272     struct average_pooling_context average_pooling;
273     struct channel_shuffle_context channel_shuffle;
274     struct conv2d_context conv2d;
275     struct dwconv2d_context dwconv2d;
276     struct dwconv_context dwconv;
277     struct elementwise_binary_context elementwise_binary;
278     struct gemm_context gemm;
279     struct global_average_pooling_nwc_context global_average_pooling_nwc;
280     struct global_average_pooling_ncw_context global_average_pooling_ncw;
281     struct igemm_context igemm;
282     struct lut_contiguous_context lut_contiguous;
283     struct lut_strided_context lut_strided;
284     struct max_pooling_context max_pooling;
285     struct pad_context pad;
286     struct pixelwise_average_pooling_context pixelwise_average_pooling;
287     struct prelu_context prelu;
288     struct resize_bilinear_context resize_bilinear;
289     struct resize_bilinear_chw_context resize_bilinear_chw;
290     struct spmm_context spmm;
291     struct subconv_context subconv;
292     struct subgemm_context subgemm;
293     struct transpose_context transpose;
294     struct floating_point_softmax_context floating_point_softmax;
295     struct u8_softmax_context u8_softmax;
296     struct univector_contiguous_context univector_contiguous;
297     struct univector_strided_context univector_strided;
298     struct unpooling_context unpooling;
299     struct vmulcaddc_context vmulcaddc;
300   } context;
301 
302   struct xnn_code_cache* code_cache;
303   struct xnn_weights_cache* weights_cache;
304   enum xnn_run_state state;
305 };
306 
packed_weights(struct xnn_operator * op)307 static inline void* packed_weights(struct xnn_operator* op) {
308   if (op->weights_cache == NULL) {
309     return op->packed_weights.pointer;
310   } else {
311     return (void*) ((uintptr_t) op->weights_cache->cache.weights.start + op->packed_weights.offset);
312   }
313 }
314 
use_weights_cache(struct xnn_operator * op)315 static inline bool use_weights_cache(struct xnn_operator* op) {
316   return op->weights_cache != NULL;
317 }
318 
319 // Get a pointer to a region to pack weights into. If weights cache is available, use it, returning to a pointer to the
320 // cache's buffer, otherwise, allocate and return a pointer to a new region. Returns NULL on error.
321 XNN_INTERNAL void* xnn_get_pointer_to_write_weights(
322   xnn_operator_t op,
323   size_t aligned_weights_size,
324   int padding_byte);
325 
326 #ifdef __cplusplus
327 extern "C" {
328 #endif
329 XNN_INTERNAL size_t xnn_compute_convolution_output_dimension(
330   size_t padded_input_dimension,
331   size_t kernel_dimension,
332   size_t dilation_dimension,
333   size_t subsampling_dimension);
334 
335 XNN_INTERNAL size_t xnn_compute_deconvolution_output_dimension(
336   size_t input_dimension,
337   size_t output_padding_dimension,
338   size_t adjustment_dimension,
339   size_t kernel_dimension,
340   size_t dilation_dimension,
341   size_t stride_dimension);
342 
343 XNN_INTERNAL size_t xnn_compute_unpooling_output_dimension(
344   size_t input_dimension,
345   size_t input_padding_dimension,
346   size_t kernel_dimension);
347 
348 XNN_INTERNAL uint32_t xnn_get_heuristic_mr_gemm(
349   size_t batch_size,
350   uint32_t max_mr,
351   uint32_t nr,
352   struct xnn_hmp_gemm_ukernel *gemm_cases);
353 
354 XNN_INTERNAL uint32_t xnn_get_heuristic_mr_igemm(
355   size_t batch_size,
356   uint32_t max_mr,
357   uint32_t nr,
358   struct xnn_hmp_igemm_ukernel *igemm_cases);
359 #ifdef __cplusplus
360 }
361 #endif
362