1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stdint.h>
16
17 #include "tensorflow/lite/c/common.h"
18 #include "tensorflow/lite/kernels/internal/compatibility.h"
19 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
20 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
21 #include "tensorflow/lite/kernels/internal/tensor.h"
22 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
23 #include "tensorflow/lite/kernels/kernel_util.h"
24
25 namespace tflite {
26 namespace ops {
27 namespace builtin {
28 namespace batch_to_space_nd {
29
30 // This file has two implementations of BatchToSpaceND.
31 enum KernelType {
32 kReference,
33 kGenericOptimized,
34 };
35
36 struct BatchToSpaceNDContext {
BatchToSpaceNDContexttflite::ops::builtin::batch_to_space_nd::BatchToSpaceNDContext37 BatchToSpaceNDContext(TfLiteContext* context, TfLiteNode* node) {
38 input = GetInput(context, node, 0);
39 block_shape = GetInput(context, node, 1);
40 crops = GetInput(context, node, 2);
41 output = GetOutput(context, node, 0);
42 }
43 const TfLiteTensor* input;
44 const TfLiteTensor* block_shape;
45 const TfLiteTensor* crops;
46 TfLiteTensor* output;
47 };
48
49 // Currently, only 3D NHC or 4D NHWC input/output op_context are supported.
50 // In case of 3D input,it will be converted to 4D by adding W=1 to be NH1C.
51 // The 4D array need to have exactly 2 spatial dimensions.
52 // TODO(ycling): Support arbitrary dimension in BatchToSpaceND.
53 const int kInputMinDimensionNum = 3;
54 const int kInputMaxDimensionNum = 4;
55
ResizeOutputTensor(TfLiteContext * context,BatchToSpaceNDContext * op_context)56 TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
57 BatchToSpaceNDContext* op_context) {
58 TfLiteIntArray* input_size = op_context->input->dims;
59 const int* block_shape = GetTensorData<int32>(op_context->block_shape);
60 const int* crops = GetTensorData<int32>(op_context->crops);
61
62 int spatial_dims_num = input_size->size - 2;
63 // Block_shape should be a 1D tensor with dimension [spatial_dims_num].
64 TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->block_shape), 1);
65 TF_LITE_ENSURE_EQ(context, op_context->block_shape->dims->data[0],
66 spatial_dims_num);
67 // Crops should be a 2D tensor with dimension [spatial_dims_num, 2].
68 TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->crops), 2);
69 TF_LITE_ENSURE_EQ(context, op_context->crops->dims->data[0],
70 spatial_dims_num);
71 TF_LITE_ENSURE_EQ(context, op_context->crops->dims->data[1], 2);
72
73 for (int i = 0; i < spatial_dims_num * 2; ++i) {
74 TF_LITE_ENSURE(context, crops[i] >= 0);
75 }
76
77 TfLiteIntArray* output_size = TfLiteIntArrayCopy(input_size);
78 int output_batch_size = input_size->data[0];
79 for (int dim = 0; dim < spatial_dims_num; ++dim) {
80 // Number of batch must be multiple of (block_shape[dim]).
81 TF_LITE_ENSURE_EQ(context, output_batch_size % block_shape[dim], 0);
82 output_batch_size = output_batch_size / block_shape[dim];
83 output_size->data[dim + 1] = input_size->data[dim + 1] * block_shape[dim] -
84 crops[dim * 2] - crops[dim * 2 + 1];
85 }
86
87 output_size->data[0] = output_batch_size;
88 output_size->data[input_size->size - 1] =
89 input_size->data[input_size->size - 1];
90
91 return context->ResizeTensor(context, op_context->output, output_size);
92 }
93
Prepare(TfLiteContext * context,TfLiteNode * node)94 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
95 TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
96 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
97
98 BatchToSpaceNDContext op_context(context, node);
99 TF_LITE_ENSURE(context,
100 NumDimensions(op_context.input) >= kInputMinDimensionNum);
101 TF_LITE_ENSURE(context,
102 NumDimensions(op_context.input) <= kInputMaxDimensionNum);
103 TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type);
104
105 if (!IsConstantTensor(op_context.block_shape) ||
106 !IsConstantTensor(op_context.crops)) {
107 SetTensorToDynamic(op_context.output);
108 return kTfLiteOk;
109 }
110 return ResizeOutputTensor(context, &op_context);
111 }
112
113 template <KernelType kernel_type>
Eval(TfLiteContext * context,TfLiteNode * node)114 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
115 BatchToSpaceNDContext op_context(context, node);
116
117 // Resize the output tensor if the output tensor is dynamic.
118 if (IsDynamicTensor(op_context.output)) {
119 TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
120 }
121
122 #define TF_LITE_BATCH_TO_SPACE_ND(type, scalar) \
123 type::BatchToSpaceND(GetTensorShape(op_context.input), \
124 GetTensorData<scalar>(op_context.input), \
125 GetTensorShape(op_context.block_shape), \
126 GetTensorData<int32_t>(op_context.block_shape), \
127 GetTensorShape(op_context.crops), \
128 GetTensorData<int32_t>(op_context.crops), \
129 GetTensorShape(op_context.output), \
130 GetTensorData<scalar>(op_context.output))
131 switch (op_context.input->type) { // Already know in/out types are same.
132 case kTfLiteFloat32:
133 if (kernel_type == kReference) {
134 TF_LITE_BATCH_TO_SPACE_ND(reference_ops, float);
135 } else {
136 TF_LITE_BATCH_TO_SPACE_ND(optimized_ops, float);
137 }
138 break;
139 case kTfLiteUInt8:
140 if (kernel_type == kReference) {
141 TF_LITE_BATCH_TO_SPACE_ND(reference_ops, uint8_t);
142 } else {
143 TF_LITE_BATCH_TO_SPACE_ND(optimized_ops, uint8_t);
144 }
145 break;
146 case kTfLiteInt8:
147 if (kernel_type == kReference) {
148 TF_LITE_BATCH_TO_SPACE_ND(reference_ops, int8_t);
149 } else {
150 TF_LITE_BATCH_TO_SPACE_ND(optimized_ops, int8_t);
151 }
152 break;
153 case kTfLiteInt32:
154 if (kernel_type == kReference) {
155 TF_LITE_BATCH_TO_SPACE_ND(reference_ops, int32_t);
156 } else {
157 TF_LITE_BATCH_TO_SPACE_ND(optimized_ops, int32_t);
158 }
159 break;
160 case kTfLiteInt64:
161 if (kernel_type == kReference) {
162 TF_LITE_BATCH_TO_SPACE_ND(reference_ops, int64_t);
163 } else {
164 TF_LITE_BATCH_TO_SPACE_ND(optimized_ops, int64_t);
165 }
166 break;
167 default:
168 context->ReportError(
169 context, "Type %d is currently not supported by BatchToSpace.",
170 op_context.input->type);
171 return kTfLiteError;
172 }
173 #undef TF_LITE_BATCH_TO_SPACE_ND
174 return kTfLiteOk;
175 }
176
177 } // namespace batch_to_space_nd
178
Register_BATCH_TO_SPACE_ND_REF()179 TfLiteRegistration* Register_BATCH_TO_SPACE_ND_REF() {
180 static TfLiteRegistration r = {
181 nullptr, nullptr, batch_to_space_nd::Prepare,
182 batch_to_space_nd::Eval<batch_to_space_nd::kReference>};
183 return &r;
184 }
185
Register_BATCH_TO_SPACE_ND_GENERIC_OPT()186 TfLiteRegistration* Register_BATCH_TO_SPACE_ND_GENERIC_OPT() {
187 static TfLiteRegistration r = {
188 nullptr, nullptr, batch_to_space_nd::Prepare,
189 batch_to_space_nd::Eval<batch_to_space_nd::kGenericOptimized>};
190 return &r;
191 }
192
Register_BATCH_TO_SPACE_ND()193 TfLiteRegistration* Register_BATCH_TO_SPACE_ND() {
194 return Register_BATCH_TO_SPACE_ND_GENERIC_OPT();
195 }
196
197 } // namespace builtin
198 } // namespace ops
199 } // namespace tflite
200