• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/internal/transpose_utils.h"
16 
17 namespace tflite {
18 namespace transpose_utils {
19 
IsTranspose2DApplicable(const TransposeParams & params,const RuntimeShape & input_shape,int * dim0,int * dim1)20 bool IsTranspose2DApplicable(const TransposeParams& params,
21                              const RuntimeShape& input_shape, int* dim0,
22                              int* dim1) {
23   const int dims_cnt = input_shape.DimensionsCount();
24 
25   if (dims_cnt == 2) {
26     *dim0 = input_shape.Dims(0);
27     *dim1 = input_shape.Dims(1);
28     return true;
29   }
30 
31   const int first_perm = params.perm[0];
32   for (int i = 1; i < dims_cnt; ++i) {
33     int rebased = params.perm[i] - first_perm;
34     if (rebased < 0) {
35       rebased += dims_cnt;
36     }
37     if (rebased != i) {
38       return false;
39     }
40   }
41   *dim0 = 1;
42   *dim1 = 1;
43   for (int i = 0; i < dims_cnt; ++i) {
44     if (i < first_perm) {
45       *dim0 *= input_shape.Dims(i);
46     } else {
47       *dim1 *= input_shape.Dims(i);
48     }
49   }
50   return true;
51 }
52 
RemoveOneSizeDimensions(RuntimeShape * input_shape,RuntimeShape * output_shape,TransposeParams * params)53 void RemoveOneSizeDimensions(RuntimeShape* input_shape,
54                              RuntimeShape* output_shape,
55                              TransposeParams* params) {
56   const int dims_cnt = input_shape->DimensionsCount();
57   TFLITE_DCHECK_EQ(params->perm_count, dims_cnt);
58 
59   bool foundOneSizeDim = false;
60   for (int i = 0; i < dims_cnt; ++i) {
61     if (input_shape->Dims(i) == 1) {
62       foundOneSizeDim = true;
63       break;
64     }
65   }
66 
67   // Return here if there is no one size dimension.
68   if (!foundOneSizeDim) return;
69 
70   // Handle the case where all the dimension size is one.
71   if (input_shape->FlatSize() == 1) {
72     input_shape->Resize(1);
73     input_shape->SetDim(0, 1);
74     output_shape->Resize(1);
75     output_shape->SetDim(0, 1);
76     params->perm_count = 1;
77     params->perm[0] = 0;
78     return;
79   }
80 
81   // Resize input shape.
82   int new_dims_cnt = 0;
83   for (int i = 0; i < dims_cnt; ++i) {
84     if (input_shape->Dims(i) == 1) {
85       continue;
86     }
87     input_shape->SetDim(new_dims_cnt, input_shape->Dims(i));
88     ++new_dims_cnt;
89   }
90   input_shape->Resize(new_dims_cnt);
91 
92   // Resize output shape and re-calculate the perm parameter.
93   TransposeParams new_params;
94   new_dims_cnt = 0;
95   for (int i = 0; i < dims_cnt; ++i) {
96     if (output_shape->Dims(i) == 1) {
97       continue;
98     }
99     new_params.perm[new_dims_cnt] = params->perm[i];
100     output_shape->SetDim(new_dims_cnt, output_shape->Dims(i));
101     ++new_dims_cnt;
102   }
103   output_shape->Resize(new_dims_cnt);
104   new_params.perm_count = new_dims_cnt;
105 
106   for (int i = 0; i < new_dims_cnt; ++i) {
107     int min_val_idx = -1;
108     for (int j = 0; j < new_dims_cnt; ++j) {
109       if (new_params.perm[j] >= i &&
110           (min_val_idx == -1 ||
111            new_params.perm[min_val_idx] > new_params.perm[j])) {
112         min_val_idx = j;
113       }
114     }
115     new_params.perm[min_val_idx] = i;
116   }
117   *params = new_params;
118 }
119 
Flatten(const RuntimeShape & input_shape,const RuntimeShape & output_shape,const TransposeParams & params,RuntimeShape * non_flatten_input_shape,RuntimeShape * non_flatten_output_shape,TransposeParams * non_flatten_params)120 size_t Flatten(const RuntimeShape& input_shape,
121                const RuntimeShape& output_shape, const TransposeParams& params,
122                RuntimeShape* non_flatten_input_shape,
123                RuntimeShape* non_flatten_output_shape,
124                TransposeParams* non_flatten_params) {
125   // Calculate the total size of non-flatten dimensions.
126   int skip_dims_cnt = 0;
127   size_t flat_size = input_shape.FlatSize();
128   for (int i = 0; i < params.perm_count; ++i) {
129     if (params.perm[i] == i) {
130       flat_size /= input_shape.Dims(i);
131       ++skip_dims_cnt;
132     } else {
133       break;
134     }
135   }
136 
137   // Shrink the shapes and re-calculate the perm parameter.
138   const int new_dims_cnt = params.perm_count - skip_dims_cnt;
139   non_flatten_input_shape->Resize(new_dims_cnt);
140   non_flatten_output_shape->Resize(new_dims_cnt);
141   non_flatten_params->perm_count = new_dims_cnt;
142 
143   for (int i = skip_dims_cnt; i < params.perm_count; ++i) {
144     non_flatten_input_shape->SetDim(i - skip_dims_cnt, input_shape.Dims(i));
145     non_flatten_output_shape->SetDim(i - skip_dims_cnt, output_shape.Dims(i));
146     non_flatten_params->perm[i - skip_dims_cnt] = params.perm[i];
147   }
148   for (int i = 0; i < new_dims_cnt; ++i) {
149     int min_val_idx = -1;
150     for (int j = 0; j < new_dims_cnt; ++j) {
151       if (non_flatten_params->perm[j] >= i &&
152           (min_val_idx == -1 || non_flatten_params->perm[min_val_idx] >
153                                     non_flatten_params->perm[j])) {
154         min_val_idx = j;
155       }
156     }
157     non_flatten_params->perm[min_val_idx] = i;
158   }
159 
160   return flat_size;
161 }
162 
163 }  // namespace transpose_utils
164 
165 }  // namespace tflite
166