• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- DialectInterface.h - IR Dialect Interfaces ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_IR_DIALECTINTERFACE_H
10 #define MLIR_IR_DIALECTINTERFACE_H
11 
12 #include "mlir/Support/TypeID.h"
13 #include "llvm/ADT/DenseSet.h"
14 
15 namespace mlir {
16 class Dialect;
17 class MLIRContext;
18 class Operation;
19 
20 //===----------------------------------------------------------------------===//
21 // DialectInterface
22 //===----------------------------------------------------------------------===//
23 namespace detail {
24 /// The base class used for all derived interface types. This class provides
25 /// utilities necessary for registration.
26 template <typename ConcreteType, typename BaseT>
27 class DialectInterfaceBase : public BaseT {
28 public:
29   using Base = DialectInterfaceBase<ConcreteType, BaseT>;
30 
31   /// Get a unique id for the derived interface type.
getInterfaceID()32   static TypeID getInterfaceID() { return TypeID::get<ConcreteType>(); }
33 
34 protected:
DialectInterfaceBase(Dialect * dialect)35   DialectInterfaceBase(Dialect *dialect) : BaseT(dialect, getInterfaceID()) {}
36 };
37 } // end namespace detail
38 
39 /// This class represents an interface overridden for a single dialect.
40 class DialectInterface {
41 public:
42   virtual ~DialectInterface();
43 
44   /// The base class used for all derived interface types. This class provides
45   /// utilities necessary for registration.
46   template <typename ConcreteType>
47   using Base = detail::DialectInterfaceBase<ConcreteType, DialectInterface>;
48 
49   /// Return the dialect that this interface represents.
getDialect()50   Dialect *getDialect() const { return dialect; }
51 
52   /// Return the derived interface id.
getID()53   TypeID getID() const { return interfaceID; }
54 
55 protected:
DialectInterface(Dialect * dialect,TypeID id)56   DialectInterface(Dialect *dialect, TypeID id)
57       : dialect(dialect), interfaceID(id) {}
58 
59 private:
60   /// The dialect that represents this interface.
61   Dialect *dialect;
62 
63   /// The unique identifier for the derived interface type.
64   TypeID interfaceID;
65 };
66 
67 //===----------------------------------------------------------------------===//
68 // DialectInterfaceCollection
69 //===----------------------------------------------------------------------===//
70 
71 namespace detail {
72 /// This class is the base class for a collection of instances for a specific
73 /// interface kind.
74 class DialectInterfaceCollectionBase {
75   /// DenseMap info for dialect interfaces that allows lookup by the dialect.
76   struct InterfaceKeyInfo : public DenseMapInfo<const DialectInterface *> {
77     using DenseMapInfo<const DialectInterface *>::isEqual;
78 
getHashValueInterfaceKeyInfo79     static unsigned getHashValue(Dialect *key) { return llvm::hash_value(key); }
getHashValueInterfaceKeyInfo80     static unsigned getHashValue(const DialectInterface *key) {
81       return getHashValue(key->getDialect());
82     }
83 
isEqualInterfaceKeyInfo84     static bool isEqual(Dialect *lhs, const DialectInterface *rhs) {
85       if (rhs == getEmptyKey() || rhs == getTombstoneKey())
86         return false;
87       return lhs == rhs->getDialect();
88     }
89   };
90 
91   /// A set of registered dialect interface instances.
92   using InterfaceSetT = DenseSet<const DialectInterface *, InterfaceKeyInfo>;
93   using InterfaceVectorT = std::vector<const DialectInterface *>;
94 
95 public:
96   DialectInterfaceCollectionBase(MLIRContext *ctx, TypeID interfaceKind);
97   virtual ~DialectInterfaceCollectionBase();
98 
99 protected:
100   /// Get the interface for the dialect of given operation, or null if one
101   /// is not registered.
102   const DialectInterface *getInterfaceFor(Operation *op) const;
103 
104   /// Get the interface for the given dialect.
getInterfaceFor(Dialect * dialect)105   const DialectInterface *getInterfaceFor(Dialect *dialect) const {
106     auto it = interfaces.find_as(dialect);
107     return it == interfaces.end() ? nullptr : *it;
108   }
109 
110   /// An iterator class that iterates the held interface objects of the given
111   /// derived interface type.
112   template <typename InterfaceT>
113   class iterator : public llvm::mapped_iterator<
114                        InterfaceVectorT::const_iterator,
115                        const InterfaceT &(*)(const DialectInterface *)> {
remapIt(const DialectInterface * interface)116     static const InterfaceT &remapIt(const DialectInterface *interface) {
117       return *static_cast<const InterfaceT *>(interface);
118     }
119 
iterator(InterfaceVectorT::const_iterator it)120     iterator(InterfaceVectorT::const_iterator it)
121         : llvm::mapped_iterator<
122               InterfaceVectorT::const_iterator,
123               const InterfaceT &(*)(const DialectInterface *)>(it, &remapIt) {}
124 
125     /// Allow access to the constructor.
126     friend DialectInterfaceCollectionBase;
127   };
128 
129   /// Iterator access to the held interfaces.
interface_begin()130   template <typename InterfaceT> iterator<InterfaceT> interface_begin() const {
131     return iterator<InterfaceT>(orderedInterfaces.begin());
132   }
interface_end()133   template <typename InterfaceT> iterator<InterfaceT> interface_end() const {
134     return iterator<InterfaceT>(orderedInterfaces.end());
135   }
136 
137 private:
138   /// A set of registered dialect interface instances.
139   InterfaceSetT interfaces;
140   /// An ordered list of the registered interface instances, necessary for
141   /// deterministic iteration.
142   // NOTE: SetVector does not provide find access, so it can't be used here.
143   InterfaceVectorT orderedInterfaces;
144 };
145 } // namespace detail
146 
147 /// A collection of dialect interfaces within a context, for a given concrete
148 /// interface type.
149 template <typename InterfaceType>
150 class DialectInterfaceCollection
151     : public detail::DialectInterfaceCollectionBase {
152 public:
153   using Base = DialectInterfaceCollection<InterfaceType>;
154 
155   /// Collect the registered dialect interfaces within the provided context.
DialectInterfaceCollection(MLIRContext * ctx)156   DialectInterfaceCollection(MLIRContext *ctx)
157       : detail::DialectInterfaceCollectionBase(
158             ctx, InterfaceType::getInterfaceID()) {}
159 
160   /// Get the interface for a given object, or null if one is not registered.
161   /// The object may be a dialect or an operation instance.
162   template <typename Object>
getInterfaceFor(Object * obj)163   const InterfaceType *getInterfaceFor(Object *obj) const {
164     return static_cast<const InterfaceType *>(
165         detail::DialectInterfaceCollectionBase::getInterfaceFor(obj));
166   }
167 
168   /// Iterator access to the held interfaces.
169   using iterator =
170       detail::DialectInterfaceCollectionBase::iterator<InterfaceType>;
begin()171   iterator begin() const { return interface_begin<InterfaceType>(); }
end()172   iterator end() const { return interface_end<InterfaceType>(); }
173 
174 private:
175   using detail::DialectInterfaceCollectionBase::interface_begin;
176   using detail::DialectInterfaceCollectionBase::interface_end;
177 };
178 
179 } // namespace mlir
180 
181 #endif
182