• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/kernels/internal/reference/integer_ops/logistic.h"
17 
18 #include "tensorflow/lite/c/builtin_op_data.h"
19 #include "tensorflow/lite/c/common.h"
20 #include "tensorflow/lite/kernels/internal/common.h"
21 #include "tensorflow/lite/kernels/internal/quantization_util.h"
22 #include "tensorflow/lite/kernels/internal/reference/logistic.h"
23 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
24 #include "tensorflow/lite/kernels/kernel_util.h"
25 #include "tensorflow/lite/kernels/op_macros.h"
26 #include "tensorflow/lite/micro/kernels/kernel_util.h"
27 
28 namespace tflite {
29 namespace ops {
30 namespace micro {
31 namespace activations {
32 namespace {
33 constexpr int kInputTensor = 0;
34 constexpr int kOutputTensor = 0;
35 
36 struct OpData {
37   int32_t input_zero_point;
38   int32_t input_range_radius;
39   int32_t input_multiplier;
40   int input_left_shift;
41 };
42 
CalculateArithmeticOpData(TfLiteContext * context,TfLiteNode * node,OpData * data)43 TfLiteStatus CalculateArithmeticOpData(TfLiteContext* context, TfLiteNode* node,
44                                        OpData* data) {
45   const TfLiteTensor* input = GetInput(context, node, kInputTensor);
46   TF_LITE_ENSURE(context, input != nullptr);
47   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
48   TF_LITE_ENSURE(context, output != nullptr);
49 
50   TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
51   if (input->type == kTfLiteInt8) {
52     TF_LITE_ENSURE_EQ(context, output->params.zero_point,
53                       std::numeric_limits<int8_t>::min());
54 
55     static constexpr int kInputIntegerBits = 4;
56     const double input_real_multiplier =
57         static_cast<double>(input->params.scale) *
58         static_cast<double>(1 << (31 - kInputIntegerBits));
59 
60     data->input_zero_point = input->params.zero_point;
61 
62     const double q = std::frexp(input_real_multiplier, &data->input_left_shift);
63     data->input_multiplier = static_cast<int32_t>(TfLiteRound(q * (1ll << 31)));
64 
65     data->input_range_radius =
66         CalculateInputRadius(kInputIntegerBits, data->input_left_shift, 31);
67   }
68   return kTfLiteOk;
69 }
70 }  // namespace
71 
LogisticInit(TfLiteContext * context,const char * buffer,size_t length)72 void* LogisticInit(TfLiteContext* context, const char* buffer, size_t length) {
73   TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
74   return context->AllocatePersistentBuffer(context, sizeof(OpData));
75 }
76 
LogisticPrepare(TfLiteContext * context,TfLiteNode * node)77 TfLiteStatus LogisticPrepare(TfLiteContext* context, TfLiteNode* node) {
78   TFLITE_DCHECK(node->user_data != nullptr);
79   OpData* data = static_cast<OpData*>(node->user_data);
80 
81   return CalculateArithmeticOpData(context, node, data);
82 }
83 
LogisticEval(TfLiteContext * context,TfLiteNode * node)84 TfLiteStatus LogisticEval(TfLiteContext* context, TfLiteNode* node) {
85   const TfLiteEvalTensor* input =
86       tflite::micro::GetEvalInput(context, node, kInputTensor);
87   TfLiteEvalTensor* output =
88       tflite::micro::GetEvalOutput(context, node, kOutputTensor);
89 
90   TFLITE_DCHECK(node->user_data != nullptr);
91   OpData* data = static_cast<OpData*>(node->user_data);
92 
93   if (input->type == kTfLiteFloat32) {
94     switch (output->type) {
95       case kTfLiteFloat32: {
96         reference_ops::Logistic(tflite::micro::GetTensorShape(input),
97                                 tflite::micro::GetTensorData<float>(input),
98                                 tflite::micro::GetTensorShape(output),
99                                 tflite::micro::GetTensorData<float>(output));
100         return kTfLiteOk;
101       }
102       default:
103         TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
104                            TfLiteTypeGetName(input->type),
105                            TfLiteTypeGetName(output->type));
106         return kTfLiteError;
107     }
108   } else if (input->type == kTfLiteInt8) {
109     switch (output->type) {
110       case kTfLiteInt8: {
111         reference_integer_ops::Logistic(
112             data->input_zero_point, data->input_range_radius,
113             data->input_multiplier, data->input_left_shift,
114             NumElements(input->dims),
115             tflite::micro::GetTensorData<int8_t>(input),
116             tflite::micro::GetTensorData<int8_t>(output));
117         return kTfLiteOk;
118       }
119       default:
120         TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
121                            TfLiteTypeGetName(input->type),
122                            TfLiteTypeGetName(output->type));
123         return kTfLiteError;
124     }
125   } else {
126     // TODO(b/141211002): Also support other data types once we have supported
127     // temporary tensors in TFLM.
128     TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
129                        TfLiteTypeGetName(input->type),
130                        TfLiteTypeGetName(output->type));
131     return kTfLiteError;
132   }
133   return kTfLiteOk;
134 }
135 
136 }  // namespace activations
137 
Register_LOGISTIC()138 TfLiteRegistration Register_LOGISTIC() {
139   return {/*init=*/activations::LogisticInit,
140           /*free=*/nullptr,
141           /*prepare=*/activations::LogisticPrepare,
142           /*invoke=*/activations::LogisticEval,
143           /*profiling_string=*/nullptr,
144           /*builtin_code=*/0,
145           /*custom_name=*/nullptr,
146           /*version=*/0};
147 }
148 }  // namespace micro
149 }  // namespace ops
150 }  // namespace tflite
151