1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/c/common.h"
17 #include "tensorflow/lite/kernels/internal/portable_tensor.h"
18 #include "tensorflow/lite/kernels/internal/reference/integer_ops/l2normalization.h"
19 #include "tensorflow/lite/kernels/internal/reference/l2normalization.h"
20 #include "tensorflow/lite/kernels/kernel_util.h"
21 #include "tensorflow/lite/micro/kernels/kernel_util.h"
22
23 namespace tflite {
24 namespace ops {
25 namespace micro {
26 namespace l2norm {
27
28 namespace {
29
30 // This file has two implementation of L2Norm.
31 enum KernelType {
32 kReference,
33 kGenericOptimized,
34 };
35
36 constexpr int kInputTensor = 0;
37 constexpr int kOutputTensor = 0;
38
39 } // namespace
40
Prepare(TfLiteContext * context,TfLiteNode * node)41 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
42 TFLITE_DCHECK(node->user_data != nullptr);
43 TFLITE_DCHECK(node->builtin_data != nullptr);
44
45 auto* params = reinterpret_cast<TfLiteL2NormParams*>(node->builtin_data);
46 L2NormalizationParams* data =
47 static_cast<L2NormalizationParams*>(node->user_data);
48
49 TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
50 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
51
52 const TfLiteTensor* input = GetInput(context, node, kInputTensor);
53 TF_LITE_ENSURE(context, input != nullptr);
54 TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
55 TF_LITE_ENSURE(context, output != nullptr);
56
57 TF_LITE_ENSURE(context, NumDimensions(input) <= 4);
58
59 TF_LITE_ENSURE(context, output->type == kTfLiteFloat32 ||
60 output->type == kTfLiteUInt8 ||
61 output->type == kTfLiteInt8);
62 TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
63
64 if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8) {
65 data->input_zero_point = input->params.zero_point;
66 } else if (output->type == kTfLiteFloat32) {
67 data->input_zero_point = 0;
68 }
69
70 // TODO(ahentz): For some reason our implementations don't support
71 // activations.
72 TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone);
73
74 return kTfLiteOk;
75 }
76
Init(TfLiteContext * context,const char * buffer,size_t length)77 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
78 TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
79 return context->AllocatePersistentBuffer(context,
80 sizeof(L2NormalizationParams));
81 }
82
Eval(TfLiteContext * context,TfLiteNode * node)83 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
84 TFLITE_DCHECK(node->user_data != nullptr);
85 const L2NormalizationParams& data =
86 *(static_cast<const L2NormalizationParams*>(node->user_data));
87
88 const TfLiteEvalTensor* input =
89 tflite::micro::GetEvalInput(context, node, kInputTensor);
90 TfLiteEvalTensor* output =
91 tflite::micro::GetEvalOutput(context, node, kOutputTensor);
92
93 // TODO(b/143912164): instead of hardcode the epsilon here, we should read it
94 // from tensorflow, i.e., adding a params.
95 // We don't compute epsilon for quantized kernel:
96 //
97 // epsilon_float = (epsilon_quant - zp) * scale
98 // so
99 // espsilon_quant = epsilon_float / scale + zp
100 // We know epsilon_float is just a very small number to avoid division by
101 // zero error, and scale is > 1, so the integer value of epsilon for quant
102 // is just dominated by the zero point.
103 // Also, GetInvSqrtQuantizedMultiplierExp handles the scenario where the sum
104 // of input value squared is zero case well.
105 // So we don't even need to do handle the epsilon for quantized kernel case.
106 const float epsilon = 1e-6f;
107 if (output->type == kTfLiteFloat32) {
108 reference_ops::L2Normalization(data, tflite::micro::GetTensorShape(input),
109 tflite::micro::GetTensorData<float>(input),
110 tflite::micro::GetTensorShape(output),
111 tflite::micro::GetTensorData<float>(output),
112 epsilon);
113 } else if (output->type == kTfLiteUInt8) {
114 reference_ops::L2Normalization(
115 data, tflite::micro::GetTensorShape(input),
116 tflite::micro::GetTensorData<uint8_t>(input),
117 tflite::micro::GetTensorShape(output),
118 tflite::micro::GetTensorData<uint8_t>(output));
119 } else if (output->type == kTfLiteInt8) {
120 const auto input_shape = tflite::micro::GetTensorShape(input);
121 const auto output_shape = tflite::micro::GetTensorShape(output);
122 const int trailing_dim = input_shape.DimensionsCount() - 1;
123 const int depth =
124 MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
125 const int outer_size =
126 MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
127 reference_integer_ops::L2Normalization(
128 data.input_zero_point, outer_size, depth,
129 tflite::micro::GetTensorData<int8_t>(input),
130 tflite::micro::GetTensorData<int8_t>(output));
131 } else {
132 TF_LITE_KERNEL_LOG(context, "Output type is %s, requires float.",
133 TfLiteTypeGetName(output->type));
134 return kTfLiteError;
135 }
136
137 return kTfLiteOk;
138 }
139
140 } // namespace l2norm
141
Register_L2NORM_REF()142 TfLiteRegistration Register_L2NORM_REF() {
143 return {/*init=*/l2norm::Init,
144 /*free=*/nullptr,
145 /*prepare=*/l2norm::Prepare,
146 /*invoke=*/l2norm::Eval,
147 /*profiling_string=*/nullptr,
148 /*builtin_code=*/0,
149 /*custom_name=*/nullptr,
150 /*version=*/0};
151 }
152
Register_L2_NORMALIZATION()153 TfLiteRegistration Register_L2_NORMALIZATION() { return Register_L2NORM_REF(); }
154
155 } // namespace micro
156 } // namespace ops
157 } // namespace tflite
158