1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
17
18 #include "mlir/IR/Builders.h" // from @llvm-project
19 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
20 #include "mlir/IR/Types.h" // from @llvm-project
21 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
22 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
23 #include "tensorflow/compiler/xla/statusor.h"
24 #include "tensorflow/core/framework/types.pb.h"
25 #include "tensorflow/core/platform/errors.h"
26 #include "tensorflow/lite/schema/schema_generated.h"
27
28 namespace tflite {
29
30 using xla::StatusOr;
31
32 namespace errors = tensorflow::errors;
33
ConvertElementType(tflite::TensorType type,mlir::Builder builder)34 mlir::Type ConvertElementType(tflite::TensorType type, mlir::Builder builder) {
35 switch (type) {
36 case tflite::TensorType_FLOAT16:
37 return builder.getF16Type();
38 case tflite::TensorType_FLOAT32:
39 return builder.getF32Type();
40 case tflite::TensorType_FLOAT64:
41 return builder.getF64Type();
42 case tflite::TensorType_INT32:
43 return builder.getIntegerType(32);
44 case tflite::TensorType_UINT32:
45 return builder.getIntegerType(32, /*isSigned=*/false);
46 case tflite::TensorType_UINT8:
47 return builder.getIntegerType(8, /*isSigned=*/false);
48 case tflite::TensorType_INT64:
49 return builder.getIntegerType(64);
50 case tflite::TensorType_STRING:
51 return mlir::TF::StringType::get(builder.getContext());
52 case tflite::TensorType_BOOL:
53 return builder.getI1Type();
54 case tflite::TensorType_INT16:
55 return builder.getIntegerType(16);
56 case tflite::TensorType_COMPLEX64:
57 return mlir::ComplexType::get(builder.getF32Type());
58 case tflite::TensorType_COMPLEX128:
59 return mlir::ComplexType::get(builder.getF64Type());
60 case tflite::TensorType_INT8:
61 return builder.getIntegerType(8);
62 case tflite::TensorType_UINT64:
63 return builder.getIntegerType(64, /*isSigned=*/false);
64 case tflite::TensorType_RESOURCE:
65 return mlir::TF::ResourceType::get(builder.getContext());
66 case tflite::TensorType_VARIANT:
67 return mlir::TF::VariantType::get(builder.getContext());
68 }
69 }
70
TflTypeToTfType(tflite::TensorType type)71 tensorflow::DataType TflTypeToTfType(tflite::TensorType type) {
72 switch (type) {
73 case tflite::TensorType_BOOL:
74 return tensorflow::DT_BOOL;
75 case tflite::TensorType_COMPLEX64:
76 return tensorflow::DT_COMPLEX64;
77 case tflite::TensorType_COMPLEX128:
78 return tensorflow::DT_COMPLEX128;
79 case tflite::TensorType_FLOAT16:
80 return tensorflow::DT_HALF;
81 case tflite::TensorType_FLOAT32:
82 return tensorflow::DT_FLOAT;
83 case tflite::TensorType_FLOAT64:
84 return tensorflow::DT_DOUBLE;
85 case tflite::TensorType_INT8:
86 return tensorflow::DT_INT8;
87 case tflite::TensorType_INT16:
88 return tensorflow::DT_INT16;
89 case tflite::TensorType_INT32:
90 return tensorflow::DT_INT32;
91 case tflite::TensorType_UINT32:
92 return tensorflow::DT_UINT32;
93 case tflite::TensorType_INT64:
94 return tensorflow::DT_INT64;
95 case tflite::TensorType_STRING:
96 return tensorflow::DT_STRING;
97 case tflite::TensorType_UINT8:
98 return tensorflow::DT_UINT8;
99 case tflite::TensorType_UINT64:
100 return tensorflow::DT_UINT64;
101 case tflite::TensorType_RESOURCE:
102 return tensorflow::DT_RESOURCE;
103 case tflite::TensorType_VARIANT:
104 return tensorflow::DT_VARIANT;
105 }
106 }
107
TfTypeToTflType(tensorflow::DataType type)108 StatusOr<tflite::TensorType> TfTypeToTflType(tensorflow::DataType type) {
109 switch (type) {
110 case tensorflow::DT_BOOL:
111 return tflite::TensorType_BOOL;
112 case tensorflow::DT_COMPLEX64:
113 return tflite::TensorType_COMPLEX64;
114 case tensorflow::DT_COMPLEX128:
115 return tflite::TensorType_COMPLEX128;
116 case tensorflow::DT_HALF:
117 return tflite::TensorType_FLOAT16;
118 case tensorflow::DT_FLOAT:
119 return tflite::TensorType_FLOAT32;
120 case tensorflow::DT_DOUBLE:
121 return tflite::TensorType_FLOAT64;
122 case tensorflow::DT_INT8:
123 return tflite::TensorType_INT8;
124 case tensorflow::DT_INT16:
125 return tflite::TensorType_INT16;
126 case tensorflow::DT_INT32:
127 return tflite::TensorType_INT32;
128 case tensorflow::DT_UINT32:
129 return tflite::TensorType_UINT32;
130 case tensorflow::DT_INT64:
131 return tflite::TensorType_INT64;
132 case tensorflow::DT_UINT64:
133 return tflite::TensorType_UINT64;
134 case tensorflow::DT_STRING:
135 return tflite::TensorType_STRING;
136 case tensorflow::DT_UINT8:
137 return tflite::TensorType_UINT8;
138 case tensorflow::DT_RESOURCE:
139 return tflite::TensorType_RESOURCE;
140 case tensorflow::DT_VARIANT:
141 return tflite::TensorType_VARIANT;
142 default:
143 return errors::InvalidArgument("unsupported tensor data type", type);
144 }
145 }
146
GetShapeStrippedType(mlir::TypeAttr type_attr)147 mlir::Type GetShapeStrippedType(mlir::TypeAttr type_attr) {
148 auto type = type_attr.getValue();
149 auto shaped_type = type.dyn_cast<mlir::ShapedType>();
150 if (shaped_type) {
151 return shaped_type.getElementType();
152 } else {
153 return type;
154 }
155 }
156
NotFromQuantOpOrSameQuantType(mlir::Value val,mlir::TypeAttr qtype_attr)157 bool NotFromQuantOpOrSameQuantType(mlir::Value val, mlir::TypeAttr qtype_attr) {
158 auto val_defn_op = val.getDefiningOp();
159 mlir::TFL::QuantizeOp q_op =
160 llvm::dyn_cast_or_null<mlir::TFL::QuantizeOp>(val_defn_op);
161 if (!q_op) return true;
162
163 // Ignore shape details - we're really only trying to
164 // check if quantization is the same.
165 auto stripped_src_qtype = GetShapeStrippedType(q_op.qtypeAttr());
166 auto stripped_qtype = GetShapeStrippedType(qtype_attr);
167 return stripped_src_qtype == stripped_qtype;
168 }
169
170 } // namespace tflite
171