• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
17 
18 #include "mlir/IR/Builders.h"  // from @llvm-project
19 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
20 #include "mlir/IR/Types.h"  // from @llvm-project
21 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
22 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
23 #include "tensorflow/compiler/xla/statusor.h"
24 #include "tensorflow/core/framework/types.pb.h"
25 #include "tensorflow/core/platform/errors.h"
26 #include "tensorflow/lite/schema/schema_generated.h"
27 
28 namespace tflite {
29 
30 using xla::StatusOr;
31 
32 namespace errors = tensorflow::errors;
33 
ConvertElementType(tflite::TensorType type,mlir::Builder builder)34 mlir::Type ConvertElementType(tflite::TensorType type, mlir::Builder builder) {
35   switch (type) {
36     case tflite::TensorType_FLOAT16:
37       return builder.getF16Type();
38     case tflite::TensorType_FLOAT32:
39       return builder.getF32Type();
40     case tflite::TensorType_FLOAT64:
41       return builder.getF64Type();
42     case tflite::TensorType_INT32:
43       return builder.getIntegerType(32);
44     case tflite::TensorType_UINT32:
45       return builder.getIntegerType(32, /*isSigned=*/false);
46     case tflite::TensorType_UINT8:
47       return builder.getIntegerType(8, /*isSigned=*/false);
48     case tflite::TensorType_INT64:
49       return builder.getIntegerType(64);
50     case tflite::TensorType_STRING:
51       return mlir::TF::StringType::get(builder.getContext());
52     case tflite::TensorType_BOOL:
53       return builder.getI1Type();
54     case tflite::TensorType_INT16:
55       return builder.getIntegerType(16);
56     case tflite::TensorType_COMPLEX64:
57       return mlir::ComplexType::get(builder.getF32Type());
58     case tflite::TensorType_COMPLEX128:
59       return mlir::ComplexType::get(builder.getF64Type());
60     case tflite::TensorType_INT8:
61       return builder.getIntegerType(8);
62     case tflite::TensorType_UINT64:
63       return builder.getIntegerType(64, /*isSigned=*/false);
64     case tflite::TensorType_RESOURCE:
65       return mlir::TF::ResourceType::get(builder.getContext());
66     case tflite::TensorType_VARIANT:
67       return mlir::TF::VariantType::get(builder.getContext());
68   }
69 }
70 
TflTypeToTfType(tflite::TensorType type)71 tensorflow::DataType TflTypeToTfType(tflite::TensorType type) {
72   switch (type) {
73     case tflite::TensorType_BOOL:
74       return tensorflow::DT_BOOL;
75     case tflite::TensorType_COMPLEX64:
76       return tensorflow::DT_COMPLEX64;
77     case tflite::TensorType_COMPLEX128:
78       return tensorflow::DT_COMPLEX128;
79     case tflite::TensorType_FLOAT16:
80       return tensorflow::DT_HALF;
81     case tflite::TensorType_FLOAT32:
82       return tensorflow::DT_FLOAT;
83     case tflite::TensorType_FLOAT64:
84       return tensorflow::DT_DOUBLE;
85     case tflite::TensorType_INT8:
86       return tensorflow::DT_INT8;
87     case tflite::TensorType_INT16:
88       return tensorflow::DT_INT16;
89     case tflite::TensorType_INT32:
90       return tensorflow::DT_INT32;
91     case tflite::TensorType_UINT32:
92       return tensorflow::DT_UINT32;
93     case tflite::TensorType_INT64:
94       return tensorflow::DT_INT64;
95     case tflite::TensorType_STRING:
96       return tensorflow::DT_STRING;
97     case tflite::TensorType_UINT8:
98       return tensorflow::DT_UINT8;
99     case tflite::TensorType_UINT64:
100       return tensorflow::DT_UINT64;
101     case tflite::TensorType_RESOURCE:
102       return tensorflow::DT_RESOURCE;
103     case tflite::TensorType_VARIANT:
104       return tensorflow::DT_VARIANT;
105   }
106 }
107 
TfTypeToTflType(tensorflow::DataType type)108 StatusOr<tflite::TensorType> TfTypeToTflType(tensorflow::DataType type) {
109   switch (type) {
110     case tensorflow::DT_BOOL:
111       return tflite::TensorType_BOOL;
112     case tensorflow::DT_COMPLEX64:
113       return tflite::TensorType_COMPLEX64;
114     case tensorflow::DT_COMPLEX128:
115       return tflite::TensorType_COMPLEX128;
116     case tensorflow::DT_HALF:
117       return tflite::TensorType_FLOAT16;
118     case tensorflow::DT_FLOAT:
119       return tflite::TensorType_FLOAT32;
120     case tensorflow::DT_DOUBLE:
121       return tflite::TensorType_FLOAT64;
122     case tensorflow::DT_INT8:
123       return tflite::TensorType_INT8;
124     case tensorflow::DT_INT16:
125       return tflite::TensorType_INT16;
126     case tensorflow::DT_INT32:
127       return tflite::TensorType_INT32;
128     case tensorflow::DT_UINT32:
129       return tflite::TensorType_UINT32;
130     case tensorflow::DT_INT64:
131       return tflite::TensorType_INT64;
132     case tensorflow::DT_UINT64:
133       return tflite::TensorType_UINT64;
134     case tensorflow::DT_STRING:
135       return tflite::TensorType_STRING;
136     case tensorflow::DT_UINT8:
137       return tflite::TensorType_UINT8;
138     case tensorflow::DT_RESOURCE:
139       return tflite::TensorType_RESOURCE;
140     case tensorflow::DT_VARIANT:
141       return tflite::TensorType_VARIANT;
142     default:
143       return errors::InvalidArgument("unsupported tensor data type", type);
144   }
145 }
146 
GetShapeStrippedType(mlir::TypeAttr type_attr)147 mlir::Type GetShapeStrippedType(mlir::TypeAttr type_attr) {
148   auto type = type_attr.getValue();
149   auto shaped_type = type.dyn_cast<mlir::ShapedType>();
150   if (shaped_type) {
151     return shaped_type.getElementType();
152   } else {
153     return type;
154   }
155 }
156 
NotFromQuantOpOrSameQuantType(mlir::Value val,mlir::TypeAttr qtype_attr)157 bool NotFromQuantOpOrSameQuantType(mlir::Value val, mlir::TypeAttr qtype_attr) {
158   auto val_defn_op = val.getDefiningOp();
159   mlir::TFL::QuantizeOp q_op =
160       llvm::dyn_cast_or_null<mlir::TFL::QuantizeOp>(val_defn_op);
161   if (!q_op) return true;
162 
163   // Ignore shape details - we're really only trying to
164   // check if quantization is the same.
165   auto stripped_src_qtype = GetShapeStrippedType(q_op.qtypeAttr());
166   auto stripped_qtype = GetShapeStrippedType(qtype_attr);
167   return stripped_src_qtype == stripped_qtype;
168 }
169 
170 }  // namespace tflite
171