• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
17 
18 #include "mlir/IR/Builders.h"  // from @llvm-project
19 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
20 #include "mlir/IR/Types.h"  // from @llvm-project
21 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
22 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
23 #include "tensorflow/compiler/xla/statusor.h"
24 #include "tensorflow/core/framework/types.pb.h"
25 #include "tensorflow/core/platform/errors.h"
26 #include "tensorflow/lite/schema/schema_generated.h"
27 
28 namespace tflite {
29 
30 using xla::StatusOr;
31 
32 namespace errors = tensorflow::errors;
33 
ConvertTypeToTensorType(mlir::Type type)34 tflite::TensorType ConvertTypeToTensorType(mlir::Type type) {
35   if (type.isF16()) {
36     return tflite::TensorType_FLOAT16;
37   } else if (type.isF32()) {
38     return tflite::TensorType_FLOAT32;
39   } else if (type.isF64()) {
40     return tflite::TensorType_FLOAT64;
41   } else if (type.isa<mlir::TF::StringType>()) {
42     return tflite::TensorType_STRING;
43   } else if (auto complex_type = type.dyn_cast<mlir::ComplexType>()) {
44     if (complex_type.getElementType().isF32()) {
45       return tflite::TensorType_COMPLEX64;
46     } else if (complex_type.getElementType().isF64()) {
47       return tflite::TensorType_COMPLEX128;
48     }
49     llvm_unreachable("invalid complex Type in conversion");
50   } else if (auto itype = type.dyn_cast<mlir::IntegerType>()) {
51     switch (itype.getWidth()) {
52       case 1:
53         return tflite::TensorType_BOOL;
54       case 8:
55         if (itype.isUnsigned())
56           return tflite::TensorType_UINT8;
57         else
58           return tflite::TensorType_INT8;
59       case 16:
60         return tflite::TensorType_INT16;
61       case 32:
62         return tflite::TensorType_INT32;
63       case 64:
64         if (itype.isUnsigned())
65           return tflite::TensorType_UINT64;
66         else
67           return tflite::TensorType_INT64;
68       default:
69         llvm_unreachable("invalid integer Type in conversion");
70     }
71   }
72   llvm_unreachable("invalid Type in conversion");
73 }
74 
ConvertElementType(tflite::TensorType type,mlir::Builder builder)75 mlir::Type ConvertElementType(tflite::TensorType type, mlir::Builder builder) {
76   switch (type) {
77     case tflite::TensorType_FLOAT16:
78       return builder.getF16Type();
79     case tflite::TensorType_FLOAT32:
80       return builder.getF32Type();
81     case tflite::TensorType_FLOAT64:
82       return builder.getF64Type();
83     case tflite::TensorType_INT32:
84       return builder.getIntegerType(32);
85     case tflite::TensorType_UINT32:
86       return builder.getIntegerType(32, /*isSigned=*/false);
87     case tflite::TensorType_UINT8:
88       return builder.getIntegerType(8, /*isSigned=*/false);
89     case tflite::TensorType_INT64:
90       return builder.getIntegerType(64);
91     case tflite::TensorType_STRING:
92       return mlir::TF::StringType::get(builder.getContext());
93     case tflite::TensorType_BOOL:
94       return builder.getI1Type();
95     case tflite::TensorType_INT16:
96       return builder.getIntegerType(16);
97     case tflite::TensorType_COMPLEX64:
98       return mlir::ComplexType::get(builder.getF32Type());
99     case tflite::TensorType_COMPLEX128:
100       return mlir::ComplexType::get(builder.getF64Type());
101     case tflite::TensorType_INT8:
102       return builder.getIntegerType(8);
103     case tflite::TensorType_UINT64:
104       return builder.getIntegerType(64, /*isSigned=*/false);
105     case tflite::TensorType_RESOURCE:
106       return mlir::TF::ResourceType::get(builder.getContext());
107     case tflite::TensorType_VARIANT:
108       return mlir::TF::VariantType::get(builder.getContext());
109   }
110 }
111 
TflTypeToTfType(tflite::TensorType type)112 tensorflow::DataType TflTypeToTfType(tflite::TensorType type) {
113   switch (type) {
114     case tflite::TensorType_BOOL:
115       return tensorflow::DT_BOOL;
116     case tflite::TensorType_COMPLEX64:
117       return tensorflow::DT_COMPLEX64;
118     case tflite::TensorType_COMPLEX128:
119       return tensorflow::DT_COMPLEX128;
120     case tflite::TensorType_FLOAT16:
121       return tensorflow::DT_HALF;
122     case tflite::TensorType_FLOAT32:
123       return tensorflow::DT_FLOAT;
124     case tflite::TensorType_FLOAT64:
125       return tensorflow::DT_DOUBLE;
126     case tflite::TensorType_INT8:
127       return tensorflow::DT_INT8;
128     case tflite::TensorType_INT16:
129       return tensorflow::DT_INT16;
130     case tflite::TensorType_INT32:
131       return tensorflow::DT_INT32;
132     case tflite::TensorType_UINT32:
133       return tensorflow::DT_UINT32;
134     case tflite::TensorType_INT64:
135       return tensorflow::DT_INT64;
136     case tflite::TensorType_STRING:
137       return tensorflow::DT_STRING;
138     case tflite::TensorType_UINT8:
139       return tensorflow::DT_UINT8;
140     case tflite::TensorType_UINT64:
141       return tensorflow::DT_UINT64;
142     case tflite::TensorType_RESOURCE:
143       return tensorflow::DT_RESOURCE;
144     case tflite::TensorType_VARIANT:
145       return tensorflow::DT_VARIANT;
146   }
147 }
148 
TfTypeToTflType(tensorflow::DataType type)149 StatusOr<tflite::TensorType> TfTypeToTflType(tensorflow::DataType type) {
150   switch (type) {
151     case tensorflow::DT_BOOL:
152       return tflite::TensorType_BOOL;
153     case tensorflow::DT_COMPLEX64:
154       return tflite::TensorType_COMPLEX64;
155     case tensorflow::DT_COMPLEX128:
156       return tflite::TensorType_COMPLEX128;
157     case tensorflow::DT_HALF:
158       return tflite::TensorType_FLOAT16;
159     case tensorflow::DT_FLOAT:
160       return tflite::TensorType_FLOAT32;
161     case tensorflow::DT_DOUBLE:
162       return tflite::TensorType_FLOAT64;
163     case tensorflow::DT_INT8:
164       return tflite::TensorType_INT8;
165     case tensorflow::DT_INT16:
166       return tflite::TensorType_INT16;
167     case tensorflow::DT_INT32:
168       return tflite::TensorType_INT32;
169     case tensorflow::DT_UINT32:
170       return tflite::TensorType_UINT32;
171     case tensorflow::DT_INT64:
172       return tflite::TensorType_INT64;
173     case tensorflow::DT_UINT64:
174       return tflite::TensorType_UINT64;
175     case tensorflow::DT_STRING:
176       return tflite::TensorType_STRING;
177     case tensorflow::DT_UINT8:
178       return tflite::TensorType_UINT8;
179     case tensorflow::DT_RESOURCE:
180       return tflite::TensorType_RESOURCE;
181     case tensorflow::DT_VARIANT:
182       return tflite::TensorType_VARIANT;
183     default:
184       return errors::InvalidArgument("unsupported tensor data type", type);
185   }
186 }
187 
GetShapeStrippedType(mlir::TypeAttr type_attr)188 mlir::Type GetShapeStrippedType(mlir::TypeAttr type_attr) {
189   auto type = type_attr.getValue();
190   auto shaped_type = type.dyn_cast<mlir::ShapedType>();
191   if (shaped_type) {
192     return shaped_type.getElementType();
193   } else {
194     return type;
195   }
196 }
197 
NotFromQuantOpOrSameQuantType(mlir::Value val,mlir::TypeAttr qtype_attr)198 bool NotFromQuantOpOrSameQuantType(mlir::Value val, mlir::TypeAttr qtype_attr) {
199   auto val_defn_op = val.getDefiningOp();
200   mlir::TFL::QuantizeOp q_op =
201       llvm::dyn_cast_or_null<mlir::TFL::QuantizeOp>(val_defn_op);
202   if (!q_op) return true;
203 
204   // Ignore shape details - we're really only trying to
205   // check if quantization is the same.
206   auto stripped_src_qtype = GetShapeStrippedType(q_op.qtypeAttr());
207   auto stripped_qtype = GetShapeStrippedType(qtype_attr);
208   return stripped_src_qtype == stripped_qtype;
209 }
210 
211 }  // namespace tflite
212