• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stdint.h>
16 
17 #include "tensorflow/lite/c/common.h"
18 #include "tensorflow/lite/kernels/internal/compatibility.h"
19 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
20 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
21 #include "tensorflow/lite/kernels/internal/tensor.h"
22 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
23 #include "tensorflow/lite/kernels/internal/types.h"
24 #include "tensorflow/lite/kernels/kernel_util.h"
25 
26 namespace tflite {
27 namespace ops {
28 namespace builtin {
29 namespace space_to_batch_nd {
30 
31 // This file has two implementations of SpaceToBatchND.
32 enum KernelType {
33   kReference,
34   kGenericOptimized,
35 };
36 
37 struct SpaceToBatchNDContext {
SpaceToBatchNDContexttflite::ops::builtin::space_to_batch_nd::SpaceToBatchNDContext38   SpaceToBatchNDContext(TfLiteContext* context, TfLiteNode* node) {
39     input = GetInput(context, node, 0);
40     block_shape = GetInput(context, node, 1);
41     paddings = GetInput(context, node, 2);
42     output = GetOutput(context, node, 0);
43   }
44   const TfLiteTensor* input;
45   const TfLiteTensor* block_shape;
46   const TfLiteTensor* paddings;
47   TfLiteTensor* output;
48 };
49 
50 // Currently, only 3D NHC and 4D NHWC input/output op_context are supported.
51 // In case of 3D input, it will be extended to 3D NHWC by adding W=1.
52 // The 4D array need to have exactly 2 spatial dimensions.
53 // TODO(b/149952582): Support arbitrary dimension in SpaceToBatchND.
54 const int kInputMinDimensionNum = 3;
55 const int kInputMaxDimensionNum = 4;
56 
ResizeOutputTensor(TfLiteContext * context,SpaceToBatchNDContext * op_context)57 TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
58                                 SpaceToBatchNDContext* op_context) {
59   TfLiteIntArray* input_size = op_context->input->dims;
60   const int32* block_shape = GetTensorData<int32>(op_context->block_shape);
61   const int32* paddings_data = GetTensorData<int32>(op_context->paddings);
62 
63   int spatial_dims_num = input_size->size - 2;
64   // Block_shape should be a 1D tensor with dimension [spatial_dims_num].
65   TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->block_shape), 1);
66   TF_LITE_ENSURE_EQ(context, op_context->block_shape->dims->data[0],
67                     spatial_dims_num);
68   // Paddings should be a 2D tensor with dimension [spatial_dims_num, 2].
69   TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->paddings), 2);
70   TF_LITE_ENSURE_EQ(context, op_context->paddings->dims->data[0],
71                     spatial_dims_num);
72   TF_LITE_ENSURE_EQ(context, op_context->paddings->dims->data[1], 2);
73 
74   TfLiteIntArray* output_size = TfLiteIntArrayCopy(input_size);
75 
76   // Ensures the input height and width (with padding) is a multiple of block
77   // shape height and width.
78   int output_batch_size = input_size->data[0];
79   for (int dim = 0; dim < spatial_dims_num; ++dim) {
80     int final_dim_size = (input_size->data[dim + 1] + paddings_data[dim * 2] +
81                           paddings_data[dim * 2 + 1]);
82     TF_LITE_ENSURE(context, block_shape[dim] != 0);
83     TF_LITE_ENSURE_EQ(context, final_dim_size % block_shape[dim], 0);
84     output_size->data[dim + 1] = final_dim_size / block_shape[dim];
85     output_batch_size *= block_shape[dim];
86   }
87 
88   output_size->data[0] = output_batch_size;
89   output_size->data[input_size->size - 1] =
90       input_size->data[input_size->size - 1];
91 
92   return context->ResizeTensor(context, op_context->output, output_size);
93 }
94 
Prepare(TfLiteContext * context,TfLiteNode * node)95 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
96   TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
97   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
98 
99   SpaceToBatchNDContext op_context(context, node);
100   TF_LITE_ENSURE(context,
101                  NumDimensions(op_context.input) >= kInputMinDimensionNum);
102   TF_LITE_ENSURE(context,
103                  NumDimensions(op_context.input) <= kInputMaxDimensionNum);
104   TF_LITE_ENSURE_TYPES_EQ(context, op_context.input->type,
105                           op_context.output->type);
106 
107   if (!IsConstantTensor(op_context.block_shape) ||
108       !IsConstantTensor(op_context.paddings)) {
109     SetTensorToDynamic(op_context.output);
110     return kTfLiteOk;
111   }
112   return ResizeOutputTensor(context, &op_context);
113 }
114 
115 template <KernelType kernel_type>
Eval(TfLiteContext * context,TfLiteNode * node)116 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
117   SpaceToBatchNDContext op_context(context, node);
118 
119   // Resize the output tensor if the output tensor is dynamic.
120   if (IsDynamicTensor(op_context.output)) {
121     TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
122   }
123 
124 #define TF_LITE_SPACE_TO_BATCH_ND(type, scalar, pad_value)             \
125   tflite::SpaceToBatchParams op_params;                                \
126   op_params.output_offset = pad_value;                                 \
127   type::SpaceToBatchND(op_params, GetTensorShape(op_context.input),    \
128                        GetTensorData<scalar>(op_context.input),        \
129                        GetTensorShape(op_context.block_shape),         \
130                        GetTensorData<int32_t>(op_context.block_shape), \
131                        GetTensorShape(op_context.paddings),            \
132                        GetTensorData<int32_t>(op_context.paddings),    \
133                        GetTensorShape(op_context.output),              \
134                        GetTensorData<scalar>(op_context.output))
135   switch (op_context.input->type) {  // Already know in/out types are same.
136     case kTfLiteFloat32:
137       if (kernel_type == kReference) {
138         TF_LITE_SPACE_TO_BATCH_ND(reference_ops, float, 0);
139       } else {
140         TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, float, 0);
141       }
142       break;
143     case kTfLiteUInt8:
144       if (kernel_type == kReference) {
145         TF_LITE_SPACE_TO_BATCH_ND(reference_ops, uint8_t,
146                                   op_context.output->params.zero_point);
147       } else {
148         TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, uint8_t,
149                                   op_context.output->params.zero_point);
150       }
151       break;
152     case kTfLiteInt8:
153       if (kernel_type == kReference) {
154         TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int8_t,
155                                   op_context.output->params.zero_point);
156       } else {
157         TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, int8_t,
158                                   op_context.output->params.zero_point);
159       }
160       break;
161     case kTfLiteInt32:
162       if (kernel_type == kReference) {
163         TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int32_t, 0);
164       } else {
165         TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, int32_t, 0);
166       }
167       break;
168     case kTfLiteInt64:
169       if (kernel_type == kReference) {
170         TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int64_t, 0);
171       } else {
172         TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, int64_t, 0);
173       }
174       break;
175     default:
176       TF_LITE_KERNEL_LOG(context,
177                          "Type %d is currently not supported by SpaceToBatch.",
178                          op_context.input->type);
179       return kTfLiteError;
180   }
181 #undef TF_LITE_SPACE_TO_BATCH_ND
182   return kTfLiteOk;
183 }
184 
185 }  // namespace space_to_batch_nd
186 
Register_SPACE_TO_BATCH_ND_REF()187 TfLiteRegistration* Register_SPACE_TO_BATCH_ND_REF() {
188   static TfLiteRegistration r = {
189       nullptr, nullptr, space_to_batch_nd::Prepare,
190       space_to_batch_nd::Eval<space_to_batch_nd::kReference>};
191   return &r;
192 }
193 
Register_SPACE_TO_BATCH_ND_GENERIC_OPT()194 TfLiteRegistration* Register_SPACE_TO_BATCH_ND_GENERIC_OPT() {
195   static TfLiteRegistration r = {
196       nullptr, nullptr, space_to_batch_nd::Prepare,
197       space_to_batch_nd::Eval<space_to_batch_nd::kGenericOptimized>};
198   return &r;
199 }
200 
Register_SPACE_TO_BATCH_ND()201 TfLiteRegistration* Register_SPACE_TO_BATCH_ND() {
202   // return Register_SPACE_TO_BATCH_ND_REF();
203   return Register_SPACE_TO_BATCH_ND_GENERIC_OPT();
204 }
205 
206 }  // namespace builtin
207 }  // namespace ops
208 }  // namespace tflite
209