1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <cassert>
17 #include <cmath>
18 #include <cstdio>
19 #include <cstdlib>
20 #include <iostream>
21 #include <limits>
22
23 #include "tensorflow/lite/c/builtin_op_data.h"
24 #include "tensorflow/lite/c/c_api_internal.h"
25 #include "tensorflow/lite/kernels/activation_functor.h"
26 #include "tensorflow/lite/kernels/gemm_support.h"
27 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
28 #include "tensorflow/lite/kernels/internal/quantization_util.h"
29 #include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h"
30 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
31 #include "tensorflow/lite/kernels/internal/tensor.h"
32 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
33 #include "tensorflow/lite/kernels/kernel_util.h"
34 #include "tensorflow/lite/kernels/op_macros.h"
35
36 namespace tflite {
37 namespace ops {
38 namespace builtin {
39 namespace fully_connected {
40
41 // This file has four implementations of FullyConnected
42 enum KernelType {
43 kReference,
44 kGenericOptimized,
45 kLegacyPie, // Legacy path used by the PIE team and related clients.
46 };
47
48 struct OpData {
49 // The scaling factor from input to output (aka the 'real multiplier') can
50 // be represented as a fixed point multiplier plus a left shift.
51 int32_t output_multiplier;
52 int output_shift;
53 // The range of the fused activation layer. For example for kNone and
54 // uint8_t these would be 0 and 255.
55 int32_t output_activation_min;
56 int32_t output_activation_max;
57 // The index of the temporary tensor where the quantized inputs are cached.
58 int scratch_tensor_index;
59 };
60
61 constexpr int kInputTensor = 0;
62 constexpr int kWeightsTensor = 1;
63 constexpr int kBiasTensor = 2;
64 constexpr int kOutputTensor = 0;
65 constexpr int kShuffledInputWorkspaceTensor = 1;
66
Init(TfLiteContext * context,const char * buffer,size_t length)67 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
68 // This is a builtin op, so we don't use the contents in 'buffer', if any.
69 // Instead, we allocate a new object to carry information from Prepare() to
70 // Eval().
71 gemm_support::IncrementUsageCounter(context);
72 auto* op_data = new OpData();
73 context->AddTensors(context, /*tensors_to_add=*/2,
74 &op_data->scratch_tensor_index);
75 return op_data;
76 }
77
Free(TfLiteContext * context,void * buffer)78 void Free(TfLiteContext* context, void* buffer) {
79 gemm_support::DecrementUsageCounter(context);
80 delete reinterpret_cast<OpData*>(buffer);
81 }
82
Prepare(TfLiteContext * context,TfLiteNode * node)83 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
84 auto* params =
85 reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data);
86 OpData* data = reinterpret_cast<OpData*>(node->user_data);
87
88 // Check we have all the inputs and outputs we need.
89 TF_LITE_ENSURE_EQ(context, node->inputs->size, 3);
90 // Shuffled formats need a workspace to store the shuffled input activations.
91 const int expected_outputs_count =
92 params->weights_format == kTfLiteFullyConnectedWeightsFormatDefault ? 1
93 : 2;
94 TF_LITE_ENSURE_EQ(context, node->outputs->size, expected_outputs_count);
95
96 const TfLiteTensor* input = GetInput(context, node, kInputTensor);
97 const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor);
98 const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
99 TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
100
101 // Check all the parameters of tensor match within themselves and match the
102 // input configuration.
103 int input_size = 1;
104 for (int i = 0; i < input->dims->size; i++) {
105 input_size *= input->dims->data[i];
106 }
107
108 TF_LITE_ENSURE_EQ(context, NumDimensions(filter), 2);
109 const int batch_size = input_size / filter->dims->data[1];
110 const int num_units = filter->dims->data[0];
111
112 TF_LITE_ENSURE_EQ(context, input_size, batch_size * filter->dims->data[1]);
113 if (bias) {
114 TF_LITE_ENSURE_EQ(context, NumElements(bias), SizeOfDimension(filter, 0));
115 }
116
117 // Note that quantized inference requires that all tensors have their
118 // parameters set. This is usually done during quantized training.
119 TfLiteType data_type = input->type;
120 if (data_type != kTfLiteFloat32 && data_type != kTfLiteInt32) {
121 double real_multiplier = 0.0;
122 TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
123 context, input, filter, bias, output, &real_multiplier));
124 int exponent;
125 QuantizeMultiplier(real_multiplier, &data->output_multiplier, &exponent);
126 data->output_shift = -exponent;
127 TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized(
128 context, params->activation, output, &data->output_activation_min,
129 &data->output_activation_max));
130 }
131
132 // If we have to perform on-the-fly quantization (with quantized weights and
133 // float inputs) first we need to quantize the inputs. Allocate a temporary
134 // buffer to store the intermediate quantized values.
135 if (input->type == kTfLiteFloat32 &&
136 (filter->type == kTfLiteUInt8 || filter->type == kTfLiteInt8)) {
137 TfLiteIntArrayFree(node->temporaries);
138 node->temporaries = TfLiteIntArrayCreate(2);
139 node->temporaries->data[0] = data->scratch_tensor_index;
140
141 TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0);
142 input_quantized->type = filter->type;
143 input_quantized->allocation_type = kTfLiteArenaRw;
144
145 // TODO(raziel): add this logic to ResizeTensor.
146 if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
147 TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
148 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
149 input_quantized_size));
150 }
151 node->temporaries->data[1] = data->scratch_tensor_index + 1;
152 TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/1);
153 scaling_factors->type = kTfLiteFloat32;
154 scaling_factors->allocation_type = kTfLiteArenaRw;
155 int scaling_dims[1] = {batch_size};
156 if (!TfLiteIntArrayEqualsArray(scaling_factors->dims, 1, scaling_dims)) {
157 TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
158 scaling_factors_size->data[0] = batch_size;
159 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
160 scaling_factors_size));
161 }
162 }
163
164 // Resize output.
165 TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(2);
166 output_size_array->data[0] = batch_size;
167 output_size_array->data[1] = num_units;
168 TF_LITE_ENSURE_OK(context,
169 context->ResizeTensor(context, output, output_size_array));
170 return kTfLiteOk;
171 }
172
EvalPie(TfLiteContext * context,TfLiteNode * node,TfLiteFullyConnectedParams * params,OpData * data,const TfLiteTensor * input,const TfLiteTensor * filter,const TfLiteTensor * bias,TfLiteTensor * output)173 TfLiteStatus EvalPie(TfLiteContext* context, TfLiteNode* node,
174 TfLiteFullyConnectedParams* params, OpData* data,
175 const TfLiteTensor* input, const TfLiteTensor* filter,
176 const TfLiteTensor* bias, TfLiteTensor* output) {
177 int total_input_size = 1;
178 for (int i = 0; i < input->dims->size; i++) {
179 total_input_size *= input->dims->data[i];
180 }
181
182 int input_size = filter->dims->data[1];
183 const int batch_size = total_input_size / filter->dims->data[1];
184 const int num_units = filter->dims->data[0];
185
186 // Output = bias if bias tensor exists.
187 if (bias) {
188 tensor_utils::VectorBatchVectorAssign(bias->data.f, num_units, batch_size,
189 output->data.f);
190 } else {
191 tensor_utils::ZeroVector(output->data.f, batch_size * num_units);
192 }
193
194 // Compute output += weight * input
195 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
196 filter->data.f, num_units, input_size, input->data.f, batch_size,
197 output->data.f, /*result_stride=*/1);
198
199 // Apply activation function
200 tensor_utils::ApplyActivationToVector(output->data.f, batch_size * num_units,
201 params->activation, output->data.f);
202
203 return kTfLiteOk;
204 }
205
EvalHybrid(TfLiteContext * context,TfLiteNode * node,TfLiteFullyConnectedParams * params,OpData * data,const TfLiteTensor * input,const TfLiteTensor * filter,const TfLiteTensor * bias,TfLiteTensor * input_quantized,TfLiteTensor * scaling_factors,TfLiteTensor * output)206 TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node,
207 TfLiteFullyConnectedParams* params, OpData* data,
208 const TfLiteTensor* input, const TfLiteTensor* filter,
209 const TfLiteTensor* bias, TfLiteTensor* input_quantized,
210 TfLiteTensor* scaling_factors, TfLiteTensor* output) {
211 // Check the types for this hybrid Op.
212 TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
213 TF_LITE_ENSURE(context,
214 filter->type == kTfLiteUInt8 || filter->type == kTfLiteInt8);
215 if (bias) {
216 TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteFloat32);
217 }
218 TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32);
219
220 int total_input_size = 1;
221 for (int i = 0; i < input->dims->size; i++) {
222 total_input_size *= input->dims->data[i];
223 }
224
225 const int input_size = filter->dims->data[1];
226 const int batch_size = total_input_size / filter->dims->data[1];
227 const int num_units = filter->dims->data[0];
228
229 // Output = bias if bias tensor exists.
230 if (bias) {
231 tensor_utils::VectorBatchVectorAssign(bias->data.f, num_units, batch_size,
232 output->data.f);
233 } else {
234 tensor_utils::ZeroVector(output->data.f, batch_size * num_units);
235 }
236
237 // Save matrix multiplication computation for all zero input.
238 if (tensor_utils::IsZeroVector(input->data.f, total_input_size)) {
239 tensor_utils::ApplyActivationToVector(output->data.f,
240 batch_size * num_units,
241 params->activation, output->data.f);
242 return kTfLiteOk;
243 }
244
245 // Quantize input from float to uint8 + quantization params (scaling factor).
246 float unused_min, unused_max;
247 float* scaling_factors_ptr = scaling_factors->data.f;
248 int8_t* quant_data;
249 int8_t* filter_data;
250 if (filter->type == kTfLiteUInt8) {
251 quant_data = reinterpret_cast<int8_t*>(input_quantized->data.uint8);
252 filter_data = reinterpret_cast<int8_t*>(filter->data.uint8);
253 } else {
254 quant_data = input_quantized->data.int8;
255 filter_data = filter->data.int8;
256 }
257
258 // Quantize each batch independently.
259 for (int b = 0; b < batch_size; ++b) {
260 const int offset = b * input_size;
261 tensor_utils::SymmetricQuantizeFloats(input->data.f + offset, input_size,
262 quant_data + offset, &unused_min,
263 &unused_max, &scaling_factors_ptr[b]);
264 // Incorporate scaling of the filter.
265 scaling_factors_ptr[b] *= filter->params.scale;
266 }
267
268 // Compute output += weight * quantized_input
269 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
270 filter_data, num_units, input_size, quant_data, scaling_factors_ptr,
271 batch_size, output->data.f,
272 /*result_stride=*/1);
273
274 // Apply activation function to floats.
275 tensor_utils::ApplyActivationToVector(output->data.f, batch_size * num_units,
276 params->activation, output->data.f);
277 return kTfLiteOk;
278 }
279
280 #define TF_LITE_MACRO_DISPATCH(macro_name, params, target_namespace) \
281 if (params->activation == kTfLiteActNone) { \
282 macro_name(target_namespace, kNone); \
283 } \
284 if (params->activation == kTfLiteActRelu) { \
285 macro_name(target_namespace, kRelu); \
286 } \
287 if (params->activation == kTfLiteActRelu6) { \
288 macro_name(target_namespace, kRelu6); \
289 }
290
291 namespace {
FullyConnectedInt8(const OpData * data,const TfLiteTensor * input,const TfLiteTensor * filter,const TfLiteTensor * bias,TfLiteTensor * output,gemmlowp::GemmContext * gemm_context)292 void FullyConnectedInt8(const OpData* data, const TfLiteTensor* input,
293 const TfLiteTensor* filter, const TfLiteTensor* bias,
294 TfLiteTensor* output,
295 gemmlowp::GemmContext* gemm_context) {
296 FullyConnectedParams op_params;
297 op_params.input_offset = -input->params.zero_point;
298 op_params.weights_offset = -filter->params.zero_point;
299 op_params.output_offset = output->params.zero_point;
300 op_params.output_multiplier = data->output_multiplier;
301 op_params.output_shift = -data->output_shift;
302 op_params.quantized_activation_min = data->output_activation_min;
303 op_params.quantized_activation_max = data->output_activation_max;
304 reference_integer_ops::FullyConnected(
305 op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
306 GetTensorShape(filter), GetTensorData<int8_t>(filter),
307 GetTensorShape(bias), GetTensorData<int32_t>(bias),
308 GetTensorShape(output), GetTensorData<int8_t>(output), gemm_context);
309 }
310 } // namespace
311
312 template <KernelType kernel_type>
EvalQuantized(TfLiteContext * context,TfLiteNode * node,TfLiteFullyConnectedParams * params,OpData * data,const TfLiteTensor * input,const TfLiteTensor * filter,const TfLiteTensor * bias,TfLiteTensor * output)313 TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
314 TfLiteFullyConnectedParams* params, OpData* data,
315 const TfLiteTensor* input,
316 const TfLiteTensor* filter, const TfLiteTensor* bias,
317 TfLiteTensor* output) {
318 gemmlowp::GemmContext* gemm_context = gemm_support::GetFromContext(context);
319
320 int32_t input_offset = -input->params.zero_point;
321 int32_t filter_offset = -filter->params.zero_point;
322 int32_t output_offset = output->params.zero_point;
323 #define TF_LITE_FULLY_CONNECTED(type, output_data_type) \
324 { \
325 FullyConnectedParams op_params; \
326 op_params.input_offset = input_offset; \
327 op_params.weights_offset = filter_offset; \
328 op_params.output_offset = output_offset; \
329 op_params.output_multiplier = data->output_multiplier; \
330 op_params.output_shift = -data->output_shift; \
331 op_params.quantized_activation_min = data->output_activation_min; \
332 op_params.quantized_activation_max = data->output_activation_max; \
333 type::FullyConnected( \
334 op_params, GetTensorShape(input), GetTensorData<uint8_t>(input), \
335 GetTensorShape(filter), GetTensorData<uint8_t>(filter), \
336 GetTensorShape(bias), GetTensorData<int32_t>(bias), \
337 GetTensorShape(output), GetTensorData<output_data_type>(output), \
338 gemm_context); \
339 }
340 // Only the Pie path supports quantized models and float inputs/outputs.
341 if (input->type == kTfLiteFloat32) {
342 TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0);
343 TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/1);
344 return EvalHybrid(context, node, params, data, input, filter, bias,
345 input_quantized, scaling_factors, output);
346 } else if (kernel_type == kReference) {
347 switch (output->type) {
348 case kTfLiteUInt8:
349 TF_LITE_FULLY_CONNECTED(reference_ops, uint8_t);
350 break;
351 case kTfLiteInt8:
352 FullyConnectedInt8(data, input, filter, bias, output, gemm_context);
353 break;
354 case kTfLiteInt16:
355 TF_LITE_FULLY_CONNECTED(reference_ops, int16_t);
356 break;
357 default:
358 context->ReportError(
359 context,
360 "Quantized FullyConnected expects output data type uint8 or int16");
361 return kTfLiteError;
362 }
363 } else {
364 switch (output->type) {
365 case kTfLiteUInt8:
366 TF_LITE_FULLY_CONNECTED(optimized_ops, uint8_t);
367 break;
368 case kTfLiteInt8:
369 FullyConnectedInt8(data, input, filter, bias, output, gemm_context);
370 break;
371 case kTfLiteInt16:
372 TF_LITE_FULLY_CONNECTED(optimized_ops, int16_t);
373 break;
374 default:
375 context->ReportError(
376 context,
377 "Quantized FullyConnected expects output data type uint8 or int16");
378 return kTfLiteError;
379 }
380 }
381 #undef TF_LITE_FULLY_CONNECTED
382
383 return kTfLiteOk;
384 }
385
386 template <KernelType kernel_type>
EvalShuffledQuantized(TfLiteContext * context,TfLiteNode * node,TfLiteFullyConnectedParams * params,OpData * data,const TfLiteTensor * input,const TfLiteTensor * filter,const TfLiteTensor * bias,TfLiteTensor * output,TfLiteTensor * shuffled_input_workspace)387 TfLiteStatus EvalShuffledQuantized(TfLiteContext* context, TfLiteNode* node,
388 TfLiteFullyConnectedParams* params,
389 OpData* data, const TfLiteTensor* input,
390 const TfLiteTensor* filter,
391 const TfLiteTensor* bias,
392 TfLiteTensor* output,
393 TfLiteTensor* shuffled_input_workspace) {
394 gemmlowp::GemmContext* gemm_context = gemm_support::GetFromContext(context);
395
396 // TODO(b/110697972) decide more consistently if / how / where we want
397 // to perform this kind of runtime data type checks.
398 if (input->type != kTfLiteUInt8 || filter->type != kTfLiteUInt8 ||
399 bias->type != kTfLiteInt32 || output->type != kTfLiteInt16 ||
400 shuffled_input_workspace->type != kTfLiteUInt8) {
401 context->ReportError(context, "Unexpected data type");
402 return kTfLiteError;
403 }
404
405 #define TF_LITE_SHUFFLED_FULLY_CONNECTED(type) \
406 { \
407 FullyConnectedParams op_params; \
408 op_params.output_multiplier = data->output_multiplier; \
409 op_params.output_shift = -data->output_shift; \
410 op_params.quantized_activation_min = data->output_activation_min; \
411 op_params.quantized_activation_max = data->output_activation_max; \
412 type::ShuffledFullyConnected( \
413 op_params, GetTensorShape(input), GetTensorData<uint8_t>(input), \
414 GetTensorShape(filter), GetTensorData<uint8_t>(filter), \
415 GetTensorShape(bias), GetTensorData<int32_t>(bias), \
416 GetTensorShape(output), GetTensorData<int16_t>(output), \
417 GetTensorData<uint8_t>(shuffled_input_workspace), gemm_context); \
418 }
419 if (kernel_type == kReference) {
420 TF_LITE_SHUFFLED_FULLY_CONNECTED(reference_ops);
421 } else {
422 TF_LITE_SHUFFLED_FULLY_CONNECTED(optimized_ops);
423 }
424 #undef TF_LITE_SHUFFLED_FULLY_CONNECTED
425
426 return kTfLiteOk;
427 }
428
429 template <KernelType kernel_type>
EvalFloat(TfLiteContext * context,TfLiteNode * node,TfLiteFullyConnectedParams * params,OpData * data,const TfLiteTensor * input,const TfLiteTensor * filter,const TfLiteTensor * bias,TfLiteTensor * output)430 TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
431 TfLiteFullyConnectedParams* params, OpData* data,
432 const TfLiteTensor* input, const TfLiteTensor* filter,
433 const TfLiteTensor* bias, TfLiteTensor* output) {
434 float output_activation_min, output_activation_max;
435 CalculateActivationRange(params->activation, &output_activation_min,
436 &output_activation_max);
437 #define TF_LITE_FULLY_CONNECTED(type) \
438 { \
439 FullyConnectedParams op_params; \
440 op_params.float_activation_min = output_activation_min; \
441 op_params.float_activation_max = output_activation_max; \
442 type::FullyConnected(op_params, GetTensorShape(input), \
443 GetTensorData<float>(input), GetTensorShape(filter), \
444 GetTensorData<float>(filter), GetTensorShape(bias), \
445 GetTensorData<float>(bias), GetTensorShape(output), \
446 GetTensorData<float>(output)); \
447 }
448 if (kernel_type == kReference) {
449 TF_LITE_FULLY_CONNECTED(reference_ops);
450 } else if (kernel_type == kLegacyPie) {
451 return EvalPie(context, node, params, data, input, filter, bias, output);
452 } else {
453 TF_LITE_FULLY_CONNECTED(optimized_ops);
454 }
455 #undef TF_LITE_FULLY_CONNECTED
456
457 return kTfLiteOk;
458 }
459
460 #undef TF_LITE_MACRO_DISPATCH
461
462 template <KernelType kernel_type>
Eval(TfLiteContext * context,TfLiteNode * node)463 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
464 auto* params =
465 reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data);
466 OpData* data = reinterpret_cast<OpData*>(node->user_data);
467
468 const TfLiteTensor* input = GetInput(context, node, kInputTensor);
469 const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor);
470 const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
471 TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
472
473 switch (filter->type) { // Already know in/out types are same.
474 case kTfLiteFloat32:
475 return EvalFloat<kernel_type>(context, node, params, data, input, filter,
476 bias, output);
477 case kTfLiteUInt8:
478 if (params->weights_format ==
479 kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8) {
480 TfLiteTensor* shuffled_input_workspace =
481 GetOutput(context, node, kShuffledInputWorkspaceTensor);
482 return EvalShuffledQuantized<kernel_type>(context, node, params, data,
483 input, filter, bias, output,
484 shuffled_input_workspace);
485 } else if (params->weights_format ==
486 kTfLiteFullyConnectedWeightsFormatDefault) {
487 return EvalQuantized<kernel_type>(context, node, params, data, input,
488 filter, bias, output);
489 } else {
490 context->ReportError(context,
491 "Unhandled fully-connected weights format");
492 return kTfLiteError;
493 }
494 case kTfLiteInt8:
495 if (params->weights_format == kTfLiteFullyConnectedWeightsFormatDefault) {
496 return EvalQuantized<kernel_type>(context, node, params, data, input,
497 filter, bias, output);
498 } else {
499 context->ReportError(context,
500 "Unhandled fully-connected weights format");
501 return kTfLiteError;
502 }
503 default:
504 context->ReportError(context, "Type %d not currently supported.",
505 filter->type);
506 return kTfLiteError;
507 }
508 return kTfLiteOk;
509 }
510
511 } // namespace fully_connected
512
Register_FULLY_CONNECTED_REF()513 TfLiteRegistration* Register_FULLY_CONNECTED_REF() {
514 static TfLiteRegistration r = {
515 fully_connected::Init, fully_connected::Free, fully_connected::Prepare,
516 fully_connected::Eval<fully_connected::kReference>};
517 return &r;
518 }
519
Register_FULLY_CONNECTED_GENERIC_OPT()520 TfLiteRegistration* Register_FULLY_CONNECTED_GENERIC_OPT() {
521 static TfLiteRegistration r = {
522 fully_connected::Init, fully_connected::Free, fully_connected::Prepare,
523 fully_connected::Eval<fully_connected::kGenericOptimized>};
524 return &r;
525 }
526
527 // Legacy path for PIE clients.
Register_FULLY_CONNECTED_PIE()528 TfLiteRegistration* Register_FULLY_CONNECTED_PIE() {
529 static TfLiteRegistration r = {
530 fully_connected::Init, fully_connected::Free, fully_connected::Prepare,
531 fully_connected::Eval<fully_connected::kLegacyPie>};
532 return &r;
533 }
534
Register_FULLY_CONNECTED()535 TfLiteRegistration* Register_FULLY_CONNECTED() {
536 return Register_FULLY_CONNECTED_GENERIC_OPT();
537 }
538
539 } // namespace builtin
540 } // namespace ops
541 } // namespace tflite
542