• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/internal/reference/pad.h"
16 
17 #include <stdint.h>
18 
19 #include <limits>
20 
21 #include "tensorflow/lite/c/common.h"
22 #include "tensorflow/lite/kernels/internal/compatibility.h"
23 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
24 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
25 #include "tensorflow/lite/kernels/internal/tensor.h"
26 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
27 #include "tensorflow/lite/kernels/internal/types.h"
28 #include "tensorflow/lite/kernels/kernel_util.h"
29 
30 namespace tflite {
31 namespace ops {
32 namespace builtin {
33 namespace pad {
34 
35 // This file has two implementations of Pad.
36 enum KernelType {
37   kReference,
38   kGenericOptimized,
39 };
40 
41 struct PadContext {
PadContexttflite::ops::builtin::pad::PadContext42   PadContext(TfLiteContext* context, TfLiteNode* node) {
43     input = GetInput(context, node, 0);
44     paddings = GetInput(context, node, 1);
45     if (NumInputs(node) == 3) {
46       constant_values = GetOptionalInputTensor(context, node, 2);
47     } else {
48       constant_values = nullptr;
49     }
50     output = GetOutput(context, node, 0);
51     dims = NumDimensions(input);
52 
53     resizing_category = ResizingCategory::kGenericResize;
54     const int paddings_total = GetTensorShape(paddings).FlatSize();
55     const int32* paddings_data = GetTensorData<int32>(paddings);
56     // Paddings will be a n,2 array, and we need to detect 4D arrays with the
57     // pattern { {0,0}, {a, b}, {c, d}, {0,0} }.
58     if (IsConstantTensor(paddings) && paddings_total == 8 &&
59         (paddings_data[0] == 0 && paddings_data[1] == 0) &&
60         (paddings_data[6] == 0 && paddings_data[7] == 0)) {
61       resizing_category = ResizingCategory::kImageStyle;
62     }
63   }
64   const TfLiteTensor* constant_values;
65   const TfLiteTensor* input;
66   const TfLiteTensor* paddings;
67   TfLiteTensor* output;
68   int dims;
69   ResizingCategory resizing_category;
70 };
71 
72 // Resizes output array based on the input size and padding size. This function
73 // is callable from both Prepare() and Eval() as long as the caller ensures the
74 // paddings data is present.
ResizeOutputTensor(TfLiteContext * context,PadContext * op_context)75 TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
76                                 PadContext* op_context) {
77   // Ensures the paddings array is dims x 2.
78   TF_LITE_ENSURE_EQ(context, SizeOfDimension(op_context->paddings, 0),
79                     op_context->dims);
80   TF_LITE_ENSURE_EQ(context, SizeOfDimension(op_context->paddings, 1), 2);
81 
82   // Ensures all the elements of the paddings is non-negative.
83   const int32* paddings_data = GetTensorData<int32>(op_context->paddings);
84 
85   for (int idx = 0; idx < op_context->dims; ++idx) {
86     int before_padding = *paddings_data++;
87     int after_padding = *paddings_data++;
88 
89     TF_LITE_ENSURE_MSG(context, (before_padding >= 0 && after_padding >= 0),
90                        "Pad value has to be greater than equal to 0.");
91   }
92 
93   // Determines the size of the output tensor.
94   TfLiteIntArray* input_size = op_context->input->dims;
95   TfLiteIntArray* output_size = TfLiteIntArrayCopy(input_size);
96   paddings_data = GetTensorData<int32>(op_context->paddings);
97 
98   for (int idx = 0; idx < op_context->dims; ++idx) {
99     int before_padding = *paddings_data++;
100     int after_padding = *paddings_data++;
101 
102     output_size->data[idx] =
103         (input_size->data[idx] + before_padding + after_padding);
104   }
105 
106   return context->ResizeTensor(context, op_context->output, output_size);
107 }
108 
Prepare(TfLiteContext * context,TfLiteNode * node)109 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
110   TF_LITE_ENSURE(context, NumInputs(node) == 2 || NumInputs(node) == 3);
111   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
112 
113   PadContext op_context(context, node);
114   TF_LITE_ENSURE_TYPES_EQ(context, op_context.input->type,
115                           op_context.output->type);
116   if (op_context.constant_values != nullptr) {
117     TF_LITE_ENSURE_TYPES_EQ(context, op_context.input->type,
118                             op_context.constant_values->type);
119   }
120 
121   // Ensure we do not exceed maximum dimension count.
122   TF_LITE_ENSURE(
123       context, op_context.dims <= reference_ops::PadKernelMaxDimensionCount());
124 
125   // Exit early if paddings is a non-const tensor or the given input is an
126   // unranked input. Set output tensor to dynamic so output size can be
127   // determined in Eval.
128   if (NumDimensions(op_context.input) == 0 ||
129       !IsConstantTensor(op_context.paddings)) {
130     SetTensorToDynamic(op_context.output);
131     return kTfLiteOk;
132   }
133   return ResizeOutputTensor(context, &op_context);
134 }
135 
136 template <typename integer_type>
EvalInt(TfLiteContext * context,const PadContext & op_context,const tflite::PadParams & op_params)137 TfLiteStatus EvalInt(TfLiteContext* context, const PadContext& op_context,
138                      const tflite::PadParams& op_params) {
139   integer_type pad_value;
140   if (op_context.constant_values == nullptr) {
141     // Quantized Pad requires that 0 is represented in the quantized
142     // range.
143     TF_LITE_ENSURE(context, op_context.output->params.zero_point >=
144                                 std::numeric_limits<integer_type>::min());
145     TF_LITE_ENSURE(context, op_context.output->params.zero_point <=
146                                 std::numeric_limits<integer_type>::max());
147     pad_value = static_cast<integer_type>(op_context.output->params.zero_point);
148   } else {
149     // Quantized Pad requires that 'constant_values' is represented in the
150     // same quantized range as the input and output tensors.
151     TF_LITE_ENSURE_EQ(context, op_context.output->params.zero_point,
152                       op_context.constant_values->params.zero_point);
153     TF_LITE_ENSURE_EQ(context, op_context.output->params.scale,
154                       op_context.constant_values->params.scale);
155     pad_value = *GetTensorData<integer_type>(op_context.constant_values);
156   }
157   const integer_type pad_value_copy = pad_value;
158   if (op_context.resizing_category == ResizingCategory::kImageStyle) {
159     optimized_ops::PadImageStyle(
160         op_params, GetTensorShape(op_context.input),
161         GetTensorData<integer_type>(op_context.input), &pad_value_copy,
162         GetTensorShape(op_context.output),
163         GetTensorData<integer_type>(op_context.output));
164   } else {
165     optimized_ops::Pad(op_params, GetTensorShape(op_context.input),
166                        GetTensorData<integer_type>(op_context.input),
167                        &pad_value_copy, GetTensorShape(op_context.output),
168                        GetTensorData<integer_type>(op_context.output));
169   }
170 
171   return kTfLiteOk;
172 }
173 
174 template <KernelType kernel_type>
Eval(TfLiteContext * context,TfLiteNode * node)175 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
176   PadContext op_context(context, node);
177 
178   if (op_context.constant_values != nullptr) {
179     // Ensure that constant_values is a scalar.
180     TF_LITE_ENSURE_EQ(context, NumElements(op_context.constant_values), 1);
181   }
182 
183   // Resize the output tensor if the output tensor is dynamic.
184   if (IsDynamicTensor(op_context.output)) {
185     TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
186   }
187 
188   // Create before and after padding arrays that are accepted by the kernel.
189   const int32* paddings_data = GetTensorData<int32>(op_context.paddings);
190 
191   TF_LITE_ENSURE(
192       context, op_context.dims <= reference_ops::PadKernelMaxDimensionCount());
193 
194   tflite::PadParams op_params;
195   op_params.left_padding_count = op_context.dims;
196   op_params.right_padding_count = op_context.dims;
197 
198   for (int idx = op_context.dims - 1; idx >= 0; --idx) {
199     op_params.left_padding[idx] = paddings_data[idx * 2];
200     op_params.right_padding[idx] = paddings_data[idx * 2 + 1];
201   }
202 
203 #define TF_LITE_PAD(type, op_name, scalar, pad_value)                     \
204   const scalar pad_value_copy = pad_value;                                \
205                                                                           \
206   type::op_name(op_params, GetTensorShape(op_context.input),              \
207                 GetTensorData<scalar>(op_context.input), &pad_value_copy, \
208                 GetTensorShape(op_context.output),                        \
209                 GetTensorData<scalar>(op_context.output))
210   switch (op_context.input->type) {
211     case kTfLiteFloat32: {
212       float pad_value = op_context.constant_values == nullptr
213                             ? 0.f
214                             : *GetTensorData<float>(op_context.constant_values);
215       if (kernel_type == kReference) {
216         if (op_context.resizing_category == ResizingCategory::kImageStyle) {
217           TF_LITE_PAD(reference_ops, PadImageStyle, float, pad_value);
218         } else {
219           TF_LITE_PAD(reference_ops, Pad, float, pad_value);
220         }
221       } else if (kernel_type == kGenericOptimized) {
222         if (op_context.resizing_category == ResizingCategory::kImageStyle) {
223           TF_LITE_PAD(optimized_ops, PadImageStyle, float, pad_value);
224         } else {
225           TF_LITE_PAD(optimized_ops, Pad, float, pad_value);
226         }
227       }
228     } break;
229     case kTfLiteUInt8: {
230       EvalInt<uint8_t>(context, op_context, op_params);
231     } break;
232     case kTfLiteInt8: {
233       EvalInt<int8_t>(context, op_context, op_params);
234     } break;
235     case kTfLiteInt16: {
236       EvalInt<int16_t>(context, op_context, op_params);
237     } break;
238     case kTfLiteInt32: {
239       int32_t pad_value =
240           op_context.constant_values == nullptr
241               ? 0
242               : *GetTensorData<int32_t>(op_context.constant_values);
243       if (kernel_type == kReference) {
244         TF_LITE_PAD(reference_ops, Pad, int32_t, pad_value);
245       } else if (kernel_type == kGenericOptimized) {
246         TF_LITE_PAD(optimized_ops, Pad, int32_t, pad_value);
247       }
248     } break;
249     case kTfLiteInt64: {
250       int64_t pad_value =
251           op_context.constant_values == nullptr
252               ? 0L
253               : *GetTensorData<int64_t>(op_context.constant_values);
254       if (kernel_type == kReference) {
255         TF_LITE_PAD(reference_ops, Pad, int64_t, pad_value);
256       } else if (kernel_type == kGenericOptimized) {
257         TF_LITE_PAD(optimized_ops, Pad, int64_t, pad_value);
258       }
259     } break;
260     default:
261       TF_LITE_KERNEL_LOG(context, "Type %s is currently not supported by Pad.",
262                          TfLiteTypeGetName(op_context.input->type));
263       return kTfLiteError;
264   }
265 #undef TF_LITE_PAD
266   return kTfLiteOk;
267 }
268 
269 }  // namespace pad
270 
Register_PAD_REF()271 TfLiteRegistration* Register_PAD_REF() {
272   static TfLiteRegistration r = {nullptr, nullptr, pad::Prepare,
273                                  pad::Eval<pad::kReference>};
274   return &r;
275 }
276 
Register_PAD_GENERIC_OPT()277 TfLiteRegistration* Register_PAD_GENERIC_OPT() {
278   static TfLiteRegistration r = {nullptr, nullptr, pad::Prepare,
279                                  pad::Eval<pad::kGenericOptimized>};
280   return &r;
281 }
282 
Register_PAD()283 TfLiteRegistration* Register_PAD() { return Register_PAD_GENERIC_OPT(); }
284 
285 // Also register Pad as PadV2.
Register_PADV2_REF()286 TfLiteRegistration* Register_PADV2_REF() {
287   static TfLiteRegistration r = {nullptr, nullptr, pad::Prepare,
288                                  pad::Eval<pad::kReference>};
289   return &r;
290 }
291 
Register_PADV2_GENERIC_OPT()292 TfLiteRegistration* Register_PADV2_GENERIC_OPT() {
293   static TfLiteRegistration r = {nullptr, nullptr, pad::Prepare,
294                                  pad::Eval<pad::kGenericOptimized>};
295   return &r;
296 }
297 
Register_PADV2()298 TfLiteRegistration* Register_PADV2() { return Register_PADV2_GENERIC_OPT(); }
299 
300 }  // namespace builtin
301 }  // namespace ops
302 }  // namespace tflite
303