1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/kernels/internal/reference/transpose_conv.h"
17
18 #include "tensorflow/lite/c/builtin_op_data.h"
19 #include "tensorflow/lite/c/common.h"
20 #include "tensorflow/lite/kernels/internal/common.h"
21 #include "tensorflow/lite/kernels/internal/quantization_util.h"
22 #include "tensorflow/lite/kernels/internal/reference/integer_ops/transpose_conv.h"
23 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
24 #include "tensorflow/lite/kernels/kernel_util.h"
25 #include "tensorflow/lite/kernels/padding.h"
26 #include "tensorflow/lite/micro/kernels/kernel_util.h"
27
28 namespace tflite {
29 namespace {
30
31 // For the TfLite transpose_conv implementation, input tensor 0 corresponds to
32 // the OutputShapeTensor. However, since TFLM does not support dynamic tensors,
33 // the TFLM implementation ignores input tensor 0 and the only inputs we care
34 // about are kFilterTensor, kInputTensor and kBiasTensor.
35 constexpr int kFilterTensor = 1;
36 constexpr int kInputTensor = 2;
37 constexpr int kBiasTensor = 3;
38 constexpr int kOutputTensor = 0;
39
40 // Conv is quantized along dimension 0:
41 // https://www.tensorflow.org/lite/performance/quantization_spec
42 constexpr int kConvQuantizedDimension = 0;
43
44 struct OpData {
45 ConvParams params;
46
47 // A scratch buffer is required for quantized implementations.
48 int scratch_buffer_index;
49
50 // Multiplier and shift arrays are required for the int8 implementation.
51 int32_t* per_channel_output_multiplier;
52 int32_t* per_channel_output_shift;
53 };
54
RuntimePaddingType(TfLitePadding padding)55 inline PaddingType RuntimePaddingType(TfLitePadding padding) {
56 switch (padding) {
57 case TfLitePadding::kTfLitePaddingSame:
58 return PaddingType::kSame;
59 case TfLitePadding::kTfLitePaddingValid:
60 return PaddingType::kValid;
61 case TfLitePadding::kTfLitePaddingUnknown:
62 default:
63 return PaddingType::kNone;
64 }
65 }
66
CalculateOpData(TfLiteContext * context,TfLiteNode * node,const TfLiteConvParams * params,int width,int height,int filter_width,int filter_height,int out_width,int out_height,const TfLiteType data_type,OpData * data)67 TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
68 const TfLiteConvParams* params, int width,
69 int height, int filter_width, int filter_height,
70 int out_width, int out_height,
71 const TfLiteType data_type, OpData* data) {
72 bool has_bias = node->inputs->size == 4;
73 // Check number of inputs/outputs
74 TF_LITE_ENSURE(context, has_bias || node->inputs->size == 3);
75 TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
76
77 // Matching GetWindowedOutputSize in TensorFlow.
78 auto padding = params->padding;
79 TfLitePaddingValues padding_values = ComputePaddingHeightWidth(
80 params->stride_height, params->stride_width,
81 params->dilation_height_factor, params->dilation_width_factor, height,
82 width, filter_height, filter_width, padding, &out_height, &out_width);
83
84 data->params.padding_type = RuntimePaddingType(padding);
85 data->params.padding_values.width = padding_values.width;
86 data->params.padding_values.height = padding_values.height;
87
88 // Note that quantized inference requires that all tensors have their
89 // parameters set. This is usually done during quantized training.
90 if (data_type != kTfLiteFloat32) {
91 const TfLiteTensor* input = GetInput(context, node, kInputTensor);
92 TF_LITE_ENSURE(context, input != nullptr);
93 const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
94 TF_LITE_ENSURE(context, filter != nullptr);
95 const TfLiteTensor* bias =
96 GetOptionalInputTensor(context, node, kBiasTensor);
97 TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
98 TF_LITE_ENSURE(context, output != nullptr);
99 int output_channels = filter->dims->data[kConvQuantizedDimension];
100
101 TF_LITE_ENSURE_STATUS(tflite::PopulateConvolutionQuantizationParams(
102 context, input, filter, bias, output, params->activation,
103 &data->params.output_multiplier, &data->params.output_shift,
104 &data->params.quantized_activation_min,
105 &data->params.quantized_activation_max,
106 data->per_channel_output_multiplier,
107 reinterpret_cast<int*>(data->per_channel_output_shift),
108 output_channels));
109 }
110 return kTfLiteOk;
111 }
112
Init(TfLiteContext * context,const char * buffer,size_t length)113 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
114 TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
115 return context->AllocatePersistentBuffer(context, sizeof(OpData));
116 }
117
Prepare(TfLiteContext * context,TfLiteNode * node)118 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
119 TFLITE_DCHECK(node->user_data != nullptr);
120 TFLITE_DCHECK(node->builtin_data != nullptr);
121
122 OpData* data = static_cast<OpData*>(node->user_data);
123 const auto params = static_cast<const TfLiteConvParams*>(node->builtin_data);
124
125 TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
126 TF_LITE_ENSURE(context, output != nullptr);
127 const TfLiteTensor* input = GetInput(context, node, kInputTensor);
128 TF_LITE_ENSURE(context, input != nullptr);
129 const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
130 TF_LITE_ENSURE(context, filter != nullptr);
131
132 int input_width = input->dims->data[2];
133 int input_height = input->dims->data[1];
134 int filter_width = filter->dims->data[2];
135 int filter_height = filter->dims->data[1];
136 int output_width = output->dims->data[2];
137 int output_height = output->dims->data[1];
138
139 // Dynamically allocate per-channel quantization parameters.
140 const int num_channels = filter->dims->data[kConvQuantizedDimension];
141 data->per_channel_output_multiplier =
142 static_cast<int32_t*>(context->AllocatePersistentBuffer(
143 context, num_channels * sizeof(int32_t)));
144 data->per_channel_output_shift =
145 static_cast<int32_t*>(context->AllocatePersistentBuffer(
146 context, num_channels * sizeof(int32_t)));
147
148 // Quantized kernels use an int32 scratch buffer.
149 if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
150 TFLITE_DCHECK(context->RequestScratchBufferInArena != nullptr);
151 TFLITE_DCHECK(context->RequestScratchBufferInArena(
152 context,
153 GetTensorShape(output).FlatSize() * sizeof(int32_t),
154 &(data->scratch_buffer_index)) == kTfLiteOk);
155 }
156
157 // All per-channel quantized tensors need valid zero point and scale arrays.
158 if (input->type == kTfLiteInt8) {
159 TF_LITE_ENSURE_EQ(context, filter->quantization.type,
160 kTfLiteAffineQuantization);
161
162 const auto* affine_quantization =
163 static_cast<TfLiteAffineQuantization*>(filter->quantization.params);
164 TF_LITE_ENSURE(context, affine_quantization);
165 TF_LITE_ENSURE(context, affine_quantization->scale);
166 TF_LITE_ENSURE(context, affine_quantization->zero_point);
167
168 TF_LITE_ENSURE(context,
169 affine_quantization->scale->size == 1 ||
170 affine_quantization->scale->size ==
171 filter->dims->data[kConvQuantizedDimension]);
172 TF_LITE_ENSURE_EQ(context, affine_quantization->scale->size,
173 affine_quantization->zero_point->size);
174 }
175
176 TF_LITE_ENSURE_STATUS(CalculateOpData(
177 context, node, params, input_width, input_height, filter_width,
178 filter_height, output_width, output_height, input->type, data));
179
180 // Offsets (zero points)
181 data->params.input_offset = -input->params.zero_point;
182 data->params.weights_offset = -filter->params.zero_point;
183 data->params.output_offset = output->params.zero_point;
184
185 // Stride + dilation
186 data->params.stride_width = params->stride_width;
187 data->params.stride_height = params->stride_height;
188 data->params.dilation_width_factor = params->dilation_width_factor;
189 data->params.dilation_height_factor = params->dilation_height_factor;
190
191 float output_activation_min, output_activation_max;
192 CalculateActivationRange(params->activation, &output_activation_min,
193 &output_activation_max);
194 data->params.float_activation_min = output_activation_min;
195 data->params.float_activation_max = output_activation_max;
196 return kTfLiteOk;
197 } // namespace conv
198
Eval(TfLiteContext * context,TfLiteNode * node)199 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
200 const TfLiteEvalTensor* input =
201 tflite::micro::GetEvalInput(context, node, kInputTensor);
202 const TfLiteEvalTensor* filter =
203 tflite::micro::GetEvalInput(context, node, kFilterTensor);
204 const TfLiteEvalTensor* bias =
205 (NumInputs(node) == 4)
206 ? tflite::micro::GetEvalInput(context, node, kBiasTensor)
207 : nullptr;
208 TfLiteEvalTensor* output =
209 tflite::micro::GetEvalOutput(context, node, kOutputTensor);
210
211 TFLITE_DCHECK(node->user_data != nullptr);
212 const OpData& data = *(static_cast<const OpData*>(node->user_data));
213
214 TF_LITE_ENSURE_EQ(context, input->type, output->type);
215 TF_LITE_ENSURE_MSG(context, input->type == filter->type,
216 "Hybrid models are not supported on TFLite Micro.");
217
218 switch (input->type) { // Already know in/out types are same.
219 case kTfLiteFloat32: {
220 reference_ops::TransposeConv(
221 data.params, tflite::micro::GetTensorShape(input),
222 tflite::micro::GetTensorData<float>(input),
223 tflite::micro::GetTensorShape(filter),
224 tflite::micro::GetTensorData<float>(filter),
225 tflite::micro::GetTensorShape(bias),
226 tflite::micro::GetTensorData<float>(bias),
227 tflite::micro::GetTensorShape(output),
228 tflite::micro::GetTensorData<float>(output),
229 tflite::micro::GetTensorShape(nullptr), nullptr);
230 break;
231 }
232 case kTfLiteInt8: {
233 int32_t* scratch_buffer = static_cast<int32_t*>(
234 context->GetScratchBuffer(context, data.scratch_buffer_index));
235 reference_integer_ops::TransposeConv(
236 data.params, data.per_channel_output_multiplier,
237 data.per_channel_output_shift, tflite::micro::GetTensorShape(input),
238 tflite::micro::GetTensorData<int8_t>(input),
239 tflite::micro::GetTensorShape(filter),
240 tflite::micro::GetTensorData<int8_t>(filter),
241 tflite::micro::GetTensorShape(bias),
242 tflite::micro::GetTensorData<int32_t>(bias),
243 tflite::micro::GetTensorShape(output),
244 tflite::micro::GetTensorData<int8_t>(output),
245 tflite::micro::GetTensorShape(nullptr), nullptr, scratch_buffer);
246 break;
247 }
248 default:
249 TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
250 TfLiteTypeGetName(input->type), input->type);
251 return kTfLiteError;
252 }
253 return kTfLiteOk;
254 }
255
256 } // namespace
257
Register_TRANSPOSE_CONV()258 TfLiteRegistration Register_TRANSPOSE_CONV() {
259 return {/*init=*/Init,
260 /*free=*/nullptr,
261 /*prepare=*/Prepare,
262 /*invoke=*/Eval,
263 /*profiling_string=*/nullptr,
264 /*builtin_code=*/0,
265 /*custom_name=*/nullptr,
266 /*version=*/0};
267 }
268
269 } // namespace tflite
270