• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- Async.cpp - MLIR Async Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Async/IR/Async.h"
10 
11 #include "mlir/IR/DialectImplementation.h"
12 #include "llvm/ADT/TypeSwitch.h"
13 
14 using namespace mlir;
15 using namespace mlir::async;
16 
initialize()17 void AsyncDialect::initialize() {
18   addOperations<
19 #define GET_OP_LIST
20 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
21       >();
22   addTypes<TokenType>();
23   addTypes<ValueType>();
24   addTypes<GroupType>();
25 }
26 
27 /// Parse a type registered to this dialect.
parseType(DialectAsmParser & parser) const28 Type AsyncDialect::parseType(DialectAsmParser &parser) const {
29   StringRef keyword;
30   if (parser.parseKeyword(&keyword))
31     return Type();
32 
33   if (keyword == "token")
34     return TokenType::get(getContext());
35 
36   if (keyword == "value") {
37     Type ty;
38     if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater()) {
39       parser.emitError(parser.getNameLoc(), "failed to parse async value type");
40       return Type();
41     }
42     return ValueType::get(ty);
43   }
44 
45   parser.emitError(parser.getNameLoc(), "unknown async type: ") << keyword;
46   return Type();
47 }
48 
49 /// Print a type registered to this dialect.
printType(Type type,DialectAsmPrinter & os) const50 void AsyncDialect::printType(Type type, DialectAsmPrinter &os) const {
51   TypeSwitch<Type>(type)
52       .Case<TokenType>([&](TokenType) { os << "token"; })
53       .Case<ValueType>([&](ValueType valueTy) {
54         os << "value<";
55         os.printType(valueTy.getValueType());
56         os << '>';
57       })
58       .Case<GroupType>([&](GroupType) { os << "group"; })
59       .Default([](Type) { llvm_unreachable("unexpected 'async' type kind"); });
60 }
61 
62 //===----------------------------------------------------------------------===//
63 /// ValueType
64 //===----------------------------------------------------------------------===//
65 
66 namespace mlir {
67 namespace async {
68 namespace detail {
69 
70 // Storage for `async.value<T>` type, the only member is the wrapped type.
71 struct ValueTypeStorage : public TypeStorage {
ValueTypeStoragemlir::async::detail::ValueTypeStorage72   ValueTypeStorage(Type valueType) : valueType(valueType) {}
73 
74   /// The hash key used for uniquing.
75   using KeyTy = Type;
operator ==mlir::async::detail::ValueTypeStorage76   bool operator==(const KeyTy &key) const { return key == valueType; }
77 
78   /// Construction.
constructmlir::async::detail::ValueTypeStorage79   static ValueTypeStorage *construct(TypeStorageAllocator &allocator,
80                                      Type valueType) {
81     return new (allocator.allocate<ValueTypeStorage>())
82         ValueTypeStorage(valueType);
83   }
84 
85   Type valueType;
86 };
87 
88 } // namespace detail
89 } // namespace async
90 } // namespace mlir
91 
get(Type valueType)92 ValueType ValueType::get(Type valueType) {
93   return Base::get(valueType.getContext(), valueType);
94 }
95 
getValueType()96 Type ValueType::getValueType() { return getImpl()->valueType; }
97 
98 //===----------------------------------------------------------------------===//
99 // YieldOp
100 //===----------------------------------------------------------------------===//
101 
verify(YieldOp op)102 static LogicalResult verify(YieldOp op) {
103   // Get the underlying value types from async values returned from the
104   // parent `async.execute` operation.
105   auto executeOp = op->getParentOfType<ExecuteOp>();
106   auto types = llvm::map_range(executeOp.results(), [](const OpResult &result) {
107     return result.getType().cast<ValueType>().getValueType();
108   });
109 
110   if (op.getOperandTypes() != types)
111     return op.emitOpError("operand types do not match the types returned from "
112                           "the parent ExecuteOp");
113 
114   return success();
115 }
116 
117 //===----------------------------------------------------------------------===//
118 /// ExecuteOp
119 //===----------------------------------------------------------------------===//
120 
121 constexpr char kOperandSegmentSizesAttr[] = "operand_segment_sizes";
122 
getNumRegionInvocations(ArrayRef<Attribute> operands,SmallVectorImpl<int64_t> & countPerRegion)123 void ExecuteOp::getNumRegionInvocations(
124     ArrayRef<Attribute> operands, SmallVectorImpl<int64_t> &countPerRegion) {
125   (void)operands;
126   assert(countPerRegion.empty());
127   countPerRegion.push_back(1);
128 }
129 
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)130 void ExecuteOp::getSuccessorRegions(Optional<unsigned> index,
131                                     ArrayRef<Attribute> operands,
132                                     SmallVectorImpl<RegionSuccessor> &regions) {
133   // The `body` region branch back to the parent operation.
134   if (index.hasValue()) {
135     assert(*index == 0);
136     regions.push_back(RegionSuccessor(getResults()));
137     return;
138   }
139 
140   // Otherwise the successor is the body region.
141   regions.push_back(RegionSuccessor(&body()));
142 }
143 
build(OpBuilder & builder,OperationState & result,TypeRange resultTypes,ValueRange dependencies,ValueRange operands,BodyBuilderFn bodyBuilder)144 void ExecuteOp::build(OpBuilder &builder, OperationState &result,
145                       TypeRange resultTypes, ValueRange dependencies,
146                       ValueRange operands, BodyBuilderFn bodyBuilder) {
147 
148   result.addOperands(dependencies);
149   result.addOperands(operands);
150 
151   // Add derived `operand_segment_sizes` attribute based on parsed operands.
152   int32_t numDependencies = dependencies.size();
153   int32_t numOperands = operands.size();
154   auto operandSegmentSizes = DenseIntElementsAttr::get(
155       VectorType::get({2}, IntegerType::get(32, result.getContext())),
156       {numDependencies, numOperands});
157   result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes);
158 
159   // First result is always a token, and then `resultTypes` wrapped into
160   // `async.value`.
161   result.addTypes({TokenType::get(result.getContext())});
162   for (Type type : resultTypes)
163     result.addTypes(ValueType::get(type));
164 
165   // Add a body region with block arguments as unwrapped async value operands.
166   Region *bodyRegion = result.addRegion();
167   bodyRegion->push_back(new Block);
168   Block &bodyBlock = bodyRegion->front();
169   for (Value operand : operands) {
170     auto valueType = operand.getType().dyn_cast<ValueType>();
171     bodyBlock.addArgument(valueType ? valueType.getValueType()
172                                     : operand.getType());
173   }
174 
175   // Create the default terminator if the builder is not provided and if the
176   // expected result is empty. Otherwise, leave this to the caller
177   // because we don't know which values to return from the execute op.
178   if (resultTypes.empty() && !bodyBuilder) {
179     OpBuilder::InsertionGuard guard(builder);
180     builder.setInsertionPointToStart(&bodyBlock);
181     builder.create<async::YieldOp>(result.location, ValueRange());
182   } else if (bodyBuilder) {
183     OpBuilder::InsertionGuard guard(builder);
184     builder.setInsertionPointToStart(&bodyBlock);
185     bodyBuilder(builder, result.location, bodyBlock.getArguments());
186   }
187 }
188 
print(OpAsmPrinter & p,ExecuteOp op)189 static void print(OpAsmPrinter &p, ExecuteOp op) {
190   p << op.getOperationName();
191 
192   // [%tokens,...]
193   if (!op.dependencies().empty())
194     p << " [" << op.dependencies() << "]";
195 
196   // (%value as %unwrapped: !async.value<!arg.type>, ...)
197   if (!op.operands().empty()) {
198     p << " (";
199     llvm::interleaveComma(op.operands(), p, [&, n = 0](Value operand) mutable {
200       p << operand << " as " << op.body().front().getArgument(n++) << ": "
201         << operand.getType();
202     });
203     p << ")";
204   }
205 
206   // -> (!async.value<!return.type>, ...)
207   p.printOptionalArrowTypeList(op.getResultTypes().drop_front(1));
208   p.printOptionalAttrDictWithKeyword(op.getAttrs(), {kOperandSegmentSizesAttr});
209   p.printRegion(op.body(), /*printEntryBlockArgs=*/false);
210 }
211 
parseExecuteOp(OpAsmParser & parser,OperationState & result)212 static ParseResult parseExecuteOp(OpAsmParser &parser, OperationState &result) {
213   MLIRContext *ctx = result.getContext();
214 
215   // Sizes of parsed variadic operands, will be updated below after parsing.
216   int32_t numDependencies = 0;
217   int32_t numOperands = 0;
218 
219   auto tokenTy = TokenType::get(ctx);
220 
221   // Parse dependency tokens.
222   if (succeeded(parser.parseOptionalLSquare())) {
223     SmallVector<OpAsmParser::OperandType, 4> tokenArgs;
224     if (parser.parseOperandList(tokenArgs) ||
225         parser.resolveOperands(tokenArgs, tokenTy, result.operands) ||
226         parser.parseRSquare())
227       return failure();
228 
229     numDependencies = tokenArgs.size();
230   }
231 
232   // Parse async value operands (%value as %unwrapped : !async.value<!type>).
233   SmallVector<OpAsmParser::OperandType, 4> valueArgs;
234   SmallVector<OpAsmParser::OperandType, 4> unwrappedArgs;
235   SmallVector<Type, 4> valueTypes;
236   SmallVector<Type, 4> unwrappedTypes;
237 
238   if (succeeded(parser.parseOptionalLParen())) {
239     auto argsLoc = parser.getCurrentLocation();
240 
241     // Parse a single instance of `%value as %unwrapped : !async.value<!type>`.
242     auto parseAsyncValueArg = [&]() -> ParseResult {
243       if (parser.parseOperand(valueArgs.emplace_back()) ||
244           parser.parseKeyword("as") ||
245           parser.parseOperand(unwrappedArgs.emplace_back()) ||
246           parser.parseColonType(valueTypes.emplace_back()))
247         return failure();
248 
249       auto valueTy = valueTypes.back().dyn_cast<ValueType>();
250       unwrappedTypes.emplace_back(valueTy ? valueTy.getValueType() : Type());
251 
252       return success();
253     };
254 
255     // If the next token is `)` skip async value arguments parsing.
256     if (failed(parser.parseOptionalRParen())) {
257       do {
258         if (parseAsyncValueArg())
259           return failure();
260       } while (succeeded(parser.parseOptionalComma()));
261 
262       if (parser.parseRParen() ||
263           parser.resolveOperands(valueArgs, valueTypes, argsLoc,
264                                  result.operands))
265         return failure();
266     }
267 
268     numOperands = valueArgs.size();
269   }
270 
271   // Add derived `operand_segment_sizes` attribute based on parsed operands.
272   auto operandSegmentSizes = DenseIntElementsAttr::get(
273       VectorType::get({2}, parser.getBuilder().getI32Type()),
274       {numDependencies, numOperands});
275   result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes);
276 
277   // Parse the types of results returned from the async execute op.
278   SmallVector<Type, 4> resultTypes;
279   if (parser.parseOptionalArrowTypeList(resultTypes))
280     return failure();
281 
282   // Async execute first result is always a completion token.
283   parser.addTypeToList(tokenTy, result.types);
284   parser.addTypesToList(resultTypes, result.types);
285 
286   // Parse operation attributes.
287   NamedAttrList attrs;
288   if (parser.parseOptionalAttrDictWithKeyword(attrs))
289     return failure();
290   result.addAttributes(attrs);
291 
292   // Parse asynchronous region.
293   Region *body = result.addRegion();
294   if (parser.parseRegion(*body, /*arguments=*/{unwrappedArgs},
295                          /*argTypes=*/{unwrappedTypes},
296                          /*enableNameShadowing=*/false))
297     return failure();
298 
299   return success();
300 }
301 
verify(ExecuteOp op)302 static LogicalResult verify(ExecuteOp op) {
303   // Unwrap async.execute value operands types.
304   auto unwrappedTypes = llvm::map_range(op.operands(), [](Value operand) {
305     return operand.getType().cast<ValueType>().getValueType();
306   });
307 
308   // Verify that unwrapped argument types matches the body region arguments.
309   if (op.body().getArgumentTypes() != unwrappedTypes)
310     return op.emitOpError("async body region argument types do not match the "
311                           "execute operation arguments types");
312 
313   return success();
314 }
315 
316 //===----------------------------------------------------------------------===//
317 /// AwaitOp
318 //===----------------------------------------------------------------------===//
319 
build(OpBuilder & builder,OperationState & result,Value operand,ArrayRef<NamedAttribute> attrs)320 void AwaitOp::build(OpBuilder &builder, OperationState &result, Value operand,
321                     ArrayRef<NamedAttribute> attrs) {
322   result.addOperands({operand});
323   result.attributes.append(attrs.begin(), attrs.end());
324 
325   // Add unwrapped async.value type to the returned values types.
326   if (auto valueType = operand.getType().dyn_cast<ValueType>())
327     result.addTypes(valueType.getValueType());
328 }
329 
parseAwaitResultType(OpAsmParser & parser,Type & operandType,Type & resultType)330 static ParseResult parseAwaitResultType(OpAsmParser &parser, Type &operandType,
331                                         Type &resultType) {
332   if (parser.parseType(operandType))
333     return failure();
334 
335   // Add unwrapped async.value type to the returned values types.
336   if (auto valueType = operandType.dyn_cast<ValueType>())
337     resultType = valueType.getValueType();
338 
339   return success();
340 }
341 
printAwaitResultType(OpAsmPrinter & p,Operation * op,Type operandType,Type resultType)342 static void printAwaitResultType(OpAsmPrinter &p, Operation *op,
343                                  Type operandType, Type resultType) {
344   p << operandType;
345 }
346 
verify(AwaitOp op)347 static LogicalResult verify(AwaitOp op) {
348   Type argType = op.operand().getType();
349 
350   // Awaiting on a token does not have any results.
351   if (argType.isa<TokenType>() && !op.getResultTypes().empty())
352     return op.emitOpError("awaiting on a token must have empty result");
353 
354   // Awaiting on a value unwraps the async value type.
355   if (auto value = argType.dyn_cast<ValueType>()) {
356     if (*op.getResultType() != value.getValueType())
357       return op.emitOpError()
358              << "result type " << *op.getResultType()
359              << " does not match async value type " << value.getValueType();
360   }
361 
362   return success();
363 }
364 
365 #define GET_OP_CLASSES
366 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
367