• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stdint.h>
16 
17 #include "tensorflow/lite/c/builtin_op_data.h"
18 #include "tensorflow/lite/c/common.h"
19 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
20 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
21 #include "tensorflow/lite/kernels/internal/tensor.h"
22 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
23 #include "tensorflow/lite/kernels/internal/types.h"
24 #include "tensorflow/lite/kernels/kernel_util.h"
25 
26 namespace tflite {
27 namespace ops {
28 namespace builtin {
29 namespace space_to_depth {
30 
31 // This file has two implementation of SpaceToDepth. Note that SpaceToDepth
32 // only works on 4D tensors.
33 enum KernelType {
34   kReference,
35   kGenericOptimized,
36 };
37 
38 constexpr int kInputTensor = 0;
39 constexpr int kOutputTensor = 0;
40 
Prepare(TfLiteContext * context,TfLiteNode * node)41 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
42   auto* params =
43       reinterpret_cast<TfLiteSpaceToDepthParams*>(node->builtin_data);
44 
45   TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
46   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
47 
48   const TfLiteTensor* input;
49   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
50   TfLiteTensor* output;
51   TF_LITE_ENSURE_OK(context,
52                     GetOutputSafe(context, node, kOutputTensor, &output));
53 
54   TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
55 
56   auto data_type = output->type;
57   TF_LITE_ENSURE(context,
58                  data_type == kTfLiteFloat32 || data_type == kTfLiteUInt8 ||
59                      data_type == kTfLiteInt8 || data_type == kTfLiteInt32 ||
60                      data_type == kTfLiteInt64);
61   TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
62 
63   const int block_size = params->block_size;
64   TF_LITE_ENSURE(context, block_size > 0);
65   const int input_height = input->dims->data[1];
66   const int input_width = input->dims->data[2];
67   int output_height = input_height / block_size;
68   int output_width = input_width / block_size;
69 
70   TF_LITE_ENSURE_EQ(context, input_height, output_height * block_size);
71   TF_LITE_ENSURE_EQ(context, input_width, output_width * block_size);
72 
73   TfLiteIntArray* output_size = TfLiteIntArrayCreate(4);
74   output_size->data[0] = input->dims->data[0];
75   output_size->data[1] = output_height;
76   output_size->data[2] = output_width;
77   output_size->data[3] = input->dims->data[3] * block_size * block_size;
78 
79   return context->ResizeTensor(context, output, output_size);
80 }
81 
82 template <KernelType kernel_type>
Eval(TfLiteContext * context,TfLiteNode * node)83 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
84   auto* params =
85       reinterpret_cast<TfLiteSpaceToDepthParams*>(node->builtin_data);
86 
87   const TfLiteTensor* input;
88   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
89   TfLiteTensor* output;
90   TF_LITE_ENSURE_OK(context,
91                     GetOutputSafe(context, node, kOutputTensor, &output));
92 
93 #define TF_LITE_SPACE_TO_DEPTH(type, scalar)                               \
94   tflite::SpaceToDepthParams op_params;                                    \
95   op_params.block_size = params->block_size;                               \
96   type::SpaceToDepth(op_params, GetTensorShape(input),                     \
97                      GetTensorData<scalar>(input), GetTensorShape(output), \
98                      GetTensorData<scalar>(output))
99   switch (input->type) {  // Already know in/out types are same.
100     case kTfLiteFloat32:
101       if (kernel_type == kReference) {
102         TF_LITE_SPACE_TO_DEPTH(reference_ops, float);
103       } else {
104         TF_LITE_SPACE_TO_DEPTH(optimized_ops, float);
105       }
106       break;
107     case kTfLiteUInt8:
108       if (kernel_type == kReference) {
109         TF_LITE_SPACE_TO_DEPTH(reference_ops, uint8_t);
110       } else {
111         TF_LITE_SPACE_TO_DEPTH(optimized_ops, uint8_t);
112       }
113       break;
114     case kTfLiteInt8:
115       if (kernel_type == kReference) {
116         TF_LITE_SPACE_TO_DEPTH(reference_ops, int8_t);
117       } else {
118         TF_LITE_SPACE_TO_DEPTH(optimized_ops, int8_t);
119       }
120       break;
121     case kTfLiteInt32:
122       if (kernel_type == kReference) {
123         TF_LITE_SPACE_TO_DEPTH(reference_ops, int32_t);
124       } else {
125         TF_LITE_SPACE_TO_DEPTH(optimized_ops, int32_t);
126       }
127       break;
128     case kTfLiteInt64:
129       if (kernel_type == kReference) {
130         TF_LITE_SPACE_TO_DEPTH(reference_ops, int64_t);
131       } else {
132         TF_LITE_SPACE_TO_DEPTH(optimized_ops, int64_t);
133       }
134       break;
135     default:
136       TF_LITE_KERNEL_LOG(context, "Type '%s' not currently supported.",
137                          TfLiteTypeGetName(input->type));
138       return kTfLiteError;
139   }
140 #undef TF_LITE_SPACE_TO_DEPTH
141 
142   return kTfLiteOk;
143 }
144 
145 }  // namespace space_to_depth
146 
Register_SPACE_TO_DEPTH_REF()147 TfLiteRegistration* Register_SPACE_TO_DEPTH_REF() {
148   static TfLiteRegistration r = {
149       nullptr, nullptr, space_to_depth::Prepare,
150       space_to_depth::Eval<space_to_depth::kReference>};
151   return &r;
152 }
153 
Register_SPACE_TO_DEPTH_GENERIC_OPT()154 TfLiteRegistration* Register_SPACE_TO_DEPTH_GENERIC_OPT() {
155   static TfLiteRegistration r = {
156       nullptr, nullptr, space_to_depth::Prepare,
157       space_to_depth::Eval<space_to_depth::kGenericOptimized>};
158   return &r;
159 }
160 
Register_SPACE_TO_DEPTH()161 TfLiteRegistration* Register_SPACE_TO_DEPTH() {
162   return Register_SPACE_TO_DEPTH_GENERIC_OPT();
163 }
164 
165 }  // namespace builtin
166 }  // namespace ops
167 }  // namespace tflite
168