• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
17 
18 #include "llvm/Support/ErrorHandling.h"
19 #include "mlir/Dialect/Traits.h"  // from @llvm-project
20 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
21 #include "mlir/IR/Dialect.h"  // from @llvm-project
22 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
23 
24 namespace {
25 // Returns the shape of the given value if it's ranked; returns llvm::None
26 // otherwise.
GetShape(mlir::Value value)27 llvm::Optional<llvm::ArrayRef<int64_t>> GetShape(mlir::Value value) {
28   auto shaped_type = value.getType().cast<mlir::ShapedType>();
29   if (shaped_type.hasRank()) return shaped_type.getShape();
30   return llvm::None;
31 }
32 
33 // Merges cast compatible shapes and returns a more refined shape. The two
34 // shapes are cast compatible if they have the same rank and at each dimension,
35 // either both have same size or one of them is dynamic. Returns false if the
36 // given shapes are not cast compatible. The refined shape is same or more
37 // precise than the two input shapes.
GetCastCompatibleShape(llvm::ArrayRef<int64_t> a_shape,llvm::ArrayRef<int64_t> b_shape,llvm::SmallVectorImpl<int64_t> * refined_shape)38 bool GetCastCompatibleShape(llvm::ArrayRef<int64_t> a_shape,
39                             llvm::ArrayRef<int64_t> b_shape,
40                             llvm::SmallVectorImpl<int64_t>* refined_shape) {
41   if (a_shape.size() != b_shape.size()) return false;
42   int64_t rank = a_shape.size();
43   refined_shape->reserve(rank);
44   for (auto dims : llvm::zip(a_shape, b_shape)) {
45     int64_t dim1 = std::get<0>(dims);
46     int64_t dim2 = std::get<1>(dims);
47 
48     if (mlir::ShapedType::isDynamic(dim1)) {
49       refined_shape->push_back(dim2);
50       continue;
51     }
52     if (mlir::ShapedType::isDynamic(dim2)) {
53       refined_shape->push_back(dim1);
54       continue;
55     }
56     if (dim1 == dim2) {
57       refined_shape->push_back(dim1);
58       continue;
59     }
60     return false;
61   }
62   return true;
63 }
64 
65 }  // namespace
66 
67 namespace mlir {
68 namespace TF {
69 //===----------------------------------------------------------------------===//
70 // Utility iterators
71 //===----------------------------------------------------------------------===//
72 
OperandShapeIterator(Operation::operand_iterator it)73 OperandShapeIterator::OperandShapeIterator(Operation::operand_iterator it)
74     : llvm::mapped_iterator<Operation::operand_iterator,
75                             llvm::Optional<ArrayRef<int64_t>> (*)(Value)>(
76           it, &GetShape) {}
77 
ResultShapeIterator(Operation::result_iterator it)78 ResultShapeIterator::ResultShapeIterator(Operation::result_iterator it)
79     : llvm::mapped_iterator<Operation::result_iterator,
80                             llvm::Optional<ArrayRef<int64_t>> (*)(Value)>(
81           it, &GetShape) {}
82 
83 //===----------------------------------------------------------------------===//
84 // TF types helper functions
85 //===----------------------------------------------------------------------===//
86 
classof(Type type)87 bool TensorFlowType::classof(Type type) {
88   return type.getDialect().getNamespace() == "tf";
89 }
classof(Type type)90 bool TensorFlowRefType::classof(Type type) {
91   return type.isa<
92 #define HANDLE_TF_TYPE(tftype, enumerant, name)
93 #define HANDLE_TF_REF_TYPE(tftype, enumerant, name) tftype##Type,
94 #define HANDLE_LAST_TF_TYPE(tftype, enumerant, name) tftype##Type
95 // NOLINTNEXTLINE
96 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def"
97       >();
98 }
classof(Type type)99 bool TensorFlowTypeWithSubtype::classof(Type type) {
100   return type.isa<ResourceType, VariantType>();
101 }
102 
get(Type type)103 TensorFlowType TensorFlowRefType::get(Type type) {
104   MLIRContext* ctx = type.getContext();
105   type = getElementTypeOrSelf(type);
106   if (type.isF16()) {
107     return HalfRefType::get(ctx);
108   } else if (type.isF32()) {
109     return FloatRefType::get(ctx);
110   } else if (type.isF64()) {
111     return DoubleRefType::get(ctx);
112   } else if (type.isBF16()) {
113     return Bfloat16RefType::get(ctx);
114   } else if (auto complex_type = type.dyn_cast<ComplexType>()) {
115     Type etype = complex_type.getElementType();
116     if (etype.isF32()) {
117       return Complex64RefType::get(ctx);
118     } else if (etype.isF64()) {
119       return Complex128RefType::get(ctx);
120     }
121     llvm_unreachable("unexpected complex type");
122   } else if (auto itype = type.dyn_cast<IntegerType>()) {
123     switch (itype.getWidth()) {
124       case 1:
125         return BoolRefType::get(ctx);
126       case 8:
127         return itype.isUnsigned() ? TensorFlowType(Uint8RefType::get(ctx))
128                                   : Int8RefType::get(ctx);
129       case 16:
130         return itype.isUnsigned() ? TensorFlowType(Uint16RefType::get(ctx))
131                                   : Int16RefType::get(ctx);
132       case 32:
133         return itype.isUnsigned() ? TensorFlowType(Uint32RefType::get(ctx))
134                                   : Int32RefType::get(ctx);
135       case 64:
136         return itype.isUnsigned() ? TensorFlowType(Uint64RefType::get(ctx))
137                                   : Int64RefType::get(ctx);
138       default:
139         llvm_unreachable("unexpected integer type");
140     }
141   }
142 #define HANDLE_TF_TYPE(tftype, enumerant, name)        \
143   if (auto derived_ty = type.dyn_cast<tftype##Type>()) \
144     return tftype##RefType::get(ctx);
145 
146 #define HANDLE_TF_REF_TYPE(tftype, enumerant, name)
147 // NOLINTNEXTLINE
148 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def"
149   llvm_unreachable("unexpected type kind");
150 }
151 
RemoveRef()152 Type TensorFlowRefType::RemoveRef() {
153   MLIRContext* ctx = getContext();
154   if (isa<HalfRefType>()) return mlir::FloatType::getF16(ctx);
155   if (isa<FloatRefType>()) return mlir::FloatType::getF32(ctx);
156   if (isa<DoubleRefType>()) return mlir::FloatType::getF64(ctx);
157   if (isa<Bfloat16RefType>()) return mlir::FloatType::getBF16(ctx);
158   if (isa<BoolRefType>()) return mlir::IntegerType::get(ctx, 1);
159   if (isa<Int8RefType>()) return mlir::IntegerType::get(ctx, 8);
160   if (isa<Int16RefType>()) return mlir::IntegerType::get(ctx, 16);
161   if (isa<Int32RefType>()) return mlir::IntegerType::get(ctx, 32);
162   if (isa<Int64RefType>()) return mlir::IntegerType::get(ctx, 64);
163   if (isa<Uint8RefType>())
164     return mlir::IntegerType::get(ctx, 8, IntegerType::Unsigned);
165   if (isa<Uint16RefType>())
166     return mlir::IntegerType::get(ctx, 16, IntegerType::Unsigned);
167   if (isa<Uint32RefType>())
168     return mlir::IntegerType::get(ctx, 32, IntegerType::Unsigned);
169   if (isa<Uint64RefType>())
170     return mlir::IntegerType::get(ctx, 64, IntegerType::Unsigned);
171   if (isa<Complex64RefType>())
172     return mlir::ComplexType::get(mlir::FloatType::getF32(ctx));
173   if (isa<Complex128RefType>())
174     return mlir::ComplexType::get(mlir::FloatType::getF64(ctx));
175 #define HANDLE_TF_TYPE(tftype, enumerant, name) \
176   if (isa<tftype##RefType>()) return tftype##Type::get(ctx);
177 
178 #define HANDLE_TF_REF_TYPE(tftype, enumerant, name)
179 // NOLINTNEXTLINE
180 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def"
181   llvm_unreachable("unexpected tensorflow ref type kind");
182 }
183 
RemoveSubtypes()184 Type TensorFlowTypeWithSubtype::RemoveSubtypes() {
185   MLIRContext* ctx = getContext();
186   if (isa<VariantType>()) return VariantType::get(ctx);
187   if (isa<ResourceType>()) return ResourceType::get(ctx);
188   llvm_unreachable("unexpected tensorflow type with subtypes kind");
189 }
190 
GetSubtypes()191 ArrayRef<TensorType> TensorFlowTypeWithSubtype::GetSubtypes() {
192   if (auto variant_type = dyn_cast<VariantType>())
193     return variant_type.getSubtypes();
194   if (auto resource_type = dyn_cast<ResourceType>())
195     return resource_type.getSubtypes();
196   llvm_unreachable("unexpected tensorflow type with subtypes kind");
197 }
198 
199 // TODO(jpienaar): BroadcastCompatible and HasCompatibleElementTypes have
200 // similar structure that could be extracted into helper method.
BroadcastCompatible(ArrayRef<Type> lhs,ArrayRef<Type> rhs)201 bool BroadcastCompatible(ArrayRef<Type> lhs, ArrayRef<Type> rhs) {
202   if (lhs.size() != rhs.size()) return false;
203   for (auto types : llvm::zip(lhs, rhs)) {
204     // Drop ref types because they don't affect broadcast compatibility. E.g.,
205     // `tensor<!tf.f32ref>` and `tensor<f32>` should be considered broadcast
206     // compatible.
207     auto lhs_type = DropRefType(std::get<0>(types));
208     auto rhs_type = DropRefType(std::get<1>(types));
209 
210     // This should be true for all TF ops:
211     auto lhs_tt = lhs_type.dyn_cast<TensorType>();
212     auto rhs_tt = rhs_type.dyn_cast<TensorType>();
213     if (!lhs_tt || !rhs_tt) {
214       if (lhs_type != rhs_type) return false;
215       continue;
216     }
217 
218     // Verify matching element types. These should be identical, except for
219     // variant type where unknown subtype is considered compatible with all
220     // subtypes.
221     auto lhs_et = lhs_tt.getElementType();
222     auto rhs_et = rhs_tt.getElementType();
223     if (lhs_et != rhs_et) {
224       // If either does not have subtypes, then the element types don't match.
225       auto lhs_wst = lhs_et.dyn_cast<TF::TensorFlowTypeWithSubtype>();
226       auto rhs_wst = rhs_et.dyn_cast<TF::TensorFlowTypeWithSubtype>();
227       if (!lhs_wst || !rhs_wst) return false;
228 
229       // Consider the subtype of variant types.
230       auto lhs_wst_st = lhs_wst.GetSubtypes();
231       auto rhs_wst_st = rhs_wst.GetSubtypes();
232       if (!lhs_wst_st.empty() && !rhs_wst_st.empty()) {
233         for (auto subtypes : llvm::zip(lhs_wst_st, rhs_wst_st)) {
234           if (!BroadcastCompatible(std::get<0>(subtypes),
235                                    std::get<1>(subtypes)))
236             return false;
237         }
238       }
239     }
240 
241     auto lhs_rt = lhs_type.dyn_cast<RankedTensorType>();
242     auto rhs_rt = rhs_type.dyn_cast<RankedTensorType>();
243     if (!lhs_rt || !rhs_rt) return true;
244     SmallVector<int64_t, 4> shape;
245     return OpTrait::util::getBroadcastedShape(lhs_rt.getShape(),
246                                               rhs_rt.getShape(), shape);
247   }
248   return true;
249 }
250 
251 // Given two types `a` and `b`, returns a refined type which is cast compatible
252 // with both `a` and `b` and is equal to or more precise than both of them. It
253 // returns empty Type if the input types are not cast compatible.
254 //
255 // The two types are considered cast compatible if they have dynamically equal
256 // shapes and element type. For element types that do not have subtypes, they
257 // must be equal. However for TensorFlow types such as Resource and Variant,
258 // that also have subtypes, we recursively check for subtype compatibilty for
259 // Resource types and assume all variant types are cast compatible. If either
260 // one of `a` or `b` have empty subtypes, they are considered cast compatible.
261 //
262 // The returned type is same or more precise than the input types. For example,
263 // if `a` and `b` are cast compatible types tensor<2x?x?xf32> and
264 // tensor<?x4x?xf32> respectively, the returned type is tensor<2x4x?xf32>.
265 //
266 // Provides option to ignore ref types on 'a'. This is useful for TF ops that
267 // might allow operands to either be same as result type or be a ref type
268 // corresponding to it.
GetCastCompatibleType(mlir::Type a,mlir::Type b,bool may_ignore_ref_type_a)269 mlir::Type GetCastCompatibleType(mlir::Type a, mlir::Type b,
270                                  bool may_ignore_ref_type_a) {
271   // Fast path if everything is equal.
272   if (a == b) return b;
273 
274   auto a_tt = a.dyn_cast<mlir::TensorType>();
275   auto b_tt = b.dyn_cast<mlir::TensorType>();
276 
277   // If only one of a or b is a tensor type, they are incompatible.
278   if (static_cast<bool>(a_tt) ^ static_cast<bool>(b_tt)) return nullptr;
279 
280   // For non-tensor types, we do not need to worry about shape and can return
281   // early.
282   if (!a_tt && !b_tt) {
283     // Remove ref types.
284     if (may_ignore_ref_type_a) {
285       if (auto ref_type = a.dyn_cast<mlir::TF::TensorFlowRefType>()) {
286         a = ref_type.RemoveRef();
287         if (a == b) return a;
288       }
289     }
290     if (a.getTypeID() != b.getTypeID()) return nullptr;
291 
292     // If either is not a type that contain subtypes then the types are not cast
293     // compatible.
294     auto a_wst = a.dyn_cast<mlir::TF::TensorFlowTypeWithSubtype>();
295     auto b_wst = b.dyn_cast<mlir::TF::TensorFlowTypeWithSubtype>();
296     if (!a_wst || !b_wst) return nullptr;
297 
298     // For Variant types we are more permissive right now and accept all pairs
299     // of Variant types. If we are more constrainted and check compatibility of
300     // subtypes, we might reject valid graphs.
301     // TODO(prakalps): Variant doesn't have a subtype, we assign it
302     // one, so we should only assign it one when we know the subtype. Then we
303     // can be more constrained and check subtypes for cast compatibility as
304     // well.
305     if (a.isa<mlir::TF::VariantType>()) return a;
306 
307     // For Resource types, we recursively check the subtypes for cast
308     // compatibility, if possible. Otherwise treat them as compatible.
309     auto a_wst_st = a_wst.GetSubtypes();
310     auto b_wst_st = b_wst.GetSubtypes();
311     if (a_wst_st.empty() || b_wst_st.empty()) return a;
312     if (a_wst_st.size() != b_wst_st.size()) return nullptr;
313     llvm::SmallVector<mlir::TensorType, 4> refined_subtypes;
314     for (auto subtypes : llvm::zip(a_wst_st, b_wst_st)) {
315       mlir::Type refined_st =
316           GetCastCompatibleType(std::get<0>(subtypes), std::get<1>(subtypes),
317                                 /*may_ignore_ref_type_a=*/false);
318       if (!refined_st) return nullptr;
319       refined_subtypes.push_back(refined_st.cast<mlir::TensorType>());
320     }
321 
322     return mlir::TF::ResourceType::get(refined_subtypes, a.getContext());
323   }
324 
325   // For tensor types, check compatibility of both element type and shape.
326   mlir::Type refined_element_ty = GetCastCompatibleType(
327       a_tt.getElementType(), b_tt.getElementType(), may_ignore_ref_type_a);
328   if (!refined_element_ty) return nullptr;
329 
330   if (!a_tt.hasRank() && !b_tt.hasRank()) {
331     return mlir::UnrankedTensorType::get(refined_element_ty);
332   }
333   if (!a_tt.hasRank()) {
334     return mlir::RankedTensorType::get(b_tt.getShape(), refined_element_ty);
335   }
336   if (!b_tt.hasRank()) {
337     return mlir::RankedTensorType::get(a_tt.getShape(), refined_element_ty);
338   }
339 
340   llvm::SmallVector<int64_t, 8> refined_shape;
341   if (!GetCastCompatibleShape(a_tt.getShape(), b_tt.getShape(), &refined_shape))
342     return nullptr;
343 
344   return mlir::RankedTensorType::get(refined_shape, refined_element_ty);
345 }
346 
HasCompatibleElementTypes(Type lhs,Type rhs,bool may_ignore_ref_type_lhs)347 bool HasCompatibleElementTypes(Type lhs, Type rhs,
348                                bool may_ignore_ref_type_lhs) {
349   return GetCastCompatibleType(lhs, rhs, may_ignore_ref_type_lhs) != nullptr;
350 }
351 
AreCastCompatible(ArrayRef<Type> types)352 bool AreCastCompatible(ArrayRef<Type> types) {
353   Type common = types.front();
354   for (auto type : types.drop_front()) {
355     Type refined_type =
356         GetCastCompatibleType(common, type, /*may_ignore_ref_type_a=*/false);
357     if (!refined_type) return false;
358     common = refined_type;
359   }
360   return true;
361 }
362 
ArraysAreCastCompatible(ArrayRef<Type> lhs,ArrayRef<Type> rhs)363 bool ArraysAreCastCompatible(ArrayRef<Type> lhs, ArrayRef<Type> rhs) {
364   if (lhs.size() != rhs.size()) return false;
365   for (auto pair : llvm::zip(lhs, rhs)) {
366     auto lhs_i = std::get<0>(pair);
367     auto rhs_i = std::get<1>(pair);
368     if (!AreCastCompatible({lhs_i, rhs_i})) return false;
369   }
370   return true;
371 }
372 
373 // Assumes a function `GetDefaultTypeOf(ComposedType)` that returns the default
374 // type for a composed type (such as a ref type or a type with subtypes).
375 template <typename ComposedType>
DropTypeHelper(Type ty)376 Type DropTypeHelper(Type ty) {
377   Type element_ty = getElementTypeOrSelf(ty);
378   auto composed_type = element_ty.dyn_cast<ComposedType>();
379   if (!composed_type) return ty;
380 
381   Type default_ty = GetDefaultTypeOf(composed_type);
382   if (auto ranked_ty = ty.dyn_cast<RankedTensorType>()) {
383     return RankedTensorType::get(ranked_ty.getShape(), default_ty);
384   } else if (ty.dyn_cast<UnrankedTensorType>()) {
385     return UnrankedTensorType::get(default_ty);
386   } else {
387     return default_ty;
388   }
389 }
390 
DropSubTypes(Type ty)391 Type DropSubTypes(Type ty) {
392   return DropTypeHelper<TF::TensorFlowTypeWithSubtype>(ty);
393 }
394 
DropRefType(Type ty)395 Type DropRefType(Type ty) { return DropTypeHelper<TF::TensorFlowRefType>(ty); }
396 
DropRefAndSubTypes(Type ty)397 Type DropRefAndSubTypes(Type ty) { return DropRefType(DropSubTypes(ty)); }
398 
399 }  // namespace TF
400 }  // namespace mlir
401