1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/kernels/internal/optimized/integer_ops/fully_connected.h"
17
18 #include <algorithm>
19 #include <cstddef>
20 #include <cstdint>
21
22 #include "tensorflow/lite/c/builtin_op_data.h"
23 #include "tensorflow/lite/c/common.h"
24 #include "tensorflow/lite/kernels/cpu_backend_context.h"
25 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
26 #include "tensorflow/lite/kernels/internal/optimized/sparse_ops/fully_connected.h"
27 #include "tensorflow/lite/kernels/internal/quantization_util.h"
28 #include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
29 #include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h"
30 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
31 #include "tensorflow/lite/kernels/internal/reference/sparse_ops/fully_connected.h"
32 #include "tensorflow/lite/kernels/internal/tensor.h"
33 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
34 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
35 #include "tensorflow/lite/kernels/internal/types.h"
36 #include "tensorflow/lite/kernels/kernel_util.h"
37
38 namespace tflite {
39 namespace ops {
40 namespace builtin {
41 namespace fully_connected {
42
43 namespace {
SupportedSparsityFormat(const TfLiteSparsity & sparsity)44 bool SupportedSparsityFormat(const TfLiteSparsity& sparsity) {
45 if (sparsity.dim_metadata[0].format == kTfLiteDimDense &&
46 sparsity.dim_metadata[1].format == kTfLiteDimSparseCSR) {
47 return true;
48 }
49
50 return false;
51 }
52
53 static const int kDimMetadataSizeRandomSparse = 2;
54 static const int kDimMetadataSizeBlockSparse = 3;
55
CreateLedgerTensor(const TfLiteSparsity * sparsity,TfLiteContext * context,TfLiteTensor * ledger)56 TfLiteStatus CreateLedgerTensor(const TfLiteSparsity* sparsity,
57 TfLiteContext* context, TfLiteTensor* ledger) {
58 TF_LITE_ENSURE(context, sparsity != nullptr);
59 ledger->type = kTfLiteUInt8;
60 ledger->allocation_type = kTfLiteArenaRwPersistent;
61 TfLiteIntArray* ledger_size = TfLiteIntArrayCreate(1);
62 ledger_size->data[0] = sparsity->dim_metadata[1].array_indices->size +
63 sparsity->dim_metadata[1].array_segments->size - 1;
64 return context->ResizeTensor(context, ledger, ledger_size);
65 }
66
PopulateLedgerData(const TfLiteSparsity * sparsity,TfLiteContext * context,uint8_t * ledger_data)67 TfLiteStatus PopulateLedgerData(const TfLiteSparsity* sparsity,
68 TfLiteContext* context, uint8_t* ledger_data) {
69 TF_LITE_ENSURE(context, sparsity != nullptr);
70 const auto* array_segments = sparsity->dim_metadata[1].array_segments;
71 const auto* array_indices = sparsity->dim_metadata[1].array_indices;
72 int output_data_ptr = 0;
73
74 for (int i = 0; i < array_segments->size - 1; i++) {
75 int row_start = array_segments->data[i];
76 int row_end = array_segments->data[i + 1];
77 if (row_end - row_start > UINT8_MAX) {
78 return kTfLiteError;
79 }
80 // Copy num of non-zero blocks in row i.
81 ledger_data[output_data_ptr] = static_cast<uint8_t>(row_end - row_start);
82 output_data_ptr++;
83
84 for (int j = row_start; j < row_end; j++) {
85 if (array_indices->data[j] > UINT8_MAX) {
86 return kTfLiteError;
87 }
88 // Copy indices of non-zero blocks in row i.
89 ledger_data[output_data_ptr] =
90 static_cast<uint8_t>(array_indices->data[j]);
91 output_data_ptr++;
92 }
93 }
94 return kTfLiteOk;
95 }
96
97 } // namespace
98
99 // This file has four implementations of FullyConnected
100 enum KernelType {
101 kReference,
102 kGenericOptimized,
103 kLegacyPie, // Legacy path used by the PIE team and related clients.
104 };
105
106 struct OpData {
107 // The scaling factor from input to output (aka the 'real multiplier') can
108 // be represented as a fixed point multiplier plus a left shift.
109 int32_t output_multiplier;
110 int output_shift;
111 // The range of the fused activation layer. For example for kNone and
112 // uint8_t these would be 0 and 255.
113 int32_t output_activation_min;
114 int32_t output_activation_max;
115 // The index of the temporary tensor where the quantized inputs are cached.
116 int scratch_tensor_index;
117 bool compute_row_sums = false;
118 // Only used for sparse hybrid fully connected kernels.
119 bool ledger_initialized;
120 };
121
122 constexpr int kInputTensor = 0;
123 constexpr int kWeightsTensor = 1;
124 constexpr int kBiasTensor = 2;
125 constexpr int kOutputTensor = 0;
126 constexpr int kShuffledInputWorkspaceTensor = 1;
127
CheckTypes(TfLiteContext * context,const TfLiteTensor * input,const TfLiteTensor * filter,const TfLiteTensor * bias,TfLiteTensor * output,TfLiteFullyConnectedParams * params)128 inline TfLiteStatus CheckTypes(TfLiteContext* context,
129 const TfLiteTensor* input,
130 const TfLiteTensor* filter,
131 const TfLiteTensor* bias, TfLiteTensor* output,
132 TfLiteFullyConnectedParams* params) {
133 const bool is_quantized =
134 ((filter->type == kTfLiteUInt8) || (filter->type == kTfLiteInt8));
135 const bool is_hybrid = is_quantized && (input->type == kTfLiteFloat32);
136 const bool is_shuffled =
137 is_quantized && (params->weights_format ==
138 kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8);
139
140 // optional bias tensor.
141 const bool is_optional_bias_float = !bias || (bias->type == kTfLiteFloat32);
142 const bool is_optional_bias_int =
143 !bias || (bias->type == kTfLiteInt32) || (bias->type == kTfLiteInt64);
144
145 if (is_quantized) {
146 if (is_shuffled) {
147 TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteUInt8);
148 TF_LITE_ENSURE_TYPES_EQ(context, filter->type, kTfLiteUInt8);
149 TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteInt16);
150 TF_LITE_ENSURE_EQ(context, is_optional_bias_int, true);
151 } else if (is_hybrid) {
152 TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
153 TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32);
154 TF_LITE_ENSURE_EQ(context, is_optional_bias_float, true);
155 } else {
156 TF_LITE_ENSURE(context, input->type == kTfLiteUInt8 ||
157 input->type == kTfLiteInt8 ||
158 input->type == kTfLiteInt16);
159 TF_LITE_ENSURE(context, output->type == kTfLiteUInt8 ||
160 output->type == kTfLiteInt8 ||
161 output->type == kTfLiteInt16);
162 TF_LITE_ENSURE_EQ(context, is_optional_bias_int, true);
163 }
164 } else {
165 // Only float32 is supported currently
166 TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
167 TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32);
168 TF_LITE_ENSURE_TYPES_EQ(context, filter->type, kTfLiteFloat32);
169 TF_LITE_ENSURE_EQ(context, is_optional_bias_float, true);
170 }
171
172 return kTfLiteOk;
173 }
174
Init(TfLiteContext * context,const char * buffer,size_t length)175 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
176 // This is a builtin op, so we don't use the contents in 'buffer', if any.
177 // Instead, we allocate a new object to carry information from Prepare() to
178 // Eval().
179 auto* op_data = new OpData();
180 context->AddTensors(context, /*tensors_to_add=*/6,
181 &op_data->scratch_tensor_index);
182 return op_data;
183 }
184
Free(TfLiteContext * context,void * buffer)185 void Free(TfLiteContext* context, void* buffer) {
186 delete reinterpret_cast<OpData*>(buffer);
187 }
188
PrepareImpl(TfLiteContext * context,TfLiteNode * node)189 TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node) {
190 auto* params =
191 reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data);
192 OpData* data = reinterpret_cast<OpData*>(node->user_data);
193 // Check we have all the inputs and outputs we need.
194 TF_LITE_ENSURE(context, node->inputs->size == 2 || node->inputs->size == 3);
195 // Shuffled formats need a workspace to store the shuffled input activations.
196 const int expected_outputs_count =
197 params->weights_format == kTfLiteFullyConnectedWeightsFormatDefault ? 1
198 : 2;
199 TF_LITE_ENSURE_EQ(context, node->outputs->size, expected_outputs_count);
200
201 const TfLiteTensor* input;
202 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
203 const TfLiteTensor* filter;
204 TF_LITE_ENSURE_OK(context,
205 GetInputSafe(context, node, kWeightsTensor, &filter));
206 const TfLiteTensor* bias =
207 (node->inputs->size == 3)
208 ? GetOptionalInputTensor(context, node, kBiasTensor)
209 : nullptr;
210 TfLiteTensor* output;
211 TF_LITE_ENSURE_OK(context,
212 GetOutputSafe(context, node, kOutputTensor, &output));
213
214 // Check proper datatype match among all Input Tensors
215 TF_LITE_ENSURE_STATUS(
216 CheckTypes(context, input, filter, bias, output, params));
217
218 // Check all the parameters of tensor match within themselves and match the
219 // input configuration.
220 int input_size = 1;
221 for (int i = 0; i < input->dims->size; i++) {
222 input_size *= input->dims->data[i];
223 }
224
225 TF_LITE_ENSURE_EQ(context, NumDimensions(filter), 2);
226 TF_LITE_ENSURE(context, filter->dims->data[1] != 0);
227 const int batch_size = input_size / filter->dims->data[1];
228 const int num_units = filter->dims->data[0];
229
230 if (bias) {
231 TF_LITE_ENSURE_EQ(context, NumElements(bias), SizeOfDimension(filter, 0));
232 }
233
234 // Note that quantized inference requires that all tensors have their
235 // parameters set. This is usually done during quantized training.
236 if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8 ||
237 input->type == kTfLiteInt16) {
238 double real_multiplier = 0.0;
239 TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
240 context, input, filter, bias, output, &real_multiplier));
241 int exponent;
242 QuantizeMultiplier(real_multiplier, &data->output_multiplier, &exponent);
243 data->output_shift = exponent;
244 TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized(
245 context, params->activation, output, &data->output_activation_min,
246 &data->output_activation_max));
247 }
248
249 if (input->type == kTfLiteInt16 && output->type == kTfLiteInt16) {
250 TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0);
251 TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
252 }
253
254 // If we have to perform on-the-fly quantization (with quantized weights and
255 // float inputs) first we need to quantize the inputs. Allocate a temporary
256 // buffer to store the intermediate quantized values.
257 // Additionally, we allocate a temporary buffer to store the accumulated
258 // quantized values prior to multiplication by the scaling factor.
259 const bool is_hybrid =
260 (input->type == kTfLiteFloat32 &&
261 (filter->type == kTfLiteUInt8 || filter->type == kTfLiteInt8));
262 const bool is_sparse = filter->sparsity != nullptr;
263 if (is_hybrid) {
264 TfLiteIntArrayFree(node->temporaries);
265 data->compute_row_sums = true;
266 if (is_sparse) {
267 node->temporaries = TfLiteIntArrayCreate(6);
268 } else {
269 node->temporaries = TfLiteIntArrayCreate(5);
270 }
271 node->temporaries->data[0] = data->scratch_tensor_index;
272
273 TfLiteTensor* input_quantized;
274 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/0,
275 &input_quantized));
276 input_quantized->type = filter->type;
277 input_quantized->allocation_type = kTfLiteArenaRw;
278
279 TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
280 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
281 input_quantized_size));
282
283 node->temporaries->data[1] = data->scratch_tensor_index + 1;
284 TfLiteTensor* scaling_factors;
285 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/1,
286 &scaling_factors));
287 scaling_factors->type = kTfLiteFloat32;
288 scaling_factors->allocation_type = kTfLiteArenaRw;
289
290 int scaling_dims[1] = {batch_size};
291 if (!TfLiteIntArrayEqualsArray(scaling_factors->dims, 1, scaling_dims)) {
292 TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
293 scaling_factors_size->data[0] = batch_size;
294 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
295 scaling_factors_size));
296 }
297
298 node->temporaries->data[2] = data->scratch_tensor_index + 2;
299 TfLiteTensor* accum_scratch;
300 TF_LITE_ENSURE_OK(
301 context, GetTemporarySafe(context, node, /*index=*/2, &accum_scratch));
302 accum_scratch->type = kTfLiteInt32;
303 accum_scratch->allocation_type = kTfLiteArenaRw;
304 int accum_scratch_dims[2] = {num_units, batch_size};
305 if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2,
306 accum_scratch_dims)) {
307 TfLiteIntArray* accum_size = TfLiteIntArrayCreate(2);
308 accum_size->data[0] = num_units;
309 accum_size->data[1] = batch_size;
310 TF_LITE_ENSURE_OK(
311 context, context->ResizeTensor(context, accum_scratch, accum_size));
312 }
313
314 node->temporaries->data[3] = data->scratch_tensor_index + 3;
315 TfLiteTensor* input_offsets;
316 TF_LITE_ENSURE_OK(
317 context, GetTemporarySafe(context, node, /*index=*/3, &input_offsets));
318 input_offsets->type = kTfLiteInt32;
319 input_offsets->allocation_type = kTfLiteArenaRw;
320 if (!TfLiteIntArrayEqualsArray(input_offsets->dims, 1, scaling_dims)) {
321 TfLiteIntArray* input_offsets_size = TfLiteIntArrayCreate(1);
322 input_offsets_size->data[0] = batch_size;
323 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_offsets,
324 input_offsets_size));
325 }
326 node->temporaries->data[4] = data->scratch_tensor_index + 4;
327 TfLiteTensor* row_sums;
328 TF_LITE_ENSURE_OK(context,
329 GetTemporarySafe(context, node, /*index=*/4, &row_sums));
330 row_sums->type = kTfLiteInt32;
331 row_sums->allocation_type = kTfLiteArenaRwPersistent;
332 int row_sums_dims[1] = {num_units};
333 if (!TfLiteIntArrayEqualsArray(row_sums->dims, 1, row_sums_dims)) {
334 TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(1);
335 row_sums_size->data[0] = row_sums_dims[0];
336 TF_LITE_ENSURE_OK(
337 context, context->ResizeTensor(context, row_sums, row_sums_size));
338 }
339
340 if (is_sparse) {
341 data->ledger_initialized = false;
342 node->temporaries->data[5] = data->scratch_tensor_index + 5;
343 TfLiteTensor* filter_ledger =
344 &context->tensors[node->temporaries->data[5]];
345 auto status =
346 CreateLedgerTensor(filter->sparsity, context, filter_ledger);
347 if (status != kTfLiteOk) return status;
348 }
349 }
350
351 // Resize output.
352 TfLiteIntArray* output_size_array = nullptr;
353 if (params->keep_num_dims) {
354 // When number of dimensions are kept the filter operates along the last
355 // dimensions. In other words, for an input tensor with shape
356 // [batch_size, ..., n_inputs] and a filter of shape [n_inputs, n_units]
357 // this Op produces an output of shape [batch_size, ..., n_units].
358 TF_LITE_ENSURE_EQ(context, input->dims->data[input->dims->size - 1],
359 SizeOfDimension(filter, 1));
360 output_size_array = TfLiteIntArrayCopy(input->dims);
361 output_size_array->data[output_size_array->size - 1] = num_units;
362 } else {
363 // Otherwise, the output is (potentially flattened to) a 2-D matrix.
364 output_size_array = TfLiteIntArrayCreate(2);
365 output_size_array->data[0] = batch_size;
366 output_size_array->data[1] = num_units;
367 }
368 TF_LITE_ENSURE_OK(context,
369 context->ResizeTensor(context, output, output_size_array));
370
371 return kTfLiteOk;
372 }
373
374 template <KernelType kernel_type>
Prepare(TfLiteContext * context,TfLiteNode * node)375 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
376 // Check for supported activation types.
377 auto* params =
378 reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data);
379 const TfLiteTensor* filter;
380 TF_LITE_ENSURE_OK(context,
381 GetInputSafe(context, node, kWeightsTensor, &filter));
382 const TfLiteTensor* input;
383 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
384 const bool is_quantized =
385 ((filter->type == kTfLiteUInt8) || (filter->type == kTfLiteInt8));
386 const bool is_hybrid = is_quantized && (input->type == kTfLiteFloat32);
387 const bool is_pie = kernel_type == kLegacyPie;
388
389 // Pie and hybrid path supports all kinds of fused activations, otherwise only
390 // clipping activations are supported.
391 if (!is_pie && !is_hybrid) {
392 TF_LITE_ENSURE(context, params->activation == kTfLiteActNone ||
393 params->activation == kTfLiteActRelu ||
394 params->activation == kTfLiteActReluN1To1 ||
395 params->activation == kTfLiteActRelu6);
396 }
397 return PrepareImpl(context, node);
398 }
399
EvalPie(TfLiteContext * context,TfLiteNode * node,TfLiteFullyConnectedParams * params,OpData * data,const TfLiteTensor * input,const TfLiteTensor * filter,const TfLiteTensor * bias,TfLiteTensor * output)400 TfLiteStatus EvalPie(TfLiteContext* context, TfLiteNode* node,
401 TfLiteFullyConnectedParams* params, OpData* data,
402 const TfLiteTensor* input, const TfLiteTensor* filter,
403 const TfLiteTensor* bias, TfLiteTensor* output) {
404 int total_input_size = 1;
405 for (int i = 0; i < input->dims->size; i++) {
406 total_input_size *= input->dims->data[i];
407 }
408
409 int input_size = filter->dims->data[1];
410 const int batch_size = total_input_size / filter->dims->data[1];
411 const int num_units = filter->dims->data[0];
412
413 // Output = bias if bias tensor exists.
414 if (bias) {
415 tensor_utils::VectorBatchVectorAssign(GetTensorData<float>(bias), num_units,
416 batch_size,
417 GetTensorData<float>(output));
418 } else {
419 std::fill_n(GetTensorData<float>(output), batch_size * num_units, 0.0f);
420 }
421
422 // Compute output += weight * input
423 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
424 GetTensorData<float>(filter), num_units, input_size,
425 GetTensorData<float>(input), batch_size, GetTensorData<float>(output));
426
427 // Apply activation function
428 tensor_utils::ApplyActivationToVector(
429 GetTensorData<float>(output), batch_size * num_units, params->activation,
430 GetTensorData<float>(output));
431
432 return kTfLiteOk;
433 }
434
EvalHybridDense(TfLiteContext * context,TfLiteNode * node,TfLiteFullyConnectedParams * params,OpData * data,const TfLiteTensor * input,const TfLiteTensor * filter,const TfLiteTensor * bias,TfLiteTensor * input_quantized,TfLiteTensor * scaling_factors,TfLiteTensor * accum_scratch,TfLiteTensor * row_sums,TfLiteTensor * input_offsets,TfLiteTensor * output)435 TfLiteStatus EvalHybridDense(
436 TfLiteContext* context, TfLiteNode* node,
437 TfLiteFullyConnectedParams* params, OpData* data, const TfLiteTensor* input,
438 const TfLiteTensor* filter, const TfLiteTensor* bias,
439 TfLiteTensor* input_quantized, TfLiteTensor* scaling_factors,
440 TfLiteTensor* accum_scratch, TfLiteTensor* row_sums,
441 TfLiteTensor* input_offsets, TfLiteTensor* output) {
442 int total_input_size = 1;
443 for (int i = 0; i < input->dims->size; i++) {
444 total_input_size *= input->dims->data[i];
445 }
446
447 const int input_size = filter->dims->data[1];
448 const int batch_size = total_input_size / filter->dims->data[1];
449 const int num_units = filter->dims->data[0];
450
451 // Output = bias if bias tensor exists.
452 if (bias) {
453 tensor_utils::VectorBatchVectorAssign(GetTensorData<float>(bias), num_units,
454 batch_size,
455 GetTensorData<float>(output));
456 } else {
457 std::fill_n(GetTensorData<float>(output), batch_size * num_units, 0.0f);
458 }
459
460 // Save matrix multiplication computation for all zero input.
461 if (tensor_utils::IsZeroVector(GetTensorData<float>(input),
462 total_input_size)) {
463 tensor_utils::ApplyActivationToVector(
464 GetTensorData<float>(output), batch_size * num_units,
465 params->activation, GetTensorData<float>(output));
466 return kTfLiteOk;
467 }
468
469 // Quantize input from float to uint8 + quantization params (scaling factor).
470 float* scaling_factors_ptr = GetTensorData<float>(scaling_factors);
471 int32_t* input_offset_ptr = nullptr;
472 int32_t* row_sums_ptr = nullptr;
473 if (params->asymmetric_quantize_inputs) {
474 input_offset_ptr = GetTensorData<int32_t>(input_offsets);
475 row_sums_ptr = GetTensorData<int32_t>(row_sums);
476 }
477 int8_t* quant_data = GetTensorData<int8_t>(input_quantized);
478 const int8_t* filter_data = GetTensorData<int8_t>(filter);
479 const float* input_ptr = GetTensorData<float>(input);
480 tensor_utils::BatchQuantizeFloats(
481 input_ptr, batch_size, input_size, quant_data, scaling_factors_ptr,
482 input_offset_ptr, params->asymmetric_quantize_inputs);
483 for (int b = 0; b < batch_size; ++b) {
484 // Incorporate scaling of the filter.
485 scaling_factors_ptr[b] *= filter->params.scale;
486 }
487
488 // Compute output += weight * quantized_input
489 int32_t* scratch = GetTensorData<int32_t>(accum_scratch);
490 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
491 filter_data, num_units, input_size, quant_data, scaling_factors_ptr,
492 batch_size, GetTensorData<float>(output), /*per_channel_scale=*/nullptr,
493 input_offset_ptr, scratch, row_sums_ptr, &data->compute_row_sums,
494 CpuBackendContext::GetFromContext(context));
495
496 // Apply activation function to floats.
497 tensor_utils::ApplyActivationToVector(
498 GetTensorData<float>(output), batch_size * num_units, params->activation,
499 GetTensorData<float>(output));
500 return kTfLiteOk;
501 }
502
EvalSparseHybridImpl(TfLiteContext * context,TfLiteNode * node,TfLiteFullyConnectedParams * params,OpData * data,const TfLiteTensor * input,const TfLiteTensor * filter,const TfLiteTensor * bias,int thread_start,int thread_end,TfLiteTensor * input_quantized,TfLiteTensor * scaling_factors,TfLiteTensor * accum_scratch,TfLiteTensor * row_sums,TfLiteTensor * input_offsets,TfLiteTensor * output)503 void EvalSparseHybridImpl(TfLiteContext* context, TfLiteNode* node,
504 TfLiteFullyConnectedParams* params, OpData* data,
505 const TfLiteTensor* input, const TfLiteTensor* filter,
506 const TfLiteTensor* bias, int thread_start,
507 int thread_end, TfLiteTensor* input_quantized,
508 TfLiteTensor* scaling_factors,
509 TfLiteTensor* accum_scratch, TfLiteTensor* row_sums,
510 TfLiteTensor* input_offsets, TfLiteTensor* output) {
511 ruy::profiler::ScopeLabel label("FullyConnected");
512 ruy::profiler::ScopeLabel inner_label("Sparse Hybrid Kernel");
513 const auto& input_shape = GetTensorShape(input);
514 const auto& output_shape = GetTensorShape(output);
515 const auto& filter_shape = GetTensorShape(filter);
516 const int input_dims_count = input_shape.DimensionsCount();
517 const int output_dims_count = output_shape.DimensionsCount();
518 const int filter_dims_count = filter_shape.DimensionsCount();
519 const int batch_size = thread_end - thread_start;
520 const int input_depth = MatchingDim(filter_shape, filter_dims_count - 1,
521 input_shape, input_dims_count - 1);
522 const int output_depth = MatchingDim(filter_shape, filter_dims_count - 2,
523 output_shape, output_dims_count - 1);
524 const int per_thread_input_size = batch_size * input_depth;
525
526 const float* per_thread_input =
527 GetTensorData<float>(input) + thread_start * input_depth;
528 float* per_thread_output =
529 GetTensorData<float>(output) + thread_start * output_depth;
530
531 // Output = bias if bias tensor exists.
532 if (bias) {
533 tensor_utils::VectorBatchVectorAssign(GetTensorData<float>(bias),
534 output_depth, batch_size,
535 per_thread_output);
536 } else {
537 std::fill_n(per_thread_output, batch_size * output_depth, 0.0f);
538 }
539
540 // Save matrix multiplication computation for all zero input.
541 if (tensor_utils::IsZeroVector(per_thread_input, per_thread_input_size)) {
542 tensor_utils::ApplyActivationToVector(
543 per_thread_output, batch_size * output_depth, params->activation,
544 per_thread_output);
545 return;
546 }
547
548 // Quantize input from float to uint8 + quantization params (scaling factor).
549 float* scaling_factors_ptr =
550 GetTensorData<float>(scaling_factors) + thread_start;
551 int32_t* input_offset_ptr = nullptr;
552 int32_t* row_sums_ptr = nullptr;
553 if (params->asymmetric_quantize_inputs) {
554 input_offset_ptr = GetTensorData<int32_t>(input_offsets) + thread_start;
555 row_sums_ptr = GetTensorData<int32_t>(row_sums);
556 }
557 int8_t* quant_data =
558 GetTensorData<int8_t>(input_quantized) + thread_start * input_depth;
559 tensor_utils::BatchQuantizeFloats(per_thread_input, batch_size, input_depth,
560 quant_data, scaling_factors_ptr,
561 input_offset_ptr,
562 params->asymmetric_quantize_inputs);
563 for (int b = 0; b < batch_size; ++b) {
564 // Incorporate scaling of the filter.
565 scaling_factors_ptr[b] *= filter->params.scale;
566 }
567
568 // Compute output += weight * quantized_input
569 TfLiteTensor* filter_ledger = &context->tensors[node->temporaries->data[5]];
570 tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate(
571 GetTensorData<int8_t>(filter), GetTensorData<uint8_t>(filter_ledger),
572 output_depth, input_depth, quant_data, scaling_factors_ptr, batch_size,
573 per_thread_output);
574
575 // Apply activation function to floats.
576 tensor_utils::ApplyActivationToVector(per_thread_output,
577 batch_size * output_depth,
578 params->activation, per_thread_output);
579 }
580
581 struct SparseHybridFullyConnectedTask : cpu_backend_threadpool::Task {
SparseHybridFullyConnectedTasktflite::ops::builtin::fully_connected::SparseHybridFullyConnectedTask582 SparseHybridFullyConnectedTask(
583 TfLiteContext* context, TfLiteNode* node,
584 TfLiteFullyConnectedParams* params, OpData* data,
585 const TfLiteTensor* input, const TfLiteTensor* filter,
586 const TfLiteTensor* bias, const int thread_start, const int thread_end,
587 TfLiteTensor* input_quantized, TfLiteTensor* scaling_factors,
588 TfLiteTensor* accum_scratch, TfLiteTensor* row_sums,
589 TfLiteTensor* input_offsets, TfLiteTensor* output)
590 : context(context),
591 node(node),
592 params(params),
593 data(data),
594 input(input),
595 filter(filter),
596 bias(bias),
597 thread_start(thread_start),
598 thread_end(thread_end),
599 input_quantized(input_quantized),
600 scaling_factors(scaling_factors),
601 accum_scratch(accum_scratch),
602 row_sums(row_sums),
603 input_offsets(input_offsets),
604 output(output) {}
605
Runtflite::ops::builtin::fully_connected::SparseHybridFullyConnectedTask606 void Run() override {
607 EvalSparseHybridImpl(context, node, params, data, input, filter, bias,
608 thread_start, thread_end, input_quantized,
609 scaling_factors, accum_scratch, row_sums,
610 input_offsets, output);
611 }
612
613 private:
614 TfLiteContext* context;
615 TfLiteNode* node;
616 TfLiteFullyConnectedParams* params;
617 OpData* data;
618 const TfLiteTensor* input;
619 const TfLiteTensor* filter;
620 const TfLiteTensor* bias;
621 const int thread_start;
622 const int thread_end;
623 TfLiteTensor* input_quantized;
624 TfLiteTensor* scaling_factors;
625 TfLiteTensor* accum_scratch;
626 TfLiteTensor* row_sums;
627 TfLiteTensor* input_offsets;
628 TfLiteTensor* output;
629 };
630
EvalHybrid(TfLiteContext * context,TfLiteNode * node,TfLiteFullyConnectedParams * params,OpData * data,const TfLiteTensor * input,const TfLiteTensor * filter,const TfLiteTensor * bias,TfLiteTensor * input_quantized,TfLiteTensor * scaling_factors,TfLiteTensor * accum_scratch,TfLiteTensor * row_sums,TfLiteTensor * input_offsets,TfLiteTensor * output)631 TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node,
632 TfLiteFullyConnectedParams* params, OpData* data,
633 const TfLiteTensor* input, const TfLiteTensor* filter,
634 const TfLiteTensor* bias, TfLiteTensor* input_quantized,
635 TfLiteTensor* scaling_factors,
636 TfLiteTensor* accum_scratch, TfLiteTensor* row_sums,
637 TfLiteTensor* input_offsets, TfLiteTensor* output) {
638 const auto& output_shape = GetTensorShape(output);
639 CpuBackendContext* cpu_backend_context =
640 CpuBackendContext::GetFromContext(context);
641 const bool is_dense = filter->sparsity == nullptr;
642 if (is_dense) {
643 return EvalHybridDense(context, node, params, data, input, filter, bias,
644 input_quantized, scaling_factors, accum_scratch,
645 row_sums, input_offsets, output);
646 }
647
648 TfLiteTensor* filter_ledger = &context->tensors[node->temporaries->data[5]];
649 if (!data->ledger_initialized) {
650 PopulateLedgerData(filter->sparsity, context,
651 GetTensorData<uint8_t>(filter_ledger));
652 data->ledger_initialized = true;
653 }
654
655 // The multi-threaded kernel slices the workload along the batch dimension. If
656 // there's not enough batches of data, the number of threads used is equal to
657 // the batch size.
658 // TODO(b/173442777): If needed, we can improve this later with slicing along
659 // the row dimension of the weight.
660 const int max_threads = cpu_backend_context->max_num_threads();
661 const int batches =
662 FlatSizeSkipDim(output_shape, output_shape.DimensionsCount() - 1);
663 const int thread_count = std::max(1, std::min(batches, max_threads));
664
665 std::vector<SparseHybridFullyConnectedTask> tasks;
666 tasks.reserve(thread_count);
667 int thread_start = 0;
668 for (int i = 0; i < thread_count; ++i) {
669 // This makes sure the workload is relatively balanced when batches is not
670 // a multiple of thread_count. The first mod(batches, thread_count) tasks
671 // need to process one more batch than the rest.
672 int thread_end = thread_start + batches / thread_count;
673 if (i < batches % thread_count) thread_end++;
674
675 tasks.emplace_back(context, node, params, data, input, filter, bias,
676 thread_start, thread_end, input_quantized,
677 scaling_factors, accum_scratch, row_sums, input_offsets,
678 output);
679 thread_start = thread_end;
680 }
681 cpu_backend_threadpool::Execute(tasks.size(), tasks.data(),
682 cpu_backend_context);
683 return kTfLiteOk;
684 }
685
686 namespace {
687 template <KernelType kernel_type>
FullyConnectedInt8(const OpData * data,const TfLiteTensor * input,const TfLiteTensor * filter,const TfLiteTensor * bias,TfLiteTensor * output,CpuBackendContext * cpu_backend_context)688 void FullyConnectedInt8(const OpData* data, const TfLiteTensor* input,
689 const TfLiteTensor* filter, const TfLiteTensor* bias,
690 TfLiteTensor* output,
691 CpuBackendContext* cpu_backend_context) {
692 FullyConnectedParams op_params;
693 op_params.input_offset = -input->params.zero_point;
694 op_params.weights_offset = -filter->params.zero_point;
695 op_params.output_offset = output->params.zero_point;
696 op_params.output_multiplier = data->output_multiplier;
697 op_params.output_shift = data->output_shift;
698 op_params.quantized_activation_min = data->output_activation_min;
699 op_params.quantized_activation_max = data->output_activation_max;
700 op_params.lhs_cacheable = IsConstantTensor(filter);
701 op_params.rhs_cacheable = IsConstantTensor(input);
702 if (kernel_type == kReference) {
703 reference_integer_ops::FullyConnected(
704 op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
705 GetTensorShape(filter), GetTensorData<int8_t>(filter),
706 GetTensorShape(bias), GetTensorData<int32_t>(bias),
707 GetTensorShape(output), GetTensorData<int8_t>(output));
708 } else {
709 optimized_integer_ops::FullyConnected(
710 op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
711 GetTensorShape(filter), GetTensorData<int8_t>(filter),
712 GetTensorShape(bias), GetTensorData<int32_t>(bias),
713 GetTensorShape(output), GetTensorData<int8_t>(output),
714 cpu_backend_context);
715 }
716 }
717 } // namespace
718
719 namespace {
720 template <KernelType kernel_type>
FullyConnectedInt16(const OpData * data,const TfLiteTensor * input,const TfLiteTensor * filter,const TfLiteTensor * bias,TfLiteTensor * output)721 void FullyConnectedInt16(const OpData* data, const TfLiteTensor* input,
722 const TfLiteTensor* filter, const TfLiteTensor* bias,
723 TfLiteTensor* output) {
724 FullyConnectedParams op_params;
725 op_params.weights_offset = -filter->params.zero_point;
726 op_params.output_multiplier = data->output_multiplier;
727 op_params.output_shift = data->output_shift;
728 op_params.quantized_activation_min = data->output_activation_min;
729 op_params.quantized_activation_max = data->output_activation_max;
730 reference_integer_ops::FullyConnected(
731 op_params, GetTensorShape(input), GetTensorData<int16_t>(input),
732 GetTensorShape(filter), GetTensorData<int8_t>(filter),
733 GetTensorShape(bias), GetTensorData<int64_t>(bias),
734 GetTensorShape(output), GetTensorData<int16_t>(output));
735 }
736 } // namespace
737
738 template <KernelType kernel_type>
EvalQuantized(TfLiteContext * context,TfLiteNode * node,TfLiteFullyConnectedParams * params,OpData * data,const TfLiteTensor * input,const TfLiteTensor * filter,const TfLiteTensor * bias,TfLiteTensor * output)739 TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
740 TfLiteFullyConnectedParams* params, OpData* data,
741 const TfLiteTensor* input,
742 const TfLiteTensor* filter, const TfLiteTensor* bias,
743 TfLiteTensor* output) {
744 int32_t input_offset = -input->params.zero_point;
745 int32_t filter_offset = -filter->params.zero_point;
746 int32_t output_offset = output->params.zero_point;
747 // Only the Pie path supports quantized models and float inputs/outputs.
748 if (input->type == kTfLiteFloat32) {
749 TfLiteTensor* input_quantized;
750 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/0,
751 &input_quantized));
752 TfLiteTensor* scaling_factors;
753 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/1,
754 &scaling_factors));
755 TfLiteTensor* accum_scratch;
756 TF_LITE_ENSURE_OK(
757 context, GetTemporarySafe(context, node, /*index=*/2, &accum_scratch));
758 TfLiteTensor* input_offsets;
759 TF_LITE_ENSURE_OK(
760 context, GetTemporarySafe(context, node, /*index=*/3, &input_offsets));
761 TfLiteTensor* row_sums;
762 TF_LITE_ENSURE_OK(context,
763 GetTemporarySafe(context, node, /*index=*/4, &row_sums));
764 return EvalHybrid(context, node, params, data, input, filter, bias,
765 input_quantized, scaling_factors, accum_scratch, row_sums,
766 input_offsets, output);
767 } else {
768 FullyConnectedParams op_params;
769 op_params.input_offset = input_offset;
770 op_params.weights_offset = filter_offset;
771 op_params.output_offset = output_offset;
772 op_params.output_multiplier = data->output_multiplier;
773 op_params.output_shift = data->output_shift;
774 op_params.quantized_activation_min = data->output_activation_min;
775 op_params.quantized_activation_max = data->output_activation_max;
776 op_params.lhs_cacheable = IsConstantTensor(filter);
777 op_params.rhs_cacheable = IsConstantTensor(input);
778 switch (output->type) {
779 case kTfLiteUInt8:
780 if (kernel_type == kReference) {
781 reference_ops::FullyConnected(
782 op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
783 GetTensorShape(filter), GetTensorData<uint8_t>(filter),
784 GetTensorShape(bias), GetTensorData<int32_t>(bias),
785 GetTensorShape(output), GetTensorData<uint8_t>(output));
786 } else {
787 optimized_ops::FullyConnected(
788 op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
789 GetTensorShape(filter), GetTensorData<uint8_t>(filter),
790 GetTensorShape(bias), GetTensorData<int32_t>(bias),
791 GetTensorShape(output), GetTensorData<uint8_t>(output),
792 CpuBackendContext::GetFromContext(context));
793 }
794 break;
795 case kTfLiteInt8:
796 FullyConnectedInt8<kernel_type>(
797 data, input, filter, bias, output,
798 CpuBackendContext::GetFromContext(context));
799 break;
800 case kTfLiteInt16:
801 if (input->type == kTfLiteInt16) {
802 FullyConnectedInt16<kernel_type>(data, input, filter, bias, output);
803 } else if (kernel_type == kReference) {
804 reference_ops::FullyConnected(
805 op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
806 GetTensorShape(filter), GetTensorData<uint8_t>(filter),
807 GetTensorShape(bias), GetTensorData<int32_t>(bias),
808 GetTensorShape(output), GetTensorData<int16_t>(output));
809 } else {
810 optimized_ops::FullyConnected(
811 op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
812 GetTensorShape(filter), GetTensorData<uint8_t>(filter),
813 GetTensorShape(bias), GetTensorData<int32_t>(bias),
814 GetTensorShape(output), GetTensorData<int16_t>(output),
815 CpuBackendContext::GetFromContext(context));
816 }
817 break;
818 default:
819 context->ReportError(context,
820 "Quantized FullyConnected expects output data "
821 "type uint8, int8 or int16");
822 return kTfLiteError;
823 }
824 }
825
826 return kTfLiteOk;
827 }
828
829 template <KernelType kernel_type>
EvalShuffledQuantized(TfLiteContext * context,TfLiteNode * node,TfLiteFullyConnectedParams * params,OpData * data,const TfLiteTensor * input,const TfLiteTensor * filter,const TfLiteTensor * bias,TfLiteTensor * output,TfLiteTensor * shuffled_input_workspace)830 TfLiteStatus EvalShuffledQuantized(TfLiteContext* context, TfLiteNode* node,
831 TfLiteFullyConnectedParams* params,
832 OpData* data, const TfLiteTensor* input,
833 const TfLiteTensor* filter,
834 const TfLiteTensor* bias,
835 TfLiteTensor* output,
836 TfLiteTensor* shuffled_input_workspace) {
837 // TODO(b/110697972) decide more consistently if / how / where we want
838 // to perform this kind of runtime data type checks.
839 if (shuffled_input_workspace->type != kTfLiteUInt8) {
840 context->ReportError(context, "Unexpected data type");
841 return kTfLiteError;
842 }
843
844 #define TF_LITE_SHUFFLED_FULLY_CONNECTED(type) \
845 { \
846 type::ShuffledFullyConnected( \
847 op_params, GetTensorShape(input), GetTensorData<uint8_t>(input), \
848 GetTensorShape(filter), GetTensorData<uint8_t>(filter), \
849 GetTensorShape(bias), GetTensorData<int32_t>(bias), \
850 GetTensorShape(output), GetTensorData<int16_t>(output), \
851 GetTensorData<uint8_t>(shuffled_input_workspace), \
852 CpuBackendContext::GetFromContext(context)); \
853 }
854 FullyConnectedParams op_params;
855 op_params.output_multiplier = data->output_multiplier;
856 op_params.output_shift = data->output_shift;
857 op_params.quantized_activation_min = data->output_activation_min;
858 op_params.quantized_activation_max = data->output_activation_max;
859 op_params.lhs_cacheable = IsConstantTensor(filter);
860 op_params.rhs_cacheable = IsConstantTensor(input);
861 if (kernel_type == kReference) {
862 reference_ops::ShuffledFullyConnected(
863 op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
864 GetTensorShape(filter), GetTensorData<uint8_t>(filter),
865 GetTensorShape(bias), GetTensorData<int32_t>(bias),
866 GetTensorShape(output), GetTensorData<int16_t>(output),
867 GetTensorData<uint8_t>(shuffled_input_workspace));
868 } else {
869 optimized_ops::ShuffledFullyConnected(
870 op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
871 GetTensorShape(filter), GetTensorData<uint8_t>(filter),
872 GetTensorShape(bias), GetTensorData<int32_t>(bias),
873 GetTensorShape(output), GetTensorData<int16_t>(output),
874 GetTensorData<uint8_t>(shuffled_input_workspace),
875 CpuBackendContext::GetFromContext(context));
876 }
877 #undef TF_LITE_SHUFFLED_FULLY_CONNECTED
878
879 return kTfLiteOk;
880 }
881
882 template <KernelType kernel_type>
EvalFloat(TfLiteContext * context,TfLiteNode * node,TfLiteFullyConnectedParams * params,OpData * data,const TfLiteTensor * input,const TfLiteTensor * filter,const TfLiteTensor * bias,TfLiteTensor * output)883 TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
884 TfLiteFullyConnectedParams* params, OpData* data,
885 const TfLiteTensor* input, const TfLiteTensor* filter,
886 const TfLiteTensor* bias, TfLiteTensor* output) {
887 float output_activation_min, output_activation_max;
888 CalculateActivationRange(params->activation, &output_activation_min,
889 &output_activation_max);
890 if (kernel_type == kReference) {
891 FullyConnectedParams op_params;
892 op_params.float_activation_min = output_activation_min;
893 op_params.float_activation_max = output_activation_max;
894 if (filter->sparsity != nullptr) {
895 const auto& sparsity = *filter->sparsity;
896 reference_ops::FullyConnectedSparseWeight(
897 sparsity, op_params, GetTensorShape(input),
898 GetTensorData<float>(input), GetTensorShape(filter),
899 GetTensorData<float>(filter), GetTensorShape(bias),
900 GetTensorData<float>(bias), GetTensorShape(output),
901 GetTensorData<float>(output));
902 } else {
903 reference_ops::FullyConnected(
904 op_params, GetTensorShape(input), GetTensorData<float>(input),
905 GetTensorShape(filter), GetTensorData<float>(filter),
906 GetTensorShape(bias), GetTensorData<float>(bias),
907 GetTensorShape(output), GetTensorData<float>(output));
908 }
909 } else if (kernel_type == kLegacyPie) {
910 return EvalPie(context, node, params, data, input, filter, bias, output);
911 } else {
912 FullyConnectedParams op_params;
913 op_params.float_activation_min = output_activation_min;
914 op_params.float_activation_max = output_activation_max;
915 if (filter->sparsity != nullptr) {
916 const auto& sparsity = *filter->sparsity;
917 if (!SupportedSparsityFormat(sparsity)) {
918 TF_LITE_KERNEL_LOG(context,
919 "Unsupported sparse fully-connected weight format.");
920 return kTfLiteError;
921 }
922
923 if (sparsity.dim_metadata_size == kDimMetadataSizeRandomSparse) {
924 // Random sparse.
925 optimized_ops::FullyConnectedSparseWeight(
926 sparsity, op_params, GetTensorShape(input),
927 GetTensorData<float>(input), GetTensorShape(filter),
928 GetTensorData<float>(filter), GetTensorShape(bias),
929 GetTensorData<float>(bias), GetTensorShape(output),
930 GetTensorData<float>(output));
931 } else if (sparsity.dim_metadata_size == kDimMetadataSizeBlockSparse &&
932 sparsity.dim_metadata[2].dense_size == 4) {
933 // Block sparse with block size of 1x4.
934 optimized_ops::FullyConnectedSparseWeight1x4(
935 sparsity, op_params, GetTensorShape(input),
936 GetTensorData<float>(input), GetTensorShape(filter),
937 GetTensorData<float>(filter), GetTensorShape(bias),
938 GetTensorData<float>(bias), GetTensorShape(output),
939 GetTensorData<float>(output),
940 CpuBackendContext::GetFromContext(context));
941 } else {
942 TF_LITE_KERNEL_LOG(context,
943 "Unsupported sparse fully-connected weight format.");
944 return kTfLiteError;
945 }
946
947 } else {
948 op_params.lhs_cacheable = IsConstantTensor(filter);
949 op_params.rhs_cacheable = IsConstantTensor(input);
950 optimized_ops::FullyConnected(
951 op_params, GetTensorShape(input), GetTensorData<float>(input),
952 GetTensorShape(filter), GetTensorData<float>(filter),
953 GetTensorShape(bias), GetTensorData<float>(bias),
954 GetTensorShape(output), GetTensorData<float>(output),
955 CpuBackendContext::GetFromContext(context));
956 }
957 }
958
959 return kTfLiteOk;
960 }
961
962 template <KernelType kernel_type>
Eval(TfLiteContext * context,TfLiteNode * node)963 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
964 auto* params =
965 reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data);
966 OpData* data = reinterpret_cast<OpData*>(node->user_data);
967
968 const TfLiteTensor* input;
969 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
970 const TfLiteTensor* filter;
971 TF_LITE_ENSURE_OK(context,
972 GetInputSafe(context, node, kWeightsTensor, &filter));
973 const TfLiteTensor* bias =
974 (node->inputs->size == 3)
975 ? GetOptionalInputTensor(context, node, kBiasTensor)
976 : nullptr;
977 TfLiteTensor* output;
978 TF_LITE_ENSURE_OK(context,
979 GetOutputSafe(context, node, kOutputTensor, &output));
980 // Do nothing if expected output is empty.
981 if (NumElements(output) == 0) {
982 return kTfLiteOk;
983 }
984
985 switch (filter->type) {
986 case kTfLiteFloat32:
987 return EvalFloat<kernel_type>(context, node, params, data, input, filter,
988 bias, output);
989 case kTfLiteUInt8:
990 if (params->weights_format ==
991 kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8) {
992 TfLiteTensor* shuffled_input_workspace;
993 TF_LITE_ENSURE_OK(
994 context, GetOutputSafe(context, node, kShuffledInputWorkspaceTensor,
995 &shuffled_input_workspace));
996 return EvalShuffledQuantized<kernel_type>(context, node, params, data,
997 input, filter, bias, output,
998 shuffled_input_workspace);
999 } else if (params->weights_format ==
1000 kTfLiteFullyConnectedWeightsFormatDefault) {
1001 return EvalQuantized<kernel_type>(context, node, params, data, input,
1002 filter, bias, output);
1003 } else {
1004 context->ReportError(context,
1005 "Unhandled fully-connected weights format");
1006 return kTfLiteError;
1007 }
1008 case kTfLiteInt8:
1009 if (params->weights_format == kTfLiteFullyConnectedWeightsFormatDefault) {
1010 return EvalQuantized<kernel_type>(context, node, params, data, input,
1011 filter, bias, output);
1012 } else {
1013 context->ReportError(context,
1014 "Unhandled fully-connected weights format");
1015 return kTfLiteError;
1016 }
1017 default:
1018 context->ReportError(context,
1019 "Filter data type %s currently not supported.",
1020 TfLiteTypeGetName(filter->type));
1021 return kTfLiteError;
1022 }
1023 return kTfLiteOk;
1024 }
1025
1026 } // namespace fully_connected
1027
Register_FULLY_CONNECTED_REF()1028 TfLiteRegistration* Register_FULLY_CONNECTED_REF() {
1029 static TfLiteRegistration r = {
1030 fully_connected::Init, fully_connected::Free,
1031 fully_connected::Prepare<fully_connected::kReference>,
1032 fully_connected::Eval<fully_connected::kReference>};
1033 return &r;
1034 }
1035
Register_FULLY_CONNECTED_GENERIC_OPT()1036 TfLiteRegistration* Register_FULLY_CONNECTED_GENERIC_OPT() {
1037 static TfLiteRegistration r = {
1038 fully_connected::Init, fully_connected::Free,
1039 fully_connected::Prepare<fully_connected::kGenericOptimized>,
1040 fully_connected::Eval<fully_connected::kGenericOptimized>};
1041 return &r;
1042 }
1043
1044 // Legacy path for PIE clients.
Register_FULLY_CONNECTED_PIE()1045 TfLiteRegistration* Register_FULLY_CONNECTED_PIE() {
1046 static TfLiteRegistration r = {
1047 fully_connected::Init, fully_connected::Free,
1048 fully_connected::Prepare<fully_connected::kLegacyPie>,
1049 fully_connected::Eval<fully_connected::kLegacyPie>};
1050 return &r;
1051 }
1052
Register_FULLY_CONNECTED()1053 TfLiteRegistration* Register_FULLY_CONNECTED() {
1054 return Register_FULLY_CONNECTED_GENERIC_OPT();
1055 }
1056
1057 } // namespace builtin
1058 } // namespace ops
1059 } // namespace tflite
1060