• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_TRANSPOSE_UTILS_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_TRANSPOSE_UTILS_H_
17 
18 #include "tensorflow/lite/kernels/internal/types.h"
19 
20 namespace tflite {
21 namespace transpose_utils {
22 
23 // IsTranspose2DApplicable returns true if the given perm can be lowered to a
24 // 2D transpose op. If possible, it copies the lowered dimension counts to the
25 // given dim0 and dim1 pointers.
26 bool IsTranspose2DApplicable(const TransposeParams& params,
27                              const RuntimeShape& input_shape, int* dim0,
28                              int* dim1);
29 
30 // RemoveOneSizeDimensions removes one size dimensions in the given input/output
31 // shapes and adjusts the parameter values for transpose op.
32 void RemoveOneSizeDimensions(RuntimeShape* input_shape,
33                              RuntimeShape* output_shape,
34                              TransposeParams* params);
35 
36 // Flatten finds the dimensions that can be flatten, shrinks the given shapes
37 // and the given perm parameter to reflect the non-flatten dimensions, and
38 // returns the total size of the non-flatten dimensions.
39 //
40 // E.g, in perm [0, 1, 3, 2] case, the first two dimensions can be flatten and
41 // it returns |Dim Size(2)| x |Dim Size(3)|.
42 size_t Flatten(const RuntimeShape& input_shape,
43                const RuntimeShape& output_shape, const TransposeParams& params,
44                RuntimeShape* non_flatten_input_shape,
45                RuntimeShape* non_flatten_output_shape,
46                TransposeParams* non_flatten_params);
47 
48 }  // namespace transpose_utils
49 
50 }  // namespace tflite
51 
52 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_TRANSPOSE_UTILS_H_
53