1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_3X3_FILTER_COMMON_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_3X3_FILTER_COMMON_H_
17
18 #include "ruy/profiler/instrumentation.h" // from @ruy
19 #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h"
20 #include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h"
21 #include "tensorflow/lite/kernels/internal/types.h"
22
23 namespace tflite {
24 namespace optimized_ops {
25 namespace depthwise_conv {
26
27 constexpr int kDepthwiseConvScratchWorkspaceSize = 10 * 10 * 64;
28 constexpr int kDepthwiseConvAdjustedBiasLimit = 64;
29 // In cases such as depth multiplication, we want to be able to load data from
30 // the workspace that is beyond the valid range. Macro-block sizes are adjusted
31 // to allow for this.
32 constexpr int kWorkspaceExtension = 16;
33
34 #ifdef USE_NEON
35
36 #ifndef __aarch64__
vqtbl4q_s8(int8x16x4_t a,int8x16_t b)37 inline int8x16_t vqtbl4q_s8(int8x16x4_t a, int8x16_t b) {
38 const uint8x16_t mask = vtstq_s8(b, vdupq_n_s8(8));
39
40 // Delete bit 3 from the indices.
41 const int8x16_t high_bits = vshrq_n_s8(b, 4);
42 int8x16_t deleted_bit_3 = b;
43 deleted_bit_3 = vsliq_n_s8(deleted_bit_3, high_bits, 3);
44
45 int8x8x4_t repacked_data;
46
47 // Calculate for lower indices.
48 repacked_data.val[0] = vget_low_s8(a.val[0]);
49 repacked_data.val[1] = vget_low_s8(a.val[1]);
50 repacked_data.val[2] = vget_low_s8(a.val[2]);
51 repacked_data.val[3] = vget_low_s8(a.val[3]);
52 const int8x16_t output_for_lower =
53 vcombine_s8(vtbl4_s8(repacked_data, vget_low_s8(deleted_bit_3)),
54 vtbl4_s8(repacked_data, vget_high_s8(deleted_bit_3)));
55
56 // Calculate for high indices.
57 repacked_data.val[0] = vget_high_s8(a.val[0]);
58 repacked_data.val[1] = vget_high_s8(a.val[1]);
59 repacked_data.val[2] = vget_high_s8(a.val[2]);
60 repacked_data.val[3] = vget_high_s8(a.val[3]);
61 const int8x16_t output_for_higher =
62 vcombine_s8(vtbl4_s8(repacked_data, vget_low_s8(deleted_bit_3)),
63 vtbl4_s8(repacked_data, vget_high_s8(deleted_bit_3)));
64
65 // Merge.
66 int8x16_t output = vbslq_s8(mask, output_for_higher, output_for_lower);
67 return output;
68 }
69 #endif // !__aarch64__
70
71 // Convenience-compatibility functions.
72 // Compatibility: Intrinsics reflect a mixture of older and newer ARM
73 // instructions. This actually results in ZIP1 / ZIP2 asm instructions, but
74 // one intrinsic is provided. Also older instructions operated in place,
75 // and it seems more defensive to assume that some versions of intrinsics
76 // might reflect this
77 // Convenience: Callers in these kernels want both ZIP1 and ZIP2, and we do not
78 // want the calling code to get cluttered with unpacking int8x16x2_t.
vzipq_s8_in_place(int8x16_t * a,int8x16_t * b)79 inline void vzipq_s8_in_place(int8x16_t* a, int8x16_t* b) {
80 int8x16x2_t r8x16;
81 r8x16 = vzipq_s8(*a, *b);
82 *a = r8x16.val[0];
83 *b = r8x16.val[1];
84 }
85
vzipq_s8x2_in_place(int8x16_t * a,int8x16_t * b)86 inline void vzipq_s8x2_in_place(int8x16_t* a, int8x16_t* b) {
87 int16x8x2_t r16x8;
88 r16x8 = vzipq_s16(vreinterpretq_s16_s8(*a), vreinterpretq_s16_s8(*b));
89 *a = vreinterpretq_s8_s16(r16x8.val[0]);
90 *b = vreinterpretq_s8_s16(r16x8.val[1]);
91 }
92
93 // Similar rationale to the zip-in_place functions, but callers only actually
94 // need the TRN1 asm instruction result.
vtrn1_s8x2_in_place(int8x16_t * a,int8x16_t * b)95 inline void vtrn1_s8x2_in_place(int8x16_t* a, int8x16_t* b) {
96 int16x8x2_t r16x8;
97 r16x8 = vtrnq_s16(vreinterpretq_s16_s8(*a), vreinterpretq_s16_s8(*b));
98 *a = vreinterpretq_s8_s16(r16x8.val[0]);
99 }
100
101 // Similar rationale to the zip-in_place functions, but callers only actually
102 // need the ZIP1 or ZIP2 asm instruction results.
vzip1q_s8(int8x16_t a,int8x16_t b)103 inline int8x16_t vzip1q_s8(int8x16_t a, int8x16_t b) {
104 return vzipq_s8(a, b).val[0];
105 }
vzip2q_s8(int8x16_t a,int8x16_t b)106 inline int8x16_t vzip2q_s8(int8x16_t a, int8x16_t b) {
107 return vzipq_s8(a, b).val[1];
108 }
109
biregister_rotate_8(int8x16_t * left,int8x16_t * right)110 inline void biregister_rotate_8(int8x16_t* left, int8x16_t* right) {
111 *left = vreinterpretq_s8_u32(vshrq_n_u32(vreinterpretq_u32_s8(*left), 8));
112 *left = vreinterpretq_s8_u32(vsliq_n_u32(vreinterpretq_u32_s8(*left),
113 vreinterpretq_u32_s8(*right), 24));
114 *right = vreinterpretq_s8_u32(vshrq_n_u32(vreinterpretq_u32_s8(*right), 8));
115 }
116
117 #ifndef __aarch64__
vpaddq_s32(int32x4_t a,int32x4_t b)118 inline int32x4_t vpaddq_s32(int32x4_t a, int32x4_t b) {
119 int32x4x2_t deinterleaved = vuzpq_s32(a, b);
120 return vqaddq_s32(deinterleaved.val[0], deinterleaved.val[1]);
121 }
122 #endif // !__aarch64__
123
124 #ifdef __ARM_FEATURE_DOTPROD
125 // The vdotq_lane_s32 takes int8x8t for the rhs parameter, whereas the actual
126 // instruction selects from between 4 32-bit (4x8-bit packed) sub-registers, an
127 // unusual interpretation of "lane".
vdotq_four_lane_s32(int32x4_t acc,int8x16_t lhs,int8x16_t rhs,const int lane)128 inline int32x4_t vdotq_four_lane_s32(int32x4_t acc, int8x16_t lhs,
129 int8x16_t rhs, const int lane) {
130 switch (lane) {
131 case 0:
132 return vdotq_lane_s32(acc, lhs, vreinterpret_s32_s8(vget_low_s8(rhs)), 0);
133 case 1:
134 return vdotq_lane_s32(acc, lhs, vreinterpret_s32_s8(vget_low_s8(rhs)), 1);
135 case 2:
136 return vdotq_lane_s32(acc, lhs, vreinterpret_s32_s8(vget_high_s8(rhs)),
137 0);
138 case 3:
139 default:
140 return vdotq_lane_s32(acc, lhs, vreinterpret_s32_s8(vget_high_s8(rhs)),
141 1);
142 }
143 }
144
145 #else
146
vdotq_s32(int32x4_t acc,int8x16_t lhs,int8x16_t rhs)147 inline int32x4_t vdotq_s32(int32x4_t acc, int8x16_t lhs, int8x16_t rhs) {
148 int32x4_t sum0 = vpaddlq_s16(vmull_s8(vget_low_s8(lhs), vget_low_s8(rhs)));
149 int32x4_t sum1 = vpaddlq_s16(vmull_s8(vget_high_s8(lhs), vget_high_s8(rhs)));
150 int32x4_t sum = vpaddq_s32(sum0, sum1);
151 return vaddq_s32(acc, sum);
152 }
153
vdotq_four_lane_s32(int32x4_t acc,int8x16_t lhs,int8x16_t rhs,int lane)154 inline int32x4_t vdotq_four_lane_s32(int32x4_t acc, int8x16_t lhs,
155 int8x16_t rhs, int lane) {
156 int8x8_t lane_rhs;
157 if (lane == 0) {
158 lane_rhs = vreinterpret_s8_s32(
159 vdup_lane_s32(vreinterpret_s32_s8(vget_low_s8(rhs)), 0));
160 } else if (lane == 1) {
161 lane_rhs = vreinterpret_s8_s32(
162 vdup_lane_s32(vreinterpret_s32_s8(vget_low_s8(rhs)), 1));
163 } else if (lane == 2) {
164 lane_rhs = vreinterpret_s8_s32(
165 vdup_lane_s32(vreinterpret_s32_s8(vget_high_s8(rhs)), 0));
166 } else {
167 lane_rhs = vreinterpret_s8_s32(
168 vdup_lane_s32(vreinterpret_s32_s8(vget_high_s8(rhs)), 1));
169 }
170 int32x4_t sum0 = vpaddlq_s16(vmull_s8(vget_low_s8(lhs), lane_rhs));
171 int32x4_t sum1 = vpaddlq_s16(vmull_s8(vget_high_s8(lhs), lane_rhs));
172 int32x4_t sum = vpaddq_s32(sum0, sum1);
173 return vaddq_s32(acc, sum);
174 }
175
176 #endif // !__ARM_FEATURE_DOTPROD
177 #endif // ARM NEON
178
179 // This structure is typically used for reducing the magnitude of outputs, and
180 // the historical name reflects that.
181 template <DepthwiseConvOutputRounding output_rounding>
182 struct DivideByPOT {};
183
184 template <>
185 struct DivideByPOT<DepthwiseConvOutputRounding::kAwayFromZero> {
186 template <typename IntegerType>
187 static inline IntegerType Run(IntegerType x, int exponent) {
188 return RoundingDivideByPOT(x, exponent);
189 }
190 // Mult versions use the exponents directly, rather than negated.
191 template <typename IntegerType>
192 static inline IntegerType RunMult(IntegerType x, int exponent) {
193 return RoundingDivideByPOT(x, -exponent);
194 }
195 };
196
197 #ifdef USE_NEON
198 template <>
199 struct DivideByPOT<DepthwiseConvOutputRounding::kUpward> {
200 template <typename IntegerType>
201 static inline IntegerType Run(IntegerType x, int exponent) {
202 return vqrshlq_s32(x, vdupq_n_s32(static_cast<int32>(-exponent)));
203 }
204 template <typename IntegerType>
205 static inline IntegerType RunMult(IntegerType x, IntegerType exponent) {
206 return vqrshlq_s32(x, exponent);
207 }
208 template <typename IntegerType>
209 static inline IntegerType RunMult(IntegerType x, int exponent) {
210 return vqrshlq_s32(x, vdupq_n_s32(static_cast<int32>(exponent)));
211 }
212 };
213 #endif // ARM NEON
214
215 // See CategorizeDotProductKernel for definitive taxonomy.
216 enum class DotProduct3x3KernelType {
217 kNone = 0, // Parameter combination is not supported for dot product kernels.
218 kPlain,
219 kWithDepthMultiplicationStride1,
220 kWithDepthMultiplicationStride2,
221 kStride2,
222 };
223
224 enum class QuantizationType {
225 kNonPerChannelUint8 = 0,
226 kPerChannelInt8 = 1,
227 };
228
229 template <QuantizationType quantization_type>
230 struct QuantizationTypeImpl {};
231
232 template <>
233 struct QuantizationTypeImpl<QuantizationType::kNonPerChannelUint8> {
234 typedef uint8 ExternalType;
235
236 static constexpr int kIntSymmetricZeroPoint = 128;
237 static constexpr uint8 kUint8SignBit = 0x80;
238 };
239
240 template <>
241 struct QuantizationTypeImpl<QuantizationType::kPerChannelInt8> {
242 typedef int8 ExternalType;
243
244 static constexpr int kIntSymmetricZeroPoint = 0;
245 static constexpr uint8 kUint8SignBit = 0x0;
246 };
247
248 template <
249 QuantizationType quantization_type = QuantizationType::kNonPerChannelUint8>
250 inline DotProduct3x3KernelType CategorizeDotProductKernel(
251 const RuntimeShape& input_shape, const RuntimeShape& filter_shape,
252 const RuntimeShape& output_shape, const DepthwiseParams& params,
253 const int32* output_shift_ptr = nullptr) {
254 constexpr int kSymmetricZeroPoint =
255 QuantizationTypeImpl<quantization_type>::kIntSymmetricZeroPoint;
256 const int padding =
257 std::max(params.padding_values.width, params.padding_values.height);
258 const int stride = params.stride_width;
259 const int32 input_depth = input_shape.Dims(3);
260 const int32 depth_multiplier = params.depth_multiplier;
261 const int32 filter_height = filter_shape.Dims(1);
262 const int32 filter_width = filter_shape.Dims(2);
263
264 bool supported = stride == params.stride_height && stride <= 2 &&
265 padding <= 1 && filter_width == 3 && filter_height == 3 &&
266 params.dilation_width_factor == 1 &&
267 params.dilation_height_factor == 1 &&
268 (((input_depth % 8) == 0 && depth_multiplier == 1) ||
269 (input_depth == 1 && depth_multiplier > 1));
270
271 if (!supported) {
272 return DotProduct3x3KernelType::kNone;
273 }
274
275 if (params.weights_offset != -kSymmetricZeroPoint) {
276 return DotProduct3x3KernelType::kNone;
277 }
278
279 if (quantization_type == QuantizationType::kPerChannelInt8) {
280 if (output_shift_ptr == nullptr) {
281 return DotProduct3x3KernelType::kNone;
282 }
283 } else if (params.output_shift > 0) {
284 return DotProduct3x3KernelType::kNone;
285 }
286
287 if (params.depth_multiplier == 1) {
288 if (stride == 1) {
289 return DotProduct3x3KernelType::kPlain;
290 } else if (stride == 2) {
291 return DotProduct3x3KernelType::kStride2;
292 } else {
293 return DotProduct3x3KernelType::kNone;
294 }
295 } else {
296 if (stride == 1) {
297 return DotProduct3x3KernelType::kWithDepthMultiplicationStride1;
298 } else if (stride == 2) {
299 return DotProduct3x3KernelType::kWithDepthMultiplicationStride2;
300 } else {
301 return DotProduct3x3KernelType::kNone;
302 }
303 }
304 }
305
306 // Encapsulates constant parameters used in DepthwiseConv.
307 // 64-bit is used for types that will be added to 64-bit addresses in asm.
308 struct DepthwiseConvParams {
309 int64_t input_depth;
310 int64_t input_row_size;
311 int64_t output_depth;
312 int64_t output_row_size;
313 int64_t filter_row_size;
314 int32 input_offset;
315 int32 output_offset;
316 int32 filter_offset;
317 int32 output_multiplier;
318 int32 output_activation_min;
319 int32 output_activation_max;
320 int32 output_right_shift;
321 int32 input_width;
322 int32 input_height;
323 int32 stride_width;
324 int32 stride_height;
325 int32 output_width;
326 int32 output_height;
327 float float_output_activation_min;
328 float float_output_activation_max;
329 };
330
331 // Encapsulates constant parameters used in DepthwiseConv using dot-product ops.
332 // 64-bit is used for types that will be added to 64-bit addresses in asm.
333 //
334 // This structure is specifically designed for use in asm.
335 struct DepthwiseConvDotProdParams {
336 int64_t input_depth;
337 int64_t output_depth;
338 int32 stride;
339 int32 bias_increment;
340 //
341 int32 input_offset;
342 int32 output_offset;
343 int32 output_multiplier;
344 int32 output_shift;
345 int32 quantized_activation_min;
346 int32 quantized_activation_max;
347 //
348 int32 padding_left;
349 int32 padding_right;
350 int32 padding_top;
351 int32 padding_bottom;
352 //
353 int32 depth_micro_repeats;
354 //
355 int32 width_macro_count;
356 int32 input_width_overall_micro_repeats;
357 int32 input_width_micro_repeats;
358 int32 residual_width;
359 int32 output_width_overall_micro_repeats;
360 int32 output_width_micro_repeats;
361 int32 output_residual_width;
362 int32 workspace_width_micro_repeats;
363 //
364 int32 height_macro_count;
365 int32 inbound_block_height;
366 int32 outbound_block_height;
367 int32 input_height_stride;
368 int32 output_height_stride;
369 int32 workspace_height_stride;
370 //
371 int32 four_over_stride;
372 //
373 const int32* output_multiplier_per_channel;
374 const int32* output_shift_per_channel;
375 };
376
377 template <DepthwiseConvOutputRounding output_rounding, int32 kDepth,
378 int32 kStrideWidth, int32 kStrideHeight>
379 struct DepthwiseConvWindow {};
380
381 template <DepthwiseConvOutputRounding output_rounding, int32 kDepth,
382 int32 kStrideWidth, int32 kStrideHeight>
383 struct DepthwiseConvWindowPerChannel {};
384
385 enum class EdgeType { kCorner, kHorizontal, kVertical, kCenter };
386
387 template <DepthwiseConvOutputRounding output_rounding, EdgeType kEdgeType,
388 int kPadWidth, int kPadHeight>
389 struct DepthwiseConvPartial {};
390
391 template <DepthwiseConvOutputRounding output_rounding, EdgeType kEdgeType,
392 int kPadWidth, int kPadHeight>
393 struct DepthwiseConvPartialPerChannel {};
394
395 // Copies a subset of the input designated by |input_ptr| into |output_ptr|
396 // with the specified output dimensions. Supports output depths of 64 only as
397 // this is the cache line size.
398 template <typename T>
399 inline void ShuffleInput(const T* input_ptr, int64_t input_depth,
400 int32 input_width, int32 input_height,
401 int64_t output_depth, int32 output_width,
402 int32 output_height, T* output_ptr) {
403 const int64_t input_row_size = input_depth * input_width;
404 for (int32 y = 0; y < output_height; y++) {
405 const T* ptr = input_ptr;
406 for (int32 x = 0; x < output_width; x++) {
407 memcpy(output_ptr, ptr, output_depth);
408 output_ptr += output_depth;
409 ptr += input_depth;
410 }
411 input_ptr += input_row_size;
412 }
413 }
414
415 // Calculates the input size depending on stride and output.
416 inline int32 get_shuffle_input_size(int32 stride, int32 output) {
417 return stride * (output - 1) + 3;
418 }
419
420 // Indicates the input and output dimensions used when shuffling input
421 // activations.
422 struct ShuffleParams {
423 int32 output_width;
424 int32 output_height;
425 int32 input_width;
426 int32 input_height;
427
428 ShuffleParams() = default;
429 ShuffleParams(int32 output_width, int32 output_height, int32 stride_width,
430 int32 stride_height)
431 : output_width(output_width),
432 output_height(output_height),
433 input_width(get_shuffle_input_size(stride_width, output_width)),
434 input_height(get_shuffle_input_size(stride_height, output_height)) {}
435 };
436
437 template <
438 QuantizationType quantization_type = QuantizationType::kNonPerChannelUint8>
439 inline bool Fast3x3FilterKernelSupported(
440 const RuntimeShape& input_shape, const RuntimeShape& filter_shape,
441 int32 stride_width, int32 stride_height, int32 dilation_width_factor,
442 int32 dilation_height_factor, int32 pad_width, int32 pad_height,
443 int32 depth_multiplier, const RuntimeShape& output_shape,
444 int32 output_shift, const int32* output_shift_ptr = nullptr) {
445 const int32 input_height = input_shape.Dims(1);
446 const int32 input_width = input_shape.Dims(2);
447 const int32 input_depth = input_shape.Dims(3);
448 const int32 filter_height = filter_shape.Dims(1);
449 const int32 filter_width = filter_shape.Dims(2);
450 const int32 output_height = output_shape.Dims(1);
451 const int32 output_width = output_shape.Dims(2);
452
453 bool supported =
454 filter_width == 3 && filter_height == 3 && depth_multiplier == 1 &&
455 (stride_width == 1 || stride_width == 2) &&
456 (stride_height == 1 || stride_height == 2) &&
457 (stride_width == stride_height) && (pad_width == 0 || pad_width == 1) &&
458 (pad_height == 0 || pad_height == 1) && (pad_width == pad_height) &&
459 (input_depth % 8) == 0 && (output_shift <= 0) &&
460 dilation_width_factor == 1 && dilation_height_factor == 1;
461
462 if (!supported) {
463 return false;
464 }
465
466 // Handle case where padding is zero but padding type is not kValid.
467 // This would require special boundary case handling that is not supported.
468
469 const int32 out_x = output_width - 1;
470 const int32 out_y = output_height - 1;
471
472 const int32 in_x_origin = (out_x * stride_width) - pad_width;
473 const int32 in_y_origin = (out_y * stride_height) - pad_height;
474
475 const int32 in_x_end = in_x_origin + filter_width;
476 const int32 in_y_end = in_y_origin + filter_height;
477
478 // Supported only if filter on the right and bottom boundary lies completely
479 // within the input if padding is zero.
480 if (pad_width == 0 && pad_height == 0) {
481 return in_x_end <= input_width && in_y_end <= input_height;
482 }
483
484 // Else if padding is 1, supported if bottom right filter lies +1 past input
485 // width and height.
486 supported = in_x_end <= (input_width + 1) && in_y_end <= (input_height + 1);
487
488 if (!supported) {
489 return false;
490 }
491
492 // Shapes with width 1 and height > 1, and vice versa are not supported yet.
493 if (input_width == 1) {
494 supported = (input_width == input_height);
495 } else if (input_height == 1) {
496 supported = (input_width == input_height);
497 }
498 return supported;
499 }
500
501 // Permute filter data, and adjust bias data to account for symmetric input
502 // offset. Details are provided in the implementation of the
503 // kUseCModel3x3DotProduct version.
504 //
505 // See the comments preceding DepthwiseConvDotProduct3x3() for further notes.
506 template <DepthwiseConvImplementation implementation,
507 QuantizationType quantization_type>
508 struct ProcessPerDepth {
509 // Routine is contained in a static Run() method. No default template version
510 // is supplied, so that all implementations are deliberate choices of template
511 // specialization.
512 //
513 // Note that the signature of the Run() method will be designed for the asm
514 // implementation rather than conforming to style.
515 };
516
517 // Copy a macro block of data from the input buffer into the workspace,
518 // permuting data within each micro block.
519 //
520 // (a) Copy a macro block of data, padding as required along the width and
521 // height.
522 // (b) Transpose the data within each micro block.
523 //
524 // See the comments preceding DepthwiseConvDotProduct3x3() for further notes.
525 template <DepthwiseConvImplementation implementation,
526 QuantizationType quantization_type,
527 DepthwiseConvDepthMultiplication depth_multiplication,
528 int32 max_padding>
529 struct PackMacroBlock {
530 // Routine is contained in a static Run() method. No default template version
531 // is supplied, so that all implementations are deliberate choices of template
532 // specialization.
533 //
534 // Note that the signature of the Run() method will be designed for the asm
535 // implementation rather than conforming to style.
536 };
537
538 // Apply filter to macro block of input data and store results. Details are
539 // provided in the implementation of the kUseCModel3x3DotProduct version.
540 //
541 // Parameters for repeats and residual sizes are in terms of outputs.
542 //
543 // See the comments preceding DepthwiseConvDotProduct3x3() for further notes.
544 template <DepthwiseConvImplementation implementation,
545 QuantizationType quantization_type,
546 DepthwiseConvDepthMultiplication depth_multiplication, int32 stride>
547 struct KernelMacroBlock {
548 // Routine is contained in a static Run() method. No default template version
549 // is supplied, so that all implementations are deliberate choices of template
550 // specialization.
551 //
552 // Note that the signature of the Run() method will be designed for the asm
553 // implementation rather than conforming to style.
554 };
555
556 #if defined(__aarch64__)
557 // Experiments suggest that a modest performance improvement is seen, at least
558 // on 855 chipset big cores, with cache hints.
559 template <typename T>
560 inline void PreloadInputBlock(
561 const T* input_block_data,
562 const DepthwiseConvDotProdParams* function_params) {
563 // Preload.
564 const int input_width_micro_repeats =
565 function_params->input_width_micro_repeats;
566 const int block_height = function_params->inbound_block_height;
567 const int residual_width = function_params->residual_width;
568 const int input_height_stride = function_params->input_height_stride;
569 const int input_depth = function_params->input_depth;
570
571 const int total_width = 4 * input_width_micro_repeats + residual_width;
572 const T* row_ptr = input_block_data;
573 for (int k_height = 0; k_height < block_height; ++k_height) {
574 const T* ptr = row_ptr;
575 for (int j = 0; j < total_width; ++j) {
576 // Input data is loaded once.
577 optimized_ops_preload_l1_keep(ptr);
578 ptr += input_depth;
579 }
580 row_ptr += input_height_stride;
581 }
582 }
583 #endif // __aarch64__
584
585 } // namespace depthwise_conv
586 } // namespace optimized_ops
587 } // namespace tflite
588
589 #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_3X3_FILTER_COMMON_H_
590