1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/internal/transpose_utils.h"
16
17 namespace tflite {
18 namespace transpose_utils {
19
IsTranspose2DApplicable(const TransposeParams & params,const RuntimeShape & input_shape,int * dim0,int * dim1)20 bool IsTranspose2DApplicable(const TransposeParams& params,
21 const RuntimeShape& input_shape, int* dim0,
22 int* dim1) {
23 const int dims_cnt = input_shape.DimensionsCount();
24
25 if (dims_cnt == 2) {
26 *dim0 = input_shape.Dims(0);
27 *dim1 = input_shape.Dims(1);
28 return true;
29 }
30
31 const int first_perm = params.perm[0];
32 for (int i = 1; i < dims_cnt; ++i) {
33 int rebased = params.perm[i] - first_perm;
34 if (rebased < 0) {
35 rebased += dims_cnt;
36 }
37 if (rebased != i) {
38 return false;
39 }
40 }
41 *dim0 = 1;
42 *dim1 = 1;
43 for (int i = 0; i < dims_cnt; ++i) {
44 if (i < first_perm) {
45 *dim0 *= input_shape.Dims(i);
46 } else {
47 *dim1 *= input_shape.Dims(i);
48 }
49 }
50 return true;
51 }
52
RemoveOneSizeDimensions(RuntimeShape * input_shape,RuntimeShape * output_shape,TransposeParams * params)53 void RemoveOneSizeDimensions(RuntimeShape* input_shape,
54 RuntimeShape* output_shape,
55 TransposeParams* params) {
56 const int dims_cnt = input_shape->DimensionsCount();
57 TFLITE_DCHECK_EQ(params->perm_count, dims_cnt);
58
59 bool foundOneSizeDim = false;
60 for (int i = 0; i < dims_cnt; ++i) {
61 if (input_shape->Dims(i) == 1) {
62 foundOneSizeDim = true;
63 break;
64 }
65 }
66
67 // Return here if there is no one size dimension.
68 if (!foundOneSizeDim) return;
69
70 // Handle the case where all the dimension size is one.
71 if (input_shape->FlatSize() == 1) {
72 input_shape->Resize(1);
73 input_shape->SetDim(0, 1);
74 output_shape->Resize(1);
75 output_shape->SetDim(0, 1);
76 params->perm_count = 1;
77 params->perm[0] = 0;
78 return;
79 }
80
81 // Resize input shape.
82 int new_dims_cnt = 0;
83 for (int i = 0; i < dims_cnt; ++i) {
84 if (input_shape->Dims(i) == 1) {
85 continue;
86 }
87 input_shape->SetDim(new_dims_cnt, input_shape->Dims(i));
88 ++new_dims_cnt;
89 }
90 input_shape->Resize(new_dims_cnt);
91
92 // Resize output shape and re-calculate the perm parameter.
93 TransposeParams new_params;
94 new_dims_cnt = 0;
95 for (int i = 0; i < dims_cnt; ++i) {
96 if (output_shape->Dims(i) == 1) {
97 continue;
98 }
99 new_params.perm[new_dims_cnt] = params->perm[i];
100 output_shape->SetDim(new_dims_cnt, output_shape->Dims(i));
101 ++new_dims_cnt;
102 }
103 output_shape->Resize(new_dims_cnt);
104 new_params.perm_count = new_dims_cnt;
105
106 for (int i = 0; i < new_dims_cnt; ++i) {
107 int min_val_idx = -1;
108 for (int j = 0; j < new_dims_cnt; ++j) {
109 if (new_params.perm[j] >= i &&
110 (min_val_idx == -1 ||
111 new_params.perm[min_val_idx] > new_params.perm[j])) {
112 min_val_idx = j;
113 }
114 }
115 new_params.perm[min_val_idx] = i;
116 }
117 *params = new_params;
118 }
119
Flatten(const RuntimeShape & input_shape,const RuntimeShape & output_shape,const TransposeParams & params,RuntimeShape * non_flatten_input_shape,RuntimeShape * non_flatten_output_shape,TransposeParams * non_flatten_params)120 size_t Flatten(const RuntimeShape& input_shape,
121 const RuntimeShape& output_shape, const TransposeParams& params,
122 RuntimeShape* non_flatten_input_shape,
123 RuntimeShape* non_flatten_output_shape,
124 TransposeParams* non_flatten_params) {
125 // Calculate the total size of non-flatten dimensions.
126 int skip_dims_cnt = 0;
127 size_t flat_size = input_shape.FlatSize();
128 for (int i = 0; i < params.perm_count; ++i) {
129 if (params.perm[i] == i) {
130 flat_size /= input_shape.Dims(i);
131 ++skip_dims_cnt;
132 } else {
133 break;
134 }
135 }
136
137 // Shrink the shapes and re-calculate the perm parameter.
138 const int new_dims_cnt = params.perm_count - skip_dims_cnt;
139 non_flatten_input_shape->Resize(new_dims_cnt);
140 non_flatten_output_shape->Resize(new_dims_cnt);
141 non_flatten_params->perm_count = new_dims_cnt;
142
143 for (int i = skip_dims_cnt; i < params.perm_count; ++i) {
144 non_flatten_input_shape->SetDim(i - skip_dims_cnt, input_shape.Dims(i));
145 non_flatten_output_shape->SetDim(i - skip_dims_cnt, output_shape.Dims(i));
146 non_flatten_params->perm[i - skip_dims_cnt] = params.perm[i];
147 }
148 for (int i = 0; i < new_dims_cnt; ++i) {
149 int min_val_idx = -1;
150 for (int j = 0; j < new_dims_cnt; ++j) {
151 if (non_flatten_params->perm[j] >= i &&
152 (min_val_idx == -1 || non_flatten_params->perm[min_val_idx] >
153 non_flatten_params->perm[j])) {
154 min_val_idx = j;
155 }
156 }
157 non_flatten_params->perm[min_val_idx] = i;
158 }
159
160 return flat_size;
161 }
162
163 } // namespace transpose_utils
164
165 } // namespace tflite
166