• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/c/common.h"
17 #include "tensorflow/lite/kernels/internal/portable_tensor.h"
18 #include "tensorflow/lite/kernels/internal/reference/integer_ops/l2normalization.h"
19 #include "tensorflow/lite/kernels/internal/reference/l2normalization.h"
20 #include "tensorflow/lite/kernels/kernel_util.h"
21 #include "tensorflow/lite/micro/kernels/kernel_util.h"
22 
23 namespace tflite {
24 namespace ops {
25 namespace micro {
26 namespace l2norm {
27 
28 namespace {
29 
30 // This file has two implementation of L2Norm.
31 enum KernelType {
32   kReference,
33   kGenericOptimized,
34 };
35 
36 constexpr int kInputTensor = 0;
37 constexpr int kOutputTensor = 0;
38 
39 }  // namespace
40 
Prepare(TfLiteContext * context,TfLiteNode * node)41 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
42   TFLITE_DCHECK(node->user_data != nullptr);
43   TFLITE_DCHECK(node->builtin_data != nullptr);
44 
45   auto* params = reinterpret_cast<TfLiteL2NormParams*>(node->builtin_data);
46   L2NormalizationParams* data =
47       static_cast<L2NormalizationParams*>(node->user_data);
48 
49   TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
50   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
51 
52   const TfLiteTensor* input = GetInput(context, node, kInputTensor);
53   TF_LITE_ENSURE(context, input != nullptr);
54   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
55   TF_LITE_ENSURE(context, output != nullptr);
56 
57   TF_LITE_ENSURE(context, NumDimensions(input) <= 4);
58 
59   TF_LITE_ENSURE(context, output->type == kTfLiteFloat32 ||
60                               output->type == kTfLiteUInt8 ||
61                               output->type == kTfLiteInt8);
62   TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
63 
64   if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8) {
65     data->input_zero_point = input->params.zero_point;
66   } else if (output->type == kTfLiteFloat32) {
67     data->input_zero_point = 0;
68   }
69 
70   // TODO(ahentz): For some reason our implementations don't support
71   // activations.
72   TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone);
73 
74   return kTfLiteOk;
75 }
76 
Init(TfLiteContext * context,const char * buffer,size_t length)77 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
78   TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
79   return context->AllocatePersistentBuffer(context,
80                                            sizeof(L2NormalizationParams));
81 }
82 
Eval(TfLiteContext * context,TfLiteNode * node)83 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
84   TFLITE_DCHECK(node->user_data != nullptr);
85   const L2NormalizationParams& data =
86       *(static_cast<const L2NormalizationParams*>(node->user_data));
87 
88   const TfLiteEvalTensor* input =
89       tflite::micro::GetEvalInput(context, node, kInputTensor);
90   TfLiteEvalTensor* output =
91       tflite::micro::GetEvalOutput(context, node, kOutputTensor);
92 
93   // TODO(b/143912164): instead of hardcode the epsilon here, we should read it
94   // from tensorflow, i.e., adding a params.
95   // We don't compute epsilon for quantized kernel:
96   //
97   // epsilon_float = (epsilon_quant - zp) * scale
98   // so
99   // espsilon_quant = epsilon_float / scale + zp
100   // We know epsilon_float is just a very small number to avoid division by
101   // zero error, and scale is > 1, so the integer value of epsilon for quant
102   // is just dominated by the zero point.
103   // Also, GetInvSqrtQuantizedMultiplierExp handles the scenario where the sum
104   // of input value squared is zero case well.
105   // So we don't even need to do handle the epsilon for quantized kernel case.
106   const float epsilon = 1e-6f;
107   if (output->type == kTfLiteFloat32) {
108     reference_ops::L2Normalization(data, tflite::micro::GetTensorShape(input),
109                                    tflite::micro::GetTensorData<float>(input),
110                                    tflite::micro::GetTensorShape(output),
111                                    tflite::micro::GetTensorData<float>(output),
112                                    epsilon);
113   } else if (output->type == kTfLiteUInt8) {
114     reference_ops::L2Normalization(
115         data, tflite::micro::GetTensorShape(input),
116         tflite::micro::GetTensorData<uint8_t>(input),
117         tflite::micro::GetTensorShape(output),
118         tflite::micro::GetTensorData<uint8_t>(output));
119   } else if (output->type == kTfLiteInt8) {
120     const auto input_shape = tflite::micro::GetTensorShape(input);
121     const auto output_shape = tflite::micro::GetTensorShape(output);
122     const int trailing_dim = input_shape.DimensionsCount() - 1;
123     const int depth =
124         MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
125     const int outer_size =
126         MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
127     reference_integer_ops::L2Normalization(
128         data.input_zero_point, outer_size, depth,
129         tflite::micro::GetTensorData<int8_t>(input),
130         tflite::micro::GetTensorData<int8_t>(output));
131   } else {
132     TF_LITE_KERNEL_LOG(context, "Output type is %s, requires float.",
133                        TfLiteTypeGetName(output->type));
134     return kTfLiteError;
135   }
136 
137   return kTfLiteOk;
138 }
139 
140 }  // namespace l2norm
141 
Register_L2NORM_REF()142 TfLiteRegistration Register_L2NORM_REF() {
143   return {/*init=*/l2norm::Init,
144           /*free=*/nullptr,
145           /*prepare=*/l2norm::Prepare,
146           /*invoke=*/l2norm::Eval,
147           /*profiling_string=*/nullptr,
148           /*builtin_code=*/0,
149           /*custom_name=*/nullptr,
150           /*version=*/0};
151 }
152 
Register_L2_NORMALIZATION()153 TfLiteRegistration Register_L2_NORMALIZATION() { return Register_L2NORM_REF(); }
154 
155 }  // namespace micro
156 }  // namespace ops
157 }  // namespace tflite
158