1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/kernels/internal/reference/integer_ops/logistic.h"
17
18 #include "tensorflow/lite/c/builtin_op_data.h"
19 #include "tensorflow/lite/c/common.h"
20 #include "tensorflow/lite/kernels/internal/common.h"
21 #include "tensorflow/lite/kernels/internal/quantization_util.h"
22 #include "tensorflow/lite/kernels/internal/reference/logistic.h"
23 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
24 #include "tensorflow/lite/kernels/kernel_util.h"
25 #include "tensorflow/lite/kernels/op_macros.h"
26 #include "tensorflow/lite/micro/kernels/kernel_util.h"
27
28 namespace tflite {
29 namespace ops {
30 namespace micro {
31 namespace activations {
32 namespace {
33 constexpr int kInputTensor = 0;
34 constexpr int kOutputTensor = 0;
35
36 struct OpData {
37 int32_t input_zero_point;
38 int32_t input_range_radius;
39 int32_t input_multiplier;
40 int input_left_shift;
41 };
42
CalculateArithmeticOpData(TfLiteContext * context,TfLiteNode * node,OpData * data)43 TfLiteStatus CalculateArithmeticOpData(TfLiteContext* context, TfLiteNode* node,
44 OpData* data) {
45 const TfLiteTensor* input = GetInput(context, node, kInputTensor);
46 TF_LITE_ENSURE(context, input != nullptr);
47 TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
48 TF_LITE_ENSURE(context, output != nullptr);
49
50 TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
51 if (input->type == kTfLiteInt8) {
52 TF_LITE_ENSURE_EQ(context, output->params.zero_point,
53 std::numeric_limits<int8_t>::min());
54
55 static constexpr int kInputIntegerBits = 4;
56 const double input_real_multiplier =
57 static_cast<double>(input->params.scale) *
58 static_cast<double>(1 << (31 - kInputIntegerBits));
59
60 data->input_zero_point = input->params.zero_point;
61
62 const double q = std::frexp(input_real_multiplier, &data->input_left_shift);
63 data->input_multiplier = static_cast<int32_t>(TfLiteRound(q * (1ll << 31)));
64
65 data->input_range_radius =
66 CalculateInputRadius(kInputIntegerBits, data->input_left_shift, 31);
67 }
68 return kTfLiteOk;
69 }
70 } // namespace
71
LogisticInit(TfLiteContext * context,const char * buffer,size_t length)72 void* LogisticInit(TfLiteContext* context, const char* buffer, size_t length) {
73 TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
74 return context->AllocatePersistentBuffer(context, sizeof(OpData));
75 }
76
LogisticPrepare(TfLiteContext * context,TfLiteNode * node)77 TfLiteStatus LogisticPrepare(TfLiteContext* context, TfLiteNode* node) {
78 TFLITE_DCHECK(node->user_data != nullptr);
79 OpData* data = static_cast<OpData*>(node->user_data);
80
81 return CalculateArithmeticOpData(context, node, data);
82 }
83
LogisticEval(TfLiteContext * context,TfLiteNode * node)84 TfLiteStatus LogisticEval(TfLiteContext* context, TfLiteNode* node) {
85 const TfLiteEvalTensor* input =
86 tflite::micro::GetEvalInput(context, node, kInputTensor);
87 TfLiteEvalTensor* output =
88 tflite::micro::GetEvalOutput(context, node, kOutputTensor);
89
90 TFLITE_DCHECK(node->user_data != nullptr);
91 OpData* data = static_cast<OpData*>(node->user_data);
92
93 if (input->type == kTfLiteFloat32) {
94 switch (output->type) {
95 case kTfLiteFloat32: {
96 reference_ops::Logistic(tflite::micro::GetTensorShape(input),
97 tflite::micro::GetTensorData<float>(input),
98 tflite::micro::GetTensorShape(output),
99 tflite::micro::GetTensorData<float>(output));
100 return kTfLiteOk;
101 }
102 default:
103 TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
104 TfLiteTypeGetName(input->type),
105 TfLiteTypeGetName(output->type));
106 return kTfLiteError;
107 }
108 } else if (input->type == kTfLiteInt8) {
109 switch (output->type) {
110 case kTfLiteInt8: {
111 reference_integer_ops::Logistic(
112 data->input_zero_point, data->input_range_radius,
113 data->input_multiplier, data->input_left_shift,
114 NumElements(input->dims),
115 tflite::micro::GetTensorData<int8_t>(input),
116 tflite::micro::GetTensorData<int8_t>(output));
117 return kTfLiteOk;
118 }
119 default:
120 TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
121 TfLiteTypeGetName(input->type),
122 TfLiteTypeGetName(output->type));
123 return kTfLiteError;
124 }
125 } else {
126 // TODO(b/141211002): Also support other data types once we have supported
127 // temporary tensors in TFLM.
128 TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
129 TfLiteTypeGetName(input->type),
130 TfLiteTypeGetName(output->type));
131 return kTfLiteError;
132 }
133 return kTfLiteOk;
134 }
135
136 } // namespace activations
137
Register_LOGISTIC()138 TfLiteRegistration Register_LOGISTIC() {
139 return {/*init=*/activations::LogisticInit,
140 /*free=*/nullptr,
141 /*prepare=*/activations::LogisticPrepare,
142 /*invoke=*/activations::LogisticEval,
143 /*profiling_string=*/nullptr,
144 /*builtin_code=*/0,
145 /*custom_name=*/nullptr,
146 /*version=*/0};
147 }
148 } // namespace micro
149 } // namespace ops
150 } // namespace tflite
151