1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
17
18 #include "mlir/IR/Builders.h" // from @llvm-project
19 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
20 #include "mlir/IR/Types.h" // from @llvm-project
21 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
22 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
23 #include "tensorflow/compiler/xla/statusor.h"
24 #include "tensorflow/core/framework/types.pb.h"
25 #include "tensorflow/core/platform/errors.h"
26 #include "tensorflow/lite/schema/schema_generated.h"
27
28 namespace tflite {
29
30 using xla::StatusOr;
31
32 namespace errors = tensorflow::errors;
33
ConvertTypeToTensorType(mlir::Type type)34 tflite::TensorType ConvertTypeToTensorType(mlir::Type type) {
35 if (type.isF16()) {
36 return tflite::TensorType_FLOAT16;
37 } else if (type.isF32()) {
38 return tflite::TensorType_FLOAT32;
39 } else if (type.isF64()) {
40 return tflite::TensorType_FLOAT64;
41 } else if (type.isa<mlir::TF::StringType>()) {
42 return tflite::TensorType_STRING;
43 } else if (auto complex_type = type.dyn_cast<mlir::ComplexType>()) {
44 if (complex_type.getElementType().isF32()) {
45 return tflite::TensorType_COMPLEX64;
46 } else if (complex_type.getElementType().isF64()) {
47 return tflite::TensorType_COMPLEX128;
48 }
49 llvm_unreachable("invalid complex Type in conversion");
50 } else if (auto itype = type.dyn_cast<mlir::IntegerType>()) {
51 switch (itype.getWidth()) {
52 case 1:
53 return tflite::TensorType_BOOL;
54 case 8:
55 if (itype.isUnsigned())
56 return tflite::TensorType_UINT8;
57 else
58 return tflite::TensorType_INT8;
59 case 16:
60 return tflite::TensorType_INT16;
61 case 32:
62 return tflite::TensorType_INT32;
63 case 64:
64 if (itype.isUnsigned())
65 return tflite::TensorType_UINT64;
66 else
67 return tflite::TensorType_INT64;
68 default:
69 llvm_unreachable("invalid integer Type in conversion");
70 }
71 }
72 llvm_unreachable("invalid Type in conversion");
73 }
74
ConvertElementType(tflite::TensorType type,mlir::Builder builder)75 mlir::Type ConvertElementType(tflite::TensorType type, mlir::Builder builder) {
76 switch (type) {
77 case tflite::TensorType_FLOAT16:
78 return builder.getF16Type();
79 case tflite::TensorType_FLOAT32:
80 return builder.getF32Type();
81 case tflite::TensorType_FLOAT64:
82 return builder.getF64Type();
83 case tflite::TensorType_INT32:
84 return builder.getIntegerType(32);
85 case tflite::TensorType_UINT32:
86 return builder.getIntegerType(32, /*isSigned=*/false);
87 case tflite::TensorType_UINT8:
88 return builder.getIntegerType(8, /*isSigned=*/false);
89 case tflite::TensorType_INT64:
90 return builder.getIntegerType(64);
91 case tflite::TensorType_STRING:
92 return mlir::TF::StringType::get(builder.getContext());
93 case tflite::TensorType_BOOL:
94 return builder.getI1Type();
95 case tflite::TensorType_INT16:
96 return builder.getIntegerType(16);
97 case tflite::TensorType_COMPLEX64:
98 return mlir::ComplexType::get(builder.getF32Type());
99 case tflite::TensorType_COMPLEX128:
100 return mlir::ComplexType::get(builder.getF64Type());
101 case tflite::TensorType_INT8:
102 return builder.getIntegerType(8);
103 case tflite::TensorType_UINT64:
104 return builder.getIntegerType(64, /*isSigned=*/false);
105 case tflite::TensorType_RESOURCE:
106 return mlir::TF::ResourceType::get(builder.getContext());
107 case tflite::TensorType_VARIANT:
108 return mlir::TF::VariantType::get(builder.getContext());
109 }
110 }
111
TflTypeToTfType(tflite::TensorType type)112 tensorflow::DataType TflTypeToTfType(tflite::TensorType type) {
113 switch (type) {
114 case tflite::TensorType_BOOL:
115 return tensorflow::DT_BOOL;
116 case tflite::TensorType_COMPLEX64:
117 return tensorflow::DT_COMPLEX64;
118 case tflite::TensorType_COMPLEX128:
119 return tensorflow::DT_COMPLEX128;
120 case tflite::TensorType_FLOAT16:
121 return tensorflow::DT_HALF;
122 case tflite::TensorType_FLOAT32:
123 return tensorflow::DT_FLOAT;
124 case tflite::TensorType_FLOAT64:
125 return tensorflow::DT_DOUBLE;
126 case tflite::TensorType_INT8:
127 return tensorflow::DT_INT8;
128 case tflite::TensorType_INT16:
129 return tensorflow::DT_INT16;
130 case tflite::TensorType_INT32:
131 return tensorflow::DT_INT32;
132 case tflite::TensorType_UINT32:
133 return tensorflow::DT_UINT32;
134 case tflite::TensorType_INT64:
135 return tensorflow::DT_INT64;
136 case tflite::TensorType_STRING:
137 return tensorflow::DT_STRING;
138 case tflite::TensorType_UINT8:
139 return tensorflow::DT_UINT8;
140 case tflite::TensorType_UINT64:
141 return tensorflow::DT_UINT64;
142 case tflite::TensorType_RESOURCE:
143 return tensorflow::DT_RESOURCE;
144 case tflite::TensorType_VARIANT:
145 return tensorflow::DT_VARIANT;
146 }
147 }
148
TfTypeToTflType(tensorflow::DataType type)149 StatusOr<tflite::TensorType> TfTypeToTflType(tensorflow::DataType type) {
150 switch (type) {
151 case tensorflow::DT_BOOL:
152 return tflite::TensorType_BOOL;
153 case tensorflow::DT_COMPLEX64:
154 return tflite::TensorType_COMPLEX64;
155 case tensorflow::DT_COMPLEX128:
156 return tflite::TensorType_COMPLEX128;
157 case tensorflow::DT_HALF:
158 return tflite::TensorType_FLOAT16;
159 case tensorflow::DT_FLOAT:
160 return tflite::TensorType_FLOAT32;
161 case tensorflow::DT_DOUBLE:
162 return tflite::TensorType_FLOAT64;
163 case tensorflow::DT_INT8:
164 return tflite::TensorType_INT8;
165 case tensorflow::DT_INT16:
166 return tflite::TensorType_INT16;
167 case tensorflow::DT_INT32:
168 return tflite::TensorType_INT32;
169 case tensorflow::DT_UINT32:
170 return tflite::TensorType_UINT32;
171 case tensorflow::DT_INT64:
172 return tflite::TensorType_INT64;
173 case tensorflow::DT_UINT64:
174 return tflite::TensorType_UINT64;
175 case tensorflow::DT_STRING:
176 return tflite::TensorType_STRING;
177 case tensorflow::DT_UINT8:
178 return tflite::TensorType_UINT8;
179 case tensorflow::DT_RESOURCE:
180 return tflite::TensorType_RESOURCE;
181 case tensorflow::DT_VARIANT:
182 return tflite::TensorType_VARIANT;
183 default:
184 return errors::InvalidArgument("unsupported tensor data type", type);
185 }
186 }
187
GetShapeStrippedType(mlir::TypeAttr type_attr)188 mlir::Type GetShapeStrippedType(mlir::TypeAttr type_attr) {
189 auto type = type_attr.getValue();
190 auto shaped_type = type.dyn_cast<mlir::ShapedType>();
191 if (shaped_type) {
192 return shaped_type.getElementType();
193 } else {
194 return type;
195 }
196 }
197
NotFromQuantOpOrSameQuantType(mlir::Value val,mlir::TypeAttr qtype_attr)198 bool NotFromQuantOpOrSameQuantType(mlir::Value val, mlir::TypeAttr qtype_attr) {
199 auto val_defn_op = val.getDefiningOp();
200 mlir::TFL::QuantizeOp q_op =
201 llvm::dyn_cast_or_null<mlir::TFL::QuantizeOp>(val_defn_op);
202 if (!q_op) return true;
203
204 // Ignore shape details - we're really only trying to
205 // check if quantization is the same.
206 auto stripped_src_qtype = GetShapeStrippedType(q_op.qtypeAttr());
207 auto stripped_qtype = GetShapeStrippedType(qtype_attr);
208 return stripped_src_qtype == stripped_qtype;
209 }
210
211 } // namespace tflite
212