1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
17
18 #include "llvm/Support/ErrorHandling.h"
19 #include "mlir/Dialect/Traits.h" // from @llvm-project
20 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
21 #include "mlir/IR/Dialect.h" // from @llvm-project
22 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
23
24 namespace {
25 // Returns the shape of the given value if it's ranked; returns llvm::None
26 // otherwise.
GetShape(mlir::Value value)27 llvm::Optional<llvm::ArrayRef<int64_t>> GetShape(mlir::Value value) {
28 auto shaped_type = value.getType().cast<mlir::ShapedType>();
29 if (shaped_type.hasRank()) return shaped_type.getShape();
30 return llvm::None;
31 }
32
33 // Merges cast compatible shapes and returns a more refined shape. The two
34 // shapes are cast compatible if they have the same rank and at each dimension,
35 // either both have same size or one of them is dynamic. Returns false if the
36 // given shapes are not cast compatible. The refined shape is same or more
37 // precise than the two input shapes.
GetCastCompatibleShape(llvm::ArrayRef<int64_t> a_shape,llvm::ArrayRef<int64_t> b_shape,llvm::SmallVectorImpl<int64_t> * refined_shape)38 bool GetCastCompatibleShape(llvm::ArrayRef<int64_t> a_shape,
39 llvm::ArrayRef<int64_t> b_shape,
40 llvm::SmallVectorImpl<int64_t>* refined_shape) {
41 if (a_shape.size() != b_shape.size()) return false;
42 int64_t rank = a_shape.size();
43 refined_shape->reserve(rank);
44 for (auto dims : llvm::zip(a_shape, b_shape)) {
45 int64_t dim1 = std::get<0>(dims);
46 int64_t dim2 = std::get<1>(dims);
47
48 if (mlir::ShapedType::isDynamic(dim1)) {
49 refined_shape->push_back(dim2);
50 continue;
51 }
52 if (mlir::ShapedType::isDynamic(dim2)) {
53 refined_shape->push_back(dim1);
54 continue;
55 }
56 if (dim1 == dim2) {
57 refined_shape->push_back(dim1);
58 continue;
59 }
60 return false;
61 }
62 return true;
63 }
64
65 } // namespace
66
67 namespace mlir {
68 namespace TF {
69 //===----------------------------------------------------------------------===//
70 // Utility iterators
71 //===----------------------------------------------------------------------===//
72
OperandShapeIterator(Operation::operand_iterator it)73 OperandShapeIterator::OperandShapeIterator(Operation::operand_iterator it)
74 : llvm::mapped_iterator<Operation::operand_iterator,
75 llvm::Optional<ArrayRef<int64_t>> (*)(Value)>(
76 it, &GetShape) {}
77
ResultShapeIterator(Operation::result_iterator it)78 ResultShapeIterator::ResultShapeIterator(Operation::result_iterator it)
79 : llvm::mapped_iterator<Operation::result_iterator,
80 llvm::Optional<ArrayRef<int64_t>> (*)(Value)>(
81 it, &GetShape) {}
82
83 //===----------------------------------------------------------------------===//
84 // TF types helper functions
85 //===----------------------------------------------------------------------===//
86
classof(Type type)87 bool TensorFlowType::classof(Type type) {
88 return type.getDialect().getNamespace() == "tf";
89 }
classof(Type type)90 bool TensorFlowRefType::classof(Type type) {
91 return type.isa<
92 #define HANDLE_TF_TYPE(tftype, enumerant, name)
93 #define HANDLE_TF_REF_TYPE(tftype, enumerant, name) tftype##Type,
94 #define HANDLE_LAST_TF_TYPE(tftype, enumerant, name) tftype##Type
95 // NOLINTNEXTLINE
96 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def"
97 >();
98 }
classof(Type type)99 bool TensorFlowTypeWithSubtype::classof(Type type) {
100 return type.isa<ResourceType, VariantType>();
101 }
102
get(Type type)103 TensorFlowType TensorFlowRefType::get(Type type) {
104 MLIRContext* ctx = type.getContext();
105 type = getElementTypeOrSelf(type);
106 if (type.isF16()) {
107 return HalfRefType::get(ctx);
108 } else if (type.isF32()) {
109 return FloatRefType::get(ctx);
110 } else if (type.isF64()) {
111 return DoubleRefType::get(ctx);
112 } else if (type.isBF16()) {
113 return Bfloat16RefType::get(ctx);
114 } else if (auto complex_type = type.dyn_cast<ComplexType>()) {
115 Type etype = complex_type.getElementType();
116 if (etype.isF32()) {
117 return Complex64RefType::get(ctx);
118 } else if (etype.isF64()) {
119 return Complex128RefType::get(ctx);
120 }
121 llvm_unreachable("unexpected complex type");
122 } else if (auto itype = type.dyn_cast<IntegerType>()) {
123 switch (itype.getWidth()) {
124 case 1:
125 return BoolRefType::get(ctx);
126 case 8:
127 return itype.isUnsigned() ? TensorFlowType(Uint8RefType::get(ctx))
128 : Int8RefType::get(ctx);
129 case 16:
130 return itype.isUnsigned() ? TensorFlowType(Uint16RefType::get(ctx))
131 : Int16RefType::get(ctx);
132 case 32:
133 return itype.isUnsigned() ? TensorFlowType(Uint32RefType::get(ctx))
134 : Int32RefType::get(ctx);
135 case 64:
136 return itype.isUnsigned() ? TensorFlowType(Uint64RefType::get(ctx))
137 : Int64RefType::get(ctx);
138 default:
139 llvm_unreachable("unexpected integer type");
140 }
141 }
142 #define HANDLE_TF_TYPE(tftype, enumerant, name) \
143 if (auto derived_ty = type.dyn_cast<tftype##Type>()) \
144 return tftype##RefType::get(ctx);
145
146 #define HANDLE_TF_REF_TYPE(tftype, enumerant, name)
147 // NOLINTNEXTLINE
148 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def"
149 llvm_unreachable("unexpected type kind");
150 }
151
RemoveRef()152 Type TensorFlowRefType::RemoveRef() {
153 MLIRContext* ctx = getContext();
154 if (isa<HalfRefType>()) return mlir::FloatType::getF16(ctx);
155 if (isa<FloatRefType>()) return mlir::FloatType::getF32(ctx);
156 if (isa<DoubleRefType>()) return mlir::FloatType::getF64(ctx);
157 if (isa<Bfloat16RefType>()) return mlir::FloatType::getBF16(ctx);
158 if (isa<BoolRefType>()) return mlir::IntegerType::get(ctx, 1);
159 if (isa<Int8RefType>()) return mlir::IntegerType::get(ctx, 8);
160 if (isa<Int16RefType>()) return mlir::IntegerType::get(ctx, 16);
161 if (isa<Int32RefType>()) return mlir::IntegerType::get(ctx, 32);
162 if (isa<Int64RefType>()) return mlir::IntegerType::get(ctx, 64);
163 if (isa<Uint8RefType>())
164 return mlir::IntegerType::get(ctx, 8, IntegerType::Unsigned);
165 if (isa<Uint16RefType>())
166 return mlir::IntegerType::get(ctx, 16, IntegerType::Unsigned);
167 if (isa<Uint32RefType>())
168 return mlir::IntegerType::get(ctx, 32, IntegerType::Unsigned);
169 if (isa<Uint64RefType>())
170 return mlir::IntegerType::get(ctx, 64, IntegerType::Unsigned);
171 if (isa<Complex64RefType>())
172 return mlir::ComplexType::get(mlir::FloatType::getF32(ctx));
173 if (isa<Complex128RefType>())
174 return mlir::ComplexType::get(mlir::FloatType::getF64(ctx));
175 #define HANDLE_TF_TYPE(tftype, enumerant, name) \
176 if (isa<tftype##RefType>()) return tftype##Type::get(ctx);
177
178 #define HANDLE_TF_REF_TYPE(tftype, enumerant, name)
179 // NOLINTNEXTLINE
180 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def"
181 llvm_unreachable("unexpected tensorflow ref type kind");
182 }
183
RemoveSubtypes()184 Type TensorFlowTypeWithSubtype::RemoveSubtypes() {
185 MLIRContext* ctx = getContext();
186 if (isa<VariantType>()) return VariantType::get(ctx);
187 if (isa<ResourceType>()) return ResourceType::get(ctx);
188 llvm_unreachable("unexpected tensorflow type with subtypes kind");
189 }
190
GetSubtypes()191 ArrayRef<TensorType> TensorFlowTypeWithSubtype::GetSubtypes() {
192 if (auto variant_type = dyn_cast<VariantType>())
193 return variant_type.getSubtypes();
194 if (auto resource_type = dyn_cast<ResourceType>())
195 return resource_type.getSubtypes();
196 llvm_unreachable("unexpected tensorflow type with subtypes kind");
197 }
198
199 // TODO(jpienaar): BroadcastCompatible and HasCompatibleElementTypes have
200 // similar structure that could be extracted into helper method.
BroadcastCompatible(ArrayRef<Type> lhs,ArrayRef<Type> rhs)201 bool BroadcastCompatible(ArrayRef<Type> lhs, ArrayRef<Type> rhs) {
202 if (lhs.size() != rhs.size()) return false;
203 for (auto types : llvm::zip(lhs, rhs)) {
204 // Drop ref types because they don't affect broadcast compatibility. E.g.,
205 // `tensor<!tf.f32ref>` and `tensor<f32>` should be considered broadcast
206 // compatible.
207 auto lhs_type = DropRefType(std::get<0>(types));
208 auto rhs_type = DropRefType(std::get<1>(types));
209
210 // This should be true for all TF ops:
211 auto lhs_tt = lhs_type.dyn_cast<TensorType>();
212 auto rhs_tt = rhs_type.dyn_cast<TensorType>();
213 if (!lhs_tt || !rhs_tt) {
214 if (lhs_type != rhs_type) return false;
215 continue;
216 }
217
218 // Verify matching element types. These should be identical, except for
219 // variant type where unknown subtype is considered compatible with all
220 // subtypes.
221 auto lhs_et = lhs_tt.getElementType();
222 auto rhs_et = rhs_tt.getElementType();
223 if (lhs_et != rhs_et) {
224 // If either does not have subtypes, then the element types don't match.
225 auto lhs_wst = lhs_et.dyn_cast<TF::TensorFlowTypeWithSubtype>();
226 auto rhs_wst = rhs_et.dyn_cast<TF::TensorFlowTypeWithSubtype>();
227 if (!lhs_wst || !rhs_wst) return false;
228
229 // Consider the subtype of variant types.
230 auto lhs_wst_st = lhs_wst.GetSubtypes();
231 auto rhs_wst_st = rhs_wst.GetSubtypes();
232 if (!lhs_wst_st.empty() && !rhs_wst_st.empty()) {
233 for (auto subtypes : llvm::zip(lhs_wst_st, rhs_wst_st)) {
234 if (!BroadcastCompatible(std::get<0>(subtypes),
235 std::get<1>(subtypes)))
236 return false;
237 }
238 }
239 }
240
241 auto lhs_rt = lhs_type.dyn_cast<RankedTensorType>();
242 auto rhs_rt = rhs_type.dyn_cast<RankedTensorType>();
243 if (!lhs_rt || !rhs_rt) return true;
244 SmallVector<int64_t, 4> shape;
245 return OpTrait::util::getBroadcastedShape(lhs_rt.getShape(),
246 rhs_rt.getShape(), shape);
247 }
248 return true;
249 }
250
251 // Given two types `a` and `b`, returns a refined type which is cast compatible
252 // with both `a` and `b` and is equal to or more precise than both of them. It
253 // returns empty Type if the input types are not cast compatible.
254 //
255 // The two types are considered cast compatible if they have dynamically equal
256 // shapes and element type. For element types that do not have subtypes, they
257 // must be equal. However for TensorFlow types such as Resource and Variant,
258 // that also have subtypes, we recursively check for subtype compatibilty for
259 // Resource types and assume all variant types are cast compatible. If either
260 // one of `a` or `b` have empty subtypes, they are considered cast compatible.
261 //
262 // The returned type is same or more precise than the input types. For example,
263 // if `a` and `b` are cast compatible types tensor<2x?x?xf32> and
264 // tensor<?x4x?xf32> respectively, the returned type is tensor<2x4x?xf32>.
265 //
266 // Provides option to ignore ref types on 'a'. This is useful for TF ops that
267 // might allow operands to either be same as result type or be a ref type
268 // corresponding to it.
GetCastCompatibleType(mlir::Type a,mlir::Type b,bool may_ignore_ref_type_a)269 mlir::Type GetCastCompatibleType(mlir::Type a, mlir::Type b,
270 bool may_ignore_ref_type_a) {
271 // Fast path if everything is equal.
272 if (a == b) return b;
273
274 auto a_tt = a.dyn_cast<mlir::TensorType>();
275 auto b_tt = b.dyn_cast<mlir::TensorType>();
276
277 // If only one of a or b is a tensor type, they are incompatible.
278 if (static_cast<bool>(a_tt) ^ static_cast<bool>(b_tt)) return nullptr;
279
280 // For non-tensor types, we do not need to worry about shape and can return
281 // early.
282 if (!a_tt && !b_tt) {
283 // Remove ref types.
284 if (may_ignore_ref_type_a) {
285 if (auto ref_type = a.dyn_cast<mlir::TF::TensorFlowRefType>()) {
286 a = ref_type.RemoveRef();
287 if (a == b) return a;
288 }
289 }
290 if (a.getTypeID() != b.getTypeID()) return nullptr;
291
292 // If either is not a type that contain subtypes then the types are not cast
293 // compatible.
294 auto a_wst = a.dyn_cast<mlir::TF::TensorFlowTypeWithSubtype>();
295 auto b_wst = b.dyn_cast<mlir::TF::TensorFlowTypeWithSubtype>();
296 if (!a_wst || !b_wst) return nullptr;
297
298 // For Variant types we are more permissive right now and accept all pairs
299 // of Variant types. If we are more constrainted and check compatibility of
300 // subtypes, we might reject valid graphs.
301 // TODO(prakalps): Variant doesn't have a subtype, we assign it
302 // one, so we should only assign it one when we know the subtype. Then we
303 // can be more constrained and check subtypes for cast compatibility as
304 // well.
305 if (a.isa<mlir::TF::VariantType>()) return a;
306
307 // For Resource types, we recursively check the subtypes for cast
308 // compatibility, if possible. Otherwise treat them as compatible.
309 auto a_wst_st = a_wst.GetSubtypes();
310 auto b_wst_st = b_wst.GetSubtypes();
311 if (a_wst_st.empty() || b_wst_st.empty()) return a;
312 if (a_wst_st.size() != b_wst_st.size()) return nullptr;
313 llvm::SmallVector<mlir::TensorType, 4> refined_subtypes;
314 for (auto subtypes : llvm::zip(a_wst_st, b_wst_st)) {
315 mlir::Type refined_st =
316 GetCastCompatibleType(std::get<0>(subtypes), std::get<1>(subtypes),
317 /*may_ignore_ref_type_a=*/false);
318 if (!refined_st) return nullptr;
319 refined_subtypes.push_back(refined_st.cast<mlir::TensorType>());
320 }
321
322 return mlir::TF::ResourceType::get(refined_subtypes, a.getContext());
323 }
324
325 // For tensor types, check compatibility of both element type and shape.
326 mlir::Type refined_element_ty = GetCastCompatibleType(
327 a_tt.getElementType(), b_tt.getElementType(), may_ignore_ref_type_a);
328 if (!refined_element_ty) return nullptr;
329
330 if (!a_tt.hasRank() && !b_tt.hasRank()) {
331 return mlir::UnrankedTensorType::get(refined_element_ty);
332 }
333 if (!a_tt.hasRank()) {
334 return mlir::RankedTensorType::get(b_tt.getShape(), refined_element_ty);
335 }
336 if (!b_tt.hasRank()) {
337 return mlir::RankedTensorType::get(a_tt.getShape(), refined_element_ty);
338 }
339
340 llvm::SmallVector<int64_t, 8> refined_shape;
341 if (!GetCastCompatibleShape(a_tt.getShape(), b_tt.getShape(), &refined_shape))
342 return nullptr;
343
344 return mlir::RankedTensorType::get(refined_shape, refined_element_ty);
345 }
346
HasCompatibleElementTypes(Type lhs,Type rhs,bool may_ignore_ref_type_lhs)347 bool HasCompatibleElementTypes(Type lhs, Type rhs,
348 bool may_ignore_ref_type_lhs) {
349 return GetCastCompatibleType(lhs, rhs, may_ignore_ref_type_lhs) != nullptr;
350 }
351
AreCastCompatible(ArrayRef<Type> types)352 bool AreCastCompatible(ArrayRef<Type> types) {
353 Type common = types.front();
354 for (auto type : types.drop_front()) {
355 Type refined_type =
356 GetCastCompatibleType(common, type, /*may_ignore_ref_type_a=*/false);
357 if (!refined_type) return false;
358 common = refined_type;
359 }
360 return true;
361 }
362
ArraysAreCastCompatible(ArrayRef<Type> lhs,ArrayRef<Type> rhs)363 bool ArraysAreCastCompatible(ArrayRef<Type> lhs, ArrayRef<Type> rhs) {
364 if (lhs.size() != rhs.size()) return false;
365 for (auto pair : llvm::zip(lhs, rhs)) {
366 auto lhs_i = std::get<0>(pair);
367 auto rhs_i = std::get<1>(pair);
368 if (!AreCastCompatible({lhs_i, rhs_i})) return false;
369 }
370 return true;
371 }
372
373 // Assumes a function `GetDefaultTypeOf(ComposedType)` that returns the default
374 // type for a composed type (such as a ref type or a type with subtypes).
375 template <typename ComposedType>
DropTypeHelper(Type ty)376 Type DropTypeHelper(Type ty) {
377 Type element_ty = getElementTypeOrSelf(ty);
378 auto composed_type = element_ty.dyn_cast<ComposedType>();
379 if (!composed_type) return ty;
380
381 Type default_ty = GetDefaultTypeOf(composed_type);
382 if (auto ranked_ty = ty.dyn_cast<RankedTensorType>()) {
383 return RankedTensorType::get(ranked_ty.getShape(), default_ty);
384 } else if (ty.dyn_cast<UnrankedTensorType>()) {
385 return UnrankedTensorType::get(default_ty);
386 } else {
387 return default_ty;
388 }
389 }
390
DropSubTypes(Type ty)391 Type DropSubTypes(Type ty) {
392 return DropTypeHelper<TF::TensorFlowTypeWithSubtype>(ty);
393 }
394
DropRefType(Type ty)395 Type DropRefType(Type ty) { return DropTypeHelper<TF::TensorFlowRefType>(ty); }
396
DropRefAndSubTypes(Type ty)397 Type DropRefAndSubTypes(Type ty) { return DropRefType(DropSubTypes(ty)); }
398
399 } // namespace TF
400 } // namespace mlir
401