1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stdint.h>
16
17 #include "tensorflow/lite/c/common.h"
18 #include "tensorflow/lite/kernels/internal/compatibility.h"
19 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
20 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
21 #include "tensorflow/lite/kernels/internal/tensor.h"
22 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
23 #include "tensorflow/lite/kernels/internal/types.h"
24 #include "tensorflow/lite/kernels/kernel_util.h"
25
26 namespace tflite {
27 namespace ops {
28 namespace builtin {
29 namespace space_to_batch_nd {
30
31 // This file has two implementations of SpaceToBatchND.
32 enum KernelType {
33 kReference,
34 kGenericOptimized,
35 };
36
37 struct SpaceToBatchNDContext {
SpaceToBatchNDContexttflite::ops::builtin::space_to_batch_nd::SpaceToBatchNDContext38 SpaceToBatchNDContext(TfLiteContext* context, TfLiteNode* node) {
39 input = GetInput(context, node, 0);
40 block_shape = GetInput(context, node, 1);
41 paddings = GetInput(context, node, 2);
42 output = GetOutput(context, node, 0);
43 }
44 const TfLiteTensor* input;
45 const TfLiteTensor* block_shape;
46 const TfLiteTensor* paddings;
47 TfLiteTensor* output;
48 };
49
50 // Currently, only 3D NHC and 4D NHWC input/output op_context are supported.
51 // In case of 3D input, it will be extended to 3D NHWC by adding W=1.
52 // The 4D array need to have exactly 2 spatial dimensions.
53 // TODO(b/149952582): Support arbitrary dimension in SpaceToBatchND.
54 const int kInputMinDimensionNum = 3;
55 const int kInputMaxDimensionNum = 4;
56
ResizeOutputTensor(TfLiteContext * context,SpaceToBatchNDContext * op_context)57 TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
58 SpaceToBatchNDContext* op_context) {
59 TfLiteIntArray* input_size = op_context->input->dims;
60 const int32* block_shape = GetTensorData<int32>(op_context->block_shape);
61 const int32* paddings_data = GetTensorData<int32>(op_context->paddings);
62
63 int spatial_dims_num = input_size->size - 2;
64 // Block_shape should be a 1D tensor with dimension [spatial_dims_num].
65 TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->block_shape), 1);
66 TF_LITE_ENSURE_EQ(context, op_context->block_shape->dims->data[0],
67 spatial_dims_num);
68 // Paddings should be a 2D tensor with dimension [spatial_dims_num, 2].
69 TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->paddings), 2);
70 TF_LITE_ENSURE_EQ(context, op_context->paddings->dims->data[0],
71 spatial_dims_num);
72 TF_LITE_ENSURE_EQ(context, op_context->paddings->dims->data[1], 2);
73
74 TfLiteIntArray* output_size = TfLiteIntArrayCopy(input_size);
75
76 // Ensures the input height and width (with padding) is a multiple of block
77 // shape height and width.
78 int output_batch_size = input_size->data[0];
79 for (int dim = 0; dim < spatial_dims_num; ++dim) {
80 int final_dim_size = (input_size->data[dim + 1] + paddings_data[dim * 2] +
81 paddings_data[dim * 2 + 1]);
82 TF_LITE_ENSURE_EQ(context, final_dim_size % block_shape[dim], 0);
83 output_size->data[dim + 1] = final_dim_size / block_shape[dim];
84 output_batch_size *= block_shape[dim];
85 }
86
87 output_size->data[0] = output_batch_size;
88 output_size->data[input_size->size - 1] =
89 input_size->data[input_size->size - 1];
90
91 return context->ResizeTensor(context, op_context->output, output_size);
92 }
93
Prepare(TfLiteContext * context,TfLiteNode * node)94 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
95 TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
96 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
97
98 SpaceToBatchNDContext op_context(context, node);
99 TF_LITE_ENSURE(context,
100 NumDimensions(op_context.input) >= kInputMinDimensionNum);
101 TF_LITE_ENSURE(context,
102 NumDimensions(op_context.input) <= kInputMaxDimensionNum);
103 TF_LITE_ENSURE_TYPES_EQ(context, op_context.input->type,
104 op_context.output->type);
105
106 if (!IsConstantTensor(op_context.block_shape) ||
107 !IsConstantTensor(op_context.paddings)) {
108 SetTensorToDynamic(op_context.output);
109 return kTfLiteOk;
110 }
111 return ResizeOutputTensor(context, &op_context);
112 }
113
114 template <KernelType kernel_type>
Eval(TfLiteContext * context,TfLiteNode * node)115 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
116 SpaceToBatchNDContext op_context(context, node);
117
118 // Resize the output tensor if the output tensor is dynamic.
119 if (IsDynamicTensor(op_context.output)) {
120 TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
121 }
122
123 #define TF_LITE_SPACE_TO_BATCH_ND(type, scalar, pad_value) \
124 tflite::SpaceToBatchParams op_params; \
125 op_params.output_offset = pad_value; \
126 type::SpaceToBatchND(op_params, GetTensorShape(op_context.input), \
127 GetTensorData<scalar>(op_context.input), \
128 GetTensorShape(op_context.block_shape), \
129 GetTensorData<int32_t>(op_context.block_shape), \
130 GetTensorShape(op_context.paddings), \
131 GetTensorData<int32_t>(op_context.paddings), \
132 GetTensorShape(op_context.output), \
133 GetTensorData<scalar>(op_context.output))
134 switch (op_context.input->type) { // Already know in/out types are same.
135 case kTfLiteFloat32:
136 if (kernel_type == kReference) {
137 TF_LITE_SPACE_TO_BATCH_ND(reference_ops, float, 0);
138 } else {
139 TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, float, 0);
140 }
141 break;
142 case kTfLiteUInt8:
143 if (kernel_type == kReference) {
144 TF_LITE_SPACE_TO_BATCH_ND(reference_ops, uint8_t,
145 op_context.output->params.zero_point);
146 } else {
147 TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, uint8_t,
148 op_context.output->params.zero_point);
149 }
150 break;
151 case kTfLiteInt8:
152 if (kernel_type == kReference) {
153 TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int8_t,
154 op_context.output->params.zero_point);
155 } else {
156 TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, int8_t,
157 op_context.output->params.zero_point);
158 }
159 break;
160 case kTfLiteInt32:
161 if (kernel_type == kReference) {
162 TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int32_t, 0);
163 } else {
164 TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, int32_t, 0);
165 }
166 break;
167 case kTfLiteInt64:
168 if (kernel_type == kReference) {
169 TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int64_t, 0);
170 } else {
171 TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, int64_t, 0);
172 }
173 break;
174 default:
175 context->ReportError(
176 context, "Type %d is currently not supported by SpaceToBatch.",
177 op_context.input->type);
178 return kTfLiteError;
179 }
180 #undef TF_LITE_SPACE_TO_BATCH_ND
181 return kTfLiteOk;
182 }
183
184 } // namespace space_to_batch_nd
185
Register_SPACE_TO_BATCH_ND_REF()186 TfLiteRegistration* Register_SPACE_TO_BATCH_ND_REF() {
187 static TfLiteRegistration r = {
188 nullptr, nullptr, space_to_batch_nd::Prepare,
189 space_to_batch_nd::Eval<space_to_batch_nd::kReference>};
190 return &r;
191 }
192
Register_SPACE_TO_BATCH_ND_GENERIC_OPT()193 TfLiteRegistration* Register_SPACE_TO_BATCH_ND_GENERIC_OPT() {
194 static TfLiteRegistration r = {
195 nullptr, nullptr, space_to_batch_nd::Prepare,
196 space_to_batch_nd::Eval<space_to_batch_nd::kGenericOptimized>};
197 return &r;
198 }
199
Register_SPACE_TO_BATCH_ND()200 TfLiteRegistration* Register_SPACE_TO_BATCH_ND() {
201 // return Register_SPACE_TO_BATCH_ND_REF();
202 return Register_SPACE_TO_BATCH_ND_GENERIC_OPT();
203 }
204
205 } // namespace builtin
206 } // namespace ops
207 } // namespace tflite
208