1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
17
18 #include "absl/strings/str_cat.h"
19 #include "llvm/Support/Casting.h"
20 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
21 #include "mlir/IR/Types.h" // from @llvm-project
22 #include "mlir/Support/DebugStringHelper.h" // from @llvm-project
23 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
24 #include "tensorflow/core/framework/types.h"
25 #include "tensorflow/core/framework/types.pb.h"
26 #include "tensorflow/core/lib/core/errors.h"
27
28 namespace tensorflow {
29
30 using mlir::Builder;
31 using mlir::ShapedType;
32 using mlir::Type;
33
ConvertDataType(DataType dtype,Builder builder,Type * type)34 Status ConvertDataType(DataType dtype, Builder builder, Type* type) {
35 switch (dtype) {
36 case DT_HALF:
37 *type = builder.getF16Type();
38 return Status::OK();
39 case DT_FLOAT:
40 *type = builder.getF32Type();
41 return Status::OK();
42 case DT_DOUBLE:
43 *type = builder.getF64Type();
44 return Status::OK();
45 case DT_BOOL:
46 *type = builder.getIntegerType(1);
47 return Status::OK();
48 case DT_INT8:
49 *type = builder.getIntegerType(8);
50 return Status::OK();
51 case DT_INT16:
52 *type = builder.getIntegerType(16);
53 return Status::OK();
54 case DT_INT32:
55 *type = builder.getIntegerType(32);
56 return Status::OK();
57 case DT_INT64:
58 *type = builder.getIntegerType(64);
59 return Status::OK();
60 case DT_UINT8:
61 *type = builder.getIntegerType(8, /*isSigned=*/false);
62 return Status::OK();
63 case DT_UINT16:
64 *type = builder.getIntegerType(16, /*isSigned=*/false);
65 return Status::OK();
66 case DT_UINT32:
67 *type = builder.getIntegerType(32, /*isSigned=*/false);
68 return Status::OK();
69 case DT_UINT64:
70 *type = builder.getIntegerType(64, /*isSigned=*/false);
71 return Status::OK();
72 case DT_BFLOAT16:
73 *type = builder.getBF16Type();
74 return Status::OK();
75 case DT_COMPLEX64:
76 *type = mlir::ComplexType::get(builder.getF32Type());
77 return Status::OK();
78 case DT_COMPLEX128:
79 *type = mlir::ComplexType::get(builder.getF64Type());
80 return Status::OK();
81 #define HANDLE_TF_TYPE(tftype, enumerant, name) \
82 case DT_##enumerant: \
83 *type = builder.getType<mlir::TF::tftype##Type>(); \
84 return Status::OK();
85 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def"
86
87 default:
88 return errors::Unimplemented(absl::StrCat(
89 "Converting DataType '", DataTypeString(dtype), "' to MLIR Type"));
90 }
91 }
92
ConvertScalarTypeToDataType(Type type,DataType * dtype)93 Status ConvertScalarTypeToDataType(Type type, DataType* dtype) {
94 if (type.isF16()) {
95 *dtype = DT_HALF;
96 return Status::OK();
97 } else if (type.isF32()) {
98 *dtype = DT_FLOAT;
99 return Status::OK();
100 } else if (type.isF64()) {
101 *dtype = DT_DOUBLE;
102 return Status::OK();
103 } else if (type.isBF16()) {
104 *dtype = DT_BFLOAT16;
105 return Status::OK();
106 } else if (auto itype = type.dyn_cast<mlir::IntegerType>()) {
107 switch (itype.getWidth()) {
108 case 1:
109 *dtype = DT_BOOL;
110 return Status::OK();
111 case 8:
112 *dtype = itype.isUnsigned() ? DT_UINT8 : DT_INT8;
113 return Status::OK();
114 case 16:
115 *dtype = itype.isUnsigned() ? DT_UINT16 : DT_INT16;
116 return Status::OK();
117 case 32:
118 *dtype = itype.isUnsigned() ? DT_UINT32 : DT_INT32;
119 return Status::OK();
120 case 64:
121 *dtype = itype.isUnsigned() ? DT_UINT64 : DT_INT64;
122 return Status::OK();
123 default:
124 return errors::Unimplemented(
125 absl::StrCat("Converting ", debugString(type), " to DataType"));
126 }
127 } else if (auto complex_type = type.dyn_cast<mlir::ComplexType>()) {
128 auto etype = complex_type.getElementType();
129 if (etype.isF32()) {
130 *dtype = DT_COMPLEX64;
131 return Status::OK();
132 } else if (etype.isF64()) {
133 *dtype = DT_COMPLEX128;
134 return Status::OK();
135 }
136 return errors::Unimplemented(
137 absl::StrCat("Converting ", debugString(type), " to DataType"));
138 }
139
140 #define HANDLE_TF_TYPE(tftype, enumerant, name) \
141 if (type.isa<mlir::TF::tftype##Type>()) { \
142 *dtype = DT_##enumerant; \
143 return Status::OK(); \
144 }
145 // NOLINTNEXTLINE
146 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def"
147
148 return errors::Unimplemented(
149 absl::StrCat("Converting ", debugString(type), " to DataType"));
150 }
151
ConvertToDataType(Type type,DataType * dtype)152 Status ConvertToDataType(Type type, DataType* dtype) {
153 if (auto stype = type.dyn_cast<ShapedType>()) {
154 TF_RETURN_IF_ERROR(
155 ConvertScalarTypeToDataType(stype.getElementType(), dtype));
156 } else {
157 TF_RETURN_IF_ERROR(ConvertScalarTypeToDataType(type, dtype));
158 }
159 return Status::OK();
160 }
161
ConvertToMlirShape(const TensorShape & input_shape,llvm::SmallVectorImpl<int64_t> * shape)162 void ConvertToMlirShape(const TensorShape& input_shape,
163 llvm::SmallVectorImpl<int64_t>* shape) {
164 shape->reserve(input_shape.dims());
165 for (const auto& d : input_shape) {
166 shape->push_back(d.size);
167 }
168 }
169
ConvertToMlirShape(const TensorShapeProto & input_shape,llvm::SmallVectorImpl<int64_t> * shape)170 Status ConvertToMlirShape(const TensorShapeProto& input_shape,
171 llvm::SmallVectorImpl<int64_t>* shape) {
172 shape->reserve(input_shape.dim_size());
173 auto& dims = input_shape.dim();
174 for (auto& d : dims) {
175 if (d.size() > std::numeric_limits<int64_t>::max()) {
176 return errors::InvalidArgument("Shape element overflows");
177 }
178 shape->push_back(d.size());
179 }
180 return Status::OK();
181 }
182
ConvertToMlirTensorType(const TensorShapeProto & shape,DataType dtype,mlir::Builder * builder)183 StatusOr<mlir::Type> ConvertToMlirTensorType(const TensorShapeProto& shape,
184 DataType dtype,
185 mlir::Builder* builder) {
186 mlir::Type element_type;
187 TF_RETURN_IF_ERROR(ConvertDataType(dtype, *builder, &element_type));
188 if (shape.unknown_rank()) {
189 return mlir::UnrankedTensorType::get(element_type);
190 }
191 llvm::SmallVector<int64_t, 4> shape_dims;
192 TF_RETURN_IF_ERROR(ConvertToMlirShape(shape, &shape_dims));
193 return mlir::RankedTensorType::get(shape_dims, element_type);
194 }
195
196 } // namespace tensorflow
197