1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SOFTMAX_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SOFTMAX_H_
17
18 #include <limits>
19
20 #include "fixedpoint/fixedpoint.h"
21 #include "tensorflow/lite/kernels/internal/common.h"
22 #include "tensorflow/lite/kernels/internal/cppmath.h"
23 #include "tensorflow/lite/kernels/internal/quantization_util.h"
24 #include "tensorflow/lite/kernels/internal/types.h"
25 #include "tensorflow/lite/kernels/op_macros.h"
26
27 namespace tflite {
28 namespace reference_ops {
29
Softmax(const SoftmaxParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)30 inline void Softmax(const SoftmaxParams& params,
31 const RuntimeShape& input_shape, const float* input_data,
32 const RuntimeShape& output_shape, float* output_data) {
33 const int trailing_dim = input_shape.DimensionsCount() - 1;
34 const int outer_size =
35 MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
36 const int depth =
37 MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
38
39 for (int i = 0; i < outer_size; ++i) {
40 // Find max element value which we'll use to ensure numerical stability
41 // taking advantage of the following equality:
42 // exp(x[i])/sum(exp(x[i])) == exp(x[i]+C)/sum(exp(x[i]+C))
43 float max = std::numeric_limits<float>::lowest();
44 for (int c = 0; c < depth; ++c) {
45 max = std::max(max, input_data[i * depth + c]);
46 }
47
48 // Compute sum.
49 float sum = 0.f;
50 for (int c = 0; c < depth; ++c) {
51 const float exp_c = std::exp((input_data[i * depth + c] - max) *
52 static_cast<float>(params.beta));
53 output_data[i * depth + c] = exp_c;
54 sum += exp_c;
55 }
56
57 // Compute result.
58 for (int c = 0; c < depth; ++c) {
59 output_data[i * depth + c] = output_data[i * depth + c] / sum;
60 }
61 }
62 }
63
64 // Quantized softmax with int8_t/uint8_t input and int8_t/uint8_t/int16_t
65 // output.
66 template <typename InputT, typename OutputT>
Softmax(const SoftmaxParams & params,const RuntimeShape & input_shape,const InputT * input_data,const RuntimeShape & output_shape,OutputT * output_data)67 inline void Softmax(const SoftmaxParams& params,
68 const RuntimeShape& input_shape, const InputT* input_data,
69 const RuntimeShape& output_shape, OutputT* output_data) {
70 const int32_t input_beta_multiplier = params.input_multiplier;
71 const int32_t input_beta_left_shift = params.input_left_shift;
72 const int diff_min = params.diff_min;
73 // The representation chosen for the input to the exp() function is Q5.26.
74 // We need to leave extra space since values that we skip might be as large as
75 // -32 before multiplying by input_beta_multiplier, and therefore as large as
76 // -16 afterwards. Note that exp(-8) is definitely not insignificant to
77 // accumulation, but exp(-16) definitely is.
78 static const int kScaledDiffIntegerBits = 5;
79 static const int kAccumulationIntegerBits = 12;
80 using FixedPointScaledDiff =
81 gemmlowp::FixedPoint<int32_t, kScaledDiffIntegerBits>;
82 using FixedPointAccum =
83 gemmlowp::FixedPoint<int32_t, kAccumulationIntegerBits>;
84 using FixedPoint0 = gemmlowp::FixedPoint<int32_t, 0>;
85
86 const int trailing_dim = input_shape.DimensionsCount() - 1;
87 const int outer_size =
88 MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
89 const int depth =
90 MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
91
92 for (int i = 0; i < outer_size; ++i) {
93 InputT max_in_row = std::numeric_limits<InputT>::min();
94 for (int c = 0; c < depth; ++c) {
95 max_in_row = std::max(max_in_row, input_data[i * depth + c]);
96 }
97
98 FixedPointAccum sum_of_exps = FixedPointAccum::Zero();
99 for (int c = 0; c < depth; ++c) {
100 int32_t input_diff =
101 static_cast<int32_t>(input_data[i * depth + c]) - max_in_row;
102 if (input_diff >= diff_min) {
103 const int32_t input_diff_rescaled =
104 MultiplyByQuantizedMultiplierGreaterThanOne(
105 input_diff, input_beta_multiplier, input_beta_left_shift);
106 const FixedPointScaledDiff scaled_diff_f8 =
107 FixedPointScaledDiff::FromRaw(input_diff_rescaled);
108 sum_of_exps = sum_of_exps + gemmlowp::Rescale<kAccumulationIntegerBits>(
109 exp_on_negative_values(scaled_diff_f8));
110 }
111 }
112
113 int num_bits_over_unit;
114 FixedPoint0 shifted_scale = FixedPoint0::FromRaw(GetReciprocal(
115 sum_of_exps.raw(), kAccumulationIntegerBits, &num_bits_over_unit));
116
117 for (int c = 0; c < depth; ++c) {
118 int32_t input_diff =
119 static_cast<int32_t>(input_data[i * depth + c]) - max_in_row;
120 if (input_diff >= diff_min) {
121 const int32_t input_diff_rescaled =
122 MultiplyByQuantizedMultiplierGreaterThanOne(
123 input_diff, input_beta_multiplier, input_beta_left_shift);
124 const FixedPointScaledDiff scaled_diff_f8 =
125 FixedPointScaledDiff::FromRaw(input_diff_rescaled);
126
127 FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8);
128 int32_t unsat_output = gemmlowp::RoundingDivideByPOT(
129 (shifted_scale * exp_in_0).raw(),
130 num_bits_over_unit + 31 - (sizeof(OutputT) * 8));
131
132 const int32_t shifted_output =
133 unsat_output +
134 static_cast<int32_t>(std::numeric_limits<OutputT>::min());
135
136 output_data[i * depth + c] = static_cast<OutputT>(std::max(
137 std::min(shifted_output,
138 static_cast<int32_t>(std::numeric_limits<OutputT>::max())),
139 static_cast<int32_t>(std::numeric_limits<OutputT>::min())));
140 } else {
141 output_data[i * depth + c] = std::numeric_limits<OutputT>::min();
142 }
143 }
144 }
145 }
146
147 // Computes exp(input - max_input)
SoftMaxCalculateExp(const SoftmaxParams & params,const int16_t * input_data,const int depth,int16_t max_in_row,int i,int c)148 inline int16_t SoftMaxCalculateExp(const SoftmaxParams& params,
149 const int16_t* input_data, const int depth,
150 int16_t max_in_row, int i, int c) {
151 int32_t input_diff = input_data[i * depth + c] - max_in_row;
152 // scale the input_diff such that [-65535, 0] correspond to [-10.0, 0.0]
153 // exp lut generated with range [-10, 0], as exp(-10) is negligible.
154 int32_t scaled_diff = MultiplyByQuantizedMultiplier(
155 input_diff, params.input_multiplier, params.input_left_shift);
156 // recenter to [-32768, 32767]
157 int32_t sym_scaled_diff = scaled_diff + 32767;
158 int16_t sat_sym_scaled_diff =
159 std::min(std::max(sym_scaled_diff, static_cast<int32_t>(-32768)),
160 static_cast<int32_t>(32767));
161 // apply the exp() LUT activation function
162 return generic_int16_table_lookup(sat_sym_scaled_diff, params.exp_lut);
163 }
164 // Quantized softmax with int16_t input and int16_t output.
SoftmaxInt16(const SoftmaxParams & params,const RuntimeShape & input_shape,const int16_t * input_data,const RuntimeShape & output_shape,int16_t * output_data)165 inline void SoftmaxInt16(const SoftmaxParams& params,
166 const RuntimeShape& input_shape,
167 const int16_t* input_data,
168 const RuntimeShape& output_shape,
169 int16_t* output_data) {
170 const int trailing_dim = input_shape.DimensionsCount() - 1;
171 const int outer_size =
172 MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
173 const int depth =
174 MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
175
176 for (int i = 0; i < outer_size; ++i) {
177 // Find the largest element
178 int16_t max_in_row = std::numeric_limits<int16_t>::min();
179 for (int c = 0; c < depth; ++c) {
180 max_in_row = std::max(max_in_row, input_data[i * depth + c]);
181 }
182
183 // This loops computes the exp values and their sum. We will need the exp
184 // values later on in the function so we cache them in the output_data
185 // buffer. This is an optimization done to avoid calculating the exp values
186 // twice making use of the output_data buffer as scratch memory.
187 int32_t sum_of_exps = 0; // Q16.15 fixed point format.
188 int16_t* exp_results_Q015 = output_data + i * depth;
189 for (int c = 0; c < depth; ++c) {
190 exp_results_Q015[c] =
191 SoftMaxCalculateExp(params, input_data, depth, max_in_row, i, c);
192 sum_of_exps += exp_results_Q015[c];
193 }
194
195 // Compute the reciprocal 1/sum_of_exps
196 uint8_t headroom_plus_one =
197 CountLeadingZeros(static_cast<uint32_t>(sum_of_exps));
198 int32_t shifted_sum =
199 ((static_cast<int64_t>(sum_of_exps) << (headroom_plus_one - 1)) +
200 (1 << 13)) >>
201 14;
202 // since the LUT computes 1/(1 + x) we need to first compute x = (sum - 1).
203 // also, the LUT expects a symmetrical input, so we must also recenter x
204 // from [0, 65535] to [-32768, 32767].
205 int32_t sym_shifted_sum = shifted_sum + (-((1 << 15) + (1 << 16)));
206 int16_t sat_sym_shifted_sum = static_cast<int16_t>(
207 std::min(std::max(sym_shifted_sum, static_cast<int32_t>(-32768)),
208 static_cast<int32_t>(32767)));
209 // apply 1/(1 + x) LUT activation function
210 int16_t reciprocal_scale_Q015 = generic_int16_table_lookup(
211 sat_sym_shifted_sum, params.one_over_one_plus_x_lut);
212
213 // Rescale the exp_result with reciprocal
214 // range of output is [0, 32767] correspond to [0.0, 1.0]
215 for (int c = 0; c < depth; ++c) {
216 uint8_t right_shift = 31 - headroom_plus_one;
217 int64_t round = 1 << (right_shift - 1);
218 int32_t result = (static_cast<int64_t>(exp_results_Q015[c]) *
219 static_cast<int64_t>(reciprocal_scale_Q015) +
220 round) >>
221 right_shift;
222 output_data[i * depth + c] = static_cast<int16_t>(
223 std::min(std::max(result, static_cast<int32_t>(0)),
224 static_cast<int32_t>(32767)));
225 }
226 }
227 }
228
229 } // namespace reference_ops
230 } // namespace tflite
231
232 #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SOFTMAX_H_
233