• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- StorageUniquerSupport.h - MLIR Storage Uniquer Utilities -*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines utility classes for interfacing with StorageUniquer.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_IR_STORAGEUNIQUERSUPPORT_H
14 #define MLIR_IR_STORAGEUNIQUERSUPPORT_H
15 
16 #include "mlir/Support/InterfaceSupport.h"
17 #include "mlir/Support/LogicalResult.h"
18 #include "mlir/Support/StorageUniquer.h"
19 #include "mlir/Support/TypeID.h"
20 
21 namespace mlir {
22 class AttributeStorage;
23 class MLIRContext;
24 
25 namespace detail {
26 /// Utility method to generate a raw default location for use when checking the
27 /// construction invariants of a storage object. This is defined out-of-line to
28 /// avoid the need to include Location.h.
29 const AttributeStorage *generateUnknownStorageLocation(MLIRContext *ctx);
30 
31 //===----------------------------------------------------------------------===//
32 // StorageUserTraitBase
33 //===----------------------------------------------------------------------===//
34 
35 /// Helper class for implementing traits for storage classes. Clients are not
36 /// expected to interact with this directly, so its members are all protected.
37 template <typename ConcreteType, template <typename> class TraitType>
38 class StorageUserTraitBase {
39 protected:
40   /// Return the derived instance.
getInstance()41   ConcreteType getInstance() const {
42     // We have to cast up to the trait type, then to the concrete type because
43     // the concrete type will multiply derive from the (content free) TraitBase
44     // class, and we need to be able to disambiguate the path for the C++
45     // compiler.
46     auto *trait = static_cast<const TraitType<ConcreteType> *>(this);
47     return *static_cast<const ConcreteType *>(trait);
48   }
49 };
50 
51 //===----------------------------------------------------------------------===//
52 // StorageUserBase
53 //===----------------------------------------------------------------------===//
54 
55 /// Utility class for implementing users of storage classes uniqued by a
56 /// StorageUniquer. Clients are not expected to interact with this class
57 /// directly.
58 template <typename ConcreteT, typename BaseT, typename StorageT,
59           typename UniquerT, template <typename T> class... Traits>
60 class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
61 public:
62   using BaseT::BaseT;
63 
64   /// Utility declarations for the concrete attribute class.
65   using Base = StorageUserBase<ConcreteT, BaseT, StorageT, UniquerT, Traits...>;
66   using ImplType = StorageT;
67 
68   /// Return a unique identifier for the concrete type.
getTypeID()69   static TypeID getTypeID() { return TypeID::get<ConcreteT>(); }
70 
71   /// Provide an implementation of 'classof' that compares the type id of the
72   /// provided value with that of the concrete type.
classof(T val)73   template <typename T> static bool classof(T val) {
74     static_assert(std::is_convertible<ConcreteT, T>::value,
75                   "casting from a non-convertible type");
76     return val.getTypeID() == getTypeID();
77   }
78 
79   /// Returns an interface map for the interfaces registered to this storage
80   /// user. This should not be used directly.
getInterfaceMap()81   static detail::InterfaceMap getInterfaceMap() {
82     return detail::InterfaceMap::template get<Traits<ConcreteT>...>();
83   }
84 
85   /// Get or create a new ConcreteT instance within the ctx. This
86   /// function is guaranteed to return a non null object and will assert if
87   /// the arguments provided are invalid.
88   template <typename... Args>
get(MLIRContext * ctx,Args...args)89   static ConcreteT get(MLIRContext *ctx, Args... args) {
90     // Ensure that the invariants are correct for construction.
91     assert(succeeded(ConcreteT::verifyConstructionInvariants(
92         generateUnknownStorageLocation(ctx), args...)));
93     return UniquerT::template get<ConcreteT>(ctx, args...);
94   }
95 
96   /// Get or create a new ConcreteT instance within the ctx, defined at
97   /// the given, potentially unknown, location. If the arguments provided are
98   /// invalid then emit errors and return a null object.
99   template <typename LocationT, typename... Args>
getChecked(LocationT loc,Args...args)100   static ConcreteT getChecked(LocationT loc, Args... args) {
101     // If the construction invariants fail then we return a null attribute.
102     if (failed(ConcreteT::verifyConstructionInvariants(loc, args...)))
103       return ConcreteT();
104     return UniquerT::template get<ConcreteT>(loc.getContext(), args...);
105   }
106 
107   /// Get an instance of the concrete type from a void pointer.
getFromOpaquePointer(const void * ptr)108   static ConcreteT getFromOpaquePointer(const void *ptr) {
109     return ptr ? BaseT::getFromOpaquePointer(ptr).template cast<ConcreteT>()
110                : nullptr;
111   }
112 
113 protected:
114   /// Mutate the current storage instance. This will not change the unique key.
115   /// The arguments are forwarded to 'ConcreteT::mutate'.
mutate(Args &&...args)116   template <typename... Args> LogicalResult mutate(Args &&...args) {
117     return UniquerT::template mutate<ConcreteT>(this->getContext(), getImpl(),
118                                                 std::forward<Args>(args)...);
119   }
120 
121   /// Default implementation that just returns success.
122   template <typename... Args>
verifyConstructionInvariants(Args...args)123   static LogicalResult verifyConstructionInvariants(Args... args) {
124     return success();
125   }
126 
127   /// Utility for easy access to the storage instance.
getImpl()128   ImplType *getImpl() const { return static_cast<ImplType *>(this->impl); }
129 };
130 } // namespace detail
131 } // namespace mlir
132 
133 #endif
134