1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <string.h>
16 #include <vector>
17 #include "tensorflow/lite/c/builtin_op_data.h"
18 #include "tensorflow/lite/c/c_api_internal.h"
19 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
20 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
21 #include "tensorflow/lite/kernels/internal/tensor.h"
22 #include "tensorflow/lite/kernels/kernel_util.h"
23 #include "tensorflow/lite/kernels/op_macros.h"
24
25 namespace tflite {
26 namespace ops {
27 namespace builtin {
28 namespace batch_to_space_nd {
29
30 // This file has two implementations of BatchToSpaceND.
31 enum KernelType {
32 kReference,
33 kGenericOptimized,
34 };
35
36 struct BatchToSpaceNDContext {
BatchToSpaceNDContexttflite::ops::builtin::batch_to_space_nd::BatchToSpaceNDContext37 BatchToSpaceNDContext(TfLiteContext* context, TfLiteNode* node) {
38 input = GetInput(context, node, 0);
39 block_shape = GetInput(context, node, 1);
40 crops = GetInput(context, node, 2);
41 output = GetOutput(context, node, 0);
42 }
43 const TfLiteTensor* input;
44 const TfLiteTensor* block_shape;
45 const TfLiteTensor* crops;
46 TfLiteTensor* output;
47 };
48
49 // Currently, only 4D NHWC input/output op_context are supported.
50 // The 4D array need to have exactly 2 spatial dimensions.
51 // TODO(ycling): Support arbitrary dimension in BatchToSpaceND.
52 const int kInputDimensionNum = 4;
53 const int kBlockSizeDimensionNum = 1;
54 const int kSpatialDimensionNum = 2;
55
ResizeOutputTensor(TfLiteContext * context,BatchToSpaceNDContext * op_context)56 TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
57 BatchToSpaceNDContext* op_context) {
58 TfLiteIntArray* input_size = op_context->input->dims;
59 const int* block_shape = GetTensorData<int32>(op_context->block_shape);
60 const int* crops = GetTensorData<int32>(op_context->crops);
61
62 TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->block_shape),
63 kBlockSizeDimensionNum);
64 TF_LITE_ENSURE_EQ(context, op_context->block_shape->dims->data[0],
65 kSpatialDimensionNum);
66 TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->crops),
67 kSpatialDimensionNum);
68
69 TF_LITE_ENSURE(context, crops[0] >= 0);
70 TF_LITE_ENSURE(context, crops[1] >= 0);
71 TF_LITE_ENSURE(context, crops[2] >= 0);
72 TF_LITE_ENSURE(context, crops[3] >= 0);
73
74 // Number of batch must be multiple of (block_shape[0] * block_shape[1]).
75 TF_LITE_ENSURE_EQ(context,
76 input_size->data[0] % (block_shape[0] * block_shape[1]), 0);
77
78 const int output_batch_size =
79 input_size->data[0] / (block_shape[0] * block_shape[1]);
80
81 const int crops_top = crops[0];
82 const int crops_bottom = crops[1];
83 const int crops_left = crops[2];
84 const int crops_right = crops[3];
85 const int output_height =
86 input_size->data[1] * block_shape[0] - crops_top - crops_bottom;
87 const int output_width =
88 input_size->data[2] * block_shape[1] - crops_left - crops_right;
89
90 const int output_channel_size = input_size->data[3];
91
92 TfLiteIntArray* output_size = TfLiteIntArrayCopy(input_size);
93 output_size->data[0] = output_batch_size;
94 output_size->data[1] = output_height;
95 output_size->data[2] = output_width;
96 output_size->data[3] = output_channel_size;
97
98 return context->ResizeTensor(context, op_context->output, output_size);
99 }
100
Prepare(TfLiteContext * context,TfLiteNode * node)101 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
102 TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
103 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
104
105 BatchToSpaceNDContext op_context(context, node);
106 TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.input),
107 kInputDimensionNum);
108 TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type);
109
110 if (!IsConstantTensor(op_context.block_shape) ||
111 !IsConstantTensor(op_context.crops)) {
112 SetTensorToDynamic(op_context.output);
113 return kTfLiteOk;
114 }
115 return ResizeOutputTensor(context, &op_context);
116 }
117
118 template <KernelType kernel_type>
Eval(TfLiteContext * context,TfLiteNode * node)119 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
120 BatchToSpaceNDContext op_context(context, node);
121
122 // Resize the output tensor if the output tensor is dynamic.
123 if (IsDynamicTensor(op_context.output)) {
124 TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
125 }
126
127 #define TF_LITE_BATCH_TO_SPACE_ND(type, scalar) \
128 type::BatchToSpaceND(GetTensorShape(op_context.input), \
129 GetTensorData<scalar>(op_context.input), \
130 GetTensorShape(op_context.block_shape), \
131 GetTensorData<int32_t>(op_context.block_shape), \
132 GetTensorShape(op_context.crops), \
133 GetTensorData<int32_t>(op_context.crops), \
134 GetTensorShape(op_context.output), \
135 GetTensorData<scalar>(op_context.output))
136 switch (op_context.input->type) { // Already know in/out types are same.
137 case kTfLiteFloat32:
138 if (kernel_type == kReference) {
139 TF_LITE_BATCH_TO_SPACE_ND(reference_ops, float);
140 } else {
141 TF_LITE_BATCH_TO_SPACE_ND(optimized_ops, float);
142 }
143 break;
144 case kTfLiteUInt8:
145 if (kernel_type == kReference) {
146 TF_LITE_BATCH_TO_SPACE_ND(reference_ops, uint8_t);
147 } else {
148 TF_LITE_BATCH_TO_SPACE_ND(optimized_ops, uint8_t);
149 }
150 break;
151 case kTfLiteInt8:
152 if (kernel_type == kReference) {
153 TF_LITE_BATCH_TO_SPACE_ND(reference_ops, int8_t);
154 } else {
155 TF_LITE_BATCH_TO_SPACE_ND(optimized_ops, int8_t);
156 }
157 break;
158 case kTfLiteInt32:
159 if (kernel_type == kReference) {
160 TF_LITE_BATCH_TO_SPACE_ND(reference_ops, int32_t);
161 } else {
162 TF_LITE_BATCH_TO_SPACE_ND(optimized_ops, int32_t);
163 }
164 break;
165 case kTfLiteInt64:
166 if (kernel_type == kReference) {
167 TF_LITE_BATCH_TO_SPACE_ND(reference_ops, int64_t);
168 } else {
169 TF_LITE_BATCH_TO_SPACE_ND(optimized_ops, int64_t);
170 }
171 break;
172 default:
173 context->ReportError(
174 context, "Type %d is currently not supported by BatchToSpace.",
175 op_context.input->type);
176 return kTfLiteError;
177 }
178 #undef TF_LITE_BATCH_TO_SPACE_ND
179 return kTfLiteOk;
180 }
181
182 } // namespace batch_to_space_nd
183
Register_BATCH_TO_SPACE_ND_REF()184 TfLiteRegistration* Register_BATCH_TO_SPACE_ND_REF() {
185 static TfLiteRegistration r = {
186 nullptr, nullptr, batch_to_space_nd::Prepare,
187 batch_to_space_nd::Eval<batch_to_space_nd::kReference>};
188 return &r;
189 }
190
Register_BATCH_TO_SPACE_ND_GENERIC_OPT()191 TfLiteRegistration* Register_BATCH_TO_SPACE_ND_GENERIC_OPT() {
192 static TfLiteRegistration r = {
193 nullptr, nullptr, batch_to_space_nd::Prepare,
194 batch_to_space_nd::Eval<batch_to_space_nd::kGenericOptimized>};
195 return &r;
196 }
197
Register_BATCH_TO_SPACE_ND()198 TfLiteRegistration* Register_BATCH_TO_SPACE_ND() {
199 return Register_BATCH_TO_SPACE_ND_GENERIC_OPT();
200 }
201
202 } // namespace builtin
203 } // namespace ops
204 } // namespace tflite
205