• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cassert>
17 #include <cmath>
18 #include <cstdio>
19 #include <cstdlib>
20 #include <iostream>
21 #include <limits>
22 
23 #include "tensorflow/lite/c/builtin_op_data.h"
24 #include "tensorflow/lite/c/c_api_internal.h"
25 #include "tensorflow/lite/kernels/activation_functor.h"
26 #include "tensorflow/lite/kernels/gemm_support.h"
27 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
28 #include "tensorflow/lite/kernels/internal/quantization_util.h"
29 #include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h"
30 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
31 #include "tensorflow/lite/kernels/internal/tensor.h"
32 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
33 #include "tensorflow/lite/kernels/kernel_util.h"
34 #include "tensorflow/lite/kernels/op_macros.h"
35 
36 namespace tflite {
37 namespace ops {
38 namespace builtin {
39 namespace fully_connected {
40 
41 // This file has four implementations of FullyConnected
42 enum KernelType {
43   kReference,
44   kGenericOptimized,
45   kLegacyPie,  // Legacy path used by the PIE team and related clients.
46 };
47 
48 struct OpData {
49   // The scaling factor from input to output (aka the 'real multiplier') can
50   // be represented as a fixed point multiplier plus a left shift.
51   int32_t output_multiplier;
52   int output_shift;
53   // The range of the fused activation layer. For example for kNone and
54   // uint8_t these would be 0 and 255.
55   int32_t output_activation_min;
56   int32_t output_activation_max;
57   // The index of the temporary tensor where the quantized inputs are cached.
58   int scratch_tensor_index;
59 };
60 
61 constexpr int kInputTensor = 0;
62 constexpr int kWeightsTensor = 1;
63 constexpr int kBiasTensor = 2;
64 constexpr int kOutputTensor = 0;
65 constexpr int kShuffledInputWorkspaceTensor = 1;
66 
Init(TfLiteContext * context,const char * buffer,size_t length)67 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
68   // This is a builtin op, so we don't use the contents in 'buffer', if any.
69   // Instead, we allocate a new object to carry information from Prepare() to
70   // Eval().
71   gemm_support::IncrementUsageCounter(context);
72   auto* op_data = new OpData();
73   context->AddTensors(context, /*tensors_to_add=*/2,
74                       &op_data->scratch_tensor_index);
75   return op_data;
76 }
77 
Free(TfLiteContext * context,void * buffer)78 void Free(TfLiteContext* context, void* buffer) {
79   gemm_support::DecrementUsageCounter(context);
80   delete reinterpret_cast<OpData*>(buffer);
81 }
82 
Prepare(TfLiteContext * context,TfLiteNode * node)83 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
84   auto* params =
85       reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data);
86   OpData* data = reinterpret_cast<OpData*>(node->user_data);
87 
88   // Check we have all the inputs and outputs we need.
89   TF_LITE_ENSURE_EQ(context, node->inputs->size, 3);
90   // Shuffled formats need a workspace to store the shuffled input activations.
91   const int expected_outputs_count =
92       params->weights_format == kTfLiteFullyConnectedWeightsFormatDefault ? 1
93                                                                           : 2;
94   TF_LITE_ENSURE_EQ(context, node->outputs->size, expected_outputs_count);
95 
96   const TfLiteTensor* input = GetInput(context, node, kInputTensor);
97   const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor);
98   const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
99   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
100 
101   // Check all the parameters of tensor match within themselves and match the
102   // input configuration.
103   int input_size = 1;
104   for (int i = 0; i < input->dims->size; i++) {
105     input_size *= input->dims->data[i];
106   }
107 
108   TF_LITE_ENSURE_EQ(context, NumDimensions(filter), 2);
109   const int batch_size = input_size / filter->dims->data[1];
110   const int num_units = filter->dims->data[0];
111 
112   TF_LITE_ENSURE_EQ(context, input_size, batch_size * filter->dims->data[1]);
113   if (bias) {
114     TF_LITE_ENSURE_EQ(context, NumElements(bias), SizeOfDimension(filter, 0));
115   }
116 
117   // Note that quantized inference requires that all tensors have their
118   // parameters set. This is usually done during quantized training.
119   TfLiteType data_type = input->type;
120   if (data_type != kTfLiteFloat32 && data_type != kTfLiteInt32) {
121     double real_multiplier = 0.0;
122     TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
123         context, input, filter, bias, output, &real_multiplier));
124     int exponent;
125     QuantizeMultiplier(real_multiplier, &data->output_multiplier, &exponent);
126     data->output_shift = -exponent;
127     TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized(
128         context, params->activation, output, &data->output_activation_min,
129         &data->output_activation_max));
130   }
131 
132   // If we have to perform on-the-fly quantization (with quantized weights and
133   // float inputs) first we need to quantize the inputs. Allocate a temporary
134   // buffer to store the intermediate quantized values.
135   if (input->type == kTfLiteFloat32 &&
136       (filter->type == kTfLiteUInt8 || filter->type == kTfLiteInt8)) {
137     TfLiteIntArrayFree(node->temporaries);
138     node->temporaries = TfLiteIntArrayCreate(2);
139     node->temporaries->data[0] = data->scratch_tensor_index;
140 
141     TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0);
142     input_quantized->type = filter->type;
143     input_quantized->allocation_type = kTfLiteArenaRw;
144 
145     // TODO(raziel): add this logic to ResizeTensor.
146     if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
147       TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
148       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
149                                                        input_quantized_size));
150     }
151     node->temporaries->data[1] = data->scratch_tensor_index + 1;
152     TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/1);
153     scaling_factors->type = kTfLiteFloat32;
154     scaling_factors->allocation_type = kTfLiteArenaRw;
155     int scaling_dims[1] = {batch_size};
156     if (!TfLiteIntArrayEqualsArray(scaling_factors->dims, 1, scaling_dims)) {
157       TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
158       scaling_factors_size->data[0] = batch_size;
159       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
160                                                        scaling_factors_size));
161     }
162   }
163 
164   // Resize output.
165   TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(2);
166   output_size_array->data[0] = batch_size;
167   output_size_array->data[1] = num_units;
168   TF_LITE_ENSURE_OK(context,
169                     context->ResizeTensor(context, output, output_size_array));
170   return kTfLiteOk;
171 }
172 
EvalPie(TfLiteContext * context,TfLiteNode * node,TfLiteFullyConnectedParams * params,OpData * data,const TfLiteTensor * input,const TfLiteTensor * filter,const TfLiteTensor * bias,TfLiteTensor * output)173 TfLiteStatus EvalPie(TfLiteContext* context, TfLiteNode* node,
174                      TfLiteFullyConnectedParams* params, OpData* data,
175                      const TfLiteTensor* input, const TfLiteTensor* filter,
176                      const TfLiteTensor* bias, TfLiteTensor* output) {
177   int total_input_size = 1;
178   for (int i = 0; i < input->dims->size; i++) {
179     total_input_size *= input->dims->data[i];
180   }
181 
182   int input_size = filter->dims->data[1];
183   const int batch_size = total_input_size / filter->dims->data[1];
184   const int num_units = filter->dims->data[0];
185 
186   // Output = bias if bias tensor exists.
187   if (bias) {
188     tensor_utils::VectorBatchVectorAssign(bias->data.f, num_units, batch_size,
189                                           output->data.f);
190   } else {
191     tensor_utils::ZeroVector(output->data.f, batch_size * num_units);
192   }
193 
194   // Compute output += weight * input
195   tensor_utils::MatrixBatchVectorMultiplyAccumulate(
196       filter->data.f, num_units, input_size, input->data.f, batch_size,
197       output->data.f, /*result_stride=*/1);
198 
199   // Apply activation function
200   tensor_utils::ApplyActivationToVector(output->data.f, batch_size * num_units,
201                                         params->activation, output->data.f);
202 
203   return kTfLiteOk;
204 }
205 
EvalHybrid(TfLiteContext * context,TfLiteNode * node,TfLiteFullyConnectedParams * params,OpData * data,const TfLiteTensor * input,const TfLiteTensor * filter,const TfLiteTensor * bias,TfLiteTensor * input_quantized,TfLiteTensor * scaling_factors,TfLiteTensor * output)206 TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node,
207                         TfLiteFullyConnectedParams* params, OpData* data,
208                         const TfLiteTensor* input, const TfLiteTensor* filter,
209                         const TfLiteTensor* bias, TfLiteTensor* input_quantized,
210                         TfLiteTensor* scaling_factors, TfLiteTensor* output) {
211   // Check the types for this hybrid Op.
212   TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
213   TF_LITE_ENSURE(context,
214                  filter->type == kTfLiteUInt8 || filter->type == kTfLiteInt8);
215   if (bias) {
216     TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteFloat32);
217   }
218   TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32);
219 
220   int total_input_size = 1;
221   for (int i = 0; i < input->dims->size; i++) {
222     total_input_size *= input->dims->data[i];
223   }
224 
225   const int input_size = filter->dims->data[1];
226   const int batch_size = total_input_size / filter->dims->data[1];
227   const int num_units = filter->dims->data[0];
228 
229   // Output = bias if bias tensor exists.
230   if (bias) {
231     tensor_utils::VectorBatchVectorAssign(bias->data.f, num_units, batch_size,
232                                           output->data.f);
233   } else {
234     tensor_utils::ZeroVector(output->data.f, batch_size * num_units);
235   }
236 
237   // Save matrix multiplication computation for all zero input.
238   if (tensor_utils::IsZeroVector(input->data.f, total_input_size)) {
239     tensor_utils::ApplyActivationToVector(output->data.f,
240                                           batch_size * num_units,
241                                           params->activation, output->data.f);
242     return kTfLiteOk;
243   }
244 
245   // Quantize input from float to uint8 + quantization params (scaling factor).
246   float unused_min, unused_max;
247   float* scaling_factors_ptr = scaling_factors->data.f;
248   int8_t* quant_data;
249   int8_t* filter_data;
250   if (filter->type == kTfLiteUInt8) {
251     quant_data = reinterpret_cast<int8_t*>(input_quantized->data.uint8);
252     filter_data = reinterpret_cast<int8_t*>(filter->data.uint8);
253   } else {
254     quant_data = input_quantized->data.int8;
255     filter_data = filter->data.int8;
256   }
257 
258   // Quantize each batch independently.
259   for (int b = 0; b < batch_size; ++b) {
260     const int offset = b * input_size;
261     tensor_utils::SymmetricQuantizeFloats(input->data.f + offset, input_size,
262                                           quant_data + offset, &unused_min,
263                                           &unused_max, &scaling_factors_ptr[b]);
264     // Incorporate scaling of the filter.
265     scaling_factors_ptr[b] *= filter->params.scale;
266   }
267 
268   // Compute output += weight * quantized_input
269   tensor_utils::MatrixBatchVectorMultiplyAccumulate(
270       filter_data, num_units, input_size, quant_data, scaling_factors_ptr,
271       batch_size, output->data.f,
272       /*result_stride=*/1);
273 
274   // Apply activation function to floats.
275   tensor_utils::ApplyActivationToVector(output->data.f, batch_size * num_units,
276                                         params->activation, output->data.f);
277   return kTfLiteOk;
278 }
279 
280 #define TF_LITE_MACRO_DISPATCH(macro_name, params, target_namespace) \
281   if (params->activation == kTfLiteActNone) {                        \
282     macro_name(target_namespace, kNone);                             \
283   }                                                                  \
284   if (params->activation == kTfLiteActRelu) {                        \
285     macro_name(target_namespace, kRelu);                             \
286   }                                                                  \
287   if (params->activation == kTfLiteActRelu6) {                       \
288     macro_name(target_namespace, kRelu6);                            \
289   }
290 
291 namespace {
FullyConnectedInt8(const OpData * data,const TfLiteTensor * input,const TfLiteTensor * filter,const TfLiteTensor * bias,TfLiteTensor * output,gemmlowp::GemmContext * gemm_context)292 void FullyConnectedInt8(const OpData* data, const TfLiteTensor* input,
293                         const TfLiteTensor* filter, const TfLiteTensor* bias,
294                         TfLiteTensor* output,
295                         gemmlowp::GemmContext* gemm_context) {
296   FullyConnectedParams op_params;
297   op_params.input_offset = -input->params.zero_point;
298   op_params.weights_offset = -filter->params.zero_point;
299   op_params.output_offset = output->params.zero_point;
300   op_params.output_multiplier = data->output_multiplier;
301   op_params.output_shift = -data->output_shift;
302   op_params.quantized_activation_min = data->output_activation_min;
303   op_params.quantized_activation_max = data->output_activation_max;
304   reference_integer_ops::FullyConnected(
305       op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
306       GetTensorShape(filter), GetTensorData<int8_t>(filter),
307       GetTensorShape(bias), GetTensorData<int32_t>(bias),
308       GetTensorShape(output), GetTensorData<int8_t>(output), gemm_context);
309 }
310 }  // namespace
311 
312 template <KernelType kernel_type>
EvalQuantized(TfLiteContext * context,TfLiteNode * node,TfLiteFullyConnectedParams * params,OpData * data,const TfLiteTensor * input,const TfLiteTensor * filter,const TfLiteTensor * bias,TfLiteTensor * output)313 TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
314                            TfLiteFullyConnectedParams* params, OpData* data,
315                            const TfLiteTensor* input,
316                            const TfLiteTensor* filter, const TfLiteTensor* bias,
317                            TfLiteTensor* output) {
318   gemmlowp::GemmContext* gemm_context = gemm_support::GetFromContext(context);
319 
320   int32_t input_offset = -input->params.zero_point;
321   int32_t filter_offset = -filter->params.zero_point;
322   int32_t output_offset = output->params.zero_point;
323 #define TF_LITE_FULLY_CONNECTED(type, output_data_type)                  \
324   {                                                                      \
325     FullyConnectedParams op_params;                                      \
326     op_params.input_offset = input_offset;                               \
327     op_params.weights_offset = filter_offset;                            \
328     op_params.output_offset = output_offset;                             \
329     op_params.output_multiplier = data->output_multiplier;               \
330     op_params.output_shift = -data->output_shift;                        \
331     op_params.quantized_activation_min = data->output_activation_min;    \
332     op_params.quantized_activation_max = data->output_activation_max;    \
333     type::FullyConnected(                                                \
334         op_params, GetTensorShape(input), GetTensorData<uint8_t>(input), \
335         GetTensorShape(filter), GetTensorData<uint8_t>(filter),          \
336         GetTensorShape(bias), GetTensorData<int32_t>(bias),              \
337         GetTensorShape(output), GetTensorData<output_data_type>(output), \
338         gemm_context);                                                   \
339   }
340   // Only the Pie path supports quantized models and float inputs/outputs.
341   if (input->type == kTfLiteFloat32) {
342     TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0);
343     TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/1);
344     return EvalHybrid(context, node, params, data, input, filter, bias,
345                       input_quantized, scaling_factors, output);
346   } else if (kernel_type == kReference) {
347     switch (output->type) {
348       case kTfLiteUInt8:
349         TF_LITE_FULLY_CONNECTED(reference_ops, uint8_t);
350         break;
351       case kTfLiteInt8:
352         FullyConnectedInt8(data, input, filter, bias, output, gemm_context);
353         break;
354       case kTfLiteInt16:
355         TF_LITE_FULLY_CONNECTED(reference_ops, int16_t);
356         break;
357       default:
358         context->ReportError(
359             context,
360             "Quantized FullyConnected expects output data type uint8 or int16");
361         return kTfLiteError;
362     }
363   } else {
364     switch (output->type) {
365       case kTfLiteUInt8:
366         TF_LITE_FULLY_CONNECTED(optimized_ops, uint8_t);
367         break;
368       case kTfLiteInt8:
369         FullyConnectedInt8(data, input, filter, bias, output, gemm_context);
370         break;
371       case kTfLiteInt16:
372         TF_LITE_FULLY_CONNECTED(optimized_ops, int16_t);
373         break;
374       default:
375         context->ReportError(
376             context,
377             "Quantized FullyConnected expects output data type uint8 or int16");
378         return kTfLiteError;
379     }
380   }
381 #undef TF_LITE_FULLY_CONNECTED
382 
383   return kTfLiteOk;
384 }
385 
386 template <KernelType kernel_type>
EvalShuffledQuantized(TfLiteContext * context,TfLiteNode * node,TfLiteFullyConnectedParams * params,OpData * data,const TfLiteTensor * input,const TfLiteTensor * filter,const TfLiteTensor * bias,TfLiteTensor * output,TfLiteTensor * shuffled_input_workspace)387 TfLiteStatus EvalShuffledQuantized(TfLiteContext* context, TfLiteNode* node,
388                                    TfLiteFullyConnectedParams* params,
389                                    OpData* data, const TfLiteTensor* input,
390                                    const TfLiteTensor* filter,
391                                    const TfLiteTensor* bias,
392                                    TfLiteTensor* output,
393                                    TfLiteTensor* shuffled_input_workspace) {
394   gemmlowp::GemmContext* gemm_context = gemm_support::GetFromContext(context);
395 
396   // TODO(b/110697972) decide more consistently if / how / where we want
397   // to perform this kind of runtime data type checks.
398   if (input->type != kTfLiteUInt8 || filter->type != kTfLiteUInt8 ||
399       bias->type != kTfLiteInt32 || output->type != kTfLiteInt16 ||
400       shuffled_input_workspace->type != kTfLiteUInt8) {
401     context->ReportError(context, "Unexpected data type");
402     return kTfLiteError;
403   }
404 
405 #define TF_LITE_SHUFFLED_FULLY_CONNECTED(type)                           \
406   {                                                                      \
407     FullyConnectedParams op_params;                                      \
408     op_params.output_multiplier = data->output_multiplier;               \
409     op_params.output_shift = -data->output_shift;                        \
410     op_params.quantized_activation_min = data->output_activation_min;    \
411     op_params.quantized_activation_max = data->output_activation_max;    \
412     type::ShuffledFullyConnected(                                        \
413         op_params, GetTensorShape(input), GetTensorData<uint8_t>(input), \
414         GetTensorShape(filter), GetTensorData<uint8_t>(filter),          \
415         GetTensorShape(bias), GetTensorData<int32_t>(bias),              \
416         GetTensorShape(output), GetTensorData<int16_t>(output),          \
417         GetTensorData<uint8_t>(shuffled_input_workspace), gemm_context); \
418   }
419   if (kernel_type == kReference) {
420     TF_LITE_SHUFFLED_FULLY_CONNECTED(reference_ops);
421   } else {
422     TF_LITE_SHUFFLED_FULLY_CONNECTED(optimized_ops);
423   }
424 #undef TF_LITE_SHUFFLED_FULLY_CONNECTED
425 
426   return kTfLiteOk;
427 }
428 
429 template <KernelType kernel_type>
EvalFloat(TfLiteContext * context,TfLiteNode * node,TfLiteFullyConnectedParams * params,OpData * data,const TfLiteTensor * input,const TfLiteTensor * filter,const TfLiteTensor * bias,TfLiteTensor * output)430 TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
431                        TfLiteFullyConnectedParams* params, OpData* data,
432                        const TfLiteTensor* input, const TfLiteTensor* filter,
433                        const TfLiteTensor* bias, TfLiteTensor* output) {
434   float output_activation_min, output_activation_max;
435   CalculateActivationRange(params->activation, &output_activation_min,
436                            &output_activation_max);
437 #define TF_LITE_FULLY_CONNECTED(type)                                         \
438   {                                                                           \
439     FullyConnectedParams op_params;                                           \
440     op_params.float_activation_min = output_activation_min;                   \
441     op_params.float_activation_max = output_activation_max;                   \
442     type::FullyConnected(op_params, GetTensorShape(input),                    \
443                          GetTensorData<float>(input), GetTensorShape(filter), \
444                          GetTensorData<float>(filter), GetTensorShape(bias),  \
445                          GetTensorData<float>(bias), GetTensorShape(output),  \
446                          GetTensorData<float>(output));                       \
447   }
448   if (kernel_type == kReference) {
449     TF_LITE_FULLY_CONNECTED(reference_ops);
450   } else if (kernel_type == kLegacyPie) {
451     return EvalPie(context, node, params, data, input, filter, bias, output);
452   } else {
453     TF_LITE_FULLY_CONNECTED(optimized_ops);
454   }
455 #undef TF_LITE_FULLY_CONNECTED
456 
457   return kTfLiteOk;
458 }
459 
460 #undef TF_LITE_MACRO_DISPATCH
461 
462 template <KernelType kernel_type>
Eval(TfLiteContext * context,TfLiteNode * node)463 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
464   auto* params =
465       reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data);
466   OpData* data = reinterpret_cast<OpData*>(node->user_data);
467 
468   const TfLiteTensor* input = GetInput(context, node, kInputTensor);
469   const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor);
470   const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
471   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
472 
473   switch (filter->type) {  // Already know in/out types are same.
474     case kTfLiteFloat32:
475       return EvalFloat<kernel_type>(context, node, params, data, input, filter,
476                                     bias, output);
477     case kTfLiteUInt8:
478       if (params->weights_format ==
479           kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8) {
480         TfLiteTensor* shuffled_input_workspace =
481             GetOutput(context, node, kShuffledInputWorkspaceTensor);
482         return EvalShuffledQuantized<kernel_type>(context, node, params, data,
483                                                   input, filter, bias, output,
484                                                   shuffled_input_workspace);
485       } else if (params->weights_format ==
486                  kTfLiteFullyConnectedWeightsFormatDefault) {
487         return EvalQuantized<kernel_type>(context, node, params, data, input,
488                                           filter, bias, output);
489       } else {
490         context->ReportError(context,
491                              "Unhandled fully-connected weights format");
492         return kTfLiteError;
493       }
494     case kTfLiteInt8:
495       if (params->weights_format == kTfLiteFullyConnectedWeightsFormatDefault) {
496         return EvalQuantized<kernel_type>(context, node, params, data, input,
497                                           filter, bias, output);
498       } else {
499         context->ReportError(context,
500                              "Unhandled fully-connected weights format");
501         return kTfLiteError;
502       }
503     default:
504       context->ReportError(context, "Type %d not currently supported.",
505                            filter->type);
506       return kTfLiteError;
507   }
508   return kTfLiteOk;
509 }
510 
511 }  // namespace fully_connected
512 
Register_FULLY_CONNECTED_REF()513 TfLiteRegistration* Register_FULLY_CONNECTED_REF() {
514   static TfLiteRegistration r = {
515       fully_connected::Init, fully_connected::Free, fully_connected::Prepare,
516       fully_connected::Eval<fully_connected::kReference>};
517   return &r;
518 }
519 
Register_FULLY_CONNECTED_GENERIC_OPT()520 TfLiteRegistration* Register_FULLY_CONNECTED_GENERIC_OPT() {
521   static TfLiteRegistration r = {
522       fully_connected::Init, fully_connected::Free, fully_connected::Prepare,
523       fully_connected::Eval<fully_connected::kGenericOptimized>};
524   return &r;
525 }
526 
527 // Legacy path for PIE clients.
Register_FULLY_CONNECTED_PIE()528 TfLiteRegistration* Register_FULLY_CONNECTED_PIE() {
529   static TfLiteRegistration r = {
530       fully_connected::Init, fully_connected::Free, fully_connected::Prepare,
531       fully_connected::Eval<fully_connected::kLegacyPie>};
532   return &r;
533 }
534 
Register_FULLY_CONNECTED()535 TfLiteRegistration* Register_FULLY_CONNECTED() {
536   return Register_FULLY_CONNECTED_GENERIC_OPT();
537 }
538 
539 }  // namespace builtin
540 }  // namespace ops
541 }  // namespace tflite
542