• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/c/builtin_op_data.h"
17 #include "tensorflow/lite/c/common.h"
18 #include "tensorflow/lite/kernels/internal/common.h"
19 #include "tensorflow/lite/kernels/internal/quantization_util.h"
20 #include "tensorflow/lite/kernels/kernel_util.h"
21 #include "tensorflow/lite/kernels/op_macros.h"
22 #include "tensorflow/lite/micro/kernels/softmax.h"
23 
24 namespace tflite {
25 
26 namespace {
27 // Softmax parameter data that persists in user_data
28 const int kInt16LUTArraySize = 513;
29 
CalculateSoftmaxParams(TfLiteContext * context,const TfLiteTensor * input,TfLiteTensor * output,const TfLiteSoftmaxParams * params,SoftmaxParams * op_data)30 TfLiteStatus CalculateSoftmaxParams(TfLiteContext* context,
31                                     const TfLiteTensor* input,
32                                     TfLiteTensor* output,
33                                     const TfLiteSoftmaxParams* params,
34                                     SoftmaxParams* op_data) {
35   if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8 ||
36       input->type == kTfLiteInt16) {
37     if (input->type == kTfLiteUInt8) {
38       TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteUInt8);
39       TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
40     } else if (input->type == kTfLiteInt16) {
41       TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
42       TF_LITE_ENSURE_NEAR(context, output->params.scale, 1.f / 32768,
43                           (0.001f * 1.f / 32768));
44     } else {  // input->type == kTfLiteInt8
45       TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteInt8);
46       if (output->type == kTfLiteInt16) {
47         TF_LITE_ENSURE_EQ(context, output->params.zero_point, -32768);
48         TF_LITE_ENSURE_NEAR(context, output->params.scale, 1.f / 65536,
49                             (0.001f * 1.f / 65536));
50       } else {  // output->type == kTfLiteint8
51         TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteInt8);
52         TF_LITE_ENSURE_EQ(context, output->params.zero_point, -128);
53         TF_LITE_ENSURE(context, output->params.scale == 1.f / 256);
54       }
55     }
56 
57     static const int kScaledDiffIntegerBits = 5;
58 
59     // Calculate input_multiplier and input_left_shift
60     if (input->type == kTfLiteInt16) {
61       int input_left_shift;
62       double input_scale_beta_rescale =
63           static_cast<double>(input->params.scale) *
64           static_cast<double>(params->beta) /
65           (10.0 / 65535.0);  // scale the input_diff such that [-65535, 0]
66                              // correspond to [-10.0, 0.0]
67       QuantizeMultiplier(input_scale_beta_rescale, &op_data->input_multiplier,
68                          &input_left_shift);
69       op_data->input_left_shift = input_left_shift;
70     } else {
71       int input_left_shift;
72       tflite::PreprocessSoftmaxScaling(
73           static_cast<double>(params->beta),
74           static_cast<double>(input->params.scale), kScaledDiffIntegerBits,
75           &op_data->input_multiplier, &input_left_shift);
76       op_data->input_left_shift = input_left_shift;
77       op_data->diff_min =
78           -1.0 * tflite::CalculateInputRadius(kScaledDiffIntegerBits,
79                                               op_data->input_left_shift);
80     }
81   } else {
82     TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
83     TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32);
84     op_data->beta = static_cast<double>(params->beta);
85   }
86   return kTfLiteOk;
87 }
88 
89 }  // namespace
90 
SoftmaxInit(TfLiteContext * context,const char * buffer,size_t length)91 void* SoftmaxInit(TfLiteContext* context, const char* buffer, size_t length) {
92   TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
93   return context->AllocatePersistentBuffer(context, sizeof(SoftmaxParams));
94 }
95 
SoftmaxPrepare(TfLiteContext * context,TfLiteNode * node)96 TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
97   TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
98   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
99   const TfLiteTensor* input = GetInput(context, node, 0);
100   TF_LITE_ENSURE(context, input != nullptr);
101   TF_LITE_ENSURE(context, NumDimensions(input) >= 1);
102   TfLiteTensor* output = GetOutput(context, node, 0);
103   TF_LITE_ENSURE(context, output != nullptr);
104 
105   TF_LITE_ENSURE(context, node->user_data != nullptr);
106   SoftmaxParams* op_data = static_cast<SoftmaxParams*>(node->user_data);
107   // Only allocate LUTs for KTfLiteInt16 data type
108   if (input->type == kTfLiteInt16) {
109     void* raw_exp_lut = context->AllocatePersistentBuffer(
110         context, sizeof(int16_t) * kInt16LUTArraySize);
111     TF_LITE_ENSURE(context, raw_exp_lut != nullptr);
112     op_data->exp_lut = reinterpret_cast<int16_t*>(raw_exp_lut);
113     void* one_over_one_plus_x_lut = context->AllocatePersistentBuffer(
114         context, sizeof(int16_t) * kInt16LUTArraySize);
115     TF_LITE_ENSURE(context, one_over_one_plus_x_lut != nullptr);
116     op_data->one_over_one_plus_x_lut =
117         reinterpret_cast<int16_t*>(one_over_one_plus_x_lut);
118   }
119 
120   if (output->type == kTfLiteInt16) {
121     TF_LITE_ENSURE(context, input->type == kTfLiteInt8 ||
122                                 input->type == kTfLiteUInt8 ||
123                                 input->type == kTfLiteInt16);
124   } else {
125     TF_LITE_ENSURE_EQ(context, input->type, output->type);
126   }
127 
128   // Populate LUT if required
129   if (input->type == kTfLiteInt16) {
130     TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
131     // exp LUT only used on negative values
132     // we consider exp(-10.0) is insignificant to accumulation
133     gen_lut([](float value) { return std::exp(value); }, -10.0f, 0.0f,
134             op_data->exp_lut, kInt16LUTArraySize);
135     gen_lut([](float value) { return 1.0f / (1.0f + value); }, 0.0f, 1.0f,
136             op_data->one_over_one_plus_x_lut, kInt16LUTArraySize);
137     op_data->zero_point = output->params.zero_point;
138     op_data->scale = output->params.scale;
139   }
140 
141   auto* params = static_cast<TfLiteSoftmaxParams*>(node->builtin_data);
142   return CalculateSoftmaxParams(context, input, output, params, op_data);
143 }
144 
145 }  // namespace tflite
146