• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <stddef.h>
17 #include <stdint.h>
18 
19 #include <vector>
20 
21 #include "tensorflow/lite/c/builtin_op_data.h"
22 #include "tensorflow/lite/c/common.h"
23 #include "tensorflow/lite/kernels/cpu_backend_context.h"
24 #include "tensorflow/lite/kernels/internal/compatibility.h"
25 // NOLINTNEXTLINE - This header file shouldn't go to the top.
26 #include "tensorflow/lite/kernels/internal/optimized/integer_ops/transpose_conv.h"
27 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
28 // NOLINTNEXTLINE - This header file shouldn't go to the top.
29 #include "tensorflow/lite/kernels/internal/reference/integer_ops/transpose_conv.h"
30 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
31 #include "tensorflow/lite/kernels/internal/tensor.h"
32 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
33 #include "tensorflow/lite/kernels/internal/types.h"
34 #include "tensorflow/lite/kernels/kernel_util.h"
35 #include "tensorflow/lite/kernels/padding.h"
36 
37 namespace tflite {
38 namespace ops {
39 namespace builtin {
40 namespace transpose_conv {
41 
42 // This file has 2 implementation of TransposeConv.
43 enum KernelType {
44   kReference,
45   kGenericOptimized,  // Neon-free
46 };
47 
48 constexpr int kOutputShapeTensor = 0;
49 constexpr int kWeightsTensor = 1;
50 constexpr int kDataInputTensor = 2;
51 constexpr int kBiasTensor = 3;
52 constexpr int kOutputTensor = 0;
53 
54 const int kTensorNotAllocated = -1;
55 
56 struct OpData {
57   // IDs are the arbitrary identifiers used by TF Lite to identify and access
58   // memory buffers.
59   int col2im_id = kTensorNotAllocated;
60   int transposed_weights_id = kTensorNotAllocated;
61   int scratch_tensor_id = kTensorNotAllocated;
62 
63   // col2im is the temporary tensor allocated and used in optimized path for
64   // storing col2im data:gemm result for input_matrix x filter_matrix.
65   int32_t col2im_index;
66 
67   // TfLiteConverter will transpose weights from HWOI to OHWI order.
68   // In optimized path, we will transpose them back to HWOI, this temporary
69   // tensor is allocated for storing transposed weights.
70   int32_t transposed_weights_index;
71 
72   // Scratch tensor is used in the quantized path for storing accumulation
73   // results.
74   int32_t scratch_tensor_index;
75 
76   TfLitePaddingValues padding;
77   // The scaling factor from input to output (aka the 'real multiplier') can
78   // be represented as a fixed point multiplier plus a left shift.
79   int32_t output_multiplier;
80   int output_shift;
81 
82   // Per channel output multiplier and shift.
83   std::vector<int32_t> per_channel_output_multiplier;
84   std::vector<int32_t> per_channel_output_shift;
85 
86   // The range of the fused activation layer. For example for kNone and
87   // uint8_t these would be 0 and 255.
88   int32_t output_activation_min;
89   int32_t output_activation_max;
90 
91   bool has_col2im = false;
92   bool weights_are_transposed = false;
93 };
94 
Init(TfLiteContext * context,const char * buffer,size_t length)95 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
96   return new OpData;
97 }
98 
Free(TfLiteContext * context,void * buffer)99 void Free(TfLiteContext* context, void* buffer) {
100   delete reinterpret_cast<OpData*>(buffer);
101 }
102 
ResizeTensor(TfLiteContext * context,const TfLiteTensor * shape_tensor,TfLiteTensor * tensor_to_resize)103 TfLiteStatus ResizeTensor(TfLiteContext* context,
104                           const TfLiteTensor* shape_tensor,
105                           TfLiteTensor* tensor_to_resize) {
106   // Currently only support int32 for output shape.
107   if (shape_tensor->type != kTfLiteInt32) {
108     TF_LITE_KERNEL_LOG(context, "Output shape is %s, not int32.",
109                        TfLiteTypeGetName(shape_tensor->type));
110     return kTfLiteError;
111   }
112 
113   TfLiteIntArray* shape = TfLiteIntArrayCreate(NumElements(shape_tensor));
114   for (int i = 0; i < shape->size; ++i) {
115     shape->data[i] = GetTensorData<int32_t>(shape_tensor)[i];
116   }
117 
118   return context->ResizeTensor(context, tensor_to_resize, shape);
119 }
120 
121 // Allocate temporary tensors if necessary.
122 template <KernelType kernel_type>
AllocateTemporaryTensorsIfRequired(TfLiteContext * context,TfLiteType input_type,TfLiteType weights_type,TfLiteNode * node)123 static TfLiteStatus AllocateTemporaryTensorsIfRequired(TfLiteContext* context,
124                                                        TfLiteType input_type,
125                                                        TfLiteType weights_type,
126                                                        TfLiteNode* node) {
127   OpData* data = reinterpret_cast<OpData*>(node->user_data);
128   int temporaries_count = 0;
129 
130   // Allocate col2im tensor. Currently it's only used for optimized kernels.
131   if (kernel_type == kGenericOptimized) {
132     if (data->col2im_id == kTensorNotAllocated) {
133       context->AddTensors(context, 1, &data->col2im_id);
134     }
135     data->col2im_index = temporaries_count;
136     data->has_col2im = true;
137     ++temporaries_count;
138   }
139 
140   // Allocate transposed_weights tensor. Currently it's only used for optimized
141   // float kernels.
142   if (kernel_type == kGenericOptimized) {
143     if (data->transposed_weights_id == kTensorNotAllocated) {
144       context->AddTensors(context, 1, &data->transposed_weights_id);
145     }
146     data->transposed_weights_index = temporaries_count;
147     data->weights_are_transposed = true;
148     ++temporaries_count;
149   }
150 
151   // Allocate scratch buffer tensor
152   if (input_type == kTfLiteUInt8 || input_type == kTfLiteInt8 ||
153       input_type == kTfLiteInt16) {
154     if (data->scratch_tensor_id == kTensorNotAllocated) {
155       context->AddTensors(context, 1, &data->scratch_tensor_id);
156     }
157     data->scratch_tensor_index = temporaries_count;
158     ++temporaries_count;
159   }
160 
161   TfLiteIntArrayFree(node->temporaries);
162   node->temporaries = TfLiteIntArrayCreate(temporaries_count);
163 
164   return kTfLiteOk;
165 }
166 
ResizeCol2ImTensor(TfLiteContext * context,const TfLiteTensor * output_shape,const TfLiteTensor * weights,const TfLiteTensor * input,TfLiteTensor * col2im)167 TfLiteStatus ResizeCol2ImTensor(TfLiteContext* context,
168                                 const TfLiteTensor* output_shape,
169                                 const TfLiteTensor* weights,
170                                 const TfLiteTensor* input,
171                                 TfLiteTensor* col2im) {
172   if (output_shape->type != kTfLiteInt32) {
173     TF_LITE_KERNEL_LOG(context, "col2im shape is %s, not int32.",
174                        TfLiteTypeGetName(output_shape->type));
175     return kTfLiteError;
176   }
177   TF_LITE_ENSURE_EQ(context, NumElements(output_shape), 4);
178   TfLiteIntArray* col2im_shape_array = TfLiteIntArrayCreate(2);
179   const RuntimeShape& input_shape = GetTensorShape(input);
180   const RuntimeShape& weights_shape = GetTensorShape(weights);
181   col2im_shape_array->data[0] = input_shape.Dims(1) * input_shape.Dims(2);
182   col2im_shape_array->data[1] =
183       weights_shape.Dims(0) * weights_shape.Dims(1) * weights_shape.Dims(2);
184 
185   col2im->type = input->type == kTfLiteFloat32 ? kTfLiteFloat32 : kTfLiteInt32;
186   col2im->allocation_type = kTfLiteDynamic;
187   return context->ResizeTensor(context, col2im, col2im_shape_array);
188 }
189 
ResizeAndTransposeWeights(TfLiteContext * context,const TfLiteTensor * weights,TfLiteTensor * transposed_weights)190 TfLiteStatus ResizeAndTransposeWeights(TfLiteContext* context,
191                                        const TfLiteTensor* weights,
192                                        TfLiteTensor* transposed_weights) {
193   TfLiteIntArray* transposed_weights_shape_array = TfLiteIntArrayCreate(4);
194   const RuntimeShape& input_shape = GetTensorShape(weights);
195   transposed_weights_shape_array->data[0] = input_shape.Dims(1);
196   transposed_weights_shape_array->data[1] = input_shape.Dims(2);
197   transposed_weights_shape_array->data[2] = input_shape.Dims(0);
198   transposed_weights_shape_array->data[3] = input_shape.Dims(3);
199 
200   transposed_weights->type = weights->type;
201   transposed_weights->allocation_type = kTfLiteDynamic;
202   TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, transposed_weights,
203                                               transposed_weights_shape_array));
204 
205   // Transpose the weights from OHWI order to HWOI order.
206   TransposeParams transpose_params;
207   transpose_params.perm_count = 4;
208   transpose_params.perm[0] = 1;
209   transpose_params.perm[1] = 2;
210   transpose_params.perm[2] = 0;
211   transpose_params.perm[3] = 3;
212 
213   if (weights->type == kTfLiteFloat32) {
214     optimized_ops::Transpose(transpose_params, input_shape,
215                              GetTensorData<float>(weights),
216                              GetTensorShape(transposed_weights),
217                              GetTensorData<float>(transposed_weights));
218   } else if (weights->type == kTfLiteUInt8) {
219     optimized_ops::Transpose(transpose_params, input_shape,
220                              GetTensorData<uint8>(weights),
221                              GetTensorShape(transposed_weights),
222                              GetTensorData<uint8>(transposed_weights));
223   } else if (weights->type == kTfLiteInt8) {
224     // int16 transpose_conv also with int8 weights
225     optimized_ops::Transpose(transpose_params, input_shape,
226                              GetTensorData<int8>(weights),
227                              GetTensorShape(transposed_weights),
228                              GetTensorData<int8>(transposed_weights));
229   } else {
230     TF_LITE_KERNEL_LOG(
231         context,
232         "Only float32, uint8, int8, int16 is supported currently, got %s.",
233         TfLiteTypeGetName(weights->type));
234     return kTfLiteError;
235   }
236 
237   return kTfLiteOk;
238 }
239 
240 template <KernelType kernel_type>
Prepare(TfLiteContext * context,TfLiteNode * node)241 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
242   OpData* data = reinterpret_cast<OpData*>(node->user_data);
243 
244   bool has_bias = NumInputs(node) == 4;
245 
246   // Sanity checks on op
247   TF_LITE_ENSURE(context, has_bias || NumInputs(node) == 3);
248   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
249 
250   // Retrieve tensors
251   const TfLiteTensor* output_shape;
252   TF_LITE_ENSURE_OK(
253       context, GetInputSafe(context, node, kOutputShapeTensor, &output_shape));
254   const TfLiteTensor* weights;
255   TF_LITE_ENSURE_OK(context,
256                     GetInputSafe(context, node, kWeightsTensor, &weights));
257   const TfLiteTensor* input;
258   TF_LITE_ENSURE_OK(context,
259                     GetInputSafe(context, node, kDataInputTensor, &input));
260   const TfLiteTensor* bias = nullptr;
261 
262   TfLiteTensor* output;
263   TF_LITE_ENSURE_OK(context,
264                     GetOutputSafe(context, node, kOutputTensor, &output));
265 
266   // Tensor sanity checks
267   TF_LITE_ENSURE_EQ(context, NumDimensions(output_shape), 1);
268   TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
269   TF_LITE_ENSURE_EQ(context, NumDimensions(weights), 4);
270   TF_LITE_ENSURE(context,
271                  input->type == kTfLiteFloat32 || input->type == kTfLiteUInt8 ||
272                      input->type == kTfLiteInt8 || input->type == kTfLiteInt16);
273 
274   if (has_bias) {
275     bias = GetOptionalInputTensor(context, node, kBiasTensor);
276     if (bias) {
277       if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
278         TF_LITE_ENSURE_TYPES_EQ(context, bias->type, kTfLiteInt32);
279         if (input->type == kTfLiteInt8) {
280           TF_LITE_ENSURE_EQ(context, bias->params.zero_point, 0);
281         }
282       } else if (input->type == kTfLiteInt16) {
283         TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt64);
284         TF_LITE_ENSURE_EQ(context, bias->params.zero_point, 0);
285       } else {
286         TF_LITE_ENSURE_TYPES_EQ(context, bias->type, input->type);
287       }
288       TF_LITE_ENSURE_EQ(context, NumElements(bias),
289                         SizeOfDimension(weights, 0));
290     }
291   }
292 
293   if (input->type == kTfLiteInt16) {
294     TF_LITE_ENSURE_EQ(context, weights->type, kTfLiteInt8);
295     TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0);
296     TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
297   } else {
298     TF_LITE_ENSURE_TYPES_EQ(context, weights->type, input->type);
299   }
300   TF_LITE_ENSURE_TYPES_EQ(context, output->type, input->type);
301   // Ensure that weights and inputs have the same channel dimension.
302   // Note: TOCO will reorder weights in the following format: OHWI.
303   TF_LITE_ENSURE_EQ(context, SizeOfDimension(input, 3),
304                     SizeOfDimension(weights, 3));
305 
306   // Allocate col2Im, transposed_weights & scratch Tensor.
307   TF_LITE_ENSURE_STATUS(AllocateTemporaryTensorsIfRequired<kernel_type>(
308       context, input->type, weights->type, node));
309 
310   OpData* user_data = reinterpret_cast<OpData*>(node->user_data);
311   TfLiteTensor* col2im = nullptr;
312   if (data->has_col2im) {
313     node->temporaries->data[data->col2im_index] = data->col2im_id;
314     TF_LITE_ENSURE_OK(
315         context,
316         GetTemporarySafe(context, node, user_data->col2im_index, &col2im));
317   }
318 
319   if (!IsConstantTensor(output_shape)) {
320     // Defer resizing until Eval().
321     SetTensorToDynamic(output);
322     if (data->has_col2im) {
323       SetTensorToDynamic(col2im);
324     }
325   } else {
326     TF_LITE_ENSURE_STATUS(ResizeTensor(context, output_shape, output));
327     if (data->has_col2im) {
328       TF_LITE_ENSURE_STATUS(
329           ResizeCol2ImTensor(context, output_shape, weights, input, col2im));
330     }
331   }
332 
333   if (data->weights_are_transposed) {
334     node->temporaries->data[data->transposed_weights_index] =
335         data->transposed_weights_id;
336     TfLiteTensor* transposed_weights;
337     TF_LITE_ENSURE_OK(
338         context,
339         GetTemporarySafe(context, node, user_data->transposed_weights_index,
340                          &transposed_weights));
341     if (!IsConstantTensor(weights)) {
342       SetTensorToDynamic(transposed_weights);
343     } else {
344       ResizeAndTransposeWeights(context, weights, transposed_weights);
345     }
346   }
347 
348   if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8 ||
349       input->type == kTfLiteInt16) {
350     node->temporaries->data[data->scratch_tensor_index] =
351         data->scratch_tensor_id;
352     TfLiteTensor* scratch_buffer;
353     TF_LITE_ENSURE_OK(
354         context, GetTemporarySafe(context, node, data->scratch_tensor_index,
355                                   &scratch_buffer));
356     if (input->type == kTfLiteInt16) {
357       scratch_buffer->type = kTfLiteInt64;
358     } else {
359       scratch_buffer->type = kTfLiteInt32;
360     }
361 
362     scratch_buffer->allocation_type = kTfLiteDynamic;
363     if (!IsConstantTensor(output_shape)) {
364       SetTensorToDynamic(scratch_buffer);
365     } else {
366       TF_LITE_ENSURE_STATUS(
367           ResizeTensor(context, output_shape, scratch_buffer));
368     }
369 
370     TF_LITE_ENSURE_EQ(context, weights->quantization.type,
371                       kTfLiteAffineQuantization);
372     const auto* affine_quantization =
373         reinterpret_cast<TfLiteAffineQuantization*>(
374             weights->quantization.params);
375     const int channels_out = weights->dims->data[0];
376     TF_LITE_ENSURE(context, affine_quantization);
377     TF_LITE_ENSURE(context, affine_quantization->scale);
378     TF_LITE_ENSURE(context, (affine_quantization->scale->size == 1 ||
379                              affine_quantization->scale->size == channels_out));
380 
381     data->per_channel_output_multiplier.resize(channels_out);
382     data->per_channel_output_shift.resize(channels_out);
383     TF_LITE_ENSURE_STATUS(tflite::PopulateConvolutionQuantizationParams(
384         context, input, weights, bias, output, kTfLiteActNone,
385         &data->output_multiplier, &data->output_shift,
386         &data->output_activation_min, &data->output_activation_max,
387         data->per_channel_output_multiplier.data(),
388         data->per_channel_output_shift.data(), channels_out));
389   }
390 
391   return kTfLiteOk;
392 }
393 
394 template <KernelType kernel_type>
EvalFloat(TfLiteContext * context,const TfLiteTransposeConvParams * params,const OpData * data,const TfLiteTensor * input,const TfLiteTensor * weights,const TfLiteTensor * bias,const TfLiteTensor * transposed_weights,TfLiteTensor * col2im,TfLiteTensor * output)395 void EvalFloat(TfLiteContext* context, const TfLiteTransposeConvParams* params,
396                const OpData* data, const TfLiteTensor* input,
397                const TfLiteTensor* weights, const TfLiteTensor* bias,
398                const TfLiteTensor* transposed_weights, TfLiteTensor* col2im,
399                TfLiteTensor* output) {
400   tflite::ConvParams op_params;
401   op_params.padding_type = PaddingType::kSame;
402   op_params.padding_values.width = data->padding.width;
403   op_params.padding_values.height = data->padding.height;
404   op_params.padding_values.width_offset = data->padding.width_offset;
405   op_params.padding_values.height_offset = data->padding.height_offset;
406   op_params.stride_width = params->stride_width;
407   op_params.stride_height = params->stride_height;
408 
409   switch (kernel_type) {
410     case kReference: {
411       reference_ops::TransposeConv(
412           op_params, GetTensorShape(input), GetTensorData<float>(input),
413           GetTensorShape(weights), GetTensorData<float>(weights),
414           GetTensorShape(bias), GetTensorData<float>(bias),
415           GetTensorShape(output), GetTensorData<float>(output),
416           GetTensorShape(col2im), GetTensorData<float>(col2im));
417       break;
418     }
419     case kGenericOptimized: {
420       optimized_ops::TransposeConvV2(
421           op_params, GetTensorShape(input), GetTensorData<float>(input),
422           GetTensorShape(transposed_weights),
423           GetTensorData<float>(transposed_weights), GetTensorShape(bias),
424           GetTensorData<float>(bias), GetTensorShape(output),
425           GetTensorData<float>(output), GetTensorShape(col2im),
426           GetTensorData<float>(col2im),
427           CpuBackendContext::GetFromContext(context));
428       break;
429     }
430   }
431 }
432 
433 template <KernelType kernel_type>
EvalQuantized(TfLiteContext * context,const TfLiteTransposeConvParams * params,OpData * data,const TfLiteTensor * input,const TfLiteTensor * weights,const TfLiteTensor * transposed_weights,const TfLiteTensor * bias,TfLiteTensor * col2im,TfLiteTensor * output,TfLiteTensor * scratch_buffer)434 void EvalQuantized(TfLiteContext* context,
435                    const TfLiteTransposeConvParams* params, OpData* data,
436                    const TfLiteTensor* input, const TfLiteTensor* weights,
437                    const TfLiteTensor* transposed_weights,
438                    const TfLiteTensor* bias, TfLiteTensor* col2im,
439                    TfLiteTensor* output, TfLiteTensor* scratch_buffer) {
440   int32_t input_offset = -input->params.zero_point;
441   int32_t filter_offset = -weights->params.zero_point;
442   int32_t output_offset = output->params.zero_point;
443 
444   tflite::ConvParams op_params;
445   op_params.padding_type = PaddingType::kSame;
446   op_params.padding_values.width = data->padding.width;
447   op_params.padding_values.height = data->padding.height;
448   op_params.padding_values.width_offset = data->padding.width_offset;
449   op_params.padding_values.height_offset = data->padding.height_offset;
450   op_params.stride_width = params->stride_width;
451   op_params.stride_height = params->stride_height;
452   op_params.input_offset = input_offset;
453   op_params.output_offset = output_offset;
454   op_params.weights_offset = filter_offset;
455   op_params.output_multiplier = data->output_multiplier;
456   op_params.output_shift = -data->output_shift;
457   op_params.quantized_activation_min = data->output_activation_min;
458   op_params.quantized_activation_max = data->output_activation_max;
459 
460   switch (kernel_type) {
461     case kReference: {
462       reference_ops::TransposeConv(
463           op_params, GetTensorShape(input), GetTensorData<uint8>(input),
464           GetTensorShape(weights), GetTensorData<uint8>(weights),
465           GetTensorShape(bias), GetTensorData<int32_t>(bias),
466           GetTensorShape(output), GetTensorData<uint8>(output),
467           GetTensorShape(col2im), GetTensorData<uint8>(col2im),
468           GetTensorData<int32_t>(scratch_buffer));
469       break;
470     }
471     case kGenericOptimized: {
472       optimized_ops::TransposeConvV2(
473           op_params, GetTensorShape(input), GetTensorData<uint8>(input),
474           GetTensorShape(transposed_weights),
475           GetTensorData<uint8>(transposed_weights), GetTensorShape(bias),
476           GetTensorData<int32>(bias), GetTensorShape(output),
477           GetTensorData<uint8>(output), GetTensorShape(col2im),
478           GetTensorData<int32>(col2im), GetTensorData<int32>(scratch_buffer),
479           CpuBackendContext::GetFromContext(context));
480       break;
481     }
482   }
483 }
484 
485 template <KernelType kernel_type>
EvalQuantizedPerChannel(TfLiteContext * context,const TfLiteTransposeConvParams * params,OpData * data,const TfLiteTensor * input,const TfLiteTensor * weights,const TfLiteTensor * transposed_weights,const TfLiteTensor * bias,TfLiteTensor * col2im,TfLiteTensor * output,TfLiteTensor * scratch_buffer)486 void EvalQuantizedPerChannel(
487     TfLiteContext* context, const TfLiteTransposeConvParams* params,
488     OpData* data, const TfLiteTensor* input, const TfLiteTensor* weights,
489     const TfLiteTensor* transposed_weights, const TfLiteTensor* bias,
490     TfLiteTensor* col2im, TfLiteTensor* output, TfLiteTensor* scratch_buffer) {
491   tflite::ConvParams op_params;
492   op_params.padding_type = PaddingType::kSame;
493   op_params.padding_values.width = data->padding.width;
494   op_params.padding_values.height = data->padding.height;
495   op_params.padding_values.width_offset = data->padding.width_offset;
496   op_params.padding_values.height_offset = data->padding.height_offset;
497   op_params.stride_width = params->stride_width;
498   op_params.stride_height = params->stride_height;
499   // Need to flip the sign of input offset to add it directly to the quantized
500   // buffer.
501   op_params.input_offset = -input->params.zero_point;
502   op_params.output_offset = output->params.zero_point;
503   op_params.quantized_activation_min = data->output_activation_min;
504   op_params.quantized_activation_max = data->output_activation_max;
505 
506   switch (kernel_type) {
507     case kReference: {
508       reference_integer_ops::TransposeConv(
509           op_params, data->per_channel_output_multiplier.data(),
510           data->per_channel_output_shift.data(), GetTensorShape(input),
511           GetTensorData<int8>(input), GetTensorShape(weights),
512           GetTensorData<int8>(weights), GetTensorShape(bias),
513           GetTensorData<int32>(bias), GetTensorShape(output),
514           GetTensorData<int8>(output), GetTensorShape(col2im),
515           GetTensorData<int8>(col2im), GetTensorData<int32_t>(scratch_buffer));
516       break;
517     }
518     case kGenericOptimized: {
519       optimized_integer_ops::TransposeConvV2(
520           op_params, data->per_channel_output_multiplier.data(),
521           data->per_channel_output_shift.data(), GetTensorShape(input),
522           GetTensorData<int8>(input), GetTensorShape(transposed_weights),
523           GetTensorData<int8>(transposed_weights), GetTensorShape(bias),
524           GetTensorData<int32>(bias), GetTensorShape(output),
525           GetTensorData<int8>(output), GetTensorShape(col2im),
526           GetTensorData<int32>(col2im), GetTensorData<int32>(scratch_buffer),
527           CpuBackendContext::GetFromContext(context));
528       break;
529     }
530   }
531 }
532 
EvalQuantizedPerChannel16x8(TfLiteContext * context,const TfLiteTransposeConvParams * params,OpData * data,const TfLiteTensor * input,const TfLiteTensor * weights,const TfLiteTensor * transposed_weights,const TfLiteTensor * bias,TfLiteTensor * col2im,TfLiteTensor * output,TfLiteTensor * scratch_buffer)533 void EvalQuantizedPerChannel16x8(
534     TfLiteContext* context, const TfLiteTransposeConvParams* params,
535     OpData* data, const TfLiteTensor* input, const TfLiteTensor* weights,
536     const TfLiteTensor* transposed_weights, const TfLiteTensor* bias,
537     TfLiteTensor* col2im, TfLiteTensor* output, TfLiteTensor* scratch_buffer) {
538   tflite::ConvParams op_params;
539   op_params.padding_type = PaddingType::kSame;
540   op_params.padding_values.width = data->padding.width;
541   op_params.padding_values.height = data->padding.height;
542   op_params.padding_values.width_offset = data->padding.width_offset;
543   op_params.padding_values.height_offset = data->padding.height_offset;
544   op_params.stride_width = params->stride_width;
545   op_params.stride_height = params->stride_height;
546   // Need to flip the sign of input offset to add it directly to the quantized
547   // buffer.
548   op_params.input_offset = -input->params.zero_point;
549   op_params.output_offset = output->params.zero_point;
550   op_params.quantized_activation_min = data->output_activation_min;
551   op_params.quantized_activation_max = data->output_activation_max;
552 
553   // Need to add optimized kernel
554   reference_integer_ops::TransposeConv(
555       op_params, data->per_channel_output_multiplier.data(),
556       data->per_channel_output_shift.data(), GetTensorShape(input),
557       GetTensorData<int16>(input), GetTensorShape(weights),
558       GetTensorData<int8>(weights), GetTensorShape(bias),
559       GetTensorData<int64_t>(bias), GetTensorShape(output),
560       GetTensorData<int16>(output), GetTensorShape(col2im),
561       GetTensorData<int8>(col2im), GetTensorData<int64_t>(scratch_buffer));
562 }
563 
564 template <KernelType kernel_type>
Eval(TfLiteContext * context,TfLiteNode * node)565 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
566   // Retrieve tensors (All should be allocated by now)
567   const TfLiteTensor* output_shape;
568   TF_LITE_ENSURE_OK(
569       context, GetInputSafe(context, node, kOutputShapeTensor, &output_shape));
570   const TfLiteTensor* weights;
571   TF_LITE_ENSURE_OK(context,
572                     GetInputSafe(context, node, kWeightsTensor, &weights));
573   const TfLiteTensor* input;
574   TF_LITE_ENSURE_OK(context,
575                     GetInputSafe(context, node, kDataInputTensor, &input));
576   const TfLiteTensor* bias =
577       (NumInputs(node) == 4)
578           ? GetOptionalInputTensor(context, node, kBiasTensor)
579           : nullptr;
580   TfLiteTensor* output;
581   TF_LITE_ENSURE_OK(context,
582                     GetOutputSafe(context, node, kOutputTensor, &output));
583   OpData* data = reinterpret_cast<OpData*>(node->user_data);
584   TfLiteTensor* col2im = data->has_col2im
585                              ? GetTemporary(context, node, data->col2im_index)
586                              : nullptr;
587   TfLiteTensor* transposed_weights =
588       data->weights_are_transposed
589           ? GetTemporary(context, node, data->transposed_weights_index)
590           : nullptr;
591   const auto* params =
592       reinterpret_cast<TfLiteTransposeConvParams*>(node->builtin_data);
593 
594   // Resize any deferred dynamic tensors
595   if (IsDynamicTensor(output)) {
596     TF_LITE_ENSURE_OK(context, ResizeTensor(context, output_shape, output));
597   }
598   if (data->has_col2im && IsDynamicTensor(col2im)) {
599     TF_LITE_ENSURE_OK(context, ResizeCol2ImTensor(context, output_shape,
600                                                   weights, input, col2im));
601   }
602 
603   // Get height and width of the output image.
604   const int width = SizeOfDimension(output, 2);
605   const int height = SizeOfDimension(output, 1);
606   const int filter_width = SizeOfDimension(weights, 2);
607   const int filter_height = SizeOfDimension(weights, 1);
608 
609   int unused_output_height, unused_output_width;
610   data->padding = ComputePaddingHeightWidth(
611       params->stride_height, params->stride_width, 1, 1, height, width,
612       filter_height, filter_width, params->padding, &unused_output_height,
613       &unused_output_width);
614 
615   // Currently support float32, uint8, int8, int16.
616   switch (input->type) {
617     case kTfLiteFloat32: {
618       // Only for GenericOptimized path, we use transposed weights.
619       if (data->weights_are_transposed) {
620         if (!IsConstantTensor(weights)) {
621           ResizeAndTransposeWeights(context, weights, transposed_weights);
622         }
623       }
624       EvalFloat<kernel_type>(context, params, data, input, weights, bias,
625                              transposed_weights, col2im, output);
626       break;
627     }
628     case kTfLiteUInt8: {
629       TfLiteTensor* scratch_buffer;
630       TF_LITE_ENSURE_OK(
631           context, GetTemporarySafe(context, node, data->scratch_tensor_index,
632                                     &scratch_buffer));
633       if (IsDynamicTensor(scratch_buffer)) {
634         TF_LITE_ENSURE_OK(context,
635                           ResizeTensor(context, output_shape, scratch_buffer));
636       }
637       if (data->weights_are_transposed) {
638         if (!IsConstantTensor(weights)) {
639           ResizeAndTransposeWeights(context, weights, transposed_weights);
640         }
641       }
642       EvalQuantized<kernel_type>(context, params, data, input, weights,
643                                  transposed_weights, bias, col2im, output,
644                                  scratch_buffer);
645       break;
646     }
647     case kTfLiteInt8: {
648       TfLiteTensor* scratch_buffer;
649       TF_LITE_ENSURE_OK(
650           context, GetTemporarySafe(context, node, data->scratch_tensor_index,
651                                     &scratch_buffer));
652       if (IsDynamicTensor(scratch_buffer)) {
653         TF_LITE_ENSURE_OK(context,
654                           ResizeTensor(context, output_shape, scratch_buffer));
655       }
656       if (data->weights_are_transposed && !IsConstantTensor(weights)) {
657         ResizeAndTransposeWeights(context, weights, transposed_weights);
658       }
659       EvalQuantizedPerChannel<kernel_type>(context, params, data, input,
660                                            weights, transposed_weights, bias,
661                                            col2im, output, scratch_buffer);
662       break;
663     }
664     case kTfLiteInt16: {
665       TfLiteTensor* scratch_buffer;
666       TF_LITE_ENSURE_OK(
667           context, GetTemporarySafe(context, node, data->scratch_tensor_index,
668                                     &scratch_buffer));
669       if (IsDynamicTensor(scratch_buffer)) {
670         TF_LITE_ENSURE_OK(context,
671                           ResizeTensor(context, output_shape, scratch_buffer));
672       }
673       if (data->weights_are_transposed && !IsConstantTensor(weights)) {
674         ResizeAndTransposeWeights(context, weights, transposed_weights);
675       }
676       EvalQuantizedPerChannel16x8(context, params, data, input, weights,
677                                   transposed_weights, bias, col2im, output,
678                                   scratch_buffer);
679       break;
680     }
681     default:
682       context->ReportError(context, "Type '%s' is not currently supported.",
683                            TfLiteTypeGetName(input->type));
684       return kTfLiteError;
685   }
686   return kTfLiteOk;
687 }
688 
689 }  // namespace transpose_conv
690 
Register_TRANSPOSECONV_REF()691 TfLiteRegistration* Register_TRANSPOSECONV_REF() {
692   static TfLiteRegistration r = {
693       transpose_conv::Init, transpose_conv::Free,
694       transpose_conv::Prepare<transpose_conv::kReference>,
695       transpose_conv::Eval<transpose_conv::kReference>};
696   return &r;
697 }
698 
Register_TRANSPOSECONV_GENERIC_OPT()699 TfLiteRegistration* Register_TRANSPOSECONV_GENERIC_OPT() {
700   static TfLiteRegistration r = {
701       transpose_conv::Init, transpose_conv::Free,
702       transpose_conv::Prepare<transpose_conv::kGenericOptimized>,
703       transpose_conv::Eval<transpose_conv::kGenericOptimized>};
704   return &r;
705 }
706 
Register_TRANSPOSE_CONV()707 TfLiteRegistration* Register_TRANSPOSE_CONV() {
708   return Register_TRANSPOSECONV_GENERIC_OPT();
709 }
710 
711 }  // namespace builtin
712 }  // namespace ops
713 }  // namespace tflite
714