• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- Builders.h - MLIR Declarative Builder Classes ------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Provides intuitive composable interfaces for building structured MLIR
10 // snippets in a declarative fashion.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_EDSC_BUILDERS_H_
15 #define MLIR_EDSC_BUILDERS_H_
16 
17 #include "mlir/IR/AffineExpr.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/Types.h"
21 
22 namespace mlir {
23 class OperationFolder;
24 
25 namespace edsc {
26 /// Helper class to transparently handle builder insertion points by RAII.
27 /// As its name indicates, a ScopedContext is means to be used locally in a
28 /// scoped fashion. This abstracts away all the boilerplate related to
29 /// checking proper usage of captures, NestedBuilders as well as handling the
30 /// setting and restoring of insertion points.
31 class ScopedContext {
32 public:
33   ScopedContext(OpBuilder &b);
34   ScopedContext(OpBuilder &b, Location location);
35 
36   /// Sets the insertion point of the builder to 'newInsertPt' for the duration
37   /// of the scope. The existing insertion point of the builder is restored on
38   /// destruction.
39   ScopedContext(OpBuilder &b, OpBuilder::InsertPoint newInsertPt,
40                 Location location);
41   ~ScopedContext();
42 
43   static MLIRContext *getContext();
44   static OpBuilder &getBuilderRef();
45   static Location getLocation();
46 
47 private:
48   /// Only NestedBuilder (which is used to create an operation with a body)
49   /// may access private members in order to implement scoping.
50   friend class NestedBuilder;
51 
52   ScopedContext() = delete;
53   ScopedContext(const ScopedContext &) = delete;
54   ScopedContext &operator=(const ScopedContext &) = delete;
55 
56   static ScopedContext *&getCurrentScopedContext();
57 
58   /// Top level OpBuilder.
59   OpBuilder &builder;
60   /// Guard to the previous insertion point.
61   OpBuilder::InsertionGuard guard;
62   /// Current location.
63   Location location;
64   /// Parent context we return into.
65   ScopedContext *enclosingScopedContext;
66 };
67 
68 template <typename Op>
69 struct ValueBuilder {
70   template <typename... Args>
ValueBuilderValueBuilder71   ValueBuilder(Args... args) {
72     value = ScopedContext::getBuilderRef()
73                 .create<Op>(ScopedContext::getLocation(), args...)
74                 .getResult();
75   }
ValueValueBuilder76   operator Value() { return value; }
77   Value value;
78 };
79 
80 template <typename Op>
81 struct OperationBuilder {
82   template <typename... Args>
OperationBuilderOperationBuilder83   OperationBuilder(Args... args) {
84     op = ScopedContext::getBuilderRef().create<Op>(ScopedContext::getLocation(),
85                                                    args...);
86   }
OpOperationBuilder87   operator Op() { return op; }
88   operator Operation *() { return op.getOperation(); }
89   Op op;
90 };
91 
92 /// Creates a block in the region that contains the insertion block of the
93 /// OpBuilder currently at the top of ScopedContext stack (appends the block to
94 /// the region). Be aware that this will NOT update the insertion point of the
95 /// builder to insert into the newly constructed block.
96 Block *createBlock(TypeRange argTypes = llvm::None);
97 
98 /// Creates a block in the specified region using OpBuilder at the top of
99 /// ScopedContext stack (appends the block to the region). Be aware that this
100 /// will NOT update the insertion point of the builder to insert into the newly
101 /// constructed block.
102 Block *createBlockInRegion(Region &region, TypeRange argTypes = llvm::None);
103 
104 /// Calls "builderFn" with ScopedContext reconfigured to insert into "block" and
105 /// passes in the block arguments. If the block has a terminator, the operations
106 /// are inserted before the terminator, otherwise appended to the block.
107 void appendToBlock(Block *block, function_ref<void(ValueRange)> builderFn);
108 
109 /// Creates a block in the region that contains the insertion block of the
110 /// OpBuilder currently at the top of ScopedContext stack, and calls "builderFn"
111 /// to populate the body of the block while passing it the block arguments.
112 Block *buildInNewBlock(TypeRange argTypes,
113                        function_ref<void(ValueRange)> builderFn);
114 
115 /// Creates a block in the specified region using OpBuilder at the top of
116 /// ScopedContext stack, and calls "builderFn" to populate the body of the block
117 /// while passing it the block arguments.
118 Block *buildInNewBlock(Region &region, TypeRange argTypes,
119                        function_ref<void(ValueRange)> builderFn);
120 
121 /// A StructuredIndexed represents an indexable quantity that is either:
122 /// 1. a captured value, which is suitable for buffer and tensor operands, or;
123 /// 2. a captured type, which is suitable for tensor return values.
124 ///
125 /// A StructuredIndexed itself is indexed and passed to `makeGenericLinalgOp`.
126 /// It enable an idiomatic syntax for index expressions such as:
127 ///
128 /// ```
129 ///      StructuredIndexed A(buffer_or_tensor_value), B(buffer_or_tensor_value),
130 ///        C(buffer_value_or_tensor_type);
131 ///      makeGenericLinalgOp({A({m, n}), B({k, n})}, {C({m, n})}, ... );
132 /// ```
133 struct StructuredIndexed {
StructuredIndexedStructuredIndexed134   StructuredIndexed(Value v) : value(v) {}
StructuredIndexedStructuredIndexed135   StructuredIndexed(Type t) : type(t) {}
operatorStructuredIndexed136   StructuredIndexed operator()(ArrayRef<AffineExpr> indexings) {
137     return value ? StructuredIndexed(value, indexings)
138                  : StructuredIndexed(type, indexings);
139   }
140 
StructuredIndexedStructuredIndexed141   StructuredIndexed(Value v, ArrayRef<AffineExpr> indexings)
142       : value(v), exprs(indexings.begin(), indexings.end()) {
143     assert((v.getType().isa<MemRefType, RankedTensorType, VectorType>()) &&
144            "MemRef, RankedTensor or Vector expected");
145   }
StructuredIndexedStructuredIndexed146   StructuredIndexed(Type t, ArrayRef<AffineExpr> indexings)
147       : type(t), exprs(indexings.begin(), indexings.end()) {
148     assert((t.isa<MemRefType, RankedTensorType, VectorType>()) &&
149            "MemRef, RankedTensor or Vector expected");
150   }
151 
hasValueStructuredIndexed152   bool hasValue() const { return (bool)value; }
getValueStructuredIndexed153   Value getValue() const {
154     assert(value && "StructuredIndexed Value not set.");
155     return value;
156   }
getTypeStructuredIndexed157   Type getType() const {
158     assert((value || type) && "StructuredIndexed Value and Type not set.");
159     return value ? value.getType() : type;
160   }
getExprsStructuredIndexed161   ArrayRef<AffineExpr> getExprs() const { return exprs; }
ValueStructuredIndexed162   operator Value() const { return getValue(); }
TypeStructuredIndexed163   operator Type() const { return getType(); }
164 
165 private:
166   // Only one of Value or type may be set.
167   Type type;
168   Value value;
169   SmallVector<AffineExpr, 4> exprs;
170 };
171 
172 /// A TemplatedIndexedValue brings an index notation over the template Load and
173 /// Store parameters. Assigning to an IndexedValue emits an actual `Store`
174 /// operation, while converting an IndexedValue to a Value emits an actual
175 /// `Load` operation.
176 template <typename Load, typename Store>
177 class TemplatedIndexedValue {
178 public:
TemplatedIndexedValue(Value v)179   explicit TemplatedIndexedValue(Value v) : value(v) {}
180 
181   TemplatedIndexedValue(const TemplatedIndexedValue &rhs) = default;
182 
operator()183   TemplatedIndexedValue operator()() { return *this; }
184   /// Returns a new `TemplatedIndexedValue`.
operator()185   TemplatedIndexedValue operator()(Value index) {
186     TemplatedIndexedValue res(value);
187     res.indices.push_back(index);
188     return res;
189   }
190   template <typename... Args>
operator()191   TemplatedIndexedValue operator()(Value index, Args... indices) {
192     return TemplatedIndexedValue(value, index).append(indices...);
193   }
operator()194   TemplatedIndexedValue operator()(ValueRange indices) {
195     return TemplatedIndexedValue(value, indices);
196   }
197 
198   /// Emits a `store`.
199   Store operator=(const TemplatedIndexedValue &rhs) {
200     return Store(rhs, value, indices);
201   }
202   Store operator=(Value rhs) { return Store(rhs, value, indices); }
203 
204   /// Emits a `load` when converting to a Value.
Value()205   operator Value() const { return Load(value, indices); }
206 
207   /// Returns the base memref.
getBase()208   Value getBase() const { return value; }
209 
210   /// Returns the underlying memref.
getMemRefType()211   MemRefType getMemRefType() const {
212     return value.getType().template cast<MemRefType>();
213   }
214 
215   /// Returns the underlying MemRef elemental type cast as `T`.
216   template <typename T>
getElementalTypeAs()217   T getElementalTypeAs() const {
218     return value.getType()
219         .template cast<MemRefType>()
220         .getElementType()
221         .template cast<T>();
222   }
223 
224   /// Arithmetic operator overloadings.
225   Value operator+(Value e);
226   Value operator-(Value e);
227   Value operator*(Value e);
228   Value operator/(Value e);
229   Value operator%(Value e);
230   Value operator^(Value e);
231   Value operator+(TemplatedIndexedValue e) {
232     return *this + static_cast<Value>(e);
233   }
234   Value operator-(TemplatedIndexedValue e) {
235     return *this - static_cast<Value>(e);
236   }
237   Value operator*(TemplatedIndexedValue e) {
238     return *this * static_cast<Value>(e);
239   }
240   Value operator/(TemplatedIndexedValue e) {
241     return *this / static_cast<Value>(e);
242   }
243   Value operator%(TemplatedIndexedValue e) {
244     return *this % static_cast<Value>(e);
245   }
246   Value operator^(TemplatedIndexedValue e) {
247     return *this ^ static_cast<Value>(e);
248   }
249 
250   /// Assignment-arithmetic operator overloadings.
251   Store operator+=(Value e);
252   Store operator-=(Value e);
253   Store operator*=(Value e);
254   Store operator/=(Value e);
255   Store operator%=(Value e);
256   Store operator^=(Value e);
257   Store operator+=(TemplatedIndexedValue e) {
258     return this->operator+=(static_cast<Value>(e));
259   }
260   Store operator-=(TemplatedIndexedValue e) {
261     return this->operator-=(static_cast<Value>(e));
262   }
263   Store operator*=(TemplatedIndexedValue e) {
264     return this->operator*=(static_cast<Value>(e));
265   }
266   Store operator/=(TemplatedIndexedValue e) {
267     return this->operator/=(static_cast<Value>(e));
268   }
269   Store operator%=(TemplatedIndexedValue e) {
270     return this->operator%=(static_cast<Value>(e));
271   }
272   Store operator^=(TemplatedIndexedValue e) {
273     return this->operator^=(static_cast<Value>(e));
274   }
275 
276   /// Logical operator overloadings.
277   Value operator&&(Value e);
278   Value operator||(Value e);
279   Value operator&&(TemplatedIndexedValue e) {
280     return *this && static_cast<Value>(e);
281   }
282   Value operator||(TemplatedIndexedValue e) {
283     return *this || static_cast<Value>(e);
284   }
285 
286   /// Comparison operator overloadings.
287   Value eq(Value e);
288   Value ne(Value e);
289   Value slt(Value e);
290   Value sle(Value e);
291   Value sgt(Value e);
292   Value sge(Value e);
293   Value ult(Value e);
294   Value ule(Value e);
295   Value ugt(Value e);
296   Value uge(Value e);
slt(TemplatedIndexedValue e)297   Value slt(TemplatedIndexedValue e) {
298     return slt(*this, static_cast<Value>(e));
299   }
sle(TemplatedIndexedValue e)300   Value sle(TemplatedIndexedValue e) {
301     return sle(*this, static_cast<Value>(e));
302   }
sgt(TemplatedIndexedValue e)303   Value sgt(TemplatedIndexedValue e) {
304     return sgt(*this, static_cast<Value>(e));
305   }
sge(TemplatedIndexedValue e)306   Value sge(TemplatedIndexedValue e) {
307     return sge(*this, static_cast<Value>(e));
308   }
ult(TemplatedIndexedValue e)309   Value ult(TemplatedIndexedValue e) {
310     return ult(*this, static_cast<Value>(e));
311   }
ule(TemplatedIndexedValue e)312   Value ule(TemplatedIndexedValue e) {
313     return ule(*this, static_cast<Value>(e));
314   }
ugt(TemplatedIndexedValue e)315   Value ugt(TemplatedIndexedValue e) {
316     return ugt(*this, static_cast<Value>(e));
317   }
uge(TemplatedIndexedValue e)318   Value uge(TemplatedIndexedValue e) {
319     return uge(*this, static_cast<Value>(e));
320   }
321 
322 private:
TemplatedIndexedValue(Value value,ValueRange indices)323   TemplatedIndexedValue(Value value, ValueRange indices)
324       : value(value), indices(indices.begin(), indices.end()) {}
325 
append()326   TemplatedIndexedValue &append() { return *this; }
327 
328   template <typename T, typename... Args>
append(T index,Args...indices)329   TemplatedIndexedValue &append(T index, Args... indices) {
330     this->indices.push_back(static_cast<Value>(index));
331     append(indices...);
332     return *this;
333   }
334   Value value;
335   SmallVector<Value, 8> indices;
336 };
337 
338 } // namespace edsc
339 } // namespace mlir
340 
341 #endif // MLIR_EDSC_BUILDERS_H_
342