1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stdint.h>
16
17 #include "tensorflow/lite/c/builtin_op_data.h"
18 #include "tensorflow/lite/c/common.h"
19 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
20 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
21 #include "tensorflow/lite/kernels/internal/tensor.h"
22 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
23 #include "tensorflow/lite/kernels/internal/types.h"
24 #include "tensorflow/lite/kernels/kernel_util.h"
25
26 namespace tflite {
27 namespace ops {
28 namespace builtin {
29 namespace space_to_depth {
30
31 // This file has two implementation of SpaceToDepth. Note that SpaceToDepth
32 // only works on 4D tensors.
33 enum KernelType {
34 kReference,
35 kGenericOptimized,
36 };
37
38 constexpr int kInputTensor = 0;
39 constexpr int kOutputTensor = 0;
40
Prepare(TfLiteContext * context,TfLiteNode * node)41 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
42 auto* params =
43 reinterpret_cast<TfLiteSpaceToDepthParams*>(node->builtin_data);
44
45 TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
46 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
47
48 const TfLiteTensor* input;
49 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
50 TfLiteTensor* output;
51 TF_LITE_ENSURE_OK(context,
52 GetOutputSafe(context, node, kOutputTensor, &output));
53
54 TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
55
56 auto data_type = output->type;
57 TF_LITE_ENSURE(context,
58 data_type == kTfLiteFloat32 || data_type == kTfLiteUInt8 ||
59 data_type == kTfLiteInt8 || data_type == kTfLiteInt32 ||
60 data_type == kTfLiteInt64);
61 TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
62
63 const int block_size = params->block_size;
64 TF_LITE_ENSURE(context, block_size > 0);
65 const int input_height = input->dims->data[1];
66 const int input_width = input->dims->data[2];
67 int output_height = input_height / block_size;
68 int output_width = input_width / block_size;
69
70 TF_LITE_ENSURE_EQ(context, input_height, output_height * block_size);
71 TF_LITE_ENSURE_EQ(context, input_width, output_width * block_size);
72
73 TfLiteIntArray* output_size = TfLiteIntArrayCreate(4);
74 output_size->data[0] = input->dims->data[0];
75 output_size->data[1] = output_height;
76 output_size->data[2] = output_width;
77 output_size->data[3] = input->dims->data[3] * block_size * block_size;
78
79 return context->ResizeTensor(context, output, output_size);
80 }
81
82 template <KernelType kernel_type>
Eval(TfLiteContext * context,TfLiteNode * node)83 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
84 auto* params =
85 reinterpret_cast<TfLiteSpaceToDepthParams*>(node->builtin_data);
86
87 const TfLiteTensor* input;
88 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
89 TfLiteTensor* output;
90 TF_LITE_ENSURE_OK(context,
91 GetOutputSafe(context, node, kOutputTensor, &output));
92
93 #define TF_LITE_SPACE_TO_DEPTH(type, scalar) \
94 tflite::SpaceToDepthParams op_params; \
95 op_params.block_size = params->block_size; \
96 type::SpaceToDepth(op_params, GetTensorShape(input), \
97 GetTensorData<scalar>(input), GetTensorShape(output), \
98 GetTensorData<scalar>(output))
99 switch (input->type) { // Already know in/out types are same.
100 case kTfLiteFloat32:
101 if (kernel_type == kReference) {
102 TF_LITE_SPACE_TO_DEPTH(reference_ops, float);
103 } else {
104 TF_LITE_SPACE_TO_DEPTH(optimized_ops, float);
105 }
106 break;
107 case kTfLiteUInt8:
108 if (kernel_type == kReference) {
109 TF_LITE_SPACE_TO_DEPTH(reference_ops, uint8_t);
110 } else {
111 TF_LITE_SPACE_TO_DEPTH(optimized_ops, uint8_t);
112 }
113 break;
114 case kTfLiteInt8:
115 if (kernel_type == kReference) {
116 TF_LITE_SPACE_TO_DEPTH(reference_ops, int8_t);
117 } else {
118 TF_LITE_SPACE_TO_DEPTH(optimized_ops, int8_t);
119 }
120 break;
121 case kTfLiteInt32:
122 if (kernel_type == kReference) {
123 TF_LITE_SPACE_TO_DEPTH(reference_ops, int32_t);
124 } else {
125 TF_LITE_SPACE_TO_DEPTH(optimized_ops, int32_t);
126 }
127 break;
128 case kTfLiteInt64:
129 if (kernel_type == kReference) {
130 TF_LITE_SPACE_TO_DEPTH(reference_ops, int64_t);
131 } else {
132 TF_LITE_SPACE_TO_DEPTH(optimized_ops, int64_t);
133 }
134 break;
135 default:
136 TF_LITE_KERNEL_LOG(context, "Type '%s' not currently supported.",
137 TfLiteTypeGetName(input->type));
138 return kTfLiteError;
139 }
140 #undef TF_LITE_SPACE_TO_DEPTH
141
142 return kTfLiteOk;
143 }
144
145 } // namespace space_to_depth
146
Register_SPACE_TO_DEPTH_REF()147 TfLiteRegistration* Register_SPACE_TO_DEPTH_REF() {
148 static TfLiteRegistration r = {
149 nullptr, nullptr, space_to_depth::Prepare,
150 space_to_depth::Eval<space_to_depth::kReference>};
151 return &r;
152 }
153
Register_SPACE_TO_DEPTH_GENERIC_OPT()154 TfLiteRegistration* Register_SPACE_TO_DEPTH_GENERIC_OPT() {
155 static TfLiteRegistration r = {
156 nullptr, nullptr, space_to_depth::Prepare,
157 space_to_depth::Eval<space_to_depth::kGenericOptimized>};
158 return &r;
159 }
160
Register_SPACE_TO_DEPTH()161 TfLiteRegistration* Register_SPACE_TO_DEPTH() {
162 return Register_SPACE_TO_DEPTH_GENERIC_OPT();
163 }
164
165 } // namespace builtin
166 } // namespace ops
167 } // namespace tflite
168