• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stdint.h>
16 
17 #include "tensorflow/lite/c/common.h"
18 #include "tensorflow/lite/kernels/internal/compatibility.h"
19 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
20 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
21 #include "tensorflow/lite/kernels/internal/tensor.h"
22 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
23 #include "tensorflow/lite/kernels/kernel_util.h"
24 
25 namespace tflite {
26 namespace ops {
27 namespace builtin {
28 namespace batch_to_space_nd {
29 
30 // This file has two implementations of BatchToSpaceND.
31 enum KernelType {
32   kReference,
33   kGenericOptimized,
34 };
35 
36 struct BatchToSpaceNDContext {
BatchToSpaceNDContexttflite::ops::builtin::batch_to_space_nd::BatchToSpaceNDContext37   BatchToSpaceNDContext(TfLiteContext* context, TfLiteNode* node) {
38     input = GetInput(context, node, 0);
39     block_shape = GetInput(context, node, 1);
40     crops = GetInput(context, node, 2);
41     output = GetOutput(context, node, 0);
42   }
43   const TfLiteTensor* input;
44   const TfLiteTensor* block_shape;
45   const TfLiteTensor* crops;
46   TfLiteTensor* output;
47 };
48 
49 // Currently, only 3D NHC or 4D NHWC input/output op_context are supported.
50 // In case of 3D input,it will be converted to 4D by adding W=1 to be NH1C.
51 // The 4D array need to have exactly 2 spatial dimensions.
52 // TODO(ycling): Support arbitrary dimension in BatchToSpaceND.
53 const int kInputMinDimensionNum = 3;
54 const int kInputMaxDimensionNum = 4;
55 
ResizeOutputTensor(TfLiteContext * context,BatchToSpaceNDContext * op_context)56 TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
57                                 BatchToSpaceNDContext* op_context) {
58   TfLiteIntArray* input_size = op_context->input->dims;
59   const int* block_shape = GetTensorData<int32>(op_context->block_shape);
60   const int* crops = GetTensorData<int32>(op_context->crops);
61 
62   int spatial_dims_num = input_size->size - 2;
63   // Block_shape should be a 1D tensor with dimension [spatial_dims_num].
64   TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->block_shape), 1);
65   TF_LITE_ENSURE_EQ(context, op_context->block_shape->dims->data[0],
66                     spatial_dims_num);
67   // Crops should be a 2D tensor with dimension [spatial_dims_num, 2].
68   TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->crops), 2);
69   TF_LITE_ENSURE_EQ(context, op_context->crops->dims->data[0],
70                     spatial_dims_num);
71   TF_LITE_ENSURE_EQ(context, op_context->crops->dims->data[1], 2);
72 
73   for (int i = 0; i < spatial_dims_num * 2; ++i) {
74     TF_LITE_ENSURE(context, crops[i] >= 0);
75   }
76 
77   TfLiteIntArray* output_size = TfLiteIntArrayCopy(input_size);
78   int output_batch_size = input_size->data[0];
79   for (int dim = 0; dim < spatial_dims_num; ++dim) {
80     // Number of batch must be multiple of (block_shape[dim]).
81     TF_LITE_ENSURE_EQ(context, output_batch_size % block_shape[dim], 0);
82     output_batch_size = output_batch_size / block_shape[dim];
83     output_size->data[dim + 1] = input_size->data[dim + 1] * block_shape[dim] -
84                                  crops[dim * 2] - crops[dim * 2 + 1];
85   }
86 
87   output_size->data[0] = output_batch_size;
88   output_size->data[input_size->size - 1] =
89       input_size->data[input_size->size - 1];
90 
91   return context->ResizeTensor(context, op_context->output, output_size);
92 }
93 
Prepare(TfLiteContext * context,TfLiteNode * node)94 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
95   TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
96   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
97 
98   BatchToSpaceNDContext op_context(context, node);
99   TF_LITE_ENSURE(context,
100                  NumDimensions(op_context.input) >= kInputMinDimensionNum);
101   TF_LITE_ENSURE(context,
102                  NumDimensions(op_context.input) <= kInputMaxDimensionNum);
103   TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type);
104 
105   if (!IsConstantTensor(op_context.block_shape) ||
106       !IsConstantTensor(op_context.crops)) {
107     SetTensorToDynamic(op_context.output);
108     return kTfLiteOk;
109   }
110   return ResizeOutputTensor(context, &op_context);
111 }
112 
113 template <KernelType kernel_type>
Eval(TfLiteContext * context,TfLiteNode * node)114 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
115   BatchToSpaceNDContext op_context(context, node);
116 
117   // Resize the output tensor if the output tensor is dynamic.
118   if (IsDynamicTensor(op_context.output)) {
119     TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
120   }
121 
122 #define TF_LITE_BATCH_TO_SPACE_ND(type, scalar)                        \
123   type::BatchToSpaceND(GetTensorShape(op_context.input),               \
124                        GetTensorData<scalar>(op_context.input),        \
125                        GetTensorShape(op_context.block_shape),         \
126                        GetTensorData<int32_t>(op_context.block_shape), \
127                        GetTensorShape(op_context.crops),               \
128                        GetTensorData<int32_t>(op_context.crops),       \
129                        GetTensorShape(op_context.output),              \
130                        GetTensorData<scalar>(op_context.output))
131   switch (op_context.input->type) {  // Already know in/out types are same.
132     case kTfLiteFloat32:
133       if (kernel_type == kReference) {
134         TF_LITE_BATCH_TO_SPACE_ND(reference_ops, float);
135       } else {
136         TF_LITE_BATCH_TO_SPACE_ND(optimized_ops, float);
137       }
138       break;
139     case kTfLiteUInt8:
140       if (kernel_type == kReference) {
141         TF_LITE_BATCH_TO_SPACE_ND(reference_ops, uint8_t);
142       } else {
143         TF_LITE_BATCH_TO_SPACE_ND(optimized_ops, uint8_t);
144       }
145       break;
146     case kTfLiteInt8:
147       if (kernel_type == kReference) {
148         TF_LITE_BATCH_TO_SPACE_ND(reference_ops, int8_t);
149       } else {
150         TF_LITE_BATCH_TO_SPACE_ND(optimized_ops, int8_t);
151       }
152       break;
153     case kTfLiteInt32:
154       if (kernel_type == kReference) {
155         TF_LITE_BATCH_TO_SPACE_ND(reference_ops, int32_t);
156       } else {
157         TF_LITE_BATCH_TO_SPACE_ND(optimized_ops, int32_t);
158       }
159       break;
160     case kTfLiteInt64:
161       if (kernel_type == kReference) {
162         TF_LITE_BATCH_TO_SPACE_ND(reference_ops, int64_t);
163       } else {
164         TF_LITE_BATCH_TO_SPACE_ND(optimized_ops, int64_t);
165       }
166       break;
167     default:
168       context->ReportError(
169           context, "Type %d is currently not supported by BatchToSpace.",
170           op_context.input->type);
171       return kTfLiteError;
172   }
173 #undef TF_LITE_BATCH_TO_SPACE_ND
174   return kTfLiteOk;
175 }
176 
177 }  // namespace batch_to_space_nd
178 
Register_BATCH_TO_SPACE_ND_REF()179 TfLiteRegistration* Register_BATCH_TO_SPACE_ND_REF() {
180   static TfLiteRegistration r = {
181       nullptr, nullptr, batch_to_space_nd::Prepare,
182       batch_to_space_nd::Eval<batch_to_space_nd::kReference>};
183   return &r;
184 }
185 
Register_BATCH_TO_SPACE_ND_GENERIC_OPT()186 TfLiteRegistration* Register_BATCH_TO_SPACE_ND_GENERIC_OPT() {
187   static TfLiteRegistration r = {
188       nullptr, nullptr, batch_to_space_nd::Prepare,
189       batch_to_space_nd::Eval<batch_to_space_nd::kGenericOptimized>};
190   return &r;
191 }
192 
Register_BATCH_TO_SPACE_ND()193 TfLiteRegistration* Register_BATCH_TO_SPACE_ND() {
194   return Register_BATCH_TO_SPACE_ND_GENERIC_OPT();
195 }
196 
197 }  // namespace builtin
198 }  // namespace ops
199 }  // namespace tflite
200