1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <string.h>
16 #include <vector>
17 #include "tensorflow/lite/c/builtin_op_data.h"
18 #include "tensorflow/lite/c/c_api_internal.h"
19 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
20 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
21 #include "tensorflow/lite/kernels/internal/tensor.h"
22 #include "tensorflow/lite/kernels/kernel_util.h"
23 #include "tensorflow/lite/kernels/op_macros.h"
24
25 namespace tflite {
26 namespace ops {
27 namespace builtin {
28 namespace space_to_batch_nd {
29
30 // This file has two implementations of SpaceToBatchND.
31 enum KernelType {
32 kReference,
33 kGenericOptimized,
34 };
35
36 struct SpaceToBatchNDContext {
SpaceToBatchNDContexttflite::ops::builtin::space_to_batch_nd::SpaceToBatchNDContext37 SpaceToBatchNDContext(TfLiteContext* context, TfLiteNode* node) {
38 input = GetInput(context, node, 0);
39 block_shape = GetInput(context, node, 1);
40 paddings = GetInput(context, node, 2);
41 output = GetOutput(context, node, 0);
42 }
43 const TfLiteTensor* input;
44 const TfLiteTensor* block_shape;
45 const TfLiteTensor* paddings;
46 TfLiteTensor* output;
47 };
48
49 // Currently, only 4D NHWC input/output op_context are supported.
50 // The 4D array need to have exactly 2 spatial dimensions.
51 // TODO(nupurgarg): Support arbitrary dimension in SpaceToBatchND.
52 const int kInputDimensionNum = 4;
53 const int kBlockSizeDimensionNum = 1;
54 const int kSpatialDimensionNum = 2;
55
ResizeOutputTensor(TfLiteContext * context,SpaceToBatchNDContext * op_context)56 TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
57 SpaceToBatchNDContext* op_context) {
58 TfLiteIntArray* input_size = op_context->input->dims;
59 const int32* block_shape = GetTensorData<int32>(op_context->block_shape);
60 const int32* paddings_data = GetTensorData<int32>(op_context->paddings);
61
62 TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->block_shape),
63 kBlockSizeDimensionNum);
64 TF_LITE_ENSURE_EQ(context, op_context->block_shape->dims->data[0],
65 kSpatialDimensionNum);
66 TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->paddings),
67 kSpatialDimensionNum);
68
69 TfLiteIntArray* output_size = TfLiteIntArrayCopy(input_size);
70
71 // Ensures the input height and width (with padding) is a multiple of block
72 // shape height and width.
73 for (int dim = 0; dim < kSpatialDimensionNum; ++dim) {
74 int final_dim_size = (input_size->data[dim + 1] + paddings_data[dim * 2] +
75 paddings_data[dim * 2 + 1]);
76 TF_LITE_ENSURE_EQ(context, final_dim_size % block_shape[dim], 0);
77 output_size->data[dim + 1] = final_dim_size / block_shape[dim];
78 }
79
80 const int output_batch_size =
81 input_size->data[0] * block_shape[0] * block_shape[1];
82 const int output_channel_size = input_size->data[3];
83
84 output_size->data[0] = output_batch_size;
85 output_size->data[3] = output_channel_size;
86
87 return context->ResizeTensor(context, op_context->output, output_size);
88 }
89
Prepare(TfLiteContext * context,TfLiteNode * node)90 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
91 TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
92 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
93
94 SpaceToBatchNDContext op_context(context, node);
95 TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.input),
96 kInputDimensionNum);
97 TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type);
98
99 if (!IsConstantTensor(op_context.block_shape) ||
100 !IsConstantTensor(op_context.paddings)) {
101 SetTensorToDynamic(op_context.output);
102 return kTfLiteOk;
103 }
104 return ResizeOutputTensor(context, &op_context);
105 }
106
107 template <KernelType kernel_type>
Eval(TfLiteContext * context,TfLiteNode * node)108 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
109 SpaceToBatchNDContext op_context(context, node);
110
111 // Resize the output tensor if the output tensor is dynamic.
112 if (IsDynamicTensor(op_context.output)) {
113 TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
114 }
115
116 #define TF_LITE_SPACE_TO_BATCH_ND(type, scalar, pad_value) \
117 tflite::SpaceToBatchParams op_params; \
118 op_params.output_offset = pad_value; \
119 type::SpaceToBatchND(op_params, GetTensorShape(op_context.input), \
120 GetTensorData<scalar>(op_context.input), \
121 GetTensorShape(op_context.block_shape), \
122 GetTensorData<int32_t>(op_context.block_shape), \
123 GetTensorShape(op_context.paddings), \
124 GetTensorData<int32_t>(op_context.paddings), \
125 GetTensorShape(op_context.output), \
126 GetTensorData<scalar>(op_context.output))
127 switch (op_context.input->type) { // Already know in/out types are same.
128 case kTfLiteFloat32:
129 if (kernel_type == kReference) {
130 TF_LITE_SPACE_TO_BATCH_ND(reference_ops, float, 0);
131 } else {
132 TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, float, 0);
133 }
134 break;
135 case kTfLiteUInt8:
136 if (kernel_type == kReference) {
137 TF_LITE_SPACE_TO_BATCH_ND(reference_ops, uint8_t,
138 op_context.output->params.zero_point);
139 } else {
140 TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, uint8_t,
141 op_context.output->params.zero_point);
142 }
143 break;
144 case kTfLiteInt8:
145 if (kernel_type == kReference) {
146 TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int8_t,
147 op_context.output->params.zero_point);
148 } else {
149 TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, int8_t,
150 op_context.output->params.zero_point);
151 }
152 break;
153 case kTfLiteInt32:
154 if (kernel_type == kReference) {
155 TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int32_t, 0);
156 } else {
157 TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, int32_t, 0);
158 }
159 break;
160 case kTfLiteInt64:
161 if (kernel_type == kReference) {
162 TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int64_t, 0);
163 } else {
164 TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, int64_t, 0);
165 }
166 break;
167 default:
168 context->ReportError(
169 context, "Type %d is currently not supported by SpaceToBatch.",
170 op_context.input->type);
171 return kTfLiteError;
172 }
173 #undef TF_LITE_SPACE_TO_BATCH_ND
174 return kTfLiteOk;
175 }
176
177 } // namespace space_to_batch_nd
178
Register_SPACE_TO_BATCH_ND_REF()179 TfLiteRegistration* Register_SPACE_TO_BATCH_ND_REF() {
180 static TfLiteRegistration r = {
181 nullptr, nullptr, space_to_batch_nd::Prepare,
182 space_to_batch_nd::Eval<space_to_batch_nd::kReference>};
183 return &r;
184 }
185
Register_SPACE_TO_BATCH_ND_GENERIC_OPT()186 TfLiteRegistration* Register_SPACE_TO_BATCH_ND_GENERIC_OPT() {
187 static TfLiteRegistration r = {
188 nullptr, nullptr, space_to_batch_nd::Prepare,
189 space_to_batch_nd::Eval<space_to_batch_nd::kGenericOptimized>};
190 return &r;
191 }
192
Register_SPACE_TO_BATCH_ND()193 TfLiteRegistration* Register_SPACE_TO_BATCH_ND() {
194 // return Register_SPACE_TO_BATCH_ND_REF();
195 return Register_SPACE_TO_BATCH_ND_GENERIC_OPT();
196 }
197
198 } // namespace builtin
199 } // namespace ops
200 } // namespace tflite
201