1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BATCH_TO_SPACE_ND_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BATCH_TO_SPACE_ND_H_
17
18 #include <cmath>
19
20 #include "ruy/profiler/instrumentation.h" // from @ruy
21 #include "tensorflow/lite/kernels/internal/types.h"
22
23 namespace tflite {
24 namespace reference_ops {
25
26 // TODO(b/135760455): Move this method anonymous namespace in a cc file.
ExtendShapeBatchToSpace(const RuntimeShape & shape)27 inline RuntimeShape ExtendShapeBatchToSpace(const RuntimeShape& shape) {
28 if (shape.DimensionsCount() == 4) {
29 return shape;
30 }
31 RuntimeShape new_shape(4, 1);
32 new_shape.SetDim(0, shape.Dims(0));
33 new_shape.SetDim(1, shape.Dims(1));
34 new_shape.SetDim(3, shape.Dims(2));
35 return new_shape;
36 }
37
38 template <typename T>
BatchToSpaceND(const RuntimeShape & unextended_input1_shape,const T * input1_data,const RuntimeShape & unextended_input2_shape,const int32_t * block_shape_data,const RuntimeShape & unextended_input3_shape,const int32_t * crops_data,const RuntimeShape & unextended_output_shape,T * output_data)39 inline void BatchToSpaceND(const RuntimeShape& unextended_input1_shape,
40 const T* input1_data,
41 const RuntimeShape& unextended_input2_shape,
42 const int32_t* block_shape_data,
43 const RuntimeShape& unextended_input3_shape,
44 const int32_t* crops_data,
45 const RuntimeShape& unextended_output_shape,
46 T* output_data) {
47 ruy::profiler::ScopeLabel label("BatchToSpaceND");
48 TFLITE_DCHECK_GE(unextended_input1_shape.DimensionsCount(), 3);
49 TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
50 TFLITE_DCHECK_EQ(unextended_input1_shape.DimensionsCount(),
51 unextended_output_shape.DimensionsCount());
52
53 const RuntimeShape input1_shape =
54 ExtendShapeBatchToSpace(unextended_input1_shape);
55 const RuntimeShape output_shape =
56 ExtendShapeBatchToSpace(unextended_output_shape);
57
58 const int output_width = output_shape.Dims(2);
59 const int output_height = output_shape.Dims(1);
60 const int output_batch_size = output_shape.Dims(0);
61
62 const int depth = input1_shape.Dims(3);
63 const int input_width = input1_shape.Dims(2);
64 const int input_height = input1_shape.Dims(1);
65 const int input_batch_size = input1_shape.Dims(0);
66
67 const int block_shape_height = block_shape_data[0];
68 const int block_shape_width =
69 unextended_input1_shape.DimensionsCount() == 4 ? block_shape_data[1] : 1;
70 const int crops_top = crops_data[0];
71 const int crops_left =
72 unextended_input1_shape.DimensionsCount() == 4 ? crops_data[2] : 0;
73 for (int in_batch = 0; in_batch < input_batch_size; ++in_batch) {
74 const int out_batch = in_batch % output_batch_size;
75 const int spatial_offset = in_batch / output_batch_size;
76 for (int in_h = 0; in_h < input_height; ++in_h) {
77 const int out_h = in_h * block_shape_height +
78 spatial_offset / block_shape_width - crops_top;
79 if (out_h < 0 || out_h >= output_height) {
80 continue;
81 }
82 for (int in_w = 0; in_w < input_width; ++in_w) {
83 const int out_w = in_w * block_shape_width +
84 spatial_offset % block_shape_width - crops_left;
85
86 if (out_w < 0 || out_w >= output_width) {
87 continue;
88 }
89 T* out = output_data + Offset(output_shape, out_batch, out_h, out_w, 0);
90 const T* in =
91 input1_data + Offset(input1_shape, in_batch, in_h, in_w, 0);
92 memcpy(out, in, depth * sizeof(T));
93 }
94 }
95 }
96 }
97
98 } // namespace reference_ops
99 } // namespace tflite
100
101 #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BATCH_TO_SPACE_ND_H_
102