1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/c/builtin_op_data.h"
17 #include "tensorflow/lite/c/common.h"
18 #include "tensorflow/lite/kernels/internal/common.h"
19 #include "tensorflow/lite/kernels/internal/quantization_util.h"
20 #include "tensorflow/lite/kernels/kernel_util.h"
21 #include "tensorflow/lite/kernels/op_macros.h"
22 #include "tensorflow/lite/micro/kernels/softmax.h"
23
24 namespace tflite {
25
26 namespace {
27 // Softmax parameter data that persists in user_data
28 const int kInt16LUTArraySize = 513;
29
CalculateSoftmaxParams(TfLiteContext * context,const TfLiteTensor * input,TfLiteTensor * output,const TfLiteSoftmaxParams * params,SoftmaxParams * op_data)30 TfLiteStatus CalculateSoftmaxParams(TfLiteContext* context,
31 const TfLiteTensor* input,
32 TfLiteTensor* output,
33 const TfLiteSoftmaxParams* params,
34 SoftmaxParams* op_data) {
35 if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8 ||
36 input->type == kTfLiteInt16) {
37 if (input->type == kTfLiteUInt8) {
38 TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteUInt8);
39 TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
40 } else if (input->type == kTfLiteInt16) {
41 TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
42 TF_LITE_ENSURE_NEAR(context, output->params.scale, 1.f / 32768,
43 (0.001f * 1.f / 32768));
44 } else { // input->type == kTfLiteInt8
45 TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteInt8);
46 if (output->type == kTfLiteInt16) {
47 TF_LITE_ENSURE_EQ(context, output->params.zero_point, -32768);
48 TF_LITE_ENSURE_NEAR(context, output->params.scale, 1.f / 65536,
49 (0.001f * 1.f / 65536));
50 } else { // output->type == kTfLiteint8
51 TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteInt8);
52 TF_LITE_ENSURE_EQ(context, output->params.zero_point, -128);
53 TF_LITE_ENSURE(context, output->params.scale == 1.f / 256);
54 }
55 }
56
57 static const int kScaledDiffIntegerBits = 5;
58
59 // Calculate input_multiplier and input_left_shift
60 if (input->type == kTfLiteInt16) {
61 int input_left_shift;
62 double input_scale_beta_rescale =
63 static_cast<double>(input->params.scale) *
64 static_cast<double>(params->beta) /
65 (10.0 / 65535.0); // scale the input_diff such that [-65535, 0]
66 // correspond to [-10.0, 0.0]
67 QuantizeMultiplier(input_scale_beta_rescale, &op_data->input_multiplier,
68 &input_left_shift);
69 op_data->input_left_shift = input_left_shift;
70 } else {
71 int input_left_shift;
72 tflite::PreprocessSoftmaxScaling(
73 static_cast<double>(params->beta),
74 static_cast<double>(input->params.scale), kScaledDiffIntegerBits,
75 &op_data->input_multiplier, &input_left_shift);
76 op_data->input_left_shift = input_left_shift;
77 op_data->diff_min =
78 -1.0 * tflite::CalculateInputRadius(kScaledDiffIntegerBits,
79 op_data->input_left_shift);
80 }
81 } else {
82 TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
83 TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32);
84 op_data->beta = static_cast<double>(params->beta);
85 }
86 return kTfLiteOk;
87 }
88
89 } // namespace
90
SoftmaxInit(TfLiteContext * context,const char * buffer,size_t length)91 void* SoftmaxInit(TfLiteContext* context, const char* buffer, size_t length) {
92 TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
93 return context->AllocatePersistentBuffer(context, sizeof(SoftmaxParams));
94 }
95
SoftmaxPrepare(TfLiteContext * context,TfLiteNode * node)96 TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
97 TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
98 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
99 const TfLiteTensor* input = GetInput(context, node, 0);
100 TF_LITE_ENSURE(context, input != nullptr);
101 TF_LITE_ENSURE(context, NumDimensions(input) >= 1);
102 TfLiteTensor* output = GetOutput(context, node, 0);
103 TF_LITE_ENSURE(context, output != nullptr);
104
105 TF_LITE_ENSURE(context, node->user_data != nullptr);
106 SoftmaxParams* op_data = static_cast<SoftmaxParams*>(node->user_data);
107 // Only allocate LUTs for KTfLiteInt16 data type
108 if (input->type == kTfLiteInt16) {
109 void* raw_exp_lut = context->AllocatePersistentBuffer(
110 context, sizeof(int16_t) * kInt16LUTArraySize);
111 TF_LITE_ENSURE(context, raw_exp_lut != nullptr);
112 op_data->exp_lut = reinterpret_cast<int16_t*>(raw_exp_lut);
113 void* one_over_one_plus_x_lut = context->AllocatePersistentBuffer(
114 context, sizeof(int16_t) * kInt16LUTArraySize);
115 TF_LITE_ENSURE(context, one_over_one_plus_x_lut != nullptr);
116 op_data->one_over_one_plus_x_lut =
117 reinterpret_cast<int16_t*>(one_over_one_plus_x_lut);
118 }
119
120 if (output->type == kTfLiteInt16) {
121 TF_LITE_ENSURE(context, input->type == kTfLiteInt8 ||
122 input->type == kTfLiteUInt8 ||
123 input->type == kTfLiteInt16);
124 } else {
125 TF_LITE_ENSURE_EQ(context, input->type, output->type);
126 }
127
128 // Populate LUT if required
129 if (input->type == kTfLiteInt16) {
130 TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
131 // exp LUT only used on negative values
132 // we consider exp(-10.0) is insignificant to accumulation
133 gen_lut([](float value) { return std::exp(value); }, -10.0f, 0.0f,
134 op_data->exp_lut, kInt16LUTArraySize);
135 gen_lut([](float value) { return 1.0f / (1.0f + value); }, 0.0f, 1.0f,
136 op_data->one_over_one_plus_x_lut, kInt16LUTArraySize);
137 op_data->zero_point = output->params.zero_point;
138 op_data->scale = output->params.scale;
139 }
140
141 auto* params = static_cast<TfLiteSoftmaxParams*>(node->builtin_data);
142 return CalculateSoftmaxParams(context, input, output, params, op_data);
143 }
144
145 } // namespace tflite
146