1 //===- InterfaceSupport.h - MLIR Interface Support Classes ------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines several support classes for defining interfaces. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef MLIR_SUPPORT_INTERFACESUPPORT_H 14 #define MLIR_SUPPORT_INTERFACESUPPORT_H 15 16 #include "mlir/Support/TypeID.h" 17 #include "llvm/ADT/DenseMap.h" 18 #include "llvm/Support/TypeName.h" 19 20 namespace mlir { 21 namespace detail { 22 //===----------------------------------------------------------------------===// 23 // Interface 24 //===----------------------------------------------------------------------===// 25 26 /// This class represents an abstract interface. An interface is a simplified 27 /// mechanism for attaching concept based polymorphism to a class hierarchy. An 28 /// interface is comprised of two components: 29 /// * The derived interface class: This is what users interact with, and invoke 30 /// methods on. 31 /// * An interface `Trait` class: This is the class that is attached to the 32 /// object implementing the interface. It is the mechanism with which models 33 /// are specialized. 34 /// 35 /// Derived interfaces types must provide the following template types: 36 /// * ConcreteType: The CRTP derived type. 37 /// * ValueT: The opaque type the derived interface operates on. For example 38 /// `Operation*` for operation interfaces, or `Attribute` for 39 /// attribute interfaces. 40 /// * Traits: A class that contains definitions for a 'Concept' and a 'Model' 41 /// class. The 'Concept' class defines an abstract virtual interface, 42 /// where as the 'Model' class implements this interface for a 43 /// specific derived T type. Both of these classes *must* not contain 44 /// non-static data. A simple example is shown below: 45 /// 46 /// ```c++ 47 /// struct ExampleInterfaceTraits { 48 /// struct Concept { 49 /// virtual unsigned getNumInputs(T t) const = 0; 50 /// }; 51 /// template <typename DerivedT> class Model { 52 /// unsigned getNumInputs(T t) const final { 53 /// return cast<DerivedT>(t).getNumInputs(); 54 /// } 55 /// }; 56 /// }; 57 /// ``` 58 /// 59 /// * BaseType: A desired base type for the interface. This is a class that 60 /// provides that provides specific functionality for the `ValueT` 61 /// value. For instance the specific `Op` that will wrap the 62 /// `Operation*` for an `OpInterface`. 63 /// * BaseTrait: The base type for the interface trait. This is the base class 64 /// to use for the interface trait that will be attached to each 65 /// instance of `ValueT` that implements this interface. 66 /// 67 template <typename ConcreteType, typename ValueT, typename Traits, 68 typename BaseType, 69 template <typename, template <typename> class> class BaseTrait> 70 class Interface : public BaseType { 71 public: 72 using Concept = typename Traits::Concept; 73 template <typename T> using Model = typename Traits::template Model<T>; 74 using InterfaceBase = 75 Interface<ConcreteType, ValueT, Traits, BaseType, BaseTrait>; 76 77 /// This is a special trait that registers a given interface with an object. 78 template <typename ConcreteT> 79 struct Trait : public BaseTrait<ConcreteT, Trait> { 80 using ModelT = Model<ConcreteT>; 81 82 /// Define an accessor for the ID of this interface. getInterfaceIDTrait83 static TypeID getInterfaceID() { return TypeID::get<ConcreteType>(); } 84 }; 85 86 /// Construct an interface from an instance of the value type. 87 Interface(ValueT t = ValueT()) BaseType(t)88 : BaseType(t), impl(t ? ConcreteType::getInterfaceFor(t) : nullptr) { 89 assert((!t || impl) && "expected value to provide interface instance"); 90 } 91 92 /// Construct an interface instance from a type that implements this 93 /// interface's trait. 94 template <typename T, typename std::enable_if_t< 95 std::is_base_of<Trait<T>, T>::value> * = nullptr> Interface(T t)96 Interface(T t) 97 : BaseType(t), impl(t ? ConcreteType::getInterfaceFor(t) : nullptr) { 98 assert((!t || impl) && "expected value to provide interface instance"); 99 } 100 101 /// Support 'classof' by checking if the given object defines the concrete 102 /// interface. classof(ValueT t)103 static bool classof(ValueT t) { return ConcreteType::getInterfaceFor(t); } 104 105 /// Define an accessor for the ID of this interface. getInterfaceID()106 static TypeID getInterfaceID() { return TypeID::get<ConcreteType>(); } 107 108 protected: 109 /// Get the raw concept in the correct derived concept type. getImpl()110 const Concept *getImpl() const { return impl; } getImpl()111 Concept *getImpl() { return impl; } 112 113 private: 114 /// A pointer to the impl concept object. 115 Concept *impl; 116 }; 117 118 //===----------------------------------------------------------------------===// 119 // InterfaceMap 120 //===----------------------------------------------------------------------===// 121 122 /// Utility to filter a given sequence of types base upon a predicate. 123 template <bool> 124 struct FilterTypeT { 125 template <class E> 126 using type = std::tuple<E>; 127 }; 128 template <> 129 struct FilterTypeT<false> { 130 template <class E> 131 using type = std::tuple<>; 132 }; 133 template <template <class> class Pred, class... Es> 134 struct FilterTypes { 135 using type = decltype(std::tuple_cat( 136 std::declval< 137 typename FilterTypeT<Pred<Es>::value>::template type<Es>>()...)); 138 }; 139 140 /// This class provides an efficient mapping between a given `Interface` type, 141 /// and a particular implementation of its concept. 142 class InterfaceMap { 143 /// Trait to check if T provides a static 'getInterfaceID' method. 144 template <typename T, typename... Args> 145 using has_get_interface_id = decltype(T::getInterfaceID()); 146 template <typename T> 147 using detect_get_interface_id = llvm::is_detected<has_get_interface_id, T>; 148 template <typename... Types> 149 using num_interface_types = typename std::tuple_size< 150 typename FilterTypes<detect_get_interface_id, Types...>::type>; 151 152 public: 153 InterfaceMap(InterfaceMap &&) = default; 154 ~InterfaceMap() { 155 if (interfaces) { 156 for (auto &it : *interfaces) 157 free(it.second); 158 } 159 } 160 161 /// Construct an InterfaceMap with the given set of template types. For 162 /// convenience given that object trait lists may contain other non-interface 163 /// types, not all of the types need to be interfaces. The provided types that 164 /// do not represent interfaces are not added to the interface map. 165 template <typename... Types> 166 static std::enable_if_t<num_interface_types<Types...>::value != 0, 167 InterfaceMap> 168 get() { 169 // Filter the provided types for those that are interfaces. 170 using FilteredTupleType = 171 typename FilterTypes<detect_get_interface_id, Types...>::type; 172 return getImpl((FilteredTupleType *)nullptr); 173 } 174 175 template <typename... Types> 176 static std::enable_if_t<num_interface_types<Types...>::value == 0, 177 InterfaceMap> 178 get() { 179 return InterfaceMap(); 180 } 181 182 /// Returns an instance of the concept object for the given interface if it 183 /// was registered to this map, null otherwise. 184 template <typename T> typename T::Concept *lookup() const { 185 void *inst = interfaces ? interfaces->lookup(T::getInterfaceID()) : nullptr; 186 return reinterpret_cast<typename T::Concept *>(inst); 187 } 188 189 private: 190 InterfaceMap() = default; 191 InterfaceMap(MutableArrayRef<std::pair<TypeID, void *>> elements) 192 : interfaces(std::make_unique<llvm::SmallDenseMap<TypeID, void *>>( 193 elements.begin(), elements.end())) {} 194 195 template <typename... Ts> 196 static InterfaceMap getImpl(std::tuple<Ts...> *) { 197 std::pair<TypeID, void *> elements[] = {std::make_pair( 198 Ts::getInterfaceID(), 199 new (malloc(sizeof(typename Ts::ModelT))) typename Ts::ModelT())...}; 200 return InterfaceMap(elements); 201 } 202 203 /// The internal map of interfaces. This is constructed statically for each 204 /// set of interfaces. 205 std::unique_ptr<llvm::SmallDenseMap<TypeID, void *>> interfaces; 206 }; 207 208 } // end namespace detail 209 } // end namespace mlir 210 211 #endif 212