• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stdint.h>
16 
17 #include "tensorflow/lite/c/common.h"
18 #include "tensorflow/lite/kernels/internal/compatibility.h"
19 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
20 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
21 #include "tensorflow/lite/kernels/internal/tensor.h"
22 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
23 #include "tensorflow/lite/kernels/internal/types.h"
24 #include "tensorflow/lite/kernels/kernel_util.h"
25 
26 namespace tflite {
27 namespace ops {
28 namespace builtin {
29 namespace space_to_batch_nd {
30 
31 // This file has two implementations of SpaceToBatchND.
32 enum KernelType {
33   kReference,
34   kGenericOptimized,
35 };
36 
37 struct SpaceToBatchNDContext {
SpaceToBatchNDContexttflite::ops::builtin::space_to_batch_nd::SpaceToBatchNDContext38   SpaceToBatchNDContext(TfLiteContext* context, TfLiteNode* node) {
39     input = GetInput(context, node, 0);
40     block_shape = GetInput(context, node, 1);
41     paddings = GetInput(context, node, 2);
42     output = GetOutput(context, node, 0);
43   }
44   const TfLiteTensor* input;
45   const TfLiteTensor* block_shape;
46   const TfLiteTensor* paddings;
47   TfLiteTensor* output;
48 };
49 
50 // Currently, only 3D NHC and 4D NHWC input/output op_context are supported.
51 // In case of 3D input, it will be extended to 3D NHWC by adding W=1.
52 // The 4D array need to have exactly 2 spatial dimensions.
53 // TODO(b/149952582): Support arbitrary dimension in SpaceToBatchND.
54 const int kInputMinDimensionNum = 3;
55 const int kInputMaxDimensionNum = 4;
56 
ResizeOutputTensor(TfLiteContext * context,SpaceToBatchNDContext * op_context)57 TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
58                                 SpaceToBatchNDContext* op_context) {
59   TfLiteIntArray* input_size = op_context->input->dims;
60   const int32* block_shape = GetTensorData<int32>(op_context->block_shape);
61   const int32* paddings_data = GetTensorData<int32>(op_context->paddings);
62 
63   int spatial_dims_num = input_size->size - 2;
64   // Block_shape should be a 1D tensor with dimension [spatial_dims_num].
65   TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->block_shape), 1);
66   TF_LITE_ENSURE_EQ(context, op_context->block_shape->dims->data[0],
67                     spatial_dims_num);
68   // Paddings should be a 2D tensor with dimension [spatial_dims_num, 2].
69   TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->paddings), 2);
70   TF_LITE_ENSURE_EQ(context, op_context->paddings->dims->data[0],
71                     spatial_dims_num);
72   TF_LITE_ENSURE_EQ(context, op_context->paddings->dims->data[1], 2);
73 
74   TfLiteIntArray* output_size = TfLiteIntArrayCopy(input_size);
75 
76   // Ensures the input height and width (with padding) is a multiple of block
77   // shape height and width.
78   int output_batch_size = input_size->data[0];
79   for (int dim = 0; dim < spatial_dims_num; ++dim) {
80     int final_dim_size = (input_size->data[dim + 1] + paddings_data[dim * 2] +
81                           paddings_data[dim * 2 + 1]);
82     TF_LITE_ENSURE_EQ(context, final_dim_size % block_shape[dim], 0);
83     output_size->data[dim + 1] = final_dim_size / block_shape[dim];
84     output_batch_size *= block_shape[dim];
85   }
86 
87   output_size->data[0] = output_batch_size;
88   output_size->data[input_size->size - 1] =
89       input_size->data[input_size->size - 1];
90 
91   return context->ResizeTensor(context, op_context->output, output_size);
92 }
93 
Prepare(TfLiteContext * context,TfLiteNode * node)94 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
95   TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
96   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
97 
98   SpaceToBatchNDContext op_context(context, node);
99   TF_LITE_ENSURE(context,
100                  NumDimensions(op_context.input) >= kInputMinDimensionNum);
101   TF_LITE_ENSURE(context,
102                  NumDimensions(op_context.input) <= kInputMaxDimensionNum);
103   TF_LITE_ENSURE_TYPES_EQ(context, op_context.input->type,
104                           op_context.output->type);
105 
106   if (!IsConstantTensor(op_context.block_shape) ||
107       !IsConstantTensor(op_context.paddings)) {
108     SetTensorToDynamic(op_context.output);
109     return kTfLiteOk;
110   }
111   return ResizeOutputTensor(context, &op_context);
112 }
113 
114 template <KernelType kernel_type>
Eval(TfLiteContext * context,TfLiteNode * node)115 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
116   SpaceToBatchNDContext op_context(context, node);
117 
118   // Resize the output tensor if the output tensor is dynamic.
119   if (IsDynamicTensor(op_context.output)) {
120     TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
121   }
122 
123 #define TF_LITE_SPACE_TO_BATCH_ND(type, scalar, pad_value)             \
124   tflite::SpaceToBatchParams op_params;                                \
125   op_params.output_offset = pad_value;                                 \
126   type::SpaceToBatchND(op_params, GetTensorShape(op_context.input),    \
127                        GetTensorData<scalar>(op_context.input),        \
128                        GetTensorShape(op_context.block_shape),         \
129                        GetTensorData<int32_t>(op_context.block_shape), \
130                        GetTensorShape(op_context.paddings),            \
131                        GetTensorData<int32_t>(op_context.paddings),    \
132                        GetTensorShape(op_context.output),              \
133                        GetTensorData<scalar>(op_context.output))
134   switch (op_context.input->type) {  // Already know in/out types are same.
135     case kTfLiteFloat32:
136       if (kernel_type == kReference) {
137         TF_LITE_SPACE_TO_BATCH_ND(reference_ops, float, 0);
138       } else {
139         TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, float, 0);
140       }
141       break;
142     case kTfLiteUInt8:
143       if (kernel_type == kReference) {
144         TF_LITE_SPACE_TO_BATCH_ND(reference_ops, uint8_t,
145                                   op_context.output->params.zero_point);
146       } else {
147         TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, uint8_t,
148                                   op_context.output->params.zero_point);
149       }
150       break;
151     case kTfLiteInt8:
152       if (kernel_type == kReference) {
153         TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int8_t,
154                                   op_context.output->params.zero_point);
155       } else {
156         TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, int8_t,
157                                   op_context.output->params.zero_point);
158       }
159       break;
160     case kTfLiteInt32:
161       if (kernel_type == kReference) {
162         TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int32_t, 0);
163       } else {
164         TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, int32_t, 0);
165       }
166       break;
167     case kTfLiteInt64:
168       if (kernel_type == kReference) {
169         TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int64_t, 0);
170       } else {
171         TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, int64_t, 0);
172       }
173       break;
174     default:
175       context->ReportError(
176           context, "Type %d is currently not supported by SpaceToBatch.",
177           op_context.input->type);
178       return kTfLiteError;
179   }
180 #undef TF_LITE_SPACE_TO_BATCH_ND
181   return kTfLiteOk;
182 }
183 
184 }  // namespace space_to_batch_nd
185 
Register_SPACE_TO_BATCH_ND_REF()186 TfLiteRegistration* Register_SPACE_TO_BATCH_ND_REF() {
187   static TfLiteRegistration r = {
188       nullptr, nullptr, space_to_batch_nd::Prepare,
189       space_to_batch_nd::Eval<space_to_batch_nd::kReference>};
190   return &r;
191 }
192 
Register_SPACE_TO_BATCH_ND_GENERIC_OPT()193 TfLiteRegistration* Register_SPACE_TO_BATCH_ND_GENERIC_OPT() {
194   static TfLiteRegistration r = {
195       nullptr, nullptr, space_to_batch_nd::Prepare,
196       space_to_batch_nd::Eval<space_to_batch_nd::kGenericOptimized>};
197   return &r;
198 }
199 
Register_SPACE_TO_BATCH_ND()200 TfLiteRegistration* Register_SPACE_TO_BATCH_ND() {
201   // return Register_SPACE_TO_BATCH_ND_REF();
202   return Register_SPACE_TO_BATCH_ND_GENERIC_OPT();
203 }
204 
205 }  // namespace builtin
206 }  // namespace ops
207 }  // namespace tflite
208