• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- TypeUtilities.cpp - Helper function for type queries ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines generic type utilities.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/IR/TypeUtilities.h"
14 #include "mlir/IR/Attributes.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Types.h"
17 #include "mlir/IR/Value.h"
18 
19 using namespace mlir;
20 
getElementTypeOrSelf(Type type)21 Type mlir::getElementTypeOrSelf(Type type) {
22   if (auto st = type.dyn_cast<ShapedType>())
23     return st.getElementType();
24   return type;
25 }
26 
getElementTypeOrSelf(Value val)27 Type mlir::getElementTypeOrSelf(Value val) {
28   return getElementTypeOrSelf(val.getType());
29 }
30 
getElementTypeOrSelf(Attribute attr)31 Type mlir::getElementTypeOrSelf(Attribute attr) {
32   return getElementTypeOrSelf(attr.getType());
33 }
34 
getFlattenedTypes(TupleType t)35 SmallVector<Type, 10> mlir::getFlattenedTypes(TupleType t) {
36   SmallVector<Type, 10> fTypes;
37   t.getFlattenedTypes(fTypes);
38   return fTypes;
39 }
40 
41 /// Return true if the specified type is an opaque type with the specified
42 /// dialect and typeData.
isOpaqueTypeWithName(Type type,StringRef dialect,StringRef typeData)43 bool mlir::isOpaqueTypeWithName(Type type, StringRef dialect,
44                                 StringRef typeData) {
45   if (auto opaque = type.dyn_cast<mlir::OpaqueType>())
46     return opaque.getDialectNamespace() == dialect &&
47            opaque.getTypeData() == typeData;
48   return false;
49 }
50 
51 /// Returns success if the given two shapes are compatible. That is, they have
52 /// the same size and each pair of the elements are equal or one of them is
53 /// dynamic.
verifyCompatibleShape(ArrayRef<int64_t> shape1,ArrayRef<int64_t> shape2)54 LogicalResult mlir::verifyCompatibleShape(ArrayRef<int64_t> shape1,
55                                           ArrayRef<int64_t> shape2) {
56   if (shape1.size() != shape2.size())
57     return failure();
58   for (auto dims : llvm::zip(shape1, shape2)) {
59     int64_t dim1 = std::get<0>(dims);
60     int64_t dim2 = std::get<1>(dims);
61     if (!ShapedType::isDynamic(dim1) && !ShapedType::isDynamic(dim2) &&
62         dim1 != dim2)
63       return failure();
64   }
65   return success();
66 }
67 
68 /// Returns success if the given two types have compatible shape. That is,
69 /// they are both scalars (not shaped), or they are both shaped types and at
70 /// least one is unranked or they have compatible dimensions. Dimensions are
71 /// compatible if at least one is dynamic or both are equal. The element type
72 /// does not matter.
verifyCompatibleShape(Type type1,Type type2)73 LogicalResult mlir::verifyCompatibleShape(Type type1, Type type2) {
74   auto sType1 = type1.dyn_cast<ShapedType>();
75   auto sType2 = type2.dyn_cast<ShapedType>();
76 
77   // Either both or neither type should be shaped.
78   if (!sType1)
79     return success(!sType2);
80   if (!sType2)
81     return failure();
82 
83   if (!sType1.hasRank() || !sType2.hasRank())
84     return success();
85 
86   return verifyCompatibleShape(sType1.getShape(), sType2.getShape());
87 }
88 
OperandElementTypeIterator(Operation::operand_iterator it)89 OperandElementTypeIterator::OperandElementTypeIterator(
90     Operation::operand_iterator it)
91     : llvm::mapped_iterator<Operation::operand_iterator, Type (*)(Value)>(
92           it, &unwrap) {}
93 
unwrap(Value value)94 Type OperandElementTypeIterator::unwrap(Value value) {
95   return value.getType().cast<ShapedType>().getElementType();
96 }
97 
ResultElementTypeIterator(Operation::result_iterator it)98 ResultElementTypeIterator::ResultElementTypeIterator(
99     Operation::result_iterator it)
100     : llvm::mapped_iterator<Operation::result_iterator, Type (*)(Value)>(
101           it, &unwrap) {}
102 
unwrap(Value value)103 Type ResultElementTypeIterator::unwrap(Value value) {
104   return value.getType().cast<ShapedType>().getElementType();
105 }
106