• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // SVDF op that compresses a fully connected op via low-rank matrix
17 // factorization. See https://research.google.com/pubs/archive/43813.pdf for
18 // details.
19 
20 #include "tensorflow/lite/kernels/internal/reference/svdf.h"
21 
22 #include <cstddef>
23 #include <cstdint>
24 
25 #include "tensorflow/lite/c/builtin_op_data.h"
26 #include "tensorflow/lite/c/common.h"
27 #include "tensorflow/lite/kernels/internal/compatibility.h"
28 #include "tensorflow/lite/kernels/internal/quantization_util.h"
29 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
30 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
31 #include "tensorflow/lite/kernels/kernel_util.h"
32 
33 namespace tflite {
34 namespace ops {
35 namespace builtin {
36 namespace svdf {
37 
38 namespace {
39 
40 struct OpData {
41   int scratch_tensor_index;
42   bool float_weights_time_initialized;
43   int32 effective_scale_1_a;
44   int effective_scale_1_b;
45   int32 effective_scale_2_a;
46   int effective_scale_2_b;
47   bool compute_row_sums = false;
48 };
49 
50 }  // namespace
51 
52 // Input tensors.
53 constexpr int kInputTensor = 0;
54 constexpr int kWeightsFeatureTensor = 1;
55 constexpr int kWeightsTimeTensor = 2;
56 constexpr int kBiasTensor = 3;
57 // This is a variable tensor, and will be modified by this op.
58 constexpr int kStateTensor = 4;
59 
60 // Output tensor.
61 constexpr int kOutputTensor = 0;
62 
Init(TfLiteContext * context,const char * buffer,size_t length)63 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
64   auto* op_data = new OpData();
65   op_data->float_weights_time_initialized = false;
66   // Note: only needs 6 scratch tensors when is_hybrid_op, only 1 otherwise.
67   context->AddTensors(context, /*tensors_to_add=*/6,
68                       &op_data->scratch_tensor_index);
69   return op_data;
70 }
71 
Free(TfLiteContext * context,void * buffer)72 void Free(TfLiteContext* context, void* buffer) {
73   delete reinterpret_cast<OpData*>(buffer);
74 }
75 
Prepare(TfLiteContext * context,TfLiteNode * node)76 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
77   const auto* params = reinterpret_cast<TfLiteSVDFParams*>(node->builtin_data);
78   OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
79   int scratch_tensor_index = op_data->scratch_tensor_index;
80 
81   // Check we have all the inputs and outputs we need.
82   TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
83   TF_LITE_ENSURE_EQ(context, node->inputs->size, 5);
84 
85   const TfLiteTensor* input;
86   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
87   const TfLiteTensor* weights_feature;
88   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kWeightsFeatureTensor,
89                                           &weights_feature));
90   const TfLiteTensor* weights_time;
91   TF_LITE_ENSURE_OK(
92       context, GetInputSafe(context, node, kWeightsTimeTensor, &weights_time));
93 
94   TF_LITE_ENSURE(context,
95                  input->type == kTfLiteFloat32 || input->type == kTfLiteInt8);
96 
97   // Check all the parameters of tensor match within themselves and match the
98   // input configuration.
99   const int rank = params->rank;
100   const int batch_size = input->dims->data[0];
101   const int num_filters = weights_feature->dims->data[0];
102   TF_LITE_ENSURE(context, rank != 0);
103   TF_LITE_ENSURE_EQ(context, num_filters % rank, 0);
104   const int num_units = num_filters / rank;
105   const int memory_size = weights_time->dims->data[1];
106   TF_LITE_ENSURE_EQ(context, input->dims->data[1],
107                     weights_feature->dims->data[1]);
108   TF_LITE_ENSURE_EQ(context, weights_time->dims->data[0], num_filters);
109 
110   const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
111   if (bias) {
112     TF_LITE_ENSURE_EQ(context, bias->dims->data[0], num_units);
113   }
114 
115   const TfLiteTensor* state;
116   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kStateTensor, &state));
117   TfLiteTensor* output;
118   TF_LITE_ENSURE_OK(context,
119                     GetOutputSafe(context, node, kOutputTensor, &output));
120 
121   // Check the shape of input state tensors.
122   TF_LITE_ENSURE_EQ(context, NumDimensions(state), 2);
123   TF_LITE_ENSURE_EQ(context, SizeOfDimension(state, 0), batch_size);
124   TF_LITE_ENSURE_EQ(context, SizeOfDimension(state, 1),
125                     memory_size * num_filters);
126 
127   // Resize output.
128   TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(2);
129   output_size_array->data[0] = batch_size;
130   output_size_array->data[1] = num_units;
131   TF_LITE_ENSURE_OK(context,
132                     context->ResizeTensor(context, output, output_size_array));
133 
134   // The weights are of consistent type, so it suffices to check one.
135   const bool is_hybrid_op = IsHybridOp(input, weights_feature);
136   const bool is_full_integer = input->type == kTfLiteInt8;
137 
138   // Resize scratch.
139   TfLiteIntArrayFree(node->temporaries);
140   if (is_hybrid_op) {
141     node->temporaries = TfLiteIntArrayCreate(6);
142   } else if (is_full_integer) {
143     node->temporaries = TfLiteIntArrayCreate(2);
144   } else {
145     node->temporaries = TfLiteIntArrayCreate(1);
146   }
147   node->temporaries->data[0] = scratch_tensor_index;
148 
149   TfLiteIntArray* scratch_size_array = TfLiteIntArrayCreate(2);
150   scratch_size_array->data[0] = batch_size;
151   scratch_size_array->data[1] = num_filters;
152 
153   TfLiteTensor* scratch_tensor;
154   TF_LITE_ENSURE_OK(
155       context, GetTemporarySafe(context, node, /*index=*/0, &scratch_tensor));
156 
157   // The scratch buffer is of type int32 for full integer svdf and it's of type
158   // float32 for hybrid and float case.
159   if (is_full_integer) {
160     scratch_tensor->type = kTfLiteInt32;
161   } else {
162     scratch_tensor->type = kTfLiteFloat32;
163   }
164   scratch_tensor->allocation_type = kTfLiteArenaRw;
165   TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_tensor,
166                                                    scratch_size_array));
167 
168   if (is_hybrid_op) {
169     op_data->compute_row_sums = true;
170     // Tell interpreter to allocate temporary tensors to store quantized values
171     // of input tensors.
172     node->temporaries->data[1] = scratch_tensor_index + 1;
173     TfLiteTensor* input_quantized;
174     TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/1,
175                                                 &input_quantized));
176     input_quantized->type = weights_feature->type;
177     input_quantized->allocation_type = kTfLiteArenaRw;
178     if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
179       TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
180       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
181                                                        input_quantized_size));
182     }
183 
184     // Tell interpreter to allocate temporary tensors to store scaling factors.
185     node->temporaries->data[2] = scratch_tensor_index + 2;
186     TfLiteTensor* scaling_factors;
187     TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/2,
188                                                 &scaling_factors));
189     scaling_factors->type = kTfLiteFloat32;
190     scaling_factors->allocation_type = kTfLiteArenaRw;
191     int scaling_dims[1] = {batch_size};
192     if (!TfLiteIntArrayEqualsArray(scaling_factors->dims, 1, scaling_dims)) {
193       TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
194       scaling_factors_size->data[0] = batch_size;
195       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
196                                                        scaling_factors_size));
197     }
198 
199     // Used to store dequantized weights_time matrix for hybrid computation of
200     // matmul(state, weights_time), which occurs in floating point.
201     node->temporaries->data[3] = scratch_tensor_index + 3;
202     TfLiteTensor* float_weights_time;
203     TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/3,
204                                                 &float_weights_time));
205     float_weights_time->type = kTfLiteFloat32;
206     float_weights_time->name = "Svdf_float_weights_time";
207     // Persistent so that we can compute the dequantized weights only once.
208     float_weights_time->allocation_type = kTfLiteArenaRwPersistent;
209     if (!TfLiteIntArrayEqual(float_weights_time->dims, weights_time->dims)) {
210       TfLiteIntArray* float_weights_time_size =
211           TfLiteIntArrayCopy(weights_time->dims);
212       TF_LITE_ENSURE_OK(context,
213                         context->ResizeTensor(context, float_weights_time,
214                                               float_weights_time_size));
215     }
216 
217     node->temporaries->data[4] = scratch_tensor_index + 4;
218     TfLiteTensor* zero_points;
219     TF_LITE_ENSURE_OK(
220         context, GetTemporarySafe(context, node, /*index=*/4, &zero_points));
221     zero_points->type = kTfLiteFloat32;
222     zero_points->allocation_type = kTfLiteArenaRw;
223     int zero_points_dims[1] = {batch_size};
224     if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, zero_points_dims)) {
225       TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1);
226       zero_points_size->data[0] = zero_points_dims[0];
227       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points,
228                                                        zero_points_size));
229     }
230 
231     node->temporaries->data[5] = scratch_tensor_index + 5;
232     TfLiteTensor* row_sums;
233     TF_LITE_ENSURE_OK(context,
234                       GetTemporarySafe(context, node, /*index=*/5, &row_sums));
235     row_sums->type = kTfLiteFloat32;
236     float_weights_time->name = "Svdf_row_sums";
237     row_sums->allocation_type = kTfLiteArenaRwPersistent;
238     int row_sums_dims[1] = {num_filters};
239     if (!TfLiteIntArrayEqualsArray(row_sums->dims, 1, row_sums_dims)) {
240       TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(1);
241       row_sums_size->data[0] = row_sums_dims[0];
242       TF_LITE_ENSURE_OK(
243           context, context->ResizeTensor(context, row_sums, row_sums_size));
244     }
245   }
246   if (is_full_integer) {
247     // Allocated one extra tensor.
248     TfLiteIntArray* output_temp_size_array = TfLiteIntArrayCreate(2);
249     output_temp_size_array->data[0] = num_units;
250     output_temp_size_array->data[1] = batch_size;
251     node->temporaries->data[1] = scratch_tensor_index + 1;
252     TfLiteTensor* output_temp;
253     TF_LITE_ENSURE_OK(
254         context, GetTemporarySafe(context, node, /*index=*/1, &output_temp));
255     output_temp->type = kTfLiteInt32;
256     output_temp->allocation_type = kTfLiteArenaRw;
257     TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_temp,
258                                                      output_temp_size_array));
259 
260     // Calculate effective scales.
261     TF_LITE_ENSURE(context, input->quantization.type != kTfLiteNoQuantization);
262     auto* input_params =
263         reinterpret_cast<TfLiteAffineQuantization*>(input->quantization.params);
264     TF_LITE_ENSURE(context,
265                    weights_feature->quantization.type != kTfLiteNoQuantization);
266     auto* weights_feature_params = reinterpret_cast<TfLiteAffineQuantization*>(
267         weights_feature->quantization.params);
268     TF_LITE_ENSURE(context, state->quantization.type != kTfLiteNoQuantization);
269     auto* state_params =
270         reinterpret_cast<TfLiteAffineQuantization*>(state->quantization.params);
271     TF_LITE_ENSURE(context,
272                    weights_time->quantization.type != kTfLiteNoQuantization);
273     auto* weight_time_params = reinterpret_cast<TfLiteAffineQuantization*>(
274         weights_time->quantization.params);
275     TF_LITE_ENSURE(context, output->quantization.type != kTfLiteNoQuantization);
276     auto* output_params = reinterpret_cast<TfLiteAffineQuantization*>(
277         output->quantization.params);
278     const double effective_scale_1 = input_params->scale->data[0] *
279                                      weights_feature_params->scale->data[0] /
280                                      state_params->scale->data[0];
281     const double effective_scale_2 = state_params->scale->data[0] *
282                                      weight_time_params->scale->data[0] /
283                                      output_params->scale->data[0];
284     QuantizeMultiplier(effective_scale_1, &op_data->effective_scale_1_a,
285                        &op_data->effective_scale_1_b);
286     QuantizeMultiplier(effective_scale_2, &op_data->effective_scale_2_a,
287                        &op_data->effective_scale_2_b);
288   }
289   return kTfLiteOk;
290 }
291 
Eval(TfLiteContext * context,TfLiteNode * node)292 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
293   auto* params = reinterpret_cast<TfLiteSVDFParams*>(node->builtin_data);
294   OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
295 
296   const TfLiteTensor* input;
297   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
298   const TfLiteTensor* weights_feature;
299   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kWeightsFeatureTensor,
300                                           &weights_feature));
301   const TfLiteTensor* weights_time;
302   TF_LITE_ENSURE_OK(
303       context, GetInputSafe(context, node, kWeightsTimeTensor, &weights_time));
304   const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
305 
306   TfLiteTensor* scratch;
307   TF_LITE_ENSURE_OK(context,
308                     GetTemporarySafe(context, node, /*index=*/0, &scratch));
309 
310   TfLiteTensor* state = GetVariableInput(context, node, kStateTensor);
311   TF_LITE_ENSURE(context, state != nullptr);
312   TfLiteTensor* output;
313   TF_LITE_ENSURE_OK(context,
314                     GetOutputSafe(context, node, kOutputTensor, &output));
315 
316   switch (weights_feature->type) {
317     case kTfLiteFloat32: {
318       reference_ops::EvalFloatSVDF(
319           params, GetTensorShape(input), GetTensorData<float>(input),
320           GetTensorShape(weights_feature),
321           GetTensorData<float>(weights_feature), GetTensorShape(weights_time),
322           GetTensorData<float>(weights_time), GetTensorShape(bias),
323           GetTensorData<float>(bias), GetTensorData<float>(scratch),
324           GetTensorData<float>(state), GetTensorShape(output),
325           GetTensorData<float>(output));
326       return kTfLiteOk;
327     }
328     case kTfLiteUInt8:
329     case kTfLiteInt8: {
330       if (input->type == kTfLiteFloat32) {
331         TfLiteTensor* input_quantized;
332         TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/1,
333                                                     &input_quantized));
334         TfLiteTensor* scaling_factors;
335         TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/2,
336                                                     &scaling_factors));
337         TfLiteTensor* float_weights_time;
338         TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/3,
339                                                     &float_weights_time));
340         TfLiteTensor* zero_points;
341         TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/4,
342                                                     &zero_points));
343         TfLiteTensor* row_sums;
344         TF_LITE_ENSURE_OK(
345             context, GetTemporarySafe(context, node, /*index=*/5, &row_sums));
346         // Dequantize weights time.
347         // TODO(alanchiao): this dequantization initialization only needs to
348         // happen once per model and should theoretically be placed in either
349         // Init or Prepare. However, TFLite doesn't allocate float_weights_time
350         // until the Eval function.
351         // TODO(alanchiao): refactor logic out into dequantize function.
352         if (!op_data->float_weights_time_initialized) {
353           const float dequantization_scale = weights_time->params.scale;
354           const int8_t* weights_time_ptr = GetTensorData<int8_t>(weights_time);
355           float* float_weights_time_ptr =
356               GetTensorData<float>(float_weights_time);
357           for (int i = 0; i < NumElements(float_weights_time); ++i) {
358             float_weights_time_ptr[i] =
359                 weights_time_ptr[i] * dequantization_scale;
360           }
361           op_data->float_weights_time_initialized = true;
362         }
363 
364         int32_t* zero_points_ptr = nullptr;
365         int32_t* row_sums_ptr = nullptr;
366         if (params->asymmetric_quantize_inputs && row_sums != nullptr) {
367           zero_points_ptr = GetTensorData<int32_t>(zero_points);
368           row_sums_ptr = GetTensorData<int32_t>(row_sums);
369         }
370 
371         reference_ops::EvalHybridSVDF(
372             params, GetTensorShape(input), GetTensorData<float>(input),
373             GetTensorShape(weights_feature),
374             GetTensorData<int8_t>(weights_feature),
375             weights_feature->params.scale, GetTensorShape(float_weights_time),
376             GetTensorData<float>(float_weights_time), GetTensorShape(bias),
377             GetTensorData<float>(bias), GetTensorData<float>(scratch),
378             GetTensorData<float>(scaling_factors),
379             GetTensorData<int8_t>(input_quantized), GetTensorData<float>(state),
380             GetTensorShape(output), GetTensorData<float>(output),
381             zero_points_ptr, row_sums_ptr, &op_data->compute_row_sums);
382         return kTfLiteOk;
383       }
384       auto* input_params = reinterpret_cast<TfLiteAffineQuantization*>(
385           input->quantization.params);
386       auto* output_params = reinterpret_cast<TfLiteAffineQuantization*>(
387           output->quantization.params);
388       TfLiteTensor* output_temp;
389       TF_LITE_ENSURE_OK(
390           context, GetTemporarySafe(context, node, /*index=*/1, &output_temp));
391 
392       // Currently supports only ReLU.
393       // TODO(jianlijianli): support other activations.
394       TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActRelu);
395 
396       reference_ops::EvalIntegerSVDF(
397           params, GetTensorShape(input), GetTensorData<int8_t>(input),
398           GetTensorShape(weights_feature),
399           GetTensorData<int8_t>(weights_feature), GetTensorShape(weights_time),
400           GetTensorData<int16_t>(weights_time), GetTensorShape(bias),
401           GetTensorData<int32_t>(bias), GetTensorData<int16_t>(state),
402           GetTensorShape(output), GetTensorData<int8_t>(output),
403           GetTensorData<int32_t>(scratch), GetTensorData<int32_t>(output_temp),
404           op_data->effective_scale_1_a, op_data->effective_scale_1_b,
405           op_data->effective_scale_2_a, op_data->effective_scale_2_b,
406           input_params->zero_point->data[0],
407           output_params->zero_point->data[0]);
408       return kTfLiteOk;
409     }
410     default:
411       TF_LITE_KERNEL_LOG(context, "Type %s not currently supported.",
412                          TfLiteTypeGetName(weights_feature->type));
413       return kTfLiteError;
414   }
415 }
416 
417 }  // namespace svdf
418 
Register_SVDF()419 TfLiteRegistration* Register_SVDF() {
420   static TfLiteRegistration r = {svdf::Init, svdf::Free, svdf::Prepare,
421                                  svdf::Eval};
422   return &r;
423 }
424 
425 }  // namespace builtin
426 }  // namespace ops
427 }  // namespace tflite
428