• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/kernels/internal/reference/softmax.h"
17 
18 #include "tensorflow/lite/c/builtin_op_data.h"
19 #include "tensorflow/lite/c/common.h"
20 #include "tensorflow/lite/kernels/internal/common.h"
21 #include "tensorflow/lite/kernels/internal/quantization_util.h"
22 #include "tensorflow/lite/kernels/internal/reference/integer_ops/softmax.h"
23 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
24 #include "tensorflow/lite/kernels/kernel_util.h"
25 #include "tensorflow/lite/kernels/op_macros.h"
26 
27 namespace tflite {
28 namespace ops {
29 namespace micro {
30 
31 namespace xtensa {
32 namespace hifimini {
33 
34 // Quantized softmax with int8 input and int8/int16 output.
35 template <typename OutputT = int8_t>
Softmax(const SoftmaxParams & params,const RuntimeShape & input_shape,const int8 * input_data,const RuntimeShape & output_shape,OutputT * output_data)36 inline void Softmax(const SoftmaxParams& params,
37                     const RuntimeShape& input_shape, const int8* input_data,
38                     const RuntimeShape& output_shape, OutputT* output_data) {
39   const int32_t input_beta_multiplier = params.input_multiplier;
40   const int32_t input_beta_left_shift = params.input_left_shift;
41   const int diff_min = params.diff_min;
42   // The representation chosen for the input to the exp() function is Q5.26.
43   // We need to leave extra space since values that we skip might be as large as
44   // -32 before multiplying by input_beta_multiplier, and therefore as large as
45   // -16 afterwards.  Note that exp(-8) is definitely not insignificant to
46   // accumulation, but exp(-16) definitely is.
47   static const int kScaledDiffIntegerBits = 5;
48   static const int kAccumulationIntegerBits = 12;
49   using FixedPointScaledDiff =
50       gemmlowp::FixedPoint<int32_t, kScaledDiffIntegerBits>;
51   using FixedPointAccum =
52       gemmlowp::FixedPoint<int32_t, kAccumulationIntegerBits>;
53   using FixedPoint0 = gemmlowp::FixedPoint<int32_t, 0>;
54 
55   const int trailing_dim = input_shape.DimensionsCount() - 1;
56   const int outer_size =
57       MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
58   const int depth =
59       MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
60 
61   for (int i = 0; i < outer_size; ++i) {
62     int8 max_in_row = -128;
63     for (int c = 0; c < depth; ++c) {
64       max_in_row = std::max(max_in_row, input_data[i * depth + c]);
65     }
66 
67     FixedPointAccum sum_of_exps = FixedPointAccum::Zero();
68     for (int c = 0; c < depth; ++c) {
69       int32_t input_diff =
70           static_cast<int32_t>(input_data[i * depth + c]) - max_in_row;
71       if (input_diff >= diff_min) {
72         const int32_t input_diff_rescaled =
73             MultiplyByQuantizedMultiplierGreaterThanOne(
74                 input_diff, input_beta_multiplier, input_beta_left_shift);
75         const FixedPointScaledDiff scaled_diff_f8 =
76             FixedPointScaledDiff::FromRaw(input_diff_rescaled);
77         sum_of_exps = sum_of_exps + gemmlowp::Rescale<kAccumulationIntegerBits>(
78                                         exp_on_negative_values(scaled_diff_f8));
79       }
80     }
81 
82     int num_bits_over_unit;
83     FixedPoint0 shifted_scale = FixedPoint0::FromRaw(GetReciprocal(
84         sum_of_exps.raw(), kAccumulationIntegerBits, &num_bits_over_unit));
85 
86     for (int c = 0; c < depth; ++c) {
87       int32_t input_diff =
88           static_cast<int32_t>(input_data[i * depth + c]) - max_in_row;
89       if (input_diff >= diff_min) {
90         const int32_t input_diff_rescaled =
91             MultiplyByQuantizedMultiplierGreaterThanOne(
92                 input_diff, input_beta_multiplier, input_beta_left_shift);
93         const FixedPointScaledDiff scaled_diff_f8 =
94             FixedPointScaledDiff::FromRaw(input_diff_rescaled);
95 
96         FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8);
97         const int32_t unsat_output = gemmlowp::RoundingDivideByPOT(
98             (shifted_scale * exp_in_0).raw(),
99             num_bits_over_unit + 31 - (sizeof(OutputT) * 8));
100         // TODO(b/148494470): Handle int32 shifts properly:
101         const int32_t shifted_output =
102             unsat_output -
103             (static_cast<int32_t>(std::numeric_limits<OutputT>::max()) + 1);
104         output_data[i * depth + c] = static_cast<OutputT>(std::max(
105             std::min(shifted_output,
106                      static_cast<int32_t>(std::numeric_limits<OutputT>::max())),
107             static_cast<int32_t>(std::numeric_limits<OutputT>::min())));
108       } else {
109         output_data[i * depth + c] = std::numeric_limits<OutputT>::min();
110       }
111     }
112   }
113 }
114 
115 }  // namespace hifimini
116 }  // namespace xtensa
117 
118 namespace activations {
119 namespace {
120 
121 struct OpData {
122   int32_t input_multiplier = 0;
123   int input_left_shift = 0;
124   int32_t input_range_radius = 0;
125   int diff_min = 0;
126 };
127 
128 // This size will work for both the hotword (1) and ambient music (0):
129 static OpData kStaticOpData;
130 
CalculateSoftmaxOpData(TfLiteContext * context,const TfLiteTensor * input,TfLiteTensor * output,const TfLiteSoftmaxParams * params,OpData * data)131 TfLiteStatus CalculateSoftmaxOpData(TfLiteContext* context,
132                                     const TfLiteTensor* input,
133                                     TfLiteTensor* output,
134                                     const TfLiteSoftmaxParams* params,
135                                     OpData* data) {
136   if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
137     if (input->type == kTfLiteUInt8) {
138       TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
139     } else {
140       if (output->type == kTfLiteInt16) {
141         TF_LITE_ENSURE_EQ(context, output->params.zero_point, -32768);
142         // NOTE: Current int16 softmax output does not require symmetric scaling
143         // - so no need to verify scale here.
144       } else {
145         TF_LITE_ENSURE_EQ(context, output->params.zero_point, -128);
146         TF_LITE_ENSURE(context, output->params.scale == 1.f / 256);
147       }
148     }
149 
150     static const int kScaledDiffIntegerBits = 5;
151 
152     tflite::PreprocessSoftmaxScaling(
153         static_cast<double>(params->beta),
154         static_cast<double>(input->params.scale), kScaledDiffIntegerBits,
155         &data->input_multiplier, &data->input_left_shift);
156     data->diff_min = -1.0 * tflite::CalculateInputRadius(
157                                 kScaledDiffIntegerBits, data->input_left_shift);
158   }
159   return kTfLiteOk;
160 }
161 
162 }  // namespace
163 
Softmax2DQuantized(const TfLiteTensor * input,TfLiteTensor * output,TfLiteSoftmaxParams * params,OpData * data)164 void Softmax2DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
165                         TfLiteSoftmaxParams* params, OpData* data) {
166   const int batch_size = input->dims->data[0];
167   const int input_size = input->dims->data[1];
168   const int32_t shape_data[4] = {batch_size, 1, 1, input_size};
169   RuntimeShape shape(4, shape_data);
170   SoftmaxParams op_params;
171   op_params.input_multiplier = data->input_multiplier;
172   op_params.input_left_shift = data->input_left_shift;
173   op_params.diff_min = data->diff_min;
174 
175   if (output->type == kTfLiteInt16) {
176     xtensa::hifimini::Softmax(op_params, shape, GetTensorData<int8_t>(input),
177                               shape, GetTensorData<int16_t>(output));
178 
179   } else {
180     xtensa::hifimini::Softmax(op_params, shape, GetTensorData<int8_t>(input),
181                               shape, GetTensorData<int8_t>(output));
182   }
183 }
184 
Init(TfLiteContext * context,const char * buffer,size_t length)185 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
186   return nullptr;
187 }
188 
Free(TfLiteContext * context,void * buffer)189 void Free(TfLiteContext* context, void* buffer) {}
190 
SoftmaxPrepare(TfLiteContext * context,TfLiteNode * node)191 TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
192   auto* params = reinterpret_cast<TfLiteSoftmaxParams*>(node->builtin_data);
193 
194   const TfLiteTensor* input = GetInput(context, node, 0);
195   TfLiteTensor* output = GetOutput(context, node, 0);
196 
197   // TODO(b/132070898): Use statically slotted OpData structures until a
198   // scratch memory API is ready.
199   OpData* op_data = &kStaticOpData;
200   node->user_data = op_data;
201 
202   TF_LITE_ENSURE_STATUS(
203       CalculateSoftmaxOpData(context, input, output, params, op_data));
204 
205   return kTfLiteOk;
206 }
207 
SoftmaxEval(TfLiteContext * context,TfLiteNode * node)208 TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
209   auto* params = reinterpret_cast<TfLiteSoftmaxParams*>(node->builtin_data);
210   auto* op_data = reinterpret_cast<OpData*>(node->user_data);
211 
212   const TfLiteTensor* input = GetInput(context, node, 0);
213   TfLiteTensor* output = GetOutput(context, node, 0);
214 
215   switch (input->type) {
216     case kTfLiteInt8: {
217       if (NumDimensions(input) == 2) {
218         Softmax2DQuantized(input, output, params, op_data);
219         return kTfLiteOk;
220       }
221       context->ReportError(context,
222                            "Only 2D tensors supported currently, got %dD.",
223                            NumDimensions(input));
224       return kTfLiteError;
225     }
226     default:
227       context->ReportError(context, "Only int8_t supported currently, got %d.",
228                            input->type);
229       return kTfLiteError;
230   }
231 }
232 }  // namespace activations
233 
Register_SOFTMAX()234 TfLiteRegistration* Register_SOFTMAX() {
235   static TfLiteRegistration r = {activations::Init, activations::Free,
236                                  activations::SoftmaxPrepare,
237                                  activations::SoftmaxEval};
238   return &r;
239 }
240 
241 }  // namespace micro
242 }  // namespace ops
243 }  // namespace tflite
244