• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <string.h>
16 #include <vector>
17 #include "tensorflow/lite/c/builtin_op_data.h"
18 #include "tensorflow/lite/c/c_api_internal.h"
19 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
20 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
21 #include "tensorflow/lite/kernels/internal/tensor.h"
22 #include "tensorflow/lite/kernels/kernel_util.h"
23 #include "tensorflow/lite/kernels/op_macros.h"
24 
25 namespace tflite {
26 namespace ops {
27 namespace builtin {
28 namespace batch_to_space_nd {
29 
30 // This file has two implementations of BatchToSpaceND.
31 enum KernelType {
32   kReference,
33   kGenericOptimized,
34 };
35 
36 struct BatchToSpaceNDContext {
BatchToSpaceNDContexttflite::ops::builtin::batch_to_space_nd::BatchToSpaceNDContext37   BatchToSpaceNDContext(TfLiteContext* context, TfLiteNode* node) {
38     input = GetInput(context, node, 0);
39     block_shape = GetInput(context, node, 1);
40     crops = GetInput(context, node, 2);
41     output = GetOutput(context, node, 0);
42   }
43   const TfLiteTensor* input;
44   const TfLiteTensor* block_shape;
45   const TfLiteTensor* crops;
46   TfLiteTensor* output;
47 };
48 
49 // Currently, only 4D NHWC input/output op_context are supported.
50 // The 4D array need to have exactly 2 spatial dimensions.
51 // TODO(ycling): Support arbitrary dimension in BatchToSpaceND.
52 const int kInputDimensionNum = 4;
53 const int kBlockSizeDimensionNum = 1;
54 const int kSpatialDimensionNum = 2;
55 
ResizeOutputTensor(TfLiteContext * context,BatchToSpaceNDContext * op_context)56 TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
57                                 BatchToSpaceNDContext* op_context) {
58   TfLiteIntArray* input_size = op_context->input->dims;
59   const int* block_shape = GetTensorData<int32>(op_context->block_shape);
60   const int* crops = GetTensorData<int32>(op_context->crops);
61 
62   TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->block_shape),
63                     kBlockSizeDimensionNum);
64   TF_LITE_ENSURE_EQ(context, op_context->block_shape->dims->data[0],
65                     kSpatialDimensionNum);
66   TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->crops),
67                     kSpatialDimensionNum);
68 
69   TF_LITE_ENSURE(context, crops[0] >= 0);
70   TF_LITE_ENSURE(context, crops[1] >= 0);
71   TF_LITE_ENSURE(context, crops[2] >= 0);
72   TF_LITE_ENSURE(context, crops[3] >= 0);
73 
74   // Number of batch must be multiple of (block_shape[0] * block_shape[1]).
75   TF_LITE_ENSURE_EQ(context,
76                     input_size->data[0] % (block_shape[0] * block_shape[1]), 0);
77 
78   const int output_batch_size =
79       input_size->data[0] / (block_shape[0] * block_shape[1]);
80 
81   const int crops_top = crops[0];
82   const int crops_bottom = crops[1];
83   const int crops_left = crops[2];
84   const int crops_right = crops[3];
85   const int output_height =
86       input_size->data[1] * block_shape[0] - crops_top - crops_bottom;
87   const int output_width =
88       input_size->data[2] * block_shape[1] - crops_left - crops_right;
89 
90   const int output_channel_size = input_size->data[3];
91 
92   TfLiteIntArray* output_size = TfLiteIntArrayCopy(input_size);
93   output_size->data[0] = output_batch_size;
94   output_size->data[1] = output_height;
95   output_size->data[2] = output_width;
96   output_size->data[3] = output_channel_size;
97 
98   return context->ResizeTensor(context, op_context->output, output_size);
99 }
100 
Prepare(TfLiteContext * context,TfLiteNode * node)101 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
102   TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
103   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
104 
105   BatchToSpaceNDContext op_context(context, node);
106   TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.input),
107                     kInputDimensionNum);
108   TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type);
109 
110   if (!IsConstantTensor(op_context.block_shape) ||
111       !IsConstantTensor(op_context.crops)) {
112     SetTensorToDynamic(op_context.output);
113     return kTfLiteOk;
114   }
115   return ResizeOutputTensor(context, &op_context);
116 }
117 
118 template <KernelType kernel_type>
Eval(TfLiteContext * context,TfLiteNode * node)119 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
120   BatchToSpaceNDContext op_context(context, node);
121 
122   // Resize the output tensor if the output tensor is dynamic.
123   if (IsDynamicTensor(op_context.output)) {
124     TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
125   }
126 
127 #define TF_LITE_BATCH_TO_SPACE_ND(type, scalar)                        \
128   type::BatchToSpaceND(GetTensorShape(op_context.input),               \
129                        GetTensorData<scalar>(op_context.input),        \
130                        GetTensorShape(op_context.block_shape),         \
131                        GetTensorData<int32_t>(op_context.block_shape), \
132                        GetTensorShape(op_context.crops),               \
133                        GetTensorData<int32_t>(op_context.crops),       \
134                        GetTensorShape(op_context.output),              \
135                        GetTensorData<scalar>(op_context.output))
136   switch (op_context.input->type) {  // Already know in/out types are same.
137     case kTfLiteFloat32:
138       if (kernel_type == kReference) {
139         TF_LITE_BATCH_TO_SPACE_ND(reference_ops, float);
140       } else {
141         TF_LITE_BATCH_TO_SPACE_ND(optimized_ops, float);
142       }
143       break;
144     case kTfLiteUInt8:
145       if (kernel_type == kReference) {
146         TF_LITE_BATCH_TO_SPACE_ND(reference_ops, uint8_t);
147       } else {
148         TF_LITE_BATCH_TO_SPACE_ND(optimized_ops, uint8_t);
149       }
150       break;
151     case kTfLiteInt8:
152       if (kernel_type == kReference) {
153         TF_LITE_BATCH_TO_SPACE_ND(reference_ops, int8_t);
154       } else {
155         TF_LITE_BATCH_TO_SPACE_ND(optimized_ops, int8_t);
156       }
157       break;
158     case kTfLiteInt32:
159       if (kernel_type == kReference) {
160         TF_LITE_BATCH_TO_SPACE_ND(reference_ops, int32_t);
161       } else {
162         TF_LITE_BATCH_TO_SPACE_ND(optimized_ops, int32_t);
163       }
164       break;
165     case kTfLiteInt64:
166       if (kernel_type == kReference) {
167         TF_LITE_BATCH_TO_SPACE_ND(reference_ops, int64_t);
168       } else {
169         TF_LITE_BATCH_TO_SPACE_ND(optimized_ops, int64_t);
170       }
171       break;
172     default:
173       context->ReportError(
174           context, "Type %d is currently not supported by BatchToSpace.",
175           op_context.input->type);
176       return kTfLiteError;
177   }
178 #undef TF_LITE_BATCH_TO_SPACE_ND
179   return kTfLiteOk;
180 }
181 
182 }  // namespace batch_to_space_nd
183 
Register_BATCH_TO_SPACE_ND_REF()184 TfLiteRegistration* Register_BATCH_TO_SPACE_ND_REF() {
185   static TfLiteRegistration r = {
186       nullptr, nullptr, batch_to_space_nd::Prepare,
187       batch_to_space_nd::Eval<batch_to_space_nd::kReference>};
188   return &r;
189 }
190 
Register_BATCH_TO_SPACE_ND_GENERIC_OPT()191 TfLiteRegistration* Register_BATCH_TO_SPACE_ND_GENERIC_OPT() {
192   static TfLiteRegistration r = {
193       nullptr, nullptr, batch_to_space_nd::Prepare,
194       batch_to_space_nd::Eval<batch_to_space_nd::kGenericOptimized>};
195   return &r;
196 }
197 
Register_BATCH_TO_SPACE_ND()198 TfLiteRegistration* Register_BATCH_TO_SPACE_ND() {
199   return Register_BATCH_TO_SPACE_ND_GENERIC_OPT();
200 }
201 
202 }  // namespace builtin
203 }  // namespace ops
204 }  // namespace tflite
205