• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- InterfaceSupport.h - MLIR Interface Support Classes ------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines several support classes for defining interfaces.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_SUPPORT_INTERFACESUPPORT_H
14 #define MLIR_SUPPORT_INTERFACESUPPORT_H
15 
16 #include "mlir/Support/TypeID.h"
17 #include "llvm/ADT/DenseMap.h"
18 #include "llvm/Support/TypeName.h"
19 
20 namespace mlir {
21 namespace detail {
22 //===----------------------------------------------------------------------===//
23 // Interface
24 //===----------------------------------------------------------------------===//
25 
26 /// This class represents an abstract interface. An interface is a simplified
27 /// mechanism for attaching concept based polymorphism to a class hierarchy. An
28 /// interface is comprised of two components:
29 /// * The derived interface class: This is what users interact with, and invoke
30 ///   methods on.
31 /// * An interface `Trait` class: This is the class that is attached to the
32 ///   object implementing the interface. It is the mechanism with which models
33 ///   are specialized.
34 ///
35 /// Derived interfaces types must provide the following template types:
36 /// * ConcreteType: The CRTP derived type.
37 /// * ValueT: The opaque type the derived interface operates on. For example
38 ///           `Operation*` for operation interfaces, or `Attribute` for
39 ///           attribute interfaces.
40 /// * Traits: A class that contains definitions for a 'Concept' and a 'Model'
41 ///           class. The 'Concept' class defines an abstract virtual interface,
42 ///           where as the 'Model' class implements this interface for a
43 ///           specific derived T type. Both of these classes *must* not contain
44 ///           non-static data. A simple example is shown below:
45 ///
46 /// ```c++
47 ///    struct ExampleInterfaceTraits {
48 ///      struct Concept {
49 ///        virtual unsigned getNumInputs(T t) const = 0;
50 ///      };
51 ///      template <typename DerivedT> class Model {
52 ///        unsigned getNumInputs(T t) const final {
53 ///          return cast<DerivedT>(t).getNumInputs();
54 ///        }
55 ///      };
56 ///    };
57 /// ```
58 ///
59 /// * BaseType: A desired base type for the interface. This is a class that
60 ///             provides that provides specific functionality for the `ValueT`
61 ///             value. For instance the specific `Op` that will wrap the
62 ///             `Operation*` for an `OpInterface`.
63 /// * BaseTrait: The base type for the interface trait. This is the base class
64 ///              to use for the interface trait that will be attached to each
65 ///              instance of `ValueT` that implements this interface.
66 ///
67 template <typename ConcreteType, typename ValueT, typename Traits,
68           typename BaseType,
69           template <typename, template <typename> class> class BaseTrait>
70 class Interface : public BaseType {
71 public:
72   using Concept = typename Traits::Concept;
73   template <typename T> using Model = typename Traits::template Model<T>;
74   using InterfaceBase =
75       Interface<ConcreteType, ValueT, Traits, BaseType, BaseTrait>;
76 
77   /// This is a special trait that registers a given interface with an object.
78   template <typename ConcreteT>
79   struct Trait : public BaseTrait<ConcreteT, Trait> {
80     using ModelT = Model<ConcreteT>;
81 
82     /// Define an accessor for the ID of this interface.
getInterfaceIDTrait83     static TypeID getInterfaceID() { return TypeID::get<ConcreteType>(); }
84   };
85 
86   /// Construct an interface from an instance of the value type.
87   Interface(ValueT t = ValueT())
BaseType(t)88       : BaseType(t), impl(t ? ConcreteType::getInterfaceFor(t) : nullptr) {
89     assert((!t || impl) && "expected value to provide interface instance");
90   }
91 
92   /// Construct an interface instance from a type that implements this
93   /// interface's trait.
94   template <typename T, typename std::enable_if_t<
95                             std::is_base_of<Trait<T>, T>::value> * = nullptr>
Interface(T t)96   Interface(T t)
97       : BaseType(t), impl(t ? ConcreteType::getInterfaceFor(t) : nullptr) {
98     assert((!t || impl) && "expected value to provide interface instance");
99   }
100 
101   /// Support 'classof' by checking if the given object defines the concrete
102   /// interface.
classof(ValueT t)103   static bool classof(ValueT t) { return ConcreteType::getInterfaceFor(t); }
104 
105   /// Define an accessor for the ID of this interface.
getInterfaceID()106   static TypeID getInterfaceID() { return TypeID::get<ConcreteType>(); }
107 
108 protected:
109   /// Get the raw concept in the correct derived concept type.
getImpl()110   const Concept *getImpl() const { return impl; }
getImpl()111   Concept *getImpl() { return impl; }
112 
113 private:
114   /// A pointer to the impl concept object.
115   Concept *impl;
116 };
117 
118 //===----------------------------------------------------------------------===//
119 // InterfaceMap
120 //===----------------------------------------------------------------------===//
121 
122 /// Utility to filter a given sequence of types base upon a predicate.
123 template <bool>
124 struct FilterTypeT {
125   template <class E>
126   using type = std::tuple<E>;
127 };
128 template <>
129 struct FilterTypeT<false> {
130   template <class E>
131   using type = std::tuple<>;
132 };
133 template <template <class> class Pred, class... Es>
134 struct FilterTypes {
135   using type = decltype(std::tuple_cat(
136       std::declval<
137           typename FilterTypeT<Pred<Es>::value>::template type<Es>>()...));
138 };
139 
140 /// This class provides an efficient mapping between a given `Interface` type,
141 /// and a particular implementation of its concept.
142 class InterfaceMap {
143   /// Trait to check if T provides a static 'getInterfaceID' method.
144   template <typename T, typename... Args>
145   using has_get_interface_id = decltype(T::getInterfaceID());
146   template <typename T>
147   using detect_get_interface_id = llvm::is_detected<has_get_interface_id, T>;
148   template <typename... Types>
149   using num_interface_types = typename std::tuple_size<
150       typename FilterTypes<detect_get_interface_id, Types...>::type>;
151 
152 public:
153   InterfaceMap(InterfaceMap &&) = default;
154   ~InterfaceMap() {
155     if (interfaces) {
156       for (auto &it : *interfaces)
157         free(it.second);
158     }
159   }
160 
161   /// Construct an InterfaceMap with the given set of template types. For
162   /// convenience given that object trait lists may contain other non-interface
163   /// types, not all of the types need to be interfaces. The provided types that
164   /// do not represent interfaces are not added to the interface map.
165   template <typename... Types>
166   static std::enable_if_t<num_interface_types<Types...>::value != 0,
167                           InterfaceMap>
168   get() {
169     // Filter the provided types for those that are interfaces.
170     using FilteredTupleType =
171         typename FilterTypes<detect_get_interface_id, Types...>::type;
172     return getImpl((FilteredTupleType *)nullptr);
173   }
174 
175   template <typename... Types>
176   static std::enable_if_t<num_interface_types<Types...>::value == 0,
177                           InterfaceMap>
178   get() {
179     return InterfaceMap();
180   }
181 
182   /// Returns an instance of the concept object for the given interface if it
183   /// was registered to this map, null otherwise.
184   template <typename T> typename T::Concept *lookup() const {
185     void *inst = interfaces ? interfaces->lookup(T::getInterfaceID()) : nullptr;
186     return reinterpret_cast<typename T::Concept *>(inst);
187   }
188 
189 private:
190   InterfaceMap() = default;
191   InterfaceMap(MutableArrayRef<std::pair<TypeID, void *>> elements)
192       : interfaces(std::make_unique<llvm::SmallDenseMap<TypeID, void *>>(
193             elements.begin(), elements.end())) {}
194 
195   template <typename... Ts>
196   static InterfaceMap getImpl(std::tuple<Ts...> *) {
197     std::pair<TypeID, void *> elements[] = {std::make_pair(
198         Ts::getInterfaceID(),
199         new (malloc(sizeof(typename Ts::ModelT))) typename Ts::ModelT())...};
200     return InterfaceMap(elements);
201   }
202 
203   /// The internal map of interfaces. This is constructed statically for each
204   /// set of interfaces.
205   std::unique_ptr<llvm::SmallDenseMap<TypeID, void *>> interfaces;
206 };
207 
208 } // end namespace detail
209 } // end namespace mlir
210 
211 #endif
212