• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <string.h>
16 #include <vector>
17 #include "tensorflow/lite/c/builtin_op_data.h"
18 #include "tensorflow/lite/c/c_api_internal.h"
19 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
20 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
21 #include "tensorflow/lite/kernels/internal/tensor.h"
22 #include "tensorflow/lite/kernels/kernel_util.h"
23 #include "tensorflow/lite/kernels/op_macros.h"
24 
25 namespace tflite {
26 namespace ops {
27 namespace builtin {
28 namespace space_to_batch_nd {
29 
30 // This file has two implementations of SpaceToBatchND.
31 enum KernelType {
32   kReference,
33   kGenericOptimized,
34 };
35 
36 struct SpaceToBatchNDContext {
SpaceToBatchNDContexttflite::ops::builtin::space_to_batch_nd::SpaceToBatchNDContext37   SpaceToBatchNDContext(TfLiteContext* context, TfLiteNode* node) {
38     input = GetInput(context, node, 0);
39     block_shape = GetInput(context, node, 1);
40     paddings = GetInput(context, node, 2);
41     output = GetOutput(context, node, 0);
42   }
43   const TfLiteTensor* input;
44   const TfLiteTensor* block_shape;
45   const TfLiteTensor* paddings;
46   TfLiteTensor* output;
47 };
48 
49 // Currently, only 4D NHWC input/output op_context are supported.
50 // The 4D array need to have exactly 2 spatial dimensions.
51 // TODO(nupurgarg): Support arbitrary dimension in SpaceToBatchND.
52 const int kInputDimensionNum = 4;
53 const int kBlockSizeDimensionNum = 1;
54 const int kSpatialDimensionNum = 2;
55 
ResizeOutputTensor(TfLiteContext * context,SpaceToBatchNDContext * op_context)56 TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
57                                 SpaceToBatchNDContext* op_context) {
58   TfLiteIntArray* input_size = op_context->input->dims;
59   const int32* block_shape = GetTensorData<int32>(op_context->block_shape);
60   const int32* paddings_data = GetTensorData<int32>(op_context->paddings);
61 
62   TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->block_shape),
63                     kBlockSizeDimensionNum);
64   TF_LITE_ENSURE_EQ(context, op_context->block_shape->dims->data[0],
65                     kSpatialDimensionNum);
66   TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->paddings),
67                     kSpatialDimensionNum);
68 
69   TfLiteIntArray* output_size = TfLiteIntArrayCopy(input_size);
70 
71   // Ensures the input height and width (with padding) is a multiple of block
72   // shape height and width.
73   for (int dim = 0; dim < kSpatialDimensionNum; ++dim) {
74     int final_dim_size = (input_size->data[dim + 1] + paddings_data[dim * 2] +
75                           paddings_data[dim * 2 + 1]);
76     TF_LITE_ENSURE_EQ(context, final_dim_size % block_shape[dim], 0);
77     output_size->data[dim + 1] = final_dim_size / block_shape[dim];
78   }
79 
80   const int output_batch_size =
81       input_size->data[0] * block_shape[0] * block_shape[1];
82   const int output_channel_size = input_size->data[3];
83 
84   output_size->data[0] = output_batch_size;
85   output_size->data[3] = output_channel_size;
86 
87   return context->ResizeTensor(context, op_context->output, output_size);
88 }
89 
Prepare(TfLiteContext * context,TfLiteNode * node)90 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
91   TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
92   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
93 
94   SpaceToBatchNDContext op_context(context, node);
95   TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.input),
96                     kInputDimensionNum);
97   TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type);
98 
99   if (!IsConstantTensor(op_context.block_shape) ||
100       !IsConstantTensor(op_context.paddings)) {
101     SetTensorToDynamic(op_context.output);
102     return kTfLiteOk;
103   }
104   return ResizeOutputTensor(context, &op_context);
105 }
106 
107 template <KernelType kernel_type>
Eval(TfLiteContext * context,TfLiteNode * node)108 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
109   SpaceToBatchNDContext op_context(context, node);
110 
111   // Resize the output tensor if the output tensor is dynamic.
112   if (IsDynamicTensor(op_context.output)) {
113     TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
114   }
115 
116 #define TF_LITE_SPACE_TO_BATCH_ND(type, scalar, pad_value)             \
117   tflite::SpaceToBatchParams op_params;                                \
118   op_params.output_offset = pad_value;                                 \
119   type::SpaceToBatchND(op_params, GetTensorShape(op_context.input),    \
120                        GetTensorData<scalar>(op_context.input),        \
121                        GetTensorShape(op_context.block_shape),         \
122                        GetTensorData<int32_t>(op_context.block_shape), \
123                        GetTensorShape(op_context.paddings),            \
124                        GetTensorData<int32_t>(op_context.paddings),    \
125                        GetTensorShape(op_context.output),              \
126                        GetTensorData<scalar>(op_context.output))
127   switch (op_context.input->type) {  // Already know in/out types are same.
128     case kTfLiteFloat32:
129       if (kernel_type == kReference) {
130         TF_LITE_SPACE_TO_BATCH_ND(reference_ops, float, 0);
131       } else {
132         TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, float, 0);
133       }
134       break;
135     case kTfLiteUInt8:
136       if (kernel_type == kReference) {
137         TF_LITE_SPACE_TO_BATCH_ND(reference_ops, uint8_t,
138                                   op_context.output->params.zero_point);
139       } else {
140         TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, uint8_t,
141                                   op_context.output->params.zero_point);
142       }
143       break;
144     case kTfLiteInt8:
145       if (kernel_type == kReference) {
146         TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int8_t,
147                                   op_context.output->params.zero_point);
148       } else {
149         TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, int8_t,
150                                   op_context.output->params.zero_point);
151       }
152       break;
153     case kTfLiteInt32:
154       if (kernel_type == kReference) {
155         TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int32_t, 0);
156       } else {
157         TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, int32_t, 0);
158       }
159       break;
160     case kTfLiteInt64:
161       if (kernel_type == kReference) {
162         TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int64_t, 0);
163       } else {
164         TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, int64_t, 0);
165       }
166       break;
167     default:
168       context->ReportError(
169           context, "Type %d is currently not supported by SpaceToBatch.",
170           op_context.input->type);
171       return kTfLiteError;
172   }
173 #undef TF_LITE_SPACE_TO_BATCH_ND
174   return kTfLiteOk;
175 }
176 
177 }  // namespace space_to_batch_nd
178 
Register_SPACE_TO_BATCH_ND_REF()179 TfLiteRegistration* Register_SPACE_TO_BATCH_ND_REF() {
180   static TfLiteRegistration r = {
181       nullptr, nullptr, space_to_batch_nd::Prepare,
182       space_to_batch_nd::Eval<space_to_batch_nd::kReference>};
183   return &r;
184 }
185 
Register_SPACE_TO_BATCH_ND_GENERIC_OPT()186 TfLiteRegistration* Register_SPACE_TO_BATCH_ND_GENERIC_OPT() {
187   static TfLiteRegistration r = {
188       nullptr, nullptr, space_to_batch_nd::Prepare,
189       space_to_batch_nd::Eval<space_to_batch_nd::kGenericOptimized>};
190   return &r;
191 }
192 
Register_SPACE_TO_BATCH_ND()193 TfLiteRegistration* Register_SPACE_TO_BATCH_ND() {
194   // return Register_SPACE_TO_BATCH_ND_REF();
195   return Register_SPACE_TO_BATCH_ND_GENERIC_OPT();
196 }
197 
198 }  // namespace builtin
199 }  // namespace ops
200 }  // namespace tflite
201