1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <string.h>
16 #include <vector>
17 #include "tensorflow/lite/c/builtin_op_data.h"
18 #include "tensorflow/lite/c/c_api_internal.h"
19 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
20 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
21 #include "tensorflow/lite/kernels/internal/tensor.h"
22 #include "tensorflow/lite/kernels/kernel_util.h"
23 #include "tensorflow/lite/kernels/op_macros.h"
24
25 namespace tflite {
26 namespace ops {
27 namespace builtin {
28 namespace pad {
29
30 // This file has two implementations of Pad.
31 enum KernelType {
32 kReference,
33 kGenericOptimized,
34 };
35
36 struct PadContext {
PadContexttflite::ops::builtin::pad::PadContext37 PadContext(TfLiteContext* context, TfLiteNode* node) {
38 input = GetInput(context, node, 0);
39 paddings = GetInput(context, node, 1);
40 if (NumInputs(node) == 3) {
41 constant_values = GetOptionalInputTensor(context, node, 2);
42 } else {
43 constant_values = nullptr;
44 }
45 output = GetOutput(context, node, 0);
46 dims = NumDimensions(input);
47
48 resizing_category = ResizingCategory::kGenericResize;
49 const int paddings_total = GetTensorShape(paddings).FlatSize();
50 const int32* paddings_data = GetTensorData<int32>(paddings);
51 // Paddings will be a n,2 array, and we need to detect 4D arrays with the
52 // pattern { {0,0}, {a, b}, {c, d}, {0,0} }.
53 if (IsConstantTensor(paddings) && paddings_total == 8 &&
54 (paddings_data[0] == 0 && paddings_data[1] == 0) &&
55 (paddings_data[6] == 0 && paddings_data[7] == 0)) {
56 resizing_category = ResizingCategory::kImageStyle;
57 }
58 }
59 const TfLiteTensor* constant_values;
60 const TfLiteTensor* input;
61 const TfLiteTensor* paddings;
62 TfLiteTensor* output;
63 int dims;
64 ResizingCategory resizing_category;
65 };
66
67 // Resizes output array based on the input size and padding size. This function
68 // is callable from both Prepare() and Eval() as long as the caller ensures the
69 // paddings data is present.
ResizeOutputTensor(TfLiteContext * context,PadContext * op_context)70 TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
71 PadContext* op_context) {
72 // Ensures the paddings array is dims x 2.
73 TF_LITE_ENSURE_EQ(context, SizeOfDimension(op_context->paddings, 0),
74 op_context->dims);
75 TF_LITE_ENSURE_EQ(context, SizeOfDimension(op_context->paddings, 1), 2);
76
77 // Determines the size of the output tensor.
78 TfLiteIntArray* input_size = op_context->input->dims;
79 TfLiteIntArray* output_size = TfLiteIntArrayCopy(input_size);
80 const int32* paddings_data = GetTensorData<int32>(op_context->paddings);
81
82 for (int idx = 0; idx < op_context->dims; ++idx) {
83 int before_padding = *paddings_data++;
84 int after_padding = *paddings_data++;
85
86 TF_LITE_ENSURE_MSG(context, (before_padding >= 0 && after_padding >= 0),
87 "Pad value has to be greater than equal to 0.");
88
89 output_size->data[idx] =
90 (input_size->data[idx] + before_padding + after_padding);
91 }
92
93 return context->ResizeTensor(context, op_context->output, output_size);
94 }
95
Prepare(TfLiteContext * context,TfLiteNode * node)96 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
97 TF_LITE_ENSURE(context, NumInputs(node) == 2 || NumInputs(node) == 3);
98 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
99
100 PadContext op_context(context, node);
101 TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type);
102 if (op_context.constant_values != nullptr) {
103 TF_LITE_ENSURE_EQ(context, op_context.input->type,
104 op_context.constant_values->type);
105 }
106
107 // TODO(nupurgarg): Current implementations rely on the inputs being <= 4D.
108 TF_LITE_ENSURE(context, op_context.dims <= 4);
109
110 // Exit early if paddings is a non-const tensor. Set output tensor to
111 // dynamic so output size can be determined in Eval.
112 if (!IsConstantTensor(op_context.paddings)) {
113 SetTensorToDynamic(op_context.output);
114 return kTfLiteOk;
115 }
116 return ResizeOutputTensor(context, &op_context);
117 }
118
119 template <KernelType kernel_type>
Eval(TfLiteContext * context,TfLiteNode * node)120 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
121 PadContext op_context(context, node);
122
123 if (op_context.constant_values != nullptr) {
124 // Ensure that constant_values is a scalar.
125 TF_LITE_ENSURE_EQ(context, NumElements(op_context.constant_values), 1);
126 }
127
128 // Resize the output tensor if the output tensor is dynamic.
129 if (IsDynamicTensor(op_context.output)) {
130 TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
131 }
132
133 // TODO(nupurgarg): Change kernel implementation to take in int* instead of
134 // vector<int> to remove malloc from Eval().
135 // Create before and after padding arrays that are accepted by the kernel.
136 std::vector<int> before_padding;
137 std::vector<int> after_padding;
138 const int32* paddings_data = GetTensorData<int32>(op_context.paddings);
139
140 // TODO(nupurgarg): Change kernel implementation to use padding arrays in
141 // forward order (depth, width, height, batch).
142 // Build paddings in order of int[] = {batch, height, width, depth} to match
143 // kernel implementation of Pad in reference_ops.h and optimized_ops.h.
144 for (int idx = op_context.dims - 1; idx >= 0; --idx) {
145 before_padding.push_back(paddings_data[idx * 2]);
146 after_padding.push_back(paddings_data[idx * 2 + 1]);
147 }
148
149 #define TF_LITE_PAD(type, op_name, scalar, pad_value) \
150 TF_LITE_ENSURE(context, before_padding.size() <= 4); \
151 TF_LITE_ENSURE(context, after_padding.size() <= 4); \
152 tflite::PadParams op_params; \
153 op_params.left_padding_count = before_padding.size(); \
154 op_params.right_padding_count = after_padding.size(); \
155 for (int i = 0; i < op_context.dims; ++i) { \
156 op_params.left_padding[i] = before_padding[op_context.dims - 1 - i]; \
157 op_params.right_padding[i] = after_padding[op_context.dims - 1 - i]; \
158 } \
159 const scalar pad_value_copy = pad_value; \
160 \
161 type::op_name(op_params, GetTensorShape(op_context.input), \
162 GetTensorData<scalar>(op_context.input), &pad_value_copy, \
163 GetTensorShape(op_context.output), \
164 GetTensorData<scalar>(op_context.output))
165 switch (op_context.input->type) {
166 case kTfLiteFloat32: {
167 float pad_value = op_context.constant_values == nullptr
168 ? 0.f
169 : *GetTensorData<float>(op_context.constant_values);
170 if (kernel_type == kReference) {
171 if (op_context.resizing_category == ResizingCategory::kImageStyle) {
172 TF_LITE_PAD(reference_ops, PadImageStyle, float, pad_value);
173 } else {
174 TF_LITE_PAD(reference_ops, Pad, float, pad_value);
175 }
176 } else if (kernel_type == kGenericOptimized) {
177 if (op_context.resizing_category == ResizingCategory::kImageStyle) {
178 TF_LITE_PAD(optimized_ops, PadImageStyle, float, pad_value);
179 } else {
180 TF_LITE_PAD(optimized_ops, Pad, float, pad_value);
181 }
182 }
183 } break;
184 case kTfLiteUInt8: {
185 uint8_t pad_value;
186 if (op_context.constant_values == nullptr) {
187 // Quantized Pad requires that 0 is represented in the quantized
188 // range.
189 TF_LITE_ENSURE(context, op_context.output->params.zero_point >=
190 std::numeric_limits<uint8_t>::min());
191 TF_LITE_ENSURE(context, op_context.output->params.zero_point <=
192 std::numeric_limits<uint8_t>::max());
193 pad_value = static_cast<uint8_t>(op_context.output->params.zero_point);
194 } else {
195 // Quantized Pad requires that 'constant_values' is represented in the
196 // same quantized range as the input and output tensors.
197 TF_LITE_ENSURE_EQ(context, op_context.output->params.zero_point,
198 op_context.constant_values->params.zero_point);
199 TF_LITE_ENSURE_EQ(context, op_context.output->params.scale,
200 op_context.constant_values->params.scale);
201 pad_value = *GetTensorData<uint8_t>(op_context.constant_values);
202 }
203 if (kernel_type == kReference) {
204 if (op_context.resizing_category == ResizingCategory::kImageStyle) {
205 TF_LITE_PAD(reference_ops, PadImageStyle, uint8_t, pad_value);
206 } else {
207 TF_LITE_PAD(reference_ops, Pad, uint8_t, pad_value);
208 }
209 } else if (kernel_type == kGenericOptimized) {
210 if (op_context.resizing_category == ResizingCategory::kImageStyle) {
211 TF_LITE_PAD(optimized_ops, PadImageStyle, uint8_t, pad_value);
212 } else {
213 TF_LITE_PAD(optimized_ops, Pad, uint8_t, pad_value);
214 }
215 }
216 } break;
217 case kTfLiteInt8: {
218 int8_t pad_value;
219 if (op_context.constant_values == nullptr) {
220 // Quantized Pad requires that 0 is represented in the quantized
221 // range.
222 TF_LITE_ENSURE(context, op_context.output->params.zero_point >=
223 std::numeric_limits<int8_t>::min());
224 TF_LITE_ENSURE(context, op_context.output->params.zero_point <=
225 std::numeric_limits<int8_t>::max());
226 pad_value = static_cast<int8_t>(op_context.output->params.zero_point);
227 } else {
228 // Quantized Pad requires that 'constant_values' is represented in the
229 // same quantized range as the input and output tensors.
230 TF_LITE_ENSURE_EQ(context, op_context.output->params.zero_point,
231 op_context.constant_values->params.zero_point);
232 TF_LITE_ENSURE_EQ(context, op_context.output->params.scale,
233 op_context.constant_values->params.scale);
234 pad_value = *GetTensorData<int8_t>(op_context.constant_values);
235 }
236 if (op_context.resizing_category == ResizingCategory::kImageStyle) {
237 TF_LITE_PAD(reference_ops, PadImageStyle, int8_t, pad_value);
238 } else {
239 TF_LITE_PAD(reference_ops, Pad, int8_t, pad_value);
240 }
241 } break;
242 case kTfLiteInt32: {
243 int32_t pad_value =
244 op_context.constant_values == nullptr
245 ? 0
246 : *GetTensorData<int32_t>(op_context.constant_values);
247 if (kernel_type == kReference) {
248 TF_LITE_PAD(reference_ops, Pad, int32_t, pad_value);
249 } else if (kernel_type == kGenericOptimized) {
250 TF_LITE_PAD(optimized_ops, Pad, int32_t, pad_value);
251 }
252 } break;
253 case kTfLiteInt64: {
254 int64_t pad_value =
255 op_context.constant_values == nullptr
256 ? 0L
257 : *GetTensorData<int64_t>(op_context.constant_values);
258 if (kernel_type == kReference) {
259 TF_LITE_PAD(reference_ops, Pad, int64_t, pad_value);
260 } else if (kernel_type == kGenericOptimized) {
261 TF_LITE_PAD(optimized_ops, Pad, int64_t, pad_value);
262 }
263 } break;
264 default:
265 context->ReportError(context,
266 "Type %d is currently not supported by Pad.",
267 op_context.input->type);
268 return kTfLiteError;
269 }
270 #undef TF_LITE_PAD
271 return kTfLiteOk;
272 }
273
274 } // namespace pad
275
Register_PAD_REF()276 TfLiteRegistration* Register_PAD_REF() {
277 static TfLiteRegistration r = {nullptr, nullptr, pad::Prepare,
278 pad::Eval<pad::kReference>};
279 return &r;
280 }
281
Register_PAD_GENERIC_OPT()282 TfLiteRegistration* Register_PAD_GENERIC_OPT() {
283 static TfLiteRegistration r = {nullptr, nullptr, pad::Prepare,
284 pad::Eval<pad::kGenericOptimized>};
285 return &r;
286 }
287
Register_PAD()288 TfLiteRegistration* Register_PAD() { return Register_PAD_GENERIC_OPT(); }
289
290 // Also register Pad as PadV2.
Register_PADV2_REF()291 TfLiteRegistration* Register_PADV2_REF() {
292 static TfLiteRegistration r = {nullptr, nullptr, pad::Prepare,
293 pad::Eval<pad::kReference>};
294 return &r;
295 }
296
Register_PADV2_GENERIC_OPT()297 TfLiteRegistration* Register_PADV2_GENERIC_OPT() {
298 static TfLiteRegistration r = {nullptr, nullptr, pad::Prepare,
299 pad::Eval<pad::kGenericOptimized>};
300 return &r;
301 }
302
Register_PADV2()303 TfLiteRegistration* Register_PADV2() { return Register_PADV2_GENERIC_OPT(); }
304
305 } // namespace builtin
306 } // namespace ops
307 } // namespace tflite
308