1 //===- Builders.cpp - MLIR Declarative Builder Classes --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Affine/EDSC/Builders.h"
10 #include "mlir/Dialect/StandardOps/EDSC/Builders.h"
11 #include "mlir/IR/AffineExpr.h"
12 #include "mlir/IR/AffineMap.h"
13
14 using namespace mlir;
15 using namespace mlir::edsc;
16
affineLoopNestBuilder(ValueRange lbs,ValueRange ubs,ArrayRef<int64_t> steps,function_ref<void (ValueRange)> bodyBuilderFn)17 void mlir::edsc::affineLoopNestBuilder(
18 ValueRange lbs, ValueRange ubs, ArrayRef<int64_t> steps,
19 function_ref<void(ValueRange)> bodyBuilderFn) {
20 assert(ScopedContext::getContext() && "EDSC ScopedContext not set up");
21
22 // Wrap the body builder function into an interface compatible with the main
23 // builder.
24 auto wrappedBuilderFn = [&](OpBuilder &nestedBuilder, Location nestedLoc,
25 ValueRange ivs) {
26 ScopedContext context(nestedBuilder, nestedLoc);
27 bodyBuilderFn(ivs);
28 };
29 function_ref<void(OpBuilder &, Location, ValueRange)> wrapper;
30 if (bodyBuilderFn)
31 wrapper = wrappedBuilderFn;
32
33 // Extract the builder, location and construct the loop nest.
34 OpBuilder &builder = ScopedContext::getBuilderRef();
35 Location loc = ScopedContext::getLocation();
36 buildAffineLoopNest(builder, loc, lbs, ubs, steps, wrapper);
37 }
38
affineLoopBuilder(ValueRange lbs,ValueRange ubs,int64_t step,function_ref<void (Value)> bodyBuilderFn)39 void mlir::edsc::affineLoopBuilder(ValueRange lbs, ValueRange ubs, int64_t step,
40 function_ref<void(Value)> bodyBuilderFn) {
41 // Fetch the builder and location.
42 assert(ScopedContext::getContext() && "EDSC ScopedContext not set up");
43 OpBuilder &builder = ScopedContext::getBuilderRef();
44 Location loc = ScopedContext::getLocation();
45
46 // Create the actual loop and call the body builder, if provided, after
47 // updating the scoped context.
48 builder.create<AffineForOp>(
49 loc, lbs, builder.getMultiDimIdentityMap(lbs.size()), ubs,
50 builder.getMultiDimIdentityMap(ubs.size()), step, llvm::None,
51 [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
52 ValueRange itrArgs) {
53 if (bodyBuilderFn) {
54 ScopedContext nestedContext(nestedBuilder, nestedLoc);
55 OpBuilder::InsertionGuard guard(nestedBuilder);
56 bodyBuilderFn(iv);
57 }
58 nestedBuilder.create<AffineYieldOp>(nestedLoc);
59 });
60 }
61
affineLoopBuilder(ValueRange lbs,ValueRange ubs,int64_t step,ValueRange iterArgs,function_ref<void (Value,ValueRange)> bodyBuilderFn)62 void mlir::edsc::affineLoopBuilder(
63 ValueRange lbs, ValueRange ubs, int64_t step, ValueRange iterArgs,
64 function_ref<void(Value, ValueRange)> bodyBuilderFn) {
65 // Fetch the builder and location.
66 assert(ScopedContext::getContext() && "EDSC ScopedContext not set up");
67 OpBuilder &builder = ScopedContext::getBuilderRef();
68 Location loc = ScopedContext::getLocation();
69
70 // Create the actual loop and call the body builder, if provided, after
71 // updating the scoped context.
72 builder.create<AffineForOp>(
73 loc, lbs, builder.getMultiDimIdentityMap(lbs.size()), ubs,
74 builder.getMultiDimIdentityMap(ubs.size()), step, iterArgs,
75 [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
76 ValueRange itrArgs) {
77 if (bodyBuilderFn) {
78 ScopedContext nestedContext(nestedBuilder, nestedLoc);
79 OpBuilder::InsertionGuard guard(nestedBuilder);
80 bodyBuilderFn(iv, itrArgs);
81 } else if (itrArgs.empty())
82 nestedBuilder.create<AffineYieldOp>(nestedLoc);
83 });
84 }
85
86 static std::pair<AffineExpr, Value>
categorizeValueByAffineType(MLIRContext * context,Value val,unsigned & numDims,unsigned & numSymbols)87 categorizeValueByAffineType(MLIRContext *context, Value val, unsigned &numDims,
88 unsigned &numSymbols) {
89 AffineExpr d;
90 Value resultVal = nullptr;
91 if (auto constant = val.getDefiningOp<ConstantIndexOp>()) {
92 d = getAffineConstantExpr(constant.getValue(), context);
93 } else if (isValidSymbol(val) && !isValidDim(val)) {
94 d = getAffineSymbolExpr(numSymbols++, context);
95 resultVal = val;
96 } else {
97 d = getAffineDimExpr(numDims++, context);
98 resultVal = val;
99 }
100 return std::make_pair(d, resultVal);
101 }
102
createBinaryIndexHandle(Value lhs,Value rhs,function_ref<AffineExpr (AffineExpr,AffineExpr)> affCombiner)103 static Value createBinaryIndexHandle(
104 Value lhs, Value rhs,
105 function_ref<AffineExpr(AffineExpr, AffineExpr)> affCombiner) {
106 MLIRContext *context = ScopedContext::getContext();
107 unsigned numDims = 0, numSymbols = 0;
108 AffineExpr d0, d1;
109 Value v0, v1;
110 std::tie(d0, v0) =
111 categorizeValueByAffineType(context, lhs, numDims, numSymbols);
112 std::tie(d1, v1) =
113 categorizeValueByAffineType(context, rhs, numDims, numSymbols);
114 SmallVector<Value, 2> operands;
115 if (v0)
116 operands.push_back(v0);
117 if (v1)
118 operands.push_back(v1);
119 auto map = AffineMap::get(numDims, numSymbols, affCombiner(d0, d1));
120
121 // TODO: createOrFold when available.
122 Operation *op =
123 makeComposedAffineApply(ScopedContext::getBuilderRef(),
124 ScopedContext::getLocation(), map, operands)
125 .getOperation();
126 assert(op->getNumResults() == 1 && "Expected single result AffineApply");
127 return op->getResult(0);
128 }
129
130 template <typename IOp, typename FOp>
createBinaryHandle(Value lhs,Value rhs,function_ref<AffineExpr (AffineExpr,AffineExpr)> affCombiner)131 static Value createBinaryHandle(
132 Value lhs, Value rhs,
133 function_ref<AffineExpr(AffineExpr, AffineExpr)> affCombiner) {
134 auto thisType = lhs.getType();
135 auto thatType = rhs.getType();
136 assert(thisType == thatType && "cannot mix types in operators");
137 (void)thisType;
138 (void)thatType;
139 if (thisType.isIndex()) {
140 return createBinaryIndexHandle(lhs, rhs, affCombiner);
141 } else if (thisType.isSignlessInteger()) {
142 return ValueBuilder<IOp>(lhs, rhs);
143 } else if (thisType.isa<FloatType>()) {
144 return ValueBuilder<FOp>(lhs, rhs);
145 } else if (thisType.isa<VectorType, TensorType>()) {
146 auto aggregateType = thisType.cast<ShapedType>();
147 if (aggregateType.getElementType().isSignlessInteger())
148 return ValueBuilder<IOp>(lhs, rhs);
149 else if (aggregateType.getElementType().isa<FloatType>())
150 return ValueBuilder<FOp>(lhs, rhs);
151 }
152 llvm_unreachable("failed to create a Value");
153 }
154
operator +(Value lhs,Value rhs)155 Value mlir::edsc::op::operator+(Value lhs, Value rhs) {
156 return createBinaryHandle<AddIOp, AddFOp>(
157 lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 + d1; });
158 }
159
operator -(Value lhs,Value rhs)160 Value mlir::edsc::op::operator-(Value lhs, Value rhs) {
161 return createBinaryHandle<SubIOp, SubFOp>(
162 lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 - d1; });
163 }
164
operator *(Value lhs,Value rhs)165 Value mlir::edsc::op::operator*(Value lhs, Value rhs) {
166 return createBinaryHandle<MulIOp, MulFOp>(
167 lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 * d1; });
168 }
169
operator /(Value lhs,Value rhs)170 Value mlir::edsc::op::operator/(Value lhs, Value rhs) {
171 return createBinaryHandle<SignedDivIOp, DivFOp>(
172 lhs, rhs, [](AffineExpr d0, AffineExpr d1) -> AffineExpr {
173 llvm_unreachable("only exprs of non-index type support operator/");
174 });
175 }
176
operator %(Value lhs,Value rhs)177 Value mlir::edsc::op::operator%(Value lhs, Value rhs) {
178 return createBinaryHandle<SignedRemIOp, RemFOp>(
179 lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 % d1; });
180 }
181
floorDiv(Value lhs,Value rhs)182 Value mlir::edsc::op::floorDiv(Value lhs, Value rhs) {
183 return createBinaryIndexHandle(
184 lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0.floorDiv(d1); });
185 }
186
ceilDiv(Value lhs,Value rhs)187 Value mlir::edsc::op::ceilDiv(Value lhs, Value rhs) {
188 return createBinaryIndexHandle(
189 lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0.ceilDiv(d1); });
190 }
191
negate(Value value)192 Value mlir::edsc::op::negate(Value value) {
193 assert(value.getType().isInteger(1) && "expected boolean expression");
194 return ValueBuilder<ConstantIntOp>(1, 1) - value;
195 }
196
operator &&(Value lhs,Value rhs)197 Value mlir::edsc::op::operator&&(Value lhs, Value rhs) {
198 assert(lhs.getType().isInteger(1) && "expected boolean expression on LHS");
199 assert(rhs.getType().isInteger(1) && "expected boolean expression on RHS");
200 return ValueBuilder<AndOp>(lhs, rhs);
201 }
202
operator ||(Value lhs,Value rhs)203 Value mlir::edsc::op::operator||(Value lhs, Value rhs) {
204 assert(lhs.getType().isInteger(1) && "expected boolean expression on LHS");
205 assert(rhs.getType().isInteger(1) && "expected boolean expression on RHS");
206 return ValueBuilder<OrOp>(lhs, rhs);
207 }
208
createIComparisonExpr(CmpIPredicate predicate,Value lhs,Value rhs)209 static Value createIComparisonExpr(CmpIPredicate predicate, Value lhs,
210 Value rhs) {
211 auto lhsType = lhs.getType();
212 auto rhsType = rhs.getType();
213 (void)lhsType;
214 (void)rhsType;
215 assert(lhsType == rhsType && "cannot mix types in operators");
216 assert((lhsType.isa<IndexType>() || lhsType.isSignlessInteger()) &&
217 "only integer comparisons are supported");
218
219 return ScopedContext::getBuilderRef().create<CmpIOp>(
220 ScopedContext::getLocation(), predicate, lhs, rhs);
221 }
222
createFComparisonExpr(CmpFPredicate predicate,Value lhs,Value rhs)223 static Value createFComparisonExpr(CmpFPredicate predicate, Value lhs,
224 Value rhs) {
225 auto lhsType = lhs.getType();
226 auto rhsType = rhs.getType();
227 (void)lhsType;
228 (void)rhsType;
229 assert(lhsType == rhsType && "cannot mix types in operators");
230 assert(lhsType.isa<FloatType>() && "only float comparisons are supported");
231
232 return ScopedContext::getBuilderRef().create<CmpFOp>(
233 ScopedContext::getLocation(), predicate, lhs, rhs);
234 }
235
236 // All floating point comparison are ordered through EDSL
eq(Value lhs,Value rhs)237 Value mlir::edsc::op::eq(Value lhs, Value rhs) {
238 auto type = lhs.getType();
239 return type.isa<FloatType>()
240 ? createFComparisonExpr(CmpFPredicate::OEQ, lhs, rhs)
241 : createIComparisonExpr(CmpIPredicate::eq, lhs, rhs);
242 }
ne(Value lhs,Value rhs)243 Value mlir::edsc::op::ne(Value lhs, Value rhs) {
244 auto type = lhs.getType();
245 return type.isa<FloatType>()
246 ? createFComparisonExpr(CmpFPredicate::ONE, lhs, rhs)
247 : createIComparisonExpr(CmpIPredicate::ne, lhs, rhs);
248 }
slt(Value lhs,Value rhs)249 Value mlir::edsc::op::slt(Value lhs, Value rhs) {
250 auto type = lhs.getType();
251 return type.isa<FloatType>()
252 ? createFComparisonExpr(CmpFPredicate::OLT, lhs, rhs)
253 : createIComparisonExpr(CmpIPredicate::slt, lhs, rhs);
254 }
sle(Value lhs,Value rhs)255 Value mlir::edsc::op::sle(Value lhs, Value rhs) {
256 auto type = lhs.getType();
257 return type.isa<FloatType>()
258 ? createFComparisonExpr(CmpFPredicate::OLE, lhs, rhs)
259 : createIComparisonExpr(CmpIPredicate::sle, lhs, rhs);
260 }
sgt(Value lhs,Value rhs)261 Value mlir::edsc::op::sgt(Value lhs, Value rhs) {
262 auto type = lhs.getType();
263 return type.isa<FloatType>()
264 ? createFComparisonExpr(CmpFPredicate::OGT, lhs, rhs)
265 : createIComparisonExpr(CmpIPredicate::sgt, lhs, rhs);
266 }
sge(Value lhs,Value rhs)267 Value mlir::edsc::op::sge(Value lhs, Value rhs) {
268 auto type = lhs.getType();
269 return type.isa<FloatType>()
270 ? createFComparisonExpr(CmpFPredicate::OGE, lhs, rhs)
271 : createIComparisonExpr(CmpIPredicate::sge, lhs, rhs);
272 }
ult(Value lhs,Value rhs)273 Value mlir::edsc::op::ult(Value lhs, Value rhs) {
274 auto type = lhs.getType();
275 return type.isa<FloatType>()
276 ? createFComparisonExpr(CmpFPredicate::OLT, lhs, rhs)
277 : createIComparisonExpr(CmpIPredicate::ult, lhs, rhs);
278 }
ule(Value lhs,Value rhs)279 Value mlir::edsc::op::ule(Value lhs, Value rhs) {
280 auto type = lhs.getType();
281 return type.isa<FloatType>()
282 ? createFComparisonExpr(CmpFPredicate::OLE, lhs, rhs)
283 : createIComparisonExpr(CmpIPredicate::ule, lhs, rhs);
284 }
ugt(Value lhs,Value rhs)285 Value mlir::edsc::op::ugt(Value lhs, Value rhs) {
286 auto type = lhs.getType();
287 return type.isa<FloatType>()
288 ? createFComparisonExpr(CmpFPredicate::OGT, lhs, rhs)
289 : createIComparisonExpr(CmpIPredicate::ugt, lhs, rhs);
290 }
uge(Value lhs,Value rhs)291 Value mlir::edsc::op::uge(Value lhs, Value rhs) {
292 auto type = lhs.getType();
293 return type.isa<FloatType>()
294 ? createFComparisonExpr(CmpFPredicate::OGE, lhs, rhs)
295 : createIComparisonExpr(CmpIPredicate::uge, lhs, rhs);
296 }
297