• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/kernels/internal/optimized/integer_ops/fully_connected.h"
17 
18 #include <algorithm>
19 #include <cstddef>
20 #include <cstdint>
21 
22 #include "tensorflow/lite/c/builtin_op_data.h"
23 #include "tensorflow/lite/c/common.h"
24 #include "tensorflow/lite/kernels/cpu_backend_context.h"
25 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
26 #include "tensorflow/lite/kernels/internal/optimized/sparse_ops/fully_connected.h"
27 #include "tensorflow/lite/kernels/internal/quantization_util.h"
28 #include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
29 #include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h"
30 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
31 #include "tensorflow/lite/kernels/internal/reference/sparse_ops/fully_connected.h"
32 #include "tensorflow/lite/kernels/internal/tensor.h"
33 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
34 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
35 #include "tensorflow/lite/kernels/internal/types.h"
36 #include "tensorflow/lite/kernels/kernel_util.h"
37 
38 namespace tflite {
39 namespace ops {
40 namespace builtin {
41 namespace fully_connected {
42 
43 namespace {
SupportedSparsityFormat(const TfLiteSparsity & sparsity)44 bool SupportedSparsityFormat(const TfLiteSparsity& sparsity) {
45   if (sparsity.dim_metadata[0].format == kTfLiteDimDense &&
46       sparsity.dim_metadata[1].format == kTfLiteDimSparseCSR) {
47     return true;
48   }
49 
50   return false;
51 }
52 
53 static const int kDimMetadataSizeRandomSparse = 2;
54 static const int kDimMetadataSizeBlockSparse = 3;
55 
CreateLedgerTensor(const TfLiteSparsity * sparsity,TfLiteContext * context,TfLiteTensor * ledger)56 TfLiteStatus CreateLedgerTensor(const TfLiteSparsity* sparsity,
57                                 TfLiteContext* context, TfLiteTensor* ledger) {
58   TF_LITE_ENSURE(context, sparsity != nullptr);
59   ledger->type = kTfLiteUInt8;
60   ledger->allocation_type = kTfLiteArenaRwPersistent;
61   TfLiteIntArray* ledger_size = TfLiteIntArrayCreate(1);
62   ledger_size->data[0] = sparsity->dim_metadata[1].array_indices->size +
63                          sparsity->dim_metadata[1].array_segments->size - 1;
64   return context->ResizeTensor(context, ledger, ledger_size);
65 }
66 
PopulateLedgerData(const TfLiteSparsity * sparsity,TfLiteContext * context,uint8_t * ledger_data)67 TfLiteStatus PopulateLedgerData(const TfLiteSparsity* sparsity,
68                                 TfLiteContext* context, uint8_t* ledger_data) {
69   TF_LITE_ENSURE(context, sparsity != nullptr);
70   const auto* array_segments = sparsity->dim_metadata[1].array_segments;
71   const auto* array_indices = sparsity->dim_metadata[1].array_indices;
72   int output_data_ptr = 0;
73 
74   for (int i = 0; i < array_segments->size - 1; i++) {
75     int row_start = array_segments->data[i];
76     int row_end = array_segments->data[i + 1];
77     if (row_end - row_start > UINT8_MAX) {
78       return kTfLiteError;
79     }
80     // Copy num of non-zero blocks in row i.
81     ledger_data[output_data_ptr] = static_cast<uint8_t>(row_end - row_start);
82     output_data_ptr++;
83 
84     for (int j = row_start; j < row_end; j++) {
85       if (array_indices->data[j] > UINT8_MAX) {
86         return kTfLiteError;
87       }
88       // Copy indices of non-zero blocks in row i.
89       ledger_data[output_data_ptr] =
90           static_cast<uint8_t>(array_indices->data[j]);
91       output_data_ptr++;
92     }
93   }
94   return kTfLiteOk;
95 }
96 
97 }  // namespace
98 
99 // This file has four implementations of FullyConnected
100 enum KernelType {
101   kReference,
102   kGenericOptimized,
103   kLegacyPie,  // Legacy path used by the PIE team and related clients.
104 };
105 
106 struct OpData {
107   // The scaling factor from input to output (aka the 'real multiplier') can
108   // be represented as a fixed point multiplier plus a left shift.
109   int32_t output_multiplier;
110   int output_shift;
111   // The range of the fused activation layer. For example for kNone and
112   // uint8_t these would be 0 and 255.
113   int32_t output_activation_min;
114   int32_t output_activation_max;
115   // The index of the temporary tensor where the quantized inputs are cached.
116   int scratch_tensor_index;
117   bool compute_row_sums = false;
118   // Only used for sparse hybrid fully connected kernels.
119   bool ledger_initialized;
120 };
121 
122 constexpr int kInputTensor = 0;
123 constexpr int kWeightsTensor = 1;
124 constexpr int kBiasTensor = 2;
125 constexpr int kOutputTensor = 0;
126 constexpr int kShuffledInputWorkspaceTensor = 1;
127 
CheckTypes(TfLiteContext * context,const TfLiteTensor * input,const TfLiteTensor * filter,const TfLiteTensor * bias,TfLiteTensor * output,TfLiteFullyConnectedParams * params)128 inline TfLiteStatus CheckTypes(TfLiteContext* context,
129                                const TfLiteTensor* input,
130                                const TfLiteTensor* filter,
131                                const TfLiteTensor* bias, TfLiteTensor* output,
132                                TfLiteFullyConnectedParams* params) {
133   const bool is_quantized =
134       ((filter->type == kTfLiteUInt8) || (filter->type == kTfLiteInt8));
135   const bool is_hybrid = is_quantized && (input->type == kTfLiteFloat32);
136   const bool is_shuffled =
137       is_quantized && (params->weights_format ==
138                        kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8);
139 
140   // optional bias tensor.
141   const bool is_optional_bias_float = !bias || (bias->type == kTfLiteFloat32);
142   const bool is_optional_bias_int =
143       !bias || (bias->type == kTfLiteInt32) || (bias->type == kTfLiteInt64);
144 
145   if (is_quantized) {
146     if (is_shuffled) {
147       TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteUInt8);
148       TF_LITE_ENSURE_TYPES_EQ(context, filter->type, kTfLiteUInt8);
149       TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteInt16);
150       TF_LITE_ENSURE_EQ(context, is_optional_bias_int, true);
151     } else if (is_hybrid) {
152       TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
153       TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32);
154       TF_LITE_ENSURE_EQ(context, is_optional_bias_float, true);
155     } else {
156       TF_LITE_ENSURE(context, input->type == kTfLiteUInt8 ||
157                                   input->type == kTfLiteInt8 ||
158                                   input->type == kTfLiteInt16);
159       TF_LITE_ENSURE(context, output->type == kTfLiteUInt8 ||
160                                   output->type == kTfLiteInt8 ||
161                                   output->type == kTfLiteInt16);
162       TF_LITE_ENSURE_EQ(context, is_optional_bias_int, true);
163     }
164   } else {
165     // Only float32 is supported currently
166     TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
167     TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32);
168     TF_LITE_ENSURE_TYPES_EQ(context, filter->type, kTfLiteFloat32);
169     TF_LITE_ENSURE_EQ(context, is_optional_bias_float, true);
170   }
171 
172   return kTfLiteOk;
173 }
174 
Init(TfLiteContext * context,const char * buffer,size_t length)175 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
176   // This is a builtin op, so we don't use the contents in 'buffer', if any.
177   // Instead, we allocate a new object to carry information from Prepare() to
178   // Eval().
179   auto* op_data = new OpData();
180   context->AddTensors(context, /*tensors_to_add=*/6,
181                       &op_data->scratch_tensor_index);
182   return op_data;
183 }
184 
Free(TfLiteContext * context,void * buffer)185 void Free(TfLiteContext* context, void* buffer) {
186   delete reinterpret_cast<OpData*>(buffer);
187 }
188 
PrepareImpl(TfLiteContext * context,TfLiteNode * node)189 TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node) {
190   auto* params =
191       reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data);
192   OpData* data = reinterpret_cast<OpData*>(node->user_data);
193   // Check we have all the inputs and outputs we need.
194   TF_LITE_ENSURE(context, node->inputs->size == 2 || node->inputs->size == 3);
195   // Shuffled formats need a workspace to store the shuffled input activations.
196   const int expected_outputs_count =
197       params->weights_format == kTfLiteFullyConnectedWeightsFormatDefault ? 1
198                                                                           : 2;
199   TF_LITE_ENSURE_EQ(context, node->outputs->size, expected_outputs_count);
200 
201   const TfLiteTensor* input;
202   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
203   const TfLiteTensor* filter;
204   TF_LITE_ENSURE_OK(context,
205                     GetInputSafe(context, node, kWeightsTensor, &filter));
206   const TfLiteTensor* bias =
207       (node->inputs->size == 3)
208           ? GetOptionalInputTensor(context, node, kBiasTensor)
209           : nullptr;
210   TfLiteTensor* output;
211   TF_LITE_ENSURE_OK(context,
212                     GetOutputSafe(context, node, kOutputTensor, &output));
213 
214   // Check proper datatype match among all Input Tensors
215   TF_LITE_ENSURE_STATUS(
216       CheckTypes(context, input, filter, bias, output, params));
217 
218   // Check all the parameters of tensor match within themselves and match the
219   // input configuration.
220   int input_size = 1;
221   for (int i = 0; i < input->dims->size; i++) {
222     input_size *= input->dims->data[i];
223   }
224 
225   TF_LITE_ENSURE_EQ(context, NumDimensions(filter), 2);
226   TF_LITE_ENSURE(context, filter->dims->data[1] != 0);
227   const int batch_size = input_size / filter->dims->data[1];
228   const int num_units = filter->dims->data[0];
229 
230   if (bias) {
231     TF_LITE_ENSURE_EQ(context, NumElements(bias), SizeOfDimension(filter, 0));
232   }
233 
234   // Note that quantized inference requires that all tensors have their
235   // parameters set. This is usually done during quantized training.
236   if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8 ||
237       input->type == kTfLiteInt16) {
238     double real_multiplier = 0.0;
239     TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
240         context, input, filter, bias, output, &real_multiplier));
241     int exponent;
242     QuantizeMultiplier(real_multiplier, &data->output_multiplier, &exponent);
243     data->output_shift = exponent;
244     TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized(
245         context, params->activation, output, &data->output_activation_min,
246         &data->output_activation_max));
247   }
248 
249   if (input->type == kTfLiteInt16 && output->type == kTfLiteInt16) {
250     TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0);
251     TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
252   }
253 
254   // If we have to perform on-the-fly quantization (with quantized weights and
255   // float inputs) first we need to quantize the inputs. Allocate a temporary
256   // buffer to store the intermediate quantized values.
257   // Additionally, we allocate a temporary buffer to store the accumulated
258   // quantized values prior to multiplication by the scaling factor.
259   const bool is_hybrid =
260       (input->type == kTfLiteFloat32 &&
261        (filter->type == kTfLiteUInt8 || filter->type == kTfLiteInt8));
262   const bool is_sparse = filter->sparsity != nullptr;
263   if (is_hybrid) {
264     TfLiteIntArrayFree(node->temporaries);
265     data->compute_row_sums = true;
266     if (is_sparse) {
267       node->temporaries = TfLiteIntArrayCreate(6);
268     } else {
269       node->temporaries = TfLiteIntArrayCreate(5);
270     }
271     node->temporaries->data[0] = data->scratch_tensor_index;
272 
273     TfLiteTensor* input_quantized;
274     TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/0,
275                                                 &input_quantized));
276     input_quantized->type = filter->type;
277     input_quantized->allocation_type = kTfLiteArenaRw;
278 
279     TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
280     TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
281                                                      input_quantized_size));
282 
283     node->temporaries->data[1] = data->scratch_tensor_index + 1;
284     TfLiteTensor* scaling_factors;
285     TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/1,
286                                                 &scaling_factors));
287     scaling_factors->type = kTfLiteFloat32;
288     scaling_factors->allocation_type = kTfLiteArenaRw;
289 
290     int scaling_dims[1] = {batch_size};
291     if (!TfLiteIntArrayEqualsArray(scaling_factors->dims, 1, scaling_dims)) {
292       TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
293       scaling_factors_size->data[0] = batch_size;
294       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
295                                                        scaling_factors_size));
296     }
297 
298     node->temporaries->data[2] = data->scratch_tensor_index + 2;
299     TfLiteTensor* accum_scratch;
300     TF_LITE_ENSURE_OK(
301         context, GetTemporarySafe(context, node, /*index=*/2, &accum_scratch));
302     accum_scratch->type = kTfLiteInt32;
303     accum_scratch->allocation_type = kTfLiteArenaRw;
304     int accum_scratch_dims[2] = {num_units, batch_size};
305     if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2,
306                                    accum_scratch_dims)) {
307       TfLiteIntArray* accum_size = TfLiteIntArrayCreate(2);
308       accum_size->data[0] = num_units;
309       accum_size->data[1] = batch_size;
310       TF_LITE_ENSURE_OK(
311           context, context->ResizeTensor(context, accum_scratch, accum_size));
312     }
313 
314     node->temporaries->data[3] = data->scratch_tensor_index + 3;
315     TfLiteTensor* input_offsets;
316     TF_LITE_ENSURE_OK(
317         context, GetTemporarySafe(context, node, /*index=*/3, &input_offsets));
318     input_offsets->type = kTfLiteInt32;
319     input_offsets->allocation_type = kTfLiteArenaRw;
320     if (!TfLiteIntArrayEqualsArray(input_offsets->dims, 1, scaling_dims)) {
321       TfLiteIntArray* input_offsets_size = TfLiteIntArrayCreate(1);
322       input_offsets_size->data[0] = batch_size;
323       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_offsets,
324                                                        input_offsets_size));
325     }
326     node->temporaries->data[4] = data->scratch_tensor_index + 4;
327     TfLiteTensor* row_sums;
328     TF_LITE_ENSURE_OK(context,
329                       GetTemporarySafe(context, node, /*index=*/4, &row_sums));
330     row_sums->type = kTfLiteInt32;
331     row_sums->allocation_type = kTfLiteArenaRwPersistent;
332     int row_sums_dims[1] = {num_units};
333     if (!TfLiteIntArrayEqualsArray(row_sums->dims, 1, row_sums_dims)) {
334       TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(1);
335       row_sums_size->data[0] = row_sums_dims[0];
336       TF_LITE_ENSURE_OK(
337           context, context->ResizeTensor(context, row_sums, row_sums_size));
338     }
339 
340     if (is_sparse) {
341       data->ledger_initialized = false;
342       node->temporaries->data[5] = data->scratch_tensor_index + 5;
343       TfLiteTensor* filter_ledger =
344           &context->tensors[node->temporaries->data[5]];
345       auto status =
346           CreateLedgerTensor(filter->sparsity, context, filter_ledger);
347       if (status != kTfLiteOk) return status;
348     }
349   }
350 
351   // Resize output.
352   TfLiteIntArray* output_size_array = nullptr;
353   if (params->keep_num_dims) {
354     // When number of dimensions are kept the filter operates along the last
355     // dimensions. In other words, for an input tensor with shape
356     // [batch_size, ..., n_inputs] and a filter of shape [n_inputs, n_units]
357     // this Op produces an output of shape [batch_size, ..., n_units].
358     TF_LITE_ENSURE_EQ(context, input->dims->data[input->dims->size - 1],
359                       SizeOfDimension(filter, 1));
360     output_size_array = TfLiteIntArrayCopy(input->dims);
361     output_size_array->data[output_size_array->size - 1] = num_units;
362   } else {
363     // Otherwise, the output is (potentially flattened to) a 2-D matrix.
364     output_size_array = TfLiteIntArrayCreate(2);
365     output_size_array->data[0] = batch_size;
366     output_size_array->data[1] = num_units;
367   }
368   TF_LITE_ENSURE_OK(context,
369                     context->ResizeTensor(context, output, output_size_array));
370 
371   return kTfLiteOk;
372 }
373 
374 template <KernelType kernel_type>
Prepare(TfLiteContext * context,TfLiteNode * node)375 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
376   // Check for supported activation types.
377   auto* params =
378       reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data);
379   const TfLiteTensor* filter;
380   TF_LITE_ENSURE_OK(context,
381                     GetInputSafe(context, node, kWeightsTensor, &filter));
382   const TfLiteTensor* input;
383   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
384   const bool is_quantized =
385       ((filter->type == kTfLiteUInt8) || (filter->type == kTfLiteInt8));
386   const bool is_hybrid = is_quantized && (input->type == kTfLiteFloat32);
387   const bool is_pie = kernel_type == kLegacyPie;
388 
389   // Pie and hybrid path supports all kinds of fused activations, otherwise only
390   // clipping activations are supported.
391   if (!is_pie && !is_hybrid) {
392     TF_LITE_ENSURE(context, params->activation == kTfLiteActNone ||
393                                 params->activation == kTfLiteActRelu ||
394                                 params->activation == kTfLiteActReluN1To1 ||
395                                 params->activation == kTfLiteActRelu6);
396   }
397   return PrepareImpl(context, node);
398 }
399 
EvalPie(TfLiteContext * context,TfLiteNode * node,TfLiteFullyConnectedParams * params,OpData * data,const TfLiteTensor * input,const TfLiteTensor * filter,const TfLiteTensor * bias,TfLiteTensor * output)400 TfLiteStatus EvalPie(TfLiteContext* context, TfLiteNode* node,
401                      TfLiteFullyConnectedParams* params, OpData* data,
402                      const TfLiteTensor* input, const TfLiteTensor* filter,
403                      const TfLiteTensor* bias, TfLiteTensor* output) {
404   int total_input_size = 1;
405   for (int i = 0; i < input->dims->size; i++) {
406     total_input_size *= input->dims->data[i];
407   }
408 
409   int input_size = filter->dims->data[1];
410   const int batch_size = total_input_size / filter->dims->data[1];
411   const int num_units = filter->dims->data[0];
412 
413   // Output = bias if bias tensor exists.
414   if (bias) {
415     tensor_utils::VectorBatchVectorAssign(GetTensorData<float>(bias), num_units,
416                                           batch_size,
417                                           GetTensorData<float>(output));
418   } else {
419     std::fill_n(GetTensorData<float>(output), batch_size * num_units, 0.0f);
420   }
421 
422   // Compute output += weight * input
423   tensor_utils::MatrixBatchVectorMultiplyAccumulate(
424       GetTensorData<float>(filter), num_units, input_size,
425       GetTensorData<float>(input), batch_size, GetTensorData<float>(output));
426 
427   // Apply activation function
428   tensor_utils::ApplyActivationToVector(
429       GetTensorData<float>(output), batch_size * num_units, params->activation,
430       GetTensorData<float>(output));
431 
432   return kTfLiteOk;
433 }
434 
EvalHybridDense(TfLiteContext * context,TfLiteNode * node,TfLiteFullyConnectedParams * params,OpData * data,const TfLiteTensor * input,const TfLiteTensor * filter,const TfLiteTensor * bias,TfLiteTensor * input_quantized,TfLiteTensor * scaling_factors,TfLiteTensor * accum_scratch,TfLiteTensor * row_sums,TfLiteTensor * input_offsets,TfLiteTensor * output)435 TfLiteStatus EvalHybridDense(
436     TfLiteContext* context, TfLiteNode* node,
437     TfLiteFullyConnectedParams* params, OpData* data, const TfLiteTensor* input,
438     const TfLiteTensor* filter, const TfLiteTensor* bias,
439     TfLiteTensor* input_quantized, TfLiteTensor* scaling_factors,
440     TfLiteTensor* accum_scratch, TfLiteTensor* row_sums,
441     TfLiteTensor* input_offsets, TfLiteTensor* output) {
442   int total_input_size = 1;
443   for (int i = 0; i < input->dims->size; i++) {
444     total_input_size *= input->dims->data[i];
445   }
446 
447   const int input_size = filter->dims->data[1];
448   const int batch_size = total_input_size / filter->dims->data[1];
449   const int num_units = filter->dims->data[0];
450 
451   // Output = bias if bias tensor exists.
452   if (bias) {
453     tensor_utils::VectorBatchVectorAssign(GetTensorData<float>(bias), num_units,
454                                           batch_size,
455                                           GetTensorData<float>(output));
456   } else {
457     std::fill_n(GetTensorData<float>(output), batch_size * num_units, 0.0f);
458   }
459 
460   // Save matrix multiplication computation for all zero input.
461   if (tensor_utils::IsZeroVector(GetTensorData<float>(input),
462                                  total_input_size)) {
463     tensor_utils::ApplyActivationToVector(
464         GetTensorData<float>(output), batch_size * num_units,
465         params->activation, GetTensorData<float>(output));
466     return kTfLiteOk;
467   }
468 
469   // Quantize input from float to uint8 + quantization params (scaling factor).
470   float* scaling_factors_ptr = GetTensorData<float>(scaling_factors);
471   int32_t* input_offset_ptr = nullptr;
472   int32_t* row_sums_ptr = nullptr;
473   if (params->asymmetric_quantize_inputs) {
474     input_offset_ptr = GetTensorData<int32_t>(input_offsets);
475     row_sums_ptr = GetTensorData<int32_t>(row_sums);
476   }
477   int8_t* quant_data = GetTensorData<int8_t>(input_quantized);
478   const int8_t* filter_data = GetTensorData<int8_t>(filter);
479   const float* input_ptr = GetTensorData<float>(input);
480   tensor_utils::BatchQuantizeFloats(
481       input_ptr, batch_size, input_size, quant_data, scaling_factors_ptr,
482       input_offset_ptr, params->asymmetric_quantize_inputs);
483   for (int b = 0; b < batch_size; ++b) {
484     // Incorporate scaling of the filter.
485     scaling_factors_ptr[b] *= filter->params.scale;
486   }
487 
488   // Compute output += weight * quantized_input
489   int32_t* scratch = GetTensorData<int32_t>(accum_scratch);
490   tensor_utils::MatrixBatchVectorMultiplyAccumulate(
491       filter_data, num_units, input_size, quant_data, scaling_factors_ptr,
492       batch_size, GetTensorData<float>(output), /*per_channel_scale=*/nullptr,
493       input_offset_ptr, scratch, row_sums_ptr, &data->compute_row_sums,
494       CpuBackendContext::GetFromContext(context));
495 
496   // Apply activation function to floats.
497   tensor_utils::ApplyActivationToVector(
498       GetTensorData<float>(output), batch_size * num_units, params->activation,
499       GetTensorData<float>(output));
500   return kTfLiteOk;
501 }
502 
EvalSparseHybridImpl(TfLiteContext * context,TfLiteNode * node,TfLiteFullyConnectedParams * params,OpData * data,const TfLiteTensor * input,const TfLiteTensor * filter,const TfLiteTensor * bias,int thread_start,int thread_end,TfLiteTensor * input_quantized,TfLiteTensor * scaling_factors,TfLiteTensor * accum_scratch,TfLiteTensor * row_sums,TfLiteTensor * input_offsets,TfLiteTensor * output)503 void EvalSparseHybridImpl(TfLiteContext* context, TfLiteNode* node,
504                           TfLiteFullyConnectedParams* params, OpData* data,
505                           const TfLiteTensor* input, const TfLiteTensor* filter,
506                           const TfLiteTensor* bias, int thread_start,
507                           int thread_end, TfLiteTensor* input_quantized,
508                           TfLiteTensor* scaling_factors,
509                           TfLiteTensor* accum_scratch, TfLiteTensor* row_sums,
510                           TfLiteTensor* input_offsets, TfLiteTensor* output) {
511   ruy::profiler::ScopeLabel label("FullyConnected");
512   ruy::profiler::ScopeLabel inner_label("Sparse Hybrid Kernel");
513   const auto& input_shape = GetTensorShape(input);
514   const auto& output_shape = GetTensorShape(output);
515   const auto& filter_shape = GetTensorShape(filter);
516   const int input_dims_count = input_shape.DimensionsCount();
517   const int output_dims_count = output_shape.DimensionsCount();
518   const int filter_dims_count = filter_shape.DimensionsCount();
519   const int batch_size = thread_end - thread_start;
520   const int input_depth = MatchingDim(filter_shape, filter_dims_count - 1,
521                                       input_shape, input_dims_count - 1);
522   const int output_depth = MatchingDim(filter_shape, filter_dims_count - 2,
523                                        output_shape, output_dims_count - 1);
524   const int per_thread_input_size = batch_size * input_depth;
525 
526   const float* per_thread_input =
527       GetTensorData<float>(input) + thread_start * input_depth;
528   float* per_thread_output =
529       GetTensorData<float>(output) + thread_start * output_depth;
530 
531   // Output = bias if bias tensor exists.
532   if (bias) {
533     tensor_utils::VectorBatchVectorAssign(GetTensorData<float>(bias),
534                                           output_depth, batch_size,
535                                           per_thread_output);
536   } else {
537     std::fill_n(per_thread_output, batch_size * output_depth, 0.0f);
538   }
539 
540   // Save matrix multiplication computation for all zero input.
541   if (tensor_utils::IsZeroVector(per_thread_input, per_thread_input_size)) {
542     tensor_utils::ApplyActivationToVector(
543         per_thread_output, batch_size * output_depth, params->activation,
544         per_thread_output);
545     return;
546   }
547 
548   // Quantize input from float to uint8 + quantization params (scaling factor).
549   float* scaling_factors_ptr =
550       GetTensorData<float>(scaling_factors) + thread_start;
551   int32_t* input_offset_ptr = nullptr;
552   int32_t* row_sums_ptr = nullptr;
553   if (params->asymmetric_quantize_inputs) {
554     input_offset_ptr = GetTensorData<int32_t>(input_offsets) + thread_start;
555     row_sums_ptr = GetTensorData<int32_t>(row_sums);
556   }
557   int8_t* quant_data =
558       GetTensorData<int8_t>(input_quantized) + thread_start * input_depth;
559   tensor_utils::BatchQuantizeFloats(per_thread_input, batch_size, input_depth,
560                                     quant_data, scaling_factors_ptr,
561                                     input_offset_ptr,
562                                     params->asymmetric_quantize_inputs);
563   for (int b = 0; b < batch_size; ++b) {
564     // Incorporate scaling of the filter.
565     scaling_factors_ptr[b] *= filter->params.scale;
566   }
567 
568   // Compute output += weight * quantized_input
569   TfLiteTensor* filter_ledger = &context->tensors[node->temporaries->data[5]];
570   tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate(
571       GetTensorData<int8_t>(filter), GetTensorData<uint8_t>(filter_ledger),
572       output_depth, input_depth, quant_data, scaling_factors_ptr, batch_size,
573       per_thread_output);
574 
575   // Apply activation function to floats.
576   tensor_utils::ApplyActivationToVector(per_thread_output,
577                                         batch_size * output_depth,
578                                         params->activation, per_thread_output);
579 }
580 
581 struct SparseHybridFullyConnectedTask : cpu_backend_threadpool::Task {
SparseHybridFullyConnectedTasktflite::ops::builtin::fully_connected::SparseHybridFullyConnectedTask582   SparseHybridFullyConnectedTask(
583       TfLiteContext* context, TfLiteNode* node,
584       TfLiteFullyConnectedParams* params, OpData* data,
585       const TfLiteTensor* input, const TfLiteTensor* filter,
586       const TfLiteTensor* bias, const int thread_start, const int thread_end,
587       TfLiteTensor* input_quantized, TfLiteTensor* scaling_factors,
588       TfLiteTensor* accum_scratch, TfLiteTensor* row_sums,
589       TfLiteTensor* input_offsets, TfLiteTensor* output)
590       : context(context),
591         node(node),
592         params(params),
593         data(data),
594         input(input),
595         filter(filter),
596         bias(bias),
597         thread_start(thread_start),
598         thread_end(thread_end),
599         input_quantized(input_quantized),
600         scaling_factors(scaling_factors),
601         accum_scratch(accum_scratch),
602         row_sums(row_sums),
603         input_offsets(input_offsets),
604         output(output) {}
605 
Runtflite::ops::builtin::fully_connected::SparseHybridFullyConnectedTask606   void Run() override {
607     EvalSparseHybridImpl(context, node, params, data, input, filter, bias,
608                          thread_start, thread_end, input_quantized,
609                          scaling_factors, accum_scratch, row_sums,
610                          input_offsets, output);
611   }
612 
613  private:
614   TfLiteContext* context;
615   TfLiteNode* node;
616   TfLiteFullyConnectedParams* params;
617   OpData* data;
618   const TfLiteTensor* input;
619   const TfLiteTensor* filter;
620   const TfLiteTensor* bias;
621   const int thread_start;
622   const int thread_end;
623   TfLiteTensor* input_quantized;
624   TfLiteTensor* scaling_factors;
625   TfLiteTensor* accum_scratch;
626   TfLiteTensor* row_sums;
627   TfLiteTensor* input_offsets;
628   TfLiteTensor* output;
629 };
630 
EvalHybrid(TfLiteContext * context,TfLiteNode * node,TfLiteFullyConnectedParams * params,OpData * data,const TfLiteTensor * input,const TfLiteTensor * filter,const TfLiteTensor * bias,TfLiteTensor * input_quantized,TfLiteTensor * scaling_factors,TfLiteTensor * accum_scratch,TfLiteTensor * row_sums,TfLiteTensor * input_offsets,TfLiteTensor * output)631 TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node,
632                         TfLiteFullyConnectedParams* params, OpData* data,
633                         const TfLiteTensor* input, const TfLiteTensor* filter,
634                         const TfLiteTensor* bias, TfLiteTensor* input_quantized,
635                         TfLiteTensor* scaling_factors,
636                         TfLiteTensor* accum_scratch, TfLiteTensor* row_sums,
637                         TfLiteTensor* input_offsets, TfLiteTensor* output) {
638   const auto& output_shape = GetTensorShape(output);
639   CpuBackendContext* cpu_backend_context =
640       CpuBackendContext::GetFromContext(context);
641   const bool is_dense = filter->sparsity == nullptr;
642   if (is_dense) {
643     return EvalHybridDense(context, node, params, data, input, filter, bias,
644                            input_quantized, scaling_factors, accum_scratch,
645                            row_sums, input_offsets, output);
646   }
647 
648   TfLiteTensor* filter_ledger = &context->tensors[node->temporaries->data[5]];
649   if (!data->ledger_initialized) {
650     PopulateLedgerData(filter->sparsity, context,
651                        GetTensorData<uint8_t>(filter_ledger));
652     data->ledger_initialized = true;
653   }
654 
655   // The multi-threaded kernel slices the workload along the batch dimension. If
656   // there's not enough batches of data, the number of threads used is equal to
657   // the batch size.
658   // TODO(b/173442777): If needed, we can improve this later with slicing along
659   // the row dimension of the weight.
660   const int max_threads = cpu_backend_context->max_num_threads();
661   const int batches =
662       FlatSizeSkipDim(output_shape, output_shape.DimensionsCount() - 1);
663   const int thread_count = std::max(1, std::min(batches, max_threads));
664 
665   std::vector<SparseHybridFullyConnectedTask> tasks;
666   tasks.reserve(thread_count);
667   int thread_start = 0;
668   for (int i = 0; i < thread_count; ++i) {
669     // This makes sure the workload is relatively balanced when batches is not
670     // a multiple of thread_count. The first mod(batches, thread_count) tasks
671     // need to process one more batch than the rest.
672     int thread_end = thread_start + batches / thread_count;
673     if (i < batches % thread_count) thread_end++;
674 
675     tasks.emplace_back(context, node, params, data, input, filter, bias,
676                        thread_start, thread_end, input_quantized,
677                        scaling_factors, accum_scratch, row_sums, input_offsets,
678                        output);
679     thread_start = thread_end;
680   }
681   cpu_backend_threadpool::Execute(tasks.size(), tasks.data(),
682                                   cpu_backend_context);
683   return kTfLiteOk;
684 }
685 
686 namespace {
687 template <KernelType kernel_type>
FullyConnectedInt8(const OpData * data,const TfLiteTensor * input,const TfLiteTensor * filter,const TfLiteTensor * bias,TfLiteTensor * output,CpuBackendContext * cpu_backend_context)688 void FullyConnectedInt8(const OpData* data, const TfLiteTensor* input,
689                         const TfLiteTensor* filter, const TfLiteTensor* bias,
690                         TfLiteTensor* output,
691                         CpuBackendContext* cpu_backend_context) {
692   FullyConnectedParams op_params;
693   op_params.input_offset = -input->params.zero_point;
694   op_params.weights_offset = -filter->params.zero_point;
695   op_params.output_offset = output->params.zero_point;
696   op_params.output_multiplier = data->output_multiplier;
697   op_params.output_shift = data->output_shift;
698   op_params.quantized_activation_min = data->output_activation_min;
699   op_params.quantized_activation_max = data->output_activation_max;
700   op_params.lhs_cacheable = IsConstantTensor(filter);
701   op_params.rhs_cacheable = IsConstantTensor(input);
702   if (kernel_type == kReference) {
703     reference_integer_ops::FullyConnected(
704         op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
705         GetTensorShape(filter), GetTensorData<int8_t>(filter),
706         GetTensorShape(bias), GetTensorData<int32_t>(bias),
707         GetTensorShape(output), GetTensorData<int8_t>(output));
708   } else {
709     optimized_integer_ops::FullyConnected(
710         op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
711         GetTensorShape(filter), GetTensorData<int8_t>(filter),
712         GetTensorShape(bias), GetTensorData<int32_t>(bias),
713         GetTensorShape(output), GetTensorData<int8_t>(output),
714         cpu_backend_context);
715   }
716 }
717 }  // namespace
718 
719 namespace {
720 template <KernelType kernel_type>
FullyConnectedInt16(const OpData * data,const TfLiteTensor * input,const TfLiteTensor * filter,const TfLiteTensor * bias,TfLiteTensor * output)721 void FullyConnectedInt16(const OpData* data, const TfLiteTensor* input,
722                          const TfLiteTensor* filter, const TfLiteTensor* bias,
723                          TfLiteTensor* output) {
724   FullyConnectedParams op_params;
725   op_params.weights_offset = -filter->params.zero_point;
726   op_params.output_multiplier = data->output_multiplier;
727   op_params.output_shift = data->output_shift;
728   op_params.quantized_activation_min = data->output_activation_min;
729   op_params.quantized_activation_max = data->output_activation_max;
730   reference_integer_ops::FullyConnected(
731       op_params, GetTensorShape(input), GetTensorData<int16_t>(input),
732       GetTensorShape(filter), GetTensorData<int8_t>(filter),
733       GetTensorShape(bias), GetTensorData<int64_t>(bias),
734       GetTensorShape(output), GetTensorData<int16_t>(output));
735 }
736 }  // namespace
737 
738 template <KernelType kernel_type>
EvalQuantized(TfLiteContext * context,TfLiteNode * node,TfLiteFullyConnectedParams * params,OpData * data,const TfLiteTensor * input,const TfLiteTensor * filter,const TfLiteTensor * bias,TfLiteTensor * output)739 TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
740                            TfLiteFullyConnectedParams* params, OpData* data,
741                            const TfLiteTensor* input,
742                            const TfLiteTensor* filter, const TfLiteTensor* bias,
743                            TfLiteTensor* output) {
744   int32_t input_offset = -input->params.zero_point;
745   int32_t filter_offset = -filter->params.zero_point;
746   int32_t output_offset = output->params.zero_point;
747   // Only the Pie path supports quantized models and float inputs/outputs.
748   if (input->type == kTfLiteFloat32) {
749     TfLiteTensor* input_quantized;
750     TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/0,
751                                                 &input_quantized));
752     TfLiteTensor* scaling_factors;
753     TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/1,
754                                                 &scaling_factors));
755     TfLiteTensor* accum_scratch;
756     TF_LITE_ENSURE_OK(
757         context, GetTemporarySafe(context, node, /*index=*/2, &accum_scratch));
758     TfLiteTensor* input_offsets;
759     TF_LITE_ENSURE_OK(
760         context, GetTemporarySafe(context, node, /*index=*/3, &input_offsets));
761     TfLiteTensor* row_sums;
762     TF_LITE_ENSURE_OK(context,
763                       GetTemporarySafe(context, node, /*index=*/4, &row_sums));
764     return EvalHybrid(context, node, params, data, input, filter, bias,
765                       input_quantized, scaling_factors, accum_scratch, row_sums,
766                       input_offsets, output);
767   } else {
768     FullyConnectedParams op_params;
769     op_params.input_offset = input_offset;
770     op_params.weights_offset = filter_offset;
771     op_params.output_offset = output_offset;
772     op_params.output_multiplier = data->output_multiplier;
773     op_params.output_shift = data->output_shift;
774     op_params.quantized_activation_min = data->output_activation_min;
775     op_params.quantized_activation_max = data->output_activation_max;
776     op_params.lhs_cacheable = IsConstantTensor(filter);
777     op_params.rhs_cacheable = IsConstantTensor(input);
778     switch (output->type) {
779       case kTfLiteUInt8:
780         if (kernel_type == kReference) {
781           reference_ops::FullyConnected(
782               op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
783               GetTensorShape(filter), GetTensorData<uint8_t>(filter),
784               GetTensorShape(bias), GetTensorData<int32_t>(bias),
785               GetTensorShape(output), GetTensorData<uint8_t>(output));
786         } else {
787           optimized_ops::FullyConnected(
788               op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
789               GetTensorShape(filter), GetTensorData<uint8_t>(filter),
790               GetTensorShape(bias), GetTensorData<int32_t>(bias),
791               GetTensorShape(output), GetTensorData<uint8_t>(output),
792               CpuBackendContext::GetFromContext(context));
793         }
794         break;
795       case kTfLiteInt8:
796         FullyConnectedInt8<kernel_type>(
797             data, input, filter, bias, output,
798             CpuBackendContext::GetFromContext(context));
799         break;
800       case kTfLiteInt16:
801         if (input->type == kTfLiteInt16) {
802           FullyConnectedInt16<kernel_type>(data, input, filter, bias, output);
803         } else if (kernel_type == kReference) {
804           reference_ops::FullyConnected(
805               op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
806               GetTensorShape(filter), GetTensorData<uint8_t>(filter),
807               GetTensorShape(bias), GetTensorData<int32_t>(bias),
808               GetTensorShape(output), GetTensorData<int16_t>(output));
809         } else {
810           optimized_ops::FullyConnected(
811               op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
812               GetTensorShape(filter), GetTensorData<uint8_t>(filter),
813               GetTensorShape(bias), GetTensorData<int32_t>(bias),
814               GetTensorShape(output), GetTensorData<int16_t>(output),
815               CpuBackendContext::GetFromContext(context));
816         }
817         break;
818       default:
819         context->ReportError(context,
820                              "Quantized FullyConnected expects output data "
821                              "type uint8, int8 or int16");
822         return kTfLiteError;
823     }
824   }
825 
826   return kTfLiteOk;
827 }
828 
829 template <KernelType kernel_type>
EvalShuffledQuantized(TfLiteContext * context,TfLiteNode * node,TfLiteFullyConnectedParams * params,OpData * data,const TfLiteTensor * input,const TfLiteTensor * filter,const TfLiteTensor * bias,TfLiteTensor * output,TfLiteTensor * shuffled_input_workspace)830 TfLiteStatus EvalShuffledQuantized(TfLiteContext* context, TfLiteNode* node,
831                                    TfLiteFullyConnectedParams* params,
832                                    OpData* data, const TfLiteTensor* input,
833                                    const TfLiteTensor* filter,
834                                    const TfLiteTensor* bias,
835                                    TfLiteTensor* output,
836                                    TfLiteTensor* shuffled_input_workspace) {
837   // TODO(b/110697972) decide more consistently if / how / where we want
838   // to perform this kind of runtime data type checks.
839   if (shuffled_input_workspace->type != kTfLiteUInt8) {
840     context->ReportError(context, "Unexpected data type");
841     return kTfLiteError;
842   }
843 
844 #define TF_LITE_SHUFFLED_FULLY_CONNECTED(type)                           \
845   {                                                                      \
846     type::ShuffledFullyConnected(                                        \
847         op_params, GetTensorShape(input), GetTensorData<uint8_t>(input), \
848         GetTensorShape(filter), GetTensorData<uint8_t>(filter),          \
849         GetTensorShape(bias), GetTensorData<int32_t>(bias),              \
850         GetTensorShape(output), GetTensorData<int16_t>(output),          \
851         GetTensorData<uint8_t>(shuffled_input_workspace),                \
852         CpuBackendContext::GetFromContext(context));                     \
853   }
854   FullyConnectedParams op_params;
855   op_params.output_multiplier = data->output_multiplier;
856   op_params.output_shift = data->output_shift;
857   op_params.quantized_activation_min = data->output_activation_min;
858   op_params.quantized_activation_max = data->output_activation_max;
859   op_params.lhs_cacheable = IsConstantTensor(filter);
860   op_params.rhs_cacheable = IsConstantTensor(input);
861   if (kernel_type == kReference) {
862     reference_ops::ShuffledFullyConnected(
863         op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
864         GetTensorShape(filter), GetTensorData<uint8_t>(filter),
865         GetTensorShape(bias), GetTensorData<int32_t>(bias),
866         GetTensorShape(output), GetTensorData<int16_t>(output),
867         GetTensorData<uint8_t>(shuffled_input_workspace));
868   } else {
869     optimized_ops::ShuffledFullyConnected(
870         op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
871         GetTensorShape(filter), GetTensorData<uint8_t>(filter),
872         GetTensorShape(bias), GetTensorData<int32_t>(bias),
873         GetTensorShape(output), GetTensorData<int16_t>(output),
874         GetTensorData<uint8_t>(shuffled_input_workspace),
875         CpuBackendContext::GetFromContext(context));
876   }
877 #undef TF_LITE_SHUFFLED_FULLY_CONNECTED
878 
879   return kTfLiteOk;
880 }
881 
882 template <KernelType kernel_type>
EvalFloat(TfLiteContext * context,TfLiteNode * node,TfLiteFullyConnectedParams * params,OpData * data,const TfLiteTensor * input,const TfLiteTensor * filter,const TfLiteTensor * bias,TfLiteTensor * output)883 TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
884                        TfLiteFullyConnectedParams* params, OpData* data,
885                        const TfLiteTensor* input, const TfLiteTensor* filter,
886                        const TfLiteTensor* bias, TfLiteTensor* output) {
887   float output_activation_min, output_activation_max;
888   CalculateActivationRange(params->activation, &output_activation_min,
889                            &output_activation_max);
890   if (kernel_type == kReference) {
891     FullyConnectedParams op_params;
892     op_params.float_activation_min = output_activation_min;
893     op_params.float_activation_max = output_activation_max;
894     if (filter->sparsity != nullptr) {
895       const auto& sparsity = *filter->sparsity;
896       reference_ops::FullyConnectedSparseWeight(
897           sparsity, op_params, GetTensorShape(input),
898           GetTensorData<float>(input), GetTensorShape(filter),
899           GetTensorData<float>(filter), GetTensorShape(bias),
900           GetTensorData<float>(bias), GetTensorShape(output),
901           GetTensorData<float>(output));
902     } else {
903       reference_ops::FullyConnected(
904           op_params, GetTensorShape(input), GetTensorData<float>(input),
905           GetTensorShape(filter), GetTensorData<float>(filter),
906           GetTensorShape(bias), GetTensorData<float>(bias),
907           GetTensorShape(output), GetTensorData<float>(output));
908     }
909   } else if (kernel_type == kLegacyPie) {
910     return EvalPie(context, node, params, data, input, filter, bias, output);
911   } else {
912     FullyConnectedParams op_params;
913     op_params.float_activation_min = output_activation_min;
914     op_params.float_activation_max = output_activation_max;
915     if (filter->sparsity != nullptr) {
916       const auto& sparsity = *filter->sparsity;
917       if (!SupportedSparsityFormat(sparsity)) {
918         TF_LITE_KERNEL_LOG(context,
919                            "Unsupported sparse fully-connected weight format.");
920         return kTfLiteError;
921       }
922 
923       if (sparsity.dim_metadata_size == kDimMetadataSizeRandomSparse) {
924         // Random sparse.
925         optimized_ops::FullyConnectedSparseWeight(
926             sparsity, op_params, GetTensorShape(input),
927             GetTensorData<float>(input), GetTensorShape(filter),
928             GetTensorData<float>(filter), GetTensorShape(bias),
929             GetTensorData<float>(bias), GetTensorShape(output),
930             GetTensorData<float>(output));
931       } else if (sparsity.dim_metadata_size == kDimMetadataSizeBlockSparse &&
932                  sparsity.dim_metadata[2].dense_size == 4) {
933         // Block sparse with block size of 1x4.
934         optimized_ops::FullyConnectedSparseWeight1x4(
935             sparsity, op_params, GetTensorShape(input),
936             GetTensorData<float>(input), GetTensorShape(filter),
937             GetTensorData<float>(filter), GetTensorShape(bias),
938             GetTensorData<float>(bias), GetTensorShape(output),
939             GetTensorData<float>(output),
940             CpuBackendContext::GetFromContext(context));
941       } else {
942         TF_LITE_KERNEL_LOG(context,
943                            "Unsupported sparse fully-connected weight format.");
944         return kTfLiteError;
945       }
946 
947     } else {
948       op_params.lhs_cacheable = IsConstantTensor(filter);
949       op_params.rhs_cacheable = IsConstantTensor(input);
950       optimized_ops::FullyConnected(
951           op_params, GetTensorShape(input), GetTensorData<float>(input),
952           GetTensorShape(filter), GetTensorData<float>(filter),
953           GetTensorShape(bias), GetTensorData<float>(bias),
954           GetTensorShape(output), GetTensorData<float>(output),
955           CpuBackendContext::GetFromContext(context));
956     }
957   }
958 
959   return kTfLiteOk;
960 }
961 
962 template <KernelType kernel_type>
Eval(TfLiteContext * context,TfLiteNode * node)963 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
964   auto* params =
965       reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data);
966   OpData* data = reinterpret_cast<OpData*>(node->user_data);
967 
968   const TfLiteTensor* input;
969   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
970   const TfLiteTensor* filter;
971   TF_LITE_ENSURE_OK(context,
972                     GetInputSafe(context, node, kWeightsTensor, &filter));
973   const TfLiteTensor* bias =
974       (node->inputs->size == 3)
975           ? GetOptionalInputTensor(context, node, kBiasTensor)
976           : nullptr;
977   TfLiteTensor* output;
978   TF_LITE_ENSURE_OK(context,
979                     GetOutputSafe(context, node, kOutputTensor, &output));
980   // Do nothing if expected output is empty.
981   if (NumElements(output) == 0) {
982     return kTfLiteOk;
983   }
984 
985   switch (filter->type) {
986     case kTfLiteFloat32:
987       return EvalFloat<kernel_type>(context, node, params, data, input, filter,
988                                     bias, output);
989     case kTfLiteUInt8:
990       if (params->weights_format ==
991           kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8) {
992         TfLiteTensor* shuffled_input_workspace;
993         TF_LITE_ENSURE_OK(
994             context, GetOutputSafe(context, node, kShuffledInputWorkspaceTensor,
995                                    &shuffled_input_workspace));
996         return EvalShuffledQuantized<kernel_type>(context, node, params, data,
997                                                   input, filter, bias, output,
998                                                   shuffled_input_workspace);
999       } else if (params->weights_format ==
1000                  kTfLiteFullyConnectedWeightsFormatDefault) {
1001         return EvalQuantized<kernel_type>(context, node, params, data, input,
1002                                           filter, bias, output);
1003       } else {
1004         context->ReportError(context,
1005                              "Unhandled fully-connected weights format");
1006         return kTfLiteError;
1007       }
1008     case kTfLiteInt8:
1009       if (params->weights_format == kTfLiteFullyConnectedWeightsFormatDefault) {
1010         return EvalQuantized<kernel_type>(context, node, params, data, input,
1011                                           filter, bias, output);
1012       } else {
1013         context->ReportError(context,
1014                              "Unhandled fully-connected weights format");
1015         return kTfLiteError;
1016       }
1017     default:
1018       context->ReportError(context,
1019                            "Filter data type %s currently not supported.",
1020                            TfLiteTypeGetName(filter->type));
1021       return kTfLiteError;
1022   }
1023   return kTfLiteOk;
1024 }
1025 
1026 }  // namespace fully_connected
1027 
Register_FULLY_CONNECTED_REF()1028 TfLiteRegistration* Register_FULLY_CONNECTED_REF() {
1029   static TfLiteRegistration r = {
1030       fully_connected::Init, fully_connected::Free,
1031       fully_connected::Prepare<fully_connected::kReference>,
1032       fully_connected::Eval<fully_connected::kReference>};
1033   return &r;
1034 }
1035 
Register_FULLY_CONNECTED_GENERIC_OPT()1036 TfLiteRegistration* Register_FULLY_CONNECTED_GENERIC_OPT() {
1037   static TfLiteRegistration r = {
1038       fully_connected::Init, fully_connected::Free,
1039       fully_connected::Prepare<fully_connected::kGenericOptimized>,
1040       fully_connected::Eval<fully_connected::kGenericOptimized>};
1041   return &r;
1042 }
1043 
1044 // Legacy path for PIE clients.
Register_FULLY_CONNECTED_PIE()1045 TfLiteRegistration* Register_FULLY_CONNECTED_PIE() {
1046   static TfLiteRegistration r = {
1047       fully_connected::Init, fully_connected::Free,
1048       fully_connected::Prepare<fully_connected::kLegacyPie>,
1049       fully_connected::Eval<fully_connected::kLegacyPie>};
1050   return &r;
1051 }
1052 
Register_FULLY_CONNECTED()1053 TfLiteRegistration* Register_FULLY_CONNECTED() {
1054   return Register_FULLY_CONNECTED_GENERIC_OPT();
1055 }
1056 
1057 }  // namespace builtin
1058 }  // namespace ops
1059 }  // namespace tflite
1060