• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_3X3_FILTER_COMMON_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_3X3_FILTER_COMMON_H_
17 
18 #include "ruy/profiler/instrumentation.h"  // from @ruy
19 #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h"
20 #include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h"
21 #include "tensorflow/lite/kernels/internal/types.h"
22 
23 namespace tflite {
24 namespace optimized_ops {
25 namespace depthwise_conv {
26 
27 constexpr int kDepthwiseConvScratchWorkspaceSize = 10 * 10 * 64;
28 constexpr int kDepthwiseConvAdjustedBiasLimit = 64;
29 // In cases such as depth multiplication, we want to be able to load data from
30 // the workspace that is beyond the valid range. Macro-block sizes are adjusted
31 // to allow for this.
32 constexpr int kWorkspaceExtension = 16;
33 
34 #ifdef USE_NEON
35 
36 #ifndef __aarch64__
vqtbl4q_s8(int8x16x4_t a,int8x16_t b)37 inline int8x16_t vqtbl4q_s8(int8x16x4_t a, int8x16_t b) {
38   const uint8x16_t mask = vtstq_s8(b, vdupq_n_s8(8));
39 
40   // Delete bit 3 from the indices.
41   const int8x16_t high_bits = vshrq_n_s8(b, 4);
42   int8x16_t deleted_bit_3 = b;
43   deleted_bit_3 = vsliq_n_s8(deleted_bit_3, high_bits, 3);
44 
45   int8x8x4_t repacked_data;
46 
47   // Calculate for lower indices.
48   repacked_data.val[0] = vget_low_s8(a.val[0]);
49   repacked_data.val[1] = vget_low_s8(a.val[1]);
50   repacked_data.val[2] = vget_low_s8(a.val[2]);
51   repacked_data.val[3] = vget_low_s8(a.val[3]);
52   const int8x16_t output_for_lower =
53       vcombine_s8(vtbl4_s8(repacked_data, vget_low_s8(deleted_bit_3)),
54                   vtbl4_s8(repacked_data, vget_high_s8(deleted_bit_3)));
55 
56   // Calculate for high indices.
57   repacked_data.val[0] = vget_high_s8(a.val[0]);
58   repacked_data.val[1] = vget_high_s8(a.val[1]);
59   repacked_data.val[2] = vget_high_s8(a.val[2]);
60   repacked_data.val[3] = vget_high_s8(a.val[3]);
61   const int8x16_t output_for_higher =
62       vcombine_s8(vtbl4_s8(repacked_data, vget_low_s8(deleted_bit_3)),
63                   vtbl4_s8(repacked_data, vget_high_s8(deleted_bit_3)));
64 
65   // Merge.
66   int8x16_t output = vbslq_s8(mask, output_for_higher, output_for_lower);
67   return output;
68 }
69 #endif  // !__aarch64__
70 
71 // Convenience-compatibility functions.
72 // Compatibility: Intrinsics reflect a mixture of older and newer ARM
73 //     instructions. This actually results in ZIP1 / ZIP2 asm instructions, but
74 //     one intrinsic is provided. Also older instructions operated in place,
75 //     and it seems more defensive to assume that some versions of intrinsics
76 //     might reflect this
77 // Convenience: Callers in these kernels want both ZIP1 and ZIP2, and we do not
78 //     want the calling code to get cluttered with unpacking int8x16x2_t.
vzipq_s8_in_place(int8x16_t * a,int8x16_t * b)79 inline void vzipq_s8_in_place(int8x16_t* a, int8x16_t* b) {
80   int8x16x2_t r8x16;
81   r8x16 = vzipq_s8(*a, *b);
82   *a = r8x16.val[0];
83   *b = r8x16.val[1];
84 }
85 
vzipq_s8x2_in_place(int8x16_t * a,int8x16_t * b)86 inline void vzipq_s8x2_in_place(int8x16_t* a, int8x16_t* b) {
87   int16x8x2_t r16x8;
88   r16x8 = vzipq_s16(vreinterpretq_s16_s8(*a), vreinterpretq_s16_s8(*b));
89   *a = vreinterpretq_s8_s16(r16x8.val[0]);
90   *b = vreinterpretq_s8_s16(r16x8.val[1]);
91 }
92 
93 // Similar rationale to the zip-in_place functions, but callers only actually
94 // need the TRN1 asm instruction result.
vtrn1_s8x2_in_place(int8x16_t * a,int8x16_t * b)95 inline void vtrn1_s8x2_in_place(int8x16_t* a, int8x16_t* b) {
96   int16x8x2_t r16x8;
97   r16x8 = vtrnq_s16(vreinterpretq_s16_s8(*a), vreinterpretq_s16_s8(*b));
98   *a = vreinterpretq_s8_s16(r16x8.val[0]);
99 }
100 
101 // Similar rationale to the zip-in_place functions, but callers only actually
102 // need the ZIP1 or ZIP2 asm instruction results.
vzip1q_s8(int8x16_t a,int8x16_t b)103 inline int8x16_t vzip1q_s8(int8x16_t a, int8x16_t b) {
104   return vzipq_s8(a, b).val[0];
105 }
vzip2q_s8(int8x16_t a,int8x16_t b)106 inline int8x16_t vzip2q_s8(int8x16_t a, int8x16_t b) {
107   return vzipq_s8(a, b).val[1];
108 }
109 
biregister_rotate_8(int8x16_t * left,int8x16_t * right)110 inline void biregister_rotate_8(int8x16_t* left, int8x16_t* right) {
111   *left = vreinterpretq_s8_u32(vshrq_n_u32(vreinterpretq_u32_s8(*left), 8));
112   *left = vreinterpretq_s8_u32(vsliq_n_u32(vreinterpretq_u32_s8(*left),
113                                            vreinterpretq_u32_s8(*right), 24));
114   *right = vreinterpretq_s8_u32(vshrq_n_u32(vreinterpretq_u32_s8(*right), 8));
115 }
116 
117 #ifndef __aarch64__
vpaddq_s32(int32x4_t a,int32x4_t b)118 inline int32x4_t vpaddq_s32(int32x4_t a, int32x4_t b) {
119   int32x4x2_t deinterleaved = vuzpq_s32(a, b);
120   return vqaddq_s32(deinterleaved.val[0], deinterleaved.val[1]);
121 }
122 #endif  // !__aarch64__
123 
124 #ifdef __ARM_FEATURE_DOTPROD
125 // The vdotq_lane_s32 takes int8x8t for the rhs parameter, whereas the actual
126 // instruction selects from between 4 32-bit (4x8-bit packed) sub-registers, an
127 // unusual interpretation of "lane".
vdotq_four_lane_s32(int32x4_t acc,int8x16_t lhs,int8x16_t rhs,const int lane)128 inline int32x4_t vdotq_four_lane_s32(int32x4_t acc, int8x16_t lhs,
129                                      int8x16_t rhs, const int lane) {
130   switch (lane) {
131     case 0:
132       return vdotq_lane_s32(acc, lhs, vreinterpret_s32_s8(vget_low_s8(rhs)), 0);
133     case 1:
134       return vdotq_lane_s32(acc, lhs, vreinterpret_s32_s8(vget_low_s8(rhs)), 1);
135     case 2:
136       return vdotq_lane_s32(acc, lhs, vreinterpret_s32_s8(vget_high_s8(rhs)),
137                             0);
138     case 3:
139     default:
140       return vdotq_lane_s32(acc, lhs, vreinterpret_s32_s8(vget_high_s8(rhs)),
141                             1);
142   }
143 }
144 
145 #else
146 
vdotq_s32(int32x4_t acc,int8x16_t lhs,int8x16_t rhs)147 inline int32x4_t vdotq_s32(int32x4_t acc, int8x16_t lhs, int8x16_t rhs) {
148   int32x4_t sum0 = vpaddlq_s16(vmull_s8(vget_low_s8(lhs), vget_low_s8(rhs)));
149   int32x4_t sum1 = vpaddlq_s16(vmull_s8(vget_high_s8(lhs), vget_high_s8(rhs)));
150   int32x4_t sum = vpaddq_s32(sum0, sum1);
151   return vaddq_s32(acc, sum);
152 }
153 
vdotq_four_lane_s32(int32x4_t acc,int8x16_t lhs,int8x16_t rhs,int lane)154 inline int32x4_t vdotq_four_lane_s32(int32x4_t acc, int8x16_t lhs,
155                                      int8x16_t rhs, int lane) {
156   int8x8_t lane_rhs;
157   if (lane == 0) {
158     lane_rhs = vreinterpret_s8_s32(
159         vdup_lane_s32(vreinterpret_s32_s8(vget_low_s8(rhs)), 0));
160   } else if (lane == 1) {
161     lane_rhs = vreinterpret_s8_s32(
162         vdup_lane_s32(vreinterpret_s32_s8(vget_low_s8(rhs)), 1));
163   } else if (lane == 2) {
164     lane_rhs = vreinterpret_s8_s32(
165         vdup_lane_s32(vreinterpret_s32_s8(vget_high_s8(rhs)), 0));
166   } else {
167     lane_rhs = vreinterpret_s8_s32(
168         vdup_lane_s32(vreinterpret_s32_s8(vget_high_s8(rhs)), 1));
169   }
170   int32x4_t sum0 = vpaddlq_s16(vmull_s8(vget_low_s8(lhs), lane_rhs));
171   int32x4_t sum1 = vpaddlq_s16(vmull_s8(vget_high_s8(lhs), lane_rhs));
172   int32x4_t sum = vpaddq_s32(sum0, sum1);
173   return vaddq_s32(acc, sum);
174 }
175 
176 #endif  // !__ARM_FEATURE_DOTPROD
177 #endif  // ARM NEON
178 
179 //  This structure is typically used for reducing the magnitude of outputs, and
180 //  the historical name reflects that.
181 template <DepthwiseConvOutputRounding output_rounding>
182 struct DivideByPOT {};
183 
184 template <>
185 struct DivideByPOT<DepthwiseConvOutputRounding::kAwayFromZero> {
186   template <typename IntegerType>
187   static inline IntegerType Run(IntegerType x, int exponent) {
188     return RoundingDivideByPOT(x, exponent);
189   }
190   // Mult versions use the exponents directly, rather than negated.
191   template <typename IntegerType>
192   static inline IntegerType RunMult(IntegerType x, int exponent) {
193     return RoundingDivideByPOT(x, -exponent);
194   }
195 };
196 
197 #ifdef USE_NEON
198 template <>
199 struct DivideByPOT<DepthwiseConvOutputRounding::kUpward> {
200   template <typename IntegerType>
201   static inline IntegerType Run(IntegerType x, int exponent) {
202     return vqrshlq_s32(x, vdupq_n_s32(static_cast<int32>(-exponent)));
203   }
204   template <typename IntegerType>
205   static inline IntegerType RunMult(IntegerType x, IntegerType exponent) {
206     return vqrshlq_s32(x, exponent);
207   }
208   template <typename IntegerType>
209   static inline IntegerType RunMult(IntegerType x, int exponent) {
210     return vqrshlq_s32(x, vdupq_n_s32(static_cast<int32>(exponent)));
211   }
212 };
213 #endif  // ARM NEON
214 
215 // See CategorizeDotProductKernel for definitive taxonomy.
216 enum class DotProduct3x3KernelType {
217   kNone = 0,  // Parameter combination is not supported for dot product kernels.
218   kPlain,
219   kWithDepthMultiplicationStride1,
220   kWithDepthMultiplicationStride2,
221   kStride2,
222 };
223 
224 enum class QuantizationType {
225   kNonPerChannelUint8 = 0,
226   kPerChannelInt8 = 1,
227 };
228 
229 template <QuantizationType quantization_type>
230 struct QuantizationTypeImpl {};
231 
232 template <>
233 struct QuantizationTypeImpl<QuantizationType::kNonPerChannelUint8> {
234   typedef uint8 ExternalType;
235 
236   static constexpr int kIntSymmetricZeroPoint = 128;
237   static constexpr uint8 kUint8SignBit = 0x80;
238 };
239 
240 template <>
241 struct QuantizationTypeImpl<QuantizationType::kPerChannelInt8> {
242   typedef int8 ExternalType;
243 
244   static constexpr int kIntSymmetricZeroPoint = 0;
245   static constexpr uint8 kUint8SignBit = 0x0;
246 };
247 
248 template <
249     QuantizationType quantization_type = QuantizationType::kNonPerChannelUint8>
250 inline DotProduct3x3KernelType CategorizeDotProductKernel(
251     const RuntimeShape& input_shape, const RuntimeShape& filter_shape,
252     const RuntimeShape& output_shape, const DepthwiseParams& params,
253     const int32* output_shift_ptr = nullptr) {
254   constexpr int kSymmetricZeroPoint =
255       QuantizationTypeImpl<quantization_type>::kIntSymmetricZeroPoint;
256   const int padding =
257       std::max(params.padding_values.width, params.padding_values.height);
258   const int stride = params.stride_width;
259   const int32 input_depth = input_shape.Dims(3);
260   const int32 depth_multiplier = params.depth_multiplier;
261   const int32 filter_height = filter_shape.Dims(1);
262   const int32 filter_width = filter_shape.Dims(2);
263 
264   bool supported = stride == params.stride_height && stride <= 2 &&
265                    padding <= 1 && filter_width == 3 && filter_height == 3 &&
266                    params.dilation_width_factor == 1 &&
267                    params.dilation_height_factor == 1 &&
268                    (((input_depth % 8) == 0 && depth_multiplier == 1) ||
269                     (input_depth == 1 && depth_multiplier > 1));
270 
271   if (!supported) {
272     return DotProduct3x3KernelType::kNone;
273   }
274 
275   if (params.weights_offset != -kSymmetricZeroPoint) {
276     return DotProduct3x3KernelType::kNone;
277   }
278 
279   if (quantization_type == QuantizationType::kPerChannelInt8) {
280     if (output_shift_ptr == nullptr) {
281       return DotProduct3x3KernelType::kNone;
282     }
283   } else if (params.output_shift > 0) {
284     return DotProduct3x3KernelType::kNone;
285   }
286 
287   if (params.depth_multiplier == 1) {
288     if (stride == 1) {
289       return DotProduct3x3KernelType::kPlain;
290     } else if (stride == 2) {
291       return DotProduct3x3KernelType::kStride2;
292     } else {
293       return DotProduct3x3KernelType::kNone;
294     }
295   } else {
296     if (stride == 1) {
297       return DotProduct3x3KernelType::kWithDepthMultiplicationStride1;
298     } else if (stride == 2) {
299       return DotProduct3x3KernelType::kWithDepthMultiplicationStride2;
300     } else {
301       return DotProduct3x3KernelType::kNone;
302     }
303   }
304 }
305 
306 // Encapsulates constant parameters used in DepthwiseConv.
307 // 64-bit is used for types that will be added to 64-bit addresses in asm.
308 struct DepthwiseConvParams {
309   int64_t input_depth;
310   int64_t input_row_size;
311   int64_t output_depth;
312   int64_t output_row_size;
313   int64_t filter_row_size;
314   int32 input_offset;
315   int32 output_offset;
316   int32 filter_offset;
317   int32 output_multiplier;
318   int32 output_activation_min;
319   int32 output_activation_max;
320   int32 output_right_shift;
321   int32 input_width;
322   int32 input_height;
323   int32 stride_width;
324   int32 stride_height;
325   int32 output_width;
326   int32 output_height;
327   float float_output_activation_min;
328   float float_output_activation_max;
329 };
330 
331 // Encapsulates constant parameters used in DepthwiseConv using dot-product ops.
332 // 64-bit is used for types that will be added to 64-bit addresses in asm.
333 //
334 // This structure is specifically designed for use in asm.
335 struct DepthwiseConvDotProdParams {
336   int64_t input_depth;
337   int64_t output_depth;
338   int32 stride;
339   int32 bias_increment;
340   //
341   int32 input_offset;
342   int32 output_offset;
343   int32 output_multiplier;
344   int32 output_shift;
345   int32 quantized_activation_min;
346   int32 quantized_activation_max;
347   //
348   int32 padding_left;
349   int32 padding_right;
350   int32 padding_top;
351   int32 padding_bottom;
352   //
353   int32 depth_micro_repeats;
354   //
355   int32 width_macro_count;
356   int32 input_width_overall_micro_repeats;
357   int32 input_width_micro_repeats;
358   int32 residual_width;
359   int32 output_width_overall_micro_repeats;
360   int32 output_width_micro_repeats;
361   int32 output_residual_width;
362   int32 workspace_width_micro_repeats;
363   //
364   int32 height_macro_count;
365   int32 inbound_block_height;
366   int32 outbound_block_height;
367   int32 input_height_stride;
368   int32 output_height_stride;
369   int32 workspace_height_stride;
370   //
371   int32 four_over_stride;
372   //
373   const int32* output_multiplier_per_channel;
374   const int32* output_shift_per_channel;
375 };
376 
377 template <DepthwiseConvOutputRounding output_rounding, int32 kDepth,
378           int32 kStrideWidth, int32 kStrideHeight>
379 struct DepthwiseConvWindow {};
380 
381 template <DepthwiseConvOutputRounding output_rounding, int32 kDepth,
382           int32 kStrideWidth, int32 kStrideHeight>
383 struct DepthwiseConvWindowPerChannel {};
384 
385 enum class EdgeType { kCorner, kHorizontal, kVertical, kCenter };
386 
387 template <DepthwiseConvOutputRounding output_rounding, EdgeType kEdgeType,
388           int kPadWidth, int kPadHeight>
389 struct DepthwiseConvPartial {};
390 
391 template <DepthwiseConvOutputRounding output_rounding, EdgeType kEdgeType,
392           int kPadWidth, int kPadHeight>
393 struct DepthwiseConvPartialPerChannel {};
394 
395 // Copies a subset of the input designated by |input_ptr| into |output_ptr|
396 // with the specified output dimensions. Supports output depths of 64 only as
397 // this is the cache line size.
398 template <typename T>
399 inline void ShuffleInput(const T* input_ptr, int64_t input_depth,
400                          int32 input_width, int32 input_height,
401                          int64_t output_depth, int32 output_width,
402                          int32 output_height, T* output_ptr) {
403   const int64_t input_row_size = input_depth * input_width;
404   for (int32 y = 0; y < output_height; y++) {
405     const T* ptr = input_ptr;
406     for (int32 x = 0; x < output_width; x++) {
407       memcpy(output_ptr, ptr, output_depth);
408       output_ptr += output_depth;
409       ptr += input_depth;
410     }
411     input_ptr += input_row_size;
412   }
413 }
414 
415 // Calculates the input size depending on stride and output.
416 inline int32 get_shuffle_input_size(int32 stride, int32 output) {
417   return stride * (output - 1) + 3;
418 }
419 
420 // Indicates the input and output dimensions used when shuffling input
421 // activations.
422 struct ShuffleParams {
423   int32 output_width;
424   int32 output_height;
425   int32 input_width;
426   int32 input_height;
427 
428   ShuffleParams() = default;
429   ShuffleParams(int32 output_width, int32 output_height, int32 stride_width,
430                 int32 stride_height)
431       : output_width(output_width),
432         output_height(output_height),
433         input_width(get_shuffle_input_size(stride_width, output_width)),
434         input_height(get_shuffle_input_size(stride_height, output_height)) {}
435 };
436 
437 template <
438     QuantizationType quantization_type = QuantizationType::kNonPerChannelUint8>
439 inline bool Fast3x3FilterKernelSupported(
440     const RuntimeShape& input_shape, const RuntimeShape& filter_shape,
441     int32 stride_width, int32 stride_height, int32 dilation_width_factor,
442     int32 dilation_height_factor, int32 pad_width, int32 pad_height,
443     int32 depth_multiplier, const RuntimeShape& output_shape,
444     int32 output_shift, const int32* output_shift_ptr = nullptr) {
445   const int32 input_height = input_shape.Dims(1);
446   const int32 input_width = input_shape.Dims(2);
447   const int32 input_depth = input_shape.Dims(3);
448   const int32 filter_height = filter_shape.Dims(1);
449   const int32 filter_width = filter_shape.Dims(2);
450   const int32 output_height = output_shape.Dims(1);
451   const int32 output_width = output_shape.Dims(2);
452 
453   bool supported =
454       filter_width == 3 && filter_height == 3 && depth_multiplier == 1 &&
455       (stride_width == 1 || stride_width == 2) &&
456       (stride_height == 1 || stride_height == 2) &&
457       (stride_width == stride_height) && (pad_width == 0 || pad_width == 1) &&
458       (pad_height == 0 || pad_height == 1) && (pad_width == pad_height) &&
459       (input_depth % 8) == 0 && (output_shift <= 0) &&
460       dilation_width_factor == 1 && dilation_height_factor == 1;
461 
462   if (!supported) {
463     return false;
464   }
465 
466   // Handle case where padding is zero but padding type is not kValid.
467   // This would require special boundary case handling that is not supported.
468 
469   const int32 out_x = output_width - 1;
470   const int32 out_y = output_height - 1;
471 
472   const int32 in_x_origin = (out_x * stride_width) - pad_width;
473   const int32 in_y_origin = (out_y * stride_height) - pad_height;
474 
475   const int32 in_x_end = in_x_origin + filter_width;
476   const int32 in_y_end = in_y_origin + filter_height;
477 
478   // Supported only if filter on the right and bottom boundary lies completely
479   // within the input if padding is zero.
480   if (pad_width == 0 && pad_height == 0) {
481     return in_x_end <= input_width && in_y_end <= input_height;
482   }
483 
484   // Else if padding is 1, supported if bottom right filter lies +1 past input
485   // width and height.
486   supported = in_x_end <= (input_width + 1) && in_y_end <= (input_height + 1);
487 
488   if (!supported) {
489     return false;
490   }
491 
492   // Shapes with width 1 and height > 1, and vice versa are not supported yet.
493   if (input_width == 1) {
494     supported = (input_width == input_height);
495   } else if (input_height == 1) {
496     supported = (input_width == input_height);
497   }
498   return supported;
499 }
500 
501 // Permute filter data, and adjust bias data to account for symmetric input
502 // offset. Details are provided in the implementation of the
503 // kUseCModel3x3DotProduct version.
504 //
505 // See the comments preceding DepthwiseConvDotProduct3x3() for further notes.
506 template <DepthwiseConvImplementation implementation,
507           QuantizationType quantization_type>
508 struct ProcessPerDepth {
509   // Routine is contained in a static Run() method. No default template version
510   // is supplied, so that all implementations are deliberate choices of template
511   // specialization.
512   //
513   // Note that the signature of the Run() method will be designed for the asm
514   // implementation rather than conforming to style.
515 };
516 
517 // Copy a macro block of data from the input buffer into the workspace,
518 // permuting data within each micro block.
519 //
520 // (a) Copy a macro block of data, padding as required along the width and
521 //     height.
522 // (b) Transpose the data within each micro block.
523 //
524 // See the comments preceding DepthwiseConvDotProduct3x3() for further notes.
525 template <DepthwiseConvImplementation implementation,
526           QuantizationType quantization_type,
527           DepthwiseConvDepthMultiplication depth_multiplication,
528           int32 max_padding>
529 struct PackMacroBlock {
530   // Routine is contained in a static Run() method. No default template version
531   // is supplied, so that all implementations are deliberate choices of template
532   // specialization.
533   //
534   // Note that the signature of the Run() method will be designed for the asm
535   // implementation rather than conforming to style.
536 };
537 
538 // Apply filter to macro block of input data and store results. Details are
539 // provided in the implementation of the kUseCModel3x3DotProduct version.
540 //
541 // Parameters for repeats and residual sizes are in terms of outputs.
542 //
543 // See the comments preceding DepthwiseConvDotProduct3x3() for further notes.
544 template <DepthwiseConvImplementation implementation,
545           QuantizationType quantization_type,
546           DepthwiseConvDepthMultiplication depth_multiplication, int32 stride>
547 struct KernelMacroBlock {
548   // Routine is contained in a static Run() method. No default template version
549   // is supplied, so that all implementations are deliberate choices of template
550   // specialization.
551   //
552   // Note that the signature of the Run() method will be designed for the asm
553   // implementation rather than conforming to style.
554 };
555 
556 #if defined(__aarch64__)
557 // Experiments suggest that a modest performance improvement is seen, at least
558 // on 855 chipset big cores, with cache hints.
559 template <typename T>
560 inline void PreloadInputBlock(
561     const T* input_block_data,
562     const DepthwiseConvDotProdParams* function_params) {
563   // Preload.
564   const int input_width_micro_repeats =
565       function_params->input_width_micro_repeats;
566   const int block_height = function_params->inbound_block_height;
567   const int residual_width = function_params->residual_width;
568   const int input_height_stride = function_params->input_height_stride;
569   const int input_depth = function_params->input_depth;
570 
571   const int total_width = 4 * input_width_micro_repeats + residual_width;
572   const T* row_ptr = input_block_data;
573   for (int k_height = 0; k_height < block_height; ++k_height) {
574     const T* ptr = row_ptr;
575     for (int j = 0; j < total_width; ++j) {
576       // Input data is loaded once.
577       optimized_ops_preload_l1_keep(ptr);
578       ptr += input_depth;
579     }
580     row_ptr += input_height_stride;
581   }
582 }
583 #endif  // __aarch64__
584 
585 }  // namespace depthwise_conv
586 }  // namespace optimized_ops
587 }  // namespace tflite
588 
589 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_3X3_FILTER_COMMON_H_
590