• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2019 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #pragma once
7 
8 
9 #include <stddef.h>
10 #include <stdint.h>
11 
12 #include <xnnpack.h>
13 #include <xnnpack/common.h>
14 #include <xnnpack/math.h>
15 #include <xnnpack/params.h>
16 
17 
18 enum xnn_parallelization_type {
19   xnn_parallelization_type_invalid = 0,
20   xnn_parallelization_type_1d,
21   xnn_parallelization_type_1d_tile_1d,
22   xnn_parallelization_type_2d,
23   xnn_parallelization_type_2d_tile_1d,
24   xnn_parallelization_type_2d_tile_2d,
25   xnn_parallelization_type_3d_tile_2d,
26   xnn_parallelization_type_4d_tile_2d,
27   xnn_parallelization_type_5d_tile_2d,
28   xnn_parallelization_type_6d_tile_2d,
29 };
30 
31 struct compute_parameters {
32   enum xnn_parallelization_type type;
33   union {
34     pthreadpool_task_1d_t task_1d;
35     pthreadpool_task_1d_tile_1d_t task_1d_tile_1d;
36     pthreadpool_task_2d_t task_2d;
37     pthreadpool_task_2d_tile_1d_t task_2d_tile_1d;
38     pthreadpool_task_2d_tile_2d_t task_2d_tile_2d;
39     pthreadpool_task_3d_tile_2d_t task_3d_tile_2d;
40     pthreadpool_task_4d_tile_2d_t task_4d_tile_2d;
41     pthreadpool_task_5d_tile_2d_t task_5d_tile_2d;
42     pthreadpool_task_6d_tile_2d_t task_6d_tile_2d;
43   };
44   size_t range[6];
45   size_t tile[2];
46 };
47 
48 struct gemm_context {
49   size_t k_scaled;
50   const void* a;
51   size_t a_stride;
52   const void* packed_w;
53   size_t w_stride;
54   size_t wg_stride;
55   void* c;
56   size_t cm_stride;
57   size_t cn_stride;
58   size_t cg_stride;
59   uint32_t log2_csize;
60   xnn_gemm_ukernel_function ukernel;
61   union {
62     union xnn_q8_gemm_params q8;
63     union xnn_f32_output_params f32;
64   } params;
65 };
66 
67 #ifndef __cplusplus
68   XNN_PRIVATE void xnn_compute_ggemm(
69       const struct gemm_context context[restrict static 1],
70       size_t group_index,
71       size_t mr_block_start,
72       size_t nr_block_start,
73       size_t mr_block_size,
74       size_t nr_block_size);
75 
76   XNN_PRIVATE void xnn_compute_gemm(
77       const struct gemm_context context[restrict static 1],
78       size_t mr_block_start,
79       size_t nr_block_start,
80       size_t mr_block_size,
81       size_t nr_block_size);
82 #endif
83 
84 // Context for Sparse Matrix-Dense Matrix Multiplication.
85 // C [MxN] := A [MxK] * B [KxN] + bias [N]
86 // A and C are dense matrices with row-major storage, B is a sparse matrix.
87 struct spmm_context {
88   // N dimension of the B and C matrices.
89   // Corresponds to number of output channels in 1x1 convolution.
90   size_t n;
91   // Input matrix A.
92   const void* a;
93   // Packed bias elements and non-zero filter elements.
94   const void* packed_weights;
95   // Input pointer increments, in bytes, after each processed non-zero weight.
96   const int32_t* input_increments;
97   // Number of non-zero filter elements per each N (output channel) dimension.
98   const uint32_t* output_channel_nonzeros;
99   // Output matrix C.
100   void* c;
101   // Stride, in bytes, between matrices A corresponding to different images in batched 1x1 Convolution
102   size_t batched_a_stride;
103   // Stride, in bytes, between matrices C corresponding to different images in batched 1x1 Convolution
104   size_t batched_c_stride;
105   // Micro-kernel function pointer.
106   xnn_spmm_ukernel_function ukernel;
107   // Output activation parameters.
108   union {
109     union xnn_f32_output_params f32;
110   } params;
111 };
112 
113 #ifndef __cplusplus
114   XNN_PRIVATE void xnn_compute_spmm(
115     const struct spmm_context context[restrict static 1],
116     size_t batch_index,
117     size_t mr_block_start,
118     size_t mr_block_size);
119 #endif
120 
121 struct igemm_context {
122   size_t ks;
123   size_t ks_scaled;
124   size_t kc;
125   size_t w_stride;
126   const void** indirect_a;
127   size_t a_offset;
128   void* zero;
129   const void* packed_w;
130   void* c;
131   size_t cm_stride;
132   size_t cn_stride;
133   size_t ga_stride;
134   size_t gw_stride;
135   size_t gc_stride;
136   size_t ba_stride;
137   size_t bc_stride;
138   uint32_t log2_csize;
139   xnn_igemm_ukernel_function ukernel;
140   union {
141     union xnn_q8_gemm_params q8;
142     union xnn_f32_output_params f32;
143   } params;
144 };
145 
146 #ifndef __cplusplus
147   XNN_PRIVATE void xnn_compute_gigemm(
148       const struct igemm_context context[restrict static 1],
149       size_t batch_index,
150       size_t group_index,
151       size_t mr_block_start,
152       size_t nr_block_start,
153       size_t mr_block_size,
154       size_t nr_block_size);
155 
156   XNN_PRIVATE void xnn_compute_igemm(
157       const struct igemm_context context[restrict static 1],
158       size_t batch_index,
159       size_t mr_block_start,
160       size_t nr_block_start,
161       size_t mr_block_size,
162       size_t nr_block_size);
163 #endif
164 
165 struct subconv_context {
166   const struct subconvolution_params* subconvolution_params;
167   size_t kc;
168   size_t a_offset;
169   void* zero;
170   size_t cx_stride;
171   size_t cy_stride;
172   size_t cn_stride;
173   size_t ga_stride;
174   size_t gw_stride;
175   size_t gc_stride;
176   size_t ba_stride;
177   size_t bc_stride;
178   uint32_t log2_csize;
179   xnn_igemm_ukernel_function ukernel;
180   union {
181     union xnn_q8_gemm_params q8;
182     union xnn_f32_output_params f32;
183   } params;
184 };
185 
186 #ifndef __cplusplus
187   XNN_PRIVATE void xnn_compute_gsubconv2d(
188       const struct subconv_context context[restrict static 1],
189       size_t batch_index,
190       size_t group_index,
191       size_t subkernel_index,
192       size_t slice_y,
193       size_t slice_x_start,
194       size_t nr_block_start,
195       size_t slice_x_max,
196       size_t nr_block_size);
197 
198   XNN_PRIVATE void xnn_compute_subconv2d(
199       const struct subconv_context context[restrict static 1],
200       size_t batch_index,
201       size_t subkernel_index,
202       size_t slice_y,
203       size_t slice_x_start,
204       size_t nr_block_start,
205       size_t slice_x_max,
206       size_t nr_block_size);
207 #endif
208 
209 struct dconv2d_context {
210   size_t input_height;
211   size_t input_width;
212   const void* input;
213   size_t input_batch_stride;
214   const void* zero;
215   const void* packed_weights;
216   void* output;
217   size_t output_batch_stride;
218   size_t input_padding_top;
219   size_t output_channels;
220   size_t output_height_stride;
221   size_t output_channel_stride;
222   union {
223     xnn_conv_hwc2spchw_ukernel_function hwc2spchw_ukernel;
224   };
225   union {
226     union xnn_f32_output_params f32;
227   } params;
228 };
229 
230 #ifndef __cplusplus
231   XNN_PRIVATE void xnn_compute_dconv2d_hwc2spchw(
232       const struct dconv2d_context context[restrict static 1],
233       size_t batch_index,
234       size_t output_y_start,
235       size_t output_y_slice);
236 #endif
237 
238 struct dwconv_context {
239   size_t groups;
240   const void** indirection_buffer;
241   size_t indirection_buffer_row_stride;
242   size_t indirection_buffer_col_stride;
243   const void* packed_weights;
244   void* output;
245   size_t output_width;
246   size_t output_row_stride;
247   size_t output_col_increment;
248   union {
249     union xnn_q8_gemm_params q8;
250     union xnn_f32_output_params f32;
251   } params;
252   union {
253     xnn_dwconv_up_ukernel_function unipass_ukernel;
254   };
255 };
256 
257 #ifndef __cplusplus
258   XNN_PRIVATE void xnn_compute_dwconv_unipass(
259       const struct dwconv_context context[restrict static 1],
260       size_t output_y);
261 #endif
262 
263 struct dwconv2d_context {
264   size_t output_height;
265   size_t input_width;
266   const void* input;
267   size_t input_channel_stride;
268   size_t input_batch_stride;
269   const void* packed_weights;
270   size_t weights_channel_stride;
271   void* output;
272   size_t output_channel_stride;
273   size_t output_batch_stride;
274   size_t input_tuple_stride;
275   size_t output_tuple_stride;
276   size_t input_pixel_stride;
277   size_t output_pixel_stride;
278   union {
279     union xnn_f32_spchw_params f32;
280   } params;
281   union {
282     xnn_dwconv_spchw_ukernel_function spchw_ukernel;
283   };
284 };
285 
286 #ifndef __cplusplus
287   XNN_PRIVATE void xnn_compute_dwconv2d_spchw(
288       const struct dwconv2d_context context[restrict static 1],
289       size_t batch_index,
290       size_t channel);
291 #endif
292 
293 struct max_pooling_context {
294   const void** indirect_input;
295   size_t indirect_input_height_stride;
296   size_t input_offset;
297   size_t input_batch_stride;
298   void* output;
299   size_t output_batch_stride;
300   size_t output_height_stride;
301   size_t output_width;
302   size_t pooling_size;
303   size_t channels;
304   size_t input_increment;
305   size_t output_increment;
306   union {
307     union xnn_u8_output_params u8;
308     union xnn_f32_output_params f32;
309   } params;
310   xnn_maxpool_ukernel_function ukernel;
311 };
312 
313 #ifndef __cplusplus
314   XNN_PRIVATE void xnn_compute_max_pooling(
315       const struct max_pooling_context context[restrict static 1],
316       size_t batch_index,
317       size_t output_y);
318 #endif
319 
320 struct unpooling_context {
321   const void* input;
322   size_t input_height_stride;
323   size_t input_width_stride;
324   const uint32_t* index;
325   size_t index_height_stride;
326   size_t index_width_stride;
327   void** indirect_output;
328   size_t indirect_output_height_stride;
329   size_t indirect_output_width_stride;
330   size_t pooling_size;
331   size_t channels;
332   uint32_t fill_value;
333   xnn_unpool_ukernel_function ukernel;
334 };
335 
336 #ifndef __cplusplus
337   XNN_PRIVATE void xnn_compute_unpooling(
338       const struct unpooling_context context[restrict static 1],
339       size_t input_y,
340       size_t input_x);
341 #endif
342 
343 struct argmax_pooling_context {
344   const void** indirect_input;
345   size_t indirect_input_height_stride;
346   size_t input_offset;
347   size_t input_batch_stride;
348   void* output;
349   size_t output_batch_stride;
350   size_t output_height_stride;
351   size_t output_width;
352   uint32_t* index;
353   size_t index_batch_stride;
354   size_t index_height_stride;
355   size_t pooling_size;
356   size_t channels;
357   size_t input_increment;
358   size_t output_increment;
359   union {
360     union xnn_f32_output_params f32;
361   } params;
362   union {
363     xnn_argmaxpool_up_ukernel_function unipass_ukernel;
364     xnn_argmaxpool_mp_ukernel_function multipass_ukernel;
365   };
366 };
367 
368 #ifndef __cplusplus
369   XNN_PRIVATE void xnn_compute_argmax_pooling_unipass(
370       const struct argmax_pooling_context context[restrict static 1],
371       size_t batch_index,
372       size_t output_y);
373 
374   XNN_PRIVATE void xnn_compute_argmax_pooling_multipass(
375       const struct argmax_pooling_context context[restrict static 1],
376       size_t batch_index,
377       size_t output_y);
378 #endif
379 
380 struct average_pooling_context {
381   const void** indirect_input;
382   size_t indirect_input_batch_stride;
383   size_t indirect_input_height_stride;
384   void* output;
385   size_t output_batch_stride;
386   size_t output_height_stride;
387   size_t output_width;
388   size_t pooling_size;
389   size_t channels;
390   const void* zero;
391   size_t input_increment;
392   size_t output_increment;
393   union {
394     union xnn_q8_avgpool_params q8;
395     union xnn_f32_avgpool_params f32;
396   } params;
397   union {
398     xnn_avgpool_up_ukernel_function unipass_ukernel;
399     xnn_avgpool_mp_ukernel_function multipass_ukernel;
400   };
401 };
402 
403 #ifndef __cplusplus
404   XNN_PRIVATE void xnn_compute_average_pooling_unipass(
405       const struct average_pooling_context context[restrict static 1],
406       size_t batch_index,
407       size_t output_y);
408 
409   XNN_PRIVATE void xnn_compute_average_pooling_multipass(
410       const struct average_pooling_context context[restrict static 1],
411       size_t batch_index,
412       size_t output_y);
413 #endif
414 
415 struct pixelwise_average_pooling_context {
416   const void** indirect_input;
417   size_t indirect_input_batch_stride;
418   size_t indirect_input_height_stride;
419   const void* pixelwise_buffer;
420   size_t pixelwise_buffer_height_stride;
421   void* output;
422   size_t output_batch_stride;
423   size_t output_height_stride;
424   size_t output_width;
425   size_t pooling_size;
426   size_t channels;
427   const void* zero;
428   size_t input_increment;
429   size_t output_increment;
430   union {
431     union xnn_u8_output_params u8;
432     union xnn_f32_output_params f32;
433   } params;
434   union {
435     xnn_pavgpool_up_ukernel_function unipass_ukernel;
436     xnn_pavgpool_mp_ukernel_function multipass_ukernel;
437   };
438 };
439 
440 #ifndef __cplusplus
441   XNN_PRIVATE void xnn_compute_pixelwise_average_pooling_unipass(
442       const struct pixelwise_average_pooling_context context[restrict static 1],
443       size_t batch_index,
444       size_t output_y);
445 
446   XNN_PRIVATE void xnn_compute_pixelwise_average_pooling_multipass(
447       const struct pixelwise_average_pooling_context context[restrict static 1],
448       size_t batch_index,
449       size_t output_y);
450 #endif
451 
452 struct global_average_pooling_nwc_context {
453   const void* input;
454   const void* zero;
455   size_t input_pixel_stride;
456   size_t input_batch_stride;
457   size_t input_elements;
458   size_t channels;
459   void* output;
460   size_t output_batch_stride;
461   union {
462     union xnn_q8_avgpool_params q8;
463     union xnn_f32_avgpool_params f32;
464   } params;
465   union {
466     xnn_gavgpool_up_ukernel_function unipass_ukernel;
467     xnn_gavgpool_mp_ukernel_function multipass_ukernel;
468   };
469 };
470 
471 #ifndef __cplusplus
472   XNN_PRIVATE void xnn_compute_global_average_pooling_nwc_unipass(
473       const struct global_average_pooling_nwc_context context[restrict static 1],
474       size_t batch_index);
475 
476   XNN_PRIVATE void xnn_compute_global_average_pooling_nwc_multipass(
477       const struct global_average_pooling_nwc_context context[restrict static 1],
478       size_t batch_index);
479 #endif
480 
481 struct global_average_pooling_ncw_context {
482   size_t input_elements;
483   const void* input;
484   size_t input_channel_stride;
485   size_t input_batch_stride;
486   void* output;
487   size_t output_channel_stride;
488   size_t output_batch_stride;
489   xnn_gavgpool_spchw_ukernel_function ukernel;
490   union {
491     union xnn_f32_gavgpool_params f32;
492   } params;
493 };
494 
495 #ifndef __cplusplus
496   XNN_PRIVATE void xnn_compute_global_average_pooling_ncw(
497       const struct global_average_pooling_ncw_context context[restrict static 1],
498       size_t batch_index,
499       size_t channels_start,
500       size_t channels_slice);
501 #endif
502 
503 struct resize_bilinear_context {
504   // Number of channels multiplied by sizeof(input element).
505   size_t scaled_channels;
506   // Indirection buffer with pointers related to rows of input pixels.
507   const void** indirect_input;
508   // Offset, in bytes, to be added to pointers in indirection buffer.
509   size_t input_offset;
510   // Stride, in bytes, between images of consecutive batches in the input.
511   size_t input_batch_stride;
512   // Packed pairs of (x, y) linear interpolation coefficients.
513   const void* packed_weights;
514   // Pointer to the output tensor.
515   void* output;
516   // Stride, in bytes, between adjacent pixels in the output.
517   size_t output_pixel_stride;
518   // Stride, in bytes, between images of consecutive batches in the output.
519   size_t output_batch_stride;
520   // log2(sizeof(weight element)).
521   uint32_t log2_wsize;
522   // Pointer to BILINEAR micro-kernel function.
523   xnn_bilinear_ukernel_function ukernel;
524 };
525 
526 #ifndef __cplusplus
527   XNN_PRIVATE void xnn_compute_resize_bilinear(
528       const struct resize_bilinear_context context[restrict static 1],
529       size_t batch_index,
530       size_t pixel_start,
531       size_t pixel_range);
532 #endif
533 
534 struct add_strided_context {
535   size_t n;
536   const void* a;
537   size_t a_stride;
538   const void* b;
539   size_t b_stride;
540   const void* y;
541   size_t y_stride;
542   union {
543     union xnn_q8_add_params q8;
544     union xnn_f32_output_params f32;
545   } params;
546   xnn_vadd_ukernel_function ukernel;
547 };
548 
549 #ifndef __cplusplus
550   XNN_PRIVATE void xnn_compute_add_strided(
551       const struct add_strided_context context[restrict static 1],
552       size_t batch_index,
553       size_t batch_range);
554 #endif
555 
556 struct add_contiguous_context {
557   const void* a;
558   const void* b;
559   void* y;
560   union {
561     union xnn_q8_add_params q8;
562     union xnn_f32_output_params f32;
563   } params;
564   xnn_vadd_ukernel_function ukernel;
565 };
566 
567 #ifndef __cplusplus
568   XNN_PRIVATE void xnn_compute_add_contiguous(
569       const struct add_contiguous_context context[restrict static 1],
570       size_t offset,
571       size_t size);
572 #endif
573 
574 struct elementwise_binary_context {
575   const void* a;
576   size_t a_stride[XNN_MAX_TENSOR_DIMS - 1];
577   const void* b;
578   size_t b_stride[XNN_MAX_TENSOR_DIMS - 1];
579   void* y;
580   size_t y_stride[XNN_MAX_TENSOR_DIMS - 1];
581   size_t elements;
582   union {
583     union xnn_q8_add_params q8;
584     union xnn_f32_output_params f32;
585   } params;
586   xnn_vbinary_ukernel_function ukernel;
587 };
588 
589 #ifndef __cplusplus
590   XNN_PRIVATE void xnn_compute_elementwise_binary_5d(
591       const struct elementwise_binary_context context[restrict static 1],
592       size_t i, size_t j, size_t k, size_t l, size_t m, size_t l_range, size_t m_range);
593 #endif
594 
595 struct channel_shuffle_context {
596   const void* x;
597   size_t x_stride;
598   void* y;
599   size_t y_stride;
600   size_t n;
601   size_t m;
602   union {
603     xnn_zipc_ukernel_function fixed_ukernel;
604     xnn_zipv_ukernel_function variable_ukernel;
605   };
606 };
607 
608 #ifndef __cplusplus
609   XNN_PRIVATE void xnn_compute_channel_shuffle_fixed(
610       const struct channel_shuffle_context context[restrict static 1],
611       size_t index);
612 
613   XNN_PRIVATE void xnn_compute_channel_shuffle_variable(
614       const struct channel_shuffle_context context[restrict static 1],
615       size_t index);
616 #endif
617 
618 struct lut_strided_context {
619   size_t n;
620   const void* x;
621   size_t x_stride;
622   const void* t;
623   void* y;
624   size_t y_stride;
625   xnn_x8_lut_ukernel_function ukernel;
626 };
627 
628 #ifndef __cplusplus
629   XNN_PRIVATE void xnn_compute_lut_strided(
630       const struct lut_strided_context context[restrict static 1],
631       size_t batch_index);
632 #endif
633 
634 struct lut_contiguous_context {
635   const void* x;
636   size_t x_stride;
637   const void* t;
638   void* y;
639   size_t y_stride;
640   xnn_x8_lut_ukernel_function ukernel;
641 };
642 
643 #ifndef __cplusplus
644   XNN_PRIVATE void xnn_compute_lut_contiguous(
645       const struct lut_contiguous_context context[restrict static 1],
646       size_t offset,
647       size_t size);
648 #endif
649 
650 struct univector_strided_context {
651   size_t n;
652   const void* x;
653   size_t x_stride;
654   void* y;
655   size_t y_stride;
656   xnn_univector_ukernel_function ukernel;
657   union {
658     union xnn_u8_output_params u8_output;
659     union xnn_f32_output_params f32_output;
660     union xnn_f32_hswish_params f32_hswish;
661   } params;
662 };
663 
664 #ifndef __cplusplus
665   XNN_PRIVATE void xnn_compute_univector_strided(
666       const struct univector_strided_context context[restrict static 1],
667       size_t batch_index,
668       size_t batch_range);
669 #endif
670 
671 struct univector_contiguous_context {
672   const void* x;
673   size_t x_stride;
674   void* y;
675   size_t y_stride;
676   xnn_univector_ukernel_function ukernel;
677   union {
678     union xnn_u8_output_params u8_output;
679     union xnn_f32_output_params f32_output;
680     union xnn_f32_hswish_params f32_hswish;
681   } params;
682 };
683 
684 #ifndef __cplusplus
685   XNN_PRIVATE void xnn_compute_univector_contiguous(
686       const struct univector_contiguous_context context[restrict static 1],
687       size_t offset,
688       size_t size);
689 #endif
690 
691 struct prelu_context {
692   size_t n;
693   const void* x;
694   size_t x_stride;
695   const void* w;
696   void* y;
697   size_t y_stride;
698   xnn_prelu_ukernel_function ukernel;
699   union xnn_f32_output_params params;
700 };
701 
702 #ifndef __cplusplus
703   XNN_PRIVATE void xnn_compute_prelu(
704       const struct prelu_context context[restrict static 1],
705       size_t batch_start,
706       size_t batch_range);
707 #endif
708 
709 struct vmulcaddc_context {
710   size_t n;
711   const void* x;
712   size_t x_stride;
713   const void* w;
714   void* y;
715   size_t y_stride;
716   xnn_vmulcaddc_ukernel_function ukernel;
717   union {
718     union xnn_f32_output_params f32;
719   } params;
720 };
721 
722 #ifndef __cplusplus
723   XNN_PRIVATE void xnn_compute_vmulcaddc(
724       const struct vmulcaddc_context context[restrict static 1],
725       size_t batch_start,
726       size_t batch_size);
727 #endif
728 
729 struct channel_pad_context {
730   size_t n;
731   size_t l;
732   size_t r;
733   uint32_t c;
734   const void* x;
735   size_t x_stride;
736   void* y;
737   size_t y_stride;
738   xnn_pad_ukernel_function ukernel;
739 };
740 
741 #ifndef __cplusplus
742   XNN_PRIVATE void xnn_compute_channel_pad(
743       const struct channel_pad_context context[restrict static 1],
744       size_t batch_start,
745       size_t batch_range);
746 #endif
747 
748 struct u8_softmax_context {
749   size_t n;
750   const uint8_t* x;
751   size_t x_stride;
752   const uint32_t* t;
753   uint8_t* y;
754   size_t y_stride;
755   xnn_u8_rmax_ukernel_function rmax_ukernel;
756   xnn_u8_lut32norm_ukernel_function lut_norm_ukernel;
757 };
758 
759 #ifndef __cplusplus
760   XNN_PRIVATE void xnn_compute_u8_softmax(
761       const struct u8_softmax_context context[restrict static 1],
762       size_t batch_index);
763 #endif
764 
765 struct f32_three_pass_softmax_context {
766   size_t n;
767   const void* x;
768   size_t x_stride;
769   void* y;
770   size_t y_stride;
771   xnn_f32_rmax_ukernel_function rmax_ukernel;
772   xnn_f32_raddstoreexpminusmax_ukernel_function raddstoreexpminusmax_ukernel;
773   xnn_vbinary_ukernel_function vmulc_ukernel;
774   union xnn_f32_output_params params;
775 };
776 
777 #ifndef __cplusplus
778   XNN_PRIVATE void xnn_compute_f32_three_pass_softmax(
779       const struct f32_three_pass_softmax_context context[restrict static 1],
780       size_t batch_index);
781 #endif
782