1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stdint.h>
16
17 #include "tensorflow/lite/c/common.h"
18 #include "tensorflow/lite/kernels/internal/compatibility.h"
19 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
20 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
21 #include "tensorflow/lite/kernels/internal/tensor.h"
22 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
23 #include "tensorflow/lite/kernels/internal/types.h"
24 #include "tensorflow/lite/kernels/kernel_util.h"
25
26 namespace tflite {
27 namespace ops {
28 namespace builtin {
29 namespace space_to_batch_nd {
30
31 // This file has two implementations of SpaceToBatchND.
32 enum KernelType {
33 kReference,
34 kGenericOptimized,
35 };
36
37 struct SpaceToBatchNDContext {
SpaceToBatchNDContexttflite::ops::builtin::space_to_batch_nd::SpaceToBatchNDContext38 SpaceToBatchNDContext(TfLiteContext* context, TfLiteNode* node) {
39 input = GetInput(context, node, 0);
40 block_shape = GetInput(context, node, 1);
41 paddings = GetInput(context, node, 2);
42 output = GetOutput(context, node, 0);
43 }
44 const TfLiteTensor* input;
45 const TfLiteTensor* block_shape;
46 const TfLiteTensor* paddings;
47 TfLiteTensor* output;
48 };
49
50 // Currently, only 3D NHC and 4D NHWC input/output op_context are supported.
51 // In case of 3D input, it will be extended to 3D NHWC by adding W=1.
52 // The 4D array need to have exactly 2 spatial dimensions.
53 // TODO(b/149952582): Support arbitrary dimension in SpaceToBatchND.
54 const int kInputMinDimensionNum = 3;
55 const int kInputMaxDimensionNum = 4;
56
ResizeOutputTensor(TfLiteContext * context,SpaceToBatchNDContext * op_context)57 TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
58 SpaceToBatchNDContext* op_context) {
59 TfLiteIntArray* input_size = op_context->input->dims;
60 const int32* block_shape = GetTensorData<int32>(op_context->block_shape);
61 const int32* paddings_data = GetTensorData<int32>(op_context->paddings);
62
63 int spatial_dims_num = input_size->size - 2;
64 // Block_shape should be a 1D tensor with dimension [spatial_dims_num].
65 TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->block_shape), 1);
66 TF_LITE_ENSURE_EQ(context, op_context->block_shape->dims->data[0],
67 spatial_dims_num);
68 // Paddings should be a 2D tensor with dimension [spatial_dims_num, 2].
69 TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->paddings), 2);
70 TF_LITE_ENSURE_EQ(context, op_context->paddings->dims->data[0],
71 spatial_dims_num);
72 TF_LITE_ENSURE_EQ(context, op_context->paddings->dims->data[1], 2);
73
74 TfLiteIntArray* output_size = TfLiteIntArrayCopy(input_size);
75
76 // Ensures the input height and width (with padding) is a multiple of block
77 // shape height and width.
78 int output_batch_size = input_size->data[0];
79 for (int dim = 0; dim < spatial_dims_num; ++dim) {
80 int final_dim_size = (input_size->data[dim + 1] + paddings_data[dim * 2] +
81 paddings_data[dim * 2 + 1]);
82 TF_LITE_ENSURE(context, block_shape[dim] != 0);
83 TF_LITE_ENSURE_EQ(context, final_dim_size % block_shape[dim], 0);
84 output_size->data[dim + 1] = final_dim_size / block_shape[dim];
85 output_batch_size *= block_shape[dim];
86 }
87
88 output_size->data[0] = output_batch_size;
89 output_size->data[input_size->size - 1] =
90 input_size->data[input_size->size - 1];
91
92 return context->ResizeTensor(context, op_context->output, output_size);
93 }
94
Prepare(TfLiteContext * context,TfLiteNode * node)95 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
96 TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
97 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
98
99 SpaceToBatchNDContext op_context(context, node);
100 TF_LITE_ENSURE(context,
101 NumDimensions(op_context.input) >= kInputMinDimensionNum);
102 TF_LITE_ENSURE(context,
103 NumDimensions(op_context.input) <= kInputMaxDimensionNum);
104 TF_LITE_ENSURE_TYPES_EQ(context, op_context.input->type,
105 op_context.output->type);
106
107 if (!IsConstantTensor(op_context.block_shape) ||
108 !IsConstantTensor(op_context.paddings)) {
109 SetTensorToDynamic(op_context.output);
110 return kTfLiteOk;
111 }
112 return ResizeOutputTensor(context, &op_context);
113 }
114
115 template <KernelType kernel_type>
Eval(TfLiteContext * context,TfLiteNode * node)116 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
117 SpaceToBatchNDContext op_context(context, node);
118
119 // Resize the output tensor if the output tensor is dynamic.
120 if (IsDynamicTensor(op_context.output)) {
121 TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
122 }
123
124 #define TF_LITE_SPACE_TO_BATCH_ND(type, scalar, pad_value) \
125 tflite::SpaceToBatchParams op_params; \
126 op_params.output_offset = pad_value; \
127 type::SpaceToBatchND(op_params, GetTensorShape(op_context.input), \
128 GetTensorData<scalar>(op_context.input), \
129 GetTensorShape(op_context.block_shape), \
130 GetTensorData<int32_t>(op_context.block_shape), \
131 GetTensorShape(op_context.paddings), \
132 GetTensorData<int32_t>(op_context.paddings), \
133 GetTensorShape(op_context.output), \
134 GetTensorData<scalar>(op_context.output))
135 switch (op_context.input->type) { // Already know in/out types are same.
136 case kTfLiteFloat32:
137 if (kernel_type == kReference) {
138 TF_LITE_SPACE_TO_BATCH_ND(reference_ops, float, 0);
139 } else {
140 TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, float, 0);
141 }
142 break;
143 case kTfLiteUInt8:
144 if (kernel_type == kReference) {
145 TF_LITE_SPACE_TO_BATCH_ND(reference_ops, uint8_t,
146 op_context.output->params.zero_point);
147 } else {
148 TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, uint8_t,
149 op_context.output->params.zero_point);
150 }
151 break;
152 case kTfLiteInt8:
153 if (kernel_type == kReference) {
154 TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int8_t,
155 op_context.output->params.zero_point);
156 } else {
157 TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, int8_t,
158 op_context.output->params.zero_point);
159 }
160 break;
161 case kTfLiteInt32:
162 if (kernel_type == kReference) {
163 TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int32_t, 0);
164 } else {
165 TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, int32_t, 0);
166 }
167 break;
168 case kTfLiteInt64:
169 if (kernel_type == kReference) {
170 TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int64_t, 0);
171 } else {
172 TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, int64_t, 0);
173 }
174 break;
175 default:
176 TF_LITE_KERNEL_LOG(context,
177 "Type %d is currently not supported by SpaceToBatch.",
178 op_context.input->type);
179 return kTfLiteError;
180 }
181 #undef TF_LITE_SPACE_TO_BATCH_ND
182 return kTfLiteOk;
183 }
184
185 } // namespace space_to_batch_nd
186
Register_SPACE_TO_BATCH_ND_REF()187 TfLiteRegistration* Register_SPACE_TO_BATCH_ND_REF() {
188 static TfLiteRegistration r = {
189 nullptr, nullptr, space_to_batch_nd::Prepare,
190 space_to_batch_nd::Eval<space_to_batch_nd::kReference>};
191 return &r;
192 }
193
Register_SPACE_TO_BATCH_ND_GENERIC_OPT()194 TfLiteRegistration* Register_SPACE_TO_BATCH_ND_GENERIC_OPT() {
195 static TfLiteRegistration r = {
196 nullptr, nullptr, space_to_batch_nd::Prepare,
197 space_to_batch_nd::Eval<space_to_batch_nd::kGenericOptimized>};
198 return &r;
199 }
200
Register_SPACE_TO_BATCH_ND()201 TfLiteRegistration* Register_SPACE_TO_BATCH_ND() {
202 // return Register_SPACE_TO_BATCH_ND_REF();
203 return Register_SPACE_TO_BATCH_ND_GENERIC_OPT();
204 }
205
206 } // namespace builtin
207 } // namespace ops
208 } // namespace tflite
209