• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <string.h>
16 #include <vector>
17 #include "tensorflow/contrib/lite/builtin_op_data.h"
18 #include "tensorflow/contrib/lite/context.h"
19 #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
20 #include "tensorflow/contrib/lite/kernels/internal/tensor.h"
21 #include "tensorflow/contrib/lite/kernels/kernel_util.h"
22 #include "tensorflow/contrib/lite/kernels/op_macros.h"
23 
24 namespace tflite {
25 namespace ops {
26 namespace builtin {
27 namespace transpose {
28 
29 // This file has two implementations of Transpose.
30 enum KernelType {
31   kReference,
32 };
33 
34 struct TransposeContext {
TransposeContexttflite::ops::builtin::transpose::TransposeContext35   TransposeContext(TfLiteContext* context, TfLiteNode* node) {
36     input = GetInput(context, node, 0);
37     perm = GetInput(context, node, 1);
38     output = GetOutput(context, node, 0);
39   }
40   TfLiteTensor* input;
41   TfLiteTensor* perm;
42   TfLiteTensor* output;
43 };
44 
ResizeOutputTensor(TfLiteContext * context,TransposeContext * op_context)45 TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
46                                 TransposeContext* op_context) {
47   int dims = NumDimensions(op_context->input);
48   const int* perm_data = GetTensorData<int32_t>(op_context->perm);
49 
50   // Ensure validity of the permutations tensor as a 1D tensor.
51   TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->perm), 1);
52   TF_LITE_ENSURE_EQ(context, op_context->perm->dims->data[0], dims);
53   for (int idx = 0; idx < dims; ++idx) {
54     TF_LITE_ENSURE_MSG(context, (perm_data[idx] >= 0 && perm_data[idx] < dims),
55                        "Transpose op permutations array is out of bounds.");
56   }
57 
58   // Determine size of output tensor.
59   TfLiteIntArray* input_size = op_context->input->dims;
60   TfLiteIntArray* output_size = TfLiteIntArrayCopy(input_size);
61   for (int idx = 0; idx < dims; ++idx) {
62     output_size->data[idx] = input_size->data[perm_data[idx]];
63   }
64 
65   return context->ResizeTensor(context, op_context->output, output_size);
66 }
67 
Prepare(TfLiteContext * context,TfLiteNode * node)68 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
69   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
70   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
71 
72   TransposeContext op_context(context, node);
73 
74   // Ensure validity of input tensor.
75   TF_LITE_ENSURE_MSG(context, NumDimensions(op_context.input) <= 4,
76                      "Transpose op only supports 1D-4D input arrays.");
77   TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type);
78 
79   if (!IsConstantTensor(op_context.perm)) {
80     SetTensorToDynamic(op_context.output);
81     return kTfLiteOk;
82   }
83   return ResizeOutputTensor(context, &op_context);
84 }
85 
86 template <KernelType kernel_type>
Eval(TfLiteContext * context,TfLiteNode * node)87 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
88   TransposeContext op_context(context, node);
89 
90   // Resize the output tensor if the output tensor is dynamic.
91   if (IsDynamicTensor(op_context.output)) {
92     TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
93   }
94 
95   // Reverse the permuted axes and convert to 4D due to the way Dims are
96   // constructed in GetTensorDims.
97   const int* perm_data = GetTensorData<int32_t>(op_context.perm);
98   const int size = op_context.perm->dims->data[0];
99   const int kOutputDimensionNum = 4;
100   int reversed_perm[kOutputDimensionNum];
101 
102   for (int output_k = 0, input_k = size - 1; output_k < size;
103        ++output_k, --input_k) {
104     reversed_perm[output_k] = size - perm_data[input_k] - 1;
105   }
106   for (int k = size; k < kOutputDimensionNum; ++k) {
107     reversed_perm[k] = k;
108   }
109 
110 #define TF_LITE_TRANSPOSE(type, scalar)                     \
111   type::Transpose(GetTensorData<scalar>(op_context.input),  \
112                   GetTensorDims(op_context.input),          \
113                   GetTensorData<scalar>(op_context.output), \
114                   GetTensorDims(op_context.output), reversed_perm)
115 
116   switch (op_context.input->type) {
117     case kTfLiteFloat32:
118       if (kernel_type == kReference) {
119         TF_LITE_TRANSPOSE(reference_ops, float);
120       }
121       break;
122     case kTfLiteUInt8:
123       if (kernel_type == kReference) {
124         TF_LITE_TRANSPOSE(reference_ops, uint8_t);
125       }
126       break;
127     case kTfLiteInt32:
128       if (kernel_type == kReference) {
129         TF_LITE_TRANSPOSE(reference_ops, int32_t);
130       }
131       break;
132     case kTfLiteInt64:
133       if (kernel_type == kReference) {
134         TF_LITE_TRANSPOSE(reference_ops, int64_t);
135       }
136       break;
137     default:
138       context->ReportError(context,
139                            "Type is currently not supported by Transpose.");
140       return kTfLiteError;
141   }
142 #undef TF_LITE_TRANSPOSE
143 
144   return kTfLiteOk;
145 }
146 
147 }  // namespace transpose
148 
Register_TRANSPOSE_REF()149 TfLiteRegistration* Register_TRANSPOSE_REF() {
150   static TfLiteRegistration r = {nullptr, nullptr, transpose::Prepare,
151                                  transpose::Eval<transpose::kReference>};
152   return &r;
153 }
154 
Register_TRANSPOSE()155 TfLiteRegistration* Register_TRANSPOSE() { return Register_TRANSPOSE_REF(); }
156 
157 }  // namespace builtin
158 }  // namespace ops
159 }  // namespace tflite
160