• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stddef.h>
16 #include <stdint.h>
17 
18 #include <algorithm>
19 
20 #include "ruy/profiler/instrumentation.h"  // from @ruy
21 #include "tensorflow/lite/c/common.h"
22 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
23 #include "tensorflow/lite/kernels/internal/quantization_util.h"
24 #include "tensorflow/lite/kernels/internal/reference/binary_function.h"
25 #include "tensorflow/lite/kernels/internal/reference/integer_ops/add.h"
26 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
27 #include "tensorflow/lite/kernels/internal/tensor.h"
28 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
29 #include "tensorflow/lite/kernels/kernel_util.h"
30 
31 namespace tflite {
32 namespace ops {
33 namespace builtin {
34 namespace squared_difference {
35 
36 constexpr int kInputTensor1 = 0;
37 constexpr int kInputTensor2 = 1;
38 constexpr int kOutputTensor = 0;
39 
40 struct OpData {
41   bool requires_broadcast;
42   ArithmeticParams arithmetic_params;
43 };
44 
45 template <typename T>
SquaredDifference(T input1,T input2)46 T SquaredDifference(T input1, T input2) {
47   const T difference = input1 - input2;
48   return difference * difference;
49 }
50 
Init(TfLiteContext * context,const char * buffer,size_t length)51 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
52   auto* data = new OpData;
53   data->requires_broadcast = false;
54   return data;
55 }
56 
Free(TfLiteContext * context,void * buffer)57 void Free(TfLiteContext* context, void* buffer) {
58   delete reinterpret_cast<OpData*>(buffer);
59 }
60 
Prepare(TfLiteContext * context,TfLiteNode * node)61 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
62   OpData* data = reinterpret_cast<OpData*>(node->user_data);
63 
64   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
65   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
66 
67   const TfLiteTensor* input1;
68   TF_LITE_ENSURE_OK(context,
69                     GetInputSafe(context, node, kInputTensor1, &input1));
70   const TfLiteTensor* input2;
71   TF_LITE_ENSURE_OK(context,
72                     GetInputSafe(context, node, kInputTensor2, &input2));
73   TfLiteTensor* output;
74   TF_LITE_ENSURE_OK(context,
75                     GetOutputSafe(context, node, kOutputTensor, &output));
76 
77   TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
78   output->type = input2->type;
79 
80   // Ensure the quantization parameters are equivalent.
81   if (input1->type == kTfLiteInt8) {
82     const auto& input1_quantization_params = input1->params;
83     const auto& input2_quantization_params = input2->params;
84     const auto& output_quantization_params = output->params;
85     const int32_t integer_type_min = std::numeric_limits<int8_t>::min();
86     const int32_t integer_type_max = std::numeric_limits<int8_t>::max();
87     TF_LITE_ENSURE(context,
88                    input1_quantization_params.zero_point >= integer_type_min);
89     TF_LITE_ENSURE(context,
90                    input1_quantization_params.zero_point <= integer_type_max);
91     TF_LITE_ENSURE(context,
92                    input2_quantization_params.zero_point >= integer_type_min);
93     TF_LITE_ENSURE(context,
94                    input2_quantization_params.zero_point <= integer_type_max);
95     TF_LITE_ENSURE(context,
96                    output_quantization_params.zero_point >= integer_type_min);
97     TF_LITE_ENSURE(context,
98                    output_quantization_params.zero_point <= integer_type_max);
99     data->arithmetic_params.input1_offset =
100         -input1_quantization_params.zero_point;
101     data->arithmetic_params.input2_offset =
102         -input2_quantization_params.zero_point;
103     data->arithmetic_params.output_offset =
104         output_quantization_params.zero_point;
105 
106     // shift to make integer for scales.
107     data->arithmetic_params.left_shift = 7;
108     const double twice_max_input_scale =
109         2 * std::max(input1_quantization_params.scale,
110                      input2_quantization_params.scale);
111     const double real_input1_multiplier =
112         input1_quantization_params.scale / twice_max_input_scale;
113     double real_input2_multiplier =
114         input2_quantization_params.scale / twice_max_input_scale;
115     const double real_output_multiplier =
116         (twice_max_input_scale * twice_max_input_scale) /
117         ((1 << data->arithmetic_params.left_shift * 2) *
118          output_quantization_params.scale);
119     tflite::QuantizeMultiplierSmallerThanOneExp(
120         real_input1_multiplier, &data->arithmetic_params.input1_multiplier,
121         &data->arithmetic_params.input1_shift);
122     tflite::QuantizeMultiplierSmallerThanOneExp(
123         real_input2_multiplier, &data->arithmetic_params.input2_multiplier,
124         &data->arithmetic_params.input2_shift);
125     tflite::QuantizeMultiplierSmallerThanOneExp(
126         real_output_multiplier, &data->arithmetic_params.output_multiplier,
127         &data->arithmetic_params.output_shift);
128     data->arithmetic_params.quantized_activation_min =
129         std::numeric_limits<int8_t>::min();
130     data->arithmetic_params.quantized_activation_max =
131         std::numeric_limits<int8_t>::max();
132   }
133 
134   data->requires_broadcast = !HaveSameShapes(input1, input2);
135 
136   TfLiteIntArray* output_size = nullptr;
137   if (data->requires_broadcast) {
138     TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast(
139                                    context, input1, input2, &output_size));
140   } else {
141     output_size = TfLiteIntArrayCopy(input1->dims);
142   }
143 
144   return context->ResizeTensor(context, output, output_size);
145 }
146 
SquaredDifference(int8_t x,int8_t y,const ArithmeticParams & params)147 inline int8_t SquaredDifference(int8_t x, int8_t y,
148                                 const ArithmeticParams& params) {
149   const int32_t input1_val = params.input1_offset + x;
150   const int32_t input2_val = params.input2_offset + y;
151   const int32_t shifted_input1_val = input1_val * (1 << params.left_shift);
152   const int32_t shifted_input2_val = input2_val * (1 << params.left_shift);
153   const int32_t scaled_input1_val =
154       MultiplyByQuantizedMultiplierSmallerThanOneExp(
155           shifted_input1_val, params.input1_multiplier, params.input1_shift);
156   const int32_t scaled_input2_val =
157       MultiplyByQuantizedMultiplierSmallerThanOneExp(
158           shifted_input2_val, params.input2_multiplier, params.input2_shift);
159   const int32_t raw_diff = scaled_input1_val - scaled_input2_val;
160 
161   // Max of this is 255^2 * (1 << 14), so won't overflow 32 bits.
162   const int32_t squared_raw_diff = raw_diff * raw_diff;
163   const int32_t raw_output =
164       MultiplyByQuantizedMultiplierSmallerThanOneExp(
165           squared_raw_diff, params.output_multiplier, params.output_shift) +
166       params.output_offset;
167   const int32_t clamped_output =
168       std::min(params.quantized_activation_max,
169                std::max(params.quantized_activation_min, raw_output));
170   return static_cast<int8_t>(clamped_output);
171 }
172 
173 template <typename T>
EvalQuantizedSquaredDifference(TfLiteContext * context,TfLiteNode * node,const OpData * data,const TfLiteTensor * input1,const TfLiteTensor * input2,TfLiteTensor * output)174 void EvalQuantizedSquaredDifference(TfLiteContext* context, TfLiteNode* node,
175                                     const OpData* data,
176                                     const TfLiteTensor* input1,
177                                     const TfLiteTensor* input2,
178                                     TfLiteTensor* output) {
179   const auto* op_data = static_cast<const OpData*>(node->user_data);
180   if (data->requires_broadcast) {
181     reference_integer_ops::BroadcastBinaryFunction4DSlow(
182         op_data->arithmetic_params, GetTensorShape(input1),
183         GetTensorData<T>(input1), GetTensorShape(input2),
184         GetTensorData<T>(input2), GetTensorShape(output),
185         GetTensorData<T>(output), reference_integer_ops::CheckArithmeticParams,
186         SquaredDifference);
187   } else {
188     const int flat_size = GetTensorShape(input1).FlatSize();
189     reference_integer_ops::ElementWise(
190         flat_size, op_data->arithmetic_params, GetTensorData<int8_t>(input1),
191         GetTensorData<int8_t>(input2), GetTensorData<int8_t>(output),
192         reference_integer_ops::CheckArithmeticParams, SquaredDifference);
193   }
194 }
195 
196 template <typename T>
EvalSquaredDifference(TfLiteContext * context,TfLiteNode * node,const OpData * data,const TfLiteTensor * input1,const TfLiteTensor * input2,TfLiteTensor * output)197 void EvalSquaredDifference(TfLiteContext* context, TfLiteNode* node,
198                            const OpData* data, const TfLiteTensor* input1,
199                            const TfLiteTensor* input2, TfLiteTensor* output) {
200   if (data->requires_broadcast) {
201     reference_ops::BroadcastBinaryFunction4DSlow<T, T, T>(
202         GetTensorShape(input1), GetTensorData<T>(input1),
203         GetTensorShape(input2), GetTensorData<T>(input2),
204         GetTensorShape(output), GetTensorData<T>(output), SquaredDifference<T>);
205   } else {
206     reference_ops::BinaryFunction<T, T, T>(
207         GetTensorShape(input1), GetTensorData<T>(input1),
208         GetTensorShape(input2), GetTensorData<T>(input2),
209         GetTensorShape(output), GetTensorData<T>(output), SquaredDifference<T>);
210   }
211 }
212 
Eval(TfLiteContext * context,TfLiteNode * node)213 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
214   OpData* data = reinterpret_cast<OpData*>(node->user_data);
215   ruy::profiler::ScopeLabel label("SquaredDifference");
216 
217   const TfLiteTensor* input1;
218   TF_LITE_ENSURE_OK(context,
219                     GetInputSafe(context, node, kInputTensor1, &input1));
220   const TfLiteTensor* input2;
221   TF_LITE_ENSURE_OK(context,
222                     GetInputSafe(context, node, kInputTensor2, &input2));
223   TfLiteTensor* output;
224   TF_LITE_ENSURE_OK(context,
225                     GetOutputSafe(context, node, kOutputTensor, &output));
226 
227   if (output->type == kTfLiteFloat32) {
228     EvalSquaredDifference<float>(context, node, data, input1, input2, output);
229   } else if (output->type == kTfLiteInt32) {
230     EvalSquaredDifference<int32_t>(context, node, data, input1, input2, output);
231   } else if (output->type == kTfLiteInt8) {
232     EvalQuantizedSquaredDifference<int8_t>(context, node, data, input1, input2,
233                                            output);
234   } else {
235     TF_LITE_KERNEL_LOG(
236         context,
237         "SquaredDifference only supports FLOAT32 and INT32 now, got %d.",
238         output->type);
239     return kTfLiteError;
240   }
241 
242   return kTfLiteOk;
243 }
244 
245 }  // namespace squared_difference
246 
Register_SQUARED_DIFFERENCE()247 TfLiteRegistration* Register_SQUARED_DIFFERENCE() {
248   static TfLiteRegistration r = {
249       squared_difference::Init, squared_difference::Free,
250       squared_difference::Prepare, squared_difference::Eval};
251   return &r;
252 }
253 
254 }  // namespace builtin
255 }  // namespace ops
256 }  // namespace tflite
257