• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/kernels/internal/optimized/integer_ops/fully_connected.h"
17 
18 #include <algorithm>
19 #include <cstddef>
20 #include <cstdint>
21 
22 #include "tensorflow/lite/c/builtin_op_data.h"
23 #include "tensorflow/lite/c/common.h"
24 #include "tensorflow/lite/kernels/cpu_backend_context.h"
25 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
26 #include "tensorflow/lite/kernels/internal/optimized/sparse_ops/fully_connected.h"
27 #include "tensorflow/lite/kernels/internal/quantization_util.h"
28 #include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
29 #include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h"
30 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
31 #include "tensorflow/lite/kernels/internal/reference/sparse_ops/fully_connected.h"
32 #include "tensorflow/lite/kernels/internal/tensor.h"
33 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
34 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
35 #include "tensorflow/lite/kernels/internal/types.h"
36 #include "tensorflow/lite/kernels/kernel_util.h"
37 
38 namespace tflite {
39 namespace ops {
40 namespace builtin {
41 namespace fully_connected {
42 
43 namespace {
SupportedSparsityFormat(const TfLiteSparsity & sparsity)44 bool SupportedSparsityFormat(const TfLiteSparsity& sparsity) {
45   if (sparsity.dim_metadata[0].format == kTfLiteDimDense &&
46       sparsity.dim_metadata[1].format == kTfLiteDimSparseCSR) {
47     return true;
48   }
49 
50   return false;
51 }
52 
53 static const int kDimMetadataSizeRandomSparse = 2;
54 static const int kDimMetadataSizeBlockSparse = 3;
55 
CreateLedgerTensor(const TfLiteSparsity * sparsity,TfLiteContext * context,TfLiteTensor * ledger)56 TfLiteStatus CreateLedgerTensor(const TfLiteSparsity* sparsity,
57                                 TfLiteContext* context, TfLiteTensor* ledger) {
58   TF_LITE_ENSURE(context, sparsity != nullptr);
59   ledger->type = kTfLiteUInt8;
60   ledger->allocation_type = kTfLiteArenaRwPersistent;
61   TfLiteIntArray* ledger_size = TfLiteIntArrayCreate(1);
62   ledger_size->data[0] = sparsity->dim_metadata[1].array_indices->size +
63                          sparsity->dim_metadata[1].array_segments->size - 1;
64   return context->ResizeTensor(context, ledger, ledger_size);
65 }
66 
PopulateLedgerData(const TfLiteSparsity * sparsity,TfLiteContext * context,uint8_t * ledger_data)67 TfLiteStatus PopulateLedgerData(const TfLiteSparsity* sparsity,
68                                 TfLiteContext* context, uint8_t* ledger_data) {
69   TF_LITE_ENSURE(context, sparsity != nullptr);
70   const auto* array_segments = sparsity->dim_metadata[1].array_segments;
71   const auto* array_indices = sparsity->dim_metadata[1].array_indices;
72   int output_data_ptr = 0;
73 
74   for (int i = 0; i < array_segments->size - 1; i++) {
75     int row_start = array_segments->data[i];
76     int row_end = array_segments->data[i + 1];
77     if (row_end - row_start > UINT8_MAX) {
78       return kTfLiteError;
79     }
80     // Copy num of non-zero blocks in row i.
81     ledger_data[output_data_ptr] = static_cast<uint8_t>(row_end - row_start);
82     output_data_ptr++;
83 
84     for (int j = row_start; j < row_end; j++) {
85       if (array_indices->data[j] > UINT8_MAX) {
86         return kTfLiteError;
87       }
88       // Copy indices of non-zero blocks in row i.
89       ledger_data[output_data_ptr] =
90           static_cast<uint8_t>(array_indices->data[j]);
91       output_data_ptr++;
92     }
93   }
94   return kTfLiteOk;
95 }
96 
97 }  // namespace
98 
99 // This file has four implementations of FullyConnected
100 enum KernelType {
101   kReference,
102   kGenericOptimized,
103   kLegacyPie,  // Legacy path used by the PIE team and related clients.
104 };
105 
106 struct OpData {
107   // The scaling factor from input to output (aka the 'real multiplier') can
108   // be represented as a fixed point multiplier plus a left shift.
109   int32_t output_multiplier;
110   int output_shift;
111   // The range of the fused activation layer. For example for kNone and
112   // uint8_t these would be 0 and 255.
113   int32_t output_activation_min;
114   int32_t output_activation_max;
115   // The index of the temporary tensor where the quantized inputs are cached.
116   int scratch_tensor_index;
117   bool compute_row_sums = false;
118   // Only used for sparse hybrid fully connected kernels.
119   bool ledger_initialized;
120 };
121 
122 constexpr int kInputTensor = 0;
123 constexpr int kWeightsTensor = 1;
124 constexpr int kBiasTensor = 2;
125 constexpr int kOutputTensor = 0;
126 constexpr int kShuffledInputWorkspaceTensor = 1;
127 
CheckTypes(TfLiteContext * context,const TfLiteTensor * input,const TfLiteTensor * filter,const TfLiteTensor * bias,TfLiteTensor * output,TfLiteFullyConnectedParams * params)128 inline TfLiteStatus CheckTypes(TfLiteContext* context,
129                                const TfLiteTensor* input,
130                                const TfLiteTensor* filter,
131                                const TfLiteTensor* bias, TfLiteTensor* output,
132                                TfLiteFullyConnectedParams* params) {
133   const bool is_quantized =
134       ((filter->type == kTfLiteUInt8) || (filter->type == kTfLiteInt8));
135   const bool is_hybrid = is_quantized && (input->type == kTfLiteFloat32);
136   const bool is_shuffled =
137       is_quantized && (params->weights_format ==
138                        kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8);
139 
140   // optional bias tensor.
141   const bool is_optional_bias_float = !bias || (bias->type == kTfLiteFloat32);
142   const bool is_optional_bias_int =
143       !bias || (bias->type == kTfLiteInt32) || (bias->type == kTfLiteInt64);
144 
145   if (is_quantized) {
146     if (is_shuffled) {
147       TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteUInt8);
148       TF_LITE_ENSURE_TYPES_EQ(context, filter->type, kTfLiteUInt8);
149       TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteInt16);
150       TF_LITE_ENSURE_EQ(context, is_optional_bias_int, true);
151     } else if (is_hybrid) {
152       TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
153       TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32);
154       TF_LITE_ENSURE_EQ(context, is_optional_bias_float, true);
155     } else {
156       TF_LITE_ENSURE(context, input->type == kTfLiteUInt8 ||
157                                   input->type == kTfLiteInt8 ||
158                                   input->type == kTfLiteInt16);
159       TF_LITE_ENSURE(context, output->type == kTfLiteUInt8 ||
160                                   output->type == kTfLiteInt8 ||
161                                   output->type == kTfLiteInt16);
162       TF_LITE_ENSURE_EQ(context, is_optional_bias_int, true);
163     }
164   } else {
165     // Only float32 is supported currently
166     TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
167     TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32);
168     TF_LITE_ENSURE_TYPES_EQ(context, filter->type, kTfLiteFloat32);
169     TF_LITE_ENSURE_EQ(context, is_optional_bias_float, true);
170   }
171 
172   return kTfLiteOk;
173 }
174 
Init(TfLiteContext * context,const char * buffer,size_t length)175 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
176   // This is a builtin op, so we don't use the contents in 'buffer', if any.
177   // Instead, we allocate a new object to carry information from Prepare() to
178   // Eval().
179   auto* op_data = new OpData();
180   context->AddTensors(context, /*tensors_to_add=*/6,
181                       &op_data->scratch_tensor_index);
182   return op_data;
183 }
184 
Free(TfLiteContext * context,void * buffer)185 void Free(TfLiteContext* context, void* buffer) {
186   delete reinterpret_cast<OpData*>(buffer);
187 }
188 
PrepareImpl(TfLiteContext * context,TfLiteNode * node)189 TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node) {
190   auto* params =
191       reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data);
192   OpData* data = reinterpret_cast<OpData*>(node->user_data);
193   // Check we have all the inputs and outputs we need.
194   TF_LITE_ENSURE(context, node->inputs->size == 2 || node->inputs->size == 3);
195   // Shuffled formats need a workspace to store the shuffled input activations.
196   const int expected_outputs_count =
197       params->weights_format == kTfLiteFullyConnectedWeightsFormatDefault ? 1
198                                                                           : 2;
199   TF_LITE_ENSURE_EQ(context, node->outputs->size, expected_outputs_count);
200 
201   const TfLiteTensor* input;
202   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
203   const TfLiteTensor* filter;
204   TF_LITE_ENSURE_OK(context,
205                     GetInputSafe(context, node, kWeightsTensor, &filter));
206   const TfLiteTensor* bias =
207       (node->inputs->size == 3)
208           ? GetOptionalInputTensor(context, node, kBiasTensor)
209           : nullptr;
210   TfLiteTensor* output;
211   TF_LITE_ENSURE_OK(context,
212                     GetOutputSafe(context, node, kOutputTensor, &output));
213 
214   // Check proper datatype match among all Input Tensors
215   TF_LITE_ENSURE_STATUS(
216       CheckTypes(context, input, filter, bias, output, params));
217 
218   // Check all the parameters of tensor match within themselves and match the
219   // input configuration.
220   int input_size = 1;
221   for (int i = 0; i < input->dims->size; i++) {
222     input_size *= input->dims->data[i];
223   }
224 
225   TF_LITE_ENSURE_EQ(context, NumDimensions(filter), 2);
226   const int batch_size = input_size / filter->dims->data[1];
227   const int num_units = filter->dims->data[0];
228 
229   if (bias) {
230     TF_LITE_ENSURE_EQ(context, NumElements(bias), SizeOfDimension(filter, 0));
231   }
232 
233   // Note that quantized inference requires that all tensors have their
234   // parameters set. This is usually done during quantized training.
235   if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8 ||
236       input->type == kTfLiteInt16) {
237     double real_multiplier = 0.0;
238     TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
239         context, input, filter, bias, output, &real_multiplier));
240     int exponent;
241     QuantizeMultiplier(real_multiplier, &data->output_multiplier, &exponent);
242     data->output_shift = exponent;
243     TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized(
244         context, params->activation, output, &data->output_activation_min,
245         &data->output_activation_max));
246   }
247 
248   if (input->type == kTfLiteInt16 && output->type == kTfLiteInt16) {
249     TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0);
250     TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
251   }
252 
253   // If we have to perform on-the-fly quantization (with quantized weights and
254   // float inputs) first we need to quantize the inputs. Allocate a temporary
255   // buffer to store the intermediate quantized values.
256   // Additionally, we allocate a temporary buffer to store the accumulated
257   // quantized values prior to multiplication by the scaling factor.
258   const bool is_hybrid =
259       (input->type == kTfLiteFloat32 &&
260        (filter->type == kTfLiteUInt8 || filter->type == kTfLiteInt8));
261   const bool is_sparse = filter->sparsity != nullptr;
262   if (is_hybrid) {
263     TfLiteIntArrayFree(node->temporaries);
264     data->compute_row_sums = true;
265     if (is_sparse) {
266       node->temporaries = TfLiteIntArrayCreate(6);
267     } else {
268       node->temporaries = TfLiteIntArrayCreate(5);
269     }
270     node->temporaries->data[0] = data->scratch_tensor_index;
271 
272     TfLiteTensor* input_quantized;
273     TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/0,
274                                                 &input_quantized));
275     input_quantized->type = filter->type;
276     input_quantized->allocation_type = kTfLiteArenaRw;
277 
278     TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
279     TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
280                                                      input_quantized_size));
281 
282     node->temporaries->data[1] = data->scratch_tensor_index + 1;
283     TfLiteTensor* scaling_factors;
284     TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/1,
285                                                 &scaling_factors));
286     scaling_factors->type = kTfLiteFloat32;
287     scaling_factors->allocation_type = kTfLiteArenaRw;
288 
289     int scaling_dims[1] = {batch_size};
290     if (!TfLiteIntArrayEqualsArray(scaling_factors->dims, 1, scaling_dims)) {
291       TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
292       scaling_factors_size->data[0] = batch_size;
293       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
294                                                        scaling_factors_size));
295     }
296 
297     node->temporaries->data[2] = data->scratch_tensor_index + 2;
298     TfLiteTensor* accum_scratch;
299     TF_LITE_ENSURE_OK(
300         context, GetTemporarySafe(context, node, /*index=*/2, &accum_scratch));
301     accum_scratch->type = kTfLiteInt32;
302     accum_scratch->allocation_type = kTfLiteArenaRw;
303     int accum_scratch_dims[2] = {num_units, batch_size};
304     if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2,
305                                    accum_scratch_dims)) {
306       TfLiteIntArray* accum_size = TfLiteIntArrayCreate(2);
307       accum_size->data[0] = num_units;
308       accum_size->data[1] = batch_size;
309       TF_LITE_ENSURE_OK(
310           context, context->ResizeTensor(context, accum_scratch, accum_size));
311     }
312 
313     node->temporaries->data[3] = data->scratch_tensor_index + 3;
314     TfLiteTensor* input_offsets;
315     TF_LITE_ENSURE_OK(
316         context, GetTemporarySafe(context, node, /*index=*/3, &input_offsets));
317     input_offsets->type = kTfLiteInt32;
318     input_offsets->allocation_type = kTfLiteArenaRw;
319     if (!TfLiteIntArrayEqualsArray(input_offsets->dims, 1, scaling_dims)) {
320       TfLiteIntArray* input_offsets_size = TfLiteIntArrayCreate(1);
321       input_offsets_size->data[0] = batch_size;
322       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_offsets,
323                                                        input_offsets_size));
324     }
325     node->temporaries->data[4] = data->scratch_tensor_index + 4;
326     TfLiteTensor* row_sums;
327     TF_LITE_ENSURE_OK(context,
328                       GetTemporarySafe(context, node, /*index=*/4, &row_sums));
329     row_sums->type = kTfLiteInt32;
330     row_sums->allocation_type = kTfLiteArenaRwPersistent;
331     int row_sums_dims[1] = {num_units};
332     if (!TfLiteIntArrayEqualsArray(row_sums->dims, 1, row_sums_dims)) {
333       TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(1);
334       row_sums_size->data[0] = row_sums_dims[0];
335       TF_LITE_ENSURE_OK(
336           context, context->ResizeTensor(context, row_sums, row_sums_size));
337     }
338 
339     if (is_sparse) {
340       data->ledger_initialized = false;
341       node->temporaries->data[5] = data->scratch_tensor_index + 5;
342       TfLiteTensor* filter_ledger =
343           &context->tensors[node->temporaries->data[5]];
344       auto status =
345           CreateLedgerTensor(filter->sparsity, context, filter_ledger);
346       if (status != kTfLiteOk) return status;
347     }
348   }
349 
350   // Resize output.
351   TfLiteIntArray* output_size_array = nullptr;
352   if (params->keep_num_dims) {
353     // When number of dimensions are kept the filter operates along the last
354     // dimensions. In other words, for an input tensor with shape
355     // [batch_size, ..., n_inputs] and a filter of shape [n_inputs, n_units]
356     // this Op produces an output of shape [batch_size, ..., n_units].
357     TF_LITE_ENSURE_EQ(context, input->dims->data[input->dims->size - 1],
358                       SizeOfDimension(filter, 1));
359     output_size_array = TfLiteIntArrayCopy(input->dims);
360     output_size_array->data[output_size_array->size - 1] = num_units;
361   } else {
362     // Otherwise, the output is (potentially flattened to) a 2-D matrix.
363     output_size_array = TfLiteIntArrayCreate(2);
364     output_size_array->data[0] = batch_size;
365     output_size_array->data[1] = num_units;
366   }
367   TF_LITE_ENSURE_OK(context,
368                     context->ResizeTensor(context, output, output_size_array));
369 
370   return kTfLiteOk;
371 }
372 
373 template <KernelType kernel_type>
Prepare(TfLiteContext * context,TfLiteNode * node)374 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
375   // Check for supported activation types.
376   auto* params =
377       reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data);
378   const TfLiteTensor* filter;
379   TF_LITE_ENSURE_OK(context,
380                     GetInputSafe(context, node, kWeightsTensor, &filter));
381   const TfLiteTensor* input;
382   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
383   const bool is_quantized =
384       ((filter->type == kTfLiteUInt8) || (filter->type == kTfLiteInt8));
385   const bool is_hybrid = is_quantized && (input->type == kTfLiteFloat32);
386   const bool is_pie = kernel_type == kLegacyPie;
387 
388   // Pie and hybrid path supports all kinds of fused activations, otherwise only
389   // clipping activations are supported.
390   if (!is_pie && !is_hybrid) {
391     TF_LITE_ENSURE(context, params->activation == kTfLiteActNone ||
392                                 params->activation == kTfLiteActRelu ||
393                                 params->activation == kTfLiteActReluN1To1 ||
394                                 params->activation == kTfLiteActRelu6);
395   }
396   return PrepareImpl(context, node);
397 }
398 
EvalPie(TfLiteContext * context,TfLiteNode * node,TfLiteFullyConnectedParams * params,OpData * data,const TfLiteTensor * input,const TfLiteTensor * filter,const TfLiteTensor * bias,TfLiteTensor * output)399 TfLiteStatus EvalPie(TfLiteContext* context, TfLiteNode* node,
400                      TfLiteFullyConnectedParams* params, OpData* data,
401                      const TfLiteTensor* input, const TfLiteTensor* filter,
402                      const TfLiteTensor* bias, TfLiteTensor* output) {
403   int total_input_size = 1;
404   for (int i = 0; i < input->dims->size; i++) {
405     total_input_size *= input->dims->data[i];
406   }
407 
408   int input_size = filter->dims->data[1];
409   const int batch_size = total_input_size / filter->dims->data[1];
410   const int num_units = filter->dims->data[0];
411 
412   // Output = bias if bias tensor exists.
413   if (bias) {
414     tensor_utils::VectorBatchVectorAssign(GetTensorData<float>(bias), num_units,
415                                           batch_size,
416                                           GetTensorData<float>(output));
417   } else {
418     std::fill_n(GetTensorData<float>(output), batch_size * num_units, 0.0f);
419   }
420 
421   // Compute output += weight * input
422   tensor_utils::MatrixBatchVectorMultiplyAccumulate(
423       GetTensorData<float>(filter), num_units, input_size,
424       GetTensorData<float>(input), batch_size, GetTensorData<float>(output));
425 
426   // Apply activation function
427   tensor_utils::ApplyActivationToVector(
428       GetTensorData<float>(output), batch_size * num_units, params->activation,
429       GetTensorData<float>(output));
430 
431   return kTfLiteOk;
432 }
433 
EvalHybridImpl(TfLiteContext * context,TfLiteNode * node,TfLiteFullyConnectedParams * params,OpData * data,const TfLiteTensor * input,const TfLiteTensor * filter,const TfLiteTensor * bias,int thread_start,int thread_end,TfLiteTensor * input_quantized,TfLiteTensor * scaling_factors,TfLiteTensor * accum_scratch,TfLiteTensor * row_sums,TfLiteTensor * input_offsets,TfLiteTensor * output)434 void EvalHybridImpl(TfLiteContext* context, TfLiteNode* node,
435                     TfLiteFullyConnectedParams* params, OpData* data,
436                     const TfLiteTensor* input, const TfLiteTensor* filter,
437                     const TfLiteTensor* bias, int thread_start, int thread_end,
438                     TfLiteTensor* input_quantized,
439                     TfLiteTensor* scaling_factors, TfLiteTensor* accum_scratch,
440                     TfLiteTensor* row_sums, TfLiteTensor* input_offsets,
441                     TfLiteTensor* output) {
442   ruy::profiler::ScopeLabel label("FullyConnected");
443   ruy::profiler::ScopeLabel inner_label("Hybrid Kernel");
444   const auto& input_shape = GetTensorShape(input);
445   const auto& output_shape = GetTensorShape(output);
446   const auto& filter_shape = GetTensorShape(filter);
447   const int input_dims_count = input_shape.DimensionsCount();
448   const int output_dims_count = output_shape.DimensionsCount();
449   const int filter_dims_count = filter_shape.DimensionsCount();
450   const int batch_size = thread_end - thread_start;
451   const int input_depth = MatchingDim(filter_shape, filter_dims_count - 1,
452                                       input_shape, input_dims_count - 1);
453   const int output_depth = MatchingDim(filter_shape, filter_dims_count - 2,
454                                        output_shape, output_dims_count - 1);
455   const int per_thread_input_size = batch_size * input_depth;
456 
457   const bool is_sparse = filter->sparsity != nullptr;
458 
459   const float* per_thread_input =
460       GetTensorData<float>(input) + thread_start * input_depth;
461   float* per_thread_output =
462       GetTensorData<float>(output) + thread_start * output_depth;
463 
464   // Output = bias if bias tensor exists.
465   if (bias) {
466     tensor_utils::VectorBatchVectorAssign(GetTensorData<float>(bias),
467                                           output_depth, batch_size,
468                                           per_thread_output);
469   } else {
470     std::fill_n(per_thread_output, batch_size * output_depth, 0.0f);
471   }
472 
473   // Save matrix multiplication computation for all zero input.
474   if (tensor_utils::IsZeroVector(per_thread_input, per_thread_input_size)) {
475     tensor_utils::ApplyActivationToVector(
476         per_thread_output, batch_size * output_depth, params->activation,
477         per_thread_output);
478     return;
479   }
480 
481   // Quantize input from float to uint8 + quantization params (scaling factor).
482   float* scaling_factors_ptr =
483       GetTensorData<float>(scaling_factors) + thread_start;
484   int32_t* input_offset_ptr = nullptr;
485   int32_t* row_sums_ptr = nullptr;
486   if (params->asymmetric_quantize_inputs) {
487     input_offset_ptr = GetTensorData<int32_t>(input_offsets) + thread_start;
488     row_sums_ptr = GetTensorData<int32_t>(row_sums);
489   }
490   int8_t* quant_data =
491       GetTensorData<int8_t>(input_quantized) + thread_start * input_depth;
492   const int8_t* filter_data = GetTensorData<int8_t>(filter);
493   tensor_utils::BatchQuantizeFloats(per_thread_input, batch_size, input_depth,
494                                     quant_data, scaling_factors_ptr,
495                                     input_offset_ptr,
496                                     params->asymmetric_quantize_inputs);
497   for (int b = 0; b < batch_size; ++b) {
498     // Incorporate scaling of the filter.
499     scaling_factors_ptr[b] *= filter->params.scale;
500   }
501 
502   // Compute output += weight * quantized_input
503   int32_t* scratch =
504       GetTensorData<int32_t>(accum_scratch) + thread_start * output_depth;
505   if (is_sparse) {
506     TfLiteTensor* filter_ledger = &context->tensors[node->temporaries->data[5]];
507     tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate(
508         GetTensorData<int8_t>(filter), GetTensorData<uint8_t>(filter_ledger),
509         output_depth, input_depth, quant_data, scaling_factors_ptr, batch_size,
510         per_thread_output);
511   } else {
512     tensor_utils::MatrixBatchVectorMultiplyAccumulate(
513         filter_data, output_depth, input_depth, quant_data, scaling_factors_ptr,
514         batch_size, per_thread_output, /*per_channel_scale=*/nullptr,
515         input_offset_ptr, scratch, row_sums_ptr, &data->compute_row_sums,
516         CpuBackendContext::GetFromContext(context));
517   }
518 
519   // Apply activation function to floats.
520   tensor_utils::ApplyActivationToVector(per_thread_output,
521                                         batch_size * output_depth,
522                                         params->activation, per_thread_output);
523 }
524 
525 struct HybridFullyConnectedTask : cpu_backend_threadpool::Task {
HybridFullyConnectedTasktflite::ops::builtin::fully_connected::HybridFullyConnectedTask526   HybridFullyConnectedTask(TfLiteContext* context, TfLiteNode* node,
527                            TfLiteFullyConnectedParams* params, OpData* data,
528                            const TfLiteTensor* input,
529                            const TfLiteTensor* filter, const TfLiteTensor* bias,
530                            const int thread_start, const int thread_end,
531                            TfLiteTensor* input_quantized,
532                            TfLiteTensor* scaling_factors,
533                            TfLiteTensor* accum_scratch, TfLiteTensor* row_sums,
534                            TfLiteTensor* input_offsets, TfLiteTensor* output)
535       : context(context),
536         node(node),
537         params(params),
538         data(data),
539         input(input),
540         filter(filter),
541         bias(bias),
542         thread_start(thread_start),
543         thread_end(thread_end),
544         input_quantized(input_quantized),
545         scaling_factors(scaling_factors),
546         accum_scratch(accum_scratch),
547         row_sums(row_sums),
548         input_offsets(input_offsets),
549         output(output) {}
550 
Runtflite::ops::builtin::fully_connected::HybridFullyConnectedTask551   void Run() override {
552     EvalHybridImpl(context, node, params, data, input, filter, bias,
553                    thread_start, thread_end, input_quantized, scaling_factors,
554                    accum_scratch, row_sums, input_offsets, output);
555   }
556 
557  private:
558   TfLiteContext* context;
559   TfLiteNode* node;
560   TfLiteFullyConnectedParams* params;
561   OpData* data;
562   const TfLiteTensor* input;
563   const TfLiteTensor* filter;
564   const TfLiteTensor* bias;
565   const int thread_start;
566   const int thread_end;
567   TfLiteTensor* input_quantized;
568   TfLiteTensor* scaling_factors;
569   TfLiteTensor* accum_scratch;
570   TfLiteTensor* row_sums;
571   TfLiteTensor* input_offsets;
572   TfLiteTensor* output;
573 };
574 
575 // The multi-threaded kernel slices the workload along the batch dimension. If
576 // there's not enough batches of data, the number of threads used is equal to
577 // the batch size.
578 // TODO(b/173442777): If needed, we can improve this later with slicing along
579 // the row dimension of the weight.
EvalHybrid(TfLiteContext * context,TfLiteNode * node,TfLiteFullyConnectedParams * params,OpData * data,const TfLiteTensor * input,const TfLiteTensor * filter,const TfLiteTensor * bias,TfLiteTensor * input_quantized,TfLiteTensor * scaling_factors,TfLiteTensor * accum_scratch,TfLiteTensor * row_sums,TfLiteTensor * input_offsets,TfLiteTensor * output)580 TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node,
581                         TfLiteFullyConnectedParams* params, OpData* data,
582                         const TfLiteTensor* input, const TfLiteTensor* filter,
583                         const TfLiteTensor* bias, TfLiteTensor* input_quantized,
584                         TfLiteTensor* scaling_factors,
585                         TfLiteTensor* accum_scratch, TfLiteTensor* row_sums,
586                         TfLiteTensor* input_offsets, TfLiteTensor* output) {
587   const auto& output_shape = GetTensorShape(output);
588   CpuBackendContext* cpu_backend_context =
589       CpuBackendContext::GetFromContext(context);
590   const bool is_sparse = filter->sparsity != nullptr;
591   if (is_sparse) {
592     TfLiteTensor* filter_ledger = &context->tensors[node->temporaries->data[5]];
593     if (!data->ledger_initialized) {
594       PopulateLedgerData(filter->sparsity, context,
595                          GetTensorData<uint8_t>(filter_ledger));
596       data->ledger_initialized = true;
597     }
598   }
599 
600   const int max_threads = cpu_backend_context->max_num_threads();
601   const int batches =
602       FlatSizeSkipDim(output_shape, output_shape.DimensionsCount() - 1);
603   const int thread_count = std::max(1, std::min(batches, max_threads));
604   if (thread_count == 1) {
605     EvalHybridImpl(context, node, params, data, input, filter, bias, 0, batches,
606                    input_quantized, scaling_factors, accum_scratch, row_sums,
607                    input_offsets, output);
608     return kTfLiteOk;
609   }
610 
611   std::vector<HybridFullyConnectedTask> tasks;
612   tasks.reserve(thread_count);
613   int thread_start = 0;
614   for (int i = 0; i < thread_count; ++i) {
615     // This makes sure the workload is relatively balanced when batches is not
616     // a multiple of thread_count. The first mod(batches, thread_count) tasks
617     // need to process one more batch than the rest.
618     int thread_end = thread_start + batches / thread_count;
619     if (i < batches % thread_count) thread_end++;
620 
621     tasks.emplace_back(context, node, params, data, input, filter, bias,
622                        thread_start, thread_end, input_quantized,
623                        scaling_factors, accum_scratch, row_sums, input_offsets,
624                        output);
625     thread_start = thread_end;
626   }
627   cpu_backend_threadpool::Execute(tasks.size(), tasks.data(),
628                                   cpu_backend_context);
629   return kTfLiteOk;
630 }
631 
632 namespace {
633 template <KernelType kernel_type>
FullyConnectedInt8(const OpData * data,const TfLiteTensor * input,const TfLiteTensor * filter,const TfLiteTensor * bias,TfLiteTensor * output,CpuBackendContext * cpu_backend_context)634 void FullyConnectedInt8(const OpData* data, const TfLiteTensor* input,
635                         const TfLiteTensor* filter, const TfLiteTensor* bias,
636                         TfLiteTensor* output,
637                         CpuBackendContext* cpu_backend_context) {
638   FullyConnectedParams op_params;
639   op_params.input_offset = -input->params.zero_point;
640   op_params.weights_offset = -filter->params.zero_point;
641   op_params.output_offset = output->params.zero_point;
642   op_params.output_multiplier = data->output_multiplier;
643   op_params.output_shift = data->output_shift;
644   op_params.quantized_activation_min = data->output_activation_min;
645   op_params.quantized_activation_max = data->output_activation_max;
646   op_params.lhs_cacheable = IsConstantTensor(filter);
647   op_params.rhs_cacheable = IsConstantTensor(input);
648   if (kernel_type == kReference) {
649     reference_integer_ops::FullyConnected(
650         op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
651         GetTensorShape(filter), GetTensorData<int8_t>(filter),
652         GetTensorShape(bias), GetTensorData<int32_t>(bias),
653         GetTensorShape(output), GetTensorData<int8_t>(output));
654   } else {
655     optimized_integer_ops::FullyConnected(
656         op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
657         GetTensorShape(filter), GetTensorData<int8_t>(filter),
658         GetTensorShape(bias), GetTensorData<int32_t>(bias),
659         GetTensorShape(output), GetTensorData<int8_t>(output),
660         cpu_backend_context);
661   }
662 }
663 }  // namespace
664 
665 namespace {
666 template <KernelType kernel_type>
FullyConnectedInt16(const OpData * data,const TfLiteTensor * input,const TfLiteTensor * filter,const TfLiteTensor * bias,TfLiteTensor * output)667 void FullyConnectedInt16(const OpData* data, const TfLiteTensor* input,
668                          const TfLiteTensor* filter, const TfLiteTensor* bias,
669                          TfLiteTensor* output) {
670   FullyConnectedParams op_params;
671   op_params.weights_offset = -filter->params.zero_point;
672   op_params.output_multiplier = data->output_multiplier;
673   op_params.output_shift = data->output_shift;
674   op_params.quantized_activation_min = data->output_activation_min;
675   op_params.quantized_activation_max = data->output_activation_max;
676   reference_integer_ops::FullyConnected(
677       op_params, GetTensorShape(input), GetTensorData<int16_t>(input),
678       GetTensorShape(filter), GetTensorData<int8_t>(filter),
679       GetTensorShape(bias), GetTensorData<int64_t>(bias),
680       GetTensorShape(output), GetTensorData<int16_t>(output));
681 }
682 }  // namespace
683 
684 template <KernelType kernel_type>
EvalQuantized(TfLiteContext * context,TfLiteNode * node,TfLiteFullyConnectedParams * params,OpData * data,const TfLiteTensor * input,const TfLiteTensor * filter,const TfLiteTensor * bias,TfLiteTensor * output)685 TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
686                            TfLiteFullyConnectedParams* params, OpData* data,
687                            const TfLiteTensor* input,
688                            const TfLiteTensor* filter, const TfLiteTensor* bias,
689                            TfLiteTensor* output) {
690   int32_t input_offset = -input->params.zero_point;
691   int32_t filter_offset = -filter->params.zero_point;
692   int32_t output_offset = output->params.zero_point;
693   // Only the Pie path supports quantized models and float inputs/outputs.
694   if (input->type == kTfLiteFloat32) {
695     TfLiteTensor* input_quantized;
696     TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/0,
697                                                 &input_quantized));
698     TfLiteTensor* scaling_factors;
699     TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/1,
700                                                 &scaling_factors));
701     TfLiteTensor* accum_scratch;
702     TF_LITE_ENSURE_OK(
703         context, GetTemporarySafe(context, node, /*index=*/2, &accum_scratch));
704     TfLiteTensor* input_offsets;
705     TF_LITE_ENSURE_OK(
706         context, GetTemporarySafe(context, node, /*index=*/3, &input_offsets));
707     TfLiteTensor* row_sums;
708     TF_LITE_ENSURE_OK(context,
709                       GetTemporarySafe(context, node, /*index=*/4, &row_sums));
710     return EvalHybrid(context, node, params, data, input, filter, bias,
711                       input_quantized, scaling_factors, accum_scratch, row_sums,
712                       input_offsets, output);
713   } else {
714     FullyConnectedParams op_params;
715     op_params.input_offset = input_offset;
716     op_params.weights_offset = filter_offset;
717     op_params.output_offset = output_offset;
718     op_params.output_multiplier = data->output_multiplier;
719     op_params.output_shift = data->output_shift;
720     op_params.quantized_activation_min = data->output_activation_min;
721     op_params.quantized_activation_max = data->output_activation_max;
722     op_params.lhs_cacheable = IsConstantTensor(filter);
723     op_params.rhs_cacheable = IsConstantTensor(input);
724     switch (output->type) {
725       case kTfLiteUInt8:
726         if (kernel_type == kReference) {
727           reference_ops::FullyConnected(
728               op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
729               GetTensorShape(filter), GetTensorData<uint8_t>(filter),
730               GetTensorShape(bias), GetTensorData<int32_t>(bias),
731               GetTensorShape(output), GetTensorData<uint8_t>(output));
732         } else {
733           optimized_ops::FullyConnected(
734               op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
735               GetTensorShape(filter), GetTensorData<uint8_t>(filter),
736               GetTensorShape(bias), GetTensorData<int32_t>(bias),
737               GetTensorShape(output), GetTensorData<uint8_t>(output),
738               CpuBackendContext::GetFromContext(context));
739         }
740         break;
741       case kTfLiteInt8:
742         FullyConnectedInt8<kernel_type>(
743             data, input, filter, bias, output,
744             CpuBackendContext::GetFromContext(context));
745         break;
746       case kTfLiteInt16:
747         if (input->type == kTfLiteInt16) {
748           FullyConnectedInt16<kernel_type>(data, input, filter, bias, output);
749         } else if (kernel_type == kReference) {
750           reference_ops::FullyConnected(
751               op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
752               GetTensorShape(filter), GetTensorData<uint8_t>(filter),
753               GetTensorShape(bias), GetTensorData<int32_t>(bias),
754               GetTensorShape(output), GetTensorData<int16_t>(output));
755         } else {
756           optimized_ops::FullyConnected(
757               op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
758               GetTensorShape(filter), GetTensorData<uint8_t>(filter),
759               GetTensorShape(bias), GetTensorData<int32_t>(bias),
760               GetTensorShape(output), GetTensorData<int16_t>(output),
761               CpuBackendContext::GetFromContext(context));
762         }
763         break;
764       default:
765         context->ReportError(context,
766                              "Quantized FullyConnected expects output data "
767                              "type uint8, int8 or int16");
768         return kTfLiteError;
769     }
770   }
771 
772   return kTfLiteOk;
773 }
774 
775 template <KernelType kernel_type>
EvalShuffledQuantized(TfLiteContext * context,TfLiteNode * node,TfLiteFullyConnectedParams * params,OpData * data,const TfLiteTensor * input,const TfLiteTensor * filter,const TfLiteTensor * bias,TfLiteTensor * output,TfLiteTensor * shuffled_input_workspace)776 TfLiteStatus EvalShuffledQuantized(TfLiteContext* context, TfLiteNode* node,
777                                    TfLiteFullyConnectedParams* params,
778                                    OpData* data, const TfLiteTensor* input,
779                                    const TfLiteTensor* filter,
780                                    const TfLiteTensor* bias,
781                                    TfLiteTensor* output,
782                                    TfLiteTensor* shuffled_input_workspace) {
783   // TODO(b/110697972) decide more consistently if / how / where we want
784   // to perform this kind of runtime data type checks.
785   if (shuffled_input_workspace->type != kTfLiteUInt8) {
786     context->ReportError(context, "Unexpected data type");
787     return kTfLiteError;
788   }
789 
790 #define TF_LITE_SHUFFLED_FULLY_CONNECTED(type)                           \
791   {                                                                      \
792     type::ShuffledFullyConnected(                                        \
793         op_params, GetTensorShape(input), GetTensorData<uint8_t>(input), \
794         GetTensorShape(filter), GetTensorData<uint8_t>(filter),          \
795         GetTensorShape(bias), GetTensorData<int32_t>(bias),              \
796         GetTensorShape(output), GetTensorData<int16_t>(output),          \
797         GetTensorData<uint8_t>(shuffled_input_workspace),                \
798         CpuBackendContext::GetFromContext(context));                     \
799   }
800   FullyConnectedParams op_params;
801   op_params.output_multiplier = data->output_multiplier;
802   op_params.output_shift = data->output_shift;
803   op_params.quantized_activation_min = data->output_activation_min;
804   op_params.quantized_activation_max = data->output_activation_max;
805   op_params.lhs_cacheable = IsConstantTensor(filter);
806   op_params.rhs_cacheable = IsConstantTensor(input);
807   if (kernel_type == kReference) {
808     reference_ops::ShuffledFullyConnected(
809         op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
810         GetTensorShape(filter), GetTensorData<uint8_t>(filter),
811         GetTensorShape(bias), GetTensorData<int32_t>(bias),
812         GetTensorShape(output), GetTensorData<int16_t>(output),
813         GetTensorData<uint8_t>(shuffled_input_workspace));
814   } else {
815     optimized_ops::ShuffledFullyConnected(
816         op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
817         GetTensorShape(filter), GetTensorData<uint8_t>(filter),
818         GetTensorShape(bias), GetTensorData<int32_t>(bias),
819         GetTensorShape(output), GetTensorData<int16_t>(output),
820         GetTensorData<uint8_t>(shuffled_input_workspace),
821         CpuBackendContext::GetFromContext(context));
822   }
823 #undef TF_LITE_SHUFFLED_FULLY_CONNECTED
824 
825   return kTfLiteOk;
826 }
827 
828 template <KernelType kernel_type>
EvalFloat(TfLiteContext * context,TfLiteNode * node,TfLiteFullyConnectedParams * params,OpData * data,const TfLiteTensor * input,const TfLiteTensor * filter,const TfLiteTensor * bias,TfLiteTensor * output)829 TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
830                        TfLiteFullyConnectedParams* params, OpData* data,
831                        const TfLiteTensor* input, const TfLiteTensor* filter,
832                        const TfLiteTensor* bias, TfLiteTensor* output) {
833   float output_activation_min, output_activation_max;
834   CalculateActivationRange(params->activation, &output_activation_min,
835                            &output_activation_max);
836   if (kernel_type == kReference) {
837     FullyConnectedParams op_params;
838     op_params.float_activation_min = output_activation_min;
839     op_params.float_activation_max = output_activation_max;
840     if (filter->sparsity != nullptr) {
841       const auto& sparsity = *filter->sparsity;
842       reference_ops::FullyConnectedSparseWeight(
843           sparsity, op_params, GetTensorShape(input),
844           GetTensorData<float>(input), GetTensorShape(filter),
845           GetTensorData<float>(filter), GetTensorShape(bias),
846           GetTensorData<float>(bias), GetTensorShape(output),
847           GetTensorData<float>(output));
848     } else {
849       reference_ops::FullyConnected(
850           op_params, GetTensorShape(input), GetTensorData<float>(input),
851           GetTensorShape(filter), GetTensorData<float>(filter),
852           GetTensorShape(bias), GetTensorData<float>(bias),
853           GetTensorShape(output), GetTensorData<float>(output));
854     }
855   } else if (kernel_type == kLegacyPie) {
856     return EvalPie(context, node, params, data, input, filter, bias, output);
857   } else {
858     FullyConnectedParams op_params;
859     op_params.float_activation_min = output_activation_min;
860     op_params.float_activation_max = output_activation_max;
861     if (filter->sparsity != nullptr) {
862       const auto& sparsity = *filter->sparsity;
863       if (!SupportedSparsityFormat(sparsity)) {
864         TF_LITE_KERNEL_LOG(context,
865                            "Unsupported sparse fully-connected weight format.");
866         return kTfLiteError;
867       }
868 
869       if (sparsity.dim_metadata_size == kDimMetadataSizeRandomSparse) {
870         // Random sparse.
871         optimized_ops::FullyConnectedSparseWeight(
872             sparsity, op_params, GetTensorShape(input),
873             GetTensorData<float>(input), GetTensorShape(filter),
874             GetTensorData<float>(filter), GetTensorShape(bias),
875             GetTensorData<float>(bias), GetTensorShape(output),
876             GetTensorData<float>(output));
877       } else if (sparsity.dim_metadata_size == kDimMetadataSizeBlockSparse &&
878                  sparsity.dim_metadata[2].dense_size == 4) {
879         // Block sparse with block size of 1x4.
880         optimized_ops::FullyConnectedSparseWeight1x4(
881             sparsity, op_params, GetTensorShape(input),
882             GetTensorData<float>(input), GetTensorShape(filter),
883             GetTensorData<float>(filter), GetTensorShape(bias),
884             GetTensorData<float>(bias), GetTensorShape(output),
885             GetTensorData<float>(output),
886             CpuBackendContext::GetFromContext(context));
887       } else {
888         TF_LITE_KERNEL_LOG(context,
889                            "Unsupported sparse fully-connected weight format.");
890         return kTfLiteError;
891       }
892 
893     } else {
894       op_params.lhs_cacheable = IsConstantTensor(filter);
895       op_params.rhs_cacheable = IsConstantTensor(input);
896       optimized_ops::FullyConnected(
897           op_params, GetTensorShape(input), GetTensorData<float>(input),
898           GetTensorShape(filter), GetTensorData<float>(filter),
899           GetTensorShape(bias), GetTensorData<float>(bias),
900           GetTensorShape(output), GetTensorData<float>(output),
901           CpuBackendContext::GetFromContext(context));
902     }
903   }
904 
905   return kTfLiteOk;
906 }
907 
908 template <KernelType kernel_type>
Eval(TfLiteContext * context,TfLiteNode * node)909 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
910   auto* params =
911       reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data);
912   OpData* data = reinterpret_cast<OpData*>(node->user_data);
913 
914   const TfLiteTensor* input;
915   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
916   const TfLiteTensor* filter;
917   TF_LITE_ENSURE_OK(context,
918                     GetInputSafe(context, node, kWeightsTensor, &filter));
919   const TfLiteTensor* bias =
920       (node->inputs->size == 3)
921           ? GetOptionalInputTensor(context, node, kBiasTensor)
922           : nullptr;
923   TfLiteTensor* output;
924   TF_LITE_ENSURE_OK(context,
925                     GetOutputSafe(context, node, kOutputTensor, &output));
926 
927   switch (filter->type) {
928     case kTfLiteFloat32:
929       return EvalFloat<kernel_type>(context, node, params, data, input, filter,
930                                     bias, output);
931     case kTfLiteUInt8:
932       if (params->weights_format ==
933           kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8) {
934         TfLiteTensor* shuffled_input_workspace;
935         TF_LITE_ENSURE_OK(
936             context, GetOutputSafe(context, node, kShuffledInputWorkspaceTensor,
937                                    &shuffled_input_workspace));
938         return EvalShuffledQuantized<kernel_type>(context, node, params, data,
939                                                   input, filter, bias, output,
940                                                   shuffled_input_workspace);
941       } else if (params->weights_format ==
942                  kTfLiteFullyConnectedWeightsFormatDefault) {
943         return EvalQuantized<kernel_type>(context, node, params, data, input,
944                                           filter, bias, output);
945       } else {
946         context->ReportError(context,
947                              "Unhandled fully-connected weights format");
948         return kTfLiteError;
949       }
950     case kTfLiteInt8:
951       if (params->weights_format == kTfLiteFullyConnectedWeightsFormatDefault) {
952         return EvalQuantized<kernel_type>(context, node, params, data, input,
953                                           filter, bias, output);
954       } else {
955         context->ReportError(context,
956                              "Unhandled fully-connected weights format");
957         return kTfLiteError;
958       }
959     default:
960       context->ReportError(context,
961                            "Filter data type %s currently not supported.",
962                            TfLiteTypeGetName(filter->type));
963       return kTfLiteError;
964   }
965   return kTfLiteOk;
966 }
967 
968 }  // namespace fully_connected
969 
Register_FULLY_CONNECTED_REF()970 TfLiteRegistration* Register_FULLY_CONNECTED_REF() {
971   static TfLiteRegistration r = {
972       fully_connected::Init, fully_connected::Free,
973       fully_connected::Prepare<fully_connected::kReference>,
974       fully_connected::Eval<fully_connected::kReference>};
975   return &r;
976 }
977 
Register_FULLY_CONNECTED_GENERIC_OPT()978 TfLiteRegistration* Register_FULLY_CONNECTED_GENERIC_OPT() {
979   static TfLiteRegistration r = {
980       fully_connected::Init, fully_connected::Free,
981       fully_connected::Prepare<fully_connected::kGenericOptimized>,
982       fully_connected::Eval<fully_connected::kGenericOptimized>};
983   return &r;
984 }
985 
986 // Legacy path for PIE clients.
Register_FULLY_CONNECTED_PIE()987 TfLiteRegistration* Register_FULLY_CONNECTED_PIE() {
988   static TfLiteRegistration r = {
989       fully_connected::Init, fully_connected::Free,
990       fully_connected::Prepare<fully_connected::kLegacyPie>,
991       fully_connected::Eval<fully_connected::kLegacyPie>};
992   return &r;
993 }
994 
Register_FULLY_CONNECTED()995 TfLiteRegistration* Register_FULLY_CONNECTED() {
996   return Register_FULLY_CONNECTED_GENERIC_OPT();
997 }
998 
999 }  // namespace builtin
1000 }  // namespace ops
1001 }  // namespace tflite
1002