1 //===- Async.cpp - MLIR Async Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Async/IR/Async.h"
10
11 #include "mlir/IR/DialectImplementation.h"
12 #include "llvm/ADT/TypeSwitch.h"
13
14 using namespace mlir;
15 using namespace mlir::async;
16
initialize()17 void AsyncDialect::initialize() {
18 addOperations<
19 #define GET_OP_LIST
20 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
21 >();
22 addTypes<TokenType>();
23 addTypes<ValueType>();
24 addTypes<GroupType>();
25 }
26
27 /// Parse a type registered to this dialect.
parseType(DialectAsmParser & parser) const28 Type AsyncDialect::parseType(DialectAsmParser &parser) const {
29 StringRef keyword;
30 if (parser.parseKeyword(&keyword))
31 return Type();
32
33 if (keyword == "token")
34 return TokenType::get(getContext());
35
36 if (keyword == "value") {
37 Type ty;
38 if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater()) {
39 parser.emitError(parser.getNameLoc(), "failed to parse async value type");
40 return Type();
41 }
42 return ValueType::get(ty);
43 }
44
45 parser.emitError(parser.getNameLoc(), "unknown async type: ") << keyword;
46 return Type();
47 }
48
49 /// Print a type registered to this dialect.
printType(Type type,DialectAsmPrinter & os) const50 void AsyncDialect::printType(Type type, DialectAsmPrinter &os) const {
51 TypeSwitch<Type>(type)
52 .Case<TokenType>([&](TokenType) { os << "token"; })
53 .Case<ValueType>([&](ValueType valueTy) {
54 os << "value<";
55 os.printType(valueTy.getValueType());
56 os << '>';
57 })
58 .Case<GroupType>([&](GroupType) { os << "group"; })
59 .Default([](Type) { llvm_unreachable("unexpected 'async' type kind"); });
60 }
61
62 //===----------------------------------------------------------------------===//
63 /// ValueType
64 //===----------------------------------------------------------------------===//
65
66 namespace mlir {
67 namespace async {
68 namespace detail {
69
70 // Storage for `async.value<T>` type, the only member is the wrapped type.
71 struct ValueTypeStorage : public TypeStorage {
ValueTypeStoragemlir::async::detail::ValueTypeStorage72 ValueTypeStorage(Type valueType) : valueType(valueType) {}
73
74 /// The hash key used for uniquing.
75 using KeyTy = Type;
operator ==mlir::async::detail::ValueTypeStorage76 bool operator==(const KeyTy &key) const { return key == valueType; }
77
78 /// Construction.
constructmlir::async::detail::ValueTypeStorage79 static ValueTypeStorage *construct(TypeStorageAllocator &allocator,
80 Type valueType) {
81 return new (allocator.allocate<ValueTypeStorage>())
82 ValueTypeStorage(valueType);
83 }
84
85 Type valueType;
86 };
87
88 } // namespace detail
89 } // namespace async
90 } // namespace mlir
91
get(Type valueType)92 ValueType ValueType::get(Type valueType) {
93 return Base::get(valueType.getContext(), valueType);
94 }
95
getValueType()96 Type ValueType::getValueType() { return getImpl()->valueType; }
97
98 //===----------------------------------------------------------------------===//
99 // YieldOp
100 //===----------------------------------------------------------------------===//
101
verify(YieldOp op)102 static LogicalResult verify(YieldOp op) {
103 // Get the underlying value types from async values returned from the
104 // parent `async.execute` operation.
105 auto executeOp = op->getParentOfType<ExecuteOp>();
106 auto types = llvm::map_range(executeOp.results(), [](const OpResult &result) {
107 return result.getType().cast<ValueType>().getValueType();
108 });
109
110 if (op.getOperandTypes() != types)
111 return op.emitOpError("operand types do not match the types returned from "
112 "the parent ExecuteOp");
113
114 return success();
115 }
116
117 //===----------------------------------------------------------------------===//
118 /// ExecuteOp
119 //===----------------------------------------------------------------------===//
120
121 constexpr char kOperandSegmentSizesAttr[] = "operand_segment_sizes";
122
getNumRegionInvocations(ArrayRef<Attribute> operands,SmallVectorImpl<int64_t> & countPerRegion)123 void ExecuteOp::getNumRegionInvocations(
124 ArrayRef<Attribute> operands, SmallVectorImpl<int64_t> &countPerRegion) {
125 (void)operands;
126 assert(countPerRegion.empty());
127 countPerRegion.push_back(1);
128 }
129
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)130 void ExecuteOp::getSuccessorRegions(Optional<unsigned> index,
131 ArrayRef<Attribute> operands,
132 SmallVectorImpl<RegionSuccessor> ®ions) {
133 // The `body` region branch back to the parent operation.
134 if (index.hasValue()) {
135 assert(*index == 0);
136 regions.push_back(RegionSuccessor(getResults()));
137 return;
138 }
139
140 // Otherwise the successor is the body region.
141 regions.push_back(RegionSuccessor(&body()));
142 }
143
build(OpBuilder & builder,OperationState & result,TypeRange resultTypes,ValueRange dependencies,ValueRange operands,BodyBuilderFn bodyBuilder)144 void ExecuteOp::build(OpBuilder &builder, OperationState &result,
145 TypeRange resultTypes, ValueRange dependencies,
146 ValueRange operands, BodyBuilderFn bodyBuilder) {
147
148 result.addOperands(dependencies);
149 result.addOperands(operands);
150
151 // Add derived `operand_segment_sizes` attribute based on parsed operands.
152 int32_t numDependencies = dependencies.size();
153 int32_t numOperands = operands.size();
154 auto operandSegmentSizes = DenseIntElementsAttr::get(
155 VectorType::get({2}, IntegerType::get(32, result.getContext())),
156 {numDependencies, numOperands});
157 result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes);
158
159 // First result is always a token, and then `resultTypes` wrapped into
160 // `async.value`.
161 result.addTypes({TokenType::get(result.getContext())});
162 for (Type type : resultTypes)
163 result.addTypes(ValueType::get(type));
164
165 // Add a body region with block arguments as unwrapped async value operands.
166 Region *bodyRegion = result.addRegion();
167 bodyRegion->push_back(new Block);
168 Block &bodyBlock = bodyRegion->front();
169 for (Value operand : operands) {
170 auto valueType = operand.getType().dyn_cast<ValueType>();
171 bodyBlock.addArgument(valueType ? valueType.getValueType()
172 : operand.getType());
173 }
174
175 // Create the default terminator if the builder is not provided and if the
176 // expected result is empty. Otherwise, leave this to the caller
177 // because we don't know which values to return from the execute op.
178 if (resultTypes.empty() && !bodyBuilder) {
179 OpBuilder::InsertionGuard guard(builder);
180 builder.setInsertionPointToStart(&bodyBlock);
181 builder.create<async::YieldOp>(result.location, ValueRange());
182 } else if (bodyBuilder) {
183 OpBuilder::InsertionGuard guard(builder);
184 builder.setInsertionPointToStart(&bodyBlock);
185 bodyBuilder(builder, result.location, bodyBlock.getArguments());
186 }
187 }
188
print(OpAsmPrinter & p,ExecuteOp op)189 static void print(OpAsmPrinter &p, ExecuteOp op) {
190 p << op.getOperationName();
191
192 // [%tokens,...]
193 if (!op.dependencies().empty())
194 p << " [" << op.dependencies() << "]";
195
196 // (%value as %unwrapped: !async.value<!arg.type>, ...)
197 if (!op.operands().empty()) {
198 p << " (";
199 llvm::interleaveComma(op.operands(), p, [&, n = 0](Value operand) mutable {
200 p << operand << " as " << op.body().front().getArgument(n++) << ": "
201 << operand.getType();
202 });
203 p << ")";
204 }
205
206 // -> (!async.value<!return.type>, ...)
207 p.printOptionalArrowTypeList(op.getResultTypes().drop_front(1));
208 p.printOptionalAttrDictWithKeyword(op.getAttrs(), {kOperandSegmentSizesAttr});
209 p.printRegion(op.body(), /*printEntryBlockArgs=*/false);
210 }
211
parseExecuteOp(OpAsmParser & parser,OperationState & result)212 static ParseResult parseExecuteOp(OpAsmParser &parser, OperationState &result) {
213 MLIRContext *ctx = result.getContext();
214
215 // Sizes of parsed variadic operands, will be updated below after parsing.
216 int32_t numDependencies = 0;
217 int32_t numOperands = 0;
218
219 auto tokenTy = TokenType::get(ctx);
220
221 // Parse dependency tokens.
222 if (succeeded(parser.parseOptionalLSquare())) {
223 SmallVector<OpAsmParser::OperandType, 4> tokenArgs;
224 if (parser.parseOperandList(tokenArgs) ||
225 parser.resolveOperands(tokenArgs, tokenTy, result.operands) ||
226 parser.parseRSquare())
227 return failure();
228
229 numDependencies = tokenArgs.size();
230 }
231
232 // Parse async value operands (%value as %unwrapped : !async.value<!type>).
233 SmallVector<OpAsmParser::OperandType, 4> valueArgs;
234 SmallVector<OpAsmParser::OperandType, 4> unwrappedArgs;
235 SmallVector<Type, 4> valueTypes;
236 SmallVector<Type, 4> unwrappedTypes;
237
238 if (succeeded(parser.parseOptionalLParen())) {
239 auto argsLoc = parser.getCurrentLocation();
240
241 // Parse a single instance of `%value as %unwrapped : !async.value<!type>`.
242 auto parseAsyncValueArg = [&]() -> ParseResult {
243 if (parser.parseOperand(valueArgs.emplace_back()) ||
244 parser.parseKeyword("as") ||
245 parser.parseOperand(unwrappedArgs.emplace_back()) ||
246 parser.parseColonType(valueTypes.emplace_back()))
247 return failure();
248
249 auto valueTy = valueTypes.back().dyn_cast<ValueType>();
250 unwrappedTypes.emplace_back(valueTy ? valueTy.getValueType() : Type());
251
252 return success();
253 };
254
255 // If the next token is `)` skip async value arguments parsing.
256 if (failed(parser.parseOptionalRParen())) {
257 do {
258 if (parseAsyncValueArg())
259 return failure();
260 } while (succeeded(parser.parseOptionalComma()));
261
262 if (parser.parseRParen() ||
263 parser.resolveOperands(valueArgs, valueTypes, argsLoc,
264 result.operands))
265 return failure();
266 }
267
268 numOperands = valueArgs.size();
269 }
270
271 // Add derived `operand_segment_sizes` attribute based on parsed operands.
272 auto operandSegmentSizes = DenseIntElementsAttr::get(
273 VectorType::get({2}, parser.getBuilder().getI32Type()),
274 {numDependencies, numOperands});
275 result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes);
276
277 // Parse the types of results returned from the async execute op.
278 SmallVector<Type, 4> resultTypes;
279 if (parser.parseOptionalArrowTypeList(resultTypes))
280 return failure();
281
282 // Async execute first result is always a completion token.
283 parser.addTypeToList(tokenTy, result.types);
284 parser.addTypesToList(resultTypes, result.types);
285
286 // Parse operation attributes.
287 NamedAttrList attrs;
288 if (parser.parseOptionalAttrDictWithKeyword(attrs))
289 return failure();
290 result.addAttributes(attrs);
291
292 // Parse asynchronous region.
293 Region *body = result.addRegion();
294 if (parser.parseRegion(*body, /*arguments=*/{unwrappedArgs},
295 /*argTypes=*/{unwrappedTypes},
296 /*enableNameShadowing=*/false))
297 return failure();
298
299 return success();
300 }
301
verify(ExecuteOp op)302 static LogicalResult verify(ExecuteOp op) {
303 // Unwrap async.execute value operands types.
304 auto unwrappedTypes = llvm::map_range(op.operands(), [](Value operand) {
305 return operand.getType().cast<ValueType>().getValueType();
306 });
307
308 // Verify that unwrapped argument types matches the body region arguments.
309 if (op.body().getArgumentTypes() != unwrappedTypes)
310 return op.emitOpError("async body region argument types do not match the "
311 "execute operation arguments types");
312
313 return success();
314 }
315
316 //===----------------------------------------------------------------------===//
317 /// AwaitOp
318 //===----------------------------------------------------------------------===//
319
build(OpBuilder & builder,OperationState & result,Value operand,ArrayRef<NamedAttribute> attrs)320 void AwaitOp::build(OpBuilder &builder, OperationState &result, Value operand,
321 ArrayRef<NamedAttribute> attrs) {
322 result.addOperands({operand});
323 result.attributes.append(attrs.begin(), attrs.end());
324
325 // Add unwrapped async.value type to the returned values types.
326 if (auto valueType = operand.getType().dyn_cast<ValueType>())
327 result.addTypes(valueType.getValueType());
328 }
329
parseAwaitResultType(OpAsmParser & parser,Type & operandType,Type & resultType)330 static ParseResult parseAwaitResultType(OpAsmParser &parser, Type &operandType,
331 Type &resultType) {
332 if (parser.parseType(operandType))
333 return failure();
334
335 // Add unwrapped async.value type to the returned values types.
336 if (auto valueType = operandType.dyn_cast<ValueType>())
337 resultType = valueType.getValueType();
338
339 return success();
340 }
341
printAwaitResultType(OpAsmPrinter & p,Operation * op,Type operandType,Type resultType)342 static void printAwaitResultType(OpAsmPrinter &p, Operation *op,
343 Type operandType, Type resultType) {
344 p << operandType;
345 }
346
verify(AwaitOp op)347 static LogicalResult verify(AwaitOp op) {
348 Type argType = op.operand().getType();
349
350 // Awaiting on a token does not have any results.
351 if (argType.isa<TokenType>() && !op.getResultTypes().empty())
352 return op.emitOpError("awaiting on a token must have empty result");
353
354 // Awaiting on a value unwraps the async value type.
355 if (auto value = argType.dyn_cast<ValueType>()) {
356 if (*op.getResultType() != value.getValueType())
357 return op.emitOpError()
358 << "result type " << *op.getResultType()
359 << " does not match async value type " << value.getValueType();
360 }
361
362 return success();
363 }
364
365 #define GET_OP_CLASSES
366 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
367