• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
17 
18 #include "absl/strings/str_cat.h"
19 #include "llvm/Support/Casting.h"
20 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
21 #include "mlir/IR/Types.h"  // from @llvm-project
22 #include "mlir/Support/DebugStringHelper.h"  // from @llvm-project
23 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
24 #include "tensorflow/core/framework/types.h"
25 #include "tensorflow/core/framework/types.pb.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 
28 namespace tensorflow {
29 
30 using mlir::Builder;
31 using mlir::ShapedType;
32 using mlir::Type;
33 
ConvertDataType(DataType dtype,Builder builder,Type * type)34 Status ConvertDataType(DataType dtype, Builder builder, Type* type) {
35   switch (dtype) {
36     case DT_HALF:
37       *type = builder.getF16Type();
38       return Status::OK();
39     case DT_FLOAT:
40       *type = builder.getF32Type();
41       return Status::OK();
42     case DT_DOUBLE:
43       *type = builder.getF64Type();
44       return Status::OK();
45     case DT_BOOL:
46       *type = builder.getIntegerType(1);
47       return Status::OK();
48     case DT_INT8:
49       *type = builder.getIntegerType(8);
50       return Status::OK();
51     case DT_INT16:
52       *type = builder.getIntegerType(16);
53       return Status::OK();
54     case DT_INT32:
55       *type = builder.getIntegerType(32);
56       return Status::OK();
57     case DT_INT64:
58       *type = builder.getIntegerType(64);
59       return Status::OK();
60     case DT_UINT8:
61       *type = builder.getIntegerType(8, /*isSigned=*/false);
62       return Status::OK();
63     case DT_UINT16:
64       *type = builder.getIntegerType(16, /*isSigned=*/false);
65       return Status::OK();
66     case DT_UINT32:
67       *type = builder.getIntegerType(32, /*isSigned=*/false);
68       return Status::OK();
69     case DT_UINT64:
70       *type = builder.getIntegerType(64, /*isSigned=*/false);
71       return Status::OK();
72     case DT_BFLOAT16:
73       *type = builder.getBF16Type();
74       return Status::OK();
75     case DT_COMPLEX64:
76       *type = mlir::ComplexType::get(builder.getF32Type());
77       return Status::OK();
78     case DT_COMPLEX128:
79       *type = mlir::ComplexType::get(builder.getF64Type());
80       return Status::OK();
81 #define HANDLE_TF_TYPE(tftype, enumerant, name)        \
82   case DT_##enumerant:                                 \
83     *type = builder.getType<mlir::TF::tftype##Type>(); \
84     return Status::OK();
85 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def"
86 
87     default:
88       return errors::Unimplemented(absl::StrCat(
89           "Converting DataType '", DataTypeString(dtype), "' to MLIR Type"));
90   }
91 }
92 
ConvertScalarTypeToDataType(Type type,DataType * dtype)93 Status ConvertScalarTypeToDataType(Type type, DataType* dtype) {
94   if (type.isF16()) {
95     *dtype = DT_HALF;
96     return Status::OK();
97   } else if (type.isF32()) {
98     *dtype = DT_FLOAT;
99     return Status::OK();
100   } else if (type.isF64()) {
101     *dtype = DT_DOUBLE;
102     return Status::OK();
103   } else if (type.isBF16()) {
104     *dtype = DT_BFLOAT16;
105     return Status::OK();
106   } else if (auto itype = type.dyn_cast<mlir::IntegerType>()) {
107     switch (itype.getWidth()) {
108       case 1:
109         *dtype = DT_BOOL;
110         return Status::OK();
111       case 8:
112         *dtype = itype.isUnsigned() ? DT_UINT8 : DT_INT8;
113         return Status::OK();
114       case 16:
115         *dtype = itype.isUnsigned() ? DT_UINT16 : DT_INT16;
116         return Status::OK();
117       case 32:
118         *dtype = itype.isUnsigned() ? DT_UINT32 : DT_INT32;
119         return Status::OK();
120       case 64:
121         *dtype = itype.isUnsigned() ? DT_UINT64 : DT_INT64;
122         return Status::OK();
123       default:
124         return errors::Unimplemented(
125             absl::StrCat("Converting ", debugString(type), " to DataType"));
126     }
127   } else if (auto complex_type = type.dyn_cast<mlir::ComplexType>()) {
128     auto etype = complex_type.getElementType();
129     if (etype.isF32()) {
130       *dtype = DT_COMPLEX64;
131       return Status::OK();
132     } else if (etype.isF64()) {
133       *dtype = DT_COMPLEX128;
134       return Status::OK();
135     }
136     return errors::Unimplemented(
137         absl::StrCat("Converting ", debugString(type), " to DataType"));
138   }
139 
140 #define HANDLE_TF_TYPE(tftype, enumerant, name) \
141   if (type.isa<mlir::TF::tftype##Type>()) {     \
142     *dtype = DT_##enumerant;                    \
143     return Status::OK();                        \
144   }
145 // NOLINTNEXTLINE
146 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def"
147 
148   return errors::Unimplemented(
149       absl::StrCat("Converting ", debugString(type), " to DataType"));
150 }
151 
ConvertToDataType(Type type,DataType * dtype)152 Status ConvertToDataType(Type type, DataType* dtype) {
153   if (auto stype = type.dyn_cast<ShapedType>()) {
154     TF_RETURN_IF_ERROR(
155         ConvertScalarTypeToDataType(stype.getElementType(), dtype));
156   } else {
157     TF_RETURN_IF_ERROR(ConvertScalarTypeToDataType(type, dtype));
158   }
159   return Status::OK();
160 }
161 
ConvertToMlirShape(const TensorShape & input_shape,llvm::SmallVectorImpl<int64_t> * shape)162 void ConvertToMlirShape(const TensorShape& input_shape,
163                         llvm::SmallVectorImpl<int64_t>* shape) {
164   shape->reserve(input_shape.dims());
165   for (const auto& d : input_shape) {
166     shape->push_back(d.size);
167   }
168 }
169 
ConvertToMlirShape(const TensorShapeProto & input_shape,llvm::SmallVectorImpl<int64_t> * shape)170 Status ConvertToMlirShape(const TensorShapeProto& input_shape,
171                           llvm::SmallVectorImpl<int64_t>* shape) {
172   shape->reserve(input_shape.dim_size());
173   auto& dims = input_shape.dim();
174   for (auto& d : dims) {
175     if (d.size() > std::numeric_limits<int64_t>::max()) {
176       return errors::InvalidArgument("Shape element overflows");
177     }
178     shape->push_back(d.size());
179   }
180   return Status::OK();
181 }
182 
ConvertToMlirTensorType(const TensorShapeProto & shape,DataType dtype,mlir::Builder * builder)183 StatusOr<mlir::Type> ConvertToMlirTensorType(const TensorShapeProto& shape,
184                                              DataType dtype,
185                                              mlir::Builder* builder) {
186   mlir::Type element_type;
187   TF_RETURN_IF_ERROR(ConvertDataType(dtype, *builder, &element_type));
188   if (shape.unknown_rank()) {
189     return mlir::UnrankedTensorType::get(element_type);
190   }
191   llvm::SmallVector<int64_t, 4> shape_dims;
192   TF_RETURN_IF_ERROR(ConvertToMlirShape(shape, &shape_dims));
193   return mlir::RankedTensorType::get(shape_dims, element_type);
194 }
195 
196 }  // namespace tensorflow
197