1 //===- DialectInterface.h - IR Dialect Interfaces ---------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_IR_DIALECTINTERFACE_H 10 #define MLIR_IR_DIALECTINTERFACE_H 11 12 #include "mlir/Support/TypeID.h" 13 #include "llvm/ADT/DenseSet.h" 14 15 namespace mlir { 16 class Dialect; 17 class MLIRContext; 18 class Operation; 19 20 //===----------------------------------------------------------------------===// 21 // DialectInterface 22 //===----------------------------------------------------------------------===// 23 namespace detail { 24 /// The base class used for all derived interface types. This class provides 25 /// utilities necessary for registration. 26 template <typename ConcreteType, typename BaseT> 27 class DialectInterfaceBase : public BaseT { 28 public: 29 using Base = DialectInterfaceBase<ConcreteType, BaseT>; 30 31 /// Get a unique id for the derived interface type. getInterfaceID()32 static TypeID getInterfaceID() { return TypeID::get<ConcreteType>(); } 33 34 protected: DialectInterfaceBase(Dialect * dialect)35 DialectInterfaceBase(Dialect *dialect) : BaseT(dialect, getInterfaceID()) {} 36 }; 37 } // end namespace detail 38 39 /// This class represents an interface overridden for a single dialect. 40 class DialectInterface { 41 public: 42 virtual ~DialectInterface(); 43 44 /// The base class used for all derived interface types. This class provides 45 /// utilities necessary for registration. 46 template <typename ConcreteType> 47 using Base = detail::DialectInterfaceBase<ConcreteType, DialectInterface>; 48 49 /// Return the dialect that this interface represents. getDialect()50 Dialect *getDialect() const { return dialect; } 51 52 /// Return the derived interface id. getID()53 TypeID getID() const { return interfaceID; } 54 55 protected: DialectInterface(Dialect * dialect,TypeID id)56 DialectInterface(Dialect *dialect, TypeID id) 57 : dialect(dialect), interfaceID(id) {} 58 59 private: 60 /// The dialect that represents this interface. 61 Dialect *dialect; 62 63 /// The unique identifier for the derived interface type. 64 TypeID interfaceID; 65 }; 66 67 //===----------------------------------------------------------------------===// 68 // DialectInterfaceCollection 69 //===----------------------------------------------------------------------===// 70 71 namespace detail { 72 /// This class is the base class for a collection of instances for a specific 73 /// interface kind. 74 class DialectInterfaceCollectionBase { 75 /// DenseMap info for dialect interfaces that allows lookup by the dialect. 76 struct InterfaceKeyInfo : public DenseMapInfo<const DialectInterface *> { 77 using DenseMapInfo<const DialectInterface *>::isEqual; 78 getHashValueInterfaceKeyInfo79 static unsigned getHashValue(Dialect *key) { return llvm::hash_value(key); } getHashValueInterfaceKeyInfo80 static unsigned getHashValue(const DialectInterface *key) { 81 return getHashValue(key->getDialect()); 82 } 83 isEqualInterfaceKeyInfo84 static bool isEqual(Dialect *lhs, const DialectInterface *rhs) { 85 if (rhs == getEmptyKey() || rhs == getTombstoneKey()) 86 return false; 87 return lhs == rhs->getDialect(); 88 } 89 }; 90 91 /// A set of registered dialect interface instances. 92 using InterfaceSetT = DenseSet<const DialectInterface *, InterfaceKeyInfo>; 93 using InterfaceVectorT = std::vector<const DialectInterface *>; 94 95 public: 96 DialectInterfaceCollectionBase(MLIRContext *ctx, TypeID interfaceKind); 97 virtual ~DialectInterfaceCollectionBase(); 98 99 protected: 100 /// Get the interface for the dialect of given operation, or null if one 101 /// is not registered. 102 const DialectInterface *getInterfaceFor(Operation *op) const; 103 104 /// Get the interface for the given dialect. getInterfaceFor(Dialect * dialect)105 const DialectInterface *getInterfaceFor(Dialect *dialect) const { 106 auto it = interfaces.find_as(dialect); 107 return it == interfaces.end() ? nullptr : *it; 108 } 109 110 /// An iterator class that iterates the held interface objects of the given 111 /// derived interface type. 112 template <typename InterfaceT> 113 class iterator : public llvm::mapped_iterator< 114 InterfaceVectorT::const_iterator, 115 const InterfaceT &(*)(const DialectInterface *)> { remapIt(const DialectInterface * interface)116 static const InterfaceT &remapIt(const DialectInterface *interface) { 117 return *static_cast<const InterfaceT *>(interface); 118 } 119 iterator(InterfaceVectorT::const_iterator it)120 iterator(InterfaceVectorT::const_iterator it) 121 : llvm::mapped_iterator< 122 InterfaceVectorT::const_iterator, 123 const InterfaceT &(*)(const DialectInterface *)>(it, &remapIt) {} 124 125 /// Allow access to the constructor. 126 friend DialectInterfaceCollectionBase; 127 }; 128 129 /// Iterator access to the held interfaces. interface_begin()130 template <typename InterfaceT> iterator<InterfaceT> interface_begin() const { 131 return iterator<InterfaceT>(orderedInterfaces.begin()); 132 } interface_end()133 template <typename InterfaceT> iterator<InterfaceT> interface_end() const { 134 return iterator<InterfaceT>(orderedInterfaces.end()); 135 } 136 137 private: 138 /// A set of registered dialect interface instances. 139 InterfaceSetT interfaces; 140 /// An ordered list of the registered interface instances, necessary for 141 /// deterministic iteration. 142 // NOTE: SetVector does not provide find access, so it can't be used here. 143 InterfaceVectorT orderedInterfaces; 144 }; 145 } // namespace detail 146 147 /// A collection of dialect interfaces within a context, for a given concrete 148 /// interface type. 149 template <typename InterfaceType> 150 class DialectInterfaceCollection 151 : public detail::DialectInterfaceCollectionBase { 152 public: 153 using Base = DialectInterfaceCollection<InterfaceType>; 154 155 /// Collect the registered dialect interfaces within the provided context. DialectInterfaceCollection(MLIRContext * ctx)156 DialectInterfaceCollection(MLIRContext *ctx) 157 : detail::DialectInterfaceCollectionBase( 158 ctx, InterfaceType::getInterfaceID()) {} 159 160 /// Get the interface for a given object, or null if one is not registered. 161 /// The object may be a dialect or an operation instance. 162 template <typename Object> getInterfaceFor(Object * obj)163 const InterfaceType *getInterfaceFor(Object *obj) const { 164 return static_cast<const InterfaceType *>( 165 detail::DialectInterfaceCollectionBase::getInterfaceFor(obj)); 166 } 167 168 /// Iterator access to the held interfaces. 169 using iterator = 170 detail::DialectInterfaceCollectionBase::iterator<InterfaceType>; begin()171 iterator begin() const { return interface_begin<InterfaceType>(); } end()172 iterator end() const { return interface_end<InterfaceType>(); } 173 174 private: 175 using detail::DialectInterfaceCollectionBase::interface_begin; 176 using detail::DialectInterfaceCollectionBase::interface_end; 177 }; 178 179 } // namespace mlir 180 181 #endif 182