1 //===- Builders.h - MLIR Declarative Builder Classes ------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Provides intuitive composable interfaces for building structured MLIR 10 // snippets in a declarative fashion. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #ifndef MLIR_EDSC_BUILDERS_H_ 15 #define MLIR_EDSC_BUILDERS_H_ 16 17 #include "mlir/IR/AffineExpr.h" 18 #include "mlir/IR/Builders.h" 19 #include "mlir/IR/BuiltinTypes.h" 20 #include "mlir/IR/Types.h" 21 22 namespace mlir { 23 class OperationFolder; 24 25 namespace edsc { 26 /// Helper class to transparently handle builder insertion points by RAII. 27 /// As its name indicates, a ScopedContext is means to be used locally in a 28 /// scoped fashion. This abstracts away all the boilerplate related to 29 /// checking proper usage of captures, NestedBuilders as well as handling the 30 /// setting and restoring of insertion points. 31 class ScopedContext { 32 public: 33 ScopedContext(OpBuilder &b); 34 ScopedContext(OpBuilder &b, Location location); 35 36 /// Sets the insertion point of the builder to 'newInsertPt' for the duration 37 /// of the scope. The existing insertion point of the builder is restored on 38 /// destruction. 39 ScopedContext(OpBuilder &b, OpBuilder::InsertPoint newInsertPt, 40 Location location); 41 ~ScopedContext(); 42 43 static MLIRContext *getContext(); 44 static OpBuilder &getBuilderRef(); 45 static Location getLocation(); 46 47 private: 48 /// Only NestedBuilder (which is used to create an operation with a body) 49 /// may access private members in order to implement scoping. 50 friend class NestedBuilder; 51 52 ScopedContext() = delete; 53 ScopedContext(const ScopedContext &) = delete; 54 ScopedContext &operator=(const ScopedContext &) = delete; 55 56 static ScopedContext *&getCurrentScopedContext(); 57 58 /// Top level OpBuilder. 59 OpBuilder &builder; 60 /// Guard to the previous insertion point. 61 OpBuilder::InsertionGuard guard; 62 /// Current location. 63 Location location; 64 /// Parent context we return into. 65 ScopedContext *enclosingScopedContext; 66 }; 67 68 template <typename Op> 69 struct ValueBuilder { 70 template <typename... Args> ValueBuilderValueBuilder71 ValueBuilder(Args... args) { 72 value = ScopedContext::getBuilderRef() 73 .create<Op>(ScopedContext::getLocation(), args...) 74 .getResult(); 75 } ValueValueBuilder76 operator Value() { return value; } 77 Value value; 78 }; 79 80 template <typename Op> 81 struct OperationBuilder { 82 template <typename... Args> OperationBuilderOperationBuilder83 OperationBuilder(Args... args) { 84 op = ScopedContext::getBuilderRef().create<Op>(ScopedContext::getLocation(), 85 args...); 86 } OpOperationBuilder87 operator Op() { return op; } 88 operator Operation *() { return op.getOperation(); } 89 Op op; 90 }; 91 92 /// Creates a block in the region that contains the insertion block of the 93 /// OpBuilder currently at the top of ScopedContext stack (appends the block to 94 /// the region). Be aware that this will NOT update the insertion point of the 95 /// builder to insert into the newly constructed block. 96 Block *createBlock(TypeRange argTypes = llvm::None); 97 98 /// Creates a block in the specified region using OpBuilder at the top of 99 /// ScopedContext stack (appends the block to the region). Be aware that this 100 /// will NOT update the insertion point of the builder to insert into the newly 101 /// constructed block. 102 Block *createBlockInRegion(Region ®ion, TypeRange argTypes = llvm::None); 103 104 /// Calls "builderFn" with ScopedContext reconfigured to insert into "block" and 105 /// passes in the block arguments. If the block has a terminator, the operations 106 /// are inserted before the terminator, otherwise appended to the block. 107 void appendToBlock(Block *block, function_ref<void(ValueRange)> builderFn); 108 109 /// Creates a block in the region that contains the insertion block of the 110 /// OpBuilder currently at the top of ScopedContext stack, and calls "builderFn" 111 /// to populate the body of the block while passing it the block arguments. 112 Block *buildInNewBlock(TypeRange argTypes, 113 function_ref<void(ValueRange)> builderFn); 114 115 /// Creates a block in the specified region using OpBuilder at the top of 116 /// ScopedContext stack, and calls "builderFn" to populate the body of the block 117 /// while passing it the block arguments. 118 Block *buildInNewBlock(Region ®ion, TypeRange argTypes, 119 function_ref<void(ValueRange)> builderFn); 120 121 /// A StructuredIndexed represents an indexable quantity that is either: 122 /// 1. a captured value, which is suitable for buffer and tensor operands, or; 123 /// 2. a captured type, which is suitable for tensor return values. 124 /// 125 /// A StructuredIndexed itself is indexed and passed to `makeGenericLinalgOp`. 126 /// It enable an idiomatic syntax for index expressions such as: 127 /// 128 /// ``` 129 /// StructuredIndexed A(buffer_or_tensor_value), B(buffer_or_tensor_value), 130 /// C(buffer_value_or_tensor_type); 131 /// makeGenericLinalgOp({A({m, n}), B({k, n})}, {C({m, n})}, ... ); 132 /// ``` 133 struct StructuredIndexed { StructuredIndexedStructuredIndexed134 StructuredIndexed(Value v) : value(v) {} StructuredIndexedStructuredIndexed135 StructuredIndexed(Type t) : type(t) {} operatorStructuredIndexed136 StructuredIndexed operator()(ArrayRef<AffineExpr> indexings) { 137 return value ? StructuredIndexed(value, indexings) 138 : StructuredIndexed(type, indexings); 139 } 140 StructuredIndexedStructuredIndexed141 StructuredIndexed(Value v, ArrayRef<AffineExpr> indexings) 142 : value(v), exprs(indexings.begin(), indexings.end()) { 143 assert((v.getType().isa<MemRefType, RankedTensorType, VectorType>()) && 144 "MemRef, RankedTensor or Vector expected"); 145 } StructuredIndexedStructuredIndexed146 StructuredIndexed(Type t, ArrayRef<AffineExpr> indexings) 147 : type(t), exprs(indexings.begin(), indexings.end()) { 148 assert((t.isa<MemRefType, RankedTensorType, VectorType>()) && 149 "MemRef, RankedTensor or Vector expected"); 150 } 151 hasValueStructuredIndexed152 bool hasValue() const { return (bool)value; } getValueStructuredIndexed153 Value getValue() const { 154 assert(value && "StructuredIndexed Value not set."); 155 return value; 156 } getTypeStructuredIndexed157 Type getType() const { 158 assert((value || type) && "StructuredIndexed Value and Type not set."); 159 return value ? value.getType() : type; 160 } getExprsStructuredIndexed161 ArrayRef<AffineExpr> getExprs() const { return exprs; } ValueStructuredIndexed162 operator Value() const { return getValue(); } TypeStructuredIndexed163 operator Type() const { return getType(); } 164 165 private: 166 // Only one of Value or type may be set. 167 Type type; 168 Value value; 169 SmallVector<AffineExpr, 4> exprs; 170 }; 171 172 /// A TemplatedIndexedValue brings an index notation over the template Load and 173 /// Store parameters. Assigning to an IndexedValue emits an actual `Store` 174 /// operation, while converting an IndexedValue to a Value emits an actual 175 /// `Load` operation. 176 template <typename Load, typename Store> 177 class TemplatedIndexedValue { 178 public: TemplatedIndexedValue(Value v)179 explicit TemplatedIndexedValue(Value v) : value(v) {} 180 181 TemplatedIndexedValue(const TemplatedIndexedValue &rhs) = default; 182 operator()183 TemplatedIndexedValue operator()() { return *this; } 184 /// Returns a new `TemplatedIndexedValue`. operator()185 TemplatedIndexedValue operator()(Value index) { 186 TemplatedIndexedValue res(value); 187 res.indices.push_back(index); 188 return res; 189 } 190 template <typename... Args> operator()191 TemplatedIndexedValue operator()(Value index, Args... indices) { 192 return TemplatedIndexedValue(value, index).append(indices...); 193 } operator()194 TemplatedIndexedValue operator()(ValueRange indices) { 195 return TemplatedIndexedValue(value, indices); 196 } 197 198 /// Emits a `store`. 199 Store operator=(const TemplatedIndexedValue &rhs) { 200 return Store(rhs, value, indices); 201 } 202 Store operator=(Value rhs) { return Store(rhs, value, indices); } 203 204 /// Emits a `load` when converting to a Value. Value()205 operator Value() const { return Load(value, indices); } 206 207 /// Returns the base memref. getBase()208 Value getBase() const { return value; } 209 210 /// Returns the underlying memref. getMemRefType()211 MemRefType getMemRefType() const { 212 return value.getType().template cast<MemRefType>(); 213 } 214 215 /// Returns the underlying MemRef elemental type cast as `T`. 216 template <typename T> getElementalTypeAs()217 T getElementalTypeAs() const { 218 return value.getType() 219 .template cast<MemRefType>() 220 .getElementType() 221 .template cast<T>(); 222 } 223 224 /// Arithmetic operator overloadings. 225 Value operator+(Value e); 226 Value operator-(Value e); 227 Value operator*(Value e); 228 Value operator/(Value e); 229 Value operator%(Value e); 230 Value operator^(Value e); 231 Value operator+(TemplatedIndexedValue e) { 232 return *this + static_cast<Value>(e); 233 } 234 Value operator-(TemplatedIndexedValue e) { 235 return *this - static_cast<Value>(e); 236 } 237 Value operator*(TemplatedIndexedValue e) { 238 return *this * static_cast<Value>(e); 239 } 240 Value operator/(TemplatedIndexedValue e) { 241 return *this / static_cast<Value>(e); 242 } 243 Value operator%(TemplatedIndexedValue e) { 244 return *this % static_cast<Value>(e); 245 } 246 Value operator^(TemplatedIndexedValue e) { 247 return *this ^ static_cast<Value>(e); 248 } 249 250 /// Assignment-arithmetic operator overloadings. 251 Store operator+=(Value e); 252 Store operator-=(Value e); 253 Store operator*=(Value e); 254 Store operator/=(Value e); 255 Store operator%=(Value e); 256 Store operator^=(Value e); 257 Store operator+=(TemplatedIndexedValue e) { 258 return this->operator+=(static_cast<Value>(e)); 259 } 260 Store operator-=(TemplatedIndexedValue e) { 261 return this->operator-=(static_cast<Value>(e)); 262 } 263 Store operator*=(TemplatedIndexedValue e) { 264 return this->operator*=(static_cast<Value>(e)); 265 } 266 Store operator/=(TemplatedIndexedValue e) { 267 return this->operator/=(static_cast<Value>(e)); 268 } 269 Store operator%=(TemplatedIndexedValue e) { 270 return this->operator%=(static_cast<Value>(e)); 271 } 272 Store operator^=(TemplatedIndexedValue e) { 273 return this->operator^=(static_cast<Value>(e)); 274 } 275 276 /// Logical operator overloadings. 277 Value operator&&(Value e); 278 Value operator||(Value e); 279 Value operator&&(TemplatedIndexedValue e) { 280 return *this && static_cast<Value>(e); 281 } 282 Value operator||(TemplatedIndexedValue e) { 283 return *this || static_cast<Value>(e); 284 } 285 286 /// Comparison operator overloadings. 287 Value eq(Value e); 288 Value ne(Value e); 289 Value slt(Value e); 290 Value sle(Value e); 291 Value sgt(Value e); 292 Value sge(Value e); 293 Value ult(Value e); 294 Value ule(Value e); 295 Value ugt(Value e); 296 Value uge(Value e); slt(TemplatedIndexedValue e)297 Value slt(TemplatedIndexedValue e) { 298 return slt(*this, static_cast<Value>(e)); 299 } sle(TemplatedIndexedValue e)300 Value sle(TemplatedIndexedValue e) { 301 return sle(*this, static_cast<Value>(e)); 302 } sgt(TemplatedIndexedValue e)303 Value sgt(TemplatedIndexedValue e) { 304 return sgt(*this, static_cast<Value>(e)); 305 } sge(TemplatedIndexedValue e)306 Value sge(TemplatedIndexedValue e) { 307 return sge(*this, static_cast<Value>(e)); 308 } ult(TemplatedIndexedValue e)309 Value ult(TemplatedIndexedValue e) { 310 return ult(*this, static_cast<Value>(e)); 311 } ule(TemplatedIndexedValue e)312 Value ule(TemplatedIndexedValue e) { 313 return ule(*this, static_cast<Value>(e)); 314 } ugt(TemplatedIndexedValue e)315 Value ugt(TemplatedIndexedValue e) { 316 return ugt(*this, static_cast<Value>(e)); 317 } uge(TemplatedIndexedValue e)318 Value uge(TemplatedIndexedValue e) { 319 return uge(*this, static_cast<Value>(e)); 320 } 321 322 private: TemplatedIndexedValue(Value value,ValueRange indices)323 TemplatedIndexedValue(Value value, ValueRange indices) 324 : value(value), indices(indices.begin(), indices.end()) {} 325 append()326 TemplatedIndexedValue &append() { return *this; } 327 328 template <typename T, typename... Args> append(T index,Args...indices)329 TemplatedIndexedValue &append(T index, Args... indices) { 330 this->indices.push_back(static_cast<Value>(index)); 331 append(indices...); 332 return *this; 333 } 334 Value value; 335 SmallVector<Value, 8> indices; 336 }; 337 338 } // namespace edsc 339 } // namespace mlir 340 341 #endif // MLIR_EDSC_BUILDERS_H_ 342