• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/kernels/internal/reference/transpose_conv.h"
17 
18 #include "tensorflow/lite/c/builtin_op_data.h"
19 #include "tensorflow/lite/c/common.h"
20 #include "tensorflow/lite/kernels/internal/common.h"
21 #include "tensorflow/lite/kernels/internal/quantization_util.h"
22 #include "tensorflow/lite/kernels/internal/reference/integer_ops/transpose_conv.h"
23 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
24 #include "tensorflow/lite/kernels/kernel_util.h"
25 #include "tensorflow/lite/kernels/padding.h"
26 #include "tensorflow/lite/micro/kernels/kernel_util.h"
27 
28 namespace tflite {
29 namespace {
30 
31 // For the TfLite transpose_conv implementation, input tensor 0 corresponds to
32 // the OutputShapeTensor. However, since TFLM does not support dynamic tensors,
33 // the TFLM implementation ignores input tensor 0 and the only inputs we care
34 // about are kFilterTensor, kInputTensor and kBiasTensor.
35 constexpr int kFilterTensor = 1;
36 constexpr int kInputTensor = 2;
37 constexpr int kBiasTensor = 3;
38 constexpr int kOutputTensor = 0;
39 
40 // Conv is quantized along dimension 0:
41 // https://www.tensorflow.org/lite/performance/quantization_spec
42 constexpr int kConvQuantizedDimension = 0;
43 
44 struct OpData {
45   ConvParams params;
46 
47   // A scratch buffer is required for quantized implementations.
48   int scratch_buffer_index;
49 
50   // Multiplier and shift arrays are required for the int8 implementation.
51   int32_t* per_channel_output_multiplier;
52   int32_t* per_channel_output_shift;
53 };
54 
RuntimePaddingType(TfLitePadding padding)55 inline PaddingType RuntimePaddingType(TfLitePadding padding) {
56   switch (padding) {
57     case TfLitePadding::kTfLitePaddingSame:
58       return PaddingType::kSame;
59     case TfLitePadding::kTfLitePaddingValid:
60       return PaddingType::kValid;
61     case TfLitePadding::kTfLitePaddingUnknown:
62     default:
63       return PaddingType::kNone;
64   }
65 }
66 
CalculateOpData(TfLiteContext * context,TfLiteNode * node,const TfLiteConvParams * params,int width,int height,int filter_width,int filter_height,int out_width,int out_height,const TfLiteType data_type,OpData * data)67 TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
68                              const TfLiteConvParams* params, int width,
69                              int height, int filter_width, int filter_height,
70                              int out_width, int out_height,
71                              const TfLiteType data_type, OpData* data) {
72   bool has_bias = node->inputs->size == 4;
73   // Check number of inputs/outputs
74   TF_LITE_ENSURE(context, has_bias || node->inputs->size == 3);
75   TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
76 
77   // Matching GetWindowedOutputSize in TensorFlow.
78   auto padding = params->padding;
79   TfLitePaddingValues padding_values = ComputePaddingHeightWidth(
80       params->stride_height, params->stride_width,
81       params->dilation_height_factor, params->dilation_width_factor, height,
82       width, filter_height, filter_width, padding, &out_height, &out_width);
83 
84   data->params.padding_type = RuntimePaddingType(padding);
85   data->params.padding_values.width = padding_values.width;
86   data->params.padding_values.height = padding_values.height;
87 
88   // Note that quantized inference requires that all tensors have their
89   // parameters set. This is usually done during quantized training.
90   if (data_type != kTfLiteFloat32) {
91     const TfLiteTensor* input = GetInput(context, node, kInputTensor);
92     TF_LITE_ENSURE(context, input != nullptr);
93     const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
94     TF_LITE_ENSURE(context, filter != nullptr);
95     const TfLiteTensor* bias =
96         GetOptionalInputTensor(context, node, kBiasTensor);
97     TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
98     TF_LITE_ENSURE(context, output != nullptr);
99     int output_channels = filter->dims->data[kConvQuantizedDimension];
100 
101     TF_LITE_ENSURE_STATUS(tflite::PopulateConvolutionQuantizationParams(
102         context, input, filter, bias, output, params->activation,
103         &data->params.output_multiplier, &data->params.output_shift,
104         &data->params.quantized_activation_min,
105         &data->params.quantized_activation_max,
106         data->per_channel_output_multiplier,
107         reinterpret_cast<int*>(data->per_channel_output_shift),
108         output_channels));
109   }
110   return kTfLiteOk;
111 }
112 
Init(TfLiteContext * context,const char * buffer,size_t length)113 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
114   TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
115   return context->AllocatePersistentBuffer(context, sizeof(OpData));
116 }
117 
Prepare(TfLiteContext * context,TfLiteNode * node)118 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
119   TFLITE_DCHECK(node->user_data != nullptr);
120   TFLITE_DCHECK(node->builtin_data != nullptr);
121 
122   OpData* data = static_cast<OpData*>(node->user_data);
123   const auto params = static_cast<const TfLiteConvParams*>(node->builtin_data);
124 
125   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
126   TF_LITE_ENSURE(context, output != nullptr);
127   const TfLiteTensor* input = GetInput(context, node, kInputTensor);
128   TF_LITE_ENSURE(context, input != nullptr);
129   const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
130   TF_LITE_ENSURE(context, filter != nullptr);
131 
132   int input_width = input->dims->data[2];
133   int input_height = input->dims->data[1];
134   int filter_width = filter->dims->data[2];
135   int filter_height = filter->dims->data[1];
136   int output_width = output->dims->data[2];
137   int output_height = output->dims->data[1];
138 
139   // Dynamically allocate per-channel quantization parameters.
140   const int num_channels = filter->dims->data[kConvQuantizedDimension];
141   data->per_channel_output_multiplier =
142       static_cast<int32_t*>(context->AllocatePersistentBuffer(
143           context, num_channels * sizeof(int32_t)));
144   data->per_channel_output_shift =
145       static_cast<int32_t*>(context->AllocatePersistentBuffer(
146           context, num_channels * sizeof(int32_t)));
147 
148   // Quantized kernels use an int32 scratch buffer.
149   if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
150     TFLITE_DCHECK(context->RequestScratchBufferInArena != nullptr);
151     TFLITE_DCHECK(context->RequestScratchBufferInArena(
152                       context,
153                       GetTensorShape(output).FlatSize() * sizeof(int32_t),
154                       &(data->scratch_buffer_index)) == kTfLiteOk);
155   }
156 
157   // All per-channel quantized tensors need valid zero point and scale arrays.
158   if (input->type == kTfLiteInt8) {
159     TF_LITE_ENSURE_EQ(context, filter->quantization.type,
160                       kTfLiteAffineQuantization);
161 
162     const auto* affine_quantization =
163         static_cast<TfLiteAffineQuantization*>(filter->quantization.params);
164     TF_LITE_ENSURE(context, affine_quantization);
165     TF_LITE_ENSURE(context, affine_quantization->scale);
166     TF_LITE_ENSURE(context, affine_quantization->zero_point);
167 
168     TF_LITE_ENSURE(context,
169                    affine_quantization->scale->size == 1 ||
170                        affine_quantization->scale->size ==
171                            filter->dims->data[kConvQuantizedDimension]);
172     TF_LITE_ENSURE_EQ(context, affine_quantization->scale->size,
173                       affine_quantization->zero_point->size);
174   }
175 
176   TF_LITE_ENSURE_STATUS(CalculateOpData(
177       context, node, params, input_width, input_height, filter_width,
178       filter_height, output_width, output_height, input->type, data));
179 
180   // Offsets (zero points)
181   data->params.input_offset = -input->params.zero_point;
182   data->params.weights_offset = -filter->params.zero_point;
183   data->params.output_offset = output->params.zero_point;
184 
185   // Stride + dilation
186   data->params.stride_width = params->stride_width;
187   data->params.stride_height = params->stride_height;
188   data->params.dilation_width_factor = params->dilation_width_factor;
189   data->params.dilation_height_factor = params->dilation_height_factor;
190 
191   float output_activation_min, output_activation_max;
192   CalculateActivationRange(params->activation, &output_activation_min,
193                            &output_activation_max);
194   data->params.float_activation_min = output_activation_min;
195   data->params.float_activation_max = output_activation_max;
196   return kTfLiteOk;
197 }  // namespace conv
198 
Eval(TfLiteContext * context,TfLiteNode * node)199 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
200   const TfLiteEvalTensor* input =
201       tflite::micro::GetEvalInput(context, node, kInputTensor);
202   const TfLiteEvalTensor* filter =
203       tflite::micro::GetEvalInput(context, node, kFilterTensor);
204   const TfLiteEvalTensor* bias =
205       (NumInputs(node) == 4)
206           ? tflite::micro::GetEvalInput(context, node, kBiasTensor)
207           : nullptr;
208   TfLiteEvalTensor* output =
209       tflite::micro::GetEvalOutput(context, node, kOutputTensor);
210 
211   TFLITE_DCHECK(node->user_data != nullptr);
212   const OpData& data = *(static_cast<const OpData*>(node->user_data));
213 
214   TF_LITE_ENSURE_EQ(context, input->type, output->type);
215   TF_LITE_ENSURE_MSG(context, input->type == filter->type,
216                      "Hybrid models are not supported on TFLite Micro.");
217 
218   switch (input->type) {  // Already know in/out types are same.
219     case kTfLiteFloat32: {
220       reference_ops::TransposeConv(
221           data.params, tflite::micro::GetTensorShape(input),
222           tflite::micro::GetTensorData<float>(input),
223           tflite::micro::GetTensorShape(filter),
224           tflite::micro::GetTensorData<float>(filter),
225           tflite::micro::GetTensorShape(bias),
226           tflite::micro::GetTensorData<float>(bias),
227           tflite::micro::GetTensorShape(output),
228           tflite::micro::GetTensorData<float>(output),
229           tflite::micro::GetTensorShape(nullptr), nullptr);
230       break;
231     }
232     case kTfLiteInt8: {
233       int32_t* scratch_buffer = static_cast<int32_t*>(
234           context->GetScratchBuffer(context, data.scratch_buffer_index));
235       reference_integer_ops::TransposeConv(
236           data.params, data.per_channel_output_multiplier,
237           data.per_channel_output_shift, tflite::micro::GetTensorShape(input),
238           tflite::micro::GetTensorData<int8_t>(input),
239           tflite::micro::GetTensorShape(filter),
240           tflite::micro::GetTensorData<int8_t>(filter),
241           tflite::micro::GetTensorShape(bias),
242           tflite::micro::GetTensorData<int32_t>(bias),
243           tflite::micro::GetTensorShape(output),
244           tflite::micro::GetTensorData<int8_t>(output),
245           tflite::micro::GetTensorShape(nullptr), nullptr, scratch_buffer);
246       break;
247     }
248     default:
249       TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
250                          TfLiteTypeGetName(input->type), input->type);
251       return kTfLiteError;
252   }
253   return kTfLiteOk;
254 }
255 
256 }  // namespace
257 
Register_TRANSPOSE_CONV()258 TfLiteRegistration Register_TRANSPOSE_CONV() {
259   return {/*init=*/Init,
260           /*free=*/nullptr,
261           /*prepare=*/Prepare,
262           /*invoke=*/Eval,
263           /*profiling_string=*/nullptr,
264           /*builtin_code=*/0,
265           /*custom_name=*/nullptr,
266           /*version=*/0};
267 }
268 
269 }  // namespace tflite
270